1//===----- SVEIntrinsicOpts - SVE ACLE Intrinsics Opts --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Performs general IR level optimizations on SVE intrinsics.
10//
11// This pass performs the following optimizations:
12//
13// - removes unnecessary ptrue intrinsics (llvm.aarch64.sve.ptrue), e.g:
14// %1 = @llvm.aarch64.sve.ptrue.nxv4i1(i32 31)
15// %2 = @llvm.aarch64.sve.ptrue.nxv8i1(i32 31)
16// ; (%1 can be replaced with a reinterpret of %2)
17//
18// - optimizes ptest intrinsics where the operands are being needlessly
19// converted to and from svbool_t.
20//
21//===----------------------------------------------------------------------===//
22
23#include "AArch64.h"
24#include "Utils/AArch64BaseInfo.h"
25#include "llvm/ADT/PostOrderIterator.h"
26#include "llvm/ADT/SetVector.h"
27#include "llvm/IR/Constants.h"
28#include "llvm/IR/Dominators.h"
29#include "llvm/IR/IRBuilder.h"
30#include "llvm/IR/Instructions.h"
31#include "llvm/IR/IntrinsicInst.h"
32#include "llvm/IR/IntrinsicsAArch64.h"
33#include "llvm/IR/LLVMContext.h"
34#include "llvm/IR/Module.h"
35#include "llvm/IR/PatternMatch.h"
36#include "llvm/InitializePasses.h"
37#include "llvm/Support/Debug.h"
38#include <optional>
39
40using namespace llvm;
41using namespace llvm::PatternMatch;
42
43#define DEBUG_TYPE "aarch64-sve-intrinsic-opts"
44
45namespace {
46struct SVEIntrinsicOpts : public ModulePass {
47 static char ID; // Pass identification, replacement for typeid
48 SVEIntrinsicOpts() : ModulePass(ID) {
49 initializeSVEIntrinsicOptsPass(*PassRegistry::getPassRegistry());
50 }
51
52 bool runOnModule(Module &M) override;
53 void getAnalysisUsage(AnalysisUsage &AU) const override;
54
55private:
56 bool coalescePTrueIntrinsicCalls(BasicBlock &BB,
57 SmallSetVector<IntrinsicInst *, 4> &PTrues);
58 bool optimizePTrueIntrinsicCalls(SmallSetVector<Function *, 4> &Functions);
59 bool optimizePredicateStore(Instruction *I);
60 bool optimizePredicateLoad(Instruction *I);
61
62 bool optimizeInstructions(SmallSetVector<Function *, 4> &Functions);
63
64 /// Operates at the function-scope. I.e., optimizations are applied local to
65 /// the functions themselves.
66 bool optimizeFunctions(SmallSetVector<Function *, 4> &Functions);
67};
68} // end anonymous namespace
69
70void SVEIntrinsicOpts::getAnalysisUsage(AnalysisUsage &AU) const {
71 AU.addRequired<DominatorTreeWrapperPass>();
72 AU.setPreservesCFG();
73}
74
75char SVEIntrinsicOpts::ID = 0;
76static const char *name = "SVE intrinsics optimizations";
77INITIALIZE_PASS_BEGIN(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false)
78INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass);
79INITIALIZE_PASS_END(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false)
80
81ModulePass *llvm::createSVEIntrinsicOptsPass() {
82 return new SVEIntrinsicOpts();
83}
84
85/// Checks if a ptrue intrinsic call is promoted. The act of promoting a
86/// ptrue will introduce zeroing. For example:
87///
88/// %1 = <vscale x 4 x i1> call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31)
89/// %2 = <vscale x 16 x i1> call @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %1)
90/// %3 = <vscale x 8 x i1> call @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %2)
91///
92/// %1 is promoted, because it is converted:
93///
94/// <vscale x 4 x i1> => <vscale x 16 x i1> => <vscale x 8 x i1>
95///
96/// via a sequence of the SVE reinterpret intrinsics convert.{to,from}.svbool.
97static bool isPTruePromoted(IntrinsicInst *PTrue) {
98 // Find all users of this intrinsic that are calls to convert-to-svbool
99 // reinterpret intrinsics.
100 SmallVector<IntrinsicInst *, 4> ConvertToUses;
101 for (User *User : PTrue->users()) {
102 if (match(V: User, P: m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>())) {
103 ConvertToUses.push_back(Elt: cast<IntrinsicInst>(Val: User));
104 }
105 }
106
107 // If no such calls were found, this is ptrue is not promoted.
108 if (ConvertToUses.empty())
109 return false;
110
111 // Otherwise, try to find users of the convert-to-svbool intrinsics that are
112 // calls to the convert-from-svbool intrinsic, and would result in some lanes
113 // being zeroed.
114 const auto *PTrueVTy = cast<ScalableVectorType>(Val: PTrue->getType());
115 for (IntrinsicInst *ConvertToUse : ConvertToUses) {
116 for (User *User : ConvertToUse->users()) {
117 auto *IntrUser = dyn_cast<IntrinsicInst>(Val: User);
118 if (IntrUser && IntrUser->getIntrinsicID() ==
119 Intrinsic::aarch64_sve_convert_from_svbool) {
120 const auto *IntrUserVTy = cast<ScalableVectorType>(Val: IntrUser->getType());
121
122 // Would some lanes become zeroed by the conversion?
123 if (IntrUserVTy->getElementCount().getKnownMinValue() >
124 PTrueVTy->getElementCount().getKnownMinValue())
125 // This is a promoted ptrue.
126 return true;
127 }
128 }
129 }
130
131 // If no matching calls were found, this is not a promoted ptrue.
132 return false;
133}
134
135/// Attempts to coalesce ptrues in a basic block.
136bool SVEIntrinsicOpts::coalescePTrueIntrinsicCalls(
137 BasicBlock &BB, SmallSetVector<IntrinsicInst *, 4> &PTrues) {
138 if (PTrues.size() <= 1)
139 return false;
140
141 // Find the ptrue with the most lanes.
142 auto *MostEncompassingPTrue =
143 *llvm::max_element(Range&: PTrues, C: [](auto *PTrue1, auto *PTrue2) {
144 auto *PTrue1VTy = cast<ScalableVectorType>(PTrue1->getType());
145 auto *PTrue2VTy = cast<ScalableVectorType>(PTrue2->getType());
146 return PTrue1VTy->getElementCount().getKnownMinValue() <
147 PTrue2VTy->getElementCount().getKnownMinValue();
148 });
149
150 // Remove the most encompassing ptrue, as well as any promoted ptrues, leaving
151 // behind only the ptrues to be coalesced.
152 PTrues.remove(X: MostEncompassingPTrue);
153 PTrues.remove_if(P: isPTruePromoted);
154
155 // Hoist MostEncompassingPTrue to the start of the basic block. It is always
156 // safe to do this, since ptrue intrinsic calls are guaranteed to have no
157 // predecessors.
158 MostEncompassingPTrue->moveBefore(BB, I: BB.getFirstInsertionPt());
159
160 LLVMContext &Ctx = BB.getContext();
161 IRBuilder<> Builder(Ctx);
162 Builder.SetInsertPoint(TheBB: &BB, IP: ++MostEncompassingPTrue->getIterator());
163
164 auto *MostEncompassingPTrueVTy =
165 cast<VectorType>(Val: MostEncompassingPTrue->getType());
166 auto *ConvertToSVBool = Builder.CreateIntrinsic(
167 ID: Intrinsic::aarch64_sve_convert_to_svbool, Types: {MostEncompassingPTrueVTy},
168 Args: {MostEncompassingPTrue});
169
170 bool ConvertFromCreated = false;
171 for (auto *PTrue : PTrues) {
172 auto *PTrueVTy = cast<VectorType>(Val: PTrue->getType());
173
174 // Only create the converts if the types are not already the same, otherwise
175 // just use the most encompassing ptrue.
176 if (MostEncompassingPTrueVTy != PTrueVTy) {
177 ConvertFromCreated = true;
178
179 Builder.SetInsertPoint(TheBB: &BB, IP: ++ConvertToSVBool->getIterator());
180 auto *ConvertFromSVBool =
181 Builder.CreateIntrinsic(ID: Intrinsic::aarch64_sve_convert_from_svbool,
182 Types: {PTrueVTy}, Args: {ConvertToSVBool});
183 PTrue->replaceAllUsesWith(V: ConvertFromSVBool);
184 } else
185 PTrue->replaceAllUsesWith(V: MostEncompassingPTrue);
186
187 PTrue->eraseFromParent();
188 }
189
190 // We never used the ConvertTo so remove it
191 if (!ConvertFromCreated)
192 ConvertToSVBool->eraseFromParent();
193
194 return true;
195}
196
197/// The goal of this function is to remove redundant calls to the SVE ptrue
198/// intrinsic in each basic block within the given functions.
199///
200/// SVE ptrues have two representations in LLVM IR:
201/// - a logical representation -- an arbitrary-width scalable vector of i1s,
202/// i.e. <vscale x N x i1>.
203/// - a physical representation (svbool, <vscale x 16 x i1>) -- a 16-element
204/// scalable vector of i1s, i.e. <vscale x 16 x i1>.
205///
206/// The SVE ptrue intrinsic is used to create a logical representation of an SVE
207/// predicate. Suppose that we have two SVE ptrue intrinsic calls: P1 and P2. If
208/// P1 creates a logical SVE predicate that is at least as wide as the logical
209/// SVE predicate created by P2, then all of the bits that are true in the
210/// physical representation of P2 are necessarily also true in the physical
211/// representation of P1. P1 'encompasses' P2, therefore, the intrinsic call to
212/// P2 is redundant and can be replaced by an SVE reinterpret of P1 via
213/// convert.{to,from}.svbool.
214///
215/// Currently, this pass only coalesces calls to SVE ptrue intrinsics
216/// if they match the following conditions:
217///
218/// - the call to the intrinsic uses either the SV_ALL or SV_POW2 patterns.
219/// SV_ALL indicates that all bits of the predicate vector are to be set to
220/// true. SV_POW2 indicates that all bits of the predicate vector up to the
221/// largest power-of-two are to be set to true.
222/// - the result of the call to the intrinsic is not promoted to a wider
223/// predicate. In this case, keeping the extra ptrue leads to better codegen
224/// -- coalescing here would create an irreducible chain of SVE reinterprets
225/// via convert.{to,from}.svbool.
226///
227/// EXAMPLE:
228///
229/// %1 = <vscale x 8 x i1> ptrue(i32 SV_ALL)
230/// ; Logical: <1, 1, 1, 1, 1, 1, 1, 1>
231/// ; Physical: <1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0>
232/// ...
233///
234/// %2 = <vscale x 4 x i1> ptrue(i32 SV_ALL)
235/// ; Logical: <1, 1, 1, 1>
236/// ; Physical: <1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0>
237/// ...
238///
239/// Here, %2 can be replaced by an SVE reinterpret of %1, giving, for instance:
240///
241/// %1 = <vscale x 8 x i1> ptrue(i32 i31)
242/// %2 = <vscale x 16 x i1> convert.to.svbool(<vscale x 8 x i1> %1)
243/// %3 = <vscale x 4 x i1> convert.from.svbool(<vscale x 16 x i1> %2)
244///
245bool SVEIntrinsicOpts::optimizePTrueIntrinsicCalls(
246 SmallSetVector<Function *, 4> &Functions) {
247 bool Changed = false;
248
249 for (auto *F : Functions) {
250 for (auto &BB : *F) {
251 SmallSetVector<IntrinsicInst *, 4> SVAllPTrues;
252 SmallSetVector<IntrinsicInst *, 4> SVPow2PTrues;
253
254 // For each basic block, collect the used ptrues and try to coalesce them.
255 for (Instruction &I : BB) {
256 if (I.use_empty())
257 continue;
258
259 auto *IntrI = dyn_cast<IntrinsicInst>(Val: &I);
260 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
261 continue;
262
263 const auto PTruePattern =
264 cast<ConstantInt>(Val: IntrI->getOperand(i_nocapture: 0))->getZExtValue();
265
266 if (PTruePattern == AArch64SVEPredPattern::all)
267 SVAllPTrues.insert(X: IntrI);
268 if (PTruePattern == AArch64SVEPredPattern::pow2)
269 SVPow2PTrues.insert(X: IntrI);
270 }
271
272 Changed |= coalescePTrueIntrinsicCalls(BB, PTrues&: SVAllPTrues);
273 Changed |= coalescePTrueIntrinsicCalls(BB, PTrues&: SVPow2PTrues);
274 }
275 }
276
277 return Changed;
278}
279
280// This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce
281// scalable stores as late as possible
282bool SVEIntrinsicOpts::optimizePredicateStore(Instruction *I) {
283 auto *F = I->getFunction();
284 auto Attr = F->getFnAttribute(Kind: Attribute::VScaleRange);
285 if (!Attr.isValid())
286 return false;
287
288 unsigned MinVScale = Attr.getVScaleRangeMin();
289 std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax();
290 // The transform needs to know the exact runtime length of scalable vectors
291 if (!MaxVScale || MinVScale != MaxVScale)
292 return false;
293
294 auto *PredType =
295 ScalableVectorType::get(ElementType: Type::getInt1Ty(C&: I->getContext()), MinNumElts: 16);
296 auto *FixedPredType =
297 FixedVectorType::get(ElementType: Type::getInt8Ty(C&: I->getContext()), NumElts: MinVScale * 2);
298
299 // If we have a store..
300 auto *Store = dyn_cast<StoreInst>(Val: I);
301 if (!Store || !Store->isSimple())
302 return false;
303
304 // ..that is storing a predicate vector sized worth of bits..
305 if (Store->getOperand(i_nocapture: 0)->getType() != FixedPredType)
306 return false;
307
308 // ..where the value stored comes from a vector extract..
309 auto *IntrI = dyn_cast<IntrinsicInst>(Val: Store->getOperand(i_nocapture: 0));
310 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_extract)
311 return false;
312
313 // ..that is extracting from index 0..
314 if (!cast<ConstantInt>(Val: IntrI->getOperand(i_nocapture: 1))->isZero())
315 return false;
316
317 // ..where the value being extract from comes from a bitcast
318 auto *BitCast = dyn_cast<BitCastInst>(Val: IntrI->getOperand(i_nocapture: 0));
319 if (!BitCast)
320 return false;
321
322 // ..and the bitcast is casting from predicate type
323 if (BitCast->getOperand(i_nocapture: 0)->getType() != PredType)
324 return false;
325
326 IRBuilder<> Builder(I->getContext());
327 Builder.SetInsertPoint(I);
328
329 Builder.CreateStore(Val: BitCast->getOperand(i_nocapture: 0), Ptr: Store->getPointerOperand());
330
331 Store->eraseFromParent();
332 if (IntrI->getNumUses() == 0)
333 IntrI->eraseFromParent();
334 if (BitCast->getNumUses() == 0)
335 BitCast->eraseFromParent();
336
337 return true;
338}
339
340// This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce
341// scalable loads as late as possible
342bool SVEIntrinsicOpts::optimizePredicateLoad(Instruction *I) {
343 auto *F = I->getFunction();
344 auto Attr = F->getFnAttribute(Kind: Attribute::VScaleRange);
345 if (!Attr.isValid())
346 return false;
347
348 unsigned MinVScale = Attr.getVScaleRangeMin();
349 std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax();
350 // The transform needs to know the exact runtime length of scalable vectors
351 if (!MaxVScale || MinVScale != MaxVScale)
352 return false;
353
354 auto *PredType =
355 ScalableVectorType::get(ElementType: Type::getInt1Ty(C&: I->getContext()), MinNumElts: 16);
356 auto *FixedPredType =
357 FixedVectorType::get(ElementType: Type::getInt8Ty(C&: I->getContext()), NumElts: MinVScale * 2);
358
359 // If we have a bitcast..
360 auto *BitCast = dyn_cast<BitCastInst>(Val: I);
361 if (!BitCast || BitCast->getType() != PredType)
362 return false;
363
364 // ..whose operand is a vector_insert..
365 auto *IntrI = dyn_cast<IntrinsicInst>(Val: BitCast->getOperand(i_nocapture: 0));
366 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_insert)
367 return false;
368
369 // ..that is inserting into index zero of an undef vector..
370 if (!isa<UndefValue>(Val: IntrI->getOperand(i_nocapture: 0)) ||
371 !cast<ConstantInt>(Val: IntrI->getOperand(i_nocapture: 2))->isZero())
372 return false;
373
374 // ..where the value inserted comes from a load..
375 auto *Load = dyn_cast<LoadInst>(Val: IntrI->getOperand(i_nocapture: 1));
376 if (!Load || !Load->isSimple())
377 return false;
378
379 // ..that is loading a predicate vector sized worth of bits..
380 if (Load->getType() != FixedPredType)
381 return false;
382
383 IRBuilder<> Builder(I->getContext());
384 Builder.SetInsertPoint(Load);
385
386 auto *LoadPred = Builder.CreateLoad(Ty: PredType, Ptr: Load->getPointerOperand());
387
388 BitCast->replaceAllUsesWith(V: LoadPred);
389 BitCast->eraseFromParent();
390 if (IntrI->getNumUses() == 0)
391 IntrI->eraseFromParent();
392 if (Load->getNumUses() == 0)
393 Load->eraseFromParent();
394
395 return true;
396}
397
398bool SVEIntrinsicOpts::optimizeInstructions(
399 SmallSetVector<Function *, 4> &Functions) {
400 bool Changed = false;
401
402 for (auto *F : Functions) {
403 DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>(F&: *F).getDomTree();
404
405 // Traverse the DT with an rpo walk so we see defs before uses, allowing
406 // simplification to be done incrementally.
407 BasicBlock *Root = DT->getRoot();
408 ReversePostOrderTraversal<BasicBlock *> RPOT(Root);
409 for (auto *BB : RPOT) {
410 for (Instruction &I : make_early_inc_range(Range&: *BB)) {
411 switch (I.getOpcode()) {
412 case Instruction::Store:
413 Changed |= optimizePredicateStore(I: &I);
414 break;
415 case Instruction::BitCast:
416 Changed |= optimizePredicateLoad(I: &I);
417 break;
418 }
419 }
420 }
421 }
422
423 return Changed;
424}
425
426bool SVEIntrinsicOpts::optimizeFunctions(
427 SmallSetVector<Function *, 4> &Functions) {
428 bool Changed = false;
429
430 Changed |= optimizePTrueIntrinsicCalls(Functions);
431 Changed |= optimizeInstructions(Functions);
432
433 return Changed;
434}
435
436bool SVEIntrinsicOpts::runOnModule(Module &M) {
437 bool Changed = false;
438 SmallSetVector<Function *, 4> Functions;
439
440 // Check for SVE intrinsic declarations first so that we only iterate over
441 // relevant functions. Where an appropriate declaration is found, store the
442 // function(s) where it is used so we can target these only.
443 for (auto &F : M.getFunctionList()) {
444 if (!F.isDeclaration())
445 continue;
446
447 switch (F.getIntrinsicID()) {
448 case Intrinsic::vector_extract:
449 case Intrinsic::vector_insert:
450 case Intrinsic::aarch64_sve_ptrue:
451 for (User *U : F.users())
452 Functions.insert(X: cast<Instruction>(Val: U)->getFunction());
453 break;
454 default:
455 break;
456 }
457 }
458
459 if (!Functions.empty())
460 Changed |= optimizeFunctions(Functions);
461
462 return Changed;
463}
464