1//===-- SPIRVCombinerHelper.cpp -------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "SPIRVCombinerHelper.h"
10#include "SPIRVGlobalRegistry.h"
11#include "SPIRVUtils.h"
12#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
13#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
14#include "llvm/IR/DerivedTypes.h"
15#include "llvm/IR/IntrinsicsSPIRV.h"
16#include "llvm/IR/LLVMContext.h" // Explicitly include for LLVMContext
17#include "llvm/Target/TargetMachine.h"
18
19using namespace llvm;
20using namespace MIPatternMatch;
21
22SPIRVCombinerHelper::SPIRVCombinerHelper(
23 GISelChangeObserver &Observer, MachineIRBuilder &B, bool IsPreLegalize,
24 GISelValueTracking *VT, MachineDominatorTree *MDT, const LegalizerInfo *LI,
25 const SPIRVSubtarget &STI)
26 : CombinerHelper(Observer, B, IsPreLegalize, VT, MDT, LI), STI(STI) {}
27
28/// This match is part of a combine that
29/// rewrites length(X - Y) to distance(X, Y)
30/// (f32 (g_intrinsic length
31/// (g_fsub (vXf32 X) (vXf32 Y))))
32/// ->
33/// (f32 (g_intrinsic distance
34/// (vXf32 X) (vXf32 Y)))
35///
36bool SPIRVCombinerHelper::matchLengthToDistance(MachineInstr &MI) const {
37 if (MI.getOpcode() != TargetOpcode::G_INTRINSIC ||
38 cast<GIntrinsic>(Val&: MI).getIntrinsicID() != Intrinsic::spv_length)
39 return false;
40
41 // First operand of MI is `G_INTRINSIC` so start at operand 2.
42 Register SubReg = MI.getOperand(i: 2).getReg();
43 MachineInstr *SubInstr = MRI.getVRegDef(Reg: SubReg);
44 if (SubInstr->getOpcode() != TargetOpcode::G_FSUB)
45 return false;
46
47 return true;
48}
49
50void SPIRVCombinerHelper::applySPIRVDistance(MachineInstr &MI) const {
51 // Extract the operands for X and Y from the match criteria.
52 Register SubDestReg = MI.getOperand(i: 2).getReg();
53 MachineInstr *SubInstr = MRI.getVRegDef(Reg: SubDestReg);
54 Register SubOperand1 = SubInstr->getOperand(i: 1).getReg();
55 Register SubOperand2 = SubInstr->getOperand(i: 2).getReg();
56 Register ResultReg = MI.getOperand(i: 0).getReg();
57
58 Builder.setInstrAndDebugLoc(MI);
59 Builder.buildIntrinsic(ID: Intrinsic::spv_distance, Res: ResultReg)
60 .addUse(RegNo: SubOperand1)
61 .addUse(RegNo: SubOperand2);
62
63 MI.eraseFromParent();
64}
65
66/// This match is part of a combine that
67/// rewrites select(fcmp(dot(I, Ng), 0), N, -N) to faceforward(N, I, Ng)
68/// (vXf32 (g_select
69/// (g_fcmp
70/// (g_intrinsic dot(vXf32 I) (vXf32 Ng)
71/// 0)
72/// (vXf32 N)
73/// (vXf32 g_fneg (vXf32 N))))
74/// ->
75/// (vXf32 (g_intrinsic faceforward
76/// (vXf32 N) (vXf32 I) (vXf32 Ng)))
77///
78/// This only works for Vulkan shader targets.
79///
80bool SPIRVCombinerHelper::matchSelectToFaceForward(MachineInstr &MI) const {
81 if (!STI.isShader())
82 return false;
83
84 // Match overall select pattern.
85 Register CondReg, TrueReg, FalseReg;
86 if (!mi_match(R: MI.getOperand(i: 0).getReg(), MRI,
87 P: m_GISelect(Src0: m_Reg(R&: CondReg), Src1: m_Reg(R&: TrueReg), Src2: m_Reg(R&: FalseReg))))
88 return false;
89
90 // Match the FCMP condition.
91 Register DotReg, CondZeroReg;
92 CmpInst::Predicate Pred;
93 if (!mi_match(R: CondReg, MRI,
94 P: m_GFCmp(P: m_Pred(P&: Pred), L: m_Reg(R&: DotReg), R: m_Reg(R&: CondZeroReg))))
95 return false;
96 if (Pred == CmpInst::FCMP_OGT || Pred == CmpInst::FCMP_UGT)
97 std::swap(a&: DotReg, b&: CondZeroReg);
98 else if (!(Pred == CmpInst::FCMP_OLT || Pred == CmpInst::FCMP_ULT))
99 return false;
100
101 // Check if FCMP is a comparison between a dot product and 0.
102 MachineInstr *DotInstr = MRI.getVRegDef(Reg: DotReg);
103 if (DotInstr->getOpcode() != TargetOpcode::G_INTRINSIC ||
104 cast<GIntrinsic>(Val: DotInstr)->getIntrinsicID() != Intrinsic::spv_fdot) {
105 Register DotOperand1, DotOperand2;
106 // Check for scalar dot product.
107 if (!mi_match(R: DotReg, MRI,
108 P: m_GFMul(L: m_Reg(R&: DotOperand1), R: m_Reg(R&: DotOperand2))) ||
109 !MRI.getType(Reg: DotOperand1).isScalar() ||
110 !MRI.getType(Reg: DotOperand2).isScalar())
111 return false;
112 }
113
114 const ConstantFP *ZeroVal;
115 if (!mi_match(R: CondZeroReg, MRI, P: m_GFCst(C&: ZeroVal)) || !ZeroVal->isZero())
116 return false;
117
118 // Check if select's false operand is the negation of the true operand.
119 auto AreNegatedConstantsOrSplats = [&](Register TrueReg, Register FalseReg) {
120 std::optional<FPValueAndVReg> TrueVal, FalseVal;
121 if (!mi_match(R: TrueReg, MRI, P: m_GFCstOrSplat(FPValReg&: TrueVal)) ||
122 !mi_match(R: FalseReg, MRI, P: m_GFCstOrSplat(FPValReg&: FalseVal)))
123 return false;
124 APFloat TrueValNegated = TrueVal->Value;
125 TrueValNegated.changeSign();
126 return FalseVal->Value.compare(RHS: TrueValNegated) == APFloat::cmpEqual;
127 };
128
129 if (!mi_match(R: TrueReg, MRI, P: m_GFNeg(Src: m_SpecificReg(RequestedReg: FalseReg))) &&
130 !mi_match(R: FalseReg, MRI, P: m_GFNeg(Src: m_SpecificReg(RequestedReg: TrueReg)))) {
131 std::optional<FPValueAndVReg> MulConstant;
132 MachineInstr *TrueInstr = MRI.getVRegDef(Reg: TrueReg);
133 MachineInstr *FalseInstr = MRI.getVRegDef(Reg: FalseReg);
134 if (TrueInstr->getOpcode() == TargetOpcode::G_BUILD_VECTOR &&
135 FalseInstr->getOpcode() == TargetOpcode::G_BUILD_VECTOR &&
136 TrueInstr->getNumOperands() == FalseInstr->getNumOperands()) {
137 for (unsigned I = 1; I < TrueInstr->getNumOperands(); ++I)
138 if (!AreNegatedConstantsOrSplats(TrueInstr->getOperand(i: I).getReg(),
139 FalseInstr->getOperand(i: I).getReg()))
140 return false;
141 } else if (mi_match(R: TrueReg, MRI,
142 P: m_GFMul(L: m_SpecificReg(RequestedReg: FalseReg),
143 R: m_GFCstOrSplat(FPValReg&: MulConstant))) ||
144 mi_match(R: FalseReg, MRI,
145 P: m_GFMul(L: m_SpecificReg(RequestedReg: TrueReg),
146 R: m_GFCstOrSplat(FPValReg&: MulConstant))) ||
147 mi_match(R: TrueReg, MRI,
148 P: m_GFMul(L: m_GFCstOrSplat(FPValReg&: MulConstant),
149 R: m_SpecificReg(RequestedReg: FalseReg))) ||
150 mi_match(R: FalseReg, MRI,
151 P: m_GFMul(L: m_GFCstOrSplat(FPValReg&: MulConstant),
152 R: m_SpecificReg(RequestedReg: TrueReg)))) {
153 if (!MulConstant || !MulConstant->Value.isMinusOne())
154 return false;
155 } else if (!AreNegatedConstantsOrSplats(TrueReg, FalseReg))
156 return false;
157 }
158
159 return true;
160}
161
162void SPIRVCombinerHelper::applySPIRVFaceForward(MachineInstr &MI) const {
163 // Extract the operands for N, I, and Ng from the match criteria.
164 Register CondReg = MI.getOperand(i: 1).getReg();
165 MachineInstr *CondInstr = MRI.getVRegDef(Reg: CondReg);
166 Register DotReg = CondInstr->getOperand(i: 2).getReg();
167 CmpInst::Predicate Pred = cast<GFCmp>(Val: CondInstr)->getCond();
168 if (Pred == CmpInst::FCMP_OGT || Pred == CmpInst::FCMP_UGT)
169 DotReg = CondInstr->getOperand(i: 3).getReg();
170 MachineInstr *DotInstr = MRI.getVRegDef(Reg: DotReg);
171 Register DotOperand1, DotOperand2;
172 if (DotInstr->getOpcode() == TargetOpcode::G_FMUL) {
173 DotOperand1 = DotInstr->getOperand(i: 1).getReg();
174 DotOperand2 = DotInstr->getOperand(i: 2).getReg();
175 } else {
176 DotOperand1 = DotInstr->getOperand(i: 2).getReg();
177 DotOperand2 = DotInstr->getOperand(i: 3).getReg();
178 }
179 Register TrueReg = MI.getOperand(i: 2).getReg();
180 Register FalseReg = MI.getOperand(i: 3).getReg();
181 MachineInstr *TrueInstr = MRI.getVRegDef(Reg: TrueReg);
182 if (TrueInstr->getOpcode() == TargetOpcode::G_FNEG ||
183 TrueInstr->getOpcode() == TargetOpcode::G_FMUL)
184 std::swap(a&: TrueReg, b&: FalseReg);
185 MachineInstr *FalseInstr = MRI.getVRegDef(Reg: FalseReg);
186
187 Register ResultReg = MI.getOperand(i: 0).getReg();
188 Builder.setInstrAndDebugLoc(MI);
189 Builder.buildIntrinsic(ID: Intrinsic::spv_faceforward, Res: ResultReg)
190 .addUse(RegNo: TrueReg) // N
191 .addUse(RegNo: DotOperand1) // I
192 .addUse(RegNo: DotOperand2); // Ng
193
194 SPIRVGlobalRegistry *GR =
195 MI.getMF()->getSubtarget<SPIRVSubtarget>().getSPIRVGlobalRegistry();
196 auto RemoveAllUses = [&](Register Reg) {
197 SmallVector<MachineInstr *, 4> UsesToErase;
198 for (auto &UseMI : MRI.use_instructions(Reg))
199 UsesToErase.push_back(Elt: &UseMI);
200
201 // calling eraseFromParent to early invalidates the iterator.
202 for (auto *MIToErase : UsesToErase)
203 MIToErase->eraseFromParent();
204 };
205
206 RemoveAllUses(CondReg); // remove all uses of FCMP Result
207 GR->invalidateMachineInstr(MI: CondInstr);
208 CondInstr->eraseFromParent(); // remove FCMP instruction
209 RemoveAllUses(DotReg); // remove all uses of spv_fdot/G_FMUL Result
210 GR->invalidateMachineInstr(MI: DotInstr);
211 DotInstr->eraseFromParent(); // remove spv_fdot/G_FMUL instruction
212 RemoveAllUses(FalseReg);
213 GR->invalidateMachineInstr(MI: FalseInstr);
214 FalseInstr->eraseFromParent();
215}
216
217bool SPIRVCombinerHelper::matchMatrixTranspose(MachineInstr &MI) const {
218 return MI.getOpcode() == TargetOpcode::G_INTRINSIC &&
219 cast<GIntrinsic>(Val&: MI).getIntrinsicID() == Intrinsic::matrix_transpose;
220}
221
222void SPIRVCombinerHelper::applyMatrixTranspose(MachineInstr &MI) const {
223 Register ResReg = MI.getOperand(i: 0).getReg();
224 Register InReg = MI.getOperand(i: 2).getReg();
225 uint32_t Rows = MI.getOperand(i: 3).getImm();
226 uint32_t Cols = MI.getOperand(i: 4).getImm();
227
228 Builder.setInstrAndDebugLoc(MI);
229
230 // A 1xN or Nx1 transpose is a pure reshape.
231 if (Rows == 1 || Cols == 1) {
232 Builder.buildCopy(Res: ResReg, Op: InReg);
233 MI.eraseFromParent();
234 return;
235 }
236
237 SmallVector<int, 16> Mask;
238 for (uint32_t K = 0; K < Rows * Cols; ++K) {
239 uint32_t R = K / Cols;
240 uint32_t C = K % Cols;
241 Mask.push_back(Elt: C * Rows + R);
242 }
243
244 Builder.buildShuffleVector(Res: ResReg, Src1: InReg, Src2: InReg, Mask);
245 MI.eraseFromParent();
246}
247
248bool SPIRVCombinerHelper::matchMatrixMultiply(MachineInstr &MI) const {
249 return MI.getOpcode() == TargetOpcode::G_INTRINSIC &&
250 cast<GIntrinsic>(Val&: MI).getIntrinsicID() == Intrinsic::matrix_multiply;
251}
252
253SmallVector<Register, 4>
254SPIRVCombinerHelper::extractColumns(Register MatrixReg, uint32_t NumberOfCols,
255 SPIRVTypeInst SpvColType,
256 SPIRVGlobalRegistry *GR) const {
257 // If the matrix is a single colunm, return that single column.
258 if (NumberOfCols == 1)
259 return {MatrixReg};
260
261 SmallVector<Register, 4> Cols;
262 LLT ColTy = GR->getRegType(SpvType: SpvColType);
263 for (uint32_t J = 0; J < NumberOfCols; ++J)
264 Cols.push_back(Elt: MRI.createGenericVirtualRegister(Ty: ColTy));
265 Builder.buildUnmerge(Res: Cols, Op: MatrixReg);
266 for (Register R : Cols) {
267 setRegClassType(Reg: R, SpvType: SpvColType, GR, MRI: &MRI, MF: Builder.getMF());
268 }
269 return Cols;
270}
271
272SmallVector<Register, 4>
273SPIRVCombinerHelper::extractRows(Register MatrixReg, uint32_t NumRows,
274 uint32_t NumCols, SPIRVTypeInst SpvRowType,
275 SPIRVGlobalRegistry *GR) const {
276 SmallVector<Register, 4> Rows;
277 LLT VecTy = GR->getRegType(SpvType: SpvRowType);
278
279 // If there is only one column, then each row is a scalar that needs
280 // to be extracted.
281 if (NumCols == 1) {
282 assert(SpvRowType->getOpcode() != SPIRV::OpTypeVector);
283 for (uint32_t I = 0; I < NumRows; ++I)
284 Rows.push_back(Elt: MRI.createGenericVirtualRegister(Ty: VecTy));
285 Builder.buildUnmerge(Res: Rows, Op: MatrixReg);
286 for (Register R : Rows) {
287 setRegClassType(Reg: R, SpvType: SpvRowType, GR, MRI: &MRI, MF: Builder.getMF());
288 }
289 return Rows;
290 }
291
292 // If the matrix is a single row return that row.
293 if (NumRows == 1) {
294 return {MatrixReg};
295 }
296
297 for (uint32_t I = 0; I < NumRows; ++I) {
298 SmallVector<int, 4> Mask;
299 for (uint32_t k = 0; k < NumCols; ++k)
300 Mask.push_back(Elt: k * NumRows + I);
301 Rows.push_back(Elt: Builder.buildShuffleVector(Res: VecTy, Src1: MatrixReg, Src2: MatrixReg, Mask)
302 .getReg(Idx: 0));
303 }
304 for (Register R : Rows) {
305 setRegClassType(Reg: R, SpvType: SpvRowType, GR, MRI: &MRI, MF: Builder.getMF());
306 }
307 return Rows;
308}
309
310Register SPIRVCombinerHelper::computeDotProduct(Register RowA, Register ColB,
311 SPIRVTypeInst SpvVecType,
312 SPIRVGlobalRegistry *GR) const {
313 bool IsVectorOp = SpvVecType->getOpcode() == SPIRV::OpTypeVector;
314 SPIRVTypeInst SpvScalarType = GR->getScalarOrVectorComponentType(Type: SpvVecType);
315 bool IsFloatOp = SpvScalarType->getOpcode() == SPIRV::OpTypeFloat;
316 LLT VecTy = GR->getRegType(SpvType: SpvVecType);
317
318 Register DotRes;
319 if (IsVectorOp) {
320 LLT ScalarTy = VecTy.getElementType();
321 Intrinsic::SPVIntrinsics DotIntrinsic =
322 (IsFloatOp ? Intrinsic::spv_fdot : Intrinsic::spv_udot);
323 DotRes = Builder.buildIntrinsic(ID: DotIntrinsic, Res: {ScalarTy})
324 .addUse(RegNo: RowA)
325 .addUse(RegNo: ColB)
326 .getReg(Idx: 0);
327 } else {
328 if (IsFloatOp)
329 DotRes = Builder.buildFMul(Dst: VecTy, Src0: RowA, Src1: ColB).getReg(Idx: 0);
330 else
331 DotRes = Builder.buildMul(Dst: VecTy, Src0: RowA, Src1: ColB).getReg(Idx: 0);
332 }
333 setRegClassType(Reg: DotRes, SpvType: SpvScalarType, GR, MRI: &MRI, MF: Builder.getMF());
334 return DotRes;
335}
336
337SmallVector<Register, 16> SPIRVCombinerHelper::computeDotProducts(
338 ArrayRef<Register> RowsA, ArrayRef<Register> ColsB,
339 SPIRVTypeInst SpvVecType, SPIRVGlobalRegistry *GR) const {
340 SmallVector<Register, 16> ResultScalars;
341 for (uint32_t J = 0; J < ColsB.size(); ++J) {
342 for (uint32_t I = 0; I < RowsA.size(); ++I) {
343 ResultScalars.push_back(
344 Elt: computeDotProduct(RowA: RowsA[I], ColB: ColsB[J], SpvVecType, GR));
345 }
346 }
347 return ResultScalars;
348}
349
350SPIRVTypeInst
351SPIRVCombinerHelper::getDotProductVectorType(Register ResReg, uint32_t K,
352 SPIRVGlobalRegistry *GR) const {
353 // Loop over all non debug uses of ResReg
354 Type *ScalarResType = nullptr;
355 for (auto &UseMI : MRI.use_instructions(Reg: ResReg)) {
356 if (UseMI.getOpcode() != TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS)
357 continue;
358
359 if (!isSpvIntrinsic(MI: UseMI, IntrinsicID: Intrinsic::spv_assign_type))
360 continue;
361
362 Type *Ty = getMDOperandAsType(N: UseMI.getOperand(i: 2).getMetadata(), I: 0);
363 if (Ty->isVectorTy())
364 ScalarResType = cast<VectorType>(Val: Ty)->getElementType();
365 else
366 ScalarResType = Ty;
367 assert(ScalarResType->isIntegerTy() || ScalarResType->isFloatingPointTy());
368 break;
369 }
370 if (!ScalarResType)
371 llvm_unreachable("Could not determine scalar result type");
372 Type *VecType =
373 (K > 1 ? FixedVectorType::get(ElementType: ScalarResType, NumElts: K) : ScalarResType);
374 return GR->getOrCreateSPIRVType(Type: VecType, MIRBuilder&: Builder,
375 AQ: SPIRV::AccessQualifier::None, EmitIR: false);
376}
377
378void SPIRVCombinerHelper::applyMatrixMultiply(MachineInstr &MI) const {
379 Register ResReg = MI.getOperand(i: 0).getReg();
380 Register AReg = MI.getOperand(i: 2).getReg();
381 Register BReg = MI.getOperand(i: 3).getReg();
382 uint32_t NumRowsA = MI.getOperand(i: 4).getImm();
383 uint32_t NumColsA = MI.getOperand(i: 5).getImm();
384 uint32_t NumColsB = MI.getOperand(i: 6).getImm();
385
386 Builder.setInstrAndDebugLoc(MI);
387
388 SPIRVGlobalRegistry *GR =
389 MI.getMF()->getSubtarget<SPIRVSubtarget>().getSPIRVGlobalRegistry();
390
391 SPIRVTypeInst SpvVecType = getDotProductVectorType(ResReg, K: NumColsA, GR);
392 SmallVector<Register, 4> ColsB =
393 extractColumns(MatrixReg: BReg, NumberOfCols: NumColsB, SpvColType: SpvVecType, GR);
394 SmallVector<Register, 4> RowsA =
395 extractRows(MatrixReg: AReg, NumRows: NumRowsA, NumCols: NumColsA, SpvRowType: SpvVecType, GR);
396 SmallVector<Register, 16> ResultScalars =
397 computeDotProducts(RowsA, ColsB, SpvVecType, GR);
398
399 if (ResultScalars.size() == 1)
400 Builder.buildCopy(Res: ResReg, Op: ResultScalars[0]);
401 else
402 Builder.buildBuildVector(Res: ResReg, Ops: ResultScalars);
403 MI.eraseFromParent();
404}
405