1//===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// This pass custom lowers llvm.gather and llvm.scatter instructions to
10/// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to
11/// produce a better final result as we go.
12//
13//===----------------------------------------------------------------------===//
14
15#include "ARM.h"
16#include "ARMBaseInstrInfo.h"
17#include "ARMSubtarget.h"
18#include "llvm/Analysis/LoopInfo.h"
19#include "llvm/Analysis/TargetTransformInfo.h"
20#include "llvm/Analysis/ValueTracking.h"
21#include "llvm/CodeGen/TargetLowering.h"
22#include "llvm/CodeGen/TargetPassConfig.h"
23#include "llvm/CodeGen/TargetSubtargetInfo.h"
24#include "llvm/IR/BasicBlock.h"
25#include "llvm/IR/Constant.h"
26#include "llvm/IR/Constants.h"
27#include "llvm/IR/DerivedTypes.h"
28#include "llvm/IR/Function.h"
29#include "llvm/IR/IRBuilder.h"
30#include "llvm/IR/InstrTypes.h"
31#include "llvm/IR/Instruction.h"
32#include "llvm/IR/Instructions.h"
33#include "llvm/IR/IntrinsicInst.h"
34#include "llvm/IR/Intrinsics.h"
35#include "llvm/IR/IntrinsicsARM.h"
36#include "llvm/IR/PatternMatch.h"
37#include "llvm/IR/Type.h"
38#include "llvm/IR/Value.h"
39#include "llvm/InitializePasses.h"
40#include "llvm/Pass.h"
41#include "llvm/Support/Casting.h"
42#include "llvm/Transforms/Utils/Local.h"
43#include <cassert>
44
45using namespace llvm;
46
47#define DEBUG_TYPE "arm-mve-gather-scatter-lowering"
48
49cl::opt<bool> EnableMaskedGatherScatters(
50 "enable-arm-maskedgatscat", cl::Hidden, cl::init(Val: true),
51 cl::desc("Enable the generation of masked gathers and scatters"));
52
53namespace {
54
55class MVEGatherScatterLowering : public FunctionPass {
56public:
57 static char ID; // Pass identification, replacement for typeid
58
59 explicit MVEGatherScatterLowering() : FunctionPass(ID) {
60 initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry());
61 }
62
63 bool runOnFunction(Function &F) override;
64
65 StringRef getPassName() const override {
66 return "MVE gather/scatter lowering";
67 }
68
69 void getAnalysisUsage(AnalysisUsage &AU) const override {
70 AU.setPreservesCFG();
71 AU.addRequired<TargetPassConfig>();
72 AU.addRequired<LoopInfoWrapperPass>();
73 FunctionPass::getAnalysisUsage(AU);
74 }
75
76private:
77 LoopInfo *LI = nullptr;
78 const DataLayout *DL;
79
80 // Check this is a valid gather with correct alignment
81 bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,
82 Align Alignment);
83 // Check whether Ptr is hidden behind a bitcast and look through it
84 void lookThroughBitcast(Value *&Ptr);
85 // Decompose a ptr into Base and Offsets, potentially using a GEP to return a
86 // scalar base and vector offsets, or else fallback to using a base of 0 and
87 // offset of Ptr where possible.
88 Value *decomposePtr(Value *Ptr, Value *&Offsets, int &Scale,
89 FixedVectorType *Ty, Type *MemoryTy,
90 IRBuilder<> &Builder);
91 // Check for a getelementptr and deduce base and offsets from it, on success
92 // returning the base directly and the offsets indirectly using the Offsets
93 // argument
94 Value *decomposeGEP(Value *&Offsets, FixedVectorType *Ty,
95 GetElementPtrInst *GEP, IRBuilder<> &Builder);
96 // Compute the scale of this gather/scatter instruction
97 int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize);
98 // If the value is a constant, or derived from constants via additions
99 // and multilications, return its numeric value
100 std::optional<int64_t> getIfConst(const Value *V);
101 // If Inst is an add instruction, check whether one summand is a
102 // constant. If so, scale this constant and return it together with
103 // the other summand.
104 std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale);
105
106 Instruction *lowerGather(IntrinsicInst *I);
107 // Create a gather from a base + vector of offsets
108 Instruction *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,
109 Instruction *&Root,
110 IRBuilder<> &Builder);
111 // Create a gather from a vector of pointers
112 Instruction *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,
113 IRBuilder<> &Builder,
114 int64_t Increment = 0);
115 // Create an incrementing gather from a vector of pointers
116 Instruction *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr,
117 IRBuilder<> &Builder,
118 int64_t Increment = 0);
119
120 Instruction *lowerScatter(IntrinsicInst *I);
121 // Create a scatter to a base + vector of offsets
122 Instruction *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets,
123 IRBuilder<> &Builder);
124 // Create a scatter to a vector of pointers
125 Instruction *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr,
126 IRBuilder<> &Builder,
127 int64_t Increment = 0);
128 // Create an incrementing scatter from a vector of pointers
129 Instruction *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr,
130 IRBuilder<> &Builder,
131 int64_t Increment = 0);
132
133 // QI gathers and scatters can increment their offsets on their own if
134 // the increment is a constant value (digit)
135 Instruction *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *Ptr,
136 IRBuilder<> &Builder);
137 // QI gathers/scatters can increment their offsets on their own if the
138 // increment is a constant value (digit) - this creates a writeback QI
139 // gather/scatter
140 Instruction *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr,
141 Value *Ptr, unsigned TypeScale,
142 IRBuilder<> &Builder);
143
144 // Optimise the base and offsets of the given address
145 bool optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI);
146 // Try to fold consecutive geps together into one
147 Value *foldGEP(GetElementPtrInst *GEP, Value *&Offsets, unsigned &Scale,
148 IRBuilder<> &Builder);
149 // Check whether these offsets could be moved out of the loop they're in
150 bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI);
151 // Pushes the given add out of the loop
152 void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);
153 // Pushes the given mul or shl out of the loop
154 void pushOutMulShl(unsigned Opc, PHINode *&Phi, Value *IncrementPerRound,
155 Value *OffsSecondOperand, unsigned LoopIncrement,
156 IRBuilder<> &Builder);
157};
158
159} // end anonymous namespace
160
161char MVEGatherScatterLowering::ID = 0;
162
163INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE,
164 "MVE gather/scattering lowering pass", false, false)
165
166Pass *llvm::createMVEGatherScatterLoweringPass() {
167 return new MVEGatherScatterLowering();
168}
169
170bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,
171 unsigned ElemSize,
172 Align Alignment) {
173 if (((NumElements == 4 &&
174 (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) ||
175 (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) ||
176 (NumElements == 16 && ElemSize == 8)) &&
177 Alignment >= ElemSize / 8)
178 return true;
179 LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have "
180 << "valid alignment or vector type \n");
181 return false;
182}
183
184static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount) {
185 // Offsets that are not of type <N x i32> are sign extended by the
186 // getelementptr instruction, and MVE gathers/scatters treat the offset as
187 // unsigned. Thus, if the element size is smaller than 32, we can only allow
188 // positive offsets - i.e., the offsets are not allowed to be variables we
189 // can't look into.
190 // Additionally, <N x i32> offsets have to either originate from a zext of a
191 // vector with element types smaller or equal the type of the gather we're
192 // looking at, or consist of constants that we can check are small enough
193 // to fit into the gather type.
194 // Thus we check that 0 < value < 2^TargetElemSize.
195 unsigned TargetElemSize = 128 / TargetElemCount;
196 unsigned OffsetElemSize = cast<FixedVectorType>(Val: Offsets->getType())
197 ->getElementType()
198 ->getScalarSizeInBits();
199 if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) {
200 Constant *ConstOff = dyn_cast<Constant>(Val: Offsets);
201 if (!ConstOff)
202 return false;
203 int64_t TargetElemMaxSize = (1ULL << TargetElemSize);
204 auto CheckValueSize = [TargetElemMaxSize](Value *OffsetElem) {
205 ConstantInt *OConst = dyn_cast<ConstantInt>(Val: OffsetElem);
206 if (!OConst)
207 return false;
208 int SExtValue = OConst->getSExtValue();
209 if (SExtValue >= TargetElemMaxSize || SExtValue < 0)
210 return false;
211 return true;
212 };
213 if (isa<FixedVectorType>(Val: ConstOff->getType())) {
214 for (unsigned i = 0; i < TargetElemCount; i++) {
215 if (!CheckValueSize(ConstOff->getAggregateElement(Elt: i)))
216 return false;
217 }
218 } else {
219 if (!CheckValueSize(ConstOff))
220 return false;
221 }
222 }
223 return true;
224}
225
226Value *MVEGatherScatterLowering::decomposePtr(Value *Ptr, Value *&Offsets,
227 int &Scale, FixedVectorType *Ty,
228 Type *MemoryTy,
229 IRBuilder<> &Builder) {
230 if (auto *GEP = dyn_cast<GetElementPtrInst>(Val: Ptr)) {
231 if (Value *V = decomposeGEP(Offsets, Ty, GEP, Builder)) {
232 Scale =
233 computeScale(GEPElemSize: GEP->getSourceElementType()->getPrimitiveSizeInBits(),
234 MemoryElemSize: MemoryTy->getScalarSizeInBits());
235 return Scale == -1 ? nullptr : V;
236 }
237 }
238
239 // If we couldn't use the GEP (or it doesn't exist), attempt to use a
240 // BasePtr of 0 with Ptr as the Offsets, so long as there are only 4
241 // elements.
242 FixedVectorType *PtrTy = cast<FixedVectorType>(Val: Ptr->getType());
243 if (PtrTy->getNumElements() != 4 || MemoryTy->getScalarSizeInBits() == 32)
244 return nullptr;
245 Value *Zero = ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0);
246 Value *BasePtr = Builder.CreateIntToPtr(V: Zero, DestTy: Builder.getPtrTy());
247 Offsets = Builder.CreatePtrToInt(
248 V: Ptr, DestTy: FixedVectorType::get(ElementType: Builder.getInt32Ty(), NumElts: 4));
249 Scale = 0;
250 return BasePtr;
251}
252
253Value *MVEGatherScatterLowering::decomposeGEP(Value *&Offsets,
254 FixedVectorType *Ty,
255 GetElementPtrInst *GEP,
256 IRBuilder<> &Builder) {
257 if (!GEP) {
258 LLVM_DEBUG(dbgs() << "masked gathers/scatters: no getelementpointer "
259 << "found\n");
260 return nullptr;
261 }
262 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found."
263 << " Looking at intrinsic for base + vector of offsets\n");
264 Value *GEPPtr = GEP->getPointerOperand();
265 Offsets = GEP->getOperand(i_nocapture: 1);
266 if (GEPPtr->getType()->isVectorTy() ||
267 !isa<FixedVectorType>(Val: Offsets->getType()))
268 return nullptr;
269
270 if (GEP->getNumOperands() != 2) {
271 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many"
272 << " operands. Expanding.\n");
273 return nullptr;
274 }
275 Offsets = GEP->getOperand(i_nocapture: 1);
276 unsigned OffsetsElemCount =
277 cast<FixedVectorType>(Val: Offsets->getType())->getNumElements();
278 // Paranoid check whether the number of parallel lanes is the same
279 assert(Ty->getNumElements() == OffsetsElemCount);
280
281 ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Val: Offsets);
282 if (ZextOffs)
283 Offsets = ZextOffs->getOperand(i_nocapture: 0);
284 FixedVectorType *OffsetType = cast<FixedVectorType>(Val: Offsets->getType());
285
286 // If the offsets are already being zext-ed to <N x i32>, that relieves us of
287 // having to make sure that they won't overflow.
288 if (!ZextOffs || cast<FixedVectorType>(Val: ZextOffs->getDestTy())
289 ->getElementType()
290 ->getScalarSizeInBits() != 32)
291 if (!checkOffsetSize(Offsets, TargetElemCount: OffsetsElemCount))
292 return nullptr;
293
294 // The offset sizes have been checked; if any truncating or zext-ing is
295 // required to fix them, do that now
296 if (Ty != Offsets->getType()) {
297 if ((Ty->getElementType()->getScalarSizeInBits() <
298 OffsetType->getElementType()->getScalarSizeInBits())) {
299 Offsets = Builder.CreateTrunc(V: Offsets, DestTy: Ty);
300 } else {
301 Offsets = Builder.CreateZExt(V: Offsets, DestTy: VectorType::getInteger(VTy: Ty));
302 }
303 }
304 // If none of the checks failed, return the gep's base pointer
305 LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n");
306 return GEPPtr;
307}
308
309void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) {
310 // Look through bitcast instruction if #elements is the same
311 if (auto *BitCast = dyn_cast<BitCastInst>(Val: Ptr)) {
312 auto *BCTy = cast<FixedVectorType>(Val: BitCast->getType());
313 auto *BCSrcTy = cast<FixedVectorType>(Val: BitCast->getOperand(i_nocapture: 0)->getType());
314 if (BCTy->getNumElements() == BCSrcTy->getNumElements()) {
315 LLVM_DEBUG(dbgs() << "masked gathers/scatters: looking through "
316 << "bitcast\n");
317 Ptr = BitCast->getOperand(i_nocapture: 0);
318 }
319 }
320}
321
322int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize,
323 unsigned MemoryElemSize) {
324 // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2,
325 // or a 8bit, 16bit or 32bit load/store scaled by 1
326 if (GEPElemSize == 32 && MemoryElemSize == 32)
327 return 2;
328 else if (GEPElemSize == 16 && MemoryElemSize == 16)
329 return 1;
330 else if (GEPElemSize == 8)
331 return 0;
332 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't "
333 << "create intrinsic\n");
334 return -1;
335}
336
337std::optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
338 const Constant *C = dyn_cast<Constant>(Val: V);
339 if (C && C->getSplatValue())
340 return std::optional<int64_t>{C->getUniqueInteger().getSExtValue()};
341 if (!isa<Instruction>(Val: V))
342 return std::optional<int64_t>{};
343
344 const Instruction *I = cast<Instruction>(Val: V);
345 if (I->getOpcode() == Instruction::Add || I->getOpcode() == Instruction::Or ||
346 I->getOpcode() == Instruction::Mul ||
347 I->getOpcode() == Instruction::Shl) {
348 std::optional<int64_t> Op0 = getIfConst(V: I->getOperand(i: 0));
349 std::optional<int64_t> Op1 = getIfConst(V: I->getOperand(i: 1));
350 if (!Op0 || !Op1)
351 return std::optional<int64_t>{};
352 if (I->getOpcode() == Instruction::Add)
353 return std::optional<int64_t>{*Op0 + *Op1};
354 if (I->getOpcode() == Instruction::Mul)
355 return std::optional<int64_t>{*Op0 * *Op1};
356 if (I->getOpcode() == Instruction::Shl)
357 return std::optional<int64_t>{*Op0 << *Op1};
358 if (I->getOpcode() == Instruction::Or)
359 return std::optional<int64_t>{*Op0 | *Op1};
360 }
361 return std::optional<int64_t>{};
362}
363
364// Return true if I is an Or instruction that is equivalent to an add, due to
365// the operands having no common bits set.
366static bool isAddLikeOr(Instruction *I, const DataLayout &DL) {
367 return I->getOpcode() == Instruction::Or &&
368 haveNoCommonBitsSet(LHSCache: I->getOperand(i: 0), RHSCache: I->getOperand(i: 1), SQ: DL);
369}
370
371std::pair<Value *, int64_t>
372MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) {
373 std::pair<Value *, int64_t> ReturnFalse =
374 std::pair<Value *, int64_t>(nullptr, 0);
375 // At this point, the instruction we're looking at must be an add or an
376 // add-like-or.
377 Instruction *Add = dyn_cast<Instruction>(Val: Inst);
378 if (Add == nullptr ||
379 (Add->getOpcode() != Instruction::Add && !isAddLikeOr(I: Add, DL: *DL)))
380 return ReturnFalse;
381
382 Value *Summand;
383 std::optional<int64_t> Const;
384 // Find out which operand the value that is increased is
385 if ((Const = getIfConst(V: Add->getOperand(i: 0))))
386 Summand = Add->getOperand(i: 1);
387 else if ((Const = getIfConst(V: Add->getOperand(i: 1))))
388 Summand = Add->getOperand(i: 0);
389 else
390 return ReturnFalse;
391
392 // Check that the constant is small enough for an incrementing gather
393 int64_t Immediate = *Const << TypeScale;
394 if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0)
395 return ReturnFalse;
396
397 return std::pair<Value *, int64_t>(Summand, Immediate);
398}
399
400Instruction *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
401 using namespace PatternMatch;
402 LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n"
403 << *I << "\n");
404
405 // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
406 // Attempt to turn the masked gather in I into a MVE intrinsic
407 // Potentially optimising the addressing modes as we do so.
408 auto *Ty = cast<FixedVectorType>(Val: I->getType());
409 Value *Ptr = I->getArgOperand(i: 0);
410 Align Alignment = I->getParamAlign(ArgNo: 0).valueOrOne();
411 Value *Mask = I->getArgOperand(i: 1);
412 Value *PassThru = I->getArgOperand(i: 2);
413
414 if (!isLegalTypeAndAlignment(NumElements: Ty->getNumElements(), ElemSize: Ty->getScalarSizeInBits(),
415 Alignment))
416 return nullptr;
417 lookThroughBitcast(Ptr);
418 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
419
420 IRBuilder<> Builder(I->getContext());
421 Builder.SetInsertPoint(I);
422 Builder.SetCurrentDebugLocation(I->getDebugLoc());
423
424 Instruction *Root = I;
425
426 Instruction *Load = tryCreateIncrementingGatScat(I, Ptr, Builder);
427 if (!Load)
428 Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder);
429 if (!Load)
430 Load = tryCreateMaskedGatherBase(I, Ptr, Builder);
431 if (!Load)
432 return nullptr;
433
434 if (!isa<UndefValue>(Val: PassThru) && !match(V: PassThru, P: m_Zero())) {
435 LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
436 << "creating select\n");
437 Load = SelectInst::Create(C: Mask, S1: Load, S2: PassThru);
438 Builder.Insert(I: Load);
439 }
440
441 Root->replaceAllUsesWith(V: Load);
442 Root->eraseFromParent();
443 if (Root != I)
444 // If this was an extending gather, we need to get rid of the sext/zext
445 // sext/zext as well as of the gather itself
446 I->eraseFromParent();
447
448 LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n"
449 << *Load << "\n");
450 return Load;
451}
452
453Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBase(
454 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
455 using namespace PatternMatch;
456 auto *Ty = cast<FixedVectorType>(Val: I->getType());
457 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
458 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
459 // Can't build an intrinsic for this
460 return nullptr;
461 Value *Mask = I->getArgOperand(i: 1);
462 if (match(V: Mask, P: m_One()))
463 return Builder.CreateIntrinsic(ID: Intrinsic::arm_mve_vldr_gather_base,
464 Types: {Ty, Ptr->getType()},
465 Args: {Ptr, Builder.getInt32(C: Increment)});
466 else
467 return Builder.CreateIntrinsic(
468 ID: Intrinsic::arm_mve_vldr_gather_base_predicated,
469 Types: {Ty, Ptr->getType(), Mask->getType()},
470 Args: {Ptr, Builder.getInt32(C: Increment), Mask});
471}
472
473Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB(
474 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
475 using namespace PatternMatch;
476 auto *Ty = cast<FixedVectorType>(Val: I->getType());
477 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers with "
478 << "writeback\n");
479 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
480 // Can't build an intrinsic for this
481 return nullptr;
482 Value *Mask = I->getArgOperand(i: 1);
483 if (match(V: Mask, P: m_One()))
484 return Builder.CreateIntrinsic(ID: Intrinsic::arm_mve_vldr_gather_base_wb,
485 Types: {Ty, Ptr->getType()},
486 Args: {Ptr, Builder.getInt32(C: Increment)});
487 else
488 return Builder.CreateIntrinsic(
489 ID: Intrinsic::arm_mve_vldr_gather_base_wb_predicated,
490 Types: {Ty, Ptr->getType(), Mask->getType()},
491 Args: {Ptr, Builder.getInt32(C: Increment), Mask});
492}
493
494Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
495 IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) {
496 using namespace PatternMatch;
497
498 Type *MemoryTy = I->getType();
499 Type *ResultTy = MemoryTy;
500
501 unsigned Unsigned = 1;
502 // The size of the gather was already checked in isLegalTypeAndAlignment;
503 // if it was not a full vector width an appropriate extend should follow.
504 auto *Extend = Root;
505 bool TruncResult = false;
506 if (MemoryTy->getPrimitiveSizeInBits() < 128) {
507 if (I->hasOneUse()) {
508 // If the gather has a single extend of the correct type, use an extending
509 // gather and replace the ext. In which case the correct root to replace
510 // is not the CallInst itself, but the instruction which extends it.
511 Instruction* User = cast<Instruction>(Val: *I->users().begin());
512 if (isa<SExtInst>(Val: User) &&
513 User->getType()->getPrimitiveSizeInBits() == 128) {
514 LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: "
515 << *User << "\n");
516 Extend = User;
517 ResultTy = User->getType();
518 Unsigned = 0;
519 } else if (isa<ZExtInst>(Val: User) &&
520 User->getType()->getPrimitiveSizeInBits() == 128) {
521 LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: "
522 << *ResultTy << "\n");
523 Extend = User;
524 ResultTy = User->getType();
525 }
526 }
527
528 // If an extend hasn't been found and the type is an integer, create an
529 // extending gather and truncate back to the original type.
530 if (ResultTy->getPrimitiveSizeInBits() < 128 &&
531 ResultTy->isIntOrIntVectorTy()) {
532 ResultTy = ResultTy->getWithNewBitWidth(
533 NewBitWidth: 128 / cast<FixedVectorType>(Val: ResultTy)->getNumElements());
534 TruncResult = true;
535 LLVM_DEBUG(dbgs() << "masked gathers: Small input type, truncing to: "
536 << *ResultTy << "\n");
537 }
538
539 // The final size of the gather must be a full vector width
540 if (ResultTy->getPrimitiveSizeInBits() != 128) {
541 LLVM_DEBUG(dbgs() << "masked gathers: Extend needed but not provided "
542 "from the correct type. Expanding\n");
543 return nullptr;
544 }
545 }
546
547 Value *Offsets;
548 int Scale;
549 Value *BasePtr = decomposePtr(
550 Ptr, Offsets, Scale, Ty: cast<FixedVectorType>(Val: ResultTy), MemoryTy, Builder);
551 if (!BasePtr)
552 return nullptr;
553
554 Root = Extend;
555 Value *Mask = I->getArgOperand(i: 1);
556 Instruction *Load = nullptr;
557 if (!match(V: Mask, P: m_One()))
558 Load = Builder.CreateIntrinsic(
559 ID: Intrinsic::arm_mve_vldr_gather_offset_predicated,
560 Types: {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()},
561 Args: {BasePtr, Offsets, Builder.getInt32(C: MemoryTy->getScalarSizeInBits()),
562 Builder.getInt32(C: Scale), Builder.getInt32(C: Unsigned), Mask});
563 else
564 Load = Builder.CreateIntrinsic(
565 ID: Intrinsic::arm_mve_vldr_gather_offset,
566 Types: {ResultTy, BasePtr->getType(), Offsets->getType()},
567 Args: {BasePtr, Offsets, Builder.getInt32(C: MemoryTy->getScalarSizeInBits()),
568 Builder.getInt32(C: Scale), Builder.getInt32(C: Unsigned)});
569
570 if (TruncResult) {
571 Load = TruncInst::Create(Instruction::Trunc, S: Load, Ty: MemoryTy);
572 Builder.Insert(I: Load);
573 }
574 return Load;
575}
576
577Instruction *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) {
578 using namespace PatternMatch;
579 LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n"
580 << *I << "\n");
581
582 // @llvm.masked.scatter.*(data, ptrs, alignment, mask)
583 // Attempt to turn the masked scatter in I into a MVE intrinsic
584 // Potentially optimising the addressing modes as we do so.
585 Value *Input = I->getArgOperand(i: 0);
586 Value *Ptr = I->getArgOperand(i: 1);
587 Align Alignment = I->getParamAlign(ArgNo: 1).valueOrOne();
588 auto *Ty = cast<FixedVectorType>(Val: Input->getType());
589
590 if (!isLegalTypeAndAlignment(NumElements: Ty->getNumElements(), ElemSize: Ty->getScalarSizeInBits(),
591 Alignment))
592 return nullptr;
593
594 lookThroughBitcast(Ptr);
595 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
596
597 IRBuilder<> Builder(I->getContext());
598 Builder.SetInsertPoint(I);
599 Builder.SetCurrentDebugLocation(I->getDebugLoc());
600
601 Instruction *Store = tryCreateIncrementingGatScat(I, Ptr, Builder);
602 if (!Store)
603 Store = tryCreateMaskedScatterOffset(I, Offsets: Ptr, Builder);
604 if (!Store)
605 Store = tryCreateMaskedScatterBase(I, Ptr, Builder);
606 if (!Store)
607 return nullptr;
608
609 LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n"
610 << *Store << "\n");
611 I->eraseFromParent();
612 return Store;
613}
614
615Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBase(
616 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
617 using namespace PatternMatch;
618 Value *Input = I->getArgOperand(i: 0);
619 auto *Ty = cast<FixedVectorType>(Val: Input->getType());
620 // Only QR variants allow truncating
621 if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) {
622 // Can't build an intrinsic for this
623 return nullptr;
624 }
625 Value *Mask = I->getArgOperand(i: 2);
626 // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask)
627 LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n");
628 if (match(V: Mask, P: m_One()))
629 return Builder.CreateIntrinsic(ID: Intrinsic::arm_mve_vstr_scatter_base,
630 Types: {Ptr->getType(), Input->getType()},
631 Args: {Ptr, Builder.getInt32(C: Increment), Input});
632 else
633 return Builder.CreateIntrinsic(
634 ID: Intrinsic::arm_mve_vstr_scatter_base_predicated,
635 Types: {Ptr->getType(), Input->getType(), Mask->getType()},
636 Args: {Ptr, Builder.getInt32(C: Increment), Input, Mask});
637}
638
639Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB(
640 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
641 using namespace PatternMatch;
642 Value *Input = I->getArgOperand(i: 0);
643 auto *Ty = cast<FixedVectorType>(Val: Input->getType());
644 LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers "
645 << "with writeback\n");
646 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
647 // Can't build an intrinsic for this
648 return nullptr;
649 Value *Mask = I->getArgOperand(i: 2);
650 if (match(V: Mask, P: m_One()))
651 return Builder.CreateIntrinsic(ID: Intrinsic::arm_mve_vstr_scatter_base_wb,
652 Types: {Ptr->getType(), Input->getType()},
653 Args: {Ptr, Builder.getInt32(C: Increment), Input});
654 else
655 return Builder.CreateIntrinsic(
656 ID: Intrinsic::arm_mve_vstr_scatter_base_wb_predicated,
657 Types: {Ptr->getType(), Input->getType(), Mask->getType()},
658 Args: {Ptr, Builder.getInt32(C: Increment), Input, Mask});
659}
660
661Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterOffset(
662 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
663 using namespace PatternMatch;
664 Value *Input = I->getArgOperand(i: 0);
665 Value *Mask = I->getArgOperand(i: 2);
666 Type *InputTy = Input->getType();
667 Type *MemoryTy = InputTy;
668
669 LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing"
670 << " to base + vector of offsets\n");
671 // If the input has been truncated, try to integrate that trunc into the
672 // scatter instruction (we don't care about alignment here)
673 if (TruncInst *Trunc = dyn_cast<TruncInst>(Val: Input)) {
674 Value *PreTrunc = Trunc->getOperand(i_nocapture: 0);
675 Type *PreTruncTy = PreTrunc->getType();
676 if (PreTruncTy->getPrimitiveSizeInBits() == 128) {
677 Input = PreTrunc;
678 InputTy = PreTruncTy;
679 }
680 }
681 bool ExtendInput = false;
682 if (InputTy->getPrimitiveSizeInBits() < 128 &&
683 InputTy->isIntOrIntVectorTy()) {
684 // If we can't find a trunc to incorporate into the instruction, create an
685 // implicit one with a zext, so that we can still create a scatter. We know
686 // that the input type is 4x/8x/16x and of type i8/i16/i32, so any type
687 // smaller than 128 bits will divide evenly into a 128bit vector.
688 InputTy = InputTy->getWithNewBitWidth(
689 NewBitWidth: 128 / cast<FixedVectorType>(Val: InputTy)->getNumElements());
690 ExtendInput = true;
691 LLVM_DEBUG(dbgs() << "masked scatters: Small input type, will extend:\n"
692 << *Input << "\n");
693 }
694 if (InputTy->getPrimitiveSizeInBits() != 128) {
695 LLVM_DEBUG(dbgs() << "masked scatters: cannot create scatters for "
696 "non-standard input types. Expanding.\n");
697 return nullptr;
698 }
699
700 Value *Offsets;
701 int Scale;
702 Value *BasePtr = decomposePtr(
703 Ptr, Offsets, Scale, Ty: cast<FixedVectorType>(Val: InputTy), MemoryTy, Builder);
704 if (!BasePtr)
705 return nullptr;
706
707 if (ExtendInput)
708 Input = Builder.CreateZExt(V: Input, DestTy: InputTy);
709 if (!match(V: Mask, P: m_One()))
710 return Builder.CreateIntrinsic(
711 ID: Intrinsic::arm_mve_vstr_scatter_offset_predicated,
712 Types: {BasePtr->getType(), Offsets->getType(), Input->getType(),
713 Mask->getType()},
714 Args: {BasePtr, Offsets, Input,
715 Builder.getInt32(C: MemoryTy->getScalarSizeInBits()),
716 Builder.getInt32(C: Scale), Mask});
717 else
718 return Builder.CreateIntrinsic(
719 ID: Intrinsic::arm_mve_vstr_scatter_offset,
720 Types: {BasePtr->getType(), Offsets->getType(), Input->getType()},
721 Args: {BasePtr, Offsets, Input,
722 Builder.getInt32(C: MemoryTy->getScalarSizeInBits()),
723 Builder.getInt32(C: Scale)});
724}
725
726Instruction *MVEGatherScatterLowering::tryCreateIncrementingGatScat(
727 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
728 FixedVectorType *Ty;
729 if (I->getIntrinsicID() == Intrinsic::masked_gather)
730 Ty = cast<FixedVectorType>(Val: I->getType());
731 else
732 Ty = cast<FixedVectorType>(Val: I->getArgOperand(i: 0)->getType());
733
734 // Incrementing gathers only exist for v4i32
735 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
736 return nullptr;
737 // Incrementing gathers are not beneficial outside of a loop
738 Loop *L = LI->getLoopFor(BB: I->getParent());
739 if (L == nullptr)
740 return nullptr;
741
742 // Decompose the GEP into Base and Offsets
743 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Val: Ptr);
744 Value *Offsets;
745 Value *BasePtr = decomposeGEP(Offsets, Ty, GEP, Builder);
746 if (!BasePtr)
747 return nullptr;
748
749 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
750 "wb gather/scatter\n");
751
752 // The gep was in charge of making sure the offsets are scaled correctly
753 // - calculate that factor so it can be applied by hand
754 int TypeScale =
755 computeScale(GEPElemSize: DL->getTypeSizeInBits(Ty: GEP->getSourceElementType()),
756 MemoryElemSize: DL->getTypeSizeInBits(Ty: GEP->getType()) /
757 cast<FixedVectorType>(Val: GEP->getType())->getNumElements());
758 if (TypeScale == -1)
759 return nullptr;
760
761 if (GEP->hasOneUse()) {
762 // Only in this case do we want to build a wb gather, because the wb will
763 // change the phi which does affect other users of the gep (which will still
764 // be using the phi in the old way)
765 if (auto *Load = tryCreateIncrementingWBGatScat(I, BasePtr, Ptr: Offsets,
766 TypeScale, Builder))
767 return Load;
768 }
769
770 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
771 "non-wb gather/scatter\n");
772
773 std::pair<Value *, int64_t> Add = getVarAndConst(Inst: Offsets, TypeScale);
774 if (Add.first == nullptr)
775 return nullptr;
776 Value *OffsetsIncoming = Add.first;
777 int64_t Immediate = Add.second;
778
779 // Make sure the offsets are scaled correctly
780 Instruction *ScaledOffsets = BinaryOperator::Create(
781 Op: Instruction::Shl, S1: OffsetsIncoming,
782 S2: Builder.CreateVectorSplat(NumElts: Ty->getNumElements(),
783 V: Builder.getInt32(C: TypeScale)),
784 Name: "ScaledIndex", InsertBefore: I->getIterator());
785 // Add the base to the offsets
786 OffsetsIncoming = BinaryOperator::Create(
787 Op: Instruction::Add, S1: ScaledOffsets,
788 S2: Builder.CreateVectorSplat(
789 NumElts: Ty->getNumElements(),
790 V: Builder.CreatePtrToInt(
791 V: BasePtr,
792 DestTy: cast<VectorType>(Val: ScaledOffsets->getType())->getElementType())),
793 Name: "StartIndex", InsertBefore: I->getIterator());
794
795 if (I->getIntrinsicID() == Intrinsic::masked_gather)
796 return tryCreateMaskedGatherBase(I, Ptr: OffsetsIncoming, Builder, Increment: Immediate);
797 else
798 return tryCreateMaskedScatterBase(I, Ptr: OffsetsIncoming, Builder, Increment: Immediate);
799}
800
801Instruction *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat(
802 IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale,
803 IRBuilder<> &Builder) {
804 // Check whether this gather's offset is incremented by a constant - if so,
805 // and the load is of the right type, we can merge this into a QI gather
806 Loop *L = LI->getLoopFor(BB: I->getParent());
807 // Offsets that are worth merging into this instruction will be incremented
808 // by a constant, thus we're looking for an add of a phi and a constant
809 PHINode *Phi = dyn_cast<PHINode>(Val: Offsets);
810 if (Phi == nullptr || Phi->getNumIncomingValues() != 2 ||
811 Phi->getParent() != L->getHeader() || !Phi->hasNUses(N: 2))
812 // No phi means no IV to write back to; if there is a phi, we expect it
813 // to have exactly two incoming values; the only phis we are interested in
814 // will be loop IV's and have exactly two uses, one in their increment and
815 // one in the gather's gep
816 return nullptr;
817
818 unsigned IncrementIndex =
819 Phi->getIncomingBlock(i: 0) == L->getLoopLatch() ? 0 : 1;
820 // Look through the phi to the phi increment
821 Offsets = Phi->getIncomingValue(i: IncrementIndex);
822
823 std::pair<Value *, int64_t> Add = getVarAndConst(Inst: Offsets, TypeScale);
824 if (Add.first == nullptr)
825 return nullptr;
826 Value *OffsetsIncoming = Add.first;
827 int64_t Immediate = Add.second;
828 if (OffsetsIncoming != Phi)
829 // Then the increment we are looking at is not an increment of the
830 // induction variable, and we don't want to do a writeback
831 return nullptr;
832
833 Builder.SetInsertPoint(&Phi->getIncomingBlock(i: 1 - IncrementIndex)->back());
834 unsigned NumElems =
835 cast<FixedVectorType>(Val: OffsetsIncoming->getType())->getNumElements();
836
837 // Make sure the offsets are scaled correctly
838 Instruction *ScaledOffsets = BinaryOperator::Create(
839 Op: Instruction::Shl, S1: Phi->getIncomingValue(i: 1 - IncrementIndex),
840 S2: Builder.CreateVectorSplat(NumElts: NumElems, V: Builder.getInt32(C: TypeScale)),
841 Name: "ScaledIndex",
842 InsertBefore: Phi->getIncomingBlock(i: 1 - IncrementIndex)->back().getIterator());
843 // Add the base to the offsets
844 OffsetsIncoming = BinaryOperator::Create(
845 Op: Instruction::Add, S1: ScaledOffsets,
846 S2: Builder.CreateVectorSplat(
847 NumElts: NumElems,
848 V: Builder.CreatePtrToInt(
849 V: BasePtr,
850 DestTy: cast<VectorType>(Val: ScaledOffsets->getType())->getElementType())),
851 Name: "StartIndex",
852 InsertBefore: Phi->getIncomingBlock(i: 1 - IncrementIndex)->back().getIterator());
853 // The gather is pre-incrementing
854 OffsetsIncoming = BinaryOperator::Create(
855 Op: Instruction::Sub, S1: OffsetsIncoming,
856 S2: Builder.CreateVectorSplat(NumElts: NumElems, V: Builder.getInt32(C: Immediate)),
857 Name: "PreIncrementStartIndex",
858 InsertBefore: Phi->getIncomingBlock(i: 1 - IncrementIndex)->back().getIterator());
859 Phi->setIncomingValue(i: 1 - IncrementIndex, V: OffsetsIncoming);
860
861 Builder.SetInsertPoint(I);
862
863 Instruction *EndResult;
864 Instruction *NewInduction;
865 if (I->getIntrinsicID() == Intrinsic::masked_gather) {
866 // Build the incrementing gather
867 Value *Load = tryCreateMaskedGatherBaseWB(I, Ptr: Phi, Builder, Increment: Immediate);
868 // One value to be handed to whoever uses the gather, one is the loop
869 // increment
870 EndResult = ExtractValueInst::Create(Agg: Load, Idxs: 0, NameStr: "Gather");
871 NewInduction = ExtractValueInst::Create(Agg: Load, Idxs: 1, NameStr: "GatherIncrement");
872 Builder.Insert(I: EndResult);
873 Builder.Insert(I: NewInduction);
874 } else {
875 // Build the incrementing scatter
876 EndResult = NewInduction =
877 tryCreateMaskedScatterBaseWB(I, Ptr: Phi, Builder, Increment: Immediate);
878 }
879 Instruction *AddInst = cast<Instruction>(Val: Offsets);
880 AddInst->replaceAllUsesWith(V: NewInduction);
881 AddInst->eraseFromParent();
882 Phi->setIncomingValue(i: IncrementIndex, V: NewInduction);
883
884 return EndResult;
885}
886
887void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,
888 Value *OffsSecondOperand,
889 unsigned StartIndex) {
890 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n");
891 assert(Phi->getNumIncomingValues() == 2);
892 BasicBlock *NewIndexBlock = Phi->getIncomingBlock(i: StartIndex);
893 BasicBlock::iterator InsertionPoint = NewIndexBlock->back().getIterator();
894 // Initialize the phi with a vector that contains a sum of the constants
895 Instruction *NewIndex = BinaryOperator::Create(
896 Op: Instruction::Add, S1: Phi->getIncomingValue(i: StartIndex), S2: OffsSecondOperand,
897 Name: "PushedOutAdd", InsertBefore: InsertionPoint);
898 unsigned IncrementIndex = StartIndex == 0 ? 1 : 0;
899
900 // Order such that start index comes first (this reduces mov's)
901 Value *IncrementIndexValue = Phi->getIncomingValue(i: IncrementIndex);
902 BasicBlock *IncrementIndexBlock = Phi->getIncomingBlock(i: IncrementIndex);
903 Phi->setIncomingValue(i: 0, V: NewIndex);
904 Phi->setIncomingBlock(i: 0, BB: NewIndexBlock);
905 Phi->setIncomingValue(i: 1, V: IncrementIndexValue);
906 Phi->setIncomingBlock(i: 1, BB: IncrementIndexBlock);
907}
908
909void MVEGatherScatterLowering::pushOutMulShl(unsigned Opcode, PHINode *&Phi,
910 Value *IncrementPerRound,
911 Value *OffsSecondOperand,
912 unsigned LoopIncrement,
913 IRBuilder<> &Builder) {
914 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");
915 assert(Phi->getNumIncomingValues() == 2);
916
917 // Create a new scalar add outside of the loop and transform it to a splat
918 // by which loop variable can be incremented
919 BasicBlock *StartIndexBlock =
920 Phi->getIncomingBlock(i: LoopIncrement == 1 ? 0 : 1);
921 BasicBlock::iterator InsertionPoint = StartIndexBlock->back().getIterator();
922
923 // Create a new index
924 Value *StartIndex =
925 BinaryOperator::Create(Op: (Instruction::BinaryOps)Opcode,
926 S1: Phi->getIncomingValue(i: LoopIncrement == 1 ? 0 : 1),
927 S2: OffsSecondOperand, Name: "PushedOutMul", InsertBefore: InsertionPoint);
928
929 Instruction *Product =
930 BinaryOperator::Create(Op: (Instruction::BinaryOps)Opcode, S1: IncrementPerRound,
931 S2: OffsSecondOperand, Name: "Product", InsertBefore: InsertionPoint);
932 BasicBlock *NewIncrementBlock = Phi->getIncomingBlock(i: LoopIncrement);
933 BasicBlock::iterator NewIncrInsertPt =
934 NewIncrementBlock->back().getIterator();
935 NewIncrInsertPt = std::prev(x: NewIncrInsertPt);
936
937 // Increment NewIndex by Product instead of the multiplication
938 Instruction *NewIncrement = BinaryOperator::Create(
939 Op: Instruction::Add, S1: Phi, S2: Product, Name: "IncrementPushedOutMul", InsertBefore: NewIncrInsertPt);
940
941 Phi->setIncomingValue(i: 0, V: StartIndex);
942 Phi->setIncomingBlock(i: 0, BB: StartIndexBlock);
943 Phi->setIncomingValue(i: 1, V: NewIncrement);
944 Phi->setIncomingBlock(i: 1, BB: NewIncrementBlock);
945}
946
947// Check whether all usages of this instruction are as offsets of
948// gathers/scatters or simple arithmetics only used by gathers/scatters
949static bool hasAllGatScatUsers(Instruction *I, const DataLayout &DL) {
950 if (I->use_empty()) {
951 return false;
952 }
953 bool Gatscat = true;
954 for (User *U : I->users()) {
955 if (!isa<Instruction>(Val: U))
956 return false;
957 if (isa<GetElementPtrInst>(Val: U) ||
958 isGatherScatter(IntInst: dyn_cast<IntrinsicInst>(Val: U))) {
959 return Gatscat;
960 } else {
961 unsigned OpCode = cast<Instruction>(Val: U)->getOpcode();
962 if ((OpCode == Instruction::Add || OpCode == Instruction::Mul ||
963 OpCode == Instruction::Shl ||
964 isAddLikeOr(I: cast<Instruction>(Val: U), DL)) &&
965 hasAllGatScatUsers(I: cast<Instruction>(Val: U), DL)) {
966 continue;
967 }
968 return false;
969 }
970 }
971 return Gatscat;
972}
973
974bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
975 LoopInfo *LI) {
976 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize: "
977 << *Offsets << "\n");
978 // Optimise the addresses of gathers/scatters by moving invariant
979 // calculations out of the loop
980 if (!isa<Instruction>(Val: Offsets))
981 return false;
982 Instruction *Offs = cast<Instruction>(Val: Offsets);
983 if (Offs->getOpcode() != Instruction::Add && !isAddLikeOr(I: Offs, DL: *DL) &&
984 Offs->getOpcode() != Instruction::Mul &&
985 Offs->getOpcode() != Instruction::Shl)
986 return false;
987 Loop *L = LI->getLoopFor(BB);
988 if (L == nullptr)
989 return false;
990 if (!Offs->hasOneUse()) {
991 if (!hasAllGatScatUsers(I: Offs, DL: *DL))
992 return false;
993 }
994
995 // Find out which, if any, operand of the instruction
996 // is a phi node
997 PHINode *Phi;
998 int OffsSecondOp;
999 if (isa<PHINode>(Val: Offs->getOperand(i: 0))) {
1000 Phi = cast<PHINode>(Val: Offs->getOperand(i: 0));
1001 OffsSecondOp = 1;
1002 } else if (isa<PHINode>(Val: Offs->getOperand(i: 1))) {
1003 Phi = cast<PHINode>(Val: Offs->getOperand(i: 1));
1004 OffsSecondOp = 0;
1005 } else {
1006 bool Changed = false;
1007 if (isa<Instruction>(Val: Offs->getOperand(i: 0)) &&
1008 L->contains(Inst: cast<Instruction>(Val: Offs->getOperand(i: 0))))
1009 Changed |= optimiseOffsets(Offsets: Offs->getOperand(i: 0), BB, LI);
1010 if (isa<Instruction>(Val: Offs->getOperand(i: 1)) &&
1011 L->contains(Inst: cast<Instruction>(Val: Offs->getOperand(i: 1))))
1012 Changed |= optimiseOffsets(Offsets: Offs->getOperand(i: 1), BB, LI);
1013 if (!Changed)
1014 return false;
1015 if (isa<PHINode>(Val: Offs->getOperand(i: 0))) {
1016 Phi = cast<PHINode>(Val: Offs->getOperand(i: 0));
1017 OffsSecondOp = 1;
1018 } else if (isa<PHINode>(Val: Offs->getOperand(i: 1))) {
1019 Phi = cast<PHINode>(Val: Offs->getOperand(i: 1));
1020 OffsSecondOp = 0;
1021 } else {
1022 return false;
1023 }
1024 }
1025 // A phi node we want to perform this function on should be from the
1026 // loop header.
1027 if (Phi->getParent() != L->getHeader())
1028 return false;
1029
1030 // We're looking for a simple add recurrence.
1031 BinaryOperator *IncInstruction;
1032 Value *Start, *IncrementPerRound;
1033 if (!matchSimpleRecurrence(P: Phi, BO&: IncInstruction, Start, Step&: IncrementPerRound) ||
1034 IncInstruction->getOpcode() != Instruction::Add)
1035 return false;
1036
1037 int IncrementingBlock = Phi->getIncomingValue(i: 0) == IncInstruction ? 0 : 1;
1038
1039 // Get the value that is added to/multiplied with the phi
1040 Value *OffsSecondOperand = Offs->getOperand(i: OffsSecondOp);
1041
1042 if (IncrementPerRound->getType() != OffsSecondOperand->getType() ||
1043 !L->isLoopInvariant(V: OffsSecondOperand))
1044 // Something has gone wrong, abort
1045 return false;
1046
1047 // Only proceed if the increment per round is a constant or an instruction
1048 // which does not originate from within the loop
1049 if (!isa<Constant>(Val: IncrementPerRound) &&
1050 !(isa<Instruction>(Val: IncrementPerRound) &&
1051 !L->contains(Inst: cast<Instruction>(Val: IncrementPerRound))))
1052 return false;
1053
1054 // If the phi is not used by anything else, we can just adapt it when
1055 // replacing the instruction; if it is, we'll have to duplicate it
1056 PHINode *NewPhi;
1057 if (Phi->hasNUses(N: 2)) {
1058 // No other users -> reuse existing phi (One user is the instruction
1059 // we're looking at, the other is the phi increment)
1060 if (!IncInstruction->hasOneUse()) {
1061 // If the incrementing instruction does have more users than
1062 // our phi, we need to copy it
1063 IncInstruction = BinaryOperator::Create(
1064 Op: Instruction::BinaryOps(IncInstruction->getOpcode()), S1: Phi,
1065 S2: IncrementPerRound, Name: "LoopIncrement", InsertBefore: IncInstruction->getIterator());
1066 Phi->setIncomingValue(i: IncrementingBlock, V: IncInstruction);
1067 }
1068 NewPhi = Phi;
1069 } else {
1070 // There are other users -> create a new phi
1071 NewPhi = PHINode::Create(Ty: Phi->getType(), NumReservedValues: 2, NameStr: "NewPhi", InsertBefore: Phi->getIterator());
1072 // Copy the incoming values of the old phi
1073 NewPhi->addIncoming(V: Phi->getIncomingValue(i: IncrementingBlock == 1 ? 0 : 1),
1074 BB: Phi->getIncomingBlock(i: IncrementingBlock == 1 ? 0 : 1));
1075 IncInstruction = BinaryOperator::Create(
1076 Op: Instruction::BinaryOps(IncInstruction->getOpcode()), S1: NewPhi,
1077 S2: IncrementPerRound, Name: "LoopIncrement", InsertBefore: IncInstruction->getIterator());
1078 NewPhi->addIncoming(V: IncInstruction,
1079 BB: Phi->getIncomingBlock(i: IncrementingBlock));
1080 IncrementingBlock = 1;
1081 }
1082
1083 IRBuilder<> Builder(BB->getContext());
1084 Builder.SetInsertPoint(Phi);
1085 Builder.SetCurrentDebugLocation(Offs->getDebugLoc());
1086
1087 switch (Offs->getOpcode()) {
1088 case Instruction::Add:
1089 case Instruction::Or:
1090 pushOutAdd(Phi&: NewPhi, OffsSecondOperand, StartIndex: IncrementingBlock == 1 ? 0 : 1);
1091 break;
1092 case Instruction::Mul:
1093 case Instruction::Shl:
1094 pushOutMulShl(Opcode: Offs->getOpcode(), Phi&: NewPhi, IncrementPerRound,
1095 OffsSecondOperand, LoopIncrement: IncrementingBlock, Builder);
1096 break;
1097 default:
1098 return false;
1099 }
1100 LLVM_DEBUG(dbgs() << "masked gathers/scatters: simplified loop variable "
1101 << "add/mul\n");
1102
1103 // The instruction has now been "absorbed" into the phi value
1104 Offs->replaceAllUsesWith(V: NewPhi);
1105 Offs->eraseFromParent();
1106 // Clean up the old increment in case it's unused because we built a new
1107 // one
1108 if (IncInstruction->use_empty())
1109 IncInstruction->eraseFromParent();
1110
1111 return true;
1112}
1113
1114static Value *CheckAndCreateOffsetAdd(Value *X, unsigned ScaleX, Value *Y,
1115 unsigned ScaleY, IRBuilder<> &Builder) {
1116 // Splat the non-vector value to a vector of the given type - if the value is
1117 // a constant (and its value isn't too big), we can even use this opportunity
1118 // to scale it to the size of the vector elements
1119 auto FixSummands = [&Builder](FixedVectorType *&VT, Value *&NonVectorVal) {
1120 ConstantInt *Const;
1121 if ((Const = dyn_cast<ConstantInt>(Val: NonVectorVal)) &&
1122 VT->getElementType() != NonVectorVal->getType()) {
1123 unsigned TargetElemSize = VT->getElementType()->getPrimitiveSizeInBits();
1124 uint64_t N = Const->getZExtValue();
1125 if (N < (unsigned)(1 << (TargetElemSize - 1))) {
1126 NonVectorVal = Builder.CreateVectorSplat(
1127 NumElts: VT->getNumElements(), V: Builder.getIntN(N: TargetElemSize, C: N));
1128 return;
1129 }
1130 }
1131 NonVectorVal =
1132 Builder.CreateVectorSplat(NumElts: VT->getNumElements(), V: NonVectorVal);
1133 };
1134
1135 FixedVectorType *XElType = dyn_cast<FixedVectorType>(Val: X->getType());
1136 FixedVectorType *YElType = dyn_cast<FixedVectorType>(Val: Y->getType());
1137 // If one of X, Y is not a vector, we have to splat it in order
1138 // to add the two of them.
1139 if (XElType && !YElType) {
1140 FixSummands(XElType, Y);
1141 YElType = cast<FixedVectorType>(Val: Y->getType());
1142 } else if (YElType && !XElType) {
1143 FixSummands(YElType, X);
1144 XElType = cast<FixedVectorType>(Val: X->getType());
1145 }
1146 assert(XElType && YElType && "Unknown vector types");
1147 // Check that the summands are of compatible types
1148 if (XElType != YElType) {
1149 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incompatible gep offsets\n");
1150 return nullptr;
1151 }
1152
1153 if (XElType->getElementType()->getScalarSizeInBits() != 32) {
1154 // Check that by adding the vectors we do not accidentally
1155 // create an overflow
1156 Constant *ConstX = dyn_cast<Constant>(Val: X);
1157 Constant *ConstY = dyn_cast<Constant>(Val: Y);
1158 if (!ConstX || !ConstY)
1159 return nullptr;
1160 unsigned TargetElemSize = 128 / XElType->getNumElements();
1161 for (unsigned i = 0; i < XElType->getNumElements(); i++) {
1162 ConstantInt *ConstXEl =
1163 dyn_cast<ConstantInt>(Val: ConstX->getAggregateElement(Elt: i));
1164 ConstantInt *ConstYEl =
1165 dyn_cast<ConstantInt>(Val: ConstY->getAggregateElement(Elt: i));
1166 if (!ConstXEl || !ConstYEl ||
1167 ConstXEl->getZExtValue() * ScaleX +
1168 ConstYEl->getZExtValue() * ScaleY >=
1169 (unsigned)(1 << (TargetElemSize - 1)))
1170 return nullptr;
1171 }
1172 }
1173
1174 Value *XScale = Builder.CreateVectorSplat(
1175 NumElts: XElType->getNumElements(),
1176 V: Builder.getIntN(N: XElType->getScalarSizeInBits(), C: ScaleX));
1177 Value *YScale = Builder.CreateVectorSplat(
1178 NumElts: YElType->getNumElements(),
1179 V: Builder.getIntN(N: YElType->getScalarSizeInBits(), C: ScaleY));
1180 Value *Add = Builder.CreateAdd(LHS: Builder.CreateMul(LHS: X, RHS: XScale),
1181 RHS: Builder.CreateMul(LHS: Y, RHS: YScale));
1182
1183 if (checkOffsetSize(Offsets: Add, TargetElemCount: XElType->getNumElements()))
1184 return Add;
1185 else
1186 return nullptr;
1187}
1188
1189Value *MVEGatherScatterLowering::foldGEP(GetElementPtrInst *GEP,
1190 Value *&Offsets, unsigned &Scale,
1191 IRBuilder<> &Builder) {
1192 Value *GEPPtr = GEP->getPointerOperand();
1193 Offsets = GEP->getOperand(i_nocapture: 1);
1194 Scale = DL->getTypeAllocSize(Ty: GEP->getSourceElementType());
1195 // We only merge geps with constant offsets, because only for those
1196 // we can make sure that we do not cause an overflow
1197 if (GEP->getNumIndices() != 1 || !isa<Constant>(Val: Offsets))
1198 return nullptr;
1199 if (GetElementPtrInst *BaseGEP = dyn_cast<GetElementPtrInst>(Val: GEPPtr)) {
1200 // Merge the two geps into one
1201 Value *BaseBasePtr = foldGEP(GEP: BaseGEP, Offsets, Scale, Builder);
1202 if (!BaseBasePtr)
1203 return nullptr;
1204 Offsets = CheckAndCreateOffsetAdd(
1205 X: Offsets, ScaleX: Scale, Y: GEP->getOperand(i_nocapture: 1),
1206 ScaleY: DL->getTypeAllocSize(Ty: GEP->getSourceElementType()), Builder);
1207 if (Offsets == nullptr)
1208 return nullptr;
1209 Scale = 1; // Scale is always an i8 at this point.
1210 return BaseBasePtr;
1211 }
1212 return GEPPtr;
1213}
1214
1215bool MVEGatherScatterLowering::optimiseAddress(Value *Address, BasicBlock *BB,
1216 LoopInfo *LI) {
1217 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Val: Address);
1218 if (!GEP)
1219 return false;
1220 bool Changed = false;
1221 if (GEP->hasOneUse() && isa<GetElementPtrInst>(Val: GEP->getPointerOperand())) {
1222 IRBuilder<> Builder(GEP->getContext());
1223 Builder.SetInsertPoint(GEP);
1224 Builder.SetCurrentDebugLocation(GEP->getDebugLoc());
1225 Value *Offsets;
1226 unsigned Scale;
1227 Value *Base = foldGEP(GEP, Offsets, Scale, Builder);
1228 // We only want to merge the geps if there is a real chance that they can be
1229 // used by an MVE gather; thus the offset has to have the correct size
1230 // (always i32 if it is not of vector type) and the base has to be a
1231 // pointer.
1232 if (Offsets && Base && Base != GEP) {
1233 assert(Scale == 1 && "Expected to fold GEP to a scale of 1");
1234 Type *BaseTy = Builder.getPtrTy();
1235 if (auto *VecTy = dyn_cast<FixedVectorType>(Val: Base->getType()))
1236 BaseTy = FixedVectorType::get(ElementType: BaseTy, FVTy: VecTy);
1237 GetElementPtrInst *NewAddress = GetElementPtrInst::Create(
1238 PointeeType: Builder.getInt8Ty(), Ptr: Builder.CreateBitCast(V: Base, DestTy: BaseTy), IdxList: Offsets,
1239 NameStr: "gep.merged", InsertBefore: GEP->getIterator());
1240 LLVM_DEBUG(dbgs() << "Folded GEP: " << *GEP
1241 << "\n new : " << *NewAddress << "\n");
1242 GEP->replaceAllUsesWith(
1243 V: Builder.CreateBitCast(V: NewAddress, DestTy: GEP->getType()));
1244 GEP = NewAddress;
1245 Changed = true;
1246 }
1247 }
1248 Changed |= optimiseOffsets(Offsets: GEP->getOperand(i_nocapture: 1), BB: GEP->getParent(), LI);
1249 return Changed;
1250}
1251
1252bool MVEGatherScatterLowering::runOnFunction(Function &F) {
1253 if (!EnableMaskedGatherScatters)
1254 return false;
1255 auto &TPC = getAnalysis<TargetPassConfig>();
1256 auto &TM = TPC.getTM<TargetMachine>();
1257 auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
1258 if (!ST->hasMVEIntegerOps())
1259 return false;
1260 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1261 DL = &F.getDataLayout();
1262 SmallVector<IntrinsicInst *, 4> Gathers;
1263 SmallVector<IntrinsicInst *, 4> Scatters;
1264
1265 bool Changed = false;
1266
1267 for (BasicBlock &BB : F) {
1268 Changed |= SimplifyInstructionsInBlock(BB: &BB);
1269
1270 for (Instruction &I : BB) {
1271 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: &I);
1272 if (II && II->getIntrinsicID() == Intrinsic::masked_gather &&
1273 isa<FixedVectorType>(Val: II->getType())) {
1274 Gathers.push_back(Elt: II);
1275 Changed |= optimiseAddress(Address: II->getArgOperand(i: 0), BB: II->getParent(), LI);
1276 } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter &&
1277 isa<FixedVectorType>(Val: II->getArgOperand(i: 0)->getType())) {
1278 Scatters.push_back(Elt: II);
1279 Changed |= optimiseAddress(Address: II->getArgOperand(i: 1), BB: II->getParent(), LI);
1280 }
1281 }
1282 }
1283 for (IntrinsicInst *I : Gathers) {
1284 Instruction *L = lowerGather(I);
1285 if (L == nullptr)
1286 continue;
1287
1288 // Get rid of any now dead instructions
1289 SimplifyInstructionsInBlock(BB: L->getParent());
1290 Changed = true;
1291 }
1292
1293 for (IntrinsicInst *I : Scatters) {
1294 Instruction *S = lowerScatter(I);
1295 if (S == nullptr)
1296 continue;
1297
1298 // Get rid of any now dead instructions
1299 SimplifyInstructionsInBlock(BB: S->getParent());
1300 Changed = true;
1301 }
1302 return Changed;
1303}
1304