1//===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass custom lowers llvm.gather and llvm.scatter instructions to
10// RISC-V intrinsics.
11//
12//===----------------------------------------------------------------------===//
13
14#include "RISCV.h"
15#include "RISCVTargetMachine.h"
16#include "llvm/Analysis/InstSimplifyFolder.h"
17#include "llvm/Analysis/LoopInfo.h"
18#include "llvm/Analysis/ValueTracking.h"
19#include "llvm/Analysis/VectorUtils.h"
20#include "llvm/CodeGen/TargetPassConfig.h"
21#include "llvm/IR/GetElementPtrTypeIterator.h"
22#include "llvm/IR/IRBuilder.h"
23#include "llvm/IR/IntrinsicInst.h"
24#include "llvm/IR/IntrinsicsRISCV.h"
25#include "llvm/IR/PatternMatch.h"
26#include "llvm/Transforms/Utils/Local.h"
27#include <optional>
28
29using namespace llvm;
30using namespace PatternMatch;
31
32#define DEBUG_TYPE "riscv-gather-scatter-lowering"
33
34namespace {
35
36class RISCVGatherScatterLowering : public FunctionPass {
37 const RISCVSubtarget *ST = nullptr;
38 const RISCVTargetLowering *TLI = nullptr;
39 LoopInfo *LI = nullptr;
40 const DataLayout *DL = nullptr;
41
42 SmallVector<WeakTrackingVH> MaybeDeadPHIs;
43
44 // Cache of the BasePtr and Stride determined from this GEP. When a GEP is
45 // used by multiple gathers/scatters, this allow us to reuse the scalar
46 // instructions we created for the first gather/scatter for the others.
47 DenseMap<GetElementPtrInst *, std::pair<Value *, Value *>> StridedAddrs;
48
49public:
50 static char ID; // Pass identification, replacement for typeid
51
52 RISCVGatherScatterLowering() : FunctionPass(ID) {}
53
54 bool runOnFunction(Function &F) override;
55
56 void getAnalysisUsage(AnalysisUsage &AU) const override {
57 AU.setPreservesCFG();
58 AU.addRequired<TargetPassConfig>();
59 AU.addRequired<LoopInfoWrapperPass>();
60 }
61
62 StringRef getPassName() const override {
63 return "RISC-V gather/scatter lowering";
64 }
65
66private:
67 bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr,
68 Value *AlignOp);
69
70 std::pair<Value *, Value *> determineBaseAndStride(Instruction *Ptr,
71 IRBuilderBase &Builder);
72
73 bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride,
74 PHINode *&BasePtr, BinaryOperator *&Inc,
75 IRBuilderBase &Builder);
76};
77
78} // end anonymous namespace
79
80char RISCVGatherScatterLowering::ID = 0;
81
82INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE,
83 "RISC-V gather/scatter lowering pass", false, false)
84
85FunctionPass *llvm::createRISCVGatherScatterLoweringPass() {
86 return new RISCVGatherScatterLowering();
87}
88
89// TODO: Should we consider the mask when looking for a stride?
90static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) {
91 if (!isa<FixedVectorType>(Val: StartC->getType()))
92 return std::make_pair(x: nullptr, y: nullptr);
93
94 unsigned NumElts = cast<FixedVectorType>(Val: StartC->getType())->getNumElements();
95
96 // Check that the start value is a strided constant.
97 auto *StartVal =
98 dyn_cast_or_null<ConstantInt>(Val: StartC->getAggregateElement(Elt: (unsigned)0));
99 if (!StartVal)
100 return std::make_pair(x: nullptr, y: nullptr);
101 APInt StrideVal(StartVal->getValue().getBitWidth(), 0);
102 ConstantInt *Prev = StartVal;
103 for (unsigned i = 1; i != NumElts; ++i) {
104 auto *C = dyn_cast_or_null<ConstantInt>(Val: StartC->getAggregateElement(Elt: i));
105 if (!C)
106 return std::make_pair(x: nullptr, y: nullptr);
107
108 APInt LocalStride = C->getValue() - Prev->getValue();
109 if (i == 1)
110 StrideVal = LocalStride;
111 else if (StrideVal != LocalStride)
112 return std::make_pair(x: nullptr, y: nullptr);
113
114 Prev = C;
115 }
116
117 Value *Stride = ConstantInt::get(Ty: StartVal->getType(), V: StrideVal);
118
119 return std::make_pair(x&: StartVal, y&: Stride);
120}
121
122static std::pair<Value *, Value *> matchStridedStart(Value *Start,
123 IRBuilderBase &Builder) {
124 // Base case, start is a strided constant.
125 auto *StartC = dyn_cast<Constant>(Val: Start);
126 if (StartC)
127 return matchStridedConstant(StartC);
128
129 // Base case, start is a stepvector
130 if (match(V: Start, P: m_Intrinsic<Intrinsic::experimental_stepvector>())) {
131 auto *Ty = Start->getType()->getScalarType();
132 return std::make_pair(x: ConstantInt::get(Ty, V: 0), y: ConstantInt::get(Ty, V: 1));
133 }
134
135 // Not a constant, maybe it's a strided constant with a splat added or
136 // multipled.
137 auto *BO = dyn_cast<BinaryOperator>(Val: Start);
138 if (!BO || (BO->getOpcode() != Instruction::Add &&
139 BO->getOpcode() != Instruction::Or &&
140 BO->getOpcode() != Instruction::Shl &&
141 BO->getOpcode() != Instruction::Mul))
142 return std::make_pair(x: nullptr, y: nullptr);
143
144 if (BO->getOpcode() == Instruction::Or &&
145 !cast<PossiblyDisjointInst>(Val: BO)->isDisjoint())
146 return std::make_pair(x: nullptr, y: nullptr);
147
148 // Look for an operand that is splatted.
149 unsigned OtherIndex = 0;
150 Value *Splat = getSplatValue(V: BO->getOperand(i_nocapture: 1));
151 if (!Splat && Instruction::isCommutative(Opcode: BO->getOpcode())) {
152 Splat = getSplatValue(V: BO->getOperand(i_nocapture: 0));
153 OtherIndex = 1;
154 }
155 if (!Splat)
156 return std::make_pair(x: nullptr, y: nullptr);
157
158 Value *Stride;
159 std::tie(args&: Start, args&: Stride) = matchStridedStart(Start: BO->getOperand(i_nocapture: OtherIndex),
160 Builder);
161 if (!Start)
162 return std::make_pair(x: nullptr, y: nullptr);
163
164 Builder.SetInsertPoint(BO);
165 Builder.SetCurrentDebugLocation(DebugLoc());
166 // Add the splat value to the start or multiply the start and stride by the
167 // splat.
168 switch (BO->getOpcode()) {
169 default:
170 llvm_unreachable("Unexpected opcode");
171 case Instruction::Or:
172 // TODO: We'd be better off creating disjoint or here, but we don't yet
173 // have an IRBuilder API for that.
174 [[fallthrough]];
175 case Instruction::Add:
176 Start = Builder.CreateAdd(LHS: Start, RHS: Splat);
177 break;
178 case Instruction::Mul:
179 Start = Builder.CreateMul(LHS: Start, RHS: Splat);
180 Stride = Builder.CreateMul(LHS: Stride, RHS: Splat);
181 break;
182 case Instruction::Shl:
183 Start = Builder.CreateShl(LHS: Start, RHS: Splat);
184 Stride = Builder.CreateShl(LHS: Stride, RHS: Splat);
185 break;
186 }
187
188 return std::make_pair(x&: Start, y&: Stride);
189}
190
191// Recursively, walk about the use-def chain until we find a Phi with a strided
192// start value. Build and update a scalar recurrence as we unwind the recursion.
193// We also update the Stride as we unwind. Our goal is to move all of the
194// arithmetic out of the loop.
195bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
196 Value *&Stride,
197 PHINode *&BasePtr,
198 BinaryOperator *&Inc,
199 IRBuilderBase &Builder) {
200 // Our base case is a Phi.
201 if (auto *Phi = dyn_cast<PHINode>(Val: Index)) {
202 // A phi node we want to perform this function on should be from the
203 // loop header.
204 if (Phi->getParent() != L->getHeader())
205 return false;
206
207 Value *Step, *Start;
208 if (!matchSimpleRecurrence(P: Phi, BO&: Inc, Start, Step) ||
209 Inc->getOpcode() != Instruction::Add)
210 return false;
211 assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
212 unsigned IncrementingBlock = Phi->getIncomingValue(i: 0) == Inc ? 0 : 1;
213 assert(Phi->getIncomingValue(IncrementingBlock) == Inc &&
214 "Expected one operand of phi to be Inc");
215
216 // Only proceed if the step is loop invariant.
217 if (!L->isLoopInvariant(V: Step))
218 return false;
219
220 // Step should be a splat.
221 Step = getSplatValue(V: Step);
222 if (!Step)
223 return false;
224
225 std::tie(args&: Start, args&: Stride) = matchStridedStart(Start, Builder);
226 if (!Start)
227 return false;
228 assert(Stride != nullptr);
229
230 // Build scalar phi and increment.
231 BasePtr =
232 PHINode::Create(Ty: Start->getType(), NumReservedValues: 2, NameStr: Phi->getName() + ".scalar", InsertBefore: Phi->getIterator());
233 Inc = BinaryOperator::CreateAdd(V1: BasePtr, V2: Step, Name: Inc->getName() + ".scalar",
234 It: Inc->getIterator());
235 BasePtr->addIncoming(V: Start, BB: Phi->getIncomingBlock(i: 1 - IncrementingBlock));
236 BasePtr->addIncoming(V: Inc, BB: Phi->getIncomingBlock(i: IncrementingBlock));
237
238 // Note that this Phi might be eligible for removal.
239 MaybeDeadPHIs.push_back(Elt: Phi);
240 return true;
241 }
242
243 // Otherwise look for binary operator.
244 auto *BO = dyn_cast<BinaryOperator>(Val: Index);
245 if (!BO)
246 return false;
247
248 switch (BO->getOpcode()) {
249 default:
250 return false;
251 case Instruction::Or:
252 // We need to be able to treat Or as Add.
253 if (!cast<PossiblyDisjointInst>(Val: BO)->isDisjoint())
254 return false;
255 break;
256 case Instruction::Add:
257 break;
258 case Instruction::Shl:
259 break;
260 case Instruction::Mul:
261 break;
262 }
263
264 // We should have one operand in the loop and one splat.
265 Value *OtherOp;
266 if (isa<Instruction>(Val: BO->getOperand(i_nocapture: 0)) &&
267 L->contains(Inst: cast<Instruction>(Val: BO->getOperand(i_nocapture: 0)))) {
268 Index = cast<Instruction>(Val: BO->getOperand(i_nocapture: 0));
269 OtherOp = BO->getOperand(i_nocapture: 1);
270 } else if (isa<Instruction>(Val: BO->getOperand(i_nocapture: 1)) &&
271 L->contains(Inst: cast<Instruction>(Val: BO->getOperand(i_nocapture: 1))) &&
272 Instruction::isCommutative(Opcode: BO->getOpcode())) {
273 Index = cast<Instruction>(Val: BO->getOperand(i_nocapture: 1));
274 OtherOp = BO->getOperand(i_nocapture: 0);
275 } else {
276 return false;
277 }
278
279 // Make sure other op is loop invariant.
280 if (!L->isLoopInvariant(V: OtherOp))
281 return false;
282
283 // Make sure we have a splat.
284 Value *SplatOp = getSplatValue(V: OtherOp);
285 if (!SplatOp)
286 return false;
287
288 // Recurse up the use-def chain.
289 if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder))
290 return false;
291
292 // Locate the Step and Start values from the recurrence.
293 unsigned StepIndex = Inc->getOperand(i_nocapture: 0) == BasePtr ? 1 : 0;
294 unsigned StartBlock = BasePtr->getOperand(i_nocapture: 0) == Inc ? 1 : 0;
295 Value *Step = Inc->getOperand(i_nocapture: StepIndex);
296 Value *Start = BasePtr->getOperand(i_nocapture: StartBlock);
297
298 // We need to adjust the start value in the preheader.
299 Builder.SetInsertPoint(
300 BasePtr->getIncomingBlock(i: StartBlock)->getTerminator());
301 Builder.SetCurrentDebugLocation(DebugLoc());
302
303 switch (BO->getOpcode()) {
304 default:
305 llvm_unreachable("Unexpected opcode!");
306 case Instruction::Add:
307 case Instruction::Or: {
308 // An add only affects the start value. It's ok to do this for Or because
309 // we already checked that there are no common set bits.
310 Start = Builder.CreateAdd(LHS: Start, RHS: SplatOp, Name: "start");
311 break;
312 }
313 case Instruction::Mul: {
314 Start = Builder.CreateMul(LHS: Start, RHS: SplatOp, Name: "start");
315 Step = Builder.CreateMul(LHS: Step, RHS: SplatOp, Name: "step");
316 Stride = Builder.CreateMul(LHS: Stride, RHS: SplatOp, Name: "stride");
317 break;
318 }
319 case Instruction::Shl: {
320 Start = Builder.CreateShl(LHS: Start, RHS: SplatOp, Name: "start");
321 Step = Builder.CreateShl(LHS: Step, RHS: SplatOp, Name: "step");
322 Stride = Builder.CreateShl(LHS: Stride, RHS: SplatOp, Name: "stride");
323 break;
324 }
325 }
326
327 Inc->setOperand(i_nocapture: StepIndex, Val_nocapture: Step);
328 BasePtr->setIncomingValue(i: StartBlock, V: Start);
329 return true;
330}
331
332std::pair<Value *, Value *>
333RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr,
334 IRBuilderBase &Builder) {
335
336 // A gather/scatter of a splat is a zero strided load/store.
337 if (auto *BasePtr = getSplatValue(V: Ptr)) {
338 Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
339 return std::make_pair(x&: BasePtr, y: ConstantInt::get(Ty: IntPtrTy, V: 0));
340 }
341
342 auto *GEP = dyn_cast<GetElementPtrInst>(Val: Ptr);
343 if (!GEP)
344 return std::make_pair(x: nullptr, y: nullptr);
345
346 auto I = StridedAddrs.find(Val: GEP);
347 if (I != StridedAddrs.end())
348 return I->second;
349
350 SmallVector<Value *, 2> Ops(GEP->operands());
351
352 // If the base pointer is a vector, check if it's strided.
353 Value *Base = GEP->getPointerOperand();
354 if (auto *BaseInst = dyn_cast<Instruction>(Val: Base);
355 BaseInst && BaseInst->getType()->isVectorTy()) {
356 // If GEP's offset is scalar then we can add it to the base pointer's base.
357 auto IsScalar = [](Value *Idx) { return !Idx->getType()->isVectorTy(); };
358 if (all_of(Range: GEP->indices(), P: IsScalar)) {
359 auto [BaseBase, Stride] = determineBaseAndStride(Ptr: BaseInst, Builder);
360 if (BaseBase) {
361 Builder.SetInsertPoint(GEP);
362 SmallVector<Value *> Indices(GEP->indices());
363 Value *OffsetBase =
364 Builder.CreateGEP(Ty: GEP->getSourceElementType(), Ptr: BaseBase, IdxList: Indices,
365 Name: GEP->getName() + "offset", NW: GEP->isInBounds());
366 return {OffsetBase, Stride};
367 }
368 }
369 }
370
371 // Base pointer needs to be a scalar.
372 Value *ScalarBase = Base;
373 if (ScalarBase->getType()->isVectorTy()) {
374 ScalarBase = getSplatValue(V: ScalarBase);
375 if (!ScalarBase)
376 return std::make_pair(x: nullptr, y: nullptr);
377 }
378
379 std::optional<unsigned> VecOperand;
380 unsigned TypeScale = 0;
381
382 // Look for a vector operand and scale.
383 gep_type_iterator GTI = gep_type_begin(GEP);
384 for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) {
385 if (!Ops[i]->getType()->isVectorTy())
386 continue;
387
388 if (VecOperand)
389 return std::make_pair(x: nullptr, y: nullptr);
390
391 VecOperand = i;
392
393 TypeSize TS = GTI.getSequentialElementStride(DL: *DL);
394 if (TS.isScalable())
395 return std::make_pair(x: nullptr, y: nullptr);
396
397 TypeScale = TS.getFixedValue();
398 }
399
400 // We need to find a vector index to simplify.
401 if (!VecOperand)
402 return std::make_pair(x: nullptr, y: nullptr);
403
404 // We can't extract the stride if the arithmetic is done at a different size
405 // than the pointer type. Adding the stride later may not wrap correctly.
406 // Technically we could handle wider indices, but I don't expect that in
407 // practice. Handle one special case here - constants. This simplifies
408 // writing test cases.
409 Value *VecIndex = Ops[*VecOperand];
410 Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType());
411 if (VecIndex->getType() != VecIntPtrTy) {
412 auto *VecIndexC = dyn_cast<Constant>(Val: VecIndex);
413 if (!VecIndexC)
414 return std::make_pair(x: nullptr, y: nullptr);
415 if (VecIndex->getType()->getScalarSizeInBits() > VecIntPtrTy->getScalarSizeInBits())
416 VecIndex = ConstantFoldCastInstruction(opcode: Instruction::Trunc, V: VecIndexC, DestTy: VecIntPtrTy);
417 else
418 VecIndex = ConstantFoldCastInstruction(opcode: Instruction::SExt, V: VecIndexC, DestTy: VecIntPtrTy);
419 }
420
421 // Handle the non-recursive case. This is what we see if the vectorizer
422 // decides to use a scalar IV + vid on demand instead of a vector IV.
423 auto [Start, Stride] = matchStridedStart(Start: VecIndex, Builder);
424 if (Start) {
425 assert(Stride);
426 Builder.SetInsertPoint(GEP);
427
428 // Replace the vector index with the scalar start and build a scalar GEP.
429 Ops[*VecOperand] = Start;
430 Type *SourceTy = GEP->getSourceElementType();
431 Value *BasePtr =
432 Builder.CreateGEP(Ty: SourceTy, Ptr: ScalarBase, IdxList: ArrayRef(Ops).drop_front());
433
434 // Convert stride to pointer size if needed.
435 Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
436 assert(Stride->getType() == IntPtrTy && "Unexpected type");
437
438 // Scale the stride by the size of the indexed type.
439 if (TypeScale != 1)
440 Stride = Builder.CreateMul(LHS: Stride, RHS: ConstantInt::get(Ty: IntPtrTy, V: TypeScale));
441
442 auto P = std::make_pair(x&: BasePtr, y&: Stride);
443 StridedAddrs[GEP] = P;
444 return P;
445 }
446
447 // Make sure we're in a loop and that has a pre-header and a single latch.
448 Loop *L = LI->getLoopFor(BB: GEP->getParent());
449 if (!L || !L->getLoopPreheader() || !L->getLoopLatch())
450 return std::make_pair(x: nullptr, y: nullptr);
451
452 BinaryOperator *Inc;
453 PHINode *BasePhi;
454 if (!matchStridedRecurrence(Index: VecIndex, L, Stride, BasePtr&: BasePhi, Inc, Builder))
455 return std::make_pair(x: nullptr, y: nullptr);
456
457 assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
458 unsigned IncrementingBlock = BasePhi->getOperand(i_nocapture: 0) == Inc ? 0 : 1;
459 assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc &&
460 "Expected one operand of phi to be Inc");
461
462 Builder.SetInsertPoint(GEP);
463
464 // Replace the vector index with the scalar phi and build a scalar GEP.
465 Ops[*VecOperand] = BasePhi;
466 Type *SourceTy = GEP->getSourceElementType();
467 Value *BasePtr =
468 Builder.CreateGEP(Ty: SourceTy, Ptr: ScalarBase, IdxList: ArrayRef(Ops).drop_front());
469
470 // Final adjustments to stride should go in the start block.
471 Builder.SetInsertPoint(
472 BasePhi->getIncomingBlock(i: 1 - IncrementingBlock)->getTerminator());
473
474 // Convert stride to pointer size if needed.
475 Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
476 assert(Stride->getType() == IntPtrTy && "Unexpected type");
477
478 // Scale the stride by the size of the indexed type.
479 if (TypeScale != 1)
480 Stride = Builder.CreateMul(LHS: Stride, RHS: ConstantInt::get(Ty: IntPtrTy, V: TypeScale));
481
482 auto P = std::make_pair(x&: BasePtr, y&: Stride);
483 StridedAddrs[GEP] = P;
484 return P;
485}
486
487bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
488 Type *DataType,
489 Value *Ptr,
490 Value *AlignOp) {
491 // Make sure the operation will be supported by the backend.
492 MaybeAlign MA = cast<ConstantInt>(Val: AlignOp)->getMaybeAlignValue();
493 EVT DataTypeVT = TLI->getValueType(DL: *DL, Ty: DataType);
494 if (!MA || !TLI->isLegalStridedLoadStore(DataType: DataTypeVT, Alignment: *MA))
495 return false;
496
497 // FIXME: Let the backend type legalize by splitting/widening?
498 if (!TLI->isTypeLegal(VT: DataTypeVT))
499 return false;
500
501 // Pointer should be an instruction.
502 auto *PtrI = dyn_cast<Instruction>(Val: Ptr);
503 if (!PtrI)
504 return false;
505
506 LLVMContext &Ctx = PtrI->getContext();
507 IRBuilder<InstSimplifyFolder> Builder(Ctx, *DL);
508 Builder.SetInsertPoint(PtrI);
509
510 Value *BasePtr, *Stride;
511 std::tie(args&: BasePtr, args&: Stride) = determineBaseAndStride(Ptr: PtrI, Builder);
512 if (!BasePtr)
513 return false;
514 assert(Stride != nullptr);
515
516 Builder.SetInsertPoint(II);
517
518 CallInst *Call;
519 if (II->getIntrinsicID() == Intrinsic::masked_gather)
520 Call = Builder.CreateIntrinsic(
521 ID: Intrinsic::riscv_masked_strided_load,
522 Types: {DataType, BasePtr->getType(), Stride->getType()},
523 Args: {II->getArgOperand(i: 3), BasePtr, Stride, II->getArgOperand(i: 2)});
524 else
525 Call = Builder.CreateIntrinsic(
526 ID: Intrinsic::riscv_masked_strided_store,
527 Types: {DataType, BasePtr->getType(), Stride->getType()},
528 Args: {II->getArgOperand(i: 0), BasePtr, Stride, II->getArgOperand(i: 3)});
529
530 Call->takeName(V: II);
531 II->replaceAllUsesWith(V: Call);
532 II->eraseFromParent();
533
534 if (PtrI->use_empty())
535 RecursivelyDeleteTriviallyDeadInstructions(V: PtrI);
536
537 return true;
538}
539
540bool RISCVGatherScatterLowering::runOnFunction(Function &F) {
541 if (skipFunction(F))
542 return false;
543
544 auto &TPC = getAnalysis<TargetPassConfig>();
545 auto &TM = TPC.getTM<RISCVTargetMachine>();
546 ST = &TM.getSubtarget<RISCVSubtarget>(F);
547 if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors())
548 return false;
549
550 TLI = ST->getTargetLowering();
551 DL = &F.getDataLayout();
552 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
553
554 StridedAddrs.clear();
555
556 SmallVector<IntrinsicInst *, 4> Gathers;
557 SmallVector<IntrinsicInst *, 4> Scatters;
558
559 bool Changed = false;
560
561 for (BasicBlock &BB : F) {
562 for (Instruction &I : BB) {
563 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: &I);
564 if (II && II->getIntrinsicID() == Intrinsic::masked_gather) {
565 Gathers.push_back(Elt: II);
566 } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) {
567 Scatters.push_back(Elt: II);
568 }
569 }
570 }
571
572 // Rewrite gather/scatter to form strided load/store if possible.
573 for (auto *II : Gathers)
574 Changed |= tryCreateStridedLoadStore(
575 II, DataType: II->getType(), Ptr: II->getArgOperand(i: 0), AlignOp: II->getArgOperand(i: 1));
576 for (auto *II : Scatters)
577 Changed |=
578 tryCreateStridedLoadStore(II, DataType: II->getArgOperand(i: 0)->getType(),
579 Ptr: II->getArgOperand(i: 1), AlignOp: II->getArgOperand(i: 2));
580
581 // Remove any dead phis.
582 while (!MaybeDeadPHIs.empty()) {
583 if (auto *Phi = dyn_cast_or_null<PHINode>(Val: MaybeDeadPHIs.pop_back_val()))
584 RecursivelyDeleteDeadPHINode(PN: Phi);
585 }
586
587 return Changed;
588}
589