| 1 | //===- Legality.cpp -------------------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h" |
| 10 | #include "llvm/SandboxIR/Instruction.h" |
| 11 | #include "llvm/SandboxIR/Operator.h" |
| 12 | #include "llvm/SandboxIR/Utils.h" |
| 13 | #include "llvm/SandboxIR/Value.h" |
| 14 | #include "llvm/Support/Debug.h" |
| 15 | #include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h" |
| 16 | #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h" |
| 17 | |
| 18 | namespace llvm::sandboxir { |
| 19 | |
| 20 | #ifndef NDEBUG |
| 21 | void ShuffleMask::dump() const { |
| 22 | print(dbgs()); |
| 23 | dbgs() << "\n" ; |
| 24 | } |
| 25 | |
| 26 | void LegalityResult::dump() const { |
| 27 | print(dbgs()); |
| 28 | dbgs() << "\n" ; |
| 29 | } |
| 30 | #endif // NDEBUG |
| 31 | |
| 32 | std::optional<ResultReason> |
| 33 | LegalityAnalysis::notVectorizableBasedOnOpcodesAndTypes( |
| 34 | ArrayRef<Value *> Bndl) { |
| 35 | auto *I0 = cast<Instruction>(Val: Bndl[0]); |
| 36 | auto Opcode = I0->getOpcode(); |
| 37 | // If they have different opcodes, then we cannot form a vector (for now). |
| 38 | if (any_of(Range: drop_begin(RangeOrContainer&: Bndl), P: [Opcode](Value *V) { |
| 39 | return cast<Instruction>(Val: V)->getOpcode() != Opcode; |
| 40 | })) |
| 41 | return ResultReason::DiffOpcodes; |
| 42 | |
| 43 | // If not the same scalar type, Pack. This will accept scalars and vectors as |
| 44 | // long as the element type is the same. |
| 45 | Type *ElmTy0 = VecUtils::getElementType(Ty: Utils::getExpectedType(V: I0)); |
| 46 | if (any_of(Range: drop_begin(RangeOrContainer&: Bndl), P: [ElmTy0](Value *V) { |
| 47 | return VecUtils::getElementType(Ty: Utils::getExpectedType(V)) != ElmTy0; |
| 48 | })) |
| 49 | return ResultReason::DiffTypes; |
| 50 | |
| 51 | // TODO: Allow vectorization of instrs with different flags as long as we |
| 52 | // change them to the least common one. |
| 53 | // For now pack if differnt FastMathFlags. |
| 54 | if (isa<FPMathOperator>(Val: I0)) { |
| 55 | FastMathFlags FMF0 = cast<Instruction>(Val: Bndl[0])->getFastMathFlags(); |
| 56 | if (any_of(Range: drop_begin(RangeOrContainer&: Bndl), P: [FMF0](auto *V) { |
| 57 | return cast<Instruction>(V)->getFastMathFlags() != FMF0; |
| 58 | })) |
| 59 | return ResultReason::DiffMathFlags; |
| 60 | } |
| 61 | |
| 62 | // TODO: Allow vectorization by using common flags. |
| 63 | // For now Pack if they don't have the same wrap flags. |
| 64 | bool CanHaveWrapFlags = |
| 65 | isa<OverflowingBinaryOperator>(Val: I0) || isa<TruncInst>(Val: I0); |
| 66 | if (CanHaveWrapFlags) { |
| 67 | bool NUW0 = I0->hasNoUnsignedWrap(); |
| 68 | bool NSW0 = I0->hasNoSignedWrap(); |
| 69 | if (any_of(Range: drop_begin(RangeOrContainer&: Bndl), P: [NUW0, NSW0](auto *V) { |
| 70 | return cast<Instruction>(V)->hasNoUnsignedWrap() != NUW0 || |
| 71 | cast<Instruction>(V)->hasNoSignedWrap() != NSW0; |
| 72 | })) { |
| 73 | return ResultReason::DiffWrapFlags; |
| 74 | } |
| 75 | } |
| 76 | |
| 77 | // Now we need to do further checks for specific opcodes. |
| 78 | switch (Opcode) { |
| 79 | case Instruction::Opcode::ZExt: |
| 80 | case Instruction::Opcode::SExt: |
| 81 | case Instruction::Opcode::FPToUI: |
| 82 | case Instruction::Opcode::FPToSI: |
| 83 | case Instruction::Opcode::FPExt: |
| 84 | case Instruction::Opcode::PtrToInt: |
| 85 | case Instruction::Opcode::IntToPtr: |
| 86 | case Instruction::Opcode::SIToFP: |
| 87 | case Instruction::Opcode::UIToFP: |
| 88 | case Instruction::Opcode::Trunc: |
| 89 | case Instruction::Opcode::FPTrunc: |
| 90 | case Instruction::Opcode::BitCast: { |
| 91 | // We have already checked that they are of the same opcode. |
| 92 | assert(all_of(Bndl, |
| 93 | [Opcode](Value *V) { |
| 94 | return cast<Instruction>(V)->getOpcode() == Opcode; |
| 95 | }) && |
| 96 | "Different opcodes, should have early returned!" ); |
| 97 | // But for these opcodes we should also check the operand type. |
| 98 | Type *FromTy0 = Utils::getExpectedType(V: I0->getOperand(OpIdx: 0)); |
| 99 | if (any_of(Range: drop_begin(RangeOrContainer&: Bndl), P: [FromTy0](Value *V) { |
| 100 | return Utils::getExpectedType(V: cast<User>(Val: V)->getOperand(OpIdx: 0)) != |
| 101 | FromTy0; |
| 102 | })) |
| 103 | return ResultReason::DiffTypes; |
| 104 | return std::nullopt; |
| 105 | } |
| 106 | case Instruction::Opcode::FCmp: |
| 107 | case Instruction::Opcode::ICmp: { |
| 108 | // We need the same predicate.. |
| 109 | auto Pred0 = cast<CmpInst>(Val: I0)->getPredicate(); |
| 110 | bool Same = all_of(Range&: Bndl, P: [Pred0](Value *V) { |
| 111 | return cast<CmpInst>(Val: V)->getPredicate() == Pred0; |
| 112 | }); |
| 113 | if (Same) |
| 114 | return std::nullopt; |
| 115 | return ResultReason::DiffOpcodes; |
| 116 | } |
| 117 | case Instruction::Opcode::Select: { |
| 118 | auto *Sel0 = cast<SelectInst>(Val: Bndl[0]); |
| 119 | auto *Cond0 = Sel0->getCondition(); |
| 120 | if (VecUtils::getNumLanes(V: Cond0) != VecUtils::getNumLanes(V: Sel0)) |
| 121 | // TODO: For now we don't vectorize if the lanes in the condition don't |
| 122 | // match those of the select instruction. |
| 123 | return ResultReason::Unimplemented; |
| 124 | return std::nullopt; |
| 125 | } |
| 126 | case Instruction::Opcode::FNeg: |
| 127 | case Instruction::Opcode::Add: |
| 128 | case Instruction::Opcode::FAdd: |
| 129 | case Instruction::Opcode::Sub: |
| 130 | case Instruction::Opcode::FSub: |
| 131 | case Instruction::Opcode::Mul: |
| 132 | case Instruction::Opcode::FMul: |
| 133 | case Instruction::Opcode::FRem: |
| 134 | case Instruction::Opcode::UDiv: |
| 135 | case Instruction::Opcode::SDiv: |
| 136 | case Instruction::Opcode::FDiv: |
| 137 | case Instruction::Opcode::URem: |
| 138 | case Instruction::Opcode::SRem: |
| 139 | case Instruction::Opcode::Shl: |
| 140 | case Instruction::Opcode::LShr: |
| 141 | case Instruction::Opcode::AShr: |
| 142 | case Instruction::Opcode::And: |
| 143 | case Instruction::Opcode::Or: |
| 144 | case Instruction::Opcode::Xor: |
| 145 | return std::nullopt; |
| 146 | case Instruction::Opcode::Load: |
| 147 | if (VecUtils::areConsecutive<LoadInst>(Bndl, SE, DL)) |
| 148 | return std::nullopt; |
| 149 | return ResultReason::NotConsecutive; |
| 150 | case Instruction::Opcode::Store: |
| 151 | if (VecUtils::areConsecutive<StoreInst>(Bndl, SE, DL)) |
| 152 | return std::nullopt; |
| 153 | return ResultReason::NotConsecutive; |
| 154 | case Instruction::Opcode::PHI: |
| 155 | return ResultReason::Unimplemented; |
| 156 | case Instruction::Opcode::Opaque: |
| 157 | return ResultReason::Unimplemented; |
| 158 | case Instruction::Opcode::Br: |
| 159 | case Instruction::Opcode::Ret: |
| 160 | case Instruction::Opcode::AddrSpaceCast: |
| 161 | case Instruction::Opcode::InsertElement: |
| 162 | case Instruction::Opcode::InsertValue: |
| 163 | case Instruction::Opcode::ExtractElement: |
| 164 | case Instruction::Opcode::ExtractValue: |
| 165 | case Instruction::Opcode::ShuffleVector: |
| 166 | case Instruction::Opcode::Call: |
| 167 | case Instruction::Opcode::GetElementPtr: |
| 168 | case Instruction::Opcode::Switch: |
| 169 | return ResultReason::Unimplemented; |
| 170 | case Instruction::Opcode::VAArg: |
| 171 | case Instruction::Opcode::Freeze: |
| 172 | case Instruction::Opcode::Fence: |
| 173 | case Instruction::Opcode::Invoke: |
| 174 | case Instruction::Opcode::CallBr: |
| 175 | case Instruction::Opcode::LandingPad: |
| 176 | case Instruction::Opcode::CatchPad: |
| 177 | case Instruction::Opcode::CleanupPad: |
| 178 | case Instruction::Opcode::CatchRet: |
| 179 | case Instruction::Opcode::CleanupRet: |
| 180 | case Instruction::Opcode::Resume: |
| 181 | case Instruction::Opcode::CatchSwitch: |
| 182 | case Instruction::Opcode::AtomicRMW: |
| 183 | case Instruction::Opcode::AtomicCmpXchg: |
| 184 | case Instruction::Opcode::Alloca: |
| 185 | case Instruction::Opcode::Unreachable: |
| 186 | return ResultReason::Infeasible; |
| 187 | } |
| 188 | |
| 189 | return std::nullopt; |
| 190 | } |
| 191 | |
| 192 | CollectDescr |
| 193 | LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const { |
| 194 | SmallVector<CollectDescr::ExtractElementDescr, 4> Vec; |
| 195 | Vec.reserve(N: Bndl.size()); |
| 196 | for (auto [Elm, V] : enumerate(First&: Bndl)) { |
| 197 | if (auto *VecOp = IMaps.getVectorForOrig(Orig: V)) { |
| 198 | // If there is a vector containing `V`, then get the lane it came from. |
| 199 | std::optional<int> = IMaps.getOrigLane(Vec: VecOp, Orig: V); |
| 200 | // This could be a vector, like <2 x float> in which case the mask needs |
| 201 | // to enumerate all lanes. |
| 202 | for (unsigned Ln = 0, Lanes = VecUtils::getNumLanes(V); Ln != Lanes; ++Ln) |
| 203 | Vec.emplace_back(Args&: VecOp, Args: ExtractIdxOpt ? *ExtractIdxOpt + Ln : -1); |
| 204 | } else { |
| 205 | Vec.emplace_back(Args: V); |
| 206 | } |
| 207 | } |
| 208 | return CollectDescr(std::move(Vec)); |
| 209 | } |
| 210 | |
| 211 | const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl, |
| 212 | bool SkipScheduling) { |
| 213 | // If Bndl contains values other than instructions, we need to Pack. |
| 214 | if (any_of(Range&: Bndl, P: [](auto *V) { return !isa<Instruction>(V); })) |
| 215 | return createLegalityResult<Pack>(Args: ResultReason::NotInstructions); |
| 216 | // Pack if not in the same BB. |
| 217 | auto *BB = cast<Instruction>(Val: Bndl[0])->getParent(); |
| 218 | if (any_of(Range: drop_begin(RangeOrContainer&: Bndl), |
| 219 | P: [BB](auto *V) { return cast<Instruction>(V)->getParent() != BB; })) |
| 220 | return createLegalityResult<Pack>(Args: ResultReason::DiffBBs); |
| 221 | // Pack if instructions repeat, i.e., require some sort of broadcast. |
| 222 | SmallPtrSet<Value *, 8> Unique(llvm::from_range, Bndl); |
| 223 | if (Unique.size() != Bndl.size()) |
| 224 | return createLegalityResult<Pack>(Args: ResultReason::RepeatedInstrs); |
| 225 | |
| 226 | auto CollectDescrs = getHowToCollectValues(Bndl); |
| 227 | if (CollectDescrs.hasVectorInputs()) { |
| 228 | if (auto ValueShuffleOpt = CollectDescrs.getSingleInput()) { |
| 229 | auto [Vec, Mask] = *ValueShuffleOpt; |
| 230 | if (Mask.isIdentity()) |
| 231 | return createLegalityResult<DiamondReuse>(Args&: Vec); |
| 232 | return createLegalityResult<DiamondReuseWithShuffle>(Args&: Vec, Args&: Mask); |
| 233 | } |
| 234 | return createLegalityResult<DiamondReuseMultiInput>( |
| 235 | Args: std::move(CollectDescrs)); |
| 236 | } |
| 237 | |
| 238 | if (auto ReasonOpt = notVectorizableBasedOnOpcodesAndTypes(Bndl)) |
| 239 | return createLegalityResult<Pack>(Args&: *ReasonOpt); |
| 240 | |
| 241 | if (!SkipScheduling) { |
| 242 | // TODO: Try to remove the IBndl vector. |
| 243 | SmallVector<Instruction *, 8> IBndl; |
| 244 | IBndl.reserve(N: Bndl.size()); |
| 245 | for (auto *V : Bndl) |
| 246 | IBndl.push_back(Elt: cast<Instruction>(Val: V)); |
| 247 | if (!Sched.trySchedule(Instrs: IBndl)) |
| 248 | return createLegalityResult<Pack>(Args: ResultReason::CantSchedule); |
| 249 | } |
| 250 | |
| 251 | return createLegalityResult<Widen>(); |
| 252 | } |
| 253 | |
| 254 | void LegalityAnalysis::clear() { |
| 255 | Sched.clear(); |
| 256 | IMaps.clear(); |
| 257 | } |
| 258 | } // namespace llvm::sandboxir |
| 259 | |