| 1 | //===-- SPIRVCombinerHelper.cpp -------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "SPIRVCombinerHelper.h" |
| 10 | #include "SPIRVGlobalRegistry.h" |
| 11 | #include "SPIRVUtils.h" |
| 12 | #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" |
| 13 | #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" |
| 14 | #include "llvm/IR/DerivedTypes.h" |
| 15 | #include "llvm/IR/IntrinsicsSPIRV.h" |
| 16 | #include "llvm/IR/LLVMContext.h" // Explicitly include for LLVMContext |
| 17 | #include "llvm/Target/TargetMachine.h" |
| 18 | |
| 19 | using namespace llvm; |
| 20 | using namespace MIPatternMatch; |
| 21 | |
| 22 | SPIRVCombinerHelper::SPIRVCombinerHelper( |
| 23 | GISelChangeObserver &Observer, MachineIRBuilder &B, bool IsPreLegalize, |
| 24 | GISelValueTracking *VT, MachineDominatorTree *MDT, const LegalizerInfo *LI, |
| 25 | const SPIRVSubtarget &STI) |
| 26 | : CombinerHelper(Observer, B, IsPreLegalize, VT, MDT, LI), STI(STI) {} |
| 27 | |
| 28 | /// This match is part of a combine that |
| 29 | /// rewrites length(X - Y) to distance(X, Y) |
| 30 | /// (f32 (g_intrinsic length |
| 31 | /// (g_fsub (vXf32 X) (vXf32 Y)))) |
| 32 | /// -> |
| 33 | /// (f32 (g_intrinsic distance |
| 34 | /// (vXf32 X) (vXf32 Y))) |
| 35 | /// |
| 36 | bool SPIRVCombinerHelper::matchLengthToDistance(MachineInstr &MI) const { |
| 37 | if (MI.getOpcode() != TargetOpcode::G_INTRINSIC || |
| 38 | cast<GIntrinsic>(Val&: MI).getIntrinsicID() != Intrinsic::spv_length) |
| 39 | return false; |
| 40 | |
| 41 | // First operand of MI is `G_INTRINSIC` so start at operand 2. |
| 42 | Register SubReg = MI.getOperand(i: 2).getReg(); |
| 43 | MachineInstr *SubInstr = MRI.getVRegDef(Reg: SubReg); |
| 44 | if (SubInstr->getOpcode() != TargetOpcode::G_FSUB) |
| 45 | return false; |
| 46 | |
| 47 | return true; |
| 48 | } |
| 49 | |
| 50 | void SPIRVCombinerHelper::applySPIRVDistance(MachineInstr &MI) const { |
| 51 | // Extract the operands for X and Y from the match criteria. |
| 52 | Register SubDestReg = MI.getOperand(i: 2).getReg(); |
| 53 | MachineInstr *SubInstr = MRI.getVRegDef(Reg: SubDestReg); |
| 54 | Register SubOperand1 = SubInstr->getOperand(i: 1).getReg(); |
| 55 | Register SubOperand2 = SubInstr->getOperand(i: 2).getReg(); |
| 56 | Register ResultReg = MI.getOperand(i: 0).getReg(); |
| 57 | |
| 58 | Builder.setInstrAndDebugLoc(MI); |
| 59 | Builder.buildIntrinsic(ID: Intrinsic::spv_distance, Res: ResultReg) |
| 60 | .addUse(RegNo: SubOperand1) |
| 61 | .addUse(RegNo: SubOperand2); |
| 62 | |
| 63 | MI.eraseFromParent(); |
| 64 | } |
| 65 | |
| 66 | /// This match is part of a combine that |
| 67 | /// rewrites select(fcmp(dot(I, Ng), 0), N, -N) to faceforward(N, I, Ng) |
| 68 | /// (vXf32 (g_select |
| 69 | /// (g_fcmp |
| 70 | /// (g_intrinsic dot(vXf32 I) (vXf32 Ng) |
| 71 | /// 0) |
| 72 | /// (vXf32 N) |
| 73 | /// (vXf32 g_fneg (vXf32 N)))) |
| 74 | /// -> |
| 75 | /// (vXf32 (g_intrinsic faceforward |
| 76 | /// (vXf32 N) (vXf32 I) (vXf32 Ng))) |
| 77 | /// |
| 78 | /// This only works for Vulkan shader targets. |
| 79 | /// |
| 80 | bool SPIRVCombinerHelper::matchSelectToFaceForward(MachineInstr &MI) const { |
| 81 | if (!STI.isShader()) |
| 82 | return false; |
| 83 | |
| 84 | // Match overall select pattern. |
| 85 | Register CondReg, TrueReg, FalseReg; |
| 86 | if (!mi_match(R: MI.getOperand(i: 0).getReg(), MRI, |
| 87 | P: m_GISelect(Src0: m_Reg(R&: CondReg), Src1: m_Reg(R&: TrueReg), Src2: m_Reg(R&: FalseReg)))) |
| 88 | return false; |
| 89 | |
| 90 | // Match the FCMP condition. |
| 91 | Register DotReg, CondZeroReg; |
| 92 | CmpInst::Predicate Pred; |
| 93 | if (!mi_match(R: CondReg, MRI, |
| 94 | P: m_GFCmp(P: m_Pred(P&: Pred), L: m_Reg(R&: DotReg), R: m_Reg(R&: CondZeroReg))) || |
| 95 | !(Pred == CmpInst::FCMP_OLT || Pred == CmpInst::FCMP_ULT)) { |
| 96 | if (!(Pred == CmpInst::FCMP_OGT || Pred == CmpInst::FCMP_UGT)) |
| 97 | return false; |
| 98 | std::swap(a&: DotReg, b&: CondZeroReg); |
| 99 | } |
| 100 | |
| 101 | // Check if FCMP is a comparison between a dot product and 0. |
| 102 | MachineInstr *DotInstr = MRI.getVRegDef(Reg: DotReg); |
| 103 | if (DotInstr->getOpcode() != TargetOpcode::G_INTRINSIC || |
| 104 | cast<GIntrinsic>(Val: DotInstr)->getIntrinsicID() != Intrinsic::spv_fdot) { |
| 105 | Register DotOperand1, DotOperand2; |
| 106 | // Check for scalar dot product. |
| 107 | if (!mi_match(R: DotReg, MRI, |
| 108 | P: m_GFMul(L: m_Reg(R&: DotOperand1), R: m_Reg(R&: DotOperand2))) || |
| 109 | !MRI.getType(Reg: DotOperand1).isScalar() || |
| 110 | !MRI.getType(Reg: DotOperand2).isScalar()) |
| 111 | return false; |
| 112 | } |
| 113 | |
| 114 | const ConstantFP *ZeroVal; |
| 115 | if (!mi_match(R: CondZeroReg, MRI, P: m_GFCst(C&: ZeroVal)) || !ZeroVal->isZero()) |
| 116 | return false; |
| 117 | |
| 118 | // Check if select's false operand is the negation of the true operand. |
| 119 | auto AreNegatedConstantsOrSplats = [&](Register TrueReg, Register FalseReg) { |
| 120 | std::optional<FPValueAndVReg> TrueVal, FalseVal; |
| 121 | if (!mi_match(R: TrueReg, MRI, P: m_GFCstOrSplat(FPValReg&: TrueVal)) || |
| 122 | !mi_match(R: FalseReg, MRI, P: m_GFCstOrSplat(FPValReg&: FalseVal))) |
| 123 | return false; |
| 124 | APFloat TrueValNegated = TrueVal->Value; |
| 125 | TrueValNegated.changeSign(); |
| 126 | return FalseVal->Value.compare(RHS: TrueValNegated) == APFloat::cmpEqual; |
| 127 | }; |
| 128 | |
| 129 | if (!mi_match(R: TrueReg, MRI, P: m_GFNeg(Src: m_SpecificReg(RequestedReg: FalseReg))) && |
| 130 | !mi_match(R: FalseReg, MRI, P: m_GFNeg(Src: m_SpecificReg(RequestedReg: TrueReg)))) { |
| 131 | std::optional<FPValueAndVReg> MulConstant; |
| 132 | MachineInstr *TrueInstr = MRI.getVRegDef(Reg: TrueReg); |
| 133 | MachineInstr *FalseInstr = MRI.getVRegDef(Reg: FalseReg); |
| 134 | if (TrueInstr->getOpcode() == TargetOpcode::G_BUILD_VECTOR && |
| 135 | FalseInstr->getOpcode() == TargetOpcode::G_BUILD_VECTOR && |
| 136 | TrueInstr->getNumOperands() == FalseInstr->getNumOperands()) { |
| 137 | for (unsigned I = 1; I < TrueInstr->getNumOperands(); ++I) |
| 138 | if (!AreNegatedConstantsOrSplats(TrueInstr->getOperand(i: I).getReg(), |
| 139 | FalseInstr->getOperand(i: I).getReg())) |
| 140 | return false; |
| 141 | } else if (mi_match(R: TrueReg, MRI, |
| 142 | P: m_GFMul(L: m_SpecificReg(RequestedReg: FalseReg), |
| 143 | R: m_GFCstOrSplat(FPValReg&: MulConstant))) || |
| 144 | mi_match(R: FalseReg, MRI, |
| 145 | P: m_GFMul(L: m_SpecificReg(RequestedReg: TrueReg), |
| 146 | R: m_GFCstOrSplat(FPValReg&: MulConstant))) || |
| 147 | mi_match(R: TrueReg, MRI, |
| 148 | P: m_GFMul(L: m_GFCstOrSplat(FPValReg&: MulConstant), |
| 149 | R: m_SpecificReg(RequestedReg: FalseReg))) || |
| 150 | mi_match(R: FalseReg, MRI, |
| 151 | P: m_GFMul(L: m_GFCstOrSplat(FPValReg&: MulConstant), |
| 152 | R: m_SpecificReg(RequestedReg: TrueReg)))) { |
| 153 | if (!MulConstant || !MulConstant->Value.isExactlyValue(V: -1.0)) |
| 154 | return false; |
| 155 | } else if (!AreNegatedConstantsOrSplats(TrueReg, FalseReg)) |
| 156 | return false; |
| 157 | } |
| 158 | |
| 159 | return true; |
| 160 | } |
| 161 | |
| 162 | void SPIRVCombinerHelper::applySPIRVFaceForward(MachineInstr &MI) const { |
| 163 | // Extract the operands for N, I, and Ng from the match criteria. |
| 164 | Register CondReg = MI.getOperand(i: 1).getReg(); |
| 165 | MachineInstr *CondInstr = MRI.getVRegDef(Reg: CondReg); |
| 166 | Register DotReg = CondInstr->getOperand(i: 2).getReg(); |
| 167 | CmpInst::Predicate Pred = cast<GFCmp>(Val: CondInstr)->getCond(); |
| 168 | if (Pred == CmpInst::FCMP_OGT || Pred == CmpInst::FCMP_UGT) |
| 169 | DotReg = CondInstr->getOperand(i: 3).getReg(); |
| 170 | MachineInstr *DotInstr = MRI.getVRegDef(Reg: DotReg); |
| 171 | Register DotOperand1, DotOperand2; |
| 172 | if (DotInstr->getOpcode() == TargetOpcode::G_FMUL) { |
| 173 | DotOperand1 = DotInstr->getOperand(i: 1).getReg(); |
| 174 | DotOperand2 = DotInstr->getOperand(i: 2).getReg(); |
| 175 | } else { |
| 176 | DotOperand1 = DotInstr->getOperand(i: 2).getReg(); |
| 177 | DotOperand2 = DotInstr->getOperand(i: 3).getReg(); |
| 178 | } |
| 179 | Register TrueReg = MI.getOperand(i: 2).getReg(); |
| 180 | Register FalseReg = MI.getOperand(i: 3).getReg(); |
| 181 | MachineInstr *TrueInstr = MRI.getVRegDef(Reg: TrueReg); |
| 182 | if (TrueInstr->getOpcode() == TargetOpcode::G_FNEG || |
| 183 | TrueInstr->getOpcode() == TargetOpcode::G_FMUL) |
| 184 | std::swap(a&: TrueReg, b&: FalseReg); |
| 185 | MachineInstr *FalseInstr = MRI.getVRegDef(Reg: FalseReg); |
| 186 | |
| 187 | Register ResultReg = MI.getOperand(i: 0).getReg(); |
| 188 | Builder.setInstrAndDebugLoc(MI); |
| 189 | Builder.buildIntrinsic(ID: Intrinsic::spv_faceforward, Res: ResultReg) |
| 190 | .addUse(RegNo: TrueReg) // N |
| 191 | .addUse(RegNo: DotOperand1) // I |
| 192 | .addUse(RegNo: DotOperand2); // Ng |
| 193 | |
| 194 | SPIRVGlobalRegistry *GR = |
| 195 | MI.getMF()->getSubtarget<SPIRVSubtarget>().getSPIRVGlobalRegistry(); |
| 196 | auto RemoveAllUses = [&](Register Reg) { |
| 197 | SmallVector<MachineInstr *, 4> UsesToErase; |
| 198 | for (auto &UseMI : MRI.use_instructions(Reg)) |
| 199 | UsesToErase.push_back(Elt: &UseMI); |
| 200 | |
| 201 | // calling eraseFromParent to early invalidates the iterator. |
| 202 | for (auto *MIToErase : UsesToErase) |
| 203 | MIToErase->eraseFromParent(); |
| 204 | }; |
| 205 | |
| 206 | RemoveAllUses(CondReg); // remove all uses of FCMP Result |
| 207 | GR->invalidateMachineInstr(MI: CondInstr); |
| 208 | CondInstr->eraseFromParent(); // remove FCMP instruction |
| 209 | RemoveAllUses(DotReg); // remove all uses of spv_fdot/G_FMUL Result |
| 210 | GR->invalidateMachineInstr(MI: DotInstr); |
| 211 | DotInstr->eraseFromParent(); // remove spv_fdot/G_FMUL instruction |
| 212 | RemoveAllUses(FalseReg); |
| 213 | GR->invalidateMachineInstr(MI: FalseInstr); |
| 214 | FalseInstr->eraseFromParent(); |
| 215 | } |
| 216 | |
| 217 | bool SPIRVCombinerHelper::matchMatrixTranspose(MachineInstr &MI) const { |
| 218 | return MI.getOpcode() == TargetOpcode::G_INTRINSIC && |
| 219 | cast<GIntrinsic>(Val&: MI).getIntrinsicID() == Intrinsic::matrix_transpose; |
| 220 | } |
| 221 | |
| 222 | void SPIRVCombinerHelper::applyMatrixTranspose(MachineInstr &MI) const { |
| 223 | Register ResReg = MI.getOperand(i: 0).getReg(); |
| 224 | Register InReg = MI.getOperand(i: 2).getReg(); |
| 225 | uint32_t Rows = MI.getOperand(i: 3).getImm(); |
| 226 | uint32_t Cols = MI.getOperand(i: 4).getImm(); |
| 227 | |
| 228 | Builder.setInstrAndDebugLoc(MI); |
| 229 | |
| 230 | if (Rows == 1 && Cols == 1) { |
| 231 | Builder.buildCopy(Res: ResReg, Op: InReg); |
| 232 | MI.eraseFromParent(); |
| 233 | return; |
| 234 | } |
| 235 | |
| 236 | SmallVector<int, 16> Mask; |
| 237 | for (uint32_t K = 0; K < Rows * Cols; ++K) { |
| 238 | uint32_t R = K / Cols; |
| 239 | uint32_t C = K % Cols; |
| 240 | Mask.push_back(Elt: C * Rows + R); |
| 241 | } |
| 242 | |
| 243 | Builder.buildShuffleVector(Res: ResReg, Src1: InReg, Src2: InReg, Mask); |
| 244 | MI.eraseFromParent(); |
| 245 | } |
| 246 | |
| 247 | bool SPIRVCombinerHelper::matchMatrixMultiply(MachineInstr &MI) const { |
| 248 | return MI.getOpcode() == TargetOpcode::G_INTRINSIC && |
| 249 | cast<GIntrinsic>(Val&: MI).getIntrinsicID() == Intrinsic::matrix_multiply; |
| 250 | } |
| 251 | |
| 252 | SmallVector<Register, 4> |
| 253 | SPIRVCombinerHelper::extractColumns(Register MatrixReg, uint32_t NumberOfCols, |
| 254 | SPIRVType *SpvColType, |
| 255 | SPIRVGlobalRegistry *GR) const { |
| 256 | // If the matrix is a single colunm, return that single column. |
| 257 | if (NumberOfCols == 1) |
| 258 | return {MatrixReg}; |
| 259 | |
| 260 | SmallVector<Register, 4> Cols; |
| 261 | LLT ColTy = GR->getRegType(SpvType: SpvColType); |
| 262 | for (uint32_t J = 0; J < NumberOfCols; ++J) |
| 263 | Cols.push_back(Elt: MRI.createGenericVirtualRegister(Ty: ColTy)); |
| 264 | Builder.buildUnmerge(Res: Cols, Op: MatrixReg); |
| 265 | for (Register R : Cols) { |
| 266 | setRegClassType(Reg: R, SpvType: SpvColType, GR, MRI: &MRI, MF: Builder.getMF()); |
| 267 | } |
| 268 | return Cols; |
| 269 | } |
| 270 | |
| 271 | SmallVector<Register, 4> |
| 272 | SPIRVCombinerHelper::(Register MatrixReg, uint32_t NumRows, |
| 273 | uint32_t NumCols, SPIRVType *SpvRowType, |
| 274 | SPIRVGlobalRegistry *GR) const { |
| 275 | SmallVector<Register, 4> Rows; |
| 276 | LLT VecTy = GR->getRegType(SpvType: SpvRowType); |
| 277 | |
| 278 | // If there is only one column, then each row is a scalar that needs |
| 279 | // to be extracted. |
| 280 | if (NumCols == 1) { |
| 281 | assert(SpvRowType->getOpcode() != SPIRV::OpTypeVector); |
| 282 | for (uint32_t I = 0; I < NumRows; ++I) |
| 283 | Rows.push_back(Elt: MRI.createGenericVirtualRegister(Ty: VecTy)); |
| 284 | Builder.buildUnmerge(Res: Rows, Op: MatrixReg); |
| 285 | for (Register R : Rows) { |
| 286 | setRegClassType(Reg: R, SpvType: SpvRowType, GR, MRI: &MRI, MF: Builder.getMF()); |
| 287 | } |
| 288 | return Rows; |
| 289 | } |
| 290 | |
| 291 | // If the matrix is a single row return that row. |
| 292 | if (NumRows == 1) { |
| 293 | return {MatrixReg}; |
| 294 | } |
| 295 | |
| 296 | for (uint32_t I = 0; I < NumRows; ++I) { |
| 297 | SmallVector<int, 4> Mask; |
| 298 | for (uint32_t k = 0; k < NumCols; ++k) |
| 299 | Mask.push_back(Elt: k * NumRows + I); |
| 300 | Rows.push_back(Elt: Builder.buildShuffleVector(Res: VecTy, Src1: MatrixReg, Src2: MatrixReg, Mask) |
| 301 | .getReg(Idx: 0)); |
| 302 | } |
| 303 | for (Register R : Rows) { |
| 304 | setRegClassType(Reg: R, SpvType: SpvRowType, GR, MRI: &MRI, MF: Builder.getMF()); |
| 305 | } |
| 306 | return Rows; |
| 307 | } |
| 308 | |
| 309 | Register SPIRVCombinerHelper::computeDotProduct(Register RowA, Register ColB, |
| 310 | SPIRVType *SpvVecType, |
| 311 | SPIRVGlobalRegistry *GR) const { |
| 312 | bool IsVectorOp = SpvVecType->getOpcode() == SPIRV::OpTypeVector; |
| 313 | SPIRVType *SpvScalarType = GR->getScalarOrVectorComponentType(Type: SpvVecType); |
| 314 | bool IsFloatOp = SpvScalarType->getOpcode() == SPIRV::OpTypeFloat; |
| 315 | LLT VecTy = GR->getRegType(SpvType: SpvVecType); |
| 316 | |
| 317 | Register DotRes; |
| 318 | if (IsVectorOp) { |
| 319 | LLT ScalarTy = VecTy.getElementType(); |
| 320 | Intrinsic::SPVIntrinsics DotIntrinsic = |
| 321 | (IsFloatOp ? Intrinsic::spv_fdot : Intrinsic::spv_udot); |
| 322 | DotRes = Builder.buildIntrinsic(ID: DotIntrinsic, Res: {ScalarTy}) |
| 323 | .addUse(RegNo: RowA) |
| 324 | .addUse(RegNo: ColB) |
| 325 | .getReg(Idx: 0); |
| 326 | } else { |
| 327 | if (IsFloatOp) |
| 328 | DotRes = Builder.buildFMul(Dst: VecTy, Src0: RowA, Src1: ColB).getReg(Idx: 0); |
| 329 | else |
| 330 | DotRes = Builder.buildMul(Dst: VecTy, Src0: RowA, Src1: ColB).getReg(Idx: 0); |
| 331 | } |
| 332 | setRegClassType(Reg: DotRes, SpvType: SpvScalarType, GR, MRI: &MRI, MF: Builder.getMF()); |
| 333 | return DotRes; |
| 334 | } |
| 335 | |
| 336 | SmallVector<Register, 16> |
| 337 | SPIRVCombinerHelper::computeDotProducts(const SmallVector<Register, 4> &RowsA, |
| 338 | const SmallVector<Register, 4> &ColsB, |
| 339 | SPIRVType *SpvVecType, |
| 340 | SPIRVGlobalRegistry *GR) const { |
| 341 | SmallVector<Register, 16> ResultScalars; |
| 342 | for (uint32_t J = 0; J < ColsB.size(); ++J) { |
| 343 | for (uint32_t I = 0; I < RowsA.size(); ++I) { |
| 344 | ResultScalars.push_back( |
| 345 | Elt: computeDotProduct(RowA: RowsA[I], ColB: ColsB[J], SpvVecType, GR)); |
| 346 | } |
| 347 | } |
| 348 | return ResultScalars; |
| 349 | } |
| 350 | |
| 351 | SPIRVType * |
| 352 | SPIRVCombinerHelper::getDotProductVectorType(Register ResReg, uint32_t K, |
| 353 | SPIRVGlobalRegistry *GR) const { |
| 354 | // Loop over all non debug uses of ResReg |
| 355 | Type *ScalarResType = nullptr; |
| 356 | for (auto &UseMI : MRI.use_instructions(Reg: ResReg)) { |
| 357 | if (UseMI.getOpcode() != TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS) |
| 358 | continue; |
| 359 | |
| 360 | if (!isSpvIntrinsic(MI: UseMI, IntrinsicID: Intrinsic::spv_assign_type)) |
| 361 | continue; |
| 362 | |
| 363 | Type *Ty = getMDOperandAsType(N: UseMI.getOperand(i: 2).getMetadata(), I: 0); |
| 364 | if (Ty->isVectorTy()) |
| 365 | ScalarResType = cast<VectorType>(Val: Ty)->getElementType(); |
| 366 | else |
| 367 | ScalarResType = Ty; |
| 368 | assert(ScalarResType->isIntegerTy() || ScalarResType->isFloatingPointTy()); |
| 369 | break; |
| 370 | } |
| 371 | if (!ScalarResType) |
| 372 | llvm_unreachable("Could not determine scalar result type" ); |
| 373 | Type *VecType = |
| 374 | (K > 1 ? FixedVectorType::get(ElementType: ScalarResType, NumElts: K) : ScalarResType); |
| 375 | return GR->getOrCreateSPIRVType(Type: VecType, MIRBuilder&: Builder, |
| 376 | AQ: SPIRV::AccessQualifier::None, EmitIR: false); |
| 377 | } |
| 378 | |
| 379 | void SPIRVCombinerHelper::applyMatrixMultiply(MachineInstr &MI) const { |
| 380 | Register ResReg = MI.getOperand(i: 0).getReg(); |
| 381 | Register AReg = MI.getOperand(i: 2).getReg(); |
| 382 | Register BReg = MI.getOperand(i: 3).getReg(); |
| 383 | uint32_t NumRowsA = MI.getOperand(i: 4).getImm(); |
| 384 | uint32_t NumColsA = MI.getOperand(i: 5).getImm(); |
| 385 | uint32_t NumColsB = MI.getOperand(i: 6).getImm(); |
| 386 | |
| 387 | Builder.setInstrAndDebugLoc(MI); |
| 388 | |
| 389 | SPIRVGlobalRegistry *GR = |
| 390 | MI.getMF()->getSubtarget<SPIRVSubtarget>().getSPIRVGlobalRegistry(); |
| 391 | |
| 392 | SPIRVType *SpvVecType = getDotProductVectorType(ResReg, K: NumColsA, GR); |
| 393 | SmallVector<Register, 4> ColsB = |
| 394 | extractColumns(MatrixReg: BReg, NumberOfCols: NumColsB, SpvColType: SpvVecType, GR); |
| 395 | SmallVector<Register, 4> RowsA = |
| 396 | extractRows(MatrixReg: AReg, NumRows: NumRowsA, NumCols: NumColsA, SpvRowType: SpvVecType, GR); |
| 397 | SmallVector<Register, 16> ResultScalars = |
| 398 | computeDotProducts(RowsA, ColsB, SpvVecType, GR); |
| 399 | |
| 400 | Builder.buildBuildVector(Res: ResReg, Ops: ResultScalars); |
| 401 | MI.eraseFromParent(); |
| 402 | } |
| 403 | |