1 | //===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This pass custom lowers llvm.gather and llvm.scatter instructions to |
10 | // RISC-V intrinsics. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "RISCV.h" |
15 | #include "RISCVTargetMachine.h" |
16 | #include "llvm/Analysis/InstSimplifyFolder.h" |
17 | #include "llvm/Analysis/LoopInfo.h" |
18 | #include "llvm/Analysis/ValueTracking.h" |
19 | #include "llvm/Analysis/VectorUtils.h" |
20 | #include "llvm/CodeGen/TargetPassConfig.h" |
21 | #include "llvm/IR/GetElementPtrTypeIterator.h" |
22 | #include "llvm/IR/IRBuilder.h" |
23 | #include "llvm/IR/IntrinsicInst.h" |
24 | #include "llvm/IR/PatternMatch.h" |
25 | #include "llvm/Transforms/Utils/Local.h" |
26 | #include <optional> |
27 | |
28 | using namespace llvm; |
29 | using namespace PatternMatch; |
30 | |
31 | #define DEBUG_TYPE "riscv-gather-scatter-lowering" |
32 | |
33 | namespace { |
34 | |
35 | class RISCVGatherScatterLowering : public FunctionPass { |
36 | const RISCVSubtarget *ST = nullptr; |
37 | const RISCVTargetLowering *TLI = nullptr; |
38 | LoopInfo *LI = nullptr; |
39 | const DataLayout *DL = nullptr; |
40 | |
41 | SmallVector<WeakTrackingVH> MaybeDeadPHIs; |
42 | |
43 | // Cache of the BasePtr and Stride determined from this GEP. When a GEP is |
44 | // used by multiple gathers/scatters, this allow us to reuse the scalar |
45 | // instructions we created for the first gather/scatter for the others. |
46 | DenseMap<GetElementPtrInst *, std::pair<Value *, Value *>> StridedAddrs; |
47 | |
48 | public: |
49 | static char ID; // Pass identification, replacement for typeid |
50 | |
51 | RISCVGatherScatterLowering() : FunctionPass(ID) {} |
52 | |
53 | bool runOnFunction(Function &F) override; |
54 | |
55 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
56 | AU.setPreservesCFG(); |
57 | AU.addRequired<TargetPassConfig>(); |
58 | AU.addRequired<LoopInfoWrapperPass>(); |
59 | } |
60 | |
61 | StringRef getPassName() const override { |
62 | return "RISC-V gather/scatter lowering" ; |
63 | } |
64 | |
65 | private: |
66 | bool tryCreateStridedLoadStore(IntrinsicInst *II); |
67 | |
68 | std::pair<Value *, Value *> determineBaseAndStride(Instruction *Ptr, |
69 | IRBuilderBase &Builder); |
70 | |
71 | bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride, |
72 | PHINode *&BasePtr, BinaryOperator *&Inc, |
73 | IRBuilderBase &Builder); |
74 | }; |
75 | |
76 | } // end anonymous namespace |
77 | |
78 | char RISCVGatherScatterLowering::ID = 0; |
79 | |
80 | INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE, |
81 | "RISC-V gather/scatter lowering pass" , false, false) |
82 | |
83 | FunctionPass *llvm::createRISCVGatherScatterLoweringPass() { |
84 | return new RISCVGatherScatterLowering(); |
85 | } |
86 | |
87 | // TODO: Should we consider the mask when looking for a stride? |
88 | static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) { |
89 | if (!isa<FixedVectorType>(Val: StartC->getType())) |
90 | return std::make_pair(x: nullptr, y: nullptr); |
91 | |
92 | unsigned NumElts = cast<FixedVectorType>(Val: StartC->getType())->getNumElements(); |
93 | |
94 | // Check that the start value is a strided constant. |
95 | auto *StartVal = |
96 | dyn_cast_or_null<ConstantInt>(Val: StartC->getAggregateElement(Elt: (unsigned)0)); |
97 | if (!StartVal) |
98 | return std::make_pair(x: nullptr, y: nullptr); |
99 | APInt StrideVal(StartVal->getValue().getBitWidth(), 0); |
100 | ConstantInt *Prev = StartVal; |
101 | for (unsigned i = 1; i != NumElts; ++i) { |
102 | auto *C = dyn_cast_or_null<ConstantInt>(Val: StartC->getAggregateElement(Elt: i)); |
103 | if (!C) |
104 | return std::make_pair(x: nullptr, y: nullptr); |
105 | |
106 | APInt LocalStride = C->getValue() - Prev->getValue(); |
107 | if (i == 1) |
108 | StrideVal = LocalStride; |
109 | else if (StrideVal != LocalStride) |
110 | return std::make_pair(x: nullptr, y: nullptr); |
111 | |
112 | Prev = C; |
113 | } |
114 | |
115 | Value *Stride = ConstantInt::get(Ty: StartVal->getType(), V: StrideVal); |
116 | |
117 | return std::make_pair(x&: StartVal, y&: Stride); |
118 | } |
119 | |
120 | static std::pair<Value *, Value *> matchStridedStart(Value *Start, |
121 | IRBuilderBase &Builder) { |
122 | // Base case, start is a strided constant. |
123 | auto *StartC = dyn_cast<Constant>(Val: Start); |
124 | if (StartC) |
125 | return matchStridedConstant(StartC); |
126 | |
127 | // Base case, start is a stepvector |
128 | if (match(V: Start, P: m_Intrinsic<Intrinsic::stepvector>())) { |
129 | auto *Ty = Start->getType()->getScalarType(); |
130 | return std::make_pair(x: ConstantInt::get(Ty, V: 0), y: ConstantInt::get(Ty, V: 1)); |
131 | } |
132 | |
133 | // Not a constant, maybe it's a strided constant with a splat added or |
134 | // multiplied. |
135 | auto *BO = dyn_cast<BinaryOperator>(Val: Start); |
136 | if (!BO || (BO->getOpcode() != Instruction::Add && |
137 | BO->getOpcode() != Instruction::Or && |
138 | BO->getOpcode() != Instruction::Shl && |
139 | BO->getOpcode() != Instruction::Mul)) |
140 | return std::make_pair(x: nullptr, y: nullptr); |
141 | |
142 | if (BO->getOpcode() == Instruction::Or && |
143 | !cast<PossiblyDisjointInst>(Val: BO)->isDisjoint()) |
144 | return std::make_pair(x: nullptr, y: nullptr); |
145 | |
146 | // Look for an operand that is splatted. |
147 | unsigned OtherIndex = 0; |
148 | Value *Splat = getSplatValue(V: BO->getOperand(i_nocapture: 1)); |
149 | if (!Splat && Instruction::isCommutative(Opcode: BO->getOpcode())) { |
150 | Splat = getSplatValue(V: BO->getOperand(i_nocapture: 0)); |
151 | OtherIndex = 1; |
152 | } |
153 | if (!Splat) |
154 | return std::make_pair(x: nullptr, y: nullptr); |
155 | |
156 | Value *Stride; |
157 | std::tie(args&: Start, args&: Stride) = matchStridedStart(Start: BO->getOperand(i_nocapture: OtherIndex), |
158 | Builder); |
159 | if (!Start) |
160 | return std::make_pair(x: nullptr, y: nullptr); |
161 | |
162 | Builder.SetInsertPoint(BO); |
163 | Builder.SetCurrentDebugLocation(DebugLoc()); |
164 | // Add the splat value to the start or multiply the start and stride by the |
165 | // splat. |
166 | switch (BO->getOpcode()) { |
167 | default: |
168 | llvm_unreachable("Unexpected opcode" ); |
169 | case Instruction::Or: |
170 | // TODO: We'd be better off creating disjoint or here, but we don't yet |
171 | // have an IRBuilder API for that. |
172 | [[fallthrough]]; |
173 | case Instruction::Add: |
174 | Start = Builder.CreateAdd(LHS: Start, RHS: Splat); |
175 | break; |
176 | case Instruction::Mul: |
177 | Start = Builder.CreateMul(LHS: Start, RHS: Splat); |
178 | Stride = Builder.CreateMul(LHS: Stride, RHS: Splat); |
179 | break; |
180 | case Instruction::Shl: |
181 | Start = Builder.CreateShl(LHS: Start, RHS: Splat); |
182 | Stride = Builder.CreateShl(LHS: Stride, RHS: Splat); |
183 | break; |
184 | } |
185 | |
186 | return std::make_pair(x&: Start, y&: Stride); |
187 | } |
188 | |
189 | // Recursively, walk about the use-def chain until we find a Phi with a strided |
190 | // start value. Build and update a scalar recurrence as we unwind the recursion. |
191 | // We also update the Stride as we unwind. Our goal is to move all of the |
192 | // arithmetic out of the loop. |
193 | bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L, |
194 | Value *&Stride, |
195 | PHINode *&BasePtr, |
196 | BinaryOperator *&Inc, |
197 | IRBuilderBase &Builder) { |
198 | // Our base case is a Phi. |
199 | if (auto *Phi = dyn_cast<PHINode>(Val: Index)) { |
200 | // A phi node we want to perform this function on should be from the |
201 | // loop header. |
202 | if (Phi->getParent() != L->getHeader()) |
203 | return false; |
204 | |
205 | Value *Step, *Start; |
206 | if (!matchSimpleRecurrence(P: Phi, BO&: Inc, Start, Step) || |
207 | Inc->getOpcode() != Instruction::Add) |
208 | return false; |
209 | assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi." ); |
210 | unsigned IncrementingBlock = Phi->getIncomingValue(i: 0) == Inc ? 0 : 1; |
211 | assert(Phi->getIncomingValue(IncrementingBlock) == Inc && |
212 | "Expected one operand of phi to be Inc" ); |
213 | |
214 | // Step should be a splat. |
215 | Step = getSplatValue(V: Step); |
216 | if (!Step) |
217 | return false; |
218 | |
219 | std::tie(args&: Start, args&: Stride) = matchStridedStart(Start, Builder); |
220 | if (!Start) |
221 | return false; |
222 | assert(Stride != nullptr); |
223 | |
224 | // Build scalar phi and increment. |
225 | BasePtr = |
226 | PHINode::Create(Ty: Start->getType(), NumReservedValues: 2, NameStr: Phi->getName() + ".scalar" , InsertBefore: Phi->getIterator()); |
227 | Inc = BinaryOperator::CreateAdd(V1: BasePtr, V2: Step, Name: Inc->getName() + ".scalar" , |
228 | InsertBefore: Inc->getIterator()); |
229 | BasePtr->addIncoming(V: Start, BB: Phi->getIncomingBlock(i: 1 - IncrementingBlock)); |
230 | BasePtr->addIncoming(V: Inc, BB: Phi->getIncomingBlock(i: IncrementingBlock)); |
231 | |
232 | // Note that this Phi might be eligible for removal. |
233 | MaybeDeadPHIs.push_back(Elt: Phi); |
234 | return true; |
235 | } |
236 | |
237 | // Otherwise look for binary operator. |
238 | auto *BO = dyn_cast<BinaryOperator>(Val: Index); |
239 | if (!BO) |
240 | return false; |
241 | |
242 | switch (BO->getOpcode()) { |
243 | default: |
244 | return false; |
245 | case Instruction::Or: |
246 | // We need to be able to treat Or as Add. |
247 | if (!cast<PossiblyDisjointInst>(Val: BO)->isDisjoint()) |
248 | return false; |
249 | break; |
250 | case Instruction::Add: |
251 | break; |
252 | case Instruction::Shl: |
253 | break; |
254 | case Instruction::Mul: |
255 | break; |
256 | } |
257 | |
258 | // We should have one operand in the loop and one splat. |
259 | Value *OtherOp; |
260 | if (isa<Instruction>(Val: BO->getOperand(i_nocapture: 0)) && |
261 | L->contains(Inst: cast<Instruction>(Val: BO->getOperand(i_nocapture: 0)))) { |
262 | Index = cast<Instruction>(Val: BO->getOperand(i_nocapture: 0)); |
263 | OtherOp = BO->getOperand(i_nocapture: 1); |
264 | } else if (isa<Instruction>(Val: BO->getOperand(i_nocapture: 1)) && |
265 | L->contains(Inst: cast<Instruction>(Val: BO->getOperand(i_nocapture: 1))) && |
266 | Instruction::isCommutative(Opcode: BO->getOpcode())) { |
267 | Index = cast<Instruction>(Val: BO->getOperand(i_nocapture: 1)); |
268 | OtherOp = BO->getOperand(i_nocapture: 0); |
269 | } else { |
270 | return false; |
271 | } |
272 | |
273 | // Make sure other op is loop invariant. |
274 | if (!L->isLoopInvariant(V: OtherOp)) |
275 | return false; |
276 | |
277 | // Make sure we have a splat. |
278 | Value *SplatOp = getSplatValue(V: OtherOp); |
279 | if (!SplatOp) |
280 | return false; |
281 | |
282 | // Recurse up the use-def chain. |
283 | if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder)) |
284 | return false; |
285 | |
286 | // Locate the Step and Start values from the recurrence. |
287 | unsigned StepIndex = Inc->getOperand(i_nocapture: 0) == BasePtr ? 1 : 0; |
288 | unsigned StartBlock = BasePtr->getOperand(i_nocapture: 0) == Inc ? 1 : 0; |
289 | Value *Step = Inc->getOperand(i_nocapture: StepIndex); |
290 | Value *Start = BasePtr->getOperand(i_nocapture: StartBlock); |
291 | |
292 | // We need to adjust the start value in the preheader. |
293 | Builder.SetInsertPoint( |
294 | BasePtr->getIncomingBlock(i: StartBlock)->getTerminator()); |
295 | Builder.SetCurrentDebugLocation(DebugLoc()); |
296 | |
297 | // TODO: Share this switch with matchStridedStart? |
298 | switch (BO->getOpcode()) { |
299 | default: |
300 | llvm_unreachable("Unexpected opcode!" ); |
301 | case Instruction::Add: |
302 | case Instruction::Or: { |
303 | // An add only affects the start value. It's ok to do this for Or because |
304 | // we already checked that there are no common set bits. |
305 | Start = Builder.CreateAdd(LHS: Start, RHS: SplatOp, Name: "start" ); |
306 | break; |
307 | } |
308 | case Instruction::Mul: { |
309 | Start = Builder.CreateMul(LHS: Start, RHS: SplatOp, Name: "start" ); |
310 | Stride = Builder.CreateMul(LHS: Stride, RHS: SplatOp, Name: "stride" ); |
311 | break; |
312 | } |
313 | case Instruction::Shl: { |
314 | Start = Builder.CreateShl(LHS: Start, RHS: SplatOp, Name: "start" ); |
315 | Stride = Builder.CreateShl(LHS: Stride, RHS: SplatOp, Name: "stride" ); |
316 | break; |
317 | } |
318 | } |
319 | |
320 | // If the Step was defined inside the loop, adjust it before its definition |
321 | // instead of in the preheader. |
322 | if (auto *StepI = dyn_cast<Instruction>(Val: Step); StepI && L->contains(Inst: StepI)) |
323 | Builder.SetInsertPoint(*StepI->getInsertionPointAfterDef()); |
324 | |
325 | switch (BO->getOpcode()) { |
326 | default: |
327 | break; |
328 | case Instruction::Mul: |
329 | Step = Builder.CreateMul(LHS: Step, RHS: SplatOp, Name: "step" ); |
330 | break; |
331 | case Instruction::Shl: |
332 | Step = Builder.CreateShl(LHS: Step, RHS: SplatOp, Name: "step" ); |
333 | break; |
334 | } |
335 | |
336 | Inc->setOperand(i_nocapture: StepIndex, Val_nocapture: Step); |
337 | BasePtr->setIncomingValue(i: StartBlock, V: Start); |
338 | return true; |
339 | } |
340 | |
341 | std::pair<Value *, Value *> |
342 | RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr, |
343 | IRBuilderBase &Builder) { |
344 | |
345 | // A gather/scatter of a splat is a zero strided load/store. |
346 | if (auto *BasePtr = getSplatValue(V: Ptr)) { |
347 | Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); |
348 | return std::make_pair(x&: BasePtr, y: ConstantInt::get(Ty: IntPtrTy, V: 0)); |
349 | } |
350 | |
351 | auto *GEP = dyn_cast<GetElementPtrInst>(Val: Ptr); |
352 | if (!GEP) |
353 | return std::make_pair(x: nullptr, y: nullptr); |
354 | |
355 | auto I = StridedAddrs.find(Val: GEP); |
356 | if (I != StridedAddrs.end()) |
357 | return I->second; |
358 | |
359 | SmallVector<Value *, 2> Ops(GEP->operands()); |
360 | |
361 | // If the base pointer is a vector, check if it's strided. |
362 | Value *Base = GEP->getPointerOperand(); |
363 | if (auto *BaseInst = dyn_cast<Instruction>(Val: Base); |
364 | BaseInst && BaseInst->getType()->isVectorTy()) { |
365 | // If GEP's offset is scalar then we can add it to the base pointer's base. |
366 | auto IsScalar = [](Value *Idx) { return !Idx->getType()->isVectorTy(); }; |
367 | if (all_of(Range: GEP->indices(), P: IsScalar)) { |
368 | auto [BaseBase, Stride] = determineBaseAndStride(Ptr: BaseInst, Builder); |
369 | if (BaseBase) { |
370 | Builder.SetInsertPoint(GEP); |
371 | SmallVector<Value *> Indices(GEP->indices()); |
372 | Value *OffsetBase = |
373 | Builder.CreateGEP(Ty: GEP->getSourceElementType(), Ptr: BaseBase, IdxList: Indices, |
374 | Name: GEP->getName() + "offset" , NW: GEP->isInBounds()); |
375 | return {OffsetBase, Stride}; |
376 | } |
377 | } |
378 | } |
379 | |
380 | // Base pointer needs to be a scalar. |
381 | Value *ScalarBase = Base; |
382 | if (ScalarBase->getType()->isVectorTy()) { |
383 | ScalarBase = getSplatValue(V: ScalarBase); |
384 | if (!ScalarBase) |
385 | return std::make_pair(x: nullptr, y: nullptr); |
386 | } |
387 | |
388 | std::optional<unsigned> VecOperand; |
389 | unsigned TypeScale = 0; |
390 | |
391 | // Look for a vector operand and scale. |
392 | gep_type_iterator GTI = gep_type_begin(GEP); |
393 | for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) { |
394 | if (!Ops[i]->getType()->isVectorTy()) |
395 | continue; |
396 | |
397 | if (VecOperand) |
398 | return std::make_pair(x: nullptr, y: nullptr); |
399 | |
400 | VecOperand = i; |
401 | |
402 | TypeSize TS = GTI.getSequentialElementStride(DL: *DL); |
403 | if (TS.isScalable()) |
404 | return std::make_pair(x: nullptr, y: nullptr); |
405 | |
406 | TypeScale = TS.getFixedValue(); |
407 | } |
408 | |
409 | // We need to find a vector index to simplify. |
410 | if (!VecOperand) |
411 | return std::make_pair(x: nullptr, y: nullptr); |
412 | |
413 | // We can't extract the stride if the arithmetic is done at a different size |
414 | // than the pointer type. Adding the stride later may not wrap correctly. |
415 | // Technically we could handle wider indices, but I don't expect that in |
416 | // practice. Handle one special case here - constants. This simplifies |
417 | // writing test cases. |
418 | Value *VecIndex = Ops[*VecOperand]; |
419 | Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType()); |
420 | if (VecIndex->getType() != VecIntPtrTy) { |
421 | auto *VecIndexC = dyn_cast<Constant>(Val: VecIndex); |
422 | if (!VecIndexC) |
423 | return std::make_pair(x: nullptr, y: nullptr); |
424 | if (VecIndex->getType()->getScalarSizeInBits() > VecIntPtrTy->getScalarSizeInBits()) |
425 | VecIndex = ConstantFoldCastInstruction(opcode: Instruction::Trunc, V: VecIndexC, DestTy: VecIntPtrTy); |
426 | else |
427 | VecIndex = ConstantFoldCastInstruction(opcode: Instruction::SExt, V: VecIndexC, DestTy: VecIntPtrTy); |
428 | } |
429 | |
430 | // Handle the non-recursive case. This is what we see if the vectorizer |
431 | // decides to use a scalar IV + vid on demand instead of a vector IV. |
432 | auto [Start, Stride] = matchStridedStart(Start: VecIndex, Builder); |
433 | if (Start) { |
434 | assert(Stride); |
435 | Builder.SetInsertPoint(GEP); |
436 | |
437 | // Replace the vector index with the scalar start and build a scalar GEP. |
438 | Ops[*VecOperand] = Start; |
439 | Type *SourceTy = GEP->getSourceElementType(); |
440 | Value *BasePtr = |
441 | Builder.CreateGEP(Ty: SourceTy, Ptr: ScalarBase, IdxList: ArrayRef(Ops).drop_front()); |
442 | |
443 | // Convert stride to pointer size if needed. |
444 | Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); |
445 | assert(Stride->getType() == IntPtrTy && "Unexpected type" ); |
446 | |
447 | // Scale the stride by the size of the indexed type. |
448 | if (TypeScale != 1) |
449 | Stride = Builder.CreateMul(LHS: Stride, RHS: ConstantInt::get(Ty: IntPtrTy, V: TypeScale)); |
450 | |
451 | auto P = std::make_pair(x&: BasePtr, y&: Stride); |
452 | StridedAddrs[GEP] = P; |
453 | return P; |
454 | } |
455 | |
456 | // Make sure we're in a loop and that has a pre-header and a single latch. |
457 | Loop *L = LI->getLoopFor(BB: GEP->getParent()); |
458 | if (!L || !L->getLoopPreheader() || !L->getLoopLatch()) |
459 | return std::make_pair(x: nullptr, y: nullptr); |
460 | |
461 | BinaryOperator *Inc; |
462 | PHINode *BasePhi; |
463 | if (!matchStridedRecurrence(Index: VecIndex, L, Stride, BasePtr&: BasePhi, Inc, Builder)) |
464 | return std::make_pair(x: nullptr, y: nullptr); |
465 | |
466 | assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi." ); |
467 | unsigned IncrementingBlock = BasePhi->getOperand(i_nocapture: 0) == Inc ? 0 : 1; |
468 | assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc && |
469 | "Expected one operand of phi to be Inc" ); |
470 | |
471 | Builder.SetInsertPoint(GEP); |
472 | |
473 | // Replace the vector index with the scalar phi and build a scalar GEP. |
474 | Ops[*VecOperand] = BasePhi; |
475 | Type *SourceTy = GEP->getSourceElementType(); |
476 | Value *BasePtr = |
477 | Builder.CreateGEP(Ty: SourceTy, Ptr: ScalarBase, IdxList: ArrayRef(Ops).drop_front()); |
478 | |
479 | // Final adjustments to stride should go in the start block. |
480 | Builder.SetInsertPoint( |
481 | BasePhi->getIncomingBlock(i: 1 - IncrementingBlock)->getTerminator()); |
482 | |
483 | // Convert stride to pointer size if needed. |
484 | Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); |
485 | assert(Stride->getType() == IntPtrTy && "Unexpected type" ); |
486 | |
487 | // Scale the stride by the size of the indexed type. |
488 | if (TypeScale != 1) |
489 | Stride = Builder.CreateMul(LHS: Stride, RHS: ConstantInt::get(Ty: IntPtrTy, V: TypeScale)); |
490 | |
491 | auto P = std::make_pair(x&: BasePtr, y&: Stride); |
492 | StridedAddrs[GEP] = P; |
493 | return P; |
494 | } |
495 | |
496 | bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II) { |
497 | VectorType *DataType; |
498 | Value *StoreVal = nullptr, *Ptr, *Mask, *EVL = nullptr; |
499 | MaybeAlign MA; |
500 | switch (II->getIntrinsicID()) { |
501 | case Intrinsic::masked_gather: |
502 | DataType = cast<VectorType>(Val: II->getType()); |
503 | Ptr = II->getArgOperand(i: 0); |
504 | MA = cast<ConstantInt>(Val: II->getArgOperand(i: 1))->getMaybeAlignValue(); |
505 | Mask = II->getArgOperand(i: 2); |
506 | break; |
507 | case Intrinsic::vp_gather: |
508 | DataType = cast<VectorType>(Val: II->getType()); |
509 | Ptr = II->getArgOperand(i: 0); |
510 | MA = II->getParamAlign(ArgNo: 0).value_or( |
511 | u: DL->getABITypeAlign(Ty: DataType->getElementType())); |
512 | Mask = II->getArgOperand(i: 1); |
513 | EVL = II->getArgOperand(i: 2); |
514 | break; |
515 | case Intrinsic::masked_scatter: |
516 | DataType = cast<VectorType>(Val: II->getArgOperand(i: 0)->getType()); |
517 | StoreVal = II->getArgOperand(i: 0); |
518 | Ptr = II->getArgOperand(i: 1); |
519 | MA = cast<ConstantInt>(Val: II->getArgOperand(i: 2))->getMaybeAlignValue(); |
520 | Mask = II->getArgOperand(i: 3); |
521 | break; |
522 | case Intrinsic::vp_scatter: |
523 | DataType = cast<VectorType>(Val: II->getArgOperand(i: 0)->getType()); |
524 | StoreVal = II->getArgOperand(i: 0); |
525 | Ptr = II->getArgOperand(i: 1); |
526 | MA = II->getParamAlign(ArgNo: 1).value_or( |
527 | u: DL->getABITypeAlign(Ty: DataType->getElementType())); |
528 | Mask = II->getArgOperand(i: 2); |
529 | EVL = II->getArgOperand(i: 3); |
530 | break; |
531 | default: |
532 | llvm_unreachable("Unexpected intrinsic" ); |
533 | } |
534 | |
535 | // Make sure the operation will be supported by the backend. |
536 | EVT DataTypeVT = TLI->getValueType(DL: *DL, Ty: DataType); |
537 | if (!MA || !TLI->isLegalStridedLoadStore(DataType: DataTypeVT, Alignment: *MA)) |
538 | return false; |
539 | |
540 | // FIXME: Let the backend type legalize by splitting/widening? |
541 | if (!TLI->isTypeLegal(VT: DataTypeVT)) |
542 | return false; |
543 | |
544 | // Pointer should be an instruction. |
545 | auto *PtrI = dyn_cast<Instruction>(Val: Ptr); |
546 | if (!PtrI) |
547 | return false; |
548 | |
549 | LLVMContext &Ctx = PtrI->getContext(); |
550 | IRBuilder Builder(Ctx, InstSimplifyFolder(*DL)); |
551 | Builder.SetInsertPoint(PtrI); |
552 | |
553 | Value *BasePtr, *Stride; |
554 | std::tie(args&: BasePtr, args&: Stride) = determineBaseAndStride(Ptr: PtrI, Builder); |
555 | if (!BasePtr) |
556 | return false; |
557 | assert(Stride != nullptr); |
558 | |
559 | Builder.SetInsertPoint(II); |
560 | |
561 | if (!EVL) |
562 | EVL = Builder.CreateElementCount( |
563 | Ty: Builder.getInt32Ty(), EC: cast<VectorType>(Val: DataType)->getElementCount()); |
564 | |
565 | CallInst *Call; |
566 | |
567 | if (!StoreVal) { |
568 | Call = Builder.CreateIntrinsic( |
569 | ID: Intrinsic::experimental_vp_strided_load, |
570 | Types: {DataType, BasePtr->getType(), Stride->getType()}, |
571 | Args: {BasePtr, Stride, Mask, EVL}); |
572 | |
573 | // Merge llvm.masked.gather's passthru |
574 | if (II->getIntrinsicID() == Intrinsic::masked_gather) |
575 | Call = Builder.CreateIntrinsic(ID: Intrinsic::vp_select, Types: {DataType}, |
576 | Args: {Mask, Call, II->getArgOperand(i: 3), EVL}); |
577 | } else |
578 | Call = Builder.CreateIntrinsic( |
579 | ID: Intrinsic::experimental_vp_strided_store, |
580 | Types: {DataType, BasePtr->getType(), Stride->getType()}, |
581 | Args: {StoreVal, BasePtr, Stride, Mask, EVL}); |
582 | |
583 | Call->takeName(V: II); |
584 | II->replaceAllUsesWith(V: Call); |
585 | II->eraseFromParent(); |
586 | |
587 | if (PtrI->use_empty()) |
588 | RecursivelyDeleteTriviallyDeadInstructions(V: PtrI); |
589 | |
590 | return true; |
591 | } |
592 | |
593 | bool RISCVGatherScatterLowering::runOnFunction(Function &F) { |
594 | if (skipFunction(F)) |
595 | return false; |
596 | |
597 | auto &TPC = getAnalysis<TargetPassConfig>(); |
598 | auto &TM = TPC.getTM<RISCVTargetMachine>(); |
599 | ST = &TM.getSubtarget<RISCVSubtarget>(F); |
600 | if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors()) |
601 | return false; |
602 | |
603 | TLI = ST->getTargetLowering(); |
604 | DL = &F.getDataLayout(); |
605 | LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); |
606 | |
607 | StridedAddrs.clear(); |
608 | |
609 | SmallVector<IntrinsicInst *, 4> Worklist; |
610 | |
611 | bool Changed = false; |
612 | |
613 | for (BasicBlock &BB : F) { |
614 | for (Instruction &I : BB) { |
615 | IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: &I); |
616 | if (!II) |
617 | continue; |
618 | switch (II->getIntrinsicID()) { |
619 | case Intrinsic::masked_gather: |
620 | case Intrinsic::masked_scatter: |
621 | case Intrinsic::vp_gather: |
622 | case Intrinsic::vp_scatter: |
623 | Worklist.push_back(Elt: II); |
624 | break; |
625 | default: |
626 | break; |
627 | } |
628 | } |
629 | } |
630 | |
631 | // Rewrite gather/scatter to form strided load/store if possible. |
632 | for (auto *II : Worklist) |
633 | Changed |= tryCreateStridedLoadStore(II); |
634 | |
635 | // Remove any dead phis. |
636 | while (!MaybeDeadPHIs.empty()) { |
637 | if (auto *Phi = dyn_cast_or_null<PHINode>(Val: MaybeDeadPHIs.pop_back_val())) |
638 | RecursivelyDeleteDeadPHINode(PN: Phi); |
639 | } |
640 | |
641 | return Changed; |
642 | } |
643 | |