1 | //===- SeedCollection.cpp - Seed collection pass --------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/SeedCollection.h" |
10 | #include "llvm/Analysis/TargetTransformInfo.h" |
11 | #include "llvm/SandboxIR/Module.h" |
12 | #include "llvm/SandboxIR/Region.h" |
13 | #include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h" |
14 | #include "llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h" |
15 | #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h" |
16 | |
17 | namespace llvm { |
18 | |
19 | static cl::opt<unsigned> |
20 | OverrideVecRegBits("sbvec-vec-reg-bits" , cl::init(Val: 0), cl::Hidden, |
21 | cl::desc("Override the vector register size in bits, " |
22 | "which is otherwise found by querying TTI." )); |
23 | static cl::opt<bool> |
24 | AllowNonPow2("sbvec-allow-non-pow2" , cl::init(Val: false), cl::Hidden, |
25 | cl::desc("Allow non-power-of-2 vectorization." )); |
26 | |
27 | #define LoadSeedsDef "loads" |
28 | #define StoreSeedsDef "stores" |
29 | cl::opt<std::string> CollectSeeds( |
30 | "sbvec-collect-seeds" , cl::init(StoreSeedsDef), cl::Hidden, |
31 | cl::desc("Collect these seeds. Use empty for none or a comma-separated " |
32 | "list of '" StoreSeedsDef "' and '" LoadSeedsDef "'." )); |
33 | |
34 | namespace sandboxir { |
35 | SeedCollection::SeedCollection(StringRef Pipeline) |
36 | : FunctionPass("seed-collection" ), |
37 | RPM("rpm" , Pipeline, SandboxVectorizerPassBuilder::createRegionPass) {} |
38 | |
39 | bool SeedCollection::runOnFunction(Function &F, const Analyses &A) { |
40 | bool Change = false; |
41 | const auto &DL = F.getParent()->getDataLayout(); |
42 | unsigned VecRegBits = |
43 | OverrideVecRegBits != 0 |
44 | ? OverrideVecRegBits |
45 | : A.getTTI() |
46 | .getRegisterBitWidth(K: TargetTransformInfo::RGK_FixedWidthVector) |
47 | .getFixedValue(); |
48 | bool CollectStores = CollectSeeds.find(StoreSeedsDef) != std::string::npos; |
49 | bool CollectLoads = CollectSeeds.find(LoadSeedsDef) != std::string::npos; |
50 | |
51 | // TODO: Start from innermost BBs first |
52 | for (auto &BB : F) { |
53 | SeedCollector SC(&BB, A.getScalarEvolution(), CollectStores, CollectLoads); |
54 | for (SeedBundle &Seeds : SC.getStoreSeeds()) { |
55 | unsigned ElmBits = |
56 | Utils::getNumBits(Ty: VecUtils::getElementType(Ty: Utils::getExpectedType( |
57 | V: Seeds[Seeds.getFirstUnusedElementIdx()])), |
58 | DL); |
59 | |
60 | auto DivideBy2 = [](unsigned Num) { |
61 | auto Floor = VecUtils::getFloorPowerOf2(Num); |
62 | if (Floor == Num) |
63 | return Floor / 2; |
64 | return Floor; |
65 | }; |
66 | // Try to create the largest vector supported by the target. If it fails |
67 | // reduce the vector size by half. |
68 | for (unsigned SliceElms = std::min(a: VecRegBits / ElmBits, |
69 | b: Seeds.getNumUnusedBits() / ElmBits); |
70 | SliceElms >= 2u; SliceElms = DivideBy2(SliceElms)) { |
71 | if (Seeds.allUsed()) |
72 | break; |
73 | // Keep trying offsets after FirstUnusedElementIdx, until we vectorize |
74 | // the slice. This could be quite expensive, so we enforce a limit. |
75 | for (unsigned Offset = Seeds.getFirstUnusedElementIdx(), |
76 | OE = Seeds.size(); |
77 | Offset + 1 < OE; Offset += 1) { |
78 | // Seeds are getting used as we vectorize, so skip them. |
79 | if (Seeds.isUsed(Element: Offset)) |
80 | continue; |
81 | if (Seeds.allUsed()) |
82 | break; |
83 | |
84 | auto SeedSlice = |
85 | Seeds.getSlice(StartIdx: Offset, MaxVecRegBits: SliceElms * ElmBits, ForcePowOf2: !AllowNonPow2); |
86 | if (SeedSlice.empty()) |
87 | continue; |
88 | |
89 | assert(SeedSlice.size() >= 2 && "Should have been rejected!" ); |
90 | |
91 | // Create a region containing the seed slice. |
92 | auto &Ctx = F.getContext(); |
93 | Region Rgn(Ctx, A.getTTI()); |
94 | Rgn.setAux(SeedSlice); |
95 | // Run the region pass pipeline. |
96 | Change |= RPM.runOnRegion(R&: Rgn, A); |
97 | Rgn.clearAux(); |
98 | } |
99 | } |
100 | } |
101 | } |
102 | return Change; |
103 | } |
104 | } // namespace sandboxir |
105 | } // namespace llvm |
106 | |