1//===-- SPIRVCombinerHelper.cpp -------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "SPIRVCombinerHelper.h"
10#include "SPIRVGlobalRegistry.h"
11#include "SPIRVUtils.h"
12#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
13#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
14#include "llvm/IR/DerivedTypes.h"
15#include "llvm/IR/IntrinsicsSPIRV.h"
16#include "llvm/IR/LLVMContext.h" // Explicitly include for LLVMContext
17#include "llvm/Target/TargetMachine.h"
18
19using namespace llvm;
20using namespace MIPatternMatch;
21
22SPIRVCombinerHelper::SPIRVCombinerHelper(
23 GISelChangeObserver &Observer, MachineIRBuilder &B, bool IsPreLegalize,
24 GISelValueTracking *VT, MachineDominatorTree *MDT, const LegalizerInfo *LI,
25 const SPIRVSubtarget &STI)
26 : CombinerHelper(Observer, B, IsPreLegalize, VT, MDT, LI), STI(STI) {}
27
28/// This match is part of a combine that
29/// rewrites length(X - Y) to distance(X, Y)
30/// (f32 (g_intrinsic length
31/// (g_fsub (vXf32 X) (vXf32 Y))))
32/// ->
33/// (f32 (g_intrinsic distance
34/// (vXf32 X) (vXf32 Y)))
35///
36bool SPIRVCombinerHelper::matchLengthToDistance(MachineInstr &MI) const {
37 if (MI.getOpcode() != TargetOpcode::G_INTRINSIC ||
38 cast<GIntrinsic>(Val&: MI).getIntrinsicID() != Intrinsic::spv_length)
39 return false;
40
41 // First operand of MI is `G_INTRINSIC` so start at operand 2.
42 Register SubReg = MI.getOperand(i: 2).getReg();
43 MachineInstr *SubInstr = MRI.getVRegDef(Reg: SubReg);
44 if (SubInstr->getOpcode() != TargetOpcode::G_FSUB)
45 return false;
46
47 return true;
48}
49
50void SPIRVCombinerHelper::applySPIRVDistance(MachineInstr &MI) const {
51 // Extract the operands for X and Y from the match criteria.
52 Register SubDestReg = MI.getOperand(i: 2).getReg();
53 MachineInstr *SubInstr = MRI.getVRegDef(Reg: SubDestReg);
54 Register SubOperand1 = SubInstr->getOperand(i: 1).getReg();
55 Register SubOperand2 = SubInstr->getOperand(i: 2).getReg();
56 Register ResultReg = MI.getOperand(i: 0).getReg();
57
58 Builder.setInstrAndDebugLoc(MI);
59 Builder.buildIntrinsic(ID: Intrinsic::spv_distance, Res: ResultReg)
60 .addUse(RegNo: SubOperand1)
61 .addUse(RegNo: SubOperand2);
62
63 MI.eraseFromParent();
64}
65
66/// This match is part of a combine that
67/// rewrites select(fcmp(dot(I, Ng), 0), N, -N) to faceforward(N, I, Ng)
68/// (vXf32 (g_select
69/// (g_fcmp
70/// (g_intrinsic dot(vXf32 I) (vXf32 Ng)
71/// 0)
72/// (vXf32 N)
73/// (vXf32 g_fneg (vXf32 N))))
74/// ->
75/// (vXf32 (g_intrinsic faceforward
76/// (vXf32 N) (vXf32 I) (vXf32 Ng)))
77///
78/// This only works for Vulkan shader targets.
79///
80bool SPIRVCombinerHelper::matchSelectToFaceForward(MachineInstr &MI) const {
81 if (!STI.isShader())
82 return false;
83
84 // Match overall select pattern.
85 Register CondReg, TrueReg, FalseReg;
86 if (!mi_match(R: MI.getOperand(i: 0).getReg(), MRI,
87 P: m_GISelect(Src0: m_Reg(R&: CondReg), Src1: m_Reg(R&: TrueReg), Src2: m_Reg(R&: FalseReg))))
88 return false;
89
90 // Match the FCMP condition.
91 Register DotReg, CondZeroReg;
92 CmpInst::Predicate Pred;
93 if (!mi_match(R: CondReg, MRI,
94 P: m_GFCmp(P: m_Pred(P&: Pred), L: m_Reg(R&: DotReg), R: m_Reg(R&: CondZeroReg))) ||
95 !(Pred == CmpInst::FCMP_OLT || Pred == CmpInst::FCMP_ULT)) {
96 if (!(Pred == CmpInst::FCMP_OGT || Pred == CmpInst::FCMP_UGT))
97 return false;
98 std::swap(a&: DotReg, b&: CondZeroReg);
99 }
100
101 // Check if FCMP is a comparison between a dot product and 0.
102 MachineInstr *DotInstr = MRI.getVRegDef(Reg: DotReg);
103 if (DotInstr->getOpcode() != TargetOpcode::G_INTRINSIC ||
104 cast<GIntrinsic>(Val: DotInstr)->getIntrinsicID() != Intrinsic::spv_fdot) {
105 Register DotOperand1, DotOperand2;
106 // Check for scalar dot product.
107 if (!mi_match(R: DotReg, MRI,
108 P: m_GFMul(L: m_Reg(R&: DotOperand1), R: m_Reg(R&: DotOperand2))) ||
109 !MRI.getType(Reg: DotOperand1).isScalar() ||
110 !MRI.getType(Reg: DotOperand2).isScalar())
111 return false;
112 }
113
114 const ConstantFP *ZeroVal;
115 if (!mi_match(R: CondZeroReg, MRI, P: m_GFCst(C&: ZeroVal)) || !ZeroVal->isZero())
116 return false;
117
118 // Check if select's false operand is the negation of the true operand.
119 auto AreNegatedConstantsOrSplats = [&](Register TrueReg, Register FalseReg) {
120 std::optional<FPValueAndVReg> TrueVal, FalseVal;
121 if (!mi_match(R: TrueReg, MRI, P: m_GFCstOrSplat(FPValReg&: TrueVal)) ||
122 !mi_match(R: FalseReg, MRI, P: m_GFCstOrSplat(FPValReg&: FalseVal)))
123 return false;
124 APFloat TrueValNegated = TrueVal->Value;
125 TrueValNegated.changeSign();
126 return FalseVal->Value.compare(RHS: TrueValNegated) == APFloat::cmpEqual;
127 };
128
129 if (!mi_match(R: TrueReg, MRI, P: m_GFNeg(Src: m_SpecificReg(RequestedReg: FalseReg))) &&
130 !mi_match(R: FalseReg, MRI, P: m_GFNeg(Src: m_SpecificReg(RequestedReg: TrueReg)))) {
131 std::optional<FPValueAndVReg> MulConstant;
132 MachineInstr *TrueInstr = MRI.getVRegDef(Reg: TrueReg);
133 MachineInstr *FalseInstr = MRI.getVRegDef(Reg: FalseReg);
134 if (TrueInstr->getOpcode() == TargetOpcode::G_BUILD_VECTOR &&
135 FalseInstr->getOpcode() == TargetOpcode::G_BUILD_VECTOR &&
136 TrueInstr->getNumOperands() == FalseInstr->getNumOperands()) {
137 for (unsigned I = 1; I < TrueInstr->getNumOperands(); ++I)
138 if (!AreNegatedConstantsOrSplats(TrueInstr->getOperand(i: I).getReg(),
139 FalseInstr->getOperand(i: I).getReg()))
140 return false;
141 } else if (mi_match(R: TrueReg, MRI,
142 P: m_GFMul(L: m_SpecificReg(RequestedReg: FalseReg),
143 R: m_GFCstOrSplat(FPValReg&: MulConstant))) ||
144 mi_match(R: FalseReg, MRI,
145 P: m_GFMul(L: m_SpecificReg(RequestedReg: TrueReg),
146 R: m_GFCstOrSplat(FPValReg&: MulConstant))) ||
147 mi_match(R: TrueReg, MRI,
148 P: m_GFMul(L: m_GFCstOrSplat(FPValReg&: MulConstant),
149 R: m_SpecificReg(RequestedReg: FalseReg))) ||
150 mi_match(R: FalseReg, MRI,
151 P: m_GFMul(L: m_GFCstOrSplat(FPValReg&: MulConstant),
152 R: m_SpecificReg(RequestedReg: TrueReg)))) {
153 if (!MulConstant || !MulConstant->Value.isExactlyValue(V: -1.0))
154 return false;
155 } else if (!AreNegatedConstantsOrSplats(TrueReg, FalseReg))
156 return false;
157 }
158
159 return true;
160}
161
162void SPIRVCombinerHelper::applySPIRVFaceForward(MachineInstr &MI) const {
163 // Extract the operands for N, I, and Ng from the match criteria.
164 Register CondReg = MI.getOperand(i: 1).getReg();
165 MachineInstr *CondInstr = MRI.getVRegDef(Reg: CondReg);
166 Register DotReg = CondInstr->getOperand(i: 2).getReg();
167 CmpInst::Predicate Pred = cast<GFCmp>(Val: CondInstr)->getCond();
168 if (Pred == CmpInst::FCMP_OGT || Pred == CmpInst::FCMP_UGT)
169 DotReg = CondInstr->getOperand(i: 3).getReg();
170 MachineInstr *DotInstr = MRI.getVRegDef(Reg: DotReg);
171 Register DotOperand1, DotOperand2;
172 if (DotInstr->getOpcode() == TargetOpcode::G_FMUL) {
173 DotOperand1 = DotInstr->getOperand(i: 1).getReg();
174 DotOperand2 = DotInstr->getOperand(i: 2).getReg();
175 } else {
176 DotOperand1 = DotInstr->getOperand(i: 2).getReg();
177 DotOperand2 = DotInstr->getOperand(i: 3).getReg();
178 }
179 Register TrueReg = MI.getOperand(i: 2).getReg();
180 Register FalseReg = MI.getOperand(i: 3).getReg();
181 MachineInstr *TrueInstr = MRI.getVRegDef(Reg: TrueReg);
182 if (TrueInstr->getOpcode() == TargetOpcode::G_FNEG ||
183 TrueInstr->getOpcode() == TargetOpcode::G_FMUL)
184 std::swap(a&: TrueReg, b&: FalseReg);
185 MachineInstr *FalseInstr = MRI.getVRegDef(Reg: FalseReg);
186
187 Register ResultReg = MI.getOperand(i: 0).getReg();
188 Builder.setInstrAndDebugLoc(MI);
189 Builder.buildIntrinsic(ID: Intrinsic::spv_faceforward, Res: ResultReg)
190 .addUse(RegNo: TrueReg) // N
191 .addUse(RegNo: DotOperand1) // I
192 .addUse(RegNo: DotOperand2); // Ng
193
194 SPIRVGlobalRegistry *GR =
195 MI.getMF()->getSubtarget<SPIRVSubtarget>().getSPIRVGlobalRegistry();
196 auto RemoveAllUses = [&](Register Reg) {
197 SmallVector<MachineInstr *, 4> UsesToErase;
198 for (auto &UseMI : MRI.use_instructions(Reg))
199 UsesToErase.push_back(Elt: &UseMI);
200
201 // calling eraseFromParent to early invalidates the iterator.
202 for (auto *MIToErase : UsesToErase)
203 MIToErase->eraseFromParent();
204 };
205
206 RemoveAllUses(CondReg); // remove all uses of FCMP Result
207 GR->invalidateMachineInstr(MI: CondInstr);
208 CondInstr->eraseFromParent(); // remove FCMP instruction
209 RemoveAllUses(DotReg); // remove all uses of spv_fdot/G_FMUL Result
210 GR->invalidateMachineInstr(MI: DotInstr);
211 DotInstr->eraseFromParent(); // remove spv_fdot/G_FMUL instruction
212 RemoveAllUses(FalseReg);
213 GR->invalidateMachineInstr(MI: FalseInstr);
214 FalseInstr->eraseFromParent();
215}
216
217bool SPIRVCombinerHelper::matchMatrixTranspose(MachineInstr &MI) const {
218 return MI.getOpcode() == TargetOpcode::G_INTRINSIC &&
219 cast<GIntrinsic>(Val&: MI).getIntrinsicID() == Intrinsic::matrix_transpose;
220}
221
222void SPIRVCombinerHelper::applyMatrixTranspose(MachineInstr &MI) const {
223 Register ResReg = MI.getOperand(i: 0).getReg();
224 Register InReg = MI.getOperand(i: 2).getReg();
225 uint32_t Rows = MI.getOperand(i: 3).getImm();
226 uint32_t Cols = MI.getOperand(i: 4).getImm();
227
228 Builder.setInstrAndDebugLoc(MI);
229
230 if (Rows == 1 && Cols == 1) {
231 Builder.buildCopy(Res: ResReg, Op: InReg);
232 MI.eraseFromParent();
233 return;
234 }
235
236 SmallVector<int, 16> Mask;
237 for (uint32_t K = 0; K < Rows * Cols; ++K) {
238 uint32_t R = K / Cols;
239 uint32_t C = K % Cols;
240 Mask.push_back(Elt: C * Rows + R);
241 }
242
243 Builder.buildShuffleVector(Res: ResReg, Src1: InReg, Src2: InReg, Mask);
244 MI.eraseFromParent();
245}
246
247bool SPIRVCombinerHelper::matchMatrixMultiply(MachineInstr &MI) const {
248 return MI.getOpcode() == TargetOpcode::G_INTRINSIC &&
249 cast<GIntrinsic>(Val&: MI).getIntrinsicID() == Intrinsic::matrix_multiply;
250}
251
252SmallVector<Register, 4>
253SPIRVCombinerHelper::extractColumns(Register MatrixReg, uint32_t NumberOfCols,
254 SPIRVType *SpvColType,
255 SPIRVGlobalRegistry *GR) const {
256 // If the matrix is a single colunm, return that single column.
257 if (NumberOfCols == 1)
258 return {MatrixReg};
259
260 SmallVector<Register, 4> Cols;
261 LLT ColTy = GR->getRegType(SpvType: SpvColType);
262 for (uint32_t J = 0; J < NumberOfCols; ++J)
263 Cols.push_back(Elt: MRI.createGenericVirtualRegister(Ty: ColTy));
264 Builder.buildUnmerge(Res: Cols, Op: MatrixReg);
265 for (Register R : Cols) {
266 setRegClassType(Reg: R, SpvType: SpvColType, GR, MRI: &MRI, MF: Builder.getMF());
267 }
268 return Cols;
269}
270
271SmallVector<Register, 4>
272SPIRVCombinerHelper::extractRows(Register MatrixReg, uint32_t NumRows,
273 uint32_t NumCols, SPIRVType *SpvRowType,
274 SPIRVGlobalRegistry *GR) const {
275 SmallVector<Register, 4> Rows;
276 LLT VecTy = GR->getRegType(SpvType: SpvRowType);
277
278 // If there is only one column, then each row is a scalar that needs
279 // to be extracted.
280 if (NumCols == 1) {
281 assert(SpvRowType->getOpcode() != SPIRV::OpTypeVector);
282 for (uint32_t I = 0; I < NumRows; ++I)
283 Rows.push_back(Elt: MRI.createGenericVirtualRegister(Ty: VecTy));
284 Builder.buildUnmerge(Res: Rows, Op: MatrixReg);
285 for (Register R : Rows) {
286 setRegClassType(Reg: R, SpvType: SpvRowType, GR, MRI: &MRI, MF: Builder.getMF());
287 }
288 return Rows;
289 }
290
291 // If the matrix is a single row return that row.
292 if (NumRows == 1) {
293 return {MatrixReg};
294 }
295
296 for (uint32_t I = 0; I < NumRows; ++I) {
297 SmallVector<int, 4> Mask;
298 for (uint32_t k = 0; k < NumCols; ++k)
299 Mask.push_back(Elt: k * NumRows + I);
300 Rows.push_back(Elt: Builder.buildShuffleVector(Res: VecTy, Src1: MatrixReg, Src2: MatrixReg, Mask)
301 .getReg(Idx: 0));
302 }
303 for (Register R : Rows) {
304 setRegClassType(Reg: R, SpvType: SpvRowType, GR, MRI: &MRI, MF: Builder.getMF());
305 }
306 return Rows;
307}
308
309Register SPIRVCombinerHelper::computeDotProduct(Register RowA, Register ColB,
310 SPIRVType *SpvVecType,
311 SPIRVGlobalRegistry *GR) const {
312 bool IsVectorOp = SpvVecType->getOpcode() == SPIRV::OpTypeVector;
313 SPIRVType *SpvScalarType = GR->getScalarOrVectorComponentType(Type: SpvVecType);
314 bool IsFloatOp = SpvScalarType->getOpcode() == SPIRV::OpTypeFloat;
315 LLT VecTy = GR->getRegType(SpvType: SpvVecType);
316
317 Register DotRes;
318 if (IsVectorOp) {
319 LLT ScalarTy = VecTy.getElementType();
320 Intrinsic::SPVIntrinsics DotIntrinsic =
321 (IsFloatOp ? Intrinsic::spv_fdot : Intrinsic::spv_udot);
322 DotRes = Builder.buildIntrinsic(ID: DotIntrinsic, Res: {ScalarTy})
323 .addUse(RegNo: RowA)
324 .addUse(RegNo: ColB)
325 .getReg(Idx: 0);
326 } else {
327 if (IsFloatOp)
328 DotRes = Builder.buildFMul(Dst: VecTy, Src0: RowA, Src1: ColB).getReg(Idx: 0);
329 else
330 DotRes = Builder.buildMul(Dst: VecTy, Src0: RowA, Src1: ColB).getReg(Idx: 0);
331 }
332 setRegClassType(Reg: DotRes, SpvType: SpvScalarType, GR, MRI: &MRI, MF: Builder.getMF());
333 return DotRes;
334}
335
336SmallVector<Register, 16>
337SPIRVCombinerHelper::computeDotProducts(const SmallVector<Register, 4> &RowsA,
338 const SmallVector<Register, 4> &ColsB,
339 SPIRVType *SpvVecType,
340 SPIRVGlobalRegistry *GR) const {
341 SmallVector<Register, 16> ResultScalars;
342 for (uint32_t J = 0; J < ColsB.size(); ++J) {
343 for (uint32_t I = 0; I < RowsA.size(); ++I) {
344 ResultScalars.push_back(
345 Elt: computeDotProduct(RowA: RowsA[I], ColB: ColsB[J], SpvVecType, GR));
346 }
347 }
348 return ResultScalars;
349}
350
351SPIRVType *
352SPIRVCombinerHelper::getDotProductVectorType(Register ResReg, uint32_t K,
353 SPIRVGlobalRegistry *GR) const {
354 // Loop over all non debug uses of ResReg
355 Type *ScalarResType = nullptr;
356 for (auto &UseMI : MRI.use_instructions(Reg: ResReg)) {
357 if (UseMI.getOpcode() != TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS)
358 continue;
359
360 if (!isSpvIntrinsic(MI: UseMI, IntrinsicID: Intrinsic::spv_assign_type))
361 continue;
362
363 Type *Ty = getMDOperandAsType(N: UseMI.getOperand(i: 2).getMetadata(), I: 0);
364 if (Ty->isVectorTy())
365 ScalarResType = cast<VectorType>(Val: Ty)->getElementType();
366 else
367 ScalarResType = Ty;
368 assert(ScalarResType->isIntegerTy() || ScalarResType->isFloatingPointTy());
369 break;
370 }
371 if (!ScalarResType)
372 llvm_unreachable("Could not determine scalar result type");
373 Type *VecType =
374 (K > 1 ? FixedVectorType::get(ElementType: ScalarResType, NumElts: K) : ScalarResType);
375 return GR->getOrCreateSPIRVType(Type: VecType, MIRBuilder&: Builder,
376 AQ: SPIRV::AccessQualifier::None, EmitIR: false);
377}
378
379void SPIRVCombinerHelper::applyMatrixMultiply(MachineInstr &MI) const {
380 Register ResReg = MI.getOperand(i: 0).getReg();
381 Register AReg = MI.getOperand(i: 2).getReg();
382 Register BReg = MI.getOperand(i: 3).getReg();
383 uint32_t NumRowsA = MI.getOperand(i: 4).getImm();
384 uint32_t NumColsA = MI.getOperand(i: 5).getImm();
385 uint32_t NumColsB = MI.getOperand(i: 6).getImm();
386
387 Builder.setInstrAndDebugLoc(MI);
388
389 SPIRVGlobalRegistry *GR =
390 MI.getMF()->getSubtarget<SPIRVSubtarget>().getSPIRVGlobalRegistry();
391
392 SPIRVType *SpvVecType = getDotProductVectorType(ResReg, K: NumColsA, GR);
393 SmallVector<Register, 4> ColsB =
394 extractColumns(MatrixReg: BReg, NumberOfCols: NumColsB, SpvColType: SpvVecType, GR);
395 SmallVector<Register, 4> RowsA =
396 extractRows(MatrixReg: AReg, NumRows: NumRowsA, NumCols: NumColsA, SpvRowType: SpvVecType, GR);
397 SmallVector<Register, 16> ResultScalars =
398 computeDotProducts(RowsA, ColsB, SpvVecType, GR);
399
400 Builder.buildBuildVector(Res: ResReg, Ops: ResultScalars);
401 MI.eraseFromParent();
402}
403