1 | //===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This pass custom lowers llvm.gather and llvm.scatter instructions to |
10 | // RISC-V intrinsics. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "RISCV.h" |
15 | #include "RISCVTargetMachine.h" |
16 | #include "llvm/Analysis/InstSimplifyFolder.h" |
17 | #include "llvm/Analysis/LoopInfo.h" |
18 | #include "llvm/Analysis/ValueTracking.h" |
19 | #include "llvm/Analysis/VectorUtils.h" |
20 | #include "llvm/CodeGen/TargetPassConfig.h" |
21 | #include "llvm/IR/GetElementPtrTypeIterator.h" |
22 | #include "llvm/IR/IRBuilder.h" |
23 | #include "llvm/IR/IntrinsicInst.h" |
24 | #include "llvm/IR/IntrinsicsRISCV.h" |
25 | #include "llvm/IR/PatternMatch.h" |
26 | #include "llvm/Transforms/Utils/Local.h" |
27 | #include <optional> |
28 | |
29 | using namespace llvm; |
30 | using namespace PatternMatch; |
31 | |
32 | #define DEBUG_TYPE "riscv-gather-scatter-lowering" |
33 | |
34 | namespace { |
35 | |
36 | class RISCVGatherScatterLowering : public FunctionPass { |
37 | const RISCVSubtarget *ST = nullptr; |
38 | const RISCVTargetLowering *TLI = nullptr; |
39 | LoopInfo *LI = nullptr; |
40 | const DataLayout *DL = nullptr; |
41 | |
42 | SmallVector<WeakTrackingVH> MaybeDeadPHIs; |
43 | |
44 | // Cache of the BasePtr and Stride determined from this GEP. When a GEP is |
45 | // used by multiple gathers/scatters, this allow us to reuse the scalar |
46 | // instructions we created for the first gather/scatter for the others. |
47 | DenseMap<GetElementPtrInst *, std::pair<Value *, Value *>> StridedAddrs; |
48 | |
49 | public: |
50 | static char ID; // Pass identification, replacement for typeid |
51 | |
52 | RISCVGatherScatterLowering() : FunctionPass(ID) {} |
53 | |
54 | bool runOnFunction(Function &F) override; |
55 | |
56 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
57 | AU.setPreservesCFG(); |
58 | AU.addRequired<TargetPassConfig>(); |
59 | AU.addRequired<LoopInfoWrapperPass>(); |
60 | } |
61 | |
62 | StringRef getPassName() const override { |
63 | return "RISC-V gather/scatter lowering" ; |
64 | } |
65 | |
66 | private: |
67 | bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr, |
68 | Value *AlignOp); |
69 | |
70 | std::pair<Value *, Value *> determineBaseAndStride(Instruction *Ptr, |
71 | IRBuilderBase &Builder); |
72 | |
73 | bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride, |
74 | PHINode *&BasePtr, BinaryOperator *&Inc, |
75 | IRBuilderBase &Builder); |
76 | }; |
77 | |
78 | } // end anonymous namespace |
79 | |
80 | char RISCVGatherScatterLowering::ID = 0; |
81 | |
82 | INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE, |
83 | "RISC-V gather/scatter lowering pass" , false, false) |
84 | |
85 | FunctionPass *llvm::createRISCVGatherScatterLoweringPass() { |
86 | return new RISCVGatherScatterLowering(); |
87 | } |
88 | |
89 | // TODO: Should we consider the mask when looking for a stride? |
90 | static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) { |
91 | if (!isa<FixedVectorType>(Val: StartC->getType())) |
92 | return std::make_pair(x: nullptr, y: nullptr); |
93 | |
94 | unsigned NumElts = cast<FixedVectorType>(Val: StartC->getType())->getNumElements(); |
95 | |
96 | // Check that the start value is a strided constant. |
97 | auto *StartVal = |
98 | dyn_cast_or_null<ConstantInt>(Val: StartC->getAggregateElement(Elt: (unsigned)0)); |
99 | if (!StartVal) |
100 | return std::make_pair(x: nullptr, y: nullptr); |
101 | APInt StrideVal(StartVal->getValue().getBitWidth(), 0); |
102 | ConstantInt *Prev = StartVal; |
103 | for (unsigned i = 1; i != NumElts; ++i) { |
104 | auto *C = dyn_cast_or_null<ConstantInt>(Val: StartC->getAggregateElement(Elt: i)); |
105 | if (!C) |
106 | return std::make_pair(x: nullptr, y: nullptr); |
107 | |
108 | APInt LocalStride = C->getValue() - Prev->getValue(); |
109 | if (i == 1) |
110 | StrideVal = LocalStride; |
111 | else if (StrideVal != LocalStride) |
112 | return std::make_pair(x: nullptr, y: nullptr); |
113 | |
114 | Prev = C; |
115 | } |
116 | |
117 | Value *Stride = ConstantInt::get(Ty: StartVal->getType(), V: StrideVal); |
118 | |
119 | return std::make_pair(x&: StartVal, y&: Stride); |
120 | } |
121 | |
122 | static std::pair<Value *, Value *> matchStridedStart(Value *Start, |
123 | IRBuilderBase &Builder) { |
124 | // Base case, start is a strided constant. |
125 | auto *StartC = dyn_cast<Constant>(Val: Start); |
126 | if (StartC) |
127 | return matchStridedConstant(StartC); |
128 | |
129 | // Base case, start is a stepvector |
130 | if (match(V: Start, P: m_Intrinsic<Intrinsic::experimental_stepvector>())) { |
131 | auto *Ty = Start->getType()->getScalarType(); |
132 | return std::make_pair(x: ConstantInt::get(Ty, V: 0), y: ConstantInt::get(Ty, V: 1)); |
133 | } |
134 | |
135 | // Not a constant, maybe it's a strided constant with a splat added or |
136 | // multipled. |
137 | auto *BO = dyn_cast<BinaryOperator>(Val: Start); |
138 | if (!BO || (BO->getOpcode() != Instruction::Add && |
139 | BO->getOpcode() != Instruction::Or && |
140 | BO->getOpcode() != Instruction::Shl && |
141 | BO->getOpcode() != Instruction::Mul)) |
142 | return std::make_pair(x: nullptr, y: nullptr); |
143 | |
144 | if (BO->getOpcode() == Instruction::Or && |
145 | !cast<PossiblyDisjointInst>(Val: BO)->isDisjoint()) |
146 | return std::make_pair(x: nullptr, y: nullptr); |
147 | |
148 | // Look for an operand that is splatted. |
149 | unsigned OtherIndex = 0; |
150 | Value *Splat = getSplatValue(V: BO->getOperand(i_nocapture: 1)); |
151 | if (!Splat && Instruction::isCommutative(Opcode: BO->getOpcode())) { |
152 | Splat = getSplatValue(V: BO->getOperand(i_nocapture: 0)); |
153 | OtherIndex = 1; |
154 | } |
155 | if (!Splat) |
156 | return std::make_pair(x: nullptr, y: nullptr); |
157 | |
158 | Value *Stride; |
159 | std::tie(args&: Start, args&: Stride) = matchStridedStart(Start: BO->getOperand(i_nocapture: OtherIndex), |
160 | Builder); |
161 | if (!Start) |
162 | return std::make_pair(x: nullptr, y: nullptr); |
163 | |
164 | Builder.SetInsertPoint(BO); |
165 | Builder.SetCurrentDebugLocation(DebugLoc()); |
166 | // Add the splat value to the start or multiply the start and stride by the |
167 | // splat. |
168 | switch (BO->getOpcode()) { |
169 | default: |
170 | llvm_unreachable("Unexpected opcode" ); |
171 | case Instruction::Or: |
172 | // TODO: We'd be better off creating disjoint or here, but we don't yet |
173 | // have an IRBuilder API for that. |
174 | [[fallthrough]]; |
175 | case Instruction::Add: |
176 | Start = Builder.CreateAdd(LHS: Start, RHS: Splat); |
177 | break; |
178 | case Instruction::Mul: |
179 | Start = Builder.CreateMul(LHS: Start, RHS: Splat); |
180 | Stride = Builder.CreateMul(LHS: Stride, RHS: Splat); |
181 | break; |
182 | case Instruction::Shl: |
183 | Start = Builder.CreateShl(LHS: Start, RHS: Splat); |
184 | Stride = Builder.CreateShl(LHS: Stride, RHS: Splat); |
185 | break; |
186 | } |
187 | |
188 | return std::make_pair(x&: Start, y&: Stride); |
189 | } |
190 | |
191 | // Recursively, walk about the use-def chain until we find a Phi with a strided |
192 | // start value. Build and update a scalar recurrence as we unwind the recursion. |
193 | // We also update the Stride as we unwind. Our goal is to move all of the |
194 | // arithmetic out of the loop. |
195 | bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L, |
196 | Value *&Stride, |
197 | PHINode *&BasePtr, |
198 | BinaryOperator *&Inc, |
199 | IRBuilderBase &Builder) { |
200 | // Our base case is a Phi. |
201 | if (auto *Phi = dyn_cast<PHINode>(Val: Index)) { |
202 | // A phi node we want to perform this function on should be from the |
203 | // loop header. |
204 | if (Phi->getParent() != L->getHeader()) |
205 | return false; |
206 | |
207 | Value *Step, *Start; |
208 | if (!matchSimpleRecurrence(P: Phi, BO&: Inc, Start, Step) || |
209 | Inc->getOpcode() != Instruction::Add) |
210 | return false; |
211 | assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi." ); |
212 | unsigned IncrementingBlock = Phi->getIncomingValue(i: 0) == Inc ? 0 : 1; |
213 | assert(Phi->getIncomingValue(IncrementingBlock) == Inc && |
214 | "Expected one operand of phi to be Inc" ); |
215 | |
216 | // Only proceed if the step is loop invariant. |
217 | if (!L->isLoopInvariant(V: Step)) |
218 | return false; |
219 | |
220 | // Step should be a splat. |
221 | Step = getSplatValue(V: Step); |
222 | if (!Step) |
223 | return false; |
224 | |
225 | std::tie(args&: Start, args&: Stride) = matchStridedStart(Start, Builder); |
226 | if (!Start) |
227 | return false; |
228 | assert(Stride != nullptr); |
229 | |
230 | // Build scalar phi and increment. |
231 | BasePtr = |
232 | PHINode::Create(Ty: Start->getType(), NumReservedValues: 2, NameStr: Phi->getName() + ".scalar" , InsertBefore: Phi->getIterator()); |
233 | Inc = BinaryOperator::CreateAdd(V1: BasePtr, V2: Step, Name: Inc->getName() + ".scalar" , |
234 | It: Inc->getIterator()); |
235 | BasePtr->addIncoming(V: Start, BB: Phi->getIncomingBlock(i: 1 - IncrementingBlock)); |
236 | BasePtr->addIncoming(V: Inc, BB: Phi->getIncomingBlock(i: IncrementingBlock)); |
237 | |
238 | // Note that this Phi might be eligible for removal. |
239 | MaybeDeadPHIs.push_back(Elt: Phi); |
240 | return true; |
241 | } |
242 | |
243 | // Otherwise look for binary operator. |
244 | auto *BO = dyn_cast<BinaryOperator>(Val: Index); |
245 | if (!BO) |
246 | return false; |
247 | |
248 | switch (BO->getOpcode()) { |
249 | default: |
250 | return false; |
251 | case Instruction::Or: |
252 | // We need to be able to treat Or as Add. |
253 | if (!cast<PossiblyDisjointInst>(Val: BO)->isDisjoint()) |
254 | return false; |
255 | break; |
256 | case Instruction::Add: |
257 | break; |
258 | case Instruction::Shl: |
259 | break; |
260 | case Instruction::Mul: |
261 | break; |
262 | } |
263 | |
264 | // We should have one operand in the loop and one splat. |
265 | Value *OtherOp; |
266 | if (isa<Instruction>(Val: BO->getOperand(i_nocapture: 0)) && |
267 | L->contains(Inst: cast<Instruction>(Val: BO->getOperand(i_nocapture: 0)))) { |
268 | Index = cast<Instruction>(Val: BO->getOperand(i_nocapture: 0)); |
269 | OtherOp = BO->getOperand(i_nocapture: 1); |
270 | } else if (isa<Instruction>(Val: BO->getOperand(i_nocapture: 1)) && |
271 | L->contains(Inst: cast<Instruction>(Val: BO->getOperand(i_nocapture: 1))) && |
272 | Instruction::isCommutative(Opcode: BO->getOpcode())) { |
273 | Index = cast<Instruction>(Val: BO->getOperand(i_nocapture: 1)); |
274 | OtherOp = BO->getOperand(i_nocapture: 0); |
275 | } else { |
276 | return false; |
277 | } |
278 | |
279 | // Make sure other op is loop invariant. |
280 | if (!L->isLoopInvariant(V: OtherOp)) |
281 | return false; |
282 | |
283 | // Make sure we have a splat. |
284 | Value *SplatOp = getSplatValue(V: OtherOp); |
285 | if (!SplatOp) |
286 | return false; |
287 | |
288 | // Recurse up the use-def chain. |
289 | if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder)) |
290 | return false; |
291 | |
292 | // Locate the Step and Start values from the recurrence. |
293 | unsigned StepIndex = Inc->getOperand(i_nocapture: 0) == BasePtr ? 1 : 0; |
294 | unsigned StartBlock = BasePtr->getOperand(i_nocapture: 0) == Inc ? 1 : 0; |
295 | Value *Step = Inc->getOperand(i_nocapture: StepIndex); |
296 | Value *Start = BasePtr->getOperand(i_nocapture: StartBlock); |
297 | |
298 | // We need to adjust the start value in the preheader. |
299 | Builder.SetInsertPoint( |
300 | BasePtr->getIncomingBlock(i: StartBlock)->getTerminator()); |
301 | Builder.SetCurrentDebugLocation(DebugLoc()); |
302 | |
303 | switch (BO->getOpcode()) { |
304 | default: |
305 | llvm_unreachable("Unexpected opcode!" ); |
306 | case Instruction::Add: |
307 | case Instruction::Or: { |
308 | // An add only affects the start value. It's ok to do this for Or because |
309 | // we already checked that there are no common set bits. |
310 | Start = Builder.CreateAdd(LHS: Start, RHS: SplatOp, Name: "start" ); |
311 | break; |
312 | } |
313 | case Instruction::Mul: { |
314 | Start = Builder.CreateMul(LHS: Start, RHS: SplatOp, Name: "start" ); |
315 | Step = Builder.CreateMul(LHS: Step, RHS: SplatOp, Name: "step" ); |
316 | Stride = Builder.CreateMul(LHS: Stride, RHS: SplatOp, Name: "stride" ); |
317 | break; |
318 | } |
319 | case Instruction::Shl: { |
320 | Start = Builder.CreateShl(LHS: Start, RHS: SplatOp, Name: "start" ); |
321 | Step = Builder.CreateShl(LHS: Step, RHS: SplatOp, Name: "step" ); |
322 | Stride = Builder.CreateShl(LHS: Stride, RHS: SplatOp, Name: "stride" ); |
323 | break; |
324 | } |
325 | } |
326 | |
327 | Inc->setOperand(i_nocapture: StepIndex, Val_nocapture: Step); |
328 | BasePtr->setIncomingValue(i: StartBlock, V: Start); |
329 | return true; |
330 | } |
331 | |
332 | std::pair<Value *, Value *> |
333 | RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr, |
334 | IRBuilderBase &Builder) { |
335 | |
336 | // A gather/scatter of a splat is a zero strided load/store. |
337 | if (auto *BasePtr = getSplatValue(V: Ptr)) { |
338 | Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); |
339 | return std::make_pair(x&: BasePtr, y: ConstantInt::get(Ty: IntPtrTy, V: 0)); |
340 | } |
341 | |
342 | auto *GEP = dyn_cast<GetElementPtrInst>(Val: Ptr); |
343 | if (!GEP) |
344 | return std::make_pair(x: nullptr, y: nullptr); |
345 | |
346 | auto I = StridedAddrs.find(Val: GEP); |
347 | if (I != StridedAddrs.end()) |
348 | return I->second; |
349 | |
350 | SmallVector<Value *, 2> Ops(GEP->operands()); |
351 | |
352 | // If the base pointer is a vector, check if it's strided. |
353 | Value *Base = GEP->getPointerOperand(); |
354 | if (auto *BaseInst = dyn_cast<Instruction>(Val: Base); |
355 | BaseInst && BaseInst->getType()->isVectorTy()) { |
356 | // If GEP's offset is scalar then we can add it to the base pointer's base. |
357 | auto IsScalar = [](Value *Idx) { return !Idx->getType()->isVectorTy(); }; |
358 | if (all_of(Range: GEP->indices(), P: IsScalar)) { |
359 | auto [BaseBase, Stride] = determineBaseAndStride(Ptr: BaseInst, Builder); |
360 | if (BaseBase) { |
361 | Builder.SetInsertPoint(GEP); |
362 | SmallVector<Value *> Indices(GEP->indices()); |
363 | Value *OffsetBase = |
364 | Builder.CreateGEP(Ty: GEP->getSourceElementType(), Ptr: BaseBase, IdxList: Indices, |
365 | Name: GEP->getName() + "offset" , NW: GEP->isInBounds()); |
366 | return {OffsetBase, Stride}; |
367 | } |
368 | } |
369 | } |
370 | |
371 | // Base pointer needs to be a scalar. |
372 | Value *ScalarBase = Base; |
373 | if (ScalarBase->getType()->isVectorTy()) { |
374 | ScalarBase = getSplatValue(V: ScalarBase); |
375 | if (!ScalarBase) |
376 | return std::make_pair(x: nullptr, y: nullptr); |
377 | } |
378 | |
379 | std::optional<unsigned> VecOperand; |
380 | unsigned TypeScale = 0; |
381 | |
382 | // Look for a vector operand and scale. |
383 | gep_type_iterator GTI = gep_type_begin(GEP); |
384 | for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) { |
385 | if (!Ops[i]->getType()->isVectorTy()) |
386 | continue; |
387 | |
388 | if (VecOperand) |
389 | return std::make_pair(x: nullptr, y: nullptr); |
390 | |
391 | VecOperand = i; |
392 | |
393 | TypeSize TS = GTI.getSequentialElementStride(DL: *DL); |
394 | if (TS.isScalable()) |
395 | return std::make_pair(x: nullptr, y: nullptr); |
396 | |
397 | TypeScale = TS.getFixedValue(); |
398 | } |
399 | |
400 | // We need to find a vector index to simplify. |
401 | if (!VecOperand) |
402 | return std::make_pair(x: nullptr, y: nullptr); |
403 | |
404 | // We can't extract the stride if the arithmetic is done at a different size |
405 | // than the pointer type. Adding the stride later may not wrap correctly. |
406 | // Technically we could handle wider indices, but I don't expect that in |
407 | // practice. Handle one special case here - constants. This simplifies |
408 | // writing test cases. |
409 | Value *VecIndex = Ops[*VecOperand]; |
410 | Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType()); |
411 | if (VecIndex->getType() != VecIntPtrTy) { |
412 | auto *VecIndexC = dyn_cast<Constant>(Val: VecIndex); |
413 | if (!VecIndexC) |
414 | return std::make_pair(x: nullptr, y: nullptr); |
415 | if (VecIndex->getType()->getScalarSizeInBits() > VecIntPtrTy->getScalarSizeInBits()) |
416 | VecIndex = ConstantFoldCastInstruction(opcode: Instruction::Trunc, V: VecIndexC, DestTy: VecIntPtrTy); |
417 | else |
418 | VecIndex = ConstantFoldCastInstruction(opcode: Instruction::SExt, V: VecIndexC, DestTy: VecIntPtrTy); |
419 | } |
420 | |
421 | // Handle the non-recursive case. This is what we see if the vectorizer |
422 | // decides to use a scalar IV + vid on demand instead of a vector IV. |
423 | auto [Start, Stride] = matchStridedStart(Start: VecIndex, Builder); |
424 | if (Start) { |
425 | assert(Stride); |
426 | Builder.SetInsertPoint(GEP); |
427 | |
428 | // Replace the vector index with the scalar start and build a scalar GEP. |
429 | Ops[*VecOperand] = Start; |
430 | Type *SourceTy = GEP->getSourceElementType(); |
431 | Value *BasePtr = |
432 | Builder.CreateGEP(Ty: SourceTy, Ptr: ScalarBase, IdxList: ArrayRef(Ops).drop_front()); |
433 | |
434 | // Convert stride to pointer size if needed. |
435 | Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); |
436 | assert(Stride->getType() == IntPtrTy && "Unexpected type" ); |
437 | |
438 | // Scale the stride by the size of the indexed type. |
439 | if (TypeScale != 1) |
440 | Stride = Builder.CreateMul(LHS: Stride, RHS: ConstantInt::get(Ty: IntPtrTy, V: TypeScale)); |
441 | |
442 | auto P = std::make_pair(x&: BasePtr, y&: Stride); |
443 | StridedAddrs[GEP] = P; |
444 | return P; |
445 | } |
446 | |
447 | // Make sure we're in a loop and that has a pre-header and a single latch. |
448 | Loop *L = LI->getLoopFor(BB: GEP->getParent()); |
449 | if (!L || !L->getLoopPreheader() || !L->getLoopLatch()) |
450 | return std::make_pair(x: nullptr, y: nullptr); |
451 | |
452 | BinaryOperator *Inc; |
453 | PHINode *BasePhi; |
454 | if (!matchStridedRecurrence(Index: VecIndex, L, Stride, BasePtr&: BasePhi, Inc, Builder)) |
455 | return std::make_pair(x: nullptr, y: nullptr); |
456 | |
457 | assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi." ); |
458 | unsigned IncrementingBlock = BasePhi->getOperand(i_nocapture: 0) == Inc ? 0 : 1; |
459 | assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc && |
460 | "Expected one operand of phi to be Inc" ); |
461 | |
462 | Builder.SetInsertPoint(GEP); |
463 | |
464 | // Replace the vector index with the scalar phi and build a scalar GEP. |
465 | Ops[*VecOperand] = BasePhi; |
466 | Type *SourceTy = GEP->getSourceElementType(); |
467 | Value *BasePtr = |
468 | Builder.CreateGEP(Ty: SourceTy, Ptr: ScalarBase, IdxList: ArrayRef(Ops).drop_front()); |
469 | |
470 | // Final adjustments to stride should go in the start block. |
471 | Builder.SetInsertPoint( |
472 | BasePhi->getIncomingBlock(i: 1 - IncrementingBlock)->getTerminator()); |
473 | |
474 | // Convert stride to pointer size if needed. |
475 | Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); |
476 | assert(Stride->getType() == IntPtrTy && "Unexpected type" ); |
477 | |
478 | // Scale the stride by the size of the indexed type. |
479 | if (TypeScale != 1) |
480 | Stride = Builder.CreateMul(LHS: Stride, RHS: ConstantInt::get(Ty: IntPtrTy, V: TypeScale)); |
481 | |
482 | auto P = std::make_pair(x&: BasePtr, y&: Stride); |
483 | StridedAddrs[GEP] = P; |
484 | return P; |
485 | } |
486 | |
487 | bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II, |
488 | Type *DataType, |
489 | Value *Ptr, |
490 | Value *AlignOp) { |
491 | // Make sure the operation will be supported by the backend. |
492 | MaybeAlign MA = cast<ConstantInt>(Val: AlignOp)->getMaybeAlignValue(); |
493 | EVT DataTypeVT = TLI->getValueType(DL: *DL, Ty: DataType); |
494 | if (!MA || !TLI->isLegalStridedLoadStore(DataType: DataTypeVT, Alignment: *MA)) |
495 | return false; |
496 | |
497 | // FIXME: Let the backend type legalize by splitting/widening? |
498 | if (!TLI->isTypeLegal(VT: DataTypeVT)) |
499 | return false; |
500 | |
501 | // Pointer should be an instruction. |
502 | auto *PtrI = dyn_cast<Instruction>(Val: Ptr); |
503 | if (!PtrI) |
504 | return false; |
505 | |
506 | LLVMContext &Ctx = PtrI->getContext(); |
507 | IRBuilder<InstSimplifyFolder> Builder(Ctx, *DL); |
508 | Builder.SetInsertPoint(PtrI); |
509 | |
510 | Value *BasePtr, *Stride; |
511 | std::tie(args&: BasePtr, args&: Stride) = determineBaseAndStride(Ptr: PtrI, Builder); |
512 | if (!BasePtr) |
513 | return false; |
514 | assert(Stride != nullptr); |
515 | |
516 | Builder.SetInsertPoint(II); |
517 | |
518 | CallInst *Call; |
519 | if (II->getIntrinsicID() == Intrinsic::masked_gather) |
520 | Call = Builder.CreateIntrinsic( |
521 | ID: Intrinsic::riscv_masked_strided_load, |
522 | Types: {DataType, BasePtr->getType(), Stride->getType()}, |
523 | Args: {II->getArgOperand(i: 3), BasePtr, Stride, II->getArgOperand(i: 2)}); |
524 | else |
525 | Call = Builder.CreateIntrinsic( |
526 | ID: Intrinsic::riscv_masked_strided_store, |
527 | Types: {DataType, BasePtr->getType(), Stride->getType()}, |
528 | Args: {II->getArgOperand(i: 0), BasePtr, Stride, II->getArgOperand(i: 3)}); |
529 | |
530 | Call->takeName(V: II); |
531 | II->replaceAllUsesWith(V: Call); |
532 | II->eraseFromParent(); |
533 | |
534 | if (PtrI->use_empty()) |
535 | RecursivelyDeleteTriviallyDeadInstructions(V: PtrI); |
536 | |
537 | return true; |
538 | } |
539 | |
540 | bool RISCVGatherScatterLowering::runOnFunction(Function &F) { |
541 | if (skipFunction(F)) |
542 | return false; |
543 | |
544 | auto &TPC = getAnalysis<TargetPassConfig>(); |
545 | auto &TM = TPC.getTM<RISCVTargetMachine>(); |
546 | ST = &TM.getSubtarget<RISCVSubtarget>(F); |
547 | if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors()) |
548 | return false; |
549 | |
550 | TLI = ST->getTargetLowering(); |
551 | DL = &F.getDataLayout(); |
552 | LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); |
553 | |
554 | StridedAddrs.clear(); |
555 | |
556 | SmallVector<IntrinsicInst *, 4> Gathers; |
557 | SmallVector<IntrinsicInst *, 4> Scatters; |
558 | |
559 | bool Changed = false; |
560 | |
561 | for (BasicBlock &BB : F) { |
562 | for (Instruction &I : BB) { |
563 | IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: &I); |
564 | if (II && II->getIntrinsicID() == Intrinsic::masked_gather) { |
565 | Gathers.push_back(Elt: II); |
566 | } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) { |
567 | Scatters.push_back(Elt: II); |
568 | } |
569 | } |
570 | } |
571 | |
572 | // Rewrite gather/scatter to form strided load/store if possible. |
573 | for (auto *II : Gathers) |
574 | Changed |= tryCreateStridedLoadStore( |
575 | II, DataType: II->getType(), Ptr: II->getArgOperand(i: 0), AlignOp: II->getArgOperand(i: 1)); |
576 | for (auto *II : Scatters) |
577 | Changed |= |
578 | tryCreateStridedLoadStore(II, DataType: II->getArgOperand(i: 0)->getType(), |
579 | Ptr: II->getArgOperand(i: 1), AlignOp: II->getArgOperand(i: 2)); |
580 | |
581 | // Remove any dead phis. |
582 | while (!MaybeDeadPHIs.empty()) { |
583 | if (auto *Phi = dyn_cast_or_null<PHINode>(Val: MaybeDeadPHIs.pop_back_val())) |
584 | RecursivelyDeleteDeadPHINode(PN: Phi); |
585 | } |
586 | |
587 | return Changed; |
588 | } |
589 | |