| 1 | //===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This pass custom lowers llvm.gather and llvm.scatter instructions to |
| 10 | // RISC-V intrinsics. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "RISCV.h" |
| 15 | #include "RISCVTargetMachine.h" |
| 16 | #include "llvm/Analysis/InstSimplifyFolder.h" |
| 17 | #include "llvm/Analysis/LoopInfo.h" |
| 18 | #include "llvm/Analysis/ValueTracking.h" |
| 19 | #include "llvm/Analysis/VectorUtils.h" |
| 20 | #include "llvm/CodeGen/TargetPassConfig.h" |
| 21 | #include "llvm/IR/GetElementPtrTypeIterator.h" |
| 22 | #include "llvm/IR/IRBuilder.h" |
| 23 | #include "llvm/IR/IntrinsicInst.h" |
| 24 | #include "llvm/IR/PatternMatch.h" |
| 25 | #include "llvm/Transforms/Utils/Local.h" |
| 26 | #include <optional> |
| 27 | |
| 28 | using namespace llvm; |
| 29 | using namespace PatternMatch; |
| 30 | |
| 31 | #define DEBUG_TYPE "riscv-gather-scatter-lowering" |
| 32 | |
| 33 | namespace { |
| 34 | |
| 35 | class RISCVGatherScatterLowering : public FunctionPass { |
| 36 | const RISCVSubtarget *ST = nullptr; |
| 37 | const RISCVTargetLowering *TLI = nullptr; |
| 38 | LoopInfo *LI = nullptr; |
| 39 | const DataLayout *DL = nullptr; |
| 40 | |
| 41 | SmallVector<WeakTrackingVH> MaybeDeadPHIs; |
| 42 | |
| 43 | // Cache of the BasePtr and Stride determined from this GEP. When a GEP is |
| 44 | // used by multiple gathers/scatters, this allow us to reuse the scalar |
| 45 | // instructions we created for the first gather/scatter for the others. |
| 46 | DenseMap<GetElementPtrInst *, std::pair<Value *, Value *>> StridedAddrs; |
| 47 | |
| 48 | public: |
| 49 | static char ID; // Pass identification, replacement for typeid |
| 50 | |
| 51 | RISCVGatherScatterLowering() : FunctionPass(ID) {} |
| 52 | |
| 53 | bool runOnFunction(Function &F) override; |
| 54 | |
| 55 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
| 56 | AU.setPreservesCFG(); |
| 57 | AU.addRequired<TargetPassConfig>(); |
| 58 | AU.addRequired<LoopInfoWrapperPass>(); |
| 59 | } |
| 60 | |
| 61 | StringRef getPassName() const override { |
| 62 | return "RISC-V gather/scatter lowering" ; |
| 63 | } |
| 64 | |
| 65 | private: |
| 66 | bool tryCreateStridedLoadStore(IntrinsicInst *II); |
| 67 | |
| 68 | std::pair<Value *, Value *> determineBaseAndStride(Instruction *Ptr, |
| 69 | IRBuilderBase &Builder); |
| 70 | |
| 71 | bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride, |
| 72 | PHINode *&BasePtr, BinaryOperator *&Inc, |
| 73 | IRBuilderBase &Builder); |
| 74 | }; |
| 75 | |
| 76 | } // end anonymous namespace |
| 77 | |
| 78 | char RISCVGatherScatterLowering::ID = 0; |
| 79 | |
| 80 | INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE, |
| 81 | "RISC-V gather/scatter lowering pass" , false, false) |
| 82 | |
| 83 | FunctionPass *llvm::createRISCVGatherScatterLoweringPass() { |
| 84 | return new RISCVGatherScatterLowering(); |
| 85 | } |
| 86 | |
| 87 | // TODO: Should we consider the mask when looking for a stride? |
| 88 | static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) { |
| 89 | if (!isa<FixedVectorType>(Val: StartC->getType())) |
| 90 | return std::make_pair(x: nullptr, y: nullptr); |
| 91 | |
| 92 | unsigned NumElts = cast<FixedVectorType>(Val: StartC->getType())->getNumElements(); |
| 93 | |
| 94 | // Check that the start value is a strided constant. |
| 95 | auto *StartVal = |
| 96 | dyn_cast_or_null<ConstantInt>(Val: StartC->getAggregateElement(Elt: (unsigned)0)); |
| 97 | if (!StartVal) |
| 98 | return std::make_pair(x: nullptr, y: nullptr); |
| 99 | APInt StrideVal(StartVal->getValue().getBitWidth(), 0); |
| 100 | ConstantInt *Prev = StartVal; |
| 101 | for (unsigned i = 1; i != NumElts; ++i) { |
| 102 | auto *C = dyn_cast_or_null<ConstantInt>(Val: StartC->getAggregateElement(Elt: i)); |
| 103 | if (!C) |
| 104 | return std::make_pair(x: nullptr, y: nullptr); |
| 105 | |
| 106 | APInt LocalStride = C->getValue() - Prev->getValue(); |
| 107 | if (i == 1) |
| 108 | StrideVal = LocalStride; |
| 109 | else if (StrideVal != LocalStride) |
| 110 | return std::make_pair(x: nullptr, y: nullptr); |
| 111 | |
| 112 | Prev = C; |
| 113 | } |
| 114 | |
| 115 | Value *Stride = ConstantInt::get(Ty: StartVal->getType(), V: StrideVal); |
| 116 | |
| 117 | return std::make_pair(x&: StartVal, y&: Stride); |
| 118 | } |
| 119 | |
| 120 | static std::pair<Value *, Value *> matchStridedStart(Value *Start, |
| 121 | IRBuilderBase &Builder) { |
| 122 | // Base case, start is a strided constant. |
| 123 | auto *StartC = dyn_cast<Constant>(Val: Start); |
| 124 | if (StartC) |
| 125 | return matchStridedConstant(StartC); |
| 126 | |
| 127 | // Base case, start is a stepvector |
| 128 | if (match(V: Start, P: m_Intrinsic<Intrinsic::stepvector>())) { |
| 129 | auto *Ty = Start->getType()->getScalarType(); |
| 130 | return std::make_pair(x: ConstantInt::get(Ty, V: 0), y: ConstantInt::get(Ty, V: 1)); |
| 131 | } |
| 132 | |
| 133 | // Not a constant, maybe it's a strided constant with a splat added or |
| 134 | // multiplied. |
| 135 | auto *BO = dyn_cast<BinaryOperator>(Val: Start); |
| 136 | if (!BO || (BO->getOpcode() != Instruction::Add && |
| 137 | BO->getOpcode() != Instruction::Or && |
| 138 | BO->getOpcode() != Instruction::Shl && |
| 139 | BO->getOpcode() != Instruction::Mul)) |
| 140 | return std::make_pair(x: nullptr, y: nullptr); |
| 141 | |
| 142 | if (BO->getOpcode() == Instruction::Or && |
| 143 | !cast<PossiblyDisjointInst>(Val: BO)->isDisjoint()) |
| 144 | return std::make_pair(x: nullptr, y: nullptr); |
| 145 | |
| 146 | // Look for an operand that is splatted. |
| 147 | unsigned OtherIndex = 0; |
| 148 | Value *Splat = getSplatValue(V: BO->getOperand(i_nocapture: 1)); |
| 149 | if (!Splat && Instruction::isCommutative(Opcode: BO->getOpcode())) { |
| 150 | Splat = getSplatValue(V: BO->getOperand(i_nocapture: 0)); |
| 151 | OtherIndex = 1; |
| 152 | } |
| 153 | if (!Splat) |
| 154 | return std::make_pair(x: nullptr, y: nullptr); |
| 155 | |
| 156 | Value *Stride; |
| 157 | std::tie(args&: Start, args&: Stride) = matchStridedStart(Start: BO->getOperand(i_nocapture: OtherIndex), |
| 158 | Builder); |
| 159 | if (!Start) |
| 160 | return std::make_pair(x: nullptr, y: nullptr); |
| 161 | |
| 162 | Builder.SetInsertPoint(BO); |
| 163 | Builder.SetCurrentDebugLocation(DebugLoc()); |
| 164 | // Add the splat value to the start or multiply the start and stride by the |
| 165 | // splat. |
| 166 | switch (BO->getOpcode()) { |
| 167 | default: |
| 168 | llvm_unreachable("Unexpected opcode" ); |
| 169 | case Instruction::Or: |
| 170 | // TODO: We'd be better off creating disjoint or here, but we don't yet |
| 171 | // have an IRBuilder API for that. |
| 172 | [[fallthrough]]; |
| 173 | case Instruction::Add: |
| 174 | Start = Builder.CreateAdd(LHS: Start, RHS: Splat); |
| 175 | break; |
| 176 | case Instruction::Mul: |
| 177 | Start = Builder.CreateMul(LHS: Start, RHS: Splat); |
| 178 | Stride = Builder.CreateMul(LHS: Stride, RHS: Splat); |
| 179 | break; |
| 180 | case Instruction::Shl: |
| 181 | Start = Builder.CreateShl(LHS: Start, RHS: Splat); |
| 182 | Stride = Builder.CreateShl(LHS: Stride, RHS: Splat); |
| 183 | break; |
| 184 | } |
| 185 | |
| 186 | return std::make_pair(x&: Start, y&: Stride); |
| 187 | } |
| 188 | |
| 189 | // Recursively, walk about the use-def chain until we find a Phi with a strided |
| 190 | // start value. Build and update a scalar recurrence as we unwind the recursion. |
| 191 | // We also update the Stride as we unwind. Our goal is to move all of the |
| 192 | // arithmetic out of the loop. |
| 193 | bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L, |
| 194 | Value *&Stride, |
| 195 | PHINode *&BasePtr, |
| 196 | BinaryOperator *&Inc, |
| 197 | IRBuilderBase &Builder) { |
| 198 | // Our base case is a Phi. |
| 199 | if (auto *Phi = dyn_cast<PHINode>(Val: Index)) { |
| 200 | // A phi node we want to perform this function on should be from the |
| 201 | // loop header. |
| 202 | if (Phi->getParent() != L->getHeader()) |
| 203 | return false; |
| 204 | |
| 205 | Value *Step, *Start; |
| 206 | if (!matchSimpleRecurrence(P: Phi, BO&: Inc, Start, Step) || |
| 207 | Inc->getOpcode() != Instruction::Add) |
| 208 | return false; |
| 209 | assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi." ); |
| 210 | unsigned IncrementingBlock = Phi->getIncomingValue(i: 0) == Inc ? 0 : 1; |
| 211 | assert(Phi->getIncomingValue(IncrementingBlock) == Inc && |
| 212 | "Expected one operand of phi to be Inc" ); |
| 213 | |
| 214 | // Step should be a splat. |
| 215 | Step = getSplatValue(V: Step); |
| 216 | if (!Step) |
| 217 | return false; |
| 218 | |
| 219 | std::tie(args&: Start, args&: Stride) = matchStridedStart(Start, Builder); |
| 220 | if (!Start) |
| 221 | return false; |
| 222 | assert(Stride != nullptr); |
| 223 | |
| 224 | // Build scalar phi and increment. |
| 225 | BasePtr = |
| 226 | PHINode::Create(Ty: Start->getType(), NumReservedValues: 2, NameStr: Phi->getName() + ".scalar" , InsertBefore: Phi->getIterator()); |
| 227 | Inc = BinaryOperator::CreateAdd(V1: BasePtr, V2: Step, Name: Inc->getName() + ".scalar" , |
| 228 | InsertBefore: Inc->getIterator()); |
| 229 | BasePtr->addIncoming(V: Start, BB: Phi->getIncomingBlock(i: 1 - IncrementingBlock)); |
| 230 | BasePtr->addIncoming(V: Inc, BB: Phi->getIncomingBlock(i: IncrementingBlock)); |
| 231 | |
| 232 | // Note that this Phi might be eligible for removal. |
| 233 | MaybeDeadPHIs.push_back(Elt: Phi); |
| 234 | return true; |
| 235 | } |
| 236 | |
| 237 | // Otherwise look for binary operator. |
| 238 | auto *BO = dyn_cast<BinaryOperator>(Val: Index); |
| 239 | if (!BO) |
| 240 | return false; |
| 241 | |
| 242 | switch (BO->getOpcode()) { |
| 243 | default: |
| 244 | return false; |
| 245 | case Instruction::Or: |
| 246 | // We need to be able to treat Or as Add. |
| 247 | if (!cast<PossiblyDisjointInst>(Val: BO)->isDisjoint()) |
| 248 | return false; |
| 249 | break; |
| 250 | case Instruction::Add: |
| 251 | break; |
| 252 | case Instruction::Shl: |
| 253 | break; |
| 254 | case Instruction::Mul: |
| 255 | break; |
| 256 | } |
| 257 | |
| 258 | // We should have one operand in the loop and one splat. |
| 259 | Value *OtherOp; |
| 260 | if (isa<Instruction>(Val: BO->getOperand(i_nocapture: 0)) && |
| 261 | L->contains(Inst: cast<Instruction>(Val: BO->getOperand(i_nocapture: 0)))) { |
| 262 | Index = cast<Instruction>(Val: BO->getOperand(i_nocapture: 0)); |
| 263 | OtherOp = BO->getOperand(i_nocapture: 1); |
| 264 | } else if (isa<Instruction>(Val: BO->getOperand(i_nocapture: 1)) && |
| 265 | L->contains(Inst: cast<Instruction>(Val: BO->getOperand(i_nocapture: 1))) && |
| 266 | Instruction::isCommutative(Opcode: BO->getOpcode())) { |
| 267 | Index = cast<Instruction>(Val: BO->getOperand(i_nocapture: 1)); |
| 268 | OtherOp = BO->getOperand(i_nocapture: 0); |
| 269 | } else { |
| 270 | return false; |
| 271 | } |
| 272 | |
| 273 | // Make sure other op is loop invariant. |
| 274 | if (!L->isLoopInvariant(V: OtherOp)) |
| 275 | return false; |
| 276 | |
| 277 | // Make sure we have a splat. |
| 278 | Value *SplatOp = getSplatValue(V: OtherOp); |
| 279 | if (!SplatOp) |
| 280 | return false; |
| 281 | |
| 282 | // Recurse up the use-def chain. |
| 283 | if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder)) |
| 284 | return false; |
| 285 | |
| 286 | // Locate the Step and Start values from the recurrence. |
| 287 | unsigned StepIndex = Inc->getOperand(i_nocapture: 0) == BasePtr ? 1 : 0; |
| 288 | unsigned StartBlock = BasePtr->getOperand(i_nocapture: 0) == Inc ? 1 : 0; |
| 289 | Value *Step = Inc->getOperand(i_nocapture: StepIndex); |
| 290 | Value *Start = BasePtr->getOperand(i_nocapture: StartBlock); |
| 291 | |
| 292 | // We need to adjust the start value in the preheader. |
| 293 | Builder.SetInsertPoint( |
| 294 | BasePtr->getIncomingBlock(i: StartBlock)->getTerminator()); |
| 295 | Builder.SetCurrentDebugLocation(DebugLoc()); |
| 296 | |
| 297 | // TODO: Share this switch with matchStridedStart? |
| 298 | switch (BO->getOpcode()) { |
| 299 | default: |
| 300 | llvm_unreachable("Unexpected opcode!" ); |
| 301 | case Instruction::Add: |
| 302 | case Instruction::Or: { |
| 303 | // An add only affects the start value. It's ok to do this for Or because |
| 304 | // we already checked that there are no common set bits. |
| 305 | Start = Builder.CreateAdd(LHS: Start, RHS: SplatOp, Name: "start" ); |
| 306 | break; |
| 307 | } |
| 308 | case Instruction::Mul: { |
| 309 | Start = Builder.CreateMul(LHS: Start, RHS: SplatOp, Name: "start" ); |
| 310 | Stride = Builder.CreateMul(LHS: Stride, RHS: SplatOp, Name: "stride" ); |
| 311 | break; |
| 312 | } |
| 313 | case Instruction::Shl: { |
| 314 | Start = Builder.CreateShl(LHS: Start, RHS: SplatOp, Name: "start" ); |
| 315 | Stride = Builder.CreateShl(LHS: Stride, RHS: SplatOp, Name: "stride" ); |
| 316 | break; |
| 317 | } |
| 318 | } |
| 319 | |
| 320 | // If the Step was defined inside the loop, adjust it before its definition |
| 321 | // instead of in the preheader. |
| 322 | if (auto *StepI = dyn_cast<Instruction>(Val: Step); StepI && L->contains(Inst: StepI)) |
| 323 | Builder.SetInsertPoint(*StepI->getInsertionPointAfterDef()); |
| 324 | |
| 325 | switch (BO->getOpcode()) { |
| 326 | default: |
| 327 | break; |
| 328 | case Instruction::Mul: |
| 329 | Step = Builder.CreateMul(LHS: Step, RHS: SplatOp, Name: "step" ); |
| 330 | break; |
| 331 | case Instruction::Shl: |
| 332 | Step = Builder.CreateShl(LHS: Step, RHS: SplatOp, Name: "step" ); |
| 333 | break; |
| 334 | } |
| 335 | |
| 336 | Inc->setOperand(i_nocapture: StepIndex, Val_nocapture: Step); |
| 337 | BasePtr->setIncomingValue(i: StartBlock, V: Start); |
| 338 | return true; |
| 339 | } |
| 340 | |
| 341 | std::pair<Value *, Value *> |
| 342 | RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr, |
| 343 | IRBuilderBase &Builder) { |
| 344 | |
| 345 | // A gather/scatter of a splat is a zero strided load/store. |
| 346 | if (auto *BasePtr = getSplatValue(V: Ptr)) { |
| 347 | Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); |
| 348 | return std::make_pair(x&: BasePtr, y: ConstantInt::get(Ty: IntPtrTy, V: 0)); |
| 349 | } |
| 350 | |
| 351 | auto *GEP = dyn_cast<GetElementPtrInst>(Val: Ptr); |
| 352 | if (!GEP) |
| 353 | return std::make_pair(x: nullptr, y: nullptr); |
| 354 | |
| 355 | auto I = StridedAddrs.find(Val: GEP); |
| 356 | if (I != StridedAddrs.end()) |
| 357 | return I->second; |
| 358 | |
| 359 | SmallVector<Value *, 2> Ops(GEP->operands()); |
| 360 | |
| 361 | // If the base pointer is a vector, check if it's strided. |
| 362 | Value *Base = GEP->getPointerOperand(); |
| 363 | if (auto *BaseInst = dyn_cast<Instruction>(Val: Base); |
| 364 | BaseInst && BaseInst->getType()->isVectorTy()) { |
| 365 | // If GEP's offset is scalar then we can add it to the base pointer's base. |
| 366 | auto IsScalar = [](Value *Idx) { return !Idx->getType()->isVectorTy(); }; |
| 367 | if (all_of(Range: GEP->indices(), P: IsScalar)) { |
| 368 | auto [BaseBase, Stride] = determineBaseAndStride(Ptr: BaseInst, Builder); |
| 369 | if (BaseBase) { |
| 370 | Builder.SetInsertPoint(GEP); |
| 371 | SmallVector<Value *> Indices(GEP->indices()); |
| 372 | Value *OffsetBase = |
| 373 | Builder.CreateGEP(Ty: GEP->getSourceElementType(), Ptr: BaseBase, IdxList: Indices, |
| 374 | Name: GEP->getName() + "offset" , NW: GEP->isInBounds()); |
| 375 | return {OffsetBase, Stride}; |
| 376 | } |
| 377 | } |
| 378 | } |
| 379 | |
| 380 | // Base pointer needs to be a scalar. |
| 381 | Value *ScalarBase = Base; |
| 382 | if (ScalarBase->getType()->isVectorTy()) { |
| 383 | ScalarBase = getSplatValue(V: ScalarBase); |
| 384 | if (!ScalarBase) |
| 385 | return std::make_pair(x: nullptr, y: nullptr); |
| 386 | } |
| 387 | |
| 388 | std::optional<unsigned> VecOperand; |
| 389 | unsigned TypeScale = 0; |
| 390 | |
| 391 | // Look for a vector operand and scale. |
| 392 | gep_type_iterator GTI = gep_type_begin(GEP); |
| 393 | for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) { |
| 394 | if (!Ops[i]->getType()->isVectorTy()) |
| 395 | continue; |
| 396 | |
| 397 | if (VecOperand) |
| 398 | return std::make_pair(x: nullptr, y: nullptr); |
| 399 | |
| 400 | VecOperand = i; |
| 401 | |
| 402 | TypeSize TS = GTI.getSequentialElementStride(DL: *DL); |
| 403 | if (TS.isScalable()) |
| 404 | return std::make_pair(x: nullptr, y: nullptr); |
| 405 | |
| 406 | TypeScale = TS.getFixedValue(); |
| 407 | } |
| 408 | |
| 409 | // We need to find a vector index to simplify. |
| 410 | if (!VecOperand) |
| 411 | return std::make_pair(x: nullptr, y: nullptr); |
| 412 | |
| 413 | // We can't extract the stride if the arithmetic is done at a different size |
| 414 | // than the pointer type. Adding the stride later may not wrap correctly. |
| 415 | // Technically we could handle wider indices, but I don't expect that in |
| 416 | // practice. Handle one special case here - constants. This simplifies |
| 417 | // writing test cases. |
| 418 | Value *VecIndex = Ops[*VecOperand]; |
| 419 | Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType()); |
| 420 | if (VecIndex->getType() != VecIntPtrTy) { |
| 421 | auto *VecIndexC = dyn_cast<Constant>(Val: VecIndex); |
| 422 | if (!VecIndexC) |
| 423 | return std::make_pair(x: nullptr, y: nullptr); |
| 424 | if (VecIndex->getType()->getScalarSizeInBits() > VecIntPtrTy->getScalarSizeInBits()) |
| 425 | VecIndex = ConstantFoldCastInstruction(opcode: Instruction::Trunc, V: VecIndexC, DestTy: VecIntPtrTy); |
| 426 | else |
| 427 | VecIndex = ConstantFoldCastInstruction(opcode: Instruction::SExt, V: VecIndexC, DestTy: VecIntPtrTy); |
| 428 | } |
| 429 | |
| 430 | // Handle the non-recursive case. This is what we see if the vectorizer |
| 431 | // decides to use a scalar IV + vid on demand instead of a vector IV. |
| 432 | auto [Start, Stride] = matchStridedStart(Start: VecIndex, Builder); |
| 433 | if (Start) { |
| 434 | assert(Stride); |
| 435 | Builder.SetInsertPoint(GEP); |
| 436 | |
| 437 | // Replace the vector index with the scalar start and build a scalar GEP. |
| 438 | Ops[*VecOperand] = Start; |
| 439 | Type *SourceTy = GEP->getSourceElementType(); |
| 440 | Value *BasePtr = |
| 441 | Builder.CreateGEP(Ty: SourceTy, Ptr: ScalarBase, IdxList: ArrayRef(Ops).drop_front()); |
| 442 | |
| 443 | // Convert stride to pointer size if needed. |
| 444 | Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); |
| 445 | assert(Stride->getType() == IntPtrTy && "Unexpected type" ); |
| 446 | |
| 447 | // Scale the stride by the size of the indexed type. |
| 448 | if (TypeScale != 1) |
| 449 | Stride = Builder.CreateMul(LHS: Stride, RHS: ConstantInt::get(Ty: IntPtrTy, V: TypeScale)); |
| 450 | |
| 451 | auto P = std::make_pair(x&: BasePtr, y&: Stride); |
| 452 | StridedAddrs[GEP] = P; |
| 453 | return P; |
| 454 | } |
| 455 | |
| 456 | // Make sure we're in a loop and that has a pre-header and a single latch. |
| 457 | Loop *L = LI->getLoopFor(BB: GEP->getParent()); |
| 458 | if (!L || !L->getLoopPreheader() || !L->getLoopLatch()) |
| 459 | return std::make_pair(x: nullptr, y: nullptr); |
| 460 | |
| 461 | BinaryOperator *Inc; |
| 462 | PHINode *BasePhi; |
| 463 | if (!matchStridedRecurrence(Index: VecIndex, L, Stride, BasePtr&: BasePhi, Inc, Builder)) |
| 464 | return std::make_pair(x: nullptr, y: nullptr); |
| 465 | |
| 466 | assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi." ); |
| 467 | unsigned IncrementingBlock = BasePhi->getOperand(i_nocapture: 0) == Inc ? 0 : 1; |
| 468 | assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc && |
| 469 | "Expected one operand of phi to be Inc" ); |
| 470 | |
| 471 | Builder.SetInsertPoint(GEP); |
| 472 | |
| 473 | // Replace the vector index with the scalar phi and build a scalar GEP. |
| 474 | Ops[*VecOperand] = BasePhi; |
| 475 | Type *SourceTy = GEP->getSourceElementType(); |
| 476 | Value *BasePtr = |
| 477 | Builder.CreateGEP(Ty: SourceTy, Ptr: ScalarBase, IdxList: ArrayRef(Ops).drop_front()); |
| 478 | |
| 479 | // Final adjustments to stride should go in the start block. |
| 480 | Builder.SetInsertPoint( |
| 481 | BasePhi->getIncomingBlock(i: 1 - IncrementingBlock)->getTerminator()); |
| 482 | |
| 483 | // Convert stride to pointer size if needed. |
| 484 | Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); |
| 485 | assert(Stride->getType() == IntPtrTy && "Unexpected type" ); |
| 486 | |
| 487 | // Scale the stride by the size of the indexed type. |
| 488 | if (TypeScale != 1) |
| 489 | Stride = Builder.CreateMul(LHS: Stride, RHS: ConstantInt::get(Ty: IntPtrTy, V: TypeScale)); |
| 490 | |
| 491 | auto P = std::make_pair(x&: BasePtr, y&: Stride); |
| 492 | StridedAddrs[GEP] = P; |
| 493 | return P; |
| 494 | } |
| 495 | |
| 496 | bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II) { |
| 497 | VectorType *DataType; |
| 498 | Value *StoreVal = nullptr, *Ptr, *Mask, *EVL = nullptr; |
| 499 | MaybeAlign MA; |
| 500 | switch (II->getIntrinsicID()) { |
| 501 | case Intrinsic::masked_gather: |
| 502 | DataType = cast<VectorType>(Val: II->getType()); |
| 503 | Ptr = II->getArgOperand(i: 0); |
| 504 | MA = cast<ConstantInt>(Val: II->getArgOperand(i: 1))->getMaybeAlignValue(); |
| 505 | Mask = II->getArgOperand(i: 2); |
| 506 | break; |
| 507 | case Intrinsic::vp_gather: |
| 508 | DataType = cast<VectorType>(Val: II->getType()); |
| 509 | Ptr = II->getArgOperand(i: 0); |
| 510 | MA = II->getParamAlign(ArgNo: 0).value_or( |
| 511 | u: DL->getABITypeAlign(Ty: DataType->getElementType())); |
| 512 | Mask = II->getArgOperand(i: 1); |
| 513 | EVL = II->getArgOperand(i: 2); |
| 514 | break; |
| 515 | case Intrinsic::masked_scatter: |
| 516 | DataType = cast<VectorType>(Val: II->getArgOperand(i: 0)->getType()); |
| 517 | StoreVal = II->getArgOperand(i: 0); |
| 518 | Ptr = II->getArgOperand(i: 1); |
| 519 | MA = cast<ConstantInt>(Val: II->getArgOperand(i: 2))->getMaybeAlignValue(); |
| 520 | Mask = II->getArgOperand(i: 3); |
| 521 | break; |
| 522 | case Intrinsic::vp_scatter: |
| 523 | DataType = cast<VectorType>(Val: II->getArgOperand(i: 0)->getType()); |
| 524 | StoreVal = II->getArgOperand(i: 0); |
| 525 | Ptr = II->getArgOperand(i: 1); |
| 526 | MA = II->getParamAlign(ArgNo: 1).value_or( |
| 527 | u: DL->getABITypeAlign(Ty: DataType->getElementType())); |
| 528 | Mask = II->getArgOperand(i: 2); |
| 529 | EVL = II->getArgOperand(i: 3); |
| 530 | break; |
| 531 | default: |
| 532 | llvm_unreachable("Unexpected intrinsic" ); |
| 533 | } |
| 534 | |
| 535 | // Make sure the operation will be supported by the backend. |
| 536 | EVT DataTypeVT = TLI->getValueType(DL: *DL, Ty: DataType); |
| 537 | if (!MA || !TLI->isLegalStridedLoadStore(DataType: DataTypeVT, Alignment: *MA)) |
| 538 | return false; |
| 539 | |
| 540 | // FIXME: Let the backend type legalize by splitting/widening? |
| 541 | if (!TLI->isTypeLegal(VT: DataTypeVT)) |
| 542 | return false; |
| 543 | |
| 544 | // Pointer should be an instruction. |
| 545 | auto *PtrI = dyn_cast<Instruction>(Val: Ptr); |
| 546 | if (!PtrI) |
| 547 | return false; |
| 548 | |
| 549 | LLVMContext &Ctx = PtrI->getContext(); |
| 550 | IRBuilder Builder(Ctx, InstSimplifyFolder(*DL)); |
| 551 | Builder.SetInsertPoint(PtrI); |
| 552 | |
| 553 | Value *BasePtr, *Stride; |
| 554 | std::tie(args&: BasePtr, args&: Stride) = determineBaseAndStride(Ptr: PtrI, Builder); |
| 555 | if (!BasePtr) |
| 556 | return false; |
| 557 | assert(Stride != nullptr); |
| 558 | |
| 559 | Builder.SetInsertPoint(II); |
| 560 | |
| 561 | if (!EVL) |
| 562 | EVL = Builder.CreateElementCount( |
| 563 | Ty: Builder.getInt32Ty(), EC: cast<VectorType>(Val: DataType)->getElementCount()); |
| 564 | |
| 565 | CallInst *Call; |
| 566 | |
| 567 | if (!StoreVal) { |
| 568 | Call = Builder.CreateIntrinsic( |
| 569 | ID: Intrinsic::experimental_vp_strided_load, |
| 570 | Types: {DataType, BasePtr->getType(), Stride->getType()}, |
| 571 | Args: {BasePtr, Stride, Mask, EVL}); |
| 572 | |
| 573 | // Merge llvm.masked.gather's passthru |
| 574 | if (II->getIntrinsicID() == Intrinsic::masked_gather) |
| 575 | Call = Builder.CreateIntrinsic(ID: Intrinsic::vp_select, Types: {DataType}, |
| 576 | Args: {Mask, Call, II->getArgOperand(i: 3), EVL}); |
| 577 | } else |
| 578 | Call = Builder.CreateIntrinsic( |
| 579 | ID: Intrinsic::experimental_vp_strided_store, |
| 580 | Types: {DataType, BasePtr->getType(), Stride->getType()}, |
| 581 | Args: {StoreVal, BasePtr, Stride, Mask, EVL}); |
| 582 | |
| 583 | Call->takeName(V: II); |
| 584 | II->replaceAllUsesWith(V: Call); |
| 585 | II->eraseFromParent(); |
| 586 | |
| 587 | if (PtrI->use_empty()) |
| 588 | RecursivelyDeleteTriviallyDeadInstructions(V: PtrI); |
| 589 | |
| 590 | return true; |
| 591 | } |
| 592 | |
| 593 | bool RISCVGatherScatterLowering::runOnFunction(Function &F) { |
| 594 | if (skipFunction(F)) |
| 595 | return false; |
| 596 | |
| 597 | auto &TPC = getAnalysis<TargetPassConfig>(); |
| 598 | auto &TM = TPC.getTM<RISCVTargetMachine>(); |
| 599 | ST = &TM.getSubtarget<RISCVSubtarget>(F); |
| 600 | if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors()) |
| 601 | return false; |
| 602 | |
| 603 | TLI = ST->getTargetLowering(); |
| 604 | DL = &F.getDataLayout(); |
| 605 | LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); |
| 606 | |
| 607 | StridedAddrs.clear(); |
| 608 | |
| 609 | SmallVector<IntrinsicInst *, 4> Worklist; |
| 610 | |
| 611 | bool Changed = false; |
| 612 | |
| 613 | for (BasicBlock &BB : F) { |
| 614 | for (Instruction &I : BB) { |
| 615 | IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: &I); |
| 616 | if (!II) |
| 617 | continue; |
| 618 | switch (II->getIntrinsicID()) { |
| 619 | case Intrinsic::masked_gather: |
| 620 | case Intrinsic::masked_scatter: |
| 621 | case Intrinsic::vp_gather: |
| 622 | case Intrinsic::vp_scatter: |
| 623 | Worklist.push_back(Elt: II); |
| 624 | break; |
| 625 | default: |
| 626 | break; |
| 627 | } |
| 628 | } |
| 629 | } |
| 630 | |
| 631 | // Rewrite gather/scatter to form strided load/store if possible. |
| 632 | for (auto *II : Worklist) |
| 633 | Changed |= tryCreateStridedLoadStore(II); |
| 634 | |
| 635 | // Remove any dead phis. |
| 636 | while (!MaybeDeadPHIs.empty()) { |
| 637 | if (auto *Phi = dyn_cast_or_null<PHINode>(Val: MaybeDeadPHIs.pop_back_val())) |
| 638 | RecursivelyDeleteDeadPHINode(PN: Phi); |
| 639 | } |
| 640 | |
| 641 | return Changed; |
| 642 | } |
| 643 | |