1//===- DAGCombiner.cpp - Implement a DAG node combiner --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass combines dag nodes to form fewer, simpler DAG nodes. It can be run
10// both before and after the DAG is legalized.
11//
12// This pass is not a substitute for the LLVM IR instcombine pass. This pass is
13// primarily intended to handle simplification opportunities that are implicit
14// in the LLVM IR and exposed by the various codegen lowering phases.
15//
16//===----------------------------------------------------------------------===//
17
18#include "llvm/ADT/APFloat.h"
19#include "llvm/ADT/APInt.h"
20#include "llvm/ADT/ArrayRef.h"
21#include "llvm/ADT/DenseMap.h"
22#include "llvm/ADT/IntervalMap.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SetVector.h"
25#include "llvm/ADT/SmallBitVector.h"
26#include "llvm/ADT/SmallPtrSet.h"
27#include "llvm/ADT/SmallSet.h"
28#include "llvm/ADT/SmallVector.h"
29#include "llvm/ADT/Statistic.h"
30#include "llvm/Analysis/AliasAnalysis.h"
31#include "llvm/Analysis/MemoryLocation.h"
32#include "llvm/Analysis/TargetLibraryInfo.h"
33#include "llvm/Analysis/ValueTracking.h"
34#include "llvm/Analysis/VectorUtils.h"
35#include "llvm/CodeGen/ByteProvider.h"
36#include "llvm/CodeGen/DAGCombine.h"
37#include "llvm/CodeGen/ISDOpcodes.h"
38#include "llvm/CodeGen/MachineFunction.h"
39#include "llvm/CodeGen/MachineMemOperand.h"
40#include "llvm/CodeGen/SDPatternMatch.h"
41#include "llvm/CodeGen/SelectionDAG.h"
42#include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
43#include "llvm/CodeGen/SelectionDAGNodes.h"
44#include "llvm/CodeGen/SelectionDAGTargetInfo.h"
45#include "llvm/CodeGen/TargetLowering.h"
46#include "llvm/CodeGen/TargetRegisterInfo.h"
47#include "llvm/CodeGen/TargetSubtargetInfo.h"
48#include "llvm/CodeGen/ValueTypes.h"
49#include "llvm/CodeGenTypes/MachineValueType.h"
50#include "llvm/IR/Attributes.h"
51#include "llvm/IR/Constant.h"
52#include "llvm/IR/DataLayout.h"
53#include "llvm/IR/DerivedTypes.h"
54#include "llvm/IR/Function.h"
55#include "llvm/IR/Metadata.h"
56#include "llvm/Support/Casting.h"
57#include "llvm/Support/CodeGen.h"
58#include "llvm/Support/CommandLine.h"
59#include "llvm/Support/Compiler.h"
60#include "llvm/Support/Debug.h"
61#include "llvm/Support/DebugCounter.h"
62#include "llvm/Support/ErrorHandling.h"
63#include "llvm/Support/KnownBits.h"
64#include "llvm/Support/MathExtras.h"
65#include "llvm/Support/raw_ostream.h"
66#include "llvm/Target/TargetMachine.h"
67#include "llvm/Target/TargetOptions.h"
68#include <algorithm>
69#include <cassert>
70#include <cstdint>
71#include <functional>
72#include <iterator>
73#include <optional>
74#include <string>
75#include <tuple>
76#include <utility>
77#include <variant>
78
79#include "MatchContext.h"
80
81using namespace llvm;
82using namespace llvm::SDPatternMatch;
83
84#define DEBUG_TYPE "dagcombine"
85
86STATISTIC(NodesCombined , "Number of dag nodes combined");
87STATISTIC(PreIndexedNodes , "Number of pre-indexed nodes created");
88STATISTIC(PostIndexedNodes, "Number of post-indexed nodes created");
89STATISTIC(OpsNarrowed , "Number of load/op/store narrowed");
90STATISTIC(LdStFP2Int , "Number of fp load/store pairs transformed to int");
91STATISTIC(SlicedLoads, "Number of load sliced");
92STATISTIC(NumFPLogicOpsConv, "Number of logic ops converted to fp ops");
93
94DEBUG_COUNTER(DAGCombineCounter, "dagcombine",
95 "Controls whether a DAG combine is performed for a node");
96
97static cl::opt<bool>
98CombinerGlobalAA("combiner-global-alias-analysis", cl::Hidden,
99 cl::desc("Enable DAG combiner's use of IR alias analysis"));
100
101static cl::opt<bool>
102UseTBAA("combiner-use-tbaa", cl::Hidden, cl::init(Val: true),
103 cl::desc("Enable DAG combiner's use of TBAA"));
104
105#ifndef NDEBUG
106static cl::opt<std::string>
107CombinerAAOnlyFunc("combiner-aa-only-func", cl::Hidden,
108 cl::desc("Only use DAG-combiner alias analysis in this"
109 " function"));
110#endif
111
112/// Hidden option to stress test load slicing, i.e., when this option
113/// is enabled, load slicing bypasses most of its profitability guards.
114static cl::opt<bool>
115StressLoadSlicing("combiner-stress-load-slicing", cl::Hidden,
116 cl::desc("Bypass the profitability model of load slicing"),
117 cl::init(Val: false));
118
119static cl::opt<bool>
120 MaySplitLoadIndex("combiner-split-load-index", cl::Hidden, cl::init(Val: true),
121 cl::desc("DAG combiner may split indexing from loads"));
122
123static cl::opt<bool>
124 EnableStoreMerging("combiner-store-merging", cl::Hidden, cl::init(Val: true),
125 cl::desc("DAG combiner enable merging multiple stores "
126 "into a wider store"));
127
128static cl::opt<unsigned> TokenFactorInlineLimit(
129 "combiner-tokenfactor-inline-limit", cl::Hidden, cl::init(Val: 2048),
130 cl::desc("Limit the number of operands to inline for Token Factors"));
131
132static cl::opt<unsigned> StoreMergeDependenceLimit(
133 "combiner-store-merge-dependence-limit", cl::Hidden, cl::init(Val: 10),
134 cl::desc("Limit the number of times for the same StoreNode and RootNode "
135 "to bail out in store merging dependence check"));
136
137static cl::opt<bool> EnableReduceLoadOpStoreWidth(
138 "combiner-reduce-load-op-store-width", cl::Hidden, cl::init(Val: true),
139 cl::desc("DAG combiner enable reducing the width of load/op/store "
140 "sequence"));
141static cl::opt<bool> ReduceLoadOpStoreWidthForceNarrowingProfitable(
142 "combiner-reduce-load-op-store-width-force-narrowing-profitable",
143 cl::Hidden, cl::init(Val: false),
144 cl::desc("DAG combiner force override the narrowing profitable check when "
145 "reducing the width of load/op/store sequences"));
146
147static cl::opt<bool> EnableShrinkLoadReplaceStoreWithStore(
148 "combiner-shrink-load-replace-store-with-store", cl::Hidden, cl::init(Val: true),
149 cl::desc("DAG combiner enable load/<replace bytes>/store with "
150 "a narrower store"));
151
152static cl::opt<bool> DisableCombines("combiner-disabled", cl::Hidden,
153 cl::init(Val: false),
154 cl::desc("Disable the DAG combiner"));
155
156namespace {
157
158 class DAGCombiner {
159 SelectionDAG &DAG;
160 const TargetLowering &TLI;
161 const SelectionDAGTargetInfo *STI;
162 CombineLevel Level = BeforeLegalizeTypes;
163 CodeGenOptLevel OptLevel;
164 bool LegalDAG = false;
165 bool LegalOperations = false;
166 bool LegalTypes = false;
167 bool ForCodeSize;
168 bool DisableGenericCombines;
169
170 /// Worklist of all of the nodes that need to be simplified.
171 ///
172 /// This must behave as a stack -- new nodes to process are pushed onto the
173 /// back and when processing we pop off of the back.
174 ///
175 /// The worklist will not contain duplicates but may contain null entries
176 /// due to nodes being deleted from the underlying DAG. For fast lookup and
177 /// deduplication, the index of the node in this vector is stored in the
178 /// node in SDNode::CombinerWorklistIndex.
179 SmallVector<SDNode *, 64> Worklist;
180
181 /// This records all nodes attempted to be added to the worklist since we
182 /// considered a new worklist entry. As we keep do not add duplicate nodes
183 /// in the worklist, this is different from the tail of the worklist.
184 SmallSetVector<SDNode *, 32> PruningList;
185
186 /// Map from candidate StoreNode to the pair of RootNode and count.
187 /// The count is used to track how many times we have seen the StoreNode
188 /// with the same RootNode bail out in dependence check. If we have seen
189 /// the bail out for the same pair many times over a limit, we won't
190 /// consider the StoreNode with the same RootNode as store merging
191 /// candidate again.
192 DenseMap<SDNode *, std::pair<SDNode *, unsigned>> StoreRootCountMap;
193
194 // BatchAA - Used for DAG load/store alias analysis.
195 BatchAAResults *BatchAA;
196
197 /// This caches all chains that have already been processed in
198 /// DAGCombiner::getStoreMergeCandidates() and found to have no mergeable
199 /// stores candidates.
200 SmallPtrSet<SDNode *, 4> ChainsWithoutMergeableStores;
201
202 /// When an instruction is simplified, add all users of the instruction to
203 /// the work lists because they might get more simplified now.
204 void AddUsersToWorklist(SDNode *N) {
205 for (SDNode *Node : N->users())
206 AddToWorklist(N: Node);
207 }
208
209 /// Convenient shorthand to add a node and all of its user to the worklist.
210 void AddToWorklistWithUsers(SDNode *N) {
211 AddUsersToWorklist(N);
212 AddToWorklist(N);
213 }
214
215 // Prune potentially dangling nodes. This is called after
216 // any visit to a node, but should also be called during a visit after any
217 // failed combine which may have created a DAG node.
218 void clearAddedDanglingWorklistEntries() {
219 // Check any nodes added to the worklist to see if they are prunable.
220 while (!PruningList.empty()) {
221 auto *N = PruningList.pop_back_val();
222 if (N->use_empty())
223 recursivelyDeleteUnusedNodes(N);
224 }
225 }
226
227 SDNode *getNextWorklistEntry() {
228 // Before we do any work, remove nodes that are not in use.
229 clearAddedDanglingWorklistEntries();
230 SDNode *N = nullptr;
231 // The Worklist holds the SDNodes in order, but it may contain null
232 // entries.
233 while (!N && !Worklist.empty()) {
234 N = Worklist.pop_back_val();
235 }
236
237 if (N) {
238 assert(N->getCombinerWorklistIndex() >= 0 &&
239 "Found a worklist entry without a corresponding map entry!");
240 // Set to -2 to indicate that we combined the node.
241 N->setCombinerWorklistIndex(-2);
242 }
243 return N;
244 }
245
246 /// Call the node-specific routine that folds each particular type of node.
247 SDValue visit(SDNode *N);
248
249 public:
250 DAGCombiner(SelectionDAG &D, BatchAAResults *BatchAA, CodeGenOptLevel OL)
251 : DAG(D), TLI(D.getTargetLoweringInfo()),
252 STI(D.getSubtarget().getSelectionDAGInfo()), OptLevel(OL),
253 BatchAA(BatchAA) {
254 ForCodeSize = DAG.shouldOptForSize();
255 DisableGenericCombines =
256 DisableCombines || (STI && STI->disableGenericCombines(OptLevel));
257
258 MaximumLegalStoreInBits = 0;
259 // We use the minimum store size here, since that's all we can guarantee
260 // for the scalable vector types.
261 for (MVT VT : MVT::all_valuetypes())
262 if (EVT(VT).isSimple() && VT != MVT::Other &&
263 TLI.isTypeLegal(VT: EVT(VT)) &&
264 VT.getSizeInBits().getKnownMinValue() >= MaximumLegalStoreInBits)
265 MaximumLegalStoreInBits = VT.getSizeInBits().getKnownMinValue();
266 }
267
268 void ConsiderForPruning(SDNode *N) {
269 // Mark this for potential pruning.
270 PruningList.insert(X: N);
271 }
272
273 /// Add to the worklist making sure its instance is at the back (next to be
274 /// processed.)
275 void AddToWorklist(SDNode *N, bool IsCandidateForPruning = true,
276 bool SkipIfCombinedBefore = false) {
277 assert(N->getOpcode() != ISD::DELETED_NODE &&
278 "Deleted Node added to Worklist");
279
280 // Skip handle nodes as they can't usefully be combined and confuse the
281 // zero-use deletion strategy.
282 if (N->getOpcode() == ISD::HANDLENODE)
283 return;
284
285 if (SkipIfCombinedBefore && N->getCombinerWorklistIndex() == -2)
286 return;
287
288 if (IsCandidateForPruning)
289 ConsiderForPruning(N);
290
291 if (N->getCombinerWorklistIndex() < 0) {
292 N->setCombinerWorklistIndex(Worklist.size());
293 Worklist.push_back(Elt: N);
294 }
295 }
296
297 /// Remove all instances of N from the worklist.
298 void removeFromWorklist(SDNode *N) {
299 PruningList.remove(X: N);
300 StoreRootCountMap.erase(Val: N);
301
302 int WorklistIndex = N->getCombinerWorklistIndex();
303 // If not in the worklist, the index might be -1 or -2 (was combined
304 // before). As the node gets deleted anyway, there's no need to update
305 // the index.
306 if (WorklistIndex < 0)
307 return; // Not in the worklist.
308
309 // Null out the entry rather than erasing it to avoid a linear operation.
310 Worklist[WorklistIndex] = nullptr;
311 N->setCombinerWorklistIndex(-1);
312 }
313
314 void deleteAndRecombine(SDNode *N);
315 bool recursivelyDeleteUnusedNodes(SDNode *N);
316
317 /// Replaces all uses of the results of one DAG node with new values.
318 SDValue CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
319 bool AddTo = true);
320
321 /// Replaces all uses of the results of one DAG node with new values.
322 SDValue CombineTo(SDNode *N, SDValue Res, bool AddTo = true) {
323 return CombineTo(N, To: &Res, NumTo: 1, AddTo);
324 }
325
326 /// Replaces all uses of the results of one DAG node with new values.
327 SDValue CombineTo(SDNode *N, SDValue Res0, SDValue Res1,
328 bool AddTo = true) {
329 SDValue To[] = { Res0, Res1 };
330 return CombineTo(N, To, NumTo: 2, AddTo);
331 }
332
333 void CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO);
334
335 private:
336 unsigned MaximumLegalStoreInBits;
337
338 /// Check the specified integer node value to see if it can be simplified or
339 /// if things it uses can be simplified by bit propagation.
340 /// If so, return true.
341 bool SimplifyDemandedBits(SDValue Op) {
342 unsigned BitWidth = Op.getScalarValueSizeInBits();
343 APInt DemandedBits = APInt::getAllOnes(numBits: BitWidth);
344 return SimplifyDemandedBits(Op, DemandedBits);
345 }
346
347 bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits) {
348 EVT VT = Op.getValueType();
349 APInt DemandedElts = VT.isFixedLengthVector()
350 ? APInt::getAllOnes(numBits: VT.getVectorNumElements())
351 : APInt(1, 1);
352 return SimplifyDemandedBits(Op, DemandedBits, DemandedElts, AssumeSingleUse: false);
353 }
354
355 /// Check the specified vector node value to see if it can be simplified or
356 /// if things it uses can be simplified as it only uses some of the
357 /// elements. If so, return true.
358 bool SimplifyDemandedVectorElts(SDValue Op) {
359 // TODO: For now just pretend it cannot be simplified.
360 if (Op.getValueType().isScalableVector())
361 return false;
362
363 unsigned NumElts = Op.getValueType().getVectorNumElements();
364 APInt DemandedElts = APInt::getAllOnes(numBits: NumElts);
365 return SimplifyDemandedVectorElts(Op, DemandedElts);
366 }
367
368 bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
369 const APInt &DemandedElts,
370 bool AssumeSingleUse = false);
371 bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedElts,
372 bool AssumeSingleUse = false);
373
374 bool CombineToPreIndexedLoadStore(SDNode *N);
375 bool CombineToPostIndexedLoadStore(SDNode *N);
376 SDValue SplitIndexingFromLoad(LoadSDNode *LD);
377 bool SliceUpLoad(SDNode *N);
378
379 // Looks up the chain to find a unique (unaliased) store feeding the passed
380 // load. If no such store is found, returns a nullptr.
381 // Note: This will look past a CALLSEQ_START if the load is chained to it so
382 // so that it can find stack stores for byval params.
383 StoreSDNode *getUniqueStoreFeeding(LoadSDNode *LD, int64_t &Offset);
384 // Scalars have size 0 to distinguish from singleton vectors.
385 SDValue ForwardStoreValueToDirectLoad(LoadSDNode *LD);
386 bool getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val);
387 bool extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val);
388
389 void ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad);
390 SDValue PromoteOperand(SDValue Op, EVT PVT, bool &Replace);
391 SDValue SExtPromoteOperand(SDValue Op, EVT PVT);
392 SDValue ZExtPromoteOperand(SDValue Op, EVT PVT);
393 SDValue PromoteIntBinOp(SDValue Op);
394 SDValue PromoteIntShiftOp(SDValue Op);
395 SDValue PromoteExtend(SDValue Op);
396 bool PromoteLoad(SDValue Op);
397
398 SDValue foldShiftToAvg(SDNode *N);
399 // Fold `a bitwiseop (~b +/- c)` -> `a bitwiseop ~(b -/+ c)`
400 SDValue foldBitwiseOpWithNeg(SDNode *N, const SDLoc &DL, EVT VT);
401
402 SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
403 SDValue RHS, SDValue True, SDValue False,
404 ISD::CondCode CC);
405
406 /// Call the node-specific routine that knows how to fold each
407 /// particular type of node. If that doesn't do anything, try the
408 /// target-specific DAG combines.
409 SDValue combine(SDNode *N);
410
411 // Visitation implementation - Implement dag node combining for different
412 // node types. The semantics are as follows:
413 // Return Value:
414 // SDValue.getNode() == 0 - No change was made
415 // SDValue.getNode() == N - N was replaced, is dead and has been handled.
416 // otherwise - N should be replaced by the returned Operand.
417 //
418 SDValue visitTokenFactor(SDNode *N);
419 SDValue visitMERGE_VALUES(SDNode *N);
420 SDValue visitADD(SDNode *N);
421 SDValue visitADDLike(SDNode *N);
422 SDValue visitADDLikeCommutative(SDValue N0, SDValue N1,
423 SDNode *LocReference);
424 SDValue visitPTRADD(SDNode *N);
425 SDValue visitSUB(SDNode *N);
426 SDValue visitADDSAT(SDNode *N);
427 SDValue visitSUBSAT(SDNode *N);
428 SDValue visitADDC(SDNode *N);
429 SDValue visitADDO(SDNode *N);
430 SDValue visitUADDOLike(SDValue N0, SDValue N1, SDNode *N);
431 SDValue visitSUBC(SDNode *N);
432 SDValue visitSUBO(SDNode *N);
433 SDValue visitADDE(SDNode *N);
434 SDValue visitUADDO_CARRY(SDNode *N);
435 SDValue visitSADDO_CARRY(SDNode *N);
436 SDValue visitUADDO_CARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
437 SDNode *N);
438 SDValue visitSADDO_CARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
439 SDNode *N);
440 SDValue visitSUBE(SDNode *N);
441 SDValue visitUSUBO_CARRY(SDNode *N);
442 SDValue visitSSUBO_CARRY(SDNode *N);
443 template <class MatchContextClass> SDValue visitMUL(SDNode *N);
444 SDValue visitMULFIX(SDNode *N);
445 SDValue useDivRem(SDNode *N);
446 SDValue visitSDIV(SDNode *N);
447 SDValue visitSDIVLike(SDValue N0, SDValue N1, SDNode *N);
448 SDValue visitUDIV(SDNode *N);
449 SDValue visitUDIVLike(SDValue N0, SDValue N1, SDNode *N);
450 SDValue visitREM(SDNode *N);
451 SDValue visitMULHU(SDNode *N);
452 SDValue visitMULHS(SDNode *N);
453 SDValue visitAVG(SDNode *N);
454 SDValue visitABD(SDNode *N);
455 SDValue visitSMUL_LOHI(SDNode *N);
456 SDValue visitUMUL_LOHI(SDNode *N);
457 SDValue visitMULO(SDNode *N);
458 SDValue visitIMINMAX(SDNode *N);
459 SDValue visitAND(SDNode *N);
460 SDValue visitANDLike(SDValue N0, SDValue N1, SDNode *N);
461 SDValue visitOR(SDNode *N);
462 SDValue visitORLike(SDValue N0, SDValue N1, const SDLoc &DL);
463 SDValue visitXOR(SDNode *N);
464 SDValue SimplifyVCastOp(SDNode *N, const SDLoc &DL);
465 SDValue SimplifyVBinOp(SDNode *N, const SDLoc &DL);
466 SDValue visitSHL(SDNode *N);
467 SDValue visitSRA(SDNode *N);
468 SDValue visitSRL(SDNode *N);
469 SDValue visitFunnelShift(SDNode *N);
470 SDValue visitSHLSAT(SDNode *N);
471 SDValue visitRotate(SDNode *N);
472 SDValue visitABS(SDNode *N);
473 SDValue visitBSWAP(SDNode *N);
474 SDValue visitBITREVERSE(SDNode *N);
475 SDValue visitCTLZ(SDNode *N);
476 SDValue visitCTLZ_ZERO_UNDEF(SDNode *N);
477 SDValue visitCTTZ(SDNode *N);
478 SDValue visitCTTZ_ZERO_UNDEF(SDNode *N);
479 SDValue visitCTPOP(SDNode *N);
480 SDValue visitSELECT(SDNode *N);
481 SDValue visitVSELECT(SDNode *N);
482 SDValue visitVP_SELECT(SDNode *N);
483 SDValue visitSELECT_CC(SDNode *N);
484 SDValue visitSETCC(SDNode *N);
485 SDValue visitSETCCCARRY(SDNode *N);
486 SDValue visitSIGN_EXTEND(SDNode *N);
487 SDValue visitZERO_EXTEND(SDNode *N);
488 SDValue visitANY_EXTEND(SDNode *N);
489 SDValue visitAssertExt(SDNode *N);
490 SDValue visitAssertAlign(SDNode *N);
491 SDValue visitSIGN_EXTEND_INREG(SDNode *N);
492 SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
493 SDValue visitTRUNCATE(SDNode *N);
494 SDValue visitTRUNCATE_USAT_U(SDNode *N);
495 SDValue visitBITCAST(SDNode *N);
496 SDValue visitFREEZE(SDNode *N);
497 SDValue visitBUILD_PAIR(SDNode *N);
498 SDValue visitFADD(SDNode *N);
499 SDValue visitVP_FADD(SDNode *N);
500 SDValue visitVP_FSUB(SDNode *N);
501 SDValue visitSTRICT_FADD(SDNode *N);
502 SDValue visitFSUB(SDNode *N);
503 SDValue visitFMUL(SDNode *N);
504 template <class MatchContextClass> SDValue visitFMA(SDNode *N);
505 SDValue visitFMAD(SDNode *N);
506 SDValue visitFDIV(SDNode *N);
507 SDValue visitFREM(SDNode *N);
508 SDValue visitFSQRT(SDNode *N);
509 SDValue visitFCOPYSIGN(SDNode *N);
510 SDValue visitFPOW(SDNode *N);
511 SDValue visitFCANONICALIZE(SDNode *N);
512 SDValue visitSINT_TO_FP(SDNode *N);
513 SDValue visitUINT_TO_FP(SDNode *N);
514 SDValue visitFP_TO_SINT(SDNode *N);
515 SDValue visitFP_TO_UINT(SDNode *N);
516 SDValue visitXROUND(SDNode *N);
517 SDValue visitFP_ROUND(SDNode *N);
518 SDValue visitFP_EXTEND(SDNode *N);
519 SDValue visitFNEG(SDNode *N);
520 SDValue visitFABS(SDNode *N);
521 SDValue visitFCEIL(SDNode *N);
522 SDValue visitFTRUNC(SDNode *N);
523 SDValue visitFFREXP(SDNode *N);
524 SDValue visitFFLOOR(SDNode *N);
525 SDValue visitFMinMax(SDNode *N);
526 SDValue visitBRCOND(SDNode *N);
527 SDValue visitBR_CC(SDNode *N);
528 SDValue visitLOAD(SDNode *N);
529
530 SDValue replaceStoreChain(StoreSDNode *ST, SDValue BetterChain);
531 SDValue replaceStoreOfFPConstant(StoreSDNode *ST);
532 SDValue replaceStoreOfInsertLoad(StoreSDNode *ST);
533
534 bool refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(SDNode *N);
535
536 SDValue visitSTORE(SDNode *N);
537 SDValue visitATOMIC_STORE(SDNode *N);
538 SDValue visitLIFETIME_END(SDNode *N);
539 SDValue visitINSERT_VECTOR_ELT(SDNode *N);
540 SDValue visitEXTRACT_VECTOR_ELT(SDNode *N);
541 SDValue visitBUILD_VECTOR(SDNode *N);
542 SDValue visitCONCAT_VECTORS(SDNode *N);
543 SDValue visitEXTRACT_SUBVECTOR(SDNode *N);
544 SDValue visitVECTOR_SHUFFLE(SDNode *N);
545 SDValue visitSCALAR_TO_VECTOR(SDNode *N);
546 SDValue visitINSERT_SUBVECTOR(SDNode *N);
547 SDValue visitVECTOR_COMPRESS(SDNode *N);
548 SDValue visitMLOAD(SDNode *N);
549 SDValue visitMSTORE(SDNode *N);
550 SDValue visitMGATHER(SDNode *N);
551 SDValue visitMSCATTER(SDNode *N);
552 SDValue visitMHISTOGRAM(SDNode *N);
553 SDValue visitPARTIAL_REDUCE_MLA(SDNode *N);
554 SDValue visitVPGATHER(SDNode *N);
555 SDValue visitVPSCATTER(SDNode *N);
556 SDValue visitVP_STRIDED_LOAD(SDNode *N);
557 SDValue visitVP_STRIDED_STORE(SDNode *N);
558 SDValue visitFP_TO_FP16(SDNode *N);
559 SDValue visitFP16_TO_FP(SDNode *N);
560 SDValue visitFP_TO_BF16(SDNode *N);
561 SDValue visitBF16_TO_FP(SDNode *N);
562 SDValue visitVECREDUCE(SDNode *N);
563 SDValue visitVPOp(SDNode *N);
564 SDValue visitGET_FPENV_MEM(SDNode *N);
565 SDValue visitSET_FPENV_MEM(SDNode *N);
566
567 template <class MatchContextClass>
568 SDValue visitFADDForFMACombine(SDNode *N);
569 template <class MatchContextClass>
570 SDValue visitFSUBForFMACombine(SDNode *N);
571 SDValue visitFMULForFMADistributiveCombine(SDNode *N);
572
573 SDValue XformToShuffleWithZero(SDNode *N);
574 bool reassociationCanBreakAddressingModePattern(unsigned Opc,
575 const SDLoc &DL,
576 SDNode *N,
577 SDValue N0,
578 SDValue N1);
579 SDValue reassociateOpsCommutative(unsigned Opc, const SDLoc &DL, SDValue N0,
580 SDValue N1, SDNodeFlags Flags);
581 SDValue reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
582 SDValue N1, SDNodeFlags Flags);
583 SDValue reassociateReduction(unsigned RedOpc, unsigned Opc, const SDLoc &DL,
584 EVT VT, SDValue N0, SDValue N1,
585 SDNodeFlags Flags = SDNodeFlags());
586
587 SDValue visitShiftByConstant(SDNode *N);
588
589 SDValue foldSelectOfConstants(SDNode *N);
590 SDValue foldVSelectOfConstants(SDNode *N);
591 SDValue foldBinOpIntoSelect(SDNode *BO);
592 bool SimplifySelectOps(SDNode *SELECT, SDValue LHS, SDValue RHS);
593 SDValue hoistLogicOpWithSameOpcodeHands(SDNode *N);
594 SDValue SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2);
595 SDValue SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
596 SDValue N2, SDValue N3, ISD::CondCode CC,
597 bool NotExtCompare = false);
598 SDValue convertSelectOfFPConstantsToLoadOffset(
599 const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
600 ISD::CondCode CC);
601 SDValue foldSignChangeInBitcast(SDNode *N);
602 SDValue foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0, SDValue N1,
603 SDValue N2, SDValue N3, ISD::CondCode CC);
604 SDValue foldSelectOfBinops(SDNode *N);
605 SDValue foldSextSetcc(SDNode *N);
606 SDValue foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
607 const SDLoc &DL);
608 SDValue foldSubToUSubSat(EVT DstVT, SDNode *N, const SDLoc &DL);
609 SDValue foldABSToABD(SDNode *N, const SDLoc &DL);
610 SDValue foldSelectToABD(SDValue LHS, SDValue RHS, SDValue True,
611 SDValue False, ISD::CondCode CC, const SDLoc &DL);
612 SDValue unfoldMaskedMerge(SDNode *N);
613 SDValue unfoldExtremeBitClearingToShifts(SDNode *N);
614 SDValue SimplifySetCC(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
615 const SDLoc &DL, bool foldBooleans);
616 SDValue rebuildSetCC(SDValue N);
617
618 bool isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
619 SDValue &CC, bool MatchStrict = false) const;
620 bool isOneUseSetCC(SDValue N) const;
621
622 SDValue foldAddToAvg(SDNode *N, const SDLoc &DL);
623 SDValue foldSubToAvg(SDNode *N, const SDLoc &DL);
624
625 SDValue SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
626 unsigned HiOp);
627 SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
628 SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
629 const TargetLowering &TLI);
630 SDValue foldPartialReduceMLAMulOp(SDNode *N);
631 SDValue foldPartialReduceAdd(SDNode *N);
632
633 SDValue CombineExtLoad(SDNode *N);
634 SDValue CombineZExtLogicopShiftLoad(SDNode *N);
635 SDValue combineRepeatedFPDivisors(SDNode *N);
636 SDValue combineFMulOrFDivWithIntPow2(SDNode *N);
637 SDValue replaceShuffleOfInsert(ShuffleVectorSDNode *Shuf);
638 SDValue mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex);
639 SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex);
640 SDValue combineInsertEltToLoad(SDNode *N, unsigned InsIndex);
641 SDValue ConstantFoldBITCASTofBUILD_VECTOR(SDNode *, EVT);
642 SDValue BuildSDIV(SDNode *N);
643 SDValue BuildSDIVPow2(SDNode *N);
644 SDValue BuildUDIV(SDNode *N);
645 SDValue BuildSREMPow2(SDNode *N);
646 SDValue buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N);
647 SDValue BuildLogBase2(SDValue V, const SDLoc &DL,
648 bool KnownNeverZero = false,
649 bool InexpensiveOnly = false,
650 std::optional<EVT> OutVT = std::nullopt);
651 SDValue BuildDivEstimate(SDValue N, SDValue Op, SDNodeFlags Flags);
652 SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags);
653 SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags Flags);
654 SDValue buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags, bool Recip);
655 SDValue buildSqrtNROneConst(SDValue Arg, SDValue Est, unsigned Iterations,
656 SDNodeFlags Flags, bool Reciprocal);
657 SDValue buildSqrtNRTwoConst(SDValue Arg, SDValue Est, unsigned Iterations,
658 SDNodeFlags Flags, bool Reciprocal);
659 SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
660 bool DemandHighBits = true);
661 SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
662 SDValue MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg,
663 SDValue InnerPos, SDValue InnerNeg, bool FromAdd,
664 bool HasPos, unsigned PosOpcode,
665 unsigned NegOpcode, const SDLoc &DL);
666 SDValue MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos, SDValue Neg,
667 SDValue InnerPos, SDValue InnerNeg, bool FromAdd,
668 bool HasPos, unsigned PosOpcode,
669 unsigned NegOpcode, const SDLoc &DL);
670 SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL,
671 bool FromAdd);
672 SDValue MatchLoadCombine(SDNode *N);
673 SDValue mergeTruncStores(StoreSDNode *N);
674 SDValue reduceLoadWidth(SDNode *N);
675 SDValue ReduceLoadOpStoreWidth(SDNode *N);
676 SDValue splitMergedValStore(StoreSDNode *ST);
677 SDValue TransformFPLoadStorePair(SDNode *N);
678 SDValue convertBuildVecZextToZext(SDNode *N);
679 SDValue convertBuildVecZextToBuildVecWithZeros(SDNode *N);
680 SDValue reduceBuildVecExtToExtBuildVec(SDNode *N);
681 SDValue reduceBuildVecTruncToBitCast(SDNode *N);
682 SDValue reduceBuildVecToShuffle(SDNode *N);
683 SDValue createBuildVecShuffle(const SDLoc &DL, SDNode *N,
684 ArrayRef<int> VectorMask, SDValue VecIn1,
685 SDValue VecIn2, unsigned LeftIdx,
686 bool DidSplitVec);
687 SDValue matchVSelectOpSizesWithSetCC(SDNode *Cast);
688
689 /// Walk up chain skipping non-aliasing memory nodes,
690 /// looking for aliasing nodes and adding them to the Aliases vector.
691 void GatherAllAliases(SDNode *N, SDValue OriginalChain,
692 SmallVectorImpl<SDValue> &Aliases);
693
694 /// Return true if there is any possibility that the two addresses overlap.
695 bool mayAlias(SDNode *Op0, SDNode *Op1) const;
696
697 /// Walk up chain skipping non-aliasing memory nodes, looking for a better
698 /// chain (aliasing node.)
699 SDValue FindBetterChain(SDNode *N, SDValue Chain);
700
701 /// Try to replace a store and any possibly adjacent stores on
702 /// consecutive chains with better chains. Return true only if St is
703 /// replaced.
704 ///
705 /// Notice that other chains may still be replaced even if the function
706 /// returns false.
707 bool findBetterNeighborChains(StoreSDNode *St);
708
709 // Helper for findBetterNeighborChains. Walk up store chain add additional
710 // chained stores that do not overlap and can be parallelized.
711 bool parallelizeChainedStores(StoreSDNode *St);
712
713 /// Holds a pointer to an LSBaseSDNode as well as information on where it
714 /// is located in a sequence of memory operations connected by a chain.
715 struct MemOpLink {
716 // Ptr to the mem node.
717 LSBaseSDNode *MemNode;
718
719 // Offset from the base ptr.
720 int64_t OffsetFromBase;
721
722 MemOpLink(LSBaseSDNode *N, int64_t Offset)
723 : MemNode(N), OffsetFromBase(Offset) {}
724 };
725
726 // Classify the origin of a stored value.
727 enum class StoreSource { Unknown, Constant, Extract, Load };
728 StoreSource getStoreSource(SDValue StoreVal) {
729 switch (StoreVal.getOpcode()) {
730 case ISD::Constant:
731 case ISD::ConstantFP:
732 return StoreSource::Constant;
733 case ISD::BUILD_VECTOR:
734 if (ISD::isBuildVectorOfConstantSDNodes(N: StoreVal.getNode()) ||
735 ISD::isBuildVectorOfConstantFPSDNodes(N: StoreVal.getNode()))
736 return StoreSource::Constant;
737 return StoreSource::Unknown;
738 case ISD::EXTRACT_VECTOR_ELT:
739 case ISD::EXTRACT_SUBVECTOR:
740 return StoreSource::Extract;
741 case ISD::LOAD:
742 return StoreSource::Load;
743 default:
744 return StoreSource::Unknown;
745 }
746 }
747
748 /// This is a helper function for visitMUL to check the profitability
749 /// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
750 /// MulNode is the original multiply, AddNode is (add x, c1),
751 /// and ConstNode is c2.
752 bool isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
753 SDValue ConstNode);
754
755 /// This is a helper function for visitAND and visitZERO_EXTEND. Returns
756 /// true if the (and (load x) c) pattern matches an extload. ExtVT returns
757 /// the type of the loaded value to be extended.
758 bool isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
759 EVT LoadResultTy, EVT &ExtVT);
760
761 /// Helper function to calculate whether the given Load/Store can have its
762 /// width reduced to ExtVT.
763 bool isLegalNarrowLdSt(LSBaseSDNode *LDSTN, ISD::LoadExtType ExtType,
764 EVT &MemVT, unsigned ShAmt = 0);
765
766 /// Used by BackwardsPropagateMask to find suitable loads.
767 bool SearchForAndLoads(SDNode *N, SmallVectorImpl<LoadSDNode*> &Loads,
768 SmallPtrSetImpl<SDNode*> &NodesWithConsts,
769 ConstantSDNode *Mask, SDNode *&NodeToMask);
770 /// Attempt to propagate a given AND node back to load leaves so that they
771 /// can be combined into narrow loads.
772 bool BackwardsPropagateMask(SDNode *N);
773
774 /// Helper function for mergeConsecutiveStores which merges the component
775 /// store chains.
776 SDValue getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
777 unsigned NumStores);
778
779 /// Helper function for mergeConsecutiveStores which checks if all the store
780 /// nodes have the same underlying object. We can still reuse the first
781 /// store's pointer info if all the stores are from the same object.
782 bool hasSameUnderlyingObj(ArrayRef<MemOpLink> StoreNodes);
783
784 /// This is a helper function for mergeConsecutiveStores. When the source
785 /// elements of the consecutive stores are all constants or all extracted
786 /// vector elements, try to merge them into one larger store introducing
787 /// bitcasts if necessary. \return True if a merged store was created.
788 bool mergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> &StoreNodes,
789 EVT MemVT, unsigned NumStores,
790 bool IsConstantSrc, bool UseVector,
791 bool UseTrunc);
792
793 /// This is a helper function for mergeConsecutiveStores. Stores that
794 /// potentially may be merged with St are placed in StoreNodes. On success,
795 /// returns a chain predecessor to all store candidates.
796 SDNode *getStoreMergeCandidates(StoreSDNode *St,
797 SmallVectorImpl<MemOpLink> &StoreNodes);
798
799 /// Helper function for mergeConsecutiveStores. Checks if candidate stores
800 /// have indirect dependency through their operands. RootNode is the
801 /// predecessor to all stores calculated by getStoreMergeCandidates and is
802 /// used to prune the dependency check. \return True if safe to merge.
803 bool checkMergeStoreCandidatesForDependencies(
804 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
805 SDNode *RootNode);
806
807 /// Helper function for tryStoreMergeOfLoads. Checks if the load/store
808 /// chain has a call in it. \return True if a call is found.
809 bool hasCallInLdStChain(StoreSDNode *St, LoadSDNode *Ld);
810
811 /// This is a helper function for mergeConsecutiveStores. Given a list of
812 /// store candidates, find the first N that are consecutive in memory.
813 /// Returns 0 if there are not at least 2 consecutive stores to try merging.
814 unsigned getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
815 int64_t ElementSizeBytes) const;
816
817 /// This is a helper function for mergeConsecutiveStores. It is used for
818 /// store chains that are composed entirely of constant values.
819 bool tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> &StoreNodes,
820 unsigned NumConsecutiveStores,
821 EVT MemVT, SDNode *Root, bool AllowVectors);
822
823 /// This is a helper function for mergeConsecutiveStores. It is used for
824 /// store chains that are composed entirely of extracted vector elements.
825 /// When extracting multiple vector elements, try to store them in one
826 /// vector store rather than a sequence of scalar stores.
827 bool tryStoreMergeOfExtracts(SmallVectorImpl<MemOpLink> &StoreNodes,
828 unsigned NumConsecutiveStores, EVT MemVT,
829 SDNode *Root);
830
831 /// This is a helper function for mergeConsecutiveStores. It is used for
832 /// store chains that are composed entirely of loaded values.
833 bool tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
834 unsigned NumConsecutiveStores, EVT MemVT,
835 SDNode *Root, bool AllowVectors,
836 bool IsNonTemporalStore, bool IsNonTemporalLoad);
837
838 /// Merge consecutive store operations into a wide store.
839 /// This optimization uses wide integers or vectors when possible.
840 /// \return true if stores were merged.
841 bool mergeConsecutiveStores(StoreSDNode *St);
842
843 /// Try to transform a truncation where C is a constant:
844 /// (trunc (and X, C)) -> (and (trunc X), (trunc C))
845 ///
846 /// \p N needs to be a truncation and its first operand an AND. Other
847 /// requirements are checked by the function (e.g. that trunc is
848 /// single-use) and if missed an empty SDValue is returned.
849 SDValue distributeTruncateThroughAnd(SDNode *N);
850
851 /// Helper function to determine whether the target supports operation
852 /// given by \p Opcode for type \p VT, that is, whether the operation
853 /// is legal or custom before legalizing operations, and whether is
854 /// legal (but not custom) after legalization.
855 bool hasOperation(unsigned Opcode, EVT VT) {
856 return TLI.isOperationLegalOrCustom(Op: Opcode, VT, LegalOnly: LegalOperations);
857 }
858
859 bool hasUMin(EVT VT) const {
860 auto LK = TLI.getTypeConversion(Context&: *DAG.getContext(), VT);
861 return (LK.first == TargetLoweringBase::TypeLegal ||
862 LK.first == TargetLoweringBase::TypePromoteInteger) &&
863 TLI.isOperationLegal(Op: ISD::UMIN, VT: LK.second);
864 }
865
866 public:
867 /// Runs the dag combiner on all nodes in the work list
868 void Run(CombineLevel AtLevel);
869
870 SelectionDAG &getDAG() const { return DAG; }
871
872 /// Convenience wrapper around TargetLowering::getShiftAmountTy.
873 EVT getShiftAmountTy(EVT LHSTy) {
874 return TLI.getShiftAmountTy(LHSTy, DL: DAG.getDataLayout());
875 }
876
877 /// This method returns true if we are running before type legalization or
878 /// if the specified VT is legal.
879 bool isTypeLegal(const EVT &VT) {
880 if (!LegalTypes) return true;
881 return TLI.isTypeLegal(VT);
882 }
883
884 /// Convenience wrapper around TargetLowering::getSetCCResultType
885 EVT getSetCCResultType(EVT VT) const {
886 return TLI.getSetCCResultType(DL: DAG.getDataLayout(), Context&: *DAG.getContext(), VT);
887 }
888
889 void ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
890 SDValue OrigLoad, SDValue ExtLoad,
891 ISD::NodeType ExtType);
892 };
893
894/// This class is a DAGUpdateListener that removes any deleted
895/// nodes from the worklist.
896class WorklistRemover : public SelectionDAG::DAGUpdateListener {
897 DAGCombiner &DC;
898
899public:
900 explicit WorklistRemover(DAGCombiner &dc)
901 : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
902
903 void NodeDeleted(SDNode *N, SDNode *E) override {
904 DC.removeFromWorklist(N);
905 }
906};
907
908class WorklistInserter : public SelectionDAG::DAGUpdateListener {
909 DAGCombiner &DC;
910
911public:
912 explicit WorklistInserter(DAGCombiner &dc)
913 : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
914
915 // FIXME: Ideally we could add N to the worklist, but this causes exponential
916 // compile time costs in large DAGs, e.g. Halide.
917 void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); }
918};
919
920} // end anonymous namespace
921
922//===----------------------------------------------------------------------===//
923// TargetLowering::DAGCombinerInfo implementation
924//===----------------------------------------------------------------------===//
925
926void TargetLowering::DAGCombinerInfo::AddToWorklist(SDNode *N) {
927 ((DAGCombiner*)DC)->AddToWorklist(N);
928}
929
930SDValue TargetLowering::DAGCombinerInfo::
931CombineTo(SDNode *N, ArrayRef<SDValue> To, bool AddTo) {
932 return ((DAGCombiner*)DC)->CombineTo(N, To: &To[0], NumTo: To.size(), AddTo);
933}
934
935SDValue TargetLowering::DAGCombinerInfo::
936CombineTo(SDNode *N, SDValue Res, bool AddTo) {
937 return ((DAGCombiner*)DC)->CombineTo(N, Res, AddTo);
938}
939
940SDValue TargetLowering::DAGCombinerInfo::
941CombineTo(SDNode *N, SDValue Res0, SDValue Res1, bool AddTo) {
942 return ((DAGCombiner*)DC)->CombineTo(N, Res0, Res1, AddTo);
943}
944
945bool TargetLowering::DAGCombinerInfo::
946recursivelyDeleteUnusedNodes(SDNode *N) {
947 return ((DAGCombiner*)DC)->recursivelyDeleteUnusedNodes(N);
948}
949
950void TargetLowering::DAGCombinerInfo::
951CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
952 return ((DAGCombiner*)DC)->CommitTargetLoweringOpt(TLO);
953}
954
955//===----------------------------------------------------------------------===//
956// Helper Functions
957//===----------------------------------------------------------------------===//
958
959void DAGCombiner::deleteAndRecombine(SDNode *N) {
960 removeFromWorklist(N);
961
962 // If the operands of this node are only used by the node, they will now be
963 // dead. Make sure to re-visit them and recursively delete dead nodes.
964 for (const SDValue &Op : N->ops())
965 // For an operand generating multiple values, one of the values may
966 // become dead allowing further simplification (e.g. split index
967 // arithmetic from an indexed load).
968 if (Op->hasOneUse() || Op->getNumValues() > 1)
969 AddToWorklist(N: Op.getNode());
970
971 DAG.DeleteNode(N);
972}
973
974// APInts must be the same size for most operations, this helper
975// function zero extends the shorter of the pair so that they match.
976// We provide an Offset so that we can create bitwidths that won't overflow.
977static void zeroExtendToMatch(APInt &LHS, APInt &RHS, unsigned Offset = 0) {
978 unsigned Bits = Offset + std::max(a: LHS.getBitWidth(), b: RHS.getBitWidth());
979 LHS = LHS.zext(width: Bits);
980 RHS = RHS.zext(width: Bits);
981}
982
983// Return true if this node is a setcc, or is a select_cc
984// that selects between the target values used for true and false, making it
985// equivalent to a setcc. Also, set the incoming LHS, RHS, and CC references to
986// the appropriate nodes based on the type of node we are checking. This
987// simplifies life a bit for the callers.
988bool DAGCombiner::isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
989 SDValue &CC, bool MatchStrict) const {
990 if (N.getOpcode() == ISD::SETCC) {
991 LHS = N.getOperand(i: 0);
992 RHS = N.getOperand(i: 1);
993 CC = N.getOperand(i: 2);
994 return true;
995 }
996
997 if (MatchStrict &&
998 (N.getOpcode() == ISD::STRICT_FSETCC ||
999 N.getOpcode() == ISD::STRICT_FSETCCS)) {
1000 LHS = N.getOperand(i: 1);
1001 RHS = N.getOperand(i: 2);
1002 CC = N.getOperand(i: 3);
1003 return true;
1004 }
1005
1006 if (N.getOpcode() != ISD::SELECT_CC || !TLI.isConstTrueVal(N: N.getOperand(i: 2)) ||
1007 !TLI.isConstFalseVal(N: N.getOperand(i: 3)))
1008 return false;
1009
1010 if (TLI.getBooleanContents(Type: N.getValueType()) ==
1011 TargetLowering::UndefinedBooleanContent)
1012 return false;
1013
1014 LHS = N.getOperand(i: 0);
1015 RHS = N.getOperand(i: 1);
1016 CC = N.getOperand(i: 4);
1017 return true;
1018}
1019
1020/// Return true if this is a SetCC-equivalent operation with only one use.
1021/// If this is true, it allows the users to invert the operation for free when
1022/// it is profitable to do so.
1023bool DAGCombiner::isOneUseSetCC(SDValue N) const {
1024 SDValue N0, N1, N2;
1025 if (isSetCCEquivalent(N, LHS&: N0, RHS&: N1, CC&: N2) && N->hasOneUse())
1026 return true;
1027 return false;
1028}
1029
1030static bool isConstantSplatVectorMaskForType(SDNode *N, EVT ScalarTy) {
1031 if (!ScalarTy.isSimple())
1032 return false;
1033
1034 uint64_t MaskForTy = 0ULL;
1035 switch (ScalarTy.getSimpleVT().SimpleTy) {
1036 case MVT::i8:
1037 MaskForTy = 0xFFULL;
1038 break;
1039 case MVT::i16:
1040 MaskForTy = 0xFFFFULL;
1041 break;
1042 case MVT::i32:
1043 MaskForTy = 0xFFFFFFFFULL;
1044 break;
1045 default:
1046 return false;
1047 break;
1048 }
1049
1050 APInt Val;
1051 if (ISD::isConstantSplatVector(N, SplatValue&: Val))
1052 return Val.getLimitedValue() == MaskForTy;
1053
1054 return false;
1055}
1056
1057// Determines if it is a constant integer or a splat/build vector of constant
1058// integers (and undefs).
1059// Do not permit build vector implicit truncation.
1060static bool isConstantOrConstantVector(SDValue N, bool NoOpaques = false) {
1061 if (ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Val&: N))
1062 return !(Const->isOpaque() && NoOpaques);
1063 if (N.getOpcode() != ISD::BUILD_VECTOR && N.getOpcode() != ISD::SPLAT_VECTOR)
1064 return false;
1065 unsigned BitWidth = N.getScalarValueSizeInBits();
1066 for (const SDValue &Op : N->op_values()) {
1067 if (Op.isUndef())
1068 continue;
1069 ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Val: Op);
1070 if (!Const || Const->getAPIntValue().getBitWidth() != BitWidth ||
1071 (Const->isOpaque() && NoOpaques))
1072 return false;
1073 }
1074 return true;
1075}
1076
1077// Determines if a BUILD_VECTOR is composed of all-constants possibly mixed with
1078// undef's.
1079static bool isAnyConstantBuildVector(SDValue V, bool NoOpaques = false) {
1080 if (V.getOpcode() != ISD::BUILD_VECTOR)
1081 return false;
1082 return isConstantOrConstantVector(N: V, NoOpaques) ||
1083 ISD::isBuildVectorOfConstantFPSDNodes(N: V.getNode());
1084}
1085
1086// Determine if this an indexed load with an opaque target constant index.
1087static bool canSplitIdx(LoadSDNode *LD) {
1088 return MaySplitLoadIndex &&
1089 (LD->getOperand(Num: 2).getOpcode() != ISD::TargetConstant ||
1090 !cast<ConstantSDNode>(Val: LD->getOperand(Num: 2))->isOpaque());
1091}
1092
1093bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc,
1094 const SDLoc &DL,
1095 SDNode *N,
1096 SDValue N0,
1097 SDValue N1) {
1098 // Currently this only tries to ensure we don't undo the GEP splits done by
1099 // CodeGenPrepare when shouldConsiderGEPOffsetSplit is true. To ensure this,
1100 // we check if the following transformation would be problematic:
1101 // (load/store (add, (add, x, offset1), offset2)) ->
1102 // (load/store (add, x, offset1+offset2)).
1103
1104 // (load/store (add, (add, x, y), offset2)) ->
1105 // (load/store (add, (add, x, offset2), y)).
1106
1107 if (!N0.isAnyAdd())
1108 return false;
1109
1110 // Check for vscale addressing modes.
1111 // (load/store (add/sub (add x, y), vscale))
1112 // (load/store (add/sub (add x, y), (lsl vscale, C)))
1113 // (load/store (add/sub (add x, y), (mul vscale, C)))
1114 if ((N1.getOpcode() == ISD::VSCALE ||
1115 ((N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::MUL) &&
1116 N1.getOperand(i: 0).getOpcode() == ISD::VSCALE &&
1117 isa<ConstantSDNode>(Val: N1.getOperand(i: 1)))) &&
1118 N1.getValueType().getFixedSizeInBits() <= 64) {
1119 int64_t ScalableOffset = N1.getOpcode() == ISD::VSCALE
1120 ? N1.getConstantOperandVal(i: 0)
1121 : (N1.getOperand(i: 0).getConstantOperandVal(i: 0) *
1122 (N1.getOpcode() == ISD::SHL
1123 ? (1LL << N1.getConstantOperandVal(i: 1))
1124 : N1.getConstantOperandVal(i: 1)));
1125 if (Opc == ISD::SUB)
1126 ScalableOffset = -ScalableOffset;
1127 if (all_of(Range: N->users(), P: [&](SDNode *Node) {
1128 if (auto *LoadStore = dyn_cast<MemSDNode>(Val: Node);
1129 LoadStore && LoadStore->getBasePtr().getNode() == N) {
1130 TargetLoweringBase::AddrMode AM;
1131 AM.HasBaseReg = true;
1132 AM.ScalableOffset = ScalableOffset;
1133 EVT VT = LoadStore->getMemoryVT();
1134 unsigned AS = LoadStore->getAddressSpace();
1135 Type *AccessTy = VT.getTypeForEVT(Context&: *DAG.getContext());
1136 return TLI.isLegalAddressingMode(DL: DAG.getDataLayout(), AM, Ty: AccessTy,
1137 AddrSpace: AS);
1138 }
1139 return false;
1140 }))
1141 return true;
1142 }
1143
1144 if (Opc != ISD::ADD && Opc != ISD::PTRADD)
1145 return false;
1146
1147 auto *C2 = dyn_cast<ConstantSDNode>(Val&: N1);
1148 if (!C2)
1149 return false;
1150
1151 const APInt &C2APIntVal = C2->getAPIntValue();
1152 if (C2APIntVal.getSignificantBits() > 64)
1153 return false;
1154
1155 if (auto *C1 = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1))) {
1156 if (N0.hasOneUse())
1157 return false;
1158
1159 const APInt &C1APIntVal = C1->getAPIntValue();
1160 const APInt CombinedValueIntVal = C1APIntVal + C2APIntVal;
1161 if (CombinedValueIntVal.getSignificantBits() > 64)
1162 return false;
1163 const int64_t CombinedValue = CombinedValueIntVal.getSExtValue();
1164
1165 for (SDNode *Node : N->users()) {
1166 if (auto *LoadStore = dyn_cast<MemSDNode>(Val: Node)) {
1167 // Is x[offset2] already not a legal addressing mode? If so then
1168 // reassociating the constants breaks nothing (we test offset2 because
1169 // that's the one we hope to fold into the load or store).
1170 TargetLoweringBase::AddrMode AM;
1171 AM.HasBaseReg = true;
1172 AM.BaseOffs = C2APIntVal.getSExtValue();
1173 EVT VT = LoadStore->getMemoryVT();
1174 unsigned AS = LoadStore->getAddressSpace();
1175 Type *AccessTy = VT.getTypeForEVT(Context&: *DAG.getContext());
1176 if (!TLI.isLegalAddressingMode(DL: DAG.getDataLayout(), AM, Ty: AccessTy, AddrSpace: AS))
1177 continue;
1178
1179 // Would x[offset1+offset2] still be a legal addressing mode?
1180 AM.BaseOffs = CombinedValue;
1181 if (!TLI.isLegalAddressingMode(DL: DAG.getDataLayout(), AM, Ty: AccessTy, AddrSpace: AS))
1182 return true;
1183 }
1184 }
1185 } else {
1186 if (auto *GA = dyn_cast<GlobalAddressSDNode>(Val: N0.getOperand(i: 1)))
1187 if (GA->getOpcode() == ISD::GlobalAddress && TLI.isOffsetFoldingLegal(GA))
1188 return false;
1189
1190 for (SDNode *Node : N->users()) {
1191 auto *LoadStore = dyn_cast<MemSDNode>(Val: Node);
1192 if (!LoadStore)
1193 return false;
1194
1195 // Is x[offset2] a legal addressing mode? If so then
1196 // reassociating the constants breaks address pattern
1197 TargetLoweringBase::AddrMode AM;
1198 AM.HasBaseReg = true;
1199 AM.BaseOffs = C2APIntVal.getSExtValue();
1200 EVT VT = LoadStore->getMemoryVT();
1201 unsigned AS = LoadStore->getAddressSpace();
1202 Type *AccessTy = VT.getTypeForEVT(Context&: *DAG.getContext());
1203 if (!TLI.isLegalAddressingMode(DL: DAG.getDataLayout(), AM, Ty: AccessTy, AddrSpace: AS))
1204 return false;
1205 }
1206 return true;
1207 }
1208
1209 return false;
1210}
1211
1212/// Helper for DAGCombiner::reassociateOps. Try to reassociate (Opc N0, N1) if
1213/// \p N0 is the same kind of operation as \p Opc.
1214SDValue DAGCombiner::reassociateOpsCommutative(unsigned Opc, const SDLoc &DL,
1215 SDValue N0, SDValue N1,
1216 SDNodeFlags Flags) {
1217 EVT VT = N0.getValueType();
1218
1219 if (N0.getOpcode() != Opc)
1220 return SDValue();
1221
1222 SDValue N00 = N0.getOperand(i: 0);
1223 SDValue N01 = N0.getOperand(i: 1);
1224
1225 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N01)) {
1226 SDNodeFlags NewFlags;
1227 if (N0.getOpcode() == ISD::ADD && N0->getFlags().hasNoUnsignedWrap() &&
1228 Flags.hasNoUnsignedWrap())
1229 NewFlags |= SDNodeFlags::NoUnsignedWrap;
1230
1231 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N1)) {
1232 // Reassociate: (op (op x, c1), c2) -> (op x, (op c1, c2))
1233 if (SDValue OpNode = DAG.FoldConstantArithmetic(Opcode: Opc, DL, VT, Ops: {N01, N1})) {
1234 NewFlags.setDisjoint(Flags.hasDisjoint() &&
1235 N0->getFlags().hasDisjoint());
1236 return DAG.getNode(Opcode: Opc, DL, VT, N1: N00, N2: OpNode, Flags: NewFlags);
1237 }
1238 return SDValue();
1239 }
1240 if (TLI.isReassocProfitable(DAG, N0, N1)) {
1241 // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1)
1242 // iff (op x, c1) has one use
1243 SDValue OpNode = DAG.getNode(Opcode: Opc, DL: SDLoc(N0), VT, N1: N00, N2: N1, Flags: NewFlags);
1244 return DAG.getNode(Opcode: Opc, DL, VT, N1: OpNode, N2: N01, Flags: NewFlags);
1245 }
1246 }
1247
1248 // Check for repeated operand logic simplifications.
1249 if (Opc == ISD::AND || Opc == ISD::OR) {
1250 // (N00 & N01) & N00 --> N00 & N01
1251 // (N00 & N01) & N01 --> N00 & N01
1252 // (N00 | N01) | N00 --> N00 | N01
1253 // (N00 | N01) | N01 --> N00 | N01
1254 if (N1 == N00 || N1 == N01)
1255 return N0;
1256 }
1257 if (Opc == ISD::XOR) {
1258 // (N00 ^ N01) ^ N00 --> N01
1259 if (N1 == N00)
1260 return N01;
1261 // (N00 ^ N01) ^ N01 --> N00
1262 if (N1 == N01)
1263 return N00;
1264 }
1265
1266 if (TLI.isReassocProfitable(DAG, N0, N1)) {
1267 if (N1 != N01) {
1268 // Reassociate if (op N00, N1) already exist
1269 if (SDNode *NE = DAG.getNodeIfExists(Opcode: Opc, VTList: DAG.getVTList(VT), Ops: {N00, N1})) {
1270 // if Op (Op N00, N1), N01 already exist
1271 // we need to stop reassciate to avoid dead loop
1272 if (!DAG.doesNodeExist(Opcode: Opc, VTList: DAG.getVTList(VT), Ops: {SDValue(NE, 0), N01}))
1273 return DAG.getNode(Opcode: Opc, DL, VT, N1: SDValue(NE, 0), N2: N01);
1274 }
1275 }
1276
1277 if (N1 != N00) {
1278 // Reassociate if (op N01, N1) already exist
1279 if (SDNode *NE = DAG.getNodeIfExists(Opcode: Opc, VTList: DAG.getVTList(VT), Ops: {N01, N1})) {
1280 // if Op (Op N01, N1), N00 already exist
1281 // we need to stop reassciate to avoid dead loop
1282 if (!DAG.doesNodeExist(Opcode: Opc, VTList: DAG.getVTList(VT), Ops: {SDValue(NE, 0), N00}))
1283 return DAG.getNode(Opcode: Opc, DL, VT, N1: SDValue(NE, 0), N2: N00);
1284 }
1285 }
1286
1287 // Reassociate the operands from (OR/AND (OR/AND(N00, N001)), N1) to (OR/AND
1288 // (OR/AND(N00, N1)), N01) when N00 and N1 are comparisons with the same
1289 // predicate or to (OR/AND (OR/AND(N1, N01)), N00) when N01 and N1 are
1290 // comparisons with the same predicate. This enables optimizations as the
1291 // following one:
1292 // CMP(A,C)||CMP(B,C) => CMP(MIN/MAX(A,B), C)
1293 // CMP(A,C)&&CMP(B,C) => CMP(MIN/MAX(A,B), C)
1294 if (Opc == ISD::AND || Opc == ISD::OR) {
1295 if (N1->getOpcode() == ISD::SETCC && N00->getOpcode() == ISD::SETCC &&
1296 N01->getOpcode() == ISD::SETCC) {
1297 ISD::CondCode CC1 = cast<CondCodeSDNode>(Val: N1.getOperand(i: 2))->get();
1298 ISD::CondCode CC00 = cast<CondCodeSDNode>(Val: N00.getOperand(i: 2))->get();
1299 ISD::CondCode CC01 = cast<CondCodeSDNode>(Val: N01.getOperand(i: 2))->get();
1300 if (CC1 == CC00 && CC1 != CC01) {
1301 SDValue OpNode = DAG.getNode(Opcode: Opc, DL: SDLoc(N0), VT, N1: N00, N2: N1, Flags);
1302 return DAG.getNode(Opcode: Opc, DL, VT, N1: OpNode, N2: N01, Flags);
1303 }
1304 if (CC1 == CC01 && CC1 != CC00) {
1305 SDValue OpNode = DAG.getNode(Opcode: Opc, DL: SDLoc(N0), VT, N1: N01, N2: N1, Flags);
1306 return DAG.getNode(Opcode: Opc, DL, VT, N1: OpNode, N2: N00, Flags);
1307 }
1308 }
1309 }
1310 }
1311
1312 return SDValue();
1313}
1314
1315/// Try to reassociate commutative (Opc N0, N1) if either \p N0 or \p N1 is the
1316/// same kind of operation as \p Opc.
1317SDValue DAGCombiner::reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
1318 SDValue N1, SDNodeFlags Flags) {
1319 assert(TLI.isCommutativeBinOp(Opc) && "Operation not commutative.");
1320
1321 // Floating-point reassociation is not allowed without loose FP math.
1322 if (N0.getValueType().isFloatingPoint() ||
1323 N1.getValueType().isFloatingPoint())
1324 if (!Flags.hasAllowReassociation() || !Flags.hasNoSignedZeros())
1325 return SDValue();
1326
1327 if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N0, N1, Flags))
1328 return Combined;
1329 if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N0: N1, N1: N0, Flags))
1330 return Combined;
1331 return SDValue();
1332}
1333
1334// Try to fold Opc(vecreduce(x), vecreduce(y)) -> vecreduce(Opc(x, y))
1335// Note that we only expect Flags to be passed from FP operations. For integer
1336// operations they need to be dropped.
1337SDValue DAGCombiner::reassociateReduction(unsigned RedOpc, unsigned Opc,
1338 const SDLoc &DL, EVT VT, SDValue N0,
1339 SDValue N1, SDNodeFlags Flags) {
1340 if (N0.getOpcode() == RedOpc && N1.getOpcode() == RedOpc &&
1341 N0.getOperand(i: 0).getValueType() == N1.getOperand(i: 0).getValueType() &&
1342 N0->hasOneUse() && N1->hasOneUse() &&
1343 TLI.isOperationLegalOrCustom(Op: Opc, VT: N0.getOperand(i: 0).getValueType()) &&
1344 TLI.shouldReassociateReduction(RedOpc, VT: N0.getOperand(i: 0).getValueType())) {
1345 SelectionDAG::FlagInserter FlagsInserter(DAG, Flags);
1346 return DAG.getNode(Opcode: RedOpc, DL, VT,
1347 Operand: DAG.getNode(Opcode: Opc, DL, VT: N0.getOperand(i: 0).getValueType(),
1348 N1: N0.getOperand(i: 0), N2: N1.getOperand(i: 0)));
1349 }
1350
1351 // Reassociate op(op(vecreduce(a), b), op(vecreduce(c), d)) into
1352 // op(vecreduce(op(a, c)), op(b, d)), to combine the reductions into a
1353 // single node.
1354 SDValue A, B, C, D, RedA, RedB;
1355 if (sd_match(N: N0, P: m_OneUse(P: m_c_BinOp(
1356 Opc,
1357 L: m_AllOf(preds: m_OneUse(P: m_UnaryOp(Opc: RedOpc, Op: m_Value(N&: A))),
1358 preds: m_Value(N&: RedA)),
1359 R: m_Value(N&: B)))) &&
1360 sd_match(N: N1, P: m_OneUse(P: m_c_BinOp(
1361 Opc,
1362 L: m_AllOf(preds: m_OneUse(P: m_UnaryOp(Opc: RedOpc, Op: m_Value(N&: C))),
1363 preds: m_Value(N&: RedB)),
1364 R: m_Value(N&: D)))) &&
1365 !sd_match(N: B, P: m_UnaryOp(Opc: RedOpc, Op: m_Value())) &&
1366 !sd_match(N: D, P: m_UnaryOp(Opc: RedOpc, Op: m_Value())) &&
1367 A.getValueType() == C.getValueType() &&
1368 hasOperation(Opcode: Opc, VT: A.getValueType()) &&
1369 TLI.shouldReassociateReduction(RedOpc, VT)) {
1370 if ((Opc == ISD::FADD || Opc == ISD::FMUL) &&
1371 (!N0->getFlags().hasAllowReassociation() ||
1372 !N1->getFlags().hasAllowReassociation() ||
1373 !RedA->getFlags().hasAllowReassociation() ||
1374 !RedB->getFlags().hasAllowReassociation()))
1375 return SDValue();
1376 SelectionDAG::FlagInserter FlagsInserter(
1377 DAG, Flags & N0->getFlags() & N1->getFlags() & RedA->getFlags() &
1378 RedB->getFlags());
1379 SDValue Op = DAG.getNode(Opcode: Opc, DL, VT: A.getValueType(), N1: A, N2: C);
1380 SDValue Red = DAG.getNode(Opcode: RedOpc, DL, VT, Operand: Op);
1381 SDValue Op2 = DAG.getNode(Opcode: Opc, DL, VT, N1: B, N2: D);
1382 return DAG.getNode(Opcode: Opc, DL, VT, N1: Red, N2: Op2);
1383 }
1384 return SDValue();
1385}
1386
1387SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
1388 bool AddTo) {
1389 assert(N->getNumValues() == NumTo && "Broken CombineTo call!");
1390 ++NodesCombined;
1391 LLVM_DEBUG(dbgs() << "\nReplacing.1 "; N->dump(&DAG); dbgs() << "\nWith: ";
1392 To[0].dump(&DAG);
1393 dbgs() << " and " << NumTo - 1 << " other values\n");
1394 for (unsigned i = 0, e = NumTo; i != e; ++i)
1395 assert((!To[i].getNode() ||
1396 N->getValueType(i) == To[i].getValueType()) &&
1397 "Cannot combine value to value of different type!");
1398
1399 WorklistRemover DeadNodes(*this);
1400 DAG.ReplaceAllUsesWith(From: N, To);
1401 if (AddTo) {
1402 // Push the new nodes and any users onto the worklist
1403 for (unsigned i = 0, e = NumTo; i != e; ++i) {
1404 if (To[i].getNode())
1405 AddToWorklistWithUsers(N: To[i].getNode());
1406 }
1407 }
1408
1409 // Finally, if the node is now dead, remove it from the graph. The node
1410 // may not be dead if the replacement process recursively simplified to
1411 // something else needing this node.
1412 if (N->use_empty())
1413 deleteAndRecombine(N);
1414 return SDValue(N, 0);
1415}
1416
1417void DAGCombiner::
1418CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
1419 // Replace the old value with the new one.
1420 ++NodesCombined;
1421 LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.dump(&DAG);
1422 dbgs() << "\nWith: "; TLO.New.dump(&DAG); dbgs() << '\n');
1423
1424 // Replace all uses.
1425 DAG.ReplaceAllUsesOfValueWith(From: TLO.Old, To: TLO.New);
1426
1427 // Push the new node and any (possibly new) users onto the worklist.
1428 AddToWorklistWithUsers(N: TLO.New.getNode());
1429
1430 // Finally, if the node is now dead, remove it from the graph.
1431 recursivelyDeleteUnusedNodes(N: TLO.Old.getNode());
1432}
1433
1434/// Check the specified integer node value to see if it can be simplified or if
1435/// things it uses can be simplified by bit propagation. If so, return true.
1436bool DAGCombiner::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
1437 const APInt &DemandedElts,
1438 bool AssumeSingleUse) {
1439 TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1440 KnownBits Known;
1441 if (!TLI.SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, Depth: 0,
1442 AssumeSingleUse))
1443 return false;
1444
1445 // Revisit the node.
1446 AddToWorklist(N: Op.getNode());
1447
1448 CommitTargetLoweringOpt(TLO);
1449 return true;
1450}
1451
1452/// Check the specified vector node value to see if it can be simplified or
1453/// if things it uses can be simplified as it only uses some of the elements.
1454/// If so, return true.
1455bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op,
1456 const APInt &DemandedElts,
1457 bool AssumeSingleUse) {
1458 TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1459 APInt KnownUndef, KnownZero;
1460 if (!TLI.SimplifyDemandedVectorElts(Op, DemandedEltMask: DemandedElts, KnownUndef, KnownZero,
1461 TLO, Depth: 0, AssumeSingleUse))
1462 return false;
1463
1464 // Revisit the node.
1465 AddToWorklist(N: Op.getNode());
1466
1467 CommitTargetLoweringOpt(TLO);
1468 return true;
1469}
1470
1471void DAGCombiner::ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad) {
1472 SDLoc DL(Load);
1473 EVT VT = Load->getValueType(ResNo: 0);
1474 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: SDValue(ExtLoad, 0));
1475
1476 LLVM_DEBUG(dbgs() << "\nReplacing.9 "; Load->dump(&DAG); dbgs() << "\nWith: ";
1477 Trunc.dump(&DAG); dbgs() << '\n');
1478
1479 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Load, 0), To: Trunc);
1480 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Load, 1), To: SDValue(ExtLoad, 1));
1481
1482 AddToWorklist(N: Trunc.getNode());
1483 recursivelyDeleteUnusedNodes(N: Load);
1484}
1485
1486SDValue DAGCombiner::PromoteOperand(SDValue Op, EVT PVT, bool &Replace) {
1487 Replace = false;
1488 SDLoc DL(Op);
1489 if (ISD::isUNINDEXEDLoad(N: Op.getNode())) {
1490 LoadSDNode *LD = cast<LoadSDNode>(Val&: Op);
1491 EVT MemVT = LD->getMemoryVT();
1492 ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(N: LD) ? ISD::EXTLOAD
1493 : LD->getExtensionType();
1494 Replace = true;
1495 return DAG.getExtLoad(ExtType, dl: DL, VT: PVT,
1496 Chain: LD->getChain(), Ptr: LD->getBasePtr(),
1497 MemVT, MMO: LD->getMemOperand());
1498 }
1499
1500 unsigned Opc = Op.getOpcode();
1501 switch (Opc) {
1502 default: break;
1503 case ISD::AssertSext:
1504 if (SDValue Op0 = SExtPromoteOperand(Op: Op.getOperand(i: 0), PVT))
1505 return DAG.getNode(Opcode: ISD::AssertSext, DL, VT: PVT, N1: Op0, N2: Op.getOperand(i: 1));
1506 break;
1507 case ISD::AssertZext:
1508 if (SDValue Op0 = ZExtPromoteOperand(Op: Op.getOperand(i: 0), PVT))
1509 return DAG.getNode(Opcode: ISD::AssertZext, DL, VT: PVT, N1: Op0, N2: Op.getOperand(i: 1));
1510 break;
1511 case ISD::Constant: {
1512 unsigned ExtOpc =
1513 Op.getValueType().isByteSized() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1514 return DAG.getNode(Opcode: ExtOpc, DL, VT: PVT, Operand: Op);
1515 }
1516 }
1517
1518 if (!TLI.isOperationLegal(Op: ISD::ANY_EXTEND, VT: PVT))
1519 return SDValue();
1520 return DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: PVT, Operand: Op);
1521}
1522
1523SDValue DAGCombiner::SExtPromoteOperand(SDValue Op, EVT PVT) {
1524 if (!TLI.isOperationLegal(Op: ISD::SIGN_EXTEND_INREG, VT: PVT))
1525 return SDValue();
1526 EVT OldVT = Op.getValueType();
1527 SDLoc DL(Op);
1528 bool Replace = false;
1529 SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1530 if (!NewOp.getNode())
1531 return SDValue();
1532 AddToWorklist(N: NewOp.getNode());
1533
1534 if (Replace)
1535 ReplaceLoadWithPromotedLoad(Load: Op.getNode(), ExtLoad: NewOp.getNode());
1536 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL, VT: NewOp.getValueType(), N1: NewOp,
1537 N2: DAG.getValueType(OldVT));
1538}
1539
1540SDValue DAGCombiner::ZExtPromoteOperand(SDValue Op, EVT PVT) {
1541 EVT OldVT = Op.getValueType();
1542 SDLoc DL(Op);
1543 bool Replace = false;
1544 SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1545 if (!NewOp.getNode())
1546 return SDValue();
1547 AddToWorklist(N: NewOp.getNode());
1548
1549 if (Replace)
1550 ReplaceLoadWithPromotedLoad(Load: Op.getNode(), ExtLoad: NewOp.getNode());
1551 return DAG.getZeroExtendInReg(Op: NewOp, DL, VT: OldVT);
1552}
1553
1554/// Promote the specified integer binary operation if the target indicates it is
1555/// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1556/// i32 since i16 instructions are longer.
1557SDValue DAGCombiner::PromoteIntBinOp(SDValue Op) {
1558 if (!LegalOperations)
1559 return SDValue();
1560
1561 EVT VT = Op.getValueType();
1562 if (VT.isVector() || !VT.isInteger())
1563 return SDValue();
1564
1565 // If operation type is 'undesirable', e.g. i16 on x86, consider
1566 // promoting it.
1567 unsigned Opc = Op.getOpcode();
1568 if (TLI.isTypeDesirableForOp(Opc, VT))
1569 return SDValue();
1570
1571 EVT PVT = VT;
1572 // Consult target whether it is a good idea to promote this operation and
1573 // what's the right type to promote it to.
1574 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1575 assert(PVT != VT && "Don't know what type to promote to!");
1576
1577 LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1578
1579 bool Replace0 = false;
1580 SDValue N0 = Op.getOperand(i: 0);
1581 SDValue NN0 = PromoteOperand(Op: N0, PVT, Replace&: Replace0);
1582
1583 bool Replace1 = false;
1584 SDValue N1 = Op.getOperand(i: 1);
1585 SDValue NN1 = PromoteOperand(Op: N1, PVT, Replace&: Replace1);
1586 SDLoc DL(Op);
1587
1588 SDValue RV =
1589 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: DAG.getNode(Opcode: Opc, DL, VT: PVT, N1: NN0, N2: NN1));
1590
1591 // We are always replacing N0/N1's use in N and only need additional
1592 // replacements if there are additional uses.
1593 // Note: We are checking uses of the *nodes* (SDNode) rather than values
1594 // (SDValue) here because the node may reference multiple values
1595 // (for example, the chain value of a load node).
1596 Replace0 &= !N0->hasOneUse();
1597 Replace1 &= (N0 != N1) && !N1->hasOneUse();
1598
1599 // Combine Op here so it is preserved past replacements.
1600 CombineTo(N: Op.getNode(), Res: RV);
1601
1602 // If operands have a use ordering, make sure we deal with
1603 // predecessor first.
1604 if (Replace0 && Replace1 && N0->isPredecessorOf(N: N1.getNode())) {
1605 std::swap(a&: N0, b&: N1);
1606 std::swap(a&: NN0, b&: NN1);
1607 }
1608
1609 if (Replace0) {
1610 AddToWorklist(N: NN0.getNode());
1611 ReplaceLoadWithPromotedLoad(Load: N0.getNode(), ExtLoad: NN0.getNode());
1612 }
1613 if (Replace1) {
1614 AddToWorklist(N: NN1.getNode());
1615 ReplaceLoadWithPromotedLoad(Load: N1.getNode(), ExtLoad: NN1.getNode());
1616 }
1617 return Op;
1618 }
1619 return SDValue();
1620}
1621
1622/// Promote the specified integer shift operation if the target indicates it is
1623/// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1624/// i32 since i16 instructions are longer.
1625SDValue DAGCombiner::PromoteIntShiftOp(SDValue Op) {
1626 if (!LegalOperations)
1627 return SDValue();
1628
1629 EVT VT = Op.getValueType();
1630 if (VT.isVector() || !VT.isInteger())
1631 return SDValue();
1632
1633 // If operation type is 'undesirable', e.g. i16 on x86, consider
1634 // promoting it.
1635 unsigned Opc = Op.getOpcode();
1636 if (TLI.isTypeDesirableForOp(Opc, VT))
1637 return SDValue();
1638
1639 EVT PVT = VT;
1640 // Consult target whether it is a good idea to promote this operation and
1641 // what's the right type to promote it to.
1642 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1643 assert(PVT != VT && "Don't know what type to promote to!");
1644
1645 LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1646
1647 bool Replace = false;
1648 SDValue N0 = Op.getOperand(i: 0);
1649 if (Opc == ISD::SRA)
1650 N0 = SExtPromoteOperand(Op: N0, PVT);
1651 else if (Opc == ISD::SRL)
1652 N0 = ZExtPromoteOperand(Op: N0, PVT);
1653 else
1654 N0 = PromoteOperand(Op: N0, PVT, Replace);
1655
1656 if (!N0.getNode())
1657 return SDValue();
1658
1659 SDLoc DL(Op);
1660 SDValue N1 = Op.getOperand(i: 1);
1661 SDValue RV =
1662 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: DAG.getNode(Opcode: Opc, DL, VT: PVT, N1: N0, N2: N1));
1663
1664 if (Replace)
1665 ReplaceLoadWithPromotedLoad(Load: Op.getOperand(i: 0).getNode(), ExtLoad: N0.getNode());
1666
1667 // Deal with Op being deleted.
1668 if (Op && Op.getOpcode() != ISD::DELETED_NODE)
1669 return RV;
1670 }
1671 return SDValue();
1672}
1673
1674SDValue DAGCombiner::PromoteExtend(SDValue Op) {
1675 if (!LegalOperations)
1676 return SDValue();
1677
1678 EVT VT = Op.getValueType();
1679 if (VT.isVector() || !VT.isInteger())
1680 return SDValue();
1681
1682 // If operation type is 'undesirable', e.g. i16 on x86, consider
1683 // promoting it.
1684 unsigned Opc = Op.getOpcode();
1685 if (TLI.isTypeDesirableForOp(Opc, VT))
1686 return SDValue();
1687
1688 EVT PVT = VT;
1689 // Consult target whether it is a good idea to promote this operation and
1690 // what's the right type to promote it to.
1691 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1692 assert(PVT != VT && "Don't know what type to promote to!");
1693 // fold (aext (aext x)) -> (aext x)
1694 // fold (aext (zext x)) -> (zext x)
1695 // fold (aext (sext x)) -> (sext x)
1696 LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1697 return DAG.getNode(Opcode: Op.getOpcode(), DL: SDLoc(Op), VT, Operand: Op.getOperand(i: 0));
1698 }
1699 return SDValue();
1700}
1701
1702bool DAGCombiner::PromoteLoad(SDValue Op) {
1703 if (!LegalOperations)
1704 return false;
1705
1706 if (!ISD::isUNINDEXEDLoad(N: Op.getNode()))
1707 return false;
1708
1709 EVT VT = Op.getValueType();
1710 if (VT.isVector() || !VT.isInteger())
1711 return false;
1712
1713 // If operation type is 'undesirable', e.g. i16 on x86, consider
1714 // promoting it.
1715 unsigned Opc = Op.getOpcode();
1716 if (TLI.isTypeDesirableForOp(Opc, VT))
1717 return false;
1718
1719 EVT PVT = VT;
1720 // Consult target whether it is a good idea to promote this operation and
1721 // what's the right type to promote it to.
1722 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1723 assert(PVT != VT && "Don't know what type to promote to!");
1724
1725 SDLoc DL(Op);
1726 SDNode *N = Op.getNode();
1727 LoadSDNode *LD = cast<LoadSDNode>(Val: N);
1728 EVT MemVT = LD->getMemoryVT();
1729 ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(N: LD) ? ISD::EXTLOAD
1730 : LD->getExtensionType();
1731 SDValue NewLD = DAG.getExtLoad(ExtType, dl: DL, VT: PVT,
1732 Chain: LD->getChain(), Ptr: LD->getBasePtr(),
1733 MemVT, MMO: LD->getMemOperand());
1734 SDValue Result = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: NewLD);
1735
1736 LLVM_DEBUG(dbgs() << "\nPromoting "; N->dump(&DAG); dbgs() << "\nTo: ";
1737 Result.dump(&DAG); dbgs() << '\n');
1738
1739 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result);
1740 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: NewLD.getValue(R: 1));
1741
1742 AddToWorklist(N: Result.getNode());
1743 recursivelyDeleteUnusedNodes(N);
1744 return true;
1745 }
1746
1747 return false;
1748}
1749
1750/// Recursively delete a node which has no uses and any operands for
1751/// which it is the only use.
1752///
1753/// Note that this both deletes the nodes and removes them from the worklist.
1754/// It also adds any nodes who have had a user deleted to the worklist as they
1755/// may now have only one use and subject to other combines.
1756bool DAGCombiner::recursivelyDeleteUnusedNodes(SDNode *N) {
1757 if (!N->use_empty())
1758 return false;
1759
1760 SmallSetVector<SDNode *, 16> Nodes;
1761 Nodes.insert(X: N);
1762 do {
1763 N = Nodes.pop_back_val();
1764 if (!N)
1765 continue;
1766
1767 if (N->use_empty()) {
1768 for (const SDValue &ChildN : N->op_values())
1769 Nodes.insert(X: ChildN.getNode());
1770
1771 removeFromWorklist(N);
1772 DAG.DeleteNode(N);
1773 } else {
1774 AddToWorklist(N);
1775 }
1776 } while (!Nodes.empty());
1777 return true;
1778}
1779
1780//===----------------------------------------------------------------------===//
1781// Main DAG Combiner implementation
1782//===----------------------------------------------------------------------===//
1783
1784void DAGCombiner::Run(CombineLevel AtLevel) {
1785 // set the instance variables, so that the various visit routines may use it.
1786 Level = AtLevel;
1787 LegalDAG = Level >= AfterLegalizeDAG;
1788 LegalOperations = Level >= AfterLegalizeVectorOps;
1789 LegalTypes = Level >= AfterLegalizeTypes;
1790
1791 WorklistInserter AddNodes(*this);
1792
1793 // Add all the dag nodes to the worklist.
1794 //
1795 // Note: All nodes are not added to PruningList here, this is because the only
1796 // nodes which can be deleted are those which have no uses and all other nodes
1797 // which would otherwise be added to the worklist by the first call to
1798 // getNextWorklistEntry are already present in it.
1799 for (SDNode &Node : DAG.allnodes())
1800 AddToWorklist(N: &Node, /* IsCandidateForPruning */ Node.use_empty());
1801
1802 // Create a dummy node (which is not added to allnodes), that adds a reference
1803 // to the root node, preventing it from being deleted, and tracking any
1804 // changes of the root.
1805 HandleSDNode Dummy(DAG.getRoot());
1806
1807 // While we have a valid worklist entry node, try to combine it.
1808 while (SDNode *N = getNextWorklistEntry()) {
1809 // If N has no uses, it is dead. Make sure to revisit all N's operands once
1810 // N is deleted from the DAG, since they too may now be dead or may have a
1811 // reduced number of uses, allowing other xforms.
1812 if (recursivelyDeleteUnusedNodes(N))
1813 continue;
1814
1815 WorklistRemover DeadNodes(*this);
1816
1817 // If this combine is running after legalizing the DAG, re-legalize any
1818 // nodes pulled off the worklist.
1819 if (LegalDAG) {
1820 SmallSetVector<SDNode *, 16> UpdatedNodes;
1821 bool NIsValid = DAG.LegalizeOp(N, UpdatedNodes);
1822
1823 for (SDNode *LN : UpdatedNodes)
1824 AddToWorklistWithUsers(N: LN);
1825
1826 if (!NIsValid)
1827 continue;
1828 }
1829
1830 LLVM_DEBUG(dbgs() << "\nCombining: "; N->dump(&DAG));
1831
1832 // Add any operands of the new node which have not yet been combined to the
1833 // worklist as well. getNextWorklistEntry flags nodes that have been
1834 // combined before. Because the worklist uniques things already, this won't
1835 // repeatedly process the same operand.
1836 for (const SDValue &ChildN : N->op_values())
1837 AddToWorklist(N: ChildN.getNode(), /*IsCandidateForPruning=*/true,
1838 /*SkipIfCombinedBefore=*/true);
1839
1840 SDValue RV = combine(N);
1841
1842 if (!RV.getNode())
1843 continue;
1844
1845 ++NodesCombined;
1846
1847 // Invalidate cached info.
1848 ChainsWithoutMergeableStores.clear();
1849
1850 // If we get back the same node we passed in, rather than a new node or
1851 // zero, we know that the node must have defined multiple values and
1852 // CombineTo was used. Since CombineTo takes care of the worklist
1853 // mechanics for us, we have no work to do in this case.
1854 if (RV.getNode() == N)
1855 continue;
1856
1857 assert(N->getOpcode() != ISD::DELETED_NODE &&
1858 RV.getOpcode() != ISD::DELETED_NODE &&
1859 "Node was deleted but visit returned new node!");
1860
1861 LLVM_DEBUG(dbgs() << " ... into: "; RV.dump(&DAG));
1862
1863 if (N->getNumValues() == RV->getNumValues())
1864 DAG.ReplaceAllUsesWith(From: N, To: RV.getNode());
1865 else {
1866 assert(N->getValueType(0) == RV.getValueType() &&
1867 N->getNumValues() == 1 && "Type mismatch");
1868 DAG.ReplaceAllUsesWith(From: N, To: &RV);
1869 }
1870
1871 // Push the new node and any users onto the worklist. Omit this if the
1872 // new node is the EntryToken (e.g. if a store managed to get optimized
1873 // out), because re-visiting the EntryToken and its users will not uncover
1874 // any additional opportunities, but there may be a large number of such
1875 // users, potentially causing compile time explosion.
1876 if (RV.getOpcode() != ISD::EntryToken)
1877 AddToWorklistWithUsers(N: RV.getNode());
1878
1879 // Finally, if the node is now dead, remove it from the graph. The node
1880 // may not be dead if the replacement process recursively simplified to
1881 // something else needing this node. This will also take care of adding any
1882 // operands which have lost a user to the worklist.
1883 recursivelyDeleteUnusedNodes(N);
1884 }
1885
1886 // If the root changed (e.g. it was a dead load, update the root).
1887 DAG.setRoot(Dummy.getValue());
1888 DAG.RemoveDeadNodes();
1889}
1890
1891SDValue DAGCombiner::visit(SDNode *N) {
1892 // clang-format off
1893 switch (N->getOpcode()) {
1894 default: break;
1895 case ISD::TokenFactor: return visitTokenFactor(N);
1896 case ISD::MERGE_VALUES: return visitMERGE_VALUES(N);
1897 case ISD::ADD: return visitADD(N);
1898 case ISD::PTRADD: return visitPTRADD(N);
1899 case ISD::SUB: return visitSUB(N);
1900 case ISD::SADDSAT:
1901 case ISD::UADDSAT: return visitADDSAT(N);
1902 case ISD::SSUBSAT:
1903 case ISD::USUBSAT: return visitSUBSAT(N);
1904 case ISD::ADDC: return visitADDC(N);
1905 case ISD::SADDO:
1906 case ISD::UADDO: return visitADDO(N);
1907 case ISD::SUBC: return visitSUBC(N);
1908 case ISD::SSUBO:
1909 case ISD::USUBO: return visitSUBO(N);
1910 case ISD::ADDE: return visitADDE(N);
1911 case ISD::UADDO_CARRY: return visitUADDO_CARRY(N);
1912 case ISD::SADDO_CARRY: return visitSADDO_CARRY(N);
1913 case ISD::SUBE: return visitSUBE(N);
1914 case ISD::USUBO_CARRY: return visitUSUBO_CARRY(N);
1915 case ISD::SSUBO_CARRY: return visitSSUBO_CARRY(N);
1916 case ISD::SMULFIX:
1917 case ISD::SMULFIXSAT:
1918 case ISD::UMULFIX:
1919 case ISD::UMULFIXSAT: return visitMULFIX(N);
1920 case ISD::MUL: return visitMUL<EmptyMatchContext>(N);
1921 case ISD::SDIV: return visitSDIV(N);
1922 case ISD::UDIV: return visitUDIV(N);
1923 case ISD::SREM:
1924 case ISD::UREM: return visitREM(N);
1925 case ISD::MULHU: return visitMULHU(N);
1926 case ISD::MULHS: return visitMULHS(N);
1927 case ISD::AVGFLOORS:
1928 case ISD::AVGFLOORU:
1929 case ISD::AVGCEILS:
1930 case ISD::AVGCEILU: return visitAVG(N);
1931 case ISD::ABDS:
1932 case ISD::ABDU: return visitABD(N);
1933 case ISD::SMUL_LOHI: return visitSMUL_LOHI(N);
1934 case ISD::UMUL_LOHI: return visitUMUL_LOHI(N);
1935 case ISD::SMULO:
1936 case ISD::UMULO: return visitMULO(N);
1937 case ISD::SMIN:
1938 case ISD::SMAX:
1939 case ISD::UMIN:
1940 case ISD::UMAX: return visitIMINMAX(N);
1941 case ISD::AND: return visitAND(N);
1942 case ISD::OR: return visitOR(N);
1943 case ISD::XOR: return visitXOR(N);
1944 case ISD::SHL: return visitSHL(N);
1945 case ISD::SRA: return visitSRA(N);
1946 case ISD::SRL: return visitSRL(N);
1947 case ISD::ROTR:
1948 case ISD::ROTL: return visitRotate(N);
1949 case ISD::FSHL:
1950 case ISD::FSHR: return visitFunnelShift(N);
1951 case ISD::SSHLSAT:
1952 case ISD::USHLSAT: return visitSHLSAT(N);
1953 case ISD::ABS: return visitABS(N);
1954 case ISD::BSWAP: return visitBSWAP(N);
1955 case ISD::BITREVERSE: return visitBITREVERSE(N);
1956 case ISD::CTLZ: return visitCTLZ(N);
1957 case ISD::CTLZ_ZERO_UNDEF: return visitCTLZ_ZERO_UNDEF(N);
1958 case ISD::CTTZ: return visitCTTZ(N);
1959 case ISD::CTTZ_ZERO_UNDEF: return visitCTTZ_ZERO_UNDEF(N);
1960 case ISD::CTPOP: return visitCTPOP(N);
1961 case ISD::SELECT: return visitSELECT(N);
1962 case ISD::VSELECT: return visitVSELECT(N);
1963 case ISD::SELECT_CC: return visitSELECT_CC(N);
1964 case ISD::SETCC: return visitSETCC(N);
1965 case ISD::SETCCCARRY: return visitSETCCCARRY(N);
1966 case ISD::SIGN_EXTEND: return visitSIGN_EXTEND(N);
1967 case ISD::ZERO_EXTEND: return visitZERO_EXTEND(N);
1968 case ISD::ANY_EXTEND: return visitANY_EXTEND(N);
1969 case ISD::AssertSext:
1970 case ISD::AssertZext: return visitAssertExt(N);
1971 case ISD::AssertAlign: return visitAssertAlign(N);
1972 case ISD::SIGN_EXTEND_INREG: return visitSIGN_EXTEND_INREG(N);
1973 case ISD::SIGN_EXTEND_VECTOR_INREG:
1974 case ISD::ZERO_EXTEND_VECTOR_INREG:
1975 case ISD::ANY_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N);
1976 case ISD::TRUNCATE: return visitTRUNCATE(N);
1977 case ISD::TRUNCATE_USAT_U: return visitTRUNCATE_USAT_U(N);
1978 case ISD::BITCAST: return visitBITCAST(N);
1979 case ISD::BUILD_PAIR: return visitBUILD_PAIR(N);
1980 case ISD::FADD: return visitFADD(N);
1981 case ISD::STRICT_FADD: return visitSTRICT_FADD(N);
1982 case ISD::FSUB: return visitFSUB(N);
1983 case ISD::FMUL: return visitFMUL(N);
1984 case ISD::FMA: return visitFMA<EmptyMatchContext>(N);
1985 case ISD::FMAD: return visitFMAD(N);
1986 case ISD::FDIV: return visitFDIV(N);
1987 case ISD::FREM: return visitFREM(N);
1988 case ISD::FSQRT: return visitFSQRT(N);
1989 case ISD::FCOPYSIGN: return visitFCOPYSIGN(N);
1990 case ISD::FPOW: return visitFPOW(N);
1991 case ISD::SINT_TO_FP: return visitSINT_TO_FP(N);
1992 case ISD::UINT_TO_FP: return visitUINT_TO_FP(N);
1993 case ISD::FP_TO_SINT: return visitFP_TO_SINT(N);
1994 case ISD::FP_TO_UINT: return visitFP_TO_UINT(N);
1995 case ISD::LROUND:
1996 case ISD::LLROUND:
1997 case ISD::LRINT:
1998 case ISD::LLRINT: return visitXROUND(N);
1999 case ISD::FP_ROUND: return visitFP_ROUND(N);
2000 case ISD::FP_EXTEND: return visitFP_EXTEND(N);
2001 case ISD::FNEG: return visitFNEG(N);
2002 case ISD::FABS: return visitFABS(N);
2003 case ISD::FFLOOR: return visitFFLOOR(N);
2004 case ISD::FMINNUM:
2005 case ISD::FMAXNUM:
2006 case ISD::FMINIMUM:
2007 case ISD::FMAXIMUM:
2008 case ISD::FMINIMUMNUM:
2009 case ISD::FMAXIMUMNUM: return visitFMinMax(N);
2010 case ISD::FCEIL: return visitFCEIL(N);
2011 case ISD::FTRUNC: return visitFTRUNC(N);
2012 case ISD::FFREXP: return visitFFREXP(N);
2013 case ISD::BRCOND: return visitBRCOND(N);
2014 case ISD::BR_CC: return visitBR_CC(N);
2015 case ISD::LOAD: return visitLOAD(N);
2016 case ISD::STORE: return visitSTORE(N);
2017 case ISD::ATOMIC_STORE: return visitATOMIC_STORE(N);
2018 case ISD::INSERT_VECTOR_ELT: return visitINSERT_VECTOR_ELT(N);
2019 case ISD::EXTRACT_VECTOR_ELT: return visitEXTRACT_VECTOR_ELT(N);
2020 case ISD::BUILD_VECTOR: return visitBUILD_VECTOR(N);
2021 case ISD::CONCAT_VECTORS: return visitCONCAT_VECTORS(N);
2022 case ISD::EXTRACT_SUBVECTOR: return visitEXTRACT_SUBVECTOR(N);
2023 case ISD::VECTOR_SHUFFLE: return visitVECTOR_SHUFFLE(N);
2024 case ISD::SCALAR_TO_VECTOR: return visitSCALAR_TO_VECTOR(N);
2025 case ISD::INSERT_SUBVECTOR: return visitINSERT_SUBVECTOR(N);
2026 case ISD::MGATHER: return visitMGATHER(N);
2027 case ISD::MLOAD: return visitMLOAD(N);
2028 case ISD::MSCATTER: return visitMSCATTER(N);
2029 case ISD::MSTORE: return visitMSTORE(N);
2030 case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM: return visitMHISTOGRAM(N);
2031 case ISD::PARTIAL_REDUCE_SMLA:
2032 case ISD::PARTIAL_REDUCE_UMLA:
2033 case ISD::PARTIAL_REDUCE_SUMLA:
2034 return visitPARTIAL_REDUCE_MLA(N);
2035 case ISD::VECTOR_COMPRESS: return visitVECTOR_COMPRESS(N);
2036 case ISD::LIFETIME_END: return visitLIFETIME_END(N);
2037 case ISD::FP_TO_FP16: return visitFP_TO_FP16(N);
2038 case ISD::FP16_TO_FP: return visitFP16_TO_FP(N);
2039 case ISD::FP_TO_BF16: return visitFP_TO_BF16(N);
2040 case ISD::BF16_TO_FP: return visitBF16_TO_FP(N);
2041 case ISD::FREEZE: return visitFREEZE(N);
2042 case ISD::GET_FPENV_MEM: return visitGET_FPENV_MEM(N);
2043 case ISD::SET_FPENV_MEM: return visitSET_FPENV_MEM(N);
2044 case ISD::FCANONICALIZE: return visitFCANONICALIZE(N);
2045 case ISD::VECREDUCE_FADD:
2046 case ISD::VECREDUCE_FMUL:
2047 case ISD::VECREDUCE_ADD:
2048 case ISD::VECREDUCE_MUL:
2049 case ISD::VECREDUCE_AND:
2050 case ISD::VECREDUCE_OR:
2051 case ISD::VECREDUCE_XOR:
2052 case ISD::VECREDUCE_SMAX:
2053 case ISD::VECREDUCE_SMIN:
2054 case ISD::VECREDUCE_UMAX:
2055 case ISD::VECREDUCE_UMIN:
2056 case ISD::VECREDUCE_FMAX:
2057 case ISD::VECREDUCE_FMIN:
2058 case ISD::VECREDUCE_FMAXIMUM:
2059 case ISD::VECREDUCE_FMINIMUM: return visitVECREDUCE(N);
2060#define BEGIN_REGISTER_VP_SDNODE(SDOPC, ...) case ISD::SDOPC:
2061#include "llvm/IR/VPIntrinsics.def"
2062 return visitVPOp(N);
2063 }
2064 // clang-format on
2065 return SDValue();
2066}
2067
2068SDValue DAGCombiner::combine(SDNode *N) {
2069 if (!DebugCounter::shouldExecute(CounterName: DAGCombineCounter))
2070 return SDValue();
2071
2072 SDValue RV;
2073 if (!DisableGenericCombines)
2074 RV = visit(N);
2075
2076 // If nothing happened, try a target-specific DAG combine.
2077 if (!RV.getNode()) {
2078 assert(N->getOpcode() != ISD::DELETED_NODE &&
2079 "Node was deleted but visit returned NULL!");
2080
2081 if (N->getOpcode() >= ISD::BUILTIN_OP_END ||
2082 TLI.hasTargetDAGCombine(NT: (ISD::NodeType)N->getOpcode())) {
2083
2084 // Expose the DAG combiner to the target combiner impls.
2085 TargetLowering::DAGCombinerInfo
2086 DagCombineInfo(DAG, Level, false, this);
2087
2088 RV = TLI.PerformDAGCombine(N, DCI&: DagCombineInfo);
2089 }
2090 }
2091
2092 // If nothing happened still, try promoting the operation.
2093 if (!RV.getNode()) {
2094 switch (N->getOpcode()) {
2095 default: break;
2096 case ISD::ADD:
2097 case ISD::SUB:
2098 case ISD::MUL:
2099 case ISD::AND:
2100 case ISD::OR:
2101 case ISD::XOR:
2102 RV = PromoteIntBinOp(Op: SDValue(N, 0));
2103 break;
2104 case ISD::SHL:
2105 case ISD::SRA:
2106 case ISD::SRL:
2107 RV = PromoteIntShiftOp(Op: SDValue(N, 0));
2108 break;
2109 case ISD::SIGN_EXTEND:
2110 case ISD::ZERO_EXTEND:
2111 case ISD::ANY_EXTEND:
2112 RV = PromoteExtend(Op: SDValue(N, 0));
2113 break;
2114 case ISD::LOAD:
2115 if (PromoteLoad(Op: SDValue(N, 0)))
2116 RV = SDValue(N, 0);
2117 break;
2118 }
2119 }
2120
2121 // If N is a commutative binary node, try to eliminate it if the commuted
2122 // version is already present in the DAG.
2123 if (!RV.getNode() && TLI.isCommutativeBinOp(Opcode: N->getOpcode())) {
2124 SDValue N0 = N->getOperand(Num: 0);
2125 SDValue N1 = N->getOperand(Num: 1);
2126
2127 // Constant operands are canonicalized to RHS.
2128 if (N0 != N1 && (isa<ConstantSDNode>(Val: N0) || !isa<ConstantSDNode>(Val: N1))) {
2129 SDValue Ops[] = {N1, N0};
2130 SDNode *CSENode = DAG.getNodeIfExists(Opcode: N->getOpcode(), VTList: N->getVTList(), Ops,
2131 Flags: N->getFlags());
2132 if (CSENode)
2133 return SDValue(CSENode, 0);
2134 }
2135 }
2136
2137 return RV;
2138}
2139
2140/// Given a node, return its input chain if it has one, otherwise return a null
2141/// sd operand.
2142static SDValue getInputChainForNode(SDNode *N) {
2143 if (unsigned NumOps = N->getNumOperands()) {
2144 if (N->getOperand(Num: 0).getValueType() == MVT::Other)
2145 return N->getOperand(Num: 0);
2146 if (N->getOperand(Num: NumOps-1).getValueType() == MVT::Other)
2147 return N->getOperand(Num: NumOps-1);
2148 for (unsigned i = 1; i < NumOps-1; ++i)
2149 if (N->getOperand(Num: i).getValueType() == MVT::Other)
2150 return N->getOperand(Num: i);
2151 }
2152 return SDValue();
2153}
2154
2155SDValue DAGCombiner::visitFCANONICALIZE(SDNode *N) {
2156 SDValue Operand = N->getOperand(Num: 0);
2157 EVT VT = Operand.getValueType();
2158 SDLoc dl(N);
2159
2160 // Canonicalize undef to quiet NaN.
2161 if (Operand.isUndef()) {
2162 APFloat CanonicalQNaN = APFloat::getQNaN(Sem: VT.getFltSemantics());
2163 return DAG.getConstantFP(Val: CanonicalQNaN, DL: dl, VT);
2164 }
2165 return SDValue();
2166}
2167
2168SDValue DAGCombiner::visitTokenFactor(SDNode *N) {
2169 // If N has two operands, where one has an input chain equal to the other,
2170 // the 'other' chain is redundant.
2171 if (N->getNumOperands() == 2) {
2172 if (getInputChainForNode(N: N->getOperand(Num: 0).getNode()) == N->getOperand(Num: 1))
2173 return N->getOperand(Num: 0);
2174 if (getInputChainForNode(N: N->getOperand(Num: 1).getNode()) == N->getOperand(Num: 0))
2175 return N->getOperand(Num: 1);
2176 }
2177
2178 // Don't simplify token factors if optnone.
2179 if (OptLevel == CodeGenOptLevel::None)
2180 return SDValue();
2181
2182 // Don't simplify the token factor if the node itself has too many operands.
2183 if (N->getNumOperands() > TokenFactorInlineLimit)
2184 return SDValue();
2185
2186 // If the sole user is a token factor, we should make sure we have a
2187 // chance to merge them together. This prevents TF chains from inhibiting
2188 // optimizations.
2189 if (N->hasOneUse() && N->user_begin()->getOpcode() == ISD::TokenFactor)
2190 AddToWorklist(N: *(N->user_begin()));
2191
2192 SmallVector<SDNode *, 8> TFs; // List of token factors to visit.
2193 SmallVector<SDValue, 8> Ops; // Ops for replacing token factor.
2194 SmallPtrSet<SDNode*, 16> SeenOps;
2195 bool Changed = false; // If we should replace this token factor.
2196
2197 // Start out with this token factor.
2198 TFs.push_back(Elt: N);
2199
2200 // Iterate through token factors. The TFs grows when new token factors are
2201 // encountered.
2202 for (unsigned i = 0; i < TFs.size(); ++i) {
2203 // Limit number of nodes to inline, to avoid quadratic compile times.
2204 // We have to add the outstanding Token Factors to Ops, otherwise we might
2205 // drop Ops from the resulting Token Factors.
2206 if (Ops.size() > TokenFactorInlineLimit) {
2207 for (unsigned j = i; j < TFs.size(); j++)
2208 Ops.emplace_back(Args&: TFs[j], Args: 0);
2209 // Drop unprocessed Token Factors from TFs, so we do not add them to the
2210 // combiner worklist later.
2211 TFs.resize(N: i);
2212 break;
2213 }
2214
2215 SDNode *TF = TFs[i];
2216 // Check each of the operands.
2217 for (const SDValue &Op : TF->op_values()) {
2218 switch (Op.getOpcode()) {
2219 case ISD::EntryToken:
2220 // Entry tokens don't need to be added to the list. They are
2221 // redundant.
2222 Changed = true;
2223 break;
2224
2225 case ISD::TokenFactor:
2226 if (Op.hasOneUse() && !is_contained(Range&: TFs, Element: Op.getNode())) {
2227 // Queue up for processing.
2228 TFs.push_back(Elt: Op.getNode());
2229 Changed = true;
2230 break;
2231 }
2232 [[fallthrough]];
2233
2234 default:
2235 // Only add if it isn't already in the list.
2236 if (SeenOps.insert(Ptr: Op.getNode()).second)
2237 Ops.push_back(Elt: Op);
2238 else
2239 Changed = true;
2240 break;
2241 }
2242 }
2243 }
2244
2245 // Re-visit inlined Token Factors, to clean them up in case they have been
2246 // removed. Skip the first Token Factor, as this is the current node.
2247 for (unsigned i = 1, e = TFs.size(); i < e; i++)
2248 AddToWorklist(N: TFs[i]);
2249
2250 // Remove Nodes that are chained to another node in the list. Do so
2251 // by walking up chains breath-first stopping when we've seen
2252 // another operand. In general we must climb to the EntryNode, but we can exit
2253 // early if we find all remaining work is associated with just one operand as
2254 // no further pruning is possible.
2255
2256 // List of nodes to search through and original Ops from which they originate.
2257 SmallVector<std::pair<SDNode *, unsigned>, 8> Worklist;
2258 SmallVector<unsigned, 8> OpWorkCount; // Count of work for each Op.
2259 SmallPtrSet<SDNode *, 16> SeenChains;
2260 bool DidPruneOps = false;
2261
2262 unsigned NumLeftToConsider = 0;
2263 for (const SDValue &Op : Ops) {
2264 Worklist.push_back(Elt: std::make_pair(x: Op.getNode(), y: NumLeftToConsider++));
2265 OpWorkCount.push_back(Elt: 1);
2266 }
2267
2268 auto AddToWorklist = [&](unsigned CurIdx, SDNode *Op, unsigned OpNumber) {
2269 // If this is an Op, we can remove the op from the list. Remark any
2270 // search associated with it as from the current OpNumber.
2271 if (SeenOps.contains(Ptr: Op)) {
2272 Changed = true;
2273 DidPruneOps = true;
2274 unsigned OrigOpNumber = 0;
2275 while (OrigOpNumber < Ops.size() && Ops[OrigOpNumber].getNode() != Op)
2276 OrigOpNumber++;
2277 assert((OrigOpNumber != Ops.size()) &&
2278 "expected to find TokenFactor Operand");
2279 // Re-mark worklist from OrigOpNumber to OpNumber
2280 for (unsigned i = CurIdx + 1; i < Worklist.size(); ++i) {
2281 if (Worklist[i].second == OrigOpNumber) {
2282 Worklist[i].second = OpNumber;
2283 }
2284 }
2285 OpWorkCount[OpNumber] += OpWorkCount[OrigOpNumber];
2286 OpWorkCount[OrigOpNumber] = 0;
2287 NumLeftToConsider--;
2288 }
2289 // Add if it's a new chain
2290 if (SeenChains.insert(Ptr: Op).second) {
2291 OpWorkCount[OpNumber]++;
2292 Worklist.push_back(Elt: std::make_pair(x&: Op, y&: OpNumber));
2293 }
2294 };
2295
2296 for (unsigned i = 0; i < Worklist.size() && i < 1024; ++i) {
2297 // We need at least be consider at least 2 Ops to prune.
2298 if (NumLeftToConsider <= 1)
2299 break;
2300 auto CurNode = Worklist[i].first;
2301 auto CurOpNumber = Worklist[i].second;
2302 assert((OpWorkCount[CurOpNumber] > 0) &&
2303 "Node should not appear in worklist");
2304 switch (CurNode->getOpcode()) {
2305 case ISD::EntryToken:
2306 // Hitting EntryToken is the only way for the search to terminate without
2307 // hitting
2308 // another operand's search. Prevent us from marking this operand
2309 // considered.
2310 NumLeftToConsider++;
2311 break;
2312 case ISD::TokenFactor:
2313 for (const SDValue &Op : CurNode->op_values())
2314 AddToWorklist(i, Op.getNode(), CurOpNumber);
2315 break;
2316 case ISD::LIFETIME_START:
2317 case ISD::LIFETIME_END:
2318 case ISD::CopyFromReg:
2319 case ISD::CopyToReg:
2320 AddToWorklist(i, CurNode->getOperand(Num: 0).getNode(), CurOpNumber);
2321 break;
2322 default:
2323 if (auto *MemNode = dyn_cast<MemSDNode>(Val: CurNode))
2324 AddToWorklist(i, MemNode->getChain().getNode(), CurOpNumber);
2325 break;
2326 }
2327 OpWorkCount[CurOpNumber]--;
2328 if (OpWorkCount[CurOpNumber] == 0)
2329 NumLeftToConsider--;
2330 }
2331
2332 // If we've changed things around then replace token factor.
2333 if (Changed) {
2334 SDValue Result;
2335 if (Ops.empty()) {
2336 // The entry token is the only possible outcome.
2337 Result = DAG.getEntryNode();
2338 } else {
2339 if (DidPruneOps) {
2340 SmallVector<SDValue, 8> PrunedOps;
2341 //
2342 for (const SDValue &Op : Ops) {
2343 if (SeenChains.count(Ptr: Op.getNode()) == 0)
2344 PrunedOps.push_back(Elt: Op);
2345 }
2346 Result = DAG.getTokenFactor(DL: SDLoc(N), Vals&: PrunedOps);
2347 } else {
2348 Result = DAG.getTokenFactor(DL: SDLoc(N), Vals&: Ops);
2349 }
2350 }
2351 return Result;
2352 }
2353 return SDValue();
2354}
2355
2356/// MERGE_VALUES can always be eliminated.
2357SDValue DAGCombiner::visitMERGE_VALUES(SDNode *N) {
2358 WorklistRemover DeadNodes(*this);
2359 // Replacing results may cause a different MERGE_VALUES to suddenly
2360 // be CSE'd with N, and carry its uses with it. Iterate until no
2361 // uses remain, to ensure that the node can be safely deleted.
2362 // First add the users of this node to the work list so that they
2363 // can be tried again once they have new operands.
2364 AddUsersToWorklist(N);
2365 do {
2366 // Do as a single replacement to avoid rewalking use lists.
2367 SmallVector<SDValue, 8> Ops(N->ops());
2368 DAG.ReplaceAllUsesWith(From: N, To: Ops.data());
2369 } while (!N->use_empty());
2370 deleteAndRecombine(N);
2371 return SDValue(N, 0); // Return N so it doesn't get rechecked!
2372}
2373
2374/// If \p N is a ConstantSDNode with isOpaque() == false return it casted to a
2375/// ConstantSDNode pointer else nullptr.
2376static ConstantSDNode *getAsNonOpaqueConstant(SDValue N) {
2377 ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Val&: N);
2378 return Const != nullptr && !Const->isOpaque() ? Const : nullptr;
2379}
2380
2381// isTruncateOf - If N is a truncate of some other value, return true, record
2382// the value being truncated in Op and which of Op's bits are zero/one in Known.
2383// This function computes KnownBits to avoid a duplicated call to
2384// computeKnownBits in the caller.
2385static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
2386 KnownBits &Known) {
2387 if (N->getOpcode() == ISD::TRUNCATE) {
2388 Op = N->getOperand(Num: 0);
2389 Known = DAG.computeKnownBits(Op);
2390 if (N->getFlags().hasNoUnsignedWrap())
2391 Known.Zero.setBitsFrom(N.getScalarValueSizeInBits());
2392 return true;
2393 }
2394
2395 if (N.getValueType().getScalarType() != MVT::i1 ||
2396 !sd_match(
2397 N, P: m_c_SetCC(LHS: m_Value(N&: Op), RHS: m_Zero(), CC: m_SpecificCondCode(CC: ISD::SETNE))))
2398 return false;
2399
2400 Known = DAG.computeKnownBits(Op);
2401 return (Known.Zero | 1).isAllOnes();
2402}
2403
2404/// Return true if 'Use' is a load or a store that uses N as its base pointer
2405/// and that N may be folded in the load / store addressing mode.
2406static bool canFoldInAddressingMode(SDNode *N, SDNode *Use, SelectionDAG &DAG,
2407 const TargetLowering &TLI) {
2408 EVT VT;
2409 unsigned AS;
2410
2411 if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Val: Use)) {
2412 if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2413 return false;
2414 VT = LD->getMemoryVT();
2415 AS = LD->getAddressSpace();
2416 } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(Val: Use)) {
2417 if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2418 return false;
2419 VT = ST->getMemoryVT();
2420 AS = ST->getAddressSpace();
2421 } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(Val: Use)) {
2422 if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2423 return false;
2424 VT = LD->getMemoryVT();
2425 AS = LD->getAddressSpace();
2426 } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(Val: Use)) {
2427 if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2428 return false;
2429 VT = ST->getMemoryVT();
2430 AS = ST->getAddressSpace();
2431 } else {
2432 return false;
2433 }
2434
2435 TargetLowering::AddrMode AM;
2436 if (N->isAnyAdd()) {
2437 AM.HasBaseReg = true;
2438 ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
2439 if (Offset)
2440 // [reg +/- imm]
2441 AM.BaseOffs = Offset->getSExtValue();
2442 else
2443 // [reg +/- reg]
2444 AM.Scale = 1;
2445 } else if (N->getOpcode() == ISD::SUB) {
2446 AM.HasBaseReg = true;
2447 ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
2448 if (Offset)
2449 // [reg +/- imm]
2450 AM.BaseOffs = -Offset->getSExtValue();
2451 else
2452 // [reg +/- reg]
2453 AM.Scale = 1;
2454 } else {
2455 return false;
2456 }
2457
2458 return TLI.isLegalAddressingMode(DL: DAG.getDataLayout(), AM,
2459 Ty: VT.getTypeForEVT(Context&: *DAG.getContext()), AddrSpace: AS);
2460}
2461
2462/// This inverts a canonicalization in IR that replaces a variable select arm
2463/// with an identity constant. Codegen improves if we re-use the variable
2464/// operand rather than load a constant. This can also be converted into a
2465/// masked vector operation if the target supports it.
2466static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
2467 bool ShouldCommuteOperands) {
2468 // Match a select as operand 1. The identity constant that we are looking for
2469 // is only valid as operand 1 of a non-commutative binop.
2470 SDValue N0 = N->getOperand(Num: 0);
2471 SDValue N1 = N->getOperand(Num: 1);
2472 if (ShouldCommuteOperands)
2473 std::swap(a&: N0, b&: N1);
2474
2475 unsigned SelOpcode = N1.getOpcode();
2476 if ((SelOpcode != ISD::VSELECT && SelOpcode != ISD::SELECT) ||
2477 !N1.hasOneUse())
2478 return SDValue();
2479
2480 // We can't hoist all instructions because of immediate UB (not speculatable).
2481 // For example div/rem by zero.
2482 if (!DAG.isSafeToSpeculativelyExecuteNode(N))
2483 return SDValue();
2484
2485 unsigned Opcode = N->getOpcode();
2486 EVT VT = N->getValueType(ResNo: 0);
2487 SDValue Cond = N1.getOperand(i: 0);
2488 SDValue TVal = N1.getOperand(i: 1);
2489 SDValue FVal = N1.getOperand(i: 2);
2490 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2491
2492 // This transform increases uses of N0, so freeze it to be safe.
2493 // binop N0, (vselect Cond, IDC, FVal) --> vselect Cond, N0, (binop N0, FVal)
2494 unsigned OpNo = ShouldCommuteOperands ? 0 : 1;
2495 if (isNeutralConstant(Opc: Opcode, Flags: N->getFlags(), V: TVal, OperandNo: OpNo) &&
2496 TLI.shouldFoldSelectWithIdentityConstant(BinOpcode: Opcode, VT, SelectOpcode: SelOpcode, X: N0,
2497 Y: FVal)) {
2498 SDValue F0 = DAG.getFreeze(V: N0);
2499 SDValue NewBO = DAG.getNode(Opcode, DL: SDLoc(N), VT, N1: F0, N2: FVal, Flags: N->getFlags());
2500 return DAG.getSelect(DL: SDLoc(N), VT, Cond, LHS: F0, RHS: NewBO);
2501 }
2502 // binop N0, (vselect Cond, TVal, IDC) --> vselect Cond, (binop N0, TVal), N0
2503 if (isNeutralConstant(Opc: Opcode, Flags: N->getFlags(), V: FVal, OperandNo: OpNo) &&
2504 TLI.shouldFoldSelectWithIdentityConstant(BinOpcode: Opcode, VT, SelectOpcode: SelOpcode, X: N0,
2505 Y: TVal)) {
2506 SDValue F0 = DAG.getFreeze(V: N0);
2507 SDValue NewBO = DAG.getNode(Opcode, DL: SDLoc(N), VT, N1: F0, N2: TVal, Flags: N->getFlags());
2508 return DAG.getSelect(DL: SDLoc(N), VT, Cond, LHS: NewBO, RHS: F0);
2509 }
2510
2511 return SDValue();
2512}
2513
2514SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
2515 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2516 assert(TLI.isBinOp(BO->getOpcode()) && BO->getNumValues() == 1 &&
2517 "Unexpected binary operator");
2518
2519 if (SDValue Sel = foldSelectWithIdentityConstant(N: BO, DAG, ShouldCommuteOperands: false))
2520 return Sel;
2521
2522 if (TLI.isCommutativeBinOp(Opcode: BO->getOpcode()))
2523 if (SDValue Sel = foldSelectWithIdentityConstant(N: BO, DAG, ShouldCommuteOperands: true))
2524 return Sel;
2525
2526 // Don't do this unless the old select is going away. We want to eliminate the
2527 // binary operator, not replace a binop with a select.
2528 // TODO: Handle ISD::SELECT_CC.
2529 unsigned SelOpNo = 0;
2530 SDValue Sel = BO->getOperand(Num: 0);
2531 auto BinOpcode = BO->getOpcode();
2532 if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) {
2533 SelOpNo = 1;
2534 Sel = BO->getOperand(Num: 1);
2535
2536 // Peek through trunc to shift amount type.
2537 if ((BinOpcode == ISD::SHL || BinOpcode == ISD::SRA ||
2538 BinOpcode == ISD::SRL) && Sel.hasOneUse()) {
2539 // This is valid when the truncated bits of x are already zero.
2540 SDValue Op;
2541 KnownBits Known;
2542 if (isTruncateOf(DAG, N: Sel, Op, Known) &&
2543 Known.countMaxActiveBits() < Sel.getScalarValueSizeInBits())
2544 Sel = Op;
2545 }
2546 }
2547
2548 if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse())
2549 return SDValue();
2550
2551 SDValue CT = Sel.getOperand(i: 1);
2552 if (!isConstantOrConstantVector(N: CT, NoOpaques: true) &&
2553 !DAG.isConstantFPBuildVectorOrConstantFP(N: CT))
2554 return SDValue();
2555
2556 SDValue CF = Sel.getOperand(i: 2);
2557 if (!isConstantOrConstantVector(N: CF, NoOpaques: true) &&
2558 !DAG.isConstantFPBuildVectorOrConstantFP(N: CF))
2559 return SDValue();
2560
2561 // Bail out if any constants are opaque because we can't constant fold those.
2562 // The exception is "and" and "or" with either 0 or -1 in which case we can
2563 // propagate non constant operands into select. I.e.:
2564 // and (select Cond, 0, -1), X --> select Cond, 0, X
2565 // or X, (select Cond, -1, 0) --> select Cond, -1, X
2566 bool CanFoldNonConst =
2567 (BinOpcode == ISD::AND || BinOpcode == ISD::OR) &&
2568 ((isNullOrNullSplat(V: CT) && isAllOnesOrAllOnesSplat(V: CF)) ||
2569 (isNullOrNullSplat(V: CF) && isAllOnesOrAllOnesSplat(V: CT)));
2570
2571 SDValue CBO = BO->getOperand(Num: SelOpNo ^ 1);
2572 if (!CanFoldNonConst &&
2573 !isConstantOrConstantVector(N: CBO, NoOpaques: true) &&
2574 !DAG.isConstantFPBuildVectorOrConstantFP(N: CBO))
2575 return SDValue();
2576
2577 SDLoc DL(Sel);
2578 SDValue NewCT, NewCF;
2579 EVT VT = BO->getValueType(ResNo: 0);
2580
2581 if (CanFoldNonConst) {
2582 // If CBO is an opaque constant, we can't rely on getNode to constant fold.
2583 if ((BinOpcode == ISD::AND && isNullOrNullSplat(V: CT)) ||
2584 (BinOpcode == ISD::OR && isAllOnesOrAllOnesSplat(V: CT)))
2585 NewCT = CT;
2586 else
2587 NewCT = CBO;
2588
2589 if ((BinOpcode == ISD::AND && isNullOrNullSplat(V: CF)) ||
2590 (BinOpcode == ISD::OR && isAllOnesOrAllOnesSplat(V: CF)))
2591 NewCF = CF;
2592 else
2593 NewCF = CBO;
2594 } else {
2595 // We have a select-of-constants followed by a binary operator with a
2596 // constant. Eliminate the binop by pulling the constant math into the
2597 // select. Example: add (select Cond, CT, CF), CBO --> select Cond, CT +
2598 // CBO, CF + CBO
2599 NewCT = SelOpNo ? DAG.FoldConstantArithmetic(Opcode: BinOpcode, DL, VT, Ops: {CBO, CT})
2600 : DAG.FoldConstantArithmetic(Opcode: BinOpcode, DL, VT, Ops: {CT, CBO});
2601 if (!NewCT)
2602 return SDValue();
2603
2604 NewCF = SelOpNo ? DAG.FoldConstantArithmetic(Opcode: BinOpcode, DL, VT, Ops: {CBO, CF})
2605 : DAG.FoldConstantArithmetic(Opcode: BinOpcode, DL, VT, Ops: {CF, CBO});
2606 if (!NewCF)
2607 return SDValue();
2608 }
2609
2610 SDValue SelectOp = DAG.getSelect(DL, VT, Cond: Sel.getOperand(i: 0), LHS: NewCT, RHS: NewCF);
2611 SelectOp->setFlags(BO->getFlags());
2612 return SelectOp;
2613}
2614
2615static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, const SDLoc &DL,
2616 SelectionDAG &DAG) {
2617 assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2618 "Expecting add or sub");
2619
2620 // Match a constant operand and a zext operand for the math instruction:
2621 // add Z, C
2622 // sub C, Z
2623 bool IsAdd = N->getOpcode() == ISD::ADD;
2624 SDValue C = IsAdd ? N->getOperand(Num: 1) : N->getOperand(Num: 0);
2625 SDValue Z = IsAdd ? N->getOperand(Num: 0) : N->getOperand(Num: 1);
2626 auto *CN = dyn_cast<ConstantSDNode>(Val&: C);
2627 if (!CN || Z.getOpcode() != ISD::ZERO_EXTEND)
2628 return SDValue();
2629
2630 // Match the zext operand as a setcc of a boolean.
2631 if (Z.getOperand(i: 0).getValueType() != MVT::i1)
2632 return SDValue();
2633
2634 // Match the compare as: setcc (X & 1), 0, eq.
2635 if (!sd_match(N: Z.getOperand(i: 0), P: m_SetCC(LHS: m_And(L: m_Value(), R: m_One()), RHS: m_Zero(),
2636 CC: m_SpecificCondCode(CC: ISD::SETEQ))))
2637 return SDValue();
2638
2639 // We are adding/subtracting a constant and an inverted low bit. Turn that
2640 // into a subtract/add of the low bit with incremented/decremented constant:
2641 // add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1))
2642 // sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1))
2643 EVT VT = C.getValueType();
2644 SDValue LowBit = DAG.getZExtOrTrunc(Op: Z.getOperand(i: 0).getOperand(i: 0), DL, VT);
2645 SDValue C1 = IsAdd ? DAG.getConstant(Val: CN->getAPIntValue() + 1, DL, VT)
2646 : DAG.getConstant(Val: CN->getAPIntValue() - 1, DL, VT);
2647 return DAG.getNode(Opcode: IsAdd ? ISD::SUB : ISD::ADD, DL, VT, N1: C1, N2: LowBit);
2648}
2649
2650// Attempt to form avgceil(A, B) from (A | B) - ((A ^ B) >> 1)
2651SDValue DAGCombiner::foldSubToAvg(SDNode *N, const SDLoc &DL) {
2652 SDValue N0 = N->getOperand(Num: 0);
2653 EVT VT = N0.getValueType();
2654 SDValue A, B;
2655
2656 if ((!LegalOperations || hasOperation(Opcode: ISD::AVGCEILU, VT)) &&
2657 sd_match(N, P: m_Sub(L: m_Or(L: m_Value(N&: A), R: m_Value(N&: B)),
2658 R: m_Srl(L: m_Xor(L: m_Deferred(V&: A), R: m_Deferred(V&: B)), R: m_One())))) {
2659 return DAG.getNode(Opcode: ISD::AVGCEILU, DL, VT, N1: A, N2: B);
2660 }
2661 if ((!LegalOperations || hasOperation(Opcode: ISD::AVGCEILS, VT)) &&
2662 sd_match(N, P: m_Sub(L: m_Or(L: m_Value(N&: A), R: m_Value(N&: B)),
2663 R: m_Sra(L: m_Xor(L: m_Deferred(V&: A), R: m_Deferred(V&: B)), R: m_One())))) {
2664 return DAG.getNode(Opcode: ISD::AVGCEILS, DL, VT, N1: A, N2: B);
2665 }
2666 return SDValue();
2667}
2668
2669/// Try to fold a pointer arithmetic node.
2670/// This needs to be done separately from normal addition, because pointer
2671/// addition is not commutative.
2672SDValue DAGCombiner::visitPTRADD(SDNode *N) {
2673 SDValue N0 = N->getOperand(Num: 0);
2674 SDValue N1 = N->getOperand(Num: 1);
2675 EVT PtrVT = N0.getValueType();
2676 EVT IntVT = N1.getValueType();
2677 SDLoc DL(N);
2678
2679 // This is already ensured by an assert in SelectionDAG::getNode(). Several
2680 // combines here depend on this assumption.
2681 assert(PtrVT == IntVT &&
2682 "PTRADD with different operand types is not supported");
2683
2684 // fold (ptradd x, 0) -> x
2685 if (isNullConstant(V: N1))
2686 return N0;
2687
2688 // fold (ptradd 0, x) -> x
2689 if (PtrVT == IntVT && isNullConstant(V: N0))
2690 return N1;
2691
2692 if (N0.getOpcode() != ISD::PTRADD ||
2693 reassociationCanBreakAddressingModePattern(Opc: ISD::PTRADD, DL, N, N0, N1))
2694 return SDValue();
2695
2696 SDValue X = N0.getOperand(i: 0);
2697 SDValue Y = N0.getOperand(i: 1);
2698 SDValue Z = N1;
2699 bool N0OneUse = N0.hasOneUse();
2700 bool YIsConstant = DAG.isConstantIntBuildVectorOrConstantInt(N: Y);
2701 bool ZIsConstant = DAG.isConstantIntBuildVectorOrConstantInt(N: Z);
2702
2703 // (ptradd (ptradd x, y), z) -> (ptradd x, (add y, z)) if:
2704 // * y is a constant and (ptradd x, y) has one use; or
2705 // * y and z are both constants.
2706 if ((YIsConstant && N0OneUse) || (YIsConstant && ZIsConstant)) {
2707 // If both additions in the original were NUW, the new ones are as well.
2708 SDNodeFlags Flags =
2709 (N->getFlags() & N0->getFlags()) & SDNodeFlags::NoUnsignedWrap;
2710 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT: IntVT, Ops: {Y, Z}, Flags);
2711 AddToWorklist(N: Add.getNode());
2712 return DAG.getMemBasePlusOffset(Base: X, Offset: Add, DL, Flags);
2713 }
2714
2715 // TODO: There is another possible fold here that was proven useful.
2716 // It would be this:
2717 //
2718 // (ptradd (ptradd x, y), z) -> (ptradd (ptradd x, z), y) if:
2719 // * (ptradd x, y) has one use; and
2720 // * y is a constant; and
2721 // * z is not a constant.
2722 //
2723 // In some cases, specifically in AArch64's FEAT_CPA, it exposes the
2724 // opportunity to select more complex instructions such as SUBPT and
2725 // MSUBPT. However, a hypothetical corner case has been found that we could
2726 // not avoid. Consider this (pseudo-POSIX C):
2727 //
2728 // char *foo(char *x, int z) {return (x + LARGE_CONSTANT) + z;}
2729 // char *p = mmap(LARGE_CONSTANT);
2730 // char *q = foo(p, -LARGE_CONSTANT);
2731 //
2732 // Then x + LARGE_CONSTANT is one-past-the-end, so valid, and a
2733 // further + z takes it back to the start of the mapping, so valid,
2734 // regardless of the address mmap gave back. However, if mmap gives you an
2735 // address < LARGE_CONSTANT (ignoring high bits), x - LARGE_CONSTANT will
2736 // borrow from the high bits (with the subsequent + z carrying back into
2737 // the high bits to give you a well-defined pointer) and thus trip
2738 // FEAT_CPA's pointer corruption checks.
2739 //
2740 // We leave this fold as an opportunity for future work, addressing the
2741 // corner case for FEAT_CPA, as well as reconciling the solution with the
2742 // more general application of pointer arithmetic in other future targets.
2743 // For now each architecture that wants this fold must implement it in the
2744 // target-specific code (see e.g. SITargetLowering::performPtrAddCombine)
2745
2746 return SDValue();
2747}
2748
2749/// Try to fold a 'not' shifted sign-bit with add/sub with constant operand into
2750/// a shift and add with a different constant.
2751static SDValue foldAddSubOfSignBit(SDNode *N, const SDLoc &DL,
2752 SelectionDAG &DAG) {
2753 assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2754 "Expecting add or sub");
2755
2756 // We need a constant operand for the add/sub, and the other operand is a
2757 // logical shift right: add (srl), C or sub C, (srl).
2758 bool IsAdd = N->getOpcode() == ISD::ADD;
2759 SDValue ConstantOp = IsAdd ? N->getOperand(Num: 1) : N->getOperand(Num: 0);
2760 SDValue ShiftOp = IsAdd ? N->getOperand(Num: 0) : N->getOperand(Num: 1);
2761 if (!DAG.isConstantIntBuildVectorOrConstantInt(N: ConstantOp) ||
2762 ShiftOp.getOpcode() != ISD::SRL)
2763 return SDValue();
2764
2765 // The shift must be of a 'not' value.
2766 SDValue Not = ShiftOp.getOperand(i: 0);
2767 if (!Not.hasOneUse() || !isBitwiseNot(V: Not))
2768 return SDValue();
2769
2770 // The shift must be moving the sign bit to the least-significant-bit.
2771 EVT VT = ShiftOp.getValueType();
2772 SDValue ShAmt = ShiftOp.getOperand(i: 1);
2773 ConstantSDNode *ShAmtC = isConstOrConstSplat(N: ShAmt);
2774 if (!ShAmtC || ShAmtC->getAPIntValue() != (VT.getScalarSizeInBits() - 1))
2775 return SDValue();
2776
2777 // Eliminate the 'not' by adjusting the shift and add/sub constant:
2778 // add (srl (not X), 31), C --> add (sra X, 31), (C + 1)
2779 // sub C, (srl (not X), 31) --> add (srl X, 31), (C - 1)
2780 if (SDValue NewC = DAG.FoldConstantArithmetic(
2781 Opcode: IsAdd ? ISD::ADD : ISD::SUB, DL, VT,
2782 Ops: {ConstantOp, DAG.getConstant(Val: 1, DL, VT)})) {
2783 SDValue NewShift = DAG.getNode(Opcode: IsAdd ? ISD::SRA : ISD::SRL, DL, VT,
2784 N1: Not.getOperand(i: 0), N2: ShAmt);
2785 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: NewShift, N2: NewC);
2786 }
2787
2788 return SDValue();
2789}
2790
2791static bool
2792areBitwiseNotOfEachother(SDValue Op0, SDValue Op1) {
2793 return (isBitwiseNot(V: Op0) && Op0.getOperand(i: 0) == Op1) ||
2794 (isBitwiseNot(V: Op1) && Op1.getOperand(i: 0) == Op0);
2795}
2796
2797/// Try to fold a node that behaves like an ADD (note that N isn't necessarily
2798/// an ISD::ADD here, it could for example be an ISD::OR if we know that there
2799/// are no common bits set in the operands).
2800SDValue DAGCombiner::visitADDLike(SDNode *N) {
2801 SDValue N0 = N->getOperand(Num: 0);
2802 SDValue N1 = N->getOperand(Num: 1);
2803 EVT VT = N0.getValueType();
2804 SDLoc DL(N);
2805
2806 // fold (add x, undef) -> undef
2807 if (N0.isUndef())
2808 return N0;
2809 if (N1.isUndef())
2810 return N1;
2811
2812 // fold (add c1, c2) -> c1+c2
2813 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::ADD, DL, VT, Ops: {N0, N1}))
2814 return C;
2815
2816 // canonicalize constant to RHS
2817 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
2818 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
2819 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: N0);
2820
2821 if (areBitwiseNotOfEachother(Op0: N0, Op1: N1))
2822 return DAG.getConstant(Val: APInt::getAllOnes(numBits: VT.getScalarSizeInBits()), DL, VT);
2823
2824 // fold vector ops
2825 if (VT.isVector()) {
2826 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
2827 return FoldedVOp;
2828
2829 // fold (add x, 0) -> x, vector edition
2830 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
2831 return N0;
2832 }
2833
2834 // fold (add x, 0) -> x
2835 if (isNullConstant(V: N1))
2836 return N0;
2837
2838 if (N0.getOpcode() == ISD::SUB) {
2839 SDValue N00 = N0.getOperand(i: 0);
2840 SDValue N01 = N0.getOperand(i: 1);
2841
2842 // fold ((A-c1)+c2) -> (A+(c2-c1))
2843 if (SDValue Sub = DAG.FoldConstantArithmetic(Opcode: ISD::SUB, DL, VT, Ops: {N1, N01}))
2844 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: Sub);
2845
2846 // fold ((c1-A)+c2) -> (c1+c2)-A
2847 if (SDValue Add = DAG.FoldConstantArithmetic(Opcode: ISD::ADD, DL, VT, Ops: {N1, N00}))
2848 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Add, N2: N0.getOperand(i: 1));
2849 }
2850
2851 // add (sext i1 X), 1 -> zext (not i1 X)
2852 // We don't transform this pattern:
2853 // add (zext i1 X), -1 -> sext (not i1 X)
2854 // because most (?) targets generate better code for the zext form.
2855 if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
2856 isOneOrOneSplat(V: N1)) {
2857 SDValue X = N0.getOperand(i: 0);
2858 if ((!LegalOperations ||
2859 (TLI.isOperationLegal(Op: ISD::XOR, VT: X.getValueType()) &&
2860 TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT))) &&
2861 X.getScalarValueSizeInBits() == 1) {
2862 SDValue Not = DAG.getNOT(DL, Val: X, VT: X.getValueType());
2863 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: Not);
2864 }
2865 }
2866
2867 // Fold (add (or x, c0), c1) -> (add x, (c0 + c1))
2868 // iff (or x, c0) is equivalent to (add x, c0).
2869 // Fold (add (xor x, c0), c1) -> (add x, (c0 + c1))
2870 // iff (xor x, c0) is equivalent to (add x, c0).
2871 if (DAG.isADDLike(Op: N0)) {
2872 SDValue N01 = N0.getOperand(i: 1);
2873 if (SDValue Add = DAG.FoldConstantArithmetic(Opcode: ISD::ADD, DL, VT, Ops: {N1, N01}))
2874 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: Add);
2875 }
2876
2877 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
2878 return NewSel;
2879
2880 // reassociate add
2881 if (!reassociationCanBreakAddressingModePattern(Opc: ISD::ADD, DL, N, N0, N1)) {
2882 if (SDValue RADD = reassociateOps(Opc: ISD::ADD, DL, N0, N1, Flags: N->getFlags()))
2883 return RADD;
2884
2885 // Reassociate (add (or x, c), y) -> (add add(x, y), c)) if (or x, c) is
2886 // equivalent to (add x, c).
2887 // Reassociate (add (xor x, c), y) -> (add add(x, y), c)) if (xor x, c) is
2888 // equivalent to (add x, c).
2889 // Do this optimization only when adding c does not introduce instructions
2890 // for adding carries.
2891 auto ReassociateAddOr = [&](SDValue N0, SDValue N1) {
2892 if (DAG.isADDLike(Op: N0) && N0.hasOneUse() &&
2893 isConstantOrConstantVector(N: N0.getOperand(i: 1), /* NoOpaque */ NoOpaques: true)) {
2894 // If N0's type does not split or is a sign mask, it does not introduce
2895 // add carry.
2896 auto TyActn = TLI.getTypeAction(Context&: *DAG.getContext(), VT: N0.getValueType());
2897 bool NoAddCarry = TyActn == TargetLoweringBase::TypeLegal ||
2898 TyActn == TargetLoweringBase::TypePromoteInteger ||
2899 isMinSignedConstant(V: N0.getOperand(i: 1));
2900 if (NoAddCarry)
2901 return DAG.getNode(
2902 Opcode: ISD::ADD, DL, VT,
2903 N1: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: N0.getOperand(i: 0)),
2904 N2: N0.getOperand(i: 1));
2905 }
2906 return SDValue();
2907 };
2908 if (SDValue Add = ReassociateAddOr(N0, N1))
2909 return Add;
2910 if (SDValue Add = ReassociateAddOr(N1, N0))
2911 return Add;
2912
2913 // Fold add(vecreduce(x), vecreduce(y)) -> vecreduce(add(x, y))
2914 if (SDValue SD =
2915 reassociateReduction(RedOpc: ISD::VECREDUCE_ADD, Opc: ISD::ADD, DL, VT, N0, N1))
2916 return SD;
2917 }
2918
2919 SDValue A, B, C, D;
2920
2921 // fold ((0-A) + B) -> B-A
2922 if (sd_match(N: N0, P: m_Neg(V: m_Value(N&: A))))
2923 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1, N2: A);
2924
2925 // fold (A + (0-B)) -> A-B
2926 if (sd_match(N: N1, P: m_Neg(V: m_Value(N&: B))))
2927 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: B);
2928
2929 // fold (A+(B-A)) -> B
2930 if (sd_match(N: N1, P: m_Sub(L: m_Value(N&: B), R: m_Specific(N: N0))))
2931 return B;
2932
2933 // fold ((B-A)+A) -> B
2934 if (sd_match(N: N0, P: m_Sub(L: m_Value(N&: B), R: m_Specific(N: N1))))
2935 return B;
2936
2937 // fold ((A-B)+(C-A)) -> (C-B)
2938 if (sd_match(N: N0, P: m_Sub(L: m_Value(N&: A), R: m_Value(N&: B))) &&
2939 sd_match(N: N1, P: m_Sub(L: m_Value(N&: C), R: m_Specific(N: A))))
2940 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: C, N2: B);
2941
2942 // fold ((A-B)+(B-C)) -> (A-C)
2943 if (sd_match(N: N0, P: m_Sub(L: m_Value(N&: A), R: m_Value(N&: B))) &&
2944 sd_match(N: N1, P: m_Sub(L: m_Specific(N: B), R: m_Value(N&: C))))
2945 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: A, N2: C);
2946
2947 // fold (A+(B-(A+C))) to (B-C)
2948 // fold (A+(B-(C+A))) to (B-C)
2949 if (sd_match(N: N1, P: m_Sub(L: m_Value(N&: B), R: m_Add(L: m_Specific(N: N0), R: m_Value(N&: C)))))
2950 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: B, N2: C);
2951
2952 // fold (A+((B-A)+or-C)) to (B+or-C)
2953 if (sd_match(N: N1,
2954 P: m_AnyOf(preds: m_Add(L: m_Sub(L: m_Value(N&: B), R: m_Specific(N: N0)), R: m_Value(N&: C)),
2955 preds: m_Sub(L: m_Sub(L: m_Value(N&: B), R: m_Specific(N: N0)), R: m_Value(N&: C)))))
2956 return DAG.getNode(Opcode: N1.getOpcode(), DL, VT, N1: B, N2: C);
2957
2958 // fold (A-B)+(C-D) to (A+C)-(B+D) when A or C is constant
2959 if (sd_match(N: N0, P: m_OneUse(P: m_Sub(L: m_Value(N&: A), R: m_Value(N&: B)))) &&
2960 sd_match(N: N1, P: m_OneUse(P: m_Sub(L: m_Value(N&: C), R: m_Value(N&: D)))) &&
2961 (isConstantOrConstantVector(N: A) || isConstantOrConstantVector(N: C)))
2962 return DAG.getNode(Opcode: ISD::SUB, DL, VT,
2963 N1: DAG.getNode(Opcode: ISD::ADD, DL: SDLoc(N0), VT, N1: A, N2: C),
2964 N2: DAG.getNode(Opcode: ISD::ADD, DL: SDLoc(N1), VT, N1: B, N2: D));
2965
2966 // fold (add (umax X, C), -C) --> (usubsat X, C)
2967 if (N0.getOpcode() == ISD::UMAX && hasOperation(Opcode: ISD::USUBSAT, VT)) {
2968 auto MatchUSUBSAT = [](ConstantSDNode *Max, ConstantSDNode *Op) {
2969 return (!Max && !Op) ||
2970 (Max && Op && Max->getAPIntValue() == (-Op->getAPIntValue()));
2971 };
2972 if (ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchUSUBSAT,
2973 /*AllowUndefs*/ true))
2974 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT, N1: N0.getOperand(i: 0),
2975 N2: N0.getOperand(i: 1));
2976 }
2977
2978 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
2979 return SDValue(N, 0);
2980
2981 if (isOneOrOneSplat(V: N1)) {
2982 // fold (add (xor a, -1), 1) -> (sub 0, a)
2983 if (isBitwiseNot(V: N0))
2984 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: DAG.getConstant(Val: 0, DL, VT),
2985 N2: N0.getOperand(i: 0));
2986
2987 // fold (add (add (xor a, -1), b), 1) -> (sub b, a)
2988 if (N0.getOpcode() == ISD::ADD) {
2989 SDValue A, Xor;
2990
2991 if (isBitwiseNot(V: N0.getOperand(i: 0))) {
2992 A = N0.getOperand(i: 1);
2993 Xor = N0.getOperand(i: 0);
2994 } else if (isBitwiseNot(V: N0.getOperand(i: 1))) {
2995 A = N0.getOperand(i: 0);
2996 Xor = N0.getOperand(i: 1);
2997 }
2998
2999 if (Xor)
3000 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: A, N2: Xor.getOperand(i: 0));
3001 }
3002
3003 // Look for:
3004 // add (add x, y), 1
3005 // And if the target does not like this form then turn into:
3006 // sub y, (xor x, -1)
3007 if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
3008 N0.hasOneUse() &&
3009 // Limit this to after legalization if the add has wrap flags
3010 (Level >= AfterLegalizeDAG || (!N->getFlags().hasNoUnsignedWrap() &&
3011 !N->getFlags().hasNoSignedWrap()))) {
3012 SDValue Not = DAG.getNOT(DL, Val: N0.getOperand(i: 0), VT);
3013 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0.getOperand(i: 1), N2: Not);
3014 }
3015 }
3016
3017 // (x - y) + -1 -> add (xor y, -1), x
3018 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
3019 isAllOnesOrAllOnesSplat(V: N1, /*AllowUndefs=*/true)) {
3020 SDValue Not = DAG.getNOT(DL, Val: N0.getOperand(i: 1), VT);
3021 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Not, N2: N0.getOperand(i: 0));
3022 }
3023
3024 // Fold add(mul(add(A, CA), CM), CB) -> add(mul(A, CM), CM*CA+CB).
3025 // This can help if the inner add has multiple uses.
3026 APInt CM, CA;
3027 if (ConstantSDNode *CB = dyn_cast<ConstantSDNode>(Val&: N1)) {
3028 if (VT.getScalarSizeInBits() <= 64) {
3029 if (sd_match(N: N0, P: m_OneUse(P: m_Mul(L: m_Add(L: m_Value(N&: A), R: m_ConstInt(V&: CA)),
3030 R: m_ConstInt(V&: CM)))) &&
3031 TLI.isLegalAddImmediate(
3032 (CA * CM + CB->getAPIntValue()).getSExtValue())) {
3033 SDNodeFlags Flags;
3034 // If all the inputs are nuw, the outputs can be nuw. If all the input
3035 // are _also_ nsw the outputs can be too.
3036 if (N->getFlags().hasNoUnsignedWrap() &&
3037 N0->getFlags().hasNoUnsignedWrap() &&
3038 N0.getOperand(i: 0)->getFlags().hasNoUnsignedWrap()) {
3039 Flags |= SDNodeFlags::NoUnsignedWrap;
3040 if (N->getFlags().hasNoSignedWrap() &&
3041 N0->getFlags().hasNoSignedWrap() &&
3042 N0.getOperand(i: 0)->getFlags().hasNoSignedWrap())
3043 Flags |= SDNodeFlags::NoSignedWrap;
3044 }
3045 SDValue Mul = DAG.getNode(Opcode: ISD::MUL, DL: SDLoc(N1), VT, N1: A,
3046 N2: DAG.getConstant(Val: CM, DL, VT), Flags);
3047 return DAG.getNode(
3048 Opcode: ISD::ADD, DL, VT, N1: Mul,
3049 N2: DAG.getConstant(Val: CA * CM + CB->getAPIntValue(), DL, VT), Flags);
3050 }
3051 // Also look in case there is an intermediate add.
3052 if (sd_match(N: N0, P: m_OneUse(P: m_Add(
3053 L: m_OneUse(P: m_Mul(L: m_Add(L: m_Value(N&: A), R: m_ConstInt(V&: CA)),
3054 R: m_ConstInt(V&: CM))),
3055 R: m_Value(N&: B)))) &&
3056 TLI.isLegalAddImmediate(
3057 (CA * CM + CB->getAPIntValue()).getSExtValue())) {
3058 SDNodeFlags Flags;
3059 // If all the inputs are nuw, the outputs can be nuw. If all the input
3060 // are _also_ nsw the outputs can be too.
3061 SDValue OMul =
3062 N0.getOperand(i: 0) == B ? N0.getOperand(i: 1) : N0.getOperand(i: 0);
3063 if (N->getFlags().hasNoUnsignedWrap() &&
3064 N0->getFlags().hasNoUnsignedWrap() &&
3065 OMul->getFlags().hasNoUnsignedWrap() &&
3066 OMul.getOperand(i: 0)->getFlags().hasNoUnsignedWrap()) {
3067 Flags |= SDNodeFlags::NoUnsignedWrap;
3068 if (N->getFlags().hasNoSignedWrap() &&
3069 N0->getFlags().hasNoSignedWrap() &&
3070 OMul->getFlags().hasNoSignedWrap() &&
3071 OMul.getOperand(i: 0)->getFlags().hasNoSignedWrap())
3072 Flags |= SDNodeFlags::NoSignedWrap;
3073 }
3074 SDValue Mul = DAG.getNode(Opcode: ISD::MUL, DL: SDLoc(N1), VT, N1: A,
3075 N2: DAG.getConstant(Val: CM, DL, VT), Flags);
3076 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL: SDLoc(N1), VT, N1: Mul, N2: B, Flags);
3077 return DAG.getNode(
3078 Opcode: ISD::ADD, DL, VT, N1: Add,
3079 N2: DAG.getConstant(Val: CA * CM + CB->getAPIntValue(), DL, VT), Flags);
3080 }
3081 }
3082 }
3083
3084 if (SDValue Combined = visitADDLikeCommutative(N0, N1, LocReference: N))
3085 return Combined;
3086
3087 if (SDValue Combined = visitADDLikeCommutative(N0: N1, N1: N0, LocReference: N))
3088 return Combined;
3089
3090 return SDValue();
3091}
3092
3093// Attempt to form avgfloor(A, B) from (A & B) + ((A ^ B) >> 1)
3094SDValue DAGCombiner::foldAddToAvg(SDNode *N, const SDLoc &DL) {
3095 SDValue N0 = N->getOperand(Num: 0);
3096 EVT VT = N0.getValueType();
3097 SDValue A, B;
3098
3099 if ((!LegalOperations || hasOperation(Opcode: ISD::AVGFLOORU, VT)) &&
3100 sd_match(N, P: m_Add(L: m_And(L: m_Value(N&: A), R: m_Value(N&: B)),
3101 R: m_Srl(L: m_Xor(L: m_Deferred(V&: A), R: m_Deferred(V&: B)), R: m_One())))) {
3102 return DAG.getNode(Opcode: ISD::AVGFLOORU, DL, VT, N1: A, N2: B);
3103 }
3104 if ((!LegalOperations || hasOperation(Opcode: ISD::AVGFLOORS, VT)) &&
3105 sd_match(N, P: m_Add(L: m_And(L: m_Value(N&: A), R: m_Value(N&: B)),
3106 R: m_Sra(L: m_Xor(L: m_Deferred(V&: A), R: m_Deferred(V&: B)), R: m_One())))) {
3107 return DAG.getNode(Opcode: ISD::AVGFLOORS, DL, VT, N1: A, N2: B);
3108 }
3109
3110 return SDValue();
3111}
3112
3113SDValue DAGCombiner::visitADD(SDNode *N) {
3114 SDValue N0 = N->getOperand(Num: 0);
3115 SDValue N1 = N->getOperand(Num: 1);
3116 EVT VT = N0.getValueType();
3117 SDLoc DL(N);
3118
3119 if (SDValue Combined = visitADDLike(N))
3120 return Combined;
3121
3122 if (SDValue V = foldAddSubBoolOfMaskedVal(N, DL, DAG))
3123 return V;
3124
3125 if (SDValue V = foldAddSubOfSignBit(N, DL, DAG))
3126 return V;
3127
3128 if (SDValue V = MatchRotate(LHS: N0, RHS: N1, DL: SDLoc(N), /*FromAdd=*/true))
3129 return V;
3130
3131 // Try to match AVGFLOOR fixedwidth pattern
3132 if (SDValue V = foldAddToAvg(N, DL))
3133 return V;
3134
3135 // fold (a+b) -> (a|b) iff a and b share no bits.
3136 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::OR, VT)) &&
3137 DAG.haveNoCommonBitsSet(A: N0, B: N1))
3138 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: N0, N2: N1, Flags: SDNodeFlags::Disjoint);
3139
3140 // Fold (add (vscale * C0), (vscale * C1)) to (vscale * (C0 + C1)).
3141 if (N0.getOpcode() == ISD::VSCALE && N1.getOpcode() == ISD::VSCALE) {
3142 const APInt &C0 = N0->getConstantOperandAPInt(Num: 0);
3143 const APInt &C1 = N1->getConstantOperandAPInt(Num: 0);
3144 return DAG.getVScale(DL, VT, MulImm: C0 + C1);
3145 }
3146
3147 // fold a+vscale(c1)+vscale(c2) -> a+vscale(c1+c2)
3148 if (N0.getOpcode() == ISD::ADD &&
3149 N0.getOperand(i: 1).getOpcode() == ISD::VSCALE &&
3150 N1.getOpcode() == ISD::VSCALE) {
3151 const APInt &VS0 = N0.getOperand(i: 1)->getConstantOperandAPInt(Num: 0);
3152 const APInt &VS1 = N1->getConstantOperandAPInt(Num: 0);
3153 SDValue VS = DAG.getVScale(DL, VT, MulImm: VS0 + VS1);
3154 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: VS);
3155 }
3156
3157 // Fold (add step_vector(c1), step_vector(c2) to step_vector(c1+c2))
3158 if (N0.getOpcode() == ISD::STEP_VECTOR &&
3159 N1.getOpcode() == ISD::STEP_VECTOR) {
3160 const APInt &C0 = N0->getConstantOperandAPInt(Num: 0);
3161 const APInt &C1 = N1->getConstantOperandAPInt(Num: 0);
3162 APInt NewStep = C0 + C1;
3163 return DAG.getStepVector(DL, ResVT: VT, StepVal: NewStep);
3164 }
3165
3166 // Fold a + step_vector(c1) + step_vector(c2) to a + step_vector(c1+c2)
3167 if (N0.getOpcode() == ISD::ADD &&
3168 N0.getOperand(i: 1).getOpcode() == ISD::STEP_VECTOR &&
3169 N1.getOpcode() == ISD::STEP_VECTOR) {
3170 const APInt &SV0 = N0.getOperand(i: 1)->getConstantOperandAPInt(Num: 0);
3171 const APInt &SV1 = N1->getConstantOperandAPInt(Num: 0);
3172 APInt NewStep = SV0 + SV1;
3173 SDValue SV = DAG.getStepVector(DL, ResVT: VT, StepVal: NewStep);
3174 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: SV);
3175 }
3176
3177 return SDValue();
3178}
3179
3180SDValue DAGCombiner::visitADDSAT(SDNode *N) {
3181 unsigned Opcode = N->getOpcode();
3182 SDValue N0 = N->getOperand(Num: 0);
3183 SDValue N1 = N->getOperand(Num: 1);
3184 EVT VT = N0.getValueType();
3185 bool IsSigned = Opcode == ISD::SADDSAT;
3186 SDLoc DL(N);
3187
3188 // fold (add_sat x, undef) -> -1
3189 if (N0.isUndef() || N1.isUndef())
3190 return DAG.getAllOnesConstant(DL, VT);
3191
3192 // fold (add_sat c1, c2) -> c3
3193 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
3194 return C;
3195
3196 // canonicalize constant to RHS
3197 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
3198 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
3199 return DAG.getNode(Opcode, DL, VT, N1, N2: N0);
3200
3201 // fold vector ops
3202 if (VT.isVector()) {
3203 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
3204 return FoldedVOp;
3205
3206 // fold (add_sat x, 0) -> x, vector edition
3207 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
3208 return N0;
3209 }
3210
3211 // fold (add_sat x, 0) -> x
3212 if (isNullConstant(V: N1))
3213 return N0;
3214
3215 // If it cannot overflow, transform into an add.
3216 if (DAG.willNotOverflowAdd(IsSigned, N0, N1))
3217 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1);
3218
3219 return SDValue();
3220}
3221
3222static SDValue getAsCarry(const TargetLowering &TLI, SDValue V,
3223 bool ForceCarryReconstruction = false) {
3224 bool Masked = false;
3225
3226 // First, peel away TRUNCATE/ZERO_EXTEND/AND nodes due to legalization.
3227 while (true) {
3228 if (V.getOpcode() == ISD::TRUNCATE || V.getOpcode() == ISD::ZERO_EXTEND) {
3229 V = V.getOperand(i: 0);
3230 continue;
3231 }
3232
3233 if (V.getOpcode() == ISD::AND && isOneConstant(V: V.getOperand(i: 1))) {
3234 if (ForceCarryReconstruction)
3235 return V;
3236
3237 Masked = true;
3238 V = V.getOperand(i: 0);
3239 continue;
3240 }
3241
3242 if (ForceCarryReconstruction && V.getValueType() == MVT::i1)
3243 return V;
3244
3245 break;
3246 }
3247
3248 // If this is not a carry, return.
3249 if (V.getResNo() != 1)
3250 return SDValue();
3251
3252 if (V.getOpcode() != ISD::UADDO_CARRY && V.getOpcode() != ISD::USUBO_CARRY &&
3253 V.getOpcode() != ISD::UADDO && V.getOpcode() != ISD::USUBO)
3254 return SDValue();
3255
3256 EVT VT = V->getValueType(ResNo: 0);
3257 if (!TLI.isOperationLegalOrCustom(Op: V.getOpcode(), VT))
3258 return SDValue();
3259
3260 // If the result is masked, then no matter what kind of bool it is we can
3261 // return. If it isn't, then we need to make sure the bool type is either 0 or
3262 // 1 and not other values.
3263 if (Masked ||
3264 TLI.getBooleanContents(Type: V.getValueType()) ==
3265 TargetLoweringBase::ZeroOrOneBooleanContent)
3266 return V;
3267
3268 return SDValue();
3269}
3270
3271/// Given the operands of an add/sub operation, see if the 2nd operand is a
3272/// masked 0/1 whose source operand is actually known to be 0/-1. If so, invert
3273/// the opcode and bypass the mask operation.
3274static SDValue foldAddSubMasked1(bool IsAdd, SDValue N0, SDValue N1,
3275 SelectionDAG &DAG, const SDLoc &DL) {
3276 if (N1.getOpcode() == ISD::ZERO_EXTEND)
3277 N1 = N1.getOperand(i: 0);
3278
3279 if (N1.getOpcode() != ISD::AND || !isOneOrOneSplat(V: N1->getOperand(Num: 1)))
3280 return SDValue();
3281
3282 EVT VT = N0.getValueType();
3283 SDValue N10 = N1.getOperand(i: 0);
3284 if (N10.getValueType() != VT && N10.getOpcode() == ISD::TRUNCATE)
3285 N10 = N10.getOperand(i: 0);
3286
3287 if (N10.getValueType() != VT)
3288 return SDValue();
3289
3290 if (DAG.ComputeNumSignBits(Op: N10) != VT.getScalarSizeInBits())
3291 return SDValue();
3292
3293 // add N0, (and (AssertSext X, i1), 1) --> sub N0, X
3294 // sub N0, (and (AssertSext X, i1), 1) --> add N0, X
3295 return DAG.getNode(Opcode: IsAdd ? ISD::SUB : ISD::ADD, DL, VT, N1: N0, N2: N10);
3296}
3297
3298/// Helper for doing combines based on N0 and N1 being added to each other.
3299SDValue DAGCombiner::visitADDLikeCommutative(SDValue N0, SDValue N1,
3300 SDNode *LocReference) {
3301 EVT VT = N0.getValueType();
3302 SDLoc DL(LocReference);
3303
3304 // fold (add x, shl(0 - y, n)) -> sub(x, shl(y, n))
3305 SDValue Y, N;
3306 if (sd_match(N: N1, P: m_Shl(L: m_Neg(V: m_Value(N&: Y)), R: m_Value(N))))
3307 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0,
3308 N2: DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Y, N2: N));
3309
3310 if (SDValue V = foldAddSubMasked1(IsAdd: true, N0, N1, DAG, DL))
3311 return V;
3312
3313 // Look for:
3314 // add (add x, 1), y
3315 // And if the target does not like this form then turn into:
3316 // sub y, (xor x, -1)
3317 if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
3318 N0.hasOneUse() && isOneOrOneSplat(V: N0.getOperand(i: 1)) &&
3319 // Limit this to after legalization if the add has wrap flags
3320 (Level >= AfterLegalizeDAG || (!N0->getFlags().hasNoUnsignedWrap() &&
3321 !N0->getFlags().hasNoSignedWrap()))) {
3322 SDValue Not = DAG.getNOT(DL, Val: N0.getOperand(i: 0), VT);
3323 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1, N2: Not);
3324 }
3325
3326 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse()) {
3327 // Hoist one-use subtraction by non-opaque constant:
3328 // (x - C) + y -> (x + y) - C
3329 // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
3330 if (isConstantOrConstantVector(N: N0.getOperand(i: 1), /*NoOpaques=*/true)) {
3331 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: N1);
3332 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Add, N2: N0.getOperand(i: 1));
3333 }
3334 // Hoist one-use subtraction from non-opaque constant:
3335 // (C - x) + y -> (y - x) + C
3336 if (isConstantOrConstantVector(N: N0.getOperand(i: 0), /*NoOpaques=*/true)) {
3337 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1, N2: N0.getOperand(i: 1));
3338 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Sub, N2: N0.getOperand(i: 0));
3339 }
3340 }
3341
3342 // add (mul x, C), x -> mul x, C+1
3343 if (N0.getOpcode() == ISD::MUL && N0.getOperand(i: 0) == N1 &&
3344 isConstantOrConstantVector(N: N0.getOperand(i: 1), /*NoOpaques=*/true) &&
3345 N0.hasOneUse()) {
3346 SDValue NewC = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 1),
3347 N2: DAG.getConstant(Val: 1, DL, VT));
3348 return DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: N0.getOperand(i: 0), N2: NewC);
3349 }
3350
3351 // If the target's bool is represented as 0/1, prefer to make this 'sub 0/1'
3352 // rather than 'add 0/-1' (the zext should get folded).
3353 // add (sext i1 Y), X --> sub X, (zext i1 Y)
3354 if (N0.getOpcode() == ISD::SIGN_EXTEND &&
3355 N0.getOperand(i: 0).getScalarValueSizeInBits() == 1 &&
3356 TLI.getBooleanContents(Type: VT) == TargetLowering::ZeroOrOneBooleanContent) {
3357 SDValue ZExt = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0.getOperand(i: 0));
3358 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1, N2: ZExt);
3359 }
3360
3361 // add X, (sextinreg Y i1) -> sub X, (and Y 1)
3362 if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
3363 VTSDNode *TN = cast<VTSDNode>(Val: N1.getOperand(i: 1));
3364 if (TN->getVT() == MVT::i1) {
3365 SDValue ZExt = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N1.getOperand(i: 0),
3366 N2: DAG.getConstant(Val: 1, DL, VT));
3367 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: ZExt);
3368 }
3369 }
3370
3371 // (add X, (uaddo_carry Y, 0, Carry)) -> (uaddo_carry X, Y, Carry)
3372 if (N1.getOpcode() == ISD::UADDO_CARRY && isNullConstant(V: N1.getOperand(i: 1)) &&
3373 N1.getResNo() == 0)
3374 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL, VTList: N1->getVTList(),
3375 N1: N0, N2: N1.getOperand(i: 0), N3: N1.getOperand(i: 2));
3376
3377 // (add X, Carry) -> (uaddo_carry X, 0, Carry)
3378 if (TLI.isOperationLegalOrCustom(Op: ISD::UADDO_CARRY, VT))
3379 if (SDValue Carry = getAsCarry(TLI, V: N1))
3380 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL,
3381 VTList: DAG.getVTList(VT1: VT, VT2: Carry.getValueType()), N1: N0,
3382 N2: DAG.getConstant(Val: 0, DL, VT), N3: Carry);
3383
3384 return SDValue();
3385}
3386
3387SDValue DAGCombiner::visitADDC(SDNode *N) {
3388 SDValue N0 = N->getOperand(Num: 0);
3389 SDValue N1 = N->getOperand(Num: 1);
3390 EVT VT = N0.getValueType();
3391 SDLoc DL(N);
3392
3393 // If the flag result is dead, turn this into an ADD.
3394 if (!N->hasAnyUseOfValue(Value: 1))
3395 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1),
3396 Res1: DAG.getNode(Opcode: ISD::CARRY_FALSE, DL, VT: MVT::Glue));
3397
3398 // canonicalize constant to RHS.
3399 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(Val&: N0);
3400 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(Val&: N1);
3401 if (N0C && !N1C)
3402 return DAG.getNode(Opcode: ISD::ADDC, DL, VTList: N->getVTList(), N1, N2: N0);
3403
3404 // fold (addc x, 0) -> x + no carry out
3405 if (isNullConstant(V: N1))
3406 return CombineTo(N, Res0: N0, Res1: DAG.getNode(Opcode: ISD::CARRY_FALSE,
3407 DL, VT: MVT::Glue));
3408
3409 // If it cannot overflow, transform into an add.
3410 if (DAG.computeOverflowForUnsignedAdd(N0, N1) == SelectionDAG::OFK_Never)
3411 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1),
3412 Res1: DAG.getNode(Opcode: ISD::CARRY_FALSE, DL, VT: MVT::Glue));
3413
3414 return SDValue();
3415}
3416
3417/**
3418 * Flips a boolean if it is cheaper to compute. If the Force parameters is set,
3419 * then the flip also occurs if computing the inverse is the same cost.
3420 * This function returns an empty SDValue in case it cannot flip the boolean
3421 * without increasing the cost of the computation. If you want to flip a boolean
3422 * no matter what, use DAG.getLogicalNOT.
3423 */
3424static SDValue extractBooleanFlip(SDValue V, SelectionDAG &DAG,
3425 const TargetLowering &TLI,
3426 bool Force) {
3427 if (Force && isa<ConstantSDNode>(Val: V))
3428 return DAG.getLogicalNOT(DL: SDLoc(V), Val: V, VT: V.getValueType());
3429
3430 if (V.getOpcode() != ISD::XOR)
3431 return SDValue();
3432
3433 if (DAG.isBoolConstant(N: V.getOperand(i: 1)) == true)
3434 return V.getOperand(i: 0);
3435 if (Force && isConstOrConstSplat(N: V.getOperand(i: 1), AllowUndefs: false))
3436 return DAG.getLogicalNOT(DL: SDLoc(V), Val: V, VT: V.getValueType());
3437 return SDValue();
3438}
3439
3440SDValue DAGCombiner::visitADDO(SDNode *N) {
3441 SDValue N0 = N->getOperand(Num: 0);
3442 SDValue N1 = N->getOperand(Num: 1);
3443 EVT VT = N0.getValueType();
3444 bool IsSigned = (ISD::SADDO == N->getOpcode());
3445
3446 EVT CarryVT = N->getValueType(ResNo: 1);
3447 SDLoc DL(N);
3448
3449 // If the flag result is dead, turn this into an ADD.
3450 if (!N->hasAnyUseOfValue(Value: 1))
3451 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1),
3452 Res1: DAG.getUNDEF(VT: CarryVT));
3453
3454 // canonicalize constant to RHS.
3455 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
3456 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
3457 return DAG.getNode(Opcode: N->getOpcode(), DL, VTList: N->getVTList(), N1, N2: N0);
3458
3459 // fold (addo x, 0) -> x + no carry out
3460 if (isNullOrNullSplat(V: N1))
3461 return CombineTo(N, Res0: N0, Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
3462
3463 // If it cannot overflow, transform into an add.
3464 if (DAG.willNotOverflowAdd(IsSigned, N0, N1))
3465 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1),
3466 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
3467
3468 if (IsSigned) {
3469 // fold (saddo (xor a, -1), 1) -> (ssub 0, a).
3470 if (isBitwiseNot(V: N0) && isOneOrOneSplat(V: N1))
3471 return DAG.getNode(Opcode: ISD::SSUBO, DL, VTList: N->getVTList(),
3472 N1: DAG.getConstant(Val: 0, DL, VT), N2: N0.getOperand(i: 0));
3473 } else {
3474 // fold (uaddo (xor a, -1), 1) -> (usub 0, a) and flip carry.
3475 if (isBitwiseNot(V: N0) && isOneOrOneSplat(V: N1)) {
3476 SDValue Sub = DAG.getNode(Opcode: ISD::USUBO, DL, VTList: N->getVTList(),
3477 N1: DAG.getConstant(Val: 0, DL, VT), N2: N0.getOperand(i: 0));
3478 return CombineTo(
3479 N, Res0: Sub, Res1: DAG.getLogicalNOT(DL, Val: Sub.getValue(R: 1), VT: Sub->getValueType(ResNo: 1)));
3480 }
3481
3482 if (SDValue Combined = visitUADDOLike(N0, N1, N))
3483 return Combined;
3484
3485 if (SDValue Combined = visitUADDOLike(N0: N1, N1: N0, N))
3486 return Combined;
3487 }
3488
3489 return SDValue();
3490}
3491
3492SDValue DAGCombiner::visitUADDOLike(SDValue N0, SDValue N1, SDNode *N) {
3493 EVT VT = N0.getValueType();
3494 if (VT.isVector())
3495 return SDValue();
3496
3497 // (uaddo X, (uaddo_carry Y, 0, Carry)) -> (uaddo_carry X, Y, Carry)
3498 // If Y + 1 cannot overflow.
3499 if (N1.getOpcode() == ISD::UADDO_CARRY && isNullConstant(V: N1.getOperand(i: 1))) {
3500 SDValue Y = N1.getOperand(i: 0);
3501 SDValue One = DAG.getConstant(Val: 1, DL: SDLoc(N), VT: Y.getValueType());
3502 if (DAG.computeOverflowForUnsignedAdd(N0: Y, N1: One) == SelectionDAG::OFK_Never)
3503 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL: SDLoc(N), VTList: N->getVTList(), N1: N0, N2: Y,
3504 N3: N1.getOperand(i: 2));
3505 }
3506
3507 // (uaddo X, Carry) -> (uaddo_carry X, 0, Carry)
3508 if (TLI.isOperationLegalOrCustom(Op: ISD::UADDO_CARRY, VT))
3509 if (SDValue Carry = getAsCarry(TLI, V: N1))
3510 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL: SDLoc(N), VTList: N->getVTList(), N1: N0,
3511 N2: DAG.getConstant(Val: 0, DL: SDLoc(N), VT), N3: Carry);
3512
3513 return SDValue();
3514}
3515
3516SDValue DAGCombiner::visitADDE(SDNode *N) {
3517 SDValue N0 = N->getOperand(Num: 0);
3518 SDValue N1 = N->getOperand(Num: 1);
3519 SDValue CarryIn = N->getOperand(Num: 2);
3520
3521 // canonicalize constant to RHS
3522 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(Val&: N0);
3523 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(Val&: N1);
3524 if (N0C && !N1C)
3525 return DAG.getNode(Opcode: ISD::ADDE, DL: SDLoc(N), VTList: N->getVTList(),
3526 N1, N2: N0, N3: CarryIn);
3527
3528 // fold (adde x, y, false) -> (addc x, y)
3529 if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
3530 return DAG.getNode(Opcode: ISD::ADDC, DL: SDLoc(N), VTList: N->getVTList(), N1: N0, N2: N1);
3531
3532 return SDValue();
3533}
3534
3535SDValue DAGCombiner::visitUADDO_CARRY(SDNode *N) {
3536 SDValue N0 = N->getOperand(Num: 0);
3537 SDValue N1 = N->getOperand(Num: 1);
3538 SDValue CarryIn = N->getOperand(Num: 2);
3539 SDLoc DL(N);
3540
3541 // canonicalize constant to RHS
3542 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(Val&: N0);
3543 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(Val&: N1);
3544 if (N0C && !N1C)
3545 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL, VTList: N->getVTList(), N1, N2: N0, N3: CarryIn);
3546
3547 // fold (uaddo_carry x, y, false) -> (uaddo x, y)
3548 if (isNullConstant(V: CarryIn)) {
3549 if (!LegalOperations ||
3550 TLI.isOperationLegalOrCustom(Op: ISD::UADDO, VT: N->getValueType(ResNo: 0)))
3551 return DAG.getNode(Opcode: ISD::UADDO, DL, VTList: N->getVTList(), N1: N0, N2: N1);
3552 }
3553
3554 // fold (uaddo_carry 0, 0, X) -> (and (ext/trunc X), 1) and no carry.
3555 if (isNullConstant(V: N0) && isNullConstant(V: N1)) {
3556 EVT VT = N0.getValueType();
3557 EVT CarryVT = CarryIn.getValueType();
3558 SDValue CarryExt = DAG.getBoolExtOrTrunc(Op: CarryIn, SL: DL, VT, OpVT: CarryVT);
3559 AddToWorklist(N: CarryExt.getNode());
3560 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::AND, DL, VT, N1: CarryExt,
3561 N2: DAG.getConstant(Val: 1, DL, VT)),
3562 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
3563 }
3564
3565 if (SDValue Combined = visitUADDO_CARRYLike(N0, N1, CarryIn, N))
3566 return Combined;
3567
3568 if (SDValue Combined = visitUADDO_CARRYLike(N0: N1, N1: N0, CarryIn, N))
3569 return Combined;
3570
3571 // We want to avoid useless duplication.
3572 // TODO: This is done automatically for binary operations. As UADDO_CARRY is
3573 // not a binary operation, this is not really possible to leverage this
3574 // existing mechanism for it. However, if more operations require the same
3575 // deduplication logic, then it may be worth generalize.
3576 SDValue Ops[] = {N1, N0, CarryIn};
3577 SDNode *CSENode =
3578 DAG.getNodeIfExists(Opcode: ISD::UADDO_CARRY, VTList: N->getVTList(), Ops, Flags: N->getFlags());
3579 if (CSENode)
3580 return SDValue(CSENode, 0);
3581
3582 return SDValue();
3583}
3584
3585/**
3586 * If we are facing some sort of diamond carry propagation pattern try to
3587 * break it up to generate something like:
3588 * (uaddo_carry X, 0, (uaddo_carry A, B, Z):Carry)
3589 *
3590 * The end result is usually an increase in operation required, but because the
3591 * carry is now linearized, other transforms can kick in and optimize the DAG.
3592 *
3593 * Patterns typically look something like
3594 * (uaddo A, B)
3595 * / \
3596 * Carry Sum
3597 * | \
3598 * | (uaddo_carry *, 0, Z)
3599 * | /
3600 * \ Carry
3601 * | /
3602 * (uaddo_carry X, *, *)
3603 *
3604 * But numerous variation exist. Our goal is to identify A, B, X and Z and
3605 * produce a combine with a single path for carry propagation.
3606 */
3607static SDValue combineUADDO_CARRYDiamond(DAGCombiner &Combiner,
3608 SelectionDAG &DAG, SDValue X,
3609 SDValue Carry0, SDValue Carry1,
3610 SDNode *N) {
3611 if (Carry1.getResNo() != 1 || Carry0.getResNo() != 1)
3612 return SDValue();
3613 if (Carry1.getOpcode() != ISD::UADDO)
3614 return SDValue();
3615
3616 SDValue Z;
3617
3618 /**
3619 * First look for a suitable Z. It will present itself in the form of
3620 * (uaddo_carry Y, 0, Z) or its equivalent (uaddo Y, 1) for Z=true
3621 */
3622 if (Carry0.getOpcode() == ISD::UADDO_CARRY &&
3623 isNullConstant(V: Carry0.getOperand(i: 1))) {
3624 Z = Carry0.getOperand(i: 2);
3625 } else if (Carry0.getOpcode() == ISD::UADDO &&
3626 isOneConstant(V: Carry0.getOperand(i: 1))) {
3627 EVT VT = Carry0->getValueType(ResNo: 1);
3628 Z = DAG.getConstant(Val: 1, DL: SDLoc(Carry0.getOperand(i: 1)), VT);
3629 } else {
3630 // We couldn't find a suitable Z.
3631 return SDValue();
3632 }
3633
3634
3635 auto cancelDiamond = [&](SDValue A,SDValue B) {
3636 SDLoc DL(N);
3637 SDValue NewY =
3638 DAG.getNode(Opcode: ISD::UADDO_CARRY, DL, VTList: Carry0->getVTList(), N1: A, N2: B, N3: Z);
3639 Combiner.AddToWorklist(N: NewY.getNode());
3640 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL, VTList: N->getVTList(), N1: X,
3641 N2: DAG.getConstant(Val: 0, DL, VT: X.getValueType()),
3642 N3: NewY.getValue(R: 1));
3643 };
3644
3645 /**
3646 * (uaddo A, B)
3647 * |
3648 * Sum
3649 * |
3650 * (uaddo_carry *, 0, Z)
3651 */
3652 if (Carry0.getOperand(i: 0) == Carry1.getValue(R: 0)) {
3653 return cancelDiamond(Carry1.getOperand(i: 0), Carry1.getOperand(i: 1));
3654 }
3655
3656 /**
3657 * (uaddo_carry A, 0, Z)
3658 * |
3659 * Sum
3660 * |
3661 * (uaddo *, B)
3662 */
3663 if (Carry1.getOperand(i: 0) == Carry0.getValue(R: 0)) {
3664 return cancelDiamond(Carry0.getOperand(i: 0), Carry1.getOperand(i: 1));
3665 }
3666
3667 if (Carry1.getOperand(i: 1) == Carry0.getValue(R: 0)) {
3668 return cancelDiamond(Carry1.getOperand(i: 0), Carry0.getOperand(i: 0));
3669 }
3670
3671 return SDValue();
3672}
3673
3674// If we are facing some sort of diamond carry/borrow in/out pattern try to
3675// match patterns like:
3676//
3677// (uaddo A, B) CarryIn
3678// | \ |
3679// | \ |
3680// PartialSum PartialCarryOutX /
3681// | | /
3682// | ____|____________/
3683// | / |
3684// (uaddo *, *) \________
3685// | \ \
3686// | \ |
3687// | PartialCarryOutY |
3688// | \ |
3689// | \ /
3690// AddCarrySum | ______/
3691// | /
3692// CarryOut = (or *, *)
3693//
3694// And generate UADDO_CARRY (or USUBO_CARRY) with two result values:
3695//
3696// {AddCarrySum, CarryOut} = (uaddo_carry A, B, CarryIn)
3697//
3698// Our goal is to identify A, B, and CarryIn and produce UADDO_CARRY/USUBO_CARRY
3699// with a single path for carry/borrow out propagation.
3700static SDValue combineCarryDiamond(SelectionDAG &DAG, const TargetLowering &TLI,
3701 SDValue N0, SDValue N1, SDNode *N) {
3702 SDValue Carry0 = getAsCarry(TLI, V: N0);
3703 if (!Carry0)
3704 return SDValue();
3705 SDValue Carry1 = getAsCarry(TLI, V: N1);
3706 if (!Carry1)
3707 return SDValue();
3708
3709 unsigned Opcode = Carry0.getOpcode();
3710 if (Opcode != Carry1.getOpcode())
3711 return SDValue();
3712 if (Opcode != ISD::UADDO && Opcode != ISD::USUBO)
3713 return SDValue();
3714 // Guarantee identical type of CarryOut
3715 EVT CarryOutType = N->getValueType(ResNo: 0);
3716 if (CarryOutType != Carry0.getValue(R: 1).getValueType() ||
3717 CarryOutType != Carry1.getValue(R: 1).getValueType())
3718 return SDValue();
3719
3720 // Canonicalize the add/sub of A and B (the top node in the above ASCII art)
3721 // as Carry0 and the add/sub of the carry in as Carry1 (the middle node).
3722 if (Carry1.getNode()->isOperandOf(N: Carry0.getNode()))
3723 std::swap(a&: Carry0, b&: Carry1);
3724
3725 // Check if nodes are connected in expected way.
3726 if (Carry1.getOperand(i: 0) != Carry0.getValue(R: 0) &&
3727 Carry1.getOperand(i: 1) != Carry0.getValue(R: 0))
3728 return SDValue();
3729
3730 // The carry in value must be on the righthand side for subtraction.
3731 unsigned CarryInOperandNum =
3732 Carry1.getOperand(i: 0) == Carry0.getValue(R: 0) ? 1 : 0;
3733 if (Opcode == ISD::USUBO && CarryInOperandNum != 1)
3734 return SDValue();
3735 SDValue CarryIn = Carry1.getOperand(i: CarryInOperandNum);
3736
3737 unsigned NewOp = Opcode == ISD::UADDO ? ISD::UADDO_CARRY : ISD::USUBO_CARRY;
3738 if (!TLI.isOperationLegalOrCustom(Op: NewOp, VT: Carry0.getValue(R: 0).getValueType()))
3739 return SDValue();
3740
3741 // Verify that the carry/borrow in is plausibly a carry/borrow bit.
3742 CarryIn = getAsCarry(TLI, V: CarryIn, ForceCarryReconstruction: true);
3743 if (!CarryIn)
3744 return SDValue();
3745
3746 SDLoc DL(N);
3747 CarryIn = DAG.getBoolExtOrTrunc(Op: CarryIn, SL: DL, VT: Carry1->getValueType(ResNo: 1),
3748 OpVT: Carry1->getValueType(ResNo: 0));
3749 SDValue Merged =
3750 DAG.getNode(Opcode: NewOp, DL, VTList: Carry1->getVTList(), N1: Carry0.getOperand(i: 0),
3751 N2: Carry0.getOperand(i: 1), N3: CarryIn);
3752
3753 // Please note that because we have proven that the result of the UADDO/USUBO
3754 // of A and B feeds into the UADDO/USUBO that does the carry/borrow in, we can
3755 // therefore prove that if the first UADDO/USUBO overflows, the second
3756 // UADDO/USUBO cannot. For example consider 8-bit numbers where 0xFF is the
3757 // maximum value.
3758 //
3759 // 0xFF + 0xFF == 0xFE with carry but 0xFE + 1 does not carry
3760 // 0x00 - 0xFF == 1 with a carry/borrow but 1 - 1 == 0 (no carry/borrow)
3761 //
3762 // This is important because it means that OR and XOR can be used to merge
3763 // carry flags; and that AND can return a constant zero.
3764 //
3765 // TODO: match other operations that can merge flags (ADD, etc)
3766 DAG.ReplaceAllUsesOfValueWith(From: Carry1.getValue(R: 0), To: Merged.getValue(R: 0));
3767 if (N->getOpcode() == ISD::AND)
3768 return DAG.getConstant(Val: 0, DL, VT: CarryOutType);
3769 return Merged.getValue(R: 1);
3770}
3771
3772SDValue DAGCombiner::visitUADDO_CARRYLike(SDValue N0, SDValue N1,
3773 SDValue CarryIn, SDNode *N) {
3774 // fold (uaddo_carry (xor a, -1), b, c) -> (usubo_carry b, a, !c) and flip
3775 // carry.
3776 if (isBitwiseNot(V: N0))
3777 if (SDValue NotC = extractBooleanFlip(V: CarryIn, DAG, TLI, Force: true)) {
3778 SDLoc DL(N);
3779 SDValue Sub = DAG.getNode(Opcode: ISD::USUBO_CARRY, DL, VTList: N->getVTList(), N1,
3780 N2: N0.getOperand(i: 0), N3: NotC);
3781 return CombineTo(
3782 N, Res0: Sub, Res1: DAG.getLogicalNOT(DL, Val: Sub.getValue(R: 1), VT: Sub->getValueType(ResNo: 1)));
3783 }
3784
3785 // Iff the flag result is dead:
3786 // (uaddo_carry (add|uaddo X, Y), 0, Carry) -> (uaddo_carry X, Y, Carry)
3787 // Don't do this if the Carry comes from the uaddo. It won't remove the uaddo
3788 // or the dependency between the instructions.
3789 if ((N0.getOpcode() == ISD::ADD ||
3790 (N0.getOpcode() == ISD::UADDO && N0.getResNo() == 0 &&
3791 N0.getValue(R: 1) != CarryIn)) &&
3792 isNullConstant(V: N1) && !N->hasAnyUseOfValue(Value: 1))
3793 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL: SDLoc(N), VTList: N->getVTList(),
3794 N1: N0.getOperand(i: 0), N2: N0.getOperand(i: 1), N3: CarryIn);
3795
3796 /**
3797 * When one of the uaddo_carry argument is itself a carry, we may be facing
3798 * a diamond carry propagation. In which case we try to transform the DAG
3799 * to ensure linear carry propagation if that is possible.
3800 */
3801 if (auto Y = getAsCarry(TLI, V: N1)) {
3802 // Because both are carries, Y and Z can be swapped.
3803 if (auto R = combineUADDO_CARRYDiamond(Combiner&: *this, DAG, X: N0, Carry0: Y, Carry1: CarryIn, N))
3804 return R;
3805 if (auto R = combineUADDO_CARRYDiamond(Combiner&: *this, DAG, X: N0, Carry0: CarryIn, Carry1: Y, N))
3806 return R;
3807 }
3808
3809 return SDValue();
3810}
3811
3812SDValue DAGCombiner::visitSADDO_CARRYLike(SDValue N0, SDValue N1,
3813 SDValue CarryIn, SDNode *N) {
3814 // fold (saddo_carry (xor a, -1), b, c) -> (ssubo_carry b, a, !c)
3815 if (isBitwiseNot(V: N0)) {
3816 if (SDValue NotC = extractBooleanFlip(V: CarryIn, DAG, TLI, Force: true))
3817 return DAG.getNode(Opcode: ISD::SSUBO_CARRY, DL: SDLoc(N), VTList: N->getVTList(), N1,
3818 N2: N0.getOperand(i: 0), N3: NotC);
3819 }
3820
3821 return SDValue();
3822}
3823
3824SDValue DAGCombiner::visitSADDO_CARRY(SDNode *N) {
3825 SDValue N0 = N->getOperand(Num: 0);
3826 SDValue N1 = N->getOperand(Num: 1);
3827 SDValue CarryIn = N->getOperand(Num: 2);
3828 SDLoc DL(N);
3829
3830 // canonicalize constant to RHS
3831 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(Val&: N0);
3832 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(Val&: N1);
3833 if (N0C && !N1C)
3834 return DAG.getNode(Opcode: ISD::SADDO_CARRY, DL, VTList: N->getVTList(), N1, N2: N0, N3: CarryIn);
3835
3836 // fold (saddo_carry x, y, false) -> (saddo x, y)
3837 if (isNullConstant(V: CarryIn)) {
3838 if (!LegalOperations ||
3839 TLI.isOperationLegalOrCustom(Op: ISD::SADDO, VT: N->getValueType(ResNo: 0)))
3840 return DAG.getNode(Opcode: ISD::SADDO, DL, VTList: N->getVTList(), N1: N0, N2: N1);
3841 }
3842
3843 if (SDValue Combined = visitSADDO_CARRYLike(N0, N1, CarryIn, N))
3844 return Combined;
3845
3846 if (SDValue Combined = visitSADDO_CARRYLike(N0: N1, N1: N0, CarryIn, N))
3847 return Combined;
3848
3849 return SDValue();
3850}
3851
3852// Attempt to create a USUBSAT(LHS, RHS) node with DstVT, performing a
3853// clamp/truncation if necessary.
3854static SDValue getTruncatedUSUBSAT(EVT DstVT, EVT SrcVT, SDValue LHS,
3855 SDValue RHS, SelectionDAG &DAG,
3856 const SDLoc &DL) {
3857 assert(DstVT.getScalarSizeInBits() <= SrcVT.getScalarSizeInBits() &&
3858 "Illegal truncation");
3859
3860 if (DstVT == SrcVT)
3861 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT: DstVT, N1: LHS, N2: RHS);
3862
3863 // If the LHS is zero-extended then we can perform the USUBSAT as DstVT by
3864 // clamping RHS.
3865 APInt UpperBits = APInt::getBitsSetFrom(numBits: SrcVT.getScalarSizeInBits(),
3866 loBit: DstVT.getScalarSizeInBits());
3867 if (!DAG.MaskedValueIsZero(Op: LHS, Mask: UpperBits))
3868 return SDValue();
3869
3870 SDValue SatLimit =
3871 DAG.getConstant(Val: APInt::getLowBitsSet(numBits: SrcVT.getScalarSizeInBits(),
3872 loBitsSet: DstVT.getScalarSizeInBits()),
3873 DL, VT: SrcVT);
3874 RHS = DAG.getNode(Opcode: ISD::UMIN, DL, VT: SrcVT, N1: RHS, N2: SatLimit);
3875 RHS = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: DstVT, Operand: RHS);
3876 LHS = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: DstVT, Operand: LHS);
3877 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT: DstVT, N1: LHS, N2: RHS);
3878}
3879
3880// Try to find umax(a,b) - b or a - umin(a,b) patterns that may be converted to
3881// usubsat(a,b), optionally as a truncated type.
3882SDValue DAGCombiner::foldSubToUSubSat(EVT DstVT, SDNode *N, const SDLoc &DL) {
3883 if (N->getOpcode() != ISD::SUB ||
3884 !(!LegalOperations || hasOperation(Opcode: ISD::USUBSAT, VT: DstVT)))
3885 return SDValue();
3886
3887 EVT SubVT = N->getValueType(ResNo: 0);
3888 SDValue Op0 = N->getOperand(Num: 0);
3889 SDValue Op1 = N->getOperand(Num: 1);
3890
3891 // Try to find umax(a,b) - b or a - umin(a,b) patterns
3892 // they may be converted to usubsat(a,b).
3893 if (Op0.getOpcode() == ISD::UMAX && Op0.hasOneUse()) {
3894 SDValue MaxLHS = Op0.getOperand(i: 0);
3895 SDValue MaxRHS = Op0.getOperand(i: 1);
3896 if (MaxLHS == Op1)
3897 return getTruncatedUSUBSAT(DstVT, SrcVT: SubVT, LHS: MaxRHS, RHS: Op1, DAG, DL);
3898 if (MaxRHS == Op1)
3899 return getTruncatedUSUBSAT(DstVT, SrcVT: SubVT, LHS: MaxLHS, RHS: Op1, DAG, DL);
3900 }
3901
3902 if (Op1.getOpcode() == ISD::UMIN && Op1.hasOneUse()) {
3903 SDValue MinLHS = Op1.getOperand(i: 0);
3904 SDValue MinRHS = Op1.getOperand(i: 1);
3905 if (MinLHS == Op0)
3906 return getTruncatedUSUBSAT(DstVT, SrcVT: SubVT, LHS: Op0, RHS: MinRHS, DAG, DL);
3907 if (MinRHS == Op0)
3908 return getTruncatedUSUBSAT(DstVT, SrcVT: SubVT, LHS: Op0, RHS: MinLHS, DAG, DL);
3909 }
3910
3911 // sub(a,trunc(umin(zext(a),b))) -> usubsat(a,trunc(umin(b,SatLimit)))
3912 if (Op1.getOpcode() == ISD::TRUNCATE &&
3913 Op1.getOperand(i: 0).getOpcode() == ISD::UMIN &&
3914 Op1.getOperand(i: 0).hasOneUse()) {
3915 SDValue MinLHS = Op1.getOperand(i: 0).getOperand(i: 0);
3916 SDValue MinRHS = Op1.getOperand(i: 0).getOperand(i: 1);
3917 if (MinLHS.getOpcode() == ISD::ZERO_EXTEND && MinLHS.getOperand(i: 0) == Op0)
3918 return getTruncatedUSUBSAT(DstVT, SrcVT: MinLHS.getValueType(), LHS: MinLHS, RHS: MinRHS,
3919 DAG, DL);
3920 if (MinRHS.getOpcode() == ISD::ZERO_EXTEND && MinRHS.getOperand(i: 0) == Op0)
3921 return getTruncatedUSUBSAT(DstVT, SrcVT: MinLHS.getValueType(), LHS: MinRHS, RHS: MinLHS,
3922 DAG, DL);
3923 }
3924
3925 return SDValue();
3926}
3927
3928// Refinement of DAG/Type Legalisation (promotion) when CTLZ is used for
3929// counting leading ones. Broadly, it replaces the substraction with a left
3930// shift.
3931//
3932// * DAG Legalisation Pattern:
3933//
3934// (sub (ctlz (zeroextend (not Src)))
3935// BitWidthDiff)
3936//
3937// if BitWidthDiff == BitWidth(Node) - BitWidth(Src)
3938// -->
3939//
3940// (ctlz_zero_undef (not (shl (anyextend Src)
3941// BitWidthDiff)))
3942//
3943// * Type Legalisation Pattern:
3944//
3945// (sub (ctlz (and (xor Src XorMask)
3946// AndMask))
3947// BitWidthDiff)
3948//
3949// if AndMask has only trailing ones
3950// and MaskBitWidth(AndMask) == BitWidth(Node) - BitWidthDiff
3951// and XorMask has more trailing ones than AndMask
3952// -->
3953//
3954// (ctlz_zero_undef (not (shl Src BitWidthDiff)))
3955template <class MatchContextClass>
3956static SDValue foldSubCtlzNot(SDNode *N, SelectionDAG &DAG) {
3957 const SDLoc DL(N);
3958 SDValue N0 = N->getOperand(Num: 0);
3959 EVT VT = N0.getValueType();
3960 unsigned BitWidth = VT.getScalarSizeInBits();
3961
3962 MatchContextClass Matcher(DAG, DAG.getTargetLoweringInfo(), N);
3963
3964 APInt AndMask;
3965 APInt XorMask;
3966 APInt BitWidthDiff;
3967
3968 SDValue CtlzOp;
3969 SDValue Src;
3970
3971 if (!sd_context_match(
3972 N, Matcher, m_Sub(L: m_Ctlz(Op: m_Value(N&: CtlzOp)), R: m_ConstInt(V&: BitWidthDiff))))
3973 return SDValue();
3974
3975 if (sd_context_match(CtlzOp, Matcher, m_ZExt(Op: m_Not(V: m_Value(N&: Src))))) {
3976 // DAG Legalisation Pattern:
3977 // (sub (ctlz (zero_extend (not Op)) BitWidthDiff))
3978 if ((BitWidth - Src.getValueType().getScalarSizeInBits()) != BitWidthDiff)
3979 return SDValue();
3980
3981 Src = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT, Operand: Src);
3982 } else if (sd_context_match(CtlzOp, Matcher,
3983 m_And(L: m_Xor(L: m_Value(N&: Src), R: m_ConstInt(V&: XorMask)),
3984 R: m_ConstInt(V&: AndMask)))) {
3985 // Type Legalisation Pattern:
3986 // (sub (ctlz (and (xor Op XorMask) AndMask)) BitWidthDiff)
3987 unsigned AndMaskWidth = BitWidth - BitWidthDiff.getZExtValue();
3988 if (!(AndMask.isMask(numBits: AndMaskWidth) && XorMask.countr_one() >= AndMaskWidth))
3989 return SDValue();
3990 } else
3991 return SDValue();
3992
3993 SDValue ShiftConst = DAG.getShiftAmountConstant(Val: BitWidthDiff, VT, DL);
3994 SDValue LShift = Matcher.getNode(ISD::SHL, DL, VT, Src, ShiftConst);
3995 SDValue Not =
3996 Matcher.getNode(ISD::XOR, DL, VT, LShift, DAG.getAllOnesConstant(DL, VT));
3997
3998 return Matcher.getNode(ISD::CTLZ_ZERO_UNDEF, DL, VT, Not);
3999}
4000
4001// Fold sub(x, mul(divrem(x,y)[0], y)) to divrem(x, y)[1]
4002static SDValue foldRemainderIdiom(SDNode *N, SelectionDAG &DAG,
4003 const SDLoc &DL) {
4004 assert(N->getOpcode() == ISD::SUB && "Node must be a SUB");
4005 SDValue Sub0 = N->getOperand(Num: 0);
4006 SDValue Sub1 = N->getOperand(Num: 1);
4007
4008 auto CheckAndFoldMulCase = [&](SDValue DivRem, SDValue MaybeY) -> SDValue {
4009 if ((DivRem.getOpcode() == ISD::SDIVREM ||
4010 DivRem.getOpcode() == ISD::UDIVREM) &&
4011 DivRem.getResNo() == 0 && DivRem.getOperand(i: 0) == Sub0 &&
4012 DivRem.getOperand(i: 1) == MaybeY) {
4013 return SDValue(DivRem.getNode(), 1);
4014 }
4015 return SDValue();
4016 };
4017
4018 if (Sub1.getOpcode() == ISD::MUL) {
4019 // (sub x, (mul divrem(x,y)[0], y))
4020 SDValue Mul0 = Sub1.getOperand(i: 0);
4021 SDValue Mul1 = Sub1.getOperand(i: 1);
4022
4023 if (SDValue Res = CheckAndFoldMulCase(Mul0, Mul1))
4024 return Res;
4025
4026 if (SDValue Res = CheckAndFoldMulCase(Mul1, Mul0))
4027 return Res;
4028
4029 } else if (Sub1.getOpcode() == ISD::SHL) {
4030 // Handle (sub x, (shl divrem(x,y)[0], C)) where y = 1 << C
4031 SDValue Shl0 = Sub1.getOperand(i: 0);
4032 SDValue Shl1 = Sub1.getOperand(i: 1);
4033 // Check if Shl0 is divrem(x, Y)[0]
4034 if ((Shl0.getOpcode() == ISD::SDIVREM ||
4035 Shl0.getOpcode() == ISD::UDIVREM) &&
4036 Shl0.getResNo() == 0 && Shl0.getOperand(i: 0) == Sub0) {
4037
4038 SDValue Divisor = Shl0.getOperand(i: 1);
4039
4040 ConstantSDNode *DivC = isConstOrConstSplat(N: Divisor);
4041 ConstantSDNode *ShC = isConstOrConstSplat(N: Shl1);
4042 if (!DivC || !ShC)
4043 return SDValue();
4044
4045 if (DivC->getAPIntValue().isPowerOf2() &&
4046 DivC->getAPIntValue().logBase2() == ShC->getAPIntValue())
4047 return SDValue(Shl0.getNode(), 1);
4048 }
4049 }
4050 return SDValue();
4051}
4052
4053// Since it may not be valid to emit a fold to zero for vector initializers
4054// check if we can before folding.
4055static SDValue tryFoldToZero(const SDLoc &DL, const TargetLowering &TLI, EVT VT,
4056 SelectionDAG &DAG, bool LegalOperations) {
4057 if (!VT.isVector())
4058 return DAG.getConstant(Val: 0, DL, VT);
4059 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT))
4060 return DAG.getConstant(Val: 0, DL, VT);
4061 return SDValue();
4062}
4063
4064SDValue DAGCombiner::visitSUB(SDNode *N) {
4065 SDValue N0 = N->getOperand(Num: 0);
4066 SDValue N1 = N->getOperand(Num: 1);
4067 EVT VT = N0.getValueType();
4068 unsigned BitWidth = VT.getScalarSizeInBits();
4069 SDLoc DL(N);
4070
4071 auto PeekThroughFreeze = [](SDValue N) {
4072 if (N->getOpcode() == ISD::FREEZE && N.hasOneUse())
4073 return N->getOperand(Num: 0);
4074 return N;
4075 };
4076
4077 if (SDValue V = foldSubCtlzNot<EmptyMatchContext>(N, DAG))
4078 return V;
4079
4080 // fold (sub x, x) -> 0
4081 // FIXME: Refactor this and xor and other similar operations together.
4082 if (PeekThroughFreeze(N0) == PeekThroughFreeze(N1))
4083 return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
4084
4085 // fold (sub c1, c2) -> c3
4086 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::SUB, DL, VT, Ops: {N0, N1}))
4087 return C;
4088
4089 // fold vector ops
4090 if (VT.isVector()) {
4091 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4092 return FoldedVOp;
4093
4094 // fold (sub x, 0) -> x, vector edition
4095 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
4096 return N0;
4097 }
4098
4099 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
4100 return NewSel;
4101
4102 // fold (sub x, c) -> (add x, -c)
4103 if (ConstantSDNode *N1C = getAsNonOpaqueConstant(N: N1))
4104 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0,
4105 N2: DAG.getConstant(Val: -N1C->getAPIntValue(), DL, VT));
4106
4107 if (isNullOrNullSplat(V: N0)) {
4108 // Right-shifting everything out but the sign bit followed by negation is
4109 // the same as flipping arithmetic/logical shift type without the negation:
4110 // -(X >>u 31) -> (X >>s 31)
4111 // -(X >>s 31) -> (X >>u 31)
4112 if (N1->getOpcode() == ISD::SRA || N1->getOpcode() == ISD::SRL) {
4113 ConstantSDNode *ShiftAmt = isConstOrConstSplat(N: N1.getOperand(i: 1));
4114 if (ShiftAmt && ShiftAmt->getAPIntValue() == (BitWidth - 1)) {
4115 auto NewSh = N1->getOpcode() == ISD::SRA ? ISD::SRL : ISD::SRA;
4116 if (!LegalOperations || TLI.isOperationLegal(Op: NewSh, VT))
4117 return DAG.getNode(Opcode: NewSh, DL, VT, N1: N1.getOperand(i: 0), N2: N1.getOperand(i: 1));
4118 }
4119 }
4120
4121 // 0 - X --> 0 if the sub is NUW.
4122 if (N->getFlags().hasNoUnsignedWrap())
4123 return N0;
4124
4125 if (DAG.MaskedValueIsZero(Op: N1, Mask: ~APInt::getSignMask(BitWidth))) {
4126 // N1 is either 0 or the minimum signed value. If the sub is NSW, then
4127 // N1 must be 0 because negating the minimum signed value is undefined.
4128 if (N->getFlags().hasNoSignedWrap())
4129 return N0;
4130
4131 // 0 - X --> X if X is 0 or the minimum signed value.
4132 return N1;
4133 }
4134
4135 // Convert 0 - abs(x).
4136 if (N1.getOpcode() == ISD::ABS && N1.hasOneUse() &&
4137 !TLI.isOperationLegalOrCustom(Op: ISD::ABS, VT))
4138 if (SDValue Result = TLI.expandABS(N: N1.getNode(), DAG, IsNegative: true))
4139 return Result;
4140
4141 // Similar to the previous rule, but this time targeting an expanded abs.
4142 // (sub 0, (max X, (sub 0, X))) --> (min X, (sub 0, X))
4143 // as well as
4144 // (sub 0, (min X, (sub 0, X))) --> (max X, (sub 0, X))
4145 // Note that these two are applicable to both signed and unsigned min/max.
4146 SDValue X;
4147 SDValue S0;
4148 auto NegPat = m_AllOf(preds: m_Neg(V: m_Deferred(V&: X)), preds: m_Value(N&: S0));
4149 if (sd_match(N: N1, P: m_OneUse(P: m_AnyOf(preds: m_SMax(L: m_Value(N&: X), R: NegPat),
4150 preds: m_UMax(L: m_Value(N&: X), R: NegPat),
4151 preds: m_SMin(L: m_Value(N&: X), R: NegPat),
4152 preds: m_UMin(L: m_Value(N&: X), R: NegPat))))) {
4153 unsigned NewOpc = ISD::getInverseMinMaxOpcode(MinMaxOpc: N1->getOpcode());
4154 if (hasOperation(Opcode: NewOpc, VT))
4155 return DAG.getNode(Opcode: NewOpc, DL, VT, N1: X, N2: S0);
4156 }
4157
4158 // Fold neg(splat(neg(x)) -> splat(x)
4159 if (VT.isVector()) {
4160 SDValue N1S = DAG.getSplatValue(V: N1, LegalTypes: true);
4161 if (N1S && N1S.getOpcode() == ISD::SUB &&
4162 isNullConstant(V: N1S.getOperand(i: 0)))
4163 return DAG.getSplat(VT, DL, Op: N1S.getOperand(i: 1));
4164 }
4165
4166 // sub 0, (and x, 1) --> SIGN_EXTEND_INREG x, i1
4167 if (N1.getOpcode() == ISD::AND && N1.hasOneUse() &&
4168 isOneOrOneSplat(V: N1->getOperand(Num: 1))) {
4169 EVT ExtVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: 1);
4170 if (VT.isVector())
4171 ExtVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: ExtVT,
4172 EC: VT.getVectorElementCount());
4173 if (TLI.getOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: ExtVT) ==
4174 TargetLowering::Legal) {
4175 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL, VT, N1: N1->getOperand(Num: 0),
4176 N2: DAG.getValueType(ExtVT));
4177 }
4178 }
4179 }
4180
4181 // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1)
4182 if (isAllOnesOrAllOnesSplat(V: N0))
4183 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1, N2: N0);
4184
4185 // fold (A - (0-B)) -> A+B
4186 if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(V: N1.getOperand(i: 0)))
4187 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1.getOperand(i: 1));
4188
4189 // fold A-(A-B) -> B
4190 if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(i: 0))
4191 return N1.getOperand(i: 1);
4192
4193 // fold (A+B)-A -> B
4194 if (N0.getOpcode() == ISD::ADD && N0.getOperand(i: 0) == N1)
4195 return N0.getOperand(i: 1);
4196
4197 // fold (A+B)-B -> A
4198 if (N0.getOpcode() == ISD::ADD && N0.getOperand(i: 1) == N1)
4199 return N0.getOperand(i: 0);
4200
4201 // fold (A+C1)-C2 -> A+(C1-C2)
4202 if (N0.getOpcode() == ISD::ADD) {
4203 SDValue N01 = N0.getOperand(i: 1);
4204 if (SDValue NewC = DAG.FoldConstantArithmetic(Opcode: ISD::SUB, DL, VT, Ops: {N01, N1}))
4205 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: NewC);
4206 }
4207
4208 // fold C2-(A+C1) -> (C2-C1)-A
4209 if (N1.getOpcode() == ISD::ADD) {
4210 SDValue N11 = N1.getOperand(i: 1);
4211 if (SDValue NewC = DAG.FoldConstantArithmetic(Opcode: ISD::SUB, DL, VT, Ops: {N0, N11}))
4212 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: NewC, N2: N1.getOperand(i: 0));
4213 }
4214
4215 // fold (A-C1)-C2 -> A-(C1+C2)
4216 if (N0.getOpcode() == ISD::SUB) {
4217 SDValue N01 = N0.getOperand(i: 1);
4218 if (SDValue NewC = DAG.FoldConstantArithmetic(Opcode: ISD::ADD, DL, VT, Ops: {N01, N1}))
4219 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0.getOperand(i: 0), N2: NewC);
4220 }
4221
4222 // fold (c1-A)-c2 -> (c1-c2)-A
4223 if (N0.getOpcode() == ISD::SUB) {
4224 SDValue N00 = N0.getOperand(i: 0);
4225 if (SDValue NewC = DAG.FoldConstantArithmetic(Opcode: ISD::SUB, DL, VT, Ops: {N00, N1}))
4226 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: NewC, N2: N0.getOperand(i: 1));
4227 }
4228
4229 SDValue A, B, C;
4230
4231 // fold ((A+(B+C))-B) -> A+C
4232 if (sd_match(N: N0, P: m_Add(L: m_Value(N&: A), R: m_Add(L: m_Specific(N: N1), R: m_Value(N&: C)))))
4233 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: A, N2: C);
4234
4235 // fold ((A+(B-C))-B) -> A-C
4236 if (sd_match(N: N0, P: m_Add(L: m_Value(N&: A), R: m_Sub(L: m_Specific(N: N1), R: m_Value(N&: C)))))
4237 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: A, N2: C);
4238
4239 // fold ((A-(B-C))-C) -> A-B
4240 if (sd_match(N: N0, P: m_Sub(L: m_Value(N&: A), R: m_Sub(L: m_Value(N&: B), R: m_Specific(N: N1)))))
4241 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: A, N2: B);
4242
4243 // fold (A-(B-C)) -> A+(C-B)
4244 if (sd_match(N: N1, P: m_OneUse(P: m_Sub(L: m_Value(N&: B), R: m_Value(N&: C)))))
4245 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0,
4246 N2: DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: C, N2: B));
4247
4248 // A - (A & B) -> A & (~B)
4249 if (sd_match(N: N1, P: m_And(L: m_Specific(N: N0), R: m_Value(N&: B))) &&
4250 (N1.hasOneUse() || isConstantOrConstantVector(N: B, /*NoOpaques=*/true)))
4251 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0, N2: DAG.getNOT(DL, Val: B, VT));
4252
4253 // fold (A - (-B * C)) -> (A + (B * C))
4254 if (sd_match(N: N1, P: m_OneUse(P: m_Mul(L: m_Neg(V: m_Value(N&: B)), R: m_Value(N&: C)))))
4255 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0,
4256 N2: DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: B, N2: C));
4257
4258 // If either operand of a sub is undef, the result is undef
4259 if (N0.isUndef())
4260 return N0;
4261 if (N1.isUndef())
4262 return N1;
4263
4264 if (SDValue V = foldAddSubBoolOfMaskedVal(N, DL, DAG))
4265 return V;
4266
4267 if (SDValue V = foldAddSubOfSignBit(N, DL, DAG))
4268 return V;
4269
4270 // Try to match AVGCEIL fixedwidth pattern
4271 if (SDValue V = foldSubToAvg(N, DL))
4272 return V;
4273
4274 if (SDValue V = foldAddSubMasked1(IsAdd: false, N0, N1, DAG, DL))
4275 return V;
4276
4277 if (SDValue V = foldSubToUSubSat(DstVT: VT, N, DL))
4278 return V;
4279
4280 if (SDValue V = foldRemainderIdiom(N, DAG, DL))
4281 return V;
4282
4283 // (A - B) - 1 -> add (xor B, -1), A
4284 if (sd_match(N, P: m_Sub(L: m_OneUse(P: m_Sub(L: m_Value(N&: A), R: m_Value(N&: B))), R: m_One())))
4285 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: A, N2: DAG.getNOT(DL, Val: B, VT));
4286
4287 // Look for:
4288 // sub y, (xor x, -1)
4289 // And if the target does not like this form then turn into:
4290 // add (add x, y), 1
4291 if (TLI.preferIncOfAddToSubOfNot(VT) && N1.hasOneUse() && isBitwiseNot(V: N1)) {
4292 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1.getOperand(i: 0));
4293 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Add, N2: DAG.getConstant(Val: 1, DL, VT));
4294 }
4295
4296 // Hoist one-use addition by non-opaque constant:
4297 // (x + C) - y -> (x - y) + C
4298 if (!reassociationCanBreakAddressingModePattern(Opc: ISD::SUB, DL, N, N0, N1) &&
4299 N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
4300 isConstantOrConstantVector(N: N0.getOperand(i: 1), /*NoOpaques=*/true)) {
4301 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0.getOperand(i: 0), N2: N1);
4302 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Sub, N2: N0.getOperand(i: 1));
4303 }
4304 // y - (x + C) -> (y - x) - C
4305 if (N1.getOpcode() == ISD::ADD && N1.hasOneUse() &&
4306 isConstantOrConstantVector(N: N1.getOperand(i: 1), /*NoOpaques=*/true)) {
4307 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: N1.getOperand(i: 0));
4308 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Sub, N2: N1.getOperand(i: 1));
4309 }
4310 // (x - C) - y -> (x - y) - C
4311 // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
4312 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
4313 isConstantOrConstantVector(N: N0.getOperand(i: 1), /*NoOpaques=*/true)) {
4314 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0.getOperand(i: 0), N2: N1);
4315 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Sub, N2: N0.getOperand(i: 1));
4316 }
4317 // (C - x) - y -> C - (x + y)
4318 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
4319 isConstantOrConstantVector(N: N0.getOperand(i: 0), /*NoOpaques=*/true)) {
4320 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 1), N2: N1);
4321 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0.getOperand(i: 0), N2: Add);
4322 }
4323
4324 // If the target's bool is represented as 0/-1, prefer to make this 'add 0/-1'
4325 // rather than 'sub 0/1' (the sext should get folded).
4326 // sub X, (zext i1 Y) --> add X, (sext i1 Y)
4327 if (N1.getOpcode() == ISD::ZERO_EXTEND &&
4328 N1.getOperand(i: 0).getScalarValueSizeInBits() == 1 &&
4329 TLI.getBooleanContents(Type: VT) ==
4330 TargetLowering::ZeroOrNegativeOneBooleanContent) {
4331 SDValue SExt = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: N1.getOperand(i: 0));
4332 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: SExt);
4333 }
4334
4335 // fold B = sra (A, size(A)-1); sub (xor (A, B), B) -> (abs A)
4336 if ((!LegalOperations || hasOperation(Opcode: ISD::ABS, VT)) &&
4337 sd_match(N: N1, P: m_Sra(L: m_Value(N&: A), R: m_SpecificInt(V: BitWidth - 1))) &&
4338 sd_match(N: N0, P: m_Xor(L: m_Specific(N: A), R: m_Specific(N: N1))))
4339 return DAG.getNode(Opcode: ISD::ABS, DL, VT, Operand: A);
4340
4341 // If the relocation model supports it, consider symbol offsets.
4342 if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(Val&: N0))
4343 if (!LegalOperations && TLI.isOffsetFoldingLegal(GA)) {
4344 // fold (sub Sym+c1, Sym+c2) -> c1-c2
4345 if (GlobalAddressSDNode *GB = dyn_cast<GlobalAddressSDNode>(Val&: N1))
4346 if (GA->getGlobal() == GB->getGlobal())
4347 return DAG.getConstant(Val: (uint64_t)GA->getOffset() - GB->getOffset(),
4348 DL, VT);
4349 }
4350
4351 // sub X, (sextinreg Y i1) -> add X, (and Y 1)
4352 if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
4353 VTSDNode *TN = cast<VTSDNode>(Val: N1.getOperand(i: 1));
4354 if (TN->getVT() == MVT::i1) {
4355 SDValue ZExt = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N1.getOperand(i: 0),
4356 N2: DAG.getConstant(Val: 1, DL, VT));
4357 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: ZExt);
4358 }
4359 }
4360
4361 // canonicalize (sub X, (vscale * C)) to (add X, (vscale * -C))
4362 if (N1.getOpcode() == ISD::VSCALE && N1.hasOneUse()) {
4363 const APInt &IntVal = N1.getConstantOperandAPInt(i: 0);
4364 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: DAG.getVScale(DL, VT, MulImm: -IntVal));
4365 }
4366
4367 // canonicalize (sub X, step_vector(C)) to (add X, step_vector(-C))
4368 if (N1.getOpcode() == ISD::STEP_VECTOR && N1.hasOneUse()) {
4369 APInt NewStep = -N1.getConstantOperandAPInt(i: 0);
4370 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0,
4371 N2: DAG.getStepVector(DL, ResVT: VT, StepVal: NewStep));
4372 }
4373
4374 // Prefer an add for more folding potential and possibly better codegen:
4375 // sub N0, (lshr N10, width-1) --> add N0, (ashr N10, width-1)
4376 if (!LegalOperations && N1.getOpcode() == ISD::SRL && N1.hasOneUse()) {
4377 SDValue ShAmt = N1.getOperand(i: 1);
4378 ConstantSDNode *ShAmtC = isConstOrConstSplat(N: ShAmt);
4379 if (ShAmtC && ShAmtC->getAPIntValue() == (BitWidth - 1)) {
4380 SDValue SRA = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: N1.getOperand(i: 0), N2: ShAmt);
4381 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: SRA);
4382 }
4383 }
4384
4385 // As with the previous fold, prefer add for more folding potential.
4386 // Subtracting SMIN/0 is the same as adding SMIN/0:
4387 // N0 - (X << BW-1) --> N0 + (X << BW-1)
4388 if (N1.getOpcode() == ISD::SHL) {
4389 ConstantSDNode *ShlC = isConstOrConstSplat(N: N1.getOperand(i: 1));
4390 if (ShlC && ShlC->getAPIntValue() == (BitWidth - 1))
4391 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: N0);
4392 }
4393
4394 // (sub (usubo_carry X, 0, Carry), Y) -> (usubo_carry X, Y, Carry)
4395 if (N0.getOpcode() == ISD::USUBO_CARRY && isNullConstant(V: N0.getOperand(i: 1)) &&
4396 N0.getResNo() == 0 && N0.hasOneUse())
4397 return DAG.getNode(Opcode: ISD::USUBO_CARRY, DL, VTList: N0->getVTList(),
4398 N1: N0.getOperand(i: 0), N2: N1, N3: N0.getOperand(i: 2));
4399
4400 if (TLI.isOperationLegalOrCustom(Op: ISD::UADDO_CARRY, VT)) {
4401 // (sub Carry, X) -> (uaddo_carry (sub 0, X), 0, Carry)
4402 if (SDValue Carry = getAsCarry(TLI, V: N0)) {
4403 SDValue X = N1;
4404 SDValue Zero = DAG.getConstant(Val: 0, DL, VT);
4405 SDValue NegX = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Zero, N2: X);
4406 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL,
4407 VTList: DAG.getVTList(VT1: VT, VT2: Carry.getValueType()), N1: NegX, N2: Zero,
4408 N3: Carry);
4409 }
4410 }
4411
4412 // If there's no chance of borrowing from adjacent bits, then sub is xor:
4413 // sub C0, X --> xor X, C0
4414 if (ConstantSDNode *C0 = isConstOrConstSplat(N: N0)) {
4415 if (!C0->isOpaque()) {
4416 const APInt &C0Val = C0->getAPIntValue();
4417 const APInt &MaybeOnes = ~DAG.computeKnownBits(Op: N1).Zero;
4418 if ((C0Val - MaybeOnes) == (C0Val ^ MaybeOnes))
4419 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1, N2: N0);
4420 }
4421 }
4422
4423 // smax(a,b) - smin(a,b) --> abds(a,b)
4424 if ((!LegalOperations || hasOperation(Opcode: ISD::ABDS, VT)) &&
4425 sd_match(N: N0, P: m_SMaxLike(L: m_Value(N&: A), R: m_Value(N&: B))) &&
4426 sd_match(N: N1, P: m_SMinLike(L: m_Specific(N: A), R: m_Specific(N: B))))
4427 return DAG.getNode(Opcode: ISD::ABDS, DL, VT, N1: A, N2: B);
4428
4429 // smin(a,b) - smax(a,b) --> neg(abds(a,b))
4430 if (hasOperation(Opcode: ISD::ABDS, VT) &&
4431 sd_match(N: N0, P: m_SMinLike(L: m_Value(N&: A), R: m_Value(N&: B))) &&
4432 sd_match(N: N1, P: m_SMaxLike(L: m_Specific(N: A), R: m_Specific(N: B))))
4433 return DAG.getNegative(Val: DAG.getNode(Opcode: ISD::ABDS, DL, VT, N1: A, N2: B), DL, VT);
4434
4435 // umax(a,b) - umin(a,b) --> abdu(a,b)
4436 if ((!LegalOperations || hasOperation(Opcode: ISD::ABDU, VT)) &&
4437 sd_match(N: N0, P: m_UMaxLike(L: m_Value(N&: A), R: m_Value(N&: B))) &&
4438 sd_match(N: N1, P: m_UMinLike(L: m_Specific(N: A), R: m_Specific(N: B))))
4439 return DAG.getNode(Opcode: ISD::ABDU, DL, VT, N1: A, N2: B);
4440
4441 // umin(a,b) - umax(a,b) --> neg(abdu(a,b))
4442 if (hasOperation(Opcode: ISD::ABDU, VT) &&
4443 sd_match(N: N0, P: m_UMinLike(L: m_Value(N&: A), R: m_Value(N&: B))) &&
4444 sd_match(N: N1, P: m_UMaxLike(L: m_Specific(N: A), R: m_Specific(N: B))))
4445 return DAG.getNegative(Val: DAG.getNode(Opcode: ISD::ABDU, DL, VT, N1: A, N2: B), DL, VT);
4446
4447 // (sub x, (select (ult x, y), 0, y)) -> (umin x, (sub x, y))
4448 // (sub x, (select (uge x, y), y, 0)) -> (umin x, (sub x, y))
4449 if (hasUMin(VT)) {
4450 SDValue Y;
4451 if (sd_match(N: N1, P: m_OneUse(P: m_Select(Cond: m_SetCC(LHS: m_Specific(N: N0), RHS: m_Value(N&: Y),
4452 CC: m_SpecificCondCode(CC: ISD::SETULT)),
4453 T: m_Zero(), F: m_Deferred(V&: Y)))) ||
4454 sd_match(N: N1, P: m_OneUse(P: m_Select(Cond: m_SetCC(LHS: m_Specific(N: N0), RHS: m_Value(N&: Y),
4455 CC: m_SpecificCondCode(CC: ISD::SETUGE)),
4456 T: m_Deferred(V&: Y), F: m_Zero()))))
4457 return DAG.getNode(Opcode: ISD::UMIN, DL, VT, N1: N0,
4458 N2: DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: Y));
4459 }
4460
4461 return SDValue();
4462}
4463
4464SDValue DAGCombiner::visitSUBSAT(SDNode *N) {
4465 unsigned Opcode = N->getOpcode();
4466 SDValue N0 = N->getOperand(Num: 0);
4467 SDValue N1 = N->getOperand(Num: 1);
4468 EVT VT = N0.getValueType();
4469 bool IsSigned = Opcode == ISD::SSUBSAT;
4470 SDLoc DL(N);
4471
4472 // fold (sub_sat x, undef) -> 0
4473 if (N0.isUndef() || N1.isUndef())
4474 return DAG.getConstant(Val: 0, DL, VT);
4475
4476 // fold (sub_sat x, x) -> 0
4477 if (N0 == N1)
4478 return DAG.getConstant(Val: 0, DL, VT);
4479
4480 // fold (sub_sat c1, c2) -> c3
4481 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
4482 return C;
4483
4484 // fold vector ops
4485 if (VT.isVector()) {
4486 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4487 return FoldedVOp;
4488
4489 // fold (sub_sat x, 0) -> x, vector edition
4490 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
4491 return N0;
4492 }
4493
4494 // fold (sub_sat x, 0) -> x
4495 if (isNullConstant(V: N1))
4496 return N0;
4497
4498 // If it cannot overflow, transform into an sub.
4499 if (DAG.willNotOverflowSub(IsSigned, N0, N1))
4500 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: N1);
4501
4502 return SDValue();
4503}
4504
4505SDValue DAGCombiner::visitSUBC(SDNode *N) {
4506 SDValue N0 = N->getOperand(Num: 0);
4507 SDValue N1 = N->getOperand(Num: 1);
4508 EVT VT = N0.getValueType();
4509 SDLoc DL(N);
4510
4511 // If the flag result is dead, turn this into an SUB.
4512 if (!N->hasAnyUseOfValue(Value: 1))
4513 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: N1),
4514 Res1: DAG.getNode(Opcode: ISD::CARRY_FALSE, DL, VT: MVT::Glue));
4515
4516 // fold (subc x, x) -> 0 + no borrow
4517 if (N0 == N1)
4518 return CombineTo(N, Res0: DAG.getConstant(Val: 0, DL, VT),
4519 Res1: DAG.getNode(Opcode: ISD::CARRY_FALSE, DL, VT: MVT::Glue));
4520
4521 // fold (subc x, 0) -> x + no borrow
4522 if (isNullConstant(V: N1))
4523 return CombineTo(N, Res0: N0, Res1: DAG.getNode(Opcode: ISD::CARRY_FALSE, DL, VT: MVT::Glue));
4524
4525 // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1) + no borrow
4526 if (isAllOnesConstant(V: N0))
4527 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::XOR, DL, VT, N1, N2: N0),
4528 Res1: DAG.getNode(Opcode: ISD::CARRY_FALSE, DL, VT: MVT::Glue));
4529
4530 return SDValue();
4531}
4532
4533SDValue DAGCombiner::visitSUBO(SDNode *N) {
4534 SDValue N0 = N->getOperand(Num: 0);
4535 SDValue N1 = N->getOperand(Num: 1);
4536 EVT VT = N0.getValueType();
4537 bool IsSigned = (ISD::SSUBO == N->getOpcode());
4538
4539 EVT CarryVT = N->getValueType(ResNo: 1);
4540 SDLoc DL(N);
4541
4542 // If the flag result is dead, turn this into an SUB.
4543 if (!N->hasAnyUseOfValue(Value: 1))
4544 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: N1),
4545 Res1: DAG.getUNDEF(VT: CarryVT));
4546
4547 // fold (subo x, x) -> 0 + no borrow
4548 if (N0 == N1)
4549 return CombineTo(N, Res0: DAG.getConstant(Val: 0, DL, VT),
4550 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
4551
4552 // fold (subox, c) -> (addo x, -c)
4553 if (ConstantSDNode *N1C = getAsNonOpaqueConstant(N: N1))
4554 if (IsSigned && !N1C->isMinSignedValue())
4555 return DAG.getNode(Opcode: ISD::SADDO, DL, VTList: N->getVTList(), N1: N0,
4556 N2: DAG.getConstant(Val: -N1C->getAPIntValue(), DL, VT));
4557
4558 // fold (subo x, 0) -> x + no borrow
4559 if (isNullOrNullSplat(V: N1))
4560 return CombineTo(N, Res0: N0, Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
4561
4562 // If it cannot overflow, transform into an sub.
4563 if (DAG.willNotOverflowSub(IsSigned, N0, N1))
4564 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: N1),
4565 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
4566
4567 // Canonicalize (usubo -1, x) -> ~x, i.e. (xor x, -1) + no borrow
4568 if (!IsSigned && isAllOnesOrAllOnesSplat(V: N0))
4569 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::XOR, DL, VT, N1, N2: N0),
4570 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
4571
4572 return SDValue();
4573}
4574
4575SDValue DAGCombiner::visitSUBE(SDNode *N) {
4576 SDValue N0 = N->getOperand(Num: 0);
4577 SDValue N1 = N->getOperand(Num: 1);
4578 SDValue CarryIn = N->getOperand(Num: 2);
4579
4580 // fold (sube x, y, false) -> (subc x, y)
4581 if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
4582 return DAG.getNode(Opcode: ISD::SUBC, DL: SDLoc(N), VTList: N->getVTList(), N1: N0, N2: N1);
4583
4584 return SDValue();
4585}
4586
4587SDValue DAGCombiner::visitUSUBO_CARRY(SDNode *N) {
4588 SDValue N0 = N->getOperand(Num: 0);
4589 SDValue N1 = N->getOperand(Num: 1);
4590 SDValue CarryIn = N->getOperand(Num: 2);
4591
4592 // fold (usubo_carry x, y, false) -> (usubo x, y)
4593 if (isNullConstant(V: CarryIn)) {
4594 if (!LegalOperations ||
4595 TLI.isOperationLegalOrCustom(Op: ISD::USUBO, VT: N->getValueType(ResNo: 0)))
4596 return DAG.getNode(Opcode: ISD::USUBO, DL: SDLoc(N), VTList: N->getVTList(), N1: N0, N2: N1);
4597 }
4598
4599 return SDValue();
4600}
4601
4602SDValue DAGCombiner::visitSSUBO_CARRY(SDNode *N) {
4603 SDValue N0 = N->getOperand(Num: 0);
4604 SDValue N1 = N->getOperand(Num: 1);
4605 SDValue CarryIn = N->getOperand(Num: 2);
4606
4607 // fold (ssubo_carry x, y, false) -> (ssubo x, y)
4608 if (isNullConstant(V: CarryIn)) {
4609 if (!LegalOperations ||
4610 TLI.isOperationLegalOrCustom(Op: ISD::SSUBO, VT: N->getValueType(ResNo: 0)))
4611 return DAG.getNode(Opcode: ISD::SSUBO, DL: SDLoc(N), VTList: N->getVTList(), N1: N0, N2: N1);
4612 }
4613
4614 return SDValue();
4615}
4616
4617// Notice that "mulfix" can be any of SMULFIX, SMULFIXSAT, UMULFIX and
4618// UMULFIXSAT here.
4619SDValue DAGCombiner::visitMULFIX(SDNode *N) {
4620 SDValue N0 = N->getOperand(Num: 0);
4621 SDValue N1 = N->getOperand(Num: 1);
4622 SDValue Scale = N->getOperand(Num: 2);
4623 EVT VT = N0.getValueType();
4624
4625 // fold (mulfix x, undef, scale) -> 0
4626 if (N0.isUndef() || N1.isUndef())
4627 return DAG.getConstant(Val: 0, DL: SDLoc(N), VT);
4628
4629 // Canonicalize constant to RHS (vector doesn't have to splat)
4630 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
4631 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
4632 return DAG.getNode(Opcode: N->getOpcode(), DL: SDLoc(N), VT, N1, N2: N0, N3: Scale);
4633
4634 // fold (mulfix x, 0, scale) -> 0
4635 if (isNullConstant(V: N1))
4636 return DAG.getConstant(Val: 0, DL: SDLoc(N), VT);
4637
4638 return SDValue();
4639}
4640
4641template <class MatchContextClass> SDValue DAGCombiner::visitMUL(SDNode *N) {
4642 SDValue N0 = N->getOperand(Num: 0);
4643 SDValue N1 = N->getOperand(Num: 1);
4644 EVT VT = N0.getValueType();
4645 unsigned BitWidth = VT.getScalarSizeInBits();
4646 SDLoc DL(N);
4647 bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
4648 MatchContextClass Matcher(DAG, TLI, N);
4649
4650 // fold (mul x, undef) -> 0
4651 if (N0.isUndef() || N1.isUndef())
4652 return DAG.getConstant(Val: 0, DL, VT);
4653
4654 // fold (mul c1, c2) -> c1*c2
4655 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::MUL, DL, VT, Ops: {N0, N1}))
4656 return C;
4657
4658 // canonicalize constant to RHS (vector doesn't have to splat)
4659 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
4660 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
4661 return Matcher.getNode(ISD::MUL, DL, VT, N1, N0);
4662
4663 bool N1IsConst = false;
4664 bool N1IsOpaqueConst = false;
4665 APInt ConstValue1;
4666
4667 // fold vector ops
4668 if (VT.isVector()) {
4669 // TODO: Change this to use SimplifyVBinOp when it supports VP op.
4670 if (!UseVP)
4671 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4672 return FoldedVOp;
4673
4674 N1IsConst = ISD::isConstantSplatVector(N: N1.getNode(), SplatValue&: ConstValue1);
4675 assert((!N1IsConst || ConstValue1.getBitWidth() == BitWidth) &&
4676 "Splat APInt should be element width");
4677 } else {
4678 N1IsConst = isa<ConstantSDNode>(Val: N1);
4679 if (N1IsConst) {
4680 ConstValue1 = N1->getAsAPIntVal();
4681 N1IsOpaqueConst = cast<ConstantSDNode>(Val&: N1)->isOpaque();
4682 }
4683 }
4684
4685 // fold (mul x, 0) -> 0
4686 if (N1IsConst && ConstValue1.isZero())
4687 return N1;
4688
4689 // fold (mul x, 1) -> x
4690 if (N1IsConst && ConstValue1.isOne())
4691 return N0;
4692
4693 if (!UseVP)
4694 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
4695 return NewSel;
4696
4697 // fold (mul x, -1) -> 0-x
4698 if (N1IsConst && ConstValue1.isAllOnes())
4699 return Matcher.getNode(ISD::SUB, DL, VT, DAG.getConstant(Val: 0, DL, VT), N0);
4700
4701 // fold (mul x, (1 << c)) -> x << c
4702 if (isConstantOrConstantVector(N: N1, /*NoOpaques*/ true) &&
4703 (!VT.isVector() || Level <= AfterLegalizeVectorOps)) {
4704 if (SDValue LogBase2 = BuildLogBase2(V: N1, DL)) {
4705 EVT ShiftVT = getShiftAmountTy(LHSTy: N0.getValueType());
4706 SDValue Trunc = DAG.getZExtOrTrunc(Op: LogBase2, DL, VT: ShiftVT);
4707 return Matcher.getNode(ISD::SHL, DL, VT, N0, Trunc);
4708 }
4709 }
4710
4711 // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
4712 if (N1IsConst && !N1IsOpaqueConst && ConstValue1.isNegatedPowerOf2()) {
4713 unsigned Log2Val = (-ConstValue1).logBase2();
4714
4715 // FIXME: If the input is something that is easily negated (e.g. a
4716 // single-use add), we should put the negate there.
4717 return Matcher.getNode(
4718 ISD::SUB, DL, VT, DAG.getConstant(Val: 0, DL, VT),
4719 Matcher.getNode(ISD::SHL, DL, VT, N0,
4720 DAG.getShiftAmountConstant(Val: Log2Val, VT, DL)));
4721 }
4722
4723 // Attempt to reuse an existing umul_lohi/smul_lohi node, but only if the
4724 // hi result is in use in case we hit this mid-legalization.
4725 if (!UseVP) {
4726 for (unsigned LoHiOpc : {ISD::UMUL_LOHI, ISD::SMUL_LOHI}) {
4727 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: LoHiOpc, VT)) {
4728 SDVTList LoHiVT = DAG.getVTList(VT1: VT, VT2: VT);
4729 // TODO: Can we match commutable operands with getNodeIfExists?
4730 if (SDNode *LoHi = DAG.getNodeIfExists(Opcode: LoHiOpc, VTList: LoHiVT, Ops: {N0, N1}))
4731 if (LoHi->hasAnyUseOfValue(Value: 1))
4732 return SDValue(LoHi, 0);
4733 if (SDNode *LoHi = DAG.getNodeIfExists(Opcode: LoHiOpc, VTList: LoHiVT, Ops: {N1, N0}))
4734 if (LoHi->hasAnyUseOfValue(Value: 1))
4735 return SDValue(LoHi, 0);
4736 }
4737 }
4738 }
4739
4740 // Try to transform:
4741 // (1) multiply-by-(power-of-2 +/- 1) into shift and add/sub.
4742 // mul x, (2^N + 1) --> add (shl x, N), x
4743 // mul x, (2^N - 1) --> sub (shl x, N), x
4744 // Examples: x * 33 --> (x << 5) + x
4745 // x * 15 --> (x << 4) - x
4746 // x * -33 --> -((x << 5) + x)
4747 // x * -15 --> -((x << 4) - x) ; this reduces --> x - (x << 4)
4748 // (2) multiply-by-(power-of-2 +/- power-of-2) into shifts and add/sub.
4749 // mul x, (2^N + 2^M) --> (add (shl x, N), (shl x, M))
4750 // mul x, (2^N - 2^M) --> (sub (shl x, N), (shl x, M))
4751 // Examples: x * 0x8800 --> (x << 15) + (x << 11)
4752 // x * 0xf800 --> (x << 16) - (x << 11)
4753 // x * -0x8800 --> -((x << 15) + (x << 11))
4754 // x * -0xf800 --> -((x << 16) - (x << 11)) ; (x << 11) - (x << 16)
4755 if (!UseVP && N1IsConst &&
4756 TLI.decomposeMulByConstant(Context&: *DAG.getContext(), VT, C: N1)) {
4757 // TODO: We could handle more general decomposition of any constant by
4758 // having the target set a limit on number of ops and making a
4759 // callback to determine that sequence (similar to sqrt expansion).
4760 unsigned MathOp = ISD::DELETED_NODE;
4761 APInt MulC = ConstValue1.abs();
4762 // The constant `2` should be treated as (2^0 + 1).
4763 unsigned TZeros = MulC == 2 ? 0 : MulC.countr_zero();
4764 MulC.lshrInPlace(ShiftAmt: TZeros);
4765 if ((MulC - 1).isPowerOf2())
4766 MathOp = ISD::ADD;
4767 else if ((MulC + 1).isPowerOf2())
4768 MathOp = ISD::SUB;
4769
4770 if (MathOp != ISD::DELETED_NODE) {
4771 unsigned ShAmt =
4772 MathOp == ISD::ADD ? (MulC - 1).logBase2() : (MulC + 1).logBase2();
4773 ShAmt += TZeros;
4774 assert(ShAmt < BitWidth &&
4775 "multiply-by-constant generated out of bounds shift");
4776 SDValue Shl =
4777 DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0, N2: DAG.getConstant(Val: ShAmt, DL, VT));
4778 SDValue R =
4779 TZeros ? DAG.getNode(Opcode: MathOp, DL, VT, N1: Shl,
4780 N2: DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0,
4781 N2: DAG.getConstant(Val: TZeros, DL, VT)))
4782 : DAG.getNode(Opcode: MathOp, DL, VT, N1: Shl, N2: N0);
4783 if (ConstValue1.isNegative())
4784 R = DAG.getNegative(Val: R, DL, VT);
4785 return R;
4786 }
4787 }
4788
4789 // (mul (shl X, c1), c2) -> (mul X, c2 << c1)
4790 if (sd_context_match(N0, Matcher, m_Opc(Opcode: ISD::SHL))) {
4791 SDValue N01 = N0.getOperand(i: 1);
4792 if (SDValue C3 = DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL, VT, Ops: {N1, N01}))
4793 return DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: N0.getOperand(i: 0), N2: C3);
4794 }
4795
4796 // Change (mul (shl X, C), Y) -> (shl (mul X, Y), C) when the shift has one
4797 // use.
4798 {
4799 SDValue Sh, Y;
4800
4801 // Check for both (mul (shl X, C), Y) and (mul Y, (shl X, C)).
4802 if (sd_context_match(N0, Matcher, m_OneUse(P: m_Opc(Opcode: ISD::SHL))) &&
4803 isConstantOrConstantVector(N: N0.getOperand(i: 1))) {
4804 Sh = N0; Y = N1;
4805 } else if (sd_context_match(N1, Matcher, m_OneUse(P: m_Opc(Opcode: ISD::SHL))) &&
4806 isConstantOrConstantVector(N: N1.getOperand(i: 1))) {
4807 Sh = N1; Y = N0;
4808 }
4809
4810 if (Sh.getNode()) {
4811 SDValue Mul = Matcher.getNode(ISD::MUL, DL, VT, Sh.getOperand(i: 0), Y);
4812 return Matcher.getNode(ISD::SHL, DL, VT, Mul, Sh.getOperand(i: 1));
4813 }
4814 }
4815
4816 // fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2)
4817 if (sd_context_match(N0, Matcher, m_Opc(Opcode: ISD::ADD)) &&
4818 DAG.isConstantIntBuildVectorOrConstantInt(N: N1) &&
4819 DAG.isConstantIntBuildVectorOrConstantInt(N: N0.getOperand(i: 1)) &&
4820 isMulAddWithConstProfitable(MulNode: N, AddNode: N0, ConstNode: N1))
4821 return Matcher.getNode(
4822 ISD::ADD, DL, VT,
4823 Matcher.getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(i: 0), N1),
4824 Matcher.getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(i: 1), N1));
4825
4826 // Fold (mul (vscale * C0), C1) to (vscale * (C0 * C1)).
4827 ConstantSDNode *NC1 = isConstOrConstSplat(N: N1);
4828 if (!UseVP && N0.getOpcode() == ISD::VSCALE && NC1) {
4829 const APInt &C0 = N0.getConstantOperandAPInt(i: 0);
4830 const APInt &C1 = NC1->getAPIntValue();
4831 return DAG.getVScale(DL, VT, MulImm: C0 * C1);
4832 }
4833
4834 // Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)).
4835 APInt MulVal;
4836 if (!UseVP && N0.getOpcode() == ISD::STEP_VECTOR &&
4837 ISD::isConstantSplatVector(N: N1.getNode(), SplatValue&: MulVal)) {
4838 const APInt &C0 = N0.getConstantOperandAPInt(i: 0);
4839 APInt NewStep = C0 * MulVal;
4840 return DAG.getStepVector(DL, ResVT: VT, StepVal: NewStep);
4841 }
4842
4843 // Fold Y = sra (X, size(X)-1); mul (or (Y, 1), X) -> (abs X)
4844 SDValue X;
4845 if (!UseVP && (!LegalOperations || hasOperation(Opcode: ISD::ABS, VT)) &&
4846 sd_context_match(
4847 N, Matcher,
4848 m_Mul(L: m_Or(L: m_Sra(L: m_Value(N&: X), R: m_SpecificInt(V: BitWidth - 1)), R: m_One()),
4849 R: m_Deferred(V&: X)))) {
4850 return Matcher.getNode(ISD::ABS, DL, VT, X);
4851 }
4852
4853 // Fold ((mul x, 0/undef) -> 0,
4854 // (mul x, 1) -> x) -> x)
4855 // -> and(x, mask)
4856 // We can replace vectors with '0' and '1' factors with a clearing mask.
4857 if (VT.isFixedLengthVector()) {
4858 unsigned NumElts = VT.getVectorNumElements();
4859 SmallBitVector ClearMask;
4860 ClearMask.reserve(N: NumElts);
4861 auto IsClearMask = [&ClearMask](ConstantSDNode *V) {
4862 if (!V || V->isZero()) {
4863 ClearMask.push_back(Val: true);
4864 return true;
4865 }
4866 ClearMask.push_back(Val: false);
4867 return V->isOne();
4868 };
4869 if ((!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::AND, VT)) &&
4870 ISD::matchUnaryPredicate(Op: N1, Match: IsClearMask, /*AllowUndefs*/ true)) {
4871 assert(N1.getOpcode() == ISD::BUILD_VECTOR && "Unknown constant vector");
4872 EVT LegalSVT = N1.getOperand(i: 0).getValueType();
4873 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: LegalSVT);
4874 SDValue AllOnes = DAG.getAllOnesConstant(DL, VT: LegalSVT);
4875 SmallVector<SDValue, 16> Mask(NumElts, AllOnes);
4876 for (unsigned I = 0; I != NumElts; ++I)
4877 if (ClearMask[I])
4878 Mask[I] = Zero;
4879 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0, N2: DAG.getBuildVector(VT, DL, Ops: Mask));
4880 }
4881 }
4882
4883 // reassociate mul
4884 // TODO: Change reassociateOps to support vp ops.
4885 if (!UseVP)
4886 if (SDValue RMUL = reassociateOps(Opc: ISD::MUL, DL, N0, N1, Flags: N->getFlags()))
4887 return RMUL;
4888
4889 // Fold mul(vecreduce(x), vecreduce(y)) -> vecreduce(mul(x, y))
4890 // TODO: Change reassociateReduction to support vp ops.
4891 if (!UseVP)
4892 if (SDValue SD =
4893 reassociateReduction(RedOpc: ISD::VECREDUCE_MUL, Opc: ISD::MUL, DL, VT, N0, N1))
4894 return SD;
4895
4896 // Simplify the operands using demanded-bits information.
4897 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
4898 return SDValue(N, 0);
4899
4900 return SDValue();
4901}
4902
4903/// Return true if divmod libcall is available.
4904static bool isDivRemLibcallAvailable(SDNode *Node, bool isSigned,
4905 const TargetLowering &TLI) {
4906 RTLIB::Libcall LC;
4907 EVT NodeType = Node->getValueType(ResNo: 0);
4908 if (!NodeType.isSimple())
4909 return false;
4910 switch (NodeType.getSimpleVT().SimpleTy) {
4911 default: return false; // No libcall for vector types.
4912 case MVT::i8: LC= isSigned ? RTLIB::SDIVREM_I8 : RTLIB::UDIVREM_I8; break;
4913 case MVT::i16: LC= isSigned ? RTLIB::SDIVREM_I16 : RTLIB::UDIVREM_I16; break;
4914 case MVT::i32: LC= isSigned ? RTLIB::SDIVREM_I32 : RTLIB::UDIVREM_I32; break;
4915 case MVT::i64: LC= isSigned ? RTLIB::SDIVREM_I64 : RTLIB::UDIVREM_I64; break;
4916 case MVT::i128: LC= isSigned ? RTLIB::SDIVREM_I128:RTLIB::UDIVREM_I128; break;
4917 }
4918
4919 return TLI.getLibcallName(Call: LC) != nullptr;
4920}
4921
4922/// Issue divrem if both quotient and remainder are needed.
4923SDValue DAGCombiner::useDivRem(SDNode *Node) {
4924 if (Node->use_empty())
4925 return SDValue(); // This is a dead node, leave it alone.
4926
4927 unsigned Opcode = Node->getOpcode();
4928 bool isSigned = (Opcode == ISD::SDIV) || (Opcode == ISD::SREM);
4929 unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM;
4930
4931 // DivMod lib calls can still work on non-legal types if using lib-calls.
4932 EVT VT = Node->getValueType(ResNo: 0);
4933 if (VT.isVector() || !VT.isInteger())
4934 return SDValue();
4935
4936 if (!TLI.isTypeLegal(VT) && !TLI.isOperationCustom(Op: DivRemOpc, VT))
4937 return SDValue();
4938
4939 // If DIVREM is going to get expanded into a libcall,
4940 // but there is no libcall available, then don't combine.
4941 if (!TLI.isOperationLegalOrCustom(Op: DivRemOpc, VT) &&
4942 !isDivRemLibcallAvailable(Node, isSigned, TLI))
4943 return SDValue();
4944
4945 // If div is legal, it's better to do the normal expansion
4946 unsigned OtherOpcode = 0;
4947 if ((Opcode == ISD::SDIV) || (Opcode == ISD::UDIV)) {
4948 OtherOpcode = isSigned ? ISD::SREM : ISD::UREM;
4949 if (TLI.isOperationLegalOrCustom(Op: Opcode, VT))
4950 return SDValue();
4951 } else {
4952 OtherOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
4953 if (TLI.isOperationLegalOrCustom(Op: OtherOpcode, VT))
4954 return SDValue();
4955 }
4956
4957 SDValue Op0 = Node->getOperand(Num: 0);
4958 SDValue Op1 = Node->getOperand(Num: 1);
4959 SDValue combined;
4960 for (SDNode *User : Op0->users()) {
4961 if (User == Node || User->getOpcode() == ISD::DELETED_NODE ||
4962 User->use_empty())
4963 continue;
4964 // Convert the other matching node(s), too;
4965 // otherwise, the DIVREM may get target-legalized into something
4966 // target-specific that we won't be able to recognize.
4967 unsigned UserOpc = User->getOpcode();
4968 if ((UserOpc == Opcode || UserOpc == OtherOpcode || UserOpc == DivRemOpc) &&
4969 User->getOperand(Num: 0) == Op0 &&
4970 User->getOperand(Num: 1) == Op1) {
4971 if (!combined) {
4972 if (UserOpc == OtherOpcode) {
4973 SDVTList VTs = DAG.getVTList(VT1: VT, VT2: VT);
4974 combined = DAG.getNode(Opcode: DivRemOpc, DL: SDLoc(Node), VTList: VTs, N1: Op0, N2: Op1);
4975 } else if (UserOpc == DivRemOpc) {
4976 combined = SDValue(User, 0);
4977 } else {
4978 assert(UserOpc == Opcode);
4979 continue;
4980 }
4981 }
4982 if (UserOpc == ISD::SDIV || UserOpc == ISD::UDIV)
4983 CombineTo(N: User, Res: combined);
4984 else if (UserOpc == ISD::SREM || UserOpc == ISD::UREM)
4985 CombineTo(N: User, Res: combined.getValue(R: 1));
4986 }
4987 }
4988 return combined;
4989}
4990
4991static SDValue simplifyDivRem(SDNode *N, SelectionDAG &DAG) {
4992 SDValue N0 = N->getOperand(Num: 0);
4993 SDValue N1 = N->getOperand(Num: 1);
4994 EVT VT = N->getValueType(ResNo: 0);
4995 SDLoc DL(N);
4996
4997 unsigned Opc = N->getOpcode();
4998 bool IsDiv = (ISD::SDIV == Opc) || (ISD::UDIV == Opc);
4999 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
5000
5001 // X / undef -> undef
5002 // X % undef -> undef
5003 // X / 0 -> undef
5004 // X % 0 -> undef
5005 // NOTE: This includes vectors where any divisor element is zero/undef.
5006 if (DAG.isUndef(Opcode: Opc, Ops: {N0, N1}))
5007 return DAG.getUNDEF(VT);
5008
5009 // undef / X -> 0
5010 // undef % X -> 0
5011 if (N0.isUndef())
5012 return DAG.getConstant(Val: 0, DL, VT);
5013
5014 // 0 / X -> 0
5015 // 0 % X -> 0
5016 ConstantSDNode *N0C = isConstOrConstSplat(N: N0);
5017 if (N0C && N0C->isZero())
5018 return N0;
5019
5020 // X / X -> 1
5021 // X % X -> 0
5022 if (N0 == N1)
5023 return DAG.getConstant(Val: IsDiv ? 1 : 0, DL, VT);
5024
5025 // X / 1 -> X
5026 // X % 1 -> 0
5027 // If this is a boolean op (single-bit element type), we can't have
5028 // division-by-zero or remainder-by-zero, so assume the divisor is 1.
5029 // TODO: Similarly, if we're zero-extending a boolean divisor, then assume
5030 // it's a 1.
5031 if ((N1C && N1C->isOne()) || (VT.getScalarType() == MVT::i1))
5032 return IsDiv ? N0 : DAG.getConstant(Val: 0, DL, VT);
5033
5034 return SDValue();
5035}
5036
5037SDValue DAGCombiner::visitSDIV(SDNode *N) {
5038 SDValue N0 = N->getOperand(Num: 0);
5039 SDValue N1 = N->getOperand(Num: 1);
5040 EVT VT = N->getValueType(ResNo: 0);
5041 EVT CCVT = getSetCCResultType(VT);
5042 SDLoc DL(N);
5043
5044 // fold (sdiv c1, c2) -> c1/c2
5045 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::SDIV, DL, VT, Ops: {N0, N1}))
5046 return C;
5047
5048 // fold vector ops
5049 if (VT.isVector())
5050 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5051 return FoldedVOp;
5052
5053 // fold (sdiv X, -1) -> 0-X
5054 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
5055 if (N1C && N1C->isAllOnes())
5056 return DAG.getNegative(Val: N0, DL, VT);
5057
5058 // fold (sdiv X, MIN_SIGNED) -> select(X == MIN_SIGNED, 1, 0)
5059 if (N1C && N1C->isMinSignedValue())
5060 return DAG.getSelect(DL, VT, Cond: DAG.getSetCC(DL, VT: CCVT, LHS: N0, RHS: N1, Cond: ISD::SETEQ),
5061 LHS: DAG.getConstant(Val: 1, DL, VT),
5062 RHS: DAG.getConstant(Val: 0, DL, VT));
5063
5064 if (SDValue V = simplifyDivRem(N, DAG))
5065 return V;
5066
5067 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
5068 return NewSel;
5069
5070 // If we know the sign bits of both operands are zero, strength reduce to a
5071 // udiv instead. Handles (X&15) /s 4 -> X&15 >> 2
5072 if (DAG.SignBitIsZero(Op: N1) && DAG.SignBitIsZero(Op: N0))
5073 return DAG.getNode(Opcode: ISD::UDIV, DL, VT: N1.getValueType(), N1: N0, N2: N1);
5074
5075 if (SDValue V = visitSDIVLike(N0, N1, N)) {
5076 // If the corresponding remainder node exists, update its users with
5077 // (Dividend - (Quotient * Divisor).
5078 if (SDNode *RemNode = DAG.getNodeIfExists(Opcode: ISD::SREM, VTList: N->getVTList(),
5079 Ops: { N0, N1 })) {
5080 // If the sdiv has the exact flag we shouldn't propagate it to the
5081 // remainder node.
5082 if (!N->getFlags().hasExact()) {
5083 SDValue Mul = DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: V, N2: N1);
5084 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: Mul);
5085 AddToWorklist(N: Mul.getNode());
5086 AddToWorklist(N: Sub.getNode());
5087 CombineTo(N: RemNode, Res: Sub);
5088 }
5089 }
5090 return V;
5091 }
5092
5093 // sdiv, srem -> sdivrem
5094 // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
5095 // true. Otherwise, we break the simplification logic in visitREM().
5096 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5097 if (!N1C || TLI.isIntDivCheap(VT: N->getValueType(ResNo: 0), Attr))
5098 if (SDValue DivRem = useDivRem(Node: N))
5099 return DivRem;
5100
5101 return SDValue();
5102}
5103
5104static bool isDivisorPowerOfTwo(SDValue Divisor) {
5105 // Helper for determining whether a value is a power-2 constant scalar or a
5106 // vector of such elements.
5107 auto IsPowerOfTwo = [](ConstantSDNode *C) {
5108 if (C->isZero() || C->isOpaque())
5109 return false;
5110 if (C->getAPIntValue().isPowerOf2())
5111 return true;
5112 if (C->getAPIntValue().isNegatedPowerOf2())
5113 return true;
5114 return false;
5115 };
5116
5117 return ISD::matchUnaryPredicate(Op: Divisor, Match: IsPowerOfTwo);
5118}
5119
5120SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) {
5121 SDLoc DL(N);
5122 EVT VT = N->getValueType(ResNo: 0);
5123 EVT CCVT = getSetCCResultType(VT);
5124 unsigned BitWidth = VT.getScalarSizeInBits();
5125
5126 // fold (sdiv X, pow2) -> simple ops after legalize
5127 // FIXME: We check for the exact bit here because the generic lowering gives
5128 // better results in that case. The target-specific lowering should learn how
5129 // to handle exact sdivs efficiently.
5130 if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(Divisor: N1)) {
5131 // Target-specific implementation of sdiv x, pow2.
5132 if (SDValue Res = BuildSDIVPow2(N))
5133 return Res;
5134
5135 // Create constants that are functions of the shift amount value.
5136 EVT ShiftAmtTy = getShiftAmountTy(LHSTy: N0.getValueType());
5137 SDValue Bits = DAG.getConstant(Val: BitWidth, DL, VT: ShiftAmtTy);
5138 SDValue C1 = DAG.getNode(Opcode: ISD::CTTZ, DL, VT, Operand: N1);
5139 C1 = DAG.getZExtOrTrunc(Op: C1, DL, VT: ShiftAmtTy);
5140 SDValue Inexact = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftAmtTy, N1: Bits, N2: C1);
5141 if (!isConstantOrConstantVector(N: Inexact))
5142 return SDValue();
5143
5144 // Splat the sign bit into the register
5145 SDValue Sign = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: N0,
5146 N2: DAG.getConstant(Val: BitWidth - 1, DL, VT: ShiftAmtTy));
5147 AddToWorklist(N: Sign.getNode());
5148
5149 // Add (N0 < 0) ? abs2 - 1 : 0;
5150 SDValue Srl = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Sign, N2: Inexact);
5151 AddToWorklist(N: Srl.getNode());
5152 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: Srl);
5153 AddToWorklist(N: Add.getNode());
5154 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: Add, N2: C1);
5155 AddToWorklist(N: Sra.getNode());
5156
5157 // Special case: (sdiv X, 1) -> X
5158 // Special Case: (sdiv X, -1) -> 0-X
5159 SDValue One = DAG.getConstant(Val: 1, DL, VT);
5160 SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
5161 SDValue IsOne = DAG.getSetCC(DL, VT: CCVT, LHS: N1, RHS: One, Cond: ISD::SETEQ);
5162 SDValue IsAllOnes = DAG.getSetCC(DL, VT: CCVT, LHS: N1, RHS: AllOnes, Cond: ISD::SETEQ);
5163 SDValue IsOneOrAllOnes = DAG.getNode(Opcode: ISD::OR, DL, VT: CCVT, N1: IsOne, N2: IsAllOnes);
5164 Sra = DAG.getSelect(DL, VT, Cond: IsOneOrAllOnes, LHS: N0, RHS: Sra);
5165
5166 // If dividing by a positive value, we're done. Otherwise, the result must
5167 // be negated.
5168 SDValue Zero = DAG.getConstant(Val: 0, DL, VT);
5169 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Zero, N2: Sra);
5170
5171 // FIXME: Use SELECT_CC once we improve SELECT_CC constant-folding.
5172 SDValue IsNeg = DAG.getSetCC(DL, VT: CCVT, LHS: N1, RHS: Zero, Cond: ISD::SETLT);
5173 SDValue Res = DAG.getSelect(DL, VT, Cond: IsNeg, LHS: Sub, RHS: Sra);
5174 return Res;
5175 }
5176
5177 // If integer divide is expensive and we satisfy the requirements, emit an
5178 // alternate sequence. Targets may check function attributes for size/speed
5179 // trade-offs.
5180 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5181 if (isConstantOrConstantVector(N: N1) &&
5182 !TLI.isIntDivCheap(VT: N->getValueType(ResNo: 0), Attr))
5183 if (SDValue Op = BuildSDIV(N))
5184 return Op;
5185
5186 return SDValue();
5187}
5188
5189SDValue DAGCombiner::visitUDIV(SDNode *N) {
5190 SDValue N0 = N->getOperand(Num: 0);
5191 SDValue N1 = N->getOperand(Num: 1);
5192 EVT VT = N->getValueType(ResNo: 0);
5193 EVT CCVT = getSetCCResultType(VT);
5194 SDLoc DL(N);
5195
5196 // fold (udiv c1, c2) -> c1/c2
5197 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::UDIV, DL, VT, Ops: {N0, N1}))
5198 return C;
5199
5200 // fold vector ops
5201 if (VT.isVector())
5202 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5203 return FoldedVOp;
5204
5205 // fold (udiv X, -1) -> select(X == -1, 1, 0)
5206 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
5207 if (N1C && N1C->isAllOnes() && CCVT.isVector() == VT.isVector()) {
5208 return DAG.getSelect(DL, VT, Cond: DAG.getSetCC(DL, VT: CCVT, LHS: N0, RHS: N1, Cond: ISD::SETEQ),
5209 LHS: DAG.getConstant(Val: 1, DL, VT),
5210 RHS: DAG.getConstant(Val: 0, DL, VT));
5211 }
5212
5213 if (SDValue V = simplifyDivRem(N, DAG))
5214 return V;
5215
5216 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
5217 return NewSel;
5218
5219 if (SDValue V = visitUDIVLike(N0, N1, N)) {
5220 // If the corresponding remainder node exists, update its users with
5221 // (Dividend - (Quotient * Divisor).
5222 if (SDNode *RemNode = DAG.getNodeIfExists(Opcode: ISD::UREM, VTList: N->getVTList(),
5223 Ops: { N0, N1 })) {
5224 // If the udiv has the exact flag we shouldn't propagate it to the
5225 // remainder node.
5226 if (!N->getFlags().hasExact()) {
5227 SDValue Mul = DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: V, N2: N1);
5228 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: Mul);
5229 AddToWorklist(N: Mul.getNode());
5230 AddToWorklist(N: Sub.getNode());
5231 CombineTo(N: RemNode, Res: Sub);
5232 }
5233 }
5234 return V;
5235 }
5236
5237 // sdiv, srem -> sdivrem
5238 // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
5239 // true. Otherwise, we break the simplification logic in visitREM().
5240 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5241 if (!N1C || TLI.isIntDivCheap(VT: N->getValueType(ResNo: 0), Attr))
5242 if (SDValue DivRem = useDivRem(Node: N))
5243 return DivRem;
5244
5245 // Simplify the operands using demanded-bits information.
5246 // We don't have demanded bits support for UDIV so this just enables constant
5247 // folding based on known bits.
5248 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
5249 return SDValue(N, 0);
5250
5251 return SDValue();
5252}
5253
5254SDValue DAGCombiner::visitUDIVLike(SDValue N0, SDValue N1, SDNode *N) {
5255 SDLoc DL(N);
5256 EVT VT = N->getValueType(ResNo: 0);
5257
5258 // fold (udiv x, (1 << c)) -> x >>u c
5259 if (isConstantOrConstantVector(N: N1, /*NoOpaques*/ true)) {
5260 if (SDValue LogBase2 = BuildLogBase2(V: N1, DL)) {
5261 AddToWorklist(N: LogBase2.getNode());
5262
5263 EVT ShiftVT = getShiftAmountTy(LHSTy: N0.getValueType());
5264 SDValue Trunc = DAG.getZExtOrTrunc(Op: LogBase2, DL, VT: ShiftVT);
5265 AddToWorklist(N: Trunc.getNode());
5266 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0, N2: Trunc);
5267 }
5268 }
5269
5270 // fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2
5271 if (N1.getOpcode() == ISD::SHL) {
5272 SDValue N10 = N1.getOperand(i: 0);
5273 if (isConstantOrConstantVector(N: N10, /*NoOpaques*/ true)) {
5274 if (SDValue LogBase2 = BuildLogBase2(V: N10, DL)) {
5275 AddToWorklist(N: LogBase2.getNode());
5276
5277 EVT ADDVT = N1.getOperand(i: 1).getValueType();
5278 SDValue Trunc = DAG.getZExtOrTrunc(Op: LogBase2, DL, VT: ADDVT);
5279 AddToWorklist(N: Trunc.getNode());
5280 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT: ADDVT, N1: N1.getOperand(i: 1), N2: Trunc);
5281 AddToWorklist(N: Add.getNode());
5282 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0, N2: Add);
5283 }
5284 }
5285 }
5286
5287 // fold (udiv x, c) -> alternate
5288 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5289 if (isConstantOrConstantVector(N: N1) &&
5290 !TLI.isIntDivCheap(VT: N->getValueType(ResNo: 0), Attr))
5291 if (SDValue Op = BuildUDIV(N))
5292 return Op;
5293
5294 return SDValue();
5295}
5296
5297SDValue DAGCombiner::buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N) {
5298 if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(Divisor: N1) &&
5299 !DAG.doesNodeExist(Opcode: ISD::SDIV, VTList: N->getVTList(), Ops: {N0, N1})) {
5300 // Target-specific implementation of srem x, pow2.
5301 if (SDValue Res = BuildSREMPow2(N))
5302 return Res;
5303 }
5304 return SDValue();
5305}
5306
5307// handles ISD::SREM and ISD::UREM
5308SDValue DAGCombiner::visitREM(SDNode *N) {
5309 unsigned Opcode = N->getOpcode();
5310 SDValue N0 = N->getOperand(Num: 0);
5311 SDValue N1 = N->getOperand(Num: 1);
5312 EVT VT = N->getValueType(ResNo: 0);
5313 EVT CCVT = getSetCCResultType(VT);
5314
5315 bool isSigned = (Opcode == ISD::SREM);
5316 SDLoc DL(N);
5317
5318 // fold (rem c1, c2) -> c1%c2
5319 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
5320 return C;
5321
5322 // fold (urem X, -1) -> select(FX == -1, 0, FX)
5323 // Freeze the numerator to avoid a miscompile with an undefined value.
5324 if (!isSigned && llvm::isAllOnesOrAllOnesSplat(V: N1, /*AllowUndefs*/ false) &&
5325 CCVT.isVector() == VT.isVector()) {
5326 SDValue F0 = DAG.getFreeze(V: N0);
5327 SDValue EqualsNeg1 = DAG.getSetCC(DL, VT: CCVT, LHS: F0, RHS: N1, Cond: ISD::SETEQ);
5328 return DAG.getSelect(DL, VT, Cond: EqualsNeg1, LHS: DAG.getConstant(Val: 0, DL, VT), RHS: F0);
5329 }
5330
5331 if (SDValue V = simplifyDivRem(N, DAG))
5332 return V;
5333
5334 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
5335 return NewSel;
5336
5337 if (isSigned) {
5338 // If we know the sign bits of both operands are zero, strength reduce to a
5339 // urem instead. Handles (X & 0x0FFFFFFF) %s 16 -> X&15
5340 if (DAG.SignBitIsZero(Op: N1) && DAG.SignBitIsZero(Op: N0))
5341 return DAG.getNode(Opcode: ISD::UREM, DL, VT, N1: N0, N2: N1);
5342 } else {
5343 if (DAG.isKnownToBeAPowerOfTwo(Val: N1)) {
5344 // fold (urem x, pow2) -> (and x, pow2-1)
5345 SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
5346 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: NegOne);
5347 AddToWorklist(N: Add.getNode());
5348 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0, N2: Add);
5349 }
5350 // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1))
5351 // fold (urem x, (lshr pow2, y)) -> (and x, (add (lshr pow2, y), -1))
5352 // TODO: We should sink the following into isKnownToBePowerOfTwo
5353 // using a OrZero parameter analogous to our handling in ValueTracking.
5354 if ((N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) &&
5355 DAG.isKnownToBeAPowerOfTwo(Val: N1.getOperand(i: 0))) {
5356 SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
5357 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: NegOne);
5358 AddToWorklist(N: Add.getNode());
5359 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0, N2: Add);
5360 }
5361 }
5362
5363 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5364
5365 // If X/C can be simplified by the division-by-constant logic, lower
5366 // X%C to the equivalent of X-X/C*C.
5367 // Reuse the SDIVLike/UDIVLike combines - to avoid mangling nodes, the
5368 // speculative DIV must not cause a DIVREM conversion. We guard against this
5369 // by skipping the simplification if isIntDivCheap(). When div is not cheap,
5370 // combine will not return a DIVREM. Regardless, checking cheapness here
5371 // makes sense since the simplification results in fatter code.
5372 if (DAG.isKnownNeverZero(Op: N1) && !TLI.isIntDivCheap(VT, Attr)) {
5373 if (isSigned) {
5374 // check if we can build faster implementation for srem
5375 if (SDValue OptimizedRem = buildOptimizedSREM(N0, N1, N))
5376 return OptimizedRem;
5377 }
5378
5379 SDValue OptimizedDiv =
5380 isSigned ? visitSDIVLike(N0, N1, N) : visitUDIVLike(N0, N1, N);
5381 if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != N) {
5382 // If the equivalent Div node also exists, update its users.
5383 unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
5384 if (SDNode *DivNode = DAG.getNodeIfExists(Opcode: DivOpcode, VTList: N->getVTList(),
5385 Ops: { N0, N1 }))
5386 CombineTo(N: DivNode, Res: OptimizedDiv);
5387 SDValue Mul = DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: OptimizedDiv, N2: N1);
5388 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: Mul);
5389 AddToWorklist(N: OptimizedDiv.getNode());
5390 AddToWorklist(N: Mul.getNode());
5391 return Sub;
5392 }
5393 }
5394
5395 // sdiv, srem -> sdivrem
5396 if (SDValue DivRem = useDivRem(Node: N))
5397 return DivRem.getValue(R: 1);
5398
5399 return SDValue();
5400}
5401
5402SDValue DAGCombiner::visitMULHS(SDNode *N) {
5403 SDValue N0 = N->getOperand(Num: 0);
5404 SDValue N1 = N->getOperand(Num: 1);
5405 EVT VT = N->getValueType(ResNo: 0);
5406 SDLoc DL(N);
5407
5408 // fold (mulhs c1, c2)
5409 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::MULHS, DL, VT, Ops: {N0, N1}))
5410 return C;
5411
5412 // canonicalize constant to RHS.
5413 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5414 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5415 return DAG.getNode(Opcode: ISD::MULHS, DL, VTList: N->getVTList(), N1, N2: N0);
5416
5417 if (VT.isVector()) {
5418 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5419 return FoldedVOp;
5420
5421 // fold (mulhs x, 0) -> 0
5422 // do not return N1, because undef node may exist.
5423 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
5424 return DAG.getConstant(Val: 0, DL, VT);
5425 }
5426
5427 // fold (mulhs x, 0) -> 0
5428 if (isNullConstant(V: N1))
5429 return N1;
5430
5431 // fold (mulhs x, 1) -> (sra x, size(x)-1)
5432 if (isOneConstant(V: N1))
5433 return DAG.getNode(
5434 Opcode: ISD::SRA, DL, VT, N1: N0,
5435 N2: DAG.getShiftAmountConstant(Val: N0.getScalarValueSizeInBits() - 1, VT, DL));
5436
5437 // fold (mulhs x, undef) -> 0
5438 if (N0.isUndef() || N1.isUndef())
5439 return DAG.getConstant(Val: 0, DL, VT);
5440
5441 // If the type twice as wide is legal, transform the mulhs to a wider multiply
5442 // plus a shift.
5443 if (!TLI.isOperationLegalOrCustom(Op: ISD::MULHS, VT) && VT.isSimple() &&
5444 !VT.isVector()) {
5445 MVT Simple = VT.getSimpleVT();
5446 unsigned SimpleSize = Simple.getSizeInBits();
5447 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SimpleSize*2);
5448 if (TLI.isOperationLegal(Op: ISD::MUL, VT: NewVT)) {
5449 N0 = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT: NewVT, Operand: N0);
5450 N1 = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT: NewVT, Operand: N1);
5451 N1 = DAG.getNode(Opcode: ISD::MUL, DL, VT: NewVT, N1: N0, N2: N1);
5452 N1 = DAG.getNode(Opcode: ISD::SRL, DL, VT: NewVT, N1,
5453 N2: DAG.getShiftAmountConstant(Val: SimpleSize, VT: NewVT, DL));
5454 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N1);
5455 }
5456 }
5457
5458 return SDValue();
5459}
5460
5461SDValue DAGCombiner::visitMULHU(SDNode *N) {
5462 SDValue N0 = N->getOperand(Num: 0);
5463 SDValue N1 = N->getOperand(Num: 1);
5464 EVT VT = N->getValueType(ResNo: 0);
5465 SDLoc DL(N);
5466
5467 // fold (mulhu c1, c2)
5468 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::MULHU, DL, VT, Ops: {N0, N1}))
5469 return C;
5470
5471 // canonicalize constant to RHS.
5472 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5473 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5474 return DAG.getNode(Opcode: ISD::MULHU, DL, VTList: N->getVTList(), N1, N2: N0);
5475
5476 if (VT.isVector()) {
5477 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5478 return FoldedVOp;
5479
5480 // fold (mulhu x, 0) -> 0
5481 // do not return N1, because undef node may exist.
5482 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
5483 return DAG.getConstant(Val: 0, DL, VT);
5484 }
5485
5486 // fold (mulhu x, 0) -> 0
5487 if (isNullConstant(V: N1))
5488 return N1;
5489
5490 // fold (mulhu x, 1) -> 0
5491 if (isOneConstant(V: N1))
5492 return DAG.getConstant(Val: 0, DL, VT);
5493
5494 // fold (mulhu x, undef) -> 0
5495 if (N0.isUndef() || N1.isUndef())
5496 return DAG.getConstant(Val: 0, DL, VT);
5497
5498 // fold (mulhu x, (1 << c)) -> x >> (bitwidth - c)
5499 if (isConstantOrConstantVector(N: N1, /*NoOpaques*/ true) &&
5500 hasOperation(Opcode: ISD::SRL, VT)) {
5501 if (SDValue LogBase2 = BuildLogBase2(V: N1, DL)) {
5502 unsigned NumEltBits = VT.getScalarSizeInBits();
5503 SDValue SRLAmt = DAG.getNode(
5504 Opcode: ISD::SUB, DL, VT, N1: DAG.getConstant(Val: NumEltBits, DL, VT), N2: LogBase2);
5505 EVT ShiftVT = getShiftAmountTy(LHSTy: N0.getValueType());
5506 SDValue Trunc = DAG.getZExtOrTrunc(Op: SRLAmt, DL, VT: ShiftVT);
5507 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0, N2: Trunc);
5508 }
5509 }
5510
5511 // If the type twice as wide is legal, transform the mulhu to a wider multiply
5512 // plus a shift.
5513 if (!TLI.isOperationLegalOrCustom(Op: ISD::MULHU, VT) && VT.isSimple() &&
5514 !VT.isVector()) {
5515 MVT Simple = VT.getSimpleVT();
5516 unsigned SimpleSize = Simple.getSizeInBits();
5517 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SimpleSize*2);
5518 if (TLI.isOperationLegal(Op: ISD::MUL, VT: NewVT)) {
5519 N0 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: NewVT, Operand: N0);
5520 N1 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: NewVT, Operand: N1);
5521 N1 = DAG.getNode(Opcode: ISD::MUL, DL, VT: NewVT, N1: N0, N2: N1);
5522 N1 = DAG.getNode(Opcode: ISD::SRL, DL, VT: NewVT, N1,
5523 N2: DAG.getShiftAmountConstant(Val: SimpleSize, VT: NewVT, DL));
5524 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N1);
5525 }
5526 }
5527
5528 // Simplify the operands using demanded-bits information.
5529 // We don't have demanded bits support for MULHU so this just enables constant
5530 // folding based on known bits.
5531 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
5532 return SDValue(N, 0);
5533
5534 return SDValue();
5535}
5536
5537SDValue DAGCombiner::visitAVG(SDNode *N) {
5538 unsigned Opcode = N->getOpcode();
5539 SDValue N0 = N->getOperand(Num: 0);
5540 SDValue N1 = N->getOperand(Num: 1);
5541 EVT VT = N->getValueType(ResNo: 0);
5542 SDLoc DL(N);
5543 bool IsSigned = Opcode == ISD::AVGCEILS || Opcode == ISD::AVGFLOORS;
5544
5545 // fold (avg c1, c2)
5546 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
5547 return C;
5548
5549 // canonicalize constant to RHS.
5550 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5551 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5552 return DAG.getNode(Opcode, DL, VTList: N->getVTList(), N1, N2: N0);
5553
5554 if (VT.isVector())
5555 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5556 return FoldedVOp;
5557
5558 // fold (avg x, undef) -> x
5559 if (N0.isUndef())
5560 return N1;
5561 if (N1.isUndef())
5562 return N0;
5563
5564 // fold (avg x, x) --> x
5565 if (N0 == N1 && Level >= AfterLegalizeTypes)
5566 return N0;
5567
5568 // fold (avgfloor x, 0) -> x >> 1
5569 SDValue X, Y;
5570 if (sd_match(N, P: m_c_BinOp(Opc: ISD::AVGFLOORS, L: m_Value(N&: X), R: m_Zero())))
5571 return DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: X,
5572 N2: DAG.getShiftAmountConstant(Val: 1, VT, DL));
5573 if (sd_match(N, P: m_c_BinOp(Opc: ISD::AVGFLOORU, L: m_Value(N&: X), R: m_Zero())))
5574 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: X,
5575 N2: DAG.getShiftAmountConstant(Val: 1, VT, DL));
5576
5577 // fold avgu(zext(x), zext(y)) -> zext(avgu(x, y))
5578 // fold avgs(sext(x), sext(y)) -> sext(avgs(x, y))
5579 if (!IsSigned &&
5580 sd_match(N, P: m_BinOp(Opc: Opcode, L: m_ZExt(Op: m_Value(N&: X)), R: m_ZExt(Op: m_Value(N&: Y)))) &&
5581 X.getValueType() == Y.getValueType() &&
5582 hasOperation(Opcode, VT: X.getValueType())) {
5583 SDValue AvgU = DAG.getNode(Opcode, DL, VT: X.getValueType(), N1: X, N2: Y);
5584 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: AvgU);
5585 }
5586 if (IsSigned &&
5587 sd_match(N, P: m_BinOp(Opc: Opcode, L: m_SExt(Op: m_Value(N&: X)), R: m_SExt(Op: m_Value(N&: Y)))) &&
5588 X.getValueType() == Y.getValueType() &&
5589 hasOperation(Opcode, VT: X.getValueType())) {
5590 SDValue AvgS = DAG.getNode(Opcode, DL, VT: X.getValueType(), N1: X, N2: Y);
5591 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: AvgS);
5592 }
5593
5594 // Fold avgflooru(x,y) -> avgceilu(x,y-1) iff y != 0
5595 // Fold avgflooru(x,y) -> avgceilu(x-1,y) iff x != 0
5596 // Check if avgflooru isn't legal/custom but avgceilu is.
5597 if (Opcode == ISD::AVGFLOORU && !hasOperation(Opcode: ISD::AVGFLOORU, VT) &&
5598 (!LegalOperations || hasOperation(Opcode: ISD::AVGCEILU, VT))) {
5599 if (DAG.isKnownNeverZero(Op: N1))
5600 return DAG.getNode(
5601 Opcode: ISD::AVGCEILU, DL, VT, N1: N0,
5602 N2: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: DAG.getAllOnesConstant(DL, VT)));
5603 if (DAG.isKnownNeverZero(Op: N0))
5604 return DAG.getNode(
5605 Opcode: ISD::AVGCEILU, DL, VT, N1,
5606 N2: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: DAG.getAllOnesConstant(DL, VT)));
5607 }
5608
5609 // Fold avgfloor((add nw x,y), 1) -> avgceil(x,y)
5610 // Fold avgfloor((add nw x,1), y) -> avgceil(x,y)
5611 if ((Opcode == ISD::AVGFLOORU && hasOperation(Opcode: ISD::AVGCEILU, VT)) ||
5612 (Opcode == ISD::AVGFLOORS && hasOperation(Opcode: ISD::AVGCEILS, VT))) {
5613 SDValue Add;
5614 if (sd_match(N,
5615 P: m_c_BinOp(Opc: Opcode,
5616 L: m_AllOf(preds: m_Value(N&: Add), preds: m_Add(L: m_Value(N&: X), R: m_Value(N&: Y))),
5617 R: m_One())) ||
5618 sd_match(N, P: m_c_BinOp(Opc: Opcode,
5619 L: m_AllOf(preds: m_Value(N&: Add), preds: m_Add(L: m_Value(N&: X), R: m_One())),
5620 R: m_Value(N&: Y)))) {
5621
5622 if (IsSigned && Add->getFlags().hasNoSignedWrap())
5623 return DAG.getNode(Opcode: ISD::AVGCEILS, DL, VT, N1: X, N2: Y);
5624
5625 if (!IsSigned && Add->getFlags().hasNoUnsignedWrap())
5626 return DAG.getNode(Opcode: ISD::AVGCEILU, DL, VT, N1: X, N2: Y);
5627 }
5628 }
5629
5630 // Fold avgfloors(x,y) -> avgflooru(x,y) if both x and y are non-negative
5631 if (Opcode == ISD::AVGFLOORS && hasOperation(Opcode: ISD::AVGFLOORU, VT)) {
5632 if (DAG.SignBitIsZero(Op: N0) && DAG.SignBitIsZero(Op: N1))
5633 return DAG.getNode(Opcode: ISD::AVGFLOORU, DL, VT, N1: N0, N2: N1);
5634 }
5635
5636 return SDValue();
5637}
5638
5639SDValue DAGCombiner::visitABD(SDNode *N) {
5640 unsigned Opcode = N->getOpcode();
5641 SDValue N0 = N->getOperand(Num: 0);
5642 SDValue N1 = N->getOperand(Num: 1);
5643 EVT VT = N->getValueType(ResNo: 0);
5644 SDLoc DL(N);
5645
5646 // fold (abd c1, c2)
5647 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
5648 return C;
5649
5650 // canonicalize constant to RHS.
5651 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5652 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5653 return DAG.getNode(Opcode, DL, VTList: N->getVTList(), N1, N2: N0);
5654
5655 if (VT.isVector())
5656 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5657 return FoldedVOp;
5658
5659 // fold (abd x, undef) -> 0
5660 if (N0.isUndef() || N1.isUndef())
5661 return DAG.getConstant(Val: 0, DL, VT);
5662
5663 // fold (abd x, x) -> 0
5664 if (N0 == N1)
5665 return DAG.getConstant(Val: 0, DL, VT);
5666
5667 SDValue X;
5668
5669 // fold (abds x, 0) -> abs x
5670 if (sd_match(N, P: m_c_BinOp(Opc: ISD::ABDS, L: m_Value(N&: X), R: m_Zero())) &&
5671 (!LegalOperations || hasOperation(Opcode: ISD::ABS, VT)))
5672 return DAG.getNode(Opcode: ISD::ABS, DL, VT, Operand: X);
5673
5674 // fold (abdu x, 0) -> x
5675 if (sd_match(N, P: m_c_BinOp(Opc: ISD::ABDU, L: m_Value(N&: X), R: m_Zero())))
5676 return X;
5677
5678 // fold (abds x, y) -> (abdu x, y) iff both args are known positive
5679 if (Opcode == ISD::ABDS && hasOperation(Opcode: ISD::ABDU, VT) &&
5680 DAG.SignBitIsZero(Op: N0) && DAG.SignBitIsZero(Op: N1))
5681 return DAG.getNode(Opcode: ISD::ABDU, DL, VT, N1, N2: N0);
5682
5683 return SDValue();
5684}
5685
5686/// Perform optimizations common to nodes that compute two values. LoOp and HiOp
5687/// give the opcodes for the two computations that are being performed. Return
5688/// true if a simplification was made.
5689SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
5690 unsigned HiOp) {
5691 // If the high half is not needed, just compute the low half.
5692 bool HiExists = N->hasAnyUseOfValue(Value: 1);
5693 if (!HiExists && (!LegalOperations ||
5694 TLI.isOperationLegalOrCustom(Op: LoOp, VT: N->getValueType(ResNo: 0)))) {
5695 SDValue Res = DAG.getNode(Opcode: LoOp, DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Ops: N->ops());
5696 return CombineTo(N, Res0: Res, Res1: Res);
5697 }
5698
5699 // If the low half is not needed, just compute the high half.
5700 bool LoExists = N->hasAnyUseOfValue(Value: 0);
5701 if (!LoExists && (!LegalOperations ||
5702 TLI.isOperationLegalOrCustom(Op: HiOp, VT: N->getValueType(ResNo: 1)))) {
5703 SDValue Res = DAG.getNode(Opcode: HiOp, DL: SDLoc(N), VT: N->getValueType(ResNo: 1), Ops: N->ops());
5704 return CombineTo(N, Res0: Res, Res1: Res);
5705 }
5706
5707 // If both halves are used, return as it is.
5708 if (LoExists && HiExists)
5709 return SDValue();
5710
5711 // If the two computed results can be simplified separately, separate them.
5712 if (LoExists) {
5713 SDValue Lo = DAG.getNode(Opcode: LoOp, DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Ops: N->ops());
5714 AddToWorklist(N: Lo.getNode());
5715 SDValue LoOpt = combine(N: Lo.getNode());
5716 if (LoOpt.getNode() && LoOpt.getNode() != Lo.getNode() &&
5717 (!LegalOperations ||
5718 TLI.isOperationLegalOrCustom(Op: LoOpt.getOpcode(), VT: LoOpt.getValueType())))
5719 return CombineTo(N, Res0: LoOpt, Res1: LoOpt);
5720 }
5721
5722 if (HiExists) {
5723 SDValue Hi = DAG.getNode(Opcode: HiOp, DL: SDLoc(N), VT: N->getValueType(ResNo: 1), Ops: N->ops());
5724 AddToWorklist(N: Hi.getNode());
5725 SDValue HiOpt = combine(N: Hi.getNode());
5726 if (HiOpt.getNode() && HiOpt != Hi &&
5727 (!LegalOperations ||
5728 TLI.isOperationLegalOrCustom(Op: HiOpt.getOpcode(), VT: HiOpt.getValueType())))
5729 return CombineTo(N, Res0: HiOpt, Res1: HiOpt);
5730 }
5731
5732 return SDValue();
5733}
5734
5735SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) {
5736 if (SDValue Res = SimplifyNodeWithTwoResults(N, LoOp: ISD::MUL, HiOp: ISD::MULHS))
5737 return Res;
5738
5739 SDValue N0 = N->getOperand(Num: 0);
5740 SDValue N1 = N->getOperand(Num: 1);
5741 EVT VT = N->getValueType(ResNo: 0);
5742 SDLoc DL(N);
5743
5744 // Constant fold.
5745 if (isa<ConstantSDNode>(Val: N0) && isa<ConstantSDNode>(Val: N1))
5746 return DAG.getNode(Opcode: ISD::SMUL_LOHI, DL, VTList: N->getVTList(), N1: N0, N2: N1);
5747
5748 // canonicalize constant to RHS (vector doesn't have to splat)
5749 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5750 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5751 return DAG.getNode(Opcode: ISD::SMUL_LOHI, DL, VTList: N->getVTList(), N1, N2: N0);
5752
5753 // If the type is twice as wide is legal, transform the mulhu to a wider
5754 // multiply plus a shift.
5755 if (VT.isSimple() && !VT.isVector()) {
5756 MVT Simple = VT.getSimpleVT();
5757 unsigned SimpleSize = Simple.getSizeInBits();
5758 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SimpleSize*2);
5759 if (TLI.isOperationLegal(Op: ISD::MUL, VT: NewVT)) {
5760 SDValue Lo = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT: NewVT, Operand: N0);
5761 SDValue Hi = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT: NewVT, Operand: N1);
5762 Lo = DAG.getNode(Opcode: ISD::MUL, DL, VT: NewVT, N1: Lo, N2: Hi);
5763 // Compute the high part as N1.
5764 Hi = DAG.getNode(Opcode: ISD::SRL, DL, VT: NewVT, N1: Lo,
5765 N2: DAG.getShiftAmountConstant(Val: SimpleSize, VT: NewVT, DL));
5766 Hi = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Hi);
5767 // Compute the low part as N0.
5768 Lo = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Lo);
5769 return CombineTo(N, Res0: Lo, Res1: Hi);
5770 }
5771 }
5772
5773 return SDValue();
5774}
5775
5776SDValue DAGCombiner::visitUMUL_LOHI(SDNode *N) {
5777 if (SDValue Res = SimplifyNodeWithTwoResults(N, LoOp: ISD::MUL, HiOp: ISD::MULHU))
5778 return Res;
5779
5780 SDValue N0 = N->getOperand(Num: 0);
5781 SDValue N1 = N->getOperand(Num: 1);
5782 EVT VT = N->getValueType(ResNo: 0);
5783 SDLoc DL(N);
5784
5785 // Constant fold.
5786 if (isa<ConstantSDNode>(Val: N0) && isa<ConstantSDNode>(Val: N1))
5787 return DAG.getNode(Opcode: ISD::UMUL_LOHI, DL, VTList: N->getVTList(), N1: N0, N2: N1);
5788
5789 // canonicalize constant to RHS (vector doesn't have to splat)
5790 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5791 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5792 return DAG.getNode(Opcode: ISD::UMUL_LOHI, DL, VTList: N->getVTList(), N1, N2: N0);
5793
5794 // (umul_lohi N0, 0) -> (0, 0)
5795 if (isNullConstant(V: N1)) {
5796 SDValue Zero = DAG.getConstant(Val: 0, DL, VT);
5797 return CombineTo(N, Res0: Zero, Res1: Zero);
5798 }
5799
5800 // (umul_lohi N0, 1) -> (N0, 0)
5801 if (isOneConstant(V: N1)) {
5802 SDValue Zero = DAG.getConstant(Val: 0, DL, VT);
5803 return CombineTo(N, Res0: N0, Res1: Zero);
5804 }
5805
5806 // If the type is twice as wide is legal, transform the mulhu to a wider
5807 // multiply plus a shift.
5808 if (VT.isSimple() && !VT.isVector()) {
5809 MVT Simple = VT.getSimpleVT();
5810 unsigned SimpleSize = Simple.getSizeInBits();
5811 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SimpleSize*2);
5812 if (TLI.isOperationLegal(Op: ISD::MUL, VT: NewVT)) {
5813 SDValue Lo = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: NewVT, Operand: N0);
5814 SDValue Hi = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: NewVT, Operand: N1);
5815 Lo = DAG.getNode(Opcode: ISD::MUL, DL, VT: NewVT, N1: Lo, N2: Hi);
5816 // Compute the high part as N1.
5817 Hi = DAG.getNode(Opcode: ISD::SRL, DL, VT: NewVT, N1: Lo,
5818 N2: DAG.getShiftAmountConstant(Val: SimpleSize, VT: NewVT, DL));
5819 Hi = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Hi);
5820 // Compute the low part as N0.
5821 Lo = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Lo);
5822 return CombineTo(N, Res0: Lo, Res1: Hi);
5823 }
5824 }
5825
5826 return SDValue();
5827}
5828
5829SDValue DAGCombiner::visitMULO(SDNode *N) {
5830 SDValue N0 = N->getOperand(Num: 0);
5831 SDValue N1 = N->getOperand(Num: 1);
5832 EVT VT = N0.getValueType();
5833 bool IsSigned = (ISD::SMULO == N->getOpcode());
5834
5835 EVT CarryVT = N->getValueType(ResNo: 1);
5836 SDLoc DL(N);
5837
5838 ConstantSDNode *N0C = isConstOrConstSplat(N: N0);
5839 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
5840
5841 // fold operation with constant operands.
5842 // TODO: Move this to FoldConstantArithmetic when it supports nodes with
5843 // multiple results.
5844 if (N0C && N1C) {
5845 bool Overflow;
5846 APInt Result =
5847 IsSigned ? N0C->getAPIntValue().smul_ov(RHS: N1C->getAPIntValue(), Overflow)
5848 : N0C->getAPIntValue().umul_ov(RHS: N1C->getAPIntValue(), Overflow);
5849 return CombineTo(N, Res0: DAG.getConstant(Val: Result, DL, VT),
5850 Res1: DAG.getBoolConstant(V: Overflow, DL, VT: CarryVT, OpVT: CarryVT));
5851 }
5852
5853 // canonicalize constant to RHS.
5854 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5855 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5856 return DAG.getNode(Opcode: N->getOpcode(), DL, VTList: N->getVTList(), N1, N2: N0);
5857
5858 // fold (mulo x, 0) -> 0 + no carry out
5859 if (isNullOrNullSplat(V: N1))
5860 return CombineTo(N, Res0: DAG.getConstant(Val: 0, DL, VT),
5861 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
5862
5863 // (mulo x, 2) -> (addo x, x)
5864 // FIXME: This needs a freeze.
5865 if (N1C && N1C->getAPIntValue() == 2 &&
5866 (!IsSigned || VT.getScalarSizeInBits() > 2))
5867 return DAG.getNode(Opcode: IsSigned ? ISD::SADDO : ISD::UADDO, DL,
5868 VTList: N->getVTList(), N1: N0, N2: N0);
5869
5870 // A 1 bit SMULO overflows if both inputs are 1.
5871 if (IsSigned && VT.getScalarSizeInBits() == 1) {
5872 SDValue And = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0, N2: N1);
5873 SDValue Cmp = DAG.getSetCC(DL, VT: CarryVT, LHS: And,
5874 RHS: DAG.getConstant(Val: 0, DL, VT), Cond: ISD::SETNE);
5875 return CombineTo(N, Res0: And, Res1: Cmp);
5876 }
5877
5878 // If it cannot overflow, transform into a mul.
5879 if (DAG.willNotOverflowMul(IsSigned, N0, N1))
5880 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: N0, N2: N1),
5881 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
5882 return SDValue();
5883}
5884
5885// Function to calculate whether the Min/Max pair of SDNodes (potentially
5886// swapped around) make a signed saturate pattern, clamping to between a signed
5887// saturate of -2^(BW-1) and 2^(BW-1)-1, or an unsigned saturate of 0 and 2^BW.
5888// Returns the node being clamped and the bitwidth of the clamp in BW. Should
5889// work with both SMIN/SMAX nodes and setcc/select combo. The operands are the
5890// same as SimplifySelectCC. N0<N1 ? N2 : N3.
5891static SDValue isSaturatingMinMax(SDValue N0, SDValue N1, SDValue N2,
5892 SDValue N3, ISD::CondCode CC, unsigned &BW,
5893 bool &Unsigned, SelectionDAG &DAG) {
5894 auto isSignedMinMax = [&](SDValue N0, SDValue N1, SDValue N2, SDValue N3,
5895 ISD::CondCode CC) {
5896 // The compare and select operand should be the same or the select operands
5897 // should be truncated versions of the comparison.
5898 if (N0 != N2 && (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(i: 0)))
5899 return 0;
5900 // The constants need to be the same or a truncated version of each other.
5901 ConstantSDNode *N1C = isConstOrConstSplat(N: peekThroughTruncates(V: N1));
5902 ConstantSDNode *N3C = isConstOrConstSplat(N: peekThroughTruncates(V: N3));
5903 if (!N1C || !N3C)
5904 return 0;
5905 const APInt &C1 = N1C->getAPIntValue().trunc(width: N1.getScalarValueSizeInBits());
5906 const APInt &C2 = N3C->getAPIntValue().trunc(width: N3.getScalarValueSizeInBits());
5907 if (C1.getBitWidth() < C2.getBitWidth() || C1 != C2.sext(width: C1.getBitWidth()))
5908 return 0;
5909 return CC == ISD::SETLT ? ISD::SMIN : (CC == ISD::SETGT ? ISD::SMAX : 0);
5910 };
5911
5912 // Check the initial value is a SMIN/SMAX equivalent.
5913 unsigned Opcode0 = isSignedMinMax(N0, N1, N2, N3, CC);
5914 if (!Opcode0)
5915 return SDValue();
5916
5917 // We could only need one range check, if the fptosi could never produce
5918 // the upper value.
5919 if (N0.getOpcode() == ISD::FP_TO_SINT && Opcode0 == ISD::SMAX) {
5920 if (isNullOrNullSplat(V: N3)) {
5921 EVT IntVT = N0.getValueType().getScalarType();
5922 EVT FPVT = N0.getOperand(i: 0).getValueType().getScalarType();
5923 if (FPVT.isSimple()) {
5924 Type *InputTy = FPVT.getTypeForEVT(Context&: *DAG.getContext());
5925 const fltSemantics &Semantics = InputTy->getFltSemantics();
5926 uint32_t MinBitWidth =
5927 APFloatBase::semanticsIntSizeInBits(Semantics, /*isSigned*/ true);
5928 if (IntVT.getSizeInBits() >= MinBitWidth) {
5929 Unsigned = true;
5930 BW = PowerOf2Ceil(A: MinBitWidth);
5931 return N0;
5932 }
5933 }
5934 }
5935 }
5936
5937 SDValue N00, N01, N02, N03;
5938 ISD::CondCode N0CC;
5939 switch (N0.getOpcode()) {
5940 case ISD::SMIN:
5941 case ISD::SMAX:
5942 N00 = N02 = N0.getOperand(i: 0);
5943 N01 = N03 = N0.getOperand(i: 1);
5944 N0CC = N0.getOpcode() == ISD::SMIN ? ISD::SETLT : ISD::SETGT;
5945 break;
5946 case ISD::SELECT_CC:
5947 N00 = N0.getOperand(i: 0);
5948 N01 = N0.getOperand(i: 1);
5949 N02 = N0.getOperand(i: 2);
5950 N03 = N0.getOperand(i: 3);
5951 N0CC = cast<CondCodeSDNode>(Val: N0.getOperand(i: 4))->get();
5952 break;
5953 case ISD::SELECT:
5954 case ISD::VSELECT:
5955 if (N0.getOperand(i: 0).getOpcode() != ISD::SETCC)
5956 return SDValue();
5957 N00 = N0.getOperand(i: 0).getOperand(i: 0);
5958 N01 = N0.getOperand(i: 0).getOperand(i: 1);
5959 N02 = N0.getOperand(i: 1);
5960 N03 = N0.getOperand(i: 2);
5961 N0CC = cast<CondCodeSDNode>(Val: N0.getOperand(i: 0).getOperand(i: 2))->get();
5962 break;
5963 default:
5964 return SDValue();
5965 }
5966
5967 unsigned Opcode1 = isSignedMinMax(N00, N01, N02, N03, N0CC);
5968 if (!Opcode1 || Opcode0 == Opcode1)
5969 return SDValue();
5970
5971 ConstantSDNode *MinCOp = isConstOrConstSplat(N: Opcode0 == ISD::SMIN ? N1 : N01);
5972 ConstantSDNode *MaxCOp = isConstOrConstSplat(N: Opcode0 == ISD::SMIN ? N01 : N1);
5973 if (!MinCOp || !MaxCOp || MinCOp->getValueType(ResNo: 0) != MaxCOp->getValueType(ResNo: 0))
5974 return SDValue();
5975
5976 const APInt &MinC = MinCOp->getAPIntValue();
5977 const APInt &MaxC = MaxCOp->getAPIntValue();
5978 APInt MinCPlus1 = MinC + 1;
5979 if (-MaxC == MinCPlus1 && MinCPlus1.isPowerOf2()) {
5980 BW = MinCPlus1.exactLogBase2() + 1;
5981 Unsigned = false;
5982 return N02;
5983 }
5984
5985 if (MaxC == 0 && MinCPlus1.isPowerOf2()) {
5986 BW = MinCPlus1.exactLogBase2();
5987 Unsigned = true;
5988 return N02;
5989 }
5990
5991 return SDValue();
5992}
5993
5994static SDValue PerformMinMaxFpToSatCombine(SDValue N0, SDValue N1, SDValue N2,
5995 SDValue N3, ISD::CondCode CC,
5996 SelectionDAG &DAG) {
5997 unsigned BW;
5998 bool Unsigned;
5999 SDValue Fp = isSaturatingMinMax(N0, N1, N2, N3, CC, BW, Unsigned, DAG);
6000 if (!Fp || Fp.getOpcode() != ISD::FP_TO_SINT)
6001 return SDValue();
6002 EVT FPVT = Fp.getOperand(i: 0).getValueType();
6003 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: BW);
6004 if (FPVT.isVector())
6005 NewVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NewVT,
6006 EC: FPVT.getVectorElementCount());
6007 unsigned NewOpc = Unsigned ? ISD::FP_TO_UINT_SAT : ISD::FP_TO_SINT_SAT;
6008 if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(Op: NewOpc, FPVT, VT: NewVT))
6009 return SDValue();
6010 SDLoc DL(Fp);
6011 SDValue Sat = DAG.getNode(Opcode: NewOpc, DL, VT: NewVT, N1: Fp.getOperand(i: 0),
6012 N2: DAG.getValueType(NewVT.getScalarType()));
6013 return DAG.getExtOrTrunc(IsSigned: !Unsigned, Op: Sat, DL, VT: N2->getValueType(ResNo: 0));
6014}
6015
6016static SDValue PerformUMinFpToSatCombine(SDValue N0, SDValue N1, SDValue N2,
6017 SDValue N3, ISD::CondCode CC,
6018 SelectionDAG &DAG) {
6019 // We are looking for UMIN(FPTOUI(X), (2^n)-1), which may have come via a
6020 // select/vselect/select_cc. The two operands pairs for the select (N2/N3) may
6021 // be truncated versions of the setcc (N0/N1).
6022 if ((N0 != N2 &&
6023 (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(i: 0))) ||
6024 N0.getOpcode() != ISD::FP_TO_UINT || CC != ISD::SETULT)
6025 return SDValue();
6026 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
6027 ConstantSDNode *N3C = isConstOrConstSplat(N: N3);
6028 if (!N1C || !N3C)
6029 return SDValue();
6030 const APInt &C1 = N1C->getAPIntValue();
6031 const APInt &C3 = N3C->getAPIntValue();
6032 if (!(C1 + 1).isPowerOf2() || C1.getBitWidth() < C3.getBitWidth() ||
6033 C1 != C3.zext(width: C1.getBitWidth()))
6034 return SDValue();
6035
6036 unsigned BW = (C1 + 1).exactLogBase2();
6037 EVT FPVT = N0.getOperand(i: 0).getValueType();
6038 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: BW);
6039 if (FPVT.isVector())
6040 NewVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NewVT,
6041 EC: FPVT.getVectorElementCount());
6042 if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(Op: ISD::FP_TO_UINT_SAT,
6043 FPVT, VT: NewVT))
6044 return SDValue();
6045
6046 SDValue Sat =
6047 DAG.getNode(Opcode: ISD::FP_TO_UINT_SAT, DL: SDLoc(N0), VT: NewVT, N1: N0.getOperand(i: 0),
6048 N2: DAG.getValueType(NewVT.getScalarType()));
6049 return DAG.getZExtOrTrunc(Op: Sat, DL: SDLoc(N0), VT: N3.getValueType());
6050}
6051
6052SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
6053 SDValue N0 = N->getOperand(Num: 0);
6054 SDValue N1 = N->getOperand(Num: 1);
6055 EVT VT = N0.getValueType();
6056 unsigned Opcode = N->getOpcode();
6057 SDLoc DL(N);
6058
6059 // fold operation with constant operands.
6060 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
6061 return C;
6062
6063 // If the operands are the same, this is a no-op.
6064 if (N0 == N1)
6065 return N0;
6066
6067 // canonicalize constant to RHS
6068 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
6069 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
6070 return DAG.getNode(Opcode, DL, VT, N1, N2: N0);
6071
6072 // fold vector ops
6073 if (VT.isVector())
6074 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
6075 return FoldedVOp;
6076
6077 // reassociate minmax
6078 if (SDValue RMINMAX = reassociateOps(Opc: Opcode, DL, N0, N1, Flags: N->getFlags()))
6079 return RMINMAX;
6080
6081 // Is sign bits are zero, flip between UMIN/UMAX and SMIN/SMAX.
6082 // Only do this if:
6083 // 1. The current op isn't legal and the flipped is.
6084 // 2. The saturation pattern is broken by canonicalization in InstCombine.
6085 bool IsOpIllegal = !TLI.isOperationLegal(Op: Opcode, VT);
6086 bool IsSatBroken = Opcode == ISD::UMIN && N0.getOpcode() == ISD::SMAX;
6087 if ((IsSatBroken || IsOpIllegal) && (N0.isUndef() || DAG.SignBitIsZero(Op: N0)) &&
6088 (N1.isUndef() || DAG.SignBitIsZero(Op: N1))) {
6089 unsigned AltOpcode;
6090 switch (Opcode) {
6091 case ISD::SMIN: AltOpcode = ISD::UMIN; break;
6092 case ISD::SMAX: AltOpcode = ISD::UMAX; break;
6093 case ISD::UMIN: AltOpcode = ISD::SMIN; break;
6094 case ISD::UMAX: AltOpcode = ISD::SMAX; break;
6095 default: llvm_unreachable("Unknown MINMAX opcode");
6096 }
6097 if ((IsSatBroken && IsOpIllegal) || TLI.isOperationLegal(Op: AltOpcode, VT))
6098 return DAG.getNode(Opcode: AltOpcode, DL, VT, N1: N0, N2: N1);
6099 }
6100
6101 if (Opcode == ISD::SMIN || Opcode == ISD::SMAX)
6102 if (SDValue S = PerformMinMaxFpToSatCombine(
6103 N0, N1, N2: N0, N3: N1, CC: Opcode == ISD::SMIN ? ISD::SETLT : ISD::SETGT, DAG))
6104 return S;
6105 if (Opcode == ISD::UMIN)
6106 if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N2: N0, N3: N1, CC: ISD::SETULT, DAG))
6107 return S;
6108
6109 // Fold min/max(vecreduce(x), vecreduce(y)) -> vecreduce(min/max(x, y))
6110 auto ReductionOpcode = [](unsigned Opcode) {
6111 switch (Opcode) {
6112 case ISD::SMIN:
6113 return ISD::VECREDUCE_SMIN;
6114 case ISD::SMAX:
6115 return ISD::VECREDUCE_SMAX;
6116 case ISD::UMIN:
6117 return ISD::VECREDUCE_UMIN;
6118 case ISD::UMAX:
6119 return ISD::VECREDUCE_UMAX;
6120 default:
6121 llvm_unreachable("Unexpected opcode");
6122 }
6123 };
6124 if (SDValue SD = reassociateReduction(RedOpc: ReductionOpcode(Opcode), Opc: Opcode,
6125 DL: SDLoc(N), VT, N0, N1))
6126 return SD;
6127
6128 // Simplify the operands using demanded-bits information.
6129 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
6130 return SDValue(N, 0);
6131
6132 return SDValue();
6133}
6134
6135/// If this is a bitwise logic instruction and both operands have the same
6136/// opcode, try to sink the other opcode after the logic instruction.
6137SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) {
6138 SDValue N0 = N->getOperand(Num: 0), N1 = N->getOperand(Num: 1);
6139 EVT VT = N0.getValueType();
6140 unsigned LogicOpcode = N->getOpcode();
6141 unsigned HandOpcode = N0.getOpcode();
6142 assert(ISD::isBitwiseLogicOp(LogicOpcode) && "Expected logic opcode");
6143 assert(HandOpcode == N1.getOpcode() && "Bad input!");
6144
6145 // Bail early if none of these transforms apply.
6146 if (N0.getNumOperands() == 0)
6147 return SDValue();
6148
6149 // FIXME: We should check number of uses of the operands to not increase
6150 // the instruction count for all transforms.
6151
6152 // Handle size-changing casts (or sign_extend_inreg).
6153 SDValue X = N0.getOperand(i: 0);
6154 SDValue Y = N1.getOperand(i: 0);
6155 EVT XVT = X.getValueType();
6156 SDLoc DL(N);
6157 if (ISD::isExtOpcode(Opcode: HandOpcode) || ISD::isExtVecInRegOpcode(Opcode: HandOpcode) ||
6158 (HandOpcode == ISD::SIGN_EXTEND_INREG &&
6159 N0.getOperand(i: 1) == N1.getOperand(i: 1))) {
6160 // If both operands have other uses, this transform would create extra
6161 // instructions without eliminating anything.
6162 if (!N0.hasOneUse() && !N1.hasOneUse())
6163 return SDValue();
6164 // We need matching integer source types.
6165 if (XVT != Y.getValueType())
6166 return SDValue();
6167 // Don't create an illegal op during or after legalization. Don't ever
6168 // create an unsupported vector op.
6169 if ((VT.isVector() || LegalOperations) &&
6170 !TLI.isOperationLegalOrCustom(Op: LogicOpcode, VT: XVT))
6171 return SDValue();
6172 // Avoid infinite looping with PromoteIntBinOp.
6173 // TODO: Should we apply desirable/legal constraints to all opcodes?
6174 if ((HandOpcode == ISD::ANY_EXTEND ||
6175 HandOpcode == ISD::ANY_EXTEND_VECTOR_INREG) &&
6176 LegalTypes && !TLI.isTypeDesirableForOp(LogicOpcode, VT: XVT))
6177 return SDValue();
6178 // logic_op (hand_op X), (hand_op Y) --> hand_op (logic_op X, Y)
6179 SDNodeFlags LogicFlags;
6180 LogicFlags.setDisjoint(N->getFlags().hasDisjoint() &&
6181 ISD::isExtOpcode(Opcode: HandOpcode));
6182 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT: XVT, N1: X, N2: Y, Flags: LogicFlags);
6183 if (HandOpcode == ISD::SIGN_EXTEND_INREG)
6184 return DAG.getNode(Opcode: HandOpcode, DL, VT, N1: Logic, N2: N0.getOperand(i: 1));
6185 return DAG.getNode(Opcode: HandOpcode, DL, VT, Operand: Logic);
6186 }
6187
6188 // logic_op (truncate x), (truncate y) --> truncate (logic_op x, y)
6189 if (HandOpcode == ISD::TRUNCATE) {
6190 // If both operands have other uses, this transform would create extra
6191 // instructions without eliminating anything.
6192 if (!N0.hasOneUse() && !N1.hasOneUse())
6193 return SDValue();
6194 // We need matching source types.
6195 if (XVT != Y.getValueType())
6196 return SDValue();
6197 // Don't create an illegal op during or after legalization.
6198 if (LegalOperations && !TLI.isOperationLegal(Op: LogicOpcode, VT: XVT))
6199 return SDValue();
6200 // Be extra careful sinking truncate. If it's free, there's no benefit in
6201 // widening a binop. Also, don't create a logic op on an illegal type.
6202 if (TLI.isZExtFree(FromTy: VT, ToTy: XVT) && TLI.isTruncateFree(FromVT: XVT, ToVT: VT))
6203 return SDValue();
6204 if (!TLI.isTypeLegal(VT: XVT))
6205 return SDValue();
6206 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT: XVT, N1: X, N2: Y);
6207 return DAG.getNode(Opcode: HandOpcode, DL, VT, Operand: Logic);
6208 }
6209
6210 // For binops SHL/SRL/SRA/AND:
6211 // logic_op (OP x, z), (OP y, z) --> OP (logic_op x, y), z
6212 if ((HandOpcode == ISD::SHL || HandOpcode == ISD::SRL ||
6213 HandOpcode == ISD::SRA || HandOpcode == ISD::AND) &&
6214 N0.getOperand(i: 1) == N1.getOperand(i: 1)) {
6215 // If either operand has other uses, this transform is not an improvement.
6216 if (!N0.hasOneUse() || !N1.hasOneUse())
6217 return SDValue();
6218 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT: XVT, N1: X, N2: Y);
6219 return DAG.getNode(Opcode: HandOpcode, DL, VT, N1: Logic, N2: N0.getOperand(i: 1));
6220 }
6221
6222 // Unary ops: logic_op (bswap x), (bswap y) --> bswap (logic_op x, y)
6223 if (HandOpcode == ISD::BSWAP) {
6224 // If either operand has other uses, this transform is not an improvement.
6225 if (!N0.hasOneUse() || !N1.hasOneUse())
6226 return SDValue();
6227 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT: XVT, N1: X, N2: Y);
6228 return DAG.getNode(Opcode: HandOpcode, DL, VT, Operand: Logic);
6229 }
6230
6231 // For funnel shifts FSHL/FSHR:
6232 // logic_op (OP x, x1, s), (OP y, y1, s) -->
6233 // --> OP (logic_op x, y), (logic_op, x1, y1), s
6234 if ((HandOpcode == ISD::FSHL || HandOpcode == ISD::FSHR) &&
6235 N0.getOperand(i: 2) == N1.getOperand(i: 2)) {
6236 if (!N0.hasOneUse() || !N1.hasOneUse())
6237 return SDValue();
6238 SDValue X1 = N0.getOperand(i: 1);
6239 SDValue Y1 = N1.getOperand(i: 1);
6240 SDValue S = N0.getOperand(i: 2);
6241 SDValue Logic0 = DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: X, N2: Y);
6242 SDValue Logic1 = DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: X1, N2: Y1);
6243 return DAG.getNode(Opcode: HandOpcode, DL, VT, N1: Logic0, N2: Logic1, N3: S);
6244 }
6245
6246 // Simplify xor/and/or (bitcast(A), bitcast(B)) -> bitcast(op (A,B))
6247 // Only perform this optimization up until type legalization, before
6248 // LegalizeVectorOprs. LegalizeVectorOprs promotes vector operations by
6249 // adding bitcasts. For example (xor v4i32) is promoted to (v2i64), and
6250 // we don't want to undo this promotion.
6251 // We also handle SCALAR_TO_VECTOR because xor/or/and operations are cheaper
6252 // on scalars.
6253 if ((HandOpcode == ISD::BITCAST || HandOpcode == ISD::SCALAR_TO_VECTOR) &&
6254 Level <= AfterLegalizeTypes) {
6255 // Input types must be integer and the same.
6256 if (XVT.isInteger() && XVT == Y.getValueType() &&
6257 !(VT.isVector() && TLI.isTypeLegal(VT) &&
6258 !XVT.isVector() && !TLI.isTypeLegal(VT: XVT))) {
6259 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT: XVT, N1: X, N2: Y);
6260 return DAG.getNode(Opcode: HandOpcode, DL, VT, Operand: Logic);
6261 }
6262 }
6263
6264 // Xor/and/or are indifferent to the swizzle operation (shuffle of one value).
6265 // Simplify xor/and/or (shuff(A), shuff(B)) -> shuff(op (A,B))
6266 // If both shuffles use the same mask, and both shuffle within a single
6267 // vector, then it is worthwhile to move the swizzle after the operation.
6268 // The type-legalizer generates this pattern when loading illegal
6269 // vector types from memory. In many cases this allows additional shuffle
6270 // optimizations.
6271 // There are other cases where moving the shuffle after the xor/and/or
6272 // is profitable even if shuffles don't perform a swizzle.
6273 // If both shuffles use the same mask, and both shuffles have the same first
6274 // or second operand, then it might still be profitable to move the shuffle
6275 // after the xor/and/or operation.
6276 if (HandOpcode == ISD::VECTOR_SHUFFLE && Level < AfterLegalizeDAG) {
6277 auto *SVN0 = cast<ShuffleVectorSDNode>(Val&: N0);
6278 auto *SVN1 = cast<ShuffleVectorSDNode>(Val&: N1);
6279 assert(X.getValueType() == Y.getValueType() &&
6280 "Inputs to shuffles are not the same type");
6281
6282 // Check that both shuffles use the same mask. The masks are known to be of
6283 // the same length because the result vector type is the same.
6284 // Check also that shuffles have only one use to avoid introducing extra
6285 // instructions.
6286 if (!SVN0->hasOneUse() || !SVN1->hasOneUse() ||
6287 !SVN0->getMask().equals(RHS: SVN1->getMask()))
6288 return SDValue();
6289
6290 // Don't try to fold this node if it requires introducing a
6291 // build vector of all zeros that might be illegal at this stage.
6292 SDValue ShOp = N0.getOperand(i: 1);
6293 if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
6294 ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
6295
6296 // (logic_op (shuf (A, C), shuf (B, C))) --> shuf (logic_op (A, B), C)
6297 if (N0.getOperand(i: 1) == N1.getOperand(i: 1) && ShOp.getNode()) {
6298 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT,
6299 N1: N0.getOperand(i: 0), N2: N1.getOperand(i: 0));
6300 return DAG.getVectorShuffle(VT, dl: DL, N1: Logic, N2: ShOp, Mask: SVN0->getMask());
6301 }
6302
6303 // Don't try to fold this node if it requires introducing a
6304 // build vector of all zeros that might be illegal at this stage.
6305 ShOp = N0.getOperand(i: 0);
6306 if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
6307 ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
6308
6309 // (logic_op (shuf (C, A), shuf (C, B))) --> shuf (C, logic_op (A, B))
6310 if (N0.getOperand(i: 0) == N1.getOperand(i: 0) && ShOp.getNode()) {
6311 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: N0.getOperand(i: 1),
6312 N2: N1.getOperand(i: 1));
6313 return DAG.getVectorShuffle(VT, dl: DL, N1: ShOp, N2: Logic, Mask: SVN0->getMask());
6314 }
6315 }
6316
6317 return SDValue();
6318}
6319
6320/// Try to make (and/or setcc (LL, LR), setcc (RL, RR)) more efficient.
6321SDValue DAGCombiner::foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
6322 const SDLoc &DL) {
6323 SDValue LL, LR, RL, RR, N0CC, N1CC;
6324 if (!isSetCCEquivalent(N: N0, LHS&: LL, RHS&: LR, CC&: N0CC) ||
6325 !isSetCCEquivalent(N: N1, LHS&: RL, RHS&: RR, CC&: N1CC))
6326 return SDValue();
6327
6328 assert(N0.getValueType() == N1.getValueType() &&
6329 "Unexpected operand types for bitwise logic op");
6330 assert(LL.getValueType() == LR.getValueType() &&
6331 RL.getValueType() == RR.getValueType() &&
6332 "Unexpected operand types for setcc");
6333
6334 // If we're here post-legalization or the logic op type is not i1, the logic
6335 // op type must match a setcc result type. Also, all folds require new
6336 // operations on the left and right operands, so those types must match.
6337 EVT VT = N0.getValueType();
6338 EVT OpVT = LL.getValueType();
6339 if (LegalOperations || VT.getScalarType() != MVT::i1)
6340 if (VT != getSetCCResultType(VT: OpVT))
6341 return SDValue();
6342 if (OpVT != RL.getValueType())
6343 return SDValue();
6344
6345 ISD::CondCode CC0 = cast<CondCodeSDNode>(Val&: N0CC)->get();
6346 ISD::CondCode CC1 = cast<CondCodeSDNode>(Val&: N1CC)->get();
6347 bool IsInteger = OpVT.isInteger();
6348 if (LR == RR && CC0 == CC1 && IsInteger) {
6349 bool IsZero = isNullOrNullSplat(V: LR);
6350 bool IsNeg1 = isAllOnesOrAllOnesSplat(V: LR);
6351
6352 // All bits clear?
6353 bool AndEqZero = IsAnd && CC1 == ISD::SETEQ && IsZero;
6354 // All sign bits clear?
6355 bool AndGtNeg1 = IsAnd && CC1 == ISD::SETGT && IsNeg1;
6356 // Any bits set?
6357 bool OrNeZero = !IsAnd && CC1 == ISD::SETNE && IsZero;
6358 // Any sign bits set?
6359 bool OrLtZero = !IsAnd && CC1 == ISD::SETLT && IsZero;
6360
6361 // (and (seteq X, 0), (seteq Y, 0)) --> (seteq (or X, Y), 0)
6362 // (and (setgt X, -1), (setgt Y, -1)) --> (setgt (or X, Y), -1)
6363 // (or (setne X, 0), (setne Y, 0)) --> (setne (or X, Y), 0)
6364 // (or (setlt X, 0), (setlt Y, 0)) --> (setlt (or X, Y), 0)
6365 if (AndEqZero || AndGtNeg1 || OrNeZero || OrLtZero) {
6366 SDValue Or = DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N0), VT: OpVT, N1: LL, N2: RL);
6367 AddToWorklist(N: Or.getNode());
6368 return DAG.getSetCC(DL, VT, LHS: Or, RHS: LR, Cond: CC1);
6369 }
6370
6371 // All bits set?
6372 bool AndEqNeg1 = IsAnd && CC1 == ISD::SETEQ && IsNeg1;
6373 // All sign bits set?
6374 bool AndLtZero = IsAnd && CC1 == ISD::SETLT && IsZero;
6375 // Any bits clear?
6376 bool OrNeNeg1 = !IsAnd && CC1 == ISD::SETNE && IsNeg1;
6377 // Any sign bits clear?
6378 bool OrGtNeg1 = !IsAnd && CC1 == ISD::SETGT && IsNeg1;
6379
6380 // (and (seteq X, -1), (seteq Y, -1)) --> (seteq (and X, Y), -1)
6381 // (and (setlt X, 0), (setlt Y, 0)) --> (setlt (and X, Y), 0)
6382 // (or (setne X, -1), (setne Y, -1)) --> (setne (and X, Y), -1)
6383 // (or (setgt X, -1), (setgt Y -1)) --> (setgt (and X, Y), -1)
6384 if (AndEqNeg1 || AndLtZero || OrNeNeg1 || OrGtNeg1) {
6385 SDValue And = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(N0), VT: OpVT, N1: LL, N2: RL);
6386 AddToWorklist(N: And.getNode());
6387 return DAG.getSetCC(DL, VT, LHS: And, RHS: LR, Cond: CC1);
6388 }
6389 }
6390
6391 // TODO: What is the 'or' equivalent of this fold?
6392 // (and (setne X, 0), (setne X, -1)) --> (setuge (add X, 1), 2)
6393 if (IsAnd && LL == RL && CC0 == CC1 && OpVT.getScalarSizeInBits() > 1 &&
6394 IsInteger && CC0 == ISD::SETNE &&
6395 ((isNullConstant(V: LR) && isAllOnesConstant(V: RR)) ||
6396 (isAllOnesConstant(V: LR) && isNullConstant(V: RR)))) {
6397 SDValue One = DAG.getConstant(Val: 1, DL, VT: OpVT);
6398 SDValue Two = DAG.getConstant(Val: 2, DL, VT: OpVT);
6399 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL: SDLoc(N0), VT: OpVT, N1: LL, N2: One);
6400 AddToWorklist(N: Add.getNode());
6401 return DAG.getSetCC(DL, VT, LHS: Add, RHS: Two, Cond: ISD::SETUGE);
6402 }
6403
6404 // Try more general transforms if the predicates match and the only user of
6405 // the compares is the 'and' or 'or'.
6406 if (IsInteger && TLI.convertSetCCLogicToBitwiseLogic(VT: OpVT) && CC0 == CC1 &&
6407 N0.hasOneUse() && N1.hasOneUse()) {
6408 // and (seteq A, B), (seteq C, D) --> seteq (or (xor A, B), (xor C, D)), 0
6409 // or (setne A, B), (setne C, D) --> setne (or (xor A, B), (xor C, D)), 0
6410 if ((IsAnd && CC1 == ISD::SETEQ) || (!IsAnd && CC1 == ISD::SETNE)) {
6411 SDValue XorL = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N0), VT: OpVT, N1: LL, N2: LR);
6412 SDValue XorR = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N1), VT: OpVT, N1: RL, N2: RR);
6413 SDValue Or = DAG.getNode(Opcode: ISD::OR, DL, VT: OpVT, N1: XorL, N2: XorR);
6414 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: OpVT);
6415 return DAG.getSetCC(DL, VT, LHS: Or, RHS: Zero, Cond: CC1);
6416 }
6417
6418 // Turn compare of constants whose difference is 1 bit into add+and+setcc.
6419 if ((IsAnd && CC1 == ISD::SETNE) || (!IsAnd && CC1 == ISD::SETEQ)) {
6420 // Match a shared variable operand and 2 non-opaque constant operands.
6421 auto MatchDiffPow2 = [&](ConstantSDNode *C0, ConstantSDNode *C1) {
6422 // The difference of the constants must be a single bit.
6423 const APInt &CMax =
6424 APIntOps::umax(A: C0->getAPIntValue(), B: C1->getAPIntValue());
6425 const APInt &CMin =
6426 APIntOps::umin(A: C0->getAPIntValue(), B: C1->getAPIntValue());
6427 return !C0->isOpaque() && !C1->isOpaque() && (CMax - CMin).isPowerOf2();
6428 };
6429 if (LL == RL && ISD::matchBinaryPredicate(LHS: LR, RHS: RR, Match: MatchDiffPow2)) {
6430 // and/or (setcc X, CMax, ne), (setcc X, CMin, ne/eq) -->
6431 // setcc ((sub X, CMin), ~(CMax - CMin)), 0, ne/eq
6432 SDValue Max = DAG.getNode(Opcode: ISD::UMAX, DL, VT: OpVT, N1: LR, N2: RR);
6433 SDValue Min = DAG.getNode(Opcode: ISD::UMIN, DL, VT: OpVT, N1: LR, N2: RR);
6434 SDValue Offset = DAG.getNode(Opcode: ISD::SUB, DL, VT: OpVT, N1: LL, N2: Min);
6435 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: OpVT, N1: Max, N2: Min);
6436 SDValue Mask = DAG.getNOT(DL, Val: Diff, VT: OpVT);
6437 SDValue And = DAG.getNode(Opcode: ISD::AND, DL, VT: OpVT, N1: Offset, N2: Mask);
6438 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: OpVT);
6439 return DAG.getSetCC(DL, VT, LHS: And, RHS: Zero, Cond: CC0);
6440 }
6441 }
6442 }
6443
6444 // Canonicalize equivalent operands to LL == RL.
6445 if (LL == RR && LR == RL) {
6446 CC1 = ISD::getSetCCSwappedOperands(Operation: CC1);
6447 std::swap(a&: RL, b&: RR);
6448 }
6449
6450 // (and (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
6451 // (or (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
6452 if (LL == RL && LR == RR) {
6453 ISD::CondCode NewCC = IsAnd ? ISD::getSetCCAndOperation(Op1: CC0, Op2: CC1, Type: OpVT)
6454 : ISD::getSetCCOrOperation(Op1: CC0, Op2: CC1, Type: OpVT);
6455 if (NewCC != ISD::SETCC_INVALID &&
6456 (!LegalOperations ||
6457 (TLI.isCondCodeLegal(CC: NewCC, VT: LL.getSimpleValueType()) &&
6458 TLI.isOperationLegal(Op: ISD::SETCC, VT: OpVT))))
6459 return DAG.getSetCC(DL, VT, LHS: LL, RHS: LR, Cond: NewCC);
6460 }
6461
6462 return SDValue();
6463}
6464
6465static bool arebothOperandsNotSNan(SDValue Operand1, SDValue Operand2,
6466 SelectionDAG &DAG) {
6467 return DAG.isKnownNeverSNaN(Op: Operand2) && DAG.isKnownNeverSNaN(Op: Operand1);
6468}
6469
6470static bool arebothOperandsNotNan(SDValue Operand1, SDValue Operand2,
6471 SelectionDAG &DAG) {
6472 return DAG.isKnownNeverNaN(Op: Operand2) && DAG.isKnownNeverNaN(Op: Operand1);
6473}
6474
6475// FIXME: use FMINIMUMNUM if possible, such as for RISC-V.
6476static unsigned getMinMaxOpcodeForFP(SDValue Operand1, SDValue Operand2,
6477 ISD::CondCode CC, unsigned OrAndOpcode,
6478 SelectionDAG &DAG,
6479 bool isFMAXNUMFMINNUM_IEEE,
6480 bool isFMAXNUMFMINNUM) {
6481 // The optimization cannot be applied for all the predicates because
6482 // of the way FMINNUM/FMAXNUM and FMINNUM_IEEE/FMAXNUM_IEEE handle
6483 // NaNs. For FMINNUM_IEEE/FMAXNUM_IEEE, the optimization cannot be
6484 // applied at all if one of the operands is a signaling NaN.
6485
6486 // It is safe to use FMINNUM_IEEE/FMAXNUM_IEEE if all the operands
6487 // are non NaN values.
6488 if (((CC == ISD::SETLT || CC == ISD::SETLE) && (OrAndOpcode == ISD::OR)) ||
6489 ((CC == ISD::SETGT || CC == ISD::SETGE) && (OrAndOpcode == ISD::AND)))
6490 return arebothOperandsNotNan(Operand1, Operand2, DAG) &&
6491 isFMAXNUMFMINNUM_IEEE
6492 ? ISD::FMINNUM_IEEE
6493 : ISD::DELETED_NODE;
6494 else if (((CC == ISD::SETGT || CC == ISD::SETGE) &&
6495 (OrAndOpcode == ISD::OR)) ||
6496 ((CC == ISD::SETLT || CC == ISD::SETLE) &&
6497 (OrAndOpcode == ISD::AND)))
6498 return arebothOperandsNotNan(Operand1, Operand2, DAG) &&
6499 isFMAXNUMFMINNUM_IEEE
6500 ? ISD::FMAXNUM_IEEE
6501 : ISD::DELETED_NODE;
6502 // Both FMINNUM/FMAXNUM and FMINNUM_IEEE/FMAXNUM_IEEE handle quiet
6503 // NaNs in the same way. But, FMINNUM/FMAXNUM and FMINNUM_IEEE/
6504 // FMAXNUM_IEEE handle signaling NaNs differently. If we cannot prove
6505 // that there are not any sNaNs, then the optimization is not valid
6506 // for FMINNUM_IEEE/FMAXNUM_IEEE. In the presence of sNaNs, we apply
6507 // the optimization using FMINNUM/FMAXNUM for the following cases. If
6508 // we can prove that we do not have any sNaNs, then we can do the
6509 // optimization using FMINNUM_IEEE/FMAXNUM_IEEE for the following
6510 // cases.
6511 else if (((CC == ISD::SETOLT || CC == ISD::SETOLE) &&
6512 (OrAndOpcode == ISD::OR)) ||
6513 ((CC == ISD::SETUGT || CC == ISD::SETUGE) &&
6514 (OrAndOpcode == ISD::AND)))
6515 return isFMAXNUMFMINNUM ? ISD::FMINNUM
6516 : arebothOperandsNotSNan(Operand1, Operand2, DAG) &&
6517 isFMAXNUMFMINNUM_IEEE
6518 ? ISD::FMINNUM_IEEE
6519 : ISD::DELETED_NODE;
6520 else if (((CC == ISD::SETOGT || CC == ISD::SETOGE) &&
6521 (OrAndOpcode == ISD::OR)) ||
6522 ((CC == ISD::SETULT || CC == ISD::SETULE) &&
6523 (OrAndOpcode == ISD::AND)))
6524 return isFMAXNUMFMINNUM ? ISD::FMAXNUM
6525 : arebothOperandsNotSNan(Operand1, Operand2, DAG) &&
6526 isFMAXNUMFMINNUM_IEEE
6527 ? ISD::FMAXNUM_IEEE
6528 : ISD::DELETED_NODE;
6529 return ISD::DELETED_NODE;
6530}
6531
6532static SDValue foldAndOrOfSETCC(SDNode *LogicOp, SelectionDAG &DAG) {
6533 using AndOrSETCCFoldKind = TargetLowering::AndOrSETCCFoldKind;
6534 assert(
6535 (LogicOp->getOpcode() == ISD::AND || LogicOp->getOpcode() == ISD::OR) &&
6536 "Invalid Op to combine SETCC with");
6537
6538 // TODO: Search past casts/truncates.
6539 SDValue LHS = LogicOp->getOperand(Num: 0);
6540 SDValue RHS = LogicOp->getOperand(Num: 1);
6541 if (LHS->getOpcode() != ISD::SETCC || RHS->getOpcode() != ISD::SETCC ||
6542 !LHS->hasOneUse() || !RHS->hasOneUse())
6543 return SDValue();
6544
6545 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6546 AndOrSETCCFoldKind TargetPreference = TLI.isDesirableToCombineLogicOpOfSETCC(
6547 LogicOp, SETCC0: LHS.getNode(), SETCC1: RHS.getNode());
6548
6549 SDValue LHS0 = LHS->getOperand(Num: 0);
6550 SDValue RHS0 = RHS->getOperand(Num: 0);
6551 SDValue LHS1 = LHS->getOperand(Num: 1);
6552 SDValue RHS1 = RHS->getOperand(Num: 1);
6553 // TODO: We don't actually need a splat here, for vectors we just need the
6554 // invariants to hold for each element.
6555 auto *LHS1C = isConstOrConstSplat(N: LHS1);
6556 auto *RHS1C = isConstOrConstSplat(N: RHS1);
6557 ISD::CondCode CCL = cast<CondCodeSDNode>(Val: LHS.getOperand(i: 2))->get();
6558 ISD::CondCode CCR = cast<CondCodeSDNode>(Val: RHS.getOperand(i: 2))->get();
6559 EVT VT = LogicOp->getValueType(ResNo: 0);
6560 EVT OpVT = LHS0.getValueType();
6561 SDLoc DL(LogicOp);
6562
6563 // Check if the operands of an and/or operation are comparisons and if they
6564 // compare against the same value. Replace the and/or-cmp-cmp sequence with
6565 // min/max cmp sequence. If LHS1 is equal to RHS1, then the or-cmp-cmp
6566 // sequence will be replaced with min-cmp sequence:
6567 // (LHS0 < LHS1) | (RHS0 < RHS1) -> min(LHS0, RHS0) < LHS1
6568 // and and-cmp-cmp will be replaced with max-cmp sequence:
6569 // (LHS0 < LHS1) & (RHS0 < RHS1) -> max(LHS0, RHS0) < LHS1
6570 // The optimization does not work for `==` or `!=` .
6571 // The two comparisons should have either the same predicate or the
6572 // predicate of one of the comparisons is the opposite of the other one.
6573 bool isFMAXNUMFMINNUM_IEEE = TLI.isOperationLegal(Op: ISD::FMAXNUM_IEEE, VT: OpVT) &&
6574 TLI.isOperationLegal(Op: ISD::FMINNUM_IEEE, VT: OpVT);
6575 bool isFMAXNUMFMINNUM = TLI.isOperationLegalOrCustom(Op: ISD::FMAXNUM, VT: OpVT) &&
6576 TLI.isOperationLegalOrCustom(Op: ISD::FMINNUM, VT: OpVT);
6577 if (((OpVT.isInteger() && TLI.isOperationLegal(Op: ISD::UMAX, VT: OpVT) &&
6578 TLI.isOperationLegal(Op: ISD::SMAX, VT: OpVT) &&
6579 TLI.isOperationLegal(Op: ISD::UMIN, VT: OpVT) &&
6580 TLI.isOperationLegal(Op: ISD::SMIN, VT: OpVT)) ||
6581 (OpVT.isFloatingPoint() &&
6582 (isFMAXNUMFMINNUM_IEEE || isFMAXNUMFMINNUM))) &&
6583 !ISD::isIntEqualitySetCC(Code: CCL) && !ISD::isFPEqualitySetCC(Code: CCL) &&
6584 CCL != ISD::SETFALSE && CCL != ISD::SETO && CCL != ISD::SETUO &&
6585 CCL != ISD::SETTRUE &&
6586 (CCL == CCR || CCL == ISD::getSetCCSwappedOperands(Operation: CCR))) {
6587
6588 SDValue CommonValue, Operand1, Operand2;
6589 ISD::CondCode CC = ISD::SETCC_INVALID;
6590 if (CCL == CCR) {
6591 if (LHS0 == RHS0) {
6592 CommonValue = LHS0;
6593 Operand1 = LHS1;
6594 Operand2 = RHS1;
6595 CC = ISD::getSetCCSwappedOperands(Operation: CCL);
6596 } else if (LHS1 == RHS1) {
6597 CommonValue = LHS1;
6598 Operand1 = LHS0;
6599 Operand2 = RHS0;
6600 CC = CCL;
6601 }
6602 } else {
6603 assert(CCL == ISD::getSetCCSwappedOperands(CCR) && "Unexpected CC");
6604 if (LHS0 == RHS1) {
6605 CommonValue = LHS0;
6606 Operand1 = LHS1;
6607 Operand2 = RHS0;
6608 CC = CCR;
6609 } else if (RHS0 == LHS1) {
6610 CommonValue = LHS1;
6611 Operand1 = LHS0;
6612 Operand2 = RHS1;
6613 CC = CCL;
6614 }
6615 }
6616
6617 // Don't do this transform for sign bit tests. Let foldLogicOfSetCCs
6618 // handle it using OR/AND.
6619 if (CC == ISD::SETLT && isNullOrNullSplat(V: CommonValue))
6620 CC = ISD::SETCC_INVALID;
6621 else if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(V: CommonValue))
6622 CC = ISD::SETCC_INVALID;
6623
6624 if (CC != ISD::SETCC_INVALID) {
6625 unsigned NewOpcode = ISD::DELETED_NODE;
6626 bool IsSigned = isSignedIntSetCC(Code: CC);
6627 if (OpVT.isInteger()) {
6628 bool IsLess = (CC == ISD::SETLE || CC == ISD::SETULE ||
6629 CC == ISD::SETLT || CC == ISD::SETULT);
6630 bool IsOr = (LogicOp->getOpcode() == ISD::OR);
6631 if (IsLess == IsOr)
6632 NewOpcode = IsSigned ? ISD::SMIN : ISD::UMIN;
6633 else
6634 NewOpcode = IsSigned ? ISD::SMAX : ISD::UMAX;
6635 } else if (OpVT.isFloatingPoint())
6636 NewOpcode =
6637 getMinMaxOpcodeForFP(Operand1, Operand2, CC, OrAndOpcode: LogicOp->getOpcode(),
6638 DAG, isFMAXNUMFMINNUM_IEEE, isFMAXNUMFMINNUM);
6639
6640 if (NewOpcode != ISD::DELETED_NODE) {
6641 SDValue MinMaxValue =
6642 DAG.getNode(Opcode: NewOpcode, DL, VT: OpVT, N1: Operand1, N2: Operand2);
6643 return DAG.getSetCC(DL, VT, LHS: MinMaxValue, RHS: CommonValue, Cond: CC);
6644 }
6645 }
6646 }
6647
6648 if (LHS0 == LHS1 && RHS0 == RHS1 && CCL == CCR &&
6649 LHS0.getValueType() == RHS0.getValueType() &&
6650 ((LogicOp->getOpcode() == ISD::AND && CCL == ISD::SETO) ||
6651 (LogicOp->getOpcode() == ISD::OR && CCL == ISD::SETUO)))
6652 return DAG.getSetCC(DL, VT, LHS: LHS0, RHS: RHS0, Cond: CCL);
6653
6654 if (TargetPreference == AndOrSETCCFoldKind::None)
6655 return SDValue();
6656
6657 if (CCL == CCR &&
6658 CCL == (LogicOp->getOpcode() == ISD::AND ? ISD::SETNE : ISD::SETEQ) &&
6659 LHS0 == RHS0 && LHS1C && RHS1C && OpVT.isInteger()) {
6660 const APInt &APLhs = LHS1C->getAPIntValue();
6661 const APInt &APRhs = RHS1C->getAPIntValue();
6662
6663 // Preference is to use ISD::ABS or we already have an ISD::ABS (in which
6664 // case this is just a compare).
6665 if (APLhs == (-APRhs) &&
6666 ((TargetPreference & AndOrSETCCFoldKind::ABS) ||
6667 DAG.doesNodeExist(Opcode: ISD::ABS, VTList: DAG.getVTList(VT: OpVT), Ops: {LHS0}))) {
6668 const APInt &C = APLhs.isNegative() ? APRhs : APLhs;
6669 // (icmp eq A, C) | (icmp eq A, -C)
6670 // -> (icmp eq Abs(A), C)
6671 // (icmp ne A, C) & (icmp ne A, -C)
6672 // -> (icmp ne Abs(A), C)
6673 SDValue AbsOp = DAG.getNode(Opcode: ISD::ABS, DL, VT: OpVT, Operand: LHS0);
6674 return DAG.getNode(Opcode: ISD::SETCC, DL, VT, N1: AbsOp,
6675 N2: DAG.getConstant(Val: C, DL, VT: OpVT), N3: LHS.getOperand(i: 2));
6676 } else if (TargetPreference &
6677 (AndOrSETCCFoldKind::AddAnd | AndOrSETCCFoldKind::NotAnd)) {
6678
6679 // AndOrSETCCFoldKind::AddAnd:
6680 // A == C0 | A == C1
6681 // IF IsPow2(smax(C0, C1)-smin(C0, C1))
6682 // -> ((A - smin(C0, C1)) & ~(smax(C0, C1)-smin(C0, C1))) == 0
6683 // A != C0 & A != C1
6684 // IF IsPow2(smax(C0, C1)-smin(C0, C1))
6685 // -> ((A - smin(C0, C1)) & ~(smax(C0, C1)-smin(C0, C1))) != 0
6686
6687 // AndOrSETCCFoldKind::NotAnd:
6688 // A == C0 | A == C1
6689 // IF smax(C0, C1) == -1 AND IsPow2(smax(C0, C1) - smin(C0, C1))
6690 // -> ~A & smin(C0, C1) == 0
6691 // A != C0 & A != C1
6692 // IF smax(C0, C1) == -1 AND IsPow2(smax(C0, C1) - smin(C0, C1))
6693 // -> ~A & smin(C0, C1) != 0
6694
6695 const APInt &MaxC = APIntOps::smax(A: APRhs, B: APLhs);
6696 const APInt &MinC = APIntOps::smin(A: APRhs, B: APLhs);
6697 APInt Dif = MaxC - MinC;
6698 if (!Dif.isZero() && Dif.isPowerOf2()) {
6699 if (MaxC.isAllOnes() &&
6700 (TargetPreference & AndOrSETCCFoldKind::NotAnd)) {
6701 SDValue NotOp = DAG.getNOT(DL, Val: LHS0, VT: OpVT);
6702 SDValue AndOp = DAG.getNode(Opcode: ISD::AND, DL, VT: OpVT, N1: NotOp,
6703 N2: DAG.getConstant(Val: MinC, DL, VT: OpVT));
6704 return DAG.getNode(Opcode: ISD::SETCC, DL, VT, N1: AndOp,
6705 N2: DAG.getConstant(Val: 0, DL, VT: OpVT), N3: LHS.getOperand(i: 2));
6706 } else if (TargetPreference & AndOrSETCCFoldKind::AddAnd) {
6707
6708 SDValue AddOp = DAG.getNode(Opcode: ISD::ADD, DL, VT: OpVT, N1: LHS0,
6709 N2: DAG.getConstant(Val: -MinC, DL, VT: OpVT));
6710 SDValue AndOp = DAG.getNode(Opcode: ISD::AND, DL, VT: OpVT, N1: AddOp,
6711 N2: DAG.getConstant(Val: ~Dif, DL, VT: OpVT));
6712 return DAG.getNode(Opcode: ISD::SETCC, DL, VT, N1: AndOp,
6713 N2: DAG.getConstant(Val: 0, DL, VT: OpVT), N3: LHS.getOperand(i: 2));
6714 }
6715 }
6716 }
6717 }
6718
6719 return SDValue();
6720}
6721
6722// Combine `(select c, (X & 1), 0)` -> `(and (zext c), X)`.
6723// We canonicalize to the `select` form in the middle end, but the `and` form
6724// gets better codegen and all tested targets (arm, x86, riscv)
6725static SDValue combineSelectAsExtAnd(SDValue Cond, SDValue T, SDValue F,
6726 const SDLoc &DL, SelectionDAG &DAG) {
6727 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6728 if (!isNullConstant(V: F))
6729 return SDValue();
6730
6731 EVT CondVT = Cond.getValueType();
6732 if (TLI.getBooleanContents(Type: CondVT) !=
6733 TargetLoweringBase::ZeroOrOneBooleanContent)
6734 return SDValue();
6735
6736 if (T.getOpcode() != ISD::AND)
6737 return SDValue();
6738
6739 if (!isOneConstant(V: T.getOperand(i: 1)))
6740 return SDValue();
6741
6742 EVT OpVT = T.getValueType();
6743
6744 SDValue CondMask =
6745 OpVT == CondVT ? Cond : DAG.getBoolExtOrTrunc(Op: Cond, SL: DL, VT: OpVT, OpVT: CondVT);
6746 return DAG.getNode(Opcode: ISD::AND, DL, VT: OpVT, N1: CondMask, N2: T.getOperand(i: 0));
6747}
6748
6749/// This contains all DAGCombine rules which reduce two values combined by
6750/// an And operation to a single value. This makes them reusable in the context
6751/// of visitSELECT(). Rules involving constants are not included as
6752/// visitSELECT() already handles those cases.
6753SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, SDNode *N) {
6754 EVT VT = N1.getValueType();
6755 SDLoc DL(N);
6756
6757 // fold (and x, undef) -> 0
6758 if (N0.isUndef() || N1.isUndef())
6759 return DAG.getConstant(Val: 0, DL, VT);
6760
6761 if (SDValue V = foldLogicOfSetCCs(IsAnd: true, N0, N1, DL))
6762 return V;
6763
6764 // Canonicalize:
6765 // and(x, add) -> and(add, x)
6766 if (N1.getOpcode() == ISD::ADD)
6767 std::swap(a&: N0, b&: N1);
6768
6769 // TODO: Rewrite this to return a new 'AND' instead of using CombineTo.
6770 if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::SRL &&
6771 VT.isScalarInteger() && VT.getSizeInBits() <= 64 && N0->hasOneUse()) {
6772 if (ConstantSDNode *ADDI = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1))) {
6773 if (ConstantSDNode *SRLI = dyn_cast<ConstantSDNode>(Val: N1.getOperand(i: 1))) {
6774 // Look for (and (add x, c1), (lshr y, c2)). If C1 wasn't a legal
6775 // immediate for an add, but it is legal if its top c2 bits are set,
6776 // transform the ADD so the immediate doesn't need to be materialized
6777 // in a register.
6778 APInt ADDC = ADDI->getAPIntValue();
6779 APInt SRLC = SRLI->getAPIntValue();
6780 if (ADDC.getSignificantBits() <= 64 && SRLC.ult(RHS: VT.getSizeInBits()) &&
6781 !TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
6782 APInt Mask = APInt::getHighBitsSet(numBits: VT.getSizeInBits(),
6783 hiBitsSet: SRLC.getZExtValue());
6784 if (DAG.MaskedValueIsZero(Op: N0.getOperand(i: 1), Mask)) {
6785 ADDC |= Mask;
6786 if (TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
6787 SDLoc DL0(N0);
6788 SDValue NewAdd =
6789 DAG.getNode(Opcode: ISD::ADD, DL: DL0, VT,
6790 N1: N0.getOperand(i: 0), N2: DAG.getConstant(Val: ADDC, DL, VT));
6791 CombineTo(N: N0.getNode(), Res: NewAdd);
6792 // Return N so it doesn't get rechecked!
6793 return SDValue(N, 0);
6794 }
6795 }
6796 }
6797 }
6798 }
6799 }
6800
6801 return SDValue();
6802}
6803
6804bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
6805 EVT LoadResultTy, EVT &ExtVT) {
6806 if (!AndC->getAPIntValue().isMask())
6807 return false;
6808
6809 unsigned ActiveBits = AndC->getAPIntValue().countr_one();
6810
6811 ExtVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ActiveBits);
6812 EVT LoadedVT = LoadN->getMemoryVT();
6813
6814 if (ExtVT == LoadedVT &&
6815 (!LegalOperations ||
6816 TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: LoadResultTy, MemVT: ExtVT))) {
6817 // ZEXTLOAD will match without needing to change the size of the value being
6818 // loaded.
6819 return true;
6820 }
6821
6822 // Do not change the width of a volatile or atomic loads.
6823 if (!LoadN->isSimple())
6824 return false;
6825
6826 // Do not generate loads of non-round integer types since these can
6827 // be expensive (and would be wrong if the type is not byte sized).
6828 if (!LoadedVT.bitsGT(VT: ExtVT) || !ExtVT.isRound())
6829 return false;
6830
6831 if (LegalOperations &&
6832 !TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: LoadResultTy, MemVT: ExtVT))
6833 return false;
6834
6835 if (!TLI.shouldReduceLoadWidth(Load: LoadN, ExtTy: ISD::ZEXTLOAD, NewVT: ExtVT, /*ByteOffset=*/0))
6836 return false;
6837
6838 return true;
6839}
6840
6841bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST,
6842 ISD::LoadExtType ExtType, EVT &MemVT,
6843 unsigned ShAmt) {
6844 if (!LDST)
6845 return false;
6846
6847 // Only allow byte offsets.
6848 if (ShAmt % 8)
6849 return false;
6850 const unsigned ByteShAmt = ShAmt / 8;
6851
6852 // Do not generate loads of non-round integer types since these can
6853 // be expensive (and would be wrong if the type is not byte sized).
6854 if (!MemVT.isRound())
6855 return false;
6856
6857 // Don't change the width of a volatile or atomic loads.
6858 if (!LDST->isSimple())
6859 return false;
6860
6861 EVT LdStMemVT = LDST->getMemoryVT();
6862
6863 // Bail out when changing the scalable property, since we can't be sure that
6864 // we're actually narrowing here.
6865 if (LdStMemVT.isScalableVector() != MemVT.isScalableVector())
6866 return false;
6867
6868 // Verify that we are actually reducing a load width here.
6869 if (LdStMemVT.bitsLT(VT: MemVT))
6870 return false;
6871
6872 // Ensure that this isn't going to produce an unsupported memory access.
6873 if (ShAmt) {
6874 const Align LDSTAlign = LDST->getAlign();
6875 const Align NarrowAlign = commonAlignment(A: LDSTAlign, Offset: ByteShAmt);
6876 if (!TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: MemVT,
6877 AddrSpace: LDST->getAddressSpace(), Alignment: NarrowAlign,
6878 Flags: LDST->getMemOperand()->getFlags()))
6879 return false;
6880 }
6881
6882 // It's not possible to generate a constant of extended or untyped type.
6883 EVT PtrType = LDST->getBasePtr().getValueType();
6884 if (PtrType == MVT::Untyped || PtrType.isExtended())
6885 return false;
6886
6887 if (isa<LoadSDNode>(Val: LDST)) {
6888 LoadSDNode *Load = cast<LoadSDNode>(Val: LDST);
6889 // Don't transform one with multiple uses, this would require adding a new
6890 // load.
6891 if (!SDValue(Load, 0).hasOneUse())
6892 return false;
6893
6894 if (LegalOperations &&
6895 !TLI.isLoadExtLegal(ExtType, ValVT: Load->getValueType(ResNo: 0), MemVT))
6896 return false;
6897
6898 // For the transform to be legal, the load must produce only two values
6899 // (the value loaded and the chain). Don't transform a pre-increment
6900 // load, for example, which produces an extra value. Otherwise the
6901 // transformation is not equivalent, and the downstream logic to replace
6902 // uses gets things wrong.
6903 if (Load->getNumValues() > 2)
6904 return false;
6905
6906 // If the load that we're shrinking is an extload and we're not just
6907 // discarding the extension we can't simply shrink the load. Bail.
6908 // TODO: It would be possible to merge the extensions in some cases.
6909 if (Load->getExtensionType() != ISD::NON_EXTLOAD &&
6910 Load->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
6911 return false;
6912
6913 if (!TLI.shouldReduceLoadWidth(Load, ExtTy: ExtType, NewVT: MemVT, ByteOffset: ByteShAmt))
6914 return false;
6915 } else {
6916 assert(isa<StoreSDNode>(LDST) && "It is not a Load nor a Store SDNode");
6917 StoreSDNode *Store = cast<StoreSDNode>(Val: LDST);
6918 // Can't write outside the original store
6919 if (Store->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
6920 return false;
6921
6922 if (LegalOperations &&
6923 !TLI.isTruncStoreLegal(ValVT: Store->getValue().getValueType(), MemVT))
6924 return false;
6925 }
6926 return true;
6927}
6928
6929bool DAGCombiner::SearchForAndLoads(SDNode *N,
6930 SmallVectorImpl<LoadSDNode*> &Loads,
6931 SmallPtrSetImpl<SDNode*> &NodesWithConsts,
6932 ConstantSDNode *Mask,
6933 SDNode *&NodeToMask) {
6934 // Recursively search for the operands, looking for loads which can be
6935 // narrowed.
6936 for (SDValue Op : N->op_values()) {
6937 if (Op.getValueType().isVector())
6938 return false;
6939
6940 // Some constants may need fixing up later if they are too large.
6941 if (auto *C = dyn_cast<ConstantSDNode>(Val&: Op)) {
6942 assert(ISD::isBitwiseLogicOp(N->getOpcode()) &&
6943 "Expected bitwise logic operation");
6944 if (!C->getAPIntValue().isSubsetOf(RHS: Mask->getAPIntValue()))
6945 NodesWithConsts.insert(Ptr: N);
6946 continue;
6947 }
6948
6949 if (!Op.hasOneUse())
6950 return false;
6951
6952 switch(Op.getOpcode()) {
6953 case ISD::LOAD: {
6954 auto *Load = cast<LoadSDNode>(Val&: Op);
6955 EVT ExtVT;
6956 if (isAndLoadExtLoad(AndC: Mask, LoadN: Load, LoadResultTy: Load->getValueType(ResNo: 0), ExtVT) &&
6957 isLegalNarrowLdSt(LDST: Load, ExtType: ISD::ZEXTLOAD, MemVT&: ExtVT)) {
6958
6959 // ZEXTLOAD is already small enough.
6960 if (Load->getExtensionType() == ISD::ZEXTLOAD &&
6961 ExtVT.bitsGE(VT: Load->getMemoryVT()))
6962 continue;
6963
6964 // Use LE to convert equal sized loads to zext.
6965 if (ExtVT.bitsLE(VT: Load->getMemoryVT()))
6966 Loads.push_back(Elt: Load);
6967
6968 continue;
6969 }
6970 return false;
6971 }
6972 case ISD::ZERO_EXTEND:
6973 case ISD::AssertZext: {
6974 unsigned ActiveBits = Mask->getAPIntValue().countr_one();
6975 EVT ExtVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ActiveBits);
6976 EVT VT = Op.getOpcode() == ISD::AssertZext ?
6977 cast<VTSDNode>(Val: Op.getOperand(i: 1))->getVT() :
6978 Op.getOperand(i: 0).getValueType();
6979
6980 // We can accept extending nodes if the mask is wider or an equal
6981 // width to the original type.
6982 if (ExtVT.bitsGE(VT))
6983 continue;
6984 break;
6985 }
6986 case ISD::OR:
6987 case ISD::XOR:
6988 case ISD::AND:
6989 if (!SearchForAndLoads(N: Op.getNode(), Loads, NodesWithConsts, Mask,
6990 NodeToMask))
6991 return false;
6992 continue;
6993 }
6994
6995 // Allow one node which will masked along with any loads found.
6996 if (NodeToMask)
6997 return false;
6998
6999 // Also ensure that the node to be masked only produces one data result.
7000 NodeToMask = Op.getNode();
7001 if (NodeToMask->getNumValues() > 1) {
7002 bool HasValue = false;
7003 for (unsigned i = 0, e = NodeToMask->getNumValues(); i < e; ++i) {
7004 MVT VT = SDValue(NodeToMask, i).getSimpleValueType();
7005 if (VT != MVT::Glue && VT != MVT::Other) {
7006 if (HasValue) {
7007 NodeToMask = nullptr;
7008 return false;
7009 }
7010 HasValue = true;
7011 }
7012 }
7013 assert(HasValue && "Node to be masked has no data result?");
7014 }
7015 }
7016 return true;
7017}
7018
7019bool DAGCombiner::BackwardsPropagateMask(SDNode *N) {
7020 auto *Mask = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
7021 if (!Mask)
7022 return false;
7023
7024 if (!Mask->getAPIntValue().isMask())
7025 return false;
7026
7027 // No need to do anything if the and directly uses a load.
7028 if (isa<LoadSDNode>(Val: N->getOperand(Num: 0)))
7029 return false;
7030
7031 SmallVector<LoadSDNode*, 8> Loads;
7032 SmallPtrSet<SDNode*, 2> NodesWithConsts;
7033 SDNode *FixupNode = nullptr;
7034 if (SearchForAndLoads(N, Loads, NodesWithConsts, Mask, NodeToMask&: FixupNode)) {
7035 if (Loads.empty())
7036 return false;
7037
7038 LLVM_DEBUG(dbgs() << "Backwards propagate AND: "; N->dump());
7039 SDValue MaskOp = N->getOperand(Num: 1);
7040
7041 // If it exists, fixup the single node we allow in the tree that needs
7042 // masking.
7043 if (FixupNode) {
7044 LLVM_DEBUG(dbgs() << "First, need to fix up: "; FixupNode->dump());
7045 SDValue And = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(FixupNode),
7046 VT: FixupNode->getValueType(ResNo: 0),
7047 N1: SDValue(FixupNode, 0), N2: MaskOp);
7048 DAG.ReplaceAllUsesOfValueWith(From: SDValue(FixupNode, 0), To: And);
7049 if (And.getOpcode() == ISD ::AND)
7050 DAG.UpdateNodeOperands(N: And.getNode(), Op1: SDValue(FixupNode, 0), Op2: MaskOp);
7051 }
7052
7053 // Narrow any constants that need it.
7054 for (auto *LogicN : NodesWithConsts) {
7055 SDValue Op0 = LogicN->getOperand(Num: 0);
7056 SDValue Op1 = LogicN->getOperand(Num: 1);
7057
7058 // We only need to fix AND if both inputs are constants. And we only need
7059 // to fix one of the constants.
7060 if (LogicN->getOpcode() == ISD::AND &&
7061 (!isa<ConstantSDNode>(Val: Op0) || !isa<ConstantSDNode>(Val: Op1)))
7062 continue;
7063
7064 if (isa<ConstantSDNode>(Val: Op0) && LogicN->getOpcode() != ISD::AND)
7065 Op0 =
7066 DAG.getNode(Opcode: ISD::AND, DL: SDLoc(Op0), VT: Op0.getValueType(), N1: Op0, N2: MaskOp);
7067
7068 if (isa<ConstantSDNode>(Val: Op1))
7069 Op1 =
7070 DAG.getNode(Opcode: ISD::AND, DL: SDLoc(Op1), VT: Op1.getValueType(), N1: Op1, N2: MaskOp);
7071
7072 if (isa<ConstantSDNode>(Val: Op0) && !isa<ConstantSDNode>(Val: Op1))
7073 std::swap(a&: Op0, b&: Op1);
7074
7075 DAG.UpdateNodeOperands(N: LogicN, Op1: Op0, Op2: Op1);
7076 }
7077
7078 // Create narrow loads.
7079 for (auto *Load : Loads) {
7080 LLVM_DEBUG(dbgs() << "Propagate AND back to: "; Load->dump());
7081 SDValue And = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(Load), VT: Load->getValueType(ResNo: 0),
7082 N1: SDValue(Load, 0), N2: MaskOp);
7083 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Load, 0), To: And);
7084 if (And.getOpcode() == ISD ::AND)
7085 And = SDValue(
7086 DAG.UpdateNodeOperands(N: And.getNode(), Op1: SDValue(Load, 0), Op2: MaskOp), 0);
7087 SDValue NewLoad = reduceLoadWidth(N: And.getNode());
7088 assert(NewLoad &&
7089 "Shouldn't be masking the load if it can't be narrowed");
7090 CombineTo(N: Load, Res0: NewLoad, Res1: NewLoad.getValue(R: 1));
7091 }
7092 DAG.ReplaceAllUsesWith(From: N, To: N->getOperand(Num: 0).getNode());
7093 return true;
7094 }
7095 return false;
7096}
7097
7098// Unfold
7099// x & (-1 'logical shift' y)
7100// To
7101// (x 'opposite logical shift' y) 'logical shift' y
7102// if it is better for performance.
7103SDValue DAGCombiner::unfoldExtremeBitClearingToShifts(SDNode *N) {
7104 assert(N->getOpcode() == ISD::AND);
7105
7106 SDValue N0 = N->getOperand(Num: 0);
7107 SDValue N1 = N->getOperand(Num: 1);
7108
7109 // Do we actually prefer shifts over mask?
7110 if (!TLI.shouldFoldMaskToVariableShiftPair(X: N0))
7111 return SDValue();
7112
7113 // Try to match (-1 '[outer] logical shift' y)
7114 unsigned OuterShift;
7115 unsigned InnerShift; // The opposite direction to the OuterShift.
7116 SDValue Y; // Shift amount.
7117 auto matchMask = [&OuterShift, &InnerShift, &Y](SDValue M) -> bool {
7118 if (!M.hasOneUse())
7119 return false;
7120 OuterShift = M->getOpcode();
7121 if (OuterShift == ISD::SHL)
7122 InnerShift = ISD::SRL;
7123 else if (OuterShift == ISD::SRL)
7124 InnerShift = ISD::SHL;
7125 else
7126 return false;
7127 if (!isAllOnesConstant(V: M->getOperand(Num: 0)))
7128 return false;
7129 Y = M->getOperand(Num: 1);
7130 return true;
7131 };
7132
7133 SDValue X;
7134 if (matchMask(N1))
7135 X = N0;
7136 else if (matchMask(N0))
7137 X = N1;
7138 else
7139 return SDValue();
7140
7141 SDLoc DL(N);
7142 EVT VT = N->getValueType(ResNo: 0);
7143
7144 // tmp = x 'opposite logical shift' y
7145 SDValue T0 = DAG.getNode(Opcode: InnerShift, DL, VT, N1: X, N2: Y);
7146 // ret = tmp 'logical shift' y
7147 SDValue T1 = DAG.getNode(Opcode: OuterShift, DL, VT, N1: T0, N2: Y);
7148
7149 return T1;
7150}
7151
7152/// Try to replace shift/logic that tests if a bit is clear with mask + setcc.
7153/// For a target with a bit test, this is expected to become test + set and save
7154/// at least 1 instruction.
7155static SDValue combineShiftAnd1ToBitTest(SDNode *And, SelectionDAG &DAG) {
7156 assert(And->getOpcode() == ISD::AND && "Expected an 'and' op");
7157
7158 // Look through an optional extension.
7159 SDValue And0 = And->getOperand(Num: 0), And1 = And->getOperand(Num: 1);
7160 if (And0.getOpcode() == ISD::ANY_EXTEND && And0.hasOneUse())
7161 And0 = And0.getOperand(i: 0);
7162 if (!isOneConstant(V: And1) || !And0.hasOneUse())
7163 return SDValue();
7164
7165 SDValue Src = And0;
7166
7167 // Attempt to find a 'not' op.
7168 // TODO: Should we favor test+set even without the 'not' op?
7169 bool FoundNot = false;
7170 if (isBitwiseNot(V: Src)) {
7171 FoundNot = true;
7172 Src = Src.getOperand(i: 0);
7173
7174 // Look though an optional truncation. The source operand may not be the
7175 // same type as the original 'and', but that is ok because we are masking
7176 // off everything but the low bit.
7177 if (Src.getOpcode() == ISD::TRUNCATE && Src.hasOneUse())
7178 Src = Src.getOperand(i: 0);
7179 }
7180
7181 // Match a shift-right by constant.
7182 if (Src.getOpcode() != ISD::SRL || !Src.hasOneUse())
7183 return SDValue();
7184
7185 // This is probably not worthwhile without a supported type.
7186 EVT SrcVT = Src.getValueType();
7187 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
7188 if (!TLI.isTypeLegal(VT: SrcVT))
7189 return SDValue();
7190
7191 // We might have looked through casts that make this transform invalid.
7192 unsigned BitWidth = SrcVT.getScalarSizeInBits();
7193 SDValue ShiftAmt = Src.getOperand(i: 1);
7194 auto *ShiftAmtC = dyn_cast<ConstantSDNode>(Val&: ShiftAmt);
7195 if (!ShiftAmtC || !ShiftAmtC->getAPIntValue().ult(RHS: BitWidth))
7196 return SDValue();
7197
7198 // Set source to shift source.
7199 Src = Src.getOperand(i: 0);
7200
7201 // Try again to find a 'not' op.
7202 // TODO: Should we favor test+set even with two 'not' ops?
7203 if (!FoundNot) {
7204 if (!isBitwiseNot(V: Src))
7205 return SDValue();
7206 Src = Src.getOperand(i: 0);
7207 }
7208
7209 if (!TLI.hasBitTest(X: Src, Y: ShiftAmt))
7210 return SDValue();
7211
7212 // Turn this into a bit-test pattern using mask op + setcc:
7213 // and (not (srl X, C)), 1 --> (and X, 1<<C) == 0
7214 // and (srl (not X), C)), 1 --> (and X, 1<<C) == 0
7215 SDLoc DL(And);
7216 SDValue X = DAG.getZExtOrTrunc(Op: Src, DL, VT: SrcVT);
7217 EVT CCVT =
7218 TLI.getSetCCResultType(DL: DAG.getDataLayout(), Context&: *DAG.getContext(), VT: SrcVT);
7219 SDValue Mask = DAG.getConstant(
7220 Val: APInt::getOneBitSet(numBits: BitWidth, BitNo: ShiftAmtC->getZExtValue()), DL, VT: SrcVT);
7221 SDValue NewAnd = DAG.getNode(Opcode: ISD::AND, DL, VT: SrcVT, N1: X, N2: Mask);
7222 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: SrcVT);
7223 SDValue Setcc = DAG.getSetCC(DL, VT: CCVT, LHS: NewAnd, RHS: Zero, Cond: ISD::SETEQ);
7224 return DAG.getZExtOrTrunc(Op: Setcc, DL, VT: And->getValueType(ResNo: 0));
7225}
7226
7227/// For targets that support usubsat, match a bit-hack form of that operation
7228/// that ends in 'and' and convert it.
7229static SDValue foldAndToUsubsat(SDNode *N, SelectionDAG &DAG, const SDLoc &DL) {
7230 EVT VT = N->getValueType(ResNo: 0);
7231 unsigned BitWidth = VT.getScalarSizeInBits();
7232 APInt SignMask = APInt::getSignMask(BitWidth);
7233
7234 // (i8 X ^ 128) & (i8 X s>> 7) --> usubsat X, 128
7235 // (i8 X + 128) & (i8 X s>> 7) --> usubsat X, 128
7236 // xor/add with SMIN (signmask) are logically equivalent.
7237 SDValue X;
7238 if (!sd_match(N, P: m_And(L: m_OneUse(P: m_Xor(L: m_Value(N&: X), R: m_SpecificInt(V: SignMask))),
7239 R: m_OneUse(P: m_Sra(L: m_Deferred(V&: X),
7240 R: m_SpecificInt(V: BitWidth - 1))))) &&
7241 !sd_match(N, P: m_And(L: m_OneUse(P: m_Add(L: m_Value(N&: X), R: m_SpecificInt(V: SignMask))),
7242 R: m_OneUse(P: m_Sra(L: m_Deferred(V&: X),
7243 R: m_SpecificInt(V: BitWidth - 1))))))
7244 return SDValue();
7245
7246 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT, N1: X,
7247 N2: DAG.getConstant(Val: SignMask, DL, VT));
7248}
7249
7250/// Given a bitwise logic operation N with a matching bitwise logic operand,
7251/// fold a pattern where 2 of the source operands are identically shifted
7252/// values. For example:
7253/// ((X0 << Y) | Z) | (X1 << Y) --> ((X0 | X1) << Y) | Z
7254static SDValue foldLogicOfShifts(SDNode *N, SDValue LogicOp, SDValue ShiftOp,
7255 SelectionDAG &DAG) {
7256 unsigned LogicOpcode = N->getOpcode();
7257 assert(ISD::isBitwiseLogicOp(LogicOpcode) &&
7258 "Expected bitwise logic operation");
7259
7260 if (!LogicOp.hasOneUse() || !ShiftOp.hasOneUse())
7261 return SDValue();
7262
7263 // Match another bitwise logic op and a shift.
7264 unsigned ShiftOpcode = ShiftOp.getOpcode();
7265 if (LogicOp.getOpcode() != LogicOpcode ||
7266 !(ShiftOpcode == ISD::SHL || ShiftOpcode == ISD::SRL ||
7267 ShiftOpcode == ISD::SRA))
7268 return SDValue();
7269
7270 // Match another shift op inside the first logic operand. Handle both commuted
7271 // possibilities.
7272 // LOGIC (LOGIC (SH X0, Y), Z), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
7273 // LOGIC (LOGIC Z, (SH X0, Y)), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
7274 SDValue X1 = ShiftOp.getOperand(i: 0);
7275 SDValue Y = ShiftOp.getOperand(i: 1);
7276 SDValue X0, Z;
7277 if (LogicOp.getOperand(i: 0).getOpcode() == ShiftOpcode &&
7278 LogicOp.getOperand(i: 0).getOperand(i: 1) == Y) {
7279 X0 = LogicOp.getOperand(i: 0).getOperand(i: 0);
7280 Z = LogicOp.getOperand(i: 1);
7281 } else if (LogicOp.getOperand(i: 1).getOpcode() == ShiftOpcode &&
7282 LogicOp.getOperand(i: 1).getOperand(i: 1) == Y) {
7283 X0 = LogicOp.getOperand(i: 1).getOperand(i: 0);
7284 Z = LogicOp.getOperand(i: 0);
7285 } else {
7286 return SDValue();
7287 }
7288
7289 EVT VT = N->getValueType(ResNo: 0);
7290 SDLoc DL(N);
7291 SDValue LogicX = DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: X0, N2: X1);
7292 SDValue NewShift = DAG.getNode(Opcode: ShiftOpcode, DL, VT, N1: LogicX, N2: Y);
7293 return DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: NewShift, N2: Z);
7294}
7295
7296/// Given a tree of logic operations with shape like
7297/// (LOGIC (LOGIC (X, Y), LOGIC (Z, Y)))
7298/// try to match and fold shift operations with the same shift amount.
7299/// For example:
7300/// LOGIC (LOGIC (SH X0, Y), Z), (LOGIC (SH X1, Y), W) -->
7301/// --> LOGIC (SH (LOGIC X0, X1), Y), (LOGIC Z, W)
7302static SDValue foldLogicTreeOfShifts(SDNode *N, SDValue LeftHand,
7303 SDValue RightHand, SelectionDAG &DAG) {
7304 unsigned LogicOpcode = N->getOpcode();
7305 assert(ISD::isBitwiseLogicOp(LogicOpcode) &&
7306 "Expected bitwise logic operation");
7307 if (LeftHand.getOpcode() != LogicOpcode ||
7308 RightHand.getOpcode() != LogicOpcode)
7309 return SDValue();
7310 if (!LeftHand.hasOneUse() || !RightHand.hasOneUse())
7311 return SDValue();
7312
7313 // Try to match one of following patterns:
7314 // LOGIC (LOGIC (SH X0, Y), Z), (LOGIC (SH X1, Y), W)
7315 // LOGIC (LOGIC (SH X0, Y), Z), (LOGIC W, (SH X1, Y))
7316 // Note that foldLogicOfShifts will handle commuted versions of the left hand
7317 // itself.
7318 SDValue CombinedShifts, W;
7319 SDValue R0 = RightHand.getOperand(i: 0);
7320 SDValue R1 = RightHand.getOperand(i: 1);
7321 if ((CombinedShifts = foldLogicOfShifts(N, LogicOp: LeftHand, ShiftOp: R0, DAG)))
7322 W = R1;
7323 else if ((CombinedShifts = foldLogicOfShifts(N, LogicOp: LeftHand, ShiftOp: R1, DAG)))
7324 W = R0;
7325 else
7326 return SDValue();
7327
7328 EVT VT = N->getValueType(ResNo: 0);
7329 SDLoc DL(N);
7330 return DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: CombinedShifts, N2: W);
7331}
7332
7333/// Fold "masked merge" expressions like `(m & x) | (~m & y)` and its DeMorgan
7334/// variant `(~m | x) & (m | y)` into the equivalent `((x ^ y) & m) ^ y)`
7335/// pattern. This is typically a better representation for targets without a
7336/// fused "and-not" operation.
7337static SDValue foldMaskedMerge(SDNode *Node, SelectionDAG &DAG,
7338 const TargetLowering &TLI, const SDLoc &DL) {
7339 // Note that masked-merge variants using XOR or ADD expressions are
7340 // normalized to OR by InstCombine so we only check for OR or AND.
7341 assert((Node->getOpcode() == ISD::OR || Node->getOpcode() == ISD::AND) &&
7342 "Must be called with ISD::OR or ISD::AND node");
7343
7344 // If the target supports and-not, don't fold this.
7345 if (TLI.hasAndNot(X: SDValue(Node, 0)))
7346 return SDValue();
7347
7348 SDValue M, X, Y;
7349
7350 if (sd_match(N: Node,
7351 P: m_Or(L: m_OneUse(P: m_And(L: m_OneUse(P: m_Not(V: m_Value(N&: M))), R: m_Value(N&: Y))),
7352 R: m_OneUse(P: m_And(L: m_Deferred(V&: M), R: m_Value(N&: X))))) ||
7353 sd_match(N: Node,
7354 P: m_And(L: m_OneUse(P: m_Or(L: m_OneUse(P: m_Not(V: m_Value(N&: M))), R: m_Value(N&: X))),
7355 R: m_OneUse(P: m_Or(L: m_Deferred(V&: M), R: m_Value(N&: Y)))))) {
7356 EVT VT = M.getValueType();
7357 SDValue Xor = DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: X, N2: Y);
7358 SDValue And = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Xor, N2: M);
7359 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: And, N2: Y);
7360 }
7361 return SDValue();
7362}
7363
7364SDValue DAGCombiner::visitAND(SDNode *N) {
7365 SDValue N0 = N->getOperand(Num: 0);
7366 SDValue N1 = N->getOperand(Num: 1);
7367 EVT VT = N1.getValueType();
7368 SDLoc DL(N);
7369
7370 // x & x --> x
7371 if (N0 == N1)
7372 return N0;
7373
7374 // fold (and c1, c2) -> c1&c2
7375 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::AND, DL, VT, Ops: {N0, N1}))
7376 return C;
7377
7378 // canonicalize constant to RHS
7379 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
7380 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
7381 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1, N2: N0);
7382
7383 if (areBitwiseNotOfEachother(Op0: N0, Op1: N1))
7384 return DAG.getConstant(Val: APInt::getZero(numBits: VT.getScalarSizeInBits()), DL, VT);
7385
7386 // fold vector ops
7387 if (VT.isVector()) {
7388 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
7389 return FoldedVOp;
7390
7391 // fold (and x, 0) -> 0, vector edition
7392 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
7393 // do not return N1, because undef node may exist in N1
7394 return DAG.getConstant(Val: APInt::getZero(numBits: N1.getScalarValueSizeInBits()), DL,
7395 VT: N1.getValueType());
7396
7397 // fold (and x, -1) -> x, vector edition
7398 if (ISD::isConstantSplatVectorAllOnes(N: N1.getNode()))
7399 return N0;
7400
7401 // fold (and (masked_load) (splat_vec (x, ...))) to zext_masked_load
7402 auto *MLoad = dyn_cast<MaskedLoadSDNode>(Val&: N0);
7403 ConstantSDNode *Splat = isConstOrConstSplat(N: N1, AllowUndefs: true, AllowTruncation: true);
7404 if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && Splat) {
7405 EVT LoadVT = MLoad->getMemoryVT();
7406 EVT ExtVT = VT;
7407 if (TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: ExtVT, MemVT: LoadVT)) {
7408 // For this AND to be a zero extension of the masked load the elements
7409 // of the BuildVec must mask the bottom bits of the extended element
7410 // type
7411 uint64_t ElementSize =
7412 LoadVT.getVectorElementType().getScalarSizeInBits();
7413 if (Splat->getAPIntValue().isMask(numBits: ElementSize)) {
7414 SDValue NewLoad = DAG.getMaskedLoad(
7415 VT: ExtVT, dl: DL, Chain: MLoad->getChain(), Base: MLoad->getBasePtr(),
7416 Offset: MLoad->getOffset(), Mask: MLoad->getMask(), Src0: MLoad->getPassThru(),
7417 MemVT: LoadVT, MMO: MLoad->getMemOperand(), AM: MLoad->getAddressingMode(),
7418 ISD::ZEXTLOAD, IsExpanding: MLoad->isExpandingLoad());
7419 bool LoadHasOtherUsers = !N0.hasOneUse();
7420 CombineTo(N, Res: NewLoad);
7421 if (LoadHasOtherUsers)
7422 CombineTo(N: MLoad, Res0: NewLoad.getValue(R: 0), Res1: NewLoad.getValue(R: 1));
7423 return SDValue(N, 0);
7424 }
7425 }
7426 }
7427 }
7428
7429 // fold (and x, -1) -> x
7430 if (isAllOnesConstant(V: N1))
7431 return N0;
7432
7433 // if (and x, c) is known to be zero, return 0
7434 unsigned BitWidth = VT.getScalarSizeInBits();
7435 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
7436 if (N1C && DAG.MaskedValueIsZero(Op: SDValue(N, 0), Mask: APInt::getAllOnes(numBits: BitWidth)))
7437 return DAG.getConstant(Val: 0, DL, VT);
7438
7439 if (SDValue R = foldAndOrOfSETCC(LogicOp: N, DAG))
7440 return R;
7441
7442 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
7443 return NewSel;
7444
7445 // reassociate and
7446 if (SDValue RAND = reassociateOps(Opc: ISD::AND, DL, N0, N1, Flags: N->getFlags()))
7447 return RAND;
7448
7449 // Fold and(vecreduce(x), vecreduce(y)) -> vecreduce(and(x, y))
7450 if (SDValue SD =
7451 reassociateReduction(RedOpc: ISD::VECREDUCE_AND, Opc: ISD::AND, DL, VT, N0, N1))
7452 return SD;
7453
7454 // fold (and (or x, C), D) -> D if (C & D) == D
7455 auto MatchSubset = [](ConstantSDNode *LHS, ConstantSDNode *RHS) {
7456 return RHS->getAPIntValue().isSubsetOf(RHS: LHS->getAPIntValue());
7457 };
7458 if (N0.getOpcode() == ISD::OR &&
7459 ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchSubset))
7460 return N1;
7461
7462 if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
7463 SDValue N0Op0 = N0.getOperand(i: 0);
7464 EVT SrcVT = N0Op0.getValueType();
7465 unsigned SrcBitWidth = SrcVT.getScalarSizeInBits();
7466 APInt Mask = ~N1C->getAPIntValue();
7467 Mask = Mask.trunc(width: SrcBitWidth);
7468
7469 // fold (and (any_ext V), c) -> (zero_ext V) if 'and' only clears top bits.
7470 if (DAG.MaskedValueIsZero(Op: N0Op0, Mask))
7471 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0Op0);
7472
7473 // fold (and (any_ext V), c) -> (zero_ext (and (trunc V), c)) if profitable.
7474 if (N1C->getAPIntValue().countLeadingZeros() >= (BitWidth - SrcBitWidth) &&
7475 TLI.isTruncateFree(FromVT: VT, ToVT: SrcVT) && TLI.isZExtFree(FromTy: SrcVT, ToTy: VT) &&
7476 TLI.isTypeDesirableForOp(ISD::AND, VT: SrcVT) &&
7477 TLI.isNarrowingProfitable(N, SrcVT: VT, DestVT: SrcVT))
7478 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT,
7479 Operand: DAG.getNode(Opcode: ISD::AND, DL, VT: SrcVT, N1: N0Op0,
7480 N2: DAG.getZExtOrTrunc(Op: N1, DL, VT: SrcVT)));
7481 }
7482
7483 // fold (and (ext (and V, c1)), c2) -> (and (ext V), (and c1, (ext c2)))
7484 if (ISD::isExtOpcode(Opcode: N0.getOpcode())) {
7485 unsigned ExtOpc = N0.getOpcode();
7486 SDValue N0Op0 = N0.getOperand(i: 0);
7487 if (N0Op0.getOpcode() == ISD::AND &&
7488 (ExtOpc != ISD::ZERO_EXTEND || !TLI.isZExtFree(Val: N0Op0, VT2: VT)) &&
7489 N0->hasOneUse() && N0Op0->hasOneUse()) {
7490 if (SDValue NewExt = DAG.FoldConstantArithmetic(Opcode: ExtOpc, DL, VT,
7491 Ops: {N0Op0.getOperand(i: 1)})) {
7492 if (SDValue NewMask =
7493 DAG.FoldConstantArithmetic(Opcode: ISD::AND, DL, VT, Ops: {N1, NewExt})) {
7494 return DAG.getNode(Opcode: ISD::AND, DL, VT,
7495 N1: DAG.getNode(Opcode: ExtOpc, DL, VT, Operand: N0Op0.getOperand(i: 0)),
7496 N2: NewMask);
7497 }
7498 }
7499 }
7500 }
7501
7502 // similarly fold (and (X (load ([non_ext|any_ext|zero_ext] V))), c) ->
7503 // (X (load ([non_ext|zero_ext] V))) if 'and' only clears top bits which must
7504 // already be zero by virtue of the width of the base type of the load.
7505 //
7506 // the 'X' node here can either be nothing or an extract_vector_elt to catch
7507 // more cases.
7508 if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
7509 N0.getValueSizeInBits() == N0.getOperand(i: 0).getScalarValueSizeInBits() &&
7510 N0.getOperand(i: 0).getOpcode() == ISD::LOAD &&
7511 N0.getOperand(i: 0).getResNo() == 0) ||
7512 (N0.getOpcode() == ISD::LOAD && N0.getResNo() == 0)) {
7513 auto *Load =
7514 cast<LoadSDNode>(Val: (N0.getOpcode() == ISD::LOAD) ? N0 : N0.getOperand(i: 0));
7515
7516 // Get the constant (if applicable) the zero'th operand is being ANDed with.
7517 // This can be a pure constant or a vector splat, in which case we treat the
7518 // vector as a scalar and use the splat value.
7519 APInt Constant = APInt::getZero(numBits: 1);
7520 if (const ConstantSDNode *C = isConstOrConstSplat(
7521 N: N1, /*AllowUndef=*/AllowUndefs: false, /*AllowTruncation=*/true)) {
7522 Constant = C->getAPIntValue();
7523 } else if (BuildVectorSDNode *Vector = dyn_cast<BuildVectorSDNode>(Val&: N1)) {
7524 unsigned EltBitWidth = Vector->getValueType(ResNo: 0).getScalarSizeInBits();
7525 APInt SplatValue, SplatUndef;
7526 unsigned SplatBitSize;
7527 bool HasAnyUndefs;
7528 // Endianness should not matter here. Code below makes sure that we only
7529 // use the result if the SplatBitSize is a multiple of the vector element
7530 // size. And after that we AND all element sized parts of the splat
7531 // together. So the end result should be the same regardless of in which
7532 // order we do those operations.
7533 const bool IsBigEndian = false;
7534 bool IsSplat =
7535 Vector->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
7536 HasAnyUndefs, MinSplatBits: EltBitWidth, isBigEndian: IsBigEndian);
7537
7538 // Make sure that variable 'Constant' is only set if 'SplatBitSize' is a
7539 // multiple of 'BitWidth'. Otherwise, we could propagate a wrong value.
7540 if (IsSplat && (SplatBitSize % EltBitWidth) == 0) {
7541 // Undef bits can contribute to a possible optimisation if set, so
7542 // set them.
7543 SplatValue |= SplatUndef;
7544
7545 // The splat value may be something like "0x00FFFFFF", which means 0 for
7546 // the first vector value and FF for the rest, repeating. We need a mask
7547 // that will apply equally to all members of the vector, so AND all the
7548 // lanes of the constant together.
7549 Constant = APInt::getAllOnes(numBits: EltBitWidth);
7550 for (unsigned i = 0, n = (SplatBitSize / EltBitWidth); i < n; ++i)
7551 Constant &= SplatValue.extractBits(numBits: EltBitWidth, bitPosition: i * EltBitWidth);
7552 }
7553 }
7554
7555 // If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is
7556 // actually legal and isn't going to get expanded, else this is a false
7557 // optimisation.
7558 bool CanZextLoadProfitably = TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD,
7559 ValVT: Load->getValueType(ResNo: 0),
7560 MemVT: Load->getMemoryVT());
7561
7562 // Resize the constant to the same size as the original memory access before
7563 // extension. If it is still the AllOnesValue then this AND is completely
7564 // unneeded.
7565 Constant = Constant.zextOrTrunc(width: Load->getMemoryVT().getScalarSizeInBits());
7566
7567 bool B;
7568 switch (Load->getExtensionType()) {
7569 default: B = false; break;
7570 case ISD::EXTLOAD: B = CanZextLoadProfitably; break;
7571 case ISD::ZEXTLOAD:
7572 case ISD::NON_EXTLOAD: B = true; break;
7573 }
7574
7575 if (B && Constant.isAllOnes()) {
7576 // If the load type was an EXTLOAD, convert to ZEXTLOAD in order to
7577 // preserve semantics once we get rid of the AND.
7578 SDValue NewLoad(Load, 0);
7579
7580 // Fold the AND away. NewLoad may get replaced immediately.
7581 CombineTo(N, Res: (N0.getNode() == Load) ? NewLoad : N0);
7582
7583 if (Load->getExtensionType() == ISD::EXTLOAD) {
7584 NewLoad = DAG.getLoad(AM: Load->getAddressingMode(), ExtType: ISD::ZEXTLOAD,
7585 VT: Load->getValueType(ResNo: 0), dl: SDLoc(Load),
7586 Chain: Load->getChain(), Ptr: Load->getBasePtr(),
7587 Offset: Load->getOffset(), MemVT: Load->getMemoryVT(),
7588 MMO: Load->getMemOperand());
7589 // Replace uses of the EXTLOAD with the new ZEXTLOAD.
7590 if (Load->getNumValues() == 3) {
7591 // PRE/POST_INC loads have 3 values.
7592 SDValue To[] = { NewLoad.getValue(R: 0), NewLoad.getValue(R: 1),
7593 NewLoad.getValue(R: 2) };
7594 CombineTo(N: Load, To, NumTo: 3, AddTo: true);
7595 } else {
7596 CombineTo(N: Load, Res0: NewLoad.getValue(R: 0), Res1: NewLoad.getValue(R: 1));
7597 }
7598 }
7599
7600 return SDValue(N, 0); // Return N so it doesn't get rechecked!
7601 }
7602 }
7603
7604 // Try to convert a constant mask AND into a shuffle clear mask.
7605 if (VT.isVector())
7606 if (SDValue Shuffle = XformToShuffleWithZero(N))
7607 return Shuffle;
7608
7609 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
7610 return Combined;
7611
7612 if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && N0.hasOneUse() && N1C &&
7613 ISD::isExtOpcode(Opcode: N0.getOperand(i: 0).getOpcode())) {
7614 SDValue Ext = N0.getOperand(i: 0);
7615 EVT ExtVT = Ext->getValueType(ResNo: 0);
7616 SDValue Extendee = Ext->getOperand(Num: 0);
7617
7618 unsigned ScalarWidth = Extendee.getValueType().getScalarSizeInBits();
7619 if (N1C->getAPIntValue().isMask(numBits: ScalarWidth) &&
7620 (!LegalOperations || TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT: ExtVT))) {
7621 // (and (extract_subvector (zext|anyext|sext v) _) iN_mask)
7622 // => (extract_subvector (iN_zeroext v))
7623 SDValue ZeroExtExtendee =
7624 DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: ExtVT, Operand: Extendee);
7625
7626 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT, N1: ZeroExtExtendee,
7627 N2: N0.getOperand(i: 1));
7628 }
7629 }
7630
7631 // fold (and (masked_gather x)) -> (zext_masked_gather x)
7632 if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(Val&: N0)) {
7633 EVT MemVT = GN0->getMemoryVT();
7634 EVT ScalarVT = MemVT.getScalarType();
7635
7636 if (SDValue(GN0, 0).hasOneUse() &&
7637 isConstantSplatVectorMaskForType(N: N1.getNode(), ScalarTy: ScalarVT) &&
7638 TLI.isVectorLoadExtDesirable(ExtVal: SDValue(SDValue(GN0, 0)))) {
7639 SDValue Ops[] = {GN0->getChain(), GN0->getPassThru(), GN0->getMask(),
7640 GN0->getBasePtr(), GN0->getIndex(), GN0->getScale()};
7641
7642 SDValue ZExtLoad = DAG.getMaskedGather(
7643 VTs: DAG.getVTList(VT1: VT, VT2: MVT::Other), MemVT, dl: DL, Ops, MMO: GN0->getMemOperand(),
7644 IndexType: GN0->getIndexType(), ExtTy: ISD::ZEXTLOAD);
7645
7646 CombineTo(N, Res: ZExtLoad);
7647 AddToWorklist(N: ZExtLoad.getNode());
7648 // Avoid recheck of N.
7649 return SDValue(N, 0);
7650 }
7651 }
7652
7653 // fold (and (load x), 255) -> (zextload x, i8)
7654 // fold (and (extload x, i16), 255) -> (zextload x, i8)
7655 if (N1C && N0.getOpcode() == ISD::LOAD && !VT.isVector())
7656 if (SDValue Res = reduceLoadWidth(N))
7657 return Res;
7658
7659 if (LegalTypes) {
7660 // Attempt to propagate the AND back up to the leaves which, if they're
7661 // loads, can be combined to narrow loads and the AND node can be removed.
7662 // Perform after legalization so that extend nodes will already be
7663 // combined into the loads.
7664 if (BackwardsPropagateMask(N))
7665 return SDValue(N, 0);
7666 }
7667
7668 if (SDValue Combined = visitANDLike(N0, N1, N))
7669 return Combined;
7670
7671 // Simplify: (and (op x...), (op y...)) -> (op (and x, y))
7672 if (N0.getOpcode() == N1.getOpcode())
7673 if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
7674 return V;
7675
7676 if (SDValue R = foldLogicOfShifts(N, LogicOp: N0, ShiftOp: N1, DAG))
7677 return R;
7678 if (SDValue R = foldLogicOfShifts(N, LogicOp: N1, ShiftOp: N0, DAG))
7679 return R;
7680
7681 // Fold (and X, (bswap (not Y))) -> (and X, (not (bswap Y)))
7682 // Fold (and X, (bitreverse (not Y))) -> (and X, (not (bitreverse Y)))
7683 SDValue X, Y, Z, NotY;
7684 for (unsigned Opc : {ISD::BSWAP, ISD::BITREVERSE})
7685 if (sd_match(N,
7686 P: m_And(L: m_Value(N&: X), R: m_OneUse(P: m_UnaryOp(Opc, Op: m_Value(N&: NotY))))) &&
7687 sd_match(N: NotY, P: m_Not(V: m_Value(N&: Y))) &&
7688 (TLI.hasAndNot(X: SDValue(N, 0)) || NotY->hasOneUse()))
7689 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: X,
7690 N2: DAG.getNOT(DL, Val: DAG.getNode(Opcode: Opc, DL, VT, Operand: Y), VT));
7691
7692 // Fold (and X, (rot (not Y), Z)) -> (and X, (not (rot Y, Z)))
7693 for (unsigned Opc : {ISD::ROTL, ISD::ROTR})
7694 if (sd_match(N, P: m_And(L: m_Value(N&: X),
7695 R: m_OneUse(P: m_BinOp(Opc, L: m_Value(N&: NotY), R: m_Value(N&: Z))))) &&
7696 sd_match(N: NotY, P: m_Not(V: m_Value(N&: Y))) &&
7697 (TLI.hasAndNot(X: SDValue(N, 0)) || NotY->hasOneUse()))
7698 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: X,
7699 N2: DAG.getNOT(DL, Val: DAG.getNode(Opcode: Opc, DL, VT, N1: Y, N2: Z), VT));
7700
7701 // Fold (and X, (add (not Y), Z)) -> (and X, (not (sub Y, Z)))
7702 // Fold (and X, (sub (not Y), Z)) -> (and X, (not (add Y, Z)))
7703 if (TLI.hasAndNot(X: SDValue(N, 0)))
7704 if (SDValue Folded = foldBitwiseOpWithNeg(N, DL, VT))
7705 return Folded;
7706
7707 // Fold (and (srl X, C), 1) -> (srl X, BW-1) for signbit extraction
7708 // If we are shifting down an extended sign bit, see if we can simplify
7709 // this to shifting the MSB directly to expose further simplifications.
7710 // This pattern often appears after sext_inreg legalization.
7711 APInt Amt;
7712 if (sd_match(N, P: m_And(L: m_Srl(L: m_Value(N&: X), R: m_ConstInt(V&: Amt)), R: m_One())) &&
7713 Amt.ult(RHS: BitWidth - 1) && Amt.uge(RHS: BitWidth - DAG.ComputeNumSignBits(Op: X)))
7714 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: X,
7715 N2: DAG.getShiftAmountConstant(Val: BitWidth - 1, VT, DL));
7716
7717 // Masking the negated extension of a boolean is just the zero-extended
7718 // boolean:
7719 // and (sub 0, zext(bool X)), 1 --> zext(bool X)
7720 // and (sub 0, sext(bool X)), 1 --> zext(bool X)
7721 //
7722 // Note: the SimplifyDemandedBits fold below can make an information-losing
7723 // transform, and then we have no way to find this better fold.
7724 if (sd_match(N, P: m_And(L: m_Sub(L: m_Zero(), R: m_Value(N&: X)), R: m_One()))) {
7725 if (X.getOpcode() == ISD::ZERO_EXTEND &&
7726 X.getOperand(i: 0).getScalarValueSizeInBits() == 1)
7727 return X;
7728 if (X.getOpcode() == ISD::SIGN_EXTEND &&
7729 X.getOperand(i: 0).getScalarValueSizeInBits() == 1)
7730 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: X.getOperand(i: 0));
7731 }
7732
7733 // fold (and (sign_extend_inreg x, i16 to i32), 1) -> (and x, 1)
7734 // fold (and (sra)) -> (and (srl)) when possible.
7735 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
7736 return SDValue(N, 0);
7737
7738 // fold (zext_inreg (extload x)) -> (zextload x)
7739 // fold (zext_inreg (sextload x)) -> (zextload x) iff load has one use
7740 if (ISD::isUNINDEXEDLoad(N: N0.getNode()) &&
7741 (ISD::isEXTLoad(N: N0.getNode()) ||
7742 (ISD::isSEXTLoad(N: N0.getNode()) && N0.hasOneUse()))) {
7743 auto *LN0 = cast<LoadSDNode>(Val&: N0);
7744 EVT MemVT = LN0->getMemoryVT();
7745 // If we zero all the possible extended bits, then we can turn this into
7746 // a zextload if we are running before legalize or the operation is legal.
7747 unsigned ExtBitSize = N1.getScalarValueSizeInBits();
7748 unsigned MemBitSize = MemVT.getScalarSizeInBits();
7749 APInt ExtBits = APInt::getHighBitsSet(numBits: ExtBitSize, hiBitsSet: ExtBitSize - MemBitSize);
7750 if (DAG.MaskedValueIsZero(Op: N1, Mask: ExtBits) &&
7751 ((!LegalOperations && LN0->isSimple()) ||
7752 TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: VT, MemVT))) {
7753 SDValue ExtLoad =
7754 DAG.getExtLoad(ExtType: ISD::ZEXTLOAD, dl: SDLoc(N0), VT, Chain: LN0->getChain(),
7755 Ptr: LN0->getBasePtr(), MemVT, MMO: LN0->getMemOperand());
7756 AddToWorklist(N);
7757 CombineTo(N: N0.getNode(), Res0: ExtLoad, Res1: ExtLoad.getValue(R: 1));
7758 return SDValue(N, 0); // Return N so it doesn't get rechecked!
7759 }
7760 }
7761
7762 // fold (and (or (srl N, 8), (shl N, 8)), 0xffff) -> (srl (bswap N), const)
7763 if (N1C && N1C->getAPIntValue() == 0xffff && N0.getOpcode() == ISD::OR) {
7764 if (SDValue BSwap = MatchBSwapHWordLow(N: N0.getNode(), N0: N0.getOperand(i: 0),
7765 N1: N0.getOperand(i: 1), DemandHighBits: false))
7766 return BSwap;
7767 }
7768
7769 if (SDValue Shifts = unfoldExtremeBitClearingToShifts(N))
7770 return Shifts;
7771
7772 if (SDValue V = combineShiftAnd1ToBitTest(And: N, DAG))
7773 return V;
7774
7775 // Recognize the following pattern:
7776 //
7777 // AndVT = (and (sign_extend NarrowVT to AndVT) #bitmask)
7778 //
7779 // where bitmask is a mask that clears the upper bits of AndVT. The
7780 // number of bits in bitmask must be a power of two.
7781 auto IsAndZeroExtMask = [](SDValue LHS, SDValue RHS) {
7782 if (LHS->getOpcode() != ISD::SIGN_EXTEND)
7783 return false;
7784
7785 auto *C = dyn_cast<ConstantSDNode>(Val&: RHS);
7786 if (!C)
7787 return false;
7788
7789 if (!C->getAPIntValue().isMask(
7790 numBits: LHS.getOperand(i: 0).getValueType().getFixedSizeInBits()))
7791 return false;
7792
7793 return true;
7794 };
7795
7796 // Replace (and (sign_extend ...) #bitmask) with (zero_extend ...).
7797 if (IsAndZeroExtMask(N0, N1))
7798 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0.getOperand(i: 0));
7799
7800 if (hasOperation(Opcode: ISD::USUBSAT, VT))
7801 if (SDValue V = foldAndToUsubsat(N, DAG, DL))
7802 return V;
7803
7804 // Postpone until legalization completed to avoid interference with bswap
7805 // folding
7806 if (LegalOperations || VT.isVector())
7807 if (SDValue R = foldLogicTreeOfShifts(N, LeftHand: N0, RightHand: N1, DAG))
7808 return R;
7809
7810 if (VT.isScalarInteger() && VT != MVT::i1)
7811 if (SDValue R = foldMaskedMerge(Node: N, DAG, TLI, DL))
7812 return R;
7813
7814 return SDValue();
7815}
7816
7817/// Match (a >> 8) | (a << 8) as (bswap a) >> 16.
7818SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
7819 bool DemandHighBits) {
7820 if (!LegalOperations)
7821 return SDValue();
7822
7823 EVT VT = N->getValueType(ResNo: 0);
7824 if (VT != MVT::i64 && VT != MVT::i32 && VT != MVT::i16)
7825 return SDValue();
7826 if (!TLI.isOperationLegalOrCustom(Op: ISD::BSWAP, VT))
7827 return SDValue();
7828
7829 // Recognize (and (shl a, 8), 0xff00), (and (srl a, 8), 0xff)
7830 bool LookPassAnd0 = false;
7831 bool LookPassAnd1 = false;
7832 if (N0.getOpcode() == ISD::AND && N0.getOperand(i: 0).getOpcode() == ISD::SRL)
7833 std::swap(a&: N0, b&: N1);
7834 if (N1.getOpcode() == ISD::AND && N1.getOperand(i: 0).getOpcode() == ISD::SHL)
7835 std::swap(a&: N0, b&: N1);
7836 if (N0.getOpcode() == ISD::AND) {
7837 if (!N0->hasOneUse())
7838 return SDValue();
7839 ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
7840 // Also handle 0xffff since the LHS is guaranteed to have zeros there.
7841 // This is needed for X86.
7842 if (!N01C || (N01C->getZExtValue() != 0xFF00 &&
7843 N01C->getZExtValue() != 0xFFFF))
7844 return SDValue();
7845 N0 = N0.getOperand(i: 0);
7846 LookPassAnd0 = true;
7847 }
7848
7849 if (N1.getOpcode() == ISD::AND) {
7850 if (!N1->hasOneUse())
7851 return SDValue();
7852 ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(Val: N1.getOperand(i: 1));
7853 if (!N11C || N11C->getZExtValue() != 0xFF)
7854 return SDValue();
7855 N1 = N1.getOperand(i: 0);
7856 LookPassAnd1 = true;
7857 }
7858
7859 if (N0.getOpcode() == ISD::SRL && N1.getOpcode() == ISD::SHL)
7860 std::swap(a&: N0, b&: N1);
7861 if (N0.getOpcode() != ISD::SHL || N1.getOpcode() != ISD::SRL)
7862 return SDValue();
7863 if (!N0->hasOneUse() || !N1->hasOneUse())
7864 return SDValue();
7865
7866 ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
7867 ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(Val: N1.getOperand(i: 1));
7868 if (!N01C || !N11C)
7869 return SDValue();
7870 if (N01C->getZExtValue() != 8 || N11C->getZExtValue() != 8)
7871 return SDValue();
7872
7873 // Look for (shl (and a, 0xff), 8), (srl (and a, 0xff00), 8)
7874 SDValue N00 = N0->getOperand(Num: 0);
7875 if (!LookPassAnd0 && N00.getOpcode() == ISD::AND) {
7876 if (!N00->hasOneUse())
7877 return SDValue();
7878 ConstantSDNode *N001C = dyn_cast<ConstantSDNode>(Val: N00.getOperand(i: 1));
7879 if (!N001C || N001C->getZExtValue() != 0xFF)
7880 return SDValue();
7881 N00 = N00.getOperand(i: 0);
7882 LookPassAnd0 = true;
7883 }
7884
7885 SDValue N10 = N1->getOperand(Num: 0);
7886 if (!LookPassAnd1 && N10.getOpcode() == ISD::AND) {
7887 if (!N10->hasOneUse())
7888 return SDValue();
7889 ConstantSDNode *N101C = dyn_cast<ConstantSDNode>(Val: N10.getOperand(i: 1));
7890 // Also allow 0xFFFF since the bits will be shifted out. This is needed
7891 // for X86.
7892 if (!N101C || (N101C->getZExtValue() != 0xFF00 &&
7893 N101C->getZExtValue() != 0xFFFF))
7894 return SDValue();
7895 N10 = N10.getOperand(i: 0);
7896 LookPassAnd1 = true;
7897 }
7898
7899 if (N00 != N10)
7900 return SDValue();
7901
7902 // Make sure everything beyond the low halfword gets set to zero since the SRL
7903 // 16 will clear the top bits.
7904 unsigned OpSizeInBits = VT.getSizeInBits();
7905 if (OpSizeInBits > 16) {
7906 // If the left-shift isn't masked out then the only way this is a bswap is
7907 // if all bits beyond the low 8 are 0. In that case the entire pattern
7908 // reduces to a left shift anyway: leave it for other parts of the combiner.
7909 if (DemandHighBits && !LookPassAnd0)
7910 return SDValue();
7911
7912 // However, if the right shift isn't masked out then it might be because
7913 // it's not needed. See if we can spot that too. If the high bits aren't
7914 // demanded, we only need bits 23:16 to be zero. Otherwise, we need all
7915 // upper bits to be zero.
7916 if (!LookPassAnd1) {
7917 unsigned HighBit = DemandHighBits ? OpSizeInBits : 24;
7918 if (!DAG.MaskedValueIsZero(Op: N10,
7919 Mask: APInt::getBitsSet(numBits: OpSizeInBits, loBit: 16, hiBit: HighBit)))
7920 return SDValue();
7921 }
7922 }
7923
7924 SDValue Res = DAG.getNode(Opcode: ISD::BSWAP, DL: SDLoc(N), VT, Operand: N00);
7925 if (OpSizeInBits > 16) {
7926 SDLoc DL(N);
7927 Res = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Res,
7928 N2: DAG.getShiftAmountConstant(Val: OpSizeInBits - 16, VT, DL));
7929 }
7930 return Res;
7931}
7932
7933/// Return true if the specified node is an element that makes up a 32-bit
7934/// packed halfword byteswap.
7935/// ((x & 0x000000ff) << 8) |
7936/// ((x & 0x0000ff00) >> 8) |
7937/// ((x & 0x00ff0000) << 8) |
7938/// ((x & 0xff000000) >> 8)
7939static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) {
7940 if (!N->hasOneUse())
7941 return false;
7942
7943 unsigned Opc = N.getOpcode();
7944 if (Opc != ISD::AND && Opc != ISD::SHL && Opc != ISD::SRL)
7945 return false;
7946
7947 SDValue N0 = N.getOperand(i: 0);
7948 unsigned Opc0 = N0.getOpcode();
7949 if (Opc0 != ISD::AND && Opc0 != ISD::SHL && Opc0 != ISD::SRL)
7950 return false;
7951
7952 ConstantSDNode *N1C = nullptr;
7953 // SHL or SRL: look upstream for AND mask operand
7954 if (Opc == ISD::AND)
7955 N1C = dyn_cast<ConstantSDNode>(Val: N.getOperand(i: 1));
7956 else if (Opc0 == ISD::AND)
7957 N1C = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
7958 if (!N1C)
7959 return false;
7960
7961 unsigned MaskByteOffset;
7962 switch (N1C->getZExtValue()) {
7963 default:
7964 return false;
7965 case 0xFF: MaskByteOffset = 0; break;
7966 case 0xFF00: MaskByteOffset = 1; break;
7967 case 0xFFFF:
7968 // In case demanded bits didn't clear the bits that will be shifted out.
7969 // This is needed for X86.
7970 if (Opc == ISD::SRL || (Opc == ISD::AND && Opc0 == ISD::SHL)) {
7971 MaskByteOffset = 1;
7972 break;
7973 }
7974 return false;
7975 case 0xFF0000: MaskByteOffset = 2; break;
7976 case 0xFF000000: MaskByteOffset = 3; break;
7977 }
7978
7979 // Look for (x & 0xff) << 8 as well as ((x << 8) & 0xff00).
7980 if (Opc == ISD::AND) {
7981 if (MaskByteOffset == 0 || MaskByteOffset == 2) {
7982 // (x >> 8) & 0xff
7983 // (x >> 8) & 0xff0000
7984 if (Opc0 != ISD::SRL)
7985 return false;
7986 ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
7987 if (!C || C->getZExtValue() != 8)
7988 return false;
7989 } else {
7990 // (x << 8) & 0xff00
7991 // (x << 8) & 0xff000000
7992 if (Opc0 != ISD::SHL)
7993 return false;
7994 ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
7995 if (!C || C->getZExtValue() != 8)
7996 return false;
7997 }
7998 } else if (Opc == ISD::SHL) {
7999 // (x & 0xff) << 8
8000 // (x & 0xff0000) << 8
8001 if (MaskByteOffset != 0 && MaskByteOffset != 2)
8002 return false;
8003 ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val: N.getOperand(i: 1));
8004 if (!C || C->getZExtValue() != 8)
8005 return false;
8006 } else { // Opc == ISD::SRL
8007 // (x & 0xff00) >> 8
8008 // (x & 0xff000000) >> 8
8009 if (MaskByteOffset != 1 && MaskByteOffset != 3)
8010 return false;
8011 ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val: N.getOperand(i: 1));
8012 if (!C || C->getZExtValue() != 8)
8013 return false;
8014 }
8015
8016 if (Parts[MaskByteOffset])
8017 return false;
8018
8019 Parts[MaskByteOffset] = N0.getOperand(i: 0).getNode();
8020 return true;
8021}
8022
8023// Match 2 elements of a packed halfword bswap.
8024static bool isBSwapHWordPair(SDValue N, MutableArrayRef<SDNode *> Parts) {
8025 if (N.getOpcode() == ISD::OR)
8026 return isBSwapHWordElement(N: N.getOperand(i: 0), Parts) &&
8027 isBSwapHWordElement(N: N.getOperand(i: 1), Parts);
8028
8029 if (N.getOpcode() == ISD::SRL && N.getOperand(i: 0).getOpcode() == ISD::BSWAP) {
8030 ConstantSDNode *C = isConstOrConstSplat(N: N.getOperand(i: 1));
8031 if (!C || C->getAPIntValue() != 16)
8032 return false;
8033 Parts[0] = Parts[1] = N.getOperand(i: 0).getOperand(i: 0).getNode();
8034 return true;
8035 }
8036
8037 return false;
8038}
8039
8040// Match this pattern:
8041// (or (and (shl (A, 8)), 0xff00ff00), (and (srl (A, 8)), 0x00ff00ff))
8042// And rewrite this to:
8043// (rotr (bswap A), 16)
8044static SDValue matchBSwapHWordOrAndAnd(const TargetLowering &TLI,
8045 SelectionDAG &DAG, SDNode *N, SDValue N0,
8046 SDValue N1, EVT VT) {
8047 assert(N->getOpcode() == ISD::OR && VT == MVT::i32 &&
8048 "MatchBSwapHWordOrAndAnd: expecting i32");
8049 if (!TLI.isOperationLegalOrCustom(Op: ISD::ROTR, VT))
8050 return SDValue();
8051 if (N0.getOpcode() != ISD::AND || N1.getOpcode() != ISD::AND)
8052 return SDValue();
8053 // TODO: this is too restrictive; lifting this restriction requires more tests
8054 if (!N0->hasOneUse() || !N1->hasOneUse())
8055 return SDValue();
8056 ConstantSDNode *Mask0 = isConstOrConstSplat(N: N0.getOperand(i: 1));
8057 ConstantSDNode *Mask1 = isConstOrConstSplat(N: N1.getOperand(i: 1));
8058 if (!Mask0 || !Mask1)
8059 return SDValue();
8060 if (Mask0->getAPIntValue() != 0xff00ff00 ||
8061 Mask1->getAPIntValue() != 0x00ff00ff)
8062 return SDValue();
8063 SDValue Shift0 = N0.getOperand(i: 0);
8064 SDValue Shift1 = N1.getOperand(i: 0);
8065 if (Shift0.getOpcode() != ISD::SHL || Shift1.getOpcode() != ISD::SRL)
8066 return SDValue();
8067 ConstantSDNode *ShiftAmt0 = isConstOrConstSplat(N: Shift0.getOperand(i: 1));
8068 ConstantSDNode *ShiftAmt1 = isConstOrConstSplat(N: Shift1.getOperand(i: 1));
8069 if (!ShiftAmt0 || !ShiftAmt1)
8070 return SDValue();
8071 if (ShiftAmt0->getAPIntValue() != 8 || ShiftAmt1->getAPIntValue() != 8)
8072 return SDValue();
8073 if (Shift0.getOperand(i: 0) != Shift1.getOperand(i: 0))
8074 return SDValue();
8075
8076 SDLoc DL(N);
8077 SDValue BSwap = DAG.getNode(Opcode: ISD::BSWAP, DL, VT, Operand: Shift0.getOperand(i: 0));
8078 SDValue ShAmt = DAG.getShiftAmountConstant(Val: 16, VT, DL);
8079 return DAG.getNode(Opcode: ISD::ROTR, DL, VT, N1: BSwap, N2: ShAmt);
8080}
8081
8082/// Match a 32-bit packed halfword bswap. That is
8083/// ((x & 0x000000ff) << 8) |
8084/// ((x & 0x0000ff00) >> 8) |
8085/// ((x & 0x00ff0000) << 8) |
8086/// ((x & 0xff000000) >> 8)
8087/// => (rotl (bswap x), 16)
8088SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
8089 if (!LegalOperations)
8090 return SDValue();
8091
8092 EVT VT = N->getValueType(ResNo: 0);
8093 if (VT != MVT::i32)
8094 return SDValue();
8095 if (!TLI.isOperationLegalOrCustom(Op: ISD::BSWAP, VT))
8096 return SDValue();
8097
8098 if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0, N1, VT))
8099 return BSwap;
8100
8101 // Try again with commuted operands.
8102 if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0: N1, N1: N0, VT))
8103 return BSwap;
8104
8105
8106 // Look for either
8107 // (or (bswaphpair), (bswaphpair))
8108 // (or (or (bswaphpair), (and)), (and))
8109 // (or (or (and), (bswaphpair)), (and))
8110 SDNode *Parts[4] = {};
8111
8112 if (isBSwapHWordPair(N: N0, Parts)) {
8113 // (or (or (and), (and)), (or (and), (and)))
8114 if (!isBSwapHWordPair(N: N1, Parts))
8115 return SDValue();
8116 } else if (N0.getOpcode() == ISD::OR) {
8117 // (or (or (or (and), (and)), (and)), (and))
8118 if (!isBSwapHWordElement(N: N1, Parts))
8119 return SDValue();
8120 SDValue N00 = N0.getOperand(i: 0);
8121 SDValue N01 = N0.getOperand(i: 1);
8122 if (!(isBSwapHWordElement(N: N01, Parts) && isBSwapHWordPair(N: N00, Parts)) &&
8123 !(isBSwapHWordElement(N: N00, Parts) && isBSwapHWordPair(N: N01, Parts)))
8124 return SDValue();
8125 } else {
8126 return SDValue();
8127 }
8128
8129 // Make sure the parts are all coming from the same node.
8130 if (Parts[0] != Parts[1] || Parts[0] != Parts[2] || Parts[0] != Parts[3])
8131 return SDValue();
8132
8133 SDLoc DL(N);
8134 SDValue BSwap = DAG.getNode(Opcode: ISD::BSWAP, DL, VT,
8135 Operand: SDValue(Parts[0], 0));
8136
8137 // Result of the bswap should be rotated by 16. If it's not legal, then
8138 // do (x << 16) | (x >> 16).
8139 SDValue ShAmt = DAG.getShiftAmountConstant(Val: 16, VT, DL);
8140 if (TLI.isOperationLegalOrCustom(Op: ISD::ROTL, VT))
8141 return DAG.getNode(Opcode: ISD::ROTL, DL, VT, N1: BSwap, N2: ShAmt);
8142 if (TLI.isOperationLegalOrCustom(Op: ISD::ROTR, VT))
8143 return DAG.getNode(Opcode: ISD::ROTR, DL, VT, N1: BSwap, N2: ShAmt);
8144 return DAG.getNode(Opcode: ISD::OR, DL, VT,
8145 N1: DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: BSwap, N2: ShAmt),
8146 N2: DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: BSwap, N2: ShAmt));
8147}
8148
8149/// This contains all DAGCombine rules which reduce two values combined by
8150/// an Or operation to a single value \see visitANDLike().
8151SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, const SDLoc &DL) {
8152 EVT VT = N1.getValueType();
8153
8154 // fold (or x, undef) -> -1
8155 if (!LegalOperations && (N0.isUndef() || N1.isUndef()))
8156 return DAG.getAllOnesConstant(DL, VT);
8157
8158 if (SDValue V = foldLogicOfSetCCs(IsAnd: false, N0, N1, DL))
8159 return V;
8160
8161 // (or (and X, C1), (and Y, C2)) -> (and (or X, Y), C3) if possible.
8162 if (N0.getOpcode() == ISD::AND && N1.getOpcode() == ISD::AND &&
8163 // Don't increase # computations.
8164 (N0->hasOneUse() || N1->hasOneUse())) {
8165 // We can only do this xform if we know that bits from X that are set in C2
8166 // but not in C1 are already zero. Likewise for Y.
8167 if (const ConstantSDNode *N0O1C =
8168 getAsNonOpaqueConstant(N: N0.getOperand(i: 1))) {
8169 if (const ConstantSDNode *N1O1C =
8170 getAsNonOpaqueConstant(N: N1.getOperand(i: 1))) {
8171 // We can only do this xform if we know that bits from X that are set in
8172 // C2 but not in C1 are already zero. Likewise for Y.
8173 const APInt &LHSMask = N0O1C->getAPIntValue();
8174 const APInt &RHSMask = N1O1C->getAPIntValue();
8175
8176 if (DAG.MaskedValueIsZero(Op: N0.getOperand(i: 0), Mask: RHSMask&~LHSMask) &&
8177 DAG.MaskedValueIsZero(Op: N1.getOperand(i: 0), Mask: LHSMask&~RHSMask)) {
8178 SDValue X = DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N0), VT,
8179 N1: N0.getOperand(i: 0), N2: N1.getOperand(i: 0));
8180 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: X,
8181 N2: DAG.getConstant(Val: LHSMask | RHSMask, DL, VT));
8182 }
8183 }
8184 }
8185 }
8186
8187 // (or (and X, M), (and X, N)) -> (and X, (or M, N))
8188 if (N0.getOpcode() == ISD::AND &&
8189 N1.getOpcode() == ISD::AND &&
8190 N0.getOperand(i: 0) == N1.getOperand(i: 0) &&
8191 // Don't increase # computations.
8192 (N0->hasOneUse() || N1->hasOneUse())) {
8193 SDValue X = DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N0), VT,
8194 N1: N0.getOperand(i: 1), N2: N1.getOperand(i: 1));
8195 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0.getOperand(i: 0), N2: X);
8196 }
8197
8198 return SDValue();
8199}
8200
8201/// OR combines for which the commuted variant will be tried as well.
8202static SDValue visitORCommutative(SelectionDAG &DAG, SDValue N0, SDValue N1,
8203 SDNode *N) {
8204 EVT VT = N0.getValueType();
8205 unsigned BW = VT.getScalarSizeInBits();
8206 SDLoc DL(N);
8207
8208 auto peekThroughResize = [](SDValue V) {
8209 if (V->getOpcode() == ISD::ZERO_EXTEND || V->getOpcode() == ISD::TRUNCATE)
8210 return V->getOperand(Num: 0);
8211 return V;
8212 };
8213
8214 SDValue N0Resized = peekThroughResize(N0);
8215 if (N0Resized.getOpcode() == ISD::AND) {
8216 SDValue N1Resized = peekThroughResize(N1);
8217 SDValue N00 = N0Resized.getOperand(i: 0);
8218 SDValue N01 = N0Resized.getOperand(i: 1);
8219
8220 // fold or (and x, y), x --> x
8221 if (N00 == N1Resized || N01 == N1Resized)
8222 return N1;
8223
8224 // fold (or (and X, (xor Y, -1)), Y) -> (or X, Y)
8225 // TODO: Set AllowUndefs = true.
8226 if (SDValue NotOperand = getBitwiseNotOperand(V: N01, Mask: N00,
8227 /* AllowUndefs */ false)) {
8228 if (peekThroughResize(NotOperand) == N1Resized)
8229 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: DAG.getZExtOrTrunc(Op: N00, DL, VT),
8230 N2: N1);
8231 }
8232
8233 // fold (or (and (xor Y, -1), X), Y) -> (or X, Y)
8234 if (SDValue NotOperand = getBitwiseNotOperand(V: N00, Mask: N01,
8235 /* AllowUndefs */ false)) {
8236 if (peekThroughResize(NotOperand) == N1Resized)
8237 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: DAG.getZExtOrTrunc(Op: N01, DL, VT),
8238 N2: N1);
8239 }
8240 }
8241
8242 SDValue X, Y;
8243
8244 // fold or (xor X, N1), N1 --> or X, N1
8245 if (sd_match(N: N0, P: m_Xor(L: m_Value(N&: X), R: m_Specific(N: N1))))
8246 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: X, N2: N1);
8247
8248 // fold or (xor x, y), (x and/or y) --> or x, y
8249 if (sd_match(N: N0, P: m_Xor(L: m_Value(N&: X), R: m_Value(N&: Y))) &&
8250 (sd_match(N: N1, P: m_And(L: m_Specific(N: X), R: m_Specific(N: Y))) ||
8251 sd_match(N: N1, P: m_Or(L: m_Specific(N: X), R: m_Specific(N: Y)))))
8252 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: X, N2: Y);
8253
8254 if (SDValue R = foldLogicOfShifts(N, LogicOp: N0, ShiftOp: N1, DAG))
8255 return R;
8256
8257 auto peekThroughZext = [](SDValue V) {
8258 if (V->getOpcode() == ISD::ZERO_EXTEND)
8259 return V->getOperand(Num: 0);
8260 return V;
8261 };
8262
8263 // (fshl X, ?, Y) | (shl X, Y) --> fshl X, ?, Y
8264 if (N0.getOpcode() == ISD::FSHL && N1.getOpcode() == ISD::SHL &&
8265 N0.getOperand(i: 0) == N1.getOperand(i: 0) &&
8266 peekThroughZext(N0.getOperand(i: 2)) == peekThroughZext(N1.getOperand(i: 1)))
8267 return N0;
8268
8269 // (fshr ?, X, Y) | (srl X, Y) --> fshr ?, X, Y
8270 if (N0.getOpcode() == ISD::FSHR && N1.getOpcode() == ISD::SRL &&
8271 N0.getOperand(i: 1) == N1.getOperand(i: 0) &&
8272 peekThroughZext(N0.getOperand(i: 2)) == peekThroughZext(N1.getOperand(i: 1)))
8273 return N0;
8274
8275 // Attempt to match a legalized build_pair-esque pattern:
8276 // or(shl(aext(Hi),BW/2),zext(Lo))
8277 SDValue Lo, Hi;
8278 if (sd_match(N: N0,
8279 P: m_OneUse(P: m_Shl(L: m_AnyExt(Op: m_Value(N&: Hi)), R: m_SpecificInt(V: BW / 2)))) &&
8280 sd_match(N: N1, P: m_ZExt(Op: m_Value(N&: Lo))) &&
8281 Lo.getScalarValueSizeInBits() == (BW / 2) &&
8282 Lo.getValueType() == Hi.getValueType()) {
8283 // Fold build_pair(not(Lo),not(Hi)) -> not(build_pair(Lo,Hi)).
8284 SDValue NotLo, NotHi;
8285 if (sd_match(N: Lo, P: m_OneUse(P: m_Not(V: m_Value(N&: NotLo)))) &&
8286 sd_match(N: Hi, P: m_OneUse(P: m_Not(V: m_Value(N&: NotHi))))) {
8287 Lo = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: NotLo);
8288 Hi = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT, Operand: NotHi);
8289 Hi = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Hi,
8290 N2: DAG.getShiftAmountConstant(Val: BW / 2, VT, DL));
8291 return DAG.getNOT(DL, Val: DAG.getNode(Opcode: ISD::OR, DL, VT, N1: Lo, N2: Hi), VT);
8292 }
8293 }
8294
8295 return SDValue();
8296}
8297
8298SDValue DAGCombiner::visitOR(SDNode *N) {
8299 SDValue N0 = N->getOperand(Num: 0);
8300 SDValue N1 = N->getOperand(Num: 1);
8301 EVT VT = N1.getValueType();
8302 SDLoc DL(N);
8303
8304 // x | x --> x
8305 if (N0 == N1)
8306 return N0;
8307
8308 // fold (or c1, c2) -> c1|c2
8309 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::OR, DL, VT, Ops: {N0, N1}))
8310 return C;
8311
8312 // canonicalize constant to RHS
8313 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
8314 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
8315 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1, N2: N0);
8316
8317 // fold vector ops
8318 if (VT.isVector()) {
8319 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
8320 return FoldedVOp;
8321
8322 // fold (or x, 0) -> x, vector edition
8323 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
8324 return N0;
8325
8326 // fold (or x, -1) -> -1, vector edition
8327 if (ISD::isConstantSplatVectorAllOnes(N: N1.getNode()))
8328 // do not return N1, because undef node may exist in N1
8329 return DAG.getAllOnesConstant(DL, VT: N1.getValueType());
8330
8331 // fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf A, B, Mask)
8332 // Do this only if the resulting type / shuffle is legal.
8333 auto *SV0 = dyn_cast<ShuffleVectorSDNode>(Val&: N0);
8334 auto *SV1 = dyn_cast<ShuffleVectorSDNode>(Val&: N1);
8335 if (SV0 && SV1 && TLI.isTypeLegal(VT)) {
8336 bool ZeroN00 = ISD::isBuildVectorAllZeros(N: N0.getOperand(i: 0).getNode());
8337 bool ZeroN01 = ISD::isBuildVectorAllZeros(N: N0.getOperand(i: 1).getNode());
8338 bool ZeroN10 = ISD::isBuildVectorAllZeros(N: N1.getOperand(i: 0).getNode());
8339 bool ZeroN11 = ISD::isBuildVectorAllZeros(N: N1.getOperand(i: 1).getNode());
8340 // Ensure both shuffles have a zero input.
8341 if ((ZeroN00 != ZeroN01) && (ZeroN10 != ZeroN11)) {
8342 assert((!ZeroN00 || !ZeroN01) && "Both inputs zero!");
8343 assert((!ZeroN10 || !ZeroN11) && "Both inputs zero!");
8344 bool CanFold = true;
8345 int NumElts = VT.getVectorNumElements();
8346 SmallVector<int, 4> Mask(NumElts, -1);
8347
8348 for (int i = 0; i != NumElts; ++i) {
8349 int M0 = SV0->getMaskElt(Idx: i);
8350 int M1 = SV1->getMaskElt(Idx: i);
8351
8352 // Determine if either index is pointing to a zero vector.
8353 bool M0Zero = M0 < 0 || (ZeroN00 == (M0 < NumElts));
8354 bool M1Zero = M1 < 0 || (ZeroN10 == (M1 < NumElts));
8355
8356 // If one element is zero and the otherside is undef, keep undef.
8357 // This also handles the case that both are undef.
8358 if ((M0Zero && M1 < 0) || (M1Zero && M0 < 0))
8359 continue;
8360
8361 // Make sure only one of the elements is zero.
8362 if (M0Zero == M1Zero) {
8363 CanFold = false;
8364 break;
8365 }
8366
8367 assert((M0 >= 0 || M1 >= 0) && "Undef index!");
8368
8369 // We have a zero and non-zero element. If the non-zero came from
8370 // SV0 make the index a LHS index. If it came from SV1, make it
8371 // a RHS index. We need to mod by NumElts because we don't care
8372 // which operand it came from in the original shuffles.
8373 Mask[i] = M1Zero ? M0 % NumElts : (M1 % NumElts) + NumElts;
8374 }
8375
8376 if (CanFold) {
8377 SDValue NewLHS = ZeroN00 ? N0.getOperand(i: 1) : N0.getOperand(i: 0);
8378 SDValue NewRHS = ZeroN10 ? N1.getOperand(i: 1) : N1.getOperand(i: 0);
8379 SDValue LegalShuffle =
8380 TLI.buildLegalVectorShuffle(VT, DL, N0: NewLHS, N1: NewRHS, Mask, DAG);
8381 if (LegalShuffle)
8382 return LegalShuffle;
8383 }
8384 }
8385 }
8386 }
8387
8388 // fold (or x, 0) -> x
8389 if (isNullConstant(V: N1))
8390 return N0;
8391
8392 // fold (or x, -1) -> -1
8393 if (isAllOnesConstant(V: N1))
8394 return N1;
8395
8396 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
8397 return NewSel;
8398
8399 // fold (or x, c) -> c iff (x & ~c) == 0
8400 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(Val&: N1);
8401 if (N1C && DAG.MaskedValueIsZero(Op: N0, Mask: ~N1C->getAPIntValue()))
8402 return N1;
8403
8404 if (SDValue R = foldAndOrOfSETCC(LogicOp: N, DAG))
8405 return R;
8406
8407 if (SDValue Combined = visitORLike(N0, N1, DL))
8408 return Combined;
8409
8410 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
8411 return Combined;
8412
8413 // Recognize halfword bswaps as (bswap + rotl 16) or (bswap + shl 16)
8414 if (SDValue BSwap = MatchBSwapHWord(N, N0, N1))
8415 return BSwap;
8416 if (SDValue BSwap = MatchBSwapHWordLow(N, N0, N1))
8417 return BSwap;
8418
8419 // reassociate or
8420 if (SDValue ROR = reassociateOps(Opc: ISD::OR, DL, N0, N1, Flags: N->getFlags()))
8421 return ROR;
8422
8423 // Fold or(vecreduce(x), vecreduce(y)) -> vecreduce(or(x, y))
8424 if (SDValue SD =
8425 reassociateReduction(RedOpc: ISD::VECREDUCE_OR, Opc: ISD::OR, DL, VT, N0, N1))
8426 return SD;
8427
8428 // Canonicalize (or (and X, c1), c2) -> (and (or X, c2), c1|c2)
8429 // iff (c1 & c2) != 0 or c1/c2 are undef.
8430 auto MatchIntersect = [](ConstantSDNode *C1, ConstantSDNode *C2) {
8431 return !C1 || !C2 || C1->getAPIntValue().intersects(RHS: C2->getAPIntValue());
8432 };
8433 if (N0.getOpcode() == ISD::AND && N0->hasOneUse() &&
8434 ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchIntersect, AllowUndefs: true)) {
8435 if (SDValue COR = DAG.FoldConstantArithmetic(Opcode: ISD::OR, DL: SDLoc(N1), VT,
8436 Ops: {N1, N0.getOperand(i: 1)})) {
8437 SDValue IOR = DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N0), VT, N1: N0.getOperand(i: 0), N2: N1);
8438 AddToWorklist(N: IOR.getNode());
8439 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: COR, N2: IOR);
8440 }
8441 }
8442
8443 if (SDValue Combined = visitORCommutative(DAG, N0, N1, N))
8444 return Combined;
8445 if (SDValue Combined = visitORCommutative(DAG, N0: N1, N1: N0, N))
8446 return Combined;
8447
8448 // Simplify: (or (op x...), (op y...)) -> (op (or x, y))
8449 if (N0.getOpcode() == N1.getOpcode())
8450 if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
8451 return V;
8452
8453 // See if this is some rotate idiom.
8454 if (SDValue Rot = MatchRotate(LHS: N0, RHS: N1, DL, /*FromAdd=*/false))
8455 return Rot;
8456
8457 if (SDValue Load = MatchLoadCombine(N))
8458 return Load;
8459
8460 // Simplify the operands using demanded-bits information.
8461 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
8462 return SDValue(N, 0);
8463
8464 // If OR can be rewritten into ADD, try combines based on ADD.
8465 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::ADD, VT)) &&
8466 DAG.isADDLike(Op: SDValue(N, 0)))
8467 if (SDValue Combined = visitADDLike(N))
8468 return Combined;
8469
8470 // Postpone until legalization completed to avoid interference with bswap
8471 // folding
8472 if (LegalOperations || VT.isVector())
8473 if (SDValue R = foldLogicTreeOfShifts(N, LeftHand: N0, RightHand: N1, DAG))
8474 return R;
8475
8476 if (VT.isScalarInteger() && VT != MVT::i1)
8477 if (SDValue R = foldMaskedMerge(Node: N, DAG, TLI, DL))
8478 return R;
8479
8480 return SDValue();
8481}
8482
8483static SDValue stripConstantMask(const SelectionDAG &DAG, SDValue Op,
8484 SDValue &Mask) {
8485 if (Op.getOpcode() == ISD::AND &&
8486 DAG.isConstantIntBuildVectorOrConstantInt(N: Op.getOperand(i: 1))) {
8487 Mask = Op.getOperand(i: 1);
8488 return Op.getOperand(i: 0);
8489 }
8490 return Op;
8491}
8492
8493/// Match "(X shl/srl V1) & V2" where V2 may not be present.
8494static bool matchRotateHalf(const SelectionDAG &DAG, SDValue Op, SDValue &Shift,
8495 SDValue &Mask) {
8496 Op = stripConstantMask(DAG, Op, Mask);
8497 if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SHL) {
8498 Shift = Op;
8499 return true;
8500 }
8501 return false;
8502}
8503
8504/// Helper function for visitOR to extract the needed side of a rotate idiom
8505/// from a shl/srl/mul/udiv. This is meant to handle cases where
8506/// InstCombine merged some outside op with one of the shifts from
8507/// the rotate pattern.
8508/// \returns An empty \c SDValue if the needed shift couldn't be extracted.
8509/// Otherwise, returns an expansion of \p ExtractFrom based on the following
8510/// patterns:
8511///
8512/// (or (add v v) (shrl v bitwidth-1)):
8513/// expands (add v v) -> (shl v 1)
8514///
8515/// (or (mul v c0) (shrl (mul v c1) c2)):
8516/// expands (mul v c0) -> (shl (mul v c1) c3)
8517///
8518/// (or (udiv v c0) (shl (udiv v c1) c2)):
8519/// expands (udiv v c0) -> (shrl (udiv v c1) c3)
8520///
8521/// (or (shl v c0) (shrl (shl v c1) c2)):
8522/// expands (shl v c0) -> (shl (shl v c1) c3)
8523///
8524/// (or (shrl v c0) (shl (shrl v c1) c2)):
8525/// expands (shrl v c0) -> (shrl (shrl v c1) c3)
8526///
8527/// Such that in all cases, c3+c2==bitwidth(op v c1).
8528static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift,
8529 SDValue ExtractFrom, SDValue &Mask,
8530 const SDLoc &DL) {
8531 assert(OppShift && ExtractFrom && "Empty SDValue");
8532 if (OppShift.getOpcode() != ISD::SHL && OppShift.getOpcode() != ISD::SRL)
8533 return SDValue();
8534
8535 ExtractFrom = stripConstantMask(DAG, Op: ExtractFrom, Mask);
8536
8537 // Value and Type of the shift.
8538 SDValue OppShiftLHS = OppShift.getOperand(i: 0);
8539 EVT ShiftedVT = OppShiftLHS.getValueType();
8540
8541 // Amount of the existing shift.
8542 ConstantSDNode *OppShiftCst = isConstOrConstSplat(N: OppShift.getOperand(i: 1));
8543
8544 // (add v v) -> (shl v 1)
8545 // TODO: Should this be a general DAG canonicalization?
8546 if (OppShift.getOpcode() == ISD::SRL && OppShiftCst &&
8547 ExtractFrom.getOpcode() == ISD::ADD &&
8548 ExtractFrom.getOperand(i: 0) == ExtractFrom.getOperand(i: 1) &&
8549 ExtractFrom.getOperand(i: 0) == OppShiftLHS &&
8550 OppShiftCst->getAPIntValue() == ShiftedVT.getScalarSizeInBits() - 1)
8551 return DAG.getNode(Opcode: ISD::SHL, DL, VT: ShiftedVT, N1: OppShiftLHS,
8552 N2: DAG.getShiftAmountConstant(Val: 1, VT: ShiftedVT, DL));
8553
8554 // Preconditions:
8555 // (or (op0 v c0) (shiftl/r (op0 v c1) c2))
8556 //
8557 // Find opcode of the needed shift to be extracted from (op0 v c0).
8558 unsigned Opcode = ISD::DELETED_NODE;
8559 bool IsMulOrDiv = false;
8560 // Set Opcode and IsMulOrDiv if the extract opcode matches the needed shift
8561 // opcode or its arithmetic (mul or udiv) variant.
8562 auto SelectOpcode = [&](unsigned NeededShift, unsigned MulOrDivVariant) {
8563 IsMulOrDiv = ExtractFrom.getOpcode() == MulOrDivVariant;
8564 if (!IsMulOrDiv && ExtractFrom.getOpcode() != NeededShift)
8565 return false;
8566 Opcode = NeededShift;
8567 return true;
8568 };
8569 // op0 must be either the needed shift opcode or the mul/udiv equivalent
8570 // that the needed shift can be extracted from.
8571 if ((OppShift.getOpcode() != ISD::SRL || !SelectOpcode(ISD::SHL, ISD::MUL)) &&
8572 (OppShift.getOpcode() != ISD::SHL || !SelectOpcode(ISD::SRL, ISD::UDIV)))
8573 return SDValue();
8574
8575 // op0 must be the same opcode on both sides, have the same LHS argument,
8576 // and produce the same value type.
8577 if (OppShiftLHS.getOpcode() != ExtractFrom.getOpcode() ||
8578 OppShiftLHS.getOperand(i: 0) != ExtractFrom.getOperand(i: 0) ||
8579 ShiftedVT != ExtractFrom.getValueType())
8580 return SDValue();
8581
8582 // Constant mul/udiv/shift amount from the RHS of the shift's LHS op.
8583 ConstantSDNode *OppLHSCst = isConstOrConstSplat(N: OppShiftLHS.getOperand(i: 1));
8584 // Constant mul/udiv/shift amount from the RHS of the ExtractFrom op.
8585 ConstantSDNode *ExtractFromCst =
8586 isConstOrConstSplat(N: ExtractFrom.getOperand(i: 1));
8587 // TODO: We should be able to handle non-uniform constant vectors for these values
8588 // Check that we have constant values.
8589 if (!OppShiftCst || !OppShiftCst->getAPIntValue() ||
8590 !OppLHSCst || !OppLHSCst->getAPIntValue() ||
8591 !ExtractFromCst || !ExtractFromCst->getAPIntValue())
8592 return SDValue();
8593
8594 // Compute the shift amount we need to extract to complete the rotate.
8595 const unsigned VTWidth = ShiftedVT.getScalarSizeInBits();
8596 if (OppShiftCst->getAPIntValue().ugt(RHS: VTWidth))
8597 return SDValue();
8598 APInt NeededShiftAmt = VTWidth - OppShiftCst->getAPIntValue();
8599 // Normalize the bitwidth of the two mul/udiv/shift constant operands.
8600 APInt ExtractFromAmt = ExtractFromCst->getAPIntValue();
8601 APInt OppLHSAmt = OppLHSCst->getAPIntValue();
8602 zeroExtendToMatch(LHS&: ExtractFromAmt, RHS&: OppLHSAmt);
8603
8604 // Now try extract the needed shift from the ExtractFrom op and see if the
8605 // result matches up with the existing shift's LHS op.
8606 if (IsMulOrDiv) {
8607 // Op to extract from is a mul or udiv by a constant.
8608 // Check:
8609 // c2 / (1 << (bitwidth(op0 v c0) - c1)) == c0
8610 // c2 % (1 << (bitwidth(op0 v c0) - c1)) == 0
8611 const APInt ExtractDiv = APInt::getOneBitSet(numBits: ExtractFromAmt.getBitWidth(),
8612 BitNo: NeededShiftAmt.getZExtValue());
8613 APInt ResultAmt;
8614 APInt Rem;
8615 APInt::udivrem(LHS: ExtractFromAmt, RHS: ExtractDiv, Quotient&: ResultAmt, Remainder&: Rem);
8616 if (Rem != 0 || ResultAmt != OppLHSAmt)
8617 return SDValue();
8618 } else {
8619 // Op to extract from is a shift by a constant.
8620 // Check:
8621 // c2 - (bitwidth(op0 v c0) - c1) == c0
8622 if (OppLHSAmt != ExtractFromAmt - NeededShiftAmt.zextOrTrunc(
8623 width: ExtractFromAmt.getBitWidth()))
8624 return SDValue();
8625 }
8626
8627 // Return the expanded shift op that should allow a rotate to be formed.
8628 EVT ShiftVT = OppShift.getOperand(i: 1).getValueType();
8629 EVT ResVT = ExtractFrom.getValueType();
8630 SDValue NewShiftNode = DAG.getConstant(Val: NeededShiftAmt, DL, VT: ShiftVT);
8631 return DAG.getNode(Opcode, DL, VT: ResVT, N1: OppShiftLHS, N2: NewShiftNode);
8632}
8633
8634// Return true if we can prove that, whenever Neg and Pos are both in the
8635// range [0, EltSize), Neg == (Pos == 0 ? 0 : EltSize - Pos). This means that
8636// for two opposing shifts shift1 and shift2 and a value X with OpBits bits:
8637//
8638// (or (shift1 X, Neg), (shift2 X, Pos))
8639//
8640// reduces to a rotate in direction shift2 by Pos or (equivalently) a rotate
8641// in direction shift1 by Neg. The range [0, EltSize) means that we only need
8642// to consider shift amounts with defined behavior.
8643//
8644// The IsRotate flag should be set when the LHS of both shifts is the same.
8645// Otherwise if matching a general funnel shift, it should be clear.
8646static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
8647 SelectionDAG &DAG, bool IsRotate, bool FromAdd) {
8648 const auto &TLI = DAG.getTargetLoweringInfo();
8649 // If EltSize is a power of 2 then:
8650 //
8651 // (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1)
8652 // (b) Neg == Neg & (EltSize - 1) whenever Neg is in [0, EltSize).
8653 //
8654 // So if EltSize is a power of 2 and Neg is (and Neg', EltSize-1), we check
8655 // for the stronger condition:
8656 //
8657 // Neg & (EltSize - 1) == (EltSize - Pos) & (EltSize - 1) [A]
8658 //
8659 // for all Neg and Pos. Since Neg & (EltSize - 1) == Neg' & (EltSize - 1)
8660 // we can just replace Neg with Neg' for the rest of the function.
8661 //
8662 // In other cases we check for the even stronger condition:
8663 //
8664 // Neg == EltSize - Pos [B]
8665 //
8666 // for all Neg and Pos. Note that the (or ...) then invokes undefined
8667 // behavior if Pos == 0 (and consequently Neg == EltSize).
8668 //
8669 // We could actually use [A] whenever EltSize is a power of 2, but the
8670 // only extra cases that it would match are those uninteresting ones
8671 // where Neg and Pos are never in range at the same time. E.g. for
8672 // EltSize == 32, using [A] would allow a Neg of the form (sub 64, Pos)
8673 // as well as (sub 32, Pos), but:
8674 //
8675 // (or (shift1 X, (sub 64, Pos)), (shift2 X, Pos))
8676 //
8677 // always invokes undefined behavior for 32-bit X.
8678 //
8679 // Below, Mask == EltSize - 1 when using [A] and is all-ones otherwise.
8680 // This allows us to peek through any operations that only affect Mask's
8681 // un-demanded bits.
8682 //
8683 // NOTE: We can only do this when matching operations which won't modify the
8684 // least Log2(EltSize) significant bits and not a general funnel shift.
8685 unsigned MaskLoBits = 0;
8686 if (IsRotate && !FromAdd && isPowerOf2_64(Value: EltSize)) {
8687 unsigned Bits = Log2_64(Value: EltSize);
8688 unsigned NegBits = Neg.getScalarValueSizeInBits();
8689 if (NegBits >= Bits) {
8690 APInt DemandedBits = APInt::getLowBitsSet(numBits: NegBits, loBitsSet: Bits);
8691 if (SDValue Inner =
8692 TLI.SimplifyMultipleUseDemandedBits(Op: Neg, DemandedBits, DAG)) {
8693 Neg = Inner;
8694 MaskLoBits = Bits;
8695 }
8696 }
8697 }
8698
8699 // Check whether Neg has the form (sub NegC, NegOp1) for some NegC and NegOp1.
8700 if (Neg.getOpcode() != ISD::SUB)
8701 return false;
8702 ConstantSDNode *NegC = isConstOrConstSplat(N: Neg.getOperand(i: 0));
8703 if (!NegC)
8704 return false;
8705 SDValue NegOp1 = Neg.getOperand(i: 1);
8706
8707 // On the RHS of [A], if Pos is the result of operation on Pos' that won't
8708 // affect Mask's demanded bits, just replace Pos with Pos'. These operations
8709 // are redundant for the purpose of the equality.
8710 if (MaskLoBits) {
8711 unsigned PosBits = Pos.getScalarValueSizeInBits();
8712 if (PosBits >= MaskLoBits) {
8713 APInt DemandedBits = APInt::getLowBitsSet(numBits: PosBits, loBitsSet: MaskLoBits);
8714 if (SDValue Inner =
8715 TLI.SimplifyMultipleUseDemandedBits(Op: Pos, DemandedBits, DAG)) {
8716 Pos = Inner;
8717 }
8718 }
8719 }
8720
8721 // The condition we need is now:
8722 //
8723 // (NegC - NegOp1) & Mask == (EltSize - Pos) & Mask
8724 //
8725 // If NegOp1 == Pos then we need:
8726 //
8727 // EltSize & Mask == NegC & Mask
8728 //
8729 // (because "x & Mask" is a truncation and distributes through subtraction).
8730 //
8731 // We also need to account for a potential truncation of NegOp1 if the amount
8732 // has already been legalized to a shift amount type.
8733 APInt Width;
8734 if ((Pos == NegOp1) ||
8735 (NegOp1.getOpcode() == ISD::TRUNCATE && Pos == NegOp1.getOperand(i: 0)))
8736 Width = NegC->getAPIntValue();
8737
8738 // Check for cases where Pos has the form (add NegOp1, PosC) for some PosC.
8739 // Then the condition we want to prove becomes:
8740 //
8741 // (NegC - NegOp1) & Mask == (EltSize - (NegOp1 + PosC)) & Mask
8742 //
8743 // which, again because "x & Mask" is a truncation, becomes:
8744 //
8745 // NegC & Mask == (EltSize - PosC) & Mask
8746 // EltSize & Mask == (NegC + PosC) & Mask
8747 else if (Pos.getOpcode() == ISD::ADD && Pos.getOperand(i: 0) == NegOp1) {
8748 if (ConstantSDNode *PosC = isConstOrConstSplat(N: Pos.getOperand(i: 1)))
8749 Width = PosC->getAPIntValue() + NegC->getAPIntValue();
8750 else
8751 return false;
8752 } else
8753 return false;
8754
8755 // Now we just need to check that EltSize & Mask == Width & Mask.
8756 if (MaskLoBits)
8757 // EltSize & Mask is 0 since Mask is EltSize - 1.
8758 return Width.getLoBits(numBits: MaskLoBits) == 0;
8759 return Width == EltSize;
8760}
8761
8762// A subroutine of MatchRotate used once we have found an OR of two opposite
8763// shifts of Shifted. If Neg == <operand size> - Pos then the OR reduces
8764// to both (PosOpcode Shifted, Pos) and (NegOpcode Shifted, Neg), with the
8765// former being preferred if supported. InnerPos and InnerNeg are Pos and
8766// Neg with outer conversions stripped away.
8767SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
8768 SDValue Neg, SDValue InnerPos,
8769 SDValue InnerNeg, bool FromAdd,
8770 bool HasPos, unsigned PosOpcode,
8771 unsigned NegOpcode, const SDLoc &DL) {
8772 // fold (or/add (shl x, (*ext y)),
8773 // (srl x, (*ext (sub 32, y)))) ->
8774 // (rotl x, y) or (rotr x, (sub 32, y))
8775 //
8776 // fold (or/add (shl x, (*ext (sub 32, y))),
8777 // (srl x, (*ext y))) ->
8778 // (rotr x, y) or (rotl x, (sub 32, y))
8779 EVT VT = Shifted.getValueType();
8780 if (matchRotateSub(Pos: InnerPos, Neg: InnerNeg, EltSize: VT.getScalarSizeInBits(), DAG,
8781 /*IsRotate*/ true, FromAdd))
8782 return DAG.getNode(Opcode: HasPos ? PosOpcode : NegOpcode, DL, VT, N1: Shifted,
8783 N2: HasPos ? Pos : Neg);
8784
8785 return SDValue();
8786}
8787
8788// A subroutine of MatchRotate used once we have found an OR of two opposite
8789// shifts of N0 + N1. If Neg == <operand size> - Pos then the OR reduces
8790// to both (PosOpcode N0, N1, Pos) and (NegOpcode N0, N1, Neg), with the
8791// former being preferred if supported. InnerPos and InnerNeg are Pos and
8792// Neg with outer conversions stripped away.
8793// TODO: Merge with MatchRotatePosNeg.
8794SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
8795 SDValue Neg, SDValue InnerPos,
8796 SDValue InnerNeg, bool FromAdd,
8797 bool HasPos, unsigned PosOpcode,
8798 unsigned NegOpcode, const SDLoc &DL) {
8799 EVT VT = N0.getValueType();
8800 unsigned EltBits = VT.getScalarSizeInBits();
8801
8802 // fold (or/add (shl x0, (*ext y)),
8803 // (srl x1, (*ext (sub 32, y)))) ->
8804 // (fshl x0, x1, y) or (fshr x0, x1, (sub 32, y))
8805 //
8806 // fold (or/add (shl x0, (*ext (sub 32, y))),
8807 // (srl x1, (*ext y))) ->
8808 // (fshr x0, x1, y) or (fshl x0, x1, (sub 32, y))
8809 if (matchRotateSub(Pos: InnerPos, Neg: InnerNeg, EltSize: EltBits, DAG, /*IsRotate*/ N0 == N1,
8810 FromAdd))
8811 return DAG.getNode(Opcode: HasPos ? PosOpcode : NegOpcode, DL, VT, N1: N0, N2: N1,
8812 N3: HasPos ? Pos : Neg);
8813
8814 // Matching the shift+xor cases, we can't easily use the xor'd shift amount
8815 // so for now just use the PosOpcode case if its legal.
8816 // TODO: When can we use the NegOpcode case?
8817 if (PosOpcode == ISD::FSHL && isPowerOf2_32(Value: EltBits)) {
8818 SDValue X;
8819 // fold (or/add (shl x0, y), (srl (srl x1, 1), (xor y, 31)))
8820 // -> (fshl x0, x1, y)
8821 if (sd_match(N: N1, P: m_Srl(L: m_Value(N&: X), R: m_One())) &&
8822 sd_match(N: InnerNeg,
8823 P: m_Xor(L: m_Specific(N: InnerPos), R: m_SpecificInt(V: EltBits - 1))) &&
8824 TLI.isOperationLegalOrCustom(Op: ISD::FSHL, VT)) {
8825 return DAG.getNode(Opcode: ISD::FSHL, DL, VT, N1: N0, N2: X, N3: Pos);
8826 }
8827
8828 // fold (or/add (shl (shl x0, 1), (xor y, 31)), (srl x1, y))
8829 // -> (fshr x0, x1, y)
8830 if (sd_match(N: N0, P: m_Shl(L: m_Value(N&: X), R: m_One())) &&
8831 sd_match(N: InnerPos,
8832 P: m_Xor(L: m_Specific(N: InnerNeg), R: m_SpecificInt(V: EltBits - 1))) &&
8833 TLI.isOperationLegalOrCustom(Op: ISD::FSHR, VT)) {
8834 return DAG.getNode(Opcode: ISD::FSHR, DL, VT, N1: X, N2: N1, N3: Neg);
8835 }
8836
8837 // fold (or/add (shl (add x0, x0), (xor y, 31)), (srl x1, y))
8838 // -> (fshr x0, x1, y)
8839 // TODO: Should add(x,x) -> shl(x,1) be a general DAG canonicalization?
8840 if (sd_match(N: N0, P: m_Add(L: m_Value(N&: X), R: m_Deferred(V&: X))) &&
8841 sd_match(N: InnerPos,
8842 P: m_Xor(L: m_Specific(N: InnerNeg), R: m_SpecificInt(V: EltBits - 1))) &&
8843 TLI.isOperationLegalOrCustom(Op: ISD::FSHR, VT)) {
8844 return DAG.getNode(Opcode: ISD::FSHR, DL, VT, N1: X, N2: N1, N3: Neg);
8845 }
8846 }
8847
8848 return SDValue();
8849}
8850
8851// MatchRotate - Handle an 'or' or 'add' of two operands. If this is one of the
8852// many idioms for rotate, and if the target supports rotation instructions,
8853// generate a rot[lr]. This also matches funnel shift patterns, similar to
8854// rotation but with different shifted sources.
8855SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL,
8856 bool FromAdd) {
8857 EVT VT = LHS.getValueType();
8858
8859 // The target must have at least one rotate/funnel flavor.
8860 // We still try to match rotate by constant pre-legalization.
8861 // TODO: Support pre-legalization funnel-shift by constant.
8862 bool HasROTL = hasOperation(Opcode: ISD::ROTL, VT);
8863 bool HasROTR = hasOperation(Opcode: ISD::ROTR, VT);
8864 bool HasFSHL = hasOperation(Opcode: ISD::FSHL, VT);
8865 bool HasFSHR = hasOperation(Opcode: ISD::FSHR, VT);
8866
8867 // If the type is going to be promoted and the target has enabled custom
8868 // lowering for rotate, allow matching rotate by non-constants. Only allow
8869 // this for scalar types.
8870 if (VT.isScalarInteger() && TLI.getTypeAction(Context&: *DAG.getContext(), VT) ==
8871 TargetLowering::TypePromoteInteger) {
8872 HasROTL |= TLI.getOperationAction(Op: ISD::ROTL, VT) == TargetLowering::Custom;
8873 HasROTR |= TLI.getOperationAction(Op: ISD::ROTR, VT) == TargetLowering::Custom;
8874 }
8875
8876 if (LegalOperations && !HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
8877 return SDValue();
8878
8879 // Check for truncated rotate.
8880 if (LHS.getOpcode() == ISD::TRUNCATE && RHS.getOpcode() == ISD::TRUNCATE &&
8881 LHS.getOperand(i: 0).getValueType() == RHS.getOperand(i: 0).getValueType()) {
8882 assert(LHS.getValueType() == RHS.getValueType());
8883 if (SDValue Rot =
8884 MatchRotate(LHS: LHS.getOperand(i: 0), RHS: RHS.getOperand(i: 0), DL, FromAdd))
8885 return DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(LHS), VT: LHS.getValueType(), Operand: Rot);
8886 }
8887
8888 // Match "(X shl/srl V1) & V2" where V2 may not be present.
8889 SDValue LHSShift; // The shift.
8890 SDValue LHSMask; // AND value if any.
8891 matchRotateHalf(DAG, Op: LHS, Shift&: LHSShift, Mask&: LHSMask);
8892
8893 SDValue RHSShift; // The shift.
8894 SDValue RHSMask; // AND value if any.
8895 matchRotateHalf(DAG, Op: RHS, Shift&: RHSShift, Mask&: RHSMask);
8896
8897 // If neither side matched a rotate half, bail
8898 if (!LHSShift && !RHSShift)
8899 return SDValue();
8900
8901 // InstCombine may have combined a constant shl, srl, mul, or udiv with one
8902 // side of the rotate, so try to handle that here. In all cases we need to
8903 // pass the matched shift from the opposite side to compute the opcode and
8904 // needed shift amount to extract. We still want to do this if both sides
8905 // matched a rotate half because one half may be a potential overshift that
8906 // can be broken down (ie if InstCombine merged two shl or srl ops into a
8907 // single one).
8908
8909 // Have LHS side of the rotate, try to extract the needed shift from the RHS.
8910 if (LHSShift)
8911 if (SDValue NewRHSShift =
8912 extractShiftForRotate(DAG, OppShift: LHSShift, ExtractFrom: RHS, Mask&: RHSMask, DL))
8913 RHSShift = NewRHSShift;
8914 // Have RHS side of the rotate, try to extract the needed shift from the LHS.
8915 if (RHSShift)
8916 if (SDValue NewLHSShift =
8917 extractShiftForRotate(DAG, OppShift: RHSShift, ExtractFrom: LHS, Mask&: LHSMask, DL))
8918 LHSShift = NewLHSShift;
8919
8920 // If a side is still missing, nothing else we can do.
8921 if (!RHSShift || !LHSShift)
8922 return SDValue();
8923
8924 // At this point we've matched or extracted a shift op on each side.
8925
8926 if (LHSShift.getOpcode() == RHSShift.getOpcode())
8927 return SDValue(); // Shifts must disagree.
8928
8929 // Canonicalize shl to left side in a shl/srl pair.
8930 if (RHSShift.getOpcode() == ISD::SHL) {
8931 std::swap(a&: LHS, b&: RHS);
8932 std::swap(a&: LHSShift, b&: RHSShift);
8933 std::swap(a&: LHSMask, b&: RHSMask);
8934 }
8935
8936 // Something has gone wrong - we've lost the shl/srl pair - bail.
8937 if (LHSShift.getOpcode() != ISD::SHL || RHSShift.getOpcode() != ISD::SRL)
8938 return SDValue();
8939
8940 unsigned EltSizeInBits = VT.getScalarSizeInBits();
8941 SDValue LHSShiftArg = LHSShift.getOperand(i: 0);
8942 SDValue LHSShiftAmt = LHSShift.getOperand(i: 1);
8943 SDValue RHSShiftArg = RHSShift.getOperand(i: 0);
8944 SDValue RHSShiftAmt = RHSShift.getOperand(i: 1);
8945
8946 auto MatchRotateSum = [EltSizeInBits](ConstantSDNode *LHS,
8947 ConstantSDNode *RHS) {
8948 return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits;
8949 };
8950
8951 auto ApplyMasks = [&](SDValue Res) {
8952 // If there is an AND of either shifted operand, apply it to the result.
8953 if (LHSMask.getNode() || RHSMask.getNode()) {
8954 SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
8955 SDValue Mask = AllOnes;
8956
8957 if (LHSMask.getNode()) {
8958 SDValue RHSBits = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: AllOnes, N2: RHSShiftAmt);
8959 Mask = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Mask,
8960 N2: DAG.getNode(Opcode: ISD::OR, DL, VT, N1: LHSMask, N2: RHSBits));
8961 }
8962 if (RHSMask.getNode()) {
8963 SDValue LHSBits = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: AllOnes, N2: LHSShiftAmt);
8964 Mask = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Mask,
8965 N2: DAG.getNode(Opcode: ISD::OR, DL, VT, N1: RHSMask, N2: LHSBits));
8966 }
8967
8968 Res = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Res, N2: Mask);
8969 }
8970
8971 return Res;
8972 };
8973
8974 // TODO: Support pre-legalization funnel-shift by constant.
8975 bool IsRotate = LHSShiftArg == RHSShiftArg;
8976 if (!IsRotate && !(HasFSHL || HasFSHR)) {
8977 if (TLI.isTypeLegal(VT) && LHS.hasOneUse() && RHS.hasOneUse() &&
8978 ISD::matchBinaryPredicate(LHS: LHSShiftAmt, RHS: RHSShiftAmt, Match: MatchRotateSum)) {
8979 // Look for a disguised rotate by constant.
8980 // The common shifted operand X may be hidden inside another 'or'.
8981 SDValue X, Y;
8982 auto matchOr = [&X, &Y](SDValue Or, SDValue CommonOp) {
8983 if (!Or.hasOneUse() || Or.getOpcode() != ISD::OR)
8984 return false;
8985 if (CommonOp == Or.getOperand(i: 0)) {
8986 X = CommonOp;
8987 Y = Or.getOperand(i: 1);
8988 return true;
8989 }
8990 if (CommonOp == Or.getOperand(i: 1)) {
8991 X = CommonOp;
8992 Y = Or.getOperand(i: 0);
8993 return true;
8994 }
8995 return false;
8996 };
8997
8998 SDValue Res;
8999 if (matchOr(LHSShiftArg, RHSShiftArg)) {
9000 // (shl (X | Y), C1) | (srl X, C2) --> (rotl X, C1) | (shl Y, C1)
9001 SDValue RotX = DAG.getNode(Opcode: ISD::ROTL, DL, VT, N1: X, N2: LHSShiftAmt);
9002 SDValue ShlY = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Y, N2: LHSShiftAmt);
9003 Res = DAG.getNode(Opcode: ISD::OR, DL, VT, N1: RotX, N2: ShlY);
9004 } else if (matchOr(RHSShiftArg, LHSShiftArg)) {
9005 // (shl X, C1) | (srl (X | Y), C2) --> (rotl X, C1) | (srl Y, C2)
9006 SDValue RotX = DAG.getNode(Opcode: ISD::ROTL, DL, VT, N1: X, N2: LHSShiftAmt);
9007 SDValue SrlY = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Y, N2: RHSShiftAmt);
9008 Res = DAG.getNode(Opcode: ISD::OR, DL, VT, N1: RotX, N2: SrlY);
9009 } else {
9010 return SDValue();
9011 }
9012
9013 return ApplyMasks(Res);
9014 }
9015
9016 return SDValue(); // Requires funnel shift support.
9017 }
9018
9019 // fold (or/add (shl x, C1), (srl x, C2)) -> (rotl x, C1)
9020 // fold (or/add (shl x, C1), (srl x, C2)) -> (rotr x, C2)
9021 // fold (or/add (shl x, C1), (srl y, C2)) -> (fshl x, y, C1)
9022 // fold (or/add (shl x, C1), (srl y, C2)) -> (fshr x, y, C2)
9023 // iff C1+C2 == EltSizeInBits
9024 if (ISD::matchBinaryPredicate(LHS: LHSShiftAmt, RHS: RHSShiftAmt, Match: MatchRotateSum)) {
9025 SDValue Res;
9026 if (IsRotate && (HasROTL || HasROTR || !(HasFSHL || HasFSHR))) {
9027 bool UseROTL = !LegalOperations || HasROTL;
9028 Res = DAG.getNode(Opcode: UseROTL ? ISD::ROTL : ISD::ROTR, DL, VT, N1: LHSShiftArg,
9029 N2: UseROTL ? LHSShiftAmt : RHSShiftAmt);
9030 } else {
9031 bool UseFSHL = !LegalOperations || HasFSHL;
9032 Res = DAG.getNode(Opcode: UseFSHL ? ISD::FSHL : ISD::FSHR, DL, VT, N1: LHSShiftArg,
9033 N2: RHSShiftArg, N3: UseFSHL ? LHSShiftAmt : RHSShiftAmt);
9034 }
9035
9036 return ApplyMasks(Res);
9037 }
9038
9039 // Even pre-legalization, we can't easily rotate/funnel-shift by a variable
9040 // shift.
9041 if (!HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
9042 return SDValue();
9043
9044 // If there is a mask here, and we have a variable shift, we can't be sure
9045 // that we're masking out the right stuff.
9046 if (LHSMask.getNode() || RHSMask.getNode())
9047 return SDValue();
9048
9049 // If the shift amount is sign/zext/any-extended just peel it off.
9050 SDValue LExtOp0 = LHSShiftAmt;
9051 SDValue RExtOp0 = RHSShiftAmt;
9052 if ((LHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
9053 LHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
9054 LHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
9055 LHSShiftAmt.getOpcode() == ISD::TRUNCATE) &&
9056 (RHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
9057 RHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
9058 RHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
9059 RHSShiftAmt.getOpcode() == ISD::TRUNCATE)) {
9060 LExtOp0 = LHSShiftAmt.getOperand(i: 0);
9061 RExtOp0 = RHSShiftAmt.getOperand(i: 0);
9062 }
9063
9064 if (IsRotate && (HasROTL || HasROTR)) {
9065 if (SDValue TryL = MatchRotatePosNeg(Shifted: LHSShiftArg, Pos: LHSShiftAmt, Neg: RHSShiftAmt,
9066 InnerPos: LExtOp0, InnerNeg: RExtOp0, FromAdd, HasPos: HasROTL,
9067 PosOpcode: ISD::ROTL, NegOpcode: ISD::ROTR, DL))
9068 return TryL;
9069
9070 if (SDValue TryR = MatchRotatePosNeg(Shifted: RHSShiftArg, Pos: RHSShiftAmt, Neg: LHSShiftAmt,
9071 InnerPos: RExtOp0, InnerNeg: LExtOp0, FromAdd, HasPos: HasROTR,
9072 PosOpcode: ISD::ROTR, NegOpcode: ISD::ROTL, DL))
9073 return TryR;
9074 }
9075
9076 if (SDValue TryL = MatchFunnelPosNeg(N0: LHSShiftArg, N1: RHSShiftArg, Pos: LHSShiftAmt,
9077 Neg: RHSShiftAmt, InnerPos: LExtOp0, InnerNeg: RExtOp0, FromAdd,
9078 HasPos: HasFSHL, PosOpcode: ISD::FSHL, NegOpcode: ISD::FSHR, DL))
9079 return TryL;
9080
9081 if (SDValue TryR = MatchFunnelPosNeg(N0: LHSShiftArg, N1: RHSShiftArg, Pos: RHSShiftAmt,
9082 Neg: LHSShiftAmt, InnerPos: RExtOp0, InnerNeg: LExtOp0, FromAdd,
9083 HasPos: HasFSHR, PosOpcode: ISD::FSHR, NegOpcode: ISD::FSHL, DL))
9084 return TryR;
9085
9086 return SDValue();
9087}
9088
9089/// Recursively traverses the expression calculating the origin of the requested
9090/// byte of the given value. Returns std::nullopt if the provider can't be
9091/// calculated.
9092///
9093/// For all the values except the root of the expression, we verify that the
9094/// value has exactly one use and if not then return std::nullopt. This way if
9095/// the origin of the byte is returned it's guaranteed that the values which
9096/// contribute to the byte are not used outside of this expression.
9097
9098/// However, there is a special case when dealing with vector loads -- we allow
9099/// more than one use if the load is a vector type. Since the values that
9100/// contribute to the byte ultimately come from the ExtractVectorElements of the
9101/// Load, we don't care if the Load has uses other than ExtractVectorElements,
9102/// because those operations are independent from the pattern to be combined.
9103/// For vector loads, we simply care that the ByteProviders are adjacent
9104/// positions of the same vector, and their index matches the byte that is being
9105/// provided. This is captured by the \p VectorIndex algorithm. \p VectorIndex
9106/// is the index used in an ExtractVectorElement, and \p StartingIndex is the
9107/// byte position we are trying to provide for the LoadCombine. If these do
9108/// not match, then we can not combine the vector loads. \p Index uses the
9109/// byte position we are trying to provide for and is matched against the
9110/// shl and load size. The \p Index algorithm ensures the requested byte is
9111/// provided for by the pattern, and the pattern does not over provide bytes.
9112///
9113///
9114/// The supported LoadCombine pattern for vector loads is as follows
9115/// or
9116/// / \
9117/// or shl
9118/// / \ |
9119/// or shl zext
9120/// / \ | |
9121/// shl zext zext EVE*
9122/// | | | |
9123/// zext EVE* EVE* LOAD
9124/// | | |
9125/// EVE* LOAD LOAD
9126/// |
9127/// LOAD
9128///
9129/// *ExtractVectorElement
9130using SDByteProvider = ByteProvider<SDNode *>;
9131
9132static std::optional<SDByteProvider>
9133calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
9134 std::optional<uint64_t> VectorIndex,
9135 unsigned StartingIndex = 0) {
9136
9137 // Typical i64 by i8 pattern requires recursion up to 8 calls depth
9138 if (Depth == 10)
9139 return std::nullopt;
9140
9141 // Only allow multiple uses if the instruction is a vector load (in which
9142 // case we will use the load for every ExtractVectorElement)
9143 if (Depth && !Op.hasOneUse() &&
9144 (Op.getOpcode() != ISD::LOAD || !Op.getValueType().isVector()))
9145 return std::nullopt;
9146
9147 // Fail to combine if we have encountered anything but a LOAD after handling
9148 // an ExtractVectorElement.
9149 if (Op.getOpcode() != ISD::LOAD && VectorIndex.has_value())
9150 return std::nullopt;
9151
9152 unsigned BitWidth = Op.getValueSizeInBits();
9153 if (BitWidth % 8 != 0)
9154 return std::nullopt;
9155 unsigned ByteWidth = BitWidth / 8;
9156 assert(Index < ByteWidth && "invalid index requested");
9157 (void) ByteWidth;
9158
9159 switch (Op.getOpcode()) {
9160 case ISD::OR: {
9161 auto LHS =
9162 calculateByteProvider(Op: Op->getOperand(Num: 0), Index, Depth: Depth + 1, VectorIndex);
9163 if (!LHS)
9164 return std::nullopt;
9165 auto RHS =
9166 calculateByteProvider(Op: Op->getOperand(Num: 1), Index, Depth: Depth + 1, VectorIndex);
9167 if (!RHS)
9168 return std::nullopt;
9169
9170 if (LHS->isConstantZero())
9171 return RHS;
9172 if (RHS->isConstantZero())
9173 return LHS;
9174 return std::nullopt;
9175 }
9176 case ISD::SHL: {
9177 auto ShiftOp = dyn_cast<ConstantSDNode>(Val: Op->getOperand(Num: 1));
9178 if (!ShiftOp)
9179 return std::nullopt;
9180
9181 uint64_t BitShift = ShiftOp->getZExtValue();
9182
9183 if (BitShift % 8 != 0)
9184 return std::nullopt;
9185 uint64_t ByteShift = BitShift / 8;
9186
9187 // If we are shifting by an amount greater than the index we are trying to
9188 // provide, then do not provide anything. Otherwise, subtract the index by
9189 // the amount we shifted by.
9190 return Index < ByteShift
9191 ? SDByteProvider::getConstantZero()
9192 : calculateByteProvider(Op: Op->getOperand(Num: 0), Index: Index - ByteShift,
9193 Depth: Depth + 1, VectorIndex, StartingIndex: Index);
9194 }
9195 case ISD::ANY_EXTEND:
9196 case ISD::SIGN_EXTEND:
9197 case ISD::ZERO_EXTEND: {
9198 SDValue NarrowOp = Op->getOperand(Num: 0);
9199 unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
9200 if (NarrowBitWidth % 8 != 0)
9201 return std::nullopt;
9202 uint64_t NarrowByteWidth = NarrowBitWidth / 8;
9203
9204 if (Index >= NarrowByteWidth)
9205 return Op.getOpcode() == ISD::ZERO_EXTEND
9206 ? std::optional<SDByteProvider>(
9207 SDByteProvider::getConstantZero())
9208 : std::nullopt;
9209 return calculateByteProvider(Op: NarrowOp, Index, Depth: Depth + 1, VectorIndex,
9210 StartingIndex);
9211 }
9212 case ISD::BSWAP:
9213 return calculateByteProvider(Op: Op->getOperand(Num: 0), Index: ByteWidth - Index - 1,
9214 Depth: Depth + 1, VectorIndex, StartingIndex);
9215 case ISD::EXTRACT_VECTOR_ELT: {
9216 auto OffsetOp = dyn_cast<ConstantSDNode>(Val: Op->getOperand(Num: 1));
9217 if (!OffsetOp)
9218 return std::nullopt;
9219
9220 VectorIndex = OffsetOp->getZExtValue();
9221
9222 SDValue NarrowOp = Op->getOperand(Num: 0);
9223 unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
9224 if (NarrowBitWidth % 8 != 0)
9225 return std::nullopt;
9226 uint64_t NarrowByteWidth = NarrowBitWidth / 8;
9227 // EXTRACT_VECTOR_ELT can extend the element type to the width of the return
9228 // type, leaving the high bits undefined.
9229 if (Index >= NarrowByteWidth)
9230 return std::nullopt;
9231
9232 // Check to see if the position of the element in the vector corresponds
9233 // with the byte we are trying to provide for. In the case of a vector of
9234 // i8, this simply means the VectorIndex == StartingIndex. For non i8 cases,
9235 // the element will provide a range of bytes. For example, if we have a
9236 // vector of i16s, each element provides two bytes (V[1] provides byte 2 and
9237 // 3).
9238 if (*VectorIndex * NarrowByteWidth > StartingIndex)
9239 return std::nullopt;
9240 if ((*VectorIndex + 1) * NarrowByteWidth <= StartingIndex)
9241 return std::nullopt;
9242
9243 return calculateByteProvider(Op: Op->getOperand(Num: 0), Index, Depth: Depth + 1,
9244 VectorIndex, StartingIndex);
9245 }
9246 case ISD::LOAD: {
9247 auto L = cast<LoadSDNode>(Val: Op.getNode());
9248 if (!L->isSimple() || L->isIndexed())
9249 return std::nullopt;
9250
9251 unsigned NarrowBitWidth = L->getMemoryVT().getSizeInBits();
9252 if (NarrowBitWidth % 8 != 0)
9253 return std::nullopt;
9254 uint64_t NarrowByteWidth = NarrowBitWidth / 8;
9255
9256 // If the width of the load does not reach byte we are trying to provide for
9257 // and it is not a ZEXTLOAD, then the load does not provide for the byte in
9258 // question
9259 if (Index >= NarrowByteWidth)
9260 return L->getExtensionType() == ISD::ZEXTLOAD
9261 ? std::optional<SDByteProvider>(
9262 SDByteProvider::getConstantZero())
9263 : std::nullopt;
9264
9265 unsigned BPVectorIndex = VectorIndex.value_or(u: 0U);
9266 return SDByteProvider::getSrc(Val: L, ByteOffset: Index, VectorOffset: BPVectorIndex);
9267 }
9268 }
9269
9270 return std::nullopt;
9271}
9272
9273static unsigned littleEndianByteAt(unsigned BW, unsigned i) {
9274 return i;
9275}
9276
9277static unsigned bigEndianByteAt(unsigned BW, unsigned i) {
9278 return BW - i - 1;
9279}
9280
9281// Check if the bytes offsets we are looking at match with either big or
9282// little endian value loaded. Return true for big endian, false for little
9283// endian, and std::nullopt if match failed.
9284static std::optional<bool> isBigEndian(const ArrayRef<int64_t> ByteOffsets,
9285 int64_t FirstOffset) {
9286 // The endian can be decided only when it is 2 bytes at least.
9287 unsigned Width = ByteOffsets.size();
9288 if (Width < 2)
9289 return std::nullopt;
9290
9291 bool BigEndian = true, LittleEndian = true;
9292 for (unsigned i = 0; i < Width; i++) {
9293 int64_t CurrentByteOffset = ByteOffsets[i] - FirstOffset;
9294 LittleEndian &= CurrentByteOffset == littleEndianByteAt(BW: Width, i);
9295 BigEndian &= CurrentByteOffset == bigEndianByteAt(BW: Width, i);
9296 if (!BigEndian && !LittleEndian)
9297 return std::nullopt;
9298 }
9299
9300 assert((BigEndian != LittleEndian) && "It should be either big endian or"
9301 "little endian");
9302 return BigEndian;
9303}
9304
9305// Look through one layer of truncate or extend.
9306static SDValue stripTruncAndExt(SDValue Value) {
9307 switch (Value.getOpcode()) {
9308 case ISD::TRUNCATE:
9309 case ISD::ZERO_EXTEND:
9310 case ISD::SIGN_EXTEND:
9311 case ISD::ANY_EXTEND:
9312 return Value.getOperand(i: 0);
9313 }
9314 return SDValue();
9315}
9316
9317/// Match a pattern where a wide type scalar value is stored by several narrow
9318/// stores. Fold it into a single store or a BSWAP and a store if the targets
9319/// supports it.
9320///
9321/// Assuming little endian target:
9322/// i8 *p = ...
9323/// i32 val = ...
9324/// p[0] = (val >> 0) & 0xFF;
9325/// p[1] = (val >> 8) & 0xFF;
9326/// p[2] = (val >> 16) & 0xFF;
9327/// p[3] = (val >> 24) & 0xFF;
9328/// =>
9329/// *((i32)p) = val;
9330///
9331/// i8 *p = ...
9332/// i32 val = ...
9333/// p[0] = (val >> 24) & 0xFF;
9334/// p[1] = (val >> 16) & 0xFF;
9335/// p[2] = (val >> 8) & 0xFF;
9336/// p[3] = (val >> 0) & 0xFF;
9337/// =>
9338/// *((i32)p) = BSWAP(val);
9339SDValue DAGCombiner::mergeTruncStores(StoreSDNode *N) {
9340 // The matching looks for "store (trunc x)" patterns that appear early but are
9341 // likely to be replaced by truncating store nodes during combining.
9342 // TODO: If there is evidence that running this later would help, this
9343 // limitation could be removed. Legality checks may need to be added
9344 // for the created store and optional bswap/rotate.
9345 if (LegalOperations || OptLevel == CodeGenOptLevel::None)
9346 return SDValue();
9347
9348 // We only handle merging simple stores of 1-4 bytes.
9349 // TODO: Allow unordered atomics when wider type is legal (see D66309)
9350 EVT MemVT = N->getMemoryVT();
9351 if (!(MemVT == MVT::i8 || MemVT == MVT::i16 || MemVT == MVT::i32) ||
9352 !N->isSimple() || N->isIndexed())
9353 return SDValue();
9354
9355 // Collect all of the stores in the chain, upto the maximum store width (i64).
9356 SDValue Chain = N->getChain();
9357 SmallVector<StoreSDNode *, 8> Stores = {N};
9358 unsigned NarrowNumBits = MemVT.getScalarSizeInBits();
9359 unsigned MaxWideNumBits = 64;
9360 unsigned MaxStores = MaxWideNumBits / NarrowNumBits;
9361 while (auto *Store = dyn_cast<StoreSDNode>(Val&: Chain)) {
9362 // All stores must be the same size to ensure that we are writing all of the
9363 // bytes in the wide value.
9364 // This store should have exactly one use as a chain operand for another
9365 // store in the merging set. If there are other chain uses, then the
9366 // transform may not be safe because order of loads/stores outside of this
9367 // set may not be preserved.
9368 // TODO: We could allow multiple sizes by tracking each stored byte.
9369 if (Store->getMemoryVT() != MemVT || !Store->isSimple() ||
9370 Store->isIndexed() || !Store->hasOneUse())
9371 return SDValue();
9372 Stores.push_back(Elt: Store);
9373 Chain = Store->getChain();
9374 if (MaxStores < Stores.size())
9375 return SDValue();
9376 }
9377 // There is no reason to continue if we do not have at least a pair of stores.
9378 if (Stores.size() < 2)
9379 return SDValue();
9380
9381 // Handle simple types only.
9382 LLVMContext &Context = *DAG.getContext();
9383 unsigned NumStores = Stores.size();
9384 unsigned WideNumBits = NumStores * NarrowNumBits;
9385 EVT WideVT = EVT::getIntegerVT(Context, BitWidth: WideNumBits);
9386 if (WideVT != MVT::i16 && WideVT != MVT::i32 && WideVT != MVT::i64)
9387 return SDValue();
9388
9389 // Check if all bytes of the source value that we are looking at are stored
9390 // to the same base address. Collect offsets from Base address into OffsetMap.
9391 SDValue SourceValue;
9392 SmallVector<int64_t, 8> OffsetMap(NumStores, INT64_MAX);
9393 int64_t FirstOffset = INT64_MAX;
9394 StoreSDNode *FirstStore = nullptr;
9395 std::optional<BaseIndexOffset> Base;
9396 for (auto *Store : Stores) {
9397 // All the stores store different parts of the CombinedValue. A truncate is
9398 // required to get the partial value.
9399 SDValue Trunc = Store->getValue();
9400 if (Trunc.getOpcode() != ISD::TRUNCATE)
9401 return SDValue();
9402 // Other than the first/last part, a shift operation is required to get the
9403 // offset.
9404 int64_t Offset = 0;
9405 SDValue WideVal = Trunc.getOperand(i: 0);
9406 if ((WideVal.getOpcode() == ISD::SRL || WideVal.getOpcode() == ISD::SRA) &&
9407 isa<ConstantSDNode>(Val: WideVal.getOperand(i: 1))) {
9408 // The shift amount must be a constant multiple of the narrow type.
9409 // It is translated to the offset address in the wide source value "y".
9410 //
9411 // x = srl y, ShiftAmtC
9412 // i8 z = trunc x
9413 // store z, ...
9414 uint64_t ShiftAmtC = WideVal.getConstantOperandVal(i: 1);
9415 if (ShiftAmtC % NarrowNumBits != 0)
9416 return SDValue();
9417
9418 // Make sure we aren't reading bits that are shifted in.
9419 if (ShiftAmtC > WideVal.getScalarValueSizeInBits() - NarrowNumBits)
9420 return SDValue();
9421
9422 Offset = ShiftAmtC / NarrowNumBits;
9423 WideVal = WideVal.getOperand(i: 0);
9424 }
9425
9426 // Stores must share the same source value with different offsets.
9427 if (!SourceValue)
9428 SourceValue = WideVal;
9429 else if (SourceValue != WideVal) {
9430 // Truncate and extends can be stripped to see if the values are related.
9431 if (stripTruncAndExt(Value: SourceValue) != WideVal &&
9432 stripTruncAndExt(Value: WideVal) != SourceValue)
9433 return SDValue();
9434
9435 if (WideVal.getScalarValueSizeInBits() >
9436 SourceValue.getScalarValueSizeInBits())
9437 SourceValue = WideVal;
9438
9439 // Give up if the source value type is smaller than the store size.
9440 if (SourceValue.getScalarValueSizeInBits() < WideVT.getScalarSizeInBits())
9441 return SDValue();
9442 }
9443
9444 // Stores must share the same base address.
9445 BaseIndexOffset Ptr = BaseIndexOffset::match(N: Store, DAG);
9446 int64_t ByteOffsetFromBase = 0;
9447 if (!Base)
9448 Base = Ptr;
9449 else if (!Base->equalBaseIndex(Other: Ptr, DAG, Off&: ByteOffsetFromBase))
9450 return SDValue();
9451
9452 // Remember the first store.
9453 if (ByteOffsetFromBase < FirstOffset) {
9454 FirstStore = Store;
9455 FirstOffset = ByteOffsetFromBase;
9456 }
9457 // Map the offset in the store and the offset in the combined value, and
9458 // early return if it has been set before.
9459 if (Offset < 0 || Offset >= NumStores || OffsetMap[Offset] != INT64_MAX)
9460 return SDValue();
9461 OffsetMap[Offset] = ByteOffsetFromBase;
9462 }
9463
9464 assert(FirstOffset != INT64_MAX && "First byte offset must be set");
9465 assert(FirstStore && "First store must be set");
9466
9467 // Check that a store of the wide type is both allowed and fast on the target
9468 const DataLayout &Layout = DAG.getDataLayout();
9469 unsigned Fast = 0;
9470 bool Allowed = TLI.allowsMemoryAccess(Context, DL: Layout, VT: WideVT,
9471 MMO: *FirstStore->getMemOperand(), Fast: &Fast);
9472 if (!Allowed || !Fast)
9473 return SDValue();
9474
9475 // Check if the pieces of the value are going to the expected places in memory
9476 // to merge the stores.
9477 auto checkOffsets = [&](bool MatchLittleEndian) {
9478 if (MatchLittleEndian) {
9479 for (unsigned i = 0; i != NumStores; ++i)
9480 if (OffsetMap[i] != i * (NarrowNumBits / 8) + FirstOffset)
9481 return false;
9482 } else { // MatchBigEndian by reversing loop counter.
9483 for (unsigned i = 0, j = NumStores - 1; i != NumStores; ++i, --j)
9484 if (OffsetMap[j] != i * (NarrowNumBits / 8) + FirstOffset)
9485 return false;
9486 }
9487 return true;
9488 };
9489
9490 // Check if the offsets line up for the native data layout of this target.
9491 bool NeedBswap = false;
9492 bool NeedRotate = false;
9493 if (!checkOffsets(Layout.isLittleEndian())) {
9494 // Special-case: check if byte offsets line up for the opposite endian.
9495 if (NarrowNumBits == 8 && checkOffsets(Layout.isBigEndian()))
9496 NeedBswap = true;
9497 else if (NumStores == 2 && checkOffsets(Layout.isBigEndian()))
9498 NeedRotate = true;
9499 else
9500 return SDValue();
9501 }
9502
9503 SDLoc DL(N);
9504 if (WideVT != SourceValue.getValueType()) {
9505 assert(SourceValue.getValueType().getScalarSizeInBits() > WideNumBits &&
9506 "Unexpected store value to merge");
9507 SourceValue = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: WideVT, Operand: SourceValue);
9508 }
9509
9510 // Before legalize we can introduce illegal bswaps/rotates which will be later
9511 // converted to an explicit bswap sequence. This way we end up with a single
9512 // store and byte shuffling instead of several stores and byte shuffling.
9513 if (NeedBswap) {
9514 SourceValue = DAG.getNode(Opcode: ISD::BSWAP, DL, VT: WideVT, Operand: SourceValue);
9515 } else if (NeedRotate) {
9516 assert(WideNumBits % 2 == 0 && "Unexpected type for rotate");
9517 SDValue RotAmt = DAG.getConstant(Val: WideNumBits / 2, DL, VT: WideVT);
9518 SourceValue = DAG.getNode(Opcode: ISD::ROTR, DL, VT: WideVT, N1: SourceValue, N2: RotAmt);
9519 }
9520
9521 SDValue NewStore =
9522 DAG.getStore(Chain, dl: DL, Val: SourceValue, Ptr: FirstStore->getBasePtr(),
9523 PtrInfo: FirstStore->getPointerInfo(), Alignment: FirstStore->getAlign());
9524
9525 // Rely on other DAG combine rules to remove the other individual stores.
9526 DAG.ReplaceAllUsesWith(From: N, To: NewStore.getNode());
9527 return NewStore;
9528}
9529
9530/// Match a pattern where a wide type scalar value is loaded by several narrow
9531/// loads and combined by shifts and ors. Fold it into a single load or a load
9532/// and a BSWAP if the targets supports it.
9533///
9534/// Assuming little endian target:
9535/// i8 *a = ...
9536/// i32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
9537/// =>
9538/// i32 val = *((i32)a)
9539///
9540/// i8 *a = ...
9541/// i32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
9542/// =>
9543/// i32 val = BSWAP(*((i32)a))
9544///
9545/// TODO: This rule matches complex patterns with OR node roots and doesn't
9546/// interact well with the worklist mechanism. When a part of the pattern is
9547/// updated (e.g. one of the loads) its direct users are put into the worklist,
9548/// but the root node of the pattern which triggers the load combine is not
9549/// necessarily a direct user of the changed node. For example, once the address
9550/// of t28 load is reassociated load combine won't be triggered:
9551/// t25: i32 = add t4, Constant:i32<2>
9552/// t26: i64 = sign_extend t25
9553/// t27: i64 = add t2, t26
9554/// t28: i8,ch = load<LD1[%tmp9]> t0, t27, undef:i64
9555/// t29: i32 = zero_extend t28
9556/// t32: i32 = shl t29, Constant:i8<8>
9557/// t33: i32 = or t23, t32
9558/// As a possible fix visitLoad can check if the load can be a part of a load
9559/// combine pattern and add corresponding OR roots to the worklist.
9560SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
9561 assert(N->getOpcode() == ISD::OR &&
9562 "Can only match load combining against OR nodes");
9563
9564 // Handles simple types only
9565 EVT VT = N->getValueType(ResNo: 0);
9566 if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
9567 return SDValue();
9568 unsigned ByteWidth = VT.getSizeInBits() / 8;
9569
9570 bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian();
9571 auto MemoryByteOffset = [&](SDByteProvider P) {
9572 assert(P.hasSrc() && "Must be a memory byte provider");
9573 auto *Load = cast<LoadSDNode>(Val: P.Src.value());
9574
9575 unsigned LoadBitWidth = Load->getMemoryVT().getScalarSizeInBits();
9576
9577 assert(LoadBitWidth % 8 == 0 &&
9578 "can only analyze providers for individual bytes not bit");
9579 unsigned LoadByteWidth = LoadBitWidth / 8;
9580 return IsBigEndianTarget ? bigEndianByteAt(BW: LoadByteWidth, i: P.DestOffset)
9581 : littleEndianByteAt(BW: LoadByteWidth, i: P.DestOffset);
9582 };
9583
9584 std::optional<BaseIndexOffset> Base;
9585 SDValue Chain;
9586
9587 SmallPtrSet<LoadSDNode *, 8> Loads;
9588 std::optional<SDByteProvider> FirstByteProvider;
9589 int64_t FirstOffset = INT64_MAX;
9590
9591 // Check if all the bytes of the OR we are looking at are loaded from the same
9592 // base address. Collect bytes offsets from Base address in ByteOffsets.
9593 SmallVector<int64_t, 8> ByteOffsets(ByteWidth);
9594 unsigned ZeroExtendedBytes = 0;
9595 for (int i = ByteWidth - 1; i >= 0; --i) {
9596 auto P =
9597 calculateByteProvider(Op: SDValue(N, 0), Index: i, Depth: 0, /*VectorIndex*/ std::nullopt,
9598 /*StartingIndex*/ i);
9599 if (!P)
9600 return SDValue();
9601
9602 if (P->isConstantZero()) {
9603 // It's OK for the N most significant bytes to be 0, we can just
9604 // zero-extend the load.
9605 if (++ZeroExtendedBytes != (ByteWidth - static_cast<unsigned>(i)))
9606 return SDValue();
9607 continue;
9608 }
9609 assert(P->hasSrc() && "provenance should either be memory or zero");
9610 auto *L = cast<LoadSDNode>(Val: P->Src.value());
9611
9612 // All loads must share the same chain
9613 SDValue LChain = L->getChain();
9614 if (!Chain)
9615 Chain = LChain;
9616 else if (Chain != LChain)
9617 return SDValue();
9618
9619 // Loads must share the same base address
9620 BaseIndexOffset Ptr = BaseIndexOffset::match(N: L, DAG);
9621 int64_t ByteOffsetFromBase = 0;
9622
9623 // For vector loads, the expected load combine pattern will have an
9624 // ExtractElement for each index in the vector. While each of these
9625 // ExtractElements will be accessing the same base address as determined
9626 // by the load instruction, the actual bytes they interact with will differ
9627 // due to different ExtractElement indices. To accurately determine the
9628 // byte position of an ExtractElement, we offset the base load ptr with
9629 // the index multiplied by the byte size of each element in the vector.
9630 if (L->getMemoryVT().isVector()) {
9631 unsigned LoadWidthInBit = L->getMemoryVT().getScalarSizeInBits();
9632 if (LoadWidthInBit % 8 != 0)
9633 return SDValue();
9634 unsigned ByteOffsetFromVector = P->SrcOffset * LoadWidthInBit / 8;
9635 Ptr.addToOffset(VectorOff: ByteOffsetFromVector);
9636 }
9637
9638 if (!Base)
9639 Base = Ptr;
9640
9641 else if (!Base->equalBaseIndex(Other: Ptr, DAG, Off&: ByteOffsetFromBase))
9642 return SDValue();
9643
9644 // Calculate the offset of the current byte from the base address
9645 ByteOffsetFromBase += MemoryByteOffset(*P);
9646 ByteOffsets[i] = ByteOffsetFromBase;
9647
9648 // Remember the first byte load
9649 if (ByteOffsetFromBase < FirstOffset) {
9650 FirstByteProvider = P;
9651 FirstOffset = ByteOffsetFromBase;
9652 }
9653
9654 Loads.insert(Ptr: L);
9655 }
9656
9657 assert(!Loads.empty() && "All the bytes of the value must be loaded from "
9658 "memory, so there must be at least one load which produces the value");
9659 assert(Base && "Base address of the accessed memory location must be set");
9660 assert(FirstOffset != INT64_MAX && "First byte offset must be set");
9661
9662 bool NeedsZext = ZeroExtendedBytes > 0;
9663
9664 EVT MemVT =
9665 EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: (ByteWidth - ZeroExtendedBytes) * 8);
9666
9667 if (!MemVT.isSimple())
9668 return SDValue();
9669
9670 // Before legalize we can introduce too wide illegal loads which will be later
9671 // split into legal sized loads. This enables us to combine i64 load by i8
9672 // patterns to a couple of i32 loads on 32 bit targets.
9673 if (LegalOperations &&
9674 !TLI.isLoadExtLegal(ExtType: NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, ValVT: VT,
9675 MemVT))
9676 return SDValue();
9677
9678 // Check if the bytes of the OR we are looking at match with either big or
9679 // little endian value load
9680 std::optional<bool> IsBigEndian = isBigEndian(
9681 ByteOffsets: ArrayRef(ByteOffsets).drop_back(N: ZeroExtendedBytes), FirstOffset);
9682 if (!IsBigEndian)
9683 return SDValue();
9684
9685 assert(FirstByteProvider && "must be set");
9686
9687 // Ensure that the first byte is loaded from zero offset of the first load.
9688 // So the combined value can be loaded from the first load address.
9689 if (MemoryByteOffset(*FirstByteProvider) != 0)
9690 return SDValue();
9691 auto *FirstLoad = cast<LoadSDNode>(Val: FirstByteProvider->Src.value());
9692
9693 // The node we are looking at matches with the pattern, check if we can
9694 // replace it with a single (possibly zero-extended) load and bswap + shift if
9695 // needed.
9696
9697 // If the load needs byte swap check if the target supports it
9698 bool NeedsBswap = IsBigEndianTarget != *IsBigEndian;
9699
9700 // Before legalize we can introduce illegal bswaps which will be later
9701 // converted to an explicit bswap sequence. This way we end up with a single
9702 // load and byte shuffling instead of several loads and byte shuffling.
9703 // We do not introduce illegal bswaps when zero-extending as this tends to
9704 // introduce too many arithmetic instructions.
9705 if (NeedsBswap && (LegalOperations || NeedsZext) &&
9706 !TLI.isOperationLegal(Op: ISD::BSWAP, VT))
9707 return SDValue();
9708
9709 // If we need to bswap and zero extend, we have to insert a shift. Check that
9710 // it is legal.
9711 if (NeedsBswap && NeedsZext && LegalOperations &&
9712 !TLI.isOperationLegal(Op: ISD::SHL, VT))
9713 return SDValue();
9714
9715 // Check that a load of the wide type is both allowed and fast on the target
9716 unsigned Fast = 0;
9717 bool Allowed =
9718 TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: MemVT,
9719 MMO: *FirstLoad->getMemOperand(), Fast: &Fast);
9720 if (!Allowed || !Fast)
9721 return SDValue();
9722
9723 SDValue NewLoad =
9724 DAG.getExtLoad(ExtType: NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, dl: SDLoc(N), VT,
9725 Chain, Ptr: FirstLoad->getBasePtr(),
9726 PtrInfo: FirstLoad->getPointerInfo(), MemVT, Alignment: FirstLoad->getAlign());
9727
9728 // Transfer chain users from old loads to the new load.
9729 for (LoadSDNode *L : Loads)
9730 DAG.makeEquivalentMemoryOrdering(OldLoad: L, NewMemOp: NewLoad);
9731
9732 if (!NeedsBswap)
9733 return NewLoad;
9734
9735 SDValue ShiftedLoad =
9736 NeedsZext ? DAG.getNode(Opcode: ISD::SHL, DL: SDLoc(N), VT, N1: NewLoad,
9737 N2: DAG.getShiftAmountConstant(Val: ZeroExtendedBytes * 8,
9738 VT, DL: SDLoc(N)))
9739 : NewLoad;
9740 return DAG.getNode(Opcode: ISD::BSWAP, DL: SDLoc(N), VT, Operand: ShiftedLoad);
9741}
9742
9743// If the target has andn, bsl, or a similar bit-select instruction,
9744// we want to unfold masked merge, with canonical pattern of:
9745// | A | |B|
9746// ((x ^ y) & m) ^ y
9747// | D |
9748// Into:
9749// (x & m) | (y & ~m)
9750// If y is a constant, m is not a 'not', and the 'andn' does not work with
9751// immediates, we unfold into a different pattern:
9752// ~(~x & m) & (m | y)
9753// If x is a constant, m is a 'not', and the 'andn' does not work with
9754// immediates, we unfold into a different pattern:
9755// (x | ~m) & ~(~m & ~y)
9756// NOTE: we don't unfold the pattern if 'xor' is actually a 'not', because at
9757// the very least that breaks andnpd / andnps patterns, and because those
9758// patterns are simplified in IR and shouldn't be created in the DAG
9759SDValue DAGCombiner::unfoldMaskedMerge(SDNode *N) {
9760 assert(N->getOpcode() == ISD::XOR);
9761
9762 // Don't touch 'not' (i.e. where y = -1).
9763 if (isAllOnesOrAllOnesSplat(V: N->getOperand(Num: 1)))
9764 return SDValue();
9765
9766 EVT VT = N->getValueType(ResNo: 0);
9767
9768 // There are 3 commutable operators in the pattern,
9769 // so we have to deal with 8 possible variants of the basic pattern.
9770 SDValue X, Y, M;
9771 auto matchAndXor = [&X, &Y, &M](SDValue And, unsigned XorIdx, SDValue Other) {
9772 if (And.getOpcode() != ISD::AND || !And.hasOneUse())
9773 return false;
9774 SDValue Xor = And.getOperand(i: XorIdx);
9775 if (Xor.getOpcode() != ISD::XOR || !Xor.hasOneUse())
9776 return false;
9777 SDValue Xor0 = Xor.getOperand(i: 0);
9778 SDValue Xor1 = Xor.getOperand(i: 1);
9779 // Don't touch 'not' (i.e. where y = -1).
9780 if (isAllOnesOrAllOnesSplat(V: Xor1))
9781 return false;
9782 if (Other == Xor0)
9783 std::swap(a&: Xor0, b&: Xor1);
9784 if (Other != Xor1)
9785 return false;
9786 X = Xor0;
9787 Y = Xor1;
9788 M = And.getOperand(i: XorIdx ? 0 : 1);
9789 return true;
9790 };
9791
9792 SDValue N0 = N->getOperand(Num: 0);
9793 SDValue N1 = N->getOperand(Num: 1);
9794 if (!matchAndXor(N0, 0, N1) && !matchAndXor(N0, 1, N1) &&
9795 !matchAndXor(N1, 0, N0) && !matchAndXor(N1, 1, N0))
9796 return SDValue();
9797
9798 // Don't do anything if the mask is constant. This should not be reachable.
9799 // InstCombine should have already unfolded this pattern, and DAGCombiner
9800 // probably shouldn't produce it, too.
9801 if (isa<ConstantSDNode>(Val: M.getNode()))
9802 return SDValue();
9803
9804 // We can transform if the target has AndNot
9805 if (!TLI.hasAndNot(X: M))
9806 return SDValue();
9807
9808 SDLoc DL(N);
9809
9810 // If Y is a constant, check that 'andn' works with immediates. Unless M is
9811 // a bitwise not that would already allow ANDN to be used.
9812 if (!TLI.hasAndNot(X: Y) && !isBitwiseNot(V: M)) {
9813 assert(TLI.hasAndNot(X) && "Only mask is a variable? Unreachable.");
9814 // If not, we need to do a bit more work to make sure andn is still used.
9815 SDValue NotX = DAG.getNOT(DL, Val: X, VT);
9816 SDValue LHS = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: NotX, N2: M);
9817 SDValue NotLHS = DAG.getNOT(DL, Val: LHS, VT);
9818 SDValue RHS = DAG.getNode(Opcode: ISD::OR, DL, VT, N1: M, N2: Y);
9819 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: NotLHS, N2: RHS);
9820 }
9821
9822 // If X is a constant and M is a bitwise not, check that 'andn' works with
9823 // immediates.
9824 if (!TLI.hasAndNot(X) && isBitwiseNot(V: M)) {
9825 assert(TLI.hasAndNot(Y) && "Only mask is a variable? Unreachable.");
9826 // If not, we need to do a bit more work to make sure andn is still used.
9827 SDValue NotM = M.getOperand(i: 0);
9828 SDValue LHS = DAG.getNode(Opcode: ISD::OR, DL, VT, N1: X, N2: NotM);
9829 SDValue NotY = DAG.getNOT(DL, Val: Y, VT);
9830 SDValue RHS = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: NotM, N2: NotY);
9831 SDValue NotRHS = DAG.getNOT(DL, Val: RHS, VT);
9832 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: LHS, N2: NotRHS);
9833 }
9834
9835 SDValue LHS = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: X, N2: M);
9836 SDValue NotM = DAG.getNOT(DL, Val: M, VT);
9837 SDValue RHS = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Y, N2: NotM);
9838
9839 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: LHS, N2: RHS);
9840}
9841
9842SDValue DAGCombiner::visitXOR(SDNode *N) {
9843 SDValue N0 = N->getOperand(Num: 0);
9844 SDValue N1 = N->getOperand(Num: 1);
9845 EVT VT = N0.getValueType();
9846 SDLoc DL(N);
9847
9848 // fold (xor undef, undef) -> 0. This is a common idiom (misuse).
9849 if (N0.isUndef() && N1.isUndef())
9850 return DAG.getConstant(Val: 0, DL, VT);
9851
9852 // fold (xor x, undef) -> undef
9853 if (N0.isUndef())
9854 return N0;
9855 if (N1.isUndef())
9856 return N1;
9857
9858 // fold (xor c1, c2) -> c1^c2
9859 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::XOR, DL, VT, Ops: {N0, N1}))
9860 return C;
9861
9862 // canonicalize constant to RHS
9863 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
9864 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
9865 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1, N2: N0);
9866
9867 // fold vector ops
9868 if (VT.isVector()) {
9869 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
9870 return FoldedVOp;
9871
9872 // fold (xor x, 0) -> x, vector edition
9873 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
9874 return N0;
9875 }
9876
9877 // fold (xor x, 0) -> x
9878 if (isNullConstant(V: N1))
9879 return N0;
9880
9881 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
9882 return NewSel;
9883
9884 // reassociate xor
9885 if (SDValue RXOR = reassociateOps(Opc: ISD::XOR, DL, N0, N1, Flags: N->getFlags()))
9886 return RXOR;
9887
9888 // Fold xor(vecreduce(x), vecreduce(y)) -> vecreduce(xor(x, y))
9889 if (SDValue SD =
9890 reassociateReduction(RedOpc: ISD::VECREDUCE_XOR, Opc: ISD::XOR, DL, VT, N0, N1))
9891 return SD;
9892
9893 // fold (a^b) -> (a|b) iff a and b share no bits.
9894 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::OR, VT)) &&
9895 DAG.haveNoCommonBitsSet(A: N0, B: N1))
9896 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: N0, N2: N1, Flags: SDNodeFlags::Disjoint);
9897
9898 // look for 'add-like' folds:
9899 // XOR(N0,MIN_SIGNED_VALUE) == ADD(N0,MIN_SIGNED_VALUE)
9900 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::ADD, VT)) &&
9901 isMinSignedConstant(V: N1))
9902 if (SDValue Combined = visitADDLike(N))
9903 return Combined;
9904
9905 // fold !(x cc y) -> (x !cc y)
9906 unsigned N0Opcode = N0.getOpcode();
9907 SDValue LHS, RHS, CC;
9908 if (TLI.isConstTrueVal(N: N1) &&
9909 isSetCCEquivalent(N: N0, LHS, RHS, CC, /*MatchStrict*/ true)) {
9910 ISD::CondCode NotCC = ISD::getSetCCInverse(Operation: cast<CondCodeSDNode>(Val&: CC)->get(),
9911 Type: LHS.getValueType());
9912 if (!LegalOperations ||
9913 TLI.isCondCodeLegal(CC: NotCC, VT: LHS.getSimpleValueType())) {
9914 switch (N0Opcode) {
9915 default:
9916 llvm_unreachable("Unhandled SetCC Equivalent!");
9917 case ISD::SETCC:
9918 return DAG.getSetCC(DL: SDLoc(N0), VT, LHS, RHS, Cond: NotCC);
9919 case ISD::SELECT_CC:
9920 return DAG.getSelectCC(DL: SDLoc(N0), LHS, RHS, True: N0.getOperand(i: 2),
9921 False: N0.getOperand(i: 3), Cond: NotCC);
9922 case ISD::STRICT_FSETCC:
9923 case ISD::STRICT_FSETCCS: {
9924 if (N0.hasOneUse()) {
9925 // FIXME Can we handle multiple uses? Could we token factor the chain
9926 // results from the new/old setcc?
9927 SDValue SetCC =
9928 DAG.getSetCC(DL: SDLoc(N0), VT, LHS, RHS, Cond: NotCC,
9929 Chain: N0.getOperand(i: 0), IsSignaling: N0Opcode == ISD::STRICT_FSETCCS);
9930 CombineTo(N, Res: SetCC);
9931 DAG.ReplaceAllUsesOfValueWith(From: N0.getValue(R: 1), To: SetCC.getValue(R: 1));
9932 recursivelyDeleteUnusedNodes(N: N0.getNode());
9933 return SDValue(N, 0); // Return N so it doesn't get rechecked!
9934 }
9935 break;
9936 }
9937 }
9938 }
9939 }
9940
9941 // fold (not (zext (setcc x, y))) -> (zext (not (setcc x, y)))
9942 if (isOneConstant(V: N1) && N0Opcode == ISD::ZERO_EXTEND && N0.hasOneUse() &&
9943 isSetCCEquivalent(N: N0.getOperand(i: 0), LHS, RHS, CC)){
9944 SDValue V = N0.getOperand(i: 0);
9945 SDLoc DL0(N0);
9946 V = DAG.getNode(Opcode: ISD::XOR, DL: DL0, VT: V.getValueType(), N1: V,
9947 N2: DAG.getConstant(Val: 1, DL: DL0, VT: V.getValueType()));
9948 AddToWorklist(N: V.getNode());
9949 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: V);
9950 }
9951
9952 // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are setcc
9953 // fold (not (and x, y)) -> (or (not x), (not y)) iff x or y are setcc
9954 if (isOneConstant(V: N1) && VT == MVT::i1 && N0.hasOneUse() &&
9955 (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
9956 SDValue N00 = N0.getOperand(i: 0), N01 = N0.getOperand(i: 1);
9957 if (isOneUseSetCC(N: N01) || isOneUseSetCC(N: N00)) {
9958 unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
9959 N00 = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N00), VT, N1: N00, N2: N1); // N00 = ~N00
9960 N01 = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N01), VT, N1: N01, N2: N1); // N01 = ~N01
9961 AddToWorklist(N: N00.getNode()); AddToWorklist(N: N01.getNode());
9962 return DAG.getNode(Opcode: NewOpcode, DL, VT, N1: N00, N2: N01);
9963 }
9964 }
9965 // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are constants
9966 // fold (not (and x, y)) -> (or (not x), (not y)) iff x or y are constants
9967 if (isAllOnesConstant(V: N1) && N0.hasOneUse() &&
9968 (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
9969 SDValue N00 = N0.getOperand(i: 0), N01 = N0.getOperand(i: 1);
9970 if (isa<ConstantSDNode>(Val: N01) || isa<ConstantSDNode>(Val: N00)) {
9971 unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
9972 N00 = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N00), VT, N1: N00, N2: N1); // N00 = ~N00
9973 N01 = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N01), VT, N1: N01, N2: N1); // N01 = ~N01
9974 AddToWorklist(N: N00.getNode()); AddToWorklist(N: N01.getNode());
9975 return DAG.getNode(Opcode: NewOpcode, DL, VT, N1: N00, N2: N01);
9976 }
9977 }
9978
9979 // fold (not (neg x)) -> (add X, -1)
9980 // FIXME: This can be generalized to (not (sub Y, X)) -> (add X, ~Y) if
9981 // Y is a constant or the subtract has a single use.
9982 if (isAllOnesConstant(V: N1) && N0.getOpcode() == ISD::SUB &&
9983 isNullConstant(V: N0.getOperand(i: 0))) {
9984 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 1),
9985 N2: DAG.getAllOnesConstant(DL, VT));
9986 }
9987
9988 // fold (not (add X, -1)) -> (neg X)
9989 if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() && isAllOnesConstant(V: N1) &&
9990 isAllOnesOrAllOnesSplat(V: N0.getOperand(i: 1))) {
9991 return DAG.getNegative(Val: N0.getOperand(i: 0), DL, VT);
9992 }
9993
9994 // fold (xor (and x, y), y) -> (and (not x), y)
9995 if (N0Opcode == ISD::AND && N0.hasOneUse() && N0->getOperand(Num: 1) == N1) {
9996 SDValue X = N0.getOperand(i: 0);
9997 SDValue NotX = DAG.getNOT(DL: SDLoc(X), Val: X, VT);
9998 AddToWorklist(N: NotX.getNode());
9999 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: NotX, N2: N1);
10000 }
10001
10002 // fold Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X)
10003 if (!LegalOperations || hasOperation(Opcode: ISD::ABS, VT)) {
10004 SDValue A = N0Opcode == ISD::ADD ? N0 : N1;
10005 SDValue S = N0Opcode == ISD::SRA ? N0 : N1;
10006 if (A.getOpcode() == ISD::ADD && S.getOpcode() == ISD::SRA) {
10007 SDValue A0 = A.getOperand(i: 0), A1 = A.getOperand(i: 1);
10008 SDValue S0 = S.getOperand(i: 0);
10009 if ((A0 == S && A1 == S0) || (A1 == S && A0 == S0))
10010 if (ConstantSDNode *C = isConstOrConstSplat(N: S.getOperand(i: 1)))
10011 if (C->getAPIntValue() == (VT.getScalarSizeInBits() - 1))
10012 return DAG.getNode(Opcode: ISD::ABS, DL, VT, Operand: S0);
10013 }
10014 }
10015
10016 // fold (xor x, x) -> 0
10017 if (N0 == N1)
10018 return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
10019
10020 // fold (xor (shl 1, x), -1) -> (rotl ~1, x)
10021 // Here is a concrete example of this equivalence:
10022 // i16 x == 14
10023 // i16 shl == 1 << 14 == 16384 == 0b0100000000000000
10024 // i16 xor == ~(1 << 14) == 49151 == 0b1011111111111111
10025 //
10026 // =>
10027 //
10028 // i16 ~1 == 0b1111111111111110
10029 // i16 rol(~1, 14) == 0b1011111111111111
10030 //
10031 // Some additional tips to help conceptualize this transform:
10032 // - Try to see the operation as placing a single zero in a value of all ones.
10033 // - There exists no value for x which would allow the result to contain zero.
10034 // - Values of x larger than the bitwidth are undefined and do not require a
10035 // consistent result.
10036 // - Pushing the zero left requires shifting one bits in from the right.
10037 // A rotate left of ~1 is a nice way of achieving the desired result.
10038 if (TLI.isOperationLegalOrCustom(Op: ISD::ROTL, VT) && N0Opcode == ISD::SHL &&
10039 isAllOnesConstant(V: N1) && isOneConstant(V: N0.getOperand(i: 0))) {
10040 return DAG.getNode(Opcode: ISD::ROTL, DL, VT, N1: DAG.getSignedConstant(Val: ~1, DL, VT),
10041 N2: N0.getOperand(i: 1));
10042 }
10043
10044 // Simplify: xor (op x...), (op y...) -> (op (xor x, y))
10045 if (N0Opcode == N1.getOpcode())
10046 if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
10047 return V;
10048
10049 if (SDValue R = foldLogicOfShifts(N, LogicOp: N0, ShiftOp: N1, DAG))
10050 return R;
10051 if (SDValue R = foldLogicOfShifts(N, LogicOp: N1, ShiftOp: N0, DAG))
10052 return R;
10053 if (SDValue R = foldLogicTreeOfShifts(N, LeftHand: N0, RightHand: N1, DAG))
10054 return R;
10055
10056 // Unfold ((x ^ y) & m) ^ y into (x & m) | (y & ~m) if profitable
10057 if (SDValue MM = unfoldMaskedMerge(N))
10058 return MM;
10059
10060 // Simplify the expression using non-local knowledge.
10061 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
10062 return SDValue(N, 0);
10063
10064 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
10065 return Combined;
10066
10067 return SDValue();
10068}
10069
10070/// If we have a shift-by-constant of a bitwise logic op that itself has a
10071/// shift-by-constant operand with identical opcode, we may be able to convert
10072/// that into 2 independent shifts followed by the logic op. This is a
10073/// throughput improvement.
10074static SDValue combineShiftOfShiftedLogic(SDNode *Shift, SelectionDAG &DAG) {
10075 // Match a one-use bitwise logic op.
10076 SDValue LogicOp = Shift->getOperand(Num: 0);
10077 if (!LogicOp.hasOneUse())
10078 return SDValue();
10079
10080 unsigned LogicOpcode = LogicOp.getOpcode();
10081 if (LogicOpcode != ISD::AND && LogicOpcode != ISD::OR &&
10082 LogicOpcode != ISD::XOR)
10083 return SDValue();
10084
10085 // Find a matching one-use shift by constant.
10086 unsigned ShiftOpcode = Shift->getOpcode();
10087 SDValue C1 = Shift->getOperand(Num: 1);
10088 ConstantSDNode *C1Node = isConstOrConstSplat(N: C1);
10089 assert(C1Node && "Expected a shift with constant operand");
10090 const APInt &C1Val = C1Node->getAPIntValue();
10091 auto matchFirstShift = [&](SDValue V, SDValue &ShiftOp,
10092 const APInt *&ShiftAmtVal) {
10093 if (V.getOpcode() != ShiftOpcode || !V.hasOneUse())
10094 return false;
10095
10096 ConstantSDNode *ShiftCNode = isConstOrConstSplat(N: V.getOperand(i: 1));
10097 if (!ShiftCNode)
10098 return false;
10099
10100 // Capture the shifted operand and shift amount value.
10101 ShiftOp = V.getOperand(i: 0);
10102 ShiftAmtVal = &ShiftCNode->getAPIntValue();
10103
10104 // Shift amount types do not have to match their operand type, so check that
10105 // the constants are the same width.
10106 if (ShiftAmtVal->getBitWidth() != C1Val.getBitWidth())
10107 return false;
10108
10109 // The fold is not valid if the sum of the shift values doesn't fit in the
10110 // given shift amount type.
10111 bool Overflow = false;
10112 APInt NewShiftAmt = C1Val.uadd_ov(RHS: *ShiftAmtVal, Overflow);
10113 if (Overflow)
10114 return false;
10115
10116 // The fold is not valid if the sum of the shift values exceeds bitwidth.
10117 if (NewShiftAmt.uge(RHS: V.getScalarValueSizeInBits()))
10118 return false;
10119
10120 return true;
10121 };
10122
10123 // Logic ops are commutative, so check each operand for a match.
10124 SDValue X, Y;
10125 const APInt *C0Val;
10126 if (matchFirstShift(LogicOp.getOperand(i: 0), X, C0Val))
10127 Y = LogicOp.getOperand(i: 1);
10128 else if (matchFirstShift(LogicOp.getOperand(i: 1), X, C0Val))
10129 Y = LogicOp.getOperand(i: 0);
10130 else
10131 return SDValue();
10132
10133 // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1)
10134 SDLoc DL(Shift);
10135 EVT VT = Shift->getValueType(ResNo: 0);
10136 EVT ShiftAmtVT = Shift->getOperand(Num: 1).getValueType();
10137 SDValue ShiftSumC = DAG.getConstant(Val: *C0Val + C1Val, DL, VT: ShiftAmtVT);
10138 SDValue NewShift1 = DAG.getNode(Opcode: ShiftOpcode, DL, VT, N1: X, N2: ShiftSumC);
10139 SDValue NewShift2 = DAG.getNode(Opcode: ShiftOpcode, DL, VT, N1: Y, N2: C1);
10140 return DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: NewShift1, N2: NewShift2,
10141 Flags: LogicOp->getFlags());
10142}
10143
10144/// Handle transforms common to the three shifts, when the shift amount is a
10145/// constant.
10146/// We are looking for: (shift being one of shl/sra/srl)
10147/// shift (binop X, C0), C1
10148/// And want to transform into:
10149/// binop (shift X, C1), (shift C0, C1)
10150SDValue DAGCombiner::visitShiftByConstant(SDNode *N) {
10151 assert(isConstOrConstSplat(N->getOperand(1)) && "Expected constant operand");
10152
10153 // Do not turn a 'not' into a regular xor.
10154 if (isBitwiseNot(V: N->getOperand(Num: 0)))
10155 return SDValue();
10156
10157 // The inner binop must be one-use, since we want to replace it.
10158 SDValue LHS = N->getOperand(Num: 0);
10159 if (!LHS.hasOneUse() || !TLI.isDesirableToCommuteWithShift(N, Level))
10160 return SDValue();
10161
10162 // Fold shift(bitop(shift(x,c1),y), c2) -> bitop(shift(x,c1+c2),shift(y,c2)).
10163 if (SDValue R = combineShiftOfShiftedLogic(Shift: N, DAG))
10164 return R;
10165
10166 // We want to pull some binops through shifts, so that we have (and (shift))
10167 // instead of (shift (and)), likewise for add, or, xor, etc. This sort of
10168 // thing happens with address calculations, so it's important to canonicalize
10169 // it.
10170 switch (LHS.getOpcode()) {
10171 default:
10172 return SDValue();
10173 case ISD::OR:
10174 case ISD::XOR:
10175 case ISD::AND:
10176 break;
10177 case ISD::ADD:
10178 if (N->getOpcode() != ISD::SHL)
10179 return SDValue(); // only shl(add) not sr[al](add).
10180 break;
10181 }
10182
10183 // FIXME: disable this unless the input to the binop is a shift by a constant
10184 // or is copy/select. Enable this in other cases when figure out it's exactly
10185 // profitable.
10186 SDValue BinOpLHSVal = LHS.getOperand(i: 0);
10187 bool IsShiftByConstant = (BinOpLHSVal.getOpcode() == ISD::SHL ||
10188 BinOpLHSVal.getOpcode() == ISD::SRA ||
10189 BinOpLHSVal.getOpcode() == ISD::SRL) &&
10190 isa<ConstantSDNode>(Val: BinOpLHSVal.getOperand(i: 1));
10191 bool IsCopyOrSelect = BinOpLHSVal.getOpcode() == ISD::CopyFromReg ||
10192 BinOpLHSVal.getOpcode() == ISD::SELECT;
10193
10194 if (!IsShiftByConstant && !IsCopyOrSelect)
10195 return SDValue();
10196
10197 if (IsCopyOrSelect && N->hasOneUse())
10198 return SDValue();
10199
10200 // Attempt to fold the constants, shifting the binop RHS by the shift amount.
10201 SDLoc DL(N);
10202 EVT VT = N->getValueType(ResNo: 0);
10203 if (SDValue NewRHS = DAG.FoldConstantArithmetic(
10204 Opcode: N->getOpcode(), DL, VT, Ops: {LHS.getOperand(i: 1), N->getOperand(Num: 1)})) {
10205 SDValue NewShift = DAG.getNode(Opcode: N->getOpcode(), DL, VT, N1: LHS.getOperand(i: 0),
10206 N2: N->getOperand(Num: 1));
10207 return DAG.getNode(Opcode: LHS.getOpcode(), DL, VT, N1: NewShift, N2: NewRHS);
10208 }
10209
10210 return SDValue();
10211}
10212
10213SDValue DAGCombiner::distributeTruncateThroughAnd(SDNode *N) {
10214 assert(N->getOpcode() == ISD::TRUNCATE);
10215 assert(N->getOperand(0).getOpcode() == ISD::AND);
10216
10217 // (truncate:TruncVT (and N00, N01C)) -> (and (truncate:TruncVT N00), TruncC)
10218 EVT TruncVT = N->getValueType(ResNo: 0);
10219 if (N->hasOneUse() && N->getOperand(Num: 0).hasOneUse() &&
10220 TLI.isTypeDesirableForOp(ISD::AND, VT: TruncVT)) {
10221 SDValue N01 = N->getOperand(Num: 0).getOperand(i: 1);
10222 if (isConstantOrConstantVector(N: N01, /* NoOpaques */ true)) {
10223 SDLoc DL(N);
10224 SDValue N00 = N->getOperand(Num: 0).getOperand(i: 0);
10225 SDValue Trunc00 = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: TruncVT, Operand: N00);
10226 SDValue Trunc01 = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: TruncVT, Operand: N01);
10227 AddToWorklist(N: Trunc00.getNode());
10228 AddToWorklist(N: Trunc01.getNode());
10229 return DAG.getNode(Opcode: ISD::AND, DL, VT: TruncVT, N1: Trunc00, N2: Trunc01);
10230 }
10231 }
10232
10233 return SDValue();
10234}
10235
10236SDValue DAGCombiner::visitRotate(SDNode *N) {
10237 SDLoc dl(N);
10238 SDValue N0 = N->getOperand(Num: 0);
10239 SDValue N1 = N->getOperand(Num: 1);
10240 EVT VT = N->getValueType(ResNo: 0);
10241 unsigned Bitsize = VT.getScalarSizeInBits();
10242
10243 // fold (rot x, 0) -> x
10244 if (isNullOrNullSplat(V: N1))
10245 return N0;
10246
10247 // fold (rot x, c) -> x iff (c % BitSize) == 0
10248 if (isPowerOf2_32(Value: Bitsize) && Bitsize > 1) {
10249 APInt ModuloMask(N1.getScalarValueSizeInBits(), Bitsize - 1);
10250 if (DAG.MaskedValueIsZero(Op: N1, Mask: ModuloMask))
10251 return N0;
10252 }
10253
10254 // fold (rot x, c) -> (rot x, c % BitSize)
10255 bool OutOfRange = false;
10256 auto MatchOutOfRange = [Bitsize, &OutOfRange](ConstantSDNode *C) {
10257 OutOfRange |= C->getAPIntValue().uge(RHS: Bitsize);
10258 return true;
10259 };
10260 if (ISD::matchUnaryPredicate(Op: N1, Match: MatchOutOfRange) && OutOfRange) {
10261 EVT AmtVT = N1.getValueType();
10262 SDValue Bits = DAG.getConstant(Val: Bitsize, DL: dl, VT: AmtVT);
10263 if (SDValue Amt =
10264 DAG.FoldConstantArithmetic(Opcode: ISD::UREM, DL: dl, VT: AmtVT, Ops: {N1, Bits}))
10265 return DAG.getNode(Opcode: N->getOpcode(), DL: dl, VT, N1: N0, N2: Amt);
10266 }
10267
10268 // rot i16 X, 8 --> bswap X
10269 auto *RotAmtC = isConstOrConstSplat(N: N1);
10270 if (RotAmtC && RotAmtC->getAPIntValue() == 8 &&
10271 VT.getScalarSizeInBits() == 16 && hasOperation(Opcode: ISD::BSWAP, VT))
10272 return DAG.getNode(Opcode: ISD::BSWAP, DL: dl, VT, Operand: N0);
10273
10274 // Simplify the operands using demanded-bits information.
10275 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
10276 return SDValue(N, 0);
10277
10278 // fold (rot* x, (trunc (and y, c))) -> (rot* x, (and (trunc y), (trunc c))).
10279 if (N1.getOpcode() == ISD::TRUNCATE &&
10280 N1.getOperand(i: 0).getOpcode() == ISD::AND) {
10281 if (SDValue NewOp1 = distributeTruncateThroughAnd(N: N1.getNode()))
10282 return DAG.getNode(Opcode: N->getOpcode(), DL: dl, VT, N1: N0, N2: NewOp1);
10283 }
10284
10285 unsigned NextOp = N0.getOpcode();
10286
10287 // fold (rot* (rot* x, c2), c1)
10288 // -> (rot* x, ((c1 % bitsize) +- (c2 % bitsize) + bitsize) % bitsize)
10289 if (NextOp == ISD::ROTL || NextOp == ISD::ROTR) {
10290 bool C1 = DAG.isConstantIntBuildVectorOrConstantInt(N: N1);
10291 bool C2 = DAG.isConstantIntBuildVectorOrConstantInt(N: N0.getOperand(i: 1));
10292 if (C1 && C2 && N1.getValueType() == N0.getOperand(i: 1).getValueType()) {
10293 EVT ShiftVT = N1.getValueType();
10294 bool SameSide = (N->getOpcode() == NextOp);
10295 unsigned CombineOp = SameSide ? ISD::ADD : ISD::SUB;
10296 SDValue BitsizeC = DAG.getConstant(Val: Bitsize, DL: dl, VT: ShiftVT);
10297 SDValue Norm1 = DAG.FoldConstantArithmetic(Opcode: ISD::UREM, DL: dl, VT: ShiftVT,
10298 Ops: {N1, BitsizeC});
10299 SDValue Norm2 = DAG.FoldConstantArithmetic(Opcode: ISD::UREM, DL: dl, VT: ShiftVT,
10300 Ops: {N0.getOperand(i: 1), BitsizeC});
10301 if (Norm1 && Norm2)
10302 if (SDValue CombinedShift = DAG.FoldConstantArithmetic(
10303 Opcode: CombineOp, DL: dl, VT: ShiftVT, Ops: {Norm1, Norm2})) {
10304 CombinedShift = DAG.FoldConstantArithmetic(Opcode: ISD::ADD, DL: dl, VT: ShiftVT,
10305 Ops: {CombinedShift, BitsizeC});
10306 SDValue CombinedShiftNorm = DAG.FoldConstantArithmetic(
10307 Opcode: ISD::UREM, DL: dl, VT: ShiftVT, Ops: {CombinedShift, BitsizeC});
10308 return DAG.getNode(Opcode: N->getOpcode(), DL: dl, VT, N1: N0->getOperand(Num: 0),
10309 N2: CombinedShiftNorm);
10310 }
10311 }
10312 }
10313 return SDValue();
10314}
10315
10316SDValue DAGCombiner::visitSHL(SDNode *N) {
10317 SDValue N0 = N->getOperand(Num: 0);
10318 SDValue N1 = N->getOperand(Num: 1);
10319 if (SDValue V = DAG.simplifyShift(X: N0, Y: N1))
10320 return V;
10321
10322 SDLoc DL(N);
10323 EVT VT = N0.getValueType();
10324 EVT ShiftVT = N1.getValueType();
10325 unsigned OpSizeInBits = VT.getScalarSizeInBits();
10326
10327 // fold (shl c1, c2) -> c1<<c2
10328 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL, VT, Ops: {N0, N1}))
10329 return C;
10330
10331 // fold vector ops
10332 if (VT.isVector()) {
10333 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
10334 return FoldedVOp;
10335
10336 BuildVectorSDNode *N1CV = dyn_cast<BuildVectorSDNode>(Val&: N1);
10337 // If setcc produces all-one true value then:
10338 // (shl (and (setcc) N01CV) N1CV) -> (and (setcc) N01CV<<N1CV)
10339 if (N1CV && N1CV->isConstant()) {
10340 if (N0.getOpcode() == ISD::AND) {
10341 SDValue N00 = N0->getOperand(Num: 0);
10342 SDValue N01 = N0->getOperand(Num: 1);
10343 BuildVectorSDNode *N01CV = dyn_cast<BuildVectorSDNode>(Val&: N01);
10344
10345 if (N01CV && N01CV->isConstant() && N00.getOpcode() == ISD::SETCC &&
10346 TLI.getBooleanContents(Type: N00.getOperand(i: 0).getValueType()) ==
10347 TargetLowering::ZeroOrNegativeOneBooleanContent) {
10348 if (SDValue C =
10349 DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL, VT, Ops: {N01, N1}))
10350 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N00, N2: C);
10351 }
10352 }
10353 }
10354 }
10355
10356 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
10357 return NewSel;
10358
10359 // if (shl x, c) is known to be zero, return 0
10360 if (DAG.MaskedValueIsZero(Op: SDValue(N, 0), Mask: APInt::getAllOnes(numBits: OpSizeInBits)))
10361 return DAG.getConstant(Val: 0, DL, VT);
10362
10363 // fold (shl x, (trunc (and y, c))) -> (shl x, (and (trunc y), (trunc c))).
10364 if (N1.getOpcode() == ISD::TRUNCATE &&
10365 N1.getOperand(i: 0).getOpcode() == ISD::AND) {
10366 if (SDValue NewOp1 = distributeTruncateThroughAnd(N: N1.getNode()))
10367 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0, N2: NewOp1);
10368 }
10369
10370 // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2))
10371 if (N0.getOpcode() == ISD::SHL) {
10372 auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
10373 ConstantSDNode *RHS) {
10374 APInt c1 = LHS->getAPIntValue();
10375 APInt c2 = RHS->getAPIntValue();
10376 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
10377 return (c1 + c2).uge(RHS: OpSizeInBits);
10378 };
10379 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchOutOfRange))
10380 return DAG.getConstant(Val: 0, DL, VT);
10381
10382 auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
10383 ConstantSDNode *RHS) {
10384 APInt c1 = LHS->getAPIntValue();
10385 APInt c2 = RHS->getAPIntValue();
10386 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
10387 return (c1 + c2).ult(RHS: OpSizeInBits);
10388 };
10389 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchInRange)) {
10390 SDValue Sum = DAG.getNode(Opcode: ISD::ADD, DL, VT: ShiftVT, N1, N2: N0.getOperand(i: 1));
10391 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0.getOperand(i: 0), N2: Sum);
10392 }
10393 }
10394
10395 // fold (shl (ext (shl x, c1)), c2) -> (shl (ext x), (add c1, c2))
10396 // For this to be valid, the second form must not preserve any of the bits
10397 // that are shifted out by the inner shift in the first form. This means
10398 // the outer shift size must be >= the number of bits added by the ext.
10399 // As a corollary, we don't care what kind of ext it is.
10400 if ((N0.getOpcode() == ISD::ZERO_EXTEND ||
10401 N0.getOpcode() == ISD::ANY_EXTEND ||
10402 N0.getOpcode() == ISD::SIGN_EXTEND) &&
10403 N0.getOperand(i: 0).getOpcode() == ISD::SHL) {
10404 SDValue N0Op0 = N0.getOperand(i: 0);
10405 SDValue InnerShiftAmt = N0Op0.getOperand(i: 1);
10406 EVT InnerVT = N0Op0.getValueType();
10407 uint64_t InnerBitwidth = InnerVT.getScalarSizeInBits();
10408
10409 auto MatchOutOfRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
10410 ConstantSDNode *RHS) {
10411 APInt c1 = LHS->getAPIntValue();
10412 APInt c2 = RHS->getAPIntValue();
10413 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
10414 return c2.uge(RHS: OpSizeInBits - InnerBitwidth) &&
10415 (c1 + c2).uge(RHS: OpSizeInBits);
10416 };
10417 if (ISD::matchBinaryPredicate(LHS: InnerShiftAmt, RHS: N1, Match: MatchOutOfRange,
10418 /*AllowUndefs*/ false,
10419 /*AllowTypeMismatch*/ true))
10420 return DAG.getConstant(Val: 0, DL, VT);
10421
10422 auto MatchInRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
10423 ConstantSDNode *RHS) {
10424 APInt c1 = LHS->getAPIntValue();
10425 APInt c2 = RHS->getAPIntValue();
10426 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
10427 return c2.uge(RHS: OpSizeInBits - InnerBitwidth) &&
10428 (c1 + c2).ult(RHS: OpSizeInBits);
10429 };
10430 if (ISD::matchBinaryPredicate(LHS: InnerShiftAmt, RHS: N1, Match: MatchInRange,
10431 /*AllowUndefs*/ false,
10432 /*AllowTypeMismatch*/ true)) {
10433 SDValue Ext = DAG.getNode(Opcode: N0.getOpcode(), DL, VT, Operand: N0Op0.getOperand(i: 0));
10434 SDValue Sum = DAG.getZExtOrTrunc(Op: InnerShiftAmt, DL, VT: ShiftVT);
10435 Sum = DAG.getNode(Opcode: ISD::ADD, DL, VT: ShiftVT, N1: Sum, N2: N1);
10436 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Ext, N2: Sum);
10437 }
10438 }
10439
10440 // fold (shl (zext (srl x, C)), C) -> (zext (shl (srl x, C), C))
10441 // Only fold this if the inner zext has no other uses to avoid increasing
10442 // the total number of instructions.
10443 if (N0.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse() &&
10444 N0.getOperand(i: 0).getOpcode() == ISD::SRL) {
10445 SDValue N0Op0 = N0.getOperand(i: 0);
10446 SDValue InnerShiftAmt = N0Op0.getOperand(i: 1);
10447
10448 auto MatchEqual = [VT](ConstantSDNode *LHS, ConstantSDNode *RHS) {
10449 APInt c1 = LHS->getAPIntValue();
10450 APInt c2 = RHS->getAPIntValue();
10451 zeroExtendToMatch(LHS&: c1, RHS&: c2);
10452 return c1.ult(RHS: VT.getScalarSizeInBits()) && (c1 == c2);
10453 };
10454 if (ISD::matchBinaryPredicate(LHS: InnerShiftAmt, RHS: N1, Match: MatchEqual,
10455 /*AllowUndefs*/ false,
10456 /*AllowTypeMismatch*/ true)) {
10457 EVT InnerShiftAmtVT = N0Op0.getOperand(i: 1).getValueType();
10458 SDValue NewSHL = DAG.getZExtOrTrunc(Op: N1, DL, VT: InnerShiftAmtVT);
10459 NewSHL = DAG.getNode(Opcode: ISD::SHL, DL, VT: N0Op0.getValueType(), N1: N0Op0, N2: NewSHL);
10460 AddToWorklist(N: NewSHL.getNode());
10461 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL: SDLoc(N0), VT, Operand: NewSHL);
10462 }
10463 }
10464
10465 if (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SRA) {
10466 auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
10467 ConstantSDNode *RHS) {
10468 const APInt &LHSC = LHS->getAPIntValue();
10469 const APInt &RHSC = RHS->getAPIntValue();
10470 return LHSC.ult(RHS: OpSizeInBits) && RHSC.ult(RHS: OpSizeInBits) &&
10471 LHSC.getZExtValue() <= RHSC.getZExtValue();
10472 };
10473
10474 // fold (shl (sr[la] exact X, C1), C2) -> (shl X, (C2-C1)) if C1 <= C2
10475 // fold (shl (sr[la] exact X, C1), C2) -> (sr[la] X, (C2-C1)) if C1 >= C2
10476 if (N0->getFlags().hasExact()) {
10477 if (ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchShiftAmount,
10478 /*AllowUndefs*/ false,
10479 /*AllowTypeMismatch*/ true)) {
10480 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
10481 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1, N2: N01);
10482 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
10483 }
10484 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchShiftAmount,
10485 /*AllowUndefs*/ false,
10486 /*AllowTypeMismatch*/ true)) {
10487 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
10488 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1: N01, N2: N1);
10489 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
10490 }
10491 }
10492
10493 // fold (shl (srl x, c1), c2) -> (and (shl x, (sub c2, c1), MASK) or
10494 // (and (srl x, (sub c1, c2), MASK)
10495 // Only fold this if the inner shift has no other uses -- if it does,
10496 // folding this will increase the total number of instructions.
10497 if (N0.getOpcode() == ISD::SRL &&
10498 (N0.getOperand(i: 1) == N1 || N0.hasOneUse()) &&
10499 TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
10500 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchShiftAmount,
10501 /*AllowUndefs*/ false,
10502 /*AllowTypeMismatch*/ true)) {
10503 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
10504 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1: N01, N2: N1);
10505 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10506 Mask = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Mask, N2: N01);
10507 Mask = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Mask, N2: Diff);
10508 SDValue Shift = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
10509 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Shift, N2: Mask);
10510 }
10511 if (ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchShiftAmount,
10512 /*AllowUndefs*/ false,
10513 /*AllowTypeMismatch*/ true)) {
10514 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
10515 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1, N2: N01);
10516 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10517 Mask = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Mask, N2: N1);
10518 SDValue Shift = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
10519 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Shift, N2: Mask);
10520 }
10521 }
10522 }
10523
10524 // fold (shl (sra x, c1), c1) -> (and x, (shl -1, c1))
10525 if (N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(i: 1) &&
10526 isConstantOrConstantVector(N: N1, /* No Opaques */ NoOpaques: true)) {
10527 SDValue AllBits = DAG.getAllOnesConstant(DL, VT);
10528 SDValue HiBitsMask = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: AllBits, N2: N1);
10529 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0.getOperand(i: 0), N2: HiBitsMask);
10530 }
10531
10532 // fold (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
10533 // fold (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
10534 // Variant of version done on multiply, except mul by a power of 2 is turned
10535 // into a shift.
10536 if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR) &&
10537 TLI.isDesirableToCommuteWithShift(N, Level)) {
10538 SDValue N01 = N0.getOperand(i: 1);
10539 if (SDValue Shl1 =
10540 DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL: SDLoc(N1), VT, Ops: {N01, N1})) {
10541 SDValue Shl0 = DAG.getNode(Opcode: ISD::SHL, DL: SDLoc(N0), VT, N1: N0.getOperand(i: 0), N2: N1);
10542 AddToWorklist(N: Shl0.getNode());
10543 SDNodeFlags Flags;
10544 // Preserve the disjoint flag for Or.
10545 if (N0.getOpcode() == ISD::OR && N0->getFlags().hasDisjoint())
10546 Flags |= SDNodeFlags::Disjoint;
10547 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, N1: Shl0, N2: Shl1, Flags);
10548 }
10549 }
10550
10551 // fold (shl (sext (add_nsw x, c1)), c2) -> (add (shl (sext x), c2), c1 << c2)
10552 // TODO: Add zext/add_nuw variant with suitable test coverage
10553 // TODO: Should we limit this with isLegalAddImmediate?
10554 if (N0.getOpcode() == ISD::SIGN_EXTEND &&
10555 N0.getOperand(i: 0).getOpcode() == ISD::ADD &&
10556 N0.getOperand(i: 0)->getFlags().hasNoSignedWrap() &&
10557 TLI.isDesirableToCommuteWithShift(N, Level)) {
10558 SDValue Add = N0.getOperand(i: 0);
10559 SDLoc DL(N0);
10560 if (SDValue ExtC = DAG.FoldConstantArithmetic(Opcode: N0.getOpcode(), DL, VT,
10561 Ops: {Add.getOperand(i: 1)})) {
10562 if (SDValue ShlC =
10563 DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL, VT, Ops: {ExtC, N1})) {
10564 SDValue ExtX = DAG.getNode(Opcode: N0.getOpcode(), DL, VT, Operand: Add.getOperand(i: 0));
10565 SDValue ShlX = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: ExtX, N2: N1);
10566 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: ShlX, N2: ShlC);
10567 }
10568 }
10569 }
10570
10571 // fold (shl (mul x, c1), c2) -> (mul x, c1 << c2)
10572 if (N0.getOpcode() == ISD::MUL && N0->hasOneUse()) {
10573 SDValue N01 = N0.getOperand(i: 1);
10574 if (SDValue Shl =
10575 DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL: SDLoc(N1), VT, Ops: {N01, N1}))
10576 return DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: N0.getOperand(i: 0), N2: Shl);
10577 }
10578
10579 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
10580 if (N1C && !N1C->isOpaque())
10581 if (SDValue NewSHL = visitShiftByConstant(N))
10582 return NewSHL;
10583
10584 // fold (shl X, cttz(Y)) -> (mul (Y & -Y), X) if cttz is unsupported on the
10585 // target.
10586 if (((N1.getOpcode() == ISD::CTTZ &&
10587 VT.getScalarSizeInBits() <= ShiftVT.getScalarSizeInBits()) ||
10588 N1.getOpcode() == ISD::CTTZ_ZERO_UNDEF) &&
10589 N1.hasOneUse() && !TLI.isOperationLegalOrCustom(Op: ISD::CTTZ, VT: ShiftVT) &&
10590 TLI.isOperationLegalOrCustom(Op: ISD::MUL, VT)) {
10591 SDValue Y = N1.getOperand(i: 0);
10592 SDLoc DL(N);
10593 SDValue NegY = DAG.getNegative(Val: Y, DL, VT: ShiftVT);
10594 SDValue And =
10595 DAG.getZExtOrTrunc(Op: DAG.getNode(Opcode: ISD::AND, DL, VT: ShiftVT, N1: Y, N2: NegY), DL, VT);
10596 return DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: And, N2: N0);
10597 }
10598
10599 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
10600 return SDValue(N, 0);
10601
10602 // Fold (shl (vscale * C0), C1) to (vscale * (C0 << C1)).
10603 if (N0.getOpcode() == ISD::VSCALE && N1C) {
10604 const APInt &C0 = N0.getConstantOperandAPInt(i: 0);
10605 const APInt &C1 = N1C->getAPIntValue();
10606 return DAG.getVScale(DL, VT, MulImm: C0 << C1);
10607 }
10608
10609 // Fold (shl step_vector(C0), C1) to (step_vector(C0 << C1)).
10610 APInt ShlVal;
10611 if (N0.getOpcode() == ISD::STEP_VECTOR &&
10612 ISD::isConstantSplatVector(N: N1.getNode(), SplatValue&: ShlVal)) {
10613 const APInt &C0 = N0.getConstantOperandAPInt(i: 0);
10614 if (ShlVal.ult(RHS: C0.getBitWidth())) {
10615 APInt NewStep = C0 << ShlVal;
10616 return DAG.getStepVector(DL, ResVT: VT, StepVal: NewStep);
10617 }
10618 }
10619
10620 return SDValue();
10621}
10622
10623// Transform a right shift of a multiply into a multiply-high.
10624// Examples:
10625// (srl (mul (zext i32:$a to i64), (zext i32:$a to i64)), 32) -> (mulhu $a, $b)
10626// (sra (mul (sext i32:$a to i64), (sext i32:$a to i64)), 32) -> (mulhs $a, $b)
10627static SDValue combineShiftToMULH(SDNode *N, const SDLoc &DL, SelectionDAG &DAG,
10628 const TargetLowering &TLI) {
10629 assert((N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) &&
10630 "SRL or SRA node is required here!");
10631
10632 // Check the shift amount. Proceed with the transformation if the shift
10633 // amount is constant.
10634 ConstantSDNode *ShiftAmtSrc = isConstOrConstSplat(N: N->getOperand(Num: 1));
10635 if (!ShiftAmtSrc)
10636 return SDValue();
10637
10638 // The operation feeding into the shift must be a multiply.
10639 SDValue ShiftOperand = N->getOperand(Num: 0);
10640 if (ShiftOperand.getOpcode() != ISD::MUL)
10641 return SDValue();
10642
10643 // Both operands must be equivalent extend nodes.
10644 SDValue LeftOp = ShiftOperand.getOperand(i: 0);
10645 SDValue RightOp = ShiftOperand.getOperand(i: 1);
10646
10647 bool IsSignExt = LeftOp.getOpcode() == ISD::SIGN_EXTEND;
10648 bool IsZeroExt = LeftOp.getOpcode() == ISD::ZERO_EXTEND;
10649
10650 if (!IsSignExt && !IsZeroExt)
10651 return SDValue();
10652
10653 EVT NarrowVT = LeftOp.getOperand(i: 0).getValueType();
10654 unsigned NarrowVTSize = NarrowVT.getScalarSizeInBits();
10655
10656 // return true if U may use the lower bits of its operands
10657 auto UserOfLowerBits = [NarrowVTSize](SDNode *U) {
10658 if (U->getOpcode() != ISD::SRL && U->getOpcode() != ISD::SRA) {
10659 return true;
10660 }
10661 ConstantSDNode *UShiftAmtSrc = isConstOrConstSplat(N: U->getOperand(Num: 1));
10662 if (!UShiftAmtSrc) {
10663 return true;
10664 }
10665 unsigned UShiftAmt = UShiftAmtSrc->getZExtValue();
10666 return UShiftAmt < NarrowVTSize;
10667 };
10668
10669 // If the lower part of the MUL is also used and MUL_LOHI is supported
10670 // do not introduce the MULH in favor of MUL_LOHI
10671 unsigned MulLoHiOp = IsSignExt ? ISD::SMUL_LOHI : ISD::UMUL_LOHI;
10672 if (!ShiftOperand.hasOneUse() &&
10673 TLI.isOperationLegalOrCustom(Op: MulLoHiOp, VT: NarrowVT) &&
10674 llvm::any_of(Range: ShiftOperand->users(), P: UserOfLowerBits)) {
10675 return SDValue();
10676 }
10677
10678 SDValue MulhRightOp;
10679 if (ConstantSDNode *Constant = isConstOrConstSplat(N: RightOp)) {
10680 unsigned ActiveBits = IsSignExt
10681 ? Constant->getAPIntValue().getSignificantBits()
10682 : Constant->getAPIntValue().getActiveBits();
10683 if (ActiveBits > NarrowVTSize)
10684 return SDValue();
10685 MulhRightOp = DAG.getConstant(
10686 Val: Constant->getAPIntValue().trunc(width: NarrowVT.getScalarSizeInBits()), DL,
10687 VT: NarrowVT);
10688 } else {
10689 if (LeftOp.getOpcode() != RightOp.getOpcode())
10690 return SDValue();
10691 // Check that the two extend nodes are the same type.
10692 if (NarrowVT != RightOp.getOperand(i: 0).getValueType())
10693 return SDValue();
10694 MulhRightOp = RightOp.getOperand(i: 0);
10695 }
10696
10697 EVT WideVT = LeftOp.getValueType();
10698 // Proceed with the transformation if the wide types match.
10699 assert((WideVT == RightOp.getValueType()) &&
10700 "Cannot have a multiply node with two different operand types.");
10701
10702 // Proceed with the transformation if the wide type is twice as large
10703 // as the narrow type.
10704 if (WideVT.getScalarSizeInBits() != 2 * NarrowVTSize)
10705 return SDValue();
10706
10707 // Check the shift amount with the narrow type size.
10708 // Proceed with the transformation if the shift amount is the width
10709 // of the narrow type.
10710 unsigned ShiftAmt = ShiftAmtSrc->getZExtValue();
10711 if (ShiftAmt != NarrowVTSize)
10712 return SDValue();
10713
10714 // If the operation feeding into the MUL is a sign extend (sext),
10715 // we use mulhs. Othewise, zero extends (zext) use mulhu.
10716 unsigned MulhOpcode = IsSignExt ? ISD::MULHS : ISD::MULHU;
10717
10718 // Combine to mulh if mulh is legal/custom for the narrow type on the target
10719 // or if it is a vector type then we could transform to an acceptable type and
10720 // rely on legalization to split/combine the result.
10721 if (NarrowVT.isVector()) {
10722 EVT TransformVT = TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: NarrowVT);
10723 if (TransformVT.getVectorElementType() != NarrowVT.getVectorElementType() ||
10724 !TLI.isOperationLegalOrCustom(Op: MulhOpcode, VT: TransformVT))
10725 return SDValue();
10726 } else {
10727 if (!TLI.isOperationLegalOrCustom(Op: MulhOpcode, VT: NarrowVT))
10728 return SDValue();
10729 }
10730
10731 SDValue Result =
10732 DAG.getNode(Opcode: MulhOpcode, DL, VT: NarrowVT, N1: LeftOp.getOperand(i: 0), N2: MulhRightOp);
10733 bool IsSigned = N->getOpcode() == ISD::SRA;
10734 return DAG.getExtOrTrunc(IsSigned, Op: Result, DL, VT: WideVT);
10735}
10736
10737// fold (bswap (logic_op(bswap(x),y))) -> logic_op(x,bswap(y))
10738// This helper function accept SDNode with opcode ISD::BSWAP and ISD::BITREVERSE
10739static SDValue foldBitOrderCrossLogicOp(SDNode *N, SelectionDAG &DAG) {
10740 unsigned Opcode = N->getOpcode();
10741 if (Opcode != ISD::BSWAP && Opcode != ISD::BITREVERSE)
10742 return SDValue();
10743
10744 SDValue N0 = N->getOperand(Num: 0);
10745 EVT VT = N->getValueType(ResNo: 0);
10746 SDLoc DL(N);
10747 SDValue X, Y;
10748
10749 // If both operands are bswap/bitreverse, ignore the multiuse
10750 if (sd_match(N: N0, P: m_OneUse(P: m_BitwiseLogic(L: m_UnaryOp(Opc: Opcode, Op: m_Value(N&: X)),
10751 R: m_UnaryOp(Opc: Opcode, Op: m_Value(N&: Y))))))
10752 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, N1: X, N2: Y);
10753
10754 // Otherwise need to ensure logic_op and bswap/bitreverse(x) have one use.
10755 if (sd_match(N: N0, P: m_OneUse(P: m_BitwiseLogic(
10756 L: m_OneUse(P: m_UnaryOp(Opc: Opcode, Op: m_Value(N&: X))), R: m_Value(N&: Y))))) {
10757 SDValue NewBitReorder = DAG.getNode(Opcode, DL, VT, Operand: Y);
10758 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, N1: X, N2: NewBitReorder);
10759 }
10760
10761 return SDValue();
10762}
10763
10764SDValue DAGCombiner::visitSRA(SDNode *N) {
10765 SDValue N0 = N->getOperand(Num: 0);
10766 SDValue N1 = N->getOperand(Num: 1);
10767 if (SDValue V = DAG.simplifyShift(X: N0, Y: N1))
10768 return V;
10769
10770 SDLoc DL(N);
10771 EVT VT = N0.getValueType();
10772 unsigned OpSizeInBits = VT.getScalarSizeInBits();
10773
10774 // fold (sra c1, c2) -> (sra c1, c2)
10775 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::SRA, DL, VT, Ops: {N0, N1}))
10776 return C;
10777
10778 // Arithmetic shifting an all-sign-bit value is a no-op.
10779 // fold (sra 0, x) -> 0
10780 // fold (sra -1, x) -> -1
10781 if (DAG.ComputeNumSignBits(Op: N0) == OpSizeInBits)
10782 return N0;
10783
10784 // fold vector ops
10785 if (VT.isVector())
10786 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
10787 return FoldedVOp;
10788
10789 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
10790 return NewSel;
10791
10792 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
10793
10794 // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2))
10795 // clamp (add c1, c2) to max shift.
10796 if (N0.getOpcode() == ISD::SRA) {
10797 EVT ShiftVT = N1.getValueType();
10798 EVT ShiftSVT = ShiftVT.getScalarType();
10799 SmallVector<SDValue, 16> ShiftValues;
10800
10801 auto SumOfShifts = [&](ConstantSDNode *LHS, ConstantSDNode *RHS) {
10802 APInt c1 = LHS->getAPIntValue();
10803 APInt c2 = RHS->getAPIntValue();
10804 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
10805 APInt Sum = c1 + c2;
10806 unsigned ShiftSum =
10807 Sum.uge(RHS: OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue();
10808 ShiftValues.push_back(Elt: DAG.getConstant(Val: ShiftSum, DL, VT: ShiftSVT));
10809 return true;
10810 };
10811 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: SumOfShifts)) {
10812 SDValue ShiftValue;
10813 if (N1.getOpcode() == ISD::BUILD_VECTOR)
10814 ShiftValue = DAG.getBuildVector(VT: ShiftVT, DL, Ops: ShiftValues);
10815 else if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
10816 assert(ShiftValues.size() == 1 &&
10817 "Expected matchBinaryPredicate to return one element for "
10818 "SPLAT_VECTORs");
10819 ShiftValue = DAG.getSplatVector(VT: ShiftVT, DL, Op: ShiftValues[0]);
10820 } else
10821 ShiftValue = ShiftValues[0];
10822 return DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: N0.getOperand(i: 0), N2: ShiftValue);
10823 }
10824 }
10825
10826 // fold (sra (shl X, m), (sub result_size, n))
10827 // -> (sign_extend (trunc (shl X, (sub (sub result_size, n), m)))) for
10828 // result_size - n != m.
10829 // If truncate is free for the target sext(shl) is likely to result in better
10830 // code.
10831 if (N0.getOpcode() == ISD::SHL && N1C) {
10832 // Get the two constants of the shifts, CN0 = m, CN = n.
10833 const ConstantSDNode *N01C = isConstOrConstSplat(N: N0.getOperand(i: 1));
10834 if (N01C) {
10835 LLVMContext &Ctx = *DAG.getContext();
10836 // Determine what the truncate's result bitsize and type would be.
10837 EVT TruncVT = EVT::getIntegerVT(Context&: Ctx, BitWidth: OpSizeInBits - N1C->getZExtValue());
10838
10839 if (VT.isVector())
10840 TruncVT = EVT::getVectorVT(Context&: Ctx, VT: TruncVT, EC: VT.getVectorElementCount());
10841
10842 // Determine the residual right-shift amount.
10843 int ShiftAmt = N1C->getZExtValue() - N01C->getZExtValue();
10844
10845 // If the shift is not a no-op (in which case this should be just a sign
10846 // extend already), the truncated to type is legal, sign_extend is legal
10847 // on that type, and the truncate to that type is both legal and free,
10848 // perform the transform.
10849 if ((ShiftAmt > 0) &&
10850 TLI.isOperationLegalOrCustom(Op: ISD::SIGN_EXTEND, VT: TruncVT) &&
10851 TLI.isOperationLegalOrCustom(Op: ISD::TRUNCATE, VT) &&
10852 TLI.isTruncateFree(FromVT: VT, ToVT: TruncVT)) {
10853 SDValue Amt = DAG.getShiftAmountConstant(Val: ShiftAmt, VT, DL);
10854 SDValue Shift = DAG.getNode(Opcode: ISD::SRL, DL, VT,
10855 N1: N0.getOperand(i: 0), N2: Amt);
10856 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: TruncVT,
10857 Operand: Shift);
10858 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL,
10859 VT: N->getValueType(ResNo: 0), Operand: Trunc);
10860 }
10861 }
10862 }
10863
10864 // We convert trunc/ext to opposing shifts in IR, but casts may be cheaper.
10865 // sra (add (shl X, N1C), AddC), N1C -->
10866 // sext (add (trunc X to (width - N1C)), AddC')
10867 // sra (sub AddC, (shl X, N1C)), N1C -->
10868 // sext (sub AddC1',(trunc X to (width - N1C)))
10869 if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::SUB) && N1C &&
10870 N0.hasOneUse()) {
10871 bool IsAdd = N0.getOpcode() == ISD::ADD;
10872 SDValue Shl = N0.getOperand(i: IsAdd ? 0 : 1);
10873 if (Shl.getOpcode() == ISD::SHL && Shl.getOperand(i: 1) == N1 &&
10874 Shl.hasOneUse()) {
10875 // TODO: AddC does not need to be a splat.
10876 if (ConstantSDNode *AddC =
10877 isConstOrConstSplat(N: N0.getOperand(i: IsAdd ? 1 : 0))) {
10878 // Determine what the truncate's type would be and ask the target if
10879 // that is a free operation.
10880 LLVMContext &Ctx = *DAG.getContext();
10881 unsigned ShiftAmt = N1C->getZExtValue();
10882 EVT TruncVT = EVT::getIntegerVT(Context&: Ctx, BitWidth: OpSizeInBits - ShiftAmt);
10883 if (VT.isVector())
10884 TruncVT = EVT::getVectorVT(Context&: Ctx, VT: TruncVT, EC: VT.getVectorElementCount());
10885
10886 // TODO: The simple type check probably belongs in the default hook
10887 // implementation and/or target-specific overrides (because
10888 // non-simple types likely require masking when legalized), but
10889 // that restriction may conflict with other transforms.
10890 if (TruncVT.isSimple() && isTypeLegal(VT: TruncVT) &&
10891 TLI.isTruncateFree(FromVT: VT, ToVT: TruncVT)) {
10892 SDValue Trunc = DAG.getZExtOrTrunc(Op: Shl.getOperand(i: 0), DL, VT: TruncVT);
10893 SDValue ShiftC =
10894 DAG.getConstant(Val: AddC->getAPIntValue().lshr(shiftAmt: ShiftAmt).trunc(
10895 width: TruncVT.getScalarSizeInBits()),
10896 DL, VT: TruncVT);
10897 SDValue Add;
10898 if (IsAdd)
10899 Add = DAG.getNode(Opcode: ISD::ADD, DL, VT: TruncVT, N1: Trunc, N2: ShiftC);
10900 else
10901 Add = DAG.getNode(Opcode: ISD::SUB, DL, VT: TruncVT, N1: ShiftC, N2: Trunc);
10902 return DAG.getSExtOrTrunc(Op: Add, DL, VT);
10903 }
10904 }
10905 }
10906 }
10907
10908 // fold (sra x, (trunc (and y, c))) -> (sra x, (and (trunc y), (trunc c))).
10909 if (N1.getOpcode() == ISD::TRUNCATE &&
10910 N1.getOperand(i: 0).getOpcode() == ISD::AND) {
10911 if (SDValue NewOp1 = distributeTruncateThroughAnd(N: N1.getNode()))
10912 return DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: N0, N2: NewOp1);
10913 }
10914
10915 // fold (sra (trunc (sra x, c1)), c2) -> (trunc (sra x, c1 + c2))
10916 // fold (sra (trunc (srl x, c1)), c2) -> (trunc (sra x, c1 + c2))
10917 // if c1 is equal to the number of bits the trunc removes
10918 // TODO - support non-uniform vector shift amounts.
10919 if (N0.getOpcode() == ISD::TRUNCATE &&
10920 (N0.getOperand(i: 0).getOpcode() == ISD::SRL ||
10921 N0.getOperand(i: 0).getOpcode() == ISD::SRA) &&
10922 N0.getOperand(i: 0).hasOneUse() &&
10923 N0.getOperand(i: 0).getOperand(i: 1).hasOneUse() && N1C) {
10924 SDValue N0Op0 = N0.getOperand(i: 0);
10925 if (ConstantSDNode *LargeShift = isConstOrConstSplat(N: N0Op0.getOperand(i: 1))) {
10926 EVT LargeVT = N0Op0.getValueType();
10927 unsigned TruncBits = LargeVT.getScalarSizeInBits() - OpSizeInBits;
10928 if (LargeShift->getAPIntValue() == TruncBits) {
10929 EVT LargeShiftVT = getShiftAmountTy(LHSTy: LargeVT);
10930 SDValue Amt = DAG.getZExtOrTrunc(Op: N1, DL, VT: LargeShiftVT);
10931 Amt = DAG.getNode(Opcode: ISD::ADD, DL, VT: LargeShiftVT, N1: Amt,
10932 N2: DAG.getConstant(Val: TruncBits, DL, VT: LargeShiftVT));
10933 SDValue SRA =
10934 DAG.getNode(Opcode: ISD::SRA, DL, VT: LargeVT, N1: N0Op0.getOperand(i: 0), N2: Amt);
10935 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: SRA);
10936 }
10937 }
10938 }
10939
10940 // Simplify, based on bits shifted out of the LHS.
10941 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
10942 return SDValue(N, 0);
10943
10944 // If the sign bit is known to be zero, switch this to a SRL.
10945 if (DAG.SignBitIsZero(Op: N0))
10946 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0, N2: N1);
10947
10948 if (N1C && !N1C->isOpaque())
10949 if (SDValue NewSRA = visitShiftByConstant(N))
10950 return NewSRA;
10951
10952 // Try to transform this shift into a multiply-high if
10953 // it matches the appropriate pattern detected in combineShiftToMULH.
10954 if (SDValue MULH = combineShiftToMULH(N, DL, DAG, TLI))
10955 return MULH;
10956
10957 // Attempt to convert a sra of a load into a narrower sign-extending load.
10958 if (SDValue NarrowLoad = reduceLoadWidth(N))
10959 return NarrowLoad;
10960
10961 if (SDValue AVG = foldShiftToAvg(N))
10962 return AVG;
10963
10964 return SDValue();
10965}
10966
10967SDValue DAGCombiner::visitSRL(SDNode *N) {
10968 SDValue N0 = N->getOperand(Num: 0);
10969 SDValue N1 = N->getOperand(Num: 1);
10970 if (SDValue V = DAG.simplifyShift(X: N0, Y: N1))
10971 return V;
10972
10973 SDLoc DL(N);
10974 EVT VT = N0.getValueType();
10975 EVT ShiftVT = N1.getValueType();
10976 unsigned OpSizeInBits = VT.getScalarSizeInBits();
10977
10978 // fold (srl c1, c2) -> c1 >>u c2
10979 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::SRL, DL, VT, Ops: {N0, N1}))
10980 return C;
10981
10982 // fold vector ops
10983 if (VT.isVector())
10984 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
10985 return FoldedVOp;
10986
10987 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
10988 return NewSel;
10989
10990 // if (srl x, c) is known to be zero, return 0
10991 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
10992 if (N1C &&
10993 DAG.MaskedValueIsZero(Op: SDValue(N, 0), Mask: APInt::getAllOnes(numBits: OpSizeInBits)))
10994 return DAG.getConstant(Val: 0, DL, VT);
10995
10996 // fold (srl (srl x, c1), c2) -> 0 or (srl x, (add c1, c2))
10997 if (N0.getOpcode() == ISD::SRL) {
10998 auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
10999 ConstantSDNode *RHS) {
11000 APInt c1 = LHS->getAPIntValue();
11001 APInt c2 = RHS->getAPIntValue();
11002 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
11003 return (c1 + c2).uge(RHS: OpSizeInBits);
11004 };
11005 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchOutOfRange))
11006 return DAG.getConstant(Val: 0, DL, VT);
11007
11008 auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
11009 ConstantSDNode *RHS) {
11010 APInt c1 = LHS->getAPIntValue();
11011 APInt c2 = RHS->getAPIntValue();
11012 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
11013 return (c1 + c2).ult(RHS: OpSizeInBits);
11014 };
11015 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchInRange)) {
11016 SDValue Sum = DAG.getNode(Opcode: ISD::ADD, DL, VT: ShiftVT, N1, N2: N0.getOperand(i: 1));
11017 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0.getOperand(i: 0), N2: Sum);
11018 }
11019 }
11020
11021 if (N1C && N0.getOpcode() == ISD::TRUNCATE &&
11022 N0.getOperand(i: 0).getOpcode() == ISD::SRL) {
11023 SDValue InnerShift = N0.getOperand(i: 0);
11024 // TODO - support non-uniform vector shift amounts.
11025 if (auto *N001C = isConstOrConstSplat(N: InnerShift.getOperand(i: 1))) {
11026 uint64_t c1 = N001C->getZExtValue();
11027 uint64_t c2 = N1C->getZExtValue();
11028 EVT InnerShiftVT = InnerShift.getValueType();
11029 EVT ShiftAmtVT = InnerShift.getOperand(i: 1).getValueType();
11030 uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits();
11031 // srl (trunc (srl x, c1)), c2 --> 0 or (trunc (srl x, (add c1, c2)))
11032 // This is only valid if the OpSizeInBits + c1 = size of inner shift.
11033 if (c1 + OpSizeInBits == InnerShiftSize) {
11034 if (c1 + c2 >= InnerShiftSize)
11035 return DAG.getConstant(Val: 0, DL, VT);
11036 SDValue NewShiftAmt = DAG.getConstant(Val: c1 + c2, DL, VT: ShiftAmtVT);
11037 SDValue NewShift = DAG.getNode(Opcode: ISD::SRL, DL, VT: InnerShiftVT,
11038 N1: InnerShift.getOperand(i: 0), N2: NewShiftAmt);
11039 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: NewShift);
11040 }
11041 // In the more general case, we can clear the high bits after the shift:
11042 // srl (trunc (srl x, c1)), c2 --> trunc (and (srl x, (c1+c2)), Mask)
11043 if (N0.hasOneUse() && InnerShift.hasOneUse() &&
11044 c1 + c2 < InnerShiftSize) {
11045 SDValue NewShiftAmt = DAG.getConstant(Val: c1 + c2, DL, VT: ShiftAmtVT);
11046 SDValue NewShift = DAG.getNode(Opcode: ISD::SRL, DL, VT: InnerShiftVT,
11047 N1: InnerShift.getOperand(i: 0), N2: NewShiftAmt);
11048 SDValue Mask = DAG.getConstant(Val: APInt::getLowBitsSet(numBits: InnerShiftSize,
11049 loBitsSet: OpSizeInBits - c2),
11050 DL, VT: InnerShiftVT);
11051 SDValue And = DAG.getNode(Opcode: ISD::AND, DL, VT: InnerShiftVT, N1: NewShift, N2: Mask);
11052 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: And);
11053 }
11054 }
11055 }
11056
11057 // fold (srl (shl x, c1), c2) -> (and (shl x, (sub c1, c2), MASK) or
11058 // (and (srl x, (sub c2, c1), MASK)
11059 if (N0.getOpcode() == ISD::SHL &&
11060 (N0.getOperand(i: 1) == N1 || N0->hasOneUse()) &&
11061 TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
11062 auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
11063 ConstantSDNode *RHS) {
11064 const APInt &LHSC = LHS->getAPIntValue();
11065 const APInt &RHSC = RHS->getAPIntValue();
11066 return LHSC.ult(RHS: OpSizeInBits) && RHSC.ult(RHS: OpSizeInBits) &&
11067 LHSC.getZExtValue() <= RHSC.getZExtValue();
11068 };
11069 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchShiftAmount,
11070 /*AllowUndefs*/ false,
11071 /*AllowTypeMismatch*/ true)) {
11072 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
11073 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1: N01, N2: N1);
11074 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
11075 Mask = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Mask, N2: N01);
11076 Mask = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Mask, N2: Diff);
11077 SDValue Shift = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
11078 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Shift, N2: Mask);
11079 }
11080 if (ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchShiftAmount,
11081 /*AllowUndefs*/ false,
11082 /*AllowTypeMismatch*/ true)) {
11083 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
11084 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1, N2: N01);
11085 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
11086 Mask = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Mask, N2: N1);
11087 SDValue Shift = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
11088 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Shift, N2: Mask);
11089 }
11090 }
11091
11092 // fold (srl (anyextend x), c) -> (and (anyextend (srl x, c)), mask)
11093 // TODO - support non-uniform vector shift amounts.
11094 if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
11095 // Shifting in all undef bits?
11096 EVT SmallVT = N0.getOperand(i: 0).getValueType();
11097 unsigned BitSize = SmallVT.getScalarSizeInBits();
11098 if (N1C->getAPIntValue().uge(RHS: BitSize))
11099 return DAG.getUNDEF(VT);
11100
11101 if (!LegalTypes || TLI.isTypeDesirableForOp(ISD::SRL, VT: SmallVT)) {
11102 uint64_t ShiftAmt = N1C->getZExtValue();
11103 SDLoc DL0(N0);
11104 SDValue SmallShift =
11105 DAG.getNode(Opcode: ISD::SRL, DL: DL0, VT: SmallVT, N1: N0.getOperand(i: 0),
11106 N2: DAG.getShiftAmountConstant(Val: ShiftAmt, VT: SmallVT, DL: DL0));
11107 AddToWorklist(N: SmallShift.getNode());
11108 APInt Mask = APInt::getLowBitsSet(numBits: OpSizeInBits, loBitsSet: OpSizeInBits - ShiftAmt);
11109 return DAG.getNode(Opcode: ISD::AND, DL, VT,
11110 N1: DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT, Operand: SmallShift),
11111 N2: DAG.getConstant(Val: Mask, DL, VT));
11112 }
11113 }
11114
11115 // fold (srl (sra X, Y), 31) -> (srl X, 31). This srl only looks at the sign
11116 // bit, which is unmodified by sra.
11117 if (N1C && N1C->getAPIntValue() == (OpSizeInBits - 1)) {
11118 if (N0.getOpcode() == ISD::SRA)
11119 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0.getOperand(i: 0), N2: N1);
11120 }
11121
11122 // fold (srl (ctlz x), "5") -> x iff x has one bit set (the low bit), and x has a power
11123 // of two bitwidth. The "5" represents (log2 (bitwidth x)).
11124 if (N1C && N0.getOpcode() == ISD::CTLZ &&
11125 isPowerOf2_32(Value: OpSizeInBits) &&
11126 N1C->getAPIntValue() == Log2_32(Value: OpSizeInBits)) {
11127 KnownBits Known = DAG.computeKnownBits(Op: N0.getOperand(i: 0));
11128
11129 // If any of the input bits are KnownOne, then the input couldn't be all
11130 // zeros, thus the result of the srl will always be zero.
11131 if (Known.One.getBoolValue()) return DAG.getConstant(Val: 0, DL: SDLoc(N0), VT);
11132
11133 // If all of the bits input the to ctlz node are known to be zero, then
11134 // the result of the ctlz is "32" and the result of the shift is one.
11135 APInt UnknownBits = ~Known.Zero;
11136 if (UnknownBits == 0) return DAG.getConstant(Val: 1, DL: SDLoc(N0), VT);
11137
11138 // Otherwise, check to see if there is exactly one bit input to the ctlz.
11139 if (UnknownBits.isPowerOf2()) {
11140 // Okay, we know that only that the single bit specified by UnknownBits
11141 // could be set on input to the CTLZ node. If this bit is set, the SRL
11142 // will return 0, if it is clear, it returns 1. Change the CTLZ/SRL pair
11143 // to an SRL/XOR pair, which is likely to simplify more.
11144 unsigned ShAmt = UnknownBits.countr_zero();
11145 SDValue Op = N0.getOperand(i: 0);
11146
11147 if (ShAmt) {
11148 SDLoc DL(N0);
11149 Op = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Op,
11150 N2: DAG.getShiftAmountConstant(Val: ShAmt, VT, DL));
11151 AddToWorklist(N: Op.getNode());
11152 }
11153 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: Op, N2: DAG.getConstant(Val: 1, DL, VT));
11154 }
11155 }
11156
11157 // fold (srl x, (trunc (and y, c))) -> (srl x, (and (trunc y), (trunc c))).
11158 if (N1.getOpcode() == ISD::TRUNCATE &&
11159 N1.getOperand(i: 0).getOpcode() == ISD::AND) {
11160 if (SDValue NewOp1 = distributeTruncateThroughAnd(N: N1.getNode()))
11161 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0, N2: NewOp1);
11162 }
11163
11164 // fold (srl (logic_op x, (shl (zext y), c1)), c1)
11165 // -> (logic_op (srl x, c1), (zext y))
11166 // c1 <= leadingzeros(zext(y))
11167 SDValue X, ZExtY;
11168 if (N1C && sd_match(N: N0, P: m_OneUse(P: m_BitwiseLogic(
11169 L: m_Value(N&: X),
11170 R: m_OneUse(P: m_Shl(L: m_AllOf(preds: m_Value(N&: ZExtY),
11171 preds: m_Opc(Opcode: ISD::ZERO_EXTEND)),
11172 R: m_Specific(N: N1))))))) {
11173 unsigned NumLeadingZeros = ZExtY.getScalarValueSizeInBits() -
11174 ZExtY.getOperand(i: 0).getScalarValueSizeInBits();
11175 if (N1C->getZExtValue() <= NumLeadingZeros)
11176 return DAG.getNode(Opcode: N0.getOpcode(), DL: SDLoc(N0), VT,
11177 N1: DAG.getNode(Opcode: ISD::SRL, DL: SDLoc(N0), VT, N1: X, N2: N1), N2: ZExtY);
11178 }
11179
11180 // fold operands of srl based on knowledge that the low bits are not
11181 // demanded.
11182 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
11183 return SDValue(N, 0);
11184
11185 if (N1C && !N1C->isOpaque())
11186 if (SDValue NewSRL = visitShiftByConstant(N))
11187 return NewSRL;
11188
11189 // Attempt to convert a srl of a load into a narrower zero-extending load.
11190 if (SDValue NarrowLoad = reduceLoadWidth(N))
11191 return NarrowLoad;
11192
11193 // Here is a common situation. We want to optimize:
11194 //
11195 // %a = ...
11196 // %b = and i32 %a, 2
11197 // %c = srl i32 %b, 1
11198 // brcond i32 %c ...
11199 //
11200 // into
11201 //
11202 // %a = ...
11203 // %b = and %a, 2
11204 // %c = setcc eq %b, 0
11205 // brcond %c ...
11206 //
11207 // However when after the source operand of SRL is optimized into AND, the SRL
11208 // itself may not be optimized further. Look for it and add the BRCOND into
11209 // the worklist.
11210 //
11211 // The also tends to happen for binary operations when SimplifyDemandedBits
11212 // is involved.
11213 //
11214 // FIXME: This is unecessary if we process the DAG in topological order,
11215 // which we plan to do. This workaround can be removed once the DAG is
11216 // processed in topological order.
11217 if (N->hasOneUse()) {
11218 SDNode *User = *N->user_begin();
11219
11220 // Look pass the truncate.
11221 if (User->getOpcode() == ISD::TRUNCATE && User->hasOneUse())
11222 User = *User->user_begin();
11223
11224 if (User->getOpcode() == ISD::BRCOND || User->getOpcode() == ISD::AND ||
11225 User->getOpcode() == ISD::OR || User->getOpcode() == ISD::XOR)
11226 AddToWorklist(N: User);
11227 }
11228
11229 // Try to transform this shift into a multiply-high if
11230 // it matches the appropriate pattern detected in combineShiftToMULH.
11231 if (SDValue MULH = combineShiftToMULH(N, DL, DAG, TLI))
11232 return MULH;
11233
11234 if (SDValue AVG = foldShiftToAvg(N))
11235 return AVG;
11236
11237 return SDValue();
11238}
11239
11240SDValue DAGCombiner::visitFunnelShift(SDNode *N) {
11241 EVT VT = N->getValueType(ResNo: 0);
11242 SDValue N0 = N->getOperand(Num: 0);
11243 SDValue N1 = N->getOperand(Num: 1);
11244 SDValue N2 = N->getOperand(Num: 2);
11245 bool IsFSHL = N->getOpcode() == ISD::FSHL;
11246 unsigned BitWidth = VT.getScalarSizeInBits();
11247 SDLoc DL(N);
11248
11249 // fold (fshl N0, N1, 0) -> N0
11250 // fold (fshr N0, N1, 0) -> N1
11251 if (isPowerOf2_32(Value: BitWidth))
11252 if (DAG.MaskedValueIsZero(
11253 Op: N2, Mask: APInt(N2.getScalarValueSizeInBits(), BitWidth - 1)))
11254 return IsFSHL ? N0 : N1;
11255
11256 auto IsUndefOrZero = [](SDValue V) {
11257 return V.isUndef() || isNullOrNullSplat(V, /*AllowUndefs*/ true);
11258 };
11259
11260 // TODO - support non-uniform vector shift amounts.
11261 if (ConstantSDNode *Cst = isConstOrConstSplat(N: N2)) {
11262 EVT ShAmtTy = N2.getValueType();
11263
11264 // fold (fsh* N0, N1, c) -> (fsh* N0, N1, c % BitWidth)
11265 if (Cst->getAPIntValue().uge(RHS: BitWidth)) {
11266 uint64_t RotAmt = Cst->getAPIntValue().urem(RHS: BitWidth);
11267 return DAG.getNode(Opcode: N->getOpcode(), DL, VT, N1: N0, N2: N1,
11268 N3: DAG.getConstant(Val: RotAmt, DL, VT: ShAmtTy));
11269 }
11270
11271 unsigned ShAmt = Cst->getZExtValue();
11272 if (ShAmt == 0)
11273 return IsFSHL ? N0 : N1;
11274
11275 // fold fshl(undef_or_zero, N1, C) -> lshr(N1, BW-C)
11276 // fold fshr(undef_or_zero, N1, C) -> lshr(N1, C)
11277 // fold fshl(N0, undef_or_zero, C) -> shl(N0, C)
11278 // fold fshr(N0, undef_or_zero, C) -> shl(N0, BW-C)
11279 if (IsUndefOrZero(N0))
11280 return DAG.getNode(
11281 Opcode: ISD::SRL, DL, VT, N1,
11282 N2: DAG.getConstant(Val: IsFSHL ? BitWidth - ShAmt : ShAmt, DL, VT: ShAmtTy));
11283 if (IsUndefOrZero(N1))
11284 return DAG.getNode(
11285 Opcode: ISD::SHL, DL, VT, N1: N0,
11286 N2: DAG.getConstant(Val: IsFSHL ? ShAmt : BitWidth - ShAmt, DL, VT: ShAmtTy));
11287
11288 // fold (fshl ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
11289 // fold (fshr ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
11290 // TODO - bigendian support once we have test coverage.
11291 // TODO - can we merge this with CombineConseutiveLoads/MatchLoadCombine?
11292 // TODO - permit LHS EXTLOAD if extensions are shifted out.
11293 if ((BitWidth % 8) == 0 && (ShAmt % 8) == 0 && !VT.isVector() &&
11294 !DAG.getDataLayout().isBigEndian()) {
11295 auto *LHS = dyn_cast<LoadSDNode>(Val&: N0);
11296 auto *RHS = dyn_cast<LoadSDNode>(Val&: N1);
11297 if (LHS && RHS && LHS->isSimple() && RHS->isSimple() &&
11298 LHS->getAddressSpace() == RHS->getAddressSpace() &&
11299 (LHS->hasNUsesOfValue(NUses: 1, Value: 0) || RHS->hasNUsesOfValue(NUses: 1, Value: 0)) &&
11300 ISD::isNON_EXTLoad(N: RHS) && ISD::isNON_EXTLoad(N: LHS)) {
11301 if (DAG.areNonVolatileConsecutiveLoads(LD: LHS, Base: RHS, Bytes: BitWidth / 8, Dist: 1)) {
11302 SDLoc DL(RHS);
11303 uint64_t PtrOff =
11304 IsFSHL ? (((BitWidth - ShAmt) % BitWidth) / 8) : (ShAmt / 8);
11305 Align NewAlign = commonAlignment(A: RHS->getAlign(), Offset: PtrOff);
11306 unsigned Fast = 0;
11307 if (TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT,
11308 AddrSpace: RHS->getAddressSpace(), Alignment: NewAlign,
11309 Flags: RHS->getMemOperand()->getFlags(), Fast: &Fast) &&
11310 Fast) {
11311 SDValue NewPtr = DAG.getMemBasePlusOffset(
11312 Base: RHS->getBasePtr(), Offset: TypeSize::getFixed(ExactSize: PtrOff), DL);
11313 AddToWorklist(N: NewPtr.getNode());
11314 SDValue Load = DAG.getLoad(
11315 VT, dl: DL, Chain: RHS->getChain(), Ptr: NewPtr,
11316 PtrInfo: RHS->getPointerInfo().getWithOffset(O: PtrOff), Alignment: NewAlign,
11317 MMOFlags: RHS->getMemOperand()->getFlags(), AAInfo: RHS->getAAInfo());
11318 DAG.makeEquivalentMemoryOrdering(OldLoad: LHS, NewMemOp: Load.getValue(R: 1));
11319 DAG.makeEquivalentMemoryOrdering(OldLoad: RHS, NewMemOp: Load.getValue(R: 1));
11320 return Load;
11321 }
11322 }
11323 }
11324 }
11325 }
11326
11327 // fold fshr(undef_or_zero, N1, N2) -> lshr(N1, N2)
11328 // fold fshl(N0, undef_or_zero, N2) -> shl(N0, N2)
11329 // iff We know the shift amount is in range.
11330 // TODO: when is it worth doing SUB(BW, N2) as well?
11331 if (isPowerOf2_32(Value: BitWidth)) {
11332 APInt ModuloBits(N2.getScalarValueSizeInBits(), BitWidth - 1);
11333 if (IsUndefOrZero(N0) && !IsFSHL && DAG.MaskedValueIsZero(Op: N2, Mask: ~ModuloBits))
11334 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1, N2);
11335 if (IsUndefOrZero(N1) && IsFSHL && DAG.MaskedValueIsZero(Op: N2, Mask: ~ModuloBits))
11336 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0, N2);
11337 }
11338
11339 // fold (fshl N0, N0, N2) -> (rotl N0, N2)
11340 // fold (fshr N0, N0, N2) -> (rotr N0, N2)
11341 // TODO: Investigate flipping this rotate if only one is legal.
11342 // If funnel shift is legal as well we might be better off avoiding
11343 // non-constant (BW - N2).
11344 unsigned RotOpc = IsFSHL ? ISD::ROTL : ISD::ROTR;
11345 if (N0 == N1 && hasOperation(Opcode: RotOpc, VT))
11346 return DAG.getNode(Opcode: RotOpc, DL, VT, N1: N0, N2);
11347
11348 // Simplify, based on bits shifted out of N0/N1.
11349 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
11350 return SDValue(N, 0);
11351
11352 return SDValue();
11353}
11354
11355SDValue DAGCombiner::visitSHLSAT(SDNode *N) {
11356 SDValue N0 = N->getOperand(Num: 0);
11357 SDValue N1 = N->getOperand(Num: 1);
11358 if (SDValue V = DAG.simplifyShift(X: N0, Y: N1))
11359 return V;
11360
11361 SDLoc DL(N);
11362 EVT VT = N0.getValueType();
11363
11364 // fold (*shlsat c1, c2) -> c1<<c2
11365 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: N->getOpcode(), DL, VT, Ops: {N0, N1}))
11366 return C;
11367
11368 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
11369
11370 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::SHL, VT)) {
11371 // fold (sshlsat x, c) -> (shl x, c)
11372 if (N->getOpcode() == ISD::SSHLSAT && N1C &&
11373 N1C->getAPIntValue().ult(RHS: DAG.ComputeNumSignBits(Op: N0)))
11374 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0, N2: N1);
11375
11376 // fold (ushlsat x, c) -> (shl x, c)
11377 if (N->getOpcode() == ISD::USHLSAT && N1C &&
11378 N1C->getAPIntValue().ule(
11379 RHS: DAG.computeKnownBits(Op: N0).countMinLeadingZeros()))
11380 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0, N2: N1);
11381 }
11382
11383 return SDValue();
11384}
11385
11386// Given a ABS node, detect the following patterns:
11387// (ABS (SUB (EXTEND a), (EXTEND b))).
11388// (TRUNC (ABS (SUB (EXTEND a), (EXTEND b)))).
11389// Generates UABD/SABD instruction.
11390SDValue DAGCombiner::foldABSToABD(SDNode *N, const SDLoc &DL) {
11391 EVT SrcVT = N->getValueType(ResNo: 0);
11392
11393 if (N->getOpcode() == ISD::TRUNCATE)
11394 N = N->getOperand(Num: 0).getNode();
11395
11396 EVT VT = N->getValueType(ResNo: 0);
11397 SDValue Op0, Op1;
11398
11399 if (!sd_match(N, P: m_Abs(Op: m_Sub(L: m_Value(N&: Op0), R: m_Value(N&: Op1)))))
11400 return SDValue();
11401
11402 SDValue AbsOp0 = N->getOperand(Num: 0);
11403 unsigned Opc0 = Op0.getOpcode();
11404
11405 // Check if the operands of the sub are (zero|sign)-extended.
11406 // TODO: Should we use ValueTracking instead?
11407 if (Opc0 != Op1.getOpcode() ||
11408 (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND &&
11409 Opc0 != ISD::SIGN_EXTEND_INREG)) {
11410 // fold (abs (sub nsw x, y)) -> abds(x, y)
11411 // Don't fold this for unsupported types as we lose the NSW handling.
11412 if (AbsOp0->getFlags().hasNoSignedWrap() && hasOperation(Opcode: ISD::ABDS, VT) &&
11413 TLI.preferABDSToABSWithNSW(VT)) {
11414 SDValue ABD = DAG.getNode(Opcode: ISD::ABDS, DL, VT, N1: Op0, N2: Op1);
11415 return DAG.getZExtOrTrunc(Op: ABD, DL, VT: SrcVT);
11416 }
11417 return SDValue();
11418 }
11419
11420 EVT VT0, VT1;
11421 if (Opc0 == ISD::SIGN_EXTEND_INREG) {
11422 VT0 = cast<VTSDNode>(Val: Op0.getOperand(i: 1))->getVT();
11423 VT1 = cast<VTSDNode>(Val: Op1.getOperand(i: 1))->getVT();
11424 } else {
11425 VT0 = Op0.getOperand(i: 0).getValueType();
11426 VT1 = Op1.getOperand(i: 0).getValueType();
11427 }
11428 unsigned ABDOpcode = (Opc0 == ISD::ZERO_EXTEND) ? ISD::ABDU : ISD::ABDS;
11429
11430 // fold abs(sext(x) - sext(y)) -> zext(abds(x, y))
11431 // fold abs(zext(x) - zext(y)) -> zext(abdu(x, y))
11432 EVT MaxVT = VT0.bitsGT(VT: VT1) ? VT0 : VT1;
11433 if ((VT0 == MaxVT || Op0->hasOneUse()) &&
11434 (VT1 == MaxVT || Op1->hasOneUse()) &&
11435 (!LegalTypes || hasOperation(Opcode: ABDOpcode, VT: MaxVT))) {
11436 SDValue ABD = DAG.getNode(Opcode: ABDOpcode, DL, VT: MaxVT,
11437 N1: DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MaxVT, Operand: Op0),
11438 N2: DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MaxVT, Operand: Op1));
11439 ABD = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: ABD);
11440 return DAG.getZExtOrTrunc(Op: ABD, DL, VT: SrcVT);
11441 }
11442
11443 // fold abs(sext(x) - sext(y)) -> abds(sext(x), sext(y))
11444 // fold abs(zext(x) - zext(y)) -> abdu(zext(x), zext(y))
11445 if (!LegalOperations || hasOperation(Opcode: ABDOpcode, VT)) {
11446 SDValue ABD = DAG.getNode(Opcode: ABDOpcode, DL, VT, N1: Op0, N2: Op1);
11447 return DAG.getZExtOrTrunc(Op: ABD, DL, VT: SrcVT);
11448 }
11449
11450 return SDValue();
11451}
11452
11453SDValue DAGCombiner::visitABS(SDNode *N) {
11454 SDValue N0 = N->getOperand(Num: 0);
11455 EVT VT = N->getValueType(ResNo: 0);
11456 SDLoc DL(N);
11457
11458 // fold (abs c1) -> c2
11459 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::ABS, DL, VT, Ops: {N0}))
11460 return C;
11461 // fold (abs (abs x)) -> (abs x)
11462 if (N0.getOpcode() == ISD::ABS)
11463 return N0;
11464 // fold (abs x) -> x iff not-negative
11465 if (DAG.SignBitIsZero(Op: N0))
11466 return N0;
11467
11468 if (SDValue ABD = foldABSToABD(N, DL))
11469 return ABD;
11470
11471 // fold (abs (sign_extend_inreg x)) -> (zero_extend (abs (truncate x)))
11472 // iff zero_extend/truncate are free.
11473 if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG) {
11474 EVT ExtVT = cast<VTSDNode>(Val: N0.getOperand(i: 1))->getVT();
11475 if (TLI.isTruncateFree(FromVT: VT, ToVT: ExtVT) && TLI.isZExtFree(FromTy: ExtVT, ToTy: VT) &&
11476 TLI.isTypeDesirableForOp(ISD::ABS, VT: ExtVT) &&
11477 hasOperation(Opcode: ISD::ABS, VT: ExtVT)) {
11478 return DAG.getNode(
11479 Opcode: ISD::ZERO_EXTEND, DL, VT,
11480 Operand: DAG.getNode(Opcode: ISD::ABS, DL, VT: ExtVT,
11481 Operand: DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: ExtVT, Operand: N0.getOperand(i: 0))));
11482 }
11483 }
11484
11485 return SDValue();
11486}
11487
11488SDValue DAGCombiner::visitBSWAP(SDNode *N) {
11489 SDValue N0 = N->getOperand(Num: 0);
11490 EVT VT = N->getValueType(ResNo: 0);
11491 SDLoc DL(N);
11492
11493 // fold (bswap c1) -> c2
11494 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::BSWAP, DL, VT, Ops: {N0}))
11495 return C;
11496 // fold (bswap (bswap x)) -> x
11497 if (N0.getOpcode() == ISD::BSWAP)
11498 return N0.getOperand(i: 0);
11499
11500 // Canonicalize bswap(bitreverse(x)) -> bitreverse(bswap(x)). If bitreverse
11501 // isn't supported, it will be expanded to bswap followed by a manual reversal
11502 // of bits in each byte. By placing bswaps before bitreverse, we can remove
11503 // the two bswaps if the bitreverse gets expanded.
11504 if (N0.getOpcode() == ISD::BITREVERSE && N0.hasOneUse()) {
11505 SDValue BSwap = DAG.getNode(Opcode: ISD::BSWAP, DL, VT, Operand: N0.getOperand(i: 0));
11506 return DAG.getNode(Opcode: ISD::BITREVERSE, DL, VT, Operand: BSwap);
11507 }
11508
11509 // fold (bswap shl(x,c)) -> (zext(bswap(trunc(shl(x,sub(c,bw/2))))))
11510 // iff x >= bw/2 (i.e. lower half is known zero)
11511 unsigned BW = VT.getScalarSizeInBits();
11512 if (BW >= 32 && N0.getOpcode() == ISD::SHL && N0.hasOneUse()) {
11513 auto *ShAmt = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
11514 EVT HalfVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: BW / 2);
11515 if (ShAmt && ShAmt->getAPIntValue().ult(RHS: BW) &&
11516 ShAmt->getZExtValue() >= (BW / 2) &&
11517 (ShAmt->getZExtValue() % 16) == 0 && TLI.isTypeLegal(VT: HalfVT) &&
11518 TLI.isTruncateFree(FromVT: VT, ToVT: HalfVT) &&
11519 (!LegalOperations || hasOperation(Opcode: ISD::BSWAP, VT: HalfVT))) {
11520 SDValue Res = N0.getOperand(i: 0);
11521 if (uint64_t NewShAmt = (ShAmt->getZExtValue() - (BW / 2)))
11522 Res = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Res,
11523 N2: DAG.getShiftAmountConstant(Val: NewShAmt, VT, DL));
11524 Res = DAG.getZExtOrTrunc(Op: Res, DL, VT: HalfVT);
11525 Res = DAG.getNode(Opcode: ISD::BSWAP, DL, VT: HalfVT, Operand: Res);
11526 return DAG.getZExtOrTrunc(Op: Res, DL, VT);
11527 }
11528 }
11529
11530 // Try to canonicalize bswap-of-logical-shift-by-8-bit-multiple as
11531 // inverse-shift-of-bswap:
11532 // bswap (X u<< C) --> (bswap X) u>> C
11533 // bswap (X u>> C) --> (bswap X) u<< C
11534 if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
11535 N0.hasOneUse()) {
11536 auto *ShAmt = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
11537 if (ShAmt && ShAmt->getAPIntValue().ult(RHS: BW) &&
11538 ShAmt->getZExtValue() % 8 == 0) {
11539 SDValue NewSwap = DAG.getNode(Opcode: ISD::BSWAP, DL, VT, Operand: N0.getOperand(i: 0));
11540 unsigned InverseShift = N0.getOpcode() == ISD::SHL ? ISD::SRL : ISD::SHL;
11541 return DAG.getNode(Opcode: InverseShift, DL, VT, N1: NewSwap, N2: N0.getOperand(i: 1));
11542 }
11543 }
11544
11545 if (SDValue V = foldBitOrderCrossLogicOp(N, DAG))
11546 return V;
11547
11548 return SDValue();
11549}
11550
11551SDValue DAGCombiner::visitBITREVERSE(SDNode *N) {
11552 SDValue N0 = N->getOperand(Num: 0);
11553 EVT VT = N->getValueType(ResNo: 0);
11554 SDLoc DL(N);
11555
11556 // fold (bitreverse c1) -> c2
11557 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::BITREVERSE, DL, VT, Ops: {N0}))
11558 return C;
11559
11560 // fold (bitreverse (bitreverse x)) -> x
11561 if (N0.getOpcode() == ISD::BITREVERSE)
11562 return N0.getOperand(i: 0);
11563
11564 SDValue X, Y;
11565
11566 // fold (bitreverse (lshr (bitreverse x), y)) -> (shl x, y)
11567 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::SHL, VT)) &&
11568 sd_match(N, P: m_BitReverse(Op: m_Srl(L: m_BitReverse(Op: m_Value(N&: X)), R: m_Value(N&: Y)))))
11569 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: X, N2: Y);
11570
11571 // fold (bitreverse (shl (bitreverse x), y)) -> (lshr x, y)
11572 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::SRL, VT)) &&
11573 sd_match(N, P: m_BitReverse(Op: m_Shl(L: m_BitReverse(Op: m_Value(N&: X)), R: m_Value(N&: Y)))))
11574 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: X, N2: Y);
11575
11576 return SDValue();
11577}
11578
11579SDValue DAGCombiner::visitCTLZ(SDNode *N) {
11580 SDValue N0 = N->getOperand(Num: 0);
11581 EVT VT = N->getValueType(ResNo: 0);
11582 SDLoc DL(N);
11583
11584 // fold (ctlz c1) -> c2
11585 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::CTLZ, DL, VT, Ops: {N0}))
11586 return C;
11587
11588 // If the value is known never to be zero, switch to the undef version.
11589 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::CTLZ_ZERO_UNDEF, VT))
11590 if (DAG.isKnownNeverZero(Op: N0))
11591 return DAG.getNode(Opcode: ISD::CTLZ_ZERO_UNDEF, DL, VT, Operand: N0);
11592
11593 return SDValue();
11594}
11595
11596SDValue DAGCombiner::visitCTLZ_ZERO_UNDEF(SDNode *N) {
11597 SDValue N0 = N->getOperand(Num: 0);
11598 EVT VT = N->getValueType(ResNo: 0);
11599 SDLoc DL(N);
11600
11601 // fold (ctlz_zero_undef c1) -> c2
11602 if (SDValue C =
11603 DAG.FoldConstantArithmetic(Opcode: ISD::CTLZ_ZERO_UNDEF, DL, VT, Ops: {N0}))
11604 return C;
11605 return SDValue();
11606}
11607
11608SDValue DAGCombiner::visitCTTZ(SDNode *N) {
11609 SDValue N0 = N->getOperand(Num: 0);
11610 EVT VT = N->getValueType(ResNo: 0);
11611 SDLoc DL(N);
11612
11613 // fold (cttz c1) -> c2
11614 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::CTTZ, DL, VT, Ops: {N0}))
11615 return C;
11616
11617 // If the value is known never to be zero, switch to the undef version.
11618 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::CTTZ_ZERO_UNDEF, VT))
11619 if (DAG.isKnownNeverZero(Op: N0))
11620 return DAG.getNode(Opcode: ISD::CTTZ_ZERO_UNDEF, DL, VT, Operand: N0);
11621
11622 return SDValue();
11623}
11624
11625SDValue DAGCombiner::visitCTTZ_ZERO_UNDEF(SDNode *N) {
11626 SDValue N0 = N->getOperand(Num: 0);
11627 EVT VT = N->getValueType(ResNo: 0);
11628 SDLoc DL(N);
11629
11630 // fold (cttz_zero_undef c1) -> c2
11631 if (SDValue C =
11632 DAG.FoldConstantArithmetic(Opcode: ISD::CTTZ_ZERO_UNDEF, DL, VT, Ops: {N0}))
11633 return C;
11634 return SDValue();
11635}
11636
11637SDValue DAGCombiner::visitCTPOP(SDNode *N) {
11638 SDValue N0 = N->getOperand(Num: 0);
11639 EVT VT = N->getValueType(ResNo: 0);
11640 unsigned NumBits = VT.getScalarSizeInBits();
11641 SDLoc DL(N);
11642
11643 // fold (ctpop c1) -> c2
11644 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::CTPOP, DL, VT, Ops: {N0}))
11645 return C;
11646
11647 // If the source is being shifted, but doesn't affect any active bits,
11648 // then we can call CTPOP on the shift source directly.
11649 if (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SHL) {
11650 if (ConstantSDNode *AmtC = isConstOrConstSplat(N: N0.getOperand(i: 1))) {
11651 const APInt &Amt = AmtC->getAPIntValue();
11652 if (Amt.ult(RHS: NumBits)) {
11653 KnownBits KnownSrc = DAG.computeKnownBits(Op: N0.getOperand(i: 0));
11654 if ((N0.getOpcode() == ISD::SRL &&
11655 Amt.ule(RHS: KnownSrc.countMinTrailingZeros())) ||
11656 (N0.getOpcode() == ISD::SHL &&
11657 Amt.ule(RHS: KnownSrc.countMinLeadingZeros()))) {
11658 return DAG.getNode(Opcode: ISD::CTPOP, DL, VT, Operand: N0.getOperand(i: 0));
11659 }
11660 }
11661 }
11662 }
11663
11664 // If the upper bits are known to be zero, then see if its profitable to
11665 // only count the lower bits.
11666 if (VT.isScalarInteger() && NumBits > 8 && (NumBits & 1) == 0) {
11667 EVT HalfVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: NumBits / 2);
11668 if (hasOperation(Opcode: ISD::CTPOP, VT: HalfVT) &&
11669 TLI.isTypeDesirableForOp(ISD::CTPOP, VT: HalfVT) &&
11670 TLI.isTruncateFree(Val: N0, VT2: HalfVT) && TLI.isZExtFree(FromTy: HalfVT, ToTy: VT)) {
11671 APInt UpperBits = APInt::getHighBitsSet(numBits: NumBits, hiBitsSet: NumBits / 2);
11672 if (DAG.MaskedValueIsZero(Op: N0, Mask: UpperBits)) {
11673 SDValue PopCnt = DAG.getNode(Opcode: ISD::CTPOP, DL, VT: HalfVT,
11674 Operand: DAG.getZExtOrTrunc(Op: N0, DL, VT: HalfVT));
11675 return DAG.getZExtOrTrunc(Op: PopCnt, DL, VT);
11676 }
11677 }
11678 }
11679
11680 return SDValue();
11681}
11682
11683static bool isLegalToCombineMinNumMaxNum(SelectionDAG &DAG, SDValue LHS,
11684 SDValue RHS, const SDNodeFlags Flags,
11685 const TargetLowering &TLI) {
11686 EVT VT = LHS.getValueType();
11687 if (!VT.isFloatingPoint())
11688 return false;
11689
11690 const TargetOptions &Options = DAG.getTarget().Options;
11691
11692 return (Flags.hasNoSignedZeros() || Options.NoSignedZerosFPMath) &&
11693 TLI.isProfitableToCombineMinNumMaxNum(VT) &&
11694 (Flags.hasNoNaNs() ||
11695 (DAG.isKnownNeverNaN(Op: RHS) && DAG.isKnownNeverNaN(Op: LHS)));
11696}
11697
11698static SDValue combineMinNumMaxNumImpl(const SDLoc &DL, EVT VT, SDValue LHS,
11699 SDValue RHS, SDValue True, SDValue False,
11700 ISD::CondCode CC,
11701 const TargetLowering &TLI,
11702 SelectionDAG &DAG) {
11703 EVT TransformVT = TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT);
11704 switch (CC) {
11705 case ISD::SETOLT:
11706 case ISD::SETOLE:
11707 case ISD::SETLT:
11708 case ISD::SETLE:
11709 case ISD::SETULT:
11710 case ISD::SETULE: {
11711 // Since it's known never nan to get here already, either fminnum or
11712 // fminnum_ieee are OK. Try the ieee version first, since it's fminnum is
11713 // expanded in terms of it.
11714 unsigned IEEEOpcode = (LHS == True) ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
11715 if (TLI.isOperationLegalOrCustom(Op: IEEEOpcode, VT))
11716 return DAG.getNode(Opcode: IEEEOpcode, DL, VT, N1: LHS, N2: RHS);
11717
11718 unsigned Opcode = (LHS == True) ? ISD::FMINNUM : ISD::FMAXNUM;
11719 if (TLI.isOperationLegalOrCustom(Op: Opcode, VT: TransformVT))
11720 return DAG.getNode(Opcode, DL, VT, N1: LHS, N2: RHS);
11721 return SDValue();
11722 }
11723 case ISD::SETOGT:
11724 case ISD::SETOGE:
11725 case ISD::SETGT:
11726 case ISD::SETGE:
11727 case ISD::SETUGT:
11728 case ISD::SETUGE: {
11729 unsigned IEEEOpcode = (LHS == True) ? ISD::FMAXNUM_IEEE : ISD::FMINNUM_IEEE;
11730 if (TLI.isOperationLegalOrCustom(Op: IEEEOpcode, VT))
11731 return DAG.getNode(Opcode: IEEEOpcode, DL, VT, N1: LHS, N2: RHS);
11732
11733 unsigned Opcode = (LHS == True) ? ISD::FMAXNUM : ISD::FMINNUM;
11734 if (TLI.isOperationLegalOrCustom(Op: Opcode, VT: TransformVT))
11735 return DAG.getNode(Opcode, DL, VT, N1: LHS, N2: RHS);
11736 return SDValue();
11737 }
11738 default:
11739 return SDValue();
11740 }
11741}
11742
11743SDValue DAGCombiner::foldShiftToAvg(SDNode *N) {
11744 const unsigned Opcode = N->getOpcode();
11745
11746 // Convert (sr[al] (add n[su]w x, y)) -> (avgfloor[su] x, y)
11747 if (Opcode != ISD::SRA && Opcode != ISD::SRL)
11748 return SDValue();
11749
11750 unsigned FloorISD = 0;
11751 auto VT = N->getValueType(ResNo: 0);
11752 bool IsUnsigned = false;
11753
11754 // Decide wether signed or unsigned.
11755 switch (Opcode) {
11756 case ISD::SRA:
11757 if (!hasOperation(Opcode: ISD::AVGFLOORS, VT))
11758 return SDValue();
11759 FloorISD = ISD::AVGFLOORS;
11760 break;
11761 case ISD::SRL:
11762 IsUnsigned = true;
11763 if (!hasOperation(Opcode: ISD::AVGFLOORU, VT))
11764 return SDValue();
11765 FloorISD = ISD::AVGFLOORU;
11766 break;
11767 default:
11768 return SDValue();
11769 }
11770
11771 // Captured values.
11772 SDValue A, B, Add;
11773
11774 // Match floor average as it is common to both floor/ceil avgs.
11775 if (!sd_match(N, P: m_BinOp(Opc: Opcode,
11776 L: m_AllOf(preds: m_Value(N&: Add), preds: m_Add(L: m_Value(N&: A), R: m_Value(N&: B))),
11777 R: m_One())))
11778 return SDValue();
11779
11780 // Can't optimize adds that may wrap.
11781 if (IsUnsigned && !Add->getFlags().hasNoUnsignedWrap())
11782 return SDValue();
11783
11784 if (!IsUnsigned && !Add->getFlags().hasNoSignedWrap())
11785 return SDValue();
11786
11787 return DAG.getNode(Opcode: FloorISD, DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Ops: {A, B});
11788}
11789
11790SDValue DAGCombiner::foldBitwiseOpWithNeg(SDNode *N, const SDLoc &DL, EVT VT) {
11791 unsigned Opc = N->getOpcode();
11792 SDValue X, Y, Z;
11793 if (sd_match(
11794 N, P: m_BitwiseLogic(L: m_Value(N&: X), R: m_Add(L: m_Not(V: m_Value(N&: Y)), R: m_Value(N&: Z)))))
11795 return DAG.getNode(Opcode: Opc, DL, VT, N1: X,
11796 N2: DAG.getNOT(DL, Val: DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Y, N2: Z), VT));
11797
11798 if (sd_match(N, P: m_BitwiseLogic(L: m_Value(N&: X), R: m_Sub(L: m_OneUse(P: m_Not(V: m_Value(N&: Y))),
11799 R: m_Value(N&: Z)))))
11800 return DAG.getNode(Opcode: Opc, DL, VT, N1: X,
11801 N2: DAG.getNOT(DL, Val: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Y, N2: Z), VT));
11802
11803 return SDValue();
11804}
11805
11806/// Generate Min/Max node
11807SDValue DAGCombiner::combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
11808 SDValue RHS, SDValue True,
11809 SDValue False, ISD::CondCode CC) {
11810 if ((LHS == True && RHS == False) || (LHS == False && RHS == True))
11811 return combineMinNumMaxNumImpl(DL, VT, LHS, RHS, True, False, CC, TLI, DAG);
11812
11813 // If we can't directly match this, try to see if we can pull an fneg out of
11814 // the select.
11815 SDValue NegTrue = TLI.getCheaperOrNeutralNegatedExpression(
11816 Op: True, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize);
11817 if (!NegTrue)
11818 return SDValue();
11819
11820 HandleSDNode NegTrueHandle(NegTrue);
11821
11822 // Try to unfold an fneg from the select if we are comparing the negated
11823 // constant.
11824 //
11825 // select (setcc x, K) (fneg x), -K -> fneg(minnum(x, K))
11826 //
11827 // TODO: Handle fabs
11828 if (LHS == NegTrue) {
11829 // If we can't directly match this, try to see if we can pull an fneg out of
11830 // the select.
11831 SDValue NegRHS = TLI.getCheaperOrNeutralNegatedExpression(
11832 Op: RHS, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize);
11833 if (NegRHS) {
11834 HandleSDNode NegRHSHandle(NegRHS);
11835 if (NegRHS == False) {
11836 SDValue Combined = combineMinNumMaxNumImpl(DL, VT, LHS, RHS, True: NegTrue,
11837 False, CC, TLI, DAG);
11838 if (Combined)
11839 return DAG.getNode(Opcode: ISD::FNEG, DL, VT, Operand: Combined);
11840 }
11841 }
11842 }
11843
11844 return SDValue();
11845}
11846
11847/// If a (v)select has a condition value that is a sign-bit test, try to smear
11848/// the condition operand sign-bit across the value width and use it as a mask.
11849static SDValue foldSelectOfConstantsUsingSra(SDNode *N, const SDLoc &DL,
11850 SelectionDAG &DAG) {
11851 SDValue Cond = N->getOperand(Num: 0);
11852 SDValue C1 = N->getOperand(Num: 1);
11853 SDValue C2 = N->getOperand(Num: 2);
11854 if (!isConstantOrConstantVector(N: C1) || !isConstantOrConstantVector(N: C2))
11855 return SDValue();
11856
11857 EVT VT = N->getValueType(ResNo: 0);
11858 if (Cond.getOpcode() != ISD::SETCC || !Cond.hasOneUse() ||
11859 VT != Cond.getOperand(i: 0).getValueType())
11860 return SDValue();
11861
11862 // The inverted-condition + commuted-select variants of these patterns are
11863 // canonicalized to these forms in IR.
11864 SDValue X = Cond.getOperand(i: 0);
11865 SDValue CondC = Cond.getOperand(i: 1);
11866 ISD::CondCode CC = cast<CondCodeSDNode>(Val: Cond.getOperand(i: 2))->get();
11867 if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(V: CondC) &&
11868 isAllOnesOrAllOnesSplat(V: C2)) {
11869 // i32 X > -1 ? C1 : -1 --> (X >>s 31) | C1
11870 SDValue ShAmtC = DAG.getConstant(Val: X.getScalarValueSizeInBits() - 1, DL, VT);
11871 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: X, N2: ShAmtC);
11872 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: Sra, N2: C1);
11873 }
11874 if (CC == ISD::SETLT && isNullOrNullSplat(V: CondC) && isNullOrNullSplat(V: C2)) {
11875 // i8 X < 0 ? C1 : 0 --> (X >>s 7) & C1
11876 SDValue ShAmtC = DAG.getConstant(Val: X.getScalarValueSizeInBits() - 1, DL, VT);
11877 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: X, N2: ShAmtC);
11878 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Sra, N2: C1);
11879 }
11880 return SDValue();
11881}
11882
11883static bool shouldConvertSelectOfConstantsToMath(const SDValue &Cond, EVT VT,
11884 const TargetLowering &TLI) {
11885 if (!TLI.convertSelectOfConstantsToMath(VT))
11886 return false;
11887
11888 if (Cond.getOpcode() != ISD::SETCC || !Cond->hasOneUse())
11889 return true;
11890 if (!TLI.isOperationLegalOrCustom(Op: ISD::SELECT_CC, VT))
11891 return true;
11892
11893 ISD::CondCode CC = cast<CondCodeSDNode>(Val: Cond.getOperand(i: 2))->get();
11894 if (CC == ISD::SETLT && isNullOrNullSplat(V: Cond.getOperand(i: 1)))
11895 return true;
11896 if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(V: Cond.getOperand(i: 1)))
11897 return true;
11898
11899 return false;
11900}
11901
11902SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) {
11903 SDValue Cond = N->getOperand(Num: 0);
11904 SDValue N1 = N->getOperand(Num: 1);
11905 SDValue N2 = N->getOperand(Num: 2);
11906 EVT VT = N->getValueType(ResNo: 0);
11907 EVT CondVT = Cond.getValueType();
11908 SDLoc DL(N);
11909
11910 if (!VT.isInteger())
11911 return SDValue();
11912
11913 auto *C1 = dyn_cast<ConstantSDNode>(Val&: N1);
11914 auto *C2 = dyn_cast<ConstantSDNode>(Val&: N2);
11915 if (!C1 || !C2)
11916 return SDValue();
11917
11918 if (CondVT != MVT::i1 || LegalOperations) {
11919 // fold (select Cond, 0, 1) -> (xor Cond, 1)
11920 // We can't do this reliably if integer based booleans have different contents
11921 // to floating point based booleans. This is because we can't tell whether we
11922 // have an integer-based boolean or a floating-point-based boolean unless we
11923 // can find the SETCC that produced it and inspect its operands. This is
11924 // fairly easy if C is the SETCC node, but it can potentially be
11925 // undiscoverable (or not reasonably discoverable). For example, it could be
11926 // in another basic block or it could require searching a complicated
11927 // expression.
11928 if (CondVT.isInteger() &&
11929 TLI.getBooleanContents(/*isVec*/false, /*isFloat*/true) ==
11930 TargetLowering::ZeroOrOneBooleanContent &&
11931 TLI.getBooleanContents(/*isVec*/false, /*isFloat*/false) ==
11932 TargetLowering::ZeroOrOneBooleanContent &&
11933 C1->isZero() && C2->isOne()) {
11934 SDValue NotCond =
11935 DAG.getNode(Opcode: ISD::XOR, DL, VT: CondVT, N1: Cond, N2: DAG.getConstant(Val: 1, DL, VT: CondVT));
11936 if (VT.bitsEq(VT: CondVT))
11937 return NotCond;
11938 return DAG.getZExtOrTrunc(Op: NotCond, DL, VT);
11939 }
11940
11941 return SDValue();
11942 }
11943
11944 // Only do this before legalization to avoid conflicting with target-specific
11945 // transforms in the other direction (create a select from a zext/sext). There
11946 // is also a target-independent combine here in DAGCombiner in the other
11947 // direction for (select Cond, -1, 0) when the condition is not i1.
11948 assert(CondVT == MVT::i1 && !LegalOperations);
11949
11950 // select Cond, 1, 0 --> zext (Cond)
11951 if (C1->isOne() && C2->isZero())
11952 return DAG.getZExtOrTrunc(Op: Cond, DL, VT);
11953
11954 // select Cond, -1, 0 --> sext (Cond)
11955 if (C1->isAllOnes() && C2->isZero())
11956 return DAG.getSExtOrTrunc(Op: Cond, DL, VT);
11957
11958 // select Cond, 0, 1 --> zext (!Cond)
11959 if (C1->isZero() && C2->isOne()) {
11960 SDValue NotCond = DAG.getNOT(DL, Val: Cond, VT: MVT::i1);
11961 NotCond = DAG.getZExtOrTrunc(Op: NotCond, DL, VT);
11962 return NotCond;
11963 }
11964
11965 // select Cond, 0, -1 --> sext (!Cond)
11966 if (C1->isZero() && C2->isAllOnes()) {
11967 SDValue NotCond = DAG.getNOT(DL, Val: Cond, VT: MVT::i1);
11968 NotCond = DAG.getSExtOrTrunc(Op: NotCond, DL, VT);
11969 return NotCond;
11970 }
11971
11972 // Use a target hook because some targets may prefer to transform in the
11973 // other direction.
11974 if (!shouldConvertSelectOfConstantsToMath(Cond, VT, TLI))
11975 return SDValue();
11976
11977 // For any constants that differ by 1, we can transform the select into
11978 // an extend and add.
11979 const APInt &C1Val = C1->getAPIntValue();
11980 const APInt &C2Val = C2->getAPIntValue();
11981
11982 // select Cond, C1, C1-1 --> add (zext Cond), C1-1
11983 if (C1Val - 1 == C2Val) {
11984 Cond = DAG.getZExtOrTrunc(Op: Cond, DL, VT);
11985 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Cond, N2);
11986 }
11987
11988 // select Cond, C1, C1+1 --> add (sext Cond), C1+1
11989 if (C1Val + 1 == C2Val) {
11990 Cond = DAG.getSExtOrTrunc(Op: Cond, DL, VT);
11991 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Cond, N2);
11992 }
11993
11994 // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
11995 if (C1Val.isPowerOf2() && C2Val.isZero()) {
11996 Cond = DAG.getZExtOrTrunc(Op: Cond, DL, VT);
11997 SDValue ShAmtC =
11998 DAG.getShiftAmountConstant(Val: C1Val.exactLogBase2(), VT, DL);
11999 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Cond, N2: ShAmtC);
12000 }
12001
12002 // select Cond, -1, C --> or (sext Cond), C
12003 if (C1->isAllOnes()) {
12004 Cond = DAG.getSExtOrTrunc(Op: Cond, DL, VT);
12005 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: Cond, N2);
12006 }
12007
12008 // select Cond, C, -1 --> or (sext (not Cond)), C
12009 if (C2->isAllOnes()) {
12010 SDValue NotCond = DAG.getNOT(DL, Val: Cond, VT: MVT::i1);
12011 NotCond = DAG.getSExtOrTrunc(Op: NotCond, DL, VT);
12012 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: NotCond, N2: N1);
12013 }
12014
12015 if (SDValue V = foldSelectOfConstantsUsingSra(N, DL, DAG))
12016 return V;
12017
12018 return SDValue();
12019}
12020
12021template <class MatchContextClass>
12022static SDValue foldBoolSelectToLogic(SDNode *N, const SDLoc &DL,
12023 SelectionDAG &DAG) {
12024 assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT ||
12025 N->getOpcode() == ISD::VP_SELECT) &&
12026 "Expected a (v)(vp.)select");
12027 SDValue Cond = N->getOperand(Num: 0);
12028 SDValue T = N->getOperand(Num: 1), F = N->getOperand(Num: 2);
12029 EVT VT = N->getValueType(ResNo: 0);
12030 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
12031 MatchContextClass matcher(DAG, TLI, N);
12032
12033 if (VT != Cond.getValueType() || VT.getScalarSizeInBits() != 1)
12034 return SDValue();
12035
12036 // select Cond, Cond, F --> or Cond, freeze(F)
12037 // select Cond, 1, F --> or Cond, freeze(F)
12038 if (Cond == T || isOneOrOneSplat(V: T, /* AllowUndefs */ true))
12039 return matcher.getNode(ISD::OR, DL, VT, Cond, DAG.getFreeze(V: F));
12040
12041 // select Cond, T, Cond --> and Cond, freeze(T)
12042 // select Cond, T, 0 --> and Cond, freeze(T)
12043 if (Cond == F || isNullOrNullSplat(V: F, /* AllowUndefs */ true))
12044 return matcher.getNode(ISD::AND, DL, VT, Cond, DAG.getFreeze(V: T));
12045
12046 // select Cond, T, 1 --> or (not Cond), freeze(T)
12047 if (isOneOrOneSplat(V: F, /* AllowUndefs */ true)) {
12048 SDValue NotCond =
12049 matcher.getNode(ISD::XOR, DL, VT, Cond, DAG.getAllOnesConstant(DL, VT));
12050 return matcher.getNode(ISD::OR, DL, VT, NotCond, DAG.getFreeze(V: T));
12051 }
12052
12053 // select Cond, 0, F --> and (not Cond), freeze(F)
12054 if (isNullOrNullSplat(V: T, /* AllowUndefs */ true)) {
12055 SDValue NotCond =
12056 matcher.getNode(ISD::XOR, DL, VT, Cond, DAG.getAllOnesConstant(DL, VT));
12057 return matcher.getNode(ISD::AND, DL, VT, NotCond, DAG.getFreeze(V: F));
12058 }
12059
12060 return SDValue();
12061}
12062
12063static SDValue foldVSelectToSignBitSplatMask(SDNode *N, SelectionDAG &DAG) {
12064 SDValue N0 = N->getOperand(Num: 0);
12065 SDValue N1 = N->getOperand(Num: 1);
12066 SDValue N2 = N->getOperand(Num: 2);
12067 EVT VT = N->getValueType(ResNo: 0);
12068 unsigned EltSizeInBits = VT.getScalarSizeInBits();
12069
12070 SDValue Cond0, Cond1;
12071 ISD::CondCode CC;
12072 if (!sd_match(N: N0, P: m_OneUse(P: m_SetCC(LHS: m_Value(N&: Cond0), RHS: m_Value(N&: Cond1),
12073 CC: m_CondCode(CC)))) ||
12074 VT != Cond0.getValueType())
12075 return SDValue();
12076
12077 // Match a signbit check of Cond0 as "Cond0 s<0". Swap select operands if the
12078 // compare is inverted from that pattern ("Cond0 s> -1").
12079 if (CC == ISD::SETLT && isNullOrNullSplat(V: Cond1))
12080 ; // This is the pattern we are looking for.
12081 else if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(V: Cond1))
12082 std::swap(a&: N1, b&: N2);
12083 else
12084 return SDValue();
12085
12086 // (Cond0 s< 0) ? N1 : 0 --> (Cond0 s>> BW-1) & freeze(N1)
12087 if (isNullOrNullSplat(V: N2)) {
12088 SDLoc DL(N);
12089 SDValue ShiftAmt = DAG.getShiftAmountConstant(Val: EltSizeInBits - 1, VT, DL);
12090 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: Cond0, N2: ShiftAmt);
12091 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Sra, N2: DAG.getFreeze(V: N1));
12092 }
12093
12094 // (Cond0 s< 0) ? -1 : N2 --> (Cond0 s>> BW-1) | freeze(N2)
12095 if (isAllOnesOrAllOnesSplat(V: N1)) {
12096 SDLoc DL(N);
12097 SDValue ShiftAmt = DAG.getShiftAmountConstant(Val: EltSizeInBits - 1, VT, DL);
12098 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: Cond0, N2: ShiftAmt);
12099 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: Sra, N2: DAG.getFreeze(V: N2));
12100 }
12101
12102 // If we have to invert the sign bit mask, only do that transform if the
12103 // target has a bitwise 'and not' instruction (the invert is free).
12104 // (Cond0 s< -0) ? 0 : N2 --> ~(Cond0 s>> BW-1) & freeze(N2)
12105 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
12106 if (isNullOrNullSplat(V: N1) && TLI.hasAndNot(X: N1)) {
12107 SDLoc DL(N);
12108 SDValue ShiftAmt = DAG.getShiftAmountConstant(Val: EltSizeInBits - 1, VT, DL);
12109 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: Cond0, N2: ShiftAmt);
12110 SDValue Not = DAG.getNOT(DL, Val: Sra, VT);
12111 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Not, N2: DAG.getFreeze(V: N2));
12112 }
12113
12114 // TODO: There's another pattern in this family, but it may require
12115 // implementing hasOrNot() to check for profitability:
12116 // (Cond0 s> -1) ? -1 : N2 --> ~(Cond0 s>> BW-1) | freeze(N2)
12117
12118 return SDValue();
12119}
12120
12121// Match SELECTs with absolute difference patterns.
12122// (select (setcc a, b, set?gt), (sub a, b), (sub b, a)) --> (abd? a, b)
12123// (select (setcc a, b, set?ge), (sub a, b), (sub b, a)) --> (abd? a, b)
12124// (select (setcc a, b, set?lt), (sub b, a), (sub a, b)) --> (abd? a, b)
12125// (select (setcc a, b, set?le), (sub b, a), (sub a, b)) --> (abd? a, b)
12126SDValue DAGCombiner::foldSelectToABD(SDValue LHS, SDValue RHS, SDValue True,
12127 SDValue False, ISD::CondCode CC,
12128 const SDLoc &DL) {
12129 bool IsSigned = isSignedIntSetCC(Code: CC);
12130 unsigned ABDOpc = IsSigned ? ISD::ABDS : ISD::ABDU;
12131 EVT VT = LHS.getValueType();
12132
12133 if (LegalOperations && !hasOperation(Opcode: ABDOpc, VT))
12134 return SDValue();
12135
12136 switch (CC) {
12137 case ISD::SETGT:
12138 case ISD::SETGE:
12139 case ISD::SETUGT:
12140 case ISD::SETUGE:
12141 if (sd_match(N: True, P: m_Sub(L: m_Specific(N: LHS), R: m_Specific(N: RHS))) &&
12142 sd_match(N: False, P: m_Sub(L: m_Specific(N: RHS), R: m_Specific(N: LHS))))
12143 return DAG.getNode(Opcode: ABDOpc, DL, VT, N1: LHS, N2: RHS);
12144 if (sd_match(N: True, P: m_Sub(L: m_Specific(N: RHS), R: m_Specific(N: LHS))) &&
12145 sd_match(N: False, P: m_Sub(L: m_Specific(N: LHS), R: m_Specific(N: RHS))) &&
12146 hasOperation(Opcode: ABDOpc, VT))
12147 return DAG.getNegative(Val: DAG.getNode(Opcode: ABDOpc, DL, VT, N1: LHS, N2: RHS), DL, VT);
12148 break;
12149 case ISD::SETLT:
12150 case ISD::SETLE:
12151 case ISD::SETULT:
12152 case ISD::SETULE:
12153 if (sd_match(N: True, P: m_Sub(L: m_Specific(N: RHS), R: m_Specific(N: LHS))) &&
12154 sd_match(N: False, P: m_Sub(L: m_Specific(N: LHS), R: m_Specific(N: RHS))))
12155 return DAG.getNode(Opcode: ABDOpc, DL, VT, N1: LHS, N2: RHS);
12156 if (sd_match(N: True, P: m_Sub(L: m_Specific(N: LHS), R: m_Specific(N: RHS))) &&
12157 sd_match(N: False, P: m_Sub(L: m_Specific(N: RHS), R: m_Specific(N: LHS))) &&
12158 hasOperation(Opcode: ABDOpc, VT))
12159 return DAG.getNegative(Val: DAG.getNode(Opcode: ABDOpc, DL, VT, N1: LHS, N2: RHS), DL, VT);
12160 break;
12161 default:
12162 break;
12163 }
12164
12165 return SDValue();
12166}
12167
12168SDValue DAGCombiner::visitSELECT(SDNode *N) {
12169 SDValue N0 = N->getOperand(Num: 0);
12170 SDValue N1 = N->getOperand(Num: 1);
12171 SDValue N2 = N->getOperand(Num: 2);
12172 EVT VT = N->getValueType(ResNo: 0);
12173 EVT VT0 = N0.getValueType();
12174 SDLoc DL(N);
12175 SDNodeFlags Flags = N->getFlags();
12176
12177 if (SDValue V = DAG.simplifySelect(Cond: N0, TVal: N1, FVal: N2))
12178 return V;
12179
12180 if (SDValue V = foldBoolSelectToLogic<EmptyMatchContext>(N, DL, DAG))
12181 return V;
12182
12183 // select (not Cond), N1, N2 -> select Cond, N2, N1
12184 if (SDValue F = extractBooleanFlip(V: N0, DAG, TLI, Force: false)) {
12185 SDValue SelectOp = DAG.getSelect(DL, VT, Cond: F, LHS: N2, RHS: N1);
12186 SelectOp->setFlags(Flags);
12187 return SelectOp;
12188 }
12189
12190 if (SDValue V = foldSelectOfConstants(N))
12191 return V;
12192
12193 // If we can fold this based on the true/false value, do so.
12194 if (SimplifySelectOps(SELECT: N, LHS: N1, RHS: N2))
12195 return SDValue(N, 0); // Don't revisit N.
12196
12197 if (VT0 == MVT::i1) {
12198 // The code in this block deals with the following 2 equivalences:
12199 // select(C0|C1, x, y) <=> select(C0, x, select(C1, x, y))
12200 // select(C0&C1, x, y) <=> select(C0, select(C1, x, y), y)
12201 // The target can specify its preferred form with the
12202 // shouldNormalizeToSelectSequence() callback. However we always transform
12203 // to the right anyway if we find the inner select exists in the DAG anyway
12204 // and we always transform to the left side if we know that we can further
12205 // optimize the combination of the conditions.
12206 bool normalizeToSequence =
12207 TLI.shouldNormalizeToSelectSequence(Context&: *DAG.getContext(), VT);
12208 // select (and Cond0, Cond1), X, Y
12209 // -> select Cond0, (select Cond1, X, Y), Y
12210 if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) {
12211 SDValue Cond0 = N0->getOperand(Num: 0);
12212 SDValue Cond1 = N0->getOperand(Num: 1);
12213 SDValue InnerSelect =
12214 DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Cond1, N2: N1, N3: N2, Flags);
12215 if (normalizeToSequence || !InnerSelect.use_empty())
12216 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Cond0,
12217 N2: InnerSelect, N3: N2, Flags);
12218 // Cleanup on failure.
12219 if (InnerSelect.use_empty())
12220 recursivelyDeleteUnusedNodes(N: InnerSelect.getNode());
12221 }
12222 // select (or Cond0, Cond1), X, Y -> select Cond0, X, (select Cond1, X, Y)
12223 if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) {
12224 SDValue Cond0 = N0->getOperand(Num: 0);
12225 SDValue Cond1 = N0->getOperand(Num: 1);
12226 SDValue InnerSelect = DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(),
12227 N1: Cond1, N2: N1, N3: N2, Flags);
12228 if (normalizeToSequence || !InnerSelect.use_empty())
12229 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Cond0, N2: N1,
12230 N3: InnerSelect, Flags);
12231 // Cleanup on failure.
12232 if (InnerSelect.use_empty())
12233 recursivelyDeleteUnusedNodes(N: InnerSelect.getNode());
12234 }
12235
12236 // select Cond0, (select Cond1, X, Y), Y -> select (and Cond0, Cond1), X, Y
12237 if (N1->getOpcode() == ISD::SELECT && N1->hasOneUse()) {
12238 SDValue N1_0 = N1->getOperand(Num: 0);
12239 SDValue N1_1 = N1->getOperand(Num: 1);
12240 SDValue N1_2 = N1->getOperand(Num: 2);
12241 if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) {
12242 // Create the actual and node if we can generate good code for it.
12243 if (!normalizeToSequence) {
12244 SDValue And = DAG.getNode(Opcode: ISD::AND, DL, VT: N0.getValueType(), N1: N0, N2: N1_0);
12245 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: And, N2: N1_1,
12246 N3: N2, Flags);
12247 }
12248 // Otherwise see if we can optimize the "and" to a better pattern.
12249 if (SDValue Combined = visitANDLike(N0, N1: N1_0, N)) {
12250 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Combined, N2: N1_1,
12251 N3: N2, Flags);
12252 }
12253 }
12254 }
12255 // select Cond0, X, (select Cond1, X, Y) -> select (or Cond0, Cond1), X, Y
12256 if (N2->getOpcode() == ISD::SELECT && N2->hasOneUse()) {
12257 SDValue N2_0 = N2->getOperand(Num: 0);
12258 SDValue N2_1 = N2->getOperand(Num: 1);
12259 SDValue N2_2 = N2->getOperand(Num: 2);
12260 if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) {
12261 // Create the actual or node if we can generate good code for it.
12262 if (!normalizeToSequence) {
12263 SDValue Or = DAG.getNode(Opcode: ISD::OR, DL, VT: N0.getValueType(), N1: N0, N2: N2_0);
12264 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Or, N2: N1,
12265 N3: N2_2, Flags);
12266 }
12267 // Otherwise see if we can optimize to a better pattern.
12268 if (SDValue Combined = visitORLike(N0, N1: N2_0, DL))
12269 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Combined, N2: N1,
12270 N3: N2_2, Flags);
12271 }
12272 }
12273
12274 // select usubo(x, y).overflow, (sub y, x), (usubo x, y) -> abdu(x, y)
12275 if (N0.getOpcode() == ISD::USUBO && N0.getResNo() == 1 &&
12276 N2.getNode() == N0.getNode() && N2.getResNo() == 0 &&
12277 N1.getOpcode() == ISD::SUB && N2.getOperand(i: 0) == N1.getOperand(i: 1) &&
12278 N2.getOperand(i: 1) == N1.getOperand(i: 0) &&
12279 (!LegalOperations || TLI.isOperationLegal(Op: ISD::ABDU, VT)))
12280 return DAG.getNode(Opcode: ISD::ABDU, DL, VT, N1: N0.getOperand(i: 0), N2: N0.getOperand(i: 1));
12281
12282 // select usubo(x, y).overflow, (usubo x, y), (sub y, x) -> neg (abdu x, y)
12283 if (N0.getOpcode() == ISD::USUBO && N0.getResNo() == 1 &&
12284 N1.getNode() == N0.getNode() && N1.getResNo() == 0 &&
12285 N2.getOpcode() == ISD::SUB && N2.getOperand(i: 0) == N1.getOperand(i: 1) &&
12286 N2.getOperand(i: 1) == N1.getOperand(i: 0) &&
12287 (!LegalOperations || TLI.isOperationLegal(Op: ISD::ABDU, VT)))
12288 return DAG.getNegative(
12289 Val: DAG.getNode(Opcode: ISD::ABDU, DL, VT, N1: N0.getOperand(i: 0), N2: N0.getOperand(i: 1)),
12290 DL, VT);
12291 }
12292
12293 // Fold selects based on a setcc into other things, such as min/max/abs.
12294 if (N0.getOpcode() == ISD::SETCC) {
12295 SDValue Cond0 = N0.getOperand(i: 0), Cond1 = N0.getOperand(i: 1);
12296 ISD::CondCode CC = cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get();
12297
12298 // select (fcmp lt x, y), x, y -> fminnum x, y
12299 // select (fcmp gt x, y), x, y -> fmaxnum x, y
12300 //
12301 // This is OK if we don't care what happens if either operand is a NaN.
12302 if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, LHS: N1, RHS: N2, Flags, TLI))
12303 if (SDValue FMinMax =
12304 combineMinNumMaxNum(DL, VT, LHS: Cond0, RHS: Cond1, True: N1, False: N2, CC))
12305 return FMinMax;
12306
12307 // Use 'unsigned add with overflow' to optimize an unsigned saturating add.
12308 // This is conservatively limited to pre-legal-operations to give targets
12309 // a chance to reverse the transform if they want to do that. Also, it is
12310 // unlikely that the pattern would be formed late, so it's probably not
12311 // worth going through the other checks.
12312 if (!LegalOperations && TLI.isOperationLegalOrCustom(Op: ISD::UADDO, VT) &&
12313 CC == ISD::SETUGT && N0.hasOneUse() && isAllOnesConstant(V: N1) &&
12314 N2.getOpcode() == ISD::ADD && Cond0 == N2.getOperand(i: 0)) {
12315 auto *C = dyn_cast<ConstantSDNode>(Val: N2.getOperand(i: 1));
12316 auto *NotC = dyn_cast<ConstantSDNode>(Val&: Cond1);
12317 if (C && NotC && C->getAPIntValue() == ~NotC->getAPIntValue()) {
12318 // select (setcc Cond0, ~C, ugt), -1, (add Cond0, C) -->
12319 // uaddo Cond0, C; select uaddo.1, -1, uaddo.0
12320 //
12321 // The IR equivalent of this transform would have this form:
12322 // %a = add %x, C
12323 // %c = icmp ugt %x, ~C
12324 // %r = select %c, -1, %a
12325 // =>
12326 // %u = call {iN,i1} llvm.uadd.with.overflow(%x, C)
12327 // %u0 = extractvalue %u, 0
12328 // %u1 = extractvalue %u, 1
12329 // %r = select %u1, -1, %u0
12330 SDVTList VTs = DAG.getVTList(VT1: VT, VT2: VT0);
12331 SDValue UAO = DAG.getNode(Opcode: ISD::UADDO, DL, VTList: VTs, N1: Cond0, N2: N2.getOperand(i: 1));
12332 return DAG.getSelect(DL, VT, Cond: UAO.getValue(R: 1), LHS: N1, RHS: UAO.getValue(R: 0));
12333 }
12334 }
12335
12336 if (TLI.isOperationLegal(Op: ISD::SELECT_CC, VT) ||
12337 (!LegalOperations &&
12338 TLI.isOperationLegalOrCustom(Op: ISD::SELECT_CC, VT))) {
12339 // Any flags available in a select/setcc fold will be on the setcc as they
12340 // migrated from fcmp
12341 Flags = N0->getFlags();
12342 SDValue SelectNode = DAG.getNode(Opcode: ISD::SELECT_CC, DL, VT, N1: Cond0, N2: Cond1, N3: N1,
12343 N4: N2, N5: N0.getOperand(i: 2));
12344 SelectNode->setFlags(Flags);
12345 return SelectNode;
12346 }
12347
12348 if (SDValue ABD = foldSelectToABD(LHS: Cond0, RHS: Cond1, True: N1, False: N2, CC, DL))
12349 return ABD;
12350
12351 if (SDValue NewSel = SimplifySelect(DL, N0, N1, N2))
12352 return NewSel;
12353
12354 // (select (ugt x, C), (add x, ~C), x) -> (umin (add x, ~C), x)
12355 // (select (ult x, C), x, (add x, -C)) -> (umin x, (add x, -C))
12356 APInt C;
12357 if (sd_match(N: Cond1, P: m_ConstInt(V&: C)) && hasUMin(VT)) {
12358 if (CC == ISD::SETUGT && Cond0 == N2 &&
12359 sd_match(N: N1, P: m_Add(L: m_Specific(N: N2), R: m_SpecificInt(V: ~C)))) {
12360 // The resulting code relies on an unsigned wrap in ADD.
12361 // Recreating ADD to drop possible nuw/nsw flags.
12362 SDValue AddC = DAG.getConstant(Val: ~C, DL, VT);
12363 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N2, N2: AddC);
12364 return DAG.getNode(Opcode: ISD::UMIN, DL, VT, N1: Add, N2);
12365 }
12366 if (CC == ISD::SETULT && Cond0 == N1 &&
12367 sd_match(N: N2, P: m_Add(L: m_Specific(N: N1), R: m_SpecificInt(V: -C)))) {
12368 // Ditto.
12369 SDValue AddC = DAG.getConstant(Val: -C, DL, VT);
12370 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: AddC);
12371 return DAG.getNode(Opcode: ISD::UMIN, DL, VT, N1, N2: Add);
12372 }
12373 }
12374 }
12375
12376 if (!VT.isVector())
12377 if (SDValue BinOp = foldSelectOfBinops(N))
12378 return BinOp;
12379
12380 if (SDValue R = combineSelectAsExtAnd(Cond: N0, T: N1, F: N2, DL, DAG))
12381 return R;
12382
12383 return SDValue();
12384}
12385
12386// This function assumes all the vselect's arguments are CONCAT_VECTOR
12387// nodes and that the condition is a BV of ConstantSDNodes (or undefs).
12388static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
12389 SDLoc DL(N);
12390 SDValue Cond = N->getOperand(Num: 0);
12391 SDValue LHS = N->getOperand(Num: 1);
12392 SDValue RHS = N->getOperand(Num: 2);
12393 EVT VT = N->getValueType(ResNo: 0);
12394 int NumElems = VT.getVectorNumElements();
12395 assert(LHS.getOpcode() == ISD::CONCAT_VECTORS &&
12396 RHS.getOpcode() == ISD::CONCAT_VECTORS &&
12397 Cond.getOpcode() == ISD::BUILD_VECTOR);
12398
12399 // CONCAT_VECTOR can take an arbitrary number of arguments. We only care about
12400 // binary ones here.
12401 if (LHS->getNumOperands() != 2 || RHS->getNumOperands() != 2)
12402 return SDValue();
12403
12404 // We're sure we have an even number of elements due to the
12405 // concat_vectors we have as arguments to vselect.
12406 // Skip BV elements until we find one that's not an UNDEF
12407 // After we find an UNDEF element, keep looping until we get to half the
12408 // length of the BV and see if all the non-undef nodes are the same.
12409 ConstantSDNode *BottomHalf = nullptr;
12410 for (int i = 0; i < NumElems / 2; ++i) {
12411 if (Cond->getOperand(Num: i)->isUndef())
12412 continue;
12413
12414 if (BottomHalf == nullptr)
12415 BottomHalf = cast<ConstantSDNode>(Val: Cond.getOperand(i));
12416 else if (Cond->getOperand(Num: i).getNode() != BottomHalf)
12417 return SDValue();
12418 }
12419
12420 // Do the same for the second half of the BuildVector
12421 ConstantSDNode *TopHalf = nullptr;
12422 for (int i = NumElems / 2; i < NumElems; ++i) {
12423 if (Cond->getOperand(Num: i)->isUndef())
12424 continue;
12425
12426 if (TopHalf == nullptr)
12427 TopHalf = cast<ConstantSDNode>(Val: Cond.getOperand(i));
12428 else if (Cond->getOperand(Num: i).getNode() != TopHalf)
12429 return SDValue();
12430 }
12431
12432 assert(TopHalf && BottomHalf &&
12433 "One half of the selector was all UNDEFs and the other was all the "
12434 "same value. This should have been addressed before this function.");
12435 return DAG.getNode(
12436 Opcode: ISD::CONCAT_VECTORS, DL, VT,
12437 N1: BottomHalf->isZero() ? RHS->getOperand(Num: 0) : LHS->getOperand(Num: 0),
12438 N2: TopHalf->isZero() ? RHS->getOperand(Num: 1) : LHS->getOperand(Num: 1));
12439}
12440
12441bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
12442 SelectionDAG &DAG, const SDLoc &DL) {
12443
12444 // Only perform the transformation when existing operands can be reused.
12445 if (IndexIsScaled)
12446 return false;
12447
12448 if (!isNullConstant(V: BasePtr) && !Index.hasOneUse())
12449 return false;
12450
12451 EVT VT = BasePtr.getValueType();
12452
12453 if (SDValue SplatVal = DAG.getSplatValue(V: Index);
12454 SplatVal && !isNullConstant(V: SplatVal) &&
12455 SplatVal.getValueType() == VT) {
12456 BasePtr = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: BasePtr, N2: SplatVal);
12457 Index = DAG.getSplat(VT: Index.getValueType(), DL, Op: DAG.getConstant(Val: 0, DL, VT));
12458 return true;
12459 }
12460
12461 if (Index.getOpcode() != ISD::ADD)
12462 return false;
12463
12464 if (SDValue SplatVal = DAG.getSplatValue(V: Index.getOperand(i: 0));
12465 SplatVal && SplatVal.getValueType() == VT) {
12466 BasePtr = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: BasePtr, N2: SplatVal);
12467 Index = Index.getOperand(i: 1);
12468 return true;
12469 }
12470 if (SDValue SplatVal = DAG.getSplatValue(V: Index.getOperand(i: 1));
12471 SplatVal && SplatVal.getValueType() == VT) {
12472 BasePtr = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: BasePtr, N2: SplatVal);
12473 Index = Index.getOperand(i: 0);
12474 return true;
12475 }
12476 return false;
12477}
12478
12479// Fold sext/zext of index into index type.
12480bool refineIndexType(SDValue &Index, ISD::MemIndexType &IndexType, EVT DataVT,
12481 SelectionDAG &DAG) {
12482 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
12483
12484 // It's always safe to look through zero extends.
12485 if (Index.getOpcode() == ISD::ZERO_EXTEND) {
12486 if (TLI.shouldRemoveExtendFromGSIndex(Extend: Index, DataVT)) {
12487 IndexType = ISD::UNSIGNED_SCALED;
12488 Index = Index.getOperand(i: 0);
12489 return true;
12490 }
12491 if (ISD::isIndexTypeSigned(IndexType)) {
12492 IndexType = ISD::UNSIGNED_SCALED;
12493 return true;
12494 }
12495 }
12496
12497 // It's only safe to look through sign extends when Index is signed.
12498 if (Index.getOpcode() == ISD::SIGN_EXTEND &&
12499 ISD::isIndexTypeSigned(IndexType) &&
12500 TLI.shouldRemoveExtendFromGSIndex(Extend: Index, DataVT)) {
12501 Index = Index.getOperand(i: 0);
12502 return true;
12503 }
12504
12505 return false;
12506}
12507
12508SDValue DAGCombiner::visitVPSCATTER(SDNode *N) {
12509 VPScatterSDNode *MSC = cast<VPScatterSDNode>(Val: N);
12510 SDValue Mask = MSC->getMask();
12511 SDValue Chain = MSC->getChain();
12512 SDValue Index = MSC->getIndex();
12513 SDValue Scale = MSC->getScale();
12514 SDValue StoreVal = MSC->getValue();
12515 SDValue BasePtr = MSC->getBasePtr();
12516 SDValue VL = MSC->getVectorLength();
12517 ISD::MemIndexType IndexType = MSC->getIndexType();
12518 SDLoc DL(N);
12519
12520 // Zap scatters with a zero mask.
12521 if (ISD::isConstantSplatVectorAllZeros(N: Mask.getNode()))
12522 return Chain;
12523
12524 if (refineUniformBase(BasePtr, Index, IndexIsScaled: MSC->isIndexScaled(), DAG, DL)) {
12525 SDValue Ops[] = {Chain, StoreVal, BasePtr, Index, Scale, Mask, VL};
12526 return DAG.getScatterVP(VTs: DAG.getVTList(VT: MVT::Other), VT: MSC->getMemoryVT(),
12527 dl: DL, Ops, MMO: MSC->getMemOperand(), IndexType);
12528 }
12529
12530 if (refineIndexType(Index, IndexType, DataVT: StoreVal.getValueType(), DAG)) {
12531 SDValue Ops[] = {Chain, StoreVal, BasePtr, Index, Scale, Mask, VL};
12532 return DAG.getScatterVP(VTs: DAG.getVTList(VT: MVT::Other), VT: MSC->getMemoryVT(),
12533 dl: DL, Ops, MMO: MSC->getMemOperand(), IndexType);
12534 }
12535
12536 return SDValue();
12537}
12538
12539SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
12540 MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(Val: N);
12541 SDValue Mask = MSC->getMask();
12542 SDValue Chain = MSC->getChain();
12543 SDValue Index = MSC->getIndex();
12544 SDValue Scale = MSC->getScale();
12545 SDValue StoreVal = MSC->getValue();
12546 SDValue BasePtr = MSC->getBasePtr();
12547 ISD::MemIndexType IndexType = MSC->getIndexType();
12548 SDLoc DL(N);
12549
12550 // Zap scatters with a zero mask.
12551 if (ISD::isConstantSplatVectorAllZeros(N: Mask.getNode()))
12552 return Chain;
12553
12554 if (refineUniformBase(BasePtr, Index, IndexIsScaled: MSC->isIndexScaled(), DAG, DL)) {
12555 SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
12556 return DAG.getMaskedScatter(VTs: DAG.getVTList(VT: MVT::Other), MemVT: MSC->getMemoryVT(),
12557 dl: DL, Ops, MMO: MSC->getMemOperand(), IndexType,
12558 IsTruncating: MSC->isTruncatingStore());
12559 }
12560
12561 if (refineIndexType(Index, IndexType, DataVT: StoreVal.getValueType(), DAG)) {
12562 SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
12563 return DAG.getMaskedScatter(VTs: DAG.getVTList(VT: MVT::Other), MemVT: MSC->getMemoryVT(),
12564 dl: DL, Ops, MMO: MSC->getMemOperand(), IndexType,
12565 IsTruncating: MSC->isTruncatingStore());
12566 }
12567
12568 return SDValue();
12569}
12570
12571SDValue DAGCombiner::visitMSTORE(SDNode *N) {
12572 MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(Val: N);
12573 SDValue Mask = MST->getMask();
12574 SDValue Chain = MST->getChain();
12575 SDValue Value = MST->getValue();
12576 SDValue Ptr = MST->getBasePtr();
12577
12578 // Zap masked stores with a zero mask.
12579 if (ISD::isConstantSplatVectorAllZeros(N: Mask.getNode()))
12580 return Chain;
12581
12582 // Remove a masked store if base pointers and masks are equal.
12583 if (MaskedStoreSDNode *MST1 = dyn_cast<MaskedStoreSDNode>(Val&: Chain)) {
12584 if (MST->isUnindexed() && MST->isSimple() && MST1->isUnindexed() &&
12585 MST1->isSimple() && MST1->getBasePtr() == Ptr &&
12586 !MST->getBasePtr().isUndef() &&
12587 ((Mask == MST1->getMask() && MST->getMemoryVT().getStoreSize() ==
12588 MST1->getMemoryVT().getStoreSize()) ||
12589 ISD::isConstantSplatVectorAllOnes(N: Mask.getNode())) &&
12590 TypeSize::isKnownLE(LHS: MST1->getMemoryVT().getStoreSize(),
12591 RHS: MST->getMemoryVT().getStoreSize())) {
12592 CombineTo(N: MST1, Res: MST1->getChain());
12593 if (N->getOpcode() != ISD::DELETED_NODE)
12594 AddToWorklist(N);
12595 return SDValue(N, 0);
12596 }
12597 }
12598
12599 // If this is a masked load with an all ones mask, we can use a unmasked load.
12600 // FIXME: Can we do this for indexed, compressing, or truncating stores?
12601 if (ISD::isConstantSplatVectorAllOnes(N: Mask.getNode()) && MST->isUnindexed() &&
12602 !MST->isCompressingStore() && !MST->isTruncatingStore())
12603 return DAG.getStore(Chain: MST->getChain(), dl: SDLoc(N), Val: MST->getValue(),
12604 Ptr: MST->getBasePtr(), PtrInfo: MST->getPointerInfo(),
12605 Alignment: MST->getBaseAlign(), MMOFlags: MST->getMemOperand()->getFlags(),
12606 AAInfo: MST->getAAInfo());
12607
12608 // Try transforming N to an indexed store.
12609 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
12610 return SDValue(N, 0);
12611
12612 if (MST->isTruncatingStore() && MST->isUnindexed() &&
12613 Value.getValueType().isInteger() &&
12614 (!isa<ConstantSDNode>(Val: Value) ||
12615 !cast<ConstantSDNode>(Val&: Value)->isOpaque())) {
12616 APInt TruncDemandedBits =
12617 APInt::getLowBitsSet(numBits: Value.getScalarValueSizeInBits(),
12618 loBitsSet: MST->getMemoryVT().getScalarSizeInBits());
12619
12620 // See if we can simplify the operation with
12621 // SimplifyDemandedBits, which only works if the value has a single use.
12622 if (SimplifyDemandedBits(Op: Value, DemandedBits: TruncDemandedBits)) {
12623 // Re-visit the store if anything changed and the store hasn't been merged
12624 // with another node (N is deleted) SimplifyDemandedBits will add Value's
12625 // node back to the worklist if necessary, but we also need to re-visit
12626 // the Store node itself.
12627 if (N->getOpcode() != ISD::DELETED_NODE)
12628 AddToWorklist(N);
12629 return SDValue(N, 0);
12630 }
12631 }
12632
12633 // If this is a TRUNC followed by a masked store, fold this into a masked
12634 // truncating store. We can do this even if this is already a masked
12635 // truncstore.
12636 // TODO: Try combine to masked compress store if possiable.
12637 if ((Value.getOpcode() == ISD::TRUNCATE) && Value->hasOneUse() &&
12638 MST->isUnindexed() && !MST->isCompressingStore() &&
12639 TLI.canCombineTruncStore(ValVT: Value.getOperand(i: 0).getValueType(),
12640 MemVT: MST->getMemoryVT(), LegalOnly: LegalOperations)) {
12641 auto Mask = TLI.promoteTargetBoolean(DAG, Bool: MST->getMask(),
12642 ValVT: Value.getOperand(i: 0).getValueType());
12643 return DAG.getMaskedStore(Chain, dl: SDLoc(N), Val: Value.getOperand(i: 0), Base: Ptr,
12644 Offset: MST->getOffset(), Mask, MemVT: MST->getMemoryVT(),
12645 MMO: MST->getMemOperand(), AM: MST->getAddressingMode(),
12646 /*IsTruncating=*/true);
12647 }
12648
12649 return SDValue();
12650}
12651
12652SDValue DAGCombiner::visitVP_STRIDED_STORE(SDNode *N) {
12653 auto *SST = cast<VPStridedStoreSDNode>(Val: N);
12654 EVT EltVT = SST->getValue().getValueType().getVectorElementType();
12655 // Combine strided stores with unit-stride to a regular VP store.
12656 if (auto *CStride = dyn_cast<ConstantSDNode>(Val: SST->getStride());
12657 CStride && CStride->getZExtValue() == EltVT.getStoreSize()) {
12658 return DAG.getStoreVP(Chain: SST->getChain(), dl: SDLoc(N), Val: SST->getValue(),
12659 Ptr: SST->getBasePtr(), Offset: SST->getOffset(), Mask: SST->getMask(),
12660 EVL: SST->getVectorLength(), MemVT: SST->getMemoryVT(),
12661 MMO: SST->getMemOperand(), AM: SST->getAddressingMode(),
12662 IsTruncating: SST->isTruncatingStore(), IsCompressing: SST->isCompressingStore());
12663 }
12664 return SDValue();
12665}
12666
12667SDValue DAGCombiner::visitVECTOR_COMPRESS(SDNode *N) {
12668 SDLoc DL(N);
12669 SDValue Vec = N->getOperand(Num: 0);
12670 SDValue Mask = N->getOperand(Num: 1);
12671 SDValue Passthru = N->getOperand(Num: 2);
12672 EVT VecVT = Vec.getValueType();
12673
12674 bool HasPassthru = !Passthru.isUndef();
12675
12676 APInt SplatVal;
12677 if (ISD::isConstantSplatVector(N: Mask.getNode(), SplatValue&: SplatVal))
12678 return TLI.isConstTrueVal(N: Mask) ? Vec : Passthru;
12679
12680 if (Vec.isUndef() || Mask.isUndef())
12681 return Passthru;
12682
12683 // No need for potentially expensive compress if the mask is constant.
12684 if (ISD::isBuildVectorOfConstantSDNodes(N: Mask.getNode())) {
12685 SmallVector<SDValue, 16> Ops;
12686 EVT ScalarVT = VecVT.getVectorElementType();
12687 unsigned NumSelected = 0;
12688 unsigned NumElmts = VecVT.getVectorNumElements();
12689 for (unsigned I = 0; I < NumElmts; ++I) {
12690 SDValue MaskI = Mask.getOperand(i: I);
12691 // We treat undef mask entries as "false".
12692 if (MaskI.isUndef())
12693 continue;
12694
12695 if (TLI.isConstTrueVal(N: MaskI)) {
12696 SDValue VecI = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: ScalarVT, N1: Vec,
12697 N2: DAG.getVectorIdxConstant(Val: I, DL));
12698 Ops.push_back(Elt: VecI);
12699 NumSelected++;
12700 }
12701 }
12702 for (unsigned Rest = NumSelected; Rest < NumElmts; ++Rest) {
12703 SDValue Val =
12704 HasPassthru
12705 ? DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: ScalarVT, N1: Passthru,
12706 N2: DAG.getVectorIdxConstant(Val: Rest, DL))
12707 : DAG.getUNDEF(VT: ScalarVT);
12708 Ops.push_back(Elt: Val);
12709 }
12710 return DAG.getBuildVector(VT: VecVT, DL, Ops);
12711 }
12712
12713 return SDValue();
12714}
12715
12716SDValue DAGCombiner::visitVPGATHER(SDNode *N) {
12717 VPGatherSDNode *MGT = cast<VPGatherSDNode>(Val: N);
12718 SDValue Mask = MGT->getMask();
12719 SDValue Chain = MGT->getChain();
12720 SDValue Index = MGT->getIndex();
12721 SDValue Scale = MGT->getScale();
12722 SDValue BasePtr = MGT->getBasePtr();
12723 SDValue VL = MGT->getVectorLength();
12724 ISD::MemIndexType IndexType = MGT->getIndexType();
12725 SDLoc DL(N);
12726
12727 if (refineUniformBase(BasePtr, Index, IndexIsScaled: MGT->isIndexScaled(), DAG, DL)) {
12728 SDValue Ops[] = {Chain, BasePtr, Index, Scale, Mask, VL};
12729 return DAG.getGatherVP(
12730 VTs: DAG.getVTList(VT1: N->getValueType(ResNo: 0), VT2: MVT::Other), VT: MGT->getMemoryVT(), dl: DL,
12731 Ops, MMO: MGT->getMemOperand(), IndexType);
12732 }
12733
12734 if (refineIndexType(Index, IndexType, DataVT: N->getValueType(ResNo: 0), DAG)) {
12735 SDValue Ops[] = {Chain, BasePtr, Index, Scale, Mask, VL};
12736 return DAG.getGatherVP(
12737 VTs: DAG.getVTList(VT1: N->getValueType(ResNo: 0), VT2: MVT::Other), VT: MGT->getMemoryVT(), dl: DL,
12738 Ops, MMO: MGT->getMemOperand(), IndexType);
12739 }
12740
12741 return SDValue();
12742}
12743
12744SDValue DAGCombiner::visitMGATHER(SDNode *N) {
12745 MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(Val: N);
12746 SDValue Mask = MGT->getMask();
12747 SDValue Chain = MGT->getChain();
12748 SDValue Index = MGT->getIndex();
12749 SDValue Scale = MGT->getScale();
12750 SDValue PassThru = MGT->getPassThru();
12751 SDValue BasePtr = MGT->getBasePtr();
12752 ISD::MemIndexType IndexType = MGT->getIndexType();
12753 SDLoc DL(N);
12754
12755 // Zap gathers with a zero mask.
12756 if (ISD::isConstantSplatVectorAllZeros(N: Mask.getNode()))
12757 return CombineTo(N, Res0: PassThru, Res1: MGT->getChain());
12758
12759 if (refineUniformBase(BasePtr, Index, IndexIsScaled: MGT->isIndexScaled(), DAG, DL)) {
12760 SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
12761 return DAG.getMaskedGather(
12762 VTs: DAG.getVTList(VT1: N->getValueType(ResNo: 0), VT2: MVT::Other), MemVT: MGT->getMemoryVT(), dl: DL,
12763 Ops, MMO: MGT->getMemOperand(), IndexType, ExtTy: MGT->getExtensionType());
12764 }
12765
12766 if (refineIndexType(Index, IndexType, DataVT: N->getValueType(ResNo: 0), DAG)) {
12767 SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
12768 return DAG.getMaskedGather(
12769 VTs: DAG.getVTList(VT1: N->getValueType(ResNo: 0), VT2: MVT::Other), MemVT: MGT->getMemoryVT(), dl: DL,
12770 Ops, MMO: MGT->getMemOperand(), IndexType, ExtTy: MGT->getExtensionType());
12771 }
12772
12773 return SDValue();
12774}
12775
12776SDValue DAGCombiner::visitMLOAD(SDNode *N) {
12777 MaskedLoadSDNode *MLD = cast<MaskedLoadSDNode>(Val: N);
12778 SDValue Mask = MLD->getMask();
12779
12780 // Zap masked loads with a zero mask.
12781 if (ISD::isConstantSplatVectorAllZeros(N: Mask.getNode()))
12782 return CombineTo(N, Res0: MLD->getPassThru(), Res1: MLD->getChain());
12783
12784 // If this is a masked load with an all ones mask, we can use a unmasked load.
12785 // FIXME: Can we do this for indexed, expanding, or extending loads?
12786 if (ISD::isConstantSplatVectorAllOnes(N: Mask.getNode()) && MLD->isUnindexed() &&
12787 !MLD->isExpandingLoad() && MLD->getExtensionType() == ISD::NON_EXTLOAD) {
12788 SDValue NewLd = DAG.getLoad(
12789 VT: N->getValueType(ResNo: 0), dl: SDLoc(N), Chain: MLD->getChain(), Ptr: MLD->getBasePtr(),
12790 PtrInfo: MLD->getPointerInfo(), Alignment: MLD->getBaseAlign(),
12791 MMOFlags: MLD->getMemOperand()->getFlags(), AAInfo: MLD->getAAInfo(), Ranges: MLD->getRanges());
12792 return CombineTo(N, Res0: NewLd, Res1: NewLd.getValue(R: 1));
12793 }
12794
12795 // Try transforming N to an indexed load.
12796 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
12797 return SDValue(N, 0);
12798
12799 return SDValue();
12800}
12801
12802SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
12803 MaskedHistogramSDNode *HG = cast<MaskedHistogramSDNode>(Val: N);
12804 SDValue Chain = HG->getChain();
12805 SDValue Inc = HG->getInc();
12806 SDValue Mask = HG->getMask();
12807 SDValue BasePtr = HG->getBasePtr();
12808 SDValue Index = HG->getIndex();
12809 SDLoc DL(HG);
12810
12811 EVT MemVT = HG->getMemoryVT();
12812 MachineMemOperand *MMO = HG->getMemOperand();
12813 ISD::MemIndexType IndexType = HG->getIndexType();
12814
12815 if (ISD::isConstantSplatVectorAllZeros(N: Mask.getNode()))
12816 return Chain;
12817
12818 SDValue Ops[] = {Chain, Inc, Mask, BasePtr, Index,
12819 HG->getScale(), HG->getIntID()};
12820 if (refineUniformBase(BasePtr, Index, IndexIsScaled: HG->isIndexScaled(), DAG, DL))
12821 return DAG.getMaskedHistogram(VTs: DAG.getVTList(VT: MVT::Other), MemVT, dl: DL, Ops,
12822 MMO, IndexType);
12823
12824 EVT DataVT = Index.getValueType();
12825 if (refineIndexType(Index, IndexType, DataVT, DAG))
12826 return DAG.getMaskedHistogram(VTs: DAG.getVTList(VT: MVT::Other), MemVT, dl: DL, Ops,
12827 MMO, IndexType);
12828 return SDValue();
12829}
12830
12831SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
12832 if (SDValue Res = foldPartialReduceMLAMulOp(N))
12833 return Res;
12834 if (SDValue Res = foldPartialReduceAdd(N))
12835 return Res;
12836 return SDValue();
12837}
12838
12839// partial_reduce_*mla(acc, mul(ext(a), ext(b)), splat(1))
12840// -> partial_reduce_*mla(acc, a, b)
12841//
12842// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
12843// -> partial_reduce_*mla(acc, x, C)
12844SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
12845 SDLoc DL(N);
12846 auto *Context = DAG.getContext();
12847 SDValue Acc = N->getOperand(Num: 0);
12848 SDValue Op1 = N->getOperand(Num: 1);
12849 SDValue Op2 = N->getOperand(Num: 2);
12850
12851 APInt C;
12852 if (Op1->getOpcode() != ISD::MUL ||
12853 !ISD::isConstantSplatVector(N: Op2.getNode(), SplatValue&: C) || !C.isOne())
12854 return SDValue();
12855
12856 SDValue LHS = Op1->getOperand(Num: 0);
12857 SDValue RHS = Op1->getOperand(Num: 1);
12858 unsigned LHSOpcode = LHS->getOpcode();
12859 if (!ISD::isExtOpcode(Opcode: LHSOpcode))
12860 return SDValue();
12861
12862 SDValue LHSExtOp = LHS->getOperand(Num: 0);
12863 EVT LHSExtOpVT = LHSExtOp.getValueType();
12864
12865 // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
12866 // -> partial_reduce_*mla(acc, x, C)
12867 if (ISD::isConstantSplatVector(N: RHS.getNode(), SplatValue&: C)) {
12868 // TODO: Make use of partial_reduce_sumla here
12869 APInt CTrunc = C.trunc(width: LHSExtOpVT.getScalarSizeInBits());
12870 unsigned LHSBits = LHS.getValueType().getScalarSizeInBits();
12871 if ((LHSOpcode != ISD::ZERO_EXTEND || CTrunc.zext(width: LHSBits) != C) &&
12872 (LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(width: LHSBits) != C))
12873 return SDValue();
12874
12875 unsigned NewOpcode = LHSOpcode == ISD::SIGN_EXTEND
12876 ? ISD::PARTIAL_REDUCE_SMLA
12877 : ISD::PARTIAL_REDUCE_UMLA;
12878
12879 // Only perform these combines if the target supports folding
12880 // the extends into the operation.
12881 if (!TLI.isPartialReduceMLALegalOrCustom(
12882 Opc: NewOpcode, AccVT: TLI.getTypeToTransformTo(Context&: *Context, VT: N->getValueType(ResNo: 0)),
12883 InputVT: TLI.getTypeToTransformTo(Context&: *Context, VT: LHSExtOpVT)))
12884 return SDValue();
12885
12886 return DAG.getNode(Opcode: NewOpcode, DL, VT: N->getValueType(ResNo: 0), N1: Acc, N2: LHSExtOp,
12887 N3: DAG.getConstant(Val: CTrunc, DL, VT: LHSExtOpVT));
12888 }
12889
12890 unsigned RHSOpcode = RHS->getOpcode();
12891 if (!ISD::isExtOpcode(Opcode: RHSOpcode))
12892 return SDValue();
12893
12894 SDValue RHSExtOp = RHS->getOperand(Num: 0);
12895 if (LHSExtOpVT != RHSExtOp.getValueType())
12896 return SDValue();
12897
12898 unsigned NewOpc;
12899 if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::SIGN_EXTEND)
12900 NewOpc = ISD::PARTIAL_REDUCE_SMLA;
12901 else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
12902 NewOpc = ISD::PARTIAL_REDUCE_UMLA;
12903 else if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
12904 NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
12905 else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::SIGN_EXTEND) {
12906 NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
12907 std::swap(a&: LHSExtOp, b&: RHSExtOp);
12908 } else
12909 return SDValue();
12910 // For a 2-stage extend the signedness of both of the extends must match
12911 // If the mul has the same type, there is no outer extend, and thus we
12912 // can simply use the inner extends to pick the result node.
12913 // TODO: extend to handle nonneg zext as sext
12914 EVT AccElemVT = Acc.getValueType().getVectorElementType();
12915 if (Op1.getValueType().getVectorElementType() != AccElemVT &&
12916 NewOpc != N->getOpcode())
12917 return SDValue();
12918
12919 // Only perform these combines if the target supports folding
12920 // the extends into the operation.
12921 if (!TLI.isPartialReduceMLALegalOrCustom(
12922 Opc: NewOpc, AccVT: TLI.getTypeToTransformTo(Context&: *Context, VT: N->getValueType(ResNo: 0)),
12923 InputVT: TLI.getTypeToTransformTo(Context&: *Context, VT: LHSExtOpVT)))
12924 return SDValue();
12925
12926 return DAG.getNode(Opcode: NewOpc, DL, VT: N->getValueType(ResNo: 0), N1: Acc, N2: LHSExtOp, N3: RHSExtOp);
12927}
12928
12929// partial.reduce.umla(acc, zext(op), splat(1))
12930// -> partial.reduce.umla(acc, op, splat(trunc(1)))
12931// partial.reduce.smla(acc, sext(op), splat(1))
12932// -> partial.reduce.smla(acc, op, splat(trunc(1)))
12933// partial.reduce.sumla(acc, sext(op), splat(1))
12934// -> partial.reduce.smla(acc, op, splat(trunc(1)))
12935SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
12936 SDLoc DL(N);
12937 SDValue Acc = N->getOperand(Num: 0);
12938 SDValue Op1 = N->getOperand(Num: 1);
12939 SDValue Op2 = N->getOperand(Num: 2);
12940
12941 APInt ConstantOne;
12942 if (!ISD::isConstantSplatVector(N: Op2.getNode(), SplatValue&: ConstantOne) ||
12943 !ConstantOne.isOne())
12944 return SDValue();
12945
12946 unsigned Op1Opcode = Op1.getOpcode();
12947 if (!ISD::isExtOpcode(Opcode: Op1Opcode))
12948 return SDValue();
12949
12950 bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
12951 bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA;
12952 EVT AccElemVT = Acc.getValueType().getVectorElementType();
12953 if (Op1IsSigned != NodeIsSigned &&
12954 Op1.getValueType().getVectorElementType() != AccElemVT)
12955 return SDValue();
12956
12957 unsigned NewOpcode =
12958 Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12959
12960 SDValue UnextOp1 = Op1.getOperand(i: 0);
12961 EVT UnextOp1VT = UnextOp1.getValueType();
12962 auto *Context = DAG.getContext();
12963 if (!TLI.isPartialReduceMLALegalOrCustom(
12964 Opc: NewOpcode, AccVT: TLI.getTypeToTransformTo(Context&: *Context, VT: N->getValueType(ResNo: 0)),
12965 InputVT: TLI.getTypeToTransformTo(Context&: *Context, VT: UnextOp1VT)))
12966 return SDValue();
12967
12968 return DAG.getNode(Opcode: NewOpcode, DL, VT: N->getValueType(ResNo: 0), N1: Acc, N2: UnextOp1,
12969 N3: DAG.getConstant(Val: 1, DL, VT: UnextOp1VT));
12970}
12971
12972SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
12973 auto *SLD = cast<VPStridedLoadSDNode>(Val: N);
12974 EVT EltVT = SLD->getValueType(ResNo: 0).getVectorElementType();
12975 // Combine strided loads with unit-stride to a regular VP load.
12976 if (auto *CStride = dyn_cast<ConstantSDNode>(Val: SLD->getStride());
12977 CStride && CStride->getZExtValue() == EltVT.getStoreSize()) {
12978 SDValue NewLd = DAG.getLoadVP(
12979 AM: SLD->getAddressingMode(), ExtType: SLD->getExtensionType(), VT: SLD->getValueType(ResNo: 0),
12980 dl: SDLoc(N), Chain: SLD->getChain(), Ptr: SLD->getBasePtr(), Offset: SLD->getOffset(),
12981 Mask: SLD->getMask(), EVL: SLD->getVectorLength(), MemVT: SLD->getMemoryVT(),
12982 MMO: SLD->getMemOperand(), IsExpanding: SLD->isExpandingLoad());
12983 return CombineTo(N, Res0: NewLd, Res1: NewLd.getValue(R: 1));
12984 }
12985 return SDValue();
12986}
12987
12988/// A vector select of 2 constant vectors can be simplified to math/logic to
12989/// avoid a variable select instruction and possibly avoid constant loads.
12990SDValue DAGCombiner::foldVSelectOfConstants(SDNode *N) {
12991 SDValue Cond = N->getOperand(Num: 0);
12992 SDValue N1 = N->getOperand(Num: 1);
12993 SDValue N2 = N->getOperand(Num: 2);
12994 EVT VT = N->getValueType(ResNo: 0);
12995 if (!Cond.hasOneUse() || Cond.getScalarValueSizeInBits() != 1 ||
12996 !shouldConvertSelectOfConstantsToMath(Cond, VT, TLI) ||
12997 !ISD::isBuildVectorOfConstantSDNodes(N: N1.getNode()) ||
12998 !ISD::isBuildVectorOfConstantSDNodes(N: N2.getNode()))
12999 return SDValue();
13000
13001 // Check if we can use the condition value to increment/decrement a single
13002 // constant value. This simplifies a select to an add and removes a constant
13003 // load/materialization from the general case.
13004 bool AllAddOne = true;
13005 bool AllSubOne = true;
13006 unsigned Elts = VT.getVectorNumElements();
13007 for (unsigned i = 0; i != Elts; ++i) {
13008 SDValue N1Elt = N1.getOperand(i);
13009 SDValue N2Elt = N2.getOperand(i);
13010 if (N1Elt.isUndef())
13011 continue;
13012 // N2 should not contain undef values since it will be reused in the fold.
13013 if (N2Elt.isUndef() || N1Elt.getValueType() != N2Elt.getValueType()) {
13014 AllAddOne = false;
13015 AllSubOne = false;
13016 break;
13017 }
13018
13019 const APInt &C1 = N1Elt->getAsAPIntVal();
13020 const APInt &C2 = N2Elt->getAsAPIntVal();
13021 if (C1 != C2 + 1)
13022 AllAddOne = false;
13023 if (C1 != C2 - 1)
13024 AllSubOne = false;
13025 }
13026
13027 // Further simplifications for the extra-special cases where the constants are
13028 // all 0 or all -1 should be implemented as folds of these patterns.
13029 SDLoc DL(N);
13030 if (AllAddOne || AllSubOne) {
13031 // vselect <N x i1> Cond, C+1, C --> add (zext Cond), C
13032 // vselect <N x i1> Cond, C-1, C --> add (sext Cond), C
13033 auto ExtendOpcode = AllAddOne ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND;
13034 SDValue ExtendedCond = DAG.getNode(Opcode: ExtendOpcode, DL, VT, Operand: Cond);
13035 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: ExtendedCond, N2);
13036 }
13037
13038 // select Cond, Pow2C, 0 --> (zext Cond) << log2(Pow2C)
13039 APInt Pow2C;
13040 if (ISD::isConstantSplatVector(N: N1.getNode(), SplatValue&: Pow2C) && Pow2C.isPowerOf2() &&
13041 isNullOrNullSplat(V: N2)) {
13042 SDValue ZextCond = DAG.getZExtOrTrunc(Op: Cond, DL, VT);
13043 SDValue ShAmtC = DAG.getConstant(Val: Pow2C.exactLogBase2(), DL, VT);
13044 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: ZextCond, N2: ShAmtC);
13045 }
13046
13047 if (SDValue V = foldSelectOfConstantsUsingSra(N, DL, DAG))
13048 return V;
13049
13050 // The general case for select-of-constants:
13051 // vselect <N x i1> Cond, C1, C2 --> xor (and (sext Cond), (C1^C2)), C2
13052 // ...but that only makes sense if a vselect is slower than 2 logic ops, so
13053 // leave that to a machine-specific pass.
13054 return SDValue();
13055}
13056
13057SDValue DAGCombiner::visitVP_SELECT(SDNode *N) {
13058 SDValue N0 = N->getOperand(Num: 0);
13059 SDValue N1 = N->getOperand(Num: 1);
13060 SDValue N2 = N->getOperand(Num: 2);
13061 SDLoc DL(N);
13062
13063 if (SDValue V = DAG.simplifySelect(Cond: N0, TVal: N1, FVal: N2))
13064 return V;
13065
13066 if (SDValue V = foldBoolSelectToLogic<VPMatchContext>(N, DL, DAG))
13067 return V;
13068
13069 return SDValue();
13070}
13071
13072static SDValue combineVSelectWithAllOnesOrZeros(SDValue Cond, SDValue TVal,
13073 SDValue FVal,
13074 const TargetLowering &TLI,
13075 SelectionDAG &DAG,
13076 const SDLoc &DL) {
13077 EVT VT = TVal.getValueType();
13078 if (!TLI.isTypeLegal(VT))
13079 return SDValue();
13080
13081 EVT CondVT = Cond.getValueType();
13082 assert(CondVT.isVector() && "Vector select expects a vector selector!");
13083
13084 bool IsTAllZero = ISD::isBuildVectorAllZeros(N: TVal.getNode());
13085 bool IsTAllOne = ISD::isBuildVectorAllOnes(N: TVal.getNode());
13086 bool IsFAllZero = ISD::isBuildVectorAllZeros(N: FVal.getNode());
13087 bool IsFAllOne = ISD::isBuildVectorAllOnes(N: FVal.getNode());
13088
13089 // no vselect(cond, 0/-1, X) or vselect(cond, X, 0/-1), return
13090 if (!IsTAllZero && !IsTAllOne && !IsFAllZero && !IsFAllOne)
13091 return SDValue();
13092
13093 // select Cond, 0, 0 → 0
13094 if (IsTAllZero && IsFAllZero) {
13095 return VT.isFloatingPoint() ? DAG.getConstantFP(Val: 0.0, DL, VT)
13096 : DAG.getConstant(Val: 0, DL, VT);
13097 }
13098
13099 // check select(setgt lhs, -1), 1, -1 --> or (sra lhs, bitwidth - 1), 1
13100 APInt TValAPInt;
13101 if (Cond.getOpcode() == ISD::SETCC &&
13102 Cond.getOperand(i: 2) == DAG.getCondCode(Cond: ISD::SETGT) &&
13103 Cond.getOperand(i: 0).getValueType() == VT && VT.isSimple() &&
13104 ISD::isConstantSplatVector(N: TVal.getNode(), SplatValue&: TValAPInt) &&
13105 TValAPInt.isOne() &&
13106 ISD::isConstantSplatVectorAllOnes(N: Cond.getOperand(i: 1).getNode()) &&
13107 ISD::isConstantSplatVectorAllOnes(N: FVal.getNode())) {
13108 return SDValue();
13109 }
13110
13111 // To use the condition operand as a bitwise mask, it must have elements that
13112 // are the same size as the select elements. i.e, the condition operand must
13113 // have already been promoted from the IR select condition type <N x i1>.
13114 // Don't check if the types themselves are equal because that excludes
13115 // vector floating-point selects.
13116 if (CondVT.getScalarSizeInBits() != VT.getScalarSizeInBits())
13117 return SDValue();
13118
13119 // Cond value must be 'sign splat' to be converted to a logical op.
13120 if (DAG.ComputeNumSignBits(Op: Cond) != CondVT.getScalarSizeInBits())
13121 return SDValue();
13122
13123 // Try inverting Cond and swapping T/F if it gives all-ones/all-zeros form
13124 if (!IsTAllOne && !IsFAllZero && Cond.hasOneUse() &&
13125 Cond.getOpcode() == ISD::SETCC &&
13126 TLI.getSetCCResultType(DL: DAG.getDataLayout(), Context&: *DAG.getContext(), VT) ==
13127 CondVT) {
13128 if (IsTAllZero || IsFAllOne) {
13129 SDValue CC = Cond.getOperand(i: 2);
13130 ISD::CondCode InverseCC = ISD::getSetCCInverse(
13131 Operation: cast<CondCodeSDNode>(Val&: CC)->get(), Type: Cond.getOperand(i: 0).getValueType());
13132 Cond = DAG.getSetCC(DL, VT: CondVT, LHS: Cond.getOperand(i: 0), RHS: Cond.getOperand(i: 1),
13133 Cond: InverseCC);
13134 std::swap(a&: TVal, b&: FVal);
13135 std::swap(a&: IsTAllOne, b&: IsFAllOne);
13136 std::swap(a&: IsTAllZero, b&: IsFAllZero);
13137 }
13138 }
13139
13140 assert(DAG.ComputeNumSignBits(Cond) == CondVT.getScalarSizeInBits() &&
13141 "Select condition no longer all-sign bits");
13142
13143 // select Cond, -1, 0 → bitcast Cond
13144 if (IsTAllOne && IsFAllZero)
13145 return DAG.getBitcast(VT, V: Cond);
13146
13147 // select Cond, -1, x → or Cond, x
13148 if (IsTAllOne) {
13149 SDValue X = DAG.getBitcast(VT: CondVT, V: FVal);
13150 SDValue Or = DAG.getNode(Opcode: ISD::OR, DL, VT: CondVT, N1: Cond, N2: X);
13151 return DAG.getBitcast(VT, V: Or);
13152 }
13153
13154 // select Cond, x, 0 → and Cond, x
13155 if (IsFAllZero) {
13156 SDValue X = DAG.getBitcast(VT: CondVT, V: TVal);
13157 SDValue And = DAG.getNode(Opcode: ISD::AND, DL, VT: CondVT, N1: Cond, N2: X);
13158 return DAG.getBitcast(VT, V: And);
13159 }
13160
13161 return SDValue();
13162}
13163
13164SDValue DAGCombiner::visitVSELECT(SDNode *N) {
13165 SDValue N0 = N->getOperand(Num: 0);
13166 SDValue N1 = N->getOperand(Num: 1);
13167 SDValue N2 = N->getOperand(Num: 2);
13168 EVT VT = N->getValueType(ResNo: 0);
13169 SDLoc DL(N);
13170
13171 if (SDValue V = DAG.simplifySelect(Cond: N0, TVal: N1, FVal: N2))
13172 return V;
13173
13174 if (SDValue V = foldBoolSelectToLogic<EmptyMatchContext>(N, DL, DAG))
13175 return V;
13176
13177 // vselect (not Cond), N1, N2 -> vselect Cond, N2, N1
13178 if (SDValue F = extractBooleanFlip(V: N0, DAG, TLI, Force: false))
13179 return DAG.getSelect(DL, VT, Cond: F, LHS: N2, RHS: N1);
13180
13181 // select (sext m), (add X, C), X --> (add X, (and C, (sext m))))
13182 if (N1.getOpcode() == ISD::ADD && N1.getOperand(i: 0) == N2 && N1->hasOneUse() &&
13183 DAG.isConstantIntBuildVectorOrConstantInt(N: N1.getOperand(i: 1)) &&
13184 N0.getScalarValueSizeInBits() == N1.getScalarValueSizeInBits() &&
13185 TLI.getBooleanContents(Type: N0.getValueType()) ==
13186 TargetLowering::ZeroOrNegativeOneBooleanContent) {
13187 return DAG.getNode(
13188 Opcode: ISD::ADD, DL, VT: N1.getValueType(), N1: N2,
13189 N2: DAG.getNode(Opcode: ISD::AND, DL, VT: N0.getValueType(), N1: N1.getOperand(i: 1), N2: N0));
13190 }
13191
13192 // Canonicalize integer abs.
13193 // vselect (setg[te] X, 0), X, -X ->
13194 // vselect (setgt X, -1), X, -X ->
13195 // vselect (setl[te] X, 0), -X, X ->
13196 // Y = sra (X, size(X)-1); xor (add (X, Y), Y)
13197 if (N0.getOpcode() == ISD::SETCC) {
13198 SDValue LHS = N0.getOperand(i: 0), RHS = N0.getOperand(i: 1);
13199 ISD::CondCode CC = cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get();
13200 bool isAbs = false;
13201 bool RHSIsAllZeros = ISD::isBuildVectorAllZeros(N: RHS.getNode());
13202
13203 if (((RHSIsAllZeros && (CC == ISD::SETGT || CC == ISD::SETGE)) ||
13204 (ISD::isBuildVectorAllOnes(N: RHS.getNode()) && CC == ISD::SETGT)) &&
13205 N1 == LHS && N2.getOpcode() == ISD::SUB && N1 == N2.getOperand(i: 1))
13206 isAbs = ISD::isBuildVectorAllZeros(N: N2.getOperand(i: 0).getNode());
13207 else if ((RHSIsAllZeros && (CC == ISD::SETLT || CC == ISD::SETLE)) &&
13208 N2 == LHS && N1.getOpcode() == ISD::SUB && N2 == N1.getOperand(i: 1))
13209 isAbs = ISD::isBuildVectorAllZeros(N: N1.getOperand(i: 0).getNode());
13210
13211 if (isAbs) {
13212 if (TLI.isOperationLegalOrCustom(Op: ISD::ABS, VT))
13213 return DAG.getNode(Opcode: ISD::ABS, DL, VT, Operand: LHS);
13214
13215 SDValue Shift = DAG.getNode(
13216 Opcode: ISD::SRA, DL, VT, N1: LHS,
13217 N2: DAG.getShiftAmountConstant(Val: VT.getScalarSizeInBits() - 1, VT, DL));
13218 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: LHS, N2: Shift);
13219 AddToWorklist(N: Shift.getNode());
13220 AddToWorklist(N: Add.getNode());
13221 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: Add, N2: Shift);
13222 }
13223
13224 // vselect x, y (fcmp lt x, y) -> fminnum x, y
13225 // vselect x, y (fcmp gt x, y) -> fmaxnum x, y
13226 //
13227 // This is OK if we don't care about what happens if either operand is a
13228 // NaN.
13229 //
13230 if (N0.hasOneUse() &&
13231 isLegalToCombineMinNumMaxNum(DAG, LHS, RHS, Flags: N->getFlags(), TLI)) {
13232 if (SDValue FMinMax = combineMinNumMaxNum(DL, VT, LHS, RHS, True: N1, False: N2, CC))
13233 return FMinMax;
13234 }
13235
13236 if (SDValue S = PerformMinMaxFpToSatCombine(N0: LHS, N1: RHS, N2: N1, N3: N2, CC, DAG))
13237 return S;
13238 if (SDValue S = PerformUMinFpToSatCombine(N0: LHS, N1: RHS, N2: N1, N3: N2, CC, DAG))
13239 return S;
13240
13241 // If this select has a condition (setcc) with narrower operands than the
13242 // select, try to widen the compare to match the select width.
13243 // TODO: This should be extended to handle any constant.
13244 // TODO: This could be extended to handle non-loading patterns, but that
13245 // requires thorough testing to avoid regressions.
13246 if (isNullOrNullSplat(V: RHS)) {
13247 EVT NarrowVT = LHS.getValueType();
13248 EVT WideVT = N1.getValueType().changeVectorElementTypeToInteger();
13249 EVT SetCCVT = getSetCCResultType(VT: LHS.getValueType());
13250 unsigned SetCCWidth = SetCCVT.getScalarSizeInBits();
13251 unsigned WideWidth = WideVT.getScalarSizeInBits();
13252 bool IsSigned = isSignedIntSetCC(Code: CC);
13253 auto LoadExtOpcode = IsSigned ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
13254 if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() &&
13255 SetCCWidth != 1 && SetCCWidth < WideWidth &&
13256 TLI.isLoadExtLegalOrCustom(ExtType: LoadExtOpcode, ValVT: WideVT, MemVT: NarrowVT) &&
13257 TLI.isOperationLegalOrCustom(Op: ISD::SETCC, VT: WideVT)) {
13258 // Both compare operands can be widened for free. The LHS can use an
13259 // extended load, and the RHS is a constant:
13260 // vselect (ext (setcc load(X), C)), N1, N2 -->
13261 // vselect (setcc extload(X), C'), N1, N2
13262 auto ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
13263 SDValue WideLHS = DAG.getNode(Opcode: ExtOpcode, DL, VT: WideVT, Operand: LHS);
13264 SDValue WideRHS = DAG.getNode(Opcode: ExtOpcode, DL, VT: WideVT, Operand: RHS);
13265 EVT WideSetCCVT = getSetCCResultType(VT: WideVT);
13266 SDValue WideSetCC = DAG.getSetCC(DL, VT: WideSetCCVT, LHS: WideLHS, RHS: WideRHS, Cond: CC);
13267 return DAG.getSelect(DL, VT: N1.getValueType(), Cond: WideSetCC, LHS: N1, RHS: N2);
13268 }
13269 }
13270
13271 if (SDValue ABD = foldSelectToABD(LHS, RHS, True: N1, False: N2, CC, DL))
13272 return ABD;
13273
13274 // Match VSELECTs into add with unsigned saturation.
13275 if (hasOperation(Opcode: ISD::UADDSAT, VT)) {
13276 // Check if one of the arms of the VSELECT is vector with all bits set.
13277 // If it's on the left side invert the predicate to simplify logic below.
13278 SDValue Other;
13279 ISD::CondCode SatCC = CC;
13280 if (ISD::isConstantSplatVectorAllOnes(N: N1.getNode())) {
13281 Other = N2;
13282 SatCC = ISD::getSetCCInverse(Operation: SatCC, Type: VT.getScalarType());
13283 } else if (ISD::isConstantSplatVectorAllOnes(N: N2.getNode())) {
13284 Other = N1;
13285 }
13286
13287 if (Other && Other.getOpcode() == ISD::ADD) {
13288 SDValue CondLHS = LHS, CondRHS = RHS;
13289 SDValue OpLHS = Other.getOperand(i: 0), OpRHS = Other.getOperand(i: 1);
13290
13291 // Canonicalize condition operands.
13292 if (SatCC == ISD::SETUGE) {
13293 std::swap(a&: CondLHS, b&: CondRHS);
13294 SatCC = ISD::SETULE;
13295 }
13296
13297 // We can test against either of the addition operands.
13298 // x <= x+y ? x+y : ~0 --> uaddsat x, y
13299 // x+y >= x ? x+y : ~0 --> uaddsat x, y
13300 if (SatCC == ISD::SETULE && Other == CondRHS &&
13301 (OpLHS == CondLHS || OpRHS == CondLHS))
13302 return DAG.getNode(Opcode: ISD::UADDSAT, DL, VT, N1: OpLHS, N2: OpRHS);
13303
13304 if (OpRHS.getOpcode() == CondRHS.getOpcode() &&
13305 (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
13306 OpRHS.getOpcode() == ISD::SPLAT_VECTOR) &&
13307 CondLHS == OpLHS) {
13308 // If the RHS is a constant we have to reverse the const
13309 // canonicalization.
13310 // x >= ~C ? x+C : ~0 --> uaddsat x, C
13311 auto MatchUADDSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
13312 return Cond->getAPIntValue() == ~Op->getAPIntValue();
13313 };
13314 if (SatCC == ISD::SETULE &&
13315 ISD::matchBinaryPredicate(LHS: OpRHS, RHS: CondRHS, Match: MatchUADDSAT))
13316 return DAG.getNode(Opcode: ISD::UADDSAT, DL, VT, N1: OpLHS, N2: OpRHS);
13317 }
13318 }
13319 }
13320
13321 // Match VSELECTs into sub with unsigned saturation.
13322 if (hasOperation(Opcode: ISD::USUBSAT, VT)) {
13323 // Check if one of the arms of the VSELECT is a zero vector. If it's on
13324 // the left side invert the predicate to simplify logic below.
13325 SDValue Other;
13326 ISD::CondCode SatCC = CC;
13327 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode())) {
13328 Other = N2;
13329 SatCC = ISD::getSetCCInverse(Operation: SatCC, Type: VT.getScalarType());
13330 } else if (ISD::isConstantSplatVectorAllZeros(N: N2.getNode())) {
13331 Other = N1;
13332 }
13333
13334 // zext(x) >= y ? trunc(zext(x) - y) : 0
13335 // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
13336 // zext(x) > y ? trunc(zext(x) - y) : 0
13337 // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
13338 if (Other && Other.getOpcode() == ISD::TRUNCATE &&
13339 Other.getOperand(i: 0).getOpcode() == ISD::SUB &&
13340 (SatCC == ISD::SETUGE || SatCC == ISD::SETUGT)) {
13341 SDValue OpLHS = Other.getOperand(i: 0).getOperand(i: 0);
13342 SDValue OpRHS = Other.getOperand(i: 0).getOperand(i: 1);
13343 if (LHS == OpLHS && RHS == OpRHS && LHS.getOpcode() == ISD::ZERO_EXTEND)
13344 if (SDValue R = getTruncatedUSUBSAT(DstVT: VT, SrcVT: LHS.getValueType(), LHS, RHS,
13345 DAG, DL))
13346 return R;
13347 }
13348
13349 if (Other && Other.getNumOperands() == 2) {
13350 SDValue CondRHS = RHS;
13351 SDValue OpLHS = Other.getOperand(i: 0), OpRHS = Other.getOperand(i: 1);
13352
13353 if (OpLHS == LHS) {
13354 // Look for a general sub with unsigned saturation first.
13355 // x >= y ? x-y : 0 --> usubsat x, y
13356 // x > y ? x-y : 0 --> usubsat x, y
13357 if ((SatCC == ISD::SETUGE || SatCC == ISD::SETUGT) &&
13358 Other.getOpcode() == ISD::SUB && OpRHS == CondRHS)
13359 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT, N1: OpLHS, N2: OpRHS);
13360
13361 if (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
13362 OpRHS.getOpcode() == ISD::SPLAT_VECTOR) {
13363 if (CondRHS.getOpcode() == ISD::BUILD_VECTOR ||
13364 CondRHS.getOpcode() == ISD::SPLAT_VECTOR) {
13365 // If the RHS is a constant we have to reverse the const
13366 // canonicalization.
13367 // x > C-1 ? x+-C : 0 --> usubsat x, C
13368 auto MatchUSUBSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
13369 return (!Op && !Cond) ||
13370 (Op && Cond &&
13371 Cond->getAPIntValue() == (-Op->getAPIntValue() - 1));
13372 };
13373 if (SatCC == ISD::SETUGT && Other.getOpcode() == ISD::ADD &&
13374 ISD::matchBinaryPredicate(LHS: OpRHS, RHS: CondRHS, Match: MatchUSUBSAT,
13375 /*AllowUndefs*/ true)) {
13376 OpRHS = DAG.getNegative(Val: OpRHS, DL, VT);
13377 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT, N1: OpLHS, N2: OpRHS);
13378 }
13379
13380 // Another special case: If C was a sign bit, the sub has been
13381 // canonicalized into a xor.
13382 // FIXME: Would it be better to use computeKnownBits to
13383 // determine whether it's safe to decanonicalize the xor?
13384 // x s< 0 ? x^C : 0 --> usubsat x, C
13385 APInt SplatValue;
13386 if (SatCC == ISD::SETLT && Other.getOpcode() == ISD::XOR &&
13387 ISD::isConstantSplatVector(N: OpRHS.getNode(), SplatValue) &&
13388 ISD::isConstantSplatVectorAllZeros(N: CondRHS.getNode()) &&
13389 SplatValue.isSignMask()) {
13390 // Note that we have to rebuild the RHS constant here to
13391 // ensure we don't rely on particular values of undef lanes.
13392 OpRHS = DAG.getConstant(Val: SplatValue, DL, VT);
13393 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT, N1: OpLHS, N2: OpRHS);
13394 }
13395 }
13396 }
13397 }
13398 }
13399 }
13400 }
13401
13402 if (SimplifySelectOps(SELECT: N, LHS: N1, RHS: N2))
13403 return SDValue(N, 0); // Don't revisit N.
13404
13405 // Fold (vselect all_ones, N1, N2) -> N1
13406 if (ISD::isConstantSplatVectorAllOnes(N: N0.getNode()))
13407 return N1;
13408 // Fold (vselect all_zeros, N1, N2) -> N2
13409 if (ISD::isConstantSplatVectorAllZeros(N: N0.getNode()))
13410 return N2;
13411
13412 // The ConvertSelectToConcatVector function is assuming both the above
13413 // checks for (vselect (build_vector all{ones,zeros) ...) have been made
13414 // and addressed.
13415 if (N1.getOpcode() == ISD::CONCAT_VECTORS &&
13416 N2.getOpcode() == ISD::CONCAT_VECTORS &&
13417 ISD::isBuildVectorOfConstantSDNodes(N: N0.getNode())) {
13418 if (SDValue CV = ConvertSelectToConcatVector(N, DAG))
13419 return CV;
13420 }
13421
13422 if (SDValue V = foldVSelectOfConstants(N))
13423 return V;
13424
13425 if (hasOperation(Opcode: ISD::SRA, VT))
13426 if (SDValue V = foldVSelectToSignBitSplatMask(N, DAG))
13427 return V;
13428
13429 if (SimplifyDemandedVectorElts(Op: SDValue(N, 0)))
13430 return SDValue(N, 0);
13431
13432 if (SDValue V = combineVSelectWithAllOnesOrZeros(Cond: N0, TVal: N1, FVal: N2, TLI, DAG, DL))
13433 return V;
13434
13435 return SDValue();
13436}
13437
13438SDValue DAGCombiner::visitSELECT_CC(SDNode *N) {
13439 SDValue N0 = N->getOperand(Num: 0);
13440 SDValue N1 = N->getOperand(Num: 1);
13441 SDValue N2 = N->getOperand(Num: 2);
13442 SDValue N3 = N->getOperand(Num: 3);
13443 SDValue N4 = N->getOperand(Num: 4);
13444 ISD::CondCode CC = cast<CondCodeSDNode>(Val&: N4)->get();
13445 SDLoc DL(N);
13446
13447 // fold select_cc lhs, rhs, x, x, cc -> x
13448 if (N2 == N3)
13449 return N2;
13450
13451 // select_cc bool, 0, x, y, seteq -> select bool, y, x
13452 if (CC == ISD::SETEQ && !LegalTypes && N0.getValueType() == MVT::i1 &&
13453 isNullConstant(V: N1))
13454 return DAG.getSelect(DL, VT: N2.getValueType(), Cond: N0, LHS: N3, RHS: N2);
13455
13456 // Determine if the condition we're dealing with is constant
13457 if (SDValue SCC = SimplifySetCC(VT: getSetCCResultType(VT: N0.getValueType()), N0, N1,
13458 Cond: CC, DL, foldBooleans: false)) {
13459 AddToWorklist(N: SCC.getNode());
13460
13461 // cond always true -> true val
13462 // cond always false -> false val
13463 if (auto *SCCC = dyn_cast<ConstantSDNode>(Val: SCC.getNode()))
13464 return SCCC->isZero() ? N3 : N2;
13465
13466 // When the condition is UNDEF, just return the first operand. This is
13467 // coherent the DAG creation, no setcc node is created in this case
13468 if (SCC->isUndef())
13469 return N2;
13470
13471 // Fold to a simpler select_cc
13472 if (SCC.getOpcode() == ISD::SETCC) {
13473 SDValue SelectOp =
13474 DAG.getNode(Opcode: ISD::SELECT_CC, DL, VT: N2.getValueType(), N1: SCC.getOperand(i: 0),
13475 N2: SCC.getOperand(i: 1), N3: N2, N4: N3, N5: SCC.getOperand(i: 2));
13476 SelectOp->setFlags(SCC->getFlags());
13477 return SelectOp;
13478 }
13479 }
13480
13481 // If we can fold this based on the true/false value, do so.
13482 if (SimplifySelectOps(SELECT: N, LHS: N2, RHS: N3))
13483 return SDValue(N, 0); // Don't revisit N.
13484
13485 // fold select_cc into other things, such as min/max/abs
13486 return SimplifySelectCC(DL, N0, N1, N2, N3, CC);
13487}
13488
13489SDValue DAGCombiner::visitSETCC(SDNode *N) {
13490 // setcc is very commonly used as an argument to brcond. This pattern
13491 // also lend itself to numerous combines and, as a result, it is desired
13492 // we keep the argument to a brcond as a setcc as much as possible.
13493 bool PreferSetCC =
13494 N->hasOneUse() && N->user_begin()->getOpcode() == ISD::BRCOND;
13495
13496 ISD::CondCode Cond = cast<CondCodeSDNode>(Val: N->getOperand(Num: 2))->get();
13497 EVT VT = N->getValueType(ResNo: 0);
13498 SDValue N0 = N->getOperand(Num: 0), N1 = N->getOperand(Num: 1);
13499 SDLoc DL(N);
13500
13501 if (SDValue Combined = SimplifySetCC(VT, N0, N1, Cond, DL, foldBooleans: !PreferSetCC)) {
13502 // If we prefer to have a setcc, and we don't, we'll try our best to
13503 // recreate one using rebuildSetCC.
13504 if (PreferSetCC && Combined.getOpcode() != ISD::SETCC) {
13505 SDValue NewSetCC = rebuildSetCC(N: Combined);
13506
13507 // We don't have anything interesting to combine to.
13508 if (NewSetCC.getNode() == N)
13509 return SDValue();
13510
13511 if (NewSetCC)
13512 return NewSetCC;
13513 }
13514 return Combined;
13515 }
13516
13517 // Optimize
13518 // 1) (icmp eq/ne (and X, C0), (shift X, C1))
13519 // or
13520 // 2) (icmp eq/ne X, (rotate X, C1))
13521 // If C0 is a mask or shifted mask and the shift amt (C1) isolates the
13522 // remaining bits (i.e something like `(x64 & UINT32_MAX) == (x64 >> 32)`)
13523 // Then:
13524 // If C1 is a power of 2, then the rotate and shift+and versions are
13525 // equivilent, so we can interchange them depending on target preference.
13526 // Otherwise, if we have the shift+and version we can interchange srl/shl
13527 // which inturn affects the constant C0. We can use this to get better
13528 // constants again determined by target preference.
13529 if (Cond == ISD::SETNE || Cond == ISD::SETEQ) {
13530 auto IsAndWithShift = [](SDValue A, SDValue B) {
13531 return A.getOpcode() == ISD::AND &&
13532 (B.getOpcode() == ISD::SRL || B.getOpcode() == ISD::SHL) &&
13533 A.getOperand(i: 0) == B.getOperand(i: 0);
13534 };
13535 auto IsRotateWithOp = [](SDValue A, SDValue B) {
13536 return (B.getOpcode() == ISD::ROTL || B.getOpcode() == ISD::ROTR) &&
13537 B.getOperand(i: 0) == A;
13538 };
13539 SDValue AndOrOp = SDValue(), ShiftOrRotate = SDValue();
13540 bool IsRotate = false;
13541
13542 // Find either shift+and or rotate pattern.
13543 if (IsAndWithShift(N0, N1)) {
13544 AndOrOp = N0;
13545 ShiftOrRotate = N1;
13546 } else if (IsAndWithShift(N1, N0)) {
13547 AndOrOp = N1;
13548 ShiftOrRotate = N0;
13549 } else if (IsRotateWithOp(N0, N1)) {
13550 IsRotate = true;
13551 AndOrOp = N0;
13552 ShiftOrRotate = N1;
13553 } else if (IsRotateWithOp(N1, N0)) {
13554 IsRotate = true;
13555 AndOrOp = N1;
13556 ShiftOrRotate = N0;
13557 }
13558
13559 if (AndOrOp && ShiftOrRotate && ShiftOrRotate.hasOneUse() &&
13560 (IsRotate || AndOrOp.hasOneUse())) {
13561 EVT OpVT = N0.getValueType();
13562 // Get constant shift/rotate amount and possibly mask (if its shift+and
13563 // variant).
13564 auto GetAPIntValue = [](SDValue Op) -> std::optional<APInt> {
13565 ConstantSDNode *CNode = isConstOrConstSplat(N: Op, /*AllowUndefs*/ false,
13566 /*AllowTrunc*/ AllowTruncation: false);
13567 if (CNode == nullptr)
13568 return std::nullopt;
13569 return CNode->getAPIntValue();
13570 };
13571 std::optional<APInt> AndCMask =
13572 IsRotate ? std::nullopt : GetAPIntValue(AndOrOp.getOperand(i: 1));
13573 std::optional<APInt> ShiftCAmt =
13574 GetAPIntValue(ShiftOrRotate.getOperand(i: 1));
13575 unsigned NumBits = OpVT.getScalarSizeInBits();
13576
13577 // We found constants.
13578 if (ShiftCAmt && (IsRotate || AndCMask) && ShiftCAmt->ult(RHS: NumBits)) {
13579 unsigned ShiftOpc = ShiftOrRotate.getOpcode();
13580 // Check that the constants meet the constraints.
13581 bool CanTransform = IsRotate;
13582 if (!CanTransform) {
13583 // Check that mask and shift compliment eachother
13584 CanTransform = *ShiftCAmt == (~*AndCMask).popcount();
13585 // Check that we are comparing all bits
13586 CanTransform &= (*ShiftCAmt + AndCMask->popcount()) == NumBits;
13587 // Check that the and mask is correct for the shift
13588 CanTransform &=
13589 ShiftOpc == ISD::SHL ? (~*AndCMask).isMask() : AndCMask->isMask();
13590 }
13591
13592 // See if target prefers another shift/rotate opcode.
13593 unsigned NewShiftOpc = TLI.preferedOpcodeForCmpEqPiecesOfOperand(
13594 VT: OpVT, ShiftOpc, MayTransformRotate: ShiftCAmt->isPowerOf2(), ShiftOrRotateAmt: *ShiftCAmt, AndMask: AndCMask);
13595 // Transform is valid and we have a new preference.
13596 if (CanTransform && NewShiftOpc != ShiftOpc) {
13597 SDValue NewShiftOrRotate =
13598 DAG.getNode(Opcode: NewShiftOpc, DL, VT: OpVT, N1: ShiftOrRotate.getOperand(i: 0),
13599 N2: ShiftOrRotate.getOperand(i: 1));
13600 SDValue NewAndOrOp = SDValue();
13601
13602 if (NewShiftOpc == ISD::SHL || NewShiftOpc == ISD::SRL) {
13603 APInt NewMask =
13604 NewShiftOpc == ISD::SHL
13605 ? APInt::getHighBitsSet(numBits: NumBits,
13606 hiBitsSet: NumBits - ShiftCAmt->getZExtValue())
13607 : APInt::getLowBitsSet(numBits: NumBits,
13608 loBitsSet: NumBits - ShiftCAmt->getZExtValue());
13609 NewAndOrOp =
13610 DAG.getNode(Opcode: ISD::AND, DL, VT: OpVT, N1: ShiftOrRotate.getOperand(i: 0),
13611 N2: DAG.getConstant(Val: NewMask, DL, VT: OpVT));
13612 } else {
13613 NewAndOrOp = ShiftOrRotate.getOperand(i: 0);
13614 }
13615
13616 return DAG.getSetCC(DL, VT, LHS: NewAndOrOp, RHS: NewShiftOrRotate, Cond);
13617 }
13618 }
13619 }
13620 }
13621 return SDValue();
13622}
13623
13624SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) {
13625 SDValue LHS = N->getOperand(Num: 0);
13626 SDValue RHS = N->getOperand(Num: 1);
13627 SDValue Carry = N->getOperand(Num: 2);
13628 SDValue Cond = N->getOperand(Num: 3);
13629
13630 // If Carry is false, fold to a regular SETCC.
13631 if (isNullConstant(V: Carry))
13632 return DAG.getNode(Opcode: ISD::SETCC, DL: SDLoc(N), VTList: N->getVTList(), N1: LHS, N2: RHS, N3: Cond);
13633
13634 return SDValue();
13635}
13636
13637/// Check if N satisfies:
13638/// N is used once.
13639/// N is a Load.
13640/// The load is compatible with ExtOpcode. It means
13641/// If load has explicit zero/sign extension, ExpOpcode must have the same
13642/// extension.
13643/// Otherwise returns true.
13644static bool isCompatibleLoad(SDValue N, unsigned ExtOpcode) {
13645 if (!N.hasOneUse())
13646 return false;
13647
13648 if (!isa<LoadSDNode>(Val: N))
13649 return false;
13650
13651 LoadSDNode *Load = cast<LoadSDNode>(Val&: N);
13652 ISD::LoadExtType LoadExt = Load->getExtensionType();
13653 if (LoadExt == ISD::NON_EXTLOAD || LoadExt == ISD::EXTLOAD)
13654 return true;
13655
13656 // Now LoadExt is either SEXTLOAD or ZEXTLOAD, ExtOpcode must have the same
13657 // extension.
13658 if ((LoadExt == ISD::SEXTLOAD && ExtOpcode != ISD::SIGN_EXTEND) ||
13659 (LoadExt == ISD::ZEXTLOAD && ExtOpcode != ISD::ZERO_EXTEND))
13660 return false;
13661
13662 return true;
13663}
13664
13665/// Fold
13666/// (sext (select c, load x, load y)) -> (select c, sextload x, sextload y)
13667/// (zext (select c, load x, load y)) -> (select c, zextload x, zextload y)
13668/// (aext (select c, load x, load y)) -> (select c, extload x, extload y)
13669/// This function is called by the DAGCombiner when visiting sext/zext/aext
13670/// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
13671static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI,
13672 SelectionDAG &DAG, const SDLoc &DL,
13673 CombineLevel Level) {
13674 unsigned Opcode = N->getOpcode();
13675 SDValue N0 = N->getOperand(Num: 0);
13676 EVT VT = N->getValueType(ResNo: 0);
13677 assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
13678 Opcode == ISD::ANY_EXTEND) &&
13679 "Expected EXTEND dag node in input!");
13680
13681 if (!(N0->getOpcode() == ISD::SELECT || N0->getOpcode() == ISD::VSELECT) ||
13682 !N0.hasOneUse())
13683 return SDValue();
13684
13685 SDValue Op1 = N0->getOperand(Num: 1);
13686 SDValue Op2 = N0->getOperand(Num: 2);
13687 if (!isCompatibleLoad(N: Op1, ExtOpcode: Opcode) || !isCompatibleLoad(N: Op2, ExtOpcode: Opcode))
13688 return SDValue();
13689
13690 auto ExtLoadOpcode = ISD::EXTLOAD;
13691 if (Opcode == ISD::SIGN_EXTEND)
13692 ExtLoadOpcode = ISD::SEXTLOAD;
13693 else if (Opcode == ISD::ZERO_EXTEND)
13694 ExtLoadOpcode = ISD::ZEXTLOAD;
13695
13696 // Illegal VSELECT may ISel fail if happen after legalization (DAG
13697 // Combine2), so we should conservatively check the OperationAction.
13698 LoadSDNode *Load1 = cast<LoadSDNode>(Val&: Op1);
13699 LoadSDNode *Load2 = cast<LoadSDNode>(Val&: Op2);
13700 if (!TLI.isLoadExtLegal(ExtType: ExtLoadOpcode, ValVT: VT, MemVT: Load1->getMemoryVT()) ||
13701 !TLI.isLoadExtLegal(ExtType: ExtLoadOpcode, ValVT: VT, MemVT: Load2->getMemoryVT()) ||
13702 (N0->getOpcode() == ISD::VSELECT && Level >= AfterLegalizeTypes &&
13703 TLI.getOperationAction(Op: ISD::VSELECT, VT) != TargetLowering::Legal))
13704 return SDValue();
13705
13706 SDValue Ext1 = DAG.getNode(Opcode, DL, VT, Operand: Op1);
13707 SDValue Ext2 = DAG.getNode(Opcode, DL, VT, Operand: Op2);
13708 return DAG.getSelect(DL, VT, Cond: N0->getOperand(Num: 0), LHS: Ext1, RHS: Ext2);
13709}
13710
13711/// Try to fold a sext/zext/aext dag node into a ConstantSDNode or
13712/// a build_vector of constants.
13713/// This function is called by the DAGCombiner when visiting sext/zext/aext
13714/// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
13715/// Vector extends are not folded if operations are legal; this is to
13716/// avoid introducing illegal build_vector dag nodes.
13717static SDValue tryToFoldExtendOfConstant(SDNode *N, const SDLoc &DL,
13718 const TargetLowering &TLI,
13719 SelectionDAG &DAG, bool LegalTypes) {
13720 unsigned Opcode = N->getOpcode();
13721 SDValue N0 = N->getOperand(Num: 0);
13722 EVT VT = N->getValueType(ResNo: 0);
13723
13724 assert((ISD::isExtOpcode(Opcode) || ISD::isExtVecInRegOpcode(Opcode)) &&
13725 "Expected EXTEND dag node in input!");
13726
13727 // fold (sext c1) -> c1
13728 // fold (zext c1) -> c1
13729 // fold (aext c1) -> c1
13730 if (isa<ConstantSDNode>(Val: N0))
13731 return DAG.getNode(Opcode, DL, VT, Operand: N0);
13732
13733 // fold (sext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
13734 // fold (zext (select cond, c1, c2)) -> (select cond, zext c1, zext c2)
13735 // fold (aext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
13736 if (N0->getOpcode() == ISD::SELECT) {
13737 SDValue Op1 = N0->getOperand(Num: 1);
13738 SDValue Op2 = N0->getOperand(Num: 2);
13739 if (isa<ConstantSDNode>(Val: Op1) && isa<ConstantSDNode>(Val: Op2) &&
13740 (Opcode != ISD::ZERO_EXTEND || !TLI.isZExtFree(FromTy: N0.getValueType(), ToTy: VT))) {
13741 // For any_extend, choose sign extension of the constants to allow a
13742 // possible further transform to sign_extend_inreg.i.e.
13743 //
13744 // t1: i8 = select t0, Constant:i8<-1>, Constant:i8<0>
13745 // t2: i64 = any_extend t1
13746 // -->
13747 // t3: i64 = select t0, Constant:i64<-1>, Constant:i64<0>
13748 // -->
13749 // t4: i64 = sign_extend_inreg t3
13750 unsigned FoldOpc = Opcode;
13751 if (FoldOpc == ISD::ANY_EXTEND)
13752 FoldOpc = ISD::SIGN_EXTEND;
13753 return DAG.getSelect(DL, VT, Cond: N0->getOperand(Num: 0),
13754 LHS: DAG.getNode(Opcode: FoldOpc, DL, VT, Operand: Op1),
13755 RHS: DAG.getNode(Opcode: FoldOpc, DL, VT, Operand: Op2));
13756 }
13757 }
13758
13759 // fold (sext (build_vector AllConstants) -> (build_vector AllConstants)
13760 // fold (zext (build_vector AllConstants) -> (build_vector AllConstants)
13761 // fold (aext (build_vector AllConstants) -> (build_vector AllConstants)
13762 EVT SVT = VT.getScalarType();
13763 if (!(VT.isVector() && (!LegalTypes || TLI.isTypeLegal(VT: SVT)) &&
13764 ISD::isBuildVectorOfConstantSDNodes(N: N0.getNode())))
13765 return SDValue();
13766
13767 // We can fold this node into a build_vector.
13768 unsigned VTBits = SVT.getSizeInBits();
13769 unsigned EVTBits = N0->getValueType(ResNo: 0).getScalarSizeInBits();
13770 SmallVector<SDValue, 8> Elts;
13771 unsigned NumElts = VT.getVectorNumElements();
13772
13773 for (unsigned i = 0; i != NumElts; ++i) {
13774 SDValue Op = N0.getOperand(i);
13775 if (Op.isUndef()) {
13776 if (Opcode == ISD::ANY_EXTEND || Opcode == ISD::ANY_EXTEND_VECTOR_INREG)
13777 Elts.push_back(Elt: DAG.getUNDEF(VT: SVT));
13778 else
13779 Elts.push_back(Elt: DAG.getConstant(Val: 0, DL, VT: SVT));
13780 continue;
13781 }
13782
13783 SDLoc DL(Op);
13784 // Get the constant value and if needed trunc it to the size of the type.
13785 // Nodes like build_vector might have constants wider than the scalar type.
13786 APInt C = Op->getAsAPIntVal().zextOrTrunc(width: EVTBits);
13787 if (Opcode == ISD::SIGN_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG)
13788 Elts.push_back(Elt: DAG.getConstant(Val: C.sext(width: VTBits), DL, VT: SVT));
13789 else
13790 Elts.push_back(Elt: DAG.getConstant(Val: C.zext(width: VTBits), DL, VT: SVT));
13791 }
13792
13793 return DAG.getBuildVector(VT, DL, Ops: Elts);
13794}
13795
13796// ExtendUsesToFormExtLoad - Trying to extend uses of a load to enable this:
13797// "fold ({s|z|a}ext (load x)) -> ({s|z|a}ext (truncate ({s|z|a}extload x)))"
13798// transformation. Returns true if extension are possible and the above
13799// mentioned transformation is profitable.
13800static bool ExtendUsesToFormExtLoad(EVT VT, SDNode *N, SDValue N0,
13801 unsigned ExtOpc,
13802 SmallVectorImpl<SDNode *> &ExtendNodes,
13803 const TargetLowering &TLI) {
13804 bool HasCopyToRegUses = false;
13805 bool isTruncFree = TLI.isTruncateFree(FromVT: VT, ToVT: N0.getValueType());
13806 for (SDUse &Use : N0->uses()) {
13807 SDNode *User = Use.getUser();
13808 if (User == N)
13809 continue;
13810 if (Use.getResNo() != N0.getResNo())
13811 continue;
13812 // FIXME: Only extend SETCC N, N and SETCC N, c for now.
13813 if (ExtOpc != ISD::ANY_EXTEND && User->getOpcode() == ISD::SETCC) {
13814 ISD::CondCode CC = cast<CondCodeSDNode>(Val: User->getOperand(Num: 2))->get();
13815 if (ExtOpc == ISD::ZERO_EXTEND && ISD::isSignedIntSetCC(Code: CC))
13816 // Sign bits will be lost after a zext.
13817 return false;
13818 bool Add = false;
13819 for (unsigned i = 0; i != 2; ++i) {
13820 SDValue UseOp = User->getOperand(Num: i);
13821 if (UseOp == N0)
13822 continue;
13823 if (!isa<ConstantSDNode>(Val: UseOp))
13824 return false;
13825 Add = true;
13826 }
13827 if (Add)
13828 ExtendNodes.push_back(Elt: User);
13829 continue;
13830 }
13831 // If truncates aren't free and there are users we can't
13832 // extend, it isn't worthwhile.
13833 if (!isTruncFree)
13834 return false;
13835 // Remember if this value is live-out.
13836 if (User->getOpcode() == ISD::CopyToReg)
13837 HasCopyToRegUses = true;
13838 }
13839
13840 if (HasCopyToRegUses) {
13841 bool BothLiveOut = false;
13842 for (SDUse &Use : N->uses()) {
13843 if (Use.getResNo() == 0 && Use.getUser()->getOpcode() == ISD::CopyToReg) {
13844 BothLiveOut = true;
13845 break;
13846 }
13847 }
13848 if (BothLiveOut)
13849 // Both unextended and extended values are live out. There had better be
13850 // a good reason for the transformation.
13851 return !ExtendNodes.empty();
13852 }
13853 return true;
13854}
13855
13856void DAGCombiner::ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
13857 SDValue OrigLoad, SDValue ExtLoad,
13858 ISD::NodeType ExtType) {
13859 // Extend SetCC uses if necessary.
13860 SDLoc DL(ExtLoad);
13861 for (SDNode *SetCC : SetCCs) {
13862 SmallVector<SDValue, 4> Ops;
13863
13864 for (unsigned j = 0; j != 2; ++j) {
13865 SDValue SOp = SetCC->getOperand(Num: j);
13866 if (SOp == OrigLoad)
13867 Ops.push_back(Elt: ExtLoad);
13868 else
13869 Ops.push_back(Elt: DAG.getNode(Opcode: ExtType, DL, VT: ExtLoad->getValueType(ResNo: 0), Operand: SOp));
13870 }
13871
13872 Ops.push_back(Elt: SetCC->getOperand(Num: 2));
13873 CombineTo(N: SetCC, Res: DAG.getNode(Opcode: ISD::SETCC, DL, VT: SetCC->getValueType(ResNo: 0), Ops));
13874 }
13875}
13876
13877// FIXME: Bring more similar combines here, common to sext/zext (maybe aext?).
13878SDValue DAGCombiner::CombineExtLoad(SDNode *N) {
13879 SDValue N0 = N->getOperand(Num: 0);
13880 EVT DstVT = N->getValueType(ResNo: 0);
13881 EVT SrcVT = N0.getValueType();
13882
13883 assert((N->getOpcode() == ISD::SIGN_EXTEND ||
13884 N->getOpcode() == ISD::ZERO_EXTEND) &&
13885 "Unexpected node type (not an extend)!");
13886
13887 // fold (sext (load x)) to multiple smaller sextloads; same for zext.
13888 // For example, on a target with legal v4i32, but illegal v8i32, turn:
13889 // (v8i32 (sext (v8i16 (load x))))
13890 // into:
13891 // (v8i32 (concat_vectors (v4i32 (sextload x)),
13892 // (v4i32 (sextload (x + 16)))))
13893 // Where uses of the original load, i.e.:
13894 // (v8i16 (load x))
13895 // are replaced with:
13896 // (v8i16 (truncate
13897 // (v8i32 (concat_vectors (v4i32 (sextload x)),
13898 // (v4i32 (sextload (x + 16)))))))
13899 //
13900 // This combine is only applicable to illegal, but splittable, vectors.
13901 // All legal types, and illegal non-vector types, are handled elsewhere.
13902 // This combine is controlled by TargetLowering::isVectorLoadExtDesirable.
13903 //
13904 if (N0->getOpcode() != ISD::LOAD)
13905 return SDValue();
13906
13907 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
13908
13909 if (!ISD::isNON_EXTLoad(N: LN0) || !ISD::isUNINDEXEDLoad(N: LN0) ||
13910 !N0.hasOneUse() || !LN0->isSimple() ||
13911 !DstVT.isVector() || !DstVT.isPow2VectorType() ||
13912 !TLI.isVectorLoadExtDesirable(ExtVal: SDValue(N, 0)))
13913 return SDValue();
13914
13915 SmallVector<SDNode *, 4> SetCCs;
13916 if (!ExtendUsesToFormExtLoad(VT: DstVT, N, N0, ExtOpc: N->getOpcode(), ExtendNodes&: SetCCs, TLI))
13917 return SDValue();
13918
13919 ISD::LoadExtType ExtType =
13920 N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
13921
13922 // Try to split the vector types to get down to legal types.
13923 EVT SplitSrcVT = SrcVT;
13924 EVT SplitDstVT = DstVT;
13925 while (!TLI.isLoadExtLegalOrCustom(ExtType, ValVT: SplitDstVT, MemVT: SplitSrcVT) &&
13926 SplitSrcVT.getVectorNumElements() > 1) {
13927 SplitDstVT = DAG.GetSplitDestVTs(VT: SplitDstVT).first;
13928 SplitSrcVT = DAG.GetSplitDestVTs(VT: SplitSrcVT).first;
13929 }
13930
13931 if (!TLI.isLoadExtLegalOrCustom(ExtType, ValVT: SplitDstVT, MemVT: SplitSrcVT))
13932 return SDValue();
13933
13934 assert(!DstVT.isScalableVector() && "Unexpected scalable vector type");
13935
13936 SDLoc DL(N);
13937 const unsigned NumSplits =
13938 DstVT.getVectorNumElements() / SplitDstVT.getVectorNumElements();
13939 const unsigned Stride = SplitSrcVT.getStoreSize();
13940 SmallVector<SDValue, 4> Loads;
13941 SmallVector<SDValue, 4> Chains;
13942
13943 SDValue BasePtr = LN0->getBasePtr();
13944 for (unsigned Idx = 0; Idx < NumSplits; Idx++) {
13945 const unsigned Offset = Idx * Stride;
13946
13947 SDValue SplitLoad =
13948 DAG.getExtLoad(ExtType, dl: SDLoc(LN0), VT: SplitDstVT, Chain: LN0->getChain(),
13949 Ptr: BasePtr, PtrInfo: LN0->getPointerInfo().getWithOffset(O: Offset),
13950 MemVT: SplitSrcVT, Alignment: LN0->getBaseAlign(),
13951 MMOFlags: LN0->getMemOperand()->getFlags(), AAInfo: LN0->getAAInfo());
13952
13953 BasePtr = DAG.getMemBasePlusOffset(Base: BasePtr, Offset: TypeSize::getFixed(ExactSize: Stride), DL);
13954
13955 Loads.push_back(Elt: SplitLoad.getValue(R: 0));
13956 Chains.push_back(Elt: SplitLoad.getValue(R: 1));
13957 }
13958
13959 SDValue NewChain = DAG.getNode(Opcode: ISD::TokenFactor, DL, VT: MVT::Other, Ops: Chains);
13960 SDValue NewValue = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT: DstVT, Ops: Loads);
13961
13962 // Simplify TF.
13963 AddToWorklist(N: NewChain.getNode());
13964
13965 CombineTo(N, Res: NewValue);
13966
13967 // Replace uses of the original load (before extension)
13968 // with a truncate of the concatenated sextloaded vectors.
13969 SDValue Trunc =
13970 DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N0), VT: N0.getValueType(), Operand: NewValue);
13971 ExtendSetCCUses(SetCCs, OrigLoad: N0, ExtLoad: NewValue, ExtType: (ISD::NodeType)N->getOpcode());
13972 CombineTo(N: N0.getNode(), Res0: Trunc, Res1: NewChain);
13973 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13974}
13975
13976// fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
13977// (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
13978SDValue DAGCombiner::CombineZExtLogicopShiftLoad(SDNode *N) {
13979 assert(N->getOpcode() == ISD::ZERO_EXTEND);
13980 EVT VT = N->getValueType(ResNo: 0);
13981 EVT OrigVT = N->getOperand(Num: 0).getValueType();
13982 if (TLI.isZExtFree(FromTy: OrigVT, ToTy: VT))
13983 return SDValue();
13984
13985 // and/or/xor
13986 SDValue N0 = N->getOperand(Num: 0);
13987 if (!ISD::isBitwiseLogicOp(Opcode: N0.getOpcode()) ||
13988 N0.getOperand(i: 1).getOpcode() != ISD::Constant ||
13989 (LegalOperations && !TLI.isOperationLegal(Op: N0.getOpcode(), VT)))
13990 return SDValue();
13991
13992 // shl/shr
13993 SDValue N1 = N0->getOperand(Num: 0);
13994 if (!(N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) ||
13995 N1.getOperand(i: 1).getOpcode() != ISD::Constant ||
13996 (LegalOperations && !TLI.isOperationLegal(Op: N1.getOpcode(), VT)))
13997 return SDValue();
13998
13999 // load
14000 if (!isa<LoadSDNode>(Val: N1.getOperand(i: 0)))
14001 return SDValue();
14002 LoadSDNode *Load = cast<LoadSDNode>(Val: N1.getOperand(i: 0));
14003 EVT MemVT = Load->getMemoryVT();
14004 if (!TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: VT, MemVT) ||
14005 Load->getExtensionType() == ISD::SEXTLOAD || Load->isIndexed())
14006 return SDValue();
14007
14008
14009 // If the shift op is SHL, the logic op must be AND, otherwise the result
14010 // will be wrong.
14011 if (N1.getOpcode() == ISD::SHL && N0.getOpcode() != ISD::AND)
14012 return SDValue();
14013
14014 if (!N0.hasOneUse() || !N1.hasOneUse())
14015 return SDValue();
14016
14017 SmallVector<SDNode*, 4> SetCCs;
14018 if (!ExtendUsesToFormExtLoad(VT, N: N1.getNode(), N0: N1.getOperand(i: 0),
14019 ExtOpc: ISD::ZERO_EXTEND, ExtendNodes&: SetCCs, TLI))
14020 return SDValue();
14021
14022 // Actually do the transformation.
14023 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::ZEXTLOAD, dl: SDLoc(Load), VT,
14024 Chain: Load->getChain(), Ptr: Load->getBasePtr(),
14025 MemVT: Load->getMemoryVT(), MMO: Load->getMemOperand());
14026
14027 SDLoc DL1(N1);
14028 SDValue Shift = DAG.getNode(Opcode: N1.getOpcode(), DL: DL1, VT, N1: ExtLoad,
14029 N2: N1.getOperand(i: 1));
14030
14031 APInt Mask = N0.getConstantOperandAPInt(i: 1).zext(width: VT.getSizeInBits());
14032 SDLoc DL0(N0);
14033 SDValue And = DAG.getNode(Opcode: N0.getOpcode(), DL: DL0, VT, N1: Shift,
14034 N2: DAG.getConstant(Val: Mask, DL: DL0, VT));
14035
14036 ExtendSetCCUses(SetCCs, OrigLoad: N1.getOperand(i: 0), ExtLoad, ExtType: ISD::ZERO_EXTEND);
14037 CombineTo(N, Res: And);
14038 if (SDValue(Load, 0).hasOneUse()) {
14039 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Load, 1), To: ExtLoad.getValue(R: 1));
14040 } else {
14041 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(Load),
14042 VT: Load->getValueType(ResNo: 0), Operand: ExtLoad);
14043 CombineTo(N: Load, Res0: Trunc, Res1: ExtLoad.getValue(R: 1));
14044 }
14045
14046 // N0 is dead at this point.
14047 recursivelyDeleteUnusedNodes(N: N0.getNode());
14048
14049 return SDValue(N,0); // Return N so it doesn't get rechecked!
14050}
14051
14052/// If we're narrowing or widening the result of a vector select and the final
14053/// size is the same size as a setcc (compare) feeding the select, then try to
14054/// apply the cast operation to the select's operands because matching vector
14055/// sizes for a select condition and other operands should be more efficient.
14056SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
14057 unsigned CastOpcode = Cast->getOpcode();
14058 assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
14059 CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
14060 CastOpcode == ISD::FP_ROUND) &&
14061 "Unexpected opcode for vector select narrowing/widening");
14062
14063 // We only do this transform before legal ops because the pattern may be
14064 // obfuscated by target-specific operations after legalization. Do not create
14065 // an illegal select op, however, because that may be difficult to lower.
14066 EVT VT = Cast->getValueType(ResNo: 0);
14067 if (LegalOperations || !TLI.isOperationLegalOrCustom(Op: ISD::VSELECT, VT))
14068 return SDValue();
14069
14070 SDValue VSel = Cast->getOperand(Num: 0);
14071 if (VSel.getOpcode() != ISD::VSELECT || !VSel.hasOneUse() ||
14072 VSel.getOperand(i: 0).getOpcode() != ISD::SETCC)
14073 return SDValue();
14074
14075 // Does the setcc have the same vector size as the casted select?
14076 SDValue SetCC = VSel.getOperand(i: 0);
14077 EVT SetCCVT = getSetCCResultType(VT: SetCC.getOperand(i: 0).getValueType());
14078 if (SetCCVT.getSizeInBits() != VT.getSizeInBits())
14079 return SDValue();
14080
14081 // cast (vsel (setcc X), A, B) --> vsel (setcc X), (cast A), (cast B)
14082 SDValue A = VSel.getOperand(i: 1);
14083 SDValue B = VSel.getOperand(i: 2);
14084 SDValue CastA, CastB;
14085 SDLoc DL(Cast);
14086 if (CastOpcode == ISD::FP_ROUND) {
14087 // FP_ROUND (fptrunc) has an extra flag operand to pass along.
14088 CastA = DAG.getNode(Opcode: CastOpcode, DL, VT, N1: A, N2: Cast->getOperand(Num: 1));
14089 CastB = DAG.getNode(Opcode: CastOpcode, DL, VT, N1: B, N2: Cast->getOperand(Num: 1));
14090 } else {
14091 CastA = DAG.getNode(Opcode: CastOpcode, DL, VT, Operand: A);
14092 CastB = DAG.getNode(Opcode: CastOpcode, DL, VT, Operand: B);
14093 }
14094 return DAG.getNode(Opcode: ISD::VSELECT, DL, VT, N1: SetCC, N2: CastA, N3: CastB);
14095}
14096
14097// fold ([s|z]ext ([s|z]extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
14098// fold ([s|z]ext ( extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
14099static SDValue tryToFoldExtOfExtload(SelectionDAG &DAG, DAGCombiner &Combiner,
14100 const TargetLowering &TLI, EVT VT,
14101 bool LegalOperations, SDNode *N,
14102 SDValue N0, ISD::LoadExtType ExtLoadType) {
14103 SDNode *N0Node = N0.getNode();
14104 bool isAExtLoad = (ExtLoadType == ISD::SEXTLOAD) ? ISD::isSEXTLoad(N: N0Node)
14105 : ISD::isZEXTLoad(N: N0Node);
14106 if ((!isAExtLoad && !ISD::isEXTLoad(N: N0Node)) ||
14107 !ISD::isUNINDEXEDLoad(N: N0Node) || !N0.hasOneUse())
14108 return SDValue();
14109
14110 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
14111 EVT MemVT = LN0->getMemoryVT();
14112 if ((LegalOperations || !LN0->isSimple() ||
14113 VT.isVector()) &&
14114 !TLI.isLoadExtLegal(ExtType: ExtLoadType, ValVT: VT, MemVT))
14115 return SDValue();
14116
14117 SDValue ExtLoad =
14118 DAG.getExtLoad(ExtType: ExtLoadType, dl: SDLoc(LN0), VT, Chain: LN0->getChain(),
14119 Ptr: LN0->getBasePtr(), MemVT, MMO: LN0->getMemOperand());
14120 Combiner.CombineTo(N, Res: ExtLoad);
14121 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN0, 1), To: ExtLoad.getValue(R: 1));
14122 if (LN0->use_empty())
14123 Combiner.recursivelyDeleteUnusedNodes(N: LN0);
14124 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14125}
14126
14127// fold ([s|z]ext (load x)) -> ([s|z]ext (truncate ([s|z]extload x)))
14128// Only generate vector extloads when 1) they're legal, and 2) they are
14129// deemed desirable by the target. NonNegZExt can be set to true if a zero
14130// extend has the nonneg flag to allow use of sextload if profitable.
14131static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner,
14132 const TargetLowering &TLI, EVT VT,
14133 bool LegalOperations, SDNode *N, SDValue N0,
14134 ISD::LoadExtType ExtLoadType,
14135 ISD::NodeType ExtOpc,
14136 bool NonNegZExt = false) {
14137 if (!ISD::isNON_EXTLoad(N: N0.getNode()) || !ISD::isUNINDEXEDLoad(N: N0.getNode()))
14138 return {};
14139
14140 // If this is zext nneg, see if it would make sense to treat it as a sext.
14141 if (NonNegZExt) {
14142 assert(ExtLoadType == ISD::ZEXTLOAD && ExtOpc == ISD::ZERO_EXTEND &&
14143 "Unexpected load type or opcode");
14144 for (SDNode *User : N0->users()) {
14145 if (User->getOpcode() == ISD::SETCC) {
14146 ISD::CondCode CC = cast<CondCodeSDNode>(Val: User->getOperand(Num: 2))->get();
14147 if (ISD::isSignedIntSetCC(Code: CC)) {
14148 ExtLoadType = ISD::SEXTLOAD;
14149 ExtOpc = ISD::SIGN_EXTEND;
14150 break;
14151 }
14152 }
14153 }
14154 }
14155
14156 // TODO: isFixedLengthVector() should be removed and any negative effects on
14157 // code generation being the result of that target's implementation of
14158 // isVectorLoadExtDesirable().
14159 if ((LegalOperations || VT.isFixedLengthVector() ||
14160 !cast<LoadSDNode>(Val&: N0)->isSimple()) &&
14161 !TLI.isLoadExtLegal(ExtType: ExtLoadType, ValVT: VT, MemVT: N0.getValueType()))
14162 return {};
14163
14164 bool DoXform = true;
14165 SmallVector<SDNode *, 4> SetCCs;
14166 if (!N0.hasOneUse())
14167 DoXform = ExtendUsesToFormExtLoad(VT, N, N0, ExtOpc, ExtendNodes&: SetCCs, TLI);
14168 if (VT.isVector())
14169 DoXform &= TLI.isVectorLoadExtDesirable(ExtVal: SDValue(N, 0));
14170 if (!DoXform)
14171 return {};
14172
14173 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
14174 SDValue ExtLoad = DAG.getExtLoad(ExtType: ExtLoadType, dl: SDLoc(LN0), VT, Chain: LN0->getChain(),
14175 Ptr: LN0->getBasePtr(), MemVT: N0.getValueType(),
14176 MMO: LN0->getMemOperand());
14177 Combiner.ExtendSetCCUses(SetCCs, OrigLoad: N0, ExtLoad, ExtType: ExtOpc);
14178 // If the load value is used only by N, replace it via CombineTo N.
14179 bool NoReplaceTrunc = SDValue(LN0, 0).hasOneUse();
14180 Combiner.CombineTo(N, Res: ExtLoad);
14181 if (NoReplaceTrunc) {
14182 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN0, 1), To: ExtLoad.getValue(R: 1));
14183 Combiner.recursivelyDeleteUnusedNodes(N: LN0);
14184 } else {
14185 SDValue Trunc =
14186 DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N0), VT: N0.getValueType(), Operand: ExtLoad);
14187 Combiner.CombineTo(N: LN0, Res0: Trunc, Res1: ExtLoad.getValue(R: 1));
14188 }
14189 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14190}
14191
14192static SDValue
14193tryToFoldExtOfMaskedLoad(SelectionDAG &DAG, const TargetLowering &TLI, EVT VT,
14194 bool LegalOperations, SDNode *N, SDValue N0,
14195 ISD::LoadExtType ExtLoadType, ISD::NodeType ExtOpc) {
14196 if (!N0.hasOneUse())
14197 return SDValue();
14198
14199 MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(Val&: N0);
14200 if (!Ld || Ld->getExtensionType() != ISD::NON_EXTLOAD)
14201 return SDValue();
14202
14203 if ((LegalOperations || !cast<MaskedLoadSDNode>(Val&: N0)->isSimple()) &&
14204 !TLI.isLoadExtLegalOrCustom(ExtType: ExtLoadType, ValVT: VT, MemVT: Ld->getValueType(ResNo: 0)))
14205 return SDValue();
14206
14207 if (!TLI.isVectorLoadExtDesirable(ExtVal: SDValue(N, 0)))
14208 return SDValue();
14209
14210 SDLoc dl(Ld);
14211 SDValue PassThru = DAG.getNode(Opcode: ExtOpc, DL: dl, VT, Operand: Ld->getPassThru());
14212 SDValue NewLoad = DAG.getMaskedLoad(
14213 VT, dl, Chain: Ld->getChain(), Base: Ld->getBasePtr(), Offset: Ld->getOffset(), Mask: Ld->getMask(),
14214 Src0: PassThru, MemVT: Ld->getMemoryVT(), MMO: Ld->getMemOperand(), AM: Ld->getAddressingMode(),
14215 ExtLoadType, IsExpanding: Ld->isExpandingLoad());
14216 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Ld, 1), To: SDValue(NewLoad.getNode(), 1));
14217 return NewLoad;
14218}
14219
14220// fold ([s|z]ext (atomic_load)) -> ([s|z]ext (truncate ([s|z]ext atomic_load)))
14221static SDValue tryToFoldExtOfAtomicLoad(SelectionDAG &DAG,
14222 const TargetLowering &TLI, EVT VT,
14223 SDValue N0,
14224 ISD::LoadExtType ExtLoadType) {
14225 auto *ALoad = dyn_cast<AtomicSDNode>(Val&: N0);
14226 if (!ALoad || ALoad->getOpcode() != ISD::ATOMIC_LOAD)
14227 return {};
14228 EVT MemoryVT = ALoad->getMemoryVT();
14229 if (!TLI.isAtomicLoadExtLegal(ExtType: ExtLoadType, ValVT: VT, MemVT: MemoryVT))
14230 return {};
14231 // Can't fold into ALoad if it is already extending differently.
14232 ISD::LoadExtType ALoadExtTy = ALoad->getExtensionType();
14233 if ((ALoadExtTy == ISD::ZEXTLOAD && ExtLoadType == ISD::SEXTLOAD) ||
14234 (ALoadExtTy == ISD::SEXTLOAD && ExtLoadType == ISD::ZEXTLOAD))
14235 return {};
14236
14237 EVT OrigVT = ALoad->getValueType(ResNo: 0);
14238 assert(OrigVT.getSizeInBits() < VT.getSizeInBits() && "VT should be wider.");
14239 auto *NewALoad = cast<AtomicSDNode>(Val: DAG.getAtomicLoad(
14240 ExtType: ExtLoadType, dl: SDLoc(ALoad), MemVT: MemoryVT, VT, Chain: ALoad->getChain(),
14241 Ptr: ALoad->getBasePtr(), MMO: ALoad->getMemOperand()));
14242 DAG.ReplaceAllUsesOfValueWith(
14243 From: SDValue(ALoad, 0),
14244 To: DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(ALoad), VT: OrigVT, Operand: SDValue(NewALoad, 0)));
14245 // Update the chain uses.
14246 DAG.ReplaceAllUsesOfValueWith(From: SDValue(ALoad, 1), To: SDValue(NewALoad, 1));
14247 return SDValue(NewALoad, 0);
14248}
14249
14250static SDValue foldExtendedSignBitTest(SDNode *N, SelectionDAG &DAG,
14251 bool LegalOperations) {
14252 assert((N->getOpcode() == ISD::SIGN_EXTEND ||
14253 N->getOpcode() == ISD::ZERO_EXTEND) && "Expected sext or zext");
14254
14255 SDValue SetCC = N->getOperand(Num: 0);
14256 if (LegalOperations || SetCC.getOpcode() != ISD::SETCC ||
14257 !SetCC.hasOneUse() || SetCC.getValueType() != MVT::i1)
14258 return SDValue();
14259
14260 SDValue X = SetCC.getOperand(i: 0);
14261 SDValue Ones = SetCC.getOperand(i: 1);
14262 ISD::CondCode CC = cast<CondCodeSDNode>(Val: SetCC.getOperand(i: 2))->get();
14263 EVT VT = N->getValueType(ResNo: 0);
14264 EVT XVT = X.getValueType();
14265 // setge X, C is canonicalized to setgt, so we do not need to match that
14266 // pattern. The setlt sibling is folded in SimplifySelectCC() because it does
14267 // not require the 'not' op.
14268 if (CC == ISD::SETGT && isAllOnesConstant(V: Ones) && VT == XVT) {
14269 // Invert and smear/shift the sign bit:
14270 // sext i1 (setgt iN X, -1) --> sra (not X), (N - 1)
14271 // zext i1 (setgt iN X, -1) --> srl (not X), (N - 1)
14272 SDLoc DL(N);
14273 unsigned ShCt = VT.getSizeInBits() - 1;
14274 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
14275 if (!TLI.shouldAvoidTransformToShift(VT, Amount: ShCt)) {
14276 SDValue NotX = DAG.getNOT(DL, Val: X, VT);
14277 SDValue ShiftAmount = DAG.getConstant(Val: ShCt, DL, VT);
14278 auto ShiftOpcode =
14279 N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SRA : ISD::SRL;
14280 return DAG.getNode(Opcode: ShiftOpcode, DL, VT, N1: NotX, N2: ShiftAmount);
14281 }
14282 }
14283 return SDValue();
14284}
14285
14286SDValue DAGCombiner::foldSextSetcc(SDNode *N) {
14287 SDValue N0 = N->getOperand(Num: 0);
14288 if (N0.getOpcode() != ISD::SETCC)
14289 return SDValue();
14290
14291 SDValue N00 = N0.getOperand(i: 0);
14292 SDValue N01 = N0.getOperand(i: 1);
14293 ISD::CondCode CC = cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get();
14294 EVT VT = N->getValueType(ResNo: 0);
14295 EVT N00VT = N00.getValueType();
14296 SDLoc DL(N);
14297
14298 // Propagate fast-math-flags.
14299 SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
14300
14301 // On some architectures (such as SSE/NEON/etc) the SETCC result type is
14302 // the same size as the compared operands. Try to optimize sext(setcc())
14303 // if this is the case.
14304 if (VT.isVector() && !LegalOperations &&
14305 TLI.getBooleanContents(Type: N00VT) ==
14306 TargetLowering::ZeroOrNegativeOneBooleanContent) {
14307 EVT SVT = getSetCCResultType(VT: N00VT);
14308
14309 // If we already have the desired type, don't change it.
14310 if (SVT != N0.getValueType()) {
14311 // We know that the # elements of the results is the same as the
14312 // # elements of the compare (and the # elements of the compare result
14313 // for that matter). Check to see that they are the same size. If so,
14314 // we know that the element size of the sext'd result matches the
14315 // element size of the compare operands.
14316 if (VT.getSizeInBits() == SVT.getSizeInBits())
14317 return DAG.getSetCC(DL, VT, LHS: N00, RHS: N01, Cond: CC);
14318
14319 // If the desired elements are smaller or larger than the source
14320 // elements, we can use a matching integer vector type and then
14321 // truncate/sign extend.
14322 EVT MatchingVecType = N00VT.changeVectorElementTypeToInteger();
14323 if (SVT == MatchingVecType) {
14324 SDValue VsetCC = DAG.getSetCC(DL, VT: MatchingVecType, LHS: N00, RHS: N01, Cond: CC);
14325 return DAG.getSExtOrTrunc(Op: VsetCC, DL, VT);
14326 }
14327 }
14328
14329 // Try to eliminate the sext of a setcc by zexting the compare operands.
14330 if (N0.hasOneUse() && TLI.isOperationLegalOrCustom(Op: ISD::SETCC, VT) &&
14331 !TLI.isOperationLegalOrCustom(Op: ISD::SETCC, VT: SVT)) {
14332 bool IsSignedCmp = ISD::isSignedIntSetCC(Code: CC);
14333 unsigned LoadOpcode = IsSignedCmp ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
14334 unsigned ExtOpcode = IsSignedCmp ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
14335
14336 // We have an unsupported narrow vector compare op that would be legal
14337 // if extended to the destination type. See if the compare operands
14338 // can be freely extended to the destination type.
14339 auto IsFreeToExtend = [&](SDValue V) {
14340 if (isConstantOrConstantVector(N: V, /*NoOpaques*/ true))
14341 return true;
14342 // Match a simple, non-extended load that can be converted to a
14343 // legal {z/s}ext-load.
14344 // TODO: Allow widening of an existing {z/s}ext-load?
14345 if (!(ISD::isNON_EXTLoad(N: V.getNode()) &&
14346 ISD::isUNINDEXEDLoad(N: V.getNode()) &&
14347 cast<LoadSDNode>(Val&: V)->isSimple() &&
14348 TLI.isLoadExtLegal(ExtType: LoadOpcode, ValVT: VT, MemVT: V.getValueType())))
14349 return false;
14350
14351 // Non-chain users of this value must either be the setcc in this
14352 // sequence or extends that can be folded into the new {z/s}ext-load.
14353 for (SDUse &Use : V->uses()) {
14354 // Skip uses of the chain and the setcc.
14355 SDNode *User = Use.getUser();
14356 if (Use.getResNo() != 0 || User == N0.getNode())
14357 continue;
14358 // Extra users must have exactly the same cast we are about to create.
14359 // TODO: This restriction could be eased if ExtendUsesToFormExtLoad()
14360 // is enhanced similarly.
14361 if (User->getOpcode() != ExtOpcode || User->getValueType(ResNo: 0) != VT)
14362 return false;
14363 }
14364 return true;
14365 };
14366
14367 if (IsFreeToExtend(N00) && IsFreeToExtend(N01)) {
14368 SDValue Ext0 = DAG.getNode(Opcode: ExtOpcode, DL, VT, Operand: N00);
14369 SDValue Ext1 = DAG.getNode(Opcode: ExtOpcode, DL, VT, Operand: N01);
14370 return DAG.getSetCC(DL, VT, LHS: Ext0, RHS: Ext1, Cond: CC);
14371 }
14372 }
14373 }
14374
14375 // sext(setcc x, y, cc) -> (select (setcc x, y, cc), T, 0)
14376 // Here, T can be 1 or -1, depending on the type of the setcc and
14377 // getBooleanContents().
14378 unsigned SetCCWidth = N0.getScalarValueSizeInBits();
14379
14380 // To determine the "true" side of the select, we need to know the high bit
14381 // of the value returned by the setcc if it evaluates to true.
14382 // If the type of the setcc is i1, then the true case of the select is just
14383 // sext(i1 1), that is, -1.
14384 // If the type of the setcc is larger (say, i8) then the value of the high
14385 // bit depends on getBooleanContents(), so ask TLI for a real "true" value
14386 // of the appropriate width.
14387 SDValue ExtTrueVal = (SetCCWidth == 1)
14388 ? DAG.getAllOnesConstant(DL, VT)
14389 : DAG.getBoolConstant(V: true, DL, VT, OpVT: N00VT);
14390 SDValue Zero = DAG.getConstant(Val: 0, DL, VT);
14391 if (SDValue SCC = SimplifySelectCC(DL, N0: N00, N1: N01, N2: ExtTrueVal, N3: Zero, CC, NotExtCompare: true))
14392 return SCC;
14393
14394 if (!VT.isVector() && !shouldConvertSelectOfConstantsToMath(Cond: N0, VT, TLI)) {
14395 EVT SetCCVT = getSetCCResultType(VT: N00VT);
14396 // Don't do this transform for i1 because there's a select transform
14397 // that would reverse it.
14398 // TODO: We should not do this transform at all without a target hook
14399 // because a sext is likely cheaper than a select?
14400 if (SetCCVT.getScalarSizeInBits() != 1 &&
14401 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SETCC, VT: N00VT))) {
14402 SDValue SetCC = DAG.getSetCC(DL, VT: SetCCVT, LHS: N00, RHS: N01, Cond: CC);
14403 return DAG.getSelect(DL, VT, Cond: SetCC, LHS: ExtTrueVal, RHS: Zero);
14404 }
14405 }
14406
14407 return SDValue();
14408}
14409
14410SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
14411 SDValue N0 = N->getOperand(Num: 0);
14412 EVT VT = N->getValueType(ResNo: 0);
14413 SDLoc DL(N);
14414
14415 if (VT.isVector())
14416 if (SDValue FoldedVOp = SimplifyVCastOp(N, DL))
14417 return FoldedVOp;
14418
14419 // sext(undef) = 0 because the top bit will all be the same.
14420 if (N0.isUndef())
14421 return DAG.getConstant(Val: 0, DL, VT);
14422
14423 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
14424 return Res;
14425
14426 // fold (sext (sext x)) -> (sext x)
14427 // fold (sext (aext x)) -> (sext x)
14428 if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
14429 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: N0.getOperand(i: 0));
14430
14431 // fold (sext (aext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
14432 // fold (sext (sext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
14433 if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
14434 N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG)
14435 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_VECTOR_INREG, DL: SDLoc(N), VT,
14436 Operand: N0.getOperand(i: 0));
14437
14438 if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG) {
14439 SDValue N00 = N0.getOperand(i: 0);
14440 EVT ExtVT = cast<VTSDNode>(Val: N0->getOperand(Num: 1))->getVT();
14441 if (N00.getOpcode() == ISD::TRUNCATE || TLI.isTruncateFree(Val: N00, VT2: ExtVT)) {
14442 // fold (sext (sext_inreg x)) -> (sext (trunc x))
14443 if ((!LegalTypes || TLI.isTypeLegal(VT: ExtVT))) {
14444 SDValue T = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: ExtVT, Operand: N00);
14445 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: T);
14446 }
14447
14448 // If the trunc wasn't legal, try to fold to (sext_inreg (anyext x))
14449 if (!LegalTypes || TLI.isTypeLegal(VT)) {
14450 SDValue ExtSrc = DAG.getAnyExtOrTrunc(Op: N00, DL, VT);
14451 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL, VT, N1: ExtSrc,
14452 N2: N0->getOperand(Num: 1));
14453 }
14454 }
14455 }
14456
14457 if (N0.getOpcode() == ISD::TRUNCATE) {
14458 // fold (sext (truncate (load x))) -> (sext (smaller load x))
14459 // fold (sext (truncate (srl (load x), c))) -> (sext (smaller load (x+c/n)))
14460 if (SDValue NarrowLoad = reduceLoadWidth(N: N0.getNode())) {
14461 SDNode *oye = N0.getOperand(i: 0).getNode();
14462 if (NarrowLoad.getNode() != N0.getNode()) {
14463 CombineTo(N: N0.getNode(), Res: NarrowLoad);
14464 // CombineTo deleted the truncate, if needed, but not what's under it.
14465 AddToWorklist(N: oye);
14466 }
14467 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14468 }
14469
14470 // See if the value being truncated is already sign extended. If so, just
14471 // eliminate the trunc/sext pair.
14472 SDValue Op = N0.getOperand(i: 0);
14473 unsigned OpBits = Op.getScalarValueSizeInBits();
14474 unsigned MidBits = N0.getScalarValueSizeInBits();
14475 unsigned DestBits = VT.getScalarSizeInBits();
14476
14477 if (N0->getFlags().hasNoSignedWrap() ||
14478 DAG.ComputeNumSignBits(Op) > OpBits - MidBits) {
14479 if (OpBits == DestBits) {
14480 // Op is i32, Mid is i8, and Dest is i32. If Op has more than 24 sign
14481 // bits, it is already ready.
14482 return Op;
14483 }
14484
14485 if (OpBits < DestBits) {
14486 // Op is i32, Mid is i8, and Dest is i64. If Op has more than 24 sign
14487 // bits, just sext from i32.
14488 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: Op);
14489 }
14490
14491 // Op is i64, Mid is i8, and Dest is i32. If Op has more than 56 sign
14492 // bits, just truncate to i32.
14493 SDNodeFlags Flags;
14494 Flags.setNoSignedWrap(true);
14495 Flags.setNoUnsignedWrap(N0->getFlags().hasNoUnsignedWrap());
14496 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Op, Flags);
14497 }
14498
14499 // fold (sext (truncate x)) -> (sextinreg x).
14500 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::SIGN_EXTEND_INREG,
14501 VT: N0.getValueType())) {
14502 if (OpBits < DestBits)
14503 Op = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL: SDLoc(N0), VT, Operand: Op);
14504 else if (OpBits > DestBits)
14505 Op = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N0), VT, Operand: Op);
14506 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL, VT, N1: Op,
14507 N2: DAG.getValueType(N0.getValueType()));
14508 }
14509 }
14510
14511 // Try to simplify (sext (load x)).
14512 if (SDValue foldedExt =
14513 tryToFoldExtOfLoad(DAG, Combiner&: *this, TLI, VT, LegalOperations, N, N0,
14514 ExtLoadType: ISD::SEXTLOAD, ExtOpc: ISD::SIGN_EXTEND))
14515 return foldedExt;
14516
14517 if (SDValue foldedExt =
14518 tryToFoldExtOfMaskedLoad(DAG, TLI, VT, LegalOperations, N, N0,
14519 ExtLoadType: ISD::SEXTLOAD, ExtOpc: ISD::SIGN_EXTEND))
14520 return foldedExt;
14521
14522 // fold (sext (load x)) to multiple smaller sextloads.
14523 // Only on illegal but splittable vectors.
14524 if (SDValue ExtLoad = CombineExtLoad(N))
14525 return ExtLoad;
14526
14527 // Try to simplify (sext (sextload x)).
14528 if (SDValue foldedExt = tryToFoldExtOfExtload(
14529 DAG, Combiner&: *this, TLI, VT, LegalOperations, N, N0, ExtLoadType: ISD::SEXTLOAD))
14530 return foldedExt;
14531
14532 // Try to simplify (sext (atomic_load x)).
14533 if (SDValue foldedExt =
14534 tryToFoldExtOfAtomicLoad(DAG, TLI, VT, N0, ExtLoadType: ISD::SEXTLOAD))
14535 return foldedExt;
14536
14537 // fold (sext (and/or/xor (load x), cst)) ->
14538 // (and/or/xor (sextload x), (sext cst))
14539 if (ISD::isBitwiseLogicOp(Opcode: N0.getOpcode()) &&
14540 isa<LoadSDNode>(Val: N0.getOperand(i: 0)) &&
14541 N0.getOperand(i: 1).getOpcode() == ISD::Constant &&
14542 (!LegalOperations && TLI.isOperationLegal(Op: N0.getOpcode(), VT))) {
14543 LoadSDNode *LN00 = cast<LoadSDNode>(Val: N0.getOperand(i: 0));
14544 EVT MemVT = LN00->getMemoryVT();
14545 if (TLI.isLoadExtLegal(ExtType: ISD::SEXTLOAD, ValVT: VT, MemVT) &&
14546 LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) {
14547 SmallVector<SDNode*, 4> SetCCs;
14548 bool DoXform = ExtendUsesToFormExtLoad(VT, N: N0.getNode(), N0: N0.getOperand(i: 0),
14549 ExtOpc: ISD::SIGN_EXTEND, ExtendNodes&: SetCCs, TLI);
14550 if (DoXform) {
14551 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::SEXTLOAD, dl: SDLoc(LN00), VT,
14552 Chain: LN00->getChain(), Ptr: LN00->getBasePtr(),
14553 MemVT: LN00->getMemoryVT(),
14554 MMO: LN00->getMemOperand());
14555 APInt Mask = N0.getConstantOperandAPInt(i: 1).sext(width: VT.getSizeInBits());
14556 SDValue And = DAG.getNode(Opcode: N0.getOpcode(), DL, VT,
14557 N1: ExtLoad, N2: DAG.getConstant(Val: Mask, DL, VT));
14558 ExtendSetCCUses(SetCCs, OrigLoad: N0.getOperand(i: 0), ExtLoad, ExtType: ISD::SIGN_EXTEND);
14559 bool NoReplaceTruncAnd = !N0.hasOneUse();
14560 bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
14561 CombineTo(N, Res: And);
14562 // If N0 has multiple uses, change other uses as well.
14563 if (NoReplaceTruncAnd) {
14564 SDValue TruncAnd =
14565 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: N0.getValueType(), Operand: And);
14566 CombineTo(N: N0.getNode(), Res: TruncAnd);
14567 }
14568 if (NoReplaceTrunc) {
14569 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN00, 1), To: ExtLoad.getValue(R: 1));
14570 } else {
14571 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(LN00),
14572 VT: LN00->getValueType(ResNo: 0), Operand: ExtLoad);
14573 CombineTo(N: LN00, Res0: Trunc, Res1: ExtLoad.getValue(R: 1));
14574 }
14575 return SDValue(N,0); // Return N so it doesn't get rechecked!
14576 }
14577 }
14578 }
14579
14580 if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
14581 return V;
14582
14583 if (SDValue V = foldSextSetcc(N))
14584 return V;
14585
14586 // fold (sext x) -> (zext x) if the sign bit is known zero.
14587 if (!TLI.isSExtCheaperThanZExt(FromTy: N0.getValueType(), ToTy: VT) &&
14588 (!LegalOperations || TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT)) &&
14589 DAG.SignBitIsZero(Op: N0))
14590 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0, Flags: SDNodeFlags::NonNeg);
14591
14592 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(Cast: N))
14593 return NewVSel;
14594
14595 // Eliminate this sign extend by doing a negation in the destination type:
14596 // sext i32 (0 - (zext i8 X to i32)) to i64 --> 0 - (zext i8 X to i64)
14597 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
14598 isNullOrNullSplat(V: N0.getOperand(i: 0)) &&
14599 N0.getOperand(i: 1).getOpcode() == ISD::ZERO_EXTEND &&
14600 TLI.isOperationLegalOrCustom(Op: ISD::SUB, VT)) {
14601 SDValue Zext = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1).getOperand(i: 0), DL, VT);
14602 return DAG.getNegative(Val: Zext, DL, VT);
14603 }
14604 // Eliminate this sign extend by doing a decrement in the destination type:
14605 // sext i32 ((zext i8 X to i32) + (-1)) to i64 --> (zext i8 X to i64) + (-1)
14606 if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
14607 isAllOnesOrAllOnesSplat(V: N0.getOperand(i: 1)) &&
14608 N0.getOperand(i: 0).getOpcode() == ISD::ZERO_EXTEND &&
14609 TLI.isOperationLegalOrCustom(Op: ISD::ADD, VT)) {
14610 SDValue Zext = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 0).getOperand(i: 0), DL, VT);
14611 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Zext, N2: DAG.getAllOnesConstant(DL, VT));
14612 }
14613
14614 // fold sext (not i1 X) -> add (zext i1 X), -1
14615 // TODO: This could be extended to handle bool vectors.
14616 if (N0.getValueType() == MVT::i1 && isBitwiseNot(V: N0) && N0.hasOneUse() &&
14617 (!LegalOperations || (TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT) &&
14618 TLI.isOperationLegal(Op: ISD::ADD, VT)))) {
14619 // If we can eliminate the 'not', the sext form should be better
14620 if (SDValue NewXor = visitXOR(N: N0.getNode())) {
14621 // Returning N0 is a form of in-visit replacement that may have
14622 // invalidated N0.
14623 if (NewXor.getNode() == N0.getNode()) {
14624 // Return SDValue here as the xor should have already been replaced in
14625 // this sext.
14626 return SDValue();
14627 }
14628
14629 // Return a new sext with the new xor.
14630 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: NewXor);
14631 }
14632
14633 SDValue Zext = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0.getOperand(i: 0));
14634 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Zext, N2: DAG.getAllOnesConstant(DL, VT));
14635 }
14636
14637 if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, DL, Level))
14638 return Res;
14639
14640 return SDValue();
14641}
14642
14643/// Given an extending node with a pop-count operand, if the target does not
14644/// support a pop-count in the narrow source type but does support it in the
14645/// destination type, widen the pop-count to the destination type.
14646static SDValue widenCtPop(SDNode *Extend, SelectionDAG &DAG, const SDLoc &DL) {
14647 assert((Extend->getOpcode() == ISD::ZERO_EXTEND ||
14648 Extend->getOpcode() == ISD::ANY_EXTEND) &&
14649 "Expected extend op");
14650
14651 SDValue CtPop = Extend->getOperand(Num: 0);
14652 if (CtPop.getOpcode() != ISD::CTPOP || !CtPop.hasOneUse())
14653 return SDValue();
14654
14655 EVT VT = Extend->getValueType(ResNo: 0);
14656 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
14657 if (TLI.isOperationLegalOrCustom(Op: ISD::CTPOP, VT: CtPop.getValueType()) ||
14658 !TLI.isOperationLegalOrCustom(Op: ISD::CTPOP, VT))
14659 return SDValue();
14660
14661 // zext (ctpop X) --> ctpop (zext X)
14662 SDValue NewZext = DAG.getZExtOrTrunc(Op: CtPop.getOperand(i: 0), DL, VT);
14663 return DAG.getNode(Opcode: ISD::CTPOP, DL, VT, Operand: NewZext);
14664}
14665
14666// If we have (zext (abs X)) where X is a type that will be promoted by type
14667// legalization, convert to (abs (sext X)). But don't extend past a legal type.
14668static SDValue widenAbs(SDNode *Extend, SelectionDAG &DAG) {
14669 assert(Extend->getOpcode() == ISD::ZERO_EXTEND && "Expected zero extend.");
14670
14671 EVT VT = Extend->getValueType(ResNo: 0);
14672 if (VT.isVector())
14673 return SDValue();
14674
14675 SDValue Abs = Extend->getOperand(Num: 0);
14676 if (Abs.getOpcode() != ISD::ABS || !Abs.hasOneUse())
14677 return SDValue();
14678
14679 EVT AbsVT = Abs.getValueType();
14680 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
14681 if (TLI.getTypeAction(Context&: *DAG.getContext(), VT: AbsVT) !=
14682 TargetLowering::TypePromoteInteger)
14683 return SDValue();
14684
14685 EVT LegalVT = TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: AbsVT);
14686
14687 SDValue SExt =
14688 DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL: SDLoc(Abs), VT: LegalVT, Operand: Abs.getOperand(i: 0));
14689 SDValue NewAbs = DAG.getNode(Opcode: ISD::ABS, DL: SDLoc(Abs), VT: LegalVT, Operand: SExt);
14690 return DAG.getZExtOrTrunc(Op: NewAbs, DL: SDLoc(Extend), VT);
14691}
14692
14693SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
14694 SDValue N0 = N->getOperand(Num: 0);
14695 EVT VT = N->getValueType(ResNo: 0);
14696 SDLoc DL(N);
14697
14698 if (VT.isVector())
14699 if (SDValue FoldedVOp = SimplifyVCastOp(N, DL))
14700 return FoldedVOp;
14701
14702 // zext(undef) = 0
14703 if (N0.isUndef())
14704 return DAG.getConstant(Val: 0, DL, VT);
14705
14706 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
14707 return Res;
14708
14709 // fold (zext (zext x)) -> (zext x)
14710 // fold (zext (aext x)) -> (zext x)
14711 if (N0.getOpcode() == ISD::ZERO_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
14712 SDNodeFlags Flags;
14713 if (N0.getOpcode() == ISD::ZERO_EXTEND)
14714 Flags.setNonNeg(N0->getFlags().hasNonNeg());
14715 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0.getOperand(i: 0), Flags);
14716 }
14717
14718 // fold (zext (aext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
14719 // fold (zext (zext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
14720 if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
14721 N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG)
14722 return DAG.getNode(Opcode: ISD::ZERO_EXTEND_VECTOR_INREG, DL, VT, Operand: N0.getOperand(i: 0));
14723
14724 // fold (zext (truncate x)) -> (zext x) or
14725 // (zext (truncate x)) -> (truncate x)
14726 // This is valid when the truncated bits of x are already zero.
14727 SDValue Op;
14728 KnownBits Known;
14729 if (isTruncateOf(DAG, N: N0, Op, Known)) {
14730 APInt TruncatedBits =
14731 (Op.getScalarValueSizeInBits() == N0.getScalarValueSizeInBits()) ?
14732 APInt(Op.getScalarValueSizeInBits(), 0) :
14733 APInt::getBitsSet(numBits: Op.getScalarValueSizeInBits(),
14734 loBit: N0.getScalarValueSizeInBits(),
14735 hiBit: std::min(a: Op.getScalarValueSizeInBits(),
14736 b: VT.getScalarSizeInBits()));
14737 if (TruncatedBits.isSubsetOf(RHS: Known.Zero)) {
14738 SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, DL, VT);
14739 DAG.salvageDebugInfo(N&: *N0.getNode());
14740
14741 return ZExtOrTrunc;
14742 }
14743 }
14744
14745 // fold (zext (truncate x)) -> (and x, mask)
14746 if (N0.getOpcode() == ISD::TRUNCATE) {
14747 // fold (zext (truncate (load x))) -> (zext (smaller load x))
14748 // fold (zext (truncate (srl (load x), c))) -> (zext (smaller load (x+c/n)))
14749 if (SDValue NarrowLoad = reduceLoadWidth(N: N0.getNode())) {
14750 SDNode *oye = N0.getOperand(i: 0).getNode();
14751 if (NarrowLoad.getNode() != N0.getNode()) {
14752 CombineTo(N: N0.getNode(), Res: NarrowLoad);
14753 // CombineTo deleted the truncate, if needed, but not what's under it.
14754 AddToWorklist(N: oye);
14755 }
14756 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14757 }
14758
14759 EVT SrcVT = N0.getOperand(i: 0).getValueType();
14760 EVT MinVT = N0.getValueType();
14761
14762 if (N->getFlags().hasNonNeg()) {
14763 SDValue Op = N0.getOperand(i: 0);
14764 unsigned OpBits = SrcVT.getScalarSizeInBits();
14765 unsigned MidBits = MinVT.getScalarSizeInBits();
14766 unsigned DestBits = VT.getScalarSizeInBits();
14767
14768 if (N0->getFlags().hasNoSignedWrap() ||
14769 DAG.ComputeNumSignBits(Op) > OpBits - MidBits) {
14770 if (OpBits == DestBits) {
14771 // Op is i32, Mid is i8, and Dest is i32. If Op has more than 24 sign
14772 // bits, it is already ready.
14773 return Op;
14774 }
14775
14776 if (OpBits < DestBits) {
14777 // Op is i32, Mid is i8, and Dest is i64. If Op has more than 24 sign
14778 // bits, just sext from i32.
14779 // FIXME: This can probably be ZERO_EXTEND nneg?
14780 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: Op);
14781 }
14782
14783 // Op is i64, Mid is i8, and Dest is i32. If Op has more than 56 sign
14784 // bits, just truncate to i32.
14785 SDNodeFlags Flags;
14786 Flags.setNoSignedWrap(true);
14787 Flags.setNoUnsignedWrap(true);
14788 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Op, Flags);
14789 }
14790 }
14791
14792 // Try to mask before the extension to avoid having to generate a larger mask,
14793 // possibly over several sub-vectors.
14794 if (SrcVT.bitsLT(VT) && VT.isVector()) {
14795 if (!LegalOperations || (TLI.isOperationLegal(Op: ISD::AND, VT: SrcVT) &&
14796 TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT))) {
14797 SDValue Op = N0.getOperand(i: 0);
14798 Op = DAG.getZeroExtendInReg(Op, DL, VT: MinVT);
14799 AddToWorklist(N: Op.getNode());
14800 SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, DL, VT);
14801 // Transfer the debug info; the new node is equivalent to N0.
14802 DAG.transferDbgValues(From: N0, To: ZExtOrTrunc);
14803 return ZExtOrTrunc;
14804 }
14805 }
14806
14807 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::AND, VT)) {
14808 SDValue Op = DAG.getAnyExtOrTrunc(Op: N0.getOperand(i: 0), DL, VT);
14809 AddToWorklist(N: Op.getNode());
14810 SDValue And = DAG.getZeroExtendInReg(Op, DL, VT: MinVT);
14811 // We may safely transfer the debug info describing the truncate node over
14812 // to the equivalent and operation.
14813 DAG.transferDbgValues(From: N0, To: And);
14814 return And;
14815 }
14816 }
14817
14818 // Fold (zext (and (trunc x), cst)) -> (and x, cst),
14819 // if either of the casts is not free.
14820 if (N0.getOpcode() == ISD::AND &&
14821 N0.getOperand(i: 0).getOpcode() == ISD::TRUNCATE &&
14822 N0.getOperand(i: 1).getOpcode() == ISD::Constant &&
14823 (!TLI.isTruncateFree(Val: N0.getOperand(i: 0).getOperand(i: 0), VT2: N0.getValueType()) ||
14824 !TLI.isZExtFree(FromTy: N0.getValueType(), ToTy: VT))) {
14825 SDValue X = N0.getOperand(i: 0).getOperand(i: 0);
14826 X = DAG.getAnyExtOrTrunc(Op: X, DL: SDLoc(X), VT);
14827 APInt Mask = N0.getConstantOperandAPInt(i: 1).zext(width: VT.getSizeInBits());
14828 return DAG.getNode(Opcode: ISD::AND, DL, VT,
14829 N1: X, N2: DAG.getConstant(Val: Mask, DL, VT));
14830 }
14831
14832 // Try to simplify (zext (load x)).
14833 if (SDValue foldedExt = tryToFoldExtOfLoad(
14834 DAG, Combiner&: *this, TLI, VT, LegalOperations, N, N0, ExtLoadType: ISD::ZEXTLOAD,
14835 ExtOpc: ISD::ZERO_EXTEND, NonNegZExt: N->getFlags().hasNonNeg()))
14836 return foldedExt;
14837
14838 if (SDValue foldedExt =
14839 tryToFoldExtOfMaskedLoad(DAG, TLI, VT, LegalOperations, N, N0,
14840 ExtLoadType: ISD::ZEXTLOAD, ExtOpc: ISD::ZERO_EXTEND))
14841 return foldedExt;
14842
14843 // fold (zext (load x)) to multiple smaller zextloads.
14844 // Only on illegal but splittable vectors.
14845 if (SDValue ExtLoad = CombineExtLoad(N))
14846 return ExtLoad;
14847
14848 // Try to simplify (zext (atomic_load x)).
14849 if (SDValue foldedExt =
14850 tryToFoldExtOfAtomicLoad(DAG, TLI, VT, N0, ExtLoadType: ISD::ZEXTLOAD))
14851 return foldedExt;
14852
14853 // fold (zext (and/or/xor (load x), cst)) ->
14854 // (and/or/xor (zextload x), (zext cst))
14855 // Unless (and (load x) cst) will match as a zextload already and has
14856 // additional users, or the zext is already free.
14857 if (ISD::isBitwiseLogicOp(Opcode: N0.getOpcode()) && !TLI.isZExtFree(Val: N0, VT2: VT) &&
14858 isa<LoadSDNode>(Val: N0.getOperand(i: 0)) &&
14859 N0.getOperand(i: 1).getOpcode() == ISD::Constant &&
14860 (!LegalOperations && TLI.isOperationLegal(Op: N0.getOpcode(), VT))) {
14861 LoadSDNode *LN00 = cast<LoadSDNode>(Val: N0.getOperand(i: 0));
14862 EVT MemVT = LN00->getMemoryVT();
14863 if (TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: VT, MemVT) &&
14864 LN00->getExtensionType() != ISD::SEXTLOAD && LN00->isUnindexed()) {
14865 bool DoXform = true;
14866 SmallVector<SDNode*, 4> SetCCs;
14867 if (!N0.hasOneUse()) {
14868 if (N0.getOpcode() == ISD::AND) {
14869 auto *AndC = cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
14870 EVT LoadResultTy = AndC->getValueType(ResNo: 0);
14871 EVT ExtVT;
14872 if (isAndLoadExtLoad(AndC, LoadN: LN00, LoadResultTy, ExtVT))
14873 DoXform = false;
14874 }
14875 }
14876 if (DoXform)
14877 DoXform = ExtendUsesToFormExtLoad(VT, N: N0.getNode(), N0: N0.getOperand(i: 0),
14878 ExtOpc: ISD::ZERO_EXTEND, ExtendNodes&: SetCCs, TLI);
14879 if (DoXform) {
14880 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::ZEXTLOAD, dl: SDLoc(LN00), VT,
14881 Chain: LN00->getChain(), Ptr: LN00->getBasePtr(),
14882 MemVT: LN00->getMemoryVT(),
14883 MMO: LN00->getMemOperand());
14884 APInt Mask = N0.getConstantOperandAPInt(i: 1).zext(width: VT.getSizeInBits());
14885 SDValue And = DAG.getNode(Opcode: N0.getOpcode(), DL, VT,
14886 N1: ExtLoad, N2: DAG.getConstant(Val: Mask, DL, VT));
14887 ExtendSetCCUses(SetCCs, OrigLoad: N0.getOperand(i: 0), ExtLoad, ExtType: ISD::ZERO_EXTEND);
14888 bool NoReplaceTruncAnd = !N0.hasOneUse();
14889 bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
14890 CombineTo(N, Res: And);
14891 // If N0 has multiple uses, change other uses as well.
14892 if (NoReplaceTruncAnd) {
14893 SDValue TruncAnd =
14894 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: N0.getValueType(), Operand: And);
14895 CombineTo(N: N0.getNode(), Res: TruncAnd);
14896 }
14897 if (NoReplaceTrunc) {
14898 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN00, 1), To: ExtLoad.getValue(R: 1));
14899 } else {
14900 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(LN00),
14901 VT: LN00->getValueType(ResNo: 0), Operand: ExtLoad);
14902 CombineTo(N: LN00, Res0: Trunc, Res1: ExtLoad.getValue(R: 1));
14903 }
14904 return SDValue(N,0); // Return N so it doesn't get rechecked!
14905 }
14906 }
14907 }
14908
14909 // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
14910 // (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
14911 if (SDValue ZExtLoad = CombineZExtLogicopShiftLoad(N))
14912 return ZExtLoad;
14913
14914 // Try to simplify (zext (zextload x)).
14915 if (SDValue foldedExt = tryToFoldExtOfExtload(
14916 DAG, Combiner&: *this, TLI, VT, LegalOperations, N, N0, ExtLoadType: ISD::ZEXTLOAD))
14917 return foldedExt;
14918
14919 if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
14920 return V;
14921
14922 if (N0.getOpcode() == ISD::SETCC) {
14923 // Propagate fast-math-flags.
14924 SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
14925
14926 // Only do this before legalize for now.
14927 if (!LegalOperations && VT.isVector() &&
14928 N0.getValueType().getVectorElementType() == MVT::i1) {
14929 EVT N00VT = N0.getOperand(i: 0).getValueType();
14930 if (getSetCCResultType(VT: N00VT) == N0.getValueType())
14931 return SDValue();
14932
14933 // We know that the # elements of the results is the same as the #
14934 // elements of the compare (and the # elements of the compare result for
14935 // that matter). Check to see that they are the same size. If so, we know
14936 // that the element size of the sext'd result matches the element size of
14937 // the compare operands.
14938 if (VT.getSizeInBits() == N00VT.getSizeInBits()) {
14939 // zext(setcc) -> zext_in_reg(vsetcc) for vectors.
14940 SDValue VSetCC = DAG.getNode(Opcode: ISD::SETCC, DL, VT, N1: N0.getOperand(i: 0),
14941 N2: N0.getOperand(i: 1), N3: N0.getOperand(i: 2));
14942 return DAG.getZeroExtendInReg(Op: VSetCC, DL, VT: N0.getValueType());
14943 }
14944
14945 // If the desired elements are smaller or larger than the source
14946 // elements we can use a matching integer vector type and then
14947 // truncate/any extend followed by zext_in_reg.
14948 EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
14949 SDValue VsetCC =
14950 DAG.getNode(Opcode: ISD::SETCC, DL, VT: MatchingVectorType, N1: N0.getOperand(i: 0),
14951 N2: N0.getOperand(i: 1), N3: N0.getOperand(i: 2));
14952 return DAG.getZeroExtendInReg(Op: DAG.getAnyExtOrTrunc(Op: VsetCC, DL, VT), DL,
14953 VT: N0.getValueType());
14954 }
14955
14956 // zext(setcc x,y,cc) -> zext(select x, y, true, false, cc)
14957 EVT N0VT = N0.getValueType();
14958 EVT N00VT = N0.getOperand(i: 0).getValueType();
14959 if (SDValue SCC = SimplifySelectCC(
14960 DL, N0: N0.getOperand(i: 0), N1: N0.getOperand(i: 1),
14961 N2: DAG.getBoolConstant(V: true, DL, VT: N0VT, OpVT: N00VT),
14962 N3: DAG.getBoolConstant(V: false, DL, VT: N0VT, OpVT: N00VT),
14963 CC: cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get(), NotExtCompare: true))
14964 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: SCC);
14965 }
14966
14967 // (zext (shl (zext x), cst)) -> (shl (zext x), cst)
14968 if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
14969 !TLI.isZExtFree(Val: N0, VT2: VT)) {
14970 SDValue ShVal = N0.getOperand(i: 0);
14971 SDValue ShAmt = N0.getOperand(i: 1);
14972 if (auto *ShAmtC = dyn_cast<ConstantSDNode>(Val&: ShAmt)) {
14973 if (ShVal.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse()) {
14974 if (N0.getOpcode() == ISD::SHL) {
14975 // If the original shl may be shifting out bits, do not perform this
14976 // transformation.
14977 unsigned KnownZeroBits = ShVal.getValueSizeInBits() -
14978 ShVal.getOperand(i: 0).getValueSizeInBits();
14979 if (ShAmtC->getAPIntValue().ugt(RHS: KnownZeroBits)) {
14980 // If the shift is too large, then see if we can deduce that the
14981 // shift is safe anyway.
14982
14983 // Check if the bits being shifted out are known to be zero.
14984 KnownBits KnownShVal = DAG.computeKnownBits(Op: ShVal);
14985 if (ShAmtC->getAPIntValue().ugt(RHS: KnownShVal.countMinLeadingZeros()))
14986 return SDValue();
14987 }
14988 }
14989
14990 // Ensure that the shift amount is wide enough for the shifted value.
14991 if (Log2_32_Ceil(Value: VT.getSizeInBits()) > ShAmt.getValueSizeInBits())
14992 ShAmt = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: MVT::i32, Operand: ShAmt);
14993
14994 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT,
14995 N1: DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: ShVal), N2: ShAmt);
14996 }
14997 }
14998 }
14999
15000 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(Cast: N))
15001 return NewVSel;
15002
15003 if (SDValue NewCtPop = widenCtPop(Extend: N, DAG, DL))
15004 return NewCtPop;
15005
15006 if (SDValue V = widenAbs(Extend: N, DAG))
15007 return V;
15008
15009 if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, DL, Level))
15010 return Res;
15011
15012 // CSE zext nneg with sext if the zext is not free.
15013 if (N->getFlags().hasNonNeg() && !TLI.isZExtFree(FromTy: N0.getValueType(), ToTy: VT)) {
15014 SDNode *CSENode = DAG.getNodeIfExists(Opcode: ISD::SIGN_EXTEND, VTList: N->getVTList(), Ops: N0);
15015 if (CSENode)
15016 return SDValue(CSENode, 0);
15017 }
15018
15019 return SDValue();
15020}
15021
15022SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
15023 SDValue N0 = N->getOperand(Num: 0);
15024 EVT VT = N->getValueType(ResNo: 0);
15025 SDLoc DL(N);
15026
15027 // aext(undef) = undef
15028 if (N0.isUndef())
15029 return DAG.getUNDEF(VT);
15030
15031 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
15032 return Res;
15033
15034 // fold (aext (aext x)) -> (aext x)
15035 // fold (aext (zext x)) -> (zext x)
15036 // fold (aext (sext x)) -> (sext x)
15037 if (N0.getOpcode() == ISD::ANY_EXTEND || N0.getOpcode() == ISD::ZERO_EXTEND ||
15038 N0.getOpcode() == ISD::SIGN_EXTEND) {
15039 SDNodeFlags Flags;
15040 if (N0.getOpcode() == ISD::ZERO_EXTEND)
15041 Flags.setNonNeg(N0->getFlags().hasNonNeg());
15042 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, Operand: N0.getOperand(i: 0), Flags);
15043 }
15044
15045 // fold (aext (aext_extend_vector_inreg x)) -> (aext_extend_vector_inreg x)
15046 // fold (aext (zext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
15047 // fold (aext (sext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
15048 if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
15049 N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG ||
15050 N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG)
15051 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, Operand: N0.getOperand(i: 0));
15052
15053 // fold (aext (truncate (load x))) -> (aext (smaller load x))
15054 // fold (aext (truncate (srl (load x), c))) -> (aext (small load (x+c/n)))
15055 if (N0.getOpcode() == ISD::TRUNCATE) {
15056 if (SDValue NarrowLoad = reduceLoadWidth(N: N0.getNode())) {
15057 SDNode *oye = N0.getOperand(i: 0).getNode();
15058 if (NarrowLoad.getNode() != N0.getNode()) {
15059 CombineTo(N: N0.getNode(), Res: NarrowLoad);
15060 // CombineTo deleted the truncate, if needed, but not what's under it.
15061 AddToWorklist(N: oye);
15062 }
15063 return SDValue(N, 0); // Return N so it doesn't get rechecked!
15064 }
15065 }
15066
15067 // fold (aext (truncate x))
15068 if (N0.getOpcode() == ISD::TRUNCATE)
15069 return DAG.getAnyExtOrTrunc(Op: N0.getOperand(i: 0), DL, VT);
15070
15071 // Fold (aext (and (trunc x), cst)) -> (and x, cst)
15072 // if the trunc is not free.
15073 if (N0.getOpcode() == ISD::AND &&
15074 N0.getOperand(i: 0).getOpcode() == ISD::TRUNCATE &&
15075 N0.getOperand(i: 1).getOpcode() == ISD::Constant &&
15076 !TLI.isTruncateFree(Val: N0.getOperand(i: 0).getOperand(i: 0), VT2: N0.getValueType())) {
15077 SDValue X = DAG.getAnyExtOrTrunc(Op: N0.getOperand(i: 0).getOperand(i: 0), DL, VT);
15078 SDValue Y = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT, Operand: N0.getOperand(i: 1));
15079 assert(isa<ConstantSDNode>(Y) && "Expected constant to be folded!");
15080 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: X, N2: Y);
15081 }
15082
15083 // fold (aext (load x)) -> (aext (truncate (extload x)))
15084 // None of the supported targets knows how to perform load and any_ext
15085 // on vectors in one instruction, so attempt to fold to zext instead.
15086 if (VT.isVector()) {
15087 // Try to simplify (zext (load x)).
15088 if (SDValue foldedExt =
15089 tryToFoldExtOfLoad(DAG, Combiner&: *this, TLI, VT, LegalOperations, N, N0,
15090 ExtLoadType: ISD::ZEXTLOAD, ExtOpc: ISD::ZERO_EXTEND))
15091 return foldedExt;
15092 } else if (ISD::isNON_EXTLoad(N: N0.getNode()) &&
15093 ISD::isUNINDEXEDLoad(N: N0.getNode()) &&
15094 TLI.isLoadExtLegal(ExtType: ISD::EXTLOAD, ValVT: VT, MemVT: N0.getValueType())) {
15095 bool DoXform = true;
15096 SmallVector<SDNode *, 4> SetCCs;
15097 if (!N0.hasOneUse())
15098 DoXform =
15099 ExtendUsesToFormExtLoad(VT, N, N0, ExtOpc: ISD::ANY_EXTEND, ExtendNodes&: SetCCs, TLI);
15100 if (DoXform) {
15101 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
15102 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::EXTLOAD, dl: DL, VT, Chain: LN0->getChain(),
15103 Ptr: LN0->getBasePtr(), MemVT: N0.getValueType(),
15104 MMO: LN0->getMemOperand());
15105 ExtendSetCCUses(SetCCs, OrigLoad: N0, ExtLoad, ExtType: ISD::ANY_EXTEND);
15106 // If the load value is used only by N, replace it via CombineTo N.
15107 bool NoReplaceTrunc = N0.hasOneUse();
15108 CombineTo(N, Res: ExtLoad);
15109 if (NoReplaceTrunc) {
15110 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN0, 1), To: ExtLoad.getValue(R: 1));
15111 recursivelyDeleteUnusedNodes(N: LN0);
15112 } else {
15113 SDValue Trunc =
15114 DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N0), VT: N0.getValueType(), Operand: ExtLoad);
15115 CombineTo(N: LN0, Res0: Trunc, Res1: ExtLoad.getValue(R: 1));
15116 }
15117 return SDValue(N, 0); // Return N so it doesn't get rechecked!
15118 }
15119 }
15120
15121 // fold (aext (zextload x)) -> (aext (truncate (zextload x)))
15122 // fold (aext (sextload x)) -> (aext (truncate (sextload x)))
15123 // fold (aext ( extload x)) -> (aext (truncate (extload x)))
15124 if (N0.getOpcode() == ISD::LOAD && !ISD::isNON_EXTLoad(N: N0.getNode()) &&
15125 ISD::isUNINDEXEDLoad(N: N0.getNode()) && N0.hasOneUse()) {
15126 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
15127 ISD::LoadExtType ExtType = LN0->getExtensionType();
15128 EVT MemVT = LN0->getMemoryVT();
15129 if (!LegalOperations || TLI.isLoadExtLegal(ExtType, ValVT: VT, MemVT)) {
15130 SDValue ExtLoad =
15131 DAG.getExtLoad(ExtType, dl: DL, VT, Chain: LN0->getChain(), Ptr: LN0->getBasePtr(),
15132 MemVT, MMO: LN0->getMemOperand());
15133 CombineTo(N, Res: ExtLoad);
15134 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN0, 1), To: ExtLoad.getValue(R: 1));
15135 recursivelyDeleteUnusedNodes(N: LN0);
15136 return SDValue(N, 0); // Return N so it doesn't get rechecked!
15137 }
15138 }
15139
15140 if (N0.getOpcode() == ISD::SETCC) {
15141 // Propagate fast-math-flags.
15142 SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
15143
15144 // For vectors:
15145 // aext(setcc) -> vsetcc
15146 // aext(setcc) -> truncate(vsetcc)
15147 // aext(setcc) -> aext(vsetcc)
15148 // Only do this before legalize for now.
15149 if (VT.isVector() && !LegalOperations) {
15150 EVT N00VT = N0.getOperand(i: 0).getValueType();
15151 if (getSetCCResultType(VT: N00VT) == N0.getValueType())
15152 return SDValue();
15153
15154 // We know that the # elements of the results is the same as the
15155 // # elements of the compare (and the # elements of the compare result
15156 // for that matter). Check to see that they are the same size. If so,
15157 // we know that the element size of the sext'd result matches the
15158 // element size of the compare operands.
15159 if (VT.getSizeInBits() == N00VT.getSizeInBits())
15160 return DAG.getSetCC(DL, VT, LHS: N0.getOperand(i: 0), RHS: N0.getOperand(i: 1),
15161 Cond: cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get());
15162
15163 // If the desired elements are smaller or larger than the source
15164 // elements we can use a matching integer vector type and then
15165 // truncate/any extend
15166 EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
15167 SDValue VsetCC = DAG.getSetCC(
15168 DL, VT: MatchingVectorType, LHS: N0.getOperand(i: 0), RHS: N0.getOperand(i: 1),
15169 Cond: cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get());
15170 return DAG.getAnyExtOrTrunc(Op: VsetCC, DL, VT);
15171 }
15172
15173 // aext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc
15174 if (SDValue SCC = SimplifySelectCC(
15175 DL, N0: N0.getOperand(i: 0), N1: N0.getOperand(i: 1), N2: DAG.getConstant(Val: 1, DL, VT),
15176 N3: DAG.getConstant(Val: 0, DL, VT),
15177 CC: cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get(), NotExtCompare: true))
15178 return SCC;
15179 }
15180
15181 if (SDValue NewCtPop = widenCtPop(Extend: N, DAG, DL))
15182 return NewCtPop;
15183
15184 if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, DL, Level))
15185 return Res;
15186
15187 return SDValue();
15188}
15189
15190SDValue DAGCombiner::visitAssertExt(SDNode *N) {
15191 unsigned Opcode = N->getOpcode();
15192 SDValue N0 = N->getOperand(Num: 0);
15193 SDValue N1 = N->getOperand(Num: 1);
15194 EVT AssertVT = cast<VTSDNode>(Val&: N1)->getVT();
15195
15196 // fold (assert?ext (assert?ext x, vt), vt) -> (assert?ext x, vt)
15197 if (N0.getOpcode() == Opcode &&
15198 AssertVT == cast<VTSDNode>(Val: N0.getOperand(i: 1))->getVT())
15199 return N0;
15200
15201 if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
15202 N0.getOperand(i: 0).getOpcode() == Opcode) {
15203 // We have an assert, truncate, assert sandwich. Make one stronger assert
15204 // by asserting on the smallest asserted type to the larger source type.
15205 // This eliminates the later assert:
15206 // assert (trunc (assert X, i8) to iN), i1 --> trunc (assert X, i1) to iN
15207 // assert (trunc (assert X, i1) to iN), i8 --> trunc (assert X, i1) to iN
15208 SDLoc DL(N);
15209 SDValue BigA = N0.getOperand(i: 0);
15210 EVT BigA_AssertVT = cast<VTSDNode>(Val: BigA.getOperand(i: 1))->getVT();
15211 EVT MinAssertVT = AssertVT.bitsLT(VT: BigA_AssertVT) ? AssertVT : BigA_AssertVT;
15212 SDValue MinAssertVTVal = DAG.getValueType(MinAssertVT);
15213 SDValue NewAssert = DAG.getNode(Opcode, DL, VT: BigA.getValueType(),
15214 N1: BigA.getOperand(i: 0), N2: MinAssertVTVal);
15215 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: N->getValueType(ResNo: 0), Operand: NewAssert);
15216 }
15217
15218 // If we have (AssertZext (truncate (AssertSext X, iX)), iY) and Y is smaller
15219 // than X. Just move the AssertZext in front of the truncate and drop the
15220 // AssertSExt.
15221 if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
15222 N0.getOperand(i: 0).getOpcode() == ISD::AssertSext &&
15223 Opcode == ISD::AssertZext) {
15224 SDValue BigA = N0.getOperand(i: 0);
15225 EVT BigA_AssertVT = cast<VTSDNode>(Val: BigA.getOperand(i: 1))->getVT();
15226 if (AssertVT.bitsLT(VT: BigA_AssertVT)) {
15227 SDLoc DL(N);
15228 SDValue NewAssert = DAG.getNode(Opcode, DL, VT: BigA.getValueType(),
15229 N1: BigA.getOperand(i: 0), N2: N1);
15230 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: N->getValueType(ResNo: 0), Operand: NewAssert);
15231 }
15232 }
15233
15234 // If we have (AssertZext (and (AssertSext X, iX), M), iY) and Y is smaller
15235 // than X, and the And doesn't change the lower iX bits, we can move the
15236 // AssertZext in front of the And and drop the AssertSext.
15237 if (Opcode == ISD::AssertZext && N0.getOpcode() == ISD::AND &&
15238 N0.hasOneUse() && N0.getOperand(i: 0).getOpcode() == ISD::AssertSext &&
15239 isa<ConstantSDNode>(Val: N0.getOperand(i: 1))) {
15240 SDValue BigA = N0.getOperand(i: 0);
15241 EVT BigA_AssertVT = cast<VTSDNode>(Val: BigA.getOperand(i: 1))->getVT();
15242 const APInt &Mask = N0.getConstantOperandAPInt(i: 1);
15243 if (AssertVT.bitsLT(VT: BigA_AssertVT) &&
15244 Mask.countr_one() >= BigA_AssertVT.getScalarSizeInBits()) {
15245 SDLoc DL(N);
15246 SDValue NewAssert =
15247 DAG.getNode(Opcode, DL, VT: N->getValueType(ResNo: 0), N1: BigA.getOperand(i: 0), N2: N1);
15248 return DAG.getNode(Opcode: ISD::AND, DL, VT: N->getValueType(ResNo: 0), N1: NewAssert,
15249 N2: N0.getOperand(i: 1));
15250 }
15251 }
15252
15253 return SDValue();
15254}
15255
15256SDValue DAGCombiner::visitAssertAlign(SDNode *N) {
15257 SDLoc DL(N);
15258
15259 Align AL = cast<AssertAlignSDNode>(Val: N)->getAlign();
15260 SDValue N0 = N->getOperand(Num: 0);
15261
15262 // Fold (assertalign (assertalign x, AL0), AL1) ->
15263 // (assertalign x, max(AL0, AL1))
15264 if (auto *AAN = dyn_cast<AssertAlignSDNode>(Val&: N0))
15265 return DAG.getAssertAlign(DL, V: N0.getOperand(i: 0),
15266 A: std::max(a: AL, b: AAN->getAlign()));
15267
15268 // In rare cases, there are trivial arithmetic ops in source operands. Sink
15269 // this assert down to source operands so that those arithmetic ops could be
15270 // exposed to the DAG combining.
15271 switch (N0.getOpcode()) {
15272 default:
15273 break;
15274 case ISD::ADD:
15275 case ISD::PTRADD:
15276 case ISD::SUB: {
15277 unsigned AlignShift = Log2(A: AL);
15278 SDValue LHS = N0.getOperand(i: 0);
15279 SDValue RHS = N0.getOperand(i: 1);
15280 unsigned LHSAlignShift = DAG.computeKnownBits(Op: LHS).countMinTrailingZeros();
15281 unsigned RHSAlignShift = DAG.computeKnownBits(Op: RHS).countMinTrailingZeros();
15282 if (LHSAlignShift >= AlignShift || RHSAlignShift >= AlignShift) {
15283 if (LHSAlignShift < AlignShift)
15284 LHS = DAG.getAssertAlign(DL, V: LHS, A: AL);
15285 if (RHSAlignShift < AlignShift)
15286 RHS = DAG.getAssertAlign(DL, V: RHS, A: AL);
15287 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT: N0.getValueType(), N1: LHS, N2: RHS);
15288 }
15289 break;
15290 }
15291 }
15292
15293 return SDValue();
15294}
15295
15296/// If the result of a load is shifted/masked/truncated to an effectively
15297/// narrower type, try to transform the load to a narrower type and/or
15298/// use an extending load.
15299SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
15300 unsigned Opc = N->getOpcode();
15301
15302 ISD::LoadExtType ExtType = ISD::NON_EXTLOAD;
15303 SDValue N0 = N->getOperand(Num: 0);
15304 EVT VT = N->getValueType(ResNo: 0);
15305 EVT ExtVT = VT;
15306
15307 // This transformation isn't valid for vector loads.
15308 if (VT.isVector())
15309 return SDValue();
15310
15311 // The ShAmt variable is used to indicate that we've consumed a right
15312 // shift. I.e. we want to narrow the width of the load by skipping to load the
15313 // ShAmt least significant bits.
15314 unsigned ShAmt = 0;
15315 // A special case is when the least significant bits from the load are masked
15316 // away, but using an AND rather than a right shift. HasShiftedOffset is used
15317 // to indicate that the narrowed load should be left-shifted ShAmt bits to get
15318 // the result.
15319 unsigned ShiftedOffset = 0;
15320 // Special case: SIGN_EXTEND_INREG is basically truncating to ExtVT then
15321 // extended to VT.
15322 if (Opc == ISD::SIGN_EXTEND_INREG) {
15323 ExtType = ISD::SEXTLOAD;
15324 ExtVT = cast<VTSDNode>(Val: N->getOperand(Num: 1))->getVT();
15325 } else if (Opc == ISD::SRL || Opc == ISD::SRA) {
15326 // Another special-case: SRL/SRA is basically zero/sign-extending a narrower
15327 // value, or it may be shifting a higher subword, half or byte into the
15328 // lowest bits.
15329
15330 // Only handle shift with constant shift amount, and the shiftee must be a
15331 // load.
15332 auto *LN = dyn_cast<LoadSDNode>(Val&: N0);
15333 auto *N1C = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
15334 if (!N1C || !LN)
15335 return SDValue();
15336 // If the shift amount is larger than the memory type then we're not
15337 // accessing any of the loaded bytes.
15338 ShAmt = N1C->getZExtValue();
15339 uint64_t MemoryWidth = LN->getMemoryVT().getScalarSizeInBits();
15340 if (MemoryWidth <= ShAmt)
15341 return SDValue();
15342 // Attempt to fold away the SRL by using ZEXTLOAD and SRA by using SEXTLOAD.
15343 ExtType = Opc == ISD::SRL ? ISD::ZEXTLOAD : ISD::SEXTLOAD;
15344 ExtVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: MemoryWidth - ShAmt);
15345 // If original load is a SEXTLOAD then we can't simply replace it by a
15346 // ZEXTLOAD (we could potentially replace it by a more narrow SEXTLOAD
15347 // followed by a ZEXT, but that is not handled at the moment). Similarly if
15348 // the original load is a ZEXTLOAD and we want to use a SEXTLOAD.
15349 if ((LN->getExtensionType() == ISD::SEXTLOAD ||
15350 LN->getExtensionType() == ISD::ZEXTLOAD) &&
15351 LN->getExtensionType() != ExtType)
15352 return SDValue();
15353 } else if (Opc == ISD::AND) {
15354 // An AND with a constant mask is the same as a truncate + zero-extend.
15355 auto AndC = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
15356 if (!AndC)
15357 return SDValue();
15358
15359 const APInt &Mask = AndC->getAPIntValue();
15360 unsigned ActiveBits = 0;
15361 if (Mask.isMask()) {
15362 ActiveBits = Mask.countr_one();
15363 } else if (Mask.isShiftedMask(MaskIdx&: ShAmt, MaskLen&: ActiveBits)) {
15364 ShiftedOffset = ShAmt;
15365 } else {
15366 return SDValue();
15367 }
15368
15369 ExtType = ISD::ZEXTLOAD;
15370 ExtVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ActiveBits);
15371 }
15372
15373 // In case Opc==SRL we've already prepared ExtVT/ExtType/ShAmt based on doing
15374 // a right shift. Here we redo some of those checks, to possibly adjust the
15375 // ExtVT even further based on "a masking AND". We could also end up here for
15376 // other reasons (e.g. based on Opc==TRUNCATE) and that is why some checks
15377 // need to be done here as well.
15378 if (Opc == ISD::SRL || N0.getOpcode() == ISD::SRL) {
15379 SDValue SRL = Opc == ISD::SRL ? SDValue(N, 0) : N0;
15380 // Bail out when the SRL has more than one use. This is done for historical
15381 // (undocumented) reasons. Maybe intent was to guard the AND-masking below
15382 // check below? And maybe it could be non-profitable to do the transform in
15383 // case the SRL has multiple uses and we get here with Opc!=ISD::SRL?
15384 // FIXME: Can't we just skip this check for the Opc==ISD::SRL case.
15385 if (!SRL.hasOneUse())
15386 return SDValue();
15387
15388 // Only handle shift with constant shift amount, and the shiftee must be a
15389 // load.
15390 auto *LN = dyn_cast<LoadSDNode>(Val: SRL.getOperand(i: 0));
15391 auto *SRL1C = dyn_cast<ConstantSDNode>(Val: SRL.getOperand(i: 1));
15392 if (!SRL1C || !LN)
15393 return SDValue();
15394
15395 // If the shift amount is larger than the input type then we're not
15396 // accessing any of the loaded bytes. If the load was a zextload/extload
15397 // then the result of the shift+trunc is zero/undef (handled elsewhere).
15398 ShAmt = SRL1C->getZExtValue();
15399 uint64_t MemoryWidth = LN->getMemoryVT().getSizeInBits();
15400 if (ShAmt >= MemoryWidth)
15401 return SDValue();
15402
15403 // Because a SRL must be assumed to *need* to zero-extend the high bits
15404 // (as opposed to anyext the high bits), we can't combine the zextload
15405 // lowering of SRL and an sextload.
15406 if (LN->getExtensionType() == ISD::SEXTLOAD)
15407 return SDValue();
15408
15409 // Avoid reading outside the memory accessed by the original load (could
15410 // happened if we only adjust the load base pointer by ShAmt). Instead we
15411 // try to narrow the load even further. The typical scenario here is:
15412 // (i64 (truncate (i96 (srl (load x), 64)))) ->
15413 // (i64 (truncate (i96 (zextload (load i32 + offset) from i32))))
15414 if (ExtVT.getScalarSizeInBits() > MemoryWidth - ShAmt) {
15415 // Don't replace sextload by zextload.
15416 if (ExtType == ISD::SEXTLOAD)
15417 return SDValue();
15418 // Narrow the load.
15419 ExtType = ISD::ZEXTLOAD;
15420 ExtVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: MemoryWidth - ShAmt);
15421 }
15422
15423 // If the SRL is only used by a masking AND, we may be able to adjust
15424 // the ExtVT to make the AND redundant.
15425 SDNode *Mask = *(SRL->user_begin());
15426 if (SRL.hasOneUse() && Mask->getOpcode() == ISD::AND &&
15427 isa<ConstantSDNode>(Val: Mask->getOperand(Num: 1))) {
15428 unsigned Offset, ActiveBits;
15429 const APInt& ShiftMask = Mask->getConstantOperandAPInt(Num: 1);
15430 if (ShiftMask.isMask()) {
15431 EVT MaskedVT =
15432 EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ShiftMask.countr_one());
15433 // If the mask is smaller, recompute the type.
15434 if ((ExtVT.getScalarSizeInBits() > MaskedVT.getScalarSizeInBits()) &&
15435 TLI.isLoadExtLegal(ExtType, ValVT: SRL.getValueType(), MemVT: MaskedVT))
15436 ExtVT = MaskedVT;
15437 } else if (ExtType == ISD::ZEXTLOAD &&
15438 ShiftMask.isShiftedMask(MaskIdx&: Offset, MaskLen&: ActiveBits) &&
15439 (Offset + ShAmt) < VT.getScalarSizeInBits()) {
15440 EVT MaskedVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ActiveBits);
15441 // If the mask is shifted we can use a narrower load and a shl to insert
15442 // the trailing zeros.
15443 if (((Offset + ActiveBits) <= ExtVT.getScalarSizeInBits()) &&
15444 TLI.isLoadExtLegal(ExtType, ValVT: SRL.getValueType(), MemVT: MaskedVT)) {
15445 ExtVT = MaskedVT;
15446 ShAmt = Offset + ShAmt;
15447 ShiftedOffset = Offset;
15448 }
15449 }
15450 }
15451
15452 N0 = SRL.getOperand(i: 0);
15453 }
15454
15455 // If the load is shifted left (and the result isn't shifted back right), we
15456 // can fold a truncate through the shift. The typical scenario is that N
15457 // points at a TRUNCATE here so the attempted fold is:
15458 // (truncate (shl (load x), c))) -> (shl (narrow load x), c)
15459 // ShLeftAmt will indicate how much a narrowed load should be shifted left.
15460 unsigned ShLeftAmt = 0;
15461 if (ShAmt == 0 && N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
15462 ExtVT == VT && TLI.isNarrowingProfitable(N, SrcVT: N0.getValueType(), DestVT: VT)) {
15463 if (ConstantSDNode *N01 = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1))) {
15464 ShLeftAmt = N01->getZExtValue();
15465 N0 = N0.getOperand(i: 0);
15466 }
15467 }
15468
15469 // If we haven't found a load, we can't narrow it.
15470 if (!isa<LoadSDNode>(Val: N0))
15471 return SDValue();
15472
15473 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
15474 // Reducing the width of a volatile load is illegal. For atomics, we may be
15475 // able to reduce the width provided we never widen again. (see D66309)
15476 if (!LN0->isSimple() ||
15477 !isLegalNarrowLdSt(LDST: LN0, ExtType, MemVT&: ExtVT, ShAmt))
15478 return SDValue();
15479
15480 auto AdjustBigEndianShift = [&](unsigned ShAmt) {
15481 unsigned LVTStoreBits =
15482 LN0->getMemoryVT().getStoreSizeInBits().getFixedValue();
15483 unsigned EVTStoreBits = ExtVT.getStoreSizeInBits().getFixedValue();
15484 return LVTStoreBits - EVTStoreBits - ShAmt;
15485 };
15486
15487 // We need to adjust the pointer to the load by ShAmt bits in order to load
15488 // the correct bytes.
15489 unsigned PtrAdjustmentInBits =
15490 DAG.getDataLayout().isBigEndian() ? AdjustBigEndianShift(ShAmt) : ShAmt;
15491
15492 uint64_t PtrOff = PtrAdjustmentInBits / 8;
15493 SDLoc DL(LN0);
15494 // The original load itself didn't wrap, so an offset within it doesn't.
15495 SDValue NewPtr =
15496 DAG.getMemBasePlusOffset(Base: LN0->getBasePtr(), Offset: TypeSize::getFixed(ExactSize: PtrOff),
15497 DL, Flags: SDNodeFlags::NoUnsignedWrap);
15498 AddToWorklist(N: NewPtr.getNode());
15499
15500 SDValue Load;
15501 if (ExtType == ISD::NON_EXTLOAD) {
15502 const MDNode *OldRanges = LN0->getRanges();
15503 const MDNode *NewRanges = nullptr;
15504 // If LSBs are loaded and the truncated ConstantRange for the OldRanges
15505 // metadata is not the full-set for the new width then create a NewRanges
15506 // metadata for the truncated load
15507 if (ShAmt == 0 && OldRanges) {
15508 ConstantRange CR = getConstantRangeFromMetadata(RangeMD: *OldRanges);
15509 unsigned BitSize = VT.getScalarSizeInBits();
15510
15511 // It is possible for an 8-bit extending load with 8-bit range
15512 // metadata to be narrowed to an 8-bit load. This guard is necessary to
15513 // ensure that truncation is strictly smaller.
15514 if (CR.getBitWidth() > BitSize) {
15515 ConstantRange TruncatedCR = CR.truncate(BitWidth: BitSize);
15516 if (!TruncatedCR.isFullSet()) {
15517 Metadata *Bounds[2] = {
15518 ConstantAsMetadata::get(
15519 C: ConstantInt::get(Context&: *DAG.getContext(), V: TruncatedCR.getLower())),
15520 ConstantAsMetadata::get(
15521 C: ConstantInt::get(Context&: *DAG.getContext(), V: TruncatedCR.getUpper()))};
15522 NewRanges = MDNode::get(Context&: *DAG.getContext(), MDs: Bounds);
15523 }
15524 } else if (CR.getBitWidth() == BitSize)
15525 NewRanges = OldRanges;
15526 }
15527 Load = DAG.getLoad(VT, dl: DL, Chain: LN0->getChain(), Ptr: NewPtr,
15528 PtrInfo: LN0->getPointerInfo().getWithOffset(O: PtrOff),
15529 Alignment: LN0->getBaseAlign(), MMOFlags: LN0->getMemOperand()->getFlags(),
15530 AAInfo: LN0->getAAInfo(), Ranges: NewRanges);
15531 } else
15532 Load = DAG.getExtLoad(ExtType, dl: DL, VT, Chain: LN0->getChain(), Ptr: NewPtr,
15533 PtrInfo: LN0->getPointerInfo().getWithOffset(O: PtrOff), MemVT: ExtVT,
15534 Alignment: LN0->getBaseAlign(), MMOFlags: LN0->getMemOperand()->getFlags(),
15535 AAInfo: LN0->getAAInfo());
15536
15537 // Replace the old load's chain with the new load's chain.
15538 WorklistRemover DeadNodes(*this);
15539 DAG.ReplaceAllUsesOfValueWith(From: N0.getValue(R: 1), To: Load.getValue(R: 1));
15540
15541 // Shift the result left, if we've swallowed a left shift.
15542 SDValue Result = Load;
15543 if (ShLeftAmt != 0) {
15544 // If the shift amount is as large as the result size (but, presumably,
15545 // no larger than the source) then the useful bits of the result are
15546 // zero; we can't simply return the shortened shift, because the result
15547 // of that operation is undefined.
15548 if (ShLeftAmt >= VT.getScalarSizeInBits())
15549 Result = DAG.getConstant(Val: 0, DL, VT);
15550 else
15551 Result = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Result,
15552 N2: DAG.getShiftAmountConstant(Val: ShLeftAmt, VT, DL));
15553 }
15554
15555 if (ShiftedOffset != 0) {
15556 // We're using a shifted mask, so the load now has an offset. This means
15557 // that data has been loaded into the lower bytes than it would have been
15558 // before, so we need to shl the loaded data into the correct position in the
15559 // register.
15560 SDValue ShiftC = DAG.getConstant(Val: ShiftedOffset, DL, VT);
15561 Result = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Result, N2: ShiftC);
15562 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result);
15563 }
15564
15565 // Return the new loaded value.
15566 return Result;
15567}
15568
15569SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
15570 SDValue N0 = N->getOperand(Num: 0);
15571 SDValue N1 = N->getOperand(Num: 1);
15572 EVT VT = N->getValueType(ResNo: 0);
15573 EVT ExtVT = cast<VTSDNode>(Val&: N1)->getVT();
15574 unsigned VTBits = VT.getScalarSizeInBits();
15575 unsigned ExtVTBits = ExtVT.getScalarSizeInBits();
15576 SDLoc DL(N);
15577
15578 // sext_vector_inreg(undef) = 0 because the top bit will all be the same.
15579 if (N0.isUndef())
15580 return DAG.getConstant(Val: 0, DL, VT);
15581
15582 // fold (sext_in_reg c1) -> c1
15583 if (SDValue C =
15584 DAG.FoldConstantArithmetic(Opcode: ISD::SIGN_EXTEND_INREG, DL, VT, Ops: {N0, N1}))
15585 return C;
15586
15587 // If the input is already sign extended, just drop the extension.
15588 if (ExtVTBits >= DAG.ComputeMaxSignificantBits(Op: N0))
15589 return N0;
15590
15591 // fold (sext_in_reg (sext_in_reg x, VT2), VT1) -> (sext_in_reg x, minVT) pt2
15592 if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
15593 ExtVT.bitsLT(VT: cast<VTSDNode>(Val: N0.getOperand(i: 1))->getVT()))
15594 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL, VT, N1: N0.getOperand(i: 0), N2: N1);
15595
15596 // fold (sext_in_reg (sext x)) -> (sext x)
15597 // fold (sext_in_reg (aext x)) -> (sext x)
15598 // if x is small enough or if we know that x has more than 1 sign bit and the
15599 // sign_extend_inreg is extending from one of them.
15600 if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
15601 SDValue N00 = N0.getOperand(i: 0);
15602 unsigned N00Bits = N00.getScalarValueSizeInBits();
15603 if ((N00Bits <= ExtVTBits ||
15604 DAG.ComputeMaxSignificantBits(Op: N00) <= ExtVTBits) &&
15605 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SIGN_EXTEND, VT)))
15606 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: N00);
15607 }
15608
15609 // fold (sext_in_reg (*_extend_vector_inreg x)) -> (sext_vector_inreg x)
15610 // if x is small enough or if we know that x has more than 1 sign bit and the
15611 // sign_extend_inreg is extending from one of them.
15612 if (ISD::isExtVecInRegOpcode(Opcode: N0.getOpcode())) {
15613 SDValue N00 = N0.getOperand(i: 0);
15614 unsigned N00Bits = N00.getScalarValueSizeInBits();
15615 bool IsZext = N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG;
15616 if ((N00Bits == ExtVTBits ||
15617 (!IsZext && (N00Bits < ExtVTBits ||
15618 DAG.ComputeMaxSignificantBits(Op: N00) <= ExtVTBits))) &&
15619 (!LegalOperations ||
15620 TLI.isOperationLegal(Op: ISD::SIGN_EXTEND_VECTOR_INREG, VT)))
15621 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_VECTOR_INREG, DL, VT, Operand: N00);
15622 }
15623
15624 // fold (sext_in_reg (zext x)) -> (sext x)
15625 // iff we are extending the source sign bit.
15626 if (N0.getOpcode() == ISD::ZERO_EXTEND) {
15627 SDValue N00 = N0.getOperand(i: 0);
15628 if (N00.getScalarValueSizeInBits() == ExtVTBits &&
15629 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SIGN_EXTEND, VT)))
15630 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: N00);
15631 }
15632
15633 // fold (sext_in_reg x) -> (zext_in_reg x) if the sign bit is known zero.
15634 if (DAG.MaskedValueIsZero(Op: N0, Mask: APInt::getOneBitSet(numBits: VTBits, BitNo: ExtVTBits - 1)))
15635 return DAG.getZeroExtendInReg(Op: N0, DL, VT: ExtVT);
15636
15637 // fold operands of sext_in_reg based on knowledge that the top bits are not
15638 // demanded.
15639 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
15640 return SDValue(N, 0);
15641
15642 // fold (sext_in_reg (load x)) -> (smaller sextload x)
15643 // fold (sext_in_reg (srl (load x), c)) -> (smaller sextload (x+c/evtbits))
15644 if (SDValue NarrowLoad = reduceLoadWidth(N))
15645 return NarrowLoad;
15646
15647 // fold (sext_in_reg (srl X, 24), i8) -> (sra X, 24)
15648 // fold (sext_in_reg (srl X, 23), i8) -> (sra X, 23) iff possible.
15649 // We already fold "(sext_in_reg (srl X, 25), i8) -> srl X, 25" above.
15650 if (N0.getOpcode() == ISD::SRL) {
15651 if (auto *ShAmt = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1)))
15652 if (ShAmt->getAPIntValue().ule(RHS: VTBits - ExtVTBits)) {
15653 // We can turn this into an SRA iff the input to the SRL is already sign
15654 // extended enough.
15655 unsigned InSignBits = DAG.ComputeNumSignBits(Op: N0.getOperand(i: 0));
15656 if (((VTBits - ExtVTBits) - ShAmt->getZExtValue()) < InSignBits)
15657 return DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: N0.getOperand(i: 0),
15658 N2: N0.getOperand(i: 1));
15659 }
15660 }
15661
15662 // fold (sext_inreg (extload x)) -> (sextload x)
15663 // If sextload is not supported by target, we can only do the combine when
15664 // load has one use. Doing otherwise can block folding the extload with other
15665 // extends that the target does support.
15666 if (ISD::isEXTLoad(N: N0.getNode()) && ISD::isUNINDEXEDLoad(N: N0.getNode()) &&
15667 ExtVT == cast<LoadSDNode>(Val&: N0)->getMemoryVT() &&
15668 ((!LegalOperations && cast<LoadSDNode>(Val&: N0)->isSimple() &&
15669 N0.hasOneUse()) ||
15670 TLI.isLoadExtLegal(ExtType: ISD::SEXTLOAD, ValVT: VT, MemVT: ExtVT))) {
15671 auto *LN0 = cast<LoadSDNode>(Val&: N0);
15672 SDValue ExtLoad =
15673 DAG.getExtLoad(ExtType: ISD::SEXTLOAD, dl: DL, VT, Chain: LN0->getChain(),
15674 Ptr: LN0->getBasePtr(), MemVT: ExtVT, MMO: LN0->getMemOperand());
15675 CombineTo(N, Res: ExtLoad);
15676 CombineTo(N: N0.getNode(), Res0: ExtLoad, Res1: ExtLoad.getValue(R: 1));
15677 AddToWorklist(N: ExtLoad.getNode());
15678 return SDValue(N, 0); // Return N so it doesn't get rechecked!
15679 }
15680
15681 // fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use
15682 if (ISD::isZEXTLoad(N: N0.getNode()) && ISD::isUNINDEXEDLoad(N: N0.getNode()) &&
15683 N0.hasOneUse() && ExtVT == cast<LoadSDNode>(Val&: N0)->getMemoryVT() &&
15684 ((!LegalOperations && cast<LoadSDNode>(Val&: N0)->isSimple()) &&
15685 TLI.isLoadExtLegal(ExtType: ISD::SEXTLOAD, ValVT: VT, MemVT: ExtVT))) {
15686 auto *LN0 = cast<LoadSDNode>(Val&: N0);
15687 SDValue ExtLoad =
15688 DAG.getExtLoad(ExtType: ISD::SEXTLOAD, dl: DL, VT, Chain: LN0->getChain(),
15689 Ptr: LN0->getBasePtr(), MemVT: ExtVT, MMO: LN0->getMemOperand());
15690 CombineTo(N, Res: ExtLoad);
15691 CombineTo(N: N0.getNode(), Res0: ExtLoad, Res1: ExtLoad.getValue(R: 1));
15692 return SDValue(N, 0); // Return N so it doesn't get rechecked!
15693 }
15694
15695 // fold (sext_inreg (masked_load x)) -> (sext_masked_load x)
15696 // ignore it if the masked load is already sign extended
15697 if (MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(Val&: N0)) {
15698 if (ExtVT == Ld->getMemoryVT() && N0.hasOneUse() &&
15699 Ld->getExtensionType() != ISD::LoadExtType::NON_EXTLOAD &&
15700 TLI.isLoadExtLegal(ExtType: ISD::SEXTLOAD, ValVT: VT, MemVT: ExtVT)) {
15701 SDValue ExtMaskedLoad = DAG.getMaskedLoad(
15702 VT, dl: DL, Chain: Ld->getChain(), Base: Ld->getBasePtr(), Offset: Ld->getOffset(),
15703 Mask: Ld->getMask(), Src0: Ld->getPassThru(), MemVT: ExtVT, MMO: Ld->getMemOperand(),
15704 AM: Ld->getAddressingMode(), ISD::SEXTLOAD, IsExpanding: Ld->isExpandingLoad());
15705 CombineTo(N, Res: ExtMaskedLoad);
15706 CombineTo(N: N0.getNode(), Res0: ExtMaskedLoad, Res1: ExtMaskedLoad.getValue(R: 1));
15707 return SDValue(N, 0); // Return N so it doesn't get rechecked!
15708 }
15709 }
15710
15711 // fold (sext_inreg (masked_gather x)) -> (sext_masked_gather x)
15712 if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(Val&: N0)) {
15713 if (SDValue(GN0, 0).hasOneUse() && ExtVT == GN0->getMemoryVT() &&
15714 TLI.isVectorLoadExtDesirable(ExtVal: SDValue(SDValue(GN0, 0)))) {
15715 SDValue Ops[] = {GN0->getChain(), GN0->getPassThru(), GN0->getMask(),
15716 GN0->getBasePtr(), GN0->getIndex(), GN0->getScale()};
15717
15718 SDValue ExtLoad = DAG.getMaskedGather(
15719 VTs: DAG.getVTList(VT1: VT, VT2: MVT::Other), MemVT: ExtVT, dl: DL, Ops, MMO: GN0->getMemOperand(),
15720 IndexType: GN0->getIndexType(), ExtTy: ISD::SEXTLOAD);
15721
15722 CombineTo(N, Res: ExtLoad);
15723 CombineTo(N: N0.getNode(), Res0: ExtLoad, Res1: ExtLoad.getValue(R: 1));
15724 AddToWorklist(N: ExtLoad.getNode());
15725 return SDValue(N, 0); // Return N so it doesn't get rechecked!
15726 }
15727 }
15728
15729 // Form (sext_inreg (bswap >> 16)) or (sext_inreg (rotl (bswap) 16))
15730 if (ExtVTBits <= 16 && N0.getOpcode() == ISD::OR) {
15731 if (SDValue BSwap = MatchBSwapHWordLow(N: N0.getNode(), N0: N0.getOperand(i: 0),
15732 N1: N0.getOperand(i: 1), DemandHighBits: false))
15733 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL, VT, N1: BSwap, N2: N1);
15734 }
15735
15736 // Fold (iM_signext_inreg
15737 // (extract_subvector (zext|anyext|sext iN_v to _) _)
15738 // from iN)
15739 // -> (extract_subvector (signext iN_v to iM))
15740 if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && N0.hasOneUse() &&
15741 ISD::isExtOpcode(Opcode: N0.getOperand(i: 0).getOpcode())) {
15742 SDValue InnerExt = N0.getOperand(i: 0);
15743 EVT InnerExtVT = InnerExt->getValueType(ResNo: 0);
15744 SDValue Extendee = InnerExt->getOperand(Num: 0);
15745
15746 if (ExtVTBits == Extendee.getValueType().getScalarSizeInBits() &&
15747 (!LegalOperations ||
15748 TLI.isOperationLegal(Op: ISD::SIGN_EXTEND, VT: InnerExtVT))) {
15749 SDValue SignExtExtendee =
15750 DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT: InnerExtVT, Operand: Extendee);
15751 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT, N1: SignExtExtendee,
15752 N2: N0.getOperand(i: 1));
15753 }
15754 }
15755
15756 return SDValue();
15757}
15758
15759static SDValue foldExtendVectorInregToExtendOfSubvector(
15760 SDNode *N, const SDLoc &DL, const TargetLowering &TLI, SelectionDAG &DAG,
15761 bool LegalOperations) {
15762 unsigned InregOpcode = N->getOpcode();
15763 unsigned Opcode = DAG.getOpcode_EXTEND(Opcode: InregOpcode);
15764
15765 SDValue Src = N->getOperand(Num: 0);
15766 EVT VT = N->getValueType(ResNo: 0);
15767 EVT SrcVT = EVT::getVectorVT(Context&: *DAG.getContext(),
15768 VT: Src.getValueType().getVectorElementType(),
15769 EC: VT.getVectorElementCount());
15770
15771 assert(ISD::isExtVecInRegOpcode(InregOpcode) &&
15772 "Expected EXTEND_VECTOR_INREG dag node in input!");
15773
15774 // Profitability check: our operand must be an one-use CONCAT_VECTORS.
15775 // FIXME: one-use check may be overly restrictive
15776 if (!Src.hasOneUse() || Src.getOpcode() != ISD::CONCAT_VECTORS)
15777 return SDValue();
15778
15779 // Profitability check: we must be extending exactly one of it's operands.
15780 // FIXME: this is probably overly restrictive.
15781 Src = Src.getOperand(i: 0);
15782 if (Src.getValueType() != SrcVT)
15783 return SDValue();
15784
15785 if (LegalOperations && !TLI.isOperationLegal(Op: Opcode, VT))
15786 return SDValue();
15787
15788 return DAG.getNode(Opcode, DL, VT, Operand: Src);
15789}
15790
15791SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
15792 SDValue N0 = N->getOperand(Num: 0);
15793 EVT VT = N->getValueType(ResNo: 0);
15794 SDLoc DL(N);
15795
15796 if (N0.isUndef()) {
15797 // aext_vector_inreg(undef) = undef because the top bits are undefined.
15798 // {s/z}ext_vector_inreg(undef) = 0 because the top bits must be the same.
15799 return N->getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG
15800 ? DAG.getUNDEF(VT)
15801 : DAG.getConstant(Val: 0, DL, VT);
15802 }
15803
15804 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
15805 return Res;
15806
15807 if (SimplifyDemandedVectorElts(Op: SDValue(N, 0)))
15808 return SDValue(N, 0);
15809
15810 if (SDValue R = foldExtendVectorInregToExtendOfSubvector(N, DL, TLI, DAG,
15811 LegalOperations))
15812 return R;
15813
15814 return SDValue();
15815}
15816
15817SDValue DAGCombiner::visitTRUNCATE_USAT_U(SDNode *N) {
15818 EVT VT = N->getValueType(ResNo: 0);
15819 SDValue N0 = N->getOperand(Num: 0);
15820
15821 SDValue FPVal;
15822 if (sd_match(N: N0, P: m_FPToUI(Op: m_Value(N&: FPVal))) &&
15823 DAG.getTargetLoweringInfo().shouldConvertFpToSat(
15824 Op: ISD::FP_TO_UINT_SAT, FPVT: FPVal.getValueType(), VT))
15825 return DAG.getNode(Opcode: ISD::FP_TO_UINT_SAT, DL: SDLoc(N0), VT, N1: FPVal,
15826 N2: DAG.getValueType(VT.getScalarType()));
15827
15828 return SDValue();
15829}
15830
15831/// Detect patterns of truncation with unsigned saturation:
15832///
15833/// (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
15834/// Return the source value x to be truncated or SDValue() if the pattern was
15835/// not matched.
15836///
15837static SDValue detectUSatUPattern(SDValue In, EVT VT) {
15838 unsigned NumDstBits = VT.getScalarSizeInBits();
15839 unsigned NumSrcBits = In.getScalarValueSizeInBits();
15840 // Saturation with truncation. We truncate from InVT to VT.
15841 assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
15842
15843 SDValue Min;
15844 APInt UnsignedMax = APInt::getMaxValue(numBits: NumDstBits).zext(width: NumSrcBits);
15845 if (sd_match(N: In, P: m_UMin(L: m_Value(N&: Min), R: m_SpecificInt(V: UnsignedMax))))
15846 return Min;
15847
15848 return SDValue();
15849}
15850
15851/// Detect patterns of truncation with signed saturation:
15852/// (truncate (smin (smax (x, signed_min_of_dest_type),
15853/// signed_max_of_dest_type)) to dest_type)
15854/// or:
15855/// (truncate (smax (smin (x, signed_max_of_dest_type),
15856/// signed_min_of_dest_type)) to dest_type).
15857///
15858/// Return the source value to be truncated or SDValue() if the pattern was not
15859/// matched.
15860static SDValue detectSSatSPattern(SDValue In, EVT VT) {
15861 unsigned NumDstBits = VT.getScalarSizeInBits();
15862 unsigned NumSrcBits = In.getScalarValueSizeInBits();
15863 // Saturation with truncation. We truncate from InVT to VT.
15864 assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
15865
15866 SDValue Val;
15867 APInt SignedMax = APInt::getSignedMaxValue(numBits: NumDstBits).sext(width: NumSrcBits);
15868 APInt SignedMin = APInt::getSignedMinValue(numBits: NumDstBits).sext(width: NumSrcBits);
15869
15870 if (sd_match(N: In, P: m_SMin(L: m_SMax(L: m_Value(N&: Val), R: m_SpecificInt(V: SignedMin)),
15871 R: m_SpecificInt(V: SignedMax))))
15872 return Val;
15873
15874 if (sd_match(N: In, P: m_SMax(L: m_SMin(L: m_Value(N&: Val), R: m_SpecificInt(V: SignedMax)),
15875 R: m_SpecificInt(V: SignedMin))))
15876 return Val;
15877
15878 return SDValue();
15879}
15880
15881/// Detect patterns of truncation with unsigned saturation:
15882static SDValue detectSSatUPattern(SDValue In, EVT VT, SelectionDAG &DAG,
15883 const SDLoc &DL) {
15884 unsigned NumDstBits = VT.getScalarSizeInBits();
15885 unsigned NumSrcBits = In.getScalarValueSizeInBits();
15886 // Saturation with truncation. We truncate from InVT to VT.
15887 assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
15888
15889 SDValue Val;
15890 APInt UnsignedMax = APInt::getMaxValue(numBits: NumDstBits).zext(width: NumSrcBits);
15891 // Min == 0, Max is unsigned max of destination type.
15892 if (sd_match(N: In, P: m_SMax(L: m_SMin(L: m_Value(N&: Val), R: m_SpecificInt(V: UnsignedMax)),
15893 R: m_Zero())))
15894 return Val;
15895
15896 if (sd_match(N: In, P: m_SMin(L: m_SMax(L: m_Value(N&: Val), R: m_Zero()),
15897 R: m_SpecificInt(V: UnsignedMax))))
15898 return Val;
15899
15900 if (sd_match(N: In, P: m_UMin(L: m_SMax(L: m_Value(N&: Val), R: m_Zero()),
15901 R: m_SpecificInt(V: UnsignedMax))))
15902 return Val;
15903
15904 return SDValue();
15905}
15906
15907static SDValue foldToSaturated(SDNode *N, EVT &VT, SDValue &Src, EVT &SrcVT,
15908 SDLoc &DL, const TargetLowering &TLI,
15909 SelectionDAG &DAG) {
15910 auto AllowedTruncateSat = [&](unsigned Opc, EVT SrcVT, EVT VT) -> bool {
15911 return (TLI.isOperationLegalOrCustom(Op: Opc, VT: SrcVT) &&
15912 TLI.isTypeDesirableForOp(Opc, VT));
15913 };
15914
15915 if (Src.getOpcode() == ISD::SMIN || Src.getOpcode() == ISD::SMAX) {
15916 if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_S, SrcVT, VT))
15917 if (SDValue SSatVal = detectSSatSPattern(In: Src, VT))
15918 return DAG.getNode(Opcode: ISD::TRUNCATE_SSAT_S, DL, VT, Operand: SSatVal);
15919 if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_U, SrcVT, VT))
15920 if (SDValue SSatVal = detectSSatUPattern(In: Src, VT, DAG, DL))
15921 return DAG.getNode(Opcode: ISD::TRUNCATE_SSAT_U, DL, VT, Operand: SSatVal);
15922 } else if (Src.getOpcode() == ISD::UMIN) {
15923 if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_U, SrcVT, VT))
15924 if (SDValue SSatVal = detectSSatUPattern(In: Src, VT, DAG, DL))
15925 return DAG.getNode(Opcode: ISD::TRUNCATE_SSAT_U, DL, VT, Operand: SSatVal);
15926 if (AllowedTruncateSat(ISD::TRUNCATE_USAT_U, SrcVT, VT))
15927 if (SDValue USatVal = detectUSatUPattern(In: Src, VT))
15928 return DAG.getNode(Opcode: ISD::TRUNCATE_USAT_U, DL, VT, Operand: USatVal);
15929 }
15930
15931 return SDValue();
15932}
15933
15934SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
15935 SDValue N0 = N->getOperand(Num: 0);
15936 EVT VT = N->getValueType(ResNo: 0);
15937 EVT SrcVT = N0.getValueType();
15938 bool isLE = DAG.getDataLayout().isLittleEndian();
15939 SDLoc DL(N);
15940
15941 // trunc(undef) = undef
15942 if (N0.isUndef())
15943 return DAG.getUNDEF(VT);
15944
15945 // fold (truncate (truncate x)) -> (truncate x)
15946 if (N0.getOpcode() == ISD::TRUNCATE)
15947 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 0));
15948
15949 // fold saturated truncate
15950 if (SDValue SaturatedTR = foldToSaturated(N, VT, Src&: N0, SrcVT, DL, TLI, DAG))
15951 return SaturatedTR;
15952
15953 // fold (truncate c1) -> c1
15954 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::TRUNCATE, DL, VT, Ops: {N0}))
15955 return C;
15956
15957 // fold (truncate (ext x)) -> (ext x) or (truncate x) or x
15958 if (N0.getOpcode() == ISD::ZERO_EXTEND ||
15959 N0.getOpcode() == ISD::SIGN_EXTEND ||
15960 N0.getOpcode() == ISD::ANY_EXTEND) {
15961 // if the source is smaller than the dest, we still need an extend.
15962 if (N0.getOperand(i: 0).getValueType().bitsLT(VT)) {
15963 SDNodeFlags Flags;
15964 if (N0.getOpcode() == ISD::ZERO_EXTEND)
15965 Flags.setNonNeg(N0->getFlags().hasNonNeg());
15966 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, Operand: N0.getOperand(i: 0), Flags);
15967 }
15968 // if the source is larger than the dest, than we just need the truncate.
15969 if (N0.getOperand(i: 0).getValueType().bitsGT(VT))
15970 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 0));
15971 // if the source and dest are the same type, we can drop both the extend
15972 // and the truncate.
15973 return N0.getOperand(i: 0);
15974 }
15975
15976 // Try to narrow a truncate-of-sext_in_reg to the destination type:
15977 // trunc (sign_ext_inreg X, iM) to iN --> sign_ext_inreg (trunc X to iN), iM
15978 if (!LegalTypes && N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
15979 N0.hasOneUse()) {
15980 SDValue X = N0.getOperand(i: 0);
15981 SDValue ExtVal = N0.getOperand(i: 1);
15982 EVT ExtVT = cast<VTSDNode>(Val&: ExtVal)->getVT();
15983 if (ExtVT.bitsLT(VT) && TLI.preferSextInRegOfTruncate(TruncVT: VT, VT: SrcVT, ExtVT)) {
15984 SDValue TrX = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: X);
15985 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL, VT, N1: TrX, N2: ExtVal);
15986 }
15987 }
15988
15989 // If this is anyext(trunc), don't fold it, allow ourselves to be folded.
15990 if (N->hasOneUse() && (N->user_begin()->getOpcode() == ISD::ANY_EXTEND))
15991 return SDValue();
15992
15993 // Fold extract-and-trunc into a narrow extract. For example:
15994 // i64 x = EXTRACT_VECTOR_ELT(v2i64 val, i32 1)
15995 // i32 y = TRUNCATE(i64 x)
15996 // -- becomes --
15997 // v16i8 b = BITCAST (v2i64 val)
15998 // i8 x = EXTRACT_VECTOR_ELT(v16i8 b, i32 8)
15999 //
16000 // Note: We only run this optimization after type legalization (which often
16001 // creates this pattern) and before operation legalization after which
16002 // we need to be more careful about the vector instructions that we generate.
16003 if (LegalTypes && !LegalOperations && VT.isScalarInteger() && VT != MVT::i1 &&
16004 N0->hasOneUse()) {
16005 EVT TrTy = N->getValueType(ResNo: 0);
16006 SDValue Src = N0;
16007
16008 // Check for cases where we shift down an upper element before truncation.
16009 int EltOffset = 0;
16010 if (Src.getOpcode() == ISD::SRL && Src.getOperand(i: 0)->hasOneUse()) {
16011 if (auto ShAmt = DAG.getValidShiftAmount(V: Src)) {
16012 if ((*ShAmt % TrTy.getSizeInBits()) == 0) {
16013 Src = Src.getOperand(i: 0);
16014 EltOffset = *ShAmt / TrTy.getSizeInBits();
16015 }
16016 }
16017 }
16018
16019 if (Src.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
16020 EVT VecTy = Src.getOperand(i: 0).getValueType();
16021 EVT ExTy = Src.getValueType();
16022
16023 auto EltCnt = VecTy.getVectorElementCount();
16024 unsigned SizeRatio = ExTy.getSizeInBits() / TrTy.getSizeInBits();
16025 auto NewEltCnt = EltCnt * SizeRatio;
16026
16027 EVT NVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: TrTy, EC: NewEltCnt);
16028 assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size");
16029
16030 SDValue EltNo = Src->getOperand(Num: 1);
16031 if (isa<ConstantSDNode>(Val: EltNo) && isTypeLegal(VT: NVT)) {
16032 int Elt = EltNo->getAsZExtVal();
16033 int Index = isLE ? (Elt * SizeRatio + EltOffset)
16034 : (Elt * SizeRatio + (SizeRatio - 1) - EltOffset);
16035 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: TrTy,
16036 N1: DAG.getBitcast(VT: NVT, V: Src.getOperand(i: 0)),
16037 N2: DAG.getVectorIdxConstant(Val: Index, DL));
16038 }
16039 }
16040 }
16041
16042 // trunc (select c, a, b) -> select c, (trunc a), (trunc b)
16043 if (N0.getOpcode() == ISD::SELECT && N0.hasOneUse() &&
16044 TLI.isTruncateFree(FromVT: SrcVT, ToVT: VT)) {
16045 if (!LegalOperations ||
16046 (TLI.isOperationLegal(Op: ISD::SELECT, VT: SrcVT) &&
16047 TLI.isNarrowingProfitable(N: N0.getNode(), SrcVT, DestVT: VT))) {
16048 SDLoc SL(N0);
16049 SDValue Cond = N0.getOperand(i: 0);
16050 SDValue TruncOp0 = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SL, VT, Operand: N0.getOperand(i: 1));
16051 SDValue TruncOp1 = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SL, VT, Operand: N0.getOperand(i: 2));
16052 return DAG.getNode(Opcode: ISD::SELECT, DL, VT, N1: Cond, N2: TruncOp0, N3: TruncOp1);
16053 }
16054 }
16055
16056 // trunc (shl x, K) -> shl (trunc x), K => K < VT.getScalarSizeInBits()
16057 if (N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
16058 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SHL, VT)) &&
16059 TLI.isTypeDesirableForOp(ISD::SHL, VT)) {
16060 SDValue Amt = N0.getOperand(i: 1);
16061 KnownBits Known = DAG.computeKnownBits(Op: Amt);
16062 unsigned Size = VT.getScalarSizeInBits();
16063 if (Known.countMaxActiveBits() <= Log2_32(Value: Size)) {
16064 EVT AmtVT = TLI.getShiftAmountTy(LHSTy: VT, DL: DAG.getDataLayout());
16065 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 0));
16066 if (AmtVT != Amt.getValueType()) {
16067 Amt = DAG.getZExtOrTrunc(Op: Amt, DL, VT: AmtVT);
16068 AddToWorklist(N: Amt.getNode());
16069 }
16070 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Trunc, N2: Amt);
16071 }
16072 }
16073
16074 if (SDValue V = foldSubToUSubSat(DstVT: VT, N: N0.getNode(), DL))
16075 return V;
16076
16077 if (SDValue ABD = foldABSToABD(N, DL))
16078 return ABD;
16079
16080 // Attempt to pre-truncate BUILD_VECTOR sources.
16081 if (N0.getOpcode() == ISD::BUILD_VECTOR && !LegalOperations &&
16082 N0.hasOneUse() &&
16083 TLI.isTruncateFree(FromVT: SrcVT.getScalarType(), ToVT: VT.getScalarType()) &&
16084 // Avoid creating illegal types if running after type legalizer.
16085 (!LegalTypes || TLI.isTypeLegal(VT: VT.getScalarType()))) {
16086 EVT SVT = VT.getScalarType();
16087 SmallVector<SDValue, 8> TruncOps;
16088 for (const SDValue &Op : N0->op_values()) {
16089 SDValue TruncOp = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: SVT, Operand: Op);
16090 TruncOps.push_back(Elt: TruncOp);
16091 }
16092 return DAG.getBuildVector(VT, DL, Ops: TruncOps);
16093 }
16094
16095 // trunc (splat_vector x) -> splat_vector (trunc x)
16096 if (N0.getOpcode() == ISD::SPLAT_VECTOR &&
16097 (!LegalTypes || TLI.isTypeLegal(VT: VT.getScalarType())) &&
16098 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SPLAT_VECTOR, VT))) {
16099 EVT SVT = VT.getScalarType();
16100 return DAG.getSplatVector(
16101 VT, DL, Op: DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: SVT, Operand: N0->getOperand(Num: 0)));
16102 }
16103
16104 // Fold a series of buildvector, bitcast, and truncate if possible.
16105 // For example fold
16106 // (2xi32 trunc (bitcast ((4xi32)buildvector x, x, y, y) 2xi64)) to
16107 // (2xi32 (buildvector x, y)).
16108 if (Level == AfterLegalizeVectorOps && VT.isVector() &&
16109 N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
16110 N0.getOperand(i: 0).getOpcode() == ISD::BUILD_VECTOR &&
16111 N0.getOperand(i: 0).hasOneUse()) {
16112 SDValue BuildVect = N0.getOperand(i: 0);
16113 EVT BuildVectEltTy = BuildVect.getValueType().getVectorElementType();
16114 EVT TruncVecEltTy = VT.getVectorElementType();
16115
16116 // Check that the element types match.
16117 if (BuildVectEltTy == TruncVecEltTy) {
16118 // Now we only need to compute the offset of the truncated elements.
16119 unsigned BuildVecNumElts = BuildVect.getNumOperands();
16120 unsigned TruncVecNumElts = VT.getVectorNumElements();
16121 unsigned TruncEltOffset = BuildVecNumElts / TruncVecNumElts;
16122 unsigned FirstElt = isLE ? 0 : (TruncEltOffset - 1);
16123
16124 assert((BuildVecNumElts % TruncVecNumElts) == 0 &&
16125 "Invalid number of elements");
16126
16127 SmallVector<SDValue, 8> Opnds;
16128 for (unsigned i = FirstElt, e = BuildVecNumElts; i < e;
16129 i += TruncEltOffset)
16130 Opnds.push_back(Elt: BuildVect.getOperand(i));
16131
16132 return DAG.getBuildVector(VT, DL, Ops: Opnds);
16133 }
16134 }
16135
16136 // fold (truncate (load x)) -> (smaller load x)
16137 // fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits))
16138 if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) {
16139 if (SDValue Reduced = reduceLoadWidth(N))
16140 return Reduced;
16141
16142 // Handle the case where the truncated result is at least as wide as the
16143 // loaded type.
16144 if (N0.hasOneUse() && ISD::isUNINDEXEDLoad(N: N0.getNode())) {
16145 auto *LN0 = cast<LoadSDNode>(Val&: N0);
16146 if (LN0->isSimple() && LN0->getMemoryVT().bitsLE(VT)) {
16147 SDValue NewLoad = DAG.getExtLoad(
16148 ExtType: LN0->getExtensionType(), dl: SDLoc(LN0), VT, Chain: LN0->getChain(),
16149 Ptr: LN0->getBasePtr(), MemVT: LN0->getMemoryVT(), MMO: LN0->getMemOperand());
16150 DAG.ReplaceAllUsesOfValueWith(From: N0.getValue(R: 1), To: NewLoad.getValue(R: 1));
16151 return NewLoad;
16152 }
16153 }
16154 }
16155
16156 // fold (trunc (concat ... x ...)) -> (concat ..., (trunc x), ...)),
16157 // where ... are all 'undef'.
16158 if (N0.getOpcode() == ISD::CONCAT_VECTORS && !LegalTypes) {
16159 SmallVector<EVT, 8> VTs;
16160 SDValue V;
16161 unsigned Idx = 0;
16162 unsigned NumDefs = 0;
16163
16164 for (unsigned i = 0, e = N0.getNumOperands(); i != e; ++i) {
16165 SDValue X = N0.getOperand(i);
16166 if (!X.isUndef()) {
16167 V = X;
16168 Idx = i;
16169 NumDefs++;
16170 }
16171 // Stop if more than one members are non-undef.
16172 if (NumDefs > 1)
16173 break;
16174
16175 VTs.push_back(Elt: EVT::getVectorVT(Context&: *DAG.getContext(),
16176 VT: VT.getVectorElementType(),
16177 EC: X.getValueType().getVectorElementCount()));
16178 }
16179
16180 if (NumDefs == 0)
16181 return DAG.getUNDEF(VT);
16182
16183 if (NumDefs == 1) {
16184 assert(V.getNode() && "The single defined operand is empty!");
16185 SmallVector<SDValue, 8> Opnds;
16186 for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
16187 if (i != Idx) {
16188 Opnds.push_back(Elt: DAG.getUNDEF(VT: VTs[i]));
16189 continue;
16190 }
16191 SDValue NV = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(V), VT: VTs[i], Operand: V);
16192 AddToWorklist(N: NV.getNode());
16193 Opnds.push_back(Elt: NV);
16194 }
16195 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT, Ops: Opnds);
16196 }
16197 }
16198
16199 // Fold truncate of a bitcast of a vector to an extract of the low vector
16200 // element.
16201 //
16202 // e.g. trunc (i64 (bitcast v2i32:x)) -> extract_vector_elt v2i32:x, idx
16203 if (N0.getOpcode() == ISD::BITCAST && !VT.isVector()) {
16204 SDValue VecSrc = N0.getOperand(i: 0);
16205 EVT VecSrcVT = VecSrc.getValueType();
16206 if (VecSrcVT.isVector() && VecSrcVT.getScalarType() == VT &&
16207 (!LegalOperations ||
16208 TLI.isOperationLegal(Op: ISD::EXTRACT_VECTOR_ELT, VT: VecSrcVT))) {
16209 unsigned Idx = isLE ? 0 : VecSrcVT.getVectorNumElements() - 1;
16210 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT, N1: VecSrc,
16211 N2: DAG.getVectorIdxConstant(Val: Idx, DL));
16212 }
16213 }
16214
16215 // Simplify the operands using demanded-bits information.
16216 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
16217 return SDValue(N, 0);
16218
16219 // fold (truncate (extract_subvector(ext x))) ->
16220 // (extract_subvector x)
16221 // TODO: This can be generalized to cover cases where the truncate and extract
16222 // do not fully cancel each other out.
16223 if (!LegalTypes && N0.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
16224 SDValue N00 = N0.getOperand(i: 0);
16225 if (N00.getOpcode() == ISD::SIGN_EXTEND ||
16226 N00.getOpcode() == ISD::ZERO_EXTEND ||
16227 N00.getOpcode() == ISD::ANY_EXTEND) {
16228 if (N00.getOperand(i: 0)->getValueType(ResNo: 0).getVectorElementType() ==
16229 VT.getVectorElementType())
16230 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL: SDLoc(N0->getOperand(Num: 0)), VT,
16231 N1: N00.getOperand(i: 0), N2: N0.getOperand(i: 1));
16232 }
16233 }
16234
16235 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(Cast: N))
16236 return NewVSel;
16237
16238 // Narrow a suitable binary operation with a non-opaque constant operand by
16239 // moving it ahead of the truncate. This is limited to pre-legalization
16240 // because targets may prefer a wider type during later combines and invert
16241 // this transform.
16242 switch (N0.getOpcode()) {
16243 case ISD::ADD:
16244 case ISD::SUB:
16245 case ISD::MUL:
16246 case ISD::AND:
16247 case ISD::OR:
16248 case ISD::XOR:
16249 if (!LegalOperations && N0.hasOneUse() &&
16250 (isConstantOrConstantVector(N: N0.getOperand(i: 0), NoOpaques: true) ||
16251 isConstantOrConstantVector(N: N0.getOperand(i: 1), NoOpaques: true))) {
16252 // TODO: We already restricted this to pre-legalization, but for vectors
16253 // we are extra cautious to not create an unsupported operation.
16254 // Target-specific changes are likely needed to avoid regressions here.
16255 if (VT.isScalarInteger() || TLI.isOperationLegal(Op: N0.getOpcode(), VT)) {
16256 SDValue NarrowL = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 0));
16257 SDValue NarrowR = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 1));
16258 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, N1: NarrowL, N2: NarrowR);
16259 }
16260 }
16261 break;
16262 case ISD::ADDE:
16263 case ISD::UADDO_CARRY:
16264 // (trunc adde(X, Y, Carry)) -> (adde trunc(X), trunc(Y), Carry)
16265 // (trunc uaddo_carry(X, Y, Carry)) ->
16266 // (uaddo_carry trunc(X), trunc(Y), Carry)
16267 // When the adde's carry is not used.
16268 // We only do for uaddo_carry before legalize operation
16269 if (((!LegalOperations && N0.getOpcode() == ISD::UADDO_CARRY) ||
16270 TLI.isOperationLegal(Op: N0.getOpcode(), VT)) &&
16271 N0.hasOneUse() && !N0->hasAnyUseOfValue(Value: 1)) {
16272 SDValue X = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 0));
16273 SDValue Y = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 1));
16274 SDVTList VTs = DAG.getVTList(VT1: VT, VT2: N0->getValueType(ResNo: 1));
16275 return DAG.getNode(Opcode: N0.getOpcode(), DL, VTList: VTs, N1: X, N2: Y, N3: N0.getOperand(i: 2));
16276 }
16277 break;
16278 case ISD::USUBSAT:
16279 // Truncate the USUBSAT only if LHS is a known zero-extension, its not
16280 // enough to know that the upper bits are zero we must ensure that we don't
16281 // introduce an extra truncate.
16282 if (!LegalOperations && N0.hasOneUse() &&
16283 N0.getOperand(i: 0).getOpcode() == ISD::ZERO_EXTEND &&
16284 N0.getOperand(i: 0).getOperand(i: 0).getScalarValueSizeInBits() <=
16285 VT.getScalarSizeInBits() &&
16286 hasOperation(Opcode: N0.getOpcode(), VT)) {
16287 return getTruncatedUSUBSAT(DstVT: VT, SrcVT, LHS: N0.getOperand(i: 0), RHS: N0.getOperand(i: 1),
16288 DAG, DL);
16289 }
16290 break;
16291 }
16292
16293 return SDValue();
16294}
16295
16296static SDNode *getBuildPairElt(SDNode *N, unsigned i) {
16297 SDValue Elt = N->getOperand(Num: i);
16298 if (Elt.getOpcode() != ISD::MERGE_VALUES)
16299 return Elt.getNode();
16300 return Elt.getOperand(i: Elt.getResNo()).getNode();
16301}
16302
16303/// build_pair (load, load) -> load
16304/// if load locations are consecutive.
16305SDValue DAGCombiner::CombineConsecutiveLoads(SDNode *N, EVT VT) {
16306 assert(N->getOpcode() == ISD::BUILD_PAIR);
16307
16308 auto *LD1 = dyn_cast<LoadSDNode>(Val: getBuildPairElt(N, i: 0));
16309 auto *LD2 = dyn_cast<LoadSDNode>(Val: getBuildPairElt(N, i: 1));
16310
16311 // A BUILD_PAIR is always having the least significant part in elt 0 and the
16312 // most significant part in elt 1. So when combining into one large load, we
16313 // need to consider the endianness.
16314 if (DAG.getDataLayout().isBigEndian())
16315 std::swap(a&: LD1, b&: LD2);
16316
16317 if (!LD1 || !LD2 || !ISD::isNON_EXTLoad(N: LD1) || !ISD::isNON_EXTLoad(N: LD2) ||
16318 !LD1->hasOneUse() || !LD2->hasOneUse() ||
16319 LD1->getAddressSpace() != LD2->getAddressSpace())
16320 return SDValue();
16321
16322 unsigned LD1Fast = 0;
16323 EVT LD1VT = LD1->getValueType(ResNo: 0);
16324 unsigned LD1Bytes = LD1VT.getStoreSize();
16325 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::LOAD, VT)) &&
16326 DAG.areNonVolatileConsecutiveLoads(LD: LD2, Base: LD1, Bytes: LD1Bytes, Dist: 1) &&
16327 TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT,
16328 MMO: *LD1->getMemOperand(), Fast: &LD1Fast) && LD1Fast)
16329 return DAG.getLoad(VT, dl: SDLoc(N), Chain: LD1->getChain(), Ptr: LD1->getBasePtr(),
16330 PtrInfo: LD1->getPointerInfo(), Alignment: LD1->getAlign());
16331
16332 return SDValue();
16333}
16334
16335static unsigned getPPCf128HiElementSelector(const SelectionDAG &DAG) {
16336 // On little-endian machines, bitcasting from ppcf128 to i128 does swap the Hi
16337 // and Lo parts; on big-endian machines it doesn't.
16338 return DAG.getDataLayout().isBigEndian() ? 1 : 0;
16339}
16340
16341SDValue DAGCombiner::foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
16342 const TargetLowering &TLI) {
16343 // If this is not a bitcast to an FP type or if the target doesn't have
16344 // IEEE754-compliant FP logic, we're done.
16345 EVT VT = N->getValueType(ResNo: 0);
16346 SDValue N0 = N->getOperand(Num: 0);
16347 EVT SourceVT = N0.getValueType();
16348
16349 if (!VT.isFloatingPoint())
16350 return SDValue();
16351
16352 // TODO: Handle cases where the integer constant is a different scalar
16353 // bitwidth to the FP.
16354 if (VT.getScalarSizeInBits() != SourceVT.getScalarSizeInBits())
16355 return SDValue();
16356
16357 unsigned FPOpcode;
16358 APInt SignMask;
16359 switch (N0.getOpcode()) {
16360 case ISD::AND:
16361 FPOpcode = ISD::FABS;
16362 SignMask = ~APInt::getSignMask(BitWidth: SourceVT.getScalarSizeInBits());
16363 break;
16364 case ISD::XOR:
16365 FPOpcode = ISD::FNEG;
16366 SignMask = APInt::getSignMask(BitWidth: SourceVT.getScalarSizeInBits());
16367 break;
16368 case ISD::OR:
16369 FPOpcode = ISD::FABS;
16370 SignMask = APInt::getSignMask(BitWidth: SourceVT.getScalarSizeInBits());
16371 break;
16372 default:
16373 return SDValue();
16374 }
16375
16376 if (LegalOperations && !TLI.isOperationLegal(Op: FPOpcode, VT))
16377 return SDValue();
16378
16379 // This needs to be the inverse of logic in foldSignChangeInBitcast.
16380 // FIXME: I don't think looking for bitcast intrinsically makes sense, but
16381 // removing this would require more changes.
16382 auto IsBitCastOrFree = [&TLI, FPOpcode](SDValue Op, EVT VT) {
16383 if (sd_match(N: Op, P: m_BitCast(Op: m_SpecificVT(RefVT: VT))))
16384 return true;
16385
16386 return FPOpcode == ISD::FABS ? TLI.isFAbsFree(VT) : TLI.isFNegFree(VT);
16387 };
16388
16389 // Fold (bitcast int (and (bitcast fp X to int), 0x7fff...) to fp) -> fabs X
16390 // Fold (bitcast int (xor (bitcast fp X to int), 0x8000...) to fp) -> fneg X
16391 // Fold (bitcast int (or (bitcast fp X to int), 0x8000...) to fp) ->
16392 // fneg (fabs X)
16393 SDValue LogicOp0 = N0.getOperand(i: 0);
16394 ConstantSDNode *LogicOp1 = isConstOrConstSplat(N: N0.getOperand(i: 1), AllowUndefs: true);
16395 if (LogicOp1 && LogicOp1->getAPIntValue() == SignMask &&
16396 IsBitCastOrFree(LogicOp0, VT)) {
16397 SDValue CastOp0 = DAG.getNode(Opcode: ISD::BITCAST, DL: SDLoc(N), VT, Operand: LogicOp0);
16398 SDValue FPOp = DAG.getNode(Opcode: FPOpcode, DL: SDLoc(N), VT, Operand: CastOp0);
16399 NumFPLogicOpsConv++;
16400 if (N0.getOpcode() == ISD::OR)
16401 return DAG.getNode(Opcode: ISD::FNEG, DL: SDLoc(N), VT, Operand: FPOp);
16402 return FPOp;
16403 }
16404
16405 return SDValue();
16406}
16407
16408SDValue DAGCombiner::visitBITCAST(SDNode *N) {
16409 SDValue N0 = N->getOperand(Num: 0);
16410 EVT VT = N->getValueType(ResNo: 0);
16411
16412 if (N0.isUndef())
16413 return DAG.getUNDEF(VT);
16414
16415 // If the input is a BUILD_VECTOR with all constant elements, fold this now.
16416 // Only do this before legalize types, unless both types are integer and the
16417 // scalar type is legal. Only do this before legalize ops, since the target
16418 // maybe depending on the bitcast.
16419 // First check to see if this is all constant.
16420 // TODO: Support FP bitcasts after legalize types.
16421 if (VT.isVector() &&
16422 (!LegalTypes ||
16423 (!LegalOperations && VT.isInteger() && N0.getValueType().isInteger() &&
16424 TLI.isTypeLegal(VT: VT.getVectorElementType()))) &&
16425 N0.getOpcode() == ISD::BUILD_VECTOR && N0->hasOneUse() &&
16426 cast<BuildVectorSDNode>(Val&: N0)->isConstant())
16427 return ConstantFoldBITCASTofBUILD_VECTOR(N0.getNode(),
16428 VT.getVectorElementType());
16429
16430 // If the input is a constant, let getNode fold it.
16431 if (isIntOrFPConstant(V: N0)) {
16432 // If we can't allow illegal operations, we need to check that this is just
16433 // a fp -> int or int -> conversion and that the resulting operation will
16434 // be legal.
16435 if (!LegalOperations ||
16436 (isa<ConstantSDNode>(Val: N0) && VT.isFloatingPoint() && !VT.isVector() &&
16437 TLI.isOperationLegal(Op: ISD::ConstantFP, VT)) ||
16438 (isa<ConstantFPSDNode>(Val: N0) && VT.isInteger() && !VT.isVector() &&
16439 TLI.isOperationLegal(Op: ISD::Constant, VT))) {
16440 SDValue C = DAG.getBitcast(VT, V: N0);
16441 if (C.getNode() != N)
16442 return C;
16443 }
16444 }
16445
16446 // (conv (conv x, t1), t2) -> (conv x, t2)
16447 if (N0.getOpcode() == ISD::BITCAST)
16448 return DAG.getBitcast(VT, V: N0.getOperand(i: 0));
16449
16450 // fold (conv (logicop (conv x), (c))) -> (logicop x, (conv c))
16451 // iff the current bitwise logicop type isn't legal
16452 if (ISD::isBitwiseLogicOp(Opcode: N0.getOpcode()) && VT.isInteger() &&
16453 !TLI.isTypeLegal(VT: N0.getOperand(i: 0).getValueType())) {
16454 auto IsFreeBitcast = [VT](SDValue V) {
16455 return (V.getOpcode() == ISD::BITCAST &&
16456 V.getOperand(i: 0).getValueType() == VT) ||
16457 (ISD::isBuildVectorOfConstantSDNodes(N: V.getNode()) &&
16458 V->hasOneUse());
16459 };
16460 if (IsFreeBitcast(N0.getOperand(i: 0)) && IsFreeBitcast(N0.getOperand(i: 1)))
16461 return DAG.getNode(Opcode: N0.getOpcode(), DL: SDLoc(N), VT,
16462 N1: DAG.getBitcast(VT, V: N0.getOperand(i: 0)),
16463 N2: DAG.getBitcast(VT, V: N0.getOperand(i: 1)));
16464 }
16465
16466 // fold (conv (load x)) -> (load (conv*)x)
16467 // If the resultant load doesn't need a higher alignment than the original!
16468 if (ISD::isNormalLoad(N: N0.getNode()) && N0.hasOneUse() &&
16469 // Do not remove the cast if the types differ in endian layout.
16470 TLI.hasBigEndianPartOrdering(VT: N0.getValueType(), DL: DAG.getDataLayout()) ==
16471 TLI.hasBigEndianPartOrdering(VT, DL: DAG.getDataLayout()) &&
16472 // If the load is volatile, we only want to change the load type if the
16473 // resulting load is legal. Otherwise we might increase the number of
16474 // memory accesses. We don't care if the original type was legal or not
16475 // as we assume software couldn't rely on the number of accesses of an
16476 // illegal type.
16477 ((!LegalOperations && cast<LoadSDNode>(Val&: N0)->isSimple()) ||
16478 TLI.isOperationLegal(Op: ISD::LOAD, VT))) {
16479 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
16480
16481 if (TLI.isLoadBitCastBeneficial(LoadVT: N0.getValueType(), BitcastVT: VT, DAG,
16482 MMO: *LN0->getMemOperand())) {
16483 // If the range metadata type does not match the new memory
16484 // operation type, remove the range metadata.
16485 if (const MDNode *MD = LN0->getRanges()) {
16486 ConstantInt *Lower = mdconst::extract<ConstantInt>(MD: MD->getOperand(I: 0));
16487 if (Lower->getBitWidth() != VT.getScalarSizeInBits() ||
16488 !VT.isInteger()) {
16489 LN0->getMemOperand()->clearRanges();
16490 }
16491 }
16492 SDValue Load =
16493 DAG.getLoad(VT, dl: SDLoc(N), Chain: LN0->getChain(), Ptr: LN0->getBasePtr(),
16494 MMO: LN0->getMemOperand());
16495 DAG.ReplaceAllUsesOfValueWith(From: N0.getValue(R: 1), To: Load.getValue(R: 1));
16496 return Load;
16497 }
16498 }
16499
16500 if (SDValue V = foldBitcastedFPLogic(N, DAG, TLI))
16501 return V;
16502
16503 // fold (bitconvert (fneg x)) -> (xor (bitconvert x), signbit)
16504 // fold (bitconvert (fabs x)) -> (and (bitconvert x), (not signbit))
16505 //
16506 // For ppc_fp128:
16507 // fold (bitcast (fneg x)) ->
16508 // flipbit = signbit
16509 // (xor (bitcast x) (build_pair flipbit, flipbit))
16510 //
16511 // fold (bitcast (fabs x)) ->
16512 // flipbit = (and (extract_element (bitcast x), 0), signbit)
16513 // (xor (bitcast x) (build_pair flipbit, flipbit))
16514 // This often reduces constant pool loads.
16515 if (((N0.getOpcode() == ISD::FNEG && !TLI.isFNegFree(VT: N0.getValueType())) ||
16516 (N0.getOpcode() == ISD::FABS && !TLI.isFAbsFree(VT: N0.getValueType()))) &&
16517 N0->hasOneUse() && VT.isInteger() && !VT.isVector() &&
16518 !N0.getValueType().isVector()) {
16519 SDValue NewConv = DAG.getBitcast(VT, V: N0.getOperand(i: 0));
16520 AddToWorklist(N: NewConv.getNode());
16521
16522 SDLoc DL(N);
16523 if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
16524 assert(VT.getSizeInBits() == 128);
16525 SDValue SignBit = DAG.getConstant(
16526 Val: APInt::getSignMask(BitWidth: VT.getSizeInBits() / 2), DL: SDLoc(N0), VT: MVT::i64);
16527 SDValue FlipBit;
16528 if (N0.getOpcode() == ISD::FNEG) {
16529 FlipBit = SignBit;
16530 AddToWorklist(N: FlipBit.getNode());
16531 } else {
16532 assert(N0.getOpcode() == ISD::FABS);
16533 SDValue Hi =
16534 DAG.getNode(Opcode: ISD::EXTRACT_ELEMENT, DL: SDLoc(NewConv), VT: MVT::i64, N1: NewConv,
16535 N2: DAG.getIntPtrConstant(Val: getPPCf128HiElementSelector(DAG),
16536 DL: SDLoc(NewConv)));
16537 AddToWorklist(N: Hi.getNode());
16538 FlipBit = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(N0), VT: MVT::i64, N1: Hi, N2: SignBit);
16539 AddToWorklist(N: FlipBit.getNode());
16540 }
16541 SDValue FlipBits =
16542 DAG.getNode(Opcode: ISD::BUILD_PAIR, DL: SDLoc(N0), VT, N1: FlipBit, N2: FlipBit);
16543 AddToWorklist(N: FlipBits.getNode());
16544 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: NewConv, N2: FlipBits);
16545 }
16546 APInt SignBit = APInt::getSignMask(BitWidth: VT.getSizeInBits());
16547 if (N0.getOpcode() == ISD::FNEG)
16548 return DAG.getNode(Opcode: ISD::XOR, DL, VT,
16549 N1: NewConv, N2: DAG.getConstant(Val: SignBit, DL, VT));
16550 assert(N0.getOpcode() == ISD::FABS);
16551 return DAG.getNode(Opcode: ISD::AND, DL, VT,
16552 N1: NewConv, N2: DAG.getConstant(Val: ~SignBit, DL, VT));
16553 }
16554
16555 // fold (bitconvert (fcopysign cst, x)) ->
16556 // (or (and (bitconvert x), sign), (and cst, (not sign)))
16557 // Note that we don't handle (copysign x, cst) because this can always be
16558 // folded to an fneg or fabs.
16559 //
16560 // For ppc_fp128:
16561 // fold (bitcast (fcopysign cst, x)) ->
16562 // flipbit = (and (extract_element
16563 // (xor (bitcast cst), (bitcast x)), 0),
16564 // signbit)
16565 // (xor (bitcast cst) (build_pair flipbit, flipbit))
16566 if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse() &&
16567 isa<ConstantFPSDNode>(Val: N0.getOperand(i: 0)) && VT.isInteger() &&
16568 !VT.isVector()) {
16569 unsigned OrigXWidth = N0.getOperand(i: 1).getValueSizeInBits();
16570 EVT IntXVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: OrigXWidth);
16571 if (isTypeLegal(VT: IntXVT)) {
16572 SDValue X = DAG.getBitcast(VT: IntXVT, V: N0.getOperand(i: 1));
16573 AddToWorklist(N: X.getNode());
16574
16575 // If X has a different width than the result/lhs, sext it or truncate it.
16576 unsigned VTWidth = VT.getSizeInBits();
16577 if (OrigXWidth < VTWidth) {
16578 X = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL: SDLoc(N), VT, Operand: X);
16579 AddToWorklist(N: X.getNode());
16580 } else if (OrigXWidth > VTWidth) {
16581 // To get the sign bit in the right place, we have to shift it right
16582 // before truncating.
16583 SDLoc DL(X);
16584 X = DAG.getNode(Opcode: ISD::SRL, DL,
16585 VT: X.getValueType(), N1: X,
16586 N2: DAG.getConstant(Val: OrigXWidth-VTWidth, DL,
16587 VT: X.getValueType()));
16588 AddToWorklist(N: X.getNode());
16589 X = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(X), VT, Operand: X);
16590 AddToWorklist(N: X.getNode());
16591 }
16592
16593 if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
16594 APInt SignBit = APInt::getSignMask(BitWidth: VT.getSizeInBits() / 2);
16595 SDValue Cst = DAG.getBitcast(VT, V: N0.getOperand(i: 0));
16596 AddToWorklist(N: Cst.getNode());
16597 SDValue X = DAG.getBitcast(VT, V: N0.getOperand(i: 1));
16598 AddToWorklist(N: X.getNode());
16599 SDValue XorResult = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N0), VT, N1: Cst, N2: X);
16600 AddToWorklist(N: XorResult.getNode());
16601 SDValue XorResult64 = DAG.getNode(
16602 Opcode: ISD::EXTRACT_ELEMENT, DL: SDLoc(XorResult), VT: MVT::i64, N1: XorResult,
16603 N2: DAG.getIntPtrConstant(Val: getPPCf128HiElementSelector(DAG),
16604 DL: SDLoc(XorResult)));
16605 AddToWorklist(N: XorResult64.getNode());
16606 SDValue FlipBit =
16607 DAG.getNode(Opcode: ISD::AND, DL: SDLoc(XorResult64), VT: MVT::i64, N1: XorResult64,
16608 N2: DAG.getConstant(Val: SignBit, DL: SDLoc(XorResult64), VT: MVT::i64));
16609 AddToWorklist(N: FlipBit.getNode());
16610 SDValue FlipBits =
16611 DAG.getNode(Opcode: ISD::BUILD_PAIR, DL: SDLoc(N0), VT, N1: FlipBit, N2: FlipBit);
16612 AddToWorklist(N: FlipBits.getNode());
16613 return DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N), VT, N1: Cst, N2: FlipBits);
16614 }
16615 APInt SignBit = APInt::getSignMask(BitWidth: VT.getSizeInBits());
16616 X = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(X), VT,
16617 N1: X, N2: DAG.getConstant(Val: SignBit, DL: SDLoc(X), VT));
16618 AddToWorklist(N: X.getNode());
16619
16620 SDValue Cst = DAG.getBitcast(VT, V: N0.getOperand(i: 0));
16621 Cst = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(Cst), VT,
16622 N1: Cst, N2: DAG.getConstant(Val: ~SignBit, DL: SDLoc(Cst), VT));
16623 AddToWorklist(N: Cst.getNode());
16624
16625 return DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N), VT, N1: X, N2: Cst);
16626 }
16627 }
16628
16629 // bitconvert(build_pair(ld, ld)) -> ld iff load locations are consecutive.
16630 if (N0.getOpcode() == ISD::BUILD_PAIR)
16631 if (SDValue CombineLD = CombineConsecutiveLoads(N: N0.getNode(), VT))
16632 return CombineLD;
16633
16634 // int_vt (bitcast (vec_vt (scalar_to_vector elt_vt:x)))
16635 // => int_vt (any_extend elt_vt:x)
16636 if (N0.getOpcode() == ISD::SCALAR_TO_VECTOR && VT.isScalarInteger()) {
16637 SDValue SrcScalar = N0.getOperand(i: 0);
16638 if (SrcScalar.getValueType().isScalarInteger())
16639 return DAG.getNode(Opcode: ISD::ANY_EXTEND, DL: SDLoc(N), VT, Operand: SrcScalar);
16640 }
16641
16642 // Remove double bitcasts from shuffles - this is often a legacy of
16643 // XformToShuffleWithZero being used to combine bitmaskings (of
16644 // float vectors bitcast to integer vectors) into shuffles.
16645 // bitcast(shuffle(bitcast(s0),bitcast(s1))) -> shuffle(s0,s1)
16646 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT) && VT.isVector() &&
16647 N0->getOpcode() == ISD::VECTOR_SHUFFLE && N0.hasOneUse() &&
16648 VT.getVectorNumElements() >= N0.getValueType().getVectorNumElements() &&
16649 !(VT.getVectorNumElements() % N0.getValueType().getVectorNumElements())) {
16650 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Val&: N0);
16651
16652 // If operands are a bitcast, peek through if it casts the original VT.
16653 // If operands are a constant, just bitcast back to original VT.
16654 auto PeekThroughBitcast = [&](SDValue Op) {
16655 if (Op.getOpcode() == ISD::BITCAST &&
16656 Op.getOperand(i: 0).getValueType() == VT)
16657 return SDValue(Op.getOperand(i: 0));
16658 if (Op.isUndef() || isAnyConstantBuildVector(V: Op))
16659 return DAG.getBitcast(VT, V: Op);
16660 return SDValue();
16661 };
16662
16663 // FIXME: If either input vector is bitcast, try to convert the shuffle to
16664 // the result type of this bitcast. This would eliminate at least one
16665 // bitcast. See the transform in InstCombine.
16666 SDValue SV0 = PeekThroughBitcast(N0->getOperand(Num: 0));
16667 SDValue SV1 = PeekThroughBitcast(N0->getOperand(Num: 1));
16668 if (!(SV0 && SV1))
16669 return SDValue();
16670
16671 int MaskScale =
16672 VT.getVectorNumElements() / N0.getValueType().getVectorNumElements();
16673 SmallVector<int, 8> NewMask;
16674 for (int M : SVN->getMask())
16675 for (int i = 0; i != MaskScale; ++i)
16676 NewMask.push_back(Elt: M < 0 ? -1 : M * MaskScale + i);
16677
16678 SDValue LegalShuffle =
16679 TLI.buildLegalVectorShuffle(VT, DL: SDLoc(N), N0: SV0, N1: SV1, Mask: NewMask, DAG);
16680 if (LegalShuffle)
16681 return LegalShuffle;
16682 }
16683
16684 return SDValue();
16685}
16686
16687SDValue DAGCombiner::visitBUILD_PAIR(SDNode *N) {
16688 EVT VT = N->getValueType(ResNo: 0);
16689 return CombineConsecutiveLoads(N, VT);
16690}
16691
16692SDValue DAGCombiner::visitFREEZE(SDNode *N) {
16693 SDValue N0 = N->getOperand(Num: 0);
16694
16695 if (DAG.isGuaranteedNotToBeUndefOrPoison(Op: N0, /*PoisonOnly*/ false))
16696 return N0;
16697
16698 // We currently avoid folding freeze over SRA/SRL, due to the problems seen
16699 // with (freeze (assert ext)) blocking simplifications of SRA/SRL. See for
16700 // example https://reviews.llvm.org/D136529#4120959.
16701 if (N0.getOpcode() == ISD::SRA || N0.getOpcode() == ISD::SRL)
16702 return SDValue();
16703
16704 // Fold freeze(op(x, ...)) -> op(freeze(x), ...).
16705 // Try to push freeze through instructions that propagate but don't produce
16706 // poison as far as possible. If an operand of freeze follows three
16707 // conditions 1) one-use, and 2) does not produce poison then push
16708 // the freeze through to the operands that are not guaranteed non-poison.
16709 // NOTE: we will strip poison-generating flags, so ignore them here.
16710 if (DAG.canCreateUndefOrPoison(Op: N0, /*PoisonOnly*/ false,
16711 /*ConsiderFlags*/ false) ||
16712 N0->getNumValues() != 1 || !N0->hasOneUse())
16713 return SDValue();
16714
16715 // Avoid turning a BUILD_VECTOR that can be recognized as "all zeros", "all
16716 // ones" or "constant" into something that depends on FrozenUndef. We can
16717 // instead pick undef values to keep those properties, while at the same time
16718 // folding away the freeze.
16719 // If we implement a more general solution for folding away freeze(undef) in
16720 // the future, then this special handling can be removed.
16721 if (N0.getOpcode() == ISD::BUILD_VECTOR) {
16722 SDLoc DL(N0);
16723 EVT VT = N0.getValueType();
16724 if (llvm::ISD::isBuildVectorAllOnes(N: N0.getNode()) && VT.isInteger())
16725 return DAG.getAllOnesConstant(DL, VT);
16726 if (llvm::ISD::isBuildVectorOfConstantSDNodes(N: N0.getNode())) {
16727 SmallVector<SDValue, 8> NewVecC;
16728 for (const SDValue &Op : N0->op_values())
16729 NewVecC.push_back(
16730 Elt: Op.isUndef() ? DAG.getConstant(Val: 0, DL, VT: Op.getValueType()) : Op);
16731 return DAG.getBuildVector(VT, DL, Ops: NewVecC);
16732 }
16733 }
16734
16735 SmallSet<SDValue, 8> MaybePoisonOperands;
16736 SmallVector<unsigned, 8> MaybePoisonOperandNumbers;
16737 for (auto [OpNo, Op] : enumerate(First: N0->ops())) {
16738 if (DAG.isGuaranteedNotToBeUndefOrPoison(Op, /*PoisonOnly*/ false,
16739 /*Depth*/ 1))
16740 continue;
16741 bool HadMaybePoisonOperands = !MaybePoisonOperands.empty();
16742 bool IsNewMaybePoisonOperand = MaybePoisonOperands.insert(V: Op).second;
16743 if (IsNewMaybePoisonOperand)
16744 MaybePoisonOperandNumbers.push_back(Elt: OpNo);
16745 if (!HadMaybePoisonOperands)
16746 continue;
16747 }
16748 // NOTE: the whole op may be not guaranteed to not be undef or poison because
16749 // it could create undef or poison due to it's poison-generating flags.
16750 // So not finding any maybe-poison operands is fine.
16751
16752 for (unsigned OpNo : MaybePoisonOperandNumbers) {
16753 // N0 can mutate during iteration, so make sure to refetch the maybe poison
16754 // operands via the operand numbers. The typical scenario is that we have
16755 // something like this
16756 // t262: i32 = freeze t181
16757 // t150: i32 = ctlz_zero_undef t262
16758 // t184: i32 = ctlz_zero_undef t181
16759 // t268: i32 = select_cc t181, Constant:i32<0>, t184, t186, setne:ch
16760 // When freezing the t181 operand we get t262 back, and then the
16761 // ReplaceAllUsesOfValueWith call will not only replace t181 by t262, but
16762 // also recursively replace t184 by t150.
16763 SDValue MaybePoisonOperand = N->getOperand(Num: 0).getOperand(i: OpNo);
16764 // Don't replace every single UNDEF everywhere with frozen UNDEF, though.
16765 if (MaybePoisonOperand.isUndef())
16766 continue;
16767 // First, freeze each offending operand.
16768 SDValue FrozenMaybePoisonOperand = DAG.getFreeze(V: MaybePoisonOperand);
16769 // Then, change all other uses of unfrozen operand to use frozen operand.
16770 DAG.ReplaceAllUsesOfValueWith(From: MaybePoisonOperand, To: FrozenMaybePoisonOperand);
16771 if (FrozenMaybePoisonOperand.getOpcode() == ISD::FREEZE &&
16772 FrozenMaybePoisonOperand.getOperand(i: 0) == FrozenMaybePoisonOperand) {
16773 // But, that also updated the use in the freeze we just created, thus
16774 // creating a cycle in a DAG. Let's undo that by mutating the freeze.
16775 DAG.UpdateNodeOperands(N: FrozenMaybePoisonOperand.getNode(),
16776 Op: MaybePoisonOperand);
16777 }
16778
16779 // This node has been merged with another.
16780 if (N->getOpcode() == ISD::DELETED_NODE)
16781 return SDValue(N, 0);
16782 }
16783
16784 assert(N->getOpcode() != ISD::DELETED_NODE && "Node was deleted!");
16785
16786 // The whole node may have been updated, so the value we were holding
16787 // may no longer be valid. Re-fetch the operand we're `freeze`ing.
16788 N0 = N->getOperand(Num: 0);
16789
16790 // Finally, recreate the node, it's operands were updated to use
16791 // frozen operands, so we just need to use it's "original" operands.
16792 SmallVector<SDValue> Ops(N0->ops());
16793 // TODO: ISD::UNDEF and ISD::POISON should get separate handling, but best
16794 // leave for a future patch.
16795 for (SDValue &Op : Ops) {
16796 if (Op.isUndef())
16797 Op = DAG.getFreeze(V: Op);
16798 }
16799
16800 SDLoc DL(N0);
16801
16802 // Special case handling for ShuffleVectorSDNode nodes.
16803 if (auto *SVN = dyn_cast<ShuffleVectorSDNode>(Val&: N0))
16804 return DAG.getVectorShuffle(VT: N0.getValueType(), dl: DL, N1: Ops[0], N2: Ops[1],
16805 Mask: SVN->getMask());
16806
16807 // NOTE: this strips poison generating flags.
16808 // Folding freeze(op(x, ...)) -> op(freeze(x), ...) does not require nnan,
16809 // ninf, nsz, or fast.
16810 // However, contract, reassoc, afn, and arcp should be preserved,
16811 // as these fast-math flags do not introduce poison values.
16812 SDNodeFlags SrcFlags = N0->getFlags();
16813 SDNodeFlags SafeFlags;
16814 SafeFlags.setAllowContract(SrcFlags.hasAllowContract());
16815 SafeFlags.setAllowReassociation(SrcFlags.hasAllowReassociation());
16816 SafeFlags.setApproximateFuncs(SrcFlags.hasApproximateFuncs());
16817 SafeFlags.setAllowReciprocal(SrcFlags.hasAllowReciprocal());
16818 return DAG.getNode(Opcode: N0.getOpcode(), DL, VTList: N0->getVTList(), Ops, Flags: SafeFlags);
16819}
16820
16821/// We know that BV is a build_vector node with Constant, ConstantFP or Undef
16822/// operands. DstEltVT indicates the destination element value type.
16823SDValue DAGCombiner::
16824ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
16825 EVT SrcEltVT = BV->getValueType(ResNo: 0).getVectorElementType();
16826
16827 // If this is already the right type, we're done.
16828 if (SrcEltVT == DstEltVT) return SDValue(BV, 0);
16829
16830 unsigned SrcBitSize = SrcEltVT.getSizeInBits();
16831 unsigned DstBitSize = DstEltVT.getSizeInBits();
16832
16833 // If this is a conversion of N elements of one type to N elements of another
16834 // type, convert each element. This handles FP<->INT cases.
16835 if (SrcBitSize == DstBitSize) {
16836 SmallVector<SDValue, 8> Ops;
16837 for (SDValue Op : BV->op_values()) {
16838 // If the vector element type is not legal, the BUILD_VECTOR operands
16839 // are promoted and implicitly truncated. Make that explicit here.
16840 if (Op.getValueType() != SrcEltVT)
16841 Op = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(BV), VT: SrcEltVT, Operand: Op);
16842 Ops.push_back(Elt: DAG.getBitcast(VT: DstEltVT, V: Op));
16843 AddToWorklist(N: Ops.back().getNode());
16844 }
16845 EVT VT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: DstEltVT,
16846 NumElements: BV->getValueType(ResNo: 0).getVectorNumElements());
16847 return DAG.getBuildVector(VT, DL: SDLoc(BV), Ops);
16848 }
16849
16850 // Otherwise, we're growing or shrinking the elements. To avoid having to
16851 // handle annoying details of growing/shrinking FP values, we convert them to
16852 // int first.
16853 if (SrcEltVT.isFloatingPoint()) {
16854 // Convert the input float vector to a int vector where the elements are the
16855 // same sizes.
16856 EVT IntVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SrcEltVT.getSizeInBits());
16857 BV = ConstantFoldBITCASTofBUILD_VECTOR(BV, DstEltVT: IntVT).getNode();
16858 SrcEltVT = IntVT;
16859 }
16860
16861 // Now we know the input is an integer vector. If the output is a FP type,
16862 // convert to integer first, then to FP of the right size.
16863 if (DstEltVT.isFloatingPoint()) {
16864 EVT TmpVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: DstEltVT.getSizeInBits());
16865 SDNode *Tmp = ConstantFoldBITCASTofBUILD_VECTOR(BV, DstEltVT: TmpVT).getNode();
16866
16867 // Next, convert to FP elements of the same size.
16868 return ConstantFoldBITCASTofBUILD_VECTOR(BV: Tmp, DstEltVT);
16869 }
16870
16871 // Okay, we know the src/dst types are both integers of differing types.
16872 assert(SrcEltVT.isInteger() && DstEltVT.isInteger());
16873
16874 // TODO: Should ConstantFoldBITCASTofBUILD_VECTOR always take a
16875 // BuildVectorSDNode?
16876 auto *BVN = cast<BuildVectorSDNode>(Val: BV);
16877
16878 // Extract the constant raw bit data.
16879 BitVector UndefElements;
16880 SmallVector<APInt> RawBits;
16881 bool IsLE = DAG.getDataLayout().isLittleEndian();
16882 if (!BVN->getConstantRawBits(IsLittleEndian: IsLE, DstEltSizeInBits: DstBitSize, RawBitElements&: RawBits, UndefElements))
16883 return SDValue();
16884
16885 SDLoc DL(BV);
16886 SmallVector<SDValue, 8> Ops;
16887 for (unsigned I = 0, E = RawBits.size(); I != E; ++I) {
16888 if (UndefElements[I])
16889 Ops.push_back(Elt: DAG.getUNDEF(VT: DstEltVT));
16890 else
16891 Ops.push_back(Elt: DAG.getConstant(Val: RawBits[I], DL, VT: DstEltVT));
16892 }
16893
16894 EVT VT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: DstEltVT, NumElements: Ops.size());
16895 return DAG.getBuildVector(VT, DL, Ops);
16896}
16897
16898// Returns true if floating point contraction is allowed on the FMUL-SDValue
16899// `N`
16900static bool isContractableFMUL(const TargetOptions &Options, SDValue N) {
16901 assert(N.getOpcode() == ISD::FMUL);
16902
16903 return Options.AllowFPOpFusion == FPOpFusion::Fast ||
16904 N->getFlags().hasAllowContract();
16905}
16906
16907// Returns true if `N` can assume no infinities involved in its computation.
16908static bool hasNoInfs(const TargetOptions &Options, SDValue N) {
16909 return Options.NoInfsFPMath || N->getFlags().hasNoInfs();
16910}
16911
16912/// Try to perform FMA combining on a given FADD node.
16913template <class MatchContextClass>
16914SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
16915 SDValue N0 = N->getOperand(Num: 0);
16916 SDValue N1 = N->getOperand(Num: 1);
16917 EVT VT = N->getValueType(ResNo: 0);
16918 SDLoc SL(N);
16919 MatchContextClass matcher(DAG, TLI, N);
16920 const TargetOptions &Options = DAG.getTarget().Options;
16921
16922 bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
16923
16924 // Floating-point multiply-add with intermediate rounding.
16925 // FIXME: Make isFMADLegal have specific behavior when using VPMatchContext.
16926 // FIXME: Add VP_FMAD opcode.
16927 bool HasFMAD = !UseVP && (LegalOperations && TLI.isFMADLegal(DAG, N));
16928
16929 // Floating-point multiply-add without intermediate rounding.
16930 bool HasFMA =
16931 (!LegalOperations || matcher.isOperationLegalOrCustom(ISD::FMA, VT)) &&
16932 TLI.isFMAFasterThanFMulAndFAdd(MF: DAG.getMachineFunction(), VT);
16933
16934 // No valid opcode, do not combine.
16935 if (!HasFMAD && !HasFMA)
16936 return SDValue();
16937
16938 bool AllowFusionGlobally =
16939 Options.AllowFPOpFusion == FPOpFusion::Fast || HasFMAD;
16940 // If the addition is not contractable, do not combine.
16941 if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
16942 return SDValue();
16943
16944 // Folding fadd (fmul x, y), (fmul x, y) -> fma x, y, (fmul x, y) is never
16945 // beneficial. It does not reduce latency. It increases register pressure. It
16946 // replaces an fadd with an fma which is a more complex instruction, so is
16947 // likely to have a larger encoding, use more functional units, etc.
16948 if (N0 == N1)
16949 return SDValue();
16950
16951 if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
16952 return SDValue();
16953
16954 // Always prefer FMAD to FMA for precision.
16955 unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
16956 bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
16957
16958 auto isFusedOp = [&](SDValue N) {
16959 return matcher.match(N, ISD::FMA) || matcher.match(N, ISD::FMAD);
16960 };
16961
16962 // Is the node an FMUL and contractable either due to global flags or
16963 // SDNodeFlags.
16964 auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) {
16965 if (!matcher.match(N, ISD::FMUL))
16966 return false;
16967 return AllowFusionGlobally || N->getFlags().hasAllowContract();
16968 };
16969 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
16970 // prefer to fold the multiply with fewer uses.
16971 if (Aggressive && isContractableFMUL(N0) && isContractableFMUL(N1)) {
16972 if (N0->use_size() > N1->use_size())
16973 std::swap(a&: N0, b&: N1);
16974 }
16975
16976 // fold (fadd (fmul x, y), z) -> (fma x, y, z)
16977 if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) {
16978 return matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(i: 0),
16979 N0.getOperand(i: 1), N1);
16980 }
16981
16982 // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
16983 // Note: Commutes FADD operands.
16984 if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) {
16985 return matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(i: 0),
16986 N1.getOperand(i: 1), N0);
16987 }
16988
16989 // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E)
16990 // fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E)
16991 // This also works with nested fma instructions:
16992 // fadd (fma A, B, (fma (C, D, (fmul (E, F))))), G -->
16993 // fma A, B, (fma C, D, fma (E, F, G))
16994 // fadd (G, (fma A, B, (fma (C, D, (fmul (E, F)))))) -->
16995 // fma A, B, (fma C, D, fma (E, F, G)).
16996 // This requires reassociation because it changes the order of operations.
16997 bool CanReassociate =
16998 Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
16999 if (CanReassociate) {
17000 SDValue FMA, E;
17001 if (isFusedOp(N0) && N0.hasOneUse()) {
17002 FMA = N0;
17003 E = N1;
17004 } else if (isFusedOp(N1) && N1.hasOneUse()) {
17005 FMA = N1;
17006 E = N0;
17007 }
17008
17009 SDValue TmpFMA = FMA;
17010 while (E && isFusedOp(TmpFMA) && TmpFMA.hasOneUse()) {
17011 SDValue FMul = TmpFMA->getOperand(Num: 2);
17012 if (matcher.match(FMul, ISD::FMUL) && FMul.hasOneUse()) {
17013 SDValue C = FMul.getOperand(i: 0);
17014 SDValue D = FMul.getOperand(i: 1);
17015 SDValue CDE = matcher.getNode(PreferredFusedOpcode, SL, VT, C, D, E);
17016 DAG.ReplaceAllUsesOfValueWith(From: FMul, To: CDE);
17017 // Replacing the inner FMul could cause the outer FMA to be simplified
17018 // away.
17019 return FMA.getOpcode() == ISD::DELETED_NODE ? SDValue(N, 0) : FMA;
17020 }
17021
17022 TmpFMA = TmpFMA->getOperand(Num: 2);
17023 }
17024 }
17025
17026 // Look through FP_EXTEND nodes to do more combining.
17027
17028 // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
17029 if (matcher.match(N0, ISD::FP_EXTEND)) {
17030 SDValue N00 = N0.getOperand(i: 0);
17031 if (isContractableFMUL(N00) &&
17032 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
17033 SrcVT: N00.getValueType())) {
17034 return matcher.getNode(
17035 PreferredFusedOpcode, SL, VT,
17036 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 0)),
17037 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 1)), N1);
17038 }
17039 }
17040
17041 // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x)
17042 // Note: Commutes FADD operands.
17043 if (matcher.match(N1, ISD::FP_EXTEND)) {
17044 SDValue N10 = N1.getOperand(i: 0);
17045 if (isContractableFMUL(N10) &&
17046 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
17047 SrcVT: N10.getValueType())) {
17048 return matcher.getNode(
17049 PreferredFusedOpcode, SL, VT,
17050 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(i: 0)),
17051 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(i: 1)), N0);
17052 }
17053 }
17054
17055 // More folding opportunities when target permits.
17056 if (Aggressive) {
17057 // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
17058 // -> (fma x, y, (fma (fpext u), (fpext v), z))
17059 auto FoldFAddFMAFPExtFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
17060 SDValue Z) {
17061 return matcher.getNode(
17062 PreferredFusedOpcode, SL, VT, X, Y,
17063 matcher.getNode(PreferredFusedOpcode, SL, VT,
17064 matcher.getNode(ISD::FP_EXTEND, SL, VT, U),
17065 matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
17066 };
17067 if (isFusedOp(N0)) {
17068 SDValue N02 = N0.getOperand(i: 2);
17069 if (matcher.match(N02, ISD::FP_EXTEND)) {
17070 SDValue N020 = N02.getOperand(i: 0);
17071 if (isContractableFMUL(N020) &&
17072 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
17073 SrcVT: N020.getValueType())) {
17074 return FoldFAddFMAFPExtFMul(N0.getOperand(i: 0), N0.getOperand(i: 1),
17075 N020.getOperand(i: 0), N020.getOperand(i: 1),
17076 N1);
17077 }
17078 }
17079 }
17080
17081 // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
17082 // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
17083 // FIXME: This turns two single-precision and one double-precision
17084 // operation into two double-precision operations, which might not be
17085 // interesting for all targets, especially GPUs.
17086 auto FoldFAddFPExtFMAFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
17087 SDValue Z) {
17088 return matcher.getNode(
17089 PreferredFusedOpcode, SL, VT,
17090 matcher.getNode(ISD::FP_EXTEND, SL, VT, X),
17091 matcher.getNode(ISD::FP_EXTEND, SL, VT, Y),
17092 matcher.getNode(PreferredFusedOpcode, SL, VT,
17093 matcher.getNode(ISD::FP_EXTEND, SL, VT, U),
17094 matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
17095 };
17096 if (N0.getOpcode() == ISD::FP_EXTEND) {
17097 SDValue N00 = N0.getOperand(i: 0);
17098 if (isFusedOp(N00)) {
17099 SDValue N002 = N00.getOperand(i: 2);
17100 if (isContractableFMUL(N002) &&
17101 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
17102 SrcVT: N00.getValueType())) {
17103 return FoldFAddFPExtFMAFMul(N00.getOperand(i: 0), N00.getOperand(i: 1),
17104 N002.getOperand(i: 0), N002.getOperand(i: 1),
17105 N1);
17106 }
17107 }
17108 }
17109
17110 // fold (fadd x, (fma y, z, (fpext (fmul u, v)))
17111 // -> (fma y, z, (fma (fpext u), (fpext v), x))
17112 if (isFusedOp(N1)) {
17113 SDValue N12 = N1.getOperand(i: 2);
17114 if (N12.getOpcode() == ISD::FP_EXTEND) {
17115 SDValue N120 = N12.getOperand(i: 0);
17116 if (isContractableFMUL(N120) &&
17117 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
17118 SrcVT: N120.getValueType())) {
17119 return FoldFAddFMAFPExtFMul(N1.getOperand(i: 0), N1.getOperand(i: 1),
17120 N120.getOperand(i: 0), N120.getOperand(i: 1),
17121 N0);
17122 }
17123 }
17124 }
17125
17126 // fold (fadd x, (fpext (fma y, z, (fmul u, v)))
17127 // -> (fma (fpext y), (fpext z), (fma (fpext u), (fpext v), x))
17128 // FIXME: This turns two single-precision and one double-precision
17129 // operation into two double-precision operations, which might not be
17130 // interesting for all targets, especially GPUs.
17131 if (N1.getOpcode() == ISD::FP_EXTEND) {
17132 SDValue N10 = N1.getOperand(i: 0);
17133 if (isFusedOp(N10)) {
17134 SDValue N102 = N10.getOperand(i: 2);
17135 if (isContractableFMUL(N102) &&
17136 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
17137 SrcVT: N10.getValueType())) {
17138 return FoldFAddFPExtFMAFMul(N10.getOperand(i: 0), N10.getOperand(i: 1),
17139 N102.getOperand(i: 0), N102.getOperand(i: 1),
17140 N0);
17141 }
17142 }
17143 }
17144 }
17145
17146 return SDValue();
17147}
17148
17149/// Try to perform FMA combining on a given FSUB node.
17150template <class MatchContextClass>
17151SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
17152 SDValue N0 = N->getOperand(Num: 0);
17153 SDValue N1 = N->getOperand(Num: 1);
17154 EVT VT = N->getValueType(ResNo: 0);
17155 SDLoc SL(N);
17156 MatchContextClass matcher(DAG, TLI, N);
17157 const TargetOptions &Options = DAG.getTarget().Options;
17158
17159 bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
17160
17161 // Floating-point multiply-add with intermediate rounding.
17162 // FIXME: Make isFMADLegal have specific behavior when using VPMatchContext.
17163 // FIXME: Add VP_FMAD opcode.
17164 bool HasFMAD = !UseVP && (LegalOperations && TLI.isFMADLegal(DAG, N));
17165
17166 // Floating-point multiply-add without intermediate rounding.
17167 bool HasFMA =
17168 (!LegalOperations || matcher.isOperationLegalOrCustom(ISD::FMA, VT)) &&
17169 TLI.isFMAFasterThanFMulAndFAdd(MF: DAG.getMachineFunction(), VT);
17170
17171 // No valid opcode, do not combine.
17172 if (!HasFMAD && !HasFMA)
17173 return SDValue();
17174
17175 const SDNodeFlags Flags = N->getFlags();
17176 bool AllowFusionGlobally =
17177 (Options.AllowFPOpFusion == FPOpFusion::Fast || HasFMAD);
17178
17179 // If the subtraction is not contractable, do not combine.
17180 if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
17181 return SDValue();
17182
17183 if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
17184 return SDValue();
17185
17186 // Always prefer FMAD to FMA for precision.
17187 unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
17188 bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
17189 bool NoSignedZero = Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros();
17190
17191 // Is the node an FMUL and contractable either due to global flags or
17192 // SDNodeFlags.
17193 auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) {
17194 if (!matcher.match(N, ISD::FMUL))
17195 return false;
17196 return AllowFusionGlobally || N->getFlags().hasAllowContract();
17197 };
17198
17199 // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
17200 auto tryToFoldXYSubZ = [&](SDValue XY, SDValue Z) {
17201 if (isContractableFMUL(XY) && (Aggressive || XY->hasOneUse())) {
17202 return matcher.getNode(PreferredFusedOpcode, SL, VT, XY.getOperand(i: 0),
17203 XY.getOperand(i: 1),
17204 matcher.getNode(ISD::FNEG, SL, VT, Z));
17205 }
17206 return SDValue();
17207 };
17208
17209 // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
17210 // Note: Commutes FSUB operands.
17211 auto tryToFoldXSubYZ = [&](SDValue X, SDValue YZ) {
17212 if (isContractableFMUL(YZ) && (Aggressive || YZ->hasOneUse())) {
17213 return matcher.getNode(
17214 PreferredFusedOpcode, SL, VT,
17215 matcher.getNode(ISD::FNEG, SL, VT, YZ.getOperand(i: 0)),
17216 YZ.getOperand(i: 1), X);
17217 }
17218 return SDValue();
17219 };
17220
17221 // If we have two choices trying to fold (fsub (fmul u, v), (fmul x, y)),
17222 // prefer to fold the multiply with fewer uses.
17223 if (isContractableFMUL(N0) && isContractableFMUL(N1) &&
17224 (N0->use_size() > N1->use_size())) {
17225 // fold (fsub (fmul a, b), (fmul c, d)) -> (fma (fneg c), d, (fmul a, b))
17226 if (SDValue V = tryToFoldXSubYZ(N0, N1))
17227 return V;
17228 // fold (fsub (fmul a, b), (fmul c, d)) -> (fma a, b, (fneg (fmul c, d)))
17229 if (SDValue V = tryToFoldXYSubZ(N0, N1))
17230 return V;
17231 } else {
17232 // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
17233 if (SDValue V = tryToFoldXYSubZ(N0, N1))
17234 return V;
17235 // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
17236 if (SDValue V = tryToFoldXSubYZ(N0, N1))
17237 return V;
17238 }
17239
17240 // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z))
17241 if (matcher.match(N0, ISD::FNEG) && isContractableFMUL(N0.getOperand(i: 0)) &&
17242 (Aggressive || (N0->hasOneUse() && N0.getOperand(i: 0).hasOneUse()))) {
17243 SDValue N00 = N0.getOperand(i: 0).getOperand(i: 0);
17244 SDValue N01 = N0.getOperand(i: 0).getOperand(i: 1);
17245 return matcher.getNode(PreferredFusedOpcode, SL, VT,
17246 matcher.getNode(ISD::FNEG, SL, VT, N00), N01,
17247 matcher.getNode(ISD::FNEG, SL, VT, N1));
17248 }
17249
17250 // Look through FP_EXTEND nodes to do more combining.
17251
17252 // fold (fsub (fpext (fmul x, y)), z)
17253 // -> (fma (fpext x), (fpext y), (fneg z))
17254 if (matcher.match(N0, ISD::FP_EXTEND)) {
17255 SDValue N00 = N0.getOperand(i: 0);
17256 if (isContractableFMUL(N00) &&
17257 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
17258 SrcVT: N00.getValueType())) {
17259 return matcher.getNode(
17260 PreferredFusedOpcode, SL, VT,
17261 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 0)),
17262 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 1)),
17263 matcher.getNode(ISD::FNEG, SL, VT, N1));
17264 }
17265 }
17266
17267 // fold (fsub x, (fpext (fmul y, z)))
17268 // -> (fma (fneg (fpext y)), (fpext z), x)
17269 // Note: Commutes FSUB operands.
17270 if (matcher.match(N1, ISD::FP_EXTEND)) {
17271 SDValue N10 = N1.getOperand(i: 0);
17272 if (isContractableFMUL(N10) &&
17273 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
17274 SrcVT: N10.getValueType())) {
17275 return matcher.getNode(
17276 PreferredFusedOpcode, SL, VT,
17277 matcher.getNode(
17278 ISD::FNEG, SL, VT,
17279 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(i: 0))),
17280 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(i: 1)), N0);
17281 }
17282 }
17283
17284 // fold (fsub (fpext (fneg (fmul, x, y))), z)
17285 // -> (fneg (fma (fpext x), (fpext y), z))
17286 // Note: This could be removed with appropriate canonicalization of the
17287 // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
17288 // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
17289 // from implementing the canonicalization in visitFSUB.
17290 if (matcher.match(N0, ISD::FP_EXTEND)) {
17291 SDValue N00 = N0.getOperand(i: 0);
17292 if (matcher.match(N00, ISD::FNEG)) {
17293 SDValue N000 = N00.getOperand(i: 0);
17294 if (isContractableFMUL(N000) &&
17295 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
17296 SrcVT: N00.getValueType())) {
17297 return matcher.getNode(
17298 ISD::FNEG, SL, VT,
17299 matcher.getNode(
17300 PreferredFusedOpcode, SL, VT,
17301 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(i: 0)),
17302 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(i: 1)),
17303 N1));
17304 }
17305 }
17306 }
17307
17308 // fold (fsub (fneg (fpext (fmul, x, y))), z)
17309 // -> (fneg (fma (fpext x)), (fpext y), z)
17310 // Note: This could be removed with appropriate canonicalization of the
17311 // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
17312 // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
17313 // from implementing the canonicalization in visitFSUB.
17314 if (matcher.match(N0, ISD::FNEG)) {
17315 SDValue N00 = N0.getOperand(i: 0);
17316 if (matcher.match(N00, ISD::FP_EXTEND)) {
17317 SDValue N000 = N00.getOperand(i: 0);
17318 if (isContractableFMUL(N000) &&
17319 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
17320 SrcVT: N000.getValueType())) {
17321 return matcher.getNode(
17322 ISD::FNEG, SL, VT,
17323 matcher.getNode(
17324 PreferredFusedOpcode, SL, VT,
17325 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(i: 0)),
17326 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(i: 1)),
17327 N1));
17328 }
17329 }
17330 }
17331
17332 auto isContractableAndReassociableFMUL = [&isContractableFMUL](SDValue N) {
17333 return isContractableFMUL(N) && N->getFlags().hasAllowReassociation();
17334 };
17335
17336 auto isFusedOp = [&](SDValue N) {
17337 return matcher.match(N, ISD::FMA) || matcher.match(N, ISD::FMAD);
17338 };
17339
17340 // More folding opportunities when target permits.
17341 if (Aggressive && N->getFlags().hasAllowReassociation()) {
17342 bool CanFuse = N->getFlags().hasAllowContract();
17343 // fold (fsub (fma x, y, (fmul u, v)), z)
17344 // -> (fma x, y (fma u, v, (fneg z)))
17345 if (CanFuse && isFusedOp(N0) &&
17346 isContractableAndReassociableFMUL(N0.getOperand(i: 2)) &&
17347 N0->hasOneUse() && N0.getOperand(i: 2)->hasOneUse()) {
17348 return matcher.getNode(
17349 PreferredFusedOpcode, SL, VT, N0.getOperand(i: 0), N0.getOperand(i: 1),
17350 matcher.getNode(PreferredFusedOpcode, SL, VT,
17351 N0.getOperand(i: 2).getOperand(i: 0),
17352 N0.getOperand(i: 2).getOperand(i: 1),
17353 matcher.getNode(ISD::FNEG, SL, VT, N1)));
17354 }
17355
17356 // fold (fsub x, (fma y, z, (fmul u, v)))
17357 // -> (fma (fneg y), z, (fma (fneg u), v, x))
17358 if (CanFuse && isFusedOp(N1) &&
17359 isContractableAndReassociableFMUL(N1.getOperand(i: 2)) &&
17360 N1->hasOneUse() && NoSignedZero) {
17361 SDValue N20 = N1.getOperand(i: 2).getOperand(i: 0);
17362 SDValue N21 = N1.getOperand(i: 2).getOperand(i: 1);
17363 return matcher.getNode(
17364 PreferredFusedOpcode, SL, VT,
17365 matcher.getNode(ISD::FNEG, SL, VT, N1.getOperand(i: 0)),
17366 N1.getOperand(i: 1),
17367 matcher.getNode(PreferredFusedOpcode, SL, VT,
17368 matcher.getNode(ISD::FNEG, SL, VT, N20), N21, N0));
17369 }
17370
17371 // fold (fsub (fma x, y, (fpext (fmul u, v))), z)
17372 // -> (fma x, y (fma (fpext u), (fpext v), (fneg z)))
17373 if (isFusedOp(N0) && N0->hasOneUse()) {
17374 SDValue N02 = N0.getOperand(i: 2);
17375 if (matcher.match(N02, ISD::FP_EXTEND)) {
17376 SDValue N020 = N02.getOperand(i: 0);
17377 if (isContractableAndReassociableFMUL(N020) &&
17378 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
17379 SrcVT: N020.getValueType())) {
17380 return matcher.getNode(
17381 PreferredFusedOpcode, SL, VT, N0.getOperand(i: 0), N0.getOperand(i: 1),
17382 matcher.getNode(
17383 PreferredFusedOpcode, SL, VT,
17384 matcher.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(i: 0)),
17385 matcher.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(i: 1)),
17386 matcher.getNode(ISD::FNEG, SL, VT, N1)));
17387 }
17388 }
17389 }
17390
17391 // fold (fsub (fpext (fma x, y, (fmul u, v))), z)
17392 // -> (fma (fpext x), (fpext y),
17393 // (fma (fpext u), (fpext v), (fneg z)))
17394 // FIXME: This turns two single-precision and one double-precision
17395 // operation into two double-precision operations, which might not be
17396 // interesting for all targets, especially GPUs.
17397 if (matcher.match(N0, ISD::FP_EXTEND)) {
17398 SDValue N00 = N0.getOperand(i: 0);
17399 if (isFusedOp(N00)) {
17400 SDValue N002 = N00.getOperand(i: 2);
17401 if (isContractableAndReassociableFMUL(N002) &&
17402 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
17403 SrcVT: N00.getValueType())) {
17404 return matcher.getNode(
17405 PreferredFusedOpcode, SL, VT,
17406 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 0)),
17407 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 1)),
17408 matcher.getNode(
17409 PreferredFusedOpcode, SL, VT,
17410 matcher.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(i: 0)),
17411 matcher.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(i: 1)),
17412 matcher.getNode(ISD::FNEG, SL, VT, N1)));
17413 }
17414 }
17415 }
17416
17417 // fold (fsub x, (fma y, z, (fpext (fmul u, v))))
17418 // -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x))
17419 if (isFusedOp(N1) && matcher.match(N1.getOperand(i: 2), ISD::FP_EXTEND) &&
17420 N1->hasOneUse()) {
17421 SDValue N120 = N1.getOperand(i: 2).getOperand(i: 0);
17422 if (isContractableAndReassociableFMUL(N120) &&
17423 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
17424 SrcVT: N120.getValueType())) {
17425 SDValue N1200 = N120.getOperand(i: 0);
17426 SDValue N1201 = N120.getOperand(i: 1);
17427 return matcher.getNode(
17428 PreferredFusedOpcode, SL, VT,
17429 matcher.getNode(ISD::FNEG, SL, VT, N1.getOperand(i: 0)),
17430 N1.getOperand(i: 1),
17431 matcher.getNode(
17432 PreferredFusedOpcode, SL, VT,
17433 matcher.getNode(ISD::FNEG, SL, VT,
17434 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1200)),
17435 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1201), N0));
17436 }
17437 }
17438
17439 // fold (fsub x, (fpext (fma y, z, (fmul u, v))))
17440 // -> (fma (fneg (fpext y)), (fpext z),
17441 // (fma (fneg (fpext u)), (fpext v), x))
17442 // FIXME: This turns two single-precision and one double-precision
17443 // operation into two double-precision operations, which might not be
17444 // interesting for all targets, especially GPUs.
17445 if (matcher.match(N1, ISD::FP_EXTEND) && isFusedOp(N1.getOperand(i: 0))) {
17446 SDValue CvtSrc = N1.getOperand(i: 0);
17447 SDValue N100 = CvtSrc.getOperand(i: 0);
17448 SDValue N101 = CvtSrc.getOperand(i: 1);
17449 SDValue N102 = CvtSrc.getOperand(i: 2);
17450 if (isContractableAndReassociableFMUL(N102) &&
17451 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
17452 SrcVT: CvtSrc.getValueType())) {
17453 SDValue N1020 = N102.getOperand(i: 0);
17454 SDValue N1021 = N102.getOperand(i: 1);
17455 return matcher.getNode(
17456 PreferredFusedOpcode, SL, VT,
17457 matcher.getNode(ISD::FNEG, SL, VT,
17458 matcher.getNode(ISD::FP_EXTEND, SL, VT, N100)),
17459 matcher.getNode(ISD::FP_EXTEND, SL, VT, N101),
17460 matcher.getNode(
17461 PreferredFusedOpcode, SL, VT,
17462 matcher.getNode(ISD::FNEG, SL, VT,
17463 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1020)),
17464 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1021), N0));
17465 }
17466 }
17467 }
17468
17469 return SDValue();
17470}
17471
17472/// Try to perform FMA combining on a given FMUL node based on the distributive
17473/// law x * (y + 1) = x * y + x and variants thereof (commuted versions,
17474/// subtraction instead of addition).
17475SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) {
17476 SDValue N0 = N->getOperand(Num: 0);
17477 SDValue N1 = N->getOperand(Num: 1);
17478 EVT VT = N->getValueType(ResNo: 0);
17479 SDLoc SL(N);
17480
17481 assert(N->getOpcode() == ISD::FMUL && "Expected FMUL Operation");
17482
17483 const TargetOptions &Options = DAG.getTarget().Options;
17484
17485 // The transforms below are incorrect when x == 0 and y == inf, because the
17486 // intermediate multiplication produces a nan.
17487 SDValue FAdd = N0.getOpcode() == ISD::FADD ? N0 : N1;
17488 if (!hasNoInfs(Options, N: FAdd))
17489 return SDValue();
17490
17491 // Floating-point multiply-add without intermediate rounding.
17492 bool HasFMA =
17493 isContractableFMUL(Options, N: SDValue(N, 0)) &&
17494 (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::FMA, VT)) &&
17495 TLI.isFMAFasterThanFMulAndFAdd(MF: DAG.getMachineFunction(), VT);
17496
17497 // Floating-point multiply-add with intermediate rounding. This can result
17498 // in a less precise result due to the changed rounding order.
17499 bool HasFMAD = LegalOperations && TLI.isFMADLegal(DAG, N);
17500
17501 // No valid opcode, do not combine.
17502 if (!HasFMAD && !HasFMA)
17503 return SDValue();
17504
17505 // Always prefer FMAD to FMA for precision.
17506 unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
17507 bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
17508
17509 // fold (fmul (fadd x0, +1.0), y) -> (fma x0, y, y)
17510 // fold (fmul (fadd x0, -1.0), y) -> (fma x0, y, (fneg y))
17511 auto FuseFADD = [&](SDValue X, SDValue Y) {
17512 if (X.getOpcode() == ISD::FADD && (Aggressive || X->hasOneUse())) {
17513 if (auto *C = isConstOrConstSplatFP(N: X.getOperand(i: 1), AllowUndefs: true)) {
17514 if (C->isExactlyValue(V: +1.0))
17515 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT, N1: X.getOperand(i: 0), N2: Y,
17516 N3: Y);
17517 if (C->isExactlyValue(V: -1.0))
17518 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT, N1: X.getOperand(i: 0), N2: Y,
17519 N3: DAG.getNode(Opcode: ISD::FNEG, DL: SL, VT, Operand: Y));
17520 }
17521 }
17522 return SDValue();
17523 };
17524
17525 if (SDValue FMA = FuseFADD(N0, N1))
17526 return FMA;
17527 if (SDValue FMA = FuseFADD(N1, N0))
17528 return FMA;
17529
17530 // fold (fmul (fsub +1.0, x1), y) -> (fma (fneg x1), y, y)
17531 // fold (fmul (fsub -1.0, x1), y) -> (fma (fneg x1), y, (fneg y))
17532 // fold (fmul (fsub x0, +1.0), y) -> (fma x0, y, (fneg y))
17533 // fold (fmul (fsub x0, -1.0), y) -> (fma x0, y, y)
17534 auto FuseFSUB = [&](SDValue X, SDValue Y) {
17535 if (X.getOpcode() == ISD::FSUB && (Aggressive || X->hasOneUse())) {
17536 if (auto *C0 = isConstOrConstSplatFP(N: X.getOperand(i: 0), AllowUndefs: true)) {
17537 if (C0->isExactlyValue(V: +1.0))
17538 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT,
17539 N1: DAG.getNode(Opcode: ISD::FNEG, DL: SL, VT, Operand: X.getOperand(i: 1)), N2: Y,
17540 N3: Y);
17541 if (C0->isExactlyValue(V: -1.0))
17542 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT,
17543 N1: DAG.getNode(Opcode: ISD::FNEG, DL: SL, VT, Operand: X.getOperand(i: 1)), N2: Y,
17544 N3: DAG.getNode(Opcode: ISD::FNEG, DL: SL, VT, Operand: Y));
17545 }
17546 if (auto *C1 = isConstOrConstSplatFP(N: X.getOperand(i: 1), AllowUndefs: true)) {
17547 if (C1->isExactlyValue(V: +1.0))
17548 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT, N1: X.getOperand(i: 0), N2: Y,
17549 N3: DAG.getNode(Opcode: ISD::FNEG, DL: SL, VT, Operand: Y));
17550 if (C1->isExactlyValue(V: -1.0))
17551 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT, N1: X.getOperand(i: 0), N2: Y,
17552 N3: Y);
17553 }
17554 }
17555 return SDValue();
17556 };
17557
17558 if (SDValue FMA = FuseFSUB(N0, N1))
17559 return FMA;
17560 if (SDValue FMA = FuseFSUB(N1, N0))
17561 return FMA;
17562
17563 return SDValue();
17564}
17565
17566SDValue DAGCombiner::visitVP_FADD(SDNode *N) {
17567 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17568
17569 // FADD -> FMA combines:
17570 if (SDValue Fused = visitFADDForFMACombine<VPMatchContext>(N)) {
17571 if (Fused.getOpcode() != ISD::DELETED_NODE)
17572 AddToWorklist(N: Fused.getNode());
17573 return Fused;
17574 }
17575 return SDValue();
17576}
17577
17578SDValue DAGCombiner::visitFADD(SDNode *N) {
17579 SDValue N0 = N->getOperand(Num: 0);
17580 SDValue N1 = N->getOperand(Num: 1);
17581 bool N0CFP = DAG.isConstantFPBuildVectorOrConstantFP(N: N0);
17582 bool N1CFP = DAG.isConstantFPBuildVectorOrConstantFP(N: N1);
17583 EVT VT = N->getValueType(ResNo: 0);
17584 SDLoc DL(N);
17585 const TargetOptions &Options = DAG.getTarget().Options;
17586 SDNodeFlags Flags = N->getFlags();
17587 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17588
17589 if (SDValue R = DAG.simplifyFPBinop(Opcode: N->getOpcode(), X: N0, Y: N1, Flags))
17590 return R;
17591
17592 // fold (fadd c1, c2) -> c1 + c2
17593 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FADD, DL, VT, Ops: {N0, N1}))
17594 return C;
17595
17596 // canonicalize constant to RHS
17597 if (N0CFP && !N1CFP)
17598 return DAG.getNode(Opcode: ISD::FADD, DL, VT, N1, N2: N0);
17599
17600 // fold vector ops
17601 if (VT.isVector())
17602 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
17603 return FoldedVOp;
17604
17605 // N0 + -0.0 --> N0 (also allowed with +0.0 and fast-math)
17606 ConstantFPSDNode *N1C = isConstOrConstSplatFP(N: N1, AllowUndefs: true);
17607 if (N1C && N1C->isZero())
17608 if (N1C->isNegative() || Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())
17609 return N0;
17610
17611 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
17612 return NewSel;
17613
17614 // fold (fadd A, (fneg B)) -> (fsub A, B)
17615 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::FSUB, VT))
17616 if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
17617 Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
17618 return DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1: N0, N2: NegN1);
17619
17620 // fold (fadd (fneg A), B) -> (fsub B, A)
17621 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::FSUB, VT))
17622 if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
17623 Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
17624 return DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1, N2: NegN0);
17625
17626 auto isFMulNegTwo = [](SDValue FMul) {
17627 if (!FMul.hasOneUse() || FMul.getOpcode() != ISD::FMUL)
17628 return false;
17629 auto *C = isConstOrConstSplatFP(N: FMul.getOperand(i: 1), AllowUndefs: true);
17630 return C && C->isExactlyValue(V: -2.0);
17631 };
17632
17633 // fadd (fmul B, -2.0), A --> fsub A, (fadd B, B)
17634 if (isFMulNegTwo(N0)) {
17635 SDValue B = N0.getOperand(i: 0);
17636 SDValue Add = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: B, N2: B);
17637 return DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1, N2: Add);
17638 }
17639 // fadd A, (fmul B, -2.0) --> fsub A, (fadd B, B)
17640 if (isFMulNegTwo(N1)) {
17641 SDValue B = N1.getOperand(i: 0);
17642 SDValue Add = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: B, N2: B);
17643 return DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1: N0, N2: Add);
17644 }
17645
17646 // No FP constant should be created after legalization as Instruction
17647 // Selection pass has a hard time dealing with FP constants.
17648 bool AllowNewConst = (Level < AfterLegalizeDAG);
17649
17650 // If nnan is enabled, fold lots of things.
17651 if ((Options.NoNaNsFPMath || Flags.hasNoNaNs()) && AllowNewConst) {
17652 // If allowed, fold (fadd (fneg x), x) -> 0.0
17653 if (N0.getOpcode() == ISD::FNEG && N0.getOperand(i: 0) == N1)
17654 return DAG.getConstantFP(Val: 0.0, DL, VT);
17655
17656 // If allowed, fold (fadd x, (fneg x)) -> 0.0
17657 if (N1.getOpcode() == ISD::FNEG && N1.getOperand(i: 0) == N0)
17658 return DAG.getConstantFP(Val: 0.0, DL, VT);
17659 }
17660
17661 // If 'unsafe math' or reassoc and nsz, fold lots of things.
17662 // TODO: break out portions of the transformations below for which Unsafe is
17663 // considered and which do not require both nsz and reassoc
17664 if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
17665 (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
17666 AllowNewConst) {
17667 // fadd (fadd x, c1), c2 -> fadd x, c1 + c2
17668 if (N1CFP && N0.getOpcode() == ISD::FADD &&
17669 DAG.isConstantFPBuildVectorOrConstantFP(N: N0.getOperand(i: 1))) {
17670 SDValue NewC = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0.getOperand(i: 1), N2: N1);
17671 return DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0.getOperand(i: 0), N2: NewC);
17672 }
17673
17674 // We can fold chains of FADD's of the same value into multiplications.
17675 // This transform is not safe in general because we are reducing the number
17676 // of rounding steps.
17677 if (TLI.isOperationLegalOrCustom(Op: ISD::FMUL, VT) && !N0CFP && !N1CFP) {
17678 if (N0.getOpcode() == ISD::FMUL) {
17679 bool CFP00 = DAG.isConstantFPBuildVectorOrConstantFP(N: N0.getOperand(i: 0));
17680 bool CFP01 = DAG.isConstantFPBuildVectorOrConstantFP(N: N0.getOperand(i: 1));
17681
17682 // (fadd (fmul x, c), x) -> (fmul x, c+1)
17683 if (CFP01 && !CFP00 && N0.getOperand(i: 0) == N1) {
17684 SDValue NewCFP = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0.getOperand(i: 1),
17685 N2: DAG.getConstantFP(Val: 1.0, DL, VT));
17686 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1, N2: NewCFP);
17687 }
17688
17689 // (fadd (fmul x, c), (fadd x, x)) -> (fmul x, c+2)
17690 if (CFP01 && !CFP00 && N1.getOpcode() == ISD::FADD &&
17691 N1.getOperand(i: 0) == N1.getOperand(i: 1) &&
17692 N0.getOperand(i: 0) == N1.getOperand(i: 0)) {
17693 SDValue NewCFP = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0.getOperand(i: 1),
17694 N2: DAG.getConstantFP(Val: 2.0, DL, VT));
17695 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0.getOperand(i: 0), N2: NewCFP);
17696 }
17697 }
17698
17699 if (N1.getOpcode() == ISD::FMUL) {
17700 bool CFP10 = DAG.isConstantFPBuildVectorOrConstantFP(N: N1.getOperand(i: 0));
17701 bool CFP11 = DAG.isConstantFPBuildVectorOrConstantFP(N: N1.getOperand(i: 1));
17702
17703 // (fadd x, (fmul x, c)) -> (fmul x, c+1)
17704 if (CFP11 && !CFP10 && N1.getOperand(i: 0) == N0) {
17705 SDValue NewCFP = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N1.getOperand(i: 1),
17706 N2: DAG.getConstantFP(Val: 1.0, DL, VT));
17707 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: NewCFP);
17708 }
17709
17710 // (fadd (fadd x, x), (fmul x, c)) -> (fmul x, c+2)
17711 if (CFP11 && !CFP10 && N0.getOpcode() == ISD::FADD &&
17712 N0.getOperand(i: 0) == N0.getOperand(i: 1) &&
17713 N1.getOperand(i: 0) == N0.getOperand(i: 0)) {
17714 SDValue NewCFP = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N1.getOperand(i: 1),
17715 N2: DAG.getConstantFP(Val: 2.0, DL, VT));
17716 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N1.getOperand(i: 0), N2: NewCFP);
17717 }
17718 }
17719
17720 if (N0.getOpcode() == ISD::FADD) {
17721 bool CFP00 = DAG.isConstantFPBuildVectorOrConstantFP(N: N0.getOperand(i: 0));
17722 // (fadd (fadd x, x), x) -> (fmul x, 3.0)
17723 if (!CFP00 && N0.getOperand(i: 0) == N0.getOperand(i: 1) &&
17724 (N0.getOperand(i: 0) == N1)) {
17725 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1,
17726 N2: DAG.getConstantFP(Val: 3.0, DL, VT));
17727 }
17728 }
17729
17730 if (N1.getOpcode() == ISD::FADD) {
17731 bool CFP10 = DAG.isConstantFPBuildVectorOrConstantFP(N: N1.getOperand(i: 0));
17732 // (fadd x, (fadd x, x)) -> (fmul x, 3.0)
17733 if (!CFP10 && N1.getOperand(i: 0) == N1.getOperand(i: 1) &&
17734 N1.getOperand(i: 0) == N0) {
17735 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0,
17736 N2: DAG.getConstantFP(Val: 3.0, DL, VT));
17737 }
17738 }
17739
17740 // (fadd (fadd x, x), (fadd x, x)) -> (fmul x, 4.0)
17741 if (N0.getOpcode() == ISD::FADD && N1.getOpcode() == ISD::FADD &&
17742 N0.getOperand(i: 0) == N0.getOperand(i: 1) &&
17743 N1.getOperand(i: 0) == N1.getOperand(i: 1) &&
17744 N0.getOperand(i: 0) == N1.getOperand(i: 0)) {
17745 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0.getOperand(i: 0),
17746 N2: DAG.getConstantFP(Val: 4.0, DL, VT));
17747 }
17748 }
17749 } // enable-unsafe-fp-math && AllowNewConst
17750
17751 if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
17752 (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros()))) {
17753 // Fold fadd(vecreduce(x), vecreduce(y)) -> vecreduce(fadd(x, y))
17754 if (SDValue SD = reassociateReduction(RedOpc: ISD::VECREDUCE_FADD, Opc: ISD::FADD, DL,
17755 VT, N0, N1, Flags))
17756 return SD;
17757 }
17758
17759 // FADD -> FMA combines:
17760 if (SDValue Fused = visitFADDForFMACombine<EmptyMatchContext>(N)) {
17761 if (Fused.getOpcode() != ISD::DELETED_NODE)
17762 AddToWorklist(N: Fused.getNode());
17763 return Fused;
17764 }
17765 return SDValue();
17766}
17767
17768SDValue DAGCombiner::visitSTRICT_FADD(SDNode *N) {
17769 SDValue Chain = N->getOperand(Num: 0);
17770 SDValue N0 = N->getOperand(Num: 1);
17771 SDValue N1 = N->getOperand(Num: 2);
17772 EVT VT = N->getValueType(ResNo: 0);
17773 EVT ChainVT = N->getValueType(ResNo: 1);
17774 SDLoc DL(N);
17775 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17776
17777 // fold (strict_fadd A, (fneg B)) -> (strict_fsub A, B)
17778 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::STRICT_FSUB, VT))
17779 if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
17780 Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize)) {
17781 return DAG.getNode(Opcode: ISD::STRICT_FSUB, DL, VTList: DAG.getVTList(VT1: VT, VT2: ChainVT),
17782 Ops: {Chain, N0, NegN1});
17783 }
17784
17785 // fold (strict_fadd (fneg A), B) -> (strict_fsub B, A)
17786 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::STRICT_FSUB, VT))
17787 if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
17788 Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize)) {
17789 return DAG.getNode(Opcode: ISD::STRICT_FSUB, DL, VTList: DAG.getVTList(VT1: VT, VT2: ChainVT),
17790 Ops: {Chain, N1, NegN0});
17791 }
17792 return SDValue();
17793}
17794
17795SDValue DAGCombiner::visitFSUB(SDNode *N) {
17796 SDValue N0 = N->getOperand(Num: 0);
17797 SDValue N1 = N->getOperand(Num: 1);
17798 ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N: N0, AllowUndefs: true);
17799 ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N: N1, AllowUndefs: true);
17800 EVT VT = N->getValueType(ResNo: 0);
17801 SDLoc DL(N);
17802 const TargetOptions &Options = DAG.getTarget().Options;
17803 const SDNodeFlags Flags = N->getFlags();
17804 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17805
17806 if (SDValue R = DAG.simplifyFPBinop(Opcode: N->getOpcode(), X: N0, Y: N1, Flags))
17807 return R;
17808
17809 // fold (fsub c1, c2) -> c1-c2
17810 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FSUB, DL, VT, Ops: {N0, N1}))
17811 return C;
17812
17813 // fold vector ops
17814 if (VT.isVector())
17815 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
17816 return FoldedVOp;
17817
17818 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
17819 return NewSel;
17820
17821 // (fsub A, 0) -> A
17822 if (N1CFP && N1CFP->isZero()) {
17823 if (!N1CFP->isNegative() || Options.NoSignedZerosFPMath ||
17824 Flags.hasNoSignedZeros()) {
17825 return N0;
17826 }
17827 }
17828
17829 if (N0 == N1) {
17830 // (fsub x, x) -> 0.0
17831 if (Options.NoNaNsFPMath || Flags.hasNoNaNs())
17832 return DAG.getConstantFP(Val: 0.0f, DL, VT);
17833 }
17834
17835 // (fsub -0.0, N1) -> -N1
17836 if (N0CFP && N0CFP->isZero()) {
17837 if (N0CFP->isNegative() ||
17838 (Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())) {
17839 // We cannot replace an FSUB(+-0.0,X) with FNEG(X) when denormals are
17840 // flushed to zero, unless all users treat denorms as zero (DAZ).
17841 // FIXME: This transform will change the sign of a NaN and the behavior
17842 // of a signaling NaN. It is only valid when a NoNaN flag is present.
17843 DenormalMode DenormMode = DAG.getDenormalMode(VT);
17844 if (DenormMode == DenormalMode::getIEEE()) {
17845 if (SDValue NegN1 =
17846 TLI.getNegatedExpression(Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
17847 return NegN1;
17848 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::FNEG, VT))
17849 return DAG.getNode(Opcode: ISD::FNEG, DL, VT, Operand: N1);
17850 }
17851 }
17852 }
17853
17854 if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
17855 (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
17856 N1.getOpcode() == ISD::FADD) {
17857 // X - (X + Y) -> -Y
17858 if (N0 == N1->getOperand(Num: 0))
17859 return DAG.getNode(Opcode: ISD::FNEG, DL, VT, Operand: N1->getOperand(Num: 1));
17860 // X - (Y + X) -> -Y
17861 if (N0 == N1->getOperand(Num: 1))
17862 return DAG.getNode(Opcode: ISD::FNEG, DL, VT, Operand: N1->getOperand(Num: 0));
17863 }
17864
17865 // fold (fsub A, (fneg B)) -> (fadd A, B)
17866 if (SDValue NegN1 =
17867 TLI.getNegatedExpression(Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
17868 return DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0, N2: NegN1);
17869
17870 // FSUB -> FMA combines:
17871 if (SDValue Fused = visitFSUBForFMACombine<EmptyMatchContext>(N)) {
17872 AddToWorklist(N: Fused.getNode());
17873 return Fused;
17874 }
17875
17876 return SDValue();
17877}
17878
17879// Transform IEEE Floats:
17880// (fmul C, (uitofp Pow2))
17881// -> (bitcast_to_FP (add (bitcast_to_INT C), Log2(Pow2) << mantissa))
17882// (fdiv C, (uitofp Pow2))
17883// -> (bitcast_to_FP (sub (bitcast_to_INT C), Log2(Pow2) << mantissa))
17884//
17885// The rationale is fmul/fdiv by a power of 2 is just change the exponent, so
17886// there is no need for more than an add/sub.
17887//
17888// This is valid under the following circumstances:
17889// 1) We are dealing with IEEE floats
17890// 2) C is normal
17891// 3) The fmul/fdiv add/sub will not go outside of min/max exponent bounds.
17892// TODO: Much of this could also be used for generating `ldexp` on targets the
17893// prefer it.
17894SDValue DAGCombiner::combineFMulOrFDivWithIntPow2(SDNode *N) {
17895 EVT VT = N->getValueType(ResNo: 0);
17896 if (!APFloat::isIEEELikeFP(VT.getFltSemantics()))
17897 return SDValue();
17898
17899 SDValue ConstOp, Pow2Op;
17900
17901 std::optional<int> Mantissa;
17902 auto GetConstAndPow2Ops = [&](unsigned ConstOpIdx) {
17903 if (ConstOpIdx == 1 && N->getOpcode() == ISD::FDIV)
17904 return false;
17905
17906 ConstOp = peekThroughBitcasts(V: N->getOperand(Num: ConstOpIdx));
17907 Pow2Op = N->getOperand(Num: 1 - ConstOpIdx);
17908 if (Pow2Op.getOpcode() != ISD::UINT_TO_FP &&
17909 (Pow2Op.getOpcode() != ISD::SINT_TO_FP ||
17910 !DAG.computeKnownBits(Op: Pow2Op).isNonNegative()))
17911 return false;
17912
17913 Pow2Op = Pow2Op.getOperand(i: 0);
17914
17915 // `Log2(Pow2Op) < Pow2Op.getScalarSizeInBits()`.
17916 // TODO: We could use knownbits to make this bound more precise.
17917 int MaxExpChange = Pow2Op.getValueType().getScalarSizeInBits();
17918
17919 auto IsFPConstValid = [N, MaxExpChange, &Mantissa](ConstantFPSDNode *CFP) {
17920 if (CFP == nullptr)
17921 return false;
17922
17923 const APFloat &APF = CFP->getValueAPF();
17924
17925 // Make sure we have normal constant.
17926 if (!APF.isNormal())
17927 return false;
17928
17929 // Make sure the floats exponent is within the bounds that this transform
17930 // produces bitwise equals value.
17931 int CurExp = ilogb(Arg: APF);
17932 // FMul by pow2 will only increase exponent.
17933 int MinExp =
17934 N->getOpcode() == ISD::FMUL ? CurExp : (CurExp - MaxExpChange);
17935 // FDiv by pow2 will only decrease exponent.
17936 int MaxExp =
17937 N->getOpcode() == ISD::FDIV ? CurExp : (CurExp + MaxExpChange);
17938 if (MinExp <= APFloat::semanticsMinExponent(APF.getSemantics()) ||
17939 MaxExp >= APFloat::semanticsMaxExponent(APF.getSemantics()))
17940 return false;
17941
17942 // Finally make sure we actually know the mantissa for the float type.
17943 int ThisMantissa = APFloat::semanticsPrecision(APF.getSemantics()) - 1;
17944 if (!Mantissa)
17945 Mantissa = ThisMantissa;
17946
17947 return *Mantissa == ThisMantissa && ThisMantissa > 0;
17948 };
17949
17950 // TODO: We may be able to include undefs.
17951 return ISD::matchUnaryFpPredicate(Op: ConstOp, Match: IsFPConstValid);
17952 };
17953
17954 if (!GetConstAndPow2Ops(0) && !GetConstAndPow2Ops(1))
17955 return SDValue();
17956
17957 if (!TLI.optimizeFMulOrFDivAsShiftAddBitcast(N, FPConst: ConstOp, IntPow2: Pow2Op))
17958 return SDValue();
17959
17960 // Get log2 after all other checks have taken place. This is because
17961 // BuildLogBase2 may create a new node.
17962 SDLoc DL(N);
17963 // Get Log2 type with same bitwidth as the float type (VT).
17964 EVT NewIntVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: VT.getScalarSizeInBits());
17965 if (VT.isVector())
17966 NewIntVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NewIntVT,
17967 EC: VT.getVectorElementCount());
17968
17969 SDValue Log2 = BuildLogBase2(V: Pow2Op, DL, KnownNeverZero: DAG.isKnownNeverZero(Op: Pow2Op),
17970 /*InexpensiveOnly*/ true, OutVT: NewIntVT);
17971 if (!Log2)
17972 return SDValue();
17973
17974 // Perform actual transform.
17975 SDValue MantissaShiftCnt =
17976 DAG.getShiftAmountConstant(Val: *Mantissa, VT: NewIntVT, DL);
17977 // TODO: Sometimes Log2 is of form `(X + C)`. `(X + C) << C1` should fold to
17978 // `(X << C1) + (C << C1)`, but that isn't always the case because of the
17979 // cast. We could implement that by handle here to handle the casts.
17980 SDValue Shift = DAG.getNode(Opcode: ISD::SHL, DL, VT: NewIntVT, N1: Log2, N2: MantissaShiftCnt);
17981 SDValue ResAsInt =
17982 DAG.getNode(Opcode: N->getOpcode() == ISD::FMUL ? ISD::ADD : ISD::SUB, DL,
17983 VT: NewIntVT, N1: DAG.getBitcast(VT: NewIntVT, V: ConstOp), N2: Shift);
17984 SDValue ResAsFP = DAG.getBitcast(VT, V: ResAsInt);
17985 return ResAsFP;
17986}
17987
17988SDValue DAGCombiner::visitFMUL(SDNode *N) {
17989 SDValue N0 = N->getOperand(Num: 0);
17990 SDValue N1 = N->getOperand(Num: 1);
17991 ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N: N1, AllowUndefs: true);
17992 EVT VT = N->getValueType(ResNo: 0);
17993 SDLoc DL(N);
17994 const TargetOptions &Options = DAG.getTarget().Options;
17995 const SDNodeFlags Flags = N->getFlags();
17996 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17997
17998 if (SDValue R = DAG.simplifyFPBinop(Opcode: N->getOpcode(), X: N0, Y: N1, Flags))
17999 return R;
18000
18001 // fold (fmul c1, c2) -> c1*c2
18002 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FMUL, DL, VT, Ops: {N0, N1}))
18003 return C;
18004
18005 // canonicalize constant to RHS
18006 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0) &&
18007 !DAG.isConstantFPBuildVectorOrConstantFP(N: N1))
18008 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1, N2: N0);
18009
18010 // fold vector ops
18011 if (VT.isVector())
18012 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
18013 return FoldedVOp;
18014
18015 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
18016 return NewSel;
18017
18018 if (Options.UnsafeFPMath || Flags.hasAllowReassociation()) {
18019 // fmul (fmul X, C1), C2 -> fmul X, C1 * C2
18020 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N1) &&
18021 N0.getOpcode() == ISD::FMUL) {
18022 SDValue N00 = N0.getOperand(i: 0);
18023 SDValue N01 = N0.getOperand(i: 1);
18024 // Avoid an infinite loop by making sure that N00 is not a constant
18025 // (the inner multiply has not been constant folded yet).
18026 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N01) &&
18027 !DAG.isConstantFPBuildVectorOrConstantFP(N: N00)) {
18028 SDValue MulConsts = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N01, N2: N1);
18029 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N00, N2: MulConsts);
18030 }
18031 }
18032
18033 // Match a special-case: we convert X * 2.0 into fadd.
18034 // fmul (fadd X, X), C -> fmul X, 2.0 * C
18035 if (N0.getOpcode() == ISD::FADD && N0.hasOneUse() &&
18036 N0.getOperand(i: 0) == N0.getOperand(i: 1)) {
18037 const SDValue Two = DAG.getConstantFP(Val: 2.0, DL, VT);
18038 SDValue MulConsts = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Two, N2: N1);
18039 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0.getOperand(i: 0), N2: MulConsts);
18040 }
18041
18042 // Fold fmul(vecreduce(x), vecreduce(y)) -> vecreduce(fmul(x, y))
18043 if (SDValue SD = reassociateReduction(RedOpc: ISD::VECREDUCE_FMUL, Opc: ISD::FMUL, DL,
18044 VT, N0, N1, Flags))
18045 return SD;
18046 }
18047
18048 // fold (fmul X, 2.0) -> (fadd X, X)
18049 if (N1CFP && N1CFP->isExactlyValue(V: +2.0))
18050 return DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0, N2: N0);
18051
18052 // fold (fmul X, -1.0) -> (fsub -0.0, X)
18053 if (N1CFP && N1CFP->isExactlyValue(V: -1.0)) {
18054 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::FSUB, VT)) {
18055 return DAG.getNode(Opcode: ISD::FSUB, DL, VT,
18056 N1: DAG.getConstantFP(Val: -0.0, DL, VT), N2: N0, Flags);
18057 }
18058 }
18059
18060 // -N0 * -N1 --> N0 * N1
18061 TargetLowering::NegatibleCost CostN0 =
18062 TargetLowering::NegatibleCost::Expensive;
18063 TargetLowering::NegatibleCost CostN1 =
18064 TargetLowering::NegatibleCost::Expensive;
18065 SDValue NegN0 =
18066 TLI.getNegatedExpression(Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN0);
18067 if (NegN0) {
18068 HandleSDNode NegN0Handle(NegN0);
18069 SDValue NegN1 =
18070 TLI.getNegatedExpression(Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN1);
18071 if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
18072 CostN1 == TargetLowering::NegatibleCost::Cheaper))
18073 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: NegN0, N2: NegN1);
18074 }
18075
18076 // fold (fmul X, (select (fcmp X > 0.0), -1.0, 1.0)) -> (fneg (fabs X))
18077 // fold (fmul X, (select (fcmp X > 0.0), 1.0, -1.0)) -> (fabs X)
18078 if (Flags.hasNoNaNs() && Flags.hasNoSignedZeros() &&
18079 (N0.getOpcode() == ISD::SELECT || N1.getOpcode() == ISD::SELECT) &&
18080 TLI.isOperationLegal(Op: ISD::FABS, VT)) {
18081 SDValue Select = N0, X = N1;
18082 if (Select.getOpcode() != ISD::SELECT)
18083 std::swap(a&: Select, b&: X);
18084
18085 SDValue Cond = Select.getOperand(i: 0);
18086 auto TrueOpnd = dyn_cast<ConstantFPSDNode>(Val: Select.getOperand(i: 1));
18087 auto FalseOpnd = dyn_cast<ConstantFPSDNode>(Val: Select.getOperand(i: 2));
18088
18089 if (TrueOpnd && FalseOpnd &&
18090 Cond.getOpcode() == ISD::SETCC && Cond.getOperand(i: 0) == X &&
18091 isa<ConstantFPSDNode>(Val: Cond.getOperand(i: 1)) &&
18092 cast<ConstantFPSDNode>(Val: Cond.getOperand(i: 1))->isExactlyValue(V: 0.0)) {
18093 ISD::CondCode CC = cast<CondCodeSDNode>(Val: Cond.getOperand(i: 2))->get();
18094 switch (CC) {
18095 default: break;
18096 case ISD::SETOLT:
18097 case ISD::SETULT:
18098 case ISD::SETOLE:
18099 case ISD::SETULE:
18100 case ISD::SETLT:
18101 case ISD::SETLE:
18102 std::swap(a&: TrueOpnd, b&: FalseOpnd);
18103 [[fallthrough]];
18104 case ISD::SETOGT:
18105 case ISD::SETUGT:
18106 case ISD::SETOGE:
18107 case ISD::SETUGE:
18108 case ISD::SETGT:
18109 case ISD::SETGE:
18110 if (TrueOpnd->isExactlyValue(V: -1.0) && FalseOpnd->isExactlyValue(V: 1.0) &&
18111 TLI.isOperationLegal(Op: ISD::FNEG, VT))
18112 return DAG.getNode(Opcode: ISD::FNEG, DL, VT,
18113 Operand: DAG.getNode(Opcode: ISD::FABS, DL, VT, Operand: X));
18114 if (TrueOpnd->isExactlyValue(V: 1.0) && FalseOpnd->isExactlyValue(V: -1.0))
18115 return DAG.getNode(Opcode: ISD::FABS, DL, VT, Operand: X);
18116
18117 break;
18118 }
18119 }
18120 }
18121
18122 // FMUL -> FMA combines:
18123 if (SDValue Fused = visitFMULForFMADistributiveCombine(N)) {
18124 AddToWorklist(N: Fused.getNode());
18125 return Fused;
18126 }
18127
18128 // Don't do `combineFMulOrFDivWithIntPow2` until after FMUL -> FMA has been
18129 // able to run.
18130 if (SDValue R = combineFMulOrFDivWithIntPow2(N))
18131 return R;
18132
18133 return SDValue();
18134}
18135
18136template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
18137 SDValue N0 = N->getOperand(Num: 0);
18138 SDValue N1 = N->getOperand(Num: 1);
18139 SDValue N2 = N->getOperand(Num: 2);
18140 ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(Val&: N0);
18141 ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(Val&: N1);
18142 ConstantFPSDNode *N2CFP = dyn_cast<ConstantFPSDNode>(Val&: N2);
18143 EVT VT = N->getValueType(ResNo: 0);
18144 SDLoc DL(N);
18145 const TargetOptions &Options = DAG.getTarget().Options;
18146 // FMA nodes have flags that propagate to the created nodes.
18147 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18148 MatchContextClass matcher(DAG, TLI, N);
18149
18150 // Constant fold FMA.
18151 if (SDValue C =
18152 DAG.FoldConstantArithmetic(Opcode: N->getOpcode(), DL, VT, Ops: {N0, N1, N2}))
18153 return C;
18154
18155 // (-N0 * -N1) + N2 --> (N0 * N1) + N2
18156 TargetLowering::NegatibleCost CostN0 =
18157 TargetLowering::NegatibleCost::Expensive;
18158 TargetLowering::NegatibleCost CostN1 =
18159 TargetLowering::NegatibleCost::Expensive;
18160 SDValue NegN0 =
18161 TLI.getNegatedExpression(Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN0);
18162 if (NegN0) {
18163 HandleSDNode NegN0Handle(NegN0);
18164 SDValue NegN1 =
18165 TLI.getNegatedExpression(Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN1);
18166 if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
18167 CostN1 == TargetLowering::NegatibleCost::Cheaper))
18168 return matcher.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2);
18169 }
18170
18171 // FIXME: use fast math flags instead of Options.UnsafeFPMath
18172 // TODO: Finally migrate away from global TargetOptions.
18173 if ((Options.NoNaNsFPMath && Options.NoInfsFPMath) ||
18174 (N->getFlags().hasNoNaNs() && N->getFlags().hasNoInfs())) {
18175 if (Options.NoSignedZerosFPMath || N->getFlags().hasNoSignedZeros() ||
18176 (N2CFP && !N2CFP->isExactlyValue(V: -0.0))) {
18177 if (N0CFP && N0CFP->isZero())
18178 return N2;
18179 if (N1CFP && N1CFP->isZero())
18180 return N2;
18181 }
18182 }
18183
18184 // FIXME: Support splat of constant.
18185 if (N0CFP && N0CFP->isExactlyValue(V: 1.0))
18186 return matcher.getNode(ISD::FADD, DL, VT, N1, N2);
18187 if (N1CFP && N1CFP->isExactlyValue(V: 1.0))
18188 return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
18189
18190 // Canonicalize (fma c, x, y) -> (fma x, c, y)
18191 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0) &&
18192 !DAG.isConstantFPBuildVectorOrConstantFP(N: N1))
18193 return matcher.getNode(ISD::FMA, DL, VT, N1, N0, N2);
18194
18195 bool CanReassociate =
18196 Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
18197 if (CanReassociate) {
18198 // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2)
18199 if (matcher.match(N2, ISD::FMUL) && N0 == N2.getOperand(i: 0) &&
18200 DAG.isConstantFPBuildVectorOrConstantFP(N: N1) &&
18201 DAG.isConstantFPBuildVectorOrConstantFP(N: N2.getOperand(i: 1))) {
18202 return matcher.getNode(
18203 ISD::FMUL, DL, VT, N0,
18204 matcher.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(i: 1)));
18205 }
18206
18207 // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y)
18208 if (matcher.match(N0, ISD::FMUL) &&
18209 DAG.isConstantFPBuildVectorOrConstantFP(N: N1) &&
18210 DAG.isConstantFPBuildVectorOrConstantFP(N: N0.getOperand(i: 1))) {
18211 return matcher.getNode(
18212 ISD::FMA, DL, VT, N0.getOperand(i: 0),
18213 matcher.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(i: 1)), N2);
18214 }
18215 }
18216
18217 // (fma x, -1, y) -> (fadd (fneg x), y)
18218 // FIXME: Support splat of constant.
18219 if (N1CFP) {
18220 if (N1CFP->isExactlyValue(V: 1.0))
18221 return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
18222
18223 if (N1CFP->isExactlyValue(V: -1.0) &&
18224 (!LegalOperations || TLI.isOperationLegal(Op: ISD::FNEG, VT))) {
18225 SDValue RHSNeg = matcher.getNode(ISD::FNEG, DL, VT, N0);
18226 AddToWorklist(N: RHSNeg.getNode());
18227 return matcher.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
18228 }
18229
18230 // fma (fneg x), K, y -> fma x -K, y
18231 if (matcher.match(N0, ISD::FNEG) &&
18232 (TLI.isOperationLegal(Op: ISD::ConstantFP, VT) ||
18233 (N1.hasOneUse() &&
18234 !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
18235 return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(i: 0),
18236 matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
18237 }
18238 }
18239
18240 // FIXME: Support splat of constant.
18241 if (CanReassociate) {
18242 // (fma x, c, x) -> (fmul x, (c+1))
18243 if (N1CFP && N0 == N2) {
18244 return matcher.getNode(ISD::FMUL, DL, VT, N0,
18245 matcher.getNode(ISD::FADD, DL, VT, N1,
18246 DAG.getConstantFP(Val: 1.0, DL, VT)));
18247 }
18248
18249 // (fma x, c, (fneg x)) -> (fmul x, (c-1))
18250 if (N1CFP && matcher.match(N2, ISD::FNEG) && N2.getOperand(i: 0) == N0) {
18251 return matcher.getNode(ISD::FMUL, DL, VT, N0,
18252 matcher.getNode(ISD::FADD, DL, VT, N1,
18253 DAG.getConstantFP(Val: -1.0, DL, VT)));
18254 }
18255 }
18256
18257 // fold ((fma (fneg X), Y, (fneg Z)) -> fneg (fma X, Y, Z))
18258 // fold ((fma X, (fneg Y), (fneg Z)) -> fneg (fma X, Y, Z))
18259 if (!TLI.isFNegFree(VT))
18260 if (SDValue Neg = TLI.getCheaperNegatedExpression(
18261 Op: SDValue(N, 0), DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
18262 return matcher.getNode(ISD::FNEG, DL, VT, Neg);
18263 return SDValue();
18264}
18265
18266SDValue DAGCombiner::visitFMAD(SDNode *N) {
18267 SDValue N0 = N->getOperand(Num: 0);
18268 SDValue N1 = N->getOperand(Num: 1);
18269 SDValue N2 = N->getOperand(Num: 2);
18270 EVT VT = N->getValueType(ResNo: 0);
18271 SDLoc DL(N);
18272
18273 // Constant fold FMAD.
18274 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FMAD, DL, VT, Ops: {N0, N1, N2}))
18275 return C;
18276
18277 return SDValue();
18278}
18279
18280// Combine multiple FDIVs with the same divisor into multiple FMULs by the
18281// reciprocal.
18282// E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip)
18283// Notice that this is not always beneficial. One reason is different targets
18284// may have different costs for FDIV and FMUL, so sometimes the cost of two
18285// FDIVs may be lower than the cost of one FDIV and two FMULs. Another reason
18286// is the critical path is increased from "one FDIV" to "one FDIV + one FMUL".
18287SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) {
18288 // TODO: Limit this transform based on optsize/minsize - it always creates at
18289 // least 1 extra instruction. But the perf win may be substantial enough
18290 // that only minsize should restrict this.
18291 bool UnsafeMath = DAG.getTarget().Options.UnsafeFPMath;
18292 const SDNodeFlags Flags = N->getFlags();
18293 if (LegalDAG || (!UnsafeMath && !Flags.hasAllowReciprocal()))
18294 return SDValue();
18295
18296 // Skip if current node is a reciprocal/fneg-reciprocal.
18297 SDValue N0 = N->getOperand(Num: 0), N1 = N->getOperand(Num: 1);
18298 ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N: N0, /* AllowUndefs */ true);
18299 if (N0CFP && (N0CFP->isExactlyValue(V: 1.0) || N0CFP->isExactlyValue(V: -1.0)))
18300 return SDValue();
18301
18302 // Exit early if the target does not want this transform or if there can't
18303 // possibly be enough uses of the divisor to make the transform worthwhile.
18304 unsigned MinUses = TLI.combineRepeatedFPDivisors();
18305
18306 // For splat vectors, scale the number of uses by the splat factor. If we can
18307 // convert the division into a scalar op, that will likely be much faster.
18308 unsigned NumElts = 1;
18309 EVT VT = N->getValueType(ResNo: 0);
18310 if (VT.isVector() && DAG.isSplatValue(V: N1))
18311 NumElts = VT.getVectorMinNumElements();
18312
18313 if (!MinUses || (N1->use_size() * NumElts) < MinUses)
18314 return SDValue();
18315
18316 // Find all FDIV users of the same divisor.
18317 // Use a set because duplicates may be present in the user list.
18318 SetVector<SDNode *> Users;
18319 for (auto *U : N1->users()) {
18320 if (U->getOpcode() == ISD::FDIV && U->getOperand(Num: 1) == N1) {
18321 // Skip X/sqrt(X) that has not been simplified to sqrt(X) yet.
18322 if (U->getOperand(Num: 1).getOpcode() == ISD::FSQRT &&
18323 U->getOperand(Num: 0) == U->getOperand(Num: 1).getOperand(i: 0) &&
18324 U->getFlags().hasAllowReassociation() &&
18325 U->getFlags().hasNoSignedZeros())
18326 continue;
18327
18328 // This division is eligible for optimization only if global unsafe math
18329 // is enabled or if this division allows reciprocal formation.
18330 if (UnsafeMath || U->getFlags().hasAllowReciprocal())
18331 Users.insert(X: U);
18332 }
18333 }
18334
18335 // Now that we have the actual number of divisor uses, make sure it meets
18336 // the minimum threshold specified by the target.
18337 if ((Users.size() * NumElts) < MinUses)
18338 return SDValue();
18339
18340 SDLoc DL(N);
18341 SDValue FPOne = DAG.getConstantFP(Val: 1.0, DL, VT);
18342 SDValue Reciprocal = DAG.getNode(Opcode: ISD::FDIV, DL, VT, N1: FPOne, N2: N1, Flags);
18343
18344 // Dividend / Divisor -> Dividend * Reciprocal
18345 for (auto *U : Users) {
18346 SDValue Dividend = U->getOperand(Num: 0);
18347 if (Dividend != FPOne) {
18348 SDValue NewNode = DAG.getNode(Opcode: ISD::FMUL, DL: SDLoc(U), VT, N1: Dividend,
18349 N2: Reciprocal, Flags);
18350 CombineTo(N: U, Res: NewNode);
18351 } else if (U != Reciprocal.getNode()) {
18352 // In the absence of fast-math-flags, this user node is always the
18353 // same node as Reciprocal, but with FMF they may be different nodes.
18354 CombineTo(N: U, Res: Reciprocal);
18355 }
18356 }
18357 return SDValue(N, 0); // N was replaced.
18358}
18359
18360SDValue DAGCombiner::visitFDIV(SDNode *N) {
18361 SDValue N0 = N->getOperand(Num: 0);
18362 SDValue N1 = N->getOperand(Num: 1);
18363 EVT VT = N->getValueType(ResNo: 0);
18364 SDLoc DL(N);
18365 const TargetOptions &Options = DAG.getTarget().Options;
18366 SDNodeFlags Flags = N->getFlags();
18367 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18368
18369 if (SDValue R = DAG.simplifyFPBinop(Opcode: N->getOpcode(), X: N0, Y: N1, Flags))
18370 return R;
18371
18372 // fold (fdiv c1, c2) -> c1/c2
18373 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FDIV, DL, VT, Ops: {N0, N1}))
18374 return C;
18375
18376 // fold vector ops
18377 if (VT.isVector())
18378 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
18379 return FoldedVOp;
18380
18381 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
18382 return NewSel;
18383
18384 if (SDValue V = combineRepeatedFPDivisors(N))
18385 return V;
18386
18387 // fold (fdiv X, c2) -> (fmul X, 1/c2) if there is no loss in precision, or
18388 // the loss is acceptable with AllowReciprocal.
18389 if (auto *N1CFP = isConstOrConstSplatFP(N: N1, AllowUndefs: true)) {
18390 // Compute the reciprocal 1.0 / c2.
18391 const APFloat &N1APF = N1CFP->getValueAPF();
18392 APFloat Recip = APFloat::getOne(Sem: N1APF.getSemantics());
18393 APFloat::opStatus st = Recip.divide(RHS: N1APF, RM: APFloat::rmNearestTiesToEven);
18394 // Only do the transform if the reciprocal is a legal fp immediate that
18395 // isn't too nasty (eg NaN, denormal, ...).
18396 if (((st == APFloat::opOK && !Recip.isDenormal()) ||
18397 (st == APFloat::opInexact && Flags.hasAllowReciprocal())) &&
18398 (!LegalOperations ||
18399 // FIXME: custom lowering of ConstantFP might fail (see e.g. ARM
18400 // backend)... we should handle this gracefully after Legalize.
18401 // TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT) ||
18402 TLI.isOperationLegal(Op: ISD::ConstantFP, VT) ||
18403 TLI.isFPImmLegal(Recip, VT, ForCodeSize)))
18404 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0,
18405 N2: DAG.getConstantFP(Val: Recip, DL, VT));
18406 }
18407
18408 if (Flags.hasAllowReciprocal()) {
18409 // If this FDIV is part of a reciprocal square root, it may be folded
18410 // into a target-specific square root estimate instruction.
18411 if (N1.getOpcode() == ISD::FSQRT) {
18412 if (SDValue RV = buildRsqrtEstimate(Op: N1.getOperand(i: 0), Flags))
18413 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: RV);
18414 } else if (N1.getOpcode() == ISD::FP_EXTEND &&
18415 N1.getOperand(i: 0).getOpcode() == ISD::FSQRT) {
18416 if (SDValue RV =
18417 buildRsqrtEstimate(Op: N1.getOperand(i: 0).getOperand(i: 0), Flags)) {
18418 RV = DAG.getNode(Opcode: ISD::FP_EXTEND, DL: SDLoc(N1), VT, Operand: RV);
18419 AddToWorklist(N: RV.getNode());
18420 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: RV);
18421 }
18422 } else if (N1.getOpcode() == ISD::FP_ROUND &&
18423 N1.getOperand(i: 0).getOpcode() == ISD::FSQRT) {
18424 if (SDValue RV =
18425 buildRsqrtEstimate(Op: N1.getOperand(i: 0).getOperand(i: 0), Flags)) {
18426 RV = DAG.getNode(Opcode: ISD::FP_ROUND, DL: SDLoc(N1), VT, N1: RV, N2: N1.getOperand(i: 1));
18427 AddToWorklist(N: RV.getNode());
18428 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: RV);
18429 }
18430 } else if (N1.getOpcode() == ISD::FMUL) {
18431 // Look through an FMUL. Even though this won't remove the FDIV directly,
18432 // it's still worthwhile to get rid of the FSQRT if possible.
18433 SDValue Sqrt, Y;
18434 if (N1.getOperand(i: 0).getOpcode() == ISD::FSQRT) {
18435 Sqrt = N1.getOperand(i: 0);
18436 Y = N1.getOperand(i: 1);
18437 } else if (N1.getOperand(i: 1).getOpcode() == ISD::FSQRT) {
18438 Sqrt = N1.getOperand(i: 1);
18439 Y = N1.getOperand(i: 0);
18440 }
18441 if (Sqrt.getNode()) {
18442 // If the other multiply operand is known positive, pull it into the
18443 // sqrt. That will eliminate the division if we convert to an estimate.
18444 if (Flags.hasAllowReassociation() && N1.hasOneUse() &&
18445 N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse()) {
18446 SDValue A;
18447 if (Y.getOpcode() == ISD::FABS && Y.hasOneUse())
18448 A = Y.getOperand(i: 0);
18449 else if (Y == Sqrt.getOperand(i: 0))
18450 A = Y;
18451 if (A) {
18452 // X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)
18453 // X / (A * sqrt(A)) --> X / sqrt(A*A*A) --> X * rsqrt(A*A*A)
18454 SDValue AA = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: A, N2: A);
18455 SDValue AAZ =
18456 DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: AA, N2: Sqrt.getOperand(i: 0));
18457 if (SDValue Rsqrt = buildRsqrtEstimate(Op: AAZ, Flags))
18458 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: Rsqrt);
18459
18460 // Estimate creation failed. Clean up speculatively created nodes.
18461 recursivelyDeleteUnusedNodes(N: AAZ.getNode());
18462 }
18463 }
18464
18465 // We found a FSQRT, so try to make this fold:
18466 // X / (Y * sqrt(Z)) -> X * (rsqrt(Z) / Y)
18467 if (SDValue Rsqrt = buildRsqrtEstimate(Op: Sqrt.getOperand(i: 0), Flags)) {
18468 SDValue Div = DAG.getNode(Opcode: ISD::FDIV, DL: SDLoc(N1), VT, N1: Rsqrt, N2: Y);
18469 AddToWorklist(N: Div.getNode());
18470 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: Div);
18471 }
18472 }
18473 }
18474
18475 // Fold into a reciprocal estimate and multiply instead of a real divide.
18476 if (Options.NoInfsFPMath || Flags.hasNoInfs())
18477 if (SDValue RV = BuildDivEstimate(N: N0, Op: N1, Flags))
18478 return RV;
18479 }
18480
18481 // Fold X/Sqrt(X) -> Sqrt(X)
18482 if ((Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros()) &&
18483 Flags.hasAllowReassociation())
18484 if (N1.getOpcode() == ISD::FSQRT && N0 == N1.getOperand(i: 0))
18485 return N1;
18486
18487 // (fdiv (fneg X), (fneg Y)) -> (fdiv X, Y)
18488 TargetLowering::NegatibleCost CostN0 =
18489 TargetLowering::NegatibleCost::Expensive;
18490 TargetLowering::NegatibleCost CostN1 =
18491 TargetLowering::NegatibleCost::Expensive;
18492 SDValue NegN0 =
18493 TLI.getNegatedExpression(Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN0);
18494 if (NegN0) {
18495 HandleSDNode NegN0Handle(NegN0);
18496 SDValue NegN1 =
18497 TLI.getNegatedExpression(Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN1);
18498 if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
18499 CostN1 == TargetLowering::NegatibleCost::Cheaper))
18500 return DAG.getNode(Opcode: ISD::FDIV, DL, VT, N1: NegN0, N2: NegN1);
18501 }
18502
18503 if (SDValue R = combineFMulOrFDivWithIntPow2(N))
18504 return R;
18505
18506 return SDValue();
18507}
18508
18509SDValue DAGCombiner::visitFREM(SDNode *N) {
18510 SDValue N0 = N->getOperand(Num: 0);
18511 SDValue N1 = N->getOperand(Num: 1);
18512 EVT VT = N->getValueType(ResNo: 0);
18513 SDNodeFlags Flags = N->getFlags();
18514 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18515 SDLoc DL(N);
18516
18517 if (SDValue R = DAG.simplifyFPBinop(Opcode: N->getOpcode(), X: N0, Y: N1, Flags))
18518 return R;
18519
18520 // fold (frem c1, c2) -> fmod(c1,c2)
18521 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FREM, DL, VT, Ops: {N0, N1}))
18522 return C;
18523
18524 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
18525 return NewSel;
18526
18527 // Lower frem N0, N1 => x - trunc(N0 / N1) * N1, providing N1 is an integer
18528 // power of 2.
18529 if (!TLI.isOperationLegal(Op: ISD::FREM, VT) &&
18530 TLI.isOperationLegalOrCustom(Op: ISD::FMUL, VT) &&
18531 TLI.isOperationLegalOrCustom(Op: ISD::FDIV, VT) &&
18532 TLI.isOperationLegalOrCustom(Op: ISD::FTRUNC, VT) &&
18533 DAG.isKnownToBeAPowerOfTwoFP(Val: N1)) {
18534 bool NeedsCopySign =
18535 !Flags.hasNoSignedZeros() && !DAG.cannotBeOrderedNegativeFP(Op: N0);
18536 SDValue Div = DAG.getNode(Opcode: ISD::FDIV, DL, VT, N1: N0, N2: N1);
18537 SDValue Rnd = DAG.getNode(Opcode: ISD::FTRUNC, DL, VT, Operand: Div);
18538 SDValue MLA;
18539 if (TLI.isFMAFasterThanFMulAndFAdd(MF: DAG.getMachineFunction(), VT)) {
18540 MLA = DAG.getNode(Opcode: ISD::FMA, DL, VT, N1: DAG.getNode(Opcode: ISD::FNEG, DL, VT, Operand: Rnd),
18541 N2: N1, N3: N0);
18542 } else {
18543 SDValue Mul = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Rnd, N2: N1);
18544 MLA = DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1: N0, N2: Mul);
18545 }
18546 return NeedsCopySign ? DAG.getNode(Opcode: ISD::FCOPYSIGN, DL, VT, N1: MLA, N2: N0) : MLA;
18547 }
18548
18549 return SDValue();
18550}
18551
18552SDValue DAGCombiner::visitFSQRT(SDNode *N) {
18553 SDNodeFlags Flags = N->getFlags();
18554 const TargetOptions &Options = DAG.getTarget().Options;
18555
18556 // Require 'ninf' flag since sqrt(+Inf) = +Inf, but the estimation goes as:
18557 // sqrt(+Inf) == rsqrt(+Inf) * +Inf = 0 * +Inf = NaN
18558 if (!Flags.hasApproximateFuncs() ||
18559 (!Options.NoInfsFPMath && !Flags.hasNoInfs()))
18560 return SDValue();
18561
18562 SDValue N0 = N->getOperand(Num: 0);
18563 if (TLI.isFsqrtCheap(X: N0, DAG))
18564 return SDValue();
18565
18566 // FSQRT nodes have flags that propagate to the created nodes.
18567 // TODO: If this is N0/sqrt(N0), and we reach this node before trying to
18568 // transform the fdiv, we may produce a sub-optimal estimate sequence
18569 // because the reciprocal calculation may not have to filter out a
18570 // 0.0 input.
18571 return buildSqrtEstimate(Op: N0, Flags);
18572}
18573
18574/// copysign(x, fp_extend(y)) -> copysign(x, y)
18575/// copysign(x, fp_round(y)) -> copysign(x, y)
18576/// Operands to the functions are the type of X and Y respectively.
18577static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(EVT XTy, EVT YTy) {
18578 // Always fold no-op FP casts.
18579 if (XTy == YTy)
18580 return true;
18581
18582 // Do not optimize out type conversion of f128 type yet.
18583 // For some targets like x86_64, configuration is changed to keep one f128
18584 // value in one SSE register, but instruction selection cannot handle
18585 // FCOPYSIGN on SSE registers yet.
18586 if (YTy == MVT::f128)
18587 return false;
18588
18589 // Avoid mismatched vector operand types, for better instruction selection.
18590 return !YTy.isVector();
18591}
18592
18593static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode *N) {
18594 SDValue N1 = N->getOperand(Num: 1);
18595 if (N1.getOpcode() != ISD::FP_EXTEND &&
18596 N1.getOpcode() != ISD::FP_ROUND)
18597 return false;
18598 EVT N1VT = N1->getValueType(ResNo: 0);
18599 EVT N1Op0VT = N1->getOperand(Num: 0).getValueType();
18600 return CanCombineFCOPYSIGN_EXTEND_ROUND(XTy: N1VT, YTy: N1Op0VT);
18601}
18602
18603SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) {
18604 SDValue N0 = N->getOperand(Num: 0);
18605 SDValue N1 = N->getOperand(Num: 1);
18606 EVT VT = N->getValueType(ResNo: 0);
18607 SDLoc DL(N);
18608
18609 // fold (fcopysign c1, c2) -> fcopysign(c1,c2)
18610 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FCOPYSIGN, DL, VT, Ops: {N0, N1}))
18611 return C;
18612
18613 // copysign(x, fp_extend(y)) -> copysign(x, y)
18614 // copysign(x, fp_round(y)) -> copysign(x, y)
18615 if (CanCombineFCOPYSIGN_EXTEND_ROUND(N))
18616 return DAG.getNode(Opcode: ISD::FCOPYSIGN, DL, VT, N1: N0, N2: N1.getOperand(i: 0));
18617
18618 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
18619 return SDValue(N, 0);
18620
18621 return SDValue();
18622}
18623
18624SDValue DAGCombiner::visitFPOW(SDNode *N) {
18625 ConstantFPSDNode *ExponentC = isConstOrConstSplatFP(N: N->getOperand(Num: 1));
18626 if (!ExponentC)
18627 return SDValue();
18628 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18629
18630 // Try to convert x ** (1/3) into cube root.
18631 // TODO: Handle the various flavors of long double.
18632 // TODO: Since we're approximating, we don't need an exact 1/3 exponent.
18633 // Some range near 1/3 should be fine.
18634 EVT VT = N->getValueType(ResNo: 0);
18635 if ((VT == MVT::f32 && ExponentC->getValueAPF().isExactlyValue(V: 1.0f/3.0f)) ||
18636 (VT == MVT::f64 && ExponentC->getValueAPF().isExactlyValue(V: 1.0/3.0))) {
18637 // pow(-0.0, 1/3) = +0.0; cbrt(-0.0) = -0.0.
18638 // pow(-inf, 1/3) = +inf; cbrt(-inf) = -inf.
18639 // pow(-val, 1/3) = nan; cbrt(-val) = -num.
18640 // For regular numbers, rounding may cause the results to differ.
18641 // Therefore, we require { nsz ninf nnan afn } for this transform.
18642 // TODO: We could select out the special cases if we don't have nsz/ninf.
18643 SDNodeFlags Flags = N->getFlags();
18644 if (!Flags.hasNoSignedZeros() || !Flags.hasNoInfs() || !Flags.hasNoNaNs() ||
18645 !Flags.hasApproximateFuncs())
18646 return SDValue();
18647
18648 // Do not create a cbrt() libcall if the target does not have it, and do not
18649 // turn a pow that has lowering support into a cbrt() libcall.
18650 if (!DAG.getLibInfo().has(F: LibFunc_cbrt) ||
18651 (!DAG.getTargetLoweringInfo().isOperationExpand(Op: ISD::FPOW, VT) &&
18652 DAG.getTargetLoweringInfo().isOperationExpand(Op: ISD::FCBRT, VT)))
18653 return SDValue();
18654
18655 return DAG.getNode(Opcode: ISD::FCBRT, DL: SDLoc(N), VT, Operand: N->getOperand(Num: 0));
18656 }
18657
18658 // Try to convert x ** (1/4) and x ** (3/4) into square roots.
18659 // x ** (1/2) is canonicalized to sqrt, so we do not bother with that case.
18660 // TODO: This could be extended (using a target hook) to handle smaller
18661 // power-of-2 fractional exponents.
18662 bool ExponentIs025 = ExponentC->getValueAPF().isExactlyValue(V: 0.25);
18663 bool ExponentIs075 = ExponentC->getValueAPF().isExactlyValue(V: 0.75);
18664 if (ExponentIs025 || ExponentIs075) {
18665 // pow(-0.0, 0.25) = +0.0; sqrt(sqrt(-0.0)) = -0.0.
18666 // pow(-inf, 0.25) = +inf; sqrt(sqrt(-inf)) = NaN.
18667 // pow(-0.0, 0.75) = +0.0; sqrt(-0.0) * sqrt(sqrt(-0.0)) = +0.0.
18668 // pow(-inf, 0.75) = +inf; sqrt(-inf) * sqrt(sqrt(-inf)) = NaN.
18669 // For regular numbers, rounding may cause the results to differ.
18670 // Therefore, we require { nsz ninf afn } for this transform.
18671 // TODO: We could select out the special cases if we don't have nsz/ninf.
18672 SDNodeFlags Flags = N->getFlags();
18673
18674 // We only need no signed zeros for the 0.25 case.
18675 if ((!Flags.hasNoSignedZeros() && ExponentIs025) || !Flags.hasNoInfs() ||
18676 !Flags.hasApproximateFuncs())
18677 return SDValue();
18678
18679 // Don't double the number of libcalls. We are trying to inline fast code.
18680 if (!DAG.getTargetLoweringInfo().isOperationLegalOrCustom(Op: ISD::FSQRT, VT))
18681 return SDValue();
18682
18683 // Assume that libcalls are the smallest code.
18684 // TODO: This restriction should probably be lifted for vectors.
18685 if (ForCodeSize)
18686 return SDValue();
18687
18688 // pow(X, 0.25) --> sqrt(sqrt(X))
18689 SDLoc DL(N);
18690 SDValue Sqrt = DAG.getNode(Opcode: ISD::FSQRT, DL, VT, Operand: N->getOperand(Num: 0));
18691 SDValue SqrtSqrt = DAG.getNode(Opcode: ISD::FSQRT, DL, VT, Operand: Sqrt);
18692 if (ExponentIs025)
18693 return SqrtSqrt;
18694 // pow(X, 0.75) --> sqrt(X) * sqrt(sqrt(X))
18695 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Sqrt, N2: SqrtSqrt);
18696 }
18697
18698 return SDValue();
18699}
18700
18701static SDValue foldFPToIntToFP(SDNode *N, const SDLoc &DL, SelectionDAG &DAG,
18702 const TargetLowering &TLI) {
18703 // We only do this if the target has legal ftrunc. Otherwise, we'd likely be
18704 // replacing casts with a libcall. We also must be allowed to ignore -0.0
18705 // because FTRUNC will return -0.0 for (-1.0, -0.0), but using integer
18706 // conversions would return +0.0.
18707 // FIXME: We should be able to use node-level FMF here.
18708 // TODO: If strict math, should we use FABS (+ range check for signed cast)?
18709 EVT VT = N->getValueType(ResNo: 0);
18710 if (!TLI.isOperationLegal(Op: ISD::FTRUNC, VT) ||
18711 !DAG.getTarget().Options.NoSignedZerosFPMath)
18712 return SDValue();
18713
18714 // fptosi/fptoui round towards zero, so converting from FP to integer and
18715 // back is the same as an 'ftrunc': [us]itofp (fpto[us]i X) --> ftrunc X
18716 SDValue N0 = N->getOperand(Num: 0);
18717 if (N->getOpcode() == ISD::SINT_TO_FP && N0.getOpcode() == ISD::FP_TO_SINT &&
18718 N0.getOperand(i: 0).getValueType() == VT)
18719 return DAG.getNode(Opcode: ISD::FTRUNC, DL, VT, Operand: N0.getOperand(i: 0));
18720
18721 if (N->getOpcode() == ISD::UINT_TO_FP && N0.getOpcode() == ISD::FP_TO_UINT &&
18722 N0.getOperand(i: 0).getValueType() == VT)
18723 return DAG.getNode(Opcode: ISD::FTRUNC, DL, VT, Operand: N0.getOperand(i: 0));
18724
18725 return SDValue();
18726}
18727
18728SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) {
18729 SDValue N0 = N->getOperand(Num: 0);
18730 EVT VT = N->getValueType(ResNo: 0);
18731 EVT OpVT = N0.getValueType();
18732 SDLoc DL(N);
18733
18734 // [us]itofp(undef) = 0, because the result value is bounded.
18735 if (N0.isUndef())
18736 return DAG.getConstantFP(Val: 0.0, DL, VT);
18737
18738 // fold (sint_to_fp c1) -> c1fp
18739 // ...but only if the target supports immediate floating-point values
18740 if ((!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::ConstantFP, VT)))
18741 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::SINT_TO_FP, DL, VT, Ops: {N0}))
18742 return C;
18743
18744 // If the input is a legal type, and SINT_TO_FP is not legal on this target,
18745 // but UINT_TO_FP is legal on this target, try to convert.
18746 if (!hasOperation(Opcode: ISD::SINT_TO_FP, VT: OpVT) &&
18747 hasOperation(Opcode: ISD::UINT_TO_FP, VT: OpVT)) {
18748 // If the sign bit is known to be zero, we can change this to UINT_TO_FP.
18749 if (DAG.SignBitIsZero(Op: N0))
18750 return DAG.getNode(Opcode: ISD::UINT_TO_FP, DL, VT, Operand: N0);
18751 }
18752
18753 // The next optimizations are desirable only if SELECT_CC can be lowered.
18754 // fold (sint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), -1.0, 0.0)
18755 if (N0.getOpcode() == ISD::SETCC && N0.getValueType() == MVT::i1 &&
18756 !VT.isVector() &&
18757 (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::ConstantFP, VT)))
18758 return DAG.getSelect(DL, VT, Cond: N0, LHS: DAG.getConstantFP(Val: -1.0, DL, VT),
18759 RHS: DAG.getConstantFP(Val: 0.0, DL, VT));
18760
18761 // fold (sint_to_fp (zext (setcc x, y, cc))) ->
18762 // (select (setcc x, y, cc), 1.0, 0.0)
18763 if (N0.getOpcode() == ISD::ZERO_EXTEND &&
18764 N0.getOperand(i: 0).getOpcode() == ISD::SETCC && !VT.isVector() &&
18765 (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::ConstantFP, VT)))
18766 return DAG.getSelect(DL, VT, Cond: N0.getOperand(i: 0),
18767 LHS: DAG.getConstantFP(Val: 1.0, DL, VT),
18768 RHS: DAG.getConstantFP(Val: 0.0, DL, VT));
18769
18770 if (SDValue FTrunc = foldFPToIntToFP(N, DL, DAG, TLI))
18771 return FTrunc;
18772
18773 return SDValue();
18774}
18775
18776SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) {
18777 SDValue N0 = N->getOperand(Num: 0);
18778 EVT VT = N->getValueType(ResNo: 0);
18779 EVT OpVT = N0.getValueType();
18780 SDLoc DL(N);
18781
18782 // [us]itofp(undef) = 0, because the result value is bounded.
18783 if (N0.isUndef())
18784 return DAG.getConstantFP(Val: 0.0, DL, VT);
18785
18786 // fold (uint_to_fp c1) -> c1fp
18787 // ...but only if the target supports immediate floating-point values
18788 if ((!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::ConstantFP, VT)))
18789 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::UINT_TO_FP, DL, VT, Ops: {N0}))
18790 return C;
18791
18792 // If the input is a legal type, and UINT_TO_FP is not legal on this target,
18793 // but SINT_TO_FP is legal on this target, try to convert.
18794 if (!hasOperation(Opcode: ISD::UINT_TO_FP, VT: OpVT) &&
18795 hasOperation(Opcode: ISD::SINT_TO_FP, VT: OpVT)) {
18796 // If the sign bit is known to be zero, we can change this to SINT_TO_FP.
18797 if (DAG.SignBitIsZero(Op: N0))
18798 return DAG.getNode(Opcode: ISD::SINT_TO_FP, DL, VT, Operand: N0);
18799 }
18800
18801 // fold (uint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), 1.0, 0.0)
18802 if (N0.getOpcode() == ISD::SETCC && !VT.isVector() &&
18803 (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::ConstantFP, VT)))
18804 return DAG.getSelect(DL, VT, Cond: N0, LHS: DAG.getConstantFP(Val: 1.0, DL, VT),
18805 RHS: DAG.getConstantFP(Val: 0.0, DL, VT));
18806
18807 if (SDValue FTrunc = foldFPToIntToFP(N, DL, DAG, TLI))
18808 return FTrunc;
18809
18810 return SDValue();
18811}
18812
18813// Fold (fp_to_{s/u}int ({s/u}int_to_fpx)) -> zext x, sext x, trunc x, or x
18814static SDValue FoldIntToFPToInt(SDNode *N, const SDLoc &DL, SelectionDAG &DAG) {
18815 SDValue N0 = N->getOperand(Num: 0);
18816 EVT VT = N->getValueType(ResNo: 0);
18817
18818 if (N0.getOpcode() != ISD::UINT_TO_FP && N0.getOpcode() != ISD::SINT_TO_FP)
18819 return SDValue();
18820
18821 SDValue Src = N0.getOperand(i: 0);
18822 EVT SrcVT = Src.getValueType();
18823 bool IsInputSigned = N0.getOpcode() == ISD::SINT_TO_FP;
18824 bool IsOutputSigned = N->getOpcode() == ISD::FP_TO_SINT;
18825
18826 // We can safely assume the conversion won't overflow the output range,
18827 // because (for example) (uint8_t)18293.f is undefined behavior.
18828
18829 // Since we can assume the conversion won't overflow, our decision as to
18830 // whether the input will fit in the float should depend on the minimum
18831 // of the input range and output range.
18832
18833 // This means this is also safe for a signed input and unsigned output, since
18834 // a negative input would lead to undefined behavior.
18835 unsigned InputSize = (int)SrcVT.getScalarSizeInBits() - IsInputSigned;
18836 unsigned OutputSize = (int)VT.getScalarSizeInBits();
18837 unsigned ActualSize = std::min(a: InputSize, b: OutputSize);
18838 const fltSemantics &Sem = N0.getValueType().getFltSemantics();
18839
18840 // We can only fold away the float conversion if the input range can be
18841 // represented exactly in the float range.
18842 if (APFloat::semanticsPrecision(Sem) >= ActualSize) {
18843 if (VT.getScalarSizeInBits() > SrcVT.getScalarSizeInBits()) {
18844 unsigned ExtOp =
18845 IsInputSigned && IsOutputSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
18846 return DAG.getNode(Opcode: ExtOp, DL, VT, Operand: Src);
18847 }
18848 if (VT.getScalarSizeInBits() < SrcVT.getScalarSizeInBits())
18849 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Src);
18850 return DAG.getBitcast(VT, V: Src);
18851 }
18852 return SDValue();
18853}
18854
18855SDValue DAGCombiner::visitFP_TO_SINT(SDNode *N) {
18856 SDValue N0 = N->getOperand(Num: 0);
18857 EVT VT = N->getValueType(ResNo: 0);
18858 SDLoc DL(N);
18859
18860 // fold (fp_to_sint undef) -> undef
18861 if (N0.isUndef())
18862 return DAG.getUNDEF(VT);
18863
18864 // fold (fp_to_sint c1fp) -> c1
18865 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FP_TO_SINT, DL, VT, Ops: {N0}))
18866 return C;
18867
18868 return FoldIntToFPToInt(N, DL, DAG);
18869}
18870
18871SDValue DAGCombiner::visitFP_TO_UINT(SDNode *N) {
18872 SDValue N0 = N->getOperand(Num: 0);
18873 EVT VT = N->getValueType(ResNo: 0);
18874 SDLoc DL(N);
18875
18876 // fold (fp_to_uint undef) -> undef
18877 if (N0.isUndef())
18878 return DAG.getUNDEF(VT);
18879
18880 // fold (fp_to_uint c1fp) -> c1
18881 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FP_TO_UINT, DL, VT, Ops: {N0}))
18882 return C;
18883
18884 return FoldIntToFPToInt(N, DL, DAG);
18885}
18886
18887SDValue DAGCombiner::visitXROUND(SDNode *N) {
18888 SDValue N0 = N->getOperand(Num: 0);
18889 EVT VT = N->getValueType(ResNo: 0);
18890
18891 // fold (lrint|llrint undef) -> undef
18892 // fold (lround|llround undef) -> undef
18893 if (N0.isUndef())
18894 return DAG.getUNDEF(VT);
18895
18896 // fold (lrint|llrint c1fp) -> c1
18897 // fold (lround|llround c1fp) -> c1
18898 if (SDValue C =
18899 DAG.FoldConstantArithmetic(Opcode: N->getOpcode(), DL: SDLoc(N), VT, Ops: {N0}))
18900 return C;
18901
18902 return SDValue();
18903}
18904
18905SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
18906 SDValue N0 = N->getOperand(Num: 0);
18907 SDValue N1 = N->getOperand(Num: 1);
18908 EVT VT = N->getValueType(ResNo: 0);
18909 SDLoc DL(N);
18910
18911 // fold (fp_round c1fp) -> c1fp
18912 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FP_ROUND, DL, VT, Ops: {N0, N1}))
18913 return C;
18914
18915 // fold (fp_round (fp_extend x)) -> x
18916 if (N0.getOpcode() == ISD::FP_EXTEND && VT == N0.getOperand(i: 0).getValueType())
18917 return N0.getOperand(i: 0);
18918
18919 // fold (fp_round (fp_round x)) -> (fp_round x)
18920 if (N0.getOpcode() == ISD::FP_ROUND) {
18921 const bool NIsTrunc = N->getConstantOperandVal(Num: 1) == 1;
18922 const bool N0IsTrunc = N0.getConstantOperandVal(i: 1) == 1;
18923
18924 // Avoid folding legal fp_rounds into non-legal ones.
18925 if (!hasOperation(Opcode: ISD::FP_ROUND, VT))
18926 return SDValue();
18927
18928 // Skip this folding if it results in an fp_round from f80 to f16.
18929 //
18930 // f80 to f16 always generates an expensive (and as yet, unimplemented)
18931 // libcall to __truncxfhf2 instead of selecting native f16 conversion
18932 // instructions from f32 or f64. Moreover, the first (value-preserving)
18933 // fp_round from f80 to either f32 or f64 may become a NOP in platforms like
18934 // x86.
18935 if (N0.getOperand(i: 0).getValueType() == MVT::f80 && VT == MVT::f16)
18936 return SDValue();
18937
18938 // If the first fp_round isn't a value preserving truncation, it might
18939 // introduce a tie in the second fp_round, that wouldn't occur in the
18940 // single-step fp_round we want to fold to.
18941 // In other words, double rounding isn't the same as rounding.
18942 // Also, this is a value preserving truncation iff both fp_round's are.
18943 if (DAG.getTarget().Options.UnsafeFPMath || N0IsTrunc)
18944 return DAG.getNode(
18945 Opcode: ISD::FP_ROUND, DL, VT, N1: N0.getOperand(i: 0),
18946 N2: DAG.getIntPtrConstant(Val: NIsTrunc && N0IsTrunc, DL, /*isTarget=*/true));
18947 }
18948
18949 // fold (fp_round (copysign X, Y)) -> (copysign (fp_round X), Y)
18950 // Note: From a legality perspective, this is a two step transform. First,
18951 // we duplicate the fp_round to the arguments of the copysign, then we
18952 // eliminate the fp_round on Y. The second step requires an additional
18953 // predicate to match the implementation above.
18954 if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse() &&
18955 CanCombineFCOPYSIGN_EXTEND_ROUND(XTy: VT,
18956 YTy: N0.getValueType())) {
18957 SDValue Tmp = DAG.getNode(Opcode: ISD::FP_ROUND, DL: SDLoc(N0), VT,
18958 N1: N0.getOperand(i: 0), N2: N1);
18959 AddToWorklist(N: Tmp.getNode());
18960 return DAG.getNode(Opcode: ISD::FCOPYSIGN, DL, VT, N1: Tmp, N2: N0.getOperand(i: 1));
18961 }
18962
18963 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(Cast: N))
18964 return NewVSel;
18965
18966 return SDValue();
18967}
18968
18969// Eliminate a floating-point widening of a narrowed value if the fast math
18970// flags allow it.
18971static SDValue eliminateFPCastPair(SDNode *N) {
18972 SDValue N0 = N->getOperand(Num: 0);
18973 EVT VT = N->getValueType(ResNo: 0);
18974
18975 unsigned NarrowingOp;
18976 switch (N->getOpcode()) {
18977 case ISD::FP16_TO_FP:
18978 NarrowingOp = ISD::FP_TO_FP16;
18979 break;
18980 case ISD::BF16_TO_FP:
18981 NarrowingOp = ISD::FP_TO_BF16;
18982 break;
18983 case ISD::FP_EXTEND:
18984 NarrowingOp = ISD::FP_ROUND;
18985 break;
18986 default:
18987 llvm_unreachable("Expected widening FP cast");
18988 }
18989
18990 if (N0.getOpcode() == NarrowingOp && N0.getOperand(i: 0).getValueType() == VT) {
18991 const SDNodeFlags NarrowFlags = N0->getFlags();
18992 const SDNodeFlags WidenFlags = N->getFlags();
18993 // Narrowing can introduce inf and change the encoding of a nan, so the
18994 // widen must have the nnan and ninf flags to indicate that we don't need to
18995 // care about that. We are also removing a rounding step, and that requires
18996 // both the narrow and widen to allow contraction.
18997 if (WidenFlags.hasNoNaNs() && WidenFlags.hasNoInfs() &&
18998 NarrowFlags.hasAllowContract() && WidenFlags.hasAllowContract()) {
18999 return N0.getOperand(i: 0);
19000 }
19001 }
19002
19003 return SDValue();
19004}
19005
19006SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
19007 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
19008 SDValue N0 = N->getOperand(Num: 0);
19009 EVT VT = N->getValueType(ResNo: 0);
19010 SDLoc DL(N);
19011
19012 if (VT.isVector())
19013 if (SDValue FoldedVOp = SimplifyVCastOp(N, DL))
19014 return FoldedVOp;
19015
19016 // If this is fp_round(fpextend), don't fold it, allow ourselves to be folded.
19017 if (N->hasOneUse() && N->user_begin()->getOpcode() == ISD::FP_ROUND)
19018 return SDValue();
19019
19020 // fold (fp_extend c1fp) -> c1fp
19021 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FP_EXTEND, DL, VT, Ops: {N0}))
19022 return C;
19023
19024 // fold (fp_extend (fp16_to_fp op)) -> (fp16_to_fp op)
19025 if (N0.getOpcode() == ISD::FP16_TO_FP &&
19026 TLI.getOperationAction(Op: ISD::FP16_TO_FP, VT) == TargetLowering::Legal)
19027 return DAG.getNode(Opcode: ISD::FP16_TO_FP, DL, VT, Operand: N0.getOperand(i: 0));
19028
19029 // Turn fp_extend(fp_round(X, 1)) -> x since the fp_round doesn't affect the
19030 // value of X.
19031 if (N0.getOpcode() == ISD::FP_ROUND && N0.getConstantOperandVal(i: 1) == 1) {
19032 SDValue In = N0.getOperand(i: 0);
19033 if (In.getValueType() == VT) return In;
19034 if (VT.bitsLT(VT: In.getValueType()))
19035 return DAG.getNode(Opcode: ISD::FP_ROUND, DL, VT, N1: In, N2: N0.getOperand(i: 1));
19036 return DAG.getNode(Opcode: ISD::FP_EXTEND, DL, VT, Operand: In);
19037 }
19038
19039 // fold (fpext (load x)) -> (fpext (fptrunc (extload x)))
19040 if (ISD::isNormalLoad(N: N0.getNode()) && N0.hasOneUse() &&
19041 TLI.isLoadExtLegalOrCustom(ExtType: ISD::EXTLOAD, ValVT: VT, MemVT: N0.getValueType())) {
19042 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
19043 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::EXTLOAD, dl: DL, VT,
19044 Chain: LN0->getChain(),
19045 Ptr: LN0->getBasePtr(), MemVT: N0.getValueType(),
19046 MMO: LN0->getMemOperand());
19047 CombineTo(N, Res: ExtLoad);
19048 CombineTo(
19049 N: N0.getNode(),
19050 Res0: DAG.getNode(Opcode: ISD::FP_ROUND, DL: SDLoc(N0), VT: N0.getValueType(), N1: ExtLoad,
19051 N2: DAG.getIntPtrConstant(Val: 1, DL: SDLoc(N0), /*isTarget=*/true)),
19052 Res1: ExtLoad.getValue(R: 1));
19053 return SDValue(N, 0); // Return N so it doesn't get rechecked!
19054 }
19055
19056 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(Cast: N))
19057 return NewVSel;
19058
19059 if (SDValue CastEliminated = eliminateFPCastPair(N))
19060 return CastEliminated;
19061
19062 return SDValue();
19063}
19064
19065SDValue DAGCombiner::visitFCEIL(SDNode *N) {
19066 SDValue N0 = N->getOperand(Num: 0);
19067 EVT VT = N->getValueType(ResNo: 0);
19068
19069 // fold (fceil c1) -> fceil(c1)
19070 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FCEIL, DL: SDLoc(N), VT, Ops: {N0}))
19071 return C;
19072
19073 return SDValue();
19074}
19075
19076SDValue DAGCombiner::visitFTRUNC(SDNode *N) {
19077 SDValue N0 = N->getOperand(Num: 0);
19078 EVT VT = N->getValueType(ResNo: 0);
19079
19080 // fold (ftrunc c1) -> ftrunc(c1)
19081 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FTRUNC, DL: SDLoc(N), VT, Ops: {N0}))
19082 return C;
19083
19084 // fold ftrunc (known rounded int x) -> x
19085 // ftrunc is a part of fptosi/fptoui expansion on some targets, so this is
19086 // likely to be generated to extract integer from a rounded floating value.
19087 switch (N0.getOpcode()) {
19088 default: break;
19089 case ISD::FRINT:
19090 case ISD::FTRUNC:
19091 case ISD::FNEARBYINT:
19092 case ISD::FROUNDEVEN:
19093 case ISD::FFLOOR:
19094 case ISD::FCEIL:
19095 return N0;
19096 }
19097
19098 return SDValue();
19099}
19100
19101SDValue DAGCombiner::visitFFREXP(SDNode *N) {
19102 SDValue N0 = N->getOperand(Num: 0);
19103
19104 // fold (ffrexp c1) -> ffrexp(c1)
19105 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
19106 return DAG.getNode(Opcode: ISD::FFREXP, DL: SDLoc(N), VTList: N->getVTList(), N: N0);
19107 return SDValue();
19108}
19109
19110SDValue DAGCombiner::visitFFLOOR(SDNode *N) {
19111 SDValue N0 = N->getOperand(Num: 0);
19112 EVT VT = N->getValueType(ResNo: 0);
19113
19114 // fold (ffloor c1) -> ffloor(c1)
19115 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FFLOOR, DL: SDLoc(N), VT, Ops: {N0}))
19116 return C;
19117
19118 return SDValue();
19119}
19120
19121SDValue DAGCombiner::visitFNEG(SDNode *N) {
19122 SDValue N0 = N->getOperand(Num: 0);
19123 EVT VT = N->getValueType(ResNo: 0);
19124 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
19125
19126 // Constant fold FNEG.
19127 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FNEG, DL: SDLoc(N), VT, Ops: {N0}))
19128 return C;
19129
19130 if (SDValue NegN0 =
19131 TLI.getNegatedExpression(Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
19132 return NegN0;
19133
19134 // -(X-Y) -> (Y-X) is unsafe because when X==Y, -0.0 != +0.0
19135 // FIXME: This is duplicated in getNegatibleCost, but getNegatibleCost doesn't
19136 // know it was called from a context with a nsz flag if the input fsub does
19137 // not.
19138 if (N0.getOpcode() == ISD::FSUB &&
19139 (DAG.getTarget().Options.NoSignedZerosFPMath ||
19140 N->getFlags().hasNoSignedZeros()) && N0.hasOneUse()) {
19141 return DAG.getNode(Opcode: ISD::FSUB, DL: SDLoc(N), VT, N1: N0.getOperand(i: 1),
19142 N2: N0.getOperand(i: 0));
19143 }
19144
19145 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
19146 return SDValue(N, 0);
19147
19148 if (SDValue Cast = foldSignChangeInBitcast(N))
19149 return Cast;
19150
19151 return SDValue();
19152}
19153
19154SDValue DAGCombiner::visitFMinMax(SDNode *N) {
19155 SDValue N0 = N->getOperand(Num: 0);
19156 SDValue N1 = N->getOperand(Num: 1);
19157 EVT VT = N->getValueType(ResNo: 0);
19158 const SDNodeFlags Flags = N->getFlags();
19159 unsigned Opc = N->getOpcode();
19160 bool PropagatesNaN = Opc == ISD::FMINIMUM || Opc == ISD::FMAXIMUM;
19161 bool IsMin = Opc == ISD::FMINNUM || Opc == ISD::FMINIMUM;
19162 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
19163
19164 // Constant fold.
19165 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: Opc, DL: SDLoc(N), VT, Ops: {N0, N1}))
19166 return C;
19167
19168 // Canonicalize to constant on RHS.
19169 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0) &&
19170 !DAG.isConstantFPBuildVectorOrConstantFP(N: N1))
19171 return DAG.getNode(Opcode: N->getOpcode(), DL: SDLoc(N), VT, N1, N2: N0);
19172
19173 if (const ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N: N1)) {
19174 const APFloat &AF = N1CFP->getValueAPF();
19175
19176 // minnum(X, nan) -> X
19177 // maxnum(X, nan) -> X
19178 // minimum(X, nan) -> nan
19179 // maximum(X, nan) -> nan
19180 if (AF.isNaN())
19181 return PropagatesNaN ? N->getOperand(Num: 1) : N->getOperand(Num: 0);
19182
19183 // In the following folds, inf can be replaced with the largest finite
19184 // float, if the ninf flag is set.
19185 if (AF.isInfinity() || (Flags.hasNoInfs() && AF.isLargest())) {
19186 // minnum(X, -inf) -> -inf
19187 // maxnum(X, +inf) -> +inf
19188 // minimum(X, -inf) -> -inf if nnan
19189 // maximum(X, +inf) -> +inf if nnan
19190 if (IsMin == AF.isNegative() && (!PropagatesNaN || Flags.hasNoNaNs()))
19191 return N->getOperand(Num: 1);
19192
19193 // minnum(X, +inf) -> X if nnan
19194 // maxnum(X, -inf) -> X if nnan
19195 // minimum(X, +inf) -> X
19196 // maximum(X, -inf) -> X
19197 if (IsMin != AF.isNegative() && (PropagatesNaN || Flags.hasNoNaNs()))
19198 return N->getOperand(Num: 0);
19199 }
19200 }
19201
19202 if (SDValue SD = reassociateReduction(
19203 RedOpc: PropagatesNaN
19204 ? (IsMin ? ISD::VECREDUCE_FMINIMUM : ISD::VECREDUCE_FMAXIMUM)
19205 : (IsMin ? ISD::VECREDUCE_FMIN : ISD::VECREDUCE_FMAX),
19206 Opc, DL: SDLoc(N), VT, N0, N1, Flags))
19207 return SD;
19208
19209 return SDValue();
19210}
19211
19212SDValue DAGCombiner::visitFABS(SDNode *N) {
19213 SDValue N0 = N->getOperand(Num: 0);
19214 EVT VT = N->getValueType(ResNo: 0);
19215 SDLoc DL(N);
19216
19217 // fold (fabs c1) -> fabs(c1)
19218 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FABS, DL, VT, Ops: {N0}))
19219 return C;
19220
19221 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
19222 return SDValue(N, 0);
19223
19224 if (SDValue Cast = foldSignChangeInBitcast(N))
19225 return Cast;
19226
19227 return SDValue();
19228}
19229
19230SDValue DAGCombiner::visitBRCOND(SDNode *N) {
19231 SDValue Chain = N->getOperand(Num: 0);
19232 SDValue N1 = N->getOperand(Num: 1);
19233 SDValue N2 = N->getOperand(Num: 2);
19234
19235 // BRCOND(FREEZE(cond)) is equivalent to BRCOND(cond) (both are
19236 // nondeterministic jumps).
19237 if (N1->getOpcode() == ISD::FREEZE && N1.hasOneUse()) {
19238 return DAG.getNode(Opcode: ISD::BRCOND, DL: SDLoc(N), VT: MVT::Other, N1: Chain,
19239 N2: N1->getOperand(Num: 0), N3: N2, Flags: N->getFlags());
19240 }
19241
19242 // Variant of the previous fold where there is a SETCC in between:
19243 // BRCOND(SETCC(FREEZE(X), CONST, Cond))
19244 // =>
19245 // BRCOND(FREEZE(SETCC(X, CONST, Cond)))
19246 // =>
19247 // BRCOND(SETCC(X, CONST, Cond))
19248 // This is correct if FREEZE(X) has one use and SETCC(FREEZE(X), CONST, Cond)
19249 // isn't equivalent to true or false.
19250 // For example, SETCC(FREEZE(X), -128, SETULT) cannot be folded to
19251 // FREEZE(SETCC(X, -128, SETULT)) because X can be poison.
19252 if (N1->getOpcode() == ISD::SETCC && N1.hasOneUse()) {
19253 SDValue S0 = N1->getOperand(Num: 0), S1 = N1->getOperand(Num: 1);
19254 ISD::CondCode Cond = cast<CondCodeSDNode>(Val: N1->getOperand(Num: 2))->get();
19255 ConstantSDNode *S0C = dyn_cast<ConstantSDNode>(Val&: S0);
19256 ConstantSDNode *S1C = dyn_cast<ConstantSDNode>(Val&: S1);
19257 bool Updated = false;
19258
19259 // Is 'X Cond C' always true or false?
19260 auto IsAlwaysTrueOrFalse = [](ISD::CondCode Cond, ConstantSDNode *C) {
19261 bool False = (Cond == ISD::SETULT && C->isZero()) ||
19262 (Cond == ISD::SETLT && C->isMinSignedValue()) ||
19263 (Cond == ISD::SETUGT && C->isAllOnes()) ||
19264 (Cond == ISD::SETGT && C->isMaxSignedValue());
19265 bool True = (Cond == ISD::SETULE && C->isAllOnes()) ||
19266 (Cond == ISD::SETLE && C->isMaxSignedValue()) ||
19267 (Cond == ISD::SETUGE && C->isZero()) ||
19268 (Cond == ISD::SETGE && C->isMinSignedValue());
19269 return True || False;
19270 };
19271
19272 if (S0->getOpcode() == ISD::FREEZE && S0.hasOneUse() && S1C) {
19273 if (!IsAlwaysTrueOrFalse(Cond, S1C)) {
19274 S0 = S0->getOperand(Num: 0);
19275 Updated = true;
19276 }
19277 }
19278 if (S1->getOpcode() == ISD::FREEZE && S1.hasOneUse() && S0C) {
19279 if (!IsAlwaysTrueOrFalse(ISD::getSetCCSwappedOperands(Operation: Cond), S0C)) {
19280 S1 = S1->getOperand(Num: 0);
19281 Updated = true;
19282 }
19283 }
19284
19285 if (Updated)
19286 return DAG.getNode(
19287 Opcode: ISD::BRCOND, DL: SDLoc(N), VT: MVT::Other, N1: Chain,
19288 N2: DAG.getSetCC(DL: SDLoc(N1), VT: N1->getValueType(ResNo: 0), LHS: S0, RHS: S1, Cond), N3: N2,
19289 Flags: N->getFlags());
19290 }
19291
19292 // If N is a constant we could fold this into a fallthrough or unconditional
19293 // branch. However that doesn't happen very often in normal code, because
19294 // Instcombine/SimplifyCFG should have handled the available opportunities.
19295 // If we did this folding here, it would be necessary to update the
19296 // MachineBasicBlock CFG, which is awkward.
19297
19298 // fold a brcond with a setcc condition into a BR_CC node if BR_CC is legal
19299 // on the target.
19300 if (N1.getOpcode() == ISD::SETCC &&
19301 TLI.isOperationLegalOrCustom(Op: ISD::BR_CC,
19302 VT: N1.getOperand(i: 0).getValueType())) {
19303 return DAG.getNode(Opcode: ISD::BR_CC, DL: SDLoc(N), VT: MVT::Other,
19304 N1: Chain, N2: N1.getOperand(i: 2),
19305 N3: N1.getOperand(i: 0), N4: N1.getOperand(i: 1), N5: N2);
19306 }
19307
19308 if (N1.hasOneUse()) {
19309 // rebuildSetCC calls visitXor which may change the Chain when there is a
19310 // STRICT_FSETCC/STRICT_FSETCCS involved. Use a handle to track changes.
19311 HandleSDNode ChainHandle(Chain);
19312 if (SDValue NewN1 = rebuildSetCC(N: N1))
19313 return DAG.getNode(Opcode: ISD::BRCOND, DL: SDLoc(N), VT: MVT::Other,
19314 N1: ChainHandle.getValue(), N2: NewN1, N3: N2, Flags: N->getFlags());
19315 }
19316
19317 return SDValue();
19318}
19319
19320SDValue DAGCombiner::rebuildSetCC(SDValue N) {
19321 if (N.getOpcode() == ISD::SRL ||
19322 (N.getOpcode() == ISD::TRUNCATE &&
19323 (N.getOperand(i: 0).hasOneUse() &&
19324 N.getOperand(i: 0).getOpcode() == ISD::SRL))) {
19325 // Look pass the truncate.
19326 if (N.getOpcode() == ISD::TRUNCATE)
19327 N = N.getOperand(i: 0);
19328
19329 // Match this pattern so that we can generate simpler code:
19330 //
19331 // %a = ...
19332 // %b = and i32 %a, 2
19333 // %c = srl i32 %b, 1
19334 // brcond i32 %c ...
19335 //
19336 // into
19337 //
19338 // %a = ...
19339 // %b = and i32 %a, 2
19340 // %c = setcc eq %b, 0
19341 // brcond %c ...
19342 //
19343 // This applies only when the AND constant value has one bit set and the
19344 // SRL constant is equal to the log2 of the AND constant. The back-end is
19345 // smart enough to convert the result into a TEST/JMP sequence.
19346 SDValue Op0 = N.getOperand(i: 0);
19347 SDValue Op1 = N.getOperand(i: 1);
19348
19349 if (Op0.getOpcode() == ISD::AND && Op1.getOpcode() == ISD::Constant) {
19350 SDValue AndOp1 = Op0.getOperand(i: 1);
19351
19352 if (AndOp1.getOpcode() == ISD::Constant) {
19353 const APInt &AndConst = AndOp1->getAsAPIntVal();
19354
19355 if (AndConst.isPowerOf2() &&
19356 Op1->getAsAPIntVal() == AndConst.logBase2()) {
19357 SDLoc DL(N);
19358 return DAG.getSetCC(DL, VT: getSetCCResultType(VT: Op0.getValueType()),
19359 LHS: Op0, RHS: DAG.getConstant(Val: 0, DL, VT: Op0.getValueType()),
19360 Cond: ISD::SETNE);
19361 }
19362 }
19363 }
19364 }
19365
19366 // Transform (brcond (xor x, y)) -> (brcond (setcc, x, y, ne))
19367 // Transform (brcond (xor (xor x, y), -1)) -> (brcond (setcc, x, y, eq))
19368 if (N.getOpcode() == ISD::XOR) {
19369 // Because we may call this on a speculatively constructed
19370 // SimplifiedSetCC Node, we need to simplify this node first.
19371 // Ideally this should be folded into SimplifySetCC and not
19372 // here. For now, grab a handle to N so we don't lose it from
19373 // replacements interal to the visit.
19374 HandleSDNode XORHandle(N);
19375 while (N.getOpcode() == ISD::XOR) {
19376 SDValue Tmp = visitXOR(N: N.getNode());
19377 // No simplification done.
19378 if (!Tmp.getNode())
19379 break;
19380 // Returning N is form in-visit replacement that may invalidated
19381 // N. Grab value from Handle.
19382 if (Tmp.getNode() == N.getNode())
19383 N = XORHandle.getValue();
19384 else // Node simplified. Try simplifying again.
19385 N = Tmp;
19386 }
19387
19388 if (N.getOpcode() != ISD::XOR)
19389 return N;
19390
19391 SDValue Op0 = N->getOperand(Num: 0);
19392 SDValue Op1 = N->getOperand(Num: 1);
19393
19394 if (Op0.getOpcode() != ISD::SETCC && Op1.getOpcode() != ISD::SETCC) {
19395 bool Equal = false;
19396 // (brcond (xor (xor x, y), -1)) -> (brcond (setcc x, y, eq))
19397 if (isBitwiseNot(V: N) && Op0.hasOneUse() && Op0.getOpcode() == ISD::XOR &&
19398 Op0.getValueType() == MVT::i1) {
19399 N = Op0;
19400 Op0 = N->getOperand(Num: 0);
19401 Op1 = N->getOperand(Num: 1);
19402 Equal = true;
19403 }
19404
19405 EVT SetCCVT = N.getValueType();
19406 if (LegalTypes)
19407 SetCCVT = getSetCCResultType(VT: SetCCVT);
19408 // Replace the uses of XOR with SETCC. Note, avoid this transformation if
19409 // it would introduce illegal operations post-legalization as this can
19410 // result in infinite looping between converting xor->setcc here, and
19411 // expanding setcc->xor in LegalizeSetCCCondCode if requested.
19412 const ISD::CondCode CC = Equal ? ISD::SETEQ : ISD::SETNE;
19413 if (!LegalOperations || TLI.isCondCodeLegal(CC, VT: Op0.getSimpleValueType()))
19414 return DAG.getSetCC(DL: SDLoc(N), VT: SetCCVT, LHS: Op0, RHS: Op1, Cond: CC);
19415 }
19416 }
19417
19418 return SDValue();
19419}
19420
19421// Operand List for BR_CC: Chain, CondCC, CondLHS, CondRHS, DestBB.
19422//
19423SDValue DAGCombiner::visitBR_CC(SDNode *N) {
19424 CondCodeSDNode *CC = cast<CondCodeSDNode>(Val: N->getOperand(Num: 1));
19425 SDValue CondLHS = N->getOperand(Num: 2), CondRHS = N->getOperand(Num: 3);
19426
19427 // If N is a constant we could fold this into a fallthrough or unconditional
19428 // branch. However that doesn't happen very often in normal code, because
19429 // Instcombine/SimplifyCFG should have handled the available opportunities.
19430 // If we did this folding here, it would be necessary to update the
19431 // MachineBasicBlock CFG, which is awkward.
19432
19433 // Use SimplifySetCC to simplify SETCC's.
19434 SDValue Simp = SimplifySetCC(VT: getSetCCResultType(VT: CondLHS.getValueType()),
19435 N0: CondLHS, N1: CondRHS, Cond: CC->get(), DL: SDLoc(N),
19436 foldBooleans: false);
19437 if (Simp.getNode()) AddToWorklist(N: Simp.getNode());
19438
19439 // fold to a simpler setcc
19440 if (Simp.getNode() && Simp.getOpcode() == ISD::SETCC)
19441 return DAG.getNode(Opcode: ISD::BR_CC, DL: SDLoc(N), VT: MVT::Other,
19442 N1: N->getOperand(Num: 0), N2: Simp.getOperand(i: 2),
19443 N3: Simp.getOperand(i: 0), N4: Simp.getOperand(i: 1),
19444 N5: N->getOperand(Num: 4));
19445
19446 return SDValue();
19447}
19448
19449static bool getCombineLoadStoreParts(SDNode *N, unsigned Inc, unsigned Dec,
19450 bool &IsLoad, bool &IsMasked, SDValue &Ptr,
19451 const TargetLowering &TLI) {
19452 if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Val: N)) {
19453 if (LD->isIndexed())
19454 return false;
19455 EVT VT = LD->getMemoryVT();
19456 if (!TLI.isIndexedLoadLegal(IdxMode: Inc, VT) && !TLI.isIndexedLoadLegal(IdxMode: Dec, VT))
19457 return false;
19458 Ptr = LD->getBasePtr();
19459 } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(Val: N)) {
19460 if (ST->isIndexed())
19461 return false;
19462 EVT VT = ST->getMemoryVT();
19463 if (!TLI.isIndexedStoreLegal(IdxMode: Inc, VT) && !TLI.isIndexedStoreLegal(IdxMode: Dec, VT))
19464 return false;
19465 Ptr = ST->getBasePtr();
19466 IsLoad = false;
19467 } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(Val: N)) {
19468 if (LD->isIndexed())
19469 return false;
19470 EVT VT = LD->getMemoryVT();
19471 if (!TLI.isIndexedMaskedLoadLegal(IdxMode: Inc, VT) &&
19472 !TLI.isIndexedMaskedLoadLegal(IdxMode: Dec, VT))
19473 return false;
19474 Ptr = LD->getBasePtr();
19475 IsMasked = true;
19476 } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(Val: N)) {
19477 if (ST->isIndexed())
19478 return false;
19479 EVT VT = ST->getMemoryVT();
19480 if (!TLI.isIndexedMaskedStoreLegal(IdxMode: Inc, VT) &&
19481 !TLI.isIndexedMaskedStoreLegal(IdxMode: Dec, VT))
19482 return false;
19483 Ptr = ST->getBasePtr();
19484 IsLoad = false;
19485 IsMasked = true;
19486 } else {
19487 return false;
19488 }
19489 return true;
19490}
19491
19492/// Try turning a load/store into a pre-indexed load/store when the base
19493/// pointer is an add or subtract and it has other uses besides the load/store.
19494/// After the transformation, the new indexed load/store has effectively folded
19495/// the add/subtract in and all of its other uses are redirected to the
19496/// new load/store.
19497bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) {
19498 if (Level < AfterLegalizeDAG)
19499 return false;
19500
19501 bool IsLoad = true;
19502 bool IsMasked = false;
19503 SDValue Ptr;
19504 if (!getCombineLoadStoreParts(N, Inc: ISD::PRE_INC, Dec: ISD::PRE_DEC, IsLoad, IsMasked,
19505 Ptr, TLI))
19506 return false;
19507
19508 // If the pointer is not an add/sub, or if it doesn't have multiple uses, bail
19509 // out. There is no reason to make this a preinc/predec.
19510 if ((Ptr.getOpcode() != ISD::ADD && Ptr.getOpcode() != ISD::SUB) ||
19511 Ptr->hasOneUse())
19512 return false;
19513
19514 // Ask the target to do addressing mode selection.
19515 SDValue BasePtr;
19516 SDValue Offset;
19517 ISD::MemIndexedMode AM = ISD::UNINDEXED;
19518 if (!TLI.getPreIndexedAddressParts(N, BasePtr, Offset, AM, DAG))
19519 return false;
19520
19521 // Backends without true r+i pre-indexed forms may need to pass a
19522 // constant base with a variable offset so that constant coercion
19523 // will work with the patterns in canonical form.
19524 bool Swapped = false;
19525 if (isa<ConstantSDNode>(Val: BasePtr)) {
19526 std::swap(a&: BasePtr, b&: Offset);
19527 Swapped = true;
19528 }
19529
19530 // Don't create a indexed load / store with zero offset.
19531 if (isNullConstant(V: Offset))
19532 return false;
19533
19534 // Try turning it into a pre-indexed load / store except when:
19535 // 1) The new base ptr is a frame index.
19536 // 2) If N is a store and the new base ptr is either the same as or is a
19537 // predecessor of the value being stored.
19538 // 3) Another use of old base ptr is a predecessor of N. If ptr is folded
19539 // that would create a cycle.
19540 // 4) All uses are load / store ops that use it as old base ptr.
19541
19542 // Check #1. Preinc'ing a frame index would require copying the stack pointer
19543 // (plus the implicit offset) to a register to preinc anyway.
19544 if (isa<FrameIndexSDNode>(Val: BasePtr) || isa<RegisterSDNode>(Val: BasePtr))
19545 return false;
19546
19547 // Check #2.
19548 if (!IsLoad) {
19549 SDValue Val = IsMasked ? cast<MaskedStoreSDNode>(Val: N)->getValue()
19550 : cast<StoreSDNode>(Val: N)->getValue();
19551
19552 // Would require a copy.
19553 if (Val == BasePtr)
19554 return false;
19555
19556 // Would create a cycle.
19557 if (Val == Ptr || Ptr->isPredecessorOf(N: Val.getNode()))
19558 return false;
19559 }
19560
19561 // Caches for hasPredecessorHelper.
19562 SmallPtrSet<const SDNode *, 32> Visited;
19563 SmallVector<const SDNode *, 16> Worklist;
19564 Worklist.push_back(Elt: N);
19565
19566 // If the offset is a constant, there may be other adds of constants that
19567 // can be folded with this one. We should do this to avoid having to keep
19568 // a copy of the original base pointer.
19569 SmallVector<SDNode *, 16> OtherUses;
19570 unsigned MaxSteps = SelectionDAG::getHasPredecessorMaxSteps();
19571 if (isa<ConstantSDNode>(Val: Offset))
19572 for (SDUse &Use : BasePtr->uses()) {
19573 // Skip the use that is Ptr and uses of other results from BasePtr's
19574 // node (important for nodes that return multiple results).
19575 if (Use.getUser() == Ptr.getNode() || Use != BasePtr)
19576 continue;
19577
19578 if (SDNode::hasPredecessorHelper(N: Use.getUser(), Visited, Worklist,
19579 MaxSteps))
19580 continue;
19581
19582 if (Use.getUser()->getOpcode() != ISD::ADD &&
19583 Use.getUser()->getOpcode() != ISD::SUB) {
19584 OtherUses.clear();
19585 break;
19586 }
19587
19588 SDValue Op1 = Use.getUser()->getOperand(Num: (Use.getOperandNo() + 1) & 1);
19589 if (!isa<ConstantSDNode>(Val: Op1)) {
19590 OtherUses.clear();
19591 break;
19592 }
19593
19594 // FIXME: In some cases, we can be smarter about this.
19595 if (Op1.getValueType() != Offset.getValueType()) {
19596 OtherUses.clear();
19597 break;
19598 }
19599
19600 OtherUses.push_back(Elt: Use.getUser());
19601 }
19602
19603 if (Swapped)
19604 std::swap(a&: BasePtr, b&: Offset);
19605
19606 // Now check for #3 and #4.
19607 bool RealUse = false;
19608
19609 for (SDNode *User : Ptr->users()) {
19610 if (User == N)
19611 continue;
19612 if (SDNode::hasPredecessorHelper(N: User, Visited, Worklist, MaxSteps))
19613 return false;
19614
19615 // If Ptr may be folded in addressing mode of other use, then it's
19616 // not profitable to do this transformation.
19617 if (!canFoldInAddressingMode(N: Ptr.getNode(), Use: User, DAG, TLI))
19618 RealUse = true;
19619 }
19620
19621 if (!RealUse)
19622 return false;
19623
19624 SDValue Result;
19625 if (!IsMasked) {
19626 if (IsLoad)
19627 Result = DAG.getIndexedLoad(OrigLoad: SDValue(N, 0), dl: SDLoc(N), Base: BasePtr, Offset, AM);
19628 else
19629 Result =
19630 DAG.getIndexedStore(OrigStore: SDValue(N, 0), dl: SDLoc(N), Base: BasePtr, Offset, AM);
19631 } else {
19632 if (IsLoad)
19633 Result = DAG.getIndexedMaskedLoad(OrigLoad: SDValue(N, 0), dl: SDLoc(N), Base: BasePtr,
19634 Offset, AM);
19635 else
19636 Result = DAG.getIndexedMaskedStore(OrigStore: SDValue(N, 0), dl: SDLoc(N), Base: BasePtr,
19637 Offset, AM);
19638 }
19639 ++PreIndexedNodes;
19640 ++NodesCombined;
19641 LLVM_DEBUG(dbgs() << "\nReplacing.4 "; N->dump(&DAG); dbgs() << "\nWith: ";
19642 Result.dump(&DAG); dbgs() << '\n');
19643 WorklistRemover DeadNodes(*this);
19644 if (IsLoad) {
19645 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result.getValue(R: 0));
19646 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: Result.getValue(R: 2));
19647 } else {
19648 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result.getValue(R: 1));
19649 }
19650
19651 // Finally, since the node is now dead, remove it from the graph.
19652 deleteAndRecombine(N);
19653
19654 if (Swapped)
19655 std::swap(a&: BasePtr, b&: Offset);
19656
19657 // Replace other uses of BasePtr that can be updated to use Ptr
19658 for (SDNode *OtherUse : OtherUses) {
19659 unsigned OffsetIdx = 1;
19660 if (OtherUse->getOperand(Num: OffsetIdx).getNode() == BasePtr.getNode())
19661 OffsetIdx = 0;
19662 assert(OtherUse->getOperand(!OffsetIdx).getNode() == BasePtr.getNode() &&
19663 "Expected BasePtr operand");
19664
19665 // We need to replace ptr0 in the following expression:
19666 // x0 * offset0 + y0 * ptr0 = t0
19667 // knowing that
19668 // x1 * offset1 + y1 * ptr0 = t1 (the indexed load/store)
19669 //
19670 // where x0, x1, y0 and y1 in {-1, 1} are given by the types of the
19671 // indexed load/store and the expression that needs to be re-written.
19672 //
19673 // Therefore, we have:
19674 // t0 = (x0 * offset0 - x1 * y0 * y1 *offset1) + (y0 * y1) * t1
19675
19676 auto *CN = cast<ConstantSDNode>(Val: OtherUse->getOperand(Num: OffsetIdx));
19677 const APInt &Offset0 = CN->getAPIntValue();
19678 const APInt &Offset1 = Offset->getAsAPIntVal();
19679 int X0 = (OtherUse->getOpcode() == ISD::SUB && OffsetIdx == 1) ? -1 : 1;
19680 int Y0 = (OtherUse->getOpcode() == ISD::SUB && OffsetIdx == 0) ? -1 : 1;
19681 int X1 = (AM == ISD::PRE_DEC && !Swapped) ? -1 : 1;
19682 int Y1 = (AM == ISD::PRE_DEC && Swapped) ? -1 : 1;
19683
19684 unsigned Opcode = (Y0 * Y1 < 0) ? ISD::SUB : ISD::ADD;
19685
19686 APInt CNV = Offset0;
19687 if (X0 < 0) CNV = -CNV;
19688 if (X1 * Y0 * Y1 < 0) CNV = CNV + Offset1;
19689 else CNV = CNV - Offset1;
19690
19691 SDLoc DL(OtherUse);
19692
19693 // We can now generate the new expression.
19694 SDValue NewOp1 = DAG.getConstant(Val: CNV, DL, VT: CN->getValueType(ResNo: 0));
19695 SDValue NewOp2 = Result.getValue(R: IsLoad ? 1 : 0);
19696
19697 SDValue NewUse =
19698 DAG.getNode(Opcode, DL, VT: OtherUse->getValueType(ResNo: 0), N1: NewOp1, N2: NewOp2);
19699 DAG.ReplaceAllUsesOfValueWith(From: SDValue(OtherUse, 0), To: NewUse);
19700 deleteAndRecombine(N: OtherUse);
19701 }
19702
19703 // Replace the uses of Ptr with uses of the updated base value.
19704 DAG.ReplaceAllUsesOfValueWith(From: Ptr, To: Result.getValue(R: IsLoad ? 1 : 0));
19705 deleteAndRecombine(N: Ptr.getNode());
19706 AddToWorklist(N: Result.getNode());
19707
19708 return true;
19709}
19710
19711static bool shouldCombineToPostInc(SDNode *N, SDValue Ptr, SDNode *PtrUse,
19712 SDValue &BasePtr, SDValue &Offset,
19713 ISD::MemIndexedMode &AM,
19714 SelectionDAG &DAG,
19715 const TargetLowering &TLI) {
19716 if (PtrUse == N ||
19717 (PtrUse->getOpcode() != ISD::ADD && PtrUse->getOpcode() != ISD::SUB))
19718 return false;
19719
19720 if (!TLI.getPostIndexedAddressParts(N, PtrUse, BasePtr, Offset, AM, DAG))
19721 return false;
19722
19723 // Don't create a indexed load / store with zero offset.
19724 if (isNullConstant(V: Offset))
19725 return false;
19726
19727 if (isa<FrameIndexSDNode>(Val: BasePtr) || isa<RegisterSDNode>(Val: BasePtr))
19728 return false;
19729
19730 SmallPtrSet<const SDNode *, 32> Visited;
19731 unsigned MaxSteps = SelectionDAG::getHasPredecessorMaxSteps();
19732 for (SDNode *User : BasePtr->users()) {
19733 if (User == Ptr.getNode())
19734 continue;
19735
19736 // No if there's a later user which could perform the index instead.
19737 if (isa<MemSDNode>(Val: User)) {
19738 bool IsLoad = true;
19739 bool IsMasked = false;
19740 SDValue OtherPtr;
19741 if (getCombineLoadStoreParts(N: User, Inc: ISD::POST_INC, Dec: ISD::POST_DEC, IsLoad,
19742 IsMasked, Ptr&: OtherPtr, TLI)) {
19743 SmallVector<const SDNode *, 2> Worklist;
19744 Worklist.push_back(Elt: User);
19745 if (SDNode::hasPredecessorHelper(N, Visited, Worklist, MaxSteps))
19746 return false;
19747 }
19748 }
19749
19750 // If all the uses are load / store addresses, then don't do the
19751 // transformation.
19752 if (User->getOpcode() == ISD::ADD || User->getOpcode() == ISD::SUB) {
19753 for (SDNode *UserUser : User->users())
19754 if (canFoldInAddressingMode(N: User, Use: UserUser, DAG, TLI))
19755 return false;
19756 }
19757 }
19758 return true;
19759}
19760
19761static SDNode *getPostIndexedLoadStoreOp(SDNode *N, bool &IsLoad,
19762 bool &IsMasked, SDValue &Ptr,
19763 SDValue &BasePtr, SDValue &Offset,
19764 ISD::MemIndexedMode &AM,
19765 SelectionDAG &DAG,
19766 const TargetLowering &TLI) {
19767 if (!getCombineLoadStoreParts(N, Inc: ISD::POST_INC, Dec: ISD::POST_DEC, IsLoad,
19768 IsMasked, Ptr, TLI) ||
19769 Ptr->hasOneUse())
19770 return nullptr;
19771
19772 // Try turning it into a post-indexed load / store except when
19773 // 1) All uses are load / store ops that use it as base ptr (and
19774 // it may be folded as addressing mmode).
19775 // 2) Op must be independent of N, i.e. Op is neither a predecessor
19776 // nor a successor of N. Otherwise, if Op is folded that would
19777 // create a cycle.
19778 unsigned MaxSteps = SelectionDAG::getHasPredecessorMaxSteps();
19779 for (SDNode *Op : Ptr->users()) {
19780 // Check for #1.
19781 if (!shouldCombineToPostInc(N, Ptr, PtrUse: Op, BasePtr, Offset, AM, DAG, TLI))
19782 continue;
19783
19784 // Check for #2.
19785 SmallPtrSet<const SDNode *, 32> Visited;
19786 SmallVector<const SDNode *, 8> Worklist;
19787 // Ptr is predecessor to both N and Op.
19788 Visited.insert(Ptr: Ptr.getNode());
19789 Worklist.push_back(Elt: N);
19790 Worklist.push_back(Elt: Op);
19791 if (!SDNode::hasPredecessorHelper(N, Visited, Worklist, MaxSteps) &&
19792 !SDNode::hasPredecessorHelper(N: Op, Visited, Worklist, MaxSteps))
19793 return Op;
19794 }
19795 return nullptr;
19796}
19797
19798/// Try to combine a load/store with a add/sub of the base pointer node into a
19799/// post-indexed load/store. The transformation folded the add/subtract into the
19800/// new indexed load/store effectively and all of its uses are redirected to the
19801/// new load/store.
19802bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) {
19803 if (Level < AfterLegalizeDAG)
19804 return false;
19805
19806 bool IsLoad = true;
19807 bool IsMasked = false;
19808 SDValue Ptr;
19809 SDValue BasePtr;
19810 SDValue Offset;
19811 ISD::MemIndexedMode AM = ISD::UNINDEXED;
19812 SDNode *Op = getPostIndexedLoadStoreOp(N, IsLoad, IsMasked, Ptr, BasePtr,
19813 Offset, AM, DAG, TLI);
19814 if (!Op)
19815 return false;
19816
19817 SDValue Result;
19818 if (!IsMasked)
19819 Result = IsLoad ? DAG.getIndexedLoad(OrigLoad: SDValue(N, 0), dl: SDLoc(N), Base: BasePtr,
19820 Offset, AM)
19821 : DAG.getIndexedStore(OrigStore: SDValue(N, 0), dl: SDLoc(N),
19822 Base: BasePtr, Offset, AM);
19823 else
19824 Result = IsLoad ? DAG.getIndexedMaskedLoad(OrigLoad: SDValue(N, 0), dl: SDLoc(N),
19825 Base: BasePtr, Offset, AM)
19826 : DAG.getIndexedMaskedStore(OrigStore: SDValue(N, 0), dl: SDLoc(N),
19827 Base: BasePtr, Offset, AM);
19828 ++PostIndexedNodes;
19829 ++NodesCombined;
19830 LLVM_DEBUG(dbgs() << "\nReplacing.5 "; N->dump(&DAG); dbgs() << "\nWith: ";
19831 Result.dump(&DAG); dbgs() << '\n');
19832 WorklistRemover DeadNodes(*this);
19833 if (IsLoad) {
19834 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result.getValue(R: 0));
19835 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: Result.getValue(R: 2));
19836 } else {
19837 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result.getValue(R: 1));
19838 }
19839
19840 // Finally, since the node is now dead, remove it from the graph.
19841 deleteAndRecombine(N);
19842
19843 // Replace the uses of Use with uses of the updated base value.
19844 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Op, 0),
19845 To: Result.getValue(R: IsLoad ? 1 : 0));
19846 deleteAndRecombine(N: Op);
19847 return true;
19848}
19849
19850/// Return the base-pointer arithmetic from an indexed \p LD.
19851SDValue DAGCombiner::SplitIndexingFromLoad(LoadSDNode *LD) {
19852 ISD::MemIndexedMode AM = LD->getAddressingMode();
19853 assert(AM != ISD::UNINDEXED);
19854 SDValue BP = LD->getOperand(Num: 1);
19855 SDValue Inc = LD->getOperand(Num: 2);
19856
19857 // Some backends use TargetConstants for load offsets, but don't expect
19858 // TargetConstants in general ADD nodes. We can convert these constants into
19859 // regular Constants (if the constant is not opaque).
19860 assert((Inc.getOpcode() != ISD::TargetConstant ||
19861 !cast<ConstantSDNode>(Inc)->isOpaque()) &&
19862 "Cannot split out indexing using opaque target constants");
19863 if (Inc.getOpcode() == ISD::TargetConstant) {
19864 ConstantSDNode *ConstInc = cast<ConstantSDNode>(Val&: Inc);
19865 Inc = DAG.getConstant(Val: *ConstInc->getConstantIntValue(), DL: SDLoc(Inc),
19866 VT: ConstInc->getValueType(ResNo: 0));
19867 }
19868
19869 unsigned Opc =
19870 (AM == ISD::PRE_INC || AM == ISD::POST_INC ? ISD::ADD : ISD::SUB);
19871 return DAG.getNode(Opcode: Opc, DL: SDLoc(LD), VT: BP.getSimpleValueType(), N1: BP, N2: Inc);
19872}
19873
19874static inline ElementCount numVectorEltsOrZero(EVT T) {
19875 return T.isVector() ? T.getVectorElementCount() : ElementCount::getFixed(MinVal: 0);
19876}
19877
19878bool DAGCombiner::getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val) {
19879 EVT STType = Val.getValueType();
19880 EVT STMemType = ST->getMemoryVT();
19881 if (STType == STMemType)
19882 return true;
19883 if (isTypeLegal(VT: STMemType))
19884 return false; // fail.
19885 if (STType.isFloatingPoint() && STMemType.isFloatingPoint() &&
19886 TLI.isOperationLegal(Op: ISD::FTRUNC, VT: STMemType)) {
19887 Val = DAG.getNode(Opcode: ISD::FTRUNC, DL: SDLoc(ST), VT: STMemType, Operand: Val);
19888 return true;
19889 }
19890 if (numVectorEltsOrZero(T: STType) == numVectorEltsOrZero(T: STMemType) &&
19891 STType.isInteger() && STMemType.isInteger()) {
19892 Val = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(ST), VT: STMemType, Operand: Val);
19893 return true;
19894 }
19895 if (STType.getSizeInBits() == STMemType.getSizeInBits()) {
19896 Val = DAG.getBitcast(VT: STMemType, V: Val);
19897 return true;
19898 }
19899 return false; // fail.
19900}
19901
19902bool DAGCombiner::extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val) {
19903 EVT LDMemType = LD->getMemoryVT();
19904 EVT LDType = LD->getValueType(ResNo: 0);
19905 assert(Val.getValueType() == LDMemType &&
19906 "Attempting to extend value of non-matching type");
19907 if (LDType == LDMemType)
19908 return true;
19909 if (LDMemType.isInteger() && LDType.isInteger()) {
19910 switch (LD->getExtensionType()) {
19911 case ISD::NON_EXTLOAD:
19912 Val = DAG.getBitcast(VT: LDType, V: Val);
19913 return true;
19914 case ISD::EXTLOAD:
19915 Val = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL: SDLoc(LD), VT: LDType, Operand: Val);
19916 return true;
19917 case ISD::SEXTLOAD:
19918 Val = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL: SDLoc(LD), VT: LDType, Operand: Val);
19919 return true;
19920 case ISD::ZEXTLOAD:
19921 Val = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL: SDLoc(LD), VT: LDType, Operand: Val);
19922 return true;
19923 }
19924 }
19925 return false;
19926}
19927
19928StoreSDNode *DAGCombiner::getUniqueStoreFeeding(LoadSDNode *LD,
19929 int64_t &Offset) {
19930 SDValue Chain = LD->getOperand(Num: 0);
19931
19932 // Look through CALLSEQ_START.
19933 if (Chain.getOpcode() == ISD::CALLSEQ_START)
19934 Chain = Chain->getOperand(Num: 0);
19935
19936 StoreSDNode *ST = nullptr;
19937 SmallVector<SDValue, 8> Aliases;
19938 if (Chain.getOpcode() == ISD::TokenFactor) {
19939 // Look for unique store within the TokenFactor.
19940 for (SDValue Op : Chain->ops()) {
19941 StoreSDNode *Store = dyn_cast<StoreSDNode>(Val: Op.getNode());
19942 if (!Store)
19943 continue;
19944 BaseIndexOffset BasePtrLD = BaseIndexOffset::match(N: LD, DAG);
19945 BaseIndexOffset BasePtrST = BaseIndexOffset::match(N: Store, DAG);
19946 if (!BasePtrST.equalBaseIndex(Other: BasePtrLD, DAG, Off&: Offset))
19947 continue;
19948 // Make sure the store is not aliased with any nodes in TokenFactor.
19949 GatherAllAliases(N: Store, OriginalChain: Chain, Aliases);
19950 if (Aliases.empty() ||
19951 (Aliases.size() == 1 && Aliases.front().getNode() == Store))
19952 ST = Store;
19953 break;
19954 }
19955 } else {
19956 StoreSDNode *Store = dyn_cast<StoreSDNode>(Val: Chain.getNode());
19957 if (Store) {
19958 BaseIndexOffset BasePtrLD = BaseIndexOffset::match(N: LD, DAG);
19959 BaseIndexOffset BasePtrST = BaseIndexOffset::match(N: Store, DAG);
19960 if (BasePtrST.equalBaseIndex(Other: BasePtrLD, DAG, Off&: Offset))
19961 ST = Store;
19962 }
19963 }
19964
19965 return ST;
19966}
19967
19968SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) {
19969 if (OptLevel == CodeGenOptLevel::None || !LD->isSimple())
19970 return SDValue();
19971 SDValue Chain = LD->getOperand(Num: 0);
19972 int64_t Offset;
19973
19974 StoreSDNode *ST = getUniqueStoreFeeding(LD, Offset);
19975 // TODO: Relax this restriction for unordered atomics (see D66309)
19976 if (!ST || !ST->isSimple() || ST->getAddressSpace() != LD->getAddressSpace())
19977 return SDValue();
19978
19979 EVT LDType = LD->getValueType(ResNo: 0);
19980 EVT LDMemType = LD->getMemoryVT();
19981 EVT STMemType = ST->getMemoryVT();
19982 EVT STType = ST->getValue().getValueType();
19983
19984 // There are two cases to consider here:
19985 // 1. The store is fixed width and the load is scalable. In this case we
19986 // don't know at compile time if the store completely envelops the load
19987 // so we abandon the optimisation.
19988 // 2. The store is scalable and the load is fixed width. We could
19989 // potentially support a limited number of cases here, but there has been
19990 // no cost-benefit analysis to prove it's worth it.
19991 bool LdStScalable = LDMemType.isScalableVT();
19992 if (LdStScalable != STMemType.isScalableVT())
19993 return SDValue();
19994
19995 // If we are dealing with scalable vectors on a big endian platform the
19996 // calculation of offsets below becomes trickier, since we do not know at
19997 // compile time the absolute size of the vector. Until we've done more
19998 // analysis on big-endian platforms it seems better to bail out for now.
19999 if (LdStScalable && DAG.getDataLayout().isBigEndian())
20000 return SDValue();
20001
20002 // Normalize for Endianness. After this Offset=0 will denote that the least
20003 // significant bit in the loaded value maps to the least significant bit in
20004 // the stored value). With Offset=n (for n > 0) the loaded value starts at the
20005 // n:th least significant byte of the stored value.
20006 int64_t OrigOffset = Offset;
20007 if (DAG.getDataLayout().isBigEndian())
20008 Offset = ((int64_t)STMemType.getStoreSizeInBits().getFixedValue() -
20009 (int64_t)LDMemType.getStoreSizeInBits().getFixedValue()) /
20010 8 -
20011 Offset;
20012
20013 // Check that the stored value cover all bits that are loaded.
20014 bool STCoversLD;
20015
20016 TypeSize LdMemSize = LDMemType.getSizeInBits();
20017 TypeSize StMemSize = STMemType.getSizeInBits();
20018 if (LdStScalable)
20019 STCoversLD = (Offset == 0) && LdMemSize == StMemSize;
20020 else
20021 STCoversLD = (Offset >= 0) && (Offset * 8 + LdMemSize.getFixedValue() <=
20022 StMemSize.getFixedValue());
20023
20024 auto ReplaceLd = [&](LoadSDNode *LD, SDValue Val, SDValue Chain) -> SDValue {
20025 if (LD->isIndexed()) {
20026 // Cannot handle opaque target constants and we must respect the user's
20027 // request not to split indexes from loads.
20028 if (!canSplitIdx(LD))
20029 return SDValue();
20030 SDValue Idx = SplitIndexingFromLoad(LD);
20031 SDValue Ops[] = {Val, Idx, Chain};
20032 return CombineTo(N: LD, To: Ops, NumTo: 3);
20033 }
20034 return CombineTo(N: LD, Res0: Val, Res1: Chain);
20035 };
20036
20037 if (!STCoversLD)
20038 return SDValue();
20039
20040 // Memory as copy space (potentially masked).
20041 if (Offset == 0 && LDType == STType && STMemType == LDMemType) {
20042 // Simple case: Direct non-truncating forwarding
20043 if (LDType.getSizeInBits() == LdMemSize)
20044 return ReplaceLd(LD, ST->getValue(), Chain);
20045 // Can we model the truncate and extension with an and mask?
20046 if (STType.isInteger() && LDMemType.isInteger() && !STType.isVector() &&
20047 !LDMemType.isVector() && LD->getExtensionType() != ISD::SEXTLOAD) {
20048 // Mask to size of LDMemType
20049 auto Mask =
20050 DAG.getConstant(Val: APInt::getLowBitsSet(numBits: STType.getFixedSizeInBits(),
20051 loBitsSet: StMemSize.getFixedValue()),
20052 DL: SDLoc(ST), VT: STType);
20053 auto Val = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(LD), VT: LDType, N1: ST->getValue(), N2: Mask);
20054 return ReplaceLd(LD, Val, Chain);
20055 }
20056 }
20057
20058 // Handle some cases for big-endian that would be Offset 0 and handled for
20059 // little-endian.
20060 SDValue Val = ST->getValue();
20061 if (DAG.getDataLayout().isBigEndian() && Offset > 0 && OrigOffset == 0) {
20062 if (STType.isInteger() && !STType.isVector() && LDType.isInteger() &&
20063 !LDType.isVector() && isTypeLegal(VT: STType) &&
20064 TLI.isOperationLegal(Op: ISD::SRL, VT: STType)) {
20065 Val = DAG.getNode(Opcode: ISD::SRL, DL: SDLoc(LD), VT: STType, N1: Val,
20066 N2: DAG.getConstant(Val: Offset * 8, DL: SDLoc(LD), VT: STType));
20067 Offset = 0;
20068 }
20069 }
20070
20071 // TODO: Deal with nonzero offset.
20072 if (LD->getBasePtr().isUndef() || Offset != 0)
20073 return SDValue();
20074 // Model necessary truncations / extenstions.
20075 // Truncate Value To Stored Memory Size.
20076 do {
20077 if (!getTruncatedStoreValue(ST, Val))
20078 break;
20079 if (!isTypeLegal(VT: LDMemType))
20080 break;
20081 if (STMemType != LDMemType) {
20082 // TODO: Support vectors? This requires extract_subvector/bitcast.
20083 if (!STMemType.isVector() && !LDMemType.isVector() &&
20084 STMemType.isInteger() && LDMemType.isInteger())
20085 Val = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(LD), VT: LDMemType, Operand: Val);
20086 else
20087 break;
20088 }
20089 if (!extendLoadedValueToExtension(LD, Val))
20090 break;
20091 return ReplaceLd(LD, Val, Chain);
20092 } while (false);
20093
20094 // On failure, cleanup dead nodes we may have created.
20095 if (Val->use_empty())
20096 deleteAndRecombine(N: Val.getNode());
20097 return SDValue();
20098}
20099
20100SDValue DAGCombiner::visitLOAD(SDNode *N) {
20101 LoadSDNode *LD = cast<LoadSDNode>(Val: N);
20102 SDValue Chain = LD->getChain();
20103 SDValue Ptr = LD->getBasePtr();
20104
20105 // If load is not volatile and there are no uses of the loaded value (and
20106 // the updated indexed value in case of indexed loads), change uses of the
20107 // chain value into uses of the chain input (i.e. delete the dead load).
20108 // TODO: Allow this for unordered atomics (see D66309)
20109 if (LD->isSimple()) {
20110 if (N->getValueType(ResNo: 1) == MVT::Other) {
20111 // Unindexed loads.
20112 if (!N->hasAnyUseOfValue(Value: 0)) {
20113 // It's not safe to use the two value CombineTo variant here. e.g.
20114 // v1, chain2 = load chain1, loc
20115 // v2, chain3 = load chain2, loc
20116 // v3 = add v2, c
20117 // Now we replace use of chain2 with chain1. This makes the second load
20118 // isomorphic to the one we are deleting, and thus makes this load live.
20119 LLVM_DEBUG(dbgs() << "\nReplacing.6 "; N->dump(&DAG);
20120 dbgs() << "\nWith chain: "; Chain.dump(&DAG);
20121 dbgs() << "\n");
20122 WorklistRemover DeadNodes(*this);
20123 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: Chain);
20124 AddUsersToWorklist(N: Chain.getNode());
20125 if (N->use_empty())
20126 deleteAndRecombine(N);
20127
20128 return SDValue(N, 0); // Return N so it doesn't get rechecked!
20129 }
20130 } else {
20131 // Indexed loads.
20132 assert(N->getValueType(2) == MVT::Other && "Malformed indexed loads?");
20133
20134 // If this load has an opaque TargetConstant offset, then we cannot split
20135 // the indexing into an add/sub directly (that TargetConstant may not be
20136 // valid for a different type of node, and we cannot convert an opaque
20137 // target constant into a regular constant).
20138 bool CanSplitIdx = canSplitIdx(LD);
20139
20140 if (!N->hasAnyUseOfValue(Value: 0) && (CanSplitIdx || !N->hasAnyUseOfValue(Value: 1))) {
20141 SDValue Undef = DAG.getUNDEF(VT: N->getValueType(ResNo: 0));
20142 SDValue Index;
20143 if (N->hasAnyUseOfValue(Value: 1) && CanSplitIdx) {
20144 Index = SplitIndexingFromLoad(LD);
20145 // Try to fold the base pointer arithmetic into subsequent loads and
20146 // stores.
20147 AddUsersToWorklist(N);
20148 } else
20149 Index = DAG.getUNDEF(VT: N->getValueType(ResNo: 1));
20150 LLVM_DEBUG(dbgs() << "\nReplacing.7 "; N->dump(&DAG);
20151 dbgs() << "\nWith: "; Undef.dump(&DAG);
20152 dbgs() << " and 2 other values\n");
20153 WorklistRemover DeadNodes(*this);
20154 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Undef);
20155 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: Index);
20156 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 2), To: Chain);
20157 deleteAndRecombine(N);
20158 return SDValue(N, 0); // Return N so it doesn't get rechecked!
20159 }
20160 }
20161 }
20162
20163 // If this load is directly stored, replace the load value with the stored
20164 // value.
20165 if (auto V = ForwardStoreValueToDirectLoad(LD))
20166 return V;
20167
20168 // Try to infer better alignment information than the load already has.
20169 if (OptLevel != CodeGenOptLevel::None && LD->isUnindexed() &&
20170 !LD->isAtomic()) {
20171 if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
20172 if (*Alignment > LD->getAlign() &&
20173 isAligned(Lhs: *Alignment, SizeInBytes: LD->getSrcValueOffset())) {
20174 SDValue NewLoad = DAG.getExtLoad(
20175 ExtType: LD->getExtensionType(), dl: SDLoc(N), VT: LD->getValueType(ResNo: 0), Chain, Ptr,
20176 PtrInfo: LD->getPointerInfo(), MemVT: LD->getMemoryVT(), Alignment: *Alignment,
20177 MMOFlags: LD->getMemOperand()->getFlags(), AAInfo: LD->getAAInfo());
20178 // NewLoad will always be N as we are only refining the alignment
20179 assert(NewLoad.getNode() == N);
20180 (void)NewLoad;
20181 }
20182 }
20183 }
20184
20185 if (LD->isUnindexed()) {
20186 // Walk up chain skipping non-aliasing memory nodes.
20187 SDValue BetterChain = FindBetterChain(N: LD, Chain);
20188
20189 // If there is a better chain.
20190 if (Chain != BetterChain) {
20191 SDValue ReplLoad;
20192
20193 // Replace the chain to void dependency.
20194 if (LD->getExtensionType() == ISD::NON_EXTLOAD) {
20195 ReplLoad = DAG.getLoad(VT: N->getValueType(ResNo: 0), dl: SDLoc(LD),
20196 Chain: BetterChain, Ptr, MMO: LD->getMemOperand());
20197 } else {
20198 ReplLoad = DAG.getExtLoad(ExtType: LD->getExtensionType(), dl: SDLoc(LD),
20199 VT: LD->getValueType(ResNo: 0),
20200 Chain: BetterChain, Ptr, MemVT: LD->getMemoryVT(),
20201 MMO: LD->getMemOperand());
20202 }
20203
20204 // Create token factor to keep old chain connected.
20205 SDValue Token = DAG.getNode(Opcode: ISD::TokenFactor, DL: SDLoc(N),
20206 VT: MVT::Other, N1: Chain, N2: ReplLoad.getValue(R: 1));
20207
20208 // Replace uses with load result and token factor
20209 return CombineTo(N, Res0: ReplLoad.getValue(R: 0), Res1: Token);
20210 }
20211 }
20212
20213 // Try transforming N to an indexed load.
20214 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
20215 return SDValue(N, 0);
20216
20217 // Try to slice up N to more direct loads if the slices are mapped to
20218 // different register banks or pairing can take place.
20219 if (SliceUpLoad(N))
20220 return SDValue(N, 0);
20221
20222 return SDValue();
20223}
20224
20225namespace {
20226
20227/// Helper structure used to slice a load in smaller loads.
20228/// Basically a slice is obtained from the following sequence:
20229/// Origin = load Ty1, Base
20230/// Shift = srl Ty1 Origin, CstTy Amount
20231/// Inst = trunc Shift to Ty2
20232///
20233/// Then, it will be rewritten into:
20234/// Slice = load SliceTy, Base + SliceOffset
20235/// [Inst = zext Slice to Ty2], only if SliceTy <> Ty2
20236///
20237/// SliceTy is deduced from the number of bits that are actually used to
20238/// build Inst.
20239struct LoadedSlice {
20240 /// Helper structure used to compute the cost of a slice.
20241 struct Cost {
20242 /// Are we optimizing for code size.
20243 bool ForCodeSize = false;
20244
20245 /// Various cost.
20246 unsigned Loads = 0;
20247 unsigned Truncates = 0;
20248 unsigned CrossRegisterBanksCopies = 0;
20249 unsigned ZExts = 0;
20250 unsigned Shift = 0;
20251
20252 explicit Cost(bool ForCodeSize) : ForCodeSize(ForCodeSize) {}
20253
20254 /// Get the cost of one isolated slice.
20255 Cost(const LoadedSlice &LS, bool ForCodeSize)
20256 : ForCodeSize(ForCodeSize), Loads(1) {
20257 EVT TruncType = LS.Inst->getValueType(ResNo: 0);
20258 EVT LoadedType = LS.getLoadedType();
20259 if (TruncType != LoadedType &&
20260 !LS.DAG->getTargetLoweringInfo().isZExtFree(FromTy: LoadedType, ToTy: TruncType))
20261 ZExts = 1;
20262 }
20263
20264 /// Account for slicing gain in the current cost.
20265 /// Slicing provide a few gains like removing a shift or a
20266 /// truncate. This method allows to grow the cost of the original
20267 /// load with the gain from this slice.
20268 void addSliceGain(const LoadedSlice &LS) {
20269 // Each slice saves a truncate.
20270 const TargetLowering &TLI = LS.DAG->getTargetLoweringInfo();
20271 if (!TLI.isTruncateFree(Val: LS.Inst->getOperand(Num: 0), VT2: LS.Inst->getValueType(ResNo: 0)))
20272 ++Truncates;
20273 // If there is a shift amount, this slice gets rid of it.
20274 if (LS.Shift)
20275 ++Shift;
20276 // If this slice can merge a cross register bank copy, account for it.
20277 if (LS.canMergeExpensiveCrossRegisterBankCopy())
20278 ++CrossRegisterBanksCopies;
20279 }
20280
20281 Cost &operator+=(const Cost &RHS) {
20282 Loads += RHS.Loads;
20283 Truncates += RHS.Truncates;
20284 CrossRegisterBanksCopies += RHS.CrossRegisterBanksCopies;
20285 ZExts += RHS.ZExts;
20286 Shift += RHS.Shift;
20287 return *this;
20288 }
20289
20290 bool operator==(const Cost &RHS) const {
20291 return Loads == RHS.Loads && Truncates == RHS.Truncates &&
20292 CrossRegisterBanksCopies == RHS.CrossRegisterBanksCopies &&
20293 ZExts == RHS.ZExts && Shift == RHS.Shift;
20294 }
20295
20296 bool operator!=(const Cost &RHS) const { return !(*this == RHS); }
20297
20298 bool operator<(const Cost &RHS) const {
20299 // Assume cross register banks copies are as expensive as loads.
20300 // FIXME: Do we want some more target hooks?
20301 unsigned ExpensiveOpsLHS = Loads + CrossRegisterBanksCopies;
20302 unsigned ExpensiveOpsRHS = RHS.Loads + RHS.CrossRegisterBanksCopies;
20303 // Unless we are optimizing for code size, consider the
20304 // expensive operation first.
20305 if (!ForCodeSize && ExpensiveOpsLHS != ExpensiveOpsRHS)
20306 return ExpensiveOpsLHS < ExpensiveOpsRHS;
20307 return (Truncates + ZExts + Shift + ExpensiveOpsLHS) <
20308 (RHS.Truncates + RHS.ZExts + RHS.Shift + ExpensiveOpsRHS);
20309 }
20310
20311 bool operator>(const Cost &RHS) const { return RHS < *this; }
20312
20313 bool operator<=(const Cost &RHS) const { return !(RHS < *this); }
20314
20315 bool operator>=(const Cost &RHS) const { return !(*this < RHS); }
20316 };
20317
20318 // The last instruction that represent the slice. This should be a
20319 // truncate instruction.
20320 SDNode *Inst;
20321
20322 // The original load instruction.
20323 LoadSDNode *Origin;
20324
20325 // The right shift amount in bits from the original load.
20326 unsigned Shift;
20327
20328 // The DAG from which Origin came from.
20329 // This is used to get some contextual information about legal types, etc.
20330 SelectionDAG *DAG;
20331
20332 LoadedSlice(SDNode *Inst = nullptr, LoadSDNode *Origin = nullptr,
20333 unsigned Shift = 0, SelectionDAG *DAG = nullptr)
20334 : Inst(Inst), Origin(Origin), Shift(Shift), DAG(DAG) {}
20335
20336 /// Get the bits used in a chunk of bits \p BitWidth large.
20337 /// \return Result is \p BitWidth and has used bits set to 1 and
20338 /// not used bits set to 0.
20339 APInt getUsedBits() const {
20340 // Reproduce the trunc(lshr) sequence:
20341 // - Start from the truncated value.
20342 // - Zero extend to the desired bit width.
20343 // - Shift left.
20344 assert(Origin && "No original load to compare against.");
20345 unsigned BitWidth = Origin->getValueSizeInBits(ResNo: 0);
20346 assert(Inst && "This slice is not bound to an instruction");
20347 assert(Inst->getValueSizeInBits(0) <= BitWidth &&
20348 "Extracted slice is bigger than the whole type!");
20349 APInt UsedBits(Inst->getValueSizeInBits(ResNo: 0), 0);
20350 UsedBits.setAllBits();
20351 UsedBits = UsedBits.zext(width: BitWidth);
20352 UsedBits <<= Shift;
20353 return UsedBits;
20354 }
20355
20356 /// Get the size of the slice to be loaded in bytes.
20357 unsigned getLoadedSize() const {
20358 unsigned SliceSize = getUsedBits().popcount();
20359 assert(!(SliceSize & 0x7) && "Size is not a multiple of a byte.");
20360 return SliceSize / 8;
20361 }
20362
20363 /// Get the type that will be loaded for this slice.
20364 /// Note: This may not be the final type for the slice.
20365 EVT getLoadedType() const {
20366 assert(DAG && "Missing context");
20367 LLVMContext &Ctxt = *DAG->getContext();
20368 return EVT::getIntegerVT(Context&: Ctxt, BitWidth: getLoadedSize() * 8);
20369 }
20370
20371 /// Get the alignment of the load used for this slice.
20372 Align getAlign() const {
20373 Align Alignment = Origin->getAlign();
20374 uint64_t Offset = getOffsetFromBase();
20375 if (Offset != 0)
20376 Alignment = commonAlignment(A: Alignment, Offset: Alignment.value() + Offset);
20377 return Alignment;
20378 }
20379
20380 /// Check if this slice can be rewritten with legal operations.
20381 bool isLegal() const {
20382 // An invalid slice is not legal.
20383 if (!Origin || !Inst || !DAG)
20384 return false;
20385
20386 // Offsets are for indexed load only, we do not handle that.
20387 if (!Origin->getOffset().isUndef())
20388 return false;
20389
20390 const TargetLowering &TLI = DAG->getTargetLoweringInfo();
20391
20392 // Check that the type is legal.
20393 EVT SliceType = getLoadedType();
20394 if (!TLI.isTypeLegal(VT: SliceType))
20395 return false;
20396
20397 // Check that the load is legal for this type.
20398 if (!TLI.isOperationLegal(Op: ISD::LOAD, VT: SliceType))
20399 return false;
20400
20401 // Check that the offset can be computed.
20402 // 1. Check its type.
20403 EVT PtrType = Origin->getBasePtr().getValueType();
20404 if (PtrType == MVT::Untyped || PtrType.isExtended())
20405 return false;
20406
20407 // 2. Check that it fits in the immediate.
20408 if (!TLI.isLegalAddImmediate(getOffsetFromBase()))
20409 return false;
20410
20411 // 3. Check that the computation is legal.
20412 if (!TLI.isOperationLegal(Op: ISD::ADD, VT: PtrType))
20413 return false;
20414
20415 // Check that the zext is legal if it needs one.
20416 EVT TruncateType = Inst->getValueType(ResNo: 0);
20417 if (TruncateType != SliceType &&
20418 !TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT: TruncateType))
20419 return false;
20420
20421 return true;
20422 }
20423
20424 /// Get the offset in bytes of this slice in the original chunk of
20425 /// bits.
20426 /// \pre DAG != nullptr.
20427 uint64_t getOffsetFromBase() const {
20428 assert(DAG && "Missing context.");
20429 bool IsBigEndian = DAG->getDataLayout().isBigEndian();
20430 assert(!(Shift & 0x7) && "Shifts not aligned on Bytes are not supported.");
20431 uint64_t Offset = Shift / 8;
20432 unsigned TySizeInBytes = Origin->getValueSizeInBits(ResNo: 0) / 8;
20433 assert(!(Origin->getValueSizeInBits(0) & 0x7) &&
20434 "The size of the original loaded type is not a multiple of a"
20435 " byte.");
20436 // If Offset is bigger than TySizeInBytes, it means we are loading all
20437 // zeros. This should have been optimized before in the process.
20438 assert(TySizeInBytes > Offset &&
20439 "Invalid shift amount for given loaded size");
20440 if (IsBigEndian)
20441 Offset = TySizeInBytes - Offset - getLoadedSize();
20442 return Offset;
20443 }
20444
20445 /// Generate the sequence of instructions to load the slice
20446 /// represented by this object and redirect the uses of this slice to
20447 /// this new sequence of instructions.
20448 /// \pre this->Inst && this->Origin are valid Instructions and this
20449 /// object passed the legal check: LoadedSlice::isLegal returned true.
20450 /// \return The last instruction of the sequence used to load the slice.
20451 SDValue loadSlice() const {
20452 assert(Inst && Origin && "Unable to replace a non-existing slice.");
20453 const SDValue &OldBaseAddr = Origin->getBasePtr();
20454 SDValue BaseAddr = OldBaseAddr;
20455 // Get the offset in that chunk of bytes w.r.t. the endianness.
20456 int64_t Offset = static_cast<int64_t>(getOffsetFromBase());
20457 assert(Offset >= 0 && "Offset too big to fit in int64_t!");
20458 if (Offset) {
20459 // BaseAddr = BaseAddr + Offset.
20460 EVT ArithType = BaseAddr.getValueType();
20461 SDLoc DL(Origin);
20462 BaseAddr = DAG->getNode(Opcode: ISD::ADD, DL, VT: ArithType, N1: BaseAddr,
20463 N2: DAG->getConstant(Val: Offset, DL, VT: ArithType));
20464 }
20465
20466 // Create the type of the loaded slice according to its size.
20467 EVT SliceType = getLoadedType();
20468
20469 // Create the load for the slice.
20470 SDValue LastInst =
20471 DAG->getLoad(VT: SliceType, dl: SDLoc(Origin), Chain: Origin->getChain(), Ptr: BaseAddr,
20472 PtrInfo: Origin->getPointerInfo().getWithOffset(O: Offset), Alignment: getAlign(),
20473 MMOFlags: Origin->getMemOperand()->getFlags());
20474 // If the final type is not the same as the loaded type, this means that
20475 // we have to pad with zero. Create a zero extend for that.
20476 EVT FinalType = Inst->getValueType(ResNo: 0);
20477 if (SliceType != FinalType)
20478 LastInst =
20479 DAG->getNode(Opcode: ISD::ZERO_EXTEND, DL: SDLoc(LastInst), VT: FinalType, Operand: LastInst);
20480 return LastInst;
20481 }
20482
20483 /// Check if this slice can be merged with an expensive cross register
20484 /// bank copy. E.g.,
20485 /// i = load i32
20486 /// f = bitcast i32 i to float
20487 bool canMergeExpensiveCrossRegisterBankCopy() const {
20488 if (!Inst || !Inst->hasOneUse())
20489 return false;
20490 SDNode *User = *Inst->user_begin();
20491 if (User->getOpcode() != ISD::BITCAST)
20492 return false;
20493 assert(DAG && "Missing context");
20494 const TargetLowering &TLI = DAG->getTargetLoweringInfo();
20495 EVT ResVT = User->getValueType(ResNo: 0);
20496 const TargetRegisterClass *ResRC =
20497 TLI.getRegClassFor(VT: ResVT.getSimpleVT(), isDivergent: User->isDivergent());
20498 const TargetRegisterClass *ArgRC =
20499 TLI.getRegClassFor(VT: User->getOperand(Num: 0).getValueType().getSimpleVT(),
20500 isDivergent: User->getOperand(Num: 0)->isDivergent());
20501 if (ArgRC == ResRC || !TLI.isOperationLegal(Op: ISD::LOAD, VT: ResVT))
20502 return false;
20503
20504 // At this point, we know that we perform a cross-register-bank copy.
20505 // Check if it is expensive.
20506 const TargetRegisterInfo *TRI = DAG->getSubtarget().getRegisterInfo();
20507 // Assume bitcasts are cheap, unless both register classes do not
20508 // explicitly share a common sub class.
20509 if (!TRI || TRI->getCommonSubClass(A: ArgRC, B: ResRC))
20510 return false;
20511
20512 // Check if it will be merged with the load.
20513 // 1. Check the alignment / fast memory access constraint.
20514 unsigned IsFast = 0;
20515 if (!TLI.allowsMemoryAccess(Context&: *DAG->getContext(), DL: DAG->getDataLayout(), VT: ResVT,
20516 AddrSpace: Origin->getAddressSpace(), Alignment: getAlign(),
20517 Flags: Origin->getMemOperand()->getFlags(), Fast: &IsFast) ||
20518 !IsFast)
20519 return false;
20520
20521 // 2. Check that the load is a legal operation for that type.
20522 if (!TLI.isOperationLegal(Op: ISD::LOAD, VT: ResVT))
20523 return false;
20524
20525 // 3. Check that we do not have a zext in the way.
20526 if (Inst->getValueType(ResNo: 0) != getLoadedType())
20527 return false;
20528
20529 return true;
20530 }
20531};
20532
20533} // end anonymous namespace
20534
20535/// Check that all bits set in \p UsedBits form a dense region, i.e.,
20536/// \p UsedBits looks like 0..0 1..1 0..0.
20537static bool areUsedBitsDense(const APInt &UsedBits) {
20538 // If all the bits are one, this is dense!
20539 if (UsedBits.isAllOnes())
20540 return true;
20541
20542 // Get rid of the unused bits on the right.
20543 APInt NarrowedUsedBits = UsedBits.lshr(shiftAmt: UsedBits.countr_zero());
20544 // Get rid of the unused bits on the left.
20545 if (NarrowedUsedBits.countl_zero())
20546 NarrowedUsedBits = NarrowedUsedBits.trunc(width: NarrowedUsedBits.getActiveBits());
20547 // Check that the chunk of bits is completely used.
20548 return NarrowedUsedBits.isAllOnes();
20549}
20550
20551/// Check whether or not \p First and \p Second are next to each other
20552/// in memory. This means that there is no hole between the bits loaded
20553/// by \p First and the bits loaded by \p Second.
20554static bool areSlicesNextToEachOther(const LoadedSlice &First,
20555 const LoadedSlice &Second) {
20556 assert(First.Origin == Second.Origin && First.Origin &&
20557 "Unable to match different memory origins.");
20558 APInt UsedBits = First.getUsedBits();
20559 assert((UsedBits & Second.getUsedBits()) == 0 &&
20560 "Slices are not supposed to overlap.");
20561 UsedBits |= Second.getUsedBits();
20562 return areUsedBitsDense(UsedBits);
20563}
20564
20565/// Adjust the \p GlobalLSCost according to the target
20566/// paring capabilities and the layout of the slices.
20567/// \pre \p GlobalLSCost should account for at least as many loads as
20568/// there is in the slices in \p LoadedSlices.
20569static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices,
20570 LoadedSlice::Cost &GlobalLSCost) {
20571 unsigned NumberOfSlices = LoadedSlices.size();
20572 // If there is less than 2 elements, no pairing is possible.
20573 if (NumberOfSlices < 2)
20574 return;
20575
20576 // Sort the slices so that elements that are likely to be next to each
20577 // other in memory are next to each other in the list.
20578 llvm::sort(C&: LoadedSlices, Comp: [](const LoadedSlice &LHS, const LoadedSlice &RHS) {
20579 assert(LHS.Origin == RHS.Origin && "Different bases not implemented.");
20580 return LHS.getOffsetFromBase() < RHS.getOffsetFromBase();
20581 });
20582 const TargetLowering &TLI = LoadedSlices[0].DAG->getTargetLoweringInfo();
20583 // First (resp. Second) is the first (resp. Second) potentially candidate
20584 // to be placed in a paired load.
20585 const LoadedSlice *First = nullptr;
20586 const LoadedSlice *Second = nullptr;
20587 for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice,
20588 // Set the beginning of the pair.
20589 First = Second) {
20590 Second = &LoadedSlices[CurrSlice];
20591
20592 // If First is NULL, it means we start a new pair.
20593 // Get to the next slice.
20594 if (!First)
20595 continue;
20596
20597 EVT LoadedType = First->getLoadedType();
20598
20599 // If the types of the slices are different, we cannot pair them.
20600 if (LoadedType != Second->getLoadedType())
20601 continue;
20602
20603 // Check if the target supplies paired loads for this type.
20604 Align RequiredAlignment;
20605 if (!TLI.hasPairedLoad(LoadedType, RequiredAlignment)) {
20606 // move to the next pair, this type is hopeless.
20607 Second = nullptr;
20608 continue;
20609 }
20610 // Check if we meet the alignment requirement.
20611 if (First->getAlign() < RequiredAlignment)
20612 continue;
20613
20614 // Check that both loads are next to each other in memory.
20615 if (!areSlicesNextToEachOther(First: *First, Second: *Second))
20616 continue;
20617
20618 assert(GlobalLSCost.Loads > 0 && "We save more loads than we created!");
20619 --GlobalLSCost.Loads;
20620 // Move to the next pair.
20621 Second = nullptr;
20622 }
20623}
20624
20625/// Check the profitability of all involved LoadedSlice.
20626/// Currently, it is considered profitable if there is exactly two
20627/// involved slices (1) which are (2) next to each other in memory, and
20628/// whose cost (\see LoadedSlice::Cost) is smaller than the original load (3).
20629///
20630/// Note: The order of the elements in \p LoadedSlices may be modified, but not
20631/// the elements themselves.
20632///
20633/// FIXME: When the cost model will be mature enough, we can relax
20634/// constraints (1) and (2).
20635static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
20636 const APInt &UsedBits, bool ForCodeSize) {
20637 unsigned NumberOfSlices = LoadedSlices.size();
20638 if (StressLoadSlicing)
20639 return NumberOfSlices > 1;
20640
20641 // Check (1).
20642 if (NumberOfSlices != 2)
20643 return false;
20644
20645 // Check (2).
20646 if (!areUsedBitsDense(UsedBits))
20647 return false;
20648
20649 // Check (3).
20650 LoadedSlice::Cost OrigCost(ForCodeSize), GlobalSlicingCost(ForCodeSize);
20651 // The original code has one big load.
20652 OrigCost.Loads = 1;
20653 for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice) {
20654 const LoadedSlice &LS = LoadedSlices[CurrSlice];
20655 // Accumulate the cost of all the slices.
20656 LoadedSlice::Cost SliceCost(LS, ForCodeSize);
20657 GlobalSlicingCost += SliceCost;
20658
20659 // Account as cost in the original configuration the gain obtained
20660 // with the current slices.
20661 OrigCost.addSliceGain(LS);
20662 }
20663
20664 // If the target supports paired load, adjust the cost accordingly.
20665 adjustCostForPairing(LoadedSlices, GlobalLSCost&: GlobalSlicingCost);
20666 return OrigCost > GlobalSlicingCost;
20667}
20668
20669/// If the given load, \p LI, is used only by trunc or trunc(lshr)
20670/// operations, split it in the various pieces being extracted.
20671///
20672/// This sort of thing is introduced by SROA.
20673/// This slicing takes care not to insert overlapping loads.
20674/// \pre LI is a simple load (i.e., not an atomic or volatile load).
20675bool DAGCombiner::SliceUpLoad(SDNode *N) {
20676 if (Level < AfterLegalizeDAG)
20677 return false;
20678
20679 LoadSDNode *LD = cast<LoadSDNode>(Val: N);
20680 if (!LD->isSimple() || !ISD::isNormalLoad(N: LD) ||
20681 !LD->getValueType(ResNo: 0).isInteger())
20682 return false;
20683
20684 // The algorithm to split up a load of a scalable vector into individual
20685 // elements currently requires knowing the length of the loaded type,
20686 // so will need adjusting to work on scalable vectors.
20687 if (LD->getValueType(ResNo: 0).isScalableVector())
20688 return false;
20689
20690 // Keep track of already used bits to detect overlapping values.
20691 // In that case, we will just abort the transformation.
20692 APInt UsedBits(LD->getValueSizeInBits(ResNo: 0), 0);
20693
20694 SmallVector<LoadedSlice, 4> LoadedSlices;
20695
20696 // Check if this load is used as several smaller chunks of bits.
20697 // Basically, look for uses in trunc or trunc(lshr) and record a new chain
20698 // of computation for each trunc.
20699 for (SDUse &U : LD->uses()) {
20700 // Skip the uses of the chain.
20701 if (U.getResNo() != 0)
20702 continue;
20703
20704 SDNode *User = U.getUser();
20705 unsigned Shift = 0;
20706
20707 // Check if this is a trunc(lshr).
20708 if (User->getOpcode() == ISD::SRL && User->hasOneUse() &&
20709 isa<ConstantSDNode>(Val: User->getOperand(Num: 1))) {
20710 Shift = User->getConstantOperandVal(Num: 1);
20711 User = *User->user_begin();
20712 }
20713
20714 // At this point, User is a Truncate, iff we encountered, trunc or
20715 // trunc(lshr).
20716 if (User->getOpcode() != ISD::TRUNCATE)
20717 return false;
20718
20719 // The width of the type must be a power of 2 and greater than 8-bits.
20720 // Otherwise the load cannot be represented in LLVM IR.
20721 // Moreover, if we shifted with a non-8-bits multiple, the slice
20722 // will be across several bytes. We do not support that.
20723 unsigned Width = User->getValueSizeInBits(ResNo: 0);
20724 if (Width < 8 || !isPowerOf2_32(Value: Width) || (Shift & 0x7))
20725 return false;
20726
20727 // Build the slice for this chain of computations.
20728 LoadedSlice LS(User, LD, Shift, &DAG);
20729 APInt CurrentUsedBits = LS.getUsedBits();
20730
20731 // Check if this slice overlaps with another.
20732 if ((CurrentUsedBits & UsedBits) != 0)
20733 return false;
20734 // Update the bits used globally.
20735 UsedBits |= CurrentUsedBits;
20736
20737 // Check if the new slice would be legal.
20738 if (!LS.isLegal())
20739 return false;
20740
20741 // Record the slice.
20742 LoadedSlices.push_back(Elt: LS);
20743 }
20744
20745 // Abort slicing if it does not seem to be profitable.
20746 if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize))
20747 return false;
20748
20749 ++SlicedLoads;
20750
20751 // Rewrite each chain to use an independent load.
20752 // By construction, each chain can be represented by a unique load.
20753
20754 // Prepare the argument for the new token factor for all the slices.
20755 SmallVector<SDValue, 8> ArgChains;
20756 for (const LoadedSlice &LS : LoadedSlices) {
20757 SDValue SliceInst = LS.loadSlice();
20758 CombineTo(N: LS.Inst, Res: SliceInst, AddTo: true);
20759 if (SliceInst.getOpcode() != ISD::LOAD)
20760 SliceInst = SliceInst.getOperand(i: 0);
20761 assert(SliceInst->getOpcode() == ISD::LOAD &&
20762 "It takes more than a zext to get to the loaded slice!!");
20763 ArgChains.push_back(Elt: SliceInst.getValue(R: 1));
20764 }
20765
20766 SDValue Chain = DAG.getNode(Opcode: ISD::TokenFactor, DL: SDLoc(LD), VT: MVT::Other,
20767 Ops: ArgChains);
20768 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: Chain);
20769 AddToWorklist(N: Chain.getNode());
20770 return true;
20771}
20772
20773/// Check to see if V is (and load (ptr), imm), where the load is having
20774/// specific bytes cleared out. If so, return the byte size being masked out
20775/// and the shift amount.
20776static std::pair<unsigned, unsigned>
20777CheckForMaskedLoad(SDValue V, SDValue Ptr, SDValue Chain) {
20778 std::pair<unsigned, unsigned> Result(0, 0);
20779
20780 // Check for the structure we're looking for.
20781 if (V->getOpcode() != ISD::AND ||
20782 !isa<ConstantSDNode>(Val: V->getOperand(Num: 1)) ||
20783 !ISD::isNormalLoad(N: V->getOperand(Num: 0).getNode()))
20784 return Result;
20785
20786 // Check the chain and pointer.
20787 LoadSDNode *LD = cast<LoadSDNode>(Val: V->getOperand(Num: 0));
20788 if (LD->getBasePtr() != Ptr) return Result; // Not from same pointer.
20789
20790 // This only handles simple types.
20791 if (V.getValueType() != MVT::i16 &&
20792 V.getValueType() != MVT::i32 &&
20793 V.getValueType() != MVT::i64)
20794 return Result;
20795
20796 // Check the constant mask. Invert it so that the bits being masked out are
20797 // 0 and the bits being kept are 1. Use getSExtValue so that leading bits
20798 // follow the sign bit for uniformity.
20799 uint64_t NotMask = ~cast<ConstantSDNode>(Val: V->getOperand(Num: 1))->getSExtValue();
20800 unsigned NotMaskLZ = llvm::countl_zero(Val: NotMask);
20801 if (NotMaskLZ & 7) return Result; // Must be multiple of a byte.
20802 unsigned NotMaskTZ = llvm::countr_zero(Val: NotMask);
20803 if (NotMaskTZ & 7) return Result; // Must be multiple of a byte.
20804 if (NotMaskLZ == 64) return Result; // All zero mask.
20805
20806 // See if we have a continuous run of bits. If so, we have 0*1+0*
20807 if (llvm::countr_one(Value: NotMask >> NotMaskTZ) + NotMaskTZ + NotMaskLZ != 64)
20808 return Result;
20809
20810 // Adjust NotMaskLZ down to be from the actual size of the int instead of i64.
20811 if (V.getValueType() != MVT::i64 && NotMaskLZ)
20812 NotMaskLZ -= 64-V.getValueSizeInBits();
20813
20814 unsigned MaskedBytes = (V.getValueSizeInBits()-NotMaskLZ-NotMaskTZ)/8;
20815 switch (MaskedBytes) {
20816 case 1:
20817 case 2:
20818 case 4: break;
20819 default: return Result; // All one mask, or 5-byte mask.
20820 }
20821
20822 // Verify that the first bit starts at a multiple of mask so that the access
20823 // is aligned the same as the access width.
20824 if (NotMaskTZ && NotMaskTZ/8 % MaskedBytes) return Result;
20825
20826 // For narrowing to be valid, it must be the case that the load the
20827 // immediately preceding memory operation before the store.
20828 if (LD == Chain.getNode())
20829 ; // ok.
20830 else if (Chain->getOpcode() == ISD::TokenFactor &&
20831 SDValue(LD, 1).hasOneUse()) {
20832 // LD has only 1 chain use so they are no indirect dependencies.
20833 if (!LD->isOperandOf(N: Chain.getNode()))
20834 return Result;
20835 } else
20836 return Result; // Fail.
20837
20838 Result.first = MaskedBytes;
20839 Result.second = NotMaskTZ/8;
20840 return Result;
20841}
20842
20843/// Check to see if IVal is something that provides a value as specified by
20844/// MaskInfo. If so, replace the specified store with a narrower store of
20845/// truncated IVal.
20846static SDValue
20847ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned, unsigned> &MaskInfo,
20848 SDValue IVal, StoreSDNode *St,
20849 DAGCombiner *DC) {
20850 unsigned NumBytes = MaskInfo.first;
20851 unsigned ByteShift = MaskInfo.second;
20852 SelectionDAG &DAG = DC->getDAG();
20853
20854 // Check to see if IVal is all zeros in the part being masked in by the 'or'
20855 // that uses this. If not, this is not a replacement.
20856 APInt Mask = ~APInt::getBitsSet(numBits: IVal.getValueSizeInBits(),
20857 loBit: ByteShift*8, hiBit: (ByteShift+NumBytes)*8);
20858 if (!DAG.MaskedValueIsZero(Op: IVal, Mask)) return SDValue();
20859
20860 // Check that it is legal on the target to do this. It is legal if the new
20861 // VT we're shrinking to (i8/i16/i32) is legal or we're still before type
20862 // legalization. If the source type is legal, but the store type isn't, see
20863 // if we can use a truncating store.
20864 MVT VT = MVT::getIntegerVT(BitWidth: NumBytes * 8);
20865 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
20866 bool UseTruncStore;
20867 if (DC->isTypeLegal(VT))
20868 UseTruncStore = false;
20869 else if (TLI.isTypeLegal(VT: IVal.getValueType()) &&
20870 TLI.isTruncStoreLegal(ValVT: IVal.getValueType(), MemVT: VT))
20871 UseTruncStore = true;
20872 else
20873 return SDValue();
20874
20875 // Can't do this for indexed stores.
20876 if (St->isIndexed())
20877 return SDValue();
20878
20879 // Check that the target doesn't think this is a bad idea.
20880 if (St->getMemOperand() &&
20881 !TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT,
20882 MMO: *St->getMemOperand()))
20883 return SDValue();
20884
20885 // Okay, we can do this! Replace the 'St' store with a store of IVal that is
20886 // shifted by ByteShift and truncated down to NumBytes.
20887 if (ByteShift) {
20888 SDLoc DL(IVal);
20889 IVal = DAG.getNode(
20890 Opcode: ISD::SRL, DL, VT: IVal.getValueType(), N1: IVal,
20891 N2: DAG.getShiftAmountConstant(Val: ByteShift * 8, VT: IVal.getValueType(), DL));
20892 }
20893
20894 // Figure out the offset for the store and the alignment of the access.
20895 unsigned StOffset;
20896 if (DAG.getDataLayout().isLittleEndian())
20897 StOffset = ByteShift;
20898 else
20899 StOffset = IVal.getValueType().getStoreSize() - ByteShift - NumBytes;
20900
20901 SDValue Ptr = St->getBasePtr();
20902 if (StOffset) {
20903 SDLoc DL(IVal);
20904 Ptr = DAG.getMemBasePlusOffset(Base: Ptr, Offset: TypeSize::getFixed(ExactSize: StOffset), DL);
20905 }
20906
20907 ++OpsNarrowed;
20908 if (UseTruncStore)
20909 return DAG.getTruncStore(Chain: St->getChain(), dl: SDLoc(St), Val: IVal, Ptr,
20910 PtrInfo: St->getPointerInfo().getWithOffset(O: StOffset), SVT: VT,
20911 Alignment: St->getBaseAlign());
20912
20913 // Truncate down to the new size.
20914 IVal = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(IVal), VT, Operand: IVal);
20915
20916 return DAG.getStore(Chain: St->getChain(), dl: SDLoc(St), Val: IVal, Ptr,
20917 PtrInfo: St->getPointerInfo().getWithOffset(O: StOffset),
20918 Alignment: St->getBaseAlign());
20919}
20920
20921/// Look for sequence of load / op / store where op is one of 'or', 'xor', and
20922/// 'and' of immediates. If 'op' is only touching some of the loaded bits, try
20923/// narrowing the load and store if it would end up being a win for performance
20924/// or code size.
20925SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) {
20926 StoreSDNode *ST = cast<StoreSDNode>(Val: N);
20927 if (!ST->isSimple())
20928 return SDValue();
20929
20930 SDValue Chain = ST->getChain();
20931 SDValue Value = ST->getValue();
20932 SDValue Ptr = ST->getBasePtr();
20933 EVT VT = Value.getValueType();
20934
20935 if (ST->isTruncatingStore() || VT.isVector())
20936 return SDValue();
20937
20938 unsigned Opc = Value.getOpcode();
20939
20940 if ((Opc != ISD::OR && Opc != ISD::XOR && Opc != ISD::AND) ||
20941 !Value.hasOneUse())
20942 return SDValue();
20943
20944 // If this is "store (or X, Y), P" and X is "(and (load P), cst)", where cst
20945 // is a byte mask indicating a consecutive number of bytes, check to see if
20946 // Y is known to provide just those bytes. If so, we try to replace the
20947 // load + replace + store sequence with a single (narrower) store, which makes
20948 // the load dead.
20949 if (Opc == ISD::OR && EnableShrinkLoadReplaceStoreWithStore) {
20950 std::pair<unsigned, unsigned> MaskedLoad;
20951 MaskedLoad = CheckForMaskedLoad(V: Value.getOperand(i: 0), Ptr, Chain);
20952 if (MaskedLoad.first)
20953 if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskInfo: MaskedLoad,
20954 IVal: Value.getOperand(i: 1), St: ST,DC: this))
20955 return NewST;
20956
20957 // Or is commutative, so try swapping X and Y.
20958 MaskedLoad = CheckForMaskedLoad(V: Value.getOperand(i: 1), Ptr, Chain);
20959 if (MaskedLoad.first)
20960 if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskInfo: MaskedLoad,
20961 IVal: Value.getOperand(i: 0), St: ST,DC: this))
20962 return NewST;
20963 }
20964
20965 if (!EnableReduceLoadOpStoreWidth)
20966 return SDValue();
20967
20968 if (Value.getOperand(i: 1).getOpcode() != ISD::Constant)
20969 return SDValue();
20970
20971 SDValue N0 = Value.getOperand(i: 0);
20972 if (ISD::isNormalLoad(N: N0.getNode()) && N0.hasOneUse() &&
20973 Chain == SDValue(N0.getNode(), 1)) {
20974 LoadSDNode *LD = cast<LoadSDNode>(Val&: N0);
20975 if (LD->getBasePtr() != Ptr ||
20976 LD->getPointerInfo().getAddrSpace() !=
20977 ST->getPointerInfo().getAddrSpace())
20978 return SDValue();
20979
20980 // Find the type NewVT to narrow the load / op / store to.
20981 SDValue N1 = Value.getOperand(i: 1);
20982 unsigned BitWidth = N1.getValueSizeInBits();
20983 APInt Imm = N1->getAsAPIntVal();
20984 if (Opc == ISD::AND)
20985 Imm.flipAllBits();
20986 if (Imm == 0 || Imm.isAllOnes())
20987 return SDValue();
20988 // Find least/most significant bit that need to be part of the narrowed
20989 // operation. We assume target will need to address/access full bytes, so
20990 // we make sure to align LSB and MSB at byte boundaries.
20991 unsigned BitsPerByteMask = 7u;
20992 unsigned LSB = Imm.countr_zero() & ~BitsPerByteMask;
20993 unsigned MSB = (Imm.getActiveBits() - 1) | BitsPerByteMask;
20994 unsigned NewBW = NextPowerOf2(A: MSB - LSB);
20995 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: NewBW);
20996 // The narrowing should be profitable, the load/store operation should be
20997 // legal (or custom) and the store size should be equal to the NewVT width.
20998 while (NewBW < BitWidth &&
20999 (NewVT.getStoreSizeInBits() != NewBW ||
21000 !TLI.isOperationLegalOrCustom(Op: Opc, VT: NewVT) ||
21001 (!ReduceLoadOpStoreWidthForceNarrowingProfitable &&
21002 !TLI.isNarrowingProfitable(N, SrcVT: VT, DestVT: NewVT)))) {
21003 NewBW = NextPowerOf2(A: NewBW);
21004 NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: NewBW);
21005 }
21006 if (NewBW >= BitWidth)
21007 return SDValue();
21008
21009 // If we come this far NewVT/NewBW reflect a power-of-2 sized type that is
21010 // large enough to cover all bits that should be modified. This type might
21011 // however be larger than really needed (such as i32 while we actually only
21012 // need to modify one byte). Now we need to find our how to align the memory
21013 // accesses to satisfy preferred alignments as well as avoiding to access
21014 // memory outside the store size of the orignal access.
21015
21016 unsigned VTStoreSize = VT.getStoreSizeInBits().getFixedValue();
21017
21018 // Let ShAmt denote amount of bits to skip, counted from the least
21019 // significant bits of Imm. And let PtrOff how much the pointer needs to be
21020 // offsetted (in bytes) for the new access.
21021 unsigned ShAmt = 0;
21022 uint64_t PtrOff = 0;
21023 for (; ShAmt + NewBW <= VTStoreSize; ShAmt += 8) {
21024 // Make sure the range [ShAmt, ShAmt+NewBW) cover both LSB and MSB.
21025 if (ShAmt > LSB)
21026 return SDValue();
21027 if (ShAmt + NewBW < MSB)
21028 continue;
21029
21030 // Calculate PtrOff.
21031 unsigned PtrAdjustmentInBits = DAG.getDataLayout().isBigEndian()
21032 ? VTStoreSize - NewBW - ShAmt
21033 : ShAmt;
21034 PtrOff = PtrAdjustmentInBits / 8;
21035
21036 // Now check if narrow access is allowed and fast, considering alignments.
21037 unsigned IsFast = 0;
21038 Align NewAlign = commonAlignment(A: LD->getAlign(), Offset: PtrOff);
21039 if (TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: NewVT,
21040 AddrSpace: LD->getAddressSpace(), Alignment: NewAlign,
21041 Flags: LD->getMemOperand()->getFlags(), Fast: &IsFast) &&
21042 IsFast)
21043 break;
21044 }
21045 // If loop above did not find any accepted ShAmt we need to exit here.
21046 if (ShAmt + NewBW > VTStoreSize)
21047 return SDValue();
21048
21049 APInt NewImm = Imm.lshr(shiftAmt: ShAmt).trunc(width: NewBW);
21050 if (Opc == ISD::AND)
21051 NewImm.flipAllBits();
21052 Align NewAlign = commonAlignment(A: LD->getAlign(), Offset: PtrOff);
21053 SDValue NewPtr =
21054 DAG.getMemBasePlusOffset(Base: Ptr, Offset: TypeSize::getFixed(ExactSize: PtrOff), DL: SDLoc(LD));
21055 SDValue NewLD =
21056 DAG.getLoad(VT: NewVT, dl: SDLoc(N0), Chain: LD->getChain(), Ptr: NewPtr,
21057 PtrInfo: LD->getPointerInfo().getWithOffset(O: PtrOff), Alignment: NewAlign,
21058 MMOFlags: LD->getMemOperand()->getFlags(), AAInfo: LD->getAAInfo());
21059 SDValue NewVal = DAG.getNode(Opcode: Opc, DL: SDLoc(Value), VT: NewVT, N1: NewLD,
21060 N2: DAG.getConstant(Val: NewImm, DL: SDLoc(Value), VT: NewVT));
21061 SDValue NewST =
21062 DAG.getStore(Chain, dl: SDLoc(N), Val: NewVal, Ptr: NewPtr,
21063 PtrInfo: ST->getPointerInfo().getWithOffset(O: PtrOff), Alignment: NewAlign);
21064
21065 AddToWorklist(N: NewPtr.getNode());
21066 AddToWorklist(N: NewLD.getNode());
21067 AddToWorklist(N: NewVal.getNode());
21068 WorklistRemover DeadNodes(*this);
21069 DAG.ReplaceAllUsesOfValueWith(From: N0.getValue(R: 1), To: NewLD.getValue(R: 1));
21070 ++OpsNarrowed;
21071 return NewST;
21072 }
21073
21074 return SDValue();
21075}
21076
21077/// For a given floating point load / store pair, if the load value isn't used
21078/// by any other operations, then consider transforming the pair to integer
21079/// load / store operations if the target deems the transformation profitable.
21080SDValue DAGCombiner::TransformFPLoadStorePair(SDNode *N) {
21081 StoreSDNode *ST = cast<StoreSDNode>(Val: N);
21082 SDValue Value = ST->getValue();
21083 if (ISD::isNormalStore(N: ST) && ISD::isNormalLoad(N: Value.getNode()) &&
21084 Value.hasOneUse()) {
21085 LoadSDNode *LD = cast<LoadSDNode>(Val&: Value);
21086 EVT VT = LD->getMemoryVT();
21087 if (!VT.isSimple() || !VT.isFloatingPoint() || VT != ST->getMemoryVT() ||
21088 LD->isNonTemporal() || ST->isNonTemporal() ||
21089 LD->getPointerInfo().getAddrSpace() != 0 ||
21090 ST->getPointerInfo().getAddrSpace() != 0)
21091 return SDValue();
21092
21093 TypeSize VTSize = VT.getSizeInBits();
21094
21095 // We don't know the size of scalable types at compile time so we cannot
21096 // create an integer of the equivalent size.
21097 if (VTSize.isScalable())
21098 return SDValue();
21099
21100 unsigned FastLD = 0, FastST = 0;
21101 EVT IntVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: VTSize.getFixedValue());
21102 if (!TLI.isOperationLegal(Op: ISD::LOAD, VT: IntVT) ||
21103 !TLI.isOperationLegal(Op: ISD::STORE, VT: IntVT) ||
21104 !TLI.isDesirableToTransformToIntegerOp(ISD::LOAD, VT) ||
21105 !TLI.isDesirableToTransformToIntegerOp(ISD::STORE, VT) ||
21106 !TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: IntVT,
21107 MMO: *LD->getMemOperand(), Fast: &FastLD) ||
21108 !TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: IntVT,
21109 MMO: *ST->getMemOperand(), Fast: &FastST) ||
21110 !FastLD || !FastST)
21111 return SDValue();
21112
21113 SDValue NewLD = DAG.getLoad(VT: IntVT, dl: SDLoc(Value), Chain: LD->getChain(),
21114 Ptr: LD->getBasePtr(), MMO: LD->getMemOperand());
21115
21116 SDValue NewST = DAG.getStore(Chain: ST->getChain(), dl: SDLoc(N), Val: NewLD,
21117 Ptr: ST->getBasePtr(), MMO: ST->getMemOperand());
21118
21119 AddToWorklist(N: NewLD.getNode());
21120 AddToWorklist(N: NewST.getNode());
21121 WorklistRemover DeadNodes(*this);
21122 DAG.ReplaceAllUsesOfValueWith(From: Value.getValue(R: 1), To: NewLD.getValue(R: 1));
21123 ++LdStFP2Int;
21124 return NewST;
21125 }
21126
21127 return SDValue();
21128}
21129
21130// This is a helper function for visitMUL to check the profitability
21131// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
21132// MulNode is the original multiply, AddNode is (add x, c1),
21133// and ConstNode is c2.
21134//
21135// If the (add x, c1) has multiple uses, we could increase
21136// the number of adds if we make this transformation.
21137// It would only be worth doing this if we can remove a
21138// multiply in the process. Check for that here.
21139// To illustrate:
21140// (A + c1) * c3
21141// (A + c2) * c3
21142// We're checking for cases where we have common "c3 * A" expressions.
21143bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
21144 SDValue ConstNode) {
21145 // If the add only has one use, and the target thinks the folding is
21146 // profitable or does not lead to worse code, this would be OK to do.
21147 if (AddNode->hasOneUse() &&
21148 TLI.isMulAddWithConstProfitable(AddNode, ConstNode))
21149 return true;
21150
21151 // Walk all the users of the constant with which we're multiplying.
21152 for (SDNode *User : ConstNode->users()) {
21153 if (User == MulNode) // This use is the one we're on right now. Skip it.
21154 continue;
21155
21156 if (User->getOpcode() == ISD::MUL) { // We have another multiply use.
21157 SDNode *OtherOp;
21158 SDNode *MulVar = AddNode.getOperand(i: 0).getNode();
21159
21160 // OtherOp is what we're multiplying against the constant.
21161 if (User->getOperand(Num: 0) == ConstNode)
21162 OtherOp = User->getOperand(Num: 1).getNode();
21163 else
21164 OtherOp = User->getOperand(Num: 0).getNode();
21165
21166 // Check to see if multiply is with the same operand of our "add".
21167 //
21168 // ConstNode = CONST
21169 // User = ConstNode * A <-- visiting User. OtherOp is A.
21170 // ...
21171 // AddNode = (A + c1) <-- MulVar is A.
21172 // = AddNode * ConstNode <-- current visiting instruction.
21173 //
21174 // If we make this transformation, we will have a common
21175 // multiply (ConstNode * A) that we can save.
21176 if (OtherOp == MulVar)
21177 return true;
21178
21179 // Now check to see if a future expansion will give us a common
21180 // multiply.
21181 //
21182 // ConstNode = CONST
21183 // AddNode = (A + c1)
21184 // ... = AddNode * ConstNode <-- current visiting instruction.
21185 // ...
21186 // OtherOp = (A + c2)
21187 // User = OtherOp * ConstNode <-- visiting User.
21188 //
21189 // If we make this transformation, we will have a common
21190 // multiply (CONST * A) after we also do the same transformation
21191 // to the "t2" instruction.
21192 if (OtherOp->getOpcode() == ISD::ADD &&
21193 DAG.isConstantIntBuildVectorOrConstantInt(N: OtherOp->getOperand(Num: 1)) &&
21194 OtherOp->getOperand(Num: 0).getNode() == MulVar)
21195 return true;
21196 }
21197 }
21198
21199 // Didn't find a case where this would be profitable.
21200 return false;
21201}
21202
21203SDValue DAGCombiner::getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
21204 unsigned NumStores) {
21205 SmallVector<SDValue, 8> Chains;
21206 SmallPtrSet<const SDNode *, 8> Visited;
21207 SDLoc StoreDL(StoreNodes[0].MemNode);
21208
21209 for (unsigned i = 0; i < NumStores; ++i) {
21210 Visited.insert(Ptr: StoreNodes[i].MemNode);
21211 }
21212
21213 // don't include nodes that are children or repeated nodes.
21214 for (unsigned i = 0; i < NumStores; ++i) {
21215 if (Visited.insert(Ptr: StoreNodes[i].MemNode->getChain().getNode()).second)
21216 Chains.push_back(Elt: StoreNodes[i].MemNode->getChain());
21217 }
21218
21219 assert(!Chains.empty() && "Chain should have generated a chain");
21220 return DAG.getTokenFactor(DL: StoreDL, Vals&: Chains);
21221}
21222
21223bool DAGCombiner::hasSameUnderlyingObj(ArrayRef<MemOpLink> StoreNodes) {
21224 const Value *UnderlyingObj = nullptr;
21225 for (const auto &MemOp : StoreNodes) {
21226 const MachineMemOperand *MMO = MemOp.MemNode->getMemOperand();
21227 // Pseudo value like stack frame has its own frame index and size, should
21228 // not use the first store's frame index for other frames.
21229 if (MMO->getPseudoValue())
21230 return false;
21231
21232 if (!MMO->getValue())
21233 return false;
21234
21235 const Value *Obj = getUnderlyingObject(V: MMO->getValue());
21236
21237 if (UnderlyingObj && UnderlyingObj != Obj)
21238 return false;
21239
21240 if (!UnderlyingObj)
21241 UnderlyingObj = Obj;
21242 }
21243
21244 return true;
21245}
21246
21247bool DAGCombiner::mergeStoresOfConstantsOrVecElts(
21248 SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT, unsigned NumStores,
21249 bool IsConstantSrc, bool UseVector, bool UseTrunc) {
21250 // Make sure we have something to merge.
21251 if (NumStores < 2)
21252 return false;
21253
21254 assert((!UseTrunc || !UseVector) &&
21255 "This optimization cannot emit a vector truncating store");
21256
21257 // The latest Node in the DAG.
21258 SDLoc DL(StoreNodes[0].MemNode);
21259
21260 TypeSize ElementSizeBits = MemVT.getStoreSizeInBits();
21261 unsigned SizeInBits = NumStores * ElementSizeBits;
21262 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
21263
21264 std::optional<MachineMemOperand::Flags> Flags;
21265 AAMDNodes AAInfo;
21266 for (unsigned I = 0; I != NumStores; ++I) {
21267 StoreSDNode *St = cast<StoreSDNode>(Val: StoreNodes[I].MemNode);
21268 if (!Flags) {
21269 Flags = St->getMemOperand()->getFlags();
21270 AAInfo = St->getAAInfo();
21271 continue;
21272 }
21273 // Skip merging if there's an inconsistent flag.
21274 if (Flags != St->getMemOperand()->getFlags())
21275 return false;
21276 // Concatenate AA metadata.
21277 AAInfo = AAInfo.concat(Other: St->getAAInfo());
21278 }
21279
21280 EVT StoreTy;
21281 if (UseVector) {
21282 unsigned Elts = NumStores * NumMemElts;
21283 // Get the type for the merged vector store.
21284 StoreTy = EVT::getVectorVT(Context&: *DAG.getContext(), VT: MemVT.getScalarType(), NumElements: Elts);
21285 } else
21286 StoreTy = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SizeInBits);
21287
21288 SDValue StoredVal;
21289 if (UseVector) {
21290 if (IsConstantSrc) {
21291 SmallVector<SDValue, 8> BuildVector;
21292 for (unsigned I = 0; I != NumStores; ++I) {
21293 StoreSDNode *St = cast<StoreSDNode>(Val: StoreNodes[I].MemNode);
21294 SDValue Val = St->getValue();
21295 // If constant is of the wrong type, convert it now. This comes up
21296 // when one of our stores was truncating.
21297 if (MemVT != Val.getValueType()) {
21298 Val = peekThroughBitcasts(V: Val);
21299 // Deal with constants of wrong size.
21300 if (ElementSizeBits != Val.getValueSizeInBits()) {
21301 auto *C = dyn_cast<ConstantSDNode>(Val);
21302 if (!C)
21303 // Not clear how to truncate FP values.
21304 // TODO: Handle truncation of build_vector constants
21305 return false;
21306
21307 EVT IntMemVT =
21308 EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: MemVT.getSizeInBits());
21309 Val = DAG.getConstant(Val: C->getAPIntValue()
21310 .zextOrTrunc(width: Val.getValueSizeInBits())
21311 .zextOrTrunc(width: ElementSizeBits),
21312 DL: SDLoc(C), VT: IntMemVT);
21313 }
21314 // Make sure correctly size type is the correct type.
21315 Val = DAG.getBitcast(VT: MemVT, V: Val);
21316 }
21317 BuildVector.push_back(Elt: Val);
21318 }
21319 StoredVal = DAG.getNode(Opcode: MemVT.isVector() ? ISD::CONCAT_VECTORS
21320 : ISD::BUILD_VECTOR,
21321 DL, VT: StoreTy, Ops: BuildVector);
21322 } else {
21323 SmallVector<SDValue, 8> Ops;
21324 for (unsigned i = 0; i < NumStores; ++i) {
21325 StoreSDNode *St = cast<StoreSDNode>(Val: StoreNodes[i].MemNode);
21326 SDValue Val = peekThroughBitcasts(V: St->getValue());
21327 // All operands of BUILD_VECTOR / CONCAT_VECTOR must be of
21328 // type MemVT. If the underlying value is not the correct
21329 // type, but it is an extraction of an appropriate vector we
21330 // can recast Val to be of the correct type. This may require
21331 // converting between EXTRACT_VECTOR_ELT and
21332 // EXTRACT_SUBVECTOR.
21333 if ((MemVT != Val.getValueType()) &&
21334 (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
21335 Val.getOpcode() == ISD::EXTRACT_SUBVECTOR)) {
21336 EVT MemVTScalarTy = MemVT.getScalarType();
21337 // We may need to add a bitcast here to get types to line up.
21338 if (MemVTScalarTy != Val.getValueType().getScalarType()) {
21339 Val = DAG.getBitcast(VT: MemVT, V: Val);
21340 } else if (MemVT.isVector() &&
21341 Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
21342 Val = DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: MemVT, Operand: Val);
21343 } else {
21344 unsigned OpC = MemVT.isVector() ? ISD::EXTRACT_SUBVECTOR
21345 : ISD::EXTRACT_VECTOR_ELT;
21346 SDValue Vec = Val.getOperand(i: 0);
21347 SDValue Idx = Val.getOperand(i: 1);
21348 Val = DAG.getNode(Opcode: OpC, DL: SDLoc(Val), VT: MemVT, N1: Vec, N2: Idx);
21349 }
21350 }
21351 Ops.push_back(Elt: Val);
21352 }
21353
21354 // Build the extracted vector elements back into a vector.
21355 StoredVal = DAG.getNode(Opcode: MemVT.isVector() ? ISD::CONCAT_VECTORS
21356 : ISD::BUILD_VECTOR,
21357 DL, VT: StoreTy, Ops);
21358 }
21359 } else {
21360 // We should always use a vector store when merging extracted vector
21361 // elements, so this path implies a store of constants.
21362 assert(IsConstantSrc && "Merged vector elements should use vector store");
21363
21364 APInt StoreInt(SizeInBits, 0);
21365
21366 // Construct a single integer constant which is made of the smaller
21367 // constant inputs.
21368 bool IsLE = DAG.getDataLayout().isLittleEndian();
21369 for (unsigned i = 0; i < NumStores; ++i) {
21370 unsigned Idx = IsLE ? (NumStores - 1 - i) : i;
21371 StoreSDNode *St = cast<StoreSDNode>(Val: StoreNodes[Idx].MemNode);
21372
21373 SDValue Val = St->getValue();
21374 Val = peekThroughBitcasts(V: Val);
21375 StoreInt <<= ElementSizeBits;
21376 if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val)) {
21377 StoreInt |= C->getAPIntValue()
21378 .zextOrTrunc(width: ElementSizeBits)
21379 .zextOrTrunc(width: SizeInBits);
21380 } else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Val)) {
21381 StoreInt |= C->getValueAPF()
21382 .bitcastToAPInt()
21383 .zextOrTrunc(width: ElementSizeBits)
21384 .zextOrTrunc(width: SizeInBits);
21385 // If fp truncation is necessary give up for now.
21386 if (MemVT.getSizeInBits() != ElementSizeBits)
21387 return false;
21388 } else if (ISD::isBuildVectorOfConstantSDNodes(N: Val.getNode()) ||
21389 ISD::isBuildVectorOfConstantFPSDNodes(N: Val.getNode())) {
21390 // Not yet handled
21391 return false;
21392 } else {
21393 llvm_unreachable("Invalid constant element type");
21394 }
21395 }
21396
21397 // Create the new Load and Store operations.
21398 StoredVal = DAG.getConstant(Val: StoreInt, DL, VT: StoreTy);
21399 }
21400
21401 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
21402 SDValue NewChain = getMergeStoreChains(StoreNodes, NumStores);
21403 bool CanReusePtrInfo = hasSameUnderlyingObj(StoreNodes);
21404
21405 // make sure we use trunc store if it's necessary to be legal.
21406 // When generate the new widen store, if the first store's pointer info can
21407 // not be reused, discard the pointer info except the address space because
21408 // now the widen store can not be represented by the original pointer info
21409 // which is for the narrow memory object.
21410 SDValue NewStore;
21411 if (!UseTrunc) {
21412 NewStore = DAG.getStore(
21413 Chain: NewChain, dl: DL, Val: StoredVal, Ptr: FirstInChain->getBasePtr(),
21414 PtrInfo: CanReusePtrInfo
21415 ? FirstInChain->getPointerInfo()
21416 : MachinePointerInfo(FirstInChain->getPointerInfo().getAddrSpace()),
21417 Alignment: FirstInChain->getAlign(), MMOFlags: *Flags, AAInfo);
21418 } else { // Must be realized as a trunc store
21419 EVT LegalizedStoredValTy =
21420 TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: StoredVal.getValueType());
21421 unsigned LegalizedStoreSize = LegalizedStoredValTy.getSizeInBits();
21422 ConstantSDNode *C = cast<ConstantSDNode>(Val&: StoredVal);
21423 SDValue ExtendedStoreVal =
21424 DAG.getConstant(Val: C->getAPIntValue().zextOrTrunc(width: LegalizedStoreSize), DL,
21425 VT: LegalizedStoredValTy);
21426 NewStore = DAG.getTruncStore(
21427 Chain: NewChain, dl: DL, Val: ExtendedStoreVal, Ptr: FirstInChain->getBasePtr(),
21428 PtrInfo: CanReusePtrInfo
21429 ? FirstInChain->getPointerInfo()
21430 : MachinePointerInfo(FirstInChain->getPointerInfo().getAddrSpace()),
21431 SVT: StoredVal.getValueType() /*TVT*/, Alignment: FirstInChain->getAlign(), MMOFlags: *Flags,
21432 AAInfo);
21433 }
21434
21435 // Replace all merged stores with the new store.
21436 for (unsigned i = 0; i < NumStores; ++i)
21437 CombineTo(N: StoreNodes[i].MemNode, Res: NewStore);
21438
21439 AddToWorklist(N: NewChain.getNode());
21440 return true;
21441}
21442
21443SDNode *
21444DAGCombiner::getStoreMergeCandidates(StoreSDNode *St,
21445 SmallVectorImpl<MemOpLink> &StoreNodes) {
21446 // This holds the base pointer, index, and the offset in bytes from the base
21447 // pointer. We must have a base and an offset. Do not handle stores to undef
21448 // base pointers.
21449 BaseIndexOffset BasePtr = BaseIndexOffset::match(N: St, DAG);
21450 if (!BasePtr.getBase().getNode() || BasePtr.getBase().isUndef())
21451 return nullptr;
21452
21453 SDValue Val = peekThroughBitcasts(V: St->getValue());
21454 StoreSource StoreSrc = getStoreSource(StoreVal: Val);
21455 assert(StoreSrc != StoreSource::Unknown && "Expected known source for store");
21456
21457 // Match on loadbaseptr if relevant.
21458 EVT MemVT = St->getMemoryVT();
21459 BaseIndexOffset LBasePtr;
21460 EVT LoadVT;
21461 if (StoreSrc == StoreSource::Load) {
21462 auto *Ld = cast<LoadSDNode>(Val);
21463 LBasePtr = BaseIndexOffset::match(N: Ld, DAG);
21464 LoadVT = Ld->getMemoryVT();
21465 // Load and store should be the same type.
21466 if (MemVT != LoadVT)
21467 return nullptr;
21468 // Loads must only have one use.
21469 if (!Ld->hasNUsesOfValue(NUses: 1, Value: 0))
21470 return nullptr;
21471 // The memory operands must not be volatile/indexed/atomic.
21472 // TODO: May be able to relax for unordered atomics (see D66309)
21473 if (!Ld->isSimple() || Ld->isIndexed())
21474 return nullptr;
21475 }
21476 auto CandidateMatch = [&](StoreSDNode *Other, BaseIndexOffset &Ptr,
21477 int64_t &Offset) -> bool {
21478 // The memory operands must not be volatile/indexed/atomic.
21479 // TODO: May be able to relax for unordered atomics (see D66309)
21480 if (!Other->isSimple() || Other->isIndexed())
21481 return false;
21482 // Don't mix temporal stores with non-temporal stores.
21483 if (St->isNonTemporal() != Other->isNonTemporal())
21484 return false;
21485 if (!TLI.areTwoSDNodeTargetMMOFlagsMergeable(NodeX: *St, NodeY: *Other))
21486 return false;
21487 SDValue OtherBC = peekThroughBitcasts(V: Other->getValue());
21488 // Allow merging constants of different types as integers.
21489 bool NoTypeMatch = (MemVT.isInteger()) ? !MemVT.bitsEq(VT: Other->getMemoryVT())
21490 : Other->getMemoryVT() != MemVT;
21491 switch (StoreSrc) {
21492 case StoreSource::Load: {
21493 if (NoTypeMatch)
21494 return false;
21495 // The Load's Base Ptr must also match.
21496 auto *OtherLd = dyn_cast<LoadSDNode>(Val&: OtherBC);
21497 if (!OtherLd)
21498 return false;
21499 BaseIndexOffset LPtr = BaseIndexOffset::match(N: OtherLd, DAG);
21500 if (LoadVT != OtherLd->getMemoryVT())
21501 return false;
21502 // Loads must only have one use.
21503 if (!OtherLd->hasNUsesOfValue(NUses: 1, Value: 0))
21504 return false;
21505 // The memory operands must not be volatile/indexed/atomic.
21506 // TODO: May be able to relax for unordered atomics (see D66309)
21507 if (!OtherLd->isSimple() || OtherLd->isIndexed())
21508 return false;
21509 // Don't mix temporal loads with non-temporal loads.
21510 if (cast<LoadSDNode>(Val)->isNonTemporal() != OtherLd->isNonTemporal())
21511 return false;
21512 if (!TLI.areTwoSDNodeTargetMMOFlagsMergeable(NodeX: *cast<LoadSDNode>(Val),
21513 NodeY: *OtherLd))
21514 return false;
21515 if (!(LBasePtr.equalBaseIndex(Other: LPtr, DAG)))
21516 return false;
21517 break;
21518 }
21519 case StoreSource::Constant:
21520 if (NoTypeMatch)
21521 return false;
21522 if (getStoreSource(StoreVal: OtherBC) != StoreSource::Constant)
21523 return false;
21524 break;
21525 case StoreSource::Extract:
21526 // Do not merge truncated stores here.
21527 if (Other->isTruncatingStore())
21528 return false;
21529 if (!MemVT.bitsEq(VT: OtherBC.getValueType()))
21530 return false;
21531 if (OtherBC.getOpcode() != ISD::EXTRACT_VECTOR_ELT &&
21532 OtherBC.getOpcode() != ISD::EXTRACT_SUBVECTOR)
21533 return false;
21534 break;
21535 default:
21536 llvm_unreachable("Unhandled store source for merging");
21537 }
21538 Ptr = BaseIndexOffset::match(N: Other, DAG);
21539 return (BasePtr.equalBaseIndex(Other: Ptr, DAG, Off&: Offset));
21540 };
21541
21542 // We are looking for a root node which is an ancestor to all mergable
21543 // stores. We search up through a load, to our root and then down
21544 // through all children. For instance we will find Store{1,2,3} if
21545 // St is Store1, Store2. or Store3 where the root is not a load
21546 // which always true for nonvolatile ops. TODO: Expand
21547 // the search to find all valid candidates through multiple layers of loads.
21548 //
21549 // Root
21550 // |-------|-------|
21551 // Load Load Store3
21552 // | |
21553 // Store1 Store2
21554 //
21555 // FIXME: We should be able to climb and
21556 // descend TokenFactors to find candidates as well.
21557
21558 SDNode *RootNode = St->getChain().getNode();
21559 // Bail out if we already analyzed this root node and found nothing.
21560 if (ChainsWithoutMergeableStores.contains(Ptr: RootNode))
21561 return nullptr;
21562
21563 // Check if the pair of StoreNode and the RootNode already bail out many
21564 // times which is over the limit in dependence check.
21565 auto OverLimitInDependenceCheck = [&](SDNode *StoreNode,
21566 SDNode *RootNode) -> bool {
21567 auto RootCount = StoreRootCountMap.find(Val: StoreNode);
21568 return RootCount != StoreRootCountMap.end() &&
21569 RootCount->second.first == RootNode &&
21570 RootCount->second.second > StoreMergeDependenceLimit;
21571 };
21572
21573 auto TryToAddCandidate = [&](SDUse &Use) {
21574 // This must be a chain use.
21575 if (Use.getOperandNo() != 0)
21576 return;
21577 if (auto *OtherStore = dyn_cast<StoreSDNode>(Val: Use.getUser())) {
21578 BaseIndexOffset Ptr;
21579 int64_t PtrDiff;
21580 if (CandidateMatch(OtherStore, Ptr, PtrDiff) &&
21581 !OverLimitInDependenceCheck(OtherStore, RootNode))
21582 StoreNodes.push_back(Elt: MemOpLink(OtherStore, PtrDiff));
21583 }
21584 };
21585
21586 unsigned NumNodesExplored = 0;
21587 const unsigned MaxSearchNodes = 1024;
21588 if (auto *Ldn = dyn_cast<LoadSDNode>(Val: RootNode)) {
21589 RootNode = Ldn->getChain().getNode();
21590 // Bail out if we already analyzed this root node and found nothing.
21591 if (ChainsWithoutMergeableStores.contains(Ptr: RootNode))
21592 return nullptr;
21593 for (auto I = RootNode->use_begin(), E = RootNode->use_end();
21594 I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored) {
21595 SDNode *User = I->getUser();
21596 if (I->getOperandNo() == 0 && isa<LoadSDNode>(Val: User)) { // walk down chain
21597 for (SDUse &U2 : User->uses())
21598 TryToAddCandidate(U2);
21599 }
21600 // Check stores that depend on the root (e.g. Store 3 in the chart above).
21601 if (I->getOperandNo() == 0 && isa<StoreSDNode>(Val: User)) {
21602 TryToAddCandidate(*I);
21603 }
21604 }
21605 } else {
21606 for (auto I = RootNode->use_begin(), E = RootNode->use_end();
21607 I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored)
21608 TryToAddCandidate(*I);
21609 }
21610
21611 return RootNode;
21612}
21613
21614// We need to check that merging these stores does not cause a loop in the
21615// DAG. Any store candidate may depend on another candidate indirectly through
21616// its operands. Check in parallel by searching up from operands of candidates.
21617bool DAGCombiner::checkMergeStoreCandidatesForDependencies(
21618 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
21619 SDNode *RootNode) {
21620 // FIXME: We should be able to truncate a full search of
21621 // predecessors by doing a BFS and keeping tabs the originating
21622 // stores from which worklist nodes come from in a similar way to
21623 // TokenFactor simplfication.
21624
21625 SmallPtrSet<const SDNode *, 32> Visited;
21626 SmallVector<const SDNode *, 8> Worklist;
21627
21628 // RootNode is a predecessor to all candidates so we need not search
21629 // past it. Add RootNode (peeking through TokenFactors). Do not count
21630 // these towards size check.
21631
21632 Worklist.push_back(Elt: RootNode);
21633 while (!Worklist.empty()) {
21634 auto N = Worklist.pop_back_val();
21635 if (!Visited.insert(Ptr: N).second)
21636 continue; // Already present in Visited.
21637 if (N->getOpcode() == ISD::TokenFactor) {
21638 for (SDValue Op : N->ops())
21639 Worklist.push_back(Elt: Op.getNode());
21640 }
21641 }
21642
21643 // Don't count pruning nodes towards max.
21644 unsigned int Max = 1024 + Visited.size();
21645 // Search Ops of store candidates.
21646 for (unsigned i = 0; i < NumStores; ++i) {
21647 SDNode *N = StoreNodes[i].MemNode;
21648 // Of the 4 Store Operands:
21649 // * Chain (Op 0) -> We have already considered these
21650 // in candidate selection, but only by following the
21651 // chain dependencies. We could still have a chain
21652 // dependency to a load, that has a non-chain dep to
21653 // another load, that depends on a store, etc. So it is
21654 // possible to have dependencies that consist of a mix
21655 // of chain and non-chain deps, and we need to include
21656 // chain operands in the analysis here..
21657 // * Value (Op 1) -> Cycles may happen (e.g. through load chains)
21658 // * Address (Op 2) -> Merged addresses may only vary by a fixed constant,
21659 // but aren't necessarily fromt the same base node, so
21660 // cycles possible (e.g. via indexed store).
21661 // * (Op 3) -> Represents the pre or post-indexing offset (or undef for
21662 // non-indexed stores). Not constant on all targets (e.g. ARM)
21663 // and so can participate in a cycle.
21664 for (const SDValue &Op : N->op_values())
21665 Worklist.push_back(Elt: Op.getNode());
21666 }
21667 // Search through DAG. We can stop early if we find a store node.
21668 for (unsigned i = 0; i < NumStores; ++i)
21669 if (SDNode::hasPredecessorHelper(N: StoreNodes[i].MemNode, Visited, Worklist,
21670 MaxSteps: Max)) {
21671 // If the searching bail out, record the StoreNode and RootNode in the
21672 // StoreRootCountMap. If we have seen the pair many times over a limit,
21673 // we won't add the StoreNode into StoreNodes set again.
21674 if (Visited.size() >= Max) {
21675 auto &RootCount = StoreRootCountMap[StoreNodes[i].MemNode];
21676 if (RootCount.first == RootNode)
21677 RootCount.second++;
21678 else
21679 RootCount = {RootNode, 1};
21680 }
21681 return false;
21682 }
21683 return true;
21684}
21685
21686bool DAGCombiner::hasCallInLdStChain(StoreSDNode *St, LoadSDNode *Ld) {
21687 SmallPtrSet<const SDNode *, 32> Visited;
21688 SmallVector<std::pair<const SDNode *, bool>, 8> Worklist;
21689 Worklist.emplace_back(Args: St->getChain().getNode(), Args: false);
21690
21691 while (!Worklist.empty()) {
21692 auto [Node, FoundCall] = Worklist.pop_back_val();
21693 if (!Visited.insert(Ptr: Node).second || Node->getNumOperands() == 0)
21694 continue;
21695
21696 switch (Node->getOpcode()) {
21697 case ISD::CALLSEQ_END:
21698 Worklist.emplace_back(Args: Node->getOperand(Num: 0).getNode(), Args: true);
21699 break;
21700 case ISD::TokenFactor:
21701 for (SDValue Op : Node->ops())
21702 Worklist.emplace_back(Args: Op.getNode(), Args&: FoundCall);
21703 break;
21704 case ISD::LOAD:
21705 if (Node == Ld)
21706 return FoundCall;
21707 [[fallthrough]];
21708 default:
21709 assert(Node->getOperand(0).getValueType() == MVT::Other &&
21710 "Invalid chain type");
21711 Worklist.emplace_back(Args: Node->getOperand(Num: 0).getNode(), Args&: FoundCall);
21712 break;
21713 }
21714 }
21715 return false;
21716}
21717
21718unsigned
21719DAGCombiner::getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
21720 int64_t ElementSizeBytes) const {
21721 while (true) {
21722 // Find a store past the width of the first store.
21723 size_t StartIdx = 0;
21724 while ((StartIdx + 1 < StoreNodes.size()) &&
21725 StoreNodes[StartIdx].OffsetFromBase + ElementSizeBytes !=
21726 StoreNodes[StartIdx + 1].OffsetFromBase)
21727 ++StartIdx;
21728
21729 // Bail if we don't have enough candidates to merge.
21730 if (StartIdx + 1 >= StoreNodes.size())
21731 return 0;
21732
21733 // Trim stores that overlapped with the first store.
21734 if (StartIdx)
21735 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + StartIdx);
21736
21737 // Scan the memory operations on the chain and find the first
21738 // non-consecutive store memory address.
21739 unsigned NumConsecutiveStores = 1;
21740 int64_t StartAddress = StoreNodes[0].OffsetFromBase;
21741 // Check that the addresses are consecutive starting from the second
21742 // element in the list of stores.
21743 for (unsigned i = 1, e = StoreNodes.size(); i < e; ++i) {
21744 int64_t CurrAddress = StoreNodes[i].OffsetFromBase;
21745 if (CurrAddress - StartAddress != (ElementSizeBytes * i))
21746 break;
21747 NumConsecutiveStores = i + 1;
21748 }
21749 if (NumConsecutiveStores > 1)
21750 return NumConsecutiveStores;
21751
21752 // There are no consecutive stores at the start of the list.
21753 // Remove the first store and try again.
21754 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + 1);
21755 }
21756}
21757
21758bool DAGCombiner::tryStoreMergeOfConstants(
21759 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
21760 EVT MemVT, SDNode *RootNode, bool AllowVectors) {
21761 LLVMContext &Context = *DAG.getContext();
21762 const DataLayout &DL = DAG.getDataLayout();
21763 int64_t ElementSizeBytes = MemVT.getStoreSize();
21764 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
21765 bool MadeChange = false;
21766
21767 // Store the constants into memory as one consecutive store.
21768 while (NumConsecutiveStores >= 2) {
21769 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
21770 unsigned FirstStoreAS = FirstInChain->getAddressSpace();
21771 Align FirstStoreAlign = FirstInChain->getAlign();
21772 unsigned LastLegalType = 1;
21773 unsigned LastLegalVectorType = 1;
21774 bool LastIntegerTrunc = false;
21775 bool NonZero = false;
21776 unsigned FirstZeroAfterNonZero = NumConsecutiveStores;
21777 for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
21778 StoreSDNode *ST = cast<StoreSDNode>(Val: StoreNodes[i].MemNode);
21779 SDValue StoredVal = ST->getValue();
21780 bool IsElementZero = false;
21781 if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val&: StoredVal))
21782 IsElementZero = C->isZero();
21783 else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Val&: StoredVal))
21784 IsElementZero = C->getConstantFPValue()->isNullValue();
21785 else if (ISD::isBuildVectorAllZeros(N: StoredVal.getNode()))
21786 IsElementZero = true;
21787 if (IsElementZero) {
21788 if (NonZero && FirstZeroAfterNonZero == NumConsecutiveStores)
21789 FirstZeroAfterNonZero = i;
21790 }
21791 NonZero |= !IsElementZero;
21792
21793 // Find a legal type for the constant store.
21794 unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
21795 EVT StoreTy = EVT::getIntegerVT(Context, BitWidth: SizeInBits);
21796 unsigned IsFast = 0;
21797
21798 // Break early when size is too large to be legal.
21799 if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
21800 break;
21801
21802 if (TLI.isTypeLegal(VT: StoreTy) &&
21803 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: StoreTy,
21804 MF: DAG.getMachineFunction()) &&
21805 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
21806 MMO: *FirstInChain->getMemOperand(), Fast: &IsFast) &&
21807 IsFast) {
21808 LastIntegerTrunc = false;
21809 LastLegalType = i + 1;
21810 // Or check whether a truncstore is legal.
21811 } else if (TLI.getTypeAction(Context, VT: StoreTy) ==
21812 TargetLowering::TypePromoteInteger) {
21813 EVT LegalizedStoredValTy =
21814 TLI.getTypeToTransformTo(Context, VT: StoredVal.getValueType());
21815 if (TLI.isTruncStoreLegal(ValVT: LegalizedStoredValTy, MemVT: StoreTy) &&
21816 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: LegalizedStoredValTy,
21817 MF: DAG.getMachineFunction()) &&
21818 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
21819 MMO: *FirstInChain->getMemOperand(), Fast: &IsFast) &&
21820 IsFast) {
21821 LastIntegerTrunc = true;
21822 LastLegalType = i + 1;
21823 }
21824 }
21825
21826 // We only use vectors if the target allows it and the function is not
21827 // marked with the noimplicitfloat attribute.
21828 if (TLI.storeOfVectorConstantIsCheap(IsZero: !NonZero, MemVT, NumElem: i + 1, AddrSpace: FirstStoreAS) &&
21829 AllowVectors) {
21830 // Find a legal type for the vector store.
21831 unsigned Elts = (i + 1) * NumMemElts;
21832 EVT Ty = EVT::getVectorVT(Context, VT: MemVT.getScalarType(), NumElements: Elts);
21833 if (TLI.isTypeLegal(VT: Ty) && TLI.isTypeLegal(VT: MemVT) &&
21834 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: Ty, MF: DAG.getMachineFunction()) &&
21835 TLI.allowsMemoryAccess(Context, DL, VT: Ty,
21836 MMO: *FirstInChain->getMemOperand(), Fast: &IsFast) &&
21837 IsFast)
21838 LastLegalVectorType = i + 1;
21839 }
21840 }
21841
21842 bool UseVector = (LastLegalVectorType > LastLegalType) && AllowVectors;
21843 unsigned NumElem = (UseVector) ? LastLegalVectorType : LastLegalType;
21844 bool UseTrunc = LastIntegerTrunc && !UseVector;
21845
21846 // Check if we found a legal integer type that creates a meaningful
21847 // merge.
21848 if (NumElem < 2) {
21849 // We know that candidate stores are in order and of correct
21850 // shape. While there is no mergeable sequence from the
21851 // beginning one may start later in the sequence. The only
21852 // reason a merge of size N could have failed where another of
21853 // the same size would not have, is if the alignment has
21854 // improved or we've dropped a non-zero value. Drop as many
21855 // candidates as we can here.
21856 unsigned NumSkip = 1;
21857 while ((NumSkip < NumConsecutiveStores) &&
21858 (NumSkip < FirstZeroAfterNonZero) &&
21859 (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
21860 NumSkip++;
21861
21862 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumSkip);
21863 NumConsecutiveStores -= NumSkip;
21864 continue;
21865 }
21866
21867 // Check that we can merge these candidates without causing a cycle.
21868 if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumStores: NumElem,
21869 RootNode)) {
21870 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumElem);
21871 NumConsecutiveStores -= NumElem;
21872 continue;
21873 }
21874
21875 MadeChange |= mergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumStores: NumElem,
21876 /*IsConstantSrc*/ true,
21877 UseVector, UseTrunc);
21878
21879 // Remove merged stores for next iteration.
21880 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumElem);
21881 NumConsecutiveStores -= NumElem;
21882 }
21883 return MadeChange;
21884}
21885
21886bool DAGCombiner::tryStoreMergeOfExtracts(
21887 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
21888 EVT MemVT, SDNode *RootNode) {
21889 LLVMContext &Context = *DAG.getContext();
21890 const DataLayout &DL = DAG.getDataLayout();
21891 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
21892 bool MadeChange = false;
21893
21894 // Loop on Consecutive Stores on success.
21895 while (NumConsecutiveStores >= 2) {
21896 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
21897 unsigned FirstStoreAS = FirstInChain->getAddressSpace();
21898 Align FirstStoreAlign = FirstInChain->getAlign();
21899 unsigned NumStoresToMerge = 1;
21900 for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
21901 // Find a legal type for the vector store.
21902 unsigned Elts = (i + 1) * NumMemElts;
21903 EVT Ty = EVT::getVectorVT(Context&: *DAG.getContext(), VT: MemVT.getScalarType(), NumElements: Elts);
21904 unsigned IsFast = 0;
21905
21906 // Break early when size is too large to be legal.
21907 if (Ty.getSizeInBits() > MaximumLegalStoreInBits)
21908 break;
21909
21910 if (TLI.isTypeLegal(VT: Ty) &&
21911 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: Ty, MF: DAG.getMachineFunction()) &&
21912 TLI.allowsMemoryAccess(Context, DL, VT: Ty,
21913 MMO: *FirstInChain->getMemOperand(), Fast: &IsFast) &&
21914 IsFast)
21915 NumStoresToMerge = i + 1;
21916 }
21917
21918 // Check if we found a legal integer type creating a meaningful
21919 // merge.
21920 if (NumStoresToMerge < 2) {
21921 // We know that candidate stores are in order and of correct
21922 // shape. While there is no mergeable sequence from the
21923 // beginning one may start later in the sequence. The only
21924 // reason a merge of size N could have failed where another of
21925 // the same size would not have, is if the alignment has
21926 // improved. Drop as many candidates as we can here.
21927 unsigned NumSkip = 1;
21928 while ((NumSkip < NumConsecutiveStores) &&
21929 (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
21930 NumSkip++;
21931
21932 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumSkip);
21933 NumConsecutiveStores -= NumSkip;
21934 continue;
21935 }
21936
21937 // Check that we can merge these candidates without causing a cycle.
21938 if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumStores: NumStoresToMerge,
21939 RootNode)) {
21940 StoreNodes.erase(CS: StoreNodes.begin(),
21941 CE: StoreNodes.begin() + NumStoresToMerge);
21942 NumConsecutiveStores -= NumStoresToMerge;
21943 continue;
21944 }
21945
21946 MadeChange |= mergeStoresOfConstantsOrVecElts(
21947 StoreNodes, MemVT, NumStores: NumStoresToMerge, /*IsConstantSrc*/ false,
21948 /*UseVector*/ true, /*UseTrunc*/ false);
21949
21950 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumStoresToMerge);
21951 NumConsecutiveStores -= NumStoresToMerge;
21952 }
21953 return MadeChange;
21954}
21955
21956bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
21957 unsigned NumConsecutiveStores, EVT MemVT,
21958 SDNode *RootNode, bool AllowVectors,
21959 bool IsNonTemporalStore,
21960 bool IsNonTemporalLoad) {
21961 LLVMContext &Context = *DAG.getContext();
21962 const DataLayout &DL = DAG.getDataLayout();
21963 int64_t ElementSizeBytes = MemVT.getStoreSize();
21964 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
21965 bool MadeChange = false;
21966
21967 // Look for load nodes which are used by the stored values.
21968 SmallVector<MemOpLink, 8> LoadNodes;
21969
21970 // Find acceptable loads. Loads need to have the same chain (token factor),
21971 // must not be zext, volatile, indexed, and they must be consecutive.
21972 BaseIndexOffset LdBasePtr;
21973
21974 for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
21975 StoreSDNode *St = cast<StoreSDNode>(Val: StoreNodes[i].MemNode);
21976 SDValue Val = peekThroughBitcasts(V: St->getValue());
21977 LoadSDNode *Ld = cast<LoadSDNode>(Val);
21978
21979 BaseIndexOffset LdPtr = BaseIndexOffset::match(N: Ld, DAG);
21980 // If this is not the first ptr that we check.
21981 int64_t LdOffset = 0;
21982 if (LdBasePtr.getBase().getNode()) {
21983 // The base ptr must be the same.
21984 if (!LdBasePtr.equalBaseIndex(Other: LdPtr, DAG, Off&: LdOffset))
21985 break;
21986 } else {
21987 // Check that all other base pointers are the same as this one.
21988 LdBasePtr = LdPtr;
21989 }
21990
21991 // We found a potential memory operand to merge.
21992 LoadNodes.push_back(Elt: MemOpLink(Ld, LdOffset));
21993 }
21994
21995 while (NumConsecutiveStores >= 2 && LoadNodes.size() >= 2) {
21996 Align RequiredAlignment;
21997 bool NeedRotate = false;
21998 if (LoadNodes.size() == 2) {
21999 // If we have load/store pair instructions and we only have two values,
22000 // don't bother merging.
22001 if (TLI.hasPairedLoad(MemVT, RequiredAlignment) &&
22002 StoreNodes[0].MemNode->getAlign() >= RequiredAlignment) {
22003 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + 2);
22004 LoadNodes.erase(CS: LoadNodes.begin(), CE: LoadNodes.begin() + 2);
22005 break;
22006 }
22007 // If the loads are reversed, see if we can rotate the halves into place.
22008 int64_t Offset0 = LoadNodes[0].OffsetFromBase;
22009 int64_t Offset1 = LoadNodes[1].OffsetFromBase;
22010 EVT PairVT = EVT::getIntegerVT(Context, BitWidth: ElementSizeBytes * 8 * 2);
22011 if (Offset0 - Offset1 == ElementSizeBytes &&
22012 (hasOperation(Opcode: ISD::ROTL, VT: PairVT) ||
22013 hasOperation(Opcode: ISD::ROTR, VT: PairVT))) {
22014 std::swap(a&: LoadNodes[0], b&: LoadNodes[1]);
22015 NeedRotate = true;
22016 }
22017 }
22018 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
22019 unsigned FirstStoreAS = FirstInChain->getAddressSpace();
22020 Align FirstStoreAlign = FirstInChain->getAlign();
22021 LoadSDNode *FirstLoad = cast<LoadSDNode>(Val: LoadNodes[0].MemNode);
22022
22023 // Scan the memory operations on the chain and find the first
22024 // non-consecutive load memory address. These variables hold the index in
22025 // the store node array.
22026
22027 unsigned LastConsecutiveLoad = 1;
22028
22029 // This variable refers to the size and not index in the array.
22030 unsigned LastLegalVectorType = 1;
22031 unsigned LastLegalIntegerType = 1;
22032 bool isDereferenceable = true;
22033 bool DoIntegerTruncate = false;
22034 int64_t StartAddress = LoadNodes[0].OffsetFromBase;
22035 SDValue LoadChain = FirstLoad->getChain();
22036 for (unsigned i = 1; i < LoadNodes.size(); ++i) {
22037 // All loads must share the same chain.
22038 if (LoadNodes[i].MemNode->getChain() != LoadChain)
22039 break;
22040
22041 int64_t CurrAddress = LoadNodes[i].OffsetFromBase;
22042 if (CurrAddress - StartAddress != (ElementSizeBytes * i))
22043 break;
22044 LastConsecutiveLoad = i;
22045
22046 if (isDereferenceable && !LoadNodes[i].MemNode->isDereferenceable())
22047 isDereferenceable = false;
22048
22049 // Find a legal type for the vector store.
22050 unsigned Elts = (i + 1) * NumMemElts;
22051 EVT StoreTy = EVT::getVectorVT(Context, VT: MemVT.getScalarType(), NumElements: Elts);
22052
22053 // Break early when size is too large to be legal.
22054 if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
22055 break;
22056
22057 unsigned IsFastSt = 0;
22058 unsigned IsFastLd = 0;
22059 // Don't try vector types if we need a rotate. We may still fail the
22060 // legality checks for the integer type, but we can't handle the rotate
22061 // case with vectors.
22062 // FIXME: We could use a shuffle in place of the rotate.
22063 if (!NeedRotate && TLI.isTypeLegal(VT: StoreTy) &&
22064 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: StoreTy,
22065 MF: DAG.getMachineFunction()) &&
22066 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
22067 MMO: *FirstInChain->getMemOperand(), Fast: &IsFastSt) &&
22068 IsFastSt &&
22069 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
22070 MMO: *FirstLoad->getMemOperand(), Fast: &IsFastLd) &&
22071 IsFastLd) {
22072 LastLegalVectorType = i + 1;
22073 }
22074
22075 // Find a legal type for the integer store.
22076 unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
22077 StoreTy = EVT::getIntegerVT(Context, BitWidth: SizeInBits);
22078 if (TLI.isTypeLegal(VT: StoreTy) &&
22079 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: StoreTy,
22080 MF: DAG.getMachineFunction()) &&
22081 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
22082 MMO: *FirstInChain->getMemOperand(), Fast: &IsFastSt) &&
22083 IsFastSt &&
22084 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
22085 MMO: *FirstLoad->getMemOperand(), Fast: &IsFastLd) &&
22086 IsFastLd) {
22087 LastLegalIntegerType = i + 1;
22088 DoIntegerTruncate = false;
22089 // Or check whether a truncstore and extload is legal.
22090 } else if (TLI.getTypeAction(Context, VT: StoreTy) ==
22091 TargetLowering::TypePromoteInteger) {
22092 EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(Context, VT: StoreTy);
22093 if (TLI.isTruncStoreLegal(ValVT: LegalizedStoredValTy, MemVT: StoreTy) &&
22094 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: LegalizedStoredValTy,
22095 MF: DAG.getMachineFunction()) &&
22096 TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: LegalizedStoredValTy, MemVT: StoreTy) &&
22097 TLI.isLoadExtLegal(ExtType: ISD::SEXTLOAD, ValVT: LegalizedStoredValTy, MemVT: StoreTy) &&
22098 TLI.isLoadExtLegal(ExtType: ISD::EXTLOAD, ValVT: LegalizedStoredValTy, MemVT: StoreTy) &&
22099 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
22100 MMO: *FirstInChain->getMemOperand(), Fast: &IsFastSt) &&
22101 IsFastSt &&
22102 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
22103 MMO: *FirstLoad->getMemOperand(), Fast: &IsFastLd) &&
22104 IsFastLd) {
22105 LastLegalIntegerType = i + 1;
22106 DoIntegerTruncate = true;
22107 }
22108 }
22109 }
22110
22111 // Only use vector types if the vector type is larger than the integer
22112 // type. If they are the same, use integers.
22113 bool UseVectorTy =
22114 LastLegalVectorType > LastLegalIntegerType && AllowVectors;
22115 unsigned LastLegalType =
22116 std::max(a: LastLegalVectorType, b: LastLegalIntegerType);
22117
22118 // We add +1 here because the LastXXX variables refer to location while
22119 // the NumElem refers to array/index size.
22120 unsigned NumElem = std::min(a: NumConsecutiveStores, b: LastConsecutiveLoad + 1);
22121 NumElem = std::min(a: LastLegalType, b: NumElem);
22122 Align FirstLoadAlign = FirstLoad->getAlign();
22123
22124 if (NumElem < 2) {
22125 // We know that candidate stores are in order and of correct
22126 // shape. While there is no mergeable sequence from the
22127 // beginning one may start later in the sequence. The only
22128 // reason a merge of size N could have failed where another of
22129 // the same size would not have is if the alignment or either
22130 // the load or store has improved. Drop as many candidates as we
22131 // can here.
22132 unsigned NumSkip = 1;
22133 while ((NumSkip < LoadNodes.size()) &&
22134 (LoadNodes[NumSkip].MemNode->getAlign() <= FirstLoadAlign) &&
22135 (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
22136 NumSkip++;
22137 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumSkip);
22138 LoadNodes.erase(CS: LoadNodes.begin(), CE: LoadNodes.begin() + NumSkip);
22139 NumConsecutiveStores -= NumSkip;
22140 continue;
22141 }
22142
22143 // Check that we can merge these candidates without causing a cycle.
22144 if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumStores: NumElem,
22145 RootNode)) {
22146 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumElem);
22147 LoadNodes.erase(CS: LoadNodes.begin(), CE: LoadNodes.begin() + NumElem);
22148 NumConsecutiveStores -= NumElem;
22149 continue;
22150 }
22151
22152 // Find if it is better to use vectors or integers to load and store
22153 // to memory.
22154 EVT JointMemOpVT;
22155 if (UseVectorTy) {
22156 // Find a legal type for the vector store.
22157 unsigned Elts = NumElem * NumMemElts;
22158 JointMemOpVT = EVT::getVectorVT(Context, VT: MemVT.getScalarType(), NumElements: Elts);
22159 } else {
22160 unsigned SizeInBits = NumElem * ElementSizeBytes * 8;
22161 JointMemOpVT = EVT::getIntegerVT(Context, BitWidth: SizeInBits);
22162 }
22163
22164 // Check if there is a call in the load/store chain.
22165 if (!TLI.shouldMergeStoreOfLoadsOverCall(MemVT, JointMemOpVT) &&
22166 hasCallInLdStChain(St: cast<StoreSDNode>(Val: StoreNodes[0].MemNode),
22167 Ld: cast<LoadSDNode>(Val: LoadNodes[0].MemNode))) {
22168 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumElem);
22169 LoadNodes.erase(CS: LoadNodes.begin(), CE: LoadNodes.begin() + NumElem);
22170 NumConsecutiveStores -= NumElem;
22171 continue;
22172 }
22173
22174 SDLoc LoadDL(LoadNodes[0].MemNode);
22175 SDLoc StoreDL(StoreNodes[0].MemNode);
22176
22177 // The merged loads are required to have the same incoming chain, so
22178 // using the first's chain is acceptable.
22179
22180 SDValue NewStoreChain = getMergeStoreChains(StoreNodes, NumStores: NumElem);
22181 bool CanReusePtrInfo = hasSameUnderlyingObj(StoreNodes);
22182 AddToWorklist(N: NewStoreChain.getNode());
22183
22184 MachineMemOperand::Flags LdMMOFlags =
22185 isDereferenceable ? MachineMemOperand::MODereferenceable
22186 : MachineMemOperand::MONone;
22187 if (IsNonTemporalLoad)
22188 LdMMOFlags |= MachineMemOperand::MONonTemporal;
22189
22190 LdMMOFlags |= TLI.getTargetMMOFlags(Node: *FirstLoad);
22191
22192 MachineMemOperand::Flags StMMOFlags = IsNonTemporalStore
22193 ? MachineMemOperand::MONonTemporal
22194 : MachineMemOperand::MONone;
22195
22196 StMMOFlags |= TLI.getTargetMMOFlags(Node: *StoreNodes[0].MemNode);
22197
22198 SDValue NewLoad, NewStore;
22199 if (UseVectorTy || !DoIntegerTruncate) {
22200 NewLoad = DAG.getLoad(
22201 VT: JointMemOpVT, dl: LoadDL, Chain: FirstLoad->getChain(), Ptr: FirstLoad->getBasePtr(),
22202 PtrInfo: FirstLoad->getPointerInfo(), Alignment: FirstLoadAlign, MMOFlags: LdMMOFlags);
22203 SDValue StoreOp = NewLoad;
22204 if (NeedRotate) {
22205 unsigned LoadWidth = ElementSizeBytes * 8 * 2;
22206 assert(JointMemOpVT == EVT::getIntegerVT(Context, LoadWidth) &&
22207 "Unexpected type for rotate-able load pair");
22208 SDValue RotAmt =
22209 DAG.getShiftAmountConstant(Val: LoadWidth / 2, VT: JointMemOpVT, DL: LoadDL);
22210 // Target can convert to the identical ROTR if it does not have ROTL.
22211 StoreOp = DAG.getNode(Opcode: ISD::ROTL, DL: LoadDL, VT: JointMemOpVT, N1: NewLoad, N2: RotAmt);
22212 }
22213 NewStore = DAG.getStore(
22214 Chain: NewStoreChain, dl: StoreDL, Val: StoreOp, Ptr: FirstInChain->getBasePtr(),
22215 PtrInfo: CanReusePtrInfo ? FirstInChain->getPointerInfo()
22216 : MachinePointerInfo(FirstStoreAS),
22217 Alignment: FirstStoreAlign, MMOFlags: StMMOFlags);
22218 } else { // This must be the truncstore/extload case
22219 EVT ExtendedTy =
22220 TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: JointMemOpVT);
22221 NewLoad = DAG.getExtLoad(ExtType: ISD::EXTLOAD, dl: LoadDL, VT: ExtendedTy,
22222 Chain: FirstLoad->getChain(), Ptr: FirstLoad->getBasePtr(),
22223 PtrInfo: FirstLoad->getPointerInfo(), MemVT: JointMemOpVT,
22224 Alignment: FirstLoadAlign, MMOFlags: LdMMOFlags);
22225 NewStore = DAG.getTruncStore(
22226 Chain: NewStoreChain, dl: StoreDL, Val: NewLoad, Ptr: FirstInChain->getBasePtr(),
22227 PtrInfo: CanReusePtrInfo ? FirstInChain->getPointerInfo()
22228 : MachinePointerInfo(FirstStoreAS),
22229 SVT: JointMemOpVT, Alignment: FirstInChain->getAlign(),
22230 MMOFlags: FirstInChain->getMemOperand()->getFlags());
22231 }
22232
22233 // Transfer chain users from old loads to the new load.
22234 for (unsigned i = 0; i < NumElem; ++i) {
22235 LoadSDNode *Ld = cast<LoadSDNode>(Val: LoadNodes[i].MemNode);
22236 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Ld, 1),
22237 To: SDValue(NewLoad.getNode(), 1));
22238 }
22239
22240 // Replace all stores with the new store. Recursively remove corresponding
22241 // values if they are no longer used.
22242 for (unsigned i = 0; i < NumElem; ++i) {
22243 SDValue Val = StoreNodes[i].MemNode->getOperand(Num: 1);
22244 CombineTo(N: StoreNodes[i].MemNode, Res: NewStore);
22245 if (Val->use_empty())
22246 recursivelyDeleteUnusedNodes(N: Val.getNode());
22247 }
22248
22249 MadeChange = true;
22250 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumElem);
22251 LoadNodes.erase(CS: LoadNodes.begin(), CE: LoadNodes.begin() + NumElem);
22252 NumConsecutiveStores -= NumElem;
22253 }
22254 return MadeChange;
22255}
22256
22257bool DAGCombiner::mergeConsecutiveStores(StoreSDNode *St) {
22258 if (OptLevel == CodeGenOptLevel::None || !EnableStoreMerging)
22259 return false;
22260
22261 // TODO: Extend this function to merge stores of scalable vectors.
22262 // (i.e. two <vscale x 8 x i8> stores can be merged to one <vscale x 16 x i8>
22263 // store since we know <vscale x 16 x i8> is exactly twice as large as
22264 // <vscale x 8 x i8>). Until then, bail out for scalable vectors.
22265 EVT MemVT = St->getMemoryVT();
22266 if (MemVT.isScalableVT())
22267 return false;
22268 if (!MemVT.isSimple() || MemVT.getSizeInBits() * 2 > MaximumLegalStoreInBits)
22269 return false;
22270
22271 // This function cannot currently deal with non-byte-sized memory sizes.
22272 int64_t ElementSizeBytes = MemVT.getStoreSize();
22273 if (ElementSizeBytes * 8 != (int64_t)MemVT.getSizeInBits())
22274 return false;
22275
22276 // Do not bother looking at stored values that are not constants, loads, or
22277 // extracted vector elements.
22278 SDValue StoredVal = peekThroughBitcasts(V: St->getValue());
22279 const StoreSource StoreSrc = getStoreSource(StoreVal: StoredVal);
22280 if (StoreSrc == StoreSource::Unknown)
22281 return false;
22282
22283 SmallVector<MemOpLink, 8> StoreNodes;
22284 // Find potential store merge candidates by searching through chain sub-DAG
22285 SDNode *RootNode = getStoreMergeCandidates(St, StoreNodes);
22286
22287 // Check if there is anything to merge.
22288 if (StoreNodes.size() < 2)
22289 return false;
22290
22291 // Sort the memory operands according to their distance from the
22292 // base pointer.
22293 llvm::sort(C&: StoreNodes, Comp: [](MemOpLink LHS, MemOpLink RHS) {
22294 return LHS.OffsetFromBase < RHS.OffsetFromBase;
22295 });
22296
22297 bool AllowVectors = !DAG.getMachineFunction().getFunction().hasFnAttribute(
22298 Kind: Attribute::NoImplicitFloat);
22299 bool IsNonTemporalStore = St->isNonTemporal();
22300 bool IsNonTemporalLoad = StoreSrc == StoreSource::Load &&
22301 cast<LoadSDNode>(Val&: StoredVal)->isNonTemporal();
22302
22303 // Store Merge attempts to merge the lowest stores. This generally
22304 // works out as if successful, as the remaining stores are checked
22305 // after the first collection of stores is merged. However, in the
22306 // case that a non-mergeable store is found first, e.g., {p[-2],
22307 // p[0], p[1], p[2], p[3]}, we would fail and miss the subsequent
22308 // mergeable cases. To prevent this, we prune such stores from the
22309 // front of StoreNodes here.
22310 bool MadeChange = false;
22311 while (StoreNodes.size() > 1) {
22312 unsigned NumConsecutiveStores =
22313 getConsecutiveStores(StoreNodes, ElementSizeBytes);
22314 // There are no more stores in the list to examine.
22315 if (NumConsecutiveStores == 0)
22316 return MadeChange;
22317
22318 // We have at least 2 consecutive stores. Try to merge them.
22319 assert(NumConsecutiveStores >= 2 && "Expected at least 2 stores");
22320 switch (StoreSrc) {
22321 case StoreSource::Constant:
22322 MadeChange |= tryStoreMergeOfConstants(StoreNodes, NumConsecutiveStores,
22323 MemVT, RootNode, AllowVectors);
22324 break;
22325
22326 case StoreSource::Extract:
22327 MadeChange |= tryStoreMergeOfExtracts(StoreNodes, NumConsecutiveStores,
22328 MemVT, RootNode);
22329 break;
22330
22331 case StoreSource::Load:
22332 MadeChange |= tryStoreMergeOfLoads(StoreNodes, NumConsecutiveStores,
22333 MemVT, RootNode, AllowVectors,
22334 IsNonTemporalStore, IsNonTemporalLoad);
22335 break;
22336
22337 default:
22338 llvm_unreachable("Unhandled store source type");
22339 }
22340 }
22341
22342 // Remember if we failed to optimize, to save compile time.
22343 if (!MadeChange)
22344 ChainsWithoutMergeableStores.insert(Ptr: RootNode);
22345
22346 return MadeChange;
22347}
22348
22349SDValue DAGCombiner::replaceStoreChain(StoreSDNode *ST, SDValue BetterChain) {
22350 SDLoc SL(ST);
22351 SDValue ReplStore;
22352
22353 // Replace the chain to avoid dependency.
22354 if (ST->isTruncatingStore()) {
22355 ReplStore = DAG.getTruncStore(Chain: BetterChain, dl: SL, Val: ST->getValue(),
22356 Ptr: ST->getBasePtr(), SVT: ST->getMemoryVT(),
22357 MMO: ST->getMemOperand());
22358 } else {
22359 ReplStore = DAG.getStore(Chain: BetterChain, dl: SL, Val: ST->getValue(), Ptr: ST->getBasePtr(),
22360 MMO: ST->getMemOperand());
22361 }
22362
22363 // Create token to keep both nodes around.
22364 SDValue Token = DAG.getNode(Opcode: ISD::TokenFactor, DL: SL,
22365 VT: MVT::Other, N1: ST->getChain(), N2: ReplStore);
22366
22367 // Make sure the new and old chains are cleaned up.
22368 AddToWorklist(N: Token.getNode());
22369
22370 // Don't add users to work list.
22371 return CombineTo(N: ST, Res: Token, AddTo: false);
22372}
22373
22374SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) {
22375 SDValue Value = ST->getValue();
22376 if (Value.getOpcode() == ISD::TargetConstantFP)
22377 return SDValue();
22378
22379 if (!ISD::isNormalStore(N: ST))
22380 return SDValue();
22381
22382 SDLoc DL(ST);
22383
22384 SDValue Chain = ST->getChain();
22385 SDValue Ptr = ST->getBasePtr();
22386
22387 const ConstantFPSDNode *CFP = cast<ConstantFPSDNode>(Val&: Value);
22388
22389 // NOTE: If the original store is volatile, this transform must not increase
22390 // the number of stores. For example, on x86-32 an f64 can be stored in one
22391 // processor operation but an i64 (which is not legal) requires two. So the
22392 // transform should not be done in this case.
22393
22394 SDValue Tmp;
22395 switch (CFP->getSimpleValueType(ResNo: 0).SimpleTy) {
22396 default:
22397 llvm_unreachable("Unknown FP type");
22398 case MVT::f16: // We don't do this for these yet.
22399 case MVT::bf16:
22400 case MVT::f80:
22401 case MVT::f128:
22402 case MVT::ppcf128:
22403 return SDValue();
22404 case MVT::f32:
22405 if ((isTypeLegal(VT: MVT::i32) && !LegalOperations && ST->isSimple()) ||
22406 TLI.isOperationLegalOrCustom(Op: ISD::STORE, VT: MVT::i32)) {
22407 Tmp = DAG.getConstant(Val: (uint32_t)CFP->getValueAPF().
22408 bitcastToAPInt().getZExtValue(), DL: SDLoc(CFP),
22409 VT: MVT::i32);
22410 return DAG.getStore(Chain, dl: DL, Val: Tmp, Ptr, MMO: ST->getMemOperand());
22411 }
22412
22413 return SDValue();
22414 case MVT::f64:
22415 if ((TLI.isTypeLegal(VT: MVT::i64) && !LegalOperations &&
22416 ST->isSimple()) ||
22417 TLI.isOperationLegalOrCustom(Op: ISD::STORE, VT: MVT::i64)) {
22418 Tmp = DAG.getConstant(Val: CFP->getValueAPF().bitcastToAPInt().
22419 getZExtValue(), DL: SDLoc(CFP), VT: MVT::i64);
22420 return DAG.getStore(Chain, dl: DL, Val: Tmp,
22421 Ptr, MMO: ST->getMemOperand());
22422 }
22423
22424 if (ST->isSimple() && TLI.isOperationLegalOrCustom(Op: ISD::STORE, VT: MVT::i32) &&
22425 !TLI.isFPImmLegal(CFP->getValueAPF(), MVT::f64)) {
22426 // Many FP stores are not made apparent until after legalize, e.g. for
22427 // argument passing. Since this is so common, custom legalize the
22428 // 64-bit integer store into two 32-bit stores.
22429 uint64_t Val = CFP->getValueAPF().bitcastToAPInt().getZExtValue();
22430 SDValue Lo = DAG.getConstant(Val: Val & 0xFFFFFFFF, DL: SDLoc(CFP), VT: MVT::i32);
22431 SDValue Hi = DAG.getConstant(Val: Val >> 32, DL: SDLoc(CFP), VT: MVT::i32);
22432 if (DAG.getDataLayout().isBigEndian())
22433 std::swap(a&: Lo, b&: Hi);
22434
22435 MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
22436 AAMDNodes AAInfo = ST->getAAInfo();
22437
22438 SDValue St0 = DAG.getStore(Chain, dl: DL, Val: Lo, Ptr, PtrInfo: ST->getPointerInfo(),
22439 Alignment: ST->getBaseAlign(), MMOFlags, AAInfo);
22440 Ptr = DAG.getMemBasePlusOffset(Base: Ptr, Offset: TypeSize::getFixed(ExactSize: 4), DL);
22441 SDValue St1 = DAG.getStore(Chain, dl: DL, Val: Hi, Ptr,
22442 PtrInfo: ST->getPointerInfo().getWithOffset(O: 4),
22443 Alignment: ST->getBaseAlign(), MMOFlags, AAInfo);
22444 return DAG.getNode(Opcode: ISD::TokenFactor, DL, VT: MVT::Other,
22445 N1: St0, N2: St1);
22446 }
22447
22448 return SDValue();
22449 }
22450}
22451
22452// (store (insert_vector_elt (load p), x, i), p) -> (store x, p+offset)
22453//
22454// If a store of a load with an element inserted into it has no other
22455// uses in between the chain, then we can consider the vector store
22456// dead and replace it with just the single scalar element store.
22457SDValue DAGCombiner::replaceStoreOfInsertLoad(StoreSDNode *ST) {
22458 SDLoc DL(ST);
22459 SDValue Value = ST->getValue();
22460 SDValue Ptr = ST->getBasePtr();
22461 SDValue Chain = ST->getChain();
22462 if (Value.getOpcode() != ISD::INSERT_VECTOR_ELT || !Value.hasOneUse())
22463 return SDValue();
22464
22465 SDValue Elt = Value.getOperand(i: 1);
22466 SDValue Idx = Value.getOperand(i: 2);
22467
22468 // If the element isn't byte sized or is implicitly truncated then we can't
22469 // compute an offset.
22470 EVT EltVT = Elt.getValueType();
22471 if (!EltVT.isByteSized() ||
22472 EltVT != Value.getOperand(i: 0).getValueType().getVectorElementType())
22473 return SDValue();
22474
22475 auto *Ld = dyn_cast<LoadSDNode>(Val: Value.getOperand(i: 0));
22476 if (!Ld || Ld->getBasePtr() != Ptr ||
22477 ST->getMemoryVT() != Ld->getMemoryVT() || !ST->isSimple() ||
22478 !ISD::isNormalStore(N: ST) ||
22479 Ld->getAddressSpace() != ST->getAddressSpace() ||
22480 !Chain.reachesChainWithoutSideEffects(Dest: SDValue(Ld, 1)))
22481 return SDValue();
22482
22483 unsigned IsFast;
22484 if (!TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(),
22485 VT: Elt.getValueType(), AddrSpace: ST->getAddressSpace(),
22486 Alignment: ST->getAlign(), Flags: ST->getMemOperand()->getFlags(),
22487 Fast: &IsFast) ||
22488 !IsFast)
22489 return SDValue();
22490
22491 MachinePointerInfo PointerInfo(ST->getAddressSpace());
22492
22493 // If the offset is a known constant then try to recover the pointer
22494 // info
22495 SDValue NewPtr;
22496 if (auto *CIdx = dyn_cast<ConstantSDNode>(Val&: Idx)) {
22497 unsigned COffset = CIdx->getSExtValue() * EltVT.getSizeInBits() / 8;
22498 NewPtr = DAG.getMemBasePlusOffset(Base: Ptr, Offset: TypeSize::getFixed(ExactSize: COffset), DL);
22499 PointerInfo = ST->getPointerInfo().getWithOffset(O: COffset);
22500 } else {
22501 NewPtr = TLI.getVectorElementPointer(DAG, VecPtr: Ptr, VecVT: Value.getValueType(), Index: Idx);
22502 }
22503
22504 return DAG.getStore(Chain, dl: DL, Val: Elt, Ptr: NewPtr, PtrInfo: PointerInfo, Alignment: ST->getAlign(),
22505 MMOFlags: ST->getMemOperand()->getFlags());
22506}
22507
22508SDValue DAGCombiner::visitATOMIC_STORE(SDNode *N) {
22509 AtomicSDNode *ST = cast<AtomicSDNode>(Val: N);
22510 SDValue Val = ST->getVal();
22511 EVT VT = Val.getValueType();
22512 EVT MemVT = ST->getMemoryVT();
22513
22514 if (MemVT.bitsLT(VT)) { // Is truncating store
22515 APInt TruncDemandedBits = APInt::getLowBitsSet(numBits: VT.getScalarSizeInBits(),
22516 loBitsSet: MemVT.getScalarSizeInBits());
22517 // See if we can simplify the operation with SimplifyDemandedBits, which
22518 // only works if the value has a single use.
22519 if (SimplifyDemandedBits(Op: Val, DemandedBits: TruncDemandedBits))
22520 return SDValue(N, 0);
22521 }
22522
22523 return SDValue();
22524}
22525
22526SDValue DAGCombiner::visitSTORE(SDNode *N) {
22527 StoreSDNode *ST = cast<StoreSDNode>(Val: N);
22528 SDValue Chain = ST->getChain();
22529 SDValue Value = ST->getValue();
22530 SDValue Ptr = ST->getBasePtr();
22531
22532 // If this is a store of a bit convert, store the input value if the
22533 // resultant store does not need a higher alignment than the original.
22534 if (Value.getOpcode() == ISD::BITCAST && !ST->isTruncatingStore() &&
22535 ST->isUnindexed()) {
22536 EVT SVT = Value.getOperand(i: 0).getValueType();
22537 // If the store is volatile, we only want to change the store type if the
22538 // resulting store is legal. Otherwise we might increase the number of
22539 // memory accesses. We don't care if the original type was legal or not
22540 // as we assume software couldn't rely on the number of accesses of an
22541 // illegal type.
22542 // TODO: May be able to relax for unordered atomics (see D66309)
22543 if (((!LegalOperations && ST->isSimple()) ||
22544 TLI.isOperationLegal(Op: ISD::STORE, VT: SVT)) &&
22545 TLI.isStoreBitCastBeneficial(StoreVT: Value.getValueType(), BitcastVT: SVT,
22546 DAG, MMO: *ST->getMemOperand())) {
22547 return DAG.getStore(Chain, dl: SDLoc(N), Val: Value.getOperand(i: 0), Ptr,
22548 MMO: ST->getMemOperand());
22549 }
22550 }
22551
22552 // Turn 'store undef, Ptr' -> nothing.
22553 if (Value.isUndef() && ST->isUnindexed() && !ST->isVolatile())
22554 return Chain;
22555
22556 // Try to infer better alignment information than the store already has.
22557 if (OptLevel != CodeGenOptLevel::None && ST->isUnindexed() &&
22558 !ST->isAtomic()) {
22559 if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
22560 if (*Alignment > ST->getAlign() &&
22561 isAligned(Lhs: *Alignment, SizeInBytes: ST->getSrcValueOffset())) {
22562 SDValue NewStore =
22563 DAG.getTruncStore(Chain, dl: SDLoc(N), Val: Value, Ptr, PtrInfo: ST->getPointerInfo(),
22564 SVT: ST->getMemoryVT(), Alignment: *Alignment,
22565 MMOFlags: ST->getMemOperand()->getFlags(), AAInfo: ST->getAAInfo());
22566 // NewStore will always be N as we are only refining the alignment
22567 assert(NewStore.getNode() == N);
22568 (void)NewStore;
22569 }
22570 }
22571 }
22572
22573 // Try transforming a pair floating point load / store ops to integer
22574 // load / store ops.
22575 if (SDValue NewST = TransformFPLoadStorePair(N))
22576 return NewST;
22577
22578 // Try transforming several stores into STORE (BSWAP).
22579 if (SDValue Store = mergeTruncStores(N: ST))
22580 return Store;
22581
22582 if (ST->isUnindexed()) {
22583 // Walk up chain skipping non-aliasing memory nodes, on this store and any
22584 // adjacent stores.
22585 if (findBetterNeighborChains(St: ST)) {
22586 // replaceStoreChain uses CombineTo, which handled all of the worklist
22587 // manipulation. Return the original node to not do anything else.
22588 return SDValue(ST, 0);
22589 }
22590 Chain = ST->getChain();
22591 }
22592
22593 // FIXME: is there such a thing as a truncating indexed store?
22594 if (ST->isTruncatingStore() && ST->isUnindexed() &&
22595 Value.getValueType().isInteger() &&
22596 (!isa<ConstantSDNode>(Val: Value) ||
22597 !cast<ConstantSDNode>(Val&: Value)->isOpaque())) {
22598 // Convert a truncating store of a extension into a standard store.
22599 if ((Value.getOpcode() == ISD::ZERO_EXTEND ||
22600 Value.getOpcode() == ISD::SIGN_EXTEND ||
22601 Value.getOpcode() == ISD::ANY_EXTEND) &&
22602 Value.getOperand(i: 0).getValueType() == ST->getMemoryVT() &&
22603 TLI.isOperationLegalOrCustom(Op: ISD::STORE, VT: ST->getMemoryVT()))
22604 return DAG.getStore(Chain, dl: SDLoc(N), Val: Value.getOperand(i: 0), Ptr,
22605 MMO: ST->getMemOperand());
22606
22607 APInt TruncDemandedBits =
22608 APInt::getLowBitsSet(numBits: Value.getScalarValueSizeInBits(),
22609 loBitsSet: ST->getMemoryVT().getScalarSizeInBits());
22610
22611 // See if we can simplify the operation with SimplifyDemandedBits, which
22612 // only works if the value has a single use.
22613 AddToWorklist(N: Value.getNode());
22614 if (SimplifyDemandedBits(Op: Value, DemandedBits: TruncDemandedBits)) {
22615 // Re-visit the store if anything changed and the store hasn't been merged
22616 // with another node (N is deleted) SimplifyDemandedBits will add Value's
22617 // node back to the worklist if necessary, but we also need to re-visit
22618 // the Store node itself.
22619 if (N->getOpcode() != ISD::DELETED_NODE)
22620 AddToWorklist(N);
22621 return SDValue(N, 0);
22622 }
22623
22624 // Otherwise, see if we can simplify the input to this truncstore with
22625 // knowledge that only the low bits are being used. For example:
22626 // "truncstore (or (shl x, 8), y), i8" -> "truncstore y, i8"
22627 if (SDValue Shorter =
22628 TLI.SimplifyMultipleUseDemandedBits(Op: Value, DemandedBits: TruncDemandedBits, DAG))
22629 return DAG.getTruncStore(Chain, dl: SDLoc(N), Val: Shorter, Ptr, SVT: ST->getMemoryVT(),
22630 MMO: ST->getMemOperand());
22631
22632 // If we're storing a truncated constant, see if we can simplify it.
22633 // TODO: Move this to targetShrinkDemandedConstant?
22634 if (auto *Cst = dyn_cast<ConstantSDNode>(Val&: Value))
22635 if (!Cst->isOpaque()) {
22636 const APInt &CValue = Cst->getAPIntValue();
22637 APInt NewVal = CValue & TruncDemandedBits;
22638 if (NewVal != CValue) {
22639 SDValue Shorter =
22640 DAG.getConstant(Val: NewVal, DL: SDLoc(N), VT: Value.getValueType());
22641 return DAG.getTruncStore(Chain, dl: SDLoc(N), Val: Shorter, Ptr,
22642 SVT: ST->getMemoryVT(), MMO: ST->getMemOperand());
22643 }
22644 }
22645 }
22646
22647 // If this is a load followed by a store to the same location, then the store
22648 // is dead/noop. Peek through any truncates if canCombineTruncStore failed.
22649 // TODO: Add big-endian truncate support with test coverage.
22650 // TODO: Can relax for unordered atomics (see D66309)
22651 SDValue TruncVal = DAG.getDataLayout().isLittleEndian()
22652 ? peekThroughTruncates(V: Value)
22653 : Value;
22654 if (auto *Ld = dyn_cast<LoadSDNode>(Val&: TruncVal)) {
22655 if (Ld->getBasePtr() == Ptr && ST->getMemoryVT() == Ld->getMemoryVT() &&
22656 ST->isUnindexed() && ST->isSimple() &&
22657 Ld->getAddressSpace() == ST->getAddressSpace() &&
22658 // There can't be any side effects between the load and store, such as
22659 // a call or store.
22660 Chain.reachesChainWithoutSideEffects(Dest: SDValue(Ld, 1))) {
22661 // The store is dead, remove it.
22662 return Chain;
22663 }
22664 }
22665
22666 // Try scalarizing vector stores of loads where we only change one element
22667 if (SDValue NewST = replaceStoreOfInsertLoad(ST))
22668 return NewST;
22669
22670 // TODO: Can relax for unordered atomics (see D66309)
22671 if (StoreSDNode *ST1 = dyn_cast<StoreSDNode>(Val&: Chain)) {
22672 if (ST->isUnindexed() && ST->isSimple() &&
22673 ST1->isUnindexed() && ST1->isSimple()) {
22674 if (OptLevel != CodeGenOptLevel::None && ST1->getBasePtr() == Ptr &&
22675 ST1->getValue() == Value && ST->getMemoryVT() == ST1->getMemoryVT() &&
22676 ST->getAddressSpace() == ST1->getAddressSpace()) {
22677 // If this is a store followed by a store with the same value to the
22678 // same location, then the store is dead/noop.
22679 return Chain;
22680 }
22681
22682 if (OptLevel != CodeGenOptLevel::None && ST1->hasOneUse() &&
22683 !ST1->getBasePtr().isUndef() &&
22684 ST->getAddressSpace() == ST1->getAddressSpace()) {
22685 // If we consider two stores and one smaller in size is a scalable
22686 // vector type and another one a bigger size store with a fixed type,
22687 // then we could not allow the scalable store removal because we don't
22688 // know its final size in the end.
22689 if (ST->getMemoryVT().isScalableVector() ||
22690 ST1->getMemoryVT().isScalableVector()) {
22691 if (ST1->getBasePtr() == Ptr &&
22692 TypeSize::isKnownLE(LHS: ST1->getMemoryVT().getStoreSize(),
22693 RHS: ST->getMemoryVT().getStoreSize())) {
22694 CombineTo(N: ST1, Res: ST1->getChain());
22695 return SDValue(N, 0);
22696 }
22697 } else {
22698 const BaseIndexOffset STBase = BaseIndexOffset::match(N: ST, DAG);
22699 const BaseIndexOffset ChainBase = BaseIndexOffset::match(N: ST1, DAG);
22700 // If this is a store who's preceding store to a subset of the current
22701 // location and no one other node is chained to that store we can
22702 // effectively drop the store. Do not remove stores to undef as they
22703 // may be used as data sinks.
22704 if (STBase.contains(DAG, BitSize: ST->getMemoryVT().getFixedSizeInBits(),
22705 Other: ChainBase,
22706 OtherBitSize: ST1->getMemoryVT().getFixedSizeInBits())) {
22707 CombineTo(N: ST1, Res: ST1->getChain());
22708 return SDValue(N, 0);
22709 }
22710 }
22711 }
22712 }
22713 }
22714
22715 // If this is an FP_ROUND or TRUNC followed by a store, fold this into a
22716 // truncating store. We can do this even if this is already a truncstore.
22717 if ((Value.getOpcode() == ISD::FP_ROUND ||
22718 Value.getOpcode() == ISD::TRUNCATE) &&
22719 Value->hasOneUse() && ST->isUnindexed() &&
22720 TLI.canCombineTruncStore(ValVT: Value.getOperand(i: 0).getValueType(),
22721 MemVT: ST->getMemoryVT(), LegalOnly: LegalOperations)) {
22722 return DAG.getTruncStore(Chain, dl: SDLoc(N), Val: Value.getOperand(i: 0),
22723 Ptr, SVT: ST->getMemoryVT(), MMO: ST->getMemOperand());
22724 }
22725
22726 // Always perform this optimization before types are legal. If the target
22727 // prefers, also try this after legalization to catch stores that were created
22728 // by intrinsics or other nodes.
22729 if (!LegalTypes || (TLI.mergeStoresAfterLegalization(MemVT: ST->getMemoryVT()))) {
22730 while (true) {
22731 // There can be multiple store sequences on the same chain.
22732 // Keep trying to merge store sequences until we are unable to do so
22733 // or until we merge the last store on the chain.
22734 bool Changed = mergeConsecutiveStores(St: ST);
22735 if (!Changed) break;
22736 // Return N as merge only uses CombineTo and no worklist clean
22737 // up is necessary.
22738 if (N->getOpcode() == ISD::DELETED_NODE || !isa<StoreSDNode>(Val: N))
22739 return SDValue(N, 0);
22740 }
22741 }
22742
22743 // Try transforming N to an indexed store.
22744 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
22745 return SDValue(N, 0);
22746
22747 // Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr'
22748 //
22749 // Make sure to do this only after attempting to merge stores in order to
22750 // avoid changing the types of some subset of stores due to visit order,
22751 // preventing their merging.
22752 if (isa<ConstantFPSDNode>(Val: ST->getValue())) {
22753 if (SDValue NewSt = replaceStoreOfFPConstant(ST))
22754 return NewSt;
22755 }
22756
22757 if (SDValue NewSt = splitMergedValStore(ST))
22758 return NewSt;
22759
22760 return ReduceLoadOpStoreWidth(N);
22761}
22762
22763SDValue DAGCombiner::visitLIFETIME_END(SDNode *N) {
22764 const auto *LifetimeEnd = cast<LifetimeSDNode>(Val: N);
22765 if (!LifetimeEnd->hasOffset())
22766 return SDValue();
22767
22768 const BaseIndexOffset LifetimeEndBase(N->getOperand(Num: 1), SDValue(),
22769 LifetimeEnd->getOffset(), false);
22770
22771 // We walk up the chains to find stores.
22772 SmallVector<SDValue, 8> Chains = {N->getOperand(Num: 0)};
22773 while (!Chains.empty()) {
22774 SDValue Chain = Chains.pop_back_val();
22775 if (!Chain.hasOneUse())
22776 continue;
22777 switch (Chain.getOpcode()) {
22778 case ISD::TokenFactor:
22779 for (unsigned Nops = Chain.getNumOperands(); Nops;)
22780 Chains.push_back(Elt: Chain.getOperand(i: --Nops));
22781 break;
22782 case ISD::LIFETIME_START:
22783 case ISD::LIFETIME_END:
22784 // We can forward past any lifetime start/end that can be proven not to
22785 // alias the node.
22786 if (!mayAlias(Op0: Chain.getNode(), Op1: N))
22787 Chains.push_back(Elt: Chain.getOperand(i: 0));
22788 break;
22789 case ISD::STORE: {
22790 StoreSDNode *ST = dyn_cast<StoreSDNode>(Val&: Chain);
22791 // TODO: Can relax for unordered atomics (see D66309)
22792 if (!ST->isSimple() || ST->isIndexed())
22793 continue;
22794 const TypeSize StoreSize = ST->getMemoryVT().getStoreSize();
22795 // The bounds of a scalable store are not known until runtime, so this
22796 // store cannot be elided.
22797 if (StoreSize.isScalable())
22798 continue;
22799 const BaseIndexOffset StoreBase = BaseIndexOffset::match(N: ST, DAG);
22800 // If we store purely within object bounds just before its lifetime ends,
22801 // we can remove the store.
22802 if (LifetimeEndBase.contains(DAG, BitSize: LifetimeEnd->getSize() * 8, Other: StoreBase,
22803 OtherBitSize: StoreSize.getFixedValue() * 8)) {
22804 LLVM_DEBUG(dbgs() << "\nRemoving store:"; StoreBase.dump();
22805 dbgs() << "\nwithin LIFETIME_END of : ";
22806 LifetimeEndBase.dump(); dbgs() << "\n");
22807 CombineTo(N: ST, Res: ST->getChain());
22808 return SDValue(N, 0);
22809 }
22810 }
22811 }
22812 }
22813 return SDValue();
22814}
22815
22816/// For the instruction sequence of store below, F and I values
22817/// are bundled together as an i64 value before being stored into memory.
22818/// Sometimes it is more efficent to generate separate stores for F and I,
22819/// which can remove the bitwise instructions or sink them to colder places.
22820///
22821/// (store (or (zext (bitcast F to i32) to i64),
22822/// (shl (zext I to i64), 32)), addr) -->
22823/// (store F, addr) and (store I, addr+4)
22824///
22825/// Similarly, splitting for other merged store can also be beneficial, like:
22826/// For pair of {i32, i32}, i64 store --> two i32 stores.
22827/// For pair of {i32, i16}, i64 store --> two i32 stores.
22828/// For pair of {i16, i16}, i32 store --> two i16 stores.
22829/// For pair of {i16, i8}, i32 store --> two i16 stores.
22830/// For pair of {i8, i8}, i16 store --> two i8 stores.
22831///
22832/// We allow each target to determine specifically which kind of splitting is
22833/// supported.
22834///
22835/// The store patterns are commonly seen from the simple code snippet below
22836/// if only std::make_pair(...) is sroa transformed before inlined into hoo.
22837/// void goo(const std::pair<int, float> &);
22838/// hoo() {
22839/// ...
22840/// goo(std::make_pair(tmp, ftmp));
22841/// ...
22842/// }
22843///
22844SDValue DAGCombiner::splitMergedValStore(StoreSDNode *ST) {
22845 if (OptLevel == CodeGenOptLevel::None)
22846 return SDValue();
22847
22848 // Can't change the number of memory accesses for a volatile store or break
22849 // atomicity for an atomic one.
22850 if (!ST->isSimple())
22851 return SDValue();
22852
22853 SDValue Val = ST->getValue();
22854 SDLoc DL(ST);
22855
22856 // Match OR operand.
22857 if (!Val.getValueType().isScalarInteger() || Val.getOpcode() != ISD::OR)
22858 return SDValue();
22859
22860 // Match SHL operand and get Lower and Higher parts of Val.
22861 SDValue Op1 = Val.getOperand(i: 0);
22862 SDValue Op2 = Val.getOperand(i: 1);
22863 SDValue Lo, Hi;
22864 if (Op1.getOpcode() != ISD::SHL) {
22865 std::swap(a&: Op1, b&: Op2);
22866 if (Op1.getOpcode() != ISD::SHL)
22867 return SDValue();
22868 }
22869 Lo = Op2;
22870 Hi = Op1.getOperand(i: 0);
22871 if (!Op1.hasOneUse())
22872 return SDValue();
22873
22874 // Match shift amount to HalfValBitSize.
22875 unsigned HalfValBitSize = Val.getValueSizeInBits() / 2;
22876 ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(Val: Op1.getOperand(i: 1));
22877 if (!ShAmt || ShAmt->getAPIntValue() != HalfValBitSize)
22878 return SDValue();
22879
22880 // Lo and Hi are zero-extended from int with size less equal than 32
22881 // to i64.
22882 if (Lo.getOpcode() != ISD::ZERO_EXTEND || !Lo.hasOneUse() ||
22883 !Lo.getOperand(i: 0).getValueType().isScalarInteger() ||
22884 Lo.getOperand(i: 0).getValueSizeInBits() > HalfValBitSize ||
22885 Hi.getOpcode() != ISD::ZERO_EXTEND || !Hi.hasOneUse() ||
22886 !Hi.getOperand(i: 0).getValueType().isScalarInteger() ||
22887 Hi.getOperand(i: 0).getValueSizeInBits() > HalfValBitSize)
22888 return SDValue();
22889
22890 // Use the EVT of low and high parts before bitcast as the input
22891 // of target query.
22892 EVT LowTy = (Lo.getOperand(i: 0).getOpcode() == ISD::BITCAST)
22893 ? Lo.getOperand(i: 0).getValueType()
22894 : Lo.getValueType();
22895 EVT HighTy = (Hi.getOperand(i: 0).getOpcode() == ISD::BITCAST)
22896 ? Hi.getOperand(i: 0).getValueType()
22897 : Hi.getValueType();
22898 if (!TLI.isMultiStoresCheaperThanBitsMerge(LTy: LowTy, HTy: HighTy))
22899 return SDValue();
22900
22901 // Start to split store.
22902 MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
22903 AAMDNodes AAInfo = ST->getAAInfo();
22904
22905 // Change the sizes of Lo and Hi's value types to HalfValBitSize.
22906 EVT VT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: HalfValBitSize);
22907 Lo = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: Lo.getOperand(i: 0));
22908 Hi = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: Hi.getOperand(i: 0));
22909
22910 SDValue Chain = ST->getChain();
22911 SDValue Ptr = ST->getBasePtr();
22912 // Lower value store.
22913 SDValue St0 = DAG.getStore(Chain, dl: DL, Val: Lo, Ptr, PtrInfo: ST->getPointerInfo(),
22914 Alignment: ST->getBaseAlign(), MMOFlags, AAInfo);
22915 Ptr =
22916 DAG.getMemBasePlusOffset(Base: Ptr, Offset: TypeSize::getFixed(ExactSize: HalfValBitSize / 8), DL);
22917 // Higher value store.
22918 SDValue St1 = DAG.getStore(
22919 Chain: St0, dl: DL, Val: Hi, Ptr, PtrInfo: ST->getPointerInfo().getWithOffset(O: HalfValBitSize / 8),
22920 Alignment: ST->getBaseAlign(), MMOFlags, AAInfo);
22921 return St1;
22922}
22923
22924// Merge an insertion into an existing shuffle:
22925// (insert_vector_elt (vector_shuffle X, Y, Mask),
22926// .(extract_vector_elt X, N), InsIndex)
22927// --> (vector_shuffle X, Y, NewMask)
22928// and variations where shuffle operands may be CONCAT_VECTORS.
22929static bool mergeEltWithShuffle(SDValue &X, SDValue &Y, ArrayRef<int> Mask,
22930 SmallVectorImpl<int> &NewMask, SDValue Elt,
22931 unsigned InsIndex) {
22932 if (Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
22933 !isa<ConstantSDNode>(Val: Elt.getOperand(i: 1)))
22934 return false;
22935
22936 // Vec's operand 0 is using indices from 0 to N-1 and
22937 // operand 1 from N to 2N - 1, where N is the number of
22938 // elements in the vectors.
22939 SDValue InsertVal0 = Elt.getOperand(i: 0);
22940 int ElementOffset = -1;
22941
22942 // We explore the inputs of the shuffle in order to see if we find the
22943 // source of the extract_vector_elt. If so, we can use it to modify the
22944 // shuffle rather than perform an insert_vector_elt.
22945 SmallVector<std::pair<int, SDValue>, 8> ArgWorkList;
22946 ArgWorkList.emplace_back(Args: Mask.size(), Args&: Y);
22947 ArgWorkList.emplace_back(Args: 0, Args&: X);
22948
22949 while (!ArgWorkList.empty()) {
22950 int ArgOffset;
22951 SDValue ArgVal;
22952 std::tie(args&: ArgOffset, args&: ArgVal) = ArgWorkList.pop_back_val();
22953
22954 if (ArgVal == InsertVal0) {
22955 ElementOffset = ArgOffset;
22956 break;
22957 }
22958
22959 // Peek through concat_vector.
22960 if (ArgVal.getOpcode() == ISD::CONCAT_VECTORS) {
22961 int CurrentArgOffset =
22962 ArgOffset + ArgVal.getValueType().getVectorNumElements();
22963 int Step = ArgVal.getOperand(i: 0).getValueType().getVectorNumElements();
22964 for (SDValue Op : reverse(C: ArgVal->ops())) {
22965 CurrentArgOffset -= Step;
22966 ArgWorkList.emplace_back(Args&: CurrentArgOffset, Args&: Op);
22967 }
22968
22969 // Make sure we went through all the elements and did not screw up index
22970 // computation.
22971 assert(CurrentArgOffset == ArgOffset);
22972 }
22973 }
22974
22975 // If we failed to find a match, see if we can replace an UNDEF shuffle
22976 // operand.
22977 if (ElementOffset == -1) {
22978 if (!Y.isUndef() || InsertVal0.getValueType() != Y.getValueType())
22979 return false;
22980 ElementOffset = Mask.size();
22981 Y = InsertVal0;
22982 }
22983
22984 NewMask.assign(in_start: Mask.begin(), in_end: Mask.end());
22985 NewMask[InsIndex] = ElementOffset + Elt.getConstantOperandVal(i: 1);
22986 assert(NewMask[InsIndex] < (int)(2 * Mask.size()) && NewMask[InsIndex] >= 0 &&
22987 "NewMask[InsIndex] is out of bound");
22988 return true;
22989}
22990
22991// Merge an insertion into an existing shuffle:
22992// (insert_vector_elt (vector_shuffle X, Y), (extract_vector_elt X, N),
22993// InsIndex)
22994// --> (vector_shuffle X, Y) and variations where shuffle operands may be
22995// CONCAT_VECTORS.
22996SDValue DAGCombiner::mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex) {
22997 assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
22998 "Expected extract_vector_elt");
22999 SDValue InsertVal = N->getOperand(Num: 1);
23000 SDValue Vec = N->getOperand(Num: 0);
23001
23002 auto *SVN = dyn_cast<ShuffleVectorSDNode>(Val&: Vec);
23003 if (!SVN || !Vec.hasOneUse())
23004 return SDValue();
23005
23006 ArrayRef<int> Mask = SVN->getMask();
23007 SDValue X = Vec.getOperand(i: 0);
23008 SDValue Y = Vec.getOperand(i: 1);
23009
23010 SmallVector<int, 16> NewMask(Mask);
23011 if (mergeEltWithShuffle(X, Y, Mask, NewMask, Elt: InsertVal, InsIndex)) {
23012 SDValue LegalShuffle = TLI.buildLegalVectorShuffle(
23013 VT: Vec.getValueType(), DL: SDLoc(N), N0: X, N1: Y, Mask: NewMask, DAG);
23014 if (LegalShuffle)
23015 return LegalShuffle;
23016 }
23017
23018 return SDValue();
23019}
23020
23021// Convert a disguised subvector insertion into a shuffle:
23022// insert_vector_elt V, (bitcast X from vector type), IdxC -->
23023// bitcast(shuffle (bitcast V), (extended X), Mask)
23024// Note: We do not use an insert_subvector node because that requires a
23025// legal subvector type.
23026SDValue DAGCombiner::combineInsertEltToShuffle(SDNode *N, unsigned InsIndex) {
23027 assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
23028 "Expected extract_vector_elt");
23029 SDValue InsertVal = N->getOperand(Num: 1);
23030
23031 if (InsertVal.getOpcode() != ISD::BITCAST || !InsertVal.hasOneUse() ||
23032 !InsertVal.getOperand(i: 0).getValueType().isVector())
23033 return SDValue();
23034
23035 SDValue SubVec = InsertVal.getOperand(i: 0);
23036 SDValue DestVec = N->getOperand(Num: 0);
23037 EVT SubVecVT = SubVec.getValueType();
23038 EVT VT = DestVec.getValueType();
23039 unsigned NumSrcElts = SubVecVT.getVectorNumElements();
23040 // If the source only has a single vector element, the cost of creating adding
23041 // it to a vector is likely to exceed the cost of a insert_vector_elt.
23042 if (NumSrcElts == 1)
23043 return SDValue();
23044 unsigned ExtendRatio = VT.getSizeInBits() / SubVecVT.getSizeInBits();
23045 unsigned NumMaskVals = ExtendRatio * NumSrcElts;
23046
23047 // Step 1: Create a shuffle mask that implements this insert operation. The
23048 // vector that we are inserting into will be operand 0 of the shuffle, so
23049 // those elements are just 'i'. The inserted subvector is in the first
23050 // positions of operand 1 of the shuffle. Example:
23051 // insert v4i32 V, (v2i16 X), 2 --> shuffle v8i16 V', X', {0,1,2,3,8,9,6,7}
23052 SmallVector<int, 16> Mask(NumMaskVals);
23053 for (unsigned i = 0; i != NumMaskVals; ++i) {
23054 if (i / NumSrcElts == InsIndex)
23055 Mask[i] = (i % NumSrcElts) + NumMaskVals;
23056 else
23057 Mask[i] = i;
23058 }
23059
23060 // Bail out if the target can not handle the shuffle we want to create.
23061 EVT SubVecEltVT = SubVecVT.getVectorElementType();
23062 EVT ShufVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: SubVecEltVT, NumElements: NumMaskVals);
23063 if (!TLI.isShuffleMaskLegal(Mask, ShufVT))
23064 return SDValue();
23065
23066 // Step 2: Create a wide vector from the inserted source vector by appending
23067 // undefined elements. This is the same size as our destination vector.
23068 SDLoc DL(N);
23069 SmallVector<SDValue, 8> ConcatOps(ExtendRatio, DAG.getUNDEF(VT: SubVecVT));
23070 ConcatOps[0] = SubVec;
23071 SDValue PaddedSubV = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT: ShufVT, Ops: ConcatOps);
23072
23073 // Step 3: Shuffle in the padded subvector.
23074 SDValue DestVecBC = DAG.getBitcast(VT: ShufVT, V: DestVec);
23075 SDValue Shuf = DAG.getVectorShuffle(VT: ShufVT, dl: DL, N1: DestVecBC, N2: PaddedSubV, Mask);
23076 AddToWorklist(N: PaddedSubV.getNode());
23077 AddToWorklist(N: DestVecBC.getNode());
23078 AddToWorklist(N: Shuf.getNode());
23079 return DAG.getBitcast(VT, V: Shuf);
23080}
23081
23082// Combine insert(shuffle(load, <u,0,1,2>), load, 0) into a single load if
23083// possible and the new load will be quick. We use more loads but less shuffles
23084// and inserts.
23085SDValue DAGCombiner::combineInsertEltToLoad(SDNode *N, unsigned InsIndex) {
23086 EVT VT = N->getValueType(ResNo: 0);
23087
23088 // InsIndex is expected to be the first of last lane.
23089 if (!VT.isFixedLengthVector() ||
23090 (InsIndex != 0 && InsIndex != VT.getVectorNumElements() - 1))
23091 return SDValue();
23092
23093 // Look for a shuffle with the mask u,0,1,2,3,4,5,6 or 1,2,3,4,5,6,7,u
23094 // depending on the InsIndex.
23095 auto *Shuffle = dyn_cast<ShuffleVectorSDNode>(Val: N->getOperand(Num: 0));
23096 SDValue Scalar = N->getOperand(Num: 1);
23097 if (!Shuffle || !all_of(Range: enumerate(First: Shuffle->getMask()), P: [&](auto P) {
23098 return InsIndex == P.index() || P.value() < 0 ||
23099 (InsIndex == 0 && P.value() == (int)P.index() - 1) ||
23100 (InsIndex == VT.getVectorNumElements() - 1 &&
23101 P.value() == (int)P.index() + 1);
23102 }))
23103 return SDValue();
23104
23105 // We optionally skip over an extend so long as both loads are extended in the
23106 // same way from the same type.
23107 unsigned Extend = 0;
23108 if (Scalar.getOpcode() == ISD::ZERO_EXTEND ||
23109 Scalar.getOpcode() == ISD::SIGN_EXTEND ||
23110 Scalar.getOpcode() == ISD::ANY_EXTEND) {
23111 Extend = Scalar.getOpcode();
23112 Scalar = Scalar.getOperand(i: 0);
23113 }
23114
23115 auto *ScalarLoad = dyn_cast<LoadSDNode>(Val&: Scalar);
23116 if (!ScalarLoad)
23117 return SDValue();
23118
23119 SDValue Vec = Shuffle->getOperand(Num: 0);
23120 if (Extend) {
23121 if (Vec.getOpcode() != Extend)
23122 return SDValue();
23123 Vec = Vec.getOperand(i: 0);
23124 }
23125 auto *VecLoad = dyn_cast<LoadSDNode>(Val&: Vec);
23126 if (!VecLoad || Vec.getValueType().getScalarType() != Scalar.getValueType())
23127 return SDValue();
23128
23129 int EltSize = ScalarLoad->getValueType(ResNo: 0).getScalarSizeInBits();
23130 if (EltSize == 0 || EltSize % 8 != 0 || !ScalarLoad->isSimple() ||
23131 !VecLoad->isSimple() || VecLoad->getExtensionType() != ISD::NON_EXTLOAD ||
23132 ScalarLoad->getExtensionType() != ISD::NON_EXTLOAD ||
23133 ScalarLoad->getAddressSpace() != VecLoad->getAddressSpace())
23134 return SDValue();
23135
23136 // Check that the offset between the pointers to produce a single continuous
23137 // load.
23138 if (InsIndex == 0) {
23139 if (!DAG.areNonVolatileConsecutiveLoads(LD: ScalarLoad, Base: VecLoad, Bytes: EltSize / 8,
23140 Dist: -1))
23141 return SDValue();
23142 } else {
23143 if (!DAG.areNonVolatileConsecutiveLoads(
23144 LD: VecLoad, Base: ScalarLoad, Bytes: VT.getVectorNumElements() * EltSize / 8, Dist: -1))
23145 return SDValue();
23146 }
23147
23148 // And that the new unaligned load will be fast.
23149 unsigned IsFast = 0;
23150 Align NewAlign = commonAlignment(A: VecLoad->getAlign(), Offset: EltSize / 8);
23151 if (!TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(),
23152 VT: Vec.getValueType(), AddrSpace: VecLoad->getAddressSpace(),
23153 Alignment: NewAlign, Flags: VecLoad->getMemOperand()->getFlags(),
23154 Fast: &IsFast) ||
23155 !IsFast)
23156 return SDValue();
23157
23158 // Calculate the new Ptr and create the new load.
23159 SDLoc DL(N);
23160 SDValue Ptr = ScalarLoad->getBasePtr();
23161 if (InsIndex != 0)
23162 Ptr = DAG.getNode(Opcode: ISD::ADD, DL, VT: Ptr.getValueType(), N1: VecLoad->getBasePtr(),
23163 N2: DAG.getConstant(Val: EltSize / 8, DL, VT: Ptr.getValueType()));
23164 MachinePointerInfo PtrInfo =
23165 InsIndex == 0 ? ScalarLoad->getPointerInfo()
23166 : VecLoad->getPointerInfo().getWithOffset(O: EltSize / 8);
23167
23168 SDValue Load = DAG.getLoad(VT: VecLoad->getValueType(ResNo: 0), dl: DL,
23169 Chain: ScalarLoad->getChain(), Ptr, PtrInfo, Alignment: NewAlign);
23170 DAG.makeEquivalentMemoryOrdering(OldLoad: ScalarLoad, NewMemOp: Load.getValue(R: 1));
23171 DAG.makeEquivalentMemoryOrdering(OldLoad: VecLoad, NewMemOp: Load.getValue(R: 1));
23172 return Extend ? DAG.getNode(Opcode: Extend, DL, VT, Operand: Load) : Load;
23173}
23174
23175SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
23176 SDValue InVec = N->getOperand(Num: 0);
23177 SDValue InVal = N->getOperand(Num: 1);
23178 SDValue EltNo = N->getOperand(Num: 2);
23179 SDLoc DL(N);
23180
23181 EVT VT = InVec.getValueType();
23182 auto *IndexC = dyn_cast<ConstantSDNode>(Val&: EltNo);
23183
23184 // Insert into out-of-bounds element is undefined.
23185 if (IndexC && VT.isFixedLengthVector() &&
23186 IndexC->getZExtValue() >= VT.getVectorNumElements())
23187 return DAG.getUNDEF(VT);
23188
23189 // Remove redundant insertions:
23190 // (insert_vector_elt x (extract_vector_elt x idx) idx) -> x
23191 if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
23192 InVec == InVal.getOperand(i: 0) && EltNo == InVal.getOperand(i: 1))
23193 return InVec;
23194
23195 if (!IndexC) {
23196 // If this is variable insert to undef vector, it might be better to splat:
23197 // inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >
23198 if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT))
23199 return DAG.getSplat(VT, DL, Op: InVal);
23200 return SDValue();
23201 }
23202
23203 if (VT.isScalableVector())
23204 return SDValue();
23205
23206 unsigned NumElts = VT.getVectorNumElements();
23207
23208 // We must know which element is being inserted for folds below here.
23209 unsigned Elt = IndexC->getZExtValue();
23210
23211 // Handle <1 x ???> vector insertion special cases.
23212 if (NumElts == 1) {
23213 // insert_vector_elt(x, extract_vector_elt(y, 0), 0) -> y
23214 if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
23215 InVal.getOperand(i: 0).getValueType() == VT &&
23216 isNullConstant(V: InVal.getOperand(i: 1)))
23217 return InVal.getOperand(i: 0);
23218 }
23219
23220 // Canonicalize insert_vector_elt dag nodes.
23221 // Example:
23222 // (insert_vector_elt (insert_vector_elt A, Idx0), Idx1)
23223 // -> (insert_vector_elt (insert_vector_elt A, Idx1), Idx0)
23224 //
23225 // Do this only if the child insert_vector node has one use; also
23226 // do this only if indices are both constants and Idx1 < Idx0.
23227 if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT && InVec.hasOneUse()
23228 && isa<ConstantSDNode>(Val: InVec.getOperand(i: 2))) {
23229 unsigned OtherElt = InVec.getConstantOperandVal(i: 2);
23230 if (Elt < OtherElt) {
23231 // Swap nodes.
23232 SDValue NewOp = DAG.getNode(Opcode: ISD::INSERT_VECTOR_ELT, DL, VT,
23233 N1: InVec.getOperand(i: 0), N2: InVal, N3: EltNo);
23234 AddToWorklist(N: NewOp.getNode());
23235 return DAG.getNode(Opcode: ISD::INSERT_VECTOR_ELT, DL: SDLoc(InVec.getNode()),
23236 VT, N1: NewOp, N2: InVec.getOperand(i: 1), N3: InVec.getOperand(i: 2));
23237 }
23238 }
23239
23240 if (SDValue Shuf = mergeInsertEltWithShuffle(N, InsIndex: Elt))
23241 return Shuf;
23242
23243 if (SDValue Shuf = combineInsertEltToShuffle(N, InsIndex: Elt))
23244 return Shuf;
23245
23246 if (SDValue Shuf = combineInsertEltToLoad(N, InsIndex: Elt))
23247 return Shuf;
23248
23249 // Attempt to convert an insert_vector_elt chain into a legal build_vector.
23250 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT)) {
23251 // vXi1 vector - we don't need to recurse.
23252 if (NumElts == 1)
23253 return DAG.getBuildVector(VT, DL, Ops: {InVal});
23254
23255 // If we haven't already collected the element, insert into the op list.
23256 EVT MaxEltVT = InVal.getValueType();
23257 auto AddBuildVectorOp = [&](SmallVectorImpl<SDValue> &Ops, SDValue Elt,
23258 unsigned Idx) {
23259 if (!Ops[Idx]) {
23260 Ops[Idx] = Elt;
23261 if (VT.isInteger()) {
23262 EVT EltVT = Elt.getValueType();
23263 MaxEltVT = MaxEltVT.bitsGE(VT: EltVT) ? MaxEltVT : EltVT;
23264 }
23265 }
23266 };
23267
23268 // Ensure all the operands are the same value type, fill any missing
23269 // operands with UNDEF and create the BUILD_VECTOR.
23270 auto CanonicalizeBuildVector = [&](SmallVectorImpl<SDValue> &Ops,
23271 bool FreezeUndef = false) {
23272 assert(Ops.size() == NumElts && "Unexpected vector size");
23273 SDValue UndefOp = FreezeUndef ? DAG.getFreeze(V: DAG.getUNDEF(VT: MaxEltVT))
23274 : DAG.getUNDEF(VT: MaxEltVT);
23275 for (SDValue &Op : Ops) {
23276 if (Op)
23277 Op = VT.isInteger() ? DAG.getAnyExtOrTrunc(Op, DL, VT: MaxEltVT) : Op;
23278 else
23279 Op = UndefOp;
23280 }
23281 return DAG.getBuildVector(VT, DL, Ops);
23282 };
23283
23284 SmallVector<SDValue, 8> Ops(NumElts, SDValue());
23285 Ops[Elt] = InVal;
23286
23287 // Recurse up a INSERT_VECTOR_ELT chain to build a BUILD_VECTOR.
23288 for (SDValue CurVec = InVec; CurVec;) {
23289 // UNDEF - build new BUILD_VECTOR from already inserted operands.
23290 if (CurVec.isUndef())
23291 return CanonicalizeBuildVector(Ops);
23292
23293 // FREEZE(UNDEF) - build new BUILD_VECTOR from already inserted operands.
23294 if (ISD::isFreezeUndef(N: CurVec.getNode()) && CurVec.hasOneUse())
23295 return CanonicalizeBuildVector(Ops, /*FreezeUndef=*/true);
23296
23297 // BUILD_VECTOR - insert unused operands and build new BUILD_VECTOR.
23298 if (CurVec.getOpcode() == ISD::BUILD_VECTOR && CurVec.hasOneUse()) {
23299 for (unsigned I = 0; I != NumElts; ++I)
23300 AddBuildVectorOp(Ops, CurVec.getOperand(i: I), I);
23301 return CanonicalizeBuildVector(Ops);
23302 }
23303
23304 // SCALAR_TO_VECTOR - insert unused scalar and build new BUILD_VECTOR.
23305 if (CurVec.getOpcode() == ISD::SCALAR_TO_VECTOR && CurVec.hasOneUse()) {
23306 AddBuildVectorOp(Ops, CurVec.getOperand(i: 0), 0);
23307 return CanonicalizeBuildVector(Ops);
23308 }
23309
23310 // INSERT_VECTOR_ELT - insert operand and continue up the chain.
23311 if (CurVec.getOpcode() == ISD::INSERT_VECTOR_ELT && CurVec.hasOneUse())
23312 if (auto *CurIdx = dyn_cast<ConstantSDNode>(Val: CurVec.getOperand(i: 2)))
23313 if (CurIdx->getAPIntValue().ult(RHS: NumElts)) {
23314 unsigned Idx = CurIdx->getZExtValue();
23315 AddBuildVectorOp(Ops, CurVec.getOperand(i: 1), Idx);
23316
23317 // Found entire BUILD_VECTOR.
23318 if (all_of(Range&: Ops, P: [](SDValue Op) { return !!Op; }))
23319 return CanonicalizeBuildVector(Ops);
23320
23321 CurVec = CurVec->getOperand(Num: 0);
23322 continue;
23323 }
23324
23325 // VECTOR_SHUFFLE - if all the operands match the shuffle's sources,
23326 // update the shuffle mask (and second operand if we started with unary
23327 // shuffle) and create a new legal shuffle.
23328 if (CurVec.getOpcode() == ISD::VECTOR_SHUFFLE && CurVec.hasOneUse()) {
23329 auto *SVN = cast<ShuffleVectorSDNode>(Val&: CurVec);
23330 SDValue LHS = SVN->getOperand(Num: 0);
23331 SDValue RHS = SVN->getOperand(Num: 1);
23332 SmallVector<int, 16> Mask(SVN->getMask());
23333 bool Merged = true;
23334 for (auto I : enumerate(First&: Ops)) {
23335 SDValue &Op = I.value();
23336 if (Op) {
23337 SmallVector<int, 16> NewMask;
23338 if (!mergeEltWithShuffle(X&: LHS, Y&: RHS, Mask, NewMask, Elt: Op, InsIndex: I.index())) {
23339 Merged = false;
23340 break;
23341 }
23342 Mask = std::move(NewMask);
23343 }
23344 }
23345 if (Merged)
23346 if (SDValue NewShuffle =
23347 TLI.buildLegalVectorShuffle(VT, DL, N0: LHS, N1: RHS, Mask, DAG))
23348 return NewShuffle;
23349 }
23350
23351 if (!LegalOperations) {
23352 bool IsNull = llvm::isNullConstant(V: InVal);
23353 // We can convert to AND/OR mask if all insertions are zero or -1
23354 // respectively.
23355 if ((IsNull || llvm::isAllOnesConstant(V: InVal)) &&
23356 all_of(Range&: Ops, P: [InVal](SDValue Op) { return !Op || Op == InVal; }) &&
23357 count_if(Range&: Ops, P: [InVal](SDValue Op) { return Op == InVal; }) >= 2) {
23358 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: MaxEltVT);
23359 SDValue AllOnes = DAG.getAllOnesConstant(DL, VT: MaxEltVT);
23360 SmallVector<SDValue, 8> Mask(NumElts);
23361
23362 // Build the mask and return the corresponding DAG node.
23363 auto BuildMaskAndNode = [&](SDValue TrueVal, SDValue FalseVal,
23364 unsigned MaskOpcode) {
23365 for (unsigned I = 0; I != NumElts; ++I)
23366 Mask[I] = Ops[I] ? TrueVal : FalseVal;
23367 return DAG.getNode(Opcode: MaskOpcode, DL, VT, N1: CurVec,
23368 N2: DAG.getBuildVector(VT, DL, Ops: Mask));
23369 };
23370
23371 // If all elements are zero, we can use AND with all ones.
23372 if (IsNull)
23373 return BuildMaskAndNode(Zero, AllOnes, ISD::AND);
23374
23375 // If all elements are -1, we can use OR with zero.
23376 return BuildMaskAndNode(AllOnes, Zero, ISD::OR);
23377 }
23378 }
23379
23380 // Failed to find a match in the chain - bail.
23381 break;
23382 }
23383
23384 // See if we can fill in the missing constant elements as zeros.
23385 // TODO: Should we do this for any constant?
23386 APInt DemandedZeroElts = APInt::getZero(numBits: NumElts);
23387 for (unsigned I = 0; I != NumElts; ++I)
23388 if (!Ops[I])
23389 DemandedZeroElts.setBit(I);
23390
23391 if (DAG.MaskedVectorIsZero(Op: InVec, DemandedElts: DemandedZeroElts)) {
23392 SDValue Zero = VT.isInteger() ? DAG.getConstant(Val: 0, DL, VT: MaxEltVT)
23393 : DAG.getConstantFP(Val: 0, DL, VT: MaxEltVT);
23394 for (unsigned I = 0; I != NumElts; ++I)
23395 if (!Ops[I])
23396 Ops[I] = Zero;
23397
23398 return CanonicalizeBuildVector(Ops);
23399 }
23400 }
23401
23402 return SDValue();
23403}
23404
23405/// Transform a vector binary operation into a scalar binary operation by moving
23406/// the math/logic after an extract element of a vector.
23407static SDValue scalarizeExtractedBinOp(SDNode *ExtElt, SelectionDAG &DAG,
23408 const SDLoc &DL, bool LegalTypes) {
23409 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
23410 SDValue Vec = ExtElt->getOperand(Num: 0);
23411 SDValue Index = ExtElt->getOperand(Num: 1);
23412 auto *IndexC = dyn_cast<ConstantSDNode>(Val&: Index);
23413 unsigned Opc = Vec.getOpcode();
23414 if (!IndexC || !Vec.hasOneUse() || (!TLI.isBinOp(Opcode: Opc) && Opc != ISD::SETCC) ||
23415 Vec->getNumValues() != 1)
23416 return SDValue();
23417
23418 // Targets may want to avoid this to prevent an expensive register transfer.
23419 if (!TLI.shouldScalarizeBinop(VecOp: Vec))
23420 return SDValue();
23421
23422 EVT ResVT = ExtElt->getValueType(ResNo: 0);
23423 if (Opc == ISD::SETCC &&
23424 (ResVT != Vec.getValueType().getVectorElementType() || LegalTypes))
23425 return SDValue();
23426
23427 // Extracting an element of a vector constant is constant-folded, so this
23428 // transform is just replacing a vector op with a scalar op while moving the
23429 // extract.
23430 SDValue Op0 = Vec.getOperand(i: 0);
23431 SDValue Op1 = Vec.getOperand(i: 1);
23432 APInt SplatVal;
23433 if (!isAnyConstantBuildVector(V: Op0, NoOpaques: true) &&
23434 !ISD::isConstantSplatVector(N: Op0.getNode(), SplatValue&: SplatVal) &&
23435 !isAnyConstantBuildVector(V: Op1, NoOpaques: true) &&
23436 !ISD::isConstantSplatVector(N: Op1.getNode(), SplatValue&: SplatVal))
23437 return SDValue();
23438
23439 // extractelt (op X, C), IndexC --> op (extractelt X, IndexC), C'
23440 // extractelt (op C, X), IndexC --> op C', (extractelt X, IndexC)
23441 if (Opc == ISD::SETCC) {
23442 EVT OpVT = Op0.getValueType().getVectorElementType();
23443 Op0 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: OpVT, N1: Op0, N2: Index);
23444 Op1 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: OpVT, N1: Op1, N2: Index);
23445 SDValue NewVal = DAG.getSetCC(
23446 DL, VT: ResVT, LHS: Op0, RHS: Op1, Cond: cast<CondCodeSDNode>(Val: Vec->getOperand(Num: 2))->get());
23447 // We may need to sign- or zero-extend the result to match the same
23448 // behaviour as the vector version of SETCC.
23449 unsigned VecBoolContents = TLI.getBooleanContents(Type: Vec.getValueType());
23450 if (ResVT != MVT::i1 &&
23451 VecBoolContents != TargetLowering::UndefinedBooleanContent &&
23452 VecBoolContents != TLI.getBooleanContents(Type: ResVT)) {
23453 if (VecBoolContents == TargetLowering::ZeroOrNegativeOneBooleanContent)
23454 NewVal = DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL, VT: ResVT, N1: NewVal,
23455 N2: DAG.getValueType(MVT::i1));
23456 else
23457 NewVal = DAG.getZeroExtendInReg(Op: NewVal, DL, VT: MVT::i1);
23458 }
23459 return NewVal;
23460 }
23461 Op0 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: ResVT, N1: Op0, N2: Index);
23462 Op1 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: ResVT, N1: Op1, N2: Index);
23463 return DAG.getNode(Opcode: Opc, DL, VT: ResVT, N1: Op0, N2: Op1);
23464}
23465
23466// Given a ISD::EXTRACT_VECTOR_ELT, which is a glorified bit sequence extract,
23467// recursively analyse all of it's users. and try to model themselves as
23468// bit sequence extractions. If all of them agree on the new, narrower element
23469// type, and all of them can be modelled as ISD::EXTRACT_VECTOR_ELT's of that
23470// new element type, do so now.
23471// This is mainly useful to recover from legalization that scalarized
23472// the vector as wide elements, but tries to rebuild it with narrower elements.
23473//
23474// Some more nodes could be modelled if that helps cover interesting patterns.
23475bool DAGCombiner::refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(
23476 SDNode *N) {
23477 // We perform this optimization post type-legalization because
23478 // the type-legalizer often scalarizes integer-promoted vectors.
23479 // Performing this optimization before may cause legalizaton cycles.
23480 if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
23481 return false;
23482
23483 // TODO: Add support for big-endian.
23484 if (DAG.getDataLayout().isBigEndian())
23485 return false;
23486
23487 SDValue VecOp = N->getOperand(Num: 0);
23488 EVT VecVT = VecOp.getValueType();
23489 assert(!VecVT.isScalableVector() && "Only for fixed vectors.");
23490
23491 // We must start with a constant extraction index.
23492 auto *IndexC = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
23493 if (!IndexC)
23494 return false;
23495
23496 assert(IndexC->getZExtValue() < VecVT.getVectorNumElements() &&
23497 "Original ISD::EXTRACT_VECTOR_ELT is undefinend?");
23498
23499 // TODO: deal with the case of implicit anyext of the extraction.
23500 unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
23501 EVT ScalarVT = N->getValueType(ResNo: 0);
23502 if (VecVT.getScalarType() != ScalarVT)
23503 return false;
23504
23505 // TODO: deal with the cases other than everything being integer-typed.
23506 if (!ScalarVT.isScalarInteger())
23507 return false;
23508
23509 struct Entry {
23510 SDNode *Producer;
23511
23512 // Which bits of VecOp does it contain?
23513 unsigned BitPos;
23514 int NumBits;
23515 // NOTE: the actual width of \p Producer may be wider than NumBits!
23516
23517 Entry(Entry &&) = default;
23518 Entry(SDNode *Producer_, unsigned BitPos_, int NumBits_)
23519 : Producer(Producer_), BitPos(BitPos_), NumBits(NumBits_) {}
23520
23521 Entry() = delete;
23522 Entry(const Entry &) = delete;
23523 Entry &operator=(const Entry &) = delete;
23524 Entry &operator=(Entry &&) = delete;
23525 };
23526 SmallVector<Entry, 32> Worklist;
23527 SmallVector<Entry, 32> Leafs;
23528
23529 // We start at the "root" ISD::EXTRACT_VECTOR_ELT.
23530 Worklist.emplace_back(Args&: N, /*BitPos=*/Args: VecEltBitWidth * IndexC->getZExtValue(),
23531 /*NumBits=*/Args&: VecEltBitWidth);
23532
23533 while (!Worklist.empty()) {
23534 Entry E = Worklist.pop_back_val();
23535 // Does the node not even use any of the VecOp bits?
23536 if (!(E.NumBits > 0 && E.BitPos < VecVT.getSizeInBits() &&
23537 E.BitPos + E.NumBits <= VecVT.getSizeInBits()))
23538 return false; // Let's allow the other combines clean this up first.
23539 // Did we fail to model any of the users of the Producer?
23540 bool ProducerIsLeaf = false;
23541 // Look at each user of this Producer.
23542 for (SDNode *User : E.Producer->users()) {
23543 switch (User->getOpcode()) {
23544 // TODO: support ISD::BITCAST
23545 // TODO: support ISD::ANY_EXTEND
23546 // TODO: support ISD::ZERO_EXTEND
23547 // TODO: support ISD::SIGN_EXTEND
23548 case ISD::TRUNCATE:
23549 // Truncation simply means we keep position, but extract less bits.
23550 Worklist.emplace_back(Args&: User, Args&: E.BitPos,
23551 /*NumBits=*/Args: User->getValueSizeInBits(ResNo: 0));
23552 break;
23553 // TODO: support ISD::SRA
23554 // TODO: support ISD::SHL
23555 case ISD::SRL:
23556 // We should be shifting the Producer by a constant amount.
23557 if (auto *ShAmtC = dyn_cast<ConstantSDNode>(Val: User->getOperand(Num: 1));
23558 User->getOperand(Num: 0).getNode() == E.Producer && ShAmtC) {
23559 // Logical right-shift means that we start extraction later,
23560 // but stop it at the same position we did previously.
23561 unsigned ShAmt = ShAmtC->getZExtValue();
23562 Worklist.emplace_back(Args&: User, Args: E.BitPos + ShAmt, Args: E.NumBits - ShAmt);
23563 break;
23564 }
23565 [[fallthrough]];
23566 default:
23567 // We can not model this user of the Producer.
23568 // Which means the current Producer will be a ISD::EXTRACT_VECTOR_ELT.
23569 ProducerIsLeaf = true;
23570 // Profitability check: all users that we can not model
23571 // must be ISD::BUILD_VECTOR's.
23572 if (User->getOpcode() != ISD::BUILD_VECTOR)
23573 return false;
23574 break;
23575 }
23576 }
23577 if (ProducerIsLeaf)
23578 Leafs.emplace_back(Args: std::move(E));
23579 }
23580
23581 unsigned NewVecEltBitWidth = Leafs.front().NumBits;
23582
23583 // If we are still at the same element granularity, give up,
23584 if (NewVecEltBitWidth == VecEltBitWidth)
23585 return false;
23586
23587 // The vector width must be a multiple of the new element width.
23588 if (VecVT.getSizeInBits() % NewVecEltBitWidth != 0)
23589 return false;
23590
23591 // All leafs must agree on the new element width.
23592 // All leafs must not expect any "padding" bits ontop of that width.
23593 // All leafs must start extraction from multiple of that width.
23594 if (!all_of(Range&: Leafs, P: [NewVecEltBitWidth](const Entry &E) {
23595 return (unsigned)E.NumBits == NewVecEltBitWidth &&
23596 E.Producer->getValueSizeInBits(ResNo: 0) == NewVecEltBitWidth &&
23597 E.BitPos % NewVecEltBitWidth == 0;
23598 }))
23599 return false;
23600
23601 EVT NewScalarVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: NewVecEltBitWidth);
23602 EVT NewVecVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NewScalarVT,
23603 NumElements: VecVT.getSizeInBits() / NewVecEltBitWidth);
23604
23605 if (LegalTypes &&
23606 !(TLI.isTypeLegal(VT: NewScalarVT) && TLI.isTypeLegal(VT: NewVecVT)))
23607 return false;
23608
23609 if (LegalOperations &&
23610 !(TLI.isOperationLegalOrCustom(Op: ISD::BITCAST, VT: NewVecVT) &&
23611 TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_VECTOR_ELT, VT: NewVecVT)))
23612 return false;
23613
23614 SDValue NewVecOp = DAG.getBitcast(VT: NewVecVT, V: VecOp);
23615 for (const Entry &E : Leafs) {
23616 SDLoc DL(E.Producer);
23617 unsigned NewIndex = E.BitPos / NewVecEltBitWidth;
23618 assert(NewIndex < NewVecVT.getVectorNumElements() &&
23619 "Creating out-of-bounds ISD::EXTRACT_VECTOR_ELT?");
23620 SDValue V = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: NewScalarVT, N1: NewVecOp,
23621 N2: DAG.getVectorIdxConstant(Val: NewIndex, DL));
23622 CombineTo(N: E.Producer, Res: V);
23623 }
23624
23625 return true;
23626}
23627
23628SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
23629 SDValue VecOp = N->getOperand(Num: 0);
23630 SDValue Index = N->getOperand(Num: 1);
23631 EVT ScalarVT = N->getValueType(ResNo: 0);
23632 EVT VecVT = VecOp.getValueType();
23633 if (VecOp.isUndef())
23634 return DAG.getUNDEF(VT: ScalarVT);
23635
23636 // extract_vector_elt (insert_vector_elt vec, val, idx), idx) -> val
23637 //
23638 // This only really matters if the index is non-constant since other combines
23639 // on the constant elements already work.
23640 SDLoc DL(N);
23641 if (VecOp.getOpcode() == ISD::INSERT_VECTOR_ELT &&
23642 Index == VecOp.getOperand(i: 2)) {
23643 SDValue Elt = VecOp.getOperand(i: 1);
23644 AddUsersToWorklist(N: VecOp.getNode());
23645 return VecVT.isInteger() ? DAG.getAnyExtOrTrunc(Op: Elt, DL, VT: ScalarVT) : Elt;
23646 }
23647
23648 // (vextract (scalar_to_vector val, 0) -> val
23649 if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR) {
23650 // Only 0'th element of SCALAR_TO_VECTOR is defined.
23651 if (DAG.isKnownNeverZero(Op: Index))
23652 return DAG.getUNDEF(VT: ScalarVT);
23653
23654 // Check if the result type doesn't match the inserted element type.
23655 // The inserted element and extracted element may have mismatched bitwidth.
23656 // As a result, EXTRACT_VECTOR_ELT may extend or truncate the extracted vector.
23657 SDValue InOp = VecOp.getOperand(i: 0);
23658 if (InOp.getValueType() != ScalarVT) {
23659 assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
23660 if (InOp.getValueType().bitsGT(VT: ScalarVT))
23661 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: ScalarVT, Operand: InOp);
23662 return DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: ScalarVT, Operand: InOp);
23663 }
23664 return InOp;
23665 }
23666
23667 // extract_vector_elt of out-of-bounds element -> UNDEF
23668 auto *IndexC = dyn_cast<ConstantSDNode>(Val&: Index);
23669 if (IndexC && VecVT.isFixedLengthVector() &&
23670 IndexC->getAPIntValue().uge(RHS: VecVT.getVectorNumElements()))
23671 return DAG.getUNDEF(VT: ScalarVT);
23672
23673 // extract_vector_elt (build_vector x, y), 1 -> y
23674 if (((IndexC && VecOp.getOpcode() == ISD::BUILD_VECTOR) ||
23675 VecOp.getOpcode() == ISD::SPLAT_VECTOR) &&
23676 TLI.isTypeLegal(VT: VecVT)) {
23677 assert((VecOp.getOpcode() != ISD::BUILD_VECTOR ||
23678 VecVT.isFixedLengthVector()) &&
23679 "BUILD_VECTOR used for scalable vectors");
23680 unsigned IndexVal =
23681 VecOp.getOpcode() == ISD::BUILD_VECTOR ? IndexC->getZExtValue() : 0;
23682 SDValue Elt = VecOp.getOperand(i: IndexVal);
23683 EVT InEltVT = Elt.getValueType();
23684
23685 if (VecOp.hasOneUse() || TLI.aggressivelyPreferBuildVectorSources(VecVT) ||
23686 isNullConstant(V: Elt)) {
23687 // Sometimes build_vector's scalar input types do not match result type.
23688 if (ScalarVT == InEltVT)
23689 return Elt;
23690
23691 // TODO: It may be useful to truncate if free if the build_vector
23692 // implicitly converts.
23693 }
23694 }
23695
23696 if (SDValue BO = scalarizeExtractedBinOp(ExtElt: N, DAG, DL, LegalTypes))
23697 return BO;
23698
23699 if (VecVT.isScalableVector())
23700 return SDValue();
23701
23702 // All the code from this point onwards assumes fixed width vectors, but it's
23703 // possible that some of the combinations could be made to work for scalable
23704 // vectors too.
23705 unsigned NumElts = VecVT.getVectorNumElements();
23706 unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
23707
23708 // See if the extracted element is constant, in which case fold it if its
23709 // a legal fp immediate.
23710 if (IndexC && ScalarVT.isFloatingPoint()) {
23711 APInt EltMask = APInt::getOneBitSet(numBits: NumElts, BitNo: IndexC->getZExtValue());
23712 KnownBits KnownElt = DAG.computeKnownBits(Op: VecOp, DemandedElts: EltMask);
23713 if (KnownElt.isConstant()) {
23714 APFloat CstFP =
23715 APFloat(ScalarVT.getFltSemantics(), KnownElt.getConstant());
23716 if (TLI.isFPImmLegal(CstFP, ScalarVT))
23717 return DAG.getConstantFP(Val: CstFP, DL, VT: ScalarVT);
23718 }
23719 }
23720
23721 // TODO: These transforms should not require the 'hasOneUse' restriction, but
23722 // there are regressions on multiple targets without it. We can end up with a
23723 // mess of scalar and vector code if we reduce only part of the DAG to scalar.
23724 if (IndexC && VecOp.getOpcode() == ISD::BITCAST && VecVT.isInteger() &&
23725 VecOp.hasOneUse()) {
23726 // The vector index of the LSBs of the source depend on the endian-ness.
23727 bool IsLE = DAG.getDataLayout().isLittleEndian();
23728 unsigned ExtractIndex = IndexC->getZExtValue();
23729 // extract_elt (v2i32 (bitcast i64:x)), BCTruncElt -> i32 (trunc i64:x)
23730 unsigned BCTruncElt = IsLE ? 0 : NumElts - 1;
23731 SDValue BCSrc = VecOp.getOperand(i: 0);
23732 if (ExtractIndex == BCTruncElt && BCSrc.getValueType().isScalarInteger())
23733 return DAG.getAnyExtOrTrunc(Op: BCSrc, DL, VT: ScalarVT);
23734
23735 // TODO: Add support for SCALAR_TO_VECTOR implicit truncation.
23736 if (LegalTypes && BCSrc.getValueType().isInteger() &&
23737 BCSrc.getOpcode() == ISD::SCALAR_TO_VECTOR &&
23738 BCSrc.getScalarValueSizeInBits() ==
23739 BCSrc.getOperand(i: 0).getScalarValueSizeInBits()) {
23740 // ext_elt (bitcast (scalar_to_vec i64 X to v2i64) to v4i32), TruncElt -->
23741 // trunc i64 X to i32
23742 SDValue X = BCSrc.getOperand(i: 0);
23743 EVT XVT = X.getValueType();
23744 assert(XVT.isScalarInteger() && ScalarVT.isScalarInteger() &&
23745 "Extract element and scalar to vector can't change element type "
23746 "from FP to integer.");
23747 unsigned XBitWidth = X.getValueSizeInBits();
23748 unsigned Scale = XBitWidth / VecEltBitWidth;
23749 BCTruncElt = IsLE ? 0 : Scale - 1;
23750
23751 // An extract element return value type can be wider than its vector
23752 // operand element type. In that case, the high bits are undefined, so
23753 // it's possible that we may need to extend rather than truncate.
23754 if (ExtractIndex < Scale && XBitWidth > VecEltBitWidth) {
23755 assert(XBitWidth % VecEltBitWidth == 0 &&
23756 "Scalar bitwidth must be a multiple of vector element bitwidth");
23757
23758 if (ExtractIndex != BCTruncElt) {
23759 unsigned ShiftIndex =
23760 IsLE ? ExtractIndex : (Scale - 1) - ExtractIndex;
23761 X = DAG.getNode(
23762 Opcode: ISD::SRL, DL, VT: XVT, N1: X,
23763 N2: DAG.getShiftAmountConstant(Val: ShiftIndex * VecEltBitWidth, VT: XVT, DL));
23764 }
23765
23766 return DAG.getAnyExtOrTrunc(Op: X, DL, VT: ScalarVT);
23767 }
23768 }
23769 }
23770
23771 // Transform: (EXTRACT_VECTOR_ELT( VECTOR_SHUFFLE )) -> EXTRACT_VECTOR_ELT.
23772 // We only perform this optimization before the op legalization phase because
23773 // we may introduce new vector instructions which are not backed by TD
23774 // patterns. For example on AVX, extracting elements from a wide vector
23775 // without using extract_subvector. However, if we can find an underlying
23776 // scalar value, then we can always use that.
23777 if (IndexC && VecOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
23778 auto *Shuf = cast<ShuffleVectorSDNode>(Val&: VecOp);
23779 // Find the new index to extract from.
23780 int OrigElt = Shuf->getMaskElt(Idx: IndexC->getZExtValue());
23781
23782 // Extracting an undef index is undef.
23783 if (OrigElt == -1)
23784 return DAG.getUNDEF(VT: ScalarVT);
23785
23786 // Select the right vector half to extract from.
23787 SDValue SVInVec;
23788 if (OrigElt < (int)NumElts) {
23789 SVInVec = VecOp.getOperand(i: 0);
23790 } else {
23791 SVInVec = VecOp.getOperand(i: 1);
23792 OrigElt -= NumElts;
23793 }
23794
23795 if (SVInVec.getOpcode() == ISD::BUILD_VECTOR) {
23796 // TODO: Check if shuffle mask is legal?
23797 if (LegalOperations && TLI.isOperationLegal(Op: ISD::VECTOR_SHUFFLE, VT: VecVT) &&
23798 !VecOp.hasOneUse())
23799 return SDValue();
23800
23801 SDValue InOp = SVInVec.getOperand(i: OrigElt);
23802 if (InOp.getValueType() != ScalarVT) {
23803 assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
23804 InOp = DAG.getSExtOrTrunc(Op: InOp, DL, VT: ScalarVT);
23805 }
23806
23807 return InOp;
23808 }
23809
23810 // FIXME: We should handle recursing on other vector shuffles and
23811 // scalar_to_vector here as well.
23812
23813 if (!LegalOperations ||
23814 // FIXME: Should really be just isOperationLegalOrCustom.
23815 TLI.isOperationLegal(Op: ISD::EXTRACT_VECTOR_ELT, VT: VecVT) ||
23816 TLI.isOperationExpand(Op: ISD::VECTOR_SHUFFLE, VT: VecVT)) {
23817 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: ScalarVT, N1: SVInVec,
23818 N2: DAG.getVectorIdxConstant(Val: OrigElt, DL));
23819 }
23820 }
23821
23822 // If only EXTRACT_VECTOR_ELT nodes use the source vector we can
23823 // simplify it based on the (valid) extraction indices.
23824 if (llvm::all_of(Range: VecOp->users(), P: [&](SDNode *Use) {
23825 return Use->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
23826 Use->getOperand(Num: 0) == VecOp &&
23827 isa<ConstantSDNode>(Val: Use->getOperand(Num: 1));
23828 })) {
23829 APInt DemandedElts = APInt::getZero(numBits: NumElts);
23830 for (SDNode *User : VecOp->users()) {
23831 auto *CstElt = cast<ConstantSDNode>(Val: User->getOperand(Num: 1));
23832 if (CstElt->getAPIntValue().ult(RHS: NumElts))
23833 DemandedElts.setBit(CstElt->getZExtValue());
23834 }
23835 if (SimplifyDemandedVectorElts(Op: VecOp, DemandedElts, AssumeSingleUse: true)) {
23836 // We simplified the vector operand of this extract element. If this
23837 // extract is not dead, visit it again so it is folded properly.
23838 if (N->getOpcode() != ISD::DELETED_NODE)
23839 AddToWorklist(N);
23840 return SDValue(N, 0);
23841 }
23842 APInt DemandedBits = APInt::getAllOnes(numBits: VecEltBitWidth);
23843 if (SimplifyDemandedBits(Op: VecOp, DemandedBits, DemandedElts, AssumeSingleUse: true)) {
23844 // We simplified the vector operand of this extract element. If this
23845 // extract is not dead, visit it again so it is folded properly.
23846 if (N->getOpcode() != ISD::DELETED_NODE)
23847 AddToWorklist(N);
23848 return SDValue(N, 0);
23849 }
23850 }
23851
23852 if (refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(N))
23853 return SDValue(N, 0);
23854
23855 // Everything under here is trying to match an extract of a loaded value.
23856 // If the result of load has to be truncated, then it's not necessarily
23857 // profitable.
23858 bool BCNumEltsChanged = false;
23859 EVT ExtVT = VecVT.getVectorElementType();
23860 EVT LVT = ExtVT;
23861 if (ScalarVT.bitsLT(VT: LVT) && !TLI.isTruncateFree(FromVT: LVT, ToVT: ScalarVT))
23862 return SDValue();
23863
23864 if (VecOp.getOpcode() == ISD::BITCAST) {
23865 // Don't duplicate a load with other uses.
23866 if (!VecOp.hasOneUse())
23867 return SDValue();
23868
23869 EVT BCVT = VecOp.getOperand(i: 0).getValueType();
23870 if (!BCVT.isVector() || ExtVT.bitsGT(VT: BCVT.getVectorElementType()))
23871 return SDValue();
23872 if (NumElts != BCVT.getVectorNumElements())
23873 BCNumEltsChanged = true;
23874 VecOp = VecOp.getOperand(i: 0);
23875 ExtVT = BCVT.getVectorElementType();
23876 }
23877
23878 // extract (vector load $addr), i --> load $addr + i * size
23879 if (!LegalOperations && !IndexC && VecOp.hasOneUse() &&
23880 ISD::isNormalLoad(N: VecOp.getNode()) &&
23881 !Index->hasPredecessor(N: VecOp.getNode())) {
23882 auto *VecLoad = dyn_cast<LoadSDNode>(Val&: VecOp);
23883 if (VecLoad && VecLoad->isSimple()) {
23884 if (SDValue Scalarized = TLI.scalarizeExtractedVectorLoad(
23885 ResultVT: ScalarVT, DL: SDLoc(N), InVecVT: VecVT, EltNo: Index, OriginalLoad: VecLoad, DAG)) {
23886 ++OpsNarrowed;
23887 return Scalarized;
23888 }
23889 }
23890 }
23891
23892 // Perform only after legalization to ensure build_vector / vector_shuffle
23893 // optimizations have already been done.
23894 if (!LegalOperations || !IndexC)
23895 return SDValue();
23896
23897 // (vextract (v4f32 load $addr), c) -> (f32 load $addr+c*size)
23898 // (vextract (v4f32 s2v (f32 load $addr)), c) -> (f32 load $addr+c*size)
23899 // (vextract (v4f32 shuffle (load $addr), <1,u,u,u>), 0) -> (f32 load $addr)
23900 int Elt = IndexC->getZExtValue();
23901 LoadSDNode *LN0 = nullptr;
23902 if (ISD::isNormalLoad(N: VecOp.getNode())) {
23903 LN0 = cast<LoadSDNode>(Val&: VecOp);
23904 } else if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR &&
23905 VecOp.getOperand(i: 0).getValueType() == ExtVT &&
23906 ISD::isNormalLoad(N: VecOp.getOperand(i: 0).getNode())) {
23907 // Don't duplicate a load with other uses.
23908 if (!VecOp.hasOneUse())
23909 return SDValue();
23910
23911 LN0 = cast<LoadSDNode>(Val: VecOp.getOperand(i: 0));
23912 }
23913 if (auto *Shuf = dyn_cast<ShuffleVectorSDNode>(Val&: VecOp)) {
23914 // (vextract (vector_shuffle (load $addr), v2, <1, u, u, u>), 1)
23915 // =>
23916 // (load $addr+1*size)
23917
23918 // Don't duplicate a load with other uses.
23919 if (!VecOp.hasOneUse())
23920 return SDValue();
23921
23922 // If the bit convert changed the number of elements, it is unsafe
23923 // to examine the mask.
23924 if (BCNumEltsChanged)
23925 return SDValue();
23926
23927 // Select the input vector, guarding against out of range extract vector.
23928 int Idx = (Elt > (int)NumElts) ? -1 : Shuf->getMaskElt(Idx: Elt);
23929 VecOp = (Idx < (int)NumElts) ? VecOp.getOperand(i: 0) : VecOp.getOperand(i: 1);
23930
23931 if (VecOp.getOpcode() == ISD::BITCAST) {
23932 // Don't duplicate a load with other uses.
23933 if (!VecOp.hasOneUse())
23934 return SDValue();
23935
23936 VecOp = VecOp.getOperand(i: 0);
23937 }
23938 if (ISD::isNormalLoad(N: VecOp.getNode())) {
23939 LN0 = cast<LoadSDNode>(Val&: VecOp);
23940 Elt = (Idx < (int)NumElts) ? Idx : Idx - (int)NumElts;
23941 Index = DAG.getConstant(Val: Elt, DL, VT: Index.getValueType());
23942 }
23943 } else if (VecOp.getOpcode() == ISD::CONCAT_VECTORS && !BCNumEltsChanged &&
23944 VecVT.getVectorElementType() == ScalarVT &&
23945 (!LegalTypes ||
23946 TLI.isTypeLegal(
23947 VT: VecOp.getOperand(i: 0).getValueType().getVectorElementType()))) {
23948 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 0
23949 // -> extract_vector_elt a, 0
23950 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 1
23951 // -> extract_vector_elt a, 1
23952 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 2
23953 // -> extract_vector_elt b, 0
23954 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 3
23955 // -> extract_vector_elt b, 1
23956 EVT ConcatVT = VecOp.getOperand(i: 0).getValueType();
23957 unsigned ConcatNumElts = ConcatVT.getVectorNumElements();
23958 SDValue NewIdx = DAG.getConstant(Val: Elt % ConcatNumElts, DL,
23959 VT: Index.getValueType());
23960
23961 SDValue ConcatOp = VecOp.getOperand(i: Elt / ConcatNumElts);
23962 SDValue Elt = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL,
23963 VT: ConcatVT.getVectorElementType(),
23964 N1: ConcatOp, N2: NewIdx);
23965 return DAG.getNode(Opcode: ISD::BITCAST, DL, VT: ScalarVT, Operand: Elt);
23966 }
23967
23968 // Make sure we found a non-volatile load and the extractelement is
23969 // the only use.
23970 if (!LN0 || !LN0->hasNUsesOfValue(NUses: 1,Value: 0) || !LN0->isSimple())
23971 return SDValue();
23972
23973 // If Idx was -1 above, Elt is going to be -1, so just return undef.
23974 if (Elt == -1)
23975 return DAG.getUNDEF(VT: LVT);
23976
23977 if (SDValue Scalarized =
23978 TLI.scalarizeExtractedVectorLoad(ResultVT: LVT, DL, InVecVT: VecVT, EltNo: Index, OriginalLoad: LN0, DAG)) {
23979 ++OpsNarrowed;
23980 return Scalarized;
23981 }
23982
23983 return SDValue();
23984}
23985
23986// Simplify (build_vec (ext )) to (bitcast (build_vec ))
23987SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) {
23988 // We perform this optimization post type-legalization because
23989 // the type-legalizer often scalarizes integer-promoted vectors.
23990 // Performing this optimization before may create bit-casts which
23991 // will be type-legalized to complex code sequences.
23992 // We perform this optimization only before the operation legalizer because we
23993 // may introduce illegal operations.
23994 if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
23995 return SDValue();
23996
23997 unsigned NumInScalars = N->getNumOperands();
23998 SDLoc DL(N);
23999 EVT VT = N->getValueType(ResNo: 0);
24000
24001 // Check to see if this is a BUILD_VECTOR of a bunch of values
24002 // which come from any_extend or zero_extend nodes. If so, we can create
24003 // a new BUILD_VECTOR using bit-casts which may enable other BUILD_VECTOR
24004 // optimizations. We do not handle sign-extend because we can't fill the sign
24005 // using shuffles.
24006 EVT SourceType = MVT::Other;
24007 bool AllAnyExt = true;
24008
24009 for (unsigned i = 0; i != NumInScalars; ++i) {
24010 SDValue In = N->getOperand(Num: i);
24011 // Ignore undef inputs.
24012 if (In.isUndef()) continue;
24013
24014 bool AnyExt = In.getOpcode() == ISD::ANY_EXTEND;
24015 bool ZeroExt = In.getOpcode() == ISD::ZERO_EXTEND;
24016
24017 // Abort if the element is not an extension.
24018 if (!ZeroExt && !AnyExt) {
24019 SourceType = MVT::Other;
24020 break;
24021 }
24022
24023 // The input is a ZeroExt or AnyExt. Check the original type.
24024 EVT InTy = In.getOperand(i: 0).getValueType();
24025
24026 // Check that all of the widened source types are the same.
24027 if (SourceType == MVT::Other)
24028 // First time.
24029 SourceType = InTy;
24030 else if (InTy != SourceType) {
24031 // Multiple income types. Abort.
24032 SourceType = MVT::Other;
24033 break;
24034 }
24035
24036 // Check if all of the extends are ANY_EXTENDs.
24037 AllAnyExt &= AnyExt;
24038 }
24039
24040 // In order to have valid types, all of the inputs must be extended from the
24041 // same source type and all of the inputs must be any or zero extend.
24042 // Scalar sizes must be a power of two.
24043 EVT OutScalarTy = VT.getScalarType();
24044 bool ValidTypes =
24045 SourceType != MVT::Other &&
24046 llvm::has_single_bit<uint32_t>(Value: OutScalarTy.getSizeInBits()) &&
24047 llvm::has_single_bit<uint32_t>(Value: SourceType.getSizeInBits());
24048
24049 // Create a new simpler BUILD_VECTOR sequence which other optimizations can
24050 // turn into a single shuffle instruction.
24051 if (!ValidTypes)
24052 return SDValue();
24053
24054 // If we already have a splat buildvector, then don't fold it if it means
24055 // introducing zeros.
24056 if (!AllAnyExt && DAG.isSplatValue(V: SDValue(N, 0), /*AllowUndefs*/ true))
24057 return SDValue();
24058
24059 bool isLE = DAG.getDataLayout().isLittleEndian();
24060 unsigned ElemRatio = OutScalarTy.getSizeInBits()/SourceType.getSizeInBits();
24061 assert(ElemRatio > 1 && "Invalid element size ratio");
24062 SDValue Filler = AllAnyExt ? DAG.getUNDEF(VT: SourceType):
24063 DAG.getConstant(Val: 0, DL, VT: SourceType);
24064
24065 unsigned NewBVElems = ElemRatio * VT.getVectorNumElements();
24066 SmallVector<SDValue, 8> Ops(NewBVElems, Filler);
24067
24068 // Populate the new build_vector
24069 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
24070 SDValue Cast = N->getOperand(Num: i);
24071 assert((Cast.getOpcode() == ISD::ANY_EXTEND ||
24072 Cast.getOpcode() == ISD::ZERO_EXTEND ||
24073 Cast.isUndef()) && "Invalid cast opcode");
24074 SDValue In;
24075 if (Cast.isUndef())
24076 In = DAG.getUNDEF(VT: SourceType);
24077 else
24078 In = Cast->getOperand(Num: 0);
24079 unsigned Index = isLE ? (i * ElemRatio) :
24080 (i * ElemRatio + (ElemRatio - 1));
24081
24082 assert(Index < Ops.size() && "Invalid index");
24083 Ops[Index] = In;
24084 }
24085
24086 // The type of the new BUILD_VECTOR node.
24087 EVT VecVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: SourceType, NumElements: NewBVElems);
24088 assert(VecVT.getSizeInBits() == VT.getSizeInBits() &&
24089 "Invalid vector size");
24090 // Check if the new vector type is legal.
24091 if (!isTypeLegal(VT: VecVT) ||
24092 (!TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT: VecVT) &&
24093 TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT)))
24094 return SDValue();
24095
24096 // Make the new BUILD_VECTOR.
24097 SDValue BV = DAG.getBuildVector(VT: VecVT, DL, Ops);
24098
24099 // The new BUILD_VECTOR node has the potential to be further optimized.
24100 AddToWorklist(N: BV.getNode());
24101 // Bitcast to the desired type.
24102 return DAG.getBitcast(VT, V: BV);
24103}
24104
24105// Simplify (build_vec (trunc $1)
24106// (trunc (srl $1 half-width))
24107// (trunc (srl $1 (2 * half-width))))
24108// to (bitcast $1)
24109SDValue DAGCombiner::reduceBuildVecTruncToBitCast(SDNode *N) {
24110 assert(N->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
24111
24112 EVT VT = N->getValueType(ResNo: 0);
24113
24114 // Don't run this before LegalizeTypes if VT is legal.
24115 // Targets may have other preferences.
24116 if (Level < AfterLegalizeTypes && TLI.isTypeLegal(VT))
24117 return SDValue();
24118
24119 // Only for little endian
24120 if (!DAG.getDataLayout().isLittleEndian())
24121 return SDValue();
24122
24123 EVT OutScalarTy = VT.getScalarType();
24124 uint64_t ScalarTypeBitsize = OutScalarTy.getSizeInBits();
24125
24126 // Only for power of two types to be sure that bitcast works well
24127 if (!isPowerOf2_64(Value: ScalarTypeBitsize))
24128 return SDValue();
24129
24130 unsigned NumInScalars = N->getNumOperands();
24131
24132 // Look through bitcasts
24133 auto PeekThroughBitcast = [](SDValue Op) {
24134 if (Op.getOpcode() == ISD::BITCAST)
24135 return Op.getOperand(i: 0);
24136 return Op;
24137 };
24138
24139 // The source value where all the parts are extracted.
24140 SDValue Src;
24141 for (unsigned i = 0; i != NumInScalars; ++i) {
24142 SDValue In = PeekThroughBitcast(N->getOperand(Num: i));
24143 // Ignore undef inputs.
24144 if (In.isUndef()) continue;
24145
24146 if (In.getOpcode() != ISD::TRUNCATE)
24147 return SDValue();
24148
24149 In = PeekThroughBitcast(In.getOperand(i: 0));
24150
24151 if (In.getOpcode() != ISD::SRL) {
24152 // For now only build_vec without shuffling, handle shifts here in the
24153 // future.
24154 if (i != 0)
24155 return SDValue();
24156
24157 Src = In;
24158 } else {
24159 // In is SRL
24160 SDValue part = PeekThroughBitcast(In.getOperand(i: 0));
24161
24162 if (!Src) {
24163 Src = part;
24164 } else if (Src != part) {
24165 // Vector parts do not stem from the same variable
24166 return SDValue();
24167 }
24168
24169 SDValue ShiftAmtVal = In.getOperand(i: 1);
24170 if (!isa<ConstantSDNode>(Val: ShiftAmtVal))
24171 return SDValue();
24172
24173 uint64_t ShiftAmt = In.getConstantOperandVal(i: 1);
24174
24175 // The extracted value is not extracted at the right position
24176 if (ShiftAmt != i * ScalarTypeBitsize)
24177 return SDValue();
24178 }
24179 }
24180
24181 // Only cast if the size is the same
24182 if (!Src || Src.getValueType().getSizeInBits() != VT.getSizeInBits())
24183 return SDValue();
24184
24185 return DAG.getBitcast(VT, V: Src);
24186}
24187
24188SDValue DAGCombiner::createBuildVecShuffle(const SDLoc &DL, SDNode *N,
24189 ArrayRef<int> VectorMask,
24190 SDValue VecIn1, SDValue VecIn2,
24191 unsigned LeftIdx, bool DidSplitVec) {
24192 EVT VT = N->getValueType(ResNo: 0);
24193 EVT InVT1 = VecIn1.getValueType();
24194 EVT InVT2 = VecIn2.getNode() ? VecIn2.getValueType() : InVT1;
24195
24196 unsigned NumElems = VT.getVectorNumElements();
24197 unsigned ShuffleNumElems = NumElems;
24198
24199 // If we artificially split a vector in two already, then the offsets in the
24200 // operands will all be based off of VecIn1, even those in VecIn2.
24201 unsigned Vec2Offset = DidSplitVec ? 0 : InVT1.getVectorNumElements();
24202
24203 uint64_t VTSize = VT.getFixedSizeInBits();
24204 uint64_t InVT1Size = InVT1.getFixedSizeInBits();
24205 uint64_t InVT2Size = InVT2.getFixedSizeInBits();
24206
24207 assert(InVT2Size <= InVT1Size &&
24208 "Inputs must be sorted to be in non-increasing vector size order.");
24209
24210 // We can't generate a shuffle node with mismatched input and output types.
24211 // Try to make the types match the type of the output.
24212 if (InVT1 != VT || InVT2 != VT) {
24213 if ((VTSize % InVT1Size == 0) && InVT1 == InVT2) {
24214 // If the output vector length is a multiple of both input lengths,
24215 // we can concatenate them and pad the rest with undefs.
24216 unsigned NumConcats = VTSize / InVT1Size;
24217 assert(NumConcats >= 2 && "Concat needs at least two inputs!");
24218 SmallVector<SDValue, 2> ConcatOps(NumConcats, DAG.getUNDEF(VT: InVT1));
24219 ConcatOps[0] = VecIn1;
24220 ConcatOps[1] = VecIn2 ? VecIn2 : DAG.getUNDEF(VT: InVT1);
24221 VecIn1 = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT, Ops: ConcatOps);
24222 VecIn2 = SDValue();
24223 } else if (InVT1Size == VTSize * 2) {
24224 if (!TLI.isExtractSubvectorCheap(ResVT: VT, SrcVT: InVT1, Index: NumElems))
24225 return SDValue();
24226
24227 if (!VecIn2.getNode()) {
24228 // If we only have one input vector, and it's twice the size of the
24229 // output, split it in two.
24230 VecIn2 = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT, N1: VecIn1,
24231 N2: DAG.getVectorIdxConstant(Val: NumElems, DL));
24232 VecIn1 = DAG.getExtractSubvector(DL, VT, Vec: VecIn1, Idx: 0);
24233 // Since we now have shorter input vectors, adjust the offset of the
24234 // second vector's start.
24235 Vec2Offset = NumElems;
24236 } else {
24237 assert(InVT2Size <= InVT1Size &&
24238 "Second input is not going to be larger than the first one.");
24239
24240 // VecIn1 is wider than the output, and we have another, possibly
24241 // smaller input. Pad the smaller input with undefs, shuffle at the
24242 // input vector width, and extract the output.
24243 // The shuffle type is different than VT, so check legality again.
24244 if (LegalOperations &&
24245 !TLI.isOperationLegal(Op: ISD::VECTOR_SHUFFLE, VT: InVT1))
24246 return SDValue();
24247
24248 // Legalizing INSERT_SUBVECTOR is tricky - you basically have to
24249 // lower it back into a BUILD_VECTOR. So if the inserted type is
24250 // illegal, don't even try.
24251 if (InVT1 != InVT2) {
24252 if (!TLI.isTypeLegal(VT: InVT2))
24253 return SDValue();
24254 VecIn2 = DAG.getInsertSubvector(DL, Vec: DAG.getUNDEF(VT: InVT1), SubVec: VecIn2, Idx: 0);
24255 }
24256 ShuffleNumElems = NumElems * 2;
24257 }
24258 } else if (InVT2Size * 2 == VTSize && InVT1Size == VTSize) {
24259 SmallVector<SDValue, 2> ConcatOps(2, DAG.getUNDEF(VT: InVT2));
24260 ConcatOps[0] = VecIn2;
24261 VecIn2 = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT, Ops: ConcatOps);
24262 } else if (InVT1Size / VTSize > 1 && InVT1Size % VTSize == 0) {
24263 if (!TLI.isExtractSubvectorCheap(ResVT: VT, SrcVT: InVT1, Index: NumElems) ||
24264 !TLI.isTypeLegal(VT: InVT1) || !TLI.isTypeLegal(VT: InVT2))
24265 return SDValue();
24266 // If dest vector has less than two elements, then use shuffle and extract
24267 // from larger regs will cost even more.
24268 if (VT.getVectorNumElements() <= 2 || !VecIn2.getNode())
24269 return SDValue();
24270 assert(InVT2Size <= InVT1Size &&
24271 "Second input is not going to be larger than the first one.");
24272
24273 // VecIn1 is wider than the output, and we have another, possibly
24274 // smaller input. Pad the smaller input with undefs, shuffle at the
24275 // input vector width, and extract the output.
24276 // The shuffle type is different than VT, so check legality again.
24277 if (LegalOperations && !TLI.isOperationLegal(Op: ISD::VECTOR_SHUFFLE, VT: InVT1))
24278 return SDValue();
24279
24280 if (InVT1 != InVT2) {
24281 VecIn2 = DAG.getInsertSubvector(DL, Vec: DAG.getUNDEF(VT: InVT1), SubVec: VecIn2, Idx: 0);
24282 }
24283 ShuffleNumElems = InVT1Size / VTSize * NumElems;
24284 } else {
24285 // TODO: Support cases where the length mismatch isn't exactly by a
24286 // factor of 2.
24287 // TODO: Move this check upwards, so that if we have bad type
24288 // mismatches, we don't create any DAG nodes.
24289 return SDValue();
24290 }
24291 }
24292
24293 // Initialize mask to undef.
24294 SmallVector<int, 8> Mask(ShuffleNumElems, -1);
24295
24296 // Only need to run up to the number of elements actually used, not the
24297 // total number of elements in the shuffle - if we are shuffling a wider
24298 // vector, the high lanes should be set to undef.
24299 for (unsigned i = 0; i != NumElems; ++i) {
24300 if (VectorMask[i] <= 0)
24301 continue;
24302
24303 unsigned ExtIndex = N->getOperand(Num: i).getConstantOperandVal(i: 1);
24304 if (VectorMask[i] == (int)LeftIdx) {
24305 Mask[i] = ExtIndex;
24306 } else if (VectorMask[i] == (int)LeftIdx + 1) {
24307 Mask[i] = Vec2Offset + ExtIndex;
24308 }
24309 }
24310
24311 // The type the input vectors may have changed above.
24312 InVT1 = VecIn1.getValueType();
24313
24314 // If we already have a VecIn2, it should have the same type as VecIn1.
24315 // If we don't, get an undef/zero vector of the appropriate type.
24316 VecIn2 = VecIn2.getNode() ? VecIn2 : DAG.getUNDEF(VT: InVT1);
24317 assert(InVT1 == VecIn2.getValueType() && "Unexpected second input type.");
24318
24319 SDValue Shuffle = DAG.getVectorShuffle(VT: InVT1, dl: DL, N1: VecIn1, N2: VecIn2, Mask);
24320 if (ShuffleNumElems > NumElems)
24321 Shuffle = DAG.getExtractSubvector(DL, VT, Vec: Shuffle, Idx: 0);
24322
24323 return Shuffle;
24324}
24325
24326static SDValue reduceBuildVecToShuffleWithZero(SDNode *BV, SelectionDAG &DAG) {
24327 assert(BV->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
24328
24329 // First, determine where the build vector is not undef.
24330 // TODO: We could extend this to handle zero elements as well as undefs.
24331 int NumBVOps = BV->getNumOperands();
24332 int ZextElt = -1;
24333 for (int i = 0; i != NumBVOps; ++i) {
24334 SDValue Op = BV->getOperand(Num: i);
24335 if (Op.isUndef())
24336 continue;
24337 if (ZextElt == -1)
24338 ZextElt = i;
24339 else
24340 return SDValue();
24341 }
24342 // Bail out if there's no non-undef element.
24343 if (ZextElt == -1)
24344 return SDValue();
24345
24346 // The build vector contains some number of undef elements and exactly
24347 // one other element. That other element must be a zero-extended scalar
24348 // extracted from a vector at a constant index to turn this into a shuffle.
24349 // Also, require that the build vector does not implicitly truncate/extend
24350 // its elements.
24351 // TODO: This could be enhanced to allow ANY_EXTEND as well as ZERO_EXTEND.
24352 EVT VT = BV->getValueType(ResNo: 0);
24353 SDValue Zext = BV->getOperand(Num: ZextElt);
24354 if (Zext.getOpcode() != ISD::ZERO_EXTEND || !Zext.hasOneUse() ||
24355 Zext.getOperand(i: 0).getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
24356 !isa<ConstantSDNode>(Val: Zext.getOperand(i: 0).getOperand(i: 1)) ||
24357 Zext.getValueSizeInBits() != VT.getScalarSizeInBits())
24358 return SDValue();
24359
24360 // The zero-extend must be a multiple of the source size, and we must be
24361 // building a vector of the same size as the source of the extract element.
24362 SDValue Extract = Zext.getOperand(i: 0);
24363 unsigned DestSize = Zext.getValueSizeInBits();
24364 unsigned SrcSize = Extract.getValueSizeInBits();
24365 if (DestSize % SrcSize != 0 ||
24366 Extract.getOperand(i: 0).getValueSizeInBits() != VT.getSizeInBits())
24367 return SDValue();
24368
24369 // Create a shuffle mask that will combine the extracted element with zeros
24370 // and undefs.
24371 int ZextRatio = DestSize / SrcSize;
24372 int NumMaskElts = NumBVOps * ZextRatio;
24373 SmallVector<int, 32> ShufMask(NumMaskElts, -1);
24374 for (int i = 0; i != NumMaskElts; ++i) {
24375 if (i / ZextRatio == ZextElt) {
24376 // The low bits of the (potentially translated) extracted element map to
24377 // the source vector. The high bits map to zero. We will use a zero vector
24378 // as the 2nd source operand of the shuffle, so use the 1st element of
24379 // that vector (mask value is number-of-elements) for the high bits.
24380 int Low = DAG.getDataLayout().isBigEndian() ? (ZextRatio - 1) : 0;
24381 ShufMask[i] = (i % ZextRatio == Low) ? Extract.getConstantOperandVal(i: 1)
24382 : NumMaskElts;
24383 }
24384
24385 // Undef elements of the build vector remain undef because we initialize
24386 // the shuffle mask with -1.
24387 }
24388
24389 // buildvec undef, ..., (zext (extractelt V, IndexC)), undef... -->
24390 // bitcast (shuffle V, ZeroVec, VectorMask)
24391 SDLoc DL(BV);
24392 EVT VecVT = Extract.getOperand(i: 0).getValueType();
24393 SDValue ZeroVec = DAG.getConstant(Val: 0, DL, VT: VecVT);
24394 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24395 SDValue Shuf = TLI.buildLegalVectorShuffle(VT: VecVT, DL, N0: Extract.getOperand(i: 0),
24396 N1: ZeroVec, Mask: ShufMask, DAG);
24397 if (!Shuf)
24398 return SDValue();
24399 return DAG.getBitcast(VT, V: Shuf);
24400}
24401
24402// FIXME: promote to STLExtras.
24403template <typename R, typename T>
24404static auto getFirstIndexOf(R &&Range, const T &Val) {
24405 auto I = find(Range, Val);
24406 if (I == Range.end())
24407 return static_cast<decltype(std::distance(Range.begin(), I))>(-1);
24408 return std::distance(Range.begin(), I);
24409}
24410
24411// Check to see if this is a BUILD_VECTOR of a bunch of EXTRACT_VECTOR_ELT
24412// operations. If the types of the vectors we're extracting from allow it,
24413// turn this into a vector_shuffle node.
24414SDValue DAGCombiner::reduceBuildVecToShuffle(SDNode *N) {
24415 SDLoc DL(N);
24416 EVT VT = N->getValueType(ResNo: 0);
24417
24418 // Only type-legal BUILD_VECTOR nodes are converted to shuffle nodes.
24419 if (!isTypeLegal(VT))
24420 return SDValue();
24421
24422 if (SDValue V = reduceBuildVecToShuffleWithZero(BV: N, DAG))
24423 return V;
24424
24425 // May only combine to shuffle after legalize if shuffle is legal.
24426 if (LegalOperations && !TLI.isOperationLegal(Op: ISD::VECTOR_SHUFFLE, VT))
24427 return SDValue();
24428
24429 bool UsesZeroVector = false;
24430 unsigned NumElems = N->getNumOperands();
24431
24432 // Record, for each element of the newly built vector, which input vector
24433 // that element comes from. -1 stands for undef, 0 for the zero vector,
24434 // and positive values for the input vectors.
24435 // VectorMask maps each element to its vector number, and VecIn maps vector
24436 // numbers to their initial SDValues.
24437
24438 SmallVector<int, 8> VectorMask(NumElems, -1);
24439 SmallVector<SDValue, 8> VecIn;
24440 VecIn.push_back(Elt: SDValue());
24441
24442 // If we have a single extract_element with a constant index, track the index
24443 // value.
24444 unsigned OneConstExtractIndex = ~0u;
24445
24446 // Count the number of extract_vector_elt sources (i.e. non-constant or undef)
24447 unsigned NumExtracts = 0;
24448
24449 for (unsigned i = 0; i != NumElems; ++i) {
24450 SDValue Op = N->getOperand(Num: i);
24451
24452 if (Op.isUndef())
24453 continue;
24454
24455 // See if we can use a blend with a zero vector.
24456 // TODO: Should we generalize this to a blend with an arbitrary constant
24457 // vector?
24458 if (isNullConstant(V: Op) || isNullFPConstant(V: Op)) {
24459 UsesZeroVector = true;
24460 VectorMask[i] = 0;
24461 continue;
24462 }
24463
24464 // Not an undef or zero. If the input is something other than an
24465 // EXTRACT_VECTOR_ELT with an in-range constant index, bail out.
24466 if (Op.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
24467 return SDValue();
24468
24469 SDValue ExtractedFromVec = Op.getOperand(i: 0);
24470 if (ExtractedFromVec.getValueType().isScalableVector())
24471 return SDValue();
24472 auto *ExtractIdx = dyn_cast<ConstantSDNode>(Val: Op.getOperand(i: 1));
24473 if (!ExtractIdx)
24474 return SDValue();
24475
24476 if (ExtractIdx->getAsAPIntVal().uge(
24477 RHS: ExtractedFromVec.getValueType().getVectorNumElements()))
24478 return SDValue();
24479
24480 // All inputs must have the same element type as the output.
24481 if (VT.getVectorElementType() !=
24482 ExtractedFromVec.getValueType().getVectorElementType())
24483 return SDValue();
24484
24485 OneConstExtractIndex = ExtractIdx->getZExtValue();
24486 ++NumExtracts;
24487
24488 // Have we seen this input vector before?
24489 // The vectors are expected to be tiny (usually 1 or 2 elements), so using
24490 // a map back from SDValues to numbers isn't worth it.
24491 int Idx = getFirstIndexOf(Range&: VecIn, Val: ExtractedFromVec);
24492 if (Idx == -1) { // A new source vector?
24493 Idx = VecIn.size();
24494 VecIn.push_back(Elt: ExtractedFromVec);
24495 }
24496
24497 VectorMask[i] = Idx;
24498 }
24499
24500 // If we didn't find at least one input vector, bail out.
24501 if (VecIn.size() < 2)
24502 return SDValue();
24503
24504 // If all the Operands of BUILD_VECTOR extract from same
24505 // vector, then split the vector efficiently based on the maximum
24506 // vector access index and adjust the VectorMask and
24507 // VecIn accordingly.
24508 bool DidSplitVec = false;
24509 if (VecIn.size() == 2) {
24510 // If we only found a single constant indexed extract_vector_elt feeding the
24511 // build_vector, do not produce a more complicated shuffle if the extract is
24512 // cheap with other constant/undef elements. Skip broadcast patterns with
24513 // multiple uses in the build_vector.
24514
24515 // TODO: This should be more aggressive about skipping the shuffle
24516 // formation, particularly if VecIn[1].hasOneUse(), and regardless of the
24517 // index.
24518 if (NumExtracts == 1 &&
24519 TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_VECTOR_ELT, VT) &&
24520 TLI.isTypeLegal(VT: VT.getVectorElementType()) &&
24521 TLI.isExtractVecEltCheap(VT, Index: OneConstExtractIndex))
24522 return SDValue();
24523
24524 unsigned MaxIndex = 0;
24525 unsigned NearestPow2 = 0;
24526 SDValue Vec = VecIn.back();
24527 EVT InVT = Vec.getValueType();
24528 SmallVector<unsigned, 8> IndexVec(NumElems, 0);
24529
24530 for (unsigned i = 0; i < NumElems; i++) {
24531 if (VectorMask[i] <= 0)
24532 continue;
24533 unsigned Index = N->getOperand(Num: i).getConstantOperandVal(i: 1);
24534 IndexVec[i] = Index;
24535 MaxIndex = std::max(a: MaxIndex, b: Index);
24536 }
24537
24538 NearestPow2 = PowerOf2Ceil(A: MaxIndex);
24539 if (InVT.isSimple() && NearestPow2 > 2 && MaxIndex < NearestPow2 &&
24540 NumElems * 2 < NearestPow2) {
24541 unsigned SplitSize = NearestPow2 / 2;
24542 EVT SplitVT = EVT::getVectorVT(Context&: *DAG.getContext(),
24543 VT: InVT.getVectorElementType(), NumElements: SplitSize);
24544 if (TLI.isTypeLegal(VT: SplitVT) &&
24545 SplitSize + SplitVT.getVectorNumElements() <=
24546 InVT.getVectorNumElements()) {
24547 SDValue VecIn2 = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: SplitVT, N1: Vec,
24548 N2: DAG.getVectorIdxConstant(Val: SplitSize, DL));
24549 SDValue VecIn1 = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: SplitVT, N1: Vec,
24550 N2: DAG.getVectorIdxConstant(Val: 0, DL));
24551 VecIn.pop_back();
24552 VecIn.push_back(Elt: VecIn1);
24553 VecIn.push_back(Elt: VecIn2);
24554 DidSplitVec = true;
24555
24556 for (unsigned i = 0; i < NumElems; i++) {
24557 if (VectorMask[i] <= 0)
24558 continue;
24559 VectorMask[i] = (IndexVec[i] < SplitSize) ? 1 : 2;
24560 }
24561 }
24562 }
24563 }
24564
24565 // Sort input vectors by decreasing vector element count,
24566 // while preserving the relative order of equally-sized vectors.
24567 // Note that we keep the first "implicit zero vector as-is.
24568 SmallVector<SDValue, 8> SortedVecIn(VecIn);
24569 llvm::stable_sort(Range: MutableArrayRef<SDValue>(SortedVecIn).drop_front(),
24570 C: [](const SDValue &a, const SDValue &b) {
24571 return a.getValueType().getVectorNumElements() >
24572 b.getValueType().getVectorNumElements();
24573 });
24574
24575 // We now also need to rebuild the VectorMask, because it referenced element
24576 // order in VecIn, and we just sorted them.
24577 for (int &SourceVectorIndex : VectorMask) {
24578 if (SourceVectorIndex <= 0)
24579 continue;
24580 unsigned Idx = getFirstIndexOf(Range&: SortedVecIn, Val: VecIn[SourceVectorIndex]);
24581 assert(Idx > 0 && Idx < SortedVecIn.size() &&
24582 VecIn[SourceVectorIndex] == SortedVecIn[Idx] && "Remapping failure");
24583 SourceVectorIndex = Idx;
24584 }
24585
24586 VecIn = std::move(SortedVecIn);
24587
24588 // TODO: Should this fire if some of the input vectors has illegal type (like
24589 // it does now), or should we let legalization run its course first?
24590
24591 // Shuffle phase:
24592 // Take pairs of vectors, and shuffle them so that the result has elements
24593 // from these vectors in the correct places.
24594 // For example, given:
24595 // t10: i32 = extract_vector_elt t1, Constant:i64<0>
24596 // t11: i32 = extract_vector_elt t2, Constant:i64<0>
24597 // t12: i32 = extract_vector_elt t3, Constant:i64<0>
24598 // t13: i32 = extract_vector_elt t1, Constant:i64<1>
24599 // t14: v4i32 = BUILD_VECTOR t10, t11, t12, t13
24600 // We will generate:
24601 // t20: v4i32 = vector_shuffle<0,4,u,1> t1, t2
24602 // t21: v4i32 = vector_shuffle<u,u,0,u> t3, undef
24603 SmallVector<SDValue, 4> Shuffles;
24604 for (unsigned In = 0, Len = (VecIn.size() / 2); In < Len; ++In) {
24605 unsigned LeftIdx = 2 * In + 1;
24606 SDValue VecLeft = VecIn[LeftIdx];
24607 SDValue VecRight =
24608 (LeftIdx + 1) < VecIn.size() ? VecIn[LeftIdx + 1] : SDValue();
24609
24610 if (SDValue Shuffle = createBuildVecShuffle(DL, N, VectorMask, VecIn1: VecLeft,
24611 VecIn2: VecRight, LeftIdx, DidSplitVec))
24612 Shuffles.push_back(Elt: Shuffle);
24613 else
24614 return SDValue();
24615 }
24616
24617 // If we need the zero vector as an "ingredient" in the blend tree, add it
24618 // to the list of shuffles.
24619 if (UsesZeroVector)
24620 Shuffles.push_back(Elt: VT.isInteger() ? DAG.getConstant(Val: 0, DL, VT)
24621 : DAG.getConstantFP(Val: 0.0, DL, VT));
24622
24623 // If we only have one shuffle, we're done.
24624 if (Shuffles.size() == 1)
24625 return Shuffles[0];
24626
24627 // Update the vector mask to point to the post-shuffle vectors.
24628 for (int &Vec : VectorMask)
24629 if (Vec == 0)
24630 Vec = Shuffles.size() - 1;
24631 else
24632 Vec = (Vec - 1) / 2;
24633
24634 // More than one shuffle. Generate a binary tree of blends, e.g. if from
24635 // the previous step we got the set of shuffles t10, t11, t12, t13, we will
24636 // generate:
24637 // t10: v8i32 = vector_shuffle<0,8,u,u,u,u,u,u> t1, t2
24638 // t11: v8i32 = vector_shuffle<u,u,0,8,u,u,u,u> t3, t4
24639 // t12: v8i32 = vector_shuffle<u,u,u,u,0,8,u,u> t5, t6
24640 // t13: v8i32 = vector_shuffle<u,u,u,u,u,u,0,8> t7, t8
24641 // t20: v8i32 = vector_shuffle<0,1,10,11,u,u,u,u> t10, t11
24642 // t21: v8i32 = vector_shuffle<u,u,u,u,4,5,14,15> t12, t13
24643 // t30: v8i32 = vector_shuffle<0,1,2,3,12,13,14,15> t20, t21
24644
24645 // Make sure the initial size of the shuffle list is even.
24646 if (Shuffles.size() % 2)
24647 Shuffles.push_back(Elt: DAG.getUNDEF(VT));
24648
24649 for (unsigned CurSize = Shuffles.size(); CurSize > 1; CurSize /= 2) {
24650 if (CurSize % 2) {
24651 Shuffles[CurSize] = DAG.getUNDEF(VT);
24652 CurSize++;
24653 }
24654 for (unsigned In = 0, Len = CurSize / 2; In < Len; ++In) {
24655 int Left = 2 * In;
24656 int Right = 2 * In + 1;
24657 SmallVector<int, 8> Mask(NumElems, -1);
24658 SDValue L = Shuffles[Left];
24659 ArrayRef<int> LMask;
24660 bool IsLeftShuffle = L.getOpcode() == ISD::VECTOR_SHUFFLE &&
24661 L.use_empty() && L.getOperand(i: 1).isUndef() &&
24662 L.getOperand(i: 0).getValueType() == L.getValueType();
24663 if (IsLeftShuffle) {
24664 LMask = cast<ShuffleVectorSDNode>(Val: L.getNode())->getMask();
24665 L = L.getOperand(i: 0);
24666 }
24667 SDValue R = Shuffles[Right];
24668 ArrayRef<int> RMask;
24669 bool IsRightShuffle = R.getOpcode() == ISD::VECTOR_SHUFFLE &&
24670 R.use_empty() && R.getOperand(i: 1).isUndef() &&
24671 R.getOperand(i: 0).getValueType() == R.getValueType();
24672 if (IsRightShuffle) {
24673 RMask = cast<ShuffleVectorSDNode>(Val: R.getNode())->getMask();
24674 R = R.getOperand(i: 0);
24675 }
24676 for (unsigned I = 0; I != NumElems; ++I) {
24677 if (VectorMask[I] == Left) {
24678 Mask[I] = I;
24679 if (IsLeftShuffle)
24680 Mask[I] = LMask[I];
24681 VectorMask[I] = In;
24682 } else if (VectorMask[I] == Right) {
24683 Mask[I] = I + NumElems;
24684 if (IsRightShuffle)
24685 Mask[I] = RMask[I] + NumElems;
24686 VectorMask[I] = In;
24687 }
24688 }
24689
24690 Shuffles[In] = DAG.getVectorShuffle(VT, dl: DL, N1: L, N2: R, Mask);
24691 }
24692 }
24693 return Shuffles[0];
24694}
24695
24696// Try to turn a build vector of zero extends of extract vector elts into a
24697// a vector zero extend and possibly an extract subvector.
24698// TODO: Support sign extend?
24699// TODO: Allow undef elements?
24700SDValue DAGCombiner::convertBuildVecZextToZext(SDNode *N) {
24701 if (LegalOperations)
24702 return SDValue();
24703
24704 EVT VT = N->getValueType(ResNo: 0);
24705
24706 bool FoundZeroExtend = false;
24707 SDValue Op0 = N->getOperand(Num: 0);
24708 auto checkElem = [&](SDValue Op) -> int64_t {
24709 unsigned Opc = Op.getOpcode();
24710 FoundZeroExtend |= (Opc == ISD::ZERO_EXTEND);
24711 if ((Opc == ISD::ZERO_EXTEND || Opc == ISD::ANY_EXTEND) &&
24712 Op.getOperand(i: 0).getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
24713 Op0.getOperand(i: 0).getOperand(i: 0) == Op.getOperand(i: 0).getOperand(i: 0))
24714 if (auto *C = dyn_cast<ConstantSDNode>(Val: Op.getOperand(i: 0).getOperand(i: 1)))
24715 return C->getZExtValue();
24716 return -1;
24717 };
24718
24719 // Make sure the first element matches
24720 // (zext (extract_vector_elt X, C))
24721 // Offset must be a constant multiple of the
24722 // known-minimum vector length of the result type.
24723 int64_t Offset = checkElem(Op0);
24724 if (Offset < 0 || (Offset % VT.getVectorNumElements()) != 0)
24725 return SDValue();
24726
24727 unsigned NumElems = N->getNumOperands();
24728 SDValue In = Op0.getOperand(i: 0).getOperand(i: 0);
24729 EVT InSVT = In.getValueType().getScalarType();
24730 EVT InVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: InSVT, NumElements: NumElems);
24731
24732 // Don't create an illegal input type after type legalization.
24733 if (LegalTypes && !TLI.isTypeLegal(VT: InVT))
24734 return SDValue();
24735
24736 // Ensure all the elements come from the same vector and are adjacent.
24737 for (unsigned i = 1; i != NumElems; ++i) {
24738 if ((Offset + i) != checkElem(N->getOperand(Num: i)))
24739 return SDValue();
24740 }
24741
24742 SDLoc DL(N);
24743 In = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: InVT, N1: In,
24744 N2: Op0.getOperand(i: 0).getOperand(i: 1));
24745 return DAG.getNode(Opcode: FoundZeroExtend ? ISD::ZERO_EXTEND : ISD::ANY_EXTEND, DL,
24746 VT, Operand: In);
24747}
24748
24749// If this is a very simple BUILD_VECTOR with first element being a ZERO_EXTEND,
24750// and all other elements being constant zero's, granularize the BUILD_VECTOR's
24751// element width, absorbing the ZERO_EXTEND, turning it into a constant zero op.
24752// This patten can appear during legalization.
24753//
24754// NOTE: This can be generalized to allow more than a single
24755// non-constant-zero op, UNDEF's, and to be KnownBits-based,
24756SDValue DAGCombiner::convertBuildVecZextToBuildVecWithZeros(SDNode *N) {
24757 // Don't run this after legalization. Targets may have other preferences.
24758 if (Level >= AfterLegalizeDAG)
24759 return SDValue();
24760
24761 // FIXME: support big-endian.
24762 if (DAG.getDataLayout().isBigEndian())
24763 return SDValue();
24764
24765 EVT VT = N->getValueType(ResNo: 0);
24766 EVT OpVT = N->getOperand(Num: 0).getValueType();
24767 assert(!VT.isScalableVector() && "Encountered scalable BUILD_VECTOR?");
24768
24769 EVT OpIntVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: OpVT.getSizeInBits());
24770
24771 if (!TLI.isTypeLegal(VT: OpIntVT) ||
24772 (LegalOperations && !TLI.isOperationLegalOrCustom(Op: ISD::BITCAST, VT: OpIntVT)))
24773 return SDValue();
24774
24775 unsigned EltBitwidth = VT.getScalarSizeInBits();
24776 // NOTE: the actual width of operands may be wider than that!
24777
24778 // Analyze all operands of this BUILD_VECTOR. What is the largest number of
24779 // active bits they all have? We'll want to truncate them all to that width.
24780 unsigned ActiveBits = 0;
24781 APInt KnownZeroOps(VT.getVectorNumElements(), 0);
24782 for (auto I : enumerate(First: N->ops())) {
24783 SDValue Op = I.value();
24784 // FIXME: support UNDEF elements?
24785 if (auto *Cst = dyn_cast<ConstantSDNode>(Val&: Op)) {
24786 unsigned OpActiveBits =
24787 Cst->getAPIntValue().trunc(width: EltBitwidth).getActiveBits();
24788 if (OpActiveBits == 0) {
24789 KnownZeroOps.setBit(I.index());
24790 continue;
24791 }
24792 // Profitability check: don't allow non-zero constant operands.
24793 return SDValue();
24794 }
24795 // Profitability check: there must only be a single non-zero operand,
24796 // and it must be the first operand of the BUILD_VECTOR.
24797 if (I.index() != 0)
24798 return SDValue();
24799 // The operand must be a zero-extension itself.
24800 // FIXME: this could be generalized to known leading zeros check.
24801 if (Op.getOpcode() != ISD::ZERO_EXTEND)
24802 return SDValue();
24803 unsigned CurrActiveBits =
24804 Op.getOperand(i: 0).getValueSizeInBits().getFixedValue();
24805 assert(!ActiveBits && "Already encountered non-constant-zero operand?");
24806 ActiveBits = CurrActiveBits;
24807 // We want to at least halve the element size.
24808 if (2 * ActiveBits > EltBitwidth)
24809 return SDValue();
24810 }
24811
24812 // This BUILD_VECTOR must have at least one non-constant-zero operand.
24813 if (ActiveBits == 0)
24814 return SDValue();
24815
24816 // We have EltBitwidth bits, the *minimal* chunk size is ActiveBits,
24817 // into how many chunks can we split our element width?
24818 EVT NewScalarIntVT, NewIntVT;
24819 std::optional<unsigned> Factor;
24820 // We can split the element into at least two chunks, but not into more
24821 // than |_ EltBitwidth / ActiveBits _| chunks. Find a largest split factor
24822 // for which the element width is a multiple of it,
24823 // and the resulting types/operations on that chunk width are legal.
24824 assert(2 * ActiveBits <= EltBitwidth &&
24825 "We know that half or less bits of the element are active.");
24826 for (unsigned Scale = EltBitwidth / ActiveBits; Scale >= 2; --Scale) {
24827 if (EltBitwidth % Scale != 0)
24828 continue;
24829 unsigned ChunkBitwidth = EltBitwidth / Scale;
24830 assert(ChunkBitwidth >= ActiveBits && "As per starting point.");
24831 NewScalarIntVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ChunkBitwidth);
24832 NewIntVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NewScalarIntVT,
24833 NumElements: Scale * N->getNumOperands());
24834 if (!TLI.isTypeLegal(VT: NewScalarIntVT) || !TLI.isTypeLegal(VT: NewIntVT) ||
24835 (LegalOperations &&
24836 !(TLI.isOperationLegalOrCustom(Op: ISD::TRUNCATE, VT: NewScalarIntVT) &&
24837 TLI.isOperationLegalOrCustom(Op: ISD::BUILD_VECTOR, VT: NewIntVT))))
24838 continue;
24839 Factor = Scale;
24840 break;
24841 }
24842 if (!Factor)
24843 return SDValue();
24844
24845 SDLoc DL(N);
24846 SDValue ZeroOp = DAG.getConstant(Val: 0, DL, VT: NewScalarIntVT);
24847
24848 // Recreate the BUILD_VECTOR, with elements now being Factor times smaller.
24849 SmallVector<SDValue, 16> NewOps;
24850 NewOps.reserve(N: NewIntVT.getVectorNumElements());
24851 for (auto I : enumerate(First: N->ops())) {
24852 SDValue Op = I.value();
24853 assert(!Op.isUndef() && "FIXME: after allowing UNDEF's, handle them here.");
24854 unsigned SrcOpIdx = I.index();
24855 if (KnownZeroOps[SrcOpIdx]) {
24856 NewOps.append(NumInputs: *Factor, Elt: ZeroOp);
24857 continue;
24858 }
24859 Op = DAG.getBitcast(VT: OpIntVT, V: Op);
24860 Op = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: NewScalarIntVT, Operand: Op);
24861 NewOps.emplace_back(Args&: Op);
24862 NewOps.append(NumInputs: *Factor - 1, Elt: ZeroOp);
24863 }
24864 assert(NewOps.size() == NewIntVT.getVectorNumElements());
24865 SDValue NewBV = DAG.getBuildVector(VT: NewIntVT, DL, Ops: NewOps);
24866 NewBV = DAG.getBitcast(VT, V: NewBV);
24867 return NewBV;
24868}
24869
24870SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) {
24871 EVT VT = N->getValueType(ResNo: 0);
24872
24873 // A vector built entirely of undefs is undef.
24874 if (ISD::allOperandsUndef(N))
24875 return DAG.getUNDEF(VT);
24876
24877 // If this is a splat of a bitcast from another vector, change to a
24878 // concat_vector.
24879 // For example:
24880 // (build_vector (i64 (bitcast (v2i32 X))), (i64 (bitcast (v2i32 X)))) ->
24881 // (v2i64 (bitcast (concat_vectors (v2i32 X), (v2i32 X))))
24882 //
24883 // If X is a build_vector itself, the concat can become a larger build_vector.
24884 // TODO: Maybe this is useful for non-splat too?
24885 if (!LegalOperations) {
24886 SDValue Splat = cast<BuildVectorSDNode>(Val: N)->getSplatValue();
24887 // Only change build_vector to a concat_vector if the splat value type is
24888 // same as the vector element type.
24889 if (Splat && Splat.getValueType() == VT.getVectorElementType()) {
24890 Splat = peekThroughBitcasts(V: Splat);
24891 EVT SrcVT = Splat.getValueType();
24892 if (SrcVT.isVector()) {
24893 unsigned NumElts = N->getNumOperands() * SrcVT.getVectorNumElements();
24894 EVT NewVT = EVT::getVectorVT(Context&: *DAG.getContext(),
24895 VT: SrcVT.getVectorElementType(), NumElements: NumElts);
24896 if (!LegalTypes || TLI.isTypeLegal(VT: NewVT)) {
24897 SmallVector<SDValue, 8> Ops(N->getNumOperands(), Splat);
24898 SDValue Concat =
24899 DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT: NewVT, Ops);
24900 return DAG.getBitcast(VT, V: Concat);
24901 }
24902 }
24903 }
24904 }
24905
24906 // Check if we can express BUILD VECTOR via subvector extract.
24907 if (!LegalTypes && (N->getNumOperands() > 1)) {
24908 SDValue Op0 = N->getOperand(Num: 0);
24909 auto checkElem = [&](SDValue Op) -> uint64_t {
24910 if ((Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT) &&
24911 (Op0.getOperand(i: 0) == Op.getOperand(i: 0)))
24912 if (auto CNode = dyn_cast<ConstantSDNode>(Val: Op.getOperand(i: 1)))
24913 return CNode->getZExtValue();
24914 return -1;
24915 };
24916
24917 int Offset = checkElem(Op0);
24918 for (unsigned i = 0; i < N->getNumOperands(); ++i) {
24919 if (Offset + i != checkElem(N->getOperand(Num: i))) {
24920 Offset = -1;
24921 break;
24922 }
24923 }
24924
24925 if ((Offset == 0) &&
24926 (Op0.getOperand(i: 0).getValueType() == N->getValueType(ResNo: 0)))
24927 return Op0.getOperand(i: 0);
24928 if ((Offset != -1) &&
24929 ((Offset % N->getValueType(ResNo: 0).getVectorNumElements()) ==
24930 0)) // IDX must be multiple of output size.
24931 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL: SDLoc(N), VT: N->getValueType(ResNo: 0),
24932 N1: Op0.getOperand(i: 0), N2: Op0.getOperand(i: 1));
24933 }
24934
24935 if (SDValue V = convertBuildVecZextToZext(N))
24936 return V;
24937
24938 if (SDValue V = convertBuildVecZextToBuildVecWithZeros(N))
24939 return V;
24940
24941 if (SDValue V = reduceBuildVecExtToExtBuildVec(N))
24942 return V;
24943
24944 if (SDValue V = reduceBuildVecTruncToBitCast(N))
24945 return V;
24946
24947 if (SDValue V = reduceBuildVecToShuffle(N))
24948 return V;
24949
24950 // A splat of a single element is a SPLAT_VECTOR if supported on the target.
24951 // Do this late as some of the above may replace the splat.
24952 if (TLI.getOperationAction(Op: ISD::SPLAT_VECTOR, VT) != TargetLowering::Expand)
24953 if (SDValue V = cast<BuildVectorSDNode>(Val: N)->getSplatValue()) {
24954 assert(!V.isUndef() && "Splat of undef should have been handled earlier");
24955 return DAG.getNode(Opcode: ISD::SPLAT_VECTOR, DL: SDLoc(N), VT, Operand: V);
24956 }
24957
24958 return SDValue();
24959}
24960
24961static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) {
24962 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24963 EVT OpVT = N->getOperand(Num: 0).getValueType();
24964
24965 // If the operands are legal vectors, leave them alone.
24966 if (TLI.isTypeLegal(VT: OpVT) || OpVT.isScalableVector())
24967 return SDValue();
24968
24969 SDLoc DL(N);
24970 EVT VT = N->getValueType(ResNo: 0);
24971 SmallVector<SDValue, 8> Ops;
24972 EVT SVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: OpVT.getSizeInBits());
24973
24974 // Keep track of what we encounter.
24975 EVT AnyFPVT;
24976
24977 for (const SDValue &Op : N->ops()) {
24978 if (ISD::BITCAST == Op.getOpcode() &&
24979 !Op.getOperand(i: 0).getValueType().isVector())
24980 Ops.push_back(Elt: Op.getOperand(i: 0));
24981 else if (Op.isUndef())
24982 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::UNDEF, DL, VT: SVT));
24983 else
24984 return SDValue();
24985
24986 // Note whether we encounter an integer or floating point scalar.
24987 // If it's neither, bail out, it could be something weird like x86mmx.
24988 EVT LastOpVT = Ops.back().getValueType();
24989 if (LastOpVT.isFloatingPoint())
24990 AnyFPVT = LastOpVT;
24991 else if (!LastOpVT.isInteger())
24992 return SDValue();
24993 }
24994
24995 // If any of the operands is a floating point scalar bitcast to a vector,
24996 // use floating point types throughout, and bitcast everything.
24997 // Replace UNDEFs by another scalar UNDEF node, of the final desired type.
24998 if (AnyFPVT != EVT()) {
24999 SVT = AnyFPVT;
25000 for (SDValue &Op : Ops) {
25001 if (Op.getValueType() == SVT)
25002 continue;
25003 if (Op.isUndef())
25004 Op = DAG.getNode(Opcode: ISD::UNDEF, DL, VT: SVT);
25005 else
25006 Op = DAG.getBitcast(VT: SVT, V: Op);
25007 }
25008 }
25009
25010 EVT VecVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: SVT,
25011 NumElements: VT.getSizeInBits() / SVT.getSizeInBits());
25012 return DAG.getBitcast(VT, V: DAG.getBuildVector(VT: VecVT, DL, Ops));
25013}
25014
25015// Attempt to merge nested concat_vectors/undefs.
25016// Fold concat_vectors(concat_vectors(x,y,z,w),u,u,concat_vectors(a,b,c,d))
25017// --> concat_vectors(x,y,z,w,u,u,u,u,u,u,u,u,a,b,c,d)
25018static SDValue combineConcatVectorOfConcatVectors(SDNode *N,
25019 SelectionDAG &DAG) {
25020 EVT VT = N->getValueType(ResNo: 0);
25021
25022 // Ensure we're concatenating UNDEF and CONCAT_VECTORS nodes of similar types.
25023 EVT SubVT;
25024 SDValue FirstConcat;
25025 for (const SDValue &Op : N->ops()) {
25026 if (Op.isUndef())
25027 continue;
25028 if (Op.getOpcode() != ISD::CONCAT_VECTORS)
25029 return SDValue();
25030 if (!FirstConcat) {
25031 SubVT = Op.getOperand(i: 0).getValueType();
25032 if (!DAG.getTargetLoweringInfo().isTypeLegal(VT: SubVT))
25033 return SDValue();
25034 FirstConcat = Op;
25035 continue;
25036 }
25037 if (SubVT != Op.getOperand(i: 0).getValueType())
25038 return SDValue();
25039 }
25040 assert(FirstConcat && "Concat of all-undefs found");
25041
25042 SmallVector<SDValue> ConcatOps;
25043 for (const SDValue &Op : N->ops()) {
25044 if (Op.isUndef()) {
25045 ConcatOps.append(NumInputs: FirstConcat->getNumOperands(), Elt: DAG.getUNDEF(VT: SubVT));
25046 continue;
25047 }
25048 ConcatOps.append(in_start: Op->op_begin(), in_end: Op->op_end());
25049 }
25050 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT, Ops: ConcatOps);
25051}
25052
25053// Check to see if this is a CONCAT_VECTORS of a bunch of EXTRACT_SUBVECTOR
25054// operations. If so, and if the EXTRACT_SUBVECTOR vector inputs come from at
25055// most two distinct vectors the same size as the result, attempt to turn this
25056// into a legal shuffle.
25057static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) {
25058 EVT VT = N->getValueType(ResNo: 0);
25059 EVT OpVT = N->getOperand(Num: 0).getValueType();
25060
25061 // We currently can't generate an appropriate shuffle for a scalable vector.
25062 if (VT.isScalableVector())
25063 return SDValue();
25064
25065 int NumElts = VT.getVectorNumElements();
25066 int NumOpElts = OpVT.getVectorNumElements();
25067
25068 SDValue SV0 = DAG.getUNDEF(VT), SV1 = DAG.getUNDEF(VT);
25069 SmallVector<int, 8> Mask;
25070
25071 for (SDValue Op : N->ops()) {
25072 Op = peekThroughBitcasts(V: Op);
25073
25074 // UNDEF nodes convert to UNDEF shuffle mask values.
25075 if (Op.isUndef()) {
25076 Mask.append(NumInputs: (unsigned)NumOpElts, Elt: -1);
25077 continue;
25078 }
25079
25080 if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
25081 return SDValue();
25082
25083 // What vector are we extracting the subvector from and at what index?
25084 SDValue ExtVec = Op.getOperand(i: 0);
25085 int ExtIdx = Op.getConstantOperandVal(i: 1);
25086
25087 // We want the EVT of the original extraction to correctly scale the
25088 // extraction index.
25089 EVT ExtVT = ExtVec.getValueType();
25090 ExtVec = peekThroughBitcasts(V: ExtVec);
25091
25092 // UNDEF nodes convert to UNDEF shuffle mask values.
25093 if (ExtVec.isUndef()) {
25094 Mask.append(NumInputs: (unsigned)NumOpElts, Elt: -1);
25095 continue;
25096 }
25097
25098 // Ensure that we are extracting a subvector from a vector the same
25099 // size as the result.
25100 if (ExtVT.getSizeInBits() != VT.getSizeInBits())
25101 return SDValue();
25102
25103 // Scale the subvector index to account for any bitcast.
25104 int NumExtElts = ExtVT.getVectorNumElements();
25105 if (0 == (NumExtElts % NumElts))
25106 ExtIdx /= (NumExtElts / NumElts);
25107 else if (0 == (NumElts % NumExtElts))
25108 ExtIdx *= (NumElts / NumExtElts);
25109 else
25110 return SDValue();
25111
25112 // At most we can reference 2 inputs in the final shuffle.
25113 if (SV0.isUndef() || SV0 == ExtVec) {
25114 SV0 = ExtVec;
25115 for (int i = 0; i != NumOpElts; ++i)
25116 Mask.push_back(Elt: i + ExtIdx);
25117 } else if (SV1.isUndef() || SV1 == ExtVec) {
25118 SV1 = ExtVec;
25119 for (int i = 0; i != NumOpElts; ++i)
25120 Mask.push_back(Elt: i + ExtIdx + NumElts);
25121 } else {
25122 return SDValue();
25123 }
25124 }
25125
25126 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
25127 return TLI.buildLegalVectorShuffle(VT, DL: SDLoc(N), N0: DAG.getBitcast(VT, V: SV0),
25128 N1: DAG.getBitcast(VT, V: SV1), Mask, DAG);
25129}
25130
25131static SDValue combineConcatVectorOfCasts(SDNode *N, SelectionDAG &DAG) {
25132 unsigned CastOpcode = N->getOperand(Num: 0).getOpcode();
25133 switch (CastOpcode) {
25134 case ISD::SINT_TO_FP:
25135 case ISD::UINT_TO_FP:
25136 case ISD::FP_TO_SINT:
25137 case ISD::FP_TO_UINT:
25138 // TODO: Allow more opcodes?
25139 // case ISD::BITCAST:
25140 // case ISD::TRUNCATE:
25141 // case ISD::ZERO_EXTEND:
25142 // case ISD::SIGN_EXTEND:
25143 // case ISD::FP_EXTEND:
25144 break;
25145 default:
25146 return SDValue();
25147 }
25148
25149 EVT SrcVT = N->getOperand(Num: 0).getOperand(i: 0).getValueType();
25150 if (!SrcVT.isVector())
25151 return SDValue();
25152
25153 // All operands of the concat must be the same kind of cast from the same
25154 // source type.
25155 SmallVector<SDValue, 4> SrcOps;
25156 for (SDValue Op : N->ops()) {
25157 if (Op.getOpcode() != CastOpcode || !Op.hasOneUse() ||
25158 Op.getOperand(i: 0).getValueType() != SrcVT)
25159 return SDValue();
25160 SrcOps.push_back(Elt: Op.getOperand(i: 0));
25161 }
25162
25163 // The wider cast must be supported by the target. This is unusual because
25164 // the operation support type parameter depends on the opcode. In addition,
25165 // check the other type in the cast to make sure this is really legal.
25166 EVT VT = N->getValueType(ResNo: 0);
25167 EVT SrcEltVT = SrcVT.getVectorElementType();
25168 ElementCount NumElts = SrcVT.getVectorElementCount() * N->getNumOperands();
25169 EVT ConcatSrcVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: SrcEltVT, EC: NumElts);
25170 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
25171 switch (CastOpcode) {
25172 case ISD::SINT_TO_FP:
25173 case ISD::UINT_TO_FP:
25174 if (!TLI.isOperationLegalOrCustom(Op: CastOpcode, VT: ConcatSrcVT) ||
25175 !TLI.isTypeLegal(VT))
25176 return SDValue();
25177 break;
25178 case ISD::FP_TO_SINT:
25179 case ISD::FP_TO_UINT:
25180 if (!TLI.isOperationLegalOrCustom(Op: CastOpcode, VT) ||
25181 !TLI.isTypeLegal(VT: ConcatSrcVT))
25182 return SDValue();
25183 break;
25184 default:
25185 llvm_unreachable("Unexpected cast opcode");
25186 }
25187
25188 // concat (cast X), (cast Y)... -> cast (concat X, Y...)
25189 SDLoc DL(N);
25190 SDValue NewConcat = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT: ConcatSrcVT, Ops: SrcOps);
25191 return DAG.getNode(Opcode: CastOpcode, DL, VT, Operand: NewConcat);
25192}
25193
25194// See if this is a simple CONCAT_VECTORS with no UNDEF operands, and if one of
25195// the operands is a SHUFFLE_VECTOR, and all other operands are also operands
25196// to that SHUFFLE_VECTOR, create wider SHUFFLE_VECTOR.
25197static SDValue combineConcatVectorOfShuffleAndItsOperands(
25198 SDNode *N, SelectionDAG &DAG, const TargetLowering &TLI, bool LegalTypes,
25199 bool LegalOperations) {
25200 EVT VT = N->getValueType(ResNo: 0);
25201 EVT OpVT = N->getOperand(Num: 0).getValueType();
25202 if (VT.isScalableVector())
25203 return SDValue();
25204
25205 // For now, only allow simple 2-operand concatenations.
25206 if (N->getNumOperands() != 2)
25207 return SDValue();
25208
25209 // Don't create illegal types/shuffles when not allowed to.
25210 if ((LegalTypes && !TLI.isTypeLegal(VT)) ||
25211 (LegalOperations &&
25212 !TLI.isOperationLegalOrCustom(Op: ISD::VECTOR_SHUFFLE, VT)))
25213 return SDValue();
25214
25215 // Analyze all of the operands of the CONCAT_VECTORS. Out of all of them,
25216 // we want to find one that is: (1) a SHUFFLE_VECTOR (2) only used by us,
25217 // and (3) all operands of CONCAT_VECTORS must be either that SHUFFLE_VECTOR,
25218 // or one of the operands of that SHUFFLE_VECTOR (but not UNDEF!).
25219 // (4) and for now, the SHUFFLE_VECTOR must be unary.
25220 ShuffleVectorSDNode *SVN = nullptr;
25221 for (SDValue Op : N->ops()) {
25222 if (auto *CurSVN = dyn_cast<ShuffleVectorSDNode>(Val&: Op);
25223 CurSVN && CurSVN->getOperand(Num: 1).isUndef() && N->isOnlyUserOf(N: CurSVN) &&
25224 all_of(Range: N->ops(), P: [CurSVN](SDValue Op) {
25225 // FIXME: can we allow UNDEF operands?
25226 return !Op.isUndef() &&
25227 (Op.getNode() == CurSVN || is_contained(Range: CurSVN->ops(), Element: Op));
25228 })) {
25229 SVN = CurSVN;
25230 break;
25231 }
25232 }
25233 if (!SVN)
25234 return SDValue();
25235
25236 // We are going to pad the shuffle operands, so any indice, that was picking
25237 // from the second operand, must be adjusted.
25238 SmallVector<int, 16> AdjustedMask(SVN->getMask());
25239 assert(SVN->getOperand(1).isUndef() && "Expected unary shuffle!");
25240
25241 // Identity masks for the operands of the (padded) shuffle.
25242 SmallVector<int, 32> IdentityMask(2 * OpVT.getVectorNumElements());
25243 MutableArrayRef<int> FirstShufOpIdentityMask =
25244 MutableArrayRef<int>(IdentityMask)
25245 .take_front(N: OpVT.getVectorNumElements());
25246 MutableArrayRef<int> SecondShufOpIdentityMask =
25247 MutableArrayRef<int>(IdentityMask).take_back(N: OpVT.getVectorNumElements());
25248 std::iota(first: FirstShufOpIdentityMask.begin(), last: FirstShufOpIdentityMask.end(), value: 0);
25249 std::iota(first: SecondShufOpIdentityMask.begin(), last: SecondShufOpIdentityMask.end(),
25250 value: VT.getVectorNumElements());
25251
25252 // New combined shuffle mask.
25253 SmallVector<int, 32> Mask;
25254 Mask.reserve(N: VT.getVectorNumElements());
25255 for (SDValue Op : N->ops()) {
25256 assert(!Op.isUndef() && "Not expecting to concatenate UNDEF.");
25257 if (Op.getNode() == SVN) {
25258 append_range(C&: Mask, R&: AdjustedMask);
25259 continue;
25260 }
25261 if (Op == SVN->getOperand(Num: 0)) {
25262 append_range(C&: Mask, R&: FirstShufOpIdentityMask);
25263 continue;
25264 }
25265 if (Op == SVN->getOperand(Num: 1)) {
25266 append_range(C&: Mask, R&: SecondShufOpIdentityMask);
25267 continue;
25268 }
25269 llvm_unreachable("Unexpected operand!");
25270 }
25271
25272 // Don't create illegal shuffle masks.
25273 if (!TLI.isShuffleMaskLegal(Mask, VT))
25274 return SDValue();
25275
25276 // Pad the shuffle operands with UNDEF.
25277 SDLoc dl(N);
25278 std::array<SDValue, 2> ShufOps;
25279 for (auto I : zip(t: SVN->ops(), u&: ShufOps)) {
25280 SDValue ShufOp = std::get<0>(t&: I);
25281 SDValue &NewShufOp = std::get<1>(t&: I);
25282 if (ShufOp.isUndef())
25283 NewShufOp = DAG.getUNDEF(VT);
25284 else {
25285 SmallVector<SDValue, 2> ShufOpParts(N->getNumOperands(),
25286 DAG.getUNDEF(VT: OpVT));
25287 ShufOpParts[0] = ShufOp;
25288 NewShufOp = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: dl, VT, Ops: ShufOpParts);
25289 }
25290 }
25291 // Finally, create the new wide shuffle.
25292 return DAG.getVectorShuffle(VT, dl, N1: ShufOps[0], N2: ShufOps[1], Mask);
25293}
25294
25295SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
25296 // If we only have one input vector, we don't need to do any concatenation.
25297 if (N->getNumOperands() == 1)
25298 return N->getOperand(Num: 0);
25299
25300 // Check if all of the operands are undefs.
25301 EVT VT = N->getValueType(ResNo: 0);
25302 if (ISD::allOperandsUndef(N))
25303 return DAG.getUNDEF(VT);
25304
25305 // Optimize concat_vectors where all but the first of the vectors are undef.
25306 if (all_of(Range: drop_begin(RangeOrContainer: N->ops()),
25307 P: [](const SDValue &Op) { return Op.isUndef(); })) {
25308 SDValue In = N->getOperand(Num: 0);
25309 assert(In.getValueType().isVector() && "Must concat vectors");
25310
25311 // If the input is a concat_vectors, just make a larger concat by padding
25312 // with smaller undefs.
25313 //
25314 // Legalizing in AArch64TargetLowering::LowerCONCAT_VECTORS() and combining
25315 // here could cause an infinite loop. That legalizing happens when LegalDAG
25316 // is true and input of AArch64TargetLowering::LowerCONCAT_VECTORS() is
25317 // scalable.
25318 if (In.getOpcode() == ISD::CONCAT_VECTORS && In.hasOneUse() &&
25319 !(LegalDAG && In.getValueType().isScalableVector())) {
25320 unsigned NumOps = N->getNumOperands() * In.getNumOperands();
25321 SmallVector<SDValue, 4> Ops(In->ops());
25322 Ops.resize(N: NumOps, NV: DAG.getUNDEF(VT: Ops[0].getValueType()));
25323 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT, Ops);
25324 }
25325
25326 SDValue Scalar = peekThroughOneUseBitcasts(V: In);
25327
25328 // concat_vectors(scalar_to_vector(scalar), undef) ->
25329 // scalar_to_vector(scalar)
25330 if (!LegalOperations && Scalar.getOpcode() == ISD::SCALAR_TO_VECTOR &&
25331 Scalar.hasOneUse()) {
25332 EVT SVT = Scalar.getValueType().getVectorElementType();
25333 if (SVT == Scalar.getOperand(i: 0).getValueType())
25334 Scalar = Scalar.getOperand(i: 0);
25335 }
25336
25337 // concat_vectors(scalar, undef) -> scalar_to_vector(scalar)
25338 if (!Scalar.getValueType().isVector() && In.hasOneUse()) {
25339 // If the bitcast type isn't legal, it might be a trunc of a legal type;
25340 // look through the trunc so we can still do the transform:
25341 // concat_vectors(trunc(scalar), undef) -> scalar_to_vector(scalar)
25342 if (Scalar->getOpcode() == ISD::TRUNCATE &&
25343 !TLI.isTypeLegal(VT: Scalar.getValueType()) &&
25344 TLI.isTypeLegal(VT: Scalar->getOperand(Num: 0).getValueType()))
25345 Scalar = Scalar->getOperand(Num: 0);
25346
25347 EVT SclTy = Scalar.getValueType();
25348
25349 if (!SclTy.isFloatingPoint() && !SclTy.isInteger())
25350 return SDValue();
25351
25352 // Bail out if the vector size is not a multiple of the scalar size.
25353 if (VT.getSizeInBits() % SclTy.getSizeInBits())
25354 return SDValue();
25355
25356 unsigned VNTNumElms = VT.getSizeInBits() / SclTy.getSizeInBits();
25357 if (VNTNumElms < 2)
25358 return SDValue();
25359
25360 EVT NVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: SclTy, NumElements: VNTNumElms);
25361 if (!TLI.isTypeLegal(VT: NVT) || !TLI.isTypeLegal(VT: Scalar.getValueType()))
25362 return SDValue();
25363
25364 SDValue Res = DAG.getNode(Opcode: ISD::SCALAR_TO_VECTOR, DL: SDLoc(N), VT: NVT, Operand: Scalar);
25365 return DAG.getBitcast(VT, V: Res);
25366 }
25367 }
25368
25369 // Fold any combination of BUILD_VECTOR or UNDEF nodes into one BUILD_VECTOR.
25370 // We have already tested above for an UNDEF only concatenation.
25371 // fold (concat_vectors (BUILD_VECTOR A, B, ...), (BUILD_VECTOR C, D, ...))
25372 // -> (BUILD_VECTOR A, B, ..., C, D, ...)
25373 auto IsBuildVectorOrUndef = [](const SDValue &Op) {
25374 return Op.isUndef() || ISD::BUILD_VECTOR == Op.getOpcode();
25375 };
25376 if (llvm::all_of(Range: N->ops(), P: IsBuildVectorOrUndef)) {
25377 SmallVector<SDValue, 8> Opnds;
25378 EVT SVT = VT.getScalarType();
25379
25380 EVT MinVT = SVT;
25381 if (!SVT.isFloatingPoint()) {
25382 // If BUILD_VECTOR are from built from integer, they may have different
25383 // operand types. Get the smallest type and truncate all operands to it.
25384 bool FoundMinVT = false;
25385 for (const SDValue &Op : N->ops())
25386 if (ISD::BUILD_VECTOR == Op.getOpcode()) {
25387 EVT OpSVT = Op.getOperand(i: 0).getValueType();
25388 MinVT = (!FoundMinVT || OpSVT.bitsLE(VT: MinVT)) ? OpSVT : MinVT;
25389 FoundMinVT = true;
25390 }
25391 assert(FoundMinVT && "Concat vector type mismatch");
25392 }
25393
25394 for (const SDValue &Op : N->ops()) {
25395 EVT OpVT = Op.getValueType();
25396 unsigned NumElts = OpVT.getVectorNumElements();
25397
25398 if (Op.isUndef())
25399 Opnds.append(NumInputs: NumElts, Elt: DAG.getUNDEF(VT: MinVT));
25400
25401 if (ISD::BUILD_VECTOR == Op.getOpcode()) {
25402 if (SVT.isFloatingPoint()) {
25403 assert(SVT == OpVT.getScalarType() && "Concat vector type mismatch");
25404 Opnds.append(in_start: Op->op_begin(), in_end: Op->op_begin() + NumElts);
25405 } else {
25406 for (unsigned i = 0; i != NumElts; ++i)
25407 Opnds.push_back(
25408 Elt: DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N), VT: MinVT, Operand: Op.getOperand(i)));
25409 }
25410 }
25411 }
25412
25413 assert(VT.getVectorNumElements() == Opnds.size() &&
25414 "Concat vector type mismatch");
25415 return DAG.getBuildVector(VT, DL: SDLoc(N), Ops: Opnds);
25416 }
25417
25418 // Fold CONCAT_VECTORS of only bitcast scalars (or undef) to BUILD_VECTOR.
25419 // FIXME: Add support for concat_vectors(bitcast(vec0),bitcast(vec1),...).
25420 if (SDValue V = combineConcatVectorOfScalars(N, DAG))
25421 return V;
25422
25423 if (Level <= AfterLegalizeVectorOps && TLI.isTypeLegal(VT)) {
25424 // Fold CONCAT_VECTORS of CONCAT_VECTORS (or undef) to VECTOR_SHUFFLE.
25425 if (SDValue V = combineConcatVectorOfConcatVectors(N, DAG))
25426 return V;
25427
25428 // Fold CONCAT_VECTORS of EXTRACT_SUBVECTOR (or undef) to VECTOR_SHUFFLE.
25429 if (SDValue V = combineConcatVectorOfExtracts(N, DAG))
25430 return V;
25431 }
25432
25433 if (SDValue V = combineConcatVectorOfCasts(N, DAG))
25434 return V;
25435
25436 if (SDValue V = combineConcatVectorOfShuffleAndItsOperands(
25437 N, DAG, TLI, LegalTypes, LegalOperations))
25438 return V;
25439
25440 // Type legalization of vectors and DAG canonicalization of SHUFFLE_VECTOR
25441 // nodes often generate nop CONCAT_VECTOR nodes. Scan the CONCAT_VECTOR
25442 // operands and look for a CONCAT operations that place the incoming vectors
25443 // at the exact same location.
25444 //
25445 // For scalable vectors, EXTRACT_SUBVECTOR indexes are implicitly scaled.
25446 SDValue SingleSource = SDValue();
25447 unsigned PartNumElem =
25448 N->getOperand(Num: 0).getValueType().getVectorMinNumElements();
25449
25450 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
25451 SDValue Op = N->getOperand(Num: i);
25452
25453 if (Op.isUndef())
25454 continue;
25455
25456 // Check if this is the identity extract:
25457 if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
25458 return SDValue();
25459
25460 // Find the single incoming vector for the extract_subvector.
25461 if (SingleSource.getNode()) {
25462 if (Op.getOperand(i: 0) != SingleSource)
25463 return SDValue();
25464 } else {
25465 SingleSource = Op.getOperand(i: 0);
25466
25467 // Check the source type is the same as the type of the result.
25468 // If not, this concat may extend the vector, so we can not
25469 // optimize it away.
25470 if (SingleSource.getValueType() != N->getValueType(ResNo: 0))
25471 return SDValue();
25472 }
25473
25474 // Check that we are reading from the identity index.
25475 unsigned IdentityIndex = i * PartNumElem;
25476 if (Op.getConstantOperandAPInt(i: 1) != IdentityIndex)
25477 return SDValue();
25478 }
25479
25480 if (SingleSource.getNode())
25481 return SingleSource;
25482
25483 return SDValue();
25484}
25485
25486// Helper that peeks through INSERT_SUBVECTOR/CONCAT_VECTORS to find
25487// if the subvector can be sourced for free.
25488static SDValue getSubVectorSrc(SDValue V, unsigned Index, EVT SubVT) {
25489 if (V.getOpcode() == ISD::INSERT_SUBVECTOR &&
25490 V.getOperand(i: 1).getValueType() == SubVT &&
25491 V.getConstantOperandAPInt(i: 2) == Index) {
25492 return V.getOperand(i: 1);
25493 }
25494 if (V.getOpcode() == ISD::CONCAT_VECTORS &&
25495 V.getOperand(i: 0).getValueType() == SubVT &&
25496 (Index % SubVT.getVectorMinNumElements()) == 0) {
25497 uint64_t SubIdx = Index / SubVT.getVectorMinNumElements();
25498 return V.getOperand(i: SubIdx);
25499 }
25500 return SDValue();
25501}
25502
25503static SDValue narrowInsertExtractVectorBinOp(EVT SubVT, SDValue BinOp,
25504 unsigned Index, const SDLoc &DL,
25505 SelectionDAG &DAG,
25506 bool LegalOperations) {
25507 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
25508 unsigned BinOpcode = BinOp.getOpcode();
25509 if (!TLI.isBinOp(Opcode: BinOpcode) || BinOp->getNumValues() != 1)
25510 return SDValue();
25511
25512 EVT VecVT = BinOp.getValueType();
25513 SDValue Bop0 = BinOp.getOperand(i: 0), Bop1 = BinOp.getOperand(i: 1);
25514 if (VecVT != Bop0.getValueType() || VecVT != Bop1.getValueType())
25515 return SDValue();
25516 if (!TLI.isOperationLegalOrCustom(Op: BinOpcode, VT: SubVT, LegalOnly: LegalOperations))
25517 return SDValue();
25518
25519 SDValue Sub0 = getSubVectorSrc(V: Bop0, Index, SubVT);
25520 SDValue Sub1 = getSubVectorSrc(V: Bop1, Index, SubVT);
25521
25522 // TODO: We could handle the case where only 1 operand is being inserted by
25523 // creating an extract of the other operand, but that requires checking
25524 // number of uses and/or costs.
25525 if (!Sub0 || !Sub1)
25526 return SDValue();
25527
25528 // We are inserting both operands of the wide binop only to extract back
25529 // to the narrow vector size. Eliminate all of the insert/extract:
25530 // ext (binop (ins ?, X, Index), (ins ?, Y, Index)), Index --> binop X, Y
25531 return DAG.getNode(Opcode: BinOpcode, DL, VT: SubVT, N1: Sub0, N2: Sub1, Flags: BinOp->getFlags());
25532}
25533
25534/// If we are extracting a subvector produced by a wide binary operator try
25535/// to use a narrow binary operator and/or avoid concatenation and extraction.
25536static SDValue narrowExtractedVectorBinOp(EVT VT, SDValue Src, unsigned Index,
25537 const SDLoc &DL, SelectionDAG &DAG,
25538 bool LegalOperations) {
25539 // TODO: Refactor with the caller (visitEXTRACT_SUBVECTOR), so we can share
25540 // some of these bailouts with other transforms.
25541
25542 if (SDValue V = narrowInsertExtractVectorBinOp(SubVT: VT, BinOp: Src, Index, DL, DAG,
25543 LegalOperations))
25544 return V;
25545
25546 // We are looking for an optionally bitcasted wide vector binary operator
25547 // feeding an extract subvector.
25548 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
25549 SDValue BinOp = peekThroughBitcasts(V: Src);
25550 unsigned BOpcode = BinOp.getOpcode();
25551 if (!TLI.isBinOp(Opcode: BOpcode) || BinOp->getNumValues() != 1)
25552 return SDValue();
25553
25554 // Exclude the fake form of fneg (fsub -0.0, x) because that is likely to be
25555 // reduced to the unary fneg when it is visited, and we probably want to deal
25556 // with fneg in a target-specific way.
25557 if (BOpcode == ISD::FSUB) {
25558 auto *C = isConstOrConstSplatFP(N: BinOp.getOperand(i: 0), /*AllowUndefs*/ true);
25559 if (C && C->getValueAPF().isNegZero())
25560 return SDValue();
25561 }
25562
25563 // The binop must be a vector type, so we can extract some fraction of it.
25564 EVT WideBVT = BinOp.getValueType();
25565 // The optimisations below currently assume we are dealing with fixed length
25566 // vectors. It is possible to add support for scalable vectors, but at the
25567 // moment we've done no analysis to prove whether they are profitable or not.
25568 if (!WideBVT.isFixedLengthVector())
25569 return SDValue();
25570
25571 assert((Index % VT.getVectorNumElements()) == 0 &&
25572 "Extract index is not a multiple of the vector length.");
25573
25574 // Bail out if this is not a proper multiple width extraction.
25575 unsigned WideWidth = WideBVT.getSizeInBits();
25576 unsigned NarrowWidth = VT.getSizeInBits();
25577 if (WideWidth % NarrowWidth != 0)
25578 return SDValue();
25579
25580 // Bail out if we are extracting a fraction of a single operation. This can
25581 // occur because we potentially looked through a bitcast of the binop.
25582 unsigned NarrowingRatio = WideWidth / NarrowWidth;
25583 unsigned WideNumElts = WideBVT.getVectorNumElements();
25584 if (WideNumElts % NarrowingRatio != 0)
25585 return SDValue();
25586
25587 // Bail out if the target does not support a narrower version of the binop.
25588 EVT NarrowBVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: WideBVT.getScalarType(),
25589 NumElements: WideNumElts / NarrowingRatio);
25590 if (!TLI.isOperationLegalOrCustomOrPromote(Op: BOpcode, VT: NarrowBVT,
25591 LegalOnly: LegalOperations))
25592 return SDValue();
25593
25594 // If extraction is cheap, we don't need to look at the binop operands
25595 // for concat ops. The narrow binop alone makes this transform profitable.
25596 // We can't just reuse the original extract index operand because we may have
25597 // bitcasted.
25598 unsigned ConcatOpNum = Index / VT.getVectorNumElements();
25599 unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements();
25600 if (TLI.isExtractSubvectorCheap(ResVT: NarrowBVT, SrcVT: WideBVT, Index: ExtBOIdx) &&
25601 BinOp.hasOneUse() && Src->hasOneUse()) {
25602 // extract (binop B0, B1), N --> binop (extract B0, N), (extract B1, N)
25603 SDValue NewExtIndex = DAG.getVectorIdxConstant(Val: ExtBOIdx, DL);
25604 SDValue X = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NarrowBVT,
25605 N1: BinOp.getOperand(i: 0), N2: NewExtIndex);
25606 SDValue Y = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NarrowBVT,
25607 N1: BinOp.getOperand(i: 1), N2: NewExtIndex);
25608 SDValue NarrowBinOp =
25609 DAG.getNode(Opcode: BOpcode, DL, VT: NarrowBVT, N1: X, N2: Y, Flags: BinOp->getFlags());
25610 return DAG.getBitcast(VT, V: NarrowBinOp);
25611 }
25612
25613 // Only handle the case where we are doubling and then halving. A larger ratio
25614 // may require more than two narrow binops to replace the wide binop.
25615 if (NarrowingRatio != 2)
25616 return SDValue();
25617
25618 // TODO: The motivating case for this transform is an x86 AVX1 target. That
25619 // target has temptingly almost legal versions of bitwise logic ops in 256-bit
25620 // flavors, but no other 256-bit integer support. This could be extended to
25621 // handle any binop, but that may require fixing/adding other folds to avoid
25622 // codegen regressions.
25623 if (BOpcode != ISD::AND && BOpcode != ISD::OR && BOpcode != ISD::XOR)
25624 return SDValue();
25625
25626 // We need at least one concatenation operation of a binop operand to make
25627 // this transform worthwhile. The concat must double the input vector sizes.
25628 auto GetSubVector = [ConcatOpNum](SDValue V) -> SDValue {
25629 if (V.getOpcode() == ISD::CONCAT_VECTORS && V.getNumOperands() == 2)
25630 return V.getOperand(i: ConcatOpNum);
25631 return SDValue();
25632 };
25633 SDValue SubVecL = GetSubVector(peekThroughBitcasts(V: BinOp.getOperand(i: 0)));
25634 SDValue SubVecR = GetSubVector(peekThroughBitcasts(V: BinOp.getOperand(i: 1)));
25635
25636 if (SubVecL || SubVecR) {
25637 // If a binop operand was not the result of a concat, we must extract a
25638 // half-sized operand for our new narrow binop:
25639 // extract (binop (concat X1, X2), (concat Y1, Y2)), N --> binop XN, YN
25640 // extract (binop (concat X1, X2), Y), N --> binop XN, (extract Y, IndexC)
25641 // extract (binop X, (concat Y1, Y2)), N --> binop (extract X, IndexC), YN
25642 SDValue IndexC = DAG.getVectorIdxConstant(Val: ExtBOIdx, DL);
25643 SDValue X = SubVecL ? DAG.getBitcast(VT: NarrowBVT, V: SubVecL)
25644 : DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NarrowBVT,
25645 N1: BinOp.getOperand(i: 0), N2: IndexC);
25646
25647 SDValue Y = SubVecR ? DAG.getBitcast(VT: NarrowBVT, V: SubVecR)
25648 : DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NarrowBVT,
25649 N1: BinOp.getOperand(i: 1), N2: IndexC);
25650
25651 SDValue NarrowBinOp = DAG.getNode(Opcode: BOpcode, DL, VT: NarrowBVT, N1: X, N2: Y);
25652 return DAG.getBitcast(VT, V: NarrowBinOp);
25653 }
25654
25655 return SDValue();
25656}
25657
25658/// If we are extracting a subvector from a wide vector load, convert to a
25659/// narrow load to eliminate the extraction:
25660/// (extract_subvector (load wide vector)) --> (load narrow vector)
25661static SDValue narrowExtractedVectorLoad(EVT VT, SDValue Src, unsigned Index,
25662 const SDLoc &DL, SelectionDAG &DAG) {
25663 // TODO: Add support for big-endian. The offset calculation must be adjusted.
25664 if (DAG.getDataLayout().isBigEndian())
25665 return SDValue();
25666
25667 auto *Ld = dyn_cast<LoadSDNode>(Val&: Src);
25668 if (!Ld || !ISD::isNormalLoad(N: Ld) || !Ld->isSimple())
25669 return SDValue();
25670
25671 // We can only create byte sized loads.
25672 if (!VT.isByteSized())
25673 return SDValue();
25674
25675 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
25676 if (!TLI.isOperationLegalOrCustomOrPromote(Op: ISD::LOAD, VT))
25677 return SDValue();
25678
25679 unsigned NumElts = VT.getVectorMinNumElements();
25680 // A fixed length vector being extracted from a scalable vector
25681 // may not be any *smaller* than the scalable one.
25682 if (Index == 0 && NumElts >= Ld->getValueType(ResNo: 0).getVectorMinNumElements())
25683 return SDValue();
25684
25685 // The definition of EXTRACT_SUBVECTOR states that the index must be a
25686 // multiple of the minimum number of elements in the result type.
25687 assert(Index % NumElts == 0 && "The extract subvector index is not a "
25688 "multiple of the result's element count");
25689
25690 // It's fine to use TypeSize here as we know the offset will not be negative.
25691 TypeSize Offset = VT.getStoreSize() * (Index / NumElts);
25692 std::optional<unsigned> ByteOffset;
25693 if (Offset.isFixed())
25694 ByteOffset = Offset.getFixedValue();
25695
25696 if (!TLI.shouldReduceLoadWidth(Load: Ld, ExtTy: Ld->getExtensionType(), NewVT: VT, ByteOffset))
25697 return SDValue();
25698
25699 // The narrow load will be offset from the base address of the old load if
25700 // we are extracting from something besides index 0 (little-endian).
25701 // TODO: Use "BaseIndexOffset" to make this more effective.
25702 SDValue NewAddr = DAG.getMemBasePlusOffset(Base: Ld->getBasePtr(), Offset, DL);
25703
25704 MachineFunction &MF = DAG.getMachineFunction();
25705 MachineMemOperand *MMO;
25706 if (Offset.isScalable()) {
25707 MachinePointerInfo MPI =
25708 MachinePointerInfo(Ld->getPointerInfo().getAddrSpace());
25709 MMO = MF.getMachineMemOperand(MMO: Ld->getMemOperand(), PtrInfo: MPI, Size: VT.getStoreSize());
25710 } else
25711 MMO = MF.getMachineMemOperand(MMO: Ld->getMemOperand(), Offset: Offset.getFixedValue(),
25712 Size: VT.getStoreSize());
25713
25714 SDValue NewLd = DAG.getLoad(VT, dl: DL, Chain: Ld->getChain(), Ptr: NewAddr, MMO);
25715 DAG.makeEquivalentMemoryOrdering(OldLoad: Ld, NewMemOp: NewLd);
25716 return NewLd;
25717}
25718
25719/// Given EXTRACT_SUBVECTOR(VECTOR_SHUFFLE(Op0, Op1, Mask)),
25720/// try to produce VECTOR_SHUFFLE(EXTRACT_SUBVECTOR(Op?, ?),
25721/// EXTRACT_SUBVECTOR(Op?, ?),
25722/// Mask'))
25723/// iff it is legal and profitable to do so. Notably, the trimmed mask
25724/// (containing only the elements that are extracted)
25725/// must reference at most two subvectors.
25726static SDValue foldExtractSubvectorFromShuffleVector(EVT NarrowVT, SDValue Src,
25727 unsigned Index,
25728 const SDLoc &DL,
25729 SelectionDAG &DAG,
25730 bool LegalOperations) {
25731 // Only deal with non-scalable vectors.
25732 EVT WideVT = Src.getValueType();
25733 if (!NarrowVT.isFixedLengthVector() || !WideVT.isFixedLengthVector())
25734 return SDValue();
25735
25736 // The operand must be a shufflevector.
25737 auto *WideShuffleVector = dyn_cast<ShuffleVectorSDNode>(Val&: Src);
25738 if (!WideShuffleVector)
25739 return SDValue();
25740
25741 // The old shuffleneeds to go away.
25742 if (!WideShuffleVector->hasOneUse())
25743 return SDValue();
25744
25745 // And the narrow shufflevector that we'll form must be legal.
25746 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
25747 if (LegalOperations &&
25748 !TLI.isOperationLegalOrCustom(Op: ISD::VECTOR_SHUFFLE, VT: NarrowVT))
25749 return SDValue();
25750
25751 int NumEltsExtracted = NarrowVT.getVectorNumElements();
25752 assert((Index % NumEltsExtracted) == 0 &&
25753 "Extract index is not a multiple of the output vector length.");
25754
25755 int WideNumElts = WideVT.getVectorNumElements();
25756
25757 SmallVector<int, 16> NewMask;
25758 NewMask.reserve(N: NumEltsExtracted);
25759 SmallSetVector<std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>, 2>
25760 DemandedSubvectors;
25761
25762 // Try to decode the wide mask into narrow mask from at most two subvectors.
25763 for (int M : WideShuffleVector->getMask().slice(N: Index, M: NumEltsExtracted)) {
25764 assert((M >= -1) && (M < (2 * WideNumElts)) &&
25765 "Out-of-bounds shuffle mask?");
25766
25767 if (M < 0) {
25768 // Does not depend on operands, does not require adjustment.
25769 NewMask.emplace_back(Args&: M);
25770 continue;
25771 }
25772
25773 // From which operand of the shuffle does this shuffle mask element pick?
25774 int WideShufOpIdx = M / WideNumElts;
25775 // Which element of that operand is picked?
25776 int OpEltIdx = M % WideNumElts;
25777
25778 assert((OpEltIdx + WideShufOpIdx * WideNumElts) == M &&
25779 "Shuffle mask vector decomposition failure.");
25780
25781 // And which NumEltsExtracted-sized subvector of that operand is that?
25782 int OpSubvecIdx = OpEltIdx / NumEltsExtracted;
25783 // And which element within that subvector of that operand is that?
25784 int OpEltIdxInSubvec = OpEltIdx % NumEltsExtracted;
25785
25786 assert((OpEltIdxInSubvec + OpSubvecIdx * NumEltsExtracted) == OpEltIdx &&
25787 "Shuffle mask subvector decomposition failure.");
25788
25789 assert((OpEltIdxInSubvec + OpSubvecIdx * NumEltsExtracted +
25790 WideShufOpIdx * WideNumElts) == M &&
25791 "Shuffle mask full decomposition failure.");
25792
25793 SDValue Op = WideShuffleVector->getOperand(Num: WideShufOpIdx);
25794
25795 if (Op.isUndef()) {
25796 // Picking from an undef operand. Let's adjust mask instead.
25797 NewMask.emplace_back(Args: -1);
25798 continue;
25799 }
25800
25801 const std::pair<SDValue, int> DemandedSubvector =
25802 std::make_pair(x&: Op, y&: OpSubvecIdx);
25803
25804 if (DemandedSubvectors.insert(X: DemandedSubvector)) {
25805 if (DemandedSubvectors.size() > 2)
25806 return SDValue(); // We can't handle more than two subvectors.
25807 // How many elements into the WideVT does this subvector start?
25808 int Index = NumEltsExtracted * OpSubvecIdx;
25809 // Bail out if the extraction isn't going to be cheap.
25810 if (!TLI.isExtractSubvectorCheap(ResVT: NarrowVT, SrcVT: WideVT, Index))
25811 return SDValue();
25812 }
25813
25814 // Ok, but from which operand of the new shuffle will this element pick?
25815 int NewOpIdx =
25816 getFirstIndexOf(Range: DemandedSubvectors.getArrayRef(), Val: DemandedSubvector);
25817 assert((NewOpIdx == 0 || NewOpIdx == 1) && "Unexpected operand index.");
25818
25819 int AdjM = OpEltIdxInSubvec + NewOpIdx * NumEltsExtracted;
25820 NewMask.emplace_back(Args&: AdjM);
25821 }
25822 assert(NewMask.size() == (unsigned)NumEltsExtracted && "Produced bad mask.");
25823 assert(DemandedSubvectors.size() <= 2 &&
25824 "Should have ended up demanding at most two subvectors.");
25825
25826 // Did we discover that the shuffle does not actually depend on operands?
25827 if (DemandedSubvectors.empty())
25828 return DAG.getUNDEF(VT: NarrowVT);
25829
25830 // Profitability check: only deal with extractions from the first subvector
25831 // unless the mask becomes an identity mask.
25832 if (!ShuffleVectorInst::isIdentityMask(Mask: NewMask, NumSrcElts: NewMask.size()) ||
25833 any_of(Range&: NewMask, P: [](int M) { return M < 0; }))
25834 for (auto &DemandedSubvector : DemandedSubvectors)
25835 if (DemandedSubvector.second != 0)
25836 return SDValue();
25837
25838 // We still perform the exact same EXTRACT_SUBVECTOR, just on different
25839 // operand[s]/index[es], so there is no point in checking for it's legality.
25840
25841 // Do not turn a legal shuffle into an illegal one.
25842 if (TLI.isShuffleMaskLegal(WideShuffleVector->getMask(), WideVT) &&
25843 !TLI.isShuffleMaskLegal(NewMask, NarrowVT))
25844 return SDValue();
25845
25846 SmallVector<SDValue, 2> NewOps;
25847 for (const std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>
25848 &DemandedSubvector : DemandedSubvectors) {
25849 // How many elements into the WideVT does this subvector start?
25850 int Index = NumEltsExtracted * DemandedSubvector.second;
25851 SDValue IndexC = DAG.getVectorIdxConstant(Val: Index, DL);
25852 NewOps.emplace_back(Args: DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NarrowVT,
25853 N1: DemandedSubvector.first, N2: IndexC));
25854 }
25855 assert((NewOps.size() == 1 || NewOps.size() == 2) &&
25856 "Should end up with either one or two ops");
25857
25858 // If we ended up with only one operand, pad with an undef.
25859 if (NewOps.size() == 1)
25860 NewOps.emplace_back(Args: DAG.getUNDEF(VT: NarrowVT));
25861
25862 return DAG.getVectorShuffle(VT: NarrowVT, dl: DL, N1: NewOps[0], N2: NewOps[1], Mask: NewMask);
25863}
25864
25865SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
25866 EVT NVT = N->getValueType(ResNo: 0);
25867 SDValue V = N->getOperand(Num: 0);
25868 uint64_t ExtIdx = N->getConstantOperandVal(Num: 1);
25869 SDLoc DL(N);
25870
25871 // Extract from UNDEF is UNDEF.
25872 if (V.isUndef())
25873 return DAG.getUNDEF(VT: NVT);
25874
25875 if (SDValue NarrowLoad = narrowExtractedVectorLoad(VT: NVT, Src: V, Index: ExtIdx, DL, DAG))
25876 return NarrowLoad;
25877
25878 // Combine an extract of an extract into a single extract_subvector.
25879 // ext (ext X, C), 0 --> ext X, C
25880 if (ExtIdx == 0 && V.getOpcode() == ISD::EXTRACT_SUBVECTOR && V.hasOneUse()) {
25881 if (TLI.isExtractSubvectorCheap(ResVT: NVT, SrcVT: V.getOperand(i: 0).getValueType(),
25882 Index: V.getConstantOperandVal(i: 1)) &&
25883 TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_SUBVECTOR, VT: NVT)) {
25884 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NVT, N1: V.getOperand(i: 0),
25885 N2: V.getOperand(i: 1));
25886 }
25887 }
25888
25889 // ty1 extract_vector(ty2 splat(V))) -> ty1 splat(V)
25890 if (V.getOpcode() == ISD::SPLAT_VECTOR)
25891 if (DAG.isConstantValueOfAnyType(N: V.getOperand(i: 0)) || V.hasOneUse())
25892 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::SPLAT_VECTOR, VT: NVT))
25893 return DAG.getSplatVector(VT: NVT, DL, Op: V.getOperand(i: 0));
25894
25895 // extract_subvector(insert_subvector(x,y,c1),c2)
25896 // --> extract_subvector(y,c2-c1)
25897 // iff we're just extracting from the inserted subvector.
25898 if (V.getOpcode() == ISD::INSERT_SUBVECTOR) {
25899 SDValue InsSub = V.getOperand(i: 1);
25900 EVT InsSubVT = InsSub.getValueType();
25901 unsigned NumInsElts = InsSubVT.getVectorMinNumElements();
25902 unsigned InsIdx = V.getConstantOperandVal(i: 2);
25903 unsigned NumSubElts = NVT.getVectorMinNumElements();
25904 if (InsIdx <= ExtIdx && (ExtIdx + NumSubElts) <= (InsIdx + NumInsElts) &&
25905 TLI.isExtractSubvectorCheap(ResVT: NVT, SrcVT: InsSubVT, Index: ExtIdx - InsIdx) &&
25906 InsSubVT.isFixedLengthVector() && NVT.isFixedLengthVector() &&
25907 V.getValueType().isFixedLengthVector())
25908 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NVT, N1: InsSub,
25909 N2: DAG.getVectorIdxConstant(Val: ExtIdx - InsIdx, DL));
25910 }
25911
25912 // Try to move vector bitcast after extract_subv by scaling extraction index:
25913 // extract_subv (bitcast X), Index --> bitcast (extract_subv X, Index')
25914 if (V.getOpcode() == ISD::BITCAST &&
25915 V.getOperand(i: 0).getValueType().isVector() &&
25916 (!LegalOperations || TLI.isOperationLegal(Op: ISD::BITCAST, VT: NVT))) {
25917 SDValue SrcOp = V.getOperand(i: 0);
25918 EVT SrcVT = SrcOp.getValueType();
25919 unsigned SrcNumElts = SrcVT.getVectorMinNumElements();
25920 unsigned DestNumElts = V.getValueType().getVectorMinNumElements();
25921 if ((SrcNumElts % DestNumElts) == 0) {
25922 unsigned SrcDestRatio = SrcNumElts / DestNumElts;
25923 ElementCount NewExtEC = NVT.getVectorElementCount() * SrcDestRatio;
25924 EVT NewExtVT =
25925 EVT::getVectorVT(Context&: *DAG.getContext(), VT: SrcVT.getScalarType(), EC: NewExtEC);
25926 if (TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_SUBVECTOR, VT: NewExtVT)) {
25927 SDValue NewIndex = DAG.getVectorIdxConstant(Val: ExtIdx * SrcDestRatio, DL);
25928 SDValue NewExtract = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NewExtVT,
25929 N1: V.getOperand(i: 0), N2: NewIndex);
25930 return DAG.getBitcast(VT: NVT, V: NewExtract);
25931 }
25932 }
25933 if ((DestNumElts % SrcNumElts) == 0) {
25934 unsigned DestSrcRatio = DestNumElts / SrcNumElts;
25935 if (NVT.getVectorElementCount().isKnownMultipleOf(RHS: DestSrcRatio)) {
25936 ElementCount NewExtEC =
25937 NVT.getVectorElementCount().divideCoefficientBy(RHS: DestSrcRatio);
25938 EVT ScalarVT = SrcVT.getScalarType();
25939 if ((ExtIdx % DestSrcRatio) == 0) {
25940 unsigned IndexValScaled = ExtIdx / DestSrcRatio;
25941 EVT NewExtVT =
25942 EVT::getVectorVT(Context&: *DAG.getContext(), VT: ScalarVT, EC: NewExtEC);
25943 if (TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_SUBVECTOR, VT: NewExtVT)) {
25944 SDValue NewIndex = DAG.getVectorIdxConstant(Val: IndexValScaled, DL);
25945 SDValue NewExtract =
25946 DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NewExtVT,
25947 N1: V.getOperand(i: 0), N2: NewIndex);
25948 return DAG.getBitcast(VT: NVT, V: NewExtract);
25949 }
25950 if (NewExtEC.isScalar() &&
25951 TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_VECTOR_ELT, VT: ScalarVT)) {
25952 SDValue NewIndex = DAG.getVectorIdxConstant(Val: IndexValScaled, DL);
25953 SDValue NewExtract =
25954 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: ScalarVT,
25955 N1: V.getOperand(i: 0), N2: NewIndex);
25956 return DAG.getBitcast(VT: NVT, V: NewExtract);
25957 }
25958 }
25959 }
25960 }
25961 }
25962
25963 if (V.getOpcode() == ISD::CONCAT_VECTORS) {
25964 unsigned ExtNumElts = NVT.getVectorMinNumElements();
25965 EVT ConcatSrcVT = V.getOperand(i: 0).getValueType();
25966 assert(ConcatSrcVT.getVectorElementType() == NVT.getVectorElementType() &&
25967 "Concat and extract subvector do not change element type");
25968 assert((ExtIdx % ExtNumElts) == 0 &&
25969 "Extract index is not a multiple of the input vector length.");
25970
25971 unsigned ConcatSrcNumElts = ConcatSrcVT.getVectorMinNumElements();
25972 unsigned ConcatOpIdx = ExtIdx / ConcatSrcNumElts;
25973
25974 // If the concatenated source types match this extract, it's a direct
25975 // simplification:
25976 // extract_subvec (concat V1, V2, ...), i --> Vi
25977 if (NVT.getVectorElementCount() == ConcatSrcVT.getVectorElementCount())
25978 return V.getOperand(i: ConcatOpIdx);
25979
25980 // If the concatenated source vectors are a multiple length of this extract,
25981 // then extract a fraction of one of those source vectors directly from a
25982 // concat operand. Example:
25983 // v2i8 extract_subvec (v16i8 concat (v8i8 X), (v8i8 Y), 14 -->
25984 // v2i8 extract_subvec v8i8 Y, 6
25985 if (NVT.isFixedLengthVector() && ConcatSrcVT.isFixedLengthVector() &&
25986 ConcatSrcNumElts % ExtNumElts == 0) {
25987 unsigned NewExtIdx = ExtIdx - ConcatOpIdx * ConcatSrcNumElts;
25988 assert(NewExtIdx + ExtNumElts <= ConcatSrcNumElts &&
25989 "Trying to extract from >1 concat operand?");
25990 assert(NewExtIdx % ExtNumElts == 0 &&
25991 "Extract index is not a multiple of the input vector length.");
25992 SDValue NewIndexC = DAG.getVectorIdxConstant(Val: NewExtIdx, DL);
25993 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NVT,
25994 N1: V.getOperand(i: ConcatOpIdx), N2: NewIndexC);
25995 }
25996 }
25997
25998 if (SDValue Shuffle = foldExtractSubvectorFromShuffleVector(
25999 NarrowVT: NVT, Src: V, Index: ExtIdx, DL, DAG, LegalOperations))
26000 return Shuffle;
26001
26002 if (SDValue NarrowBOp =
26003 narrowExtractedVectorBinOp(VT: NVT, Src: V, Index: ExtIdx, DL, DAG, LegalOperations))
26004 return NarrowBOp;
26005
26006 V = peekThroughBitcasts(V);
26007
26008 // If the input is a build vector. Try to make a smaller build vector.
26009 if (V.getOpcode() == ISD::BUILD_VECTOR) {
26010 EVT InVT = V.getValueType();
26011 unsigned ExtractSize = NVT.getSizeInBits();
26012 unsigned EltSize = InVT.getScalarSizeInBits();
26013 // Only do this if we won't split any elements.
26014 if (ExtractSize % EltSize == 0) {
26015 unsigned NumElems = ExtractSize / EltSize;
26016 EVT EltVT = InVT.getVectorElementType();
26017 EVT ExtractVT =
26018 NumElems == 1 ? EltVT
26019 : EVT::getVectorVT(Context&: *DAG.getContext(), VT: EltVT, NumElements: NumElems);
26020 if ((Level < AfterLegalizeDAG ||
26021 (NumElems == 1 ||
26022 TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT: ExtractVT))) &&
26023 (!LegalTypes || TLI.isTypeLegal(VT: ExtractVT))) {
26024 unsigned IdxVal = (ExtIdx * NVT.getScalarSizeInBits()) / EltSize;
26025
26026 if (NumElems == 1) {
26027 SDValue Src = V->getOperand(Num: IdxVal);
26028 if (EltVT != Src.getValueType())
26029 Src = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: EltVT, Operand: Src);
26030 return DAG.getBitcast(VT: NVT, V: Src);
26031 }
26032
26033 // Extract the pieces from the original build_vector.
26034 SDValue BuildVec =
26035 DAG.getBuildVector(VT: ExtractVT, DL, Ops: V->ops().slice(N: IdxVal, M: NumElems));
26036 return DAG.getBitcast(VT: NVT, V: BuildVec);
26037 }
26038 }
26039 }
26040
26041 if (V.getOpcode() == ISD::INSERT_SUBVECTOR) {
26042 // Handle only simple case where vector being inserted and vector
26043 // being extracted are of same size.
26044 EVT SmallVT = V.getOperand(i: 1).getValueType();
26045 if (NVT.bitsEq(VT: SmallVT)) {
26046 // Combine:
26047 // (extract_subvec (insert_subvec V1, V2, InsIdx), ExtIdx)
26048 // Into:
26049 // indices are equal or bit offsets are equal => V1
26050 // otherwise => (extract_subvec V1, ExtIdx)
26051 uint64_t InsIdx = V.getConstantOperandVal(i: 2);
26052 if (InsIdx * SmallVT.getScalarSizeInBits() ==
26053 ExtIdx * NVT.getScalarSizeInBits()) {
26054 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::BITCAST, VT: NVT))
26055 return DAG.getBitcast(VT: NVT, V: V.getOperand(i: 1));
26056 } else {
26057 return DAG.getNode(
26058 Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NVT,
26059 N1: DAG.getBitcast(VT: N->getOperand(Num: 0).getValueType(), V: V.getOperand(i: 0)),
26060 N2: N->getOperand(Num: 1));
26061 }
26062 }
26063 }
26064
26065 // If only EXTRACT_SUBVECTOR nodes use the source vector we can
26066 // simplify it based on the (valid) extractions.
26067 if (!V.getValueType().isScalableVector() &&
26068 llvm::all_of(Range: V->users(), P: [&](SDNode *Use) {
26069 return Use->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
26070 Use->getOperand(Num: 0) == V;
26071 })) {
26072 unsigned NumElts = V.getValueType().getVectorNumElements();
26073 APInt DemandedElts = APInt::getZero(numBits: NumElts);
26074 for (SDNode *User : V->users()) {
26075 unsigned ExtIdx = User->getConstantOperandVal(Num: 1);
26076 unsigned NumSubElts = User->getValueType(ResNo: 0).getVectorNumElements();
26077 DemandedElts.setBits(loBit: ExtIdx, hiBit: ExtIdx + NumSubElts);
26078 }
26079 if (SimplifyDemandedVectorElts(Op: V, DemandedElts, /*AssumeSingleUse=*/true)) {
26080 // We simplified the vector operand of this extract subvector. If this
26081 // extract is not dead, visit it again so it is folded properly.
26082 if (N->getOpcode() != ISD::DELETED_NODE)
26083 AddToWorklist(N);
26084 return SDValue(N, 0);
26085 }
26086 } else {
26087 if (SimplifyDemandedVectorElts(Op: SDValue(N, 0)))
26088 return SDValue(N, 0);
26089 }
26090
26091 return SDValue();
26092}
26093
26094/// Try to convert a wide shuffle of concatenated vectors into 2 narrow shuffles
26095/// followed by concatenation. Narrow vector ops may have better performance
26096/// than wide ops, and this can unlock further narrowing of other vector ops.
26097/// Targets can invert this transform later if it is not profitable.
26098static SDValue foldShuffleOfConcatUndefs(ShuffleVectorSDNode *Shuf,
26099 SelectionDAG &DAG) {
26100 SDValue N0 = Shuf->getOperand(Num: 0), N1 = Shuf->getOperand(Num: 1);
26101 if (N0.getOpcode() != ISD::CONCAT_VECTORS || N0.getNumOperands() != 2 ||
26102 N1.getOpcode() != ISD::CONCAT_VECTORS || N1.getNumOperands() != 2 ||
26103 !N0.getOperand(i: 1).isUndef() || !N1.getOperand(i: 1).isUndef())
26104 return SDValue();
26105
26106 // Split the wide shuffle mask into halves. Any mask element that is accessing
26107 // operand 1 is offset down to account for narrowing of the vectors.
26108 ArrayRef<int> Mask = Shuf->getMask();
26109 EVT VT = Shuf->getValueType(ResNo: 0);
26110 unsigned NumElts = VT.getVectorNumElements();
26111 unsigned HalfNumElts = NumElts / 2;
26112 SmallVector<int, 16> Mask0(HalfNumElts, -1);
26113 SmallVector<int, 16> Mask1(HalfNumElts, -1);
26114 for (unsigned i = 0; i != NumElts; ++i) {
26115 if (Mask[i] == -1)
26116 continue;
26117 // If we reference the upper (undef) subvector then the element is undef.
26118 if ((Mask[i] % NumElts) >= HalfNumElts)
26119 continue;
26120 int M = Mask[i] < (int)NumElts ? Mask[i] : Mask[i] - (int)HalfNumElts;
26121 if (i < HalfNumElts)
26122 Mask0[i] = M;
26123 else
26124 Mask1[i - HalfNumElts] = M;
26125 }
26126
26127 // Ask the target if this is a valid transform.
26128 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
26129 EVT HalfVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: VT.getScalarType(),
26130 NumElements: HalfNumElts);
26131 if (!TLI.isShuffleMaskLegal(Mask0, HalfVT) ||
26132 !TLI.isShuffleMaskLegal(Mask1, HalfVT))
26133 return SDValue();
26134
26135 // shuffle (concat X, undef), (concat Y, undef), Mask -->
26136 // concat (shuffle X, Y, Mask0), (shuffle X, Y, Mask1)
26137 SDValue X = N0.getOperand(i: 0), Y = N1.getOperand(i: 0);
26138 SDLoc DL(Shuf);
26139 SDValue Shuf0 = DAG.getVectorShuffle(VT: HalfVT, dl: DL, N1: X, N2: Y, Mask: Mask0);
26140 SDValue Shuf1 = DAG.getVectorShuffle(VT: HalfVT, dl: DL, N1: X, N2: Y, Mask: Mask1);
26141 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT, N1: Shuf0, N2: Shuf1);
26142}
26143
26144// Tries to turn a shuffle of two CONCAT_VECTORS into a single concat,
26145// or turn a shuffle of a single concat into simpler shuffle then concat.
26146static SDValue partitionShuffleOfConcats(SDNode *N, SelectionDAG &DAG) {
26147 EVT VT = N->getValueType(ResNo: 0);
26148 unsigned NumElts = VT.getVectorNumElements();
26149
26150 SDValue N0 = N->getOperand(Num: 0);
26151 SDValue N1 = N->getOperand(Num: 1);
26152 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Val: N);
26153 ArrayRef<int> Mask = SVN->getMask();
26154
26155 SmallVector<SDValue, 4> Ops;
26156 EVT ConcatVT = N0.getOperand(i: 0).getValueType();
26157 unsigned NumElemsPerConcat = ConcatVT.getVectorNumElements();
26158 unsigned NumConcats = NumElts / NumElemsPerConcat;
26159
26160 auto IsUndefMaskElt = [](int i) { return i == -1; };
26161
26162 // Special case: shuffle(concat(A,B)) can be more efficiently represented
26163 // as concat(shuffle(A,B),UNDEF) if the shuffle doesn't set any of the high
26164 // half vector elements.
26165 if (NumElemsPerConcat * 2 == NumElts && N1.isUndef() &&
26166 llvm::all_of(Range: Mask.slice(N: NumElemsPerConcat, M: NumElemsPerConcat),
26167 P: IsUndefMaskElt)) {
26168 N0 = DAG.getVectorShuffle(VT: ConcatVT, dl: SDLoc(N), N1: N0.getOperand(i: 0),
26169 N2: N0.getOperand(i: 1),
26170 Mask: Mask.slice(N: 0, M: NumElemsPerConcat));
26171 N1 = DAG.getUNDEF(VT: ConcatVT);
26172 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT, N1: N0, N2: N1);
26173 }
26174
26175 // Look at every vector that's inserted. We're looking for exact
26176 // subvector-sized copies from a concatenated vector
26177 for (unsigned I = 0; I != NumConcats; ++I) {
26178 unsigned Begin = I * NumElemsPerConcat;
26179 ArrayRef<int> SubMask = Mask.slice(N: Begin, M: NumElemsPerConcat);
26180
26181 // Make sure we're dealing with a copy.
26182 if (llvm::all_of(Range&: SubMask, P: IsUndefMaskElt)) {
26183 Ops.push_back(Elt: DAG.getUNDEF(VT: ConcatVT));
26184 continue;
26185 }
26186
26187 int OpIdx = -1;
26188 for (int i = 0; i != (int)NumElemsPerConcat; ++i) {
26189 if (IsUndefMaskElt(SubMask[i]))
26190 continue;
26191 if ((SubMask[i] % (int)NumElemsPerConcat) != i)
26192 return SDValue();
26193 int EltOpIdx = SubMask[i] / NumElemsPerConcat;
26194 if (0 <= OpIdx && EltOpIdx != OpIdx)
26195 return SDValue();
26196 OpIdx = EltOpIdx;
26197 }
26198 assert(0 <= OpIdx && "Unknown concat_vectors op");
26199
26200 if (OpIdx < (int)N0.getNumOperands())
26201 Ops.push_back(Elt: N0.getOperand(i: OpIdx));
26202 else
26203 Ops.push_back(Elt: N1.getOperand(i: OpIdx - N0.getNumOperands()));
26204 }
26205
26206 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT, Ops);
26207}
26208
26209// Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
26210// BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
26211//
26212// SHUFFLE(BUILD_VECTOR(), BUILD_VECTOR()) -> BUILD_VECTOR() is always
26213// a simplification in some sense, but it isn't appropriate in general: some
26214// BUILD_VECTORs are substantially cheaper than others. The general case
26215// of a BUILD_VECTOR requires inserting each element individually (or
26216// performing the equivalent in a temporary stack variable). A BUILD_VECTOR of
26217// all constants is a single constant pool load. A BUILD_VECTOR where each
26218// element is identical is a splat. A BUILD_VECTOR where most of the operands
26219// are undef lowers to a small number of element insertions.
26220//
26221// To deal with this, we currently use a bunch of mostly arbitrary heuristics.
26222// We don't fold shuffles where one side is a non-zero constant, and we don't
26223// fold shuffles if the resulting (non-splat) BUILD_VECTOR would have duplicate
26224// non-constant operands. This seems to work out reasonably well in practice.
26225static SDValue combineShuffleOfScalars(ShuffleVectorSDNode *SVN,
26226 SelectionDAG &DAG,
26227 const TargetLowering &TLI) {
26228 EVT VT = SVN->getValueType(ResNo: 0);
26229 unsigned NumElts = VT.getVectorNumElements();
26230 SDValue N0 = SVN->getOperand(Num: 0);
26231 SDValue N1 = SVN->getOperand(Num: 1);
26232
26233 if (!N0->hasOneUse())
26234 return SDValue();
26235
26236 // If only one of N1,N2 is constant, bail out if it is not ALL_ZEROS as
26237 // discussed above.
26238 if (!N1.isUndef()) {
26239 if (!N1->hasOneUse())
26240 return SDValue();
26241
26242 bool N0AnyConst = isAnyConstantBuildVector(V: N0);
26243 bool N1AnyConst = isAnyConstantBuildVector(V: N1);
26244 if (N0AnyConst && !N1AnyConst && !ISD::isBuildVectorAllZeros(N: N0.getNode()))
26245 return SDValue();
26246 if (!N0AnyConst && N1AnyConst && !ISD::isBuildVectorAllZeros(N: N1.getNode()))
26247 return SDValue();
26248 }
26249
26250 // If both inputs are splats of the same value then we can safely merge this
26251 // to a single BUILD_VECTOR with undef elements based on the shuffle mask.
26252 bool IsSplat = false;
26253 auto *BV0 = dyn_cast<BuildVectorSDNode>(Val&: N0);
26254 auto *BV1 = dyn_cast<BuildVectorSDNode>(Val&: N1);
26255 if (BV0 && BV1)
26256 if (SDValue Splat0 = BV0->getSplatValue())
26257 IsSplat = (Splat0 == BV1->getSplatValue());
26258
26259 SmallVector<SDValue, 8> Ops;
26260 SmallSet<SDValue, 16> DuplicateOps;
26261 for (int M : SVN->getMask()) {
26262 SDValue Op = DAG.getUNDEF(VT: VT.getScalarType());
26263 if (M >= 0) {
26264 int Idx = M < (int)NumElts ? M : M - NumElts;
26265 SDValue &S = (M < (int)NumElts ? N0 : N1);
26266 if (S.getOpcode() == ISD::BUILD_VECTOR) {
26267 Op = S.getOperand(i: Idx);
26268 } else if (S.getOpcode() == ISD::SCALAR_TO_VECTOR) {
26269 SDValue Op0 = S.getOperand(i: 0);
26270 Op = Idx == 0 ? Op0 : DAG.getUNDEF(VT: Op0.getValueType());
26271 } else {
26272 // Operand can't be combined - bail out.
26273 return SDValue();
26274 }
26275 }
26276
26277 // Don't duplicate a non-constant BUILD_VECTOR operand unless we're
26278 // generating a splat; semantically, this is fine, but it's likely to
26279 // generate low-quality code if the target can't reconstruct an appropriate
26280 // shuffle.
26281 if (!Op.isUndef() && !isIntOrFPConstant(V: Op))
26282 if (!IsSplat && !DuplicateOps.insert(V: Op).second)
26283 return SDValue();
26284
26285 Ops.push_back(Elt: Op);
26286 }
26287
26288 // BUILD_VECTOR requires all inputs to be of the same type, find the
26289 // maximum type and extend them all.
26290 EVT SVT = VT.getScalarType();
26291 if (SVT.isInteger())
26292 for (SDValue &Op : Ops)
26293 SVT = (SVT.bitsLT(VT: Op.getValueType()) ? Op.getValueType() : SVT);
26294 if (SVT != VT.getScalarType())
26295 for (SDValue &Op : Ops)
26296 Op = Op.isUndef() ? DAG.getUNDEF(VT: SVT)
26297 : (TLI.isZExtFree(FromTy: Op.getValueType(), ToTy: SVT)
26298 ? DAG.getZExtOrTrunc(Op, DL: SDLoc(SVN), VT: SVT)
26299 : DAG.getSExtOrTrunc(Op, DL: SDLoc(SVN), VT: SVT));
26300 return DAG.getBuildVector(VT, DL: SDLoc(SVN), Ops);
26301}
26302
26303// Match shuffles that can be converted to *_vector_extend_in_reg.
26304// This is often generated during legalization.
26305// e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src)),
26306// and returns the EVT to which the extension should be performed.
26307// NOTE: this assumes that the src is the first operand of the shuffle.
26308static std::optional<EVT> canCombineShuffleToExtendVectorInreg(
26309 unsigned Opcode, EVT VT, std::function<bool(unsigned)> Match,
26310 SelectionDAG &DAG, const TargetLowering &TLI, bool LegalTypes,
26311 bool LegalOperations) {
26312 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
26313
26314 // TODO Add support for big-endian when we have a test case.
26315 if (!VT.isInteger() || IsBigEndian)
26316 return std::nullopt;
26317
26318 unsigned NumElts = VT.getVectorNumElements();
26319 unsigned EltSizeInBits = VT.getScalarSizeInBits();
26320
26321 // Attempt to match a '*_extend_vector_inreg' shuffle, we just search for
26322 // power-of-2 extensions as they are the most likely.
26323 // FIXME: should try Scale == NumElts case too,
26324 for (unsigned Scale = 2; Scale < NumElts; Scale *= 2) {
26325 // The vector width must be a multiple of Scale.
26326 if (NumElts % Scale != 0)
26327 continue;
26328
26329 EVT OutSVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: EltSizeInBits * Scale);
26330 EVT OutVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: OutSVT, NumElements: NumElts / Scale);
26331
26332 if ((LegalTypes && !TLI.isTypeLegal(VT: OutVT)) ||
26333 (LegalOperations && !TLI.isOperationLegalOrCustom(Op: Opcode, VT: OutVT)))
26334 continue;
26335
26336 if (Match(Scale))
26337 return OutVT;
26338 }
26339
26340 return std::nullopt;
26341}
26342
26343// Match shuffles that can be converted to any_vector_extend_in_reg.
26344// This is often generated during legalization.
26345// e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src))
26346static SDValue combineShuffleToAnyExtendVectorInreg(ShuffleVectorSDNode *SVN,
26347 SelectionDAG &DAG,
26348 const TargetLowering &TLI,
26349 bool LegalOperations) {
26350 EVT VT = SVN->getValueType(ResNo: 0);
26351 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
26352
26353 // TODO Add support for big-endian when we have a test case.
26354 if (!VT.isInteger() || IsBigEndian)
26355 return SDValue();
26356
26357 // shuffle<0,-1,1,-1> == (v2i64 anyextend_vector_inreg(v4i32))
26358 auto isAnyExtend = [NumElts = VT.getVectorNumElements(),
26359 Mask = SVN->getMask()](unsigned Scale) {
26360 for (unsigned i = 0; i != NumElts; ++i) {
26361 if (Mask[i] < 0)
26362 continue;
26363 if ((i % Scale) == 0 && Mask[i] == (int)(i / Scale))
26364 continue;
26365 return false;
26366 }
26367 return true;
26368 };
26369
26370 unsigned Opcode = ISD::ANY_EXTEND_VECTOR_INREG;
26371 SDValue N0 = SVN->getOperand(Num: 0);
26372 // Never create an illegal type. Only create unsupported operations if we
26373 // are pre-legalization.
26374 std::optional<EVT> OutVT = canCombineShuffleToExtendVectorInreg(
26375 Opcode, VT, Match: isAnyExtend, DAG, TLI, /*LegalTypes=*/true, LegalOperations);
26376 if (!OutVT)
26377 return SDValue();
26378 return DAG.getBitcast(VT, V: DAG.getNode(Opcode, DL: SDLoc(SVN), VT: *OutVT, Operand: N0));
26379}
26380
26381// Match shuffles that can be converted to zero_extend_vector_inreg.
26382// This is often generated during legalization.
26383// e.g. v4i32 <0,z,1,u> -> (v2i64 zero_extend_vector_inreg(v4i32 src))
26384static SDValue combineShuffleToZeroExtendVectorInReg(ShuffleVectorSDNode *SVN,
26385 SelectionDAG &DAG,
26386 const TargetLowering &TLI,
26387 bool LegalOperations) {
26388 bool LegalTypes = true;
26389 EVT VT = SVN->getValueType(ResNo: 0);
26390 assert(!VT.isScalableVector() && "Encountered scalable shuffle?");
26391 unsigned NumElts = VT.getVectorNumElements();
26392 unsigned EltSizeInBits = VT.getScalarSizeInBits();
26393
26394 // TODO: add support for big-endian when we have a test case.
26395 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
26396 if (!VT.isInteger() || IsBigEndian)
26397 return SDValue();
26398
26399 SmallVector<int, 16> Mask(SVN->getMask());
26400 auto ForEachDecomposedIndice = [NumElts, &Mask](auto Fn) {
26401 for (int &Indice : Mask) {
26402 if (Indice < 0)
26403 continue;
26404 int OpIdx = (unsigned)Indice < NumElts ? 0 : 1;
26405 int OpEltIdx = (unsigned)Indice < NumElts ? Indice : Indice - NumElts;
26406 Fn(Indice, OpIdx, OpEltIdx);
26407 }
26408 };
26409
26410 // Which elements of which operand does this shuffle demand?
26411 std::array<APInt, 2> OpsDemandedElts;
26412 for (APInt &OpDemandedElts : OpsDemandedElts)
26413 OpDemandedElts = APInt::getZero(numBits: NumElts);
26414 ForEachDecomposedIndice(
26415 [&OpsDemandedElts](int &Indice, int OpIdx, int OpEltIdx) {
26416 OpsDemandedElts[OpIdx].setBit(OpEltIdx);
26417 });
26418
26419 // Element-wise(!), which of these demanded elements are know to be zero?
26420 std::array<APInt, 2> OpsKnownZeroElts;
26421 for (auto I : zip(t: SVN->ops(), u&: OpsDemandedElts, args&: OpsKnownZeroElts))
26422 std::get<2>(t&: I) =
26423 DAG.computeVectorKnownZeroElements(Op: std::get<0>(t&: I), DemandedElts: std::get<1>(t&: I));
26424
26425 // Manifest zeroable element knowledge in the shuffle mask.
26426 // NOTE: we don't have 'zeroable' sentinel value in generic DAG,
26427 // this is a local invention, but it won't leak into DAG.
26428 // FIXME: should we not manifest them, but just check when matching?
26429 bool HadZeroableElts = false;
26430 ForEachDecomposedIndice([&OpsKnownZeroElts, &HadZeroableElts](
26431 int &Indice, int OpIdx, int OpEltIdx) {
26432 if (OpsKnownZeroElts[OpIdx][OpEltIdx]) {
26433 Indice = -2; // Zeroable element.
26434 HadZeroableElts = true;
26435 }
26436 });
26437
26438 // Don't proceed unless we've refined at least one zeroable mask indice.
26439 // If we didn't, then we are still trying to match the same shuffle mask
26440 // we previously tried to match as ISD::ANY_EXTEND_VECTOR_INREG,
26441 // and evidently failed. Proceeding will lead to endless combine loops.
26442 if (!HadZeroableElts)
26443 return SDValue();
26444
26445 // The shuffle may be more fine-grained than we want. Widen elements first.
26446 // FIXME: should we do this before manifesting zeroable shuffle mask indices?
26447 SmallVector<int, 16> ScaledMask;
26448 getShuffleMaskWithWidestElts(Mask, ScaledMask);
26449 assert(Mask.size() >= ScaledMask.size() &&
26450 Mask.size() % ScaledMask.size() == 0 && "Unexpected mask widening.");
26451 int Prescale = Mask.size() / ScaledMask.size();
26452
26453 NumElts = ScaledMask.size();
26454 EltSizeInBits *= Prescale;
26455
26456 EVT PrescaledVT = EVT::getVectorVT(
26457 Context&: *DAG.getContext(), VT: EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: EltSizeInBits),
26458 NumElements: NumElts);
26459
26460 if (LegalTypes && !TLI.isTypeLegal(VT: PrescaledVT) && TLI.isTypeLegal(VT))
26461 return SDValue();
26462
26463 // For example,
26464 // shuffle<0,z,1,-1> == (v2i64 zero_extend_vector_inreg(v4i32))
26465 // But not shuffle<z,z,1,-1> and not shuffle<0,z,z,-1> ! (for same types)
26466 auto isZeroExtend = [NumElts, &ScaledMask](unsigned Scale) {
26467 assert(Scale >= 2 && Scale <= NumElts && NumElts % Scale == 0 &&
26468 "Unexpected mask scaling factor.");
26469 ArrayRef<int> Mask = ScaledMask;
26470 for (unsigned SrcElt = 0, NumSrcElts = NumElts / Scale;
26471 SrcElt != NumSrcElts; ++SrcElt) {
26472 // Analyze the shuffle mask in Scale-sized chunks.
26473 ArrayRef<int> MaskChunk = Mask.take_front(N: Scale);
26474 assert(MaskChunk.size() == Scale && "Unexpected mask size.");
26475 Mask = Mask.drop_front(N: MaskChunk.size());
26476 // The first indice in this chunk must be SrcElt, but not zero!
26477 // FIXME: undef should be fine, but that results in more-defined result.
26478 if (int FirstIndice = MaskChunk[0]; (unsigned)FirstIndice != SrcElt)
26479 return false;
26480 // The rest of the indices in this chunk must be zeros.
26481 // FIXME: undef should be fine, but that results in more-defined result.
26482 if (!all_of(Range: MaskChunk.drop_front(N: 1),
26483 P: [](int Indice) { return Indice == -2; }))
26484 return false;
26485 }
26486 assert(Mask.empty() && "Did not process the whole mask?");
26487 return true;
26488 };
26489
26490 unsigned Opcode = ISD::ZERO_EXTEND_VECTOR_INREG;
26491 for (bool Commuted : {false, true}) {
26492 SDValue Op = SVN->getOperand(Num: !Commuted ? 0 : 1);
26493 if (Commuted)
26494 ShuffleVectorSDNode::commuteMask(Mask: ScaledMask);
26495 std::optional<EVT> OutVT = canCombineShuffleToExtendVectorInreg(
26496 Opcode, VT: PrescaledVT, Match: isZeroExtend, DAG, TLI, LegalTypes,
26497 LegalOperations);
26498 if (OutVT)
26499 return DAG.getBitcast(VT, V: DAG.getNode(Opcode, DL: SDLoc(SVN), VT: *OutVT,
26500 Operand: DAG.getBitcast(VT: PrescaledVT, V: Op)));
26501 }
26502 return SDValue();
26503}
26504
26505// Detect 'truncate_vector_inreg' style shuffles that pack the lower parts of
26506// each source element of a large type into the lowest elements of a smaller
26507// destination type. This is often generated during legalization.
26508// If the source node itself was a '*_extend_vector_inreg' node then we should
26509// then be able to remove it.
26510static SDValue combineTruncationShuffle(ShuffleVectorSDNode *SVN,
26511 SelectionDAG &DAG) {
26512 EVT VT = SVN->getValueType(ResNo: 0);
26513 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
26514
26515 // TODO Add support for big-endian when we have a test case.
26516 if (!VT.isInteger() || IsBigEndian)
26517 return SDValue();
26518
26519 SDValue N0 = peekThroughBitcasts(V: SVN->getOperand(Num: 0));
26520
26521 unsigned Opcode = N0.getOpcode();
26522 if (!ISD::isExtVecInRegOpcode(Opcode))
26523 return SDValue();
26524
26525 SDValue N00 = N0.getOperand(i: 0);
26526 ArrayRef<int> Mask = SVN->getMask();
26527 unsigned NumElts = VT.getVectorNumElements();
26528 unsigned EltSizeInBits = VT.getScalarSizeInBits();
26529 unsigned ExtSrcSizeInBits = N00.getScalarValueSizeInBits();
26530 unsigned ExtDstSizeInBits = N0.getScalarValueSizeInBits();
26531
26532 if (ExtDstSizeInBits % ExtSrcSizeInBits != 0)
26533 return SDValue();
26534 unsigned ExtScale = ExtDstSizeInBits / ExtSrcSizeInBits;
26535
26536 // (v4i32 truncate_vector_inreg(v2i64)) == shuffle<0,2-1,-1>
26537 // (v8i16 truncate_vector_inreg(v4i32)) == shuffle<0,2,4,6,-1,-1,-1,-1>
26538 // (v8i16 truncate_vector_inreg(v2i64)) == shuffle<0,4,-1,-1,-1,-1,-1,-1>
26539 auto isTruncate = [&Mask, &NumElts](unsigned Scale) {
26540 for (unsigned i = 0; i != NumElts; ++i) {
26541 if (Mask[i] < 0)
26542 continue;
26543 if ((i * Scale) < NumElts && Mask[i] == (int)(i * Scale))
26544 continue;
26545 return false;
26546 }
26547 return true;
26548 };
26549
26550 // At the moment we just handle the case where we've truncated back to the
26551 // same size as before the extension.
26552 // TODO: handle more extension/truncation cases as cases arise.
26553 if (EltSizeInBits != ExtSrcSizeInBits)
26554 return SDValue();
26555
26556 // We can remove *extend_vector_inreg only if the truncation happens at
26557 // the same scale as the extension.
26558 if (isTruncate(ExtScale))
26559 return DAG.getBitcast(VT, V: N00);
26560
26561 return SDValue();
26562}
26563
26564// Combine shuffles of splat-shuffles of the form:
26565// shuffle (shuffle V, undef, splat-mask), undef, M
26566// If splat-mask contains undef elements, we need to be careful about
26567// introducing undef's in the folded mask which are not the result of composing
26568// the masks of the shuffles.
26569static SDValue combineShuffleOfSplatVal(ShuffleVectorSDNode *Shuf,
26570 SelectionDAG &DAG) {
26571 EVT VT = Shuf->getValueType(ResNo: 0);
26572 unsigned NumElts = VT.getVectorNumElements();
26573
26574 if (!Shuf->getOperand(Num: 1).isUndef())
26575 return SDValue();
26576
26577 // See if this unary non-splat shuffle actually *is* a splat shuffle,
26578 // in disguise, with all demanded elements being identical.
26579 // FIXME: this can be done per-operand.
26580 if (!Shuf->isSplat()) {
26581 APInt DemandedElts(NumElts, 0);
26582 for (int Idx : Shuf->getMask()) {
26583 if (Idx < 0)
26584 continue; // Ignore sentinel indices.
26585 assert((unsigned)Idx < NumElts && "Out-of-bounds shuffle indice?");
26586 DemandedElts.setBit(Idx);
26587 }
26588 assert(DemandedElts.popcount() > 1 && "Is a splat shuffle already?");
26589 APInt UndefElts;
26590 if (DAG.isSplatValue(V: Shuf->getOperand(Num: 0), DemandedElts, UndefElts)) {
26591 // Even if all demanded elements are splat, some of them could be undef.
26592 // Which lowest demanded element is *not* known-undef?
26593 std::optional<unsigned> MinNonUndefIdx;
26594 for (int Idx : Shuf->getMask()) {
26595 if (Idx < 0 || UndefElts[Idx])
26596 continue; // Ignore sentinel indices, and undef elements.
26597 MinNonUndefIdx = std::min<unsigned>(a: Idx, b: MinNonUndefIdx.value_or(u: ~0U));
26598 }
26599 if (!MinNonUndefIdx)
26600 return DAG.getUNDEF(VT); // All undef - result is undef.
26601 assert(*MinNonUndefIdx < NumElts && "Expected valid element index.");
26602 SmallVector<int, 8> SplatMask(Shuf->getMask());
26603 for (int &Idx : SplatMask) {
26604 if (Idx < 0)
26605 continue; // Passthrough sentinel indices.
26606 // Otherwise, just pick the lowest demanded non-undef element.
26607 // Or sentinel undef, if we know we'd pick a known-undef element.
26608 Idx = UndefElts[Idx] ? -1 : *MinNonUndefIdx;
26609 }
26610 assert(SplatMask != Shuf->getMask() && "Expected mask to change!");
26611 return DAG.getVectorShuffle(VT, dl: SDLoc(Shuf), N1: Shuf->getOperand(Num: 0),
26612 N2: Shuf->getOperand(Num: 1), Mask: SplatMask);
26613 }
26614 }
26615
26616 // If the inner operand is a known splat with no undefs, just return that directly.
26617 // TODO: Create DemandedElts mask from Shuf's mask.
26618 // TODO: Allow undef elements and merge with the shuffle code below.
26619 if (DAG.isSplatValue(V: Shuf->getOperand(Num: 0), /*AllowUndefs*/ false))
26620 return Shuf->getOperand(Num: 0);
26621
26622 auto *Splat = dyn_cast<ShuffleVectorSDNode>(Val: Shuf->getOperand(Num: 0));
26623 if (!Splat || !Splat->isSplat())
26624 return SDValue();
26625
26626 ArrayRef<int> ShufMask = Shuf->getMask();
26627 ArrayRef<int> SplatMask = Splat->getMask();
26628 assert(ShufMask.size() == SplatMask.size() && "Mask length mismatch");
26629
26630 // Prefer simplifying to the splat-shuffle, if possible. This is legal if
26631 // every undef mask element in the splat-shuffle has a corresponding undef
26632 // element in the user-shuffle's mask or if the composition of mask elements
26633 // would result in undef.
26634 // Examples for (shuffle (shuffle v, undef, SplatMask), undef, UserMask):
26635 // * UserMask=[0,2,u,u], SplatMask=[2,u,2,u] -> [2,2,u,u]
26636 // In this case it is not legal to simplify to the splat-shuffle because we
26637 // may be exposing the users of the shuffle an undef element at index 1
26638 // which was not there before the combine.
26639 // * UserMask=[0,u,2,u], SplatMask=[2,u,2,u] -> [2,u,2,u]
26640 // In this case the composition of masks yields SplatMask, so it's ok to
26641 // simplify to the splat-shuffle.
26642 // * UserMask=[3,u,2,u], SplatMask=[2,u,2,u] -> [u,u,2,u]
26643 // In this case the composed mask includes all undef elements of SplatMask
26644 // and in addition sets element zero to undef. It is safe to simplify to
26645 // the splat-shuffle.
26646 auto CanSimplifyToExistingSplat = [](ArrayRef<int> UserMask,
26647 ArrayRef<int> SplatMask) {
26648 for (unsigned i = 0, e = UserMask.size(); i != e; ++i)
26649 if (UserMask[i] != -1 && SplatMask[i] == -1 &&
26650 SplatMask[UserMask[i]] != -1)
26651 return false;
26652 return true;
26653 };
26654 if (CanSimplifyToExistingSplat(ShufMask, SplatMask))
26655 return Shuf->getOperand(Num: 0);
26656
26657 // Create a new shuffle with a mask that is composed of the two shuffles'
26658 // masks.
26659 SmallVector<int, 32> NewMask;
26660 for (int Idx : ShufMask)
26661 NewMask.push_back(Elt: Idx == -1 ? -1 : SplatMask[Idx]);
26662
26663 return DAG.getVectorShuffle(VT: Splat->getValueType(ResNo: 0), dl: SDLoc(Splat),
26664 N1: Splat->getOperand(Num: 0), N2: Splat->getOperand(Num: 1),
26665 Mask: NewMask);
26666}
26667
26668// Combine shuffles of bitcasts into a shuffle of the bitcast type, providing
26669// the mask can be treated as a larger type.
26670static SDValue combineShuffleOfBitcast(ShuffleVectorSDNode *SVN,
26671 SelectionDAG &DAG,
26672 const TargetLowering &TLI,
26673 bool LegalOperations) {
26674 SDValue Op0 = SVN->getOperand(Num: 0);
26675 SDValue Op1 = SVN->getOperand(Num: 1);
26676 EVT VT = SVN->getValueType(ResNo: 0);
26677 if (Op0.getOpcode() != ISD::BITCAST)
26678 return SDValue();
26679 EVT InVT = Op0.getOperand(i: 0).getValueType();
26680 if (!InVT.isVector() ||
26681 (!Op1.isUndef() && (Op1.getOpcode() != ISD::BITCAST ||
26682 Op1.getOperand(i: 0).getValueType() != InVT)))
26683 return SDValue();
26684 if (isAnyConstantBuildVector(V: Op0.getOperand(i: 0)) &&
26685 (Op1.isUndef() || isAnyConstantBuildVector(V: Op1.getOperand(i: 0))))
26686 return SDValue();
26687
26688 int VTLanes = VT.getVectorNumElements();
26689 int InLanes = InVT.getVectorNumElements();
26690 if (VTLanes <= InLanes || VTLanes % InLanes != 0 ||
26691 (LegalOperations &&
26692 !TLI.isOperationLegalOrCustom(Op: ISD::VECTOR_SHUFFLE, VT: InVT)))
26693 return SDValue();
26694 int Factor = VTLanes / InLanes;
26695
26696 // Check that each group of lanes in the mask are either undef or make a valid
26697 // mask for the wider lane type.
26698 ArrayRef<int> Mask = SVN->getMask();
26699 SmallVector<int> NewMask;
26700 if (!widenShuffleMaskElts(Scale: Factor, Mask, ScaledMask&: NewMask))
26701 return SDValue();
26702
26703 if (!TLI.isShuffleMaskLegal(NewMask, InVT))
26704 return SDValue();
26705
26706 // Create the new shuffle with the new mask and bitcast it back to the
26707 // original type.
26708 SDLoc DL(SVN);
26709 Op0 = Op0.getOperand(i: 0);
26710 Op1 = Op1.isUndef() ? DAG.getUNDEF(VT: InVT) : Op1.getOperand(i: 0);
26711 SDValue NewShuf = DAG.getVectorShuffle(VT: InVT, dl: DL, N1: Op0, N2: Op1, Mask: NewMask);
26712 return DAG.getBitcast(VT, V: NewShuf);
26713}
26714
26715/// Combine shuffle of shuffle of the form:
26716/// shuf (shuf X, undef, InnerMask), undef, OuterMask --> splat X
26717static SDValue formSplatFromShuffles(ShuffleVectorSDNode *OuterShuf,
26718 SelectionDAG &DAG) {
26719 if (!OuterShuf->getOperand(Num: 1).isUndef())
26720 return SDValue();
26721 auto *InnerShuf = dyn_cast<ShuffleVectorSDNode>(Val: OuterShuf->getOperand(Num: 0));
26722 if (!InnerShuf || !InnerShuf->getOperand(Num: 1).isUndef())
26723 return SDValue();
26724
26725 ArrayRef<int> OuterMask = OuterShuf->getMask();
26726 ArrayRef<int> InnerMask = InnerShuf->getMask();
26727 unsigned NumElts = OuterMask.size();
26728 assert(NumElts == InnerMask.size() && "Mask length mismatch");
26729 SmallVector<int, 32> CombinedMask(NumElts, -1);
26730 int SplatIndex = -1;
26731 for (unsigned i = 0; i != NumElts; ++i) {
26732 // Undef lanes remain undef.
26733 int OuterMaskElt = OuterMask[i];
26734 if (OuterMaskElt == -1)
26735 continue;
26736
26737 // Peek through the shuffle masks to get the underlying source element.
26738 int InnerMaskElt = InnerMask[OuterMaskElt];
26739 if (InnerMaskElt == -1)
26740 continue;
26741
26742 // Initialize the splatted element.
26743 if (SplatIndex == -1)
26744 SplatIndex = InnerMaskElt;
26745
26746 // Non-matching index - this is not a splat.
26747 if (SplatIndex != InnerMaskElt)
26748 return SDValue();
26749
26750 CombinedMask[i] = InnerMaskElt;
26751 }
26752 assert((all_of(CombinedMask, [](int M) { return M == -1; }) ||
26753 getSplatIndex(CombinedMask) != -1) &&
26754 "Expected a splat mask");
26755
26756 // TODO: The transform may be a win even if the mask is not legal.
26757 EVT VT = OuterShuf->getValueType(ResNo: 0);
26758 assert(VT == InnerShuf->getValueType(0) && "Expected matching shuffle types");
26759 if (!DAG.getTargetLoweringInfo().isShuffleMaskLegal(CombinedMask, VT))
26760 return SDValue();
26761
26762 return DAG.getVectorShuffle(VT, dl: SDLoc(OuterShuf), N1: InnerShuf->getOperand(Num: 0),
26763 N2: InnerShuf->getOperand(Num: 1), Mask: CombinedMask);
26764}
26765
26766/// If the shuffle mask is taking exactly one element from the first vector
26767/// operand and passing through all other elements from the second vector
26768/// operand, return the index of the mask element that is choosing an element
26769/// from the first operand. Otherwise, return -1.
26770static int getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask) {
26771 int MaskSize = Mask.size();
26772 int EltFromOp0 = -1;
26773 // TODO: This does not match if there are undef elements in the shuffle mask.
26774 // Should we ignore undefs in the shuffle mask instead? The trade-off is
26775 // removing an instruction (a shuffle), but losing the knowledge that some
26776 // vector lanes are not needed.
26777 for (int i = 0; i != MaskSize; ++i) {
26778 if (Mask[i] >= 0 && Mask[i] < MaskSize) {
26779 // We're looking for a shuffle of exactly one element from operand 0.
26780 if (EltFromOp0 != -1)
26781 return -1;
26782 EltFromOp0 = i;
26783 } else if (Mask[i] != i + MaskSize) {
26784 // Nothing from operand 1 can change lanes.
26785 return -1;
26786 }
26787 }
26788 return EltFromOp0;
26789}
26790
26791/// If a shuffle inserts exactly one element from a source vector operand into
26792/// another vector operand and we can access the specified element as a scalar,
26793/// then we can eliminate the shuffle.
26794SDValue DAGCombiner::replaceShuffleOfInsert(ShuffleVectorSDNode *Shuf) {
26795 // First, check if we are taking one element of a vector and shuffling that
26796 // element into another vector.
26797 ArrayRef<int> Mask = Shuf->getMask();
26798 SmallVector<int, 16> CommutedMask(Mask);
26799 SDValue Op0 = Shuf->getOperand(Num: 0);
26800 SDValue Op1 = Shuf->getOperand(Num: 1);
26801 int ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(Mask);
26802 if (ShufOp0Index == -1) {
26803 // Commute mask and check again.
26804 ShuffleVectorSDNode::commuteMask(Mask: CommutedMask);
26805 ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(Mask: CommutedMask);
26806 if (ShufOp0Index == -1)
26807 return SDValue();
26808 // Commute operands to match the commuted shuffle mask.
26809 std::swap(a&: Op0, b&: Op1);
26810 Mask = CommutedMask;
26811 }
26812
26813 // The shuffle inserts exactly one element from operand 0 into operand 1.
26814 // Now see if we can access that element as a scalar via a real insert element
26815 // instruction.
26816 // TODO: We can try harder to locate the element as a scalar. Examples: it
26817 // could be an operand of BUILD_VECTOR, or a constant.
26818 assert(Mask[ShufOp0Index] >= 0 && Mask[ShufOp0Index] < (int)Mask.size() &&
26819 "Shuffle mask value must be from operand 0");
26820
26821 SDValue Elt;
26822 if (sd_match(N: Op0, P: m_InsertElt(Vec: m_Value(), Val: m_Value(N&: Elt),
26823 Idx: m_SpecificInt(V: Mask[ShufOp0Index])))) {
26824 // There's an existing insertelement with constant insertion index, so we
26825 // don't need to check the legality/profitability of a replacement operation
26826 // that differs at most in the constant value. The target should be able to
26827 // lower any of those in a similar way. If not, legalization will expand
26828 // this to a scalar-to-vector plus shuffle.
26829 //
26830 // Note that the shuffle may move the scalar from the position that the
26831 // insert element used. Therefore, our new insert element occurs at the
26832 // shuffle's mask index value, not the insert's index value.
26833 //
26834 // shuffle (insertelt v1, x, C), v2, mask --> insertelt v2, x, C'
26835 SDValue NewInsIndex = DAG.getVectorIdxConstant(Val: ShufOp0Index, DL: SDLoc(Shuf));
26836 return DAG.getNode(Opcode: ISD::INSERT_VECTOR_ELT, DL: SDLoc(Shuf), VT: Op0.getValueType(),
26837 N1: Op1, N2: Elt, N3: NewInsIndex);
26838 }
26839
26840 if (!hasOperation(Opcode: ISD::INSERT_VECTOR_ELT, VT: Op0.getValueType()))
26841 return SDValue();
26842
26843 if (sd_match(N: Op0, P: m_UnaryOp(Opc: ISD::SCALAR_TO_VECTOR, Op: m_Value(N&: Elt))) &&
26844 Mask[ShufOp0Index] == 0) {
26845 SDValue NewInsIndex = DAG.getVectorIdxConstant(Val: ShufOp0Index, DL: SDLoc(Shuf));
26846 return DAG.getNode(Opcode: ISD::INSERT_VECTOR_ELT, DL: SDLoc(Shuf), VT: Op0.getValueType(),
26847 N1: Op1, N2: Elt, N3: NewInsIndex);
26848 }
26849
26850 return SDValue();
26851}
26852
26853/// If we have a unary shuffle of a shuffle, see if it can be folded away
26854/// completely. This has the potential to lose undef knowledge because the first
26855/// shuffle may not have an undef mask element where the second one does. So
26856/// only call this after doing simplifications based on demanded elements.
26857static SDValue simplifyShuffleOfShuffle(ShuffleVectorSDNode *Shuf) {
26858 // shuf (shuf0 X, Y, Mask0), undef, Mask
26859 auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(Val: Shuf->getOperand(Num: 0));
26860 if (!Shuf0 || !Shuf->getOperand(Num: 1).isUndef())
26861 return SDValue();
26862
26863 ArrayRef<int> Mask = Shuf->getMask();
26864 ArrayRef<int> Mask0 = Shuf0->getMask();
26865 for (int i = 0, e = (int)Mask.size(); i != e; ++i) {
26866 // Ignore undef elements.
26867 if (Mask[i] == -1)
26868 continue;
26869 assert(Mask[i] >= 0 && Mask[i] < e && "Unexpected shuffle mask value");
26870
26871 // Is the element of the shuffle operand chosen by this shuffle the same as
26872 // the element chosen by the shuffle operand itself?
26873 if (Mask0[Mask[i]] != Mask0[i])
26874 return SDValue();
26875 }
26876 // Every element of this shuffle is identical to the result of the previous
26877 // shuffle, so we can replace this value.
26878 return Shuf->getOperand(Num: 0);
26879}
26880
26881SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
26882 EVT VT = N->getValueType(ResNo: 0);
26883 unsigned NumElts = VT.getVectorNumElements();
26884
26885 SDValue N0 = N->getOperand(Num: 0);
26886 SDValue N1 = N->getOperand(Num: 1);
26887
26888 assert(N0.getValueType() == VT && "Vector shuffle must be normalized in DAG");
26889
26890 // Canonicalize shuffle undef, undef -> undef
26891 if (N0.isUndef() && N1.isUndef())
26892 return DAG.getUNDEF(VT);
26893
26894 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Val: N);
26895
26896 // Canonicalize shuffle v, v -> v, undef
26897 if (N0 == N1)
26898 return DAG.getVectorShuffle(VT, dl: SDLoc(N), N1: N0, N2: DAG.getUNDEF(VT),
26899 Mask: createUnaryMask(Mask: SVN->getMask(), NumElts));
26900
26901 // Canonicalize shuffle undef, v -> v, undef. Commute the shuffle mask.
26902 if (N0.isUndef())
26903 return DAG.getCommutedVectorShuffle(SV: *SVN);
26904
26905 // Remove references to rhs if it is undef
26906 if (N1.isUndef()) {
26907 bool Changed = false;
26908 SmallVector<int, 8> NewMask;
26909 for (unsigned i = 0; i != NumElts; ++i) {
26910 int Idx = SVN->getMaskElt(Idx: i);
26911 if (Idx >= (int)NumElts) {
26912 Idx = -1;
26913 Changed = true;
26914 }
26915 NewMask.push_back(Elt: Idx);
26916 }
26917 if (Changed)
26918 return DAG.getVectorShuffle(VT, dl: SDLoc(N), N1: N0, N2: N1, Mask: NewMask);
26919 }
26920
26921 if (SDValue InsElt = replaceShuffleOfInsert(Shuf: SVN))
26922 return InsElt;
26923
26924 // A shuffle of a single vector that is a splatted value can always be folded.
26925 if (SDValue V = combineShuffleOfSplatVal(Shuf: SVN, DAG))
26926 return V;
26927
26928 if (SDValue V = formSplatFromShuffles(OuterShuf: SVN, DAG))
26929 return V;
26930
26931 // If it is a splat, check if the argument vector is another splat or a
26932 // build_vector.
26933 if (SVN->isSplat() && SVN->getSplatIndex() < (int)NumElts) {
26934 int SplatIndex = SVN->getSplatIndex();
26935 if (N0.hasOneUse() && TLI.isExtractVecEltCheap(VT, Index: SplatIndex) &&
26936 TLI.isBinOp(Opcode: N0.getOpcode()) && N0->getNumValues() == 1) {
26937 // splat (vector_bo L, R), Index -->
26938 // splat (scalar_bo (extelt L, Index), (extelt R, Index))
26939 SDValue L = N0.getOperand(i: 0), R = N0.getOperand(i: 1);
26940 SDLoc DL(N);
26941 EVT EltVT = VT.getScalarType();
26942 SDValue Index = DAG.getVectorIdxConstant(Val: SplatIndex, DL);
26943 SDValue ExtL = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: L, N2: Index);
26944 SDValue ExtR = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: R, N2: Index);
26945 SDValue NewBO =
26946 DAG.getNode(Opcode: N0.getOpcode(), DL, VT: EltVT, N1: ExtL, N2: ExtR, Flags: N0->getFlags());
26947 SDValue Insert = DAG.getNode(Opcode: ISD::SCALAR_TO_VECTOR, DL, VT, Operand: NewBO);
26948 SmallVector<int, 16> ZeroMask(VT.getVectorNumElements(), 0);
26949 return DAG.getVectorShuffle(VT, dl: DL, N1: Insert, N2: DAG.getUNDEF(VT), Mask: ZeroMask);
26950 }
26951
26952 // splat(scalar_to_vector(x), 0) -> build_vector(x,...,x)
26953 // splat(insert_vector_elt(v, x, c), c) -> build_vector(x,...,x)
26954 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT)) &&
26955 N0.hasOneUse()) {
26956 if (N0.getOpcode() == ISD::SCALAR_TO_VECTOR && SplatIndex == 0)
26957 return DAG.getSplatBuildVector(VT, DL: SDLoc(N), Op: N0.getOperand(i: 0));
26958
26959 if (N0.getOpcode() == ISD::INSERT_VECTOR_ELT)
26960 if (auto *Idx = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 2)))
26961 if (Idx->getAPIntValue() == SplatIndex)
26962 return DAG.getSplatBuildVector(VT, DL: SDLoc(N), Op: N0.getOperand(i: 1));
26963
26964 // Look through a bitcast if LE and splatting lane 0, through to a
26965 // scalar_to_vector or a build_vector.
26966 if (N0.getOpcode() == ISD::BITCAST && N0.getOperand(i: 0).hasOneUse() &&
26967 SplatIndex == 0 && DAG.getDataLayout().isLittleEndian() &&
26968 (N0.getOperand(i: 0).getOpcode() == ISD::SCALAR_TO_VECTOR ||
26969 N0.getOperand(i: 0).getOpcode() == ISD::BUILD_VECTOR)) {
26970 EVT N00VT = N0.getOperand(i: 0).getValueType();
26971 if (VT.getScalarSizeInBits() <= N00VT.getScalarSizeInBits() &&
26972 VT.isInteger() && N00VT.isInteger()) {
26973 EVT InVT =
26974 TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: VT.getScalarType());
26975 SDValue Op = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 0).getOperand(i: 0),
26976 DL: SDLoc(N), VT: InVT);
26977 return DAG.getSplatBuildVector(VT, DL: SDLoc(N), Op);
26978 }
26979 }
26980 }
26981
26982 // If this is a bit convert that changes the element type of the vector but
26983 // not the number of vector elements, look through it. Be careful not to
26984 // look though conversions that change things like v4f32 to v2f64.
26985 SDNode *V = N0.getNode();
26986 if (V->getOpcode() == ISD::BITCAST) {
26987 SDValue ConvInput = V->getOperand(Num: 0);
26988 if (ConvInput.getValueType().isVector() &&
26989 ConvInput.getValueType().getVectorNumElements() == NumElts)
26990 V = ConvInput.getNode();
26991 }
26992
26993 if (V->getOpcode() == ISD::BUILD_VECTOR) {
26994 assert(V->getNumOperands() == NumElts &&
26995 "BUILD_VECTOR has wrong number of operands");
26996 SDValue Base;
26997 bool AllSame = true;
26998 for (unsigned i = 0; i != NumElts; ++i) {
26999 if (!V->getOperand(Num: i).isUndef()) {
27000 Base = V->getOperand(Num: i);
27001 break;
27002 }
27003 }
27004 // Splat of <u, u, u, u>, return <u, u, u, u>
27005 if (!Base.getNode())
27006 return N0;
27007 for (unsigned i = 0; i != NumElts; ++i) {
27008 if (V->getOperand(Num: i) != Base) {
27009 AllSame = false;
27010 break;
27011 }
27012 }
27013 // Splat of <x, x, x, x>, return <x, x, x, x>
27014 if (AllSame)
27015 return N0;
27016
27017 // Canonicalize any other splat as a build_vector, but avoid defining any
27018 // undefined elements in the mask.
27019 SDValue Splatted = V->getOperand(Num: SplatIndex);
27020 SmallVector<SDValue, 8> Ops(NumElts, Splatted);
27021 EVT EltVT = Splatted.getValueType();
27022
27023 for (unsigned i = 0; i != NumElts; ++i) {
27024 if (SVN->getMaskElt(Idx: i) < 0)
27025 Ops[i] = DAG.getUNDEF(VT: EltVT);
27026 }
27027
27028 SDValue NewBV = DAG.getBuildVector(VT: V->getValueType(ResNo: 0), DL: SDLoc(N), Ops);
27029
27030 // We may have jumped through bitcasts, so the type of the
27031 // BUILD_VECTOR may not match the type of the shuffle.
27032 if (V->getValueType(ResNo: 0) != VT)
27033 NewBV = DAG.getBitcast(VT, V: NewBV);
27034 return NewBV;
27035 }
27036 }
27037
27038 // Simplify source operands based on shuffle mask.
27039 if (SimplifyDemandedVectorElts(Op: SDValue(N, 0)))
27040 return SDValue(N, 0);
27041
27042 // This is intentionally placed after demanded elements simplification because
27043 // it could eliminate knowledge of undef elements created by this shuffle.
27044 if (SDValue ShufOp = simplifyShuffleOfShuffle(Shuf: SVN))
27045 return ShufOp;
27046
27047 // Match shuffles that can be converted to any_vector_extend_in_reg.
27048 if (SDValue V =
27049 combineShuffleToAnyExtendVectorInreg(SVN, DAG, TLI, LegalOperations))
27050 return V;
27051
27052 // Combine "truncate_vector_in_reg" style shuffles.
27053 if (SDValue V = combineTruncationShuffle(SVN, DAG))
27054 return V;
27055
27056 if (N0.getOpcode() == ISD::CONCAT_VECTORS &&
27057 Level < AfterLegalizeVectorOps &&
27058 (N1.isUndef() ||
27059 (N1.getOpcode() == ISD::CONCAT_VECTORS &&
27060 N0.getOperand(i: 0).getValueType() == N1.getOperand(i: 0).getValueType()))) {
27061 if (SDValue V = partitionShuffleOfConcats(N, DAG))
27062 return V;
27063 }
27064
27065 // A shuffle of a concat of the same narrow vector can be reduced to use
27066 // only low-half elements of a concat with undef:
27067 // shuf (concat X, X), undef, Mask --> shuf (concat X, undef), undef, Mask'
27068 if (N0.getOpcode() == ISD::CONCAT_VECTORS && N1.isUndef() &&
27069 N0.getNumOperands() == 2 &&
27070 N0.getOperand(i: 0) == N0.getOperand(i: 1)) {
27071 int HalfNumElts = (int)NumElts / 2;
27072 SmallVector<int, 8> NewMask;
27073 for (unsigned i = 0; i != NumElts; ++i) {
27074 int Idx = SVN->getMaskElt(Idx: i);
27075 if (Idx >= HalfNumElts) {
27076 assert(Idx < (int)NumElts && "Shuffle mask chooses undef op");
27077 Idx -= HalfNumElts;
27078 }
27079 NewMask.push_back(Elt: Idx);
27080 }
27081 if (TLI.isShuffleMaskLegal(NewMask, VT)) {
27082 SDValue UndefVec = DAG.getUNDEF(VT: N0.getOperand(i: 0).getValueType());
27083 SDValue NewCat = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT,
27084 N1: N0.getOperand(i: 0), N2: UndefVec);
27085 return DAG.getVectorShuffle(VT, dl: SDLoc(N), N1: NewCat, N2: N1, Mask: NewMask);
27086 }
27087 }
27088
27089 // See if we can replace a shuffle with an insert_subvector.
27090 // e.g. v2i32 into v8i32:
27091 // shuffle(lhs,concat(rhs0,rhs1,rhs2,rhs3),0,1,2,3,10,11,6,7).
27092 // --> insert_subvector(lhs,rhs1,4).
27093 if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT) &&
27094 TLI.isOperationLegalOrCustom(Op: ISD::INSERT_SUBVECTOR, VT)) {
27095 auto ShuffleToInsert = [&](SDValue LHS, SDValue RHS, ArrayRef<int> Mask) {
27096 // Ensure RHS subvectors are legal.
27097 assert(RHS.getOpcode() == ISD::CONCAT_VECTORS && "Can't find subvectors");
27098 EVT SubVT = RHS.getOperand(i: 0).getValueType();
27099 int NumSubVecs = RHS.getNumOperands();
27100 int NumSubElts = SubVT.getVectorNumElements();
27101 assert((NumElts % NumSubElts) == 0 && "Subvector mismatch");
27102 if (!TLI.isTypeLegal(VT: SubVT))
27103 return SDValue();
27104
27105 // Don't bother if we have an unary shuffle (matches undef + LHS elts).
27106 if (all_of(Range&: Mask, P: [NumElts](int M) { return M < (int)NumElts; }))
27107 return SDValue();
27108
27109 // Search [NumSubElts] spans for RHS sequence.
27110 // TODO: Can we avoid nested loops to increase performance?
27111 SmallVector<int> InsertionMask(NumElts);
27112 for (int SubVec = 0; SubVec != NumSubVecs; ++SubVec) {
27113 for (int SubIdx = 0; SubIdx != (int)NumElts; SubIdx += NumSubElts) {
27114 // Reset mask to identity.
27115 std::iota(first: InsertionMask.begin(), last: InsertionMask.end(), value: 0);
27116
27117 // Add subvector insertion.
27118 std::iota(first: InsertionMask.begin() + SubIdx,
27119 last: InsertionMask.begin() + SubIdx + NumSubElts,
27120 value: NumElts + (SubVec * NumSubElts));
27121
27122 // See if the shuffle mask matches the reference insertion mask.
27123 bool MatchingShuffle = true;
27124 for (int i = 0; i != (int)NumElts; ++i) {
27125 int ExpectIdx = InsertionMask[i];
27126 int ActualIdx = Mask[i];
27127 if (0 <= ActualIdx && ExpectIdx != ActualIdx) {
27128 MatchingShuffle = false;
27129 break;
27130 }
27131 }
27132
27133 if (MatchingShuffle)
27134 return DAG.getInsertSubvector(DL: SDLoc(N), Vec: LHS, SubVec: RHS.getOperand(i: SubVec),
27135 Idx: SubIdx);
27136 }
27137 }
27138 return SDValue();
27139 };
27140 ArrayRef<int> Mask = SVN->getMask();
27141 if (N1.getOpcode() == ISD::CONCAT_VECTORS)
27142 if (SDValue InsertN1 = ShuffleToInsert(N0, N1, Mask))
27143 return InsertN1;
27144 if (N0.getOpcode() == ISD::CONCAT_VECTORS) {
27145 SmallVector<int> CommuteMask(Mask);
27146 ShuffleVectorSDNode::commuteMask(Mask: CommuteMask);
27147 if (SDValue InsertN0 = ShuffleToInsert(N1, N0, CommuteMask))
27148 return InsertN0;
27149 }
27150 }
27151
27152 // If we're not performing a select/blend shuffle, see if we can convert the
27153 // shuffle into a AND node, with all the out-of-lane elements are known zero.
27154 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
27155 bool IsInLaneMask = true;
27156 ArrayRef<int> Mask = SVN->getMask();
27157 SmallVector<int, 16> ClearMask(NumElts, -1);
27158 APInt DemandedLHS = APInt::getZero(numBits: NumElts);
27159 APInt DemandedRHS = APInt::getZero(numBits: NumElts);
27160 for (int I = 0; I != (int)NumElts; ++I) {
27161 int M = Mask[I];
27162 if (M < 0)
27163 continue;
27164 ClearMask[I] = M == I ? I : (I + NumElts);
27165 IsInLaneMask &= (M == I) || (M == (int)(I + NumElts));
27166 if (M != I) {
27167 APInt &Demanded = M < (int)NumElts ? DemandedLHS : DemandedRHS;
27168 Demanded.setBit(M % NumElts);
27169 }
27170 }
27171 // TODO: Should we try to mask with N1 as well?
27172 if (!IsInLaneMask && (!DemandedLHS.isZero() || !DemandedRHS.isZero()) &&
27173 (DemandedLHS.isZero() || DAG.MaskedVectorIsZero(Op: N0, DemandedElts: DemandedLHS)) &&
27174 (DemandedRHS.isZero() || DAG.MaskedVectorIsZero(Op: N1, DemandedElts: DemandedRHS))) {
27175 SDLoc DL(N);
27176 EVT IntVT = VT.changeVectorElementTypeToInteger();
27177 EVT IntSVT = VT.getVectorElementType().changeTypeToInteger();
27178 // Transform the type to a legal type so that the buildvector constant
27179 // elements are not illegal. Make sure that the result is larger than the
27180 // original type, incase the value is split into two (eg i64->i32).
27181 if (!TLI.isTypeLegal(VT: IntSVT) && LegalTypes)
27182 IntSVT = TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: IntSVT);
27183 if (IntSVT.getSizeInBits() >= IntVT.getScalarSizeInBits()) {
27184 SDValue ZeroElt = DAG.getConstant(Val: 0, DL, VT: IntSVT);
27185 SDValue AllOnesElt = DAG.getAllOnesConstant(DL, VT: IntSVT);
27186 SmallVector<SDValue, 16> AndMask(NumElts, DAG.getUNDEF(VT: IntSVT));
27187 for (int I = 0; I != (int)NumElts; ++I)
27188 if (0 <= Mask[I])
27189 AndMask[I] = Mask[I] == I ? AllOnesElt : ZeroElt;
27190
27191 // See if a clear mask is legal instead of going via
27192 // XformToShuffleWithZero which loses UNDEF mask elements.
27193 if (TLI.isVectorClearMaskLegal(ClearMask, IntVT))
27194 return DAG.getBitcast(
27195 VT, V: DAG.getVectorShuffle(VT: IntVT, dl: DL, N1: DAG.getBitcast(VT: IntVT, V: N0),
27196 N2: DAG.getConstant(Val: 0, DL, VT: IntVT), Mask: ClearMask));
27197
27198 if (TLI.isOperationLegalOrCustom(Op: ISD::AND, VT: IntVT))
27199 return DAG.getBitcast(
27200 VT, V: DAG.getNode(Opcode: ISD::AND, DL, VT: IntVT, N1: DAG.getBitcast(VT: IntVT, V: N0),
27201 N2: DAG.getBuildVector(VT: IntVT, DL, Ops: AndMask)));
27202 }
27203 }
27204 }
27205
27206 // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
27207 // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
27208 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT))
27209 if (SDValue Res = combineShuffleOfScalars(SVN, DAG, TLI))
27210 return Res;
27211
27212 // If this shuffle only has a single input that is a bitcasted shuffle,
27213 // attempt to merge the 2 shuffles and suitably bitcast the inputs/output
27214 // back to their original types.
27215 if (N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
27216 N1.isUndef() && Level < AfterLegalizeVectorOps &&
27217 TLI.isTypeLegal(VT)) {
27218
27219 SDValue BC0 = peekThroughOneUseBitcasts(V: N0);
27220 if (BC0.getOpcode() == ISD::VECTOR_SHUFFLE && BC0.hasOneUse()) {
27221 EVT SVT = VT.getScalarType();
27222 EVT InnerVT = BC0->getValueType(ResNo: 0);
27223 EVT InnerSVT = InnerVT.getScalarType();
27224
27225 // Determine which shuffle works with the smaller scalar type.
27226 EVT ScaleVT = SVT.bitsLT(VT: InnerSVT) ? VT : InnerVT;
27227 EVT ScaleSVT = ScaleVT.getScalarType();
27228
27229 if (TLI.isTypeLegal(VT: ScaleVT) &&
27230 0 == (InnerSVT.getSizeInBits() % ScaleSVT.getSizeInBits()) &&
27231 0 == (SVT.getSizeInBits() % ScaleSVT.getSizeInBits())) {
27232 int InnerScale = InnerSVT.getSizeInBits() / ScaleSVT.getSizeInBits();
27233 int OuterScale = SVT.getSizeInBits() / ScaleSVT.getSizeInBits();
27234
27235 // Scale the shuffle masks to the smaller scalar type.
27236 ShuffleVectorSDNode *InnerSVN = cast<ShuffleVectorSDNode>(Val&: BC0);
27237 SmallVector<int, 8> InnerMask;
27238 SmallVector<int, 8> OuterMask;
27239 narrowShuffleMaskElts(Scale: InnerScale, Mask: InnerSVN->getMask(), ScaledMask&: InnerMask);
27240 narrowShuffleMaskElts(Scale: OuterScale, Mask: SVN->getMask(), ScaledMask&: OuterMask);
27241
27242 // Merge the shuffle masks.
27243 SmallVector<int, 8> NewMask;
27244 for (int M : OuterMask)
27245 NewMask.push_back(Elt: M < 0 ? -1 : InnerMask[M]);
27246
27247 // Test for shuffle mask legality over both commutations.
27248 SDValue SV0 = BC0->getOperand(Num: 0);
27249 SDValue SV1 = BC0->getOperand(Num: 1);
27250 bool LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
27251 if (!LegalMask) {
27252 std::swap(a&: SV0, b&: SV1);
27253 ShuffleVectorSDNode::commuteMask(Mask: NewMask);
27254 LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
27255 }
27256
27257 if (LegalMask) {
27258 SV0 = DAG.getBitcast(VT: ScaleVT, V: SV0);
27259 SV1 = DAG.getBitcast(VT: ScaleVT, V: SV1);
27260 return DAG.getBitcast(
27261 VT, V: DAG.getVectorShuffle(VT: ScaleVT, dl: SDLoc(N), N1: SV0, N2: SV1, Mask: NewMask));
27262 }
27263 }
27264 }
27265 }
27266
27267 // Match shuffles of bitcasts, so long as the mask can be treated as the
27268 // larger type.
27269 if (SDValue V = combineShuffleOfBitcast(SVN, DAG, TLI, LegalOperations))
27270 return V;
27271
27272 // Compute the combined shuffle mask for a shuffle with SV0 as the first
27273 // operand, and SV1 as the second operand.
27274 // i.e. Merge SVN(OtherSVN, N1) -> shuffle(SV0, SV1, Mask) iff Commute = false
27275 // Merge SVN(N1, OtherSVN) -> shuffle(SV0, SV1, Mask') iff Commute = true
27276 auto MergeInnerShuffle =
27277 [NumElts, &VT](bool Commute, ShuffleVectorSDNode *SVN,
27278 ShuffleVectorSDNode *OtherSVN, SDValue N1,
27279 const TargetLowering &TLI, SDValue &SV0, SDValue &SV1,
27280 SmallVectorImpl<int> &Mask) -> bool {
27281 // Don't try to fold splats; they're likely to simplify somehow, or they
27282 // might be free.
27283 if (OtherSVN->isSplat())
27284 return false;
27285
27286 SV0 = SV1 = SDValue();
27287 Mask.clear();
27288
27289 for (unsigned i = 0; i != NumElts; ++i) {
27290 int Idx = SVN->getMaskElt(Idx: i);
27291 if (Idx < 0) {
27292 // Propagate Undef.
27293 Mask.push_back(Elt: Idx);
27294 continue;
27295 }
27296
27297 if (Commute)
27298 Idx = (Idx < (int)NumElts) ? (Idx + NumElts) : (Idx - NumElts);
27299
27300 SDValue CurrentVec;
27301 if (Idx < (int)NumElts) {
27302 // This shuffle index refers to the inner shuffle N0. Lookup the inner
27303 // shuffle mask to identify which vector is actually referenced.
27304 Idx = OtherSVN->getMaskElt(Idx);
27305 if (Idx < 0) {
27306 // Propagate Undef.
27307 Mask.push_back(Elt: Idx);
27308 continue;
27309 }
27310 CurrentVec = (Idx < (int)NumElts) ? OtherSVN->getOperand(Num: 0)
27311 : OtherSVN->getOperand(Num: 1);
27312 } else {
27313 // This shuffle index references an element within N1.
27314 CurrentVec = N1;
27315 }
27316
27317 // Simple case where 'CurrentVec' is UNDEF.
27318 if (CurrentVec.isUndef()) {
27319 Mask.push_back(Elt: -1);
27320 continue;
27321 }
27322
27323 // Canonicalize the shuffle index. We don't know yet if CurrentVec
27324 // will be the first or second operand of the combined shuffle.
27325 Idx = Idx % NumElts;
27326 if (!SV0.getNode() || SV0 == CurrentVec) {
27327 // Ok. CurrentVec is the left hand side.
27328 // Update the mask accordingly.
27329 SV0 = CurrentVec;
27330 Mask.push_back(Elt: Idx);
27331 continue;
27332 }
27333 if (!SV1.getNode() || SV1 == CurrentVec) {
27334 // Ok. CurrentVec is the right hand side.
27335 // Update the mask accordingly.
27336 SV1 = CurrentVec;
27337 Mask.push_back(Elt: Idx + NumElts);
27338 continue;
27339 }
27340
27341 // Last chance - see if the vector is another shuffle and if it
27342 // uses one of the existing candidate shuffle ops.
27343 if (auto *CurrentSVN = dyn_cast<ShuffleVectorSDNode>(Val&: CurrentVec)) {
27344 int InnerIdx = CurrentSVN->getMaskElt(Idx);
27345 if (InnerIdx < 0) {
27346 Mask.push_back(Elt: -1);
27347 continue;
27348 }
27349 SDValue InnerVec = (InnerIdx < (int)NumElts)
27350 ? CurrentSVN->getOperand(Num: 0)
27351 : CurrentSVN->getOperand(Num: 1);
27352 if (InnerVec.isUndef()) {
27353 Mask.push_back(Elt: -1);
27354 continue;
27355 }
27356 InnerIdx %= NumElts;
27357 if (InnerVec == SV0) {
27358 Mask.push_back(Elt: InnerIdx);
27359 continue;
27360 }
27361 if (InnerVec == SV1) {
27362 Mask.push_back(Elt: InnerIdx + NumElts);
27363 continue;
27364 }
27365 }
27366
27367 // Bail out if we cannot convert the shuffle pair into a single shuffle.
27368 return false;
27369 }
27370
27371 if (llvm::all_of(Range&: Mask, P: [](int M) { return M < 0; }))
27372 return true;
27373
27374 // Avoid introducing shuffles with illegal mask.
27375 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
27376 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
27377 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
27378 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, A, M2)
27379 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, A, M2)
27380 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, B, M2)
27381 if (TLI.isShuffleMaskLegal(Mask, VT))
27382 return true;
27383
27384 std::swap(a&: SV0, b&: SV1);
27385 ShuffleVectorSDNode::commuteMask(Mask);
27386 return TLI.isShuffleMaskLegal(Mask, VT);
27387 };
27388
27389 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
27390 // Canonicalize shuffles according to rules:
27391 // shuffle(A, shuffle(A, B)) -> shuffle(shuffle(A,B), A)
27392 // shuffle(B, shuffle(A, B)) -> shuffle(shuffle(A,B), B)
27393 // shuffle(B, shuffle(A, Undef)) -> shuffle(shuffle(A, Undef), B)
27394 if (N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
27395 N0.getOpcode() != ISD::VECTOR_SHUFFLE) {
27396 // The incoming shuffle must be of the same type as the result of the
27397 // current shuffle.
27398 assert(N1->getOperand(0).getValueType() == VT &&
27399 "Shuffle types don't match");
27400
27401 SDValue SV0 = N1->getOperand(Num: 0);
27402 SDValue SV1 = N1->getOperand(Num: 1);
27403 bool HasSameOp0 = N0 == SV0;
27404 bool IsSV1Undef = SV1.isUndef();
27405 if (HasSameOp0 || IsSV1Undef || N0 == SV1)
27406 // Commute the operands of this shuffle so merging below will trigger.
27407 return DAG.getCommutedVectorShuffle(SV: *SVN);
27408 }
27409
27410 // Canonicalize splat shuffles to the RHS to improve merging below.
27411 // shuffle(splat(A,u), shuffle(C,D)) -> shuffle'(shuffle(C,D), splat(A,u))
27412 if (N0.getOpcode() == ISD::VECTOR_SHUFFLE &&
27413 N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
27414 cast<ShuffleVectorSDNode>(Val&: N0)->isSplat() &&
27415 !cast<ShuffleVectorSDNode>(Val&: N1)->isSplat()) {
27416 return DAG.getCommutedVectorShuffle(SV: *SVN);
27417 }
27418
27419 // Try to fold according to rules:
27420 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
27421 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
27422 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
27423 // Don't try to fold shuffles with illegal type.
27424 // Only fold if this shuffle is the only user of the other shuffle.
27425 // Try matching shuffle(C,shuffle(A,B)) commutted patterns as well.
27426 for (int i = 0; i != 2; ++i) {
27427 if (N->getOperand(Num: i).getOpcode() == ISD::VECTOR_SHUFFLE &&
27428 N->isOnlyUserOf(N: N->getOperand(Num: i).getNode())) {
27429 // The incoming shuffle must be of the same type as the result of the
27430 // current shuffle.
27431 auto *OtherSV = cast<ShuffleVectorSDNode>(Val: N->getOperand(Num: i));
27432 assert(OtherSV->getOperand(0).getValueType() == VT &&
27433 "Shuffle types don't match");
27434
27435 SDValue SV0, SV1;
27436 SmallVector<int, 4> Mask;
27437 if (MergeInnerShuffle(i != 0, SVN, OtherSV, N->getOperand(Num: 1 - i), TLI,
27438 SV0, SV1, Mask)) {
27439 // Check if all indices in Mask are Undef. In case, propagate Undef.
27440 if (llvm::all_of(Range&: Mask, P: [](int M) { return M < 0; }))
27441 return DAG.getUNDEF(VT);
27442
27443 return DAG.getVectorShuffle(VT, dl: SDLoc(N),
27444 N1: SV0 ? SV0 : DAG.getUNDEF(VT),
27445 N2: SV1 ? SV1 : DAG.getUNDEF(VT), Mask);
27446 }
27447 }
27448 }
27449
27450 // Merge shuffles through binops if we are able to merge it with at least
27451 // one other shuffles.
27452 // shuffle(bop(shuffle(x,y),shuffle(z,w)),undef)
27453 // shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d)))
27454 unsigned SrcOpcode = N0.getOpcode();
27455 if (TLI.isBinOp(Opcode: SrcOpcode) && N->isOnlyUserOf(N: N0.getNode()) &&
27456 (N1.isUndef() ||
27457 (SrcOpcode == N1.getOpcode() && N->isOnlyUserOf(N: N1.getNode())))) {
27458 // Get binop source ops, or just pass on the undef.
27459 SDValue Op00 = N0.getOperand(i: 0);
27460 SDValue Op01 = N0.getOperand(i: 1);
27461 SDValue Op10 = N1.isUndef() ? N1 : N1.getOperand(i: 0);
27462 SDValue Op11 = N1.isUndef() ? N1 : N1.getOperand(i: 1);
27463 // TODO: We might be able to relax the VT check but we don't currently
27464 // have any isBinOp() that has different result/ops VTs so play safe until
27465 // we have test coverage.
27466 if (Op00.getValueType() == VT && Op10.getValueType() == VT &&
27467 Op01.getValueType() == VT && Op11.getValueType() == VT &&
27468 (Op00.getOpcode() == ISD::VECTOR_SHUFFLE ||
27469 Op10.getOpcode() == ISD::VECTOR_SHUFFLE ||
27470 Op01.getOpcode() == ISD::VECTOR_SHUFFLE ||
27471 Op11.getOpcode() == ISD::VECTOR_SHUFFLE)) {
27472 auto CanMergeInnerShuffle = [&](SDValue &SV0, SDValue &SV1,
27473 SmallVectorImpl<int> &Mask, bool LeftOp,
27474 bool Commute) {
27475 SDValue InnerN = Commute ? N1 : N0;
27476 SDValue Op0 = LeftOp ? Op00 : Op01;
27477 SDValue Op1 = LeftOp ? Op10 : Op11;
27478 if (Commute)
27479 std::swap(a&: Op0, b&: Op1);
27480 // Only accept the merged shuffle if we don't introduce undef elements,
27481 // or the inner shuffle already contained undef elements.
27482 auto *SVN0 = dyn_cast<ShuffleVectorSDNode>(Val&: Op0);
27483 return SVN0 && InnerN->isOnlyUserOf(N: SVN0) &&
27484 MergeInnerShuffle(Commute, SVN, SVN0, Op1, TLI, SV0, SV1,
27485 Mask) &&
27486 (llvm::any_of(Range: SVN0->getMask(), P: [](int M) { return M < 0; }) ||
27487 llvm::none_of(Range&: Mask, P: [](int M) { return M < 0; }));
27488 };
27489
27490 // Ensure we don't increase the number of shuffles - we must merge a
27491 // shuffle from at least one of the LHS and RHS ops.
27492 bool MergedLeft = false;
27493 SDValue LeftSV0, LeftSV1;
27494 SmallVector<int, 4> LeftMask;
27495 if (CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, false) ||
27496 CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, true)) {
27497 MergedLeft = true;
27498 } else {
27499 LeftMask.assign(in_start: SVN->getMask().begin(), in_end: SVN->getMask().end());
27500 LeftSV0 = Op00, LeftSV1 = Op10;
27501 }
27502
27503 bool MergedRight = false;
27504 SDValue RightSV0, RightSV1;
27505 SmallVector<int, 4> RightMask;
27506 if (CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, false) ||
27507 CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, true)) {
27508 MergedRight = true;
27509 } else {
27510 RightMask.assign(in_start: SVN->getMask().begin(), in_end: SVN->getMask().end());
27511 RightSV0 = Op01, RightSV1 = Op11;
27512 }
27513
27514 if (MergedLeft || MergedRight) {
27515 SDLoc DL(N);
27516 SDValue LHS = DAG.getVectorShuffle(
27517 VT, dl: DL, N1: LeftSV0 ? LeftSV0 : DAG.getUNDEF(VT),
27518 N2: LeftSV1 ? LeftSV1 : DAG.getUNDEF(VT), Mask: LeftMask);
27519 SDValue RHS = DAG.getVectorShuffle(
27520 VT, dl: DL, N1: RightSV0 ? RightSV0 : DAG.getUNDEF(VT),
27521 N2: RightSV1 ? RightSV1 : DAG.getUNDEF(VT), Mask: RightMask);
27522 return DAG.getNode(Opcode: SrcOpcode, DL, VT, N1: LHS, N2: RHS);
27523 }
27524 }
27525 }
27526 }
27527
27528 if (SDValue V = foldShuffleOfConcatUndefs(Shuf: SVN, DAG))
27529 return V;
27530
27531 // Match shuffles that can be converted to ISD::ZERO_EXTEND_VECTOR_INREG.
27532 // Perform this really late, because it could eliminate knowledge
27533 // of undef elements created by this shuffle.
27534 if (Level < AfterLegalizeTypes)
27535 if (SDValue V = combineShuffleToZeroExtendVectorInReg(SVN, DAG, TLI,
27536 LegalOperations))
27537 return V;
27538
27539 return SDValue();
27540}
27541
27542SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) {
27543 EVT VT = N->getValueType(ResNo: 0);
27544 if (!VT.isFixedLengthVector())
27545 return SDValue();
27546
27547 // Try to convert a scalar binop with an extracted vector element to a vector
27548 // binop. This is intended to reduce potentially expensive register moves.
27549 // TODO: Check if both operands are extracted.
27550 // TODO: How to prefer scalar/vector ops with multiple uses of the extact?
27551 // TODO: Generalize this, so it can be called from visitINSERT_VECTOR_ELT().
27552 SDValue Scalar = N->getOperand(Num: 0);
27553 unsigned Opcode = Scalar.getOpcode();
27554 EVT VecEltVT = VT.getScalarType();
27555 if (Scalar.hasOneUse() && Scalar->getNumValues() == 1 &&
27556 TLI.isBinOp(Opcode) && Scalar.getValueType() == VecEltVT &&
27557 Scalar.getOperand(i: 0).getValueType() == VecEltVT &&
27558 Scalar.getOperand(i: 1).getValueType() == VecEltVT &&
27559 Scalar->isOnlyUserOf(N: Scalar.getOperand(i: 0).getNode()) &&
27560 Scalar->isOnlyUserOf(N: Scalar.getOperand(i: 1).getNode()) &&
27561 DAG.isSafeToSpeculativelyExecute(Opcode) && hasOperation(Opcode, VT)) {
27562 // Match an extract element and get a shuffle mask equivalent.
27563 SmallVector<int, 8> ShufMask(VT.getVectorNumElements(), -1);
27564
27565 for (int i : {0, 1}) {
27566 // s2v (bo (extelt V, Idx), C) --> shuffle (bo V, C'), {Idx, -1, -1...}
27567 // s2v (bo C, (extelt V, Idx)) --> shuffle (bo C', V), {Idx, -1, -1...}
27568 SDValue EE = Scalar.getOperand(i);
27569 auto *C = dyn_cast<ConstantSDNode>(Val: Scalar.getOperand(i: i ? 0 : 1));
27570 if (C && EE.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
27571 EE.getOperand(i: 0).getValueType() == VT &&
27572 isa<ConstantSDNode>(Val: EE.getOperand(i: 1))) {
27573 // Mask = {ExtractIndex, undef, undef....}
27574 ShufMask[0] = EE.getConstantOperandVal(i: 1);
27575 // Make sure the shuffle is legal if we are crossing lanes.
27576 if (TLI.isShuffleMaskLegal(ShufMask, VT)) {
27577 SDLoc DL(N);
27578 SDValue V[] = {EE.getOperand(i: 0),
27579 DAG.getConstant(Val: C->getAPIntValue(), DL, VT)};
27580 SDValue VecBO = DAG.getNode(Opcode, DL, VT, N1: V[i], N2: V[1 - i]);
27581 return DAG.getVectorShuffle(VT, dl: DL, N1: VecBO, N2: DAG.getUNDEF(VT),
27582 Mask: ShufMask);
27583 }
27584 }
27585 }
27586 }
27587
27588 // Replace a SCALAR_TO_VECTOR(EXTRACT_VECTOR_ELT(V,C0)) pattern
27589 // with a VECTOR_SHUFFLE and possible truncate.
27590 if (Opcode != ISD::EXTRACT_VECTOR_ELT ||
27591 !Scalar.getOperand(i: 0).getValueType().isFixedLengthVector())
27592 return SDValue();
27593
27594 // If we have an implicit truncate, truncate here if it is legal.
27595 if (VecEltVT != Scalar.getValueType() &&
27596 Scalar.getValueType().isScalarInteger() && isTypeLegal(VT: VecEltVT)) {
27597 SDValue Val = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(Scalar), VT: VecEltVT, Operand: Scalar);
27598 return DAG.getNode(Opcode: ISD::SCALAR_TO_VECTOR, DL: SDLoc(N), VT, Operand: Val);
27599 }
27600
27601 auto *ExtIndexC = dyn_cast<ConstantSDNode>(Val: Scalar.getOperand(i: 1));
27602 if (!ExtIndexC)
27603 return SDValue();
27604
27605 SDValue SrcVec = Scalar.getOperand(i: 0);
27606 EVT SrcVT = SrcVec.getValueType();
27607 unsigned SrcNumElts = SrcVT.getVectorNumElements();
27608 unsigned VTNumElts = VT.getVectorNumElements();
27609 if (VecEltVT == SrcVT.getScalarType() && VTNumElts <= SrcNumElts) {
27610 // Create a shuffle equivalent for scalar-to-vector: {ExtIndex, -1, -1, ...}
27611 SmallVector<int, 8> Mask(SrcNumElts, -1);
27612 Mask[0] = ExtIndexC->getZExtValue();
27613 SDValue LegalShuffle = TLI.buildLegalVectorShuffle(
27614 VT: SrcVT, DL: SDLoc(N), N0: SrcVec, N1: DAG.getUNDEF(VT: SrcVT), Mask, DAG);
27615 if (!LegalShuffle)
27616 return SDValue();
27617
27618 // If the initial vector is the same size, the shuffle is the result.
27619 if (VT == SrcVT)
27620 return LegalShuffle;
27621
27622 // If not, shorten the shuffled vector.
27623 if (VTNumElts != SrcNumElts) {
27624 SDValue ZeroIdx = DAG.getVectorIdxConstant(Val: 0, DL: SDLoc(N));
27625 EVT SubVT = EVT::getVectorVT(Context&: *DAG.getContext(),
27626 VT: SrcVT.getVectorElementType(), NumElements: VTNumElts);
27627 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL: SDLoc(N), VT: SubVT, N1: LegalShuffle,
27628 N2: ZeroIdx);
27629 }
27630 }
27631
27632 return SDValue();
27633}
27634
27635SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) {
27636 EVT VT = N->getValueType(ResNo: 0);
27637 SDValue N0 = N->getOperand(Num: 0);
27638 SDValue N1 = N->getOperand(Num: 1);
27639 SDValue N2 = N->getOperand(Num: 2);
27640 uint64_t InsIdx = N->getConstantOperandVal(Num: 2);
27641
27642 // If inserting an UNDEF, just return the original vector.
27643 if (N1.isUndef())
27644 return N0;
27645
27646 // If this is an insert of an extracted vector into an undef vector, we can
27647 // just use the input to the extract if the types match, and can simplify
27648 // in some cases even if they don't.
27649 if (N0.isUndef() && N1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
27650 N1.getOperand(i: 1) == N2) {
27651 EVT SrcVT = N1.getOperand(i: 0).getValueType();
27652 if (SrcVT == VT)
27653 return N1.getOperand(i: 0);
27654 // TODO: To remove the zero check, need to adjust the offset to
27655 // a multiple of the new src type.
27656 if (isNullConstant(V: N2)) {
27657 if (VT.knownBitsGE(VT: SrcVT) &&
27658 !(VT.isFixedLengthVector() && SrcVT.isScalableVector()))
27659 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N),
27660 VT, N1: N0, N2: N1.getOperand(i: 0), N3: N2);
27661 else if (VT.knownBitsLE(VT: SrcVT) &&
27662 !(VT.isScalableVector() && SrcVT.isFixedLengthVector()))
27663 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL: SDLoc(N),
27664 VT, N1: N1.getOperand(i: 0), N2);
27665 }
27666 }
27667
27668 // Handle case where we've ended up inserting back into the source vector
27669 // we extracted the subvector from.
27670 // insert_subvector(N0, extract_subvector(N0, N2), N2) --> N0
27671 if (N1.getOpcode() == ISD::EXTRACT_SUBVECTOR && N1.getOperand(i: 0) == N0 &&
27672 N1.getOperand(i: 1) == N2)
27673 return N0;
27674
27675 // Simplify scalar inserts into an undef vector:
27676 // insert_subvector undef, (splat X), N2 -> splat X
27677 if (N0.isUndef() && N1.getOpcode() == ISD::SPLAT_VECTOR)
27678 if (DAG.isConstantValueOfAnyType(N: N1.getOperand(i: 0)) || N1.hasOneUse())
27679 return DAG.getNode(Opcode: ISD::SPLAT_VECTOR, DL: SDLoc(N), VT, Operand: N1.getOperand(i: 0));
27680
27681 // If we are inserting a bitcast value into an undef, with the same
27682 // number of elements, just use the bitcast input of the extract.
27683 // i.e. INSERT_SUBVECTOR UNDEF (BITCAST N1) N2 ->
27684 // BITCAST (INSERT_SUBVECTOR UNDEF N1 N2)
27685 if (N0.isUndef() && N1.getOpcode() == ISD::BITCAST &&
27686 N1.getOperand(i: 0).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
27687 N1.getOperand(i: 0).getOperand(i: 1) == N2 &&
27688 N1.getOperand(i: 0).getOperand(i: 0).getValueType().getVectorElementCount() ==
27689 VT.getVectorElementCount() &&
27690 N1.getOperand(i: 0).getOperand(i: 0).getValueType().getSizeInBits() ==
27691 VT.getSizeInBits()) {
27692 return DAG.getBitcast(VT, V: N1.getOperand(i: 0).getOperand(i: 0));
27693 }
27694
27695 // If both N1 and N2 are bitcast values on which insert_subvector
27696 // would makes sense, pull the bitcast through.
27697 // i.e. INSERT_SUBVECTOR (BITCAST N0) (BITCAST N1) N2 ->
27698 // BITCAST (INSERT_SUBVECTOR N0 N1 N2)
27699 if (N0.getOpcode() == ISD::BITCAST && N1.getOpcode() == ISD::BITCAST) {
27700 SDValue CN0 = N0.getOperand(i: 0);
27701 SDValue CN1 = N1.getOperand(i: 0);
27702 EVT CN0VT = CN0.getValueType();
27703 EVT CN1VT = CN1.getValueType();
27704 if (CN0VT.isVector() && CN1VT.isVector() &&
27705 CN0VT.getVectorElementType() == CN1VT.getVectorElementType() &&
27706 CN0VT.getVectorElementCount() == VT.getVectorElementCount()) {
27707 SDValue NewINSERT = DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N),
27708 VT: CN0.getValueType(), N1: CN0, N2: CN1, N3: N2);
27709 return DAG.getBitcast(VT, V: NewINSERT);
27710 }
27711 }
27712
27713 // Combine INSERT_SUBVECTORs where we are inserting to the same index.
27714 // INSERT_SUBVECTOR( INSERT_SUBVECTOR( Vec, SubOld, Idx ), SubNew, Idx )
27715 // --> INSERT_SUBVECTOR( Vec, SubNew, Idx )
27716 if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
27717 N0.getOperand(i: 1).getValueType() == N1.getValueType() &&
27718 N0.getOperand(i: 2) == N2)
27719 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N), VT, N1: N0.getOperand(i: 0),
27720 N2: N1, N3: N2);
27721
27722 // Eliminate an intermediate insert into an undef vector:
27723 // insert_subvector undef, (insert_subvector undef, X, 0), 0 -->
27724 // insert_subvector undef, X, 0
27725 if (N0.isUndef() && N1.getOpcode() == ISD::INSERT_SUBVECTOR &&
27726 N1.getOperand(i: 0).isUndef() && isNullConstant(V: N1.getOperand(i: 2)) &&
27727 isNullConstant(V: N2))
27728 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N), VT, N1: N0,
27729 N2: N1.getOperand(i: 1), N3: N2);
27730
27731 // Push subvector bitcasts to the output, adjusting the index as we go.
27732 // insert_subvector(bitcast(v), bitcast(s), c1)
27733 // -> bitcast(insert_subvector(v, s, c2))
27734 if ((N0.isUndef() || N0.getOpcode() == ISD::BITCAST) &&
27735 N1.getOpcode() == ISD::BITCAST) {
27736 SDValue N0Src = peekThroughBitcasts(V: N0);
27737 SDValue N1Src = peekThroughBitcasts(V: N1);
27738 EVT N0SrcSVT = N0Src.getValueType().getScalarType();
27739 EVT N1SrcSVT = N1Src.getValueType().getScalarType();
27740 if ((N0.isUndef() || N0SrcSVT == N1SrcSVT) &&
27741 N0Src.getValueType().isVector() && N1Src.getValueType().isVector()) {
27742 EVT NewVT;
27743 SDLoc DL(N);
27744 SDValue NewIdx;
27745 LLVMContext &Ctx = *DAG.getContext();
27746 ElementCount NumElts = VT.getVectorElementCount();
27747 unsigned EltSizeInBits = VT.getScalarSizeInBits();
27748 if ((EltSizeInBits % N1SrcSVT.getSizeInBits()) == 0) {
27749 unsigned Scale = EltSizeInBits / N1SrcSVT.getSizeInBits();
27750 NewVT = EVT::getVectorVT(Context&: Ctx, VT: N1SrcSVT, EC: NumElts * Scale);
27751 NewIdx = DAG.getVectorIdxConstant(Val: InsIdx * Scale, DL);
27752 } else if ((N1SrcSVT.getSizeInBits() % EltSizeInBits) == 0) {
27753 unsigned Scale = N1SrcSVT.getSizeInBits() / EltSizeInBits;
27754 if (NumElts.isKnownMultipleOf(RHS: Scale) && (InsIdx % Scale) == 0) {
27755 NewVT = EVT::getVectorVT(Context&: Ctx, VT: N1SrcSVT,
27756 EC: NumElts.divideCoefficientBy(RHS: Scale));
27757 NewIdx = DAG.getVectorIdxConstant(Val: InsIdx / Scale, DL);
27758 }
27759 }
27760 if (NewIdx && hasOperation(Opcode: ISD::INSERT_SUBVECTOR, VT: NewVT)) {
27761 SDValue Res = DAG.getBitcast(VT: NewVT, V: N0Src);
27762 Res = DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL, VT: NewVT, N1: Res, N2: N1Src, N3: NewIdx);
27763 return DAG.getBitcast(VT, V: Res);
27764 }
27765 }
27766 }
27767
27768 // Canonicalize insert_subvector dag nodes.
27769 // Example:
27770 // (insert_subvector (insert_subvector A, Idx0), Idx1)
27771 // -> (insert_subvector (insert_subvector A, Idx1), Idx0)
27772 if (N0.getOpcode() == ISD::INSERT_SUBVECTOR && N0.hasOneUse() &&
27773 N1.getValueType() == N0.getOperand(i: 1).getValueType()) {
27774 unsigned OtherIdx = N0.getConstantOperandVal(i: 2);
27775 if (InsIdx < OtherIdx) {
27776 // Swap nodes.
27777 SDValue NewOp = DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N), VT,
27778 N1: N0.getOperand(i: 0), N2: N1, N3: N2);
27779 AddToWorklist(N: NewOp.getNode());
27780 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N0.getNode()),
27781 VT, N1: NewOp, N2: N0.getOperand(i: 1), N3: N0.getOperand(i: 2));
27782 }
27783 }
27784
27785 // If the input vector is a concatenation, and the insert replaces
27786 // one of the pieces, we can optimize into a single concat_vectors.
27787 if (N0.getOpcode() == ISD::CONCAT_VECTORS && N0.hasOneUse() &&
27788 N0.getOperand(i: 0).getValueType() == N1.getValueType() &&
27789 N0.getOperand(i: 0).getValueType().isScalableVector() ==
27790 N1.getValueType().isScalableVector()) {
27791 unsigned Factor = N1.getValueType().getVectorMinNumElements();
27792 SmallVector<SDValue, 8> Ops(N0->ops());
27793 Ops[InsIdx / Factor] = N1;
27794 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT, Ops);
27795 }
27796
27797 // Simplify source operands based on insertion.
27798 if (SimplifyDemandedVectorElts(Op: SDValue(N, 0)))
27799 return SDValue(N, 0);
27800
27801 return SDValue();
27802}
27803
27804SDValue DAGCombiner::visitFP_TO_FP16(SDNode *N) {
27805 SDValue N0 = N->getOperand(Num: 0);
27806
27807 // fold (fp_to_fp16 (fp16_to_fp op)) -> op
27808 if (N0->getOpcode() == ISD::FP16_TO_FP)
27809 return N0->getOperand(Num: 0);
27810
27811 return SDValue();
27812}
27813
27814SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
27815 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
27816 auto Op = N->getOpcode();
27817 assert((Op == ISD::FP16_TO_FP || Op == ISD::BF16_TO_FP) &&
27818 "opcode should be FP16_TO_FP or BF16_TO_FP.");
27819 SDValue N0 = N->getOperand(Num: 0);
27820
27821 // fold fp16_to_fp(op & 0xffff) -> fp16_to_fp(op) or
27822 // fold bf16_to_fp(op & 0xffff) -> bf16_to_fp(op)
27823 if (!TLI.shouldKeepZExtForFP16Conv() && N0->getOpcode() == ISD::AND) {
27824 ConstantSDNode *AndConst = getAsNonOpaqueConstant(N: N0.getOperand(i: 1));
27825 if (AndConst && AndConst->getAPIntValue() == 0xffff) {
27826 return DAG.getNode(Opcode: Op, DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Operand: N0.getOperand(i: 0));
27827 }
27828 }
27829
27830 if (SDValue CastEliminated = eliminateFPCastPair(N))
27831 return CastEliminated;
27832
27833 // Sometimes constants manage to survive very late in the pipeline, e.g.,
27834 // because they are wrapped inside the <1 x f16> type. Try one last time to
27835 // get rid of them.
27836 SDValue Folded = DAG.FoldConstantArithmetic(Opcode: N->getOpcode(), DL: SDLoc(N),
27837 VT: N->getValueType(ResNo: 0), Ops: {N0});
27838 return Folded;
27839}
27840
27841SDValue DAGCombiner::visitFP_TO_BF16(SDNode *N) {
27842 SDValue N0 = N->getOperand(Num: 0);
27843
27844 // fold (fp_to_bf16 (bf16_to_fp op)) -> op
27845 if (N0->getOpcode() == ISD::BF16_TO_FP)
27846 return N0->getOperand(Num: 0);
27847
27848 return SDValue();
27849}
27850
27851SDValue DAGCombiner::visitBF16_TO_FP(SDNode *N) {
27852 // fold bf16_to_fp(op & 0xffff) -> bf16_to_fp(op)
27853 return visitFP16_TO_FP(N);
27854}
27855
27856SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
27857 SDValue N0 = N->getOperand(Num: 0);
27858 EVT VT = N0.getValueType();
27859 unsigned Opcode = N->getOpcode();
27860
27861 // VECREDUCE over 1-element vector is just an extract.
27862 if (VT.getVectorElementCount().isScalar()) {
27863 SDLoc dl(N);
27864 SDValue Res =
27865 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT: VT.getVectorElementType(), N1: N0,
27866 N2: DAG.getVectorIdxConstant(Val: 0, DL: dl));
27867 if (Res.getValueType() != N->getValueType(ResNo: 0))
27868 Res = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL: dl, VT: N->getValueType(ResNo: 0), Operand: Res);
27869 return Res;
27870 }
27871
27872 // On an boolean vector an and/or reduction is the same as a umin/umax
27873 // reduction. Convert them if the latter is legal while the former isn't.
27874 if (Opcode == ISD::VECREDUCE_AND || Opcode == ISD::VECREDUCE_OR) {
27875 unsigned NewOpcode = Opcode == ISD::VECREDUCE_AND
27876 ? ISD::VECREDUCE_UMIN : ISD::VECREDUCE_UMAX;
27877 if (!TLI.isOperationLegalOrCustom(Op: Opcode, VT) &&
27878 TLI.isOperationLegalOrCustom(Op: NewOpcode, VT) &&
27879 DAG.ComputeNumSignBits(Op: N0) == VT.getScalarSizeInBits())
27880 return DAG.getNode(Opcode: NewOpcode, DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Operand: N0);
27881 }
27882
27883 // vecreduce_or(insert_subvector(zero or undef, val)) -> vecreduce_or(val)
27884 // vecreduce_and(insert_subvector(ones or undef, val)) -> vecreduce_and(val)
27885 if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
27886 TLI.isTypeLegal(VT: N0.getOperand(i: 1).getValueType())) {
27887 SDValue Vec = N0.getOperand(i: 0);
27888 SDValue Subvec = N0.getOperand(i: 1);
27889 if ((Opcode == ISD::VECREDUCE_OR &&
27890 (N0.getOperand(i: 0).isUndef() || isNullOrNullSplat(V: Vec))) ||
27891 (Opcode == ISD::VECREDUCE_AND &&
27892 (N0.getOperand(i: 0).isUndef() || isAllOnesOrAllOnesSplat(V: Vec))))
27893 return DAG.getNode(Opcode, DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Operand: Subvec);
27894 }
27895
27896 // vecreduce_or(sext(x)) -> sext(vecreduce_or(x))
27897 // Same for zext and anyext, and for and/or/xor reductions.
27898 if ((Opcode == ISD::VECREDUCE_OR || Opcode == ISD::VECREDUCE_AND ||
27899 Opcode == ISD::VECREDUCE_XOR) &&
27900 (N0.getOpcode() == ISD::SIGN_EXTEND ||
27901 N0.getOpcode() == ISD::ZERO_EXTEND ||
27902 N0.getOpcode() == ISD::ANY_EXTEND) &&
27903 TLI.isOperationLegalOrCustom(Op: Opcode, VT: N0.getOperand(i: 0).getValueType())) {
27904 SDValue Red = DAG.getNode(Opcode, DL: SDLoc(N),
27905 VT: N0.getOperand(i: 0).getValueType().getScalarType(),
27906 Operand: N0.getOperand(i: 0));
27907 return DAG.getNode(Opcode: N0.getOpcode(), DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Operand: Red);
27908 }
27909 return SDValue();
27910}
27911
27912SDValue DAGCombiner::visitVP_FSUB(SDNode *N) {
27913 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
27914
27915 // FSUB -> FMA combines:
27916 if (SDValue Fused = visitFSUBForFMACombine<VPMatchContext>(N)) {
27917 AddToWorklist(N: Fused.getNode());
27918 return Fused;
27919 }
27920 return SDValue();
27921}
27922
27923SDValue DAGCombiner::visitVPOp(SDNode *N) {
27924
27925 if (N->getOpcode() == ISD::VP_GATHER)
27926 if (SDValue SD = visitVPGATHER(N))
27927 return SD;
27928
27929 if (N->getOpcode() == ISD::VP_SCATTER)
27930 if (SDValue SD = visitVPSCATTER(N))
27931 return SD;
27932
27933 if (N->getOpcode() == ISD::EXPERIMENTAL_VP_STRIDED_LOAD)
27934 if (SDValue SD = visitVP_STRIDED_LOAD(N))
27935 return SD;
27936
27937 if (N->getOpcode() == ISD::EXPERIMENTAL_VP_STRIDED_STORE)
27938 if (SDValue SD = visitVP_STRIDED_STORE(N))
27939 return SD;
27940
27941 // VP operations in which all vector elements are disabled - either by
27942 // determining that the mask is all false or that the EVL is 0 - can be
27943 // eliminated.
27944 bool AreAllEltsDisabled = false;
27945 if (auto EVLIdx = ISD::getVPExplicitVectorLengthIdx(Opcode: N->getOpcode()))
27946 AreAllEltsDisabled |= isNullConstant(V: N->getOperand(Num: *EVLIdx));
27947 if (auto MaskIdx = ISD::getVPMaskIdx(Opcode: N->getOpcode()))
27948 AreAllEltsDisabled |=
27949 ISD::isConstantSplatVectorAllZeros(N: N->getOperand(Num: *MaskIdx).getNode());
27950
27951 // This is the only generic VP combine we support for now.
27952 if (!AreAllEltsDisabled) {
27953 switch (N->getOpcode()) {
27954 case ISD::VP_FADD:
27955 return visitVP_FADD(N);
27956 case ISD::VP_FSUB:
27957 return visitVP_FSUB(N);
27958 case ISD::VP_FMA:
27959 return visitFMA<VPMatchContext>(N);
27960 case ISD::VP_SELECT:
27961 return visitVP_SELECT(N);
27962 case ISD::VP_MUL:
27963 return visitMUL<VPMatchContext>(N);
27964 case ISD::VP_SUB:
27965 return foldSubCtlzNot<VPMatchContext>(N, DAG);
27966 default:
27967 break;
27968 }
27969 return SDValue();
27970 }
27971
27972 // Binary operations can be replaced by UNDEF.
27973 if (ISD::isVPBinaryOp(Opcode: N->getOpcode()))
27974 return DAG.getUNDEF(VT: N->getValueType(ResNo: 0));
27975
27976 // VP Memory operations can be replaced by either the chain (stores) or the
27977 // chain + undef (loads).
27978 if (const auto *MemSD = dyn_cast<MemSDNode>(Val: N)) {
27979 if (MemSD->writeMem())
27980 return MemSD->getChain();
27981 return CombineTo(N, Res0: DAG.getUNDEF(VT: N->getValueType(ResNo: 0)), Res1: MemSD->getChain());
27982 }
27983
27984 // Reduction operations return the start operand when no elements are active.
27985 if (ISD::isVPReduction(Opcode: N->getOpcode()))
27986 return N->getOperand(Num: 0);
27987
27988 return SDValue();
27989}
27990
27991SDValue DAGCombiner::visitGET_FPENV_MEM(SDNode *N) {
27992 SDValue Chain = N->getOperand(Num: 0);
27993 SDValue Ptr = N->getOperand(Num: 1);
27994 EVT MemVT = cast<FPStateAccessSDNode>(Val: N)->getMemoryVT();
27995
27996 // Check if the memory, where FP state is written to, is used only in a single
27997 // load operation.
27998 LoadSDNode *LdNode = nullptr;
27999 for (auto *U : Ptr->users()) {
28000 if (U == N)
28001 continue;
28002 if (auto *Ld = dyn_cast<LoadSDNode>(Val: U)) {
28003 if (LdNode && LdNode != Ld)
28004 return SDValue();
28005 LdNode = Ld;
28006 continue;
28007 }
28008 return SDValue();
28009 }
28010 if (!LdNode || !LdNode->isSimple() || LdNode->isIndexed() ||
28011 !LdNode->getOffset().isUndef() || LdNode->getMemoryVT() != MemVT ||
28012 !LdNode->getChain().reachesChainWithoutSideEffects(Dest: SDValue(N, 0)))
28013 return SDValue();
28014
28015 // Check if the loaded value is used only in a store operation.
28016 StoreSDNode *StNode = nullptr;
28017 for (SDUse &U : LdNode->uses()) {
28018 if (U.getResNo() == 0) {
28019 if (auto *St = dyn_cast<StoreSDNode>(Val: U.getUser())) {
28020 if (StNode)
28021 return SDValue();
28022 StNode = St;
28023 } else {
28024 return SDValue();
28025 }
28026 }
28027 }
28028 if (!StNode || !StNode->isSimple() || StNode->isIndexed() ||
28029 !StNode->getOffset().isUndef() || StNode->getMemoryVT() != MemVT ||
28030 !StNode->getChain().reachesChainWithoutSideEffects(Dest: SDValue(LdNode, 1)))
28031 return SDValue();
28032
28033 // Create new node GET_FPENV_MEM, which uses the store address to write FP
28034 // environment.
28035 SDValue Res = DAG.getGetFPEnv(Chain, dl: SDLoc(N), Ptr: StNode->getBasePtr(), MemVT,
28036 MMO: StNode->getMemOperand());
28037 CombineTo(N: StNode, Res, AddTo: false);
28038 return Res;
28039}
28040
28041SDValue DAGCombiner::visitSET_FPENV_MEM(SDNode *N) {
28042 SDValue Chain = N->getOperand(Num: 0);
28043 SDValue Ptr = N->getOperand(Num: 1);
28044 EVT MemVT = cast<FPStateAccessSDNode>(Val: N)->getMemoryVT();
28045
28046 // Check if the address of FP state is used also in a store operation only.
28047 StoreSDNode *StNode = nullptr;
28048 for (auto *U : Ptr->users()) {
28049 if (U == N)
28050 continue;
28051 if (auto *St = dyn_cast<StoreSDNode>(Val: U)) {
28052 if (StNode && StNode != St)
28053 return SDValue();
28054 StNode = St;
28055 continue;
28056 }
28057 return SDValue();
28058 }
28059 if (!StNode || !StNode->isSimple() || StNode->isIndexed() ||
28060 !StNode->getOffset().isUndef() || StNode->getMemoryVT() != MemVT ||
28061 !Chain.reachesChainWithoutSideEffects(Dest: SDValue(StNode, 0)))
28062 return SDValue();
28063
28064 // Check if the stored value is loaded from some location and the loaded
28065 // value is used only in the store operation.
28066 SDValue StValue = StNode->getValue();
28067 auto *LdNode = dyn_cast<LoadSDNode>(Val&: StValue);
28068 if (!LdNode || !LdNode->isSimple() || LdNode->isIndexed() ||
28069 !LdNode->getOffset().isUndef() || LdNode->getMemoryVT() != MemVT ||
28070 !StNode->getChain().reachesChainWithoutSideEffects(Dest: SDValue(LdNode, 1)))
28071 return SDValue();
28072
28073 // Create new node SET_FPENV_MEM, which uses the load address to read FP
28074 // environment.
28075 SDValue Res =
28076 DAG.getSetFPEnv(Chain: LdNode->getChain(), dl: SDLoc(N), Ptr: LdNode->getBasePtr(), MemVT,
28077 MMO: LdNode->getMemOperand());
28078 return Res;
28079}
28080
28081/// Returns a vector_shuffle if it able to transform an AND to a vector_shuffle
28082/// with the destination vector and a zero vector.
28083/// e.g. AND V, <0xffffffff, 0, 0xffffffff, 0>. ==>
28084/// vector_shuffle V, Zero, <0, 4, 2, 4>
28085SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) {
28086 assert(N->getOpcode() == ISD::AND && "Unexpected opcode!");
28087
28088 EVT VT = N->getValueType(ResNo: 0);
28089 SDValue LHS = N->getOperand(Num: 0);
28090 SDValue RHS = peekThroughBitcasts(V: N->getOperand(Num: 1));
28091 SDLoc DL(N);
28092
28093 // Make sure we're not running after operation legalization where it
28094 // may have custom lowered the vector shuffles.
28095 if (LegalOperations)
28096 return SDValue();
28097
28098 if (RHS.getOpcode() != ISD::BUILD_VECTOR)
28099 return SDValue();
28100
28101 EVT RVT = RHS.getValueType();
28102 unsigned NumElts = RHS.getNumOperands();
28103
28104 // Attempt to create a valid clear mask, splitting the mask into
28105 // sub elements and checking to see if each is
28106 // all zeros or all ones - suitable for shuffle masking.
28107 auto BuildClearMask = [&](int Split) {
28108 int NumSubElts = NumElts * Split;
28109 int NumSubBits = RVT.getScalarSizeInBits() / Split;
28110
28111 SmallVector<int, 8> Indices;
28112 for (int i = 0; i != NumSubElts; ++i) {
28113 int EltIdx = i / Split;
28114 int SubIdx = i % Split;
28115 SDValue Elt = RHS.getOperand(i: EltIdx);
28116 // X & undef --> 0 (not undef). So this lane must be converted to choose
28117 // from the zero constant vector (same as if the element had all 0-bits).
28118 if (Elt.isUndef()) {
28119 Indices.push_back(Elt: i + NumSubElts);
28120 continue;
28121 }
28122
28123 std::optional<APInt> Bits = Elt->bitcastToAPInt();
28124 if (!Bits)
28125 return SDValue();
28126
28127 // Extract the sub element from the constant bit mask.
28128 if (DAG.getDataLayout().isBigEndian())
28129 *Bits =
28130 Bits->extractBits(numBits: NumSubBits, bitPosition: (Split - SubIdx - 1) * NumSubBits);
28131 else
28132 *Bits = Bits->extractBits(numBits: NumSubBits, bitPosition: SubIdx * NumSubBits);
28133
28134 if (Bits->isAllOnes())
28135 Indices.push_back(Elt: i);
28136 else if (*Bits == 0)
28137 Indices.push_back(Elt: i + NumSubElts);
28138 else
28139 return SDValue();
28140 }
28141
28142 // Let's see if the target supports this vector_shuffle.
28143 EVT ClearSVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: NumSubBits);
28144 EVT ClearVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: ClearSVT, NumElements: NumSubElts);
28145 if (!TLI.isVectorClearMaskLegal(Indices, ClearVT))
28146 return SDValue();
28147
28148 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: ClearVT);
28149 return DAG.getBitcast(VT, V: DAG.getVectorShuffle(VT: ClearVT, dl: DL,
28150 N1: DAG.getBitcast(VT: ClearVT, V: LHS),
28151 N2: Zero, Mask: Indices));
28152 };
28153
28154 // Determine maximum split level (byte level masking).
28155 int MaxSplit = 1;
28156 if (RVT.getScalarSizeInBits() % 8 == 0)
28157 MaxSplit = RVT.getScalarSizeInBits() / 8;
28158
28159 for (int Split = 1; Split <= MaxSplit; ++Split)
28160 if (RVT.getScalarSizeInBits() % Split == 0)
28161 if (SDValue S = BuildClearMask(Split))
28162 return S;
28163
28164 return SDValue();
28165}
28166
28167/// If a vector binop is performed on splat values, it may be profitable to
28168/// extract, scalarize, and insert/splat.
28169static SDValue scalarizeBinOpOfSplats(SDNode *N, SelectionDAG &DAG,
28170 const SDLoc &DL, bool LegalTypes) {
28171 SDValue N0 = N->getOperand(Num: 0);
28172 SDValue N1 = N->getOperand(Num: 1);
28173 unsigned Opcode = N->getOpcode();
28174 EVT VT = N->getValueType(ResNo: 0);
28175 EVT EltVT = VT.getVectorElementType();
28176 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
28177
28178 // TODO: Remove/replace the extract cost check? If the elements are available
28179 // as scalars, then there may be no extract cost. Should we ask if
28180 // inserting a scalar back into a vector is cheap instead?
28181 int Index0, Index1;
28182 SDValue Src0 = DAG.getSplatSourceVector(V: N0, SplatIndex&: Index0);
28183 SDValue Src1 = DAG.getSplatSourceVector(V: N1, SplatIndex&: Index1);
28184 // Extract element from splat_vector should be free.
28185 // TODO: use DAG.isSplatValue instead?
28186 bool IsBothSplatVector = N0.getOpcode() == ISD::SPLAT_VECTOR &&
28187 N1.getOpcode() == ISD::SPLAT_VECTOR;
28188 if (!Src0 || !Src1 || Index0 != Index1 ||
28189 Src0.getValueType().getVectorElementType() != EltVT ||
28190 Src1.getValueType().getVectorElementType() != EltVT ||
28191 !(IsBothSplatVector || TLI.isExtractVecEltCheap(VT, Index: Index0)) ||
28192 // If before type legalization, allow scalar types that will eventually be
28193 // made legal.
28194 !TLI.isOperationLegalOrCustom(
28195 Op: Opcode, VT: LegalTypes
28196 ? EltVT
28197 : TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: EltVT)))
28198 return SDValue();
28199
28200 // FIXME: Type legalization can't handle illegal MULHS/MULHU.
28201 if ((Opcode == ISD::MULHS || Opcode == ISD::MULHU) && !TLI.isTypeLegal(VT: EltVT))
28202 return SDValue();
28203
28204 if (N0.getOpcode() == ISD::BUILD_VECTOR && N0.getOpcode() == N1.getOpcode()) {
28205 // All but one element should have an undef input, which will fold to a
28206 // constant or undef. Avoid splatting which would over-define potentially
28207 // undefined elements.
28208
28209 // bo (build_vec ..undef, X, undef...), (build_vec ..undef, Y, undef...) -->
28210 // build_vec ..undef, (bo X, Y), undef...
28211 SmallVector<SDValue, 16> EltsX, EltsY, EltsResult;
28212 DAG.ExtractVectorElements(Op: Src0, Args&: EltsX);
28213 DAG.ExtractVectorElements(Op: Src1, Args&: EltsY);
28214
28215 for (auto [X, Y] : zip(t&: EltsX, u&: EltsY))
28216 EltsResult.push_back(Elt: DAG.getNode(Opcode, DL, VT: EltVT, N1: X, N2: Y, Flags: N->getFlags()));
28217 return DAG.getBuildVector(VT, DL, Ops: EltsResult);
28218 }
28219
28220 SDValue IndexC = DAG.getVectorIdxConstant(Val: Index0, DL);
28221 SDValue X = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: Src0, N2: IndexC);
28222 SDValue Y = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: Src1, N2: IndexC);
28223 SDValue ScalarBO = DAG.getNode(Opcode, DL, VT: EltVT, N1: X, N2: Y, Flags: N->getFlags());
28224
28225 // bo (splat X, Index), (splat Y, Index) --> splat (bo X, Y), Index
28226 return DAG.getSplat(VT, DL, Op: ScalarBO);
28227}
28228
28229/// Visit a vector cast operation, like FP_EXTEND.
28230SDValue DAGCombiner::SimplifyVCastOp(SDNode *N, const SDLoc &DL) {
28231 EVT VT = N->getValueType(ResNo: 0);
28232 assert(VT.isVector() && "SimplifyVCastOp only works on vectors!");
28233 EVT EltVT = VT.getVectorElementType();
28234 unsigned Opcode = N->getOpcode();
28235
28236 SDValue N0 = N->getOperand(Num: 0);
28237 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
28238
28239 // TODO: promote operation might be also good here?
28240 int Index0;
28241 SDValue Src0 = DAG.getSplatSourceVector(V: N0, SplatIndex&: Index0);
28242 if (Src0 &&
28243 (N0.getOpcode() == ISD::SPLAT_VECTOR ||
28244 TLI.isExtractVecEltCheap(VT, Index: Index0)) &&
28245 TLI.isOperationLegalOrCustom(Op: Opcode, VT: EltVT) &&
28246 TLI.preferScalarizeSplat(N)) {
28247 EVT SrcVT = N0.getValueType();
28248 EVT SrcEltVT = SrcVT.getVectorElementType();
28249 SDValue IndexC = DAG.getVectorIdxConstant(Val: Index0, DL);
28250 SDValue Elt =
28251 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: SrcEltVT, N1: Src0, N2: IndexC);
28252 SDValue ScalarBO = DAG.getNode(Opcode, DL, VT: EltVT, Operand: Elt, Flags: N->getFlags());
28253 if (VT.isScalableVector())
28254 return DAG.getSplatVector(VT, DL, Op: ScalarBO);
28255 SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), ScalarBO);
28256 return DAG.getBuildVector(VT, DL, Ops);
28257 }
28258
28259 return SDValue();
28260}
28261
28262/// Visit a binary vector operation, like ADD.
28263SDValue DAGCombiner::SimplifyVBinOp(SDNode *N, const SDLoc &DL) {
28264 EVT VT = N->getValueType(ResNo: 0);
28265 assert(VT.isVector() && "SimplifyVBinOp only works on vectors!");
28266
28267 SDValue LHS = N->getOperand(Num: 0);
28268 SDValue RHS = N->getOperand(Num: 1);
28269 unsigned Opcode = N->getOpcode();
28270 SDNodeFlags Flags = N->getFlags();
28271
28272 // Move unary shuffles with identical masks after a vector binop:
28273 // VBinOp (shuffle A, Undef, Mask), (shuffle B, Undef, Mask))
28274 // --> shuffle (VBinOp A, B), Undef, Mask
28275 // This does not require type legality checks because we are creating the
28276 // same types of operations that are in the original sequence. We do have to
28277 // restrict ops like integer div that have immediate UB (eg, div-by-zero)
28278 // though. This code is adapted from the identical transform in instcombine.
28279 if (DAG.isSafeToSpeculativelyExecute(Opcode)) {
28280 auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(Val&: LHS);
28281 auto *Shuf1 = dyn_cast<ShuffleVectorSDNode>(Val&: RHS);
28282 if (Shuf0 && Shuf1 && Shuf0->getMask().equals(RHS: Shuf1->getMask()) &&
28283 LHS.getOperand(i: 1).isUndef() && RHS.getOperand(i: 1).isUndef() &&
28284 (LHS.hasOneUse() || RHS.hasOneUse() || LHS == RHS)) {
28285 SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, N1: LHS.getOperand(i: 0),
28286 N2: RHS.getOperand(i: 0), Flags);
28287 SDValue UndefV = LHS.getOperand(i: 1);
28288 return DAG.getVectorShuffle(VT, dl: DL, N1: NewBinOp, N2: UndefV, Mask: Shuf0->getMask());
28289 }
28290
28291 // Try to sink a splat shuffle after a binop with a uniform constant.
28292 // This is limited to cases where neither the shuffle nor the constant have
28293 // undefined elements because that could be poison-unsafe or inhibit
28294 // demanded elements analysis. It is further limited to not change a splat
28295 // of an inserted scalar because that may be optimized better by
28296 // load-folding or other target-specific behaviors.
28297 if (isConstOrConstSplat(N: RHS) && Shuf0 && all_equal(Range: Shuf0->getMask()) &&
28298 Shuf0->hasOneUse() && Shuf0->getOperand(Num: 1).isUndef() &&
28299 Shuf0->getOperand(Num: 0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
28300 // binop (splat X), (splat C) --> splat (binop X, C)
28301 SDValue X = Shuf0->getOperand(Num: 0);
28302 SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, N1: X, N2: RHS, Flags);
28303 return DAG.getVectorShuffle(VT, dl: DL, N1: NewBinOp, N2: DAG.getUNDEF(VT),
28304 Mask: Shuf0->getMask());
28305 }
28306 if (isConstOrConstSplat(N: LHS) && Shuf1 && all_equal(Range: Shuf1->getMask()) &&
28307 Shuf1->hasOneUse() && Shuf1->getOperand(Num: 1).isUndef() &&
28308 Shuf1->getOperand(Num: 0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
28309 // binop (splat C), (splat X) --> splat (binop C, X)
28310 SDValue X = Shuf1->getOperand(Num: 0);
28311 SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, N1: LHS, N2: X, Flags);
28312 return DAG.getVectorShuffle(VT, dl: DL, N1: NewBinOp, N2: DAG.getUNDEF(VT),
28313 Mask: Shuf1->getMask());
28314 }
28315 }
28316
28317 // The following pattern is likely to emerge with vector reduction ops. Moving
28318 // the binary operation ahead of insertion may allow using a narrower vector
28319 // instruction that has better performance than the wide version of the op:
28320 // VBinOp (ins undef, X, Z), (ins undef, Y, Z) --> ins VecC, (VBinOp X, Y), Z
28321 if (LHS.getOpcode() == ISD::INSERT_SUBVECTOR && LHS.getOperand(i: 0).isUndef() &&
28322 RHS.getOpcode() == ISD::INSERT_SUBVECTOR && RHS.getOperand(i: 0).isUndef() &&
28323 LHS.getOperand(i: 2) == RHS.getOperand(i: 2) &&
28324 (LHS.hasOneUse() || RHS.hasOneUse())) {
28325 SDValue X = LHS.getOperand(i: 1);
28326 SDValue Y = RHS.getOperand(i: 1);
28327 SDValue Z = LHS.getOperand(i: 2);
28328 EVT NarrowVT = X.getValueType();
28329 if (NarrowVT == Y.getValueType() &&
28330 TLI.isOperationLegalOrCustomOrPromote(Op: Opcode, VT: NarrowVT,
28331 LegalOnly: LegalOperations)) {
28332 // (binop undef, undef) may not return undef, so compute that result.
28333 SDValue VecC =
28334 DAG.getNode(Opcode, DL, VT, N1: DAG.getUNDEF(VT), N2: DAG.getUNDEF(VT));
28335 SDValue NarrowBO = DAG.getNode(Opcode, DL, VT: NarrowVT, N1: X, N2: Y);
28336 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL, VT, N1: VecC, N2: NarrowBO, N3: Z);
28337 }
28338 }
28339
28340 // Make sure all but the first op are undef or constant.
28341 auto ConcatWithConstantOrUndef = [](SDValue Concat) {
28342 return Concat.getOpcode() == ISD::CONCAT_VECTORS &&
28343 all_of(Range: drop_begin(RangeOrContainer: Concat->ops()), P: [](const SDValue &Op) {
28344 return Op.isUndef() ||
28345 ISD::isBuildVectorOfConstantSDNodes(N: Op.getNode());
28346 });
28347 };
28348
28349 // The following pattern is likely to emerge with vector reduction ops. Moving
28350 // the binary operation ahead of the concat may allow using a narrower vector
28351 // instruction that has better performance than the wide version of the op:
28352 // VBinOp (concat X, undef/constant), (concat Y, undef/constant) -->
28353 // concat (VBinOp X, Y), VecC
28354 if (ConcatWithConstantOrUndef(LHS) && ConcatWithConstantOrUndef(RHS) &&
28355 (LHS.hasOneUse() || RHS.hasOneUse())) {
28356 EVT NarrowVT = LHS.getOperand(i: 0).getValueType();
28357 if (NarrowVT == RHS.getOperand(i: 0).getValueType() &&
28358 TLI.isOperationLegalOrCustomOrPromote(Op: Opcode, VT: NarrowVT)) {
28359 unsigned NumOperands = LHS.getNumOperands();
28360 SmallVector<SDValue, 4> ConcatOps;
28361 for (unsigned i = 0; i != NumOperands; ++i) {
28362 // This constant fold for operands 1 and up.
28363 ConcatOps.push_back(Elt: DAG.getNode(Opcode, DL, VT: NarrowVT, N1: LHS.getOperand(i),
28364 N2: RHS.getOperand(i)));
28365 }
28366
28367 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT, Ops: ConcatOps);
28368 }
28369 }
28370
28371 if (SDValue V = scalarizeBinOpOfSplats(N, DAG, DL, LegalTypes))
28372 return V;
28373
28374 return SDValue();
28375}
28376
28377SDValue DAGCombiner::SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1,
28378 SDValue N2) {
28379 assert(N0.getOpcode() == ISD::SETCC &&
28380 "First argument must be a SetCC node!");
28381
28382 SDValue SCC = SimplifySelectCC(DL, N0: N0.getOperand(i: 0), N1: N0.getOperand(i: 1), N2: N1, N3: N2,
28383 CC: cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get());
28384
28385 // If we got a simplified select_cc node back from SimplifySelectCC, then
28386 // break it down into a new SETCC node, and a new SELECT node, and then return
28387 // the SELECT node, since we were called with a SELECT node.
28388 if (SCC.getNode()) {
28389 // Check to see if we got a select_cc back (to turn into setcc/select).
28390 // Otherwise, just return whatever node we got back, like fabs.
28391 if (SCC.getOpcode() == ISD::SELECT_CC) {
28392 const SDNodeFlags Flags = N0->getFlags();
28393 SDValue SETCC = DAG.getNode(Opcode: ISD::SETCC, DL: SDLoc(N0),
28394 VT: N0.getValueType(),
28395 N1: SCC.getOperand(i: 0), N2: SCC.getOperand(i: 1),
28396 N3: SCC.getOperand(i: 4), Flags);
28397 AddToWorklist(N: SETCC.getNode());
28398 SDValue SelectNode = DAG.getSelect(DL: SDLoc(SCC), VT: SCC.getValueType(), Cond: SETCC,
28399 LHS: SCC.getOperand(i: 2), RHS: SCC.getOperand(i: 3));
28400 SelectNode->setFlags(Flags);
28401 return SelectNode;
28402 }
28403
28404 return SCC;
28405 }
28406 return SDValue();
28407}
28408
28409/// Given a SELECT or a SELECT_CC node, where LHS and RHS are the two values
28410/// being selected between, see if we can simplify the select. Callers of this
28411/// should assume that TheSelect is deleted if this returns true. As such, they
28412/// should return the appropriate thing (e.g. the node) back to the top-level of
28413/// the DAG combiner loop to avoid it being looked at.
28414bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS,
28415 SDValue RHS) {
28416 // fold (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
28417 // The select + setcc is redundant, because fsqrt returns NaN for X < 0.
28418 if (const ConstantFPSDNode *NaN = isConstOrConstSplatFP(N: LHS)) {
28419 if (NaN->isNaN() && RHS.getOpcode() == ISD::FSQRT) {
28420 // We have: (select (setcc ?, ?, ?), NaN, (fsqrt ?))
28421 SDValue Sqrt = RHS;
28422 ISD::CondCode CC;
28423 SDValue CmpLHS;
28424 const ConstantFPSDNode *Zero = nullptr;
28425
28426 if (TheSelect->getOpcode() == ISD::SELECT_CC) {
28427 CC = cast<CondCodeSDNode>(Val: TheSelect->getOperand(Num: 4))->get();
28428 CmpLHS = TheSelect->getOperand(Num: 0);
28429 Zero = isConstOrConstSplatFP(N: TheSelect->getOperand(Num: 1));
28430 } else {
28431 // SELECT or VSELECT
28432 SDValue Cmp = TheSelect->getOperand(Num: 0);
28433 if (Cmp.getOpcode() == ISD::SETCC) {
28434 CC = cast<CondCodeSDNode>(Val: Cmp.getOperand(i: 2))->get();
28435 CmpLHS = Cmp.getOperand(i: 0);
28436 Zero = isConstOrConstSplatFP(N: Cmp.getOperand(i: 1));
28437 }
28438 }
28439 if (Zero && Zero->isZero() &&
28440 Sqrt.getOperand(i: 0) == CmpLHS && (CC == ISD::SETOLT ||
28441 CC == ISD::SETULT || CC == ISD::SETLT)) {
28442 // We have: (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
28443 CombineTo(N: TheSelect, Res: Sqrt);
28444 return true;
28445 }
28446 }
28447 }
28448 // Cannot simplify select with vector condition
28449 if (TheSelect->getOperand(Num: 0).getValueType().isVector()) return false;
28450
28451 // If this is a select from two identical things, try to pull the operation
28452 // through the select.
28453 if (LHS.getOpcode() != RHS.getOpcode() ||
28454 !LHS.hasOneUse() || !RHS.hasOneUse())
28455 return false;
28456
28457 // If this is a load and the token chain is identical, replace the select
28458 // of two loads with a load through a select of the address to load from.
28459 // This triggers in things like "select bool X, 10.0, 123.0" after the FP
28460 // constants have been dropped into the constant pool.
28461 if (LHS.getOpcode() == ISD::LOAD) {
28462 LoadSDNode *LLD = cast<LoadSDNode>(Val&: LHS);
28463 LoadSDNode *RLD = cast<LoadSDNode>(Val&: RHS);
28464
28465 // Token chains must be identical.
28466 if (LHS.getOperand(i: 0) != RHS.getOperand(i: 0) ||
28467 // Do not let this transformation reduce the number of volatile loads.
28468 // Be conservative for atomics for the moment
28469 // TODO: This does appear to be legal for unordered atomics (see D66309)
28470 !LLD->isSimple() || !RLD->isSimple() ||
28471 // FIXME: If either is a pre/post inc/dec load,
28472 // we'd need to split out the address adjustment.
28473 LLD->isIndexed() || RLD->isIndexed() ||
28474 // If this is an EXTLOAD, the VT's must match.
28475 LLD->getMemoryVT() != RLD->getMemoryVT() ||
28476 // If this is an EXTLOAD, the kind of extension must match.
28477 (LLD->getExtensionType() != RLD->getExtensionType() &&
28478 // The only exception is if one of the extensions is anyext.
28479 LLD->getExtensionType() != ISD::EXTLOAD &&
28480 RLD->getExtensionType() != ISD::EXTLOAD) ||
28481 // FIXME: this discards src value information. This is
28482 // over-conservative. It would be beneficial to be able to remember
28483 // both potential memory locations. Since we are discarding
28484 // src value info, don't do the transformation if the memory
28485 // locations are not in the default address space.
28486 LLD->getPointerInfo().getAddrSpace() != 0 ||
28487 RLD->getPointerInfo().getAddrSpace() != 0 ||
28488 // We can't produce a CMOV of a TargetFrameIndex since we won't
28489 // generate the address generation required.
28490 LLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
28491 RLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
28492 !TLI.isOperationLegalOrCustom(Op: TheSelect->getOpcode(),
28493 VT: LLD->getBasePtr().getValueType()))
28494 return false;
28495
28496 // The loads must not depend on one another.
28497 if (LLD->isPredecessorOf(N: RLD) || RLD->isPredecessorOf(N: LLD))
28498 return false;
28499
28500 // Check that the select condition doesn't reach either load. If so,
28501 // folding this will induce a cycle into the DAG. If not, this is safe to
28502 // xform, so create a select of the addresses.
28503
28504 SmallPtrSet<const SDNode *, 32> Visited;
28505 SmallVector<const SDNode *, 16> Worklist;
28506
28507 // Always fail if LLD and RLD are not independent. TheSelect is a
28508 // predecessor to all Nodes in question so we need not search past it.
28509
28510 Visited.insert(Ptr: TheSelect);
28511 Worklist.push_back(Elt: LLD);
28512 Worklist.push_back(Elt: RLD);
28513
28514 if (SDNode::hasPredecessorHelper(N: LLD, Visited, Worklist) ||
28515 SDNode::hasPredecessorHelper(N: RLD, Visited, Worklist))
28516 return false;
28517
28518 SDValue Addr;
28519 if (TheSelect->getOpcode() == ISD::SELECT) {
28520 // We cannot do this optimization if any pair of {RLD, LLD} is a
28521 // predecessor to {RLD, LLD, CondNode}. As we've already compared the
28522 // Loads, we only need to check if CondNode is a successor to one of the
28523 // loads. We can further avoid this if there's no use of their chain
28524 // value.
28525 SDNode *CondNode = TheSelect->getOperand(Num: 0).getNode();
28526 Worklist.push_back(Elt: CondNode);
28527
28528 if ((LLD->hasAnyUseOfValue(Value: 1) &&
28529 SDNode::hasPredecessorHelper(N: LLD, Visited, Worklist)) ||
28530 (RLD->hasAnyUseOfValue(Value: 1) &&
28531 SDNode::hasPredecessorHelper(N: RLD, Visited, Worklist)))
28532 return false;
28533
28534 Addr = DAG.getSelect(DL: SDLoc(TheSelect),
28535 VT: LLD->getBasePtr().getValueType(),
28536 Cond: TheSelect->getOperand(Num: 0), LHS: LLD->getBasePtr(),
28537 RHS: RLD->getBasePtr());
28538 } else { // Otherwise SELECT_CC
28539 // We cannot do this optimization if any pair of {RLD, LLD} is a
28540 // predecessor to {RLD, LLD, CondLHS, CondRHS}. As we've already compared
28541 // the Loads, we only need to check if CondLHS/CondRHS is a successor to
28542 // one of the loads. We can further avoid this if there's no use of their
28543 // chain value.
28544
28545 SDNode *CondLHS = TheSelect->getOperand(Num: 0).getNode();
28546 SDNode *CondRHS = TheSelect->getOperand(Num: 1).getNode();
28547 Worklist.push_back(Elt: CondLHS);
28548 Worklist.push_back(Elt: CondRHS);
28549
28550 if ((LLD->hasAnyUseOfValue(Value: 1) &&
28551 SDNode::hasPredecessorHelper(N: LLD, Visited, Worklist)) ||
28552 (RLD->hasAnyUseOfValue(Value: 1) &&
28553 SDNode::hasPredecessorHelper(N: RLD, Visited, Worklist)))
28554 return false;
28555
28556 Addr = DAG.getNode(Opcode: ISD::SELECT_CC, DL: SDLoc(TheSelect),
28557 VT: LLD->getBasePtr().getValueType(),
28558 N1: TheSelect->getOperand(Num: 0),
28559 N2: TheSelect->getOperand(Num: 1),
28560 N3: LLD->getBasePtr(), N4: RLD->getBasePtr(),
28561 N5: TheSelect->getOperand(Num: 4));
28562 }
28563
28564 SDValue Load;
28565 // It is safe to replace the two loads if they have different alignments,
28566 // but the new load must be the minimum (most restrictive) alignment of the
28567 // inputs.
28568 Align Alignment = std::min(a: LLD->getAlign(), b: RLD->getAlign());
28569 MachineMemOperand::Flags MMOFlags = LLD->getMemOperand()->getFlags();
28570 if (!RLD->isInvariant())
28571 MMOFlags &= ~MachineMemOperand::MOInvariant;
28572 if (!RLD->isDereferenceable())
28573 MMOFlags &= ~MachineMemOperand::MODereferenceable;
28574 if (LLD->getExtensionType() == ISD::NON_EXTLOAD) {
28575 // FIXME: Discards pointer and AA info.
28576 Load = DAG.getLoad(VT: TheSelect->getValueType(ResNo: 0), dl: SDLoc(TheSelect),
28577 Chain: LLD->getChain(), Ptr: Addr, PtrInfo: MachinePointerInfo(), Alignment,
28578 MMOFlags);
28579 } else {
28580 // FIXME: Discards pointer and AA info.
28581 Load = DAG.getExtLoad(
28582 ExtType: LLD->getExtensionType() == ISD::EXTLOAD ? RLD->getExtensionType()
28583 : LLD->getExtensionType(),
28584 dl: SDLoc(TheSelect), VT: TheSelect->getValueType(ResNo: 0), Chain: LLD->getChain(), Ptr: Addr,
28585 PtrInfo: MachinePointerInfo(), MemVT: LLD->getMemoryVT(), Alignment, MMOFlags);
28586 }
28587
28588 // Users of the select now use the result of the load.
28589 CombineTo(N: TheSelect, Res: Load);
28590
28591 // Users of the old loads now use the new load's chain. We know the
28592 // old-load value is dead now.
28593 CombineTo(N: LHS.getNode(), Res0: Load.getValue(R: 0), Res1: Load.getValue(R: 1));
28594 CombineTo(N: RHS.getNode(), Res0: Load.getValue(R: 0), Res1: Load.getValue(R: 1));
28595 return true;
28596 }
28597
28598 return false;
28599}
28600
28601/// Try to fold an expression of the form (N0 cond N1) ? N2 : N3 to a shift and
28602/// bitwise 'and'.
28603SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0,
28604 SDValue N1, SDValue N2, SDValue N3,
28605 ISD::CondCode CC) {
28606 // If this is a select where the false operand is zero and the compare is a
28607 // check of the sign bit, see if we can perform the "gzip trick":
28608 // select_cc setlt X, 0, A, 0 -> and (sra X, size(X)-1), A
28609 // select_cc setgt X, 0, A, 0 -> and (not (sra X, size(X)-1)), A
28610 EVT XType = N0.getValueType();
28611 EVT AType = N2.getValueType();
28612 if (!isNullConstant(V: N3) || !XType.bitsGE(VT: AType))
28613 return SDValue();
28614
28615 // If the comparison is testing for a positive value, we have to invert
28616 // the sign bit mask, so only do that transform if the target has a bitwise
28617 // 'and not' instruction (the invert is free).
28618 if (CC == ISD::SETGT && TLI.hasAndNot(X: N2)) {
28619 // (X > -1) ? A : 0
28620 // (X > 0) ? X : 0 <-- This is canonical signed max.
28621 if (!(isAllOnesConstant(V: N1) || (isNullConstant(V: N1) && N0 == N2)))
28622 return SDValue();
28623 } else if (CC == ISD::SETLT) {
28624 // (X < 0) ? A : 0
28625 // (X < 1) ? X : 0 <-- This is un-canonicalized signed min.
28626 if (!(isNullConstant(V: N1) || (isOneConstant(V: N1) && N0 == N2)))
28627 return SDValue();
28628 } else {
28629 return SDValue();
28630 }
28631
28632 // and (sra X, size(X)-1), A -> "and (srl X, C2), A" iff A is a single-bit
28633 // constant.
28634 auto *N2C = dyn_cast<ConstantSDNode>(Val: N2.getNode());
28635 if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue() - 1)) == 0)) {
28636 unsigned ShCt = XType.getSizeInBits() - N2C->getAPIntValue().logBase2() - 1;
28637 if (!TLI.shouldAvoidTransformToShift(VT: XType, Amount: ShCt)) {
28638 SDValue ShiftAmt = DAG.getShiftAmountConstant(Val: ShCt, VT: XType, DL);
28639 SDValue Shift = DAG.getNode(Opcode: ISD::SRL, DL, VT: XType, N1: N0, N2: ShiftAmt);
28640 AddToWorklist(N: Shift.getNode());
28641
28642 if (XType.bitsGT(VT: AType)) {
28643 Shift = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: AType, Operand: Shift);
28644 AddToWorklist(N: Shift.getNode());
28645 }
28646
28647 if (CC == ISD::SETGT)
28648 Shift = DAG.getNOT(DL, Val: Shift, VT: AType);
28649
28650 return DAG.getNode(Opcode: ISD::AND, DL, VT: AType, N1: Shift, N2);
28651 }
28652 }
28653
28654 unsigned ShCt = XType.getSizeInBits() - 1;
28655 if (TLI.shouldAvoidTransformToShift(VT: XType, Amount: ShCt))
28656 return SDValue();
28657
28658 SDValue ShiftAmt = DAG.getShiftAmountConstant(Val: ShCt, VT: XType, DL);
28659 SDValue Shift = DAG.getNode(Opcode: ISD::SRA, DL, VT: XType, N1: N0, N2: ShiftAmt);
28660 AddToWorklist(N: Shift.getNode());
28661
28662 if (XType.bitsGT(VT: AType)) {
28663 Shift = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: AType, Operand: Shift);
28664 AddToWorklist(N: Shift.getNode());
28665 }
28666
28667 if (CC == ISD::SETGT)
28668 Shift = DAG.getNOT(DL, Val: Shift, VT: AType);
28669
28670 return DAG.getNode(Opcode: ISD::AND, DL, VT: AType, N1: Shift, N2);
28671}
28672
28673// Fold select(cc, binop(), binop()) -> binop(select(), select()) etc.
28674SDValue DAGCombiner::foldSelectOfBinops(SDNode *N) {
28675 SDValue N0 = N->getOperand(Num: 0);
28676 SDValue N1 = N->getOperand(Num: 1);
28677 SDValue N2 = N->getOperand(Num: 2);
28678 SDLoc DL(N);
28679
28680 unsigned BinOpc = N1.getOpcode();
28681 if (!TLI.isBinOp(Opcode: BinOpc) || (N2.getOpcode() != BinOpc) ||
28682 (N1.getResNo() != N2.getResNo()))
28683 return SDValue();
28684
28685 // The use checks are intentionally on SDNode because we may be dealing
28686 // with opcodes that produce more than one SDValue.
28687 // TODO: Do we really need to check N0 (the condition operand of the select)?
28688 // But removing that clause could cause an infinite loop...
28689 if (!N0->hasOneUse() || !N1->hasOneUse() || !N2->hasOneUse())
28690 return SDValue();
28691
28692 // Binops may include opcodes that return multiple values, so all values
28693 // must be created/propagated from the newly created binops below.
28694 SDVTList OpVTs = N1->getVTList();
28695
28696 // Fold select(cond, binop(x, y), binop(z, y))
28697 // --> binop(select(cond, x, z), y)
28698 if (N1.getOperand(i: 1) == N2.getOperand(i: 1)) {
28699 SDValue N10 = N1.getOperand(i: 0);
28700 SDValue N20 = N2.getOperand(i: 0);
28701 SDValue NewSel = DAG.getSelect(DL, VT: N10.getValueType(), Cond: N0, LHS: N10, RHS: N20);
28702 SDValue NewBinOp = DAG.getNode(Opcode: BinOpc, DL, VTList: OpVTs, N1: NewSel, N2: N1.getOperand(i: 1));
28703 NewBinOp->setFlags(N1->getFlags());
28704 NewBinOp->intersectFlagsWith(Flags: N2->getFlags());
28705 return SDValue(NewBinOp.getNode(), N1.getResNo());
28706 }
28707
28708 // Fold select(cond, binop(x, y), binop(x, z))
28709 // --> binop(x, select(cond, y, z))
28710 if (N1.getOperand(i: 0) == N2.getOperand(i: 0)) {
28711 SDValue N11 = N1.getOperand(i: 1);
28712 SDValue N21 = N2.getOperand(i: 1);
28713 // Second op VT might be different (e.g. shift amount type)
28714 if (N11.getValueType() == N21.getValueType()) {
28715 SDValue NewSel = DAG.getSelect(DL, VT: N11.getValueType(), Cond: N0, LHS: N11, RHS: N21);
28716 SDValue NewBinOp =
28717 DAG.getNode(Opcode: BinOpc, DL, VTList: OpVTs, N1: N1.getOperand(i: 0), N2: NewSel);
28718 NewBinOp->setFlags(N1->getFlags());
28719 NewBinOp->intersectFlagsWith(Flags: N2->getFlags());
28720 return SDValue(NewBinOp.getNode(), N1.getResNo());
28721 }
28722 }
28723
28724 // TODO: Handle isCommutativeBinOp patterns as well?
28725 return SDValue();
28726}
28727
28728// Transform (fneg/fabs (bitconvert x)) to avoid loading constant pool values.
28729SDValue DAGCombiner::foldSignChangeInBitcast(SDNode *N) {
28730 SDValue N0 = N->getOperand(Num: 0);
28731 EVT VT = N->getValueType(ResNo: 0);
28732 bool IsFabs = N->getOpcode() == ISD::FABS;
28733 bool IsFree = IsFabs ? TLI.isFAbsFree(VT) : TLI.isFNegFree(VT);
28734
28735 if (IsFree || N0.getOpcode() != ISD::BITCAST || !N0.hasOneUse())
28736 return SDValue();
28737
28738 SDValue Int = N0.getOperand(i: 0);
28739 EVT IntVT = Int.getValueType();
28740
28741 // The operand to cast should be integer.
28742 if (!IntVT.isInteger() || IntVT.isVector())
28743 return SDValue();
28744
28745 // (fneg (bitconvert x)) -> (bitconvert (xor x sign))
28746 // (fabs (bitconvert x)) -> (bitconvert (and x ~sign))
28747 APInt SignMask;
28748 if (N0.getValueType().isVector()) {
28749 // For vector, create a sign mask (0x80...) or its inverse (for fabs,
28750 // 0x7f...) per element and splat it.
28751 SignMask = APInt::getSignMask(BitWidth: N0.getScalarValueSizeInBits());
28752 if (IsFabs)
28753 SignMask = ~SignMask;
28754 SignMask = APInt::getSplat(NewLen: IntVT.getSizeInBits(), V: SignMask);
28755 } else {
28756 // For scalar, just use the sign mask (0x80... or the inverse, 0x7f...)
28757 SignMask = APInt::getSignMask(BitWidth: IntVT.getSizeInBits());
28758 if (IsFabs)
28759 SignMask = ~SignMask;
28760 }
28761 SDLoc DL(N0);
28762 Int = DAG.getNode(Opcode: IsFabs ? ISD::AND : ISD::XOR, DL, VT: IntVT, N1: Int,
28763 N2: DAG.getConstant(Val: SignMask, DL, VT: IntVT));
28764 AddToWorklist(N: Int.getNode());
28765 return DAG.getBitcast(VT, V: Int);
28766}
28767
28768/// Turn "(a cond b) ? 1.0f : 2.0f" into "load (tmp + ((a cond b) ? 0 : 4)"
28769/// where "tmp" is a constant pool entry containing an array with 1.0 and 2.0
28770/// in it. This may be a win when the constant is not otherwise available
28771/// because it replaces two constant pool loads with one.
28772SDValue DAGCombiner::convertSelectOfFPConstantsToLoadOffset(
28773 const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
28774 ISD::CondCode CC) {
28775 if (!TLI.reduceSelectOfFPConstantLoads(CmpOpVT: N0.getValueType()))
28776 return SDValue();
28777
28778 // If we are before legalize types, we want the other legalization to happen
28779 // first (for example, to avoid messing with soft float).
28780 auto *TV = dyn_cast<ConstantFPSDNode>(Val&: N2);
28781 auto *FV = dyn_cast<ConstantFPSDNode>(Val&: N3);
28782 EVT VT = N2.getValueType();
28783 if (!TV || !FV || !TLI.isTypeLegal(VT))
28784 return SDValue();
28785
28786 // If a constant can be materialized without loads, this does not make sense.
28787 if (TLI.getOperationAction(Op: ISD::ConstantFP, VT) == TargetLowering::Legal ||
28788 TLI.isFPImmLegal(TV->getValueAPF(), TV->getValueType(ResNo: 0), ForCodeSize) ||
28789 TLI.isFPImmLegal(FV->getValueAPF(), FV->getValueType(ResNo: 0), ForCodeSize))
28790 return SDValue();
28791
28792 // If both constants have multiple uses, then we won't need to do an extra
28793 // load. The values are likely around in registers for other users.
28794 if (!TV->hasOneUse() && !FV->hasOneUse())
28795 return SDValue();
28796
28797 Constant *Elts[] = { const_cast<ConstantFP*>(FV->getConstantFPValue()),
28798 const_cast<ConstantFP*>(TV->getConstantFPValue()) };
28799 Type *FPTy = Elts[0]->getType();
28800 const DataLayout &TD = DAG.getDataLayout();
28801
28802 // Create a ConstantArray of the two constants.
28803 Constant *CA = ConstantArray::get(T: ArrayType::get(ElementType: FPTy, NumElements: 2), V: Elts);
28804 SDValue CPIdx = DAG.getConstantPool(C: CA, VT: TLI.getPointerTy(DL: DAG.getDataLayout()),
28805 Align: TD.getPrefTypeAlign(Ty: FPTy));
28806 Align Alignment = cast<ConstantPoolSDNode>(Val&: CPIdx)->getAlign();
28807
28808 // Get offsets to the 0 and 1 elements of the array, so we can select between
28809 // them.
28810 SDValue Zero = DAG.getIntPtrConstant(Val: 0, DL);
28811 unsigned EltSize = (unsigned)TD.getTypeAllocSize(Ty: Elts[0]->getType());
28812 SDValue One = DAG.getIntPtrConstant(Val: EltSize, DL: SDLoc(FV));
28813 SDValue Cond =
28814 DAG.getSetCC(DL, VT: getSetCCResultType(VT: N0.getValueType()), LHS: N0, RHS: N1, Cond: CC);
28815 AddToWorklist(N: Cond.getNode());
28816 SDValue CstOffset = DAG.getSelect(DL, VT: Zero.getValueType(), Cond, LHS: One, RHS: Zero);
28817 AddToWorklist(N: CstOffset.getNode());
28818 CPIdx = DAG.getNode(Opcode: ISD::ADD, DL, VT: CPIdx.getValueType(), N1: CPIdx, N2: CstOffset);
28819 AddToWorklist(N: CPIdx.getNode());
28820 return DAG.getLoad(VT: TV->getValueType(ResNo: 0), dl: DL, Chain: DAG.getEntryNode(), Ptr: CPIdx,
28821 PtrInfo: MachinePointerInfo::getConstantPool(
28822 MF&: DAG.getMachineFunction()), Alignment);
28823}
28824
28825/// Simplify an expression of the form (N0 cond N1) ? N2 : N3
28826/// where 'cond' is the comparison specified by CC.
28827SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
28828 SDValue N2, SDValue N3, ISD::CondCode CC,
28829 bool NotExtCompare) {
28830 // (x ? y : y) -> y.
28831 if (N2 == N3) return N2;
28832
28833 EVT CmpOpVT = N0.getValueType();
28834 EVT CmpResVT = getSetCCResultType(VT: CmpOpVT);
28835 EVT VT = N2.getValueType();
28836 auto *N1C = dyn_cast<ConstantSDNode>(Val: N1.getNode());
28837 auto *N2C = dyn_cast<ConstantSDNode>(Val: N2.getNode());
28838 auto *N3C = dyn_cast<ConstantSDNode>(Val: N3.getNode());
28839
28840 // Determine if the condition we're dealing with is constant.
28841 if (SDValue SCC = DAG.FoldSetCC(VT: CmpResVT, N1: N0, N2: N1, Cond: CC, dl: DL)) {
28842 AddToWorklist(N: SCC.getNode());
28843 if (auto *SCCC = dyn_cast<ConstantSDNode>(Val&: SCC)) {
28844 // fold select_cc true, x, y -> x
28845 // fold select_cc false, x, y -> y
28846 return !(SCCC->isZero()) ? N2 : N3;
28847 }
28848 }
28849
28850 if (SDValue V =
28851 convertSelectOfFPConstantsToLoadOffset(DL, N0, N1, N2, N3, CC))
28852 return V;
28853
28854 if (SDValue V = foldSelectCCToShiftAnd(DL, N0, N1, N2, N3, CC))
28855 return V;
28856
28857 // fold (select_cc seteq (and x, y), 0, 0, A) -> (and (sra (shl x)) A)
28858 // where y is has a single bit set.
28859 // A plaintext description would be, we can turn the SELECT_CC into an AND
28860 // when the condition can be materialized as an all-ones register. Any
28861 // single bit-test can be materialized as an all-ones register with
28862 // shift-left and shift-right-arith.
28863 if (CC == ISD::SETEQ && N0->getOpcode() == ISD::AND &&
28864 N0->getValueType(ResNo: 0) == VT && isNullConstant(V: N1) && isNullConstant(V: N2)) {
28865 SDValue AndLHS = N0->getOperand(Num: 0);
28866 auto *ConstAndRHS = dyn_cast<ConstantSDNode>(Val: N0->getOperand(Num: 1));
28867 if (ConstAndRHS && ConstAndRHS->getAPIntValue().popcount() == 1) {
28868 // Shift the tested bit over the sign bit.
28869 const APInt &AndMask = ConstAndRHS->getAPIntValue();
28870 if (TLI.shouldFoldSelectWithSingleBitTest(VT, AndMask)) {
28871 unsigned ShCt = AndMask.getBitWidth() - 1;
28872 SDValue ShlAmt = DAG.getShiftAmountConstant(Val: AndMask.countl_zero(), VT,
28873 DL: SDLoc(AndLHS));
28874 SDValue Shl = DAG.getNode(Opcode: ISD::SHL, DL: SDLoc(N0), VT, N1: AndLHS, N2: ShlAmt);
28875
28876 // Now arithmetic right shift it all the way over, so the result is
28877 // either all-ones, or zero.
28878 SDValue ShrAmt = DAG.getShiftAmountConstant(Val: ShCt, VT, DL: SDLoc(Shl));
28879 SDValue Shr = DAG.getNode(Opcode: ISD::SRA, DL: SDLoc(N0), VT, N1: Shl, N2: ShrAmt);
28880
28881 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Shr, N2: N3);
28882 }
28883 }
28884 }
28885
28886 // fold select C, 16, 0 -> shl C, 4
28887 bool Fold = N2C && isNullConstant(V: N3) && N2C->getAPIntValue().isPowerOf2();
28888 bool Swap = N3C && isNullConstant(V: N2) && N3C->getAPIntValue().isPowerOf2();
28889
28890 if ((Fold || Swap) &&
28891 TLI.getBooleanContents(Type: CmpOpVT) ==
28892 TargetLowering::ZeroOrOneBooleanContent &&
28893 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SETCC, VT: CmpOpVT)) &&
28894 TLI.convertSelectOfConstantsToMath(VT)) {
28895
28896 if (Swap) {
28897 CC = ISD::getSetCCInverse(Operation: CC, Type: CmpOpVT);
28898 std::swap(a&: N2C, b&: N3C);
28899 }
28900
28901 // If the caller doesn't want us to simplify this into a zext of a compare,
28902 // don't do it.
28903 if (NotExtCompare && N2C->isOne())
28904 return SDValue();
28905
28906 SDValue Temp, SCC;
28907 // zext (setcc n0, n1)
28908 if (LegalTypes) {
28909 SCC = DAG.getSetCC(DL, VT: CmpResVT, LHS: N0, RHS: N1, Cond: CC);
28910 Temp = DAG.getZExtOrTrunc(Op: SCC, DL: SDLoc(N2), VT);
28911 } else {
28912 SCC = DAG.getSetCC(DL: SDLoc(N0), VT: MVT::i1, LHS: N0, RHS: N1, Cond: CC);
28913 Temp = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL: SDLoc(N2), VT, Operand: SCC);
28914 }
28915
28916 AddToWorklist(N: SCC.getNode());
28917 AddToWorklist(N: Temp.getNode());
28918
28919 if (N2C->isOne())
28920 return Temp;
28921
28922 unsigned ShCt = N2C->getAPIntValue().logBase2();
28923 if (TLI.shouldAvoidTransformToShift(VT, Amount: ShCt))
28924 return SDValue();
28925
28926 // shl setcc result by log2 n2c
28927 return DAG.getNode(
28928 Opcode: ISD::SHL, DL, VT: N2.getValueType(), N1: Temp,
28929 N2: DAG.getShiftAmountConstant(Val: ShCt, VT: N2.getValueType(), DL: SDLoc(Temp)));
28930 }
28931
28932 // select_cc seteq X, 0, sizeof(X), ctlz(X) -> ctlz(X)
28933 // select_cc seteq X, 0, sizeof(X), ctlz_zero_undef(X) -> ctlz(X)
28934 // select_cc seteq X, 0, sizeof(X), cttz(X) -> cttz(X)
28935 // select_cc seteq X, 0, sizeof(X), cttz_zero_undef(X) -> cttz(X)
28936 // select_cc setne X, 0, ctlz(X), sizeof(X) -> ctlz(X)
28937 // select_cc setne X, 0, ctlz_zero_undef(X), sizeof(X) -> ctlz(X)
28938 // select_cc setne X, 0, cttz(X), sizeof(X) -> cttz(X)
28939 // select_cc setne X, 0, cttz_zero_undef(X), sizeof(X) -> cttz(X)
28940 if (N1C && N1C->isZero() && (CC == ISD::SETEQ || CC == ISD::SETNE)) {
28941 SDValue ValueOnZero = N2;
28942 SDValue Count = N3;
28943 // If the condition is NE instead of E, swap the operands.
28944 if (CC == ISD::SETNE)
28945 std::swap(a&: ValueOnZero, b&: Count);
28946 // Check if the value on zero is a constant equal to the bits in the type.
28947 if (auto *ValueOnZeroC = dyn_cast<ConstantSDNode>(Val&: ValueOnZero)) {
28948 if (ValueOnZeroC->getAPIntValue() == VT.getSizeInBits()) {
28949 // If the other operand is cttz/cttz_zero_undef of N0, and cttz is
28950 // legal, combine to just cttz.
28951 if ((Count.getOpcode() == ISD::CTTZ ||
28952 Count.getOpcode() == ISD::CTTZ_ZERO_UNDEF) &&
28953 N0 == Count.getOperand(i: 0) &&
28954 (!LegalOperations || TLI.isOperationLegal(Op: ISD::CTTZ, VT)))
28955 return DAG.getNode(Opcode: ISD::CTTZ, DL, VT, Operand: N0);
28956 // If the other operand is ctlz/ctlz_zero_undef of N0, and ctlz is
28957 // legal, combine to just ctlz.
28958 if ((Count.getOpcode() == ISD::CTLZ ||
28959 Count.getOpcode() == ISD::CTLZ_ZERO_UNDEF) &&
28960 N0 == Count.getOperand(i: 0) &&
28961 (!LegalOperations || TLI.isOperationLegal(Op: ISD::CTLZ, VT)))
28962 return DAG.getNode(Opcode: ISD::CTLZ, DL, VT, Operand: N0);
28963 }
28964 }
28965 }
28966
28967 // Fold select_cc setgt X, -1, C, ~C -> xor (ashr X, BW-1), C
28968 // Fold select_cc setlt X, 0, C, ~C -> xor (ashr X, BW-1), ~C
28969 if (!NotExtCompare && N1C && N2C && N3C &&
28970 N2C->getAPIntValue() == ~N3C->getAPIntValue() &&
28971 ((N1C->isAllOnes() && CC == ISD::SETGT) ||
28972 (N1C->isZero() && CC == ISD::SETLT)) &&
28973 !TLI.shouldAvoidTransformToShift(VT, Amount: CmpOpVT.getScalarSizeInBits() - 1)) {
28974 SDValue ASR = DAG.getNode(
28975 Opcode: ISD::SRA, DL, VT: CmpOpVT, N1: N0,
28976 N2: DAG.getConstant(Val: CmpOpVT.getScalarSizeInBits() - 1, DL, VT: CmpOpVT));
28977 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: DAG.getSExtOrTrunc(Op: ASR, DL, VT),
28978 N2: DAG.getSExtOrTrunc(Op: CC == ISD::SETLT ? N3 : N2, DL, VT));
28979 }
28980
28981 if (SDValue S = PerformMinMaxFpToSatCombine(N0, N1, N2, N3, CC, DAG))
28982 return S;
28983 if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N2, N3, CC, DAG))
28984 return S;
28985 if (SDValue ABD = foldSelectToABD(LHS: N0, RHS: N1, True: N2, False: N3, CC, DL))
28986 return ABD;
28987
28988 return SDValue();
28989}
28990
28991/// This is a stub for TargetLowering::SimplifySetCC.
28992SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
28993 ISD::CondCode Cond, const SDLoc &DL,
28994 bool foldBooleans) {
28995 TargetLowering::DAGCombinerInfo
28996 DagCombineInfo(DAG, Level, false, this);
28997 return TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DCI&: DagCombineInfo, dl: DL);
28998}
28999
29000/// Given an ISD::SDIV node expressing a divide by constant, return
29001/// a DAG expression to select that will generate the same value by multiplying
29002/// by a magic number.
29003/// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
29004SDValue DAGCombiner::BuildSDIV(SDNode *N) {
29005 // when optimising for minimum size, we don't want to expand a div to a mul
29006 // and a shift.
29007 if (DAG.getMachineFunction().getFunction().hasMinSize())
29008 return SDValue();
29009
29010 SmallVector<SDNode *, 8> Built;
29011 if (SDValue S = TLI.BuildSDIV(N, DAG, IsAfterLegalization: LegalOperations, IsAfterLegalTypes: LegalTypes, Created&: Built)) {
29012 for (SDNode *N : Built)
29013 AddToWorklist(N);
29014 return S;
29015 }
29016
29017 return SDValue();
29018}
29019
29020/// Given an ISD::SDIV node expressing a divide by constant power of 2, return a
29021/// DAG expression that will generate the same value by right shifting.
29022SDValue DAGCombiner::BuildSDIVPow2(SDNode *N) {
29023 ConstantSDNode *C = isConstOrConstSplat(N: N->getOperand(Num: 1));
29024 if (!C)
29025 return SDValue();
29026
29027 // Avoid division by zero.
29028 if (C->isZero())
29029 return SDValue();
29030
29031 SmallVector<SDNode *, 8> Built;
29032 if (SDValue S = TLI.BuildSDIVPow2(N, Divisor: C->getAPIntValue(), DAG, Created&: Built)) {
29033 for (SDNode *N : Built)
29034 AddToWorklist(N);
29035 return S;
29036 }
29037
29038 return SDValue();
29039}
29040
29041/// Given an ISD::UDIV node expressing a divide by constant, return a DAG
29042/// expression that will generate the same value by multiplying by a magic
29043/// number.
29044/// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
29045SDValue DAGCombiner::BuildUDIV(SDNode *N) {
29046 // when optimising for minimum size, we don't want to expand a div to a mul
29047 // and a shift.
29048 if (DAG.getMachineFunction().getFunction().hasMinSize())
29049 return SDValue();
29050
29051 SmallVector<SDNode *, 8> Built;
29052 if (SDValue S = TLI.BuildUDIV(N, DAG, IsAfterLegalization: LegalOperations, IsAfterLegalTypes: LegalTypes, Created&: Built)) {
29053 for (SDNode *N : Built)
29054 AddToWorklist(N);
29055 return S;
29056 }
29057
29058 return SDValue();
29059}
29060
29061/// Given an ISD::SREM node expressing a remainder by constant power of 2,
29062/// return a DAG expression that will generate the same value.
29063SDValue DAGCombiner::BuildSREMPow2(SDNode *N) {
29064 ConstantSDNode *C = isConstOrConstSplat(N: N->getOperand(Num: 1));
29065 if (!C)
29066 return SDValue();
29067
29068 // Avoid division by zero.
29069 if (C->isZero())
29070 return SDValue();
29071
29072 SmallVector<SDNode *, 8> Built;
29073 if (SDValue S = TLI.BuildSREMPow2(N, Divisor: C->getAPIntValue(), DAG, Created&: Built)) {
29074 for (SDNode *N : Built)
29075 AddToWorklist(N);
29076 return S;
29077 }
29078
29079 return SDValue();
29080}
29081
29082// This is basically just a port of takeLog2 from InstCombineMulDivRem.cpp
29083//
29084// Returns the node that represents `Log2(Op)`. This may create a new node. If
29085// we are unable to compute `Log2(Op)` its return `SDValue()`.
29086//
29087// All nodes will be created at `DL` and the output will be of type `VT`.
29088//
29089// This will only return `Log2(Op)` if we can prove `Op` is non-zero. Set
29090// `AssumeNonZero` if this function should simply assume (not require proving
29091// `Op` is non-zero).
29092static SDValue takeInexpensiveLog2(SelectionDAG &DAG, const SDLoc &DL, EVT VT,
29093 SDValue Op, unsigned Depth,
29094 bool AssumeNonZero) {
29095 assert(VT.isInteger() && "Only integer types are supported!");
29096
29097 auto PeekThroughCastsAndTrunc = [](SDValue V) {
29098 while (true) {
29099 switch (V.getOpcode()) {
29100 case ISD::TRUNCATE:
29101 case ISD::ZERO_EXTEND:
29102 V = V.getOperand(i: 0);
29103 break;
29104 default:
29105 return V;
29106 }
29107 }
29108 };
29109
29110 if (VT.isScalableVector())
29111 return SDValue();
29112
29113 Op = PeekThroughCastsAndTrunc(Op);
29114
29115 // Helper for determining whether a value is a power-2 constant scalar or a
29116 // vector of such elements.
29117 SmallVector<APInt> Pow2Constants;
29118 auto IsPowerOfTwo = [&Pow2Constants](ConstantSDNode *C) {
29119 if (C->isZero() || C->isOpaque())
29120 return false;
29121 // TODO: We may also be able to support negative powers of 2 here.
29122 if (C->getAPIntValue().isPowerOf2()) {
29123 Pow2Constants.emplace_back(Args: C->getAPIntValue());
29124 return true;
29125 }
29126 return false;
29127 };
29128
29129 if (ISD::matchUnaryPredicate(Op, Match: IsPowerOfTwo)) {
29130 if (!VT.isVector())
29131 return DAG.getConstant(Val: Pow2Constants.back().logBase2(), DL, VT);
29132 // We need to create a build vector
29133 if (Op.getOpcode() == ISD::SPLAT_VECTOR)
29134 return DAG.getSplat(VT, DL,
29135 Op: DAG.getConstant(Val: Pow2Constants.back().logBase2(), DL,
29136 VT: VT.getScalarType()));
29137 SmallVector<SDValue> Log2Ops;
29138 for (const APInt &Pow2 : Pow2Constants)
29139 Log2Ops.emplace_back(
29140 Args: DAG.getConstant(Val: Pow2.logBase2(), DL, VT: VT.getScalarType()));
29141 return DAG.getBuildVector(VT, DL, Ops: Log2Ops);
29142 }
29143
29144 if (Depth >= DAG.MaxRecursionDepth)
29145 return SDValue();
29146
29147 auto CastToVT = [&](EVT NewVT, SDValue ToCast) {
29148 // Peek through zero extend. We can't peek through truncates since this
29149 // function is called on a shift amount. We must ensure that all of the bits
29150 // above the original shift amount are zeroed by this function.
29151 while (ToCast.getOpcode() == ISD::ZERO_EXTEND)
29152 ToCast = ToCast.getOperand(i: 0);
29153 EVT CurVT = ToCast.getValueType();
29154 if (NewVT == CurVT)
29155 return ToCast;
29156
29157 if (NewVT.getSizeInBits() == CurVT.getSizeInBits())
29158 return DAG.getBitcast(VT: NewVT, V: ToCast);
29159
29160 return DAG.getZExtOrTrunc(Op: ToCast, DL, VT: NewVT);
29161 };
29162
29163 // log2(X << Y) -> log2(X) + Y
29164 if (Op.getOpcode() == ISD::SHL) {
29165 // 1 << Y and X nuw/nsw << Y are all non-zero.
29166 if (AssumeNonZero || Op->getFlags().hasNoUnsignedWrap() ||
29167 Op->getFlags().hasNoSignedWrap() || isOneConstant(V: Op.getOperand(i: 0)))
29168 if (SDValue LogX = takeInexpensiveLog2(DAG, DL, VT, Op: Op.getOperand(i: 0),
29169 Depth: Depth + 1, AssumeNonZero))
29170 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: LogX,
29171 N2: CastToVT(VT, Op.getOperand(i: 1)));
29172 }
29173
29174 // c ? X : Y -> c ? Log2(X) : Log2(Y)
29175 if ((Op.getOpcode() == ISD::SELECT || Op.getOpcode() == ISD::VSELECT) &&
29176 Op.hasOneUse()) {
29177 if (SDValue LogX = takeInexpensiveLog2(DAG, DL, VT, Op: Op.getOperand(i: 1),
29178 Depth: Depth + 1, AssumeNonZero))
29179 if (SDValue LogY = takeInexpensiveLog2(DAG, DL, VT, Op: Op.getOperand(i: 2),
29180 Depth: Depth + 1, AssumeNonZero))
29181 return DAG.getSelect(DL, VT, Cond: Op.getOperand(i: 0), LHS: LogX, RHS: LogY);
29182 }
29183
29184 // log2(umin(X, Y)) -> umin(log2(X), log2(Y))
29185 // log2(umax(X, Y)) -> umax(log2(X), log2(Y))
29186 if ((Op.getOpcode() == ISD::UMIN || Op.getOpcode() == ISD::UMAX) &&
29187 Op.hasOneUse()) {
29188 // Use AssumeNonZero as false here. Otherwise we can hit case where
29189 // log2(umax(X, Y)) != umax(log2(X), log2(Y)) (because overflow).
29190 if (SDValue LogX =
29191 takeInexpensiveLog2(DAG, DL, VT, Op: Op.getOperand(i: 0), Depth: Depth + 1,
29192 /*AssumeNonZero*/ false))
29193 if (SDValue LogY =
29194 takeInexpensiveLog2(DAG, DL, VT, Op: Op.getOperand(i: 1), Depth: Depth + 1,
29195 /*AssumeNonZero*/ false))
29196 return DAG.getNode(Opcode: Op.getOpcode(), DL, VT, N1: LogX, N2: LogY);
29197 }
29198
29199 return SDValue();
29200}
29201
29202/// Determines the LogBase2 value for a non-null input value using the
29203/// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
29204SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL,
29205 bool KnownNonZero, bool InexpensiveOnly,
29206 std::optional<EVT> OutVT) {
29207 EVT VT = OutVT ? *OutVT : V.getValueType();
29208 SDValue InexpensiveLogBase2 =
29209 takeInexpensiveLog2(DAG, DL, VT, Op: V, /*Depth*/ 0, AssumeNonZero: KnownNonZero);
29210 if (InexpensiveLogBase2 || InexpensiveOnly || !DAG.isKnownToBeAPowerOfTwo(Val: V))
29211 return InexpensiveLogBase2;
29212
29213 SDValue Ctlz = DAG.getNode(Opcode: ISD::CTLZ, DL, VT, Operand: V);
29214 SDValue Base = DAG.getConstant(Val: VT.getScalarSizeInBits() - 1, DL, VT);
29215 SDValue LogBase2 = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Base, N2: Ctlz);
29216 return LogBase2;
29217}
29218
29219/// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
29220/// For the reciprocal, we need to find the zero of the function:
29221/// F(X) = 1/X - A [which has a zero at X = 1/A]
29222/// =>
29223/// X_{i+1} = X_i (2 - A X_i) = X_i + X_i (1 - A X_i) [this second form
29224/// does not require additional intermediate precision]
29225/// For the last iteration, put numerator N into it to gain more precision:
29226/// Result = N X_i + X_i (N - N A X_i)
29227SDValue DAGCombiner::BuildDivEstimate(SDValue N, SDValue Op,
29228 SDNodeFlags Flags) {
29229 if (LegalDAG)
29230 return SDValue();
29231
29232 // TODO: Handle extended types?
29233 EVT VT = Op.getValueType();
29234 if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
29235 VT.getScalarType() != MVT::f64)
29236 return SDValue();
29237
29238 // If estimates are explicitly disabled for this function, we're done.
29239 MachineFunction &MF = DAG.getMachineFunction();
29240 int Enabled = TLI.getRecipEstimateDivEnabled(VT, MF);
29241 if (Enabled == TLI.ReciprocalEstimate::Disabled)
29242 return SDValue();
29243
29244 // Estimates may be explicitly enabled for this type with a custom number of
29245 // refinement steps.
29246 int Iterations = TLI.getDivRefinementSteps(VT, MF);
29247 if (SDValue Est = TLI.getRecipEstimate(Operand: Op, DAG, Enabled, RefinementSteps&: Iterations)) {
29248 AddToWorklist(N: Est.getNode());
29249
29250 SDLoc DL(Op);
29251 if (Iterations) {
29252 SDValue FPOne = DAG.getConstantFP(Val: 1.0, DL, VT);
29253
29254 // Newton iterations: Est = Est + Est (N - Arg * Est)
29255 // If this is the last iteration, also multiply by the numerator.
29256 for (int i = 0; i < Iterations; ++i) {
29257 SDValue MulEst = Est;
29258
29259 if (i == Iterations - 1) {
29260 MulEst = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N, N2: Est, Flags);
29261 AddToWorklist(N: MulEst.getNode());
29262 }
29263
29264 SDValue NewEst = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Op, N2: MulEst, Flags);
29265 AddToWorklist(N: NewEst.getNode());
29266
29267 NewEst = DAG.getNode(Opcode: ISD::FSUB, DL, VT,
29268 N1: (i == Iterations - 1 ? N : FPOne), N2: NewEst, Flags);
29269 AddToWorklist(N: NewEst.getNode());
29270
29271 NewEst = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: NewEst, Flags);
29272 AddToWorklist(N: NewEst.getNode());
29273
29274 Est = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: MulEst, N2: NewEst, Flags);
29275 AddToWorklist(N: Est.getNode());
29276 }
29277 } else {
29278 // If no iterations are available, multiply with N.
29279 Est = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: N, Flags);
29280 AddToWorklist(N: Est.getNode());
29281 }
29282
29283 return Est;
29284 }
29285
29286 return SDValue();
29287}
29288
29289/// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
29290/// For the reciprocal sqrt, we need to find the zero of the function:
29291/// F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
29292/// =>
29293/// X_{i+1} = X_i (1.5 - A X_i^2 / 2)
29294/// As a result, we precompute A/2 prior to the iteration loop.
29295SDValue DAGCombiner::buildSqrtNROneConst(SDValue Arg, SDValue Est,
29296 unsigned Iterations,
29297 SDNodeFlags Flags, bool Reciprocal) {
29298 EVT VT = Arg.getValueType();
29299 SDLoc DL(Arg);
29300 SDValue ThreeHalves = DAG.getConstantFP(Val: 1.5, DL, VT);
29301
29302 // We now need 0.5 * Arg which we can write as (1.5 * Arg - Arg) so that
29303 // this entire sequence requires only one FP constant.
29304 SDValue HalfArg = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: ThreeHalves, N2: Arg, Flags);
29305 HalfArg = DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1: HalfArg, N2: Arg, Flags);
29306
29307 // Newton iterations: Est = Est * (1.5 - HalfArg * Est * Est)
29308 for (unsigned i = 0; i < Iterations; ++i) {
29309 SDValue NewEst = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: Est, Flags);
29310 NewEst = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: HalfArg, N2: NewEst, Flags);
29311 NewEst = DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1: ThreeHalves, N2: NewEst, Flags);
29312 Est = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: NewEst, Flags);
29313 }
29314
29315 // If non-reciprocal square root is requested, multiply the result by Arg.
29316 if (!Reciprocal)
29317 Est = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: Arg, Flags);
29318
29319 return Est;
29320}
29321
29322/// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
29323/// For the reciprocal sqrt, we need to find the zero of the function:
29324/// F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
29325/// =>
29326/// X_{i+1} = (-0.5 * X_i) * (A * X_i * X_i + (-3.0))
29327SDValue DAGCombiner::buildSqrtNRTwoConst(SDValue Arg, SDValue Est,
29328 unsigned Iterations,
29329 SDNodeFlags Flags, bool Reciprocal) {
29330 EVT VT = Arg.getValueType();
29331 SDLoc DL(Arg);
29332 SDValue MinusThree = DAG.getConstantFP(Val: -3.0, DL, VT);
29333 SDValue MinusHalf = DAG.getConstantFP(Val: -0.5, DL, VT);
29334
29335 // This routine must enter the loop below to work correctly
29336 // when (Reciprocal == false).
29337 assert(Iterations > 0);
29338
29339 // Newton iterations for reciprocal square root:
29340 // E = (E * -0.5) * ((A * E) * E + -3.0)
29341 for (unsigned i = 0; i < Iterations; ++i) {
29342 SDValue AE = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Arg, N2: Est, Flags);
29343 SDValue AEE = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: AE, N2: Est, Flags);
29344 SDValue RHS = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: AEE, N2: MinusThree, Flags);
29345
29346 // When calculating a square root at the last iteration build:
29347 // S = ((A * E) * -0.5) * ((A * E) * E + -3.0)
29348 // (notice a common subexpression)
29349 SDValue LHS;
29350 if (Reciprocal || (i + 1) < Iterations) {
29351 // RSQRT: LHS = (E * -0.5)
29352 LHS = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: MinusHalf, Flags);
29353 } else {
29354 // SQRT: LHS = (A * E) * -0.5
29355 LHS = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: AE, N2: MinusHalf, Flags);
29356 }
29357
29358 Est = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: LHS, N2: RHS, Flags);
29359 }
29360
29361 return Est;
29362}
29363
29364/// Build code to calculate either rsqrt(Op) or sqrt(Op). In the latter case
29365/// Op*rsqrt(Op) is actually computed, so additional postprocessing is needed if
29366/// Op can be zero.
29367SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
29368 bool Reciprocal) {
29369 if (LegalDAG)
29370 return SDValue();
29371
29372 // TODO: Handle extended types?
29373 EVT VT = Op.getValueType();
29374 if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
29375 VT.getScalarType() != MVT::f64)
29376 return SDValue();
29377
29378 // If estimates are explicitly disabled for this function, we're done.
29379 MachineFunction &MF = DAG.getMachineFunction();
29380 int Enabled = TLI.getRecipEstimateSqrtEnabled(VT, MF);
29381 if (Enabled == TLI.ReciprocalEstimate::Disabled)
29382 return SDValue();
29383
29384 // Estimates may be explicitly enabled for this type with a custom number of
29385 // refinement steps.
29386 int Iterations = TLI.getSqrtRefinementSteps(VT, MF);
29387
29388 bool UseOneConstNR = false;
29389 if (SDValue Est =
29390 TLI.getSqrtEstimate(Operand: Op, DAG, Enabled, RefinementSteps&: Iterations, UseOneConstNR,
29391 Reciprocal)) {
29392 AddToWorklist(N: Est.getNode());
29393
29394 if (Iterations > 0)
29395 Est = UseOneConstNR
29396 ? buildSqrtNROneConst(Arg: Op, Est, Iterations, Flags, Reciprocal)
29397 : buildSqrtNRTwoConst(Arg: Op, Est, Iterations, Flags, Reciprocal);
29398 if (!Reciprocal) {
29399 SDLoc DL(Op);
29400 // Try the target specific test first.
29401 SDValue Test = TLI.getSqrtInputTest(Operand: Op, DAG, Mode: DAG.getDenormalMode(VT));
29402
29403 // The estimate is now completely wrong if the input was exactly 0.0 or
29404 // possibly a denormal. Force the answer to 0.0 or value provided by
29405 // target for those cases.
29406 Est = DAG.getSelect(DL, VT, Cond: Test,
29407 LHS: TLI.getSqrtResultForDenormInput(Operand: Op, DAG), RHS: Est);
29408 }
29409 return Est;
29410 }
29411
29412 return SDValue();
29413}
29414
29415SDValue DAGCombiner::buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags) {
29416 return buildSqrtEstimateImpl(Op, Flags, Reciprocal: true);
29417}
29418
29419SDValue DAGCombiner::buildSqrtEstimate(SDValue Op, SDNodeFlags Flags) {
29420 return buildSqrtEstimateImpl(Op, Flags, Reciprocal: false);
29421}
29422
29423/// Return true if there is any possibility that the two addresses overlap.
29424bool DAGCombiner::mayAlias(SDNode *Op0, SDNode *Op1) const {
29425
29426 struct MemUseCharacteristics {
29427 bool IsVolatile;
29428 bool IsAtomic;
29429 SDValue BasePtr;
29430 int64_t Offset;
29431 LocationSize NumBytes;
29432 MachineMemOperand *MMO;
29433 };
29434
29435 auto getCharacteristics = [](SDNode *N) -> MemUseCharacteristics {
29436 if (const auto *LSN = dyn_cast<LSBaseSDNode>(Val: N)) {
29437 int64_t Offset = 0;
29438 if (auto *C = dyn_cast<ConstantSDNode>(Val: LSN->getOffset()))
29439 Offset = (LSN->getAddressingMode() == ISD::PRE_INC) ? C->getSExtValue()
29440 : (LSN->getAddressingMode() == ISD::PRE_DEC)
29441 ? -1 * C->getSExtValue()
29442 : 0;
29443 TypeSize Size = LSN->getMemoryVT().getStoreSize();
29444 return {.IsVolatile: LSN->isVolatile(), .IsAtomic: LSN->isAtomic(),
29445 .BasePtr: LSN->getBasePtr(), .Offset: Offset /*base offset*/,
29446 .NumBytes: LocationSize::precise(Value: Size), .MMO: LSN->getMemOperand()};
29447 }
29448 if (const auto *LN = cast<LifetimeSDNode>(Val: N))
29449 return {.IsVolatile: false /*isVolatile*/,
29450 /*isAtomic*/ .IsAtomic: false,
29451 .BasePtr: LN->getOperand(Num: 1),
29452 .Offset: (LN->hasOffset()) ? LN->getOffset() : 0,
29453 .NumBytes: (LN->hasOffset()) ? LocationSize::precise(Value: LN->getSize())
29454 : LocationSize::beforeOrAfterPointer(),
29455 .MMO: (MachineMemOperand *)nullptr};
29456 // Default.
29457 return {.IsVolatile: false /*isvolatile*/,
29458 /*isAtomic*/ .IsAtomic: false,
29459 .BasePtr: SDValue(),
29460 .Offset: (int64_t)0 /*offset*/,
29461 .NumBytes: LocationSize::beforeOrAfterPointer() /*size*/,
29462 .MMO: (MachineMemOperand *)nullptr};
29463 };
29464
29465 MemUseCharacteristics MUC0 = getCharacteristics(Op0),
29466 MUC1 = getCharacteristics(Op1);
29467
29468 // If they are to the same address, then they must be aliases.
29469 if (MUC0.BasePtr.getNode() && MUC0.BasePtr == MUC1.BasePtr &&
29470 MUC0.Offset == MUC1.Offset)
29471 return true;
29472
29473 // If they are both volatile then they cannot be reordered.
29474 if (MUC0.IsVolatile && MUC1.IsVolatile)
29475 return true;
29476
29477 // Be conservative about atomics for the moment
29478 // TODO: This is way overconservative for unordered atomics (see D66309)
29479 if (MUC0.IsAtomic && MUC1.IsAtomic)
29480 return true;
29481
29482 if (MUC0.MMO && MUC1.MMO) {
29483 if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
29484 (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
29485 return false;
29486 }
29487
29488 // If NumBytes is scalable and offset is not 0, conservatively return may
29489 // alias
29490 if ((MUC0.NumBytes.hasValue() && MUC0.NumBytes.isScalable() &&
29491 MUC0.Offset != 0) ||
29492 (MUC1.NumBytes.hasValue() && MUC1.NumBytes.isScalable() &&
29493 MUC1.Offset != 0))
29494 return true;
29495 // Try to prove that there is aliasing, or that there is no aliasing. Either
29496 // way, we can return now. If nothing can be proved, proceed with more tests.
29497 bool IsAlias;
29498 if (BaseIndexOffset::computeAliasing(Op0, NumBytes0: MUC0.NumBytes, Op1, NumBytes1: MUC1.NumBytes,
29499 DAG, IsAlias))
29500 return IsAlias;
29501
29502 // The following all rely on MMO0 and MMO1 being valid. Fail conservatively if
29503 // either are not known.
29504 if (!MUC0.MMO || !MUC1.MMO)
29505 return true;
29506
29507 // If one operation reads from invariant memory, and the other may store, they
29508 // cannot alias. These should really be checking the equivalent of mayWrite,
29509 // but it only matters for memory nodes other than load /store.
29510 if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
29511 (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
29512 return false;
29513
29514 // If we know required SrcValue1 and SrcValue2 have relatively large
29515 // alignment compared to the size and offset of the access, we may be able
29516 // to prove they do not alias. This check is conservative for now to catch
29517 // cases created by splitting vector types, it only works when the offsets are
29518 // multiples of the size of the data.
29519 int64_t SrcValOffset0 = MUC0.MMO->getOffset();
29520 int64_t SrcValOffset1 = MUC1.MMO->getOffset();
29521 Align OrigAlignment0 = MUC0.MMO->getBaseAlign();
29522 Align OrigAlignment1 = MUC1.MMO->getBaseAlign();
29523 LocationSize Size0 = MUC0.NumBytes;
29524 LocationSize Size1 = MUC1.NumBytes;
29525
29526 if (OrigAlignment0 == OrigAlignment1 && SrcValOffset0 != SrcValOffset1 &&
29527 Size0.hasValue() && Size1.hasValue() && !Size0.isScalable() &&
29528 !Size1.isScalable() && Size0 == Size1 &&
29529 OrigAlignment0 > Size0.getValue().getKnownMinValue() &&
29530 SrcValOffset0 % Size0.getValue().getKnownMinValue() == 0 &&
29531 SrcValOffset1 % Size1.getValue().getKnownMinValue() == 0) {
29532 int64_t OffAlign0 = SrcValOffset0 % OrigAlignment0.value();
29533 int64_t OffAlign1 = SrcValOffset1 % OrigAlignment1.value();
29534
29535 // There is no overlap between these relatively aligned accesses of
29536 // similar size. Return no alias.
29537 if ((OffAlign0 + static_cast<int64_t>(
29538 Size0.getValue().getKnownMinValue())) <= OffAlign1 ||
29539 (OffAlign1 + static_cast<int64_t>(
29540 Size1.getValue().getKnownMinValue())) <= OffAlign0)
29541 return false;
29542 }
29543
29544 bool UseAA = CombinerGlobalAA.getNumOccurrences() > 0
29545 ? CombinerGlobalAA
29546 : DAG.getSubtarget().useAA();
29547#ifndef NDEBUG
29548 if (CombinerAAOnlyFunc.getNumOccurrences() &&
29549 CombinerAAOnlyFunc != DAG.getMachineFunction().getName())
29550 UseAA = false;
29551#endif
29552
29553 if (UseAA && BatchAA && MUC0.MMO->getValue() && MUC1.MMO->getValue() &&
29554 Size0.hasValue() && Size1.hasValue() &&
29555 // Can't represent a scalable size + fixed offset in LocationSize
29556 (!Size0.isScalable() || SrcValOffset0 == 0) &&
29557 (!Size1.isScalable() || SrcValOffset1 == 0)) {
29558 // Use alias analysis information.
29559 int64_t MinOffset = std::min(a: SrcValOffset0, b: SrcValOffset1);
29560 int64_t Overlap0 =
29561 Size0.getValue().getKnownMinValue() + SrcValOffset0 - MinOffset;
29562 int64_t Overlap1 =
29563 Size1.getValue().getKnownMinValue() + SrcValOffset1 - MinOffset;
29564 LocationSize Loc0 =
29565 Size0.isScalable() ? Size0 : LocationSize::precise(Value: Overlap0);
29566 LocationSize Loc1 =
29567 Size1.isScalable() ? Size1 : LocationSize::precise(Value: Overlap1);
29568 if (BatchAA->isNoAlias(
29569 LocA: MemoryLocation(MUC0.MMO->getValue(), Loc0,
29570 UseTBAA ? MUC0.MMO->getAAInfo() : AAMDNodes()),
29571 LocB: MemoryLocation(MUC1.MMO->getValue(), Loc1,
29572 UseTBAA ? MUC1.MMO->getAAInfo() : AAMDNodes())))
29573 return false;
29574 }
29575
29576 // Otherwise we have to assume they alias.
29577 return true;
29578}
29579
29580/// Walk up chain skipping non-aliasing memory nodes,
29581/// looking for aliasing nodes and adding them to the Aliases vector.
29582void DAGCombiner::GatherAllAliases(SDNode *N, SDValue OriginalChain,
29583 SmallVectorImpl<SDValue> &Aliases) {
29584 SmallVector<SDValue, 8> Chains; // List of chains to visit.
29585 SmallPtrSet<SDNode *, 16> Visited; // Visited node set.
29586
29587 // Get alias information for node.
29588 // TODO: relax aliasing for unordered atomics (see D66309)
29589 const bool IsLoad = isa<LoadSDNode>(Val: N) && cast<LoadSDNode>(Val: N)->isSimple();
29590
29591 // Starting off.
29592 Chains.push_back(Elt: OriginalChain);
29593 unsigned Depth = 0;
29594
29595 // Attempt to improve chain by a single step
29596 auto ImproveChain = [&](SDValue &C) -> bool {
29597 switch (C.getOpcode()) {
29598 case ISD::EntryToken:
29599 // No need to mark EntryToken.
29600 C = SDValue();
29601 return true;
29602 case ISD::LOAD:
29603 case ISD::STORE: {
29604 // Get alias information for C.
29605 // TODO: Relax aliasing for unordered atomics (see D66309)
29606 bool IsOpLoad = isa<LoadSDNode>(Val: C.getNode()) &&
29607 cast<LSBaseSDNode>(Val: C.getNode())->isSimple();
29608 if ((IsLoad && IsOpLoad) || !mayAlias(Op0: N, Op1: C.getNode())) {
29609 // Look further up the chain.
29610 C = C.getOperand(i: 0);
29611 return true;
29612 }
29613 // Alias, so stop here.
29614 return false;
29615 }
29616
29617 case ISD::CopyFromReg:
29618 // Always forward past CopyFromReg.
29619 C = C.getOperand(i: 0);
29620 return true;
29621
29622 case ISD::LIFETIME_START:
29623 case ISD::LIFETIME_END: {
29624 // We can forward past any lifetime start/end that can be proven not to
29625 // alias the memory access.
29626 if (!mayAlias(Op0: N, Op1: C.getNode())) {
29627 // Look further up the chain.
29628 C = C.getOperand(i: 0);
29629 return true;
29630 }
29631 return false;
29632 }
29633 default:
29634 return false;
29635 }
29636 };
29637
29638 // Look at each chain and determine if it is an alias. If so, add it to the
29639 // aliases list. If not, then continue up the chain looking for the next
29640 // candidate.
29641 while (!Chains.empty()) {
29642 SDValue Chain = Chains.pop_back_val();
29643
29644 // Don't bother if we've seen Chain before.
29645 if (!Visited.insert(Ptr: Chain.getNode()).second)
29646 continue;
29647
29648 // For TokenFactor nodes, look at each operand and only continue up the
29649 // chain until we reach the depth limit.
29650 //
29651 // FIXME: The depth check could be made to return the last non-aliasing
29652 // chain we found before we hit a tokenfactor rather than the original
29653 // chain.
29654 if (Depth > TLI.getGatherAllAliasesMaxDepth()) {
29655 Aliases.clear();
29656 Aliases.push_back(Elt: OriginalChain);
29657 return;
29658 }
29659
29660 if (Chain.getOpcode() == ISD::TokenFactor) {
29661 // We have to check each of the operands of the token factor for "small"
29662 // token factors, so we queue them up. Adding the operands to the queue
29663 // (stack) in reverse order maintains the original order and increases the
29664 // likelihood that getNode will find a matching token factor (CSE.)
29665 if (Chain.getNumOperands() > 16) {
29666 Aliases.push_back(Elt: Chain);
29667 continue;
29668 }
29669 for (unsigned n = Chain.getNumOperands(); n;)
29670 Chains.push_back(Elt: Chain.getOperand(i: --n));
29671 ++Depth;
29672 continue;
29673 }
29674 // Everything else
29675 if (ImproveChain(Chain)) {
29676 // Updated Chain Found, Consider new chain if one exists.
29677 if (Chain.getNode())
29678 Chains.push_back(Elt: Chain);
29679 ++Depth;
29680 continue;
29681 }
29682 // No Improved Chain Possible, treat as Alias.
29683 Aliases.push_back(Elt: Chain);
29684 }
29685}
29686
29687/// Walk up chain skipping non-aliasing memory nodes, looking for a better chain
29688/// (aliasing node.)
29689SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) {
29690 if (OptLevel == CodeGenOptLevel::None)
29691 return OldChain;
29692
29693 // Ops for replacing token factor.
29694 SmallVector<SDValue, 8> Aliases;
29695
29696 // Accumulate all the aliases to this node.
29697 GatherAllAliases(N, OriginalChain: OldChain, Aliases);
29698
29699 // If no operands then chain to entry token.
29700 if (Aliases.empty())
29701 return DAG.getEntryNode();
29702
29703 // If a single operand then chain to it. We don't need to revisit it.
29704 if (Aliases.size() == 1)
29705 return Aliases[0];
29706
29707 // Construct a custom tailored token factor.
29708 return DAG.getTokenFactor(DL: SDLoc(N), Vals&: Aliases);
29709}
29710
29711// This function tries to collect a bunch of potentially interesting
29712// nodes to improve the chains of, all at once. This might seem
29713// redundant, as this function gets called when visiting every store
29714// node, so why not let the work be done on each store as it's visited?
29715//
29716// I believe this is mainly important because mergeConsecutiveStores
29717// is unable to deal with merging stores of different sizes, so unless
29718// we improve the chains of all the potential candidates up-front
29719// before running mergeConsecutiveStores, it might only see some of
29720// the nodes that will eventually be candidates, and then not be able
29721// to go from a partially-merged state to the desired final
29722// fully-merged state.
29723
29724bool DAGCombiner::parallelizeChainedStores(StoreSDNode *St) {
29725 SmallVector<StoreSDNode *, 8> ChainedStores;
29726 StoreSDNode *STChain = St;
29727 // Intervals records which offsets from BaseIndex have been covered. In
29728 // the common case, every store writes to the immediately previous address
29729 // space and thus merged with the previous interval at insertion time.
29730
29731 using IMap = llvm::IntervalMap<int64_t, std::monostate, 8,
29732 IntervalMapHalfOpenInfo<int64_t>>;
29733 IMap::Allocator A;
29734 IMap Intervals(A);
29735
29736 // This holds the base pointer, index, and the offset in bytes from the base
29737 // pointer.
29738 const BaseIndexOffset BasePtr = BaseIndexOffset::match(N: St, DAG);
29739
29740 // We must have a base and an offset.
29741 if (!BasePtr.getBase().getNode())
29742 return false;
29743
29744 // Do not handle stores to undef base pointers.
29745 if (BasePtr.getBase().isUndef())
29746 return false;
29747
29748 // Do not handle stores to opaque types
29749 if (St->getMemoryVT().isZeroSized())
29750 return false;
29751
29752 // BaseIndexOffset assumes that offsets are fixed-size, which
29753 // is not valid for scalable vectors where the offsets are
29754 // scaled by `vscale`, so bail out early.
29755 if (St->getMemoryVT().isScalableVT())
29756 return false;
29757
29758 // Add ST's interval.
29759 Intervals.insert(a: 0, b: (St->getMemoryVT().getSizeInBits() + 7) / 8,
29760 y: std::monostate{});
29761
29762 while (StoreSDNode *Chain = dyn_cast<StoreSDNode>(Val: STChain->getChain())) {
29763 if (Chain->getMemoryVT().isScalableVector())
29764 return false;
29765
29766 // If the chain has more than one use, then we can't reorder the mem ops.
29767 if (!SDValue(Chain, 0)->hasOneUse())
29768 break;
29769 // TODO: Relax for unordered atomics (see D66309)
29770 if (!Chain->isSimple() || Chain->isIndexed())
29771 break;
29772
29773 // Find the base pointer and offset for this memory node.
29774 const BaseIndexOffset Ptr = BaseIndexOffset::match(N: Chain, DAG);
29775 // Check that the base pointer is the same as the original one.
29776 int64_t Offset;
29777 if (!BasePtr.equalBaseIndex(Other: Ptr, DAG, Off&: Offset))
29778 break;
29779 int64_t Length = (Chain->getMemoryVT().getSizeInBits() + 7) / 8;
29780 // Make sure we don't overlap with other intervals by checking the ones to
29781 // the left or right before inserting.
29782 auto I = Intervals.find(x: Offset);
29783 // If there's a next interval, we should end before it.
29784 if (I != Intervals.end() && I.start() < (Offset + Length))
29785 break;
29786 // If there's a previous interval, we should start after it.
29787 if (I != Intervals.begin() && (--I).stop() <= Offset)
29788 break;
29789 Intervals.insert(a: Offset, b: Offset + Length, y: std::monostate{});
29790
29791 ChainedStores.push_back(Elt: Chain);
29792 STChain = Chain;
29793 }
29794
29795 // If we didn't find a chained store, exit.
29796 if (ChainedStores.empty())
29797 return false;
29798
29799 // Improve all chained stores (St and ChainedStores members) starting from
29800 // where the store chain ended and return single TokenFactor.
29801 SDValue NewChain = STChain->getChain();
29802 SmallVector<SDValue, 8> TFOps;
29803 for (unsigned I = ChainedStores.size(); I;) {
29804 StoreSDNode *S = ChainedStores[--I];
29805 SDValue BetterChain = FindBetterChain(N: S, OldChain: NewChain);
29806 S = cast<StoreSDNode>(Val: DAG.UpdateNodeOperands(
29807 N: S, Op1: BetterChain, Op2: S->getOperand(Num: 1), Op3: S->getOperand(Num: 2), Op4: S->getOperand(Num: 3)));
29808 TFOps.push_back(Elt: SDValue(S, 0));
29809 ChainedStores[I] = S;
29810 }
29811
29812 // Improve St's chain. Use a new node to avoid creating a loop from CombineTo.
29813 SDValue BetterChain = FindBetterChain(N: St, OldChain: NewChain);
29814 SDValue NewST;
29815 if (St->isTruncatingStore())
29816 NewST = DAG.getTruncStore(Chain: BetterChain, dl: SDLoc(St), Val: St->getValue(),
29817 Ptr: St->getBasePtr(), SVT: St->getMemoryVT(),
29818 MMO: St->getMemOperand());
29819 else
29820 NewST = DAG.getStore(Chain: BetterChain, dl: SDLoc(St), Val: St->getValue(),
29821 Ptr: St->getBasePtr(), MMO: St->getMemOperand());
29822
29823 TFOps.push_back(Elt: NewST);
29824
29825 // If we improved every element of TFOps, then we've lost the dependence on
29826 // NewChain to successors of St and we need to add it back to TFOps. Do so at
29827 // the beginning to keep relative order consistent with FindBetterChains.
29828 auto hasImprovedChain = [&](SDValue ST) -> bool {
29829 return ST->getOperand(Num: 0) != NewChain;
29830 };
29831 bool AddNewChain = llvm::all_of(Range&: TFOps, P: hasImprovedChain);
29832 if (AddNewChain)
29833 TFOps.insert(I: TFOps.begin(), Elt: NewChain);
29834
29835 SDValue TF = DAG.getTokenFactor(DL: SDLoc(STChain), Vals&: TFOps);
29836 CombineTo(N: St, Res: TF);
29837
29838 // Add TF and its operands to the worklist.
29839 AddToWorklist(N: TF.getNode());
29840 for (const SDValue &Op : TF->ops())
29841 AddToWorklist(N: Op.getNode());
29842 AddToWorklist(N: STChain);
29843 return true;
29844}
29845
29846bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) {
29847 if (OptLevel == CodeGenOptLevel::None)
29848 return false;
29849
29850 const BaseIndexOffset BasePtr = BaseIndexOffset::match(N: St, DAG);
29851
29852 // We must have a base and an offset.
29853 if (!BasePtr.getBase().getNode())
29854 return false;
29855
29856 // Do not handle stores to undef base pointers.
29857 if (BasePtr.getBase().isUndef())
29858 return false;
29859
29860 // Directly improve a chain of disjoint stores starting at St.
29861 if (parallelizeChainedStores(St))
29862 return true;
29863
29864 // Improve St's Chain..
29865 SDValue BetterChain = FindBetterChain(N: St, OldChain: St->getChain());
29866 if (St->getChain() != BetterChain) {
29867 replaceStoreChain(ST: St, BetterChain);
29868 return true;
29869 }
29870 return false;
29871}
29872
29873/// This is the entry point for the file.
29874void SelectionDAG::Combine(CombineLevel Level, BatchAAResults *BatchAA,
29875 CodeGenOptLevel OptLevel) {
29876 /// This is the main entry point to this class.
29877 DAGCombiner(*this, BatchAA, OptLevel).Run(AtLevel: Level);
29878}
29879