1//===-- SPIRVCombinerHelper.h -----------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// This contains common combine transformations that may be used in a combine
10/// pass.
11///
12//===----------------------------------------------------------------------===//
13
14#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVCOMBINERHELPER_H
15#define LLVM_LIB_TARGET_SPIRV_SPIRVCOMBINERHELPER_H
16
17#include "SPIRVSubtarget.h"
18#include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
19
20namespace llvm {
21class SPIRVCombinerHelper : public CombinerHelper {
22protected:
23 const SPIRVSubtarget &STI;
24
25public:
26 using CombinerHelper::CombinerHelper;
27 SPIRVCombinerHelper(GISelChangeObserver &Observer, MachineIRBuilder &B,
28 bool IsPreLegalize, GISelValueTracking *VT,
29 MachineDominatorTree *MDT, const LegalizerInfo *LI,
30 const SPIRVSubtarget &STI);
31
32 bool matchLengthToDistance(MachineInstr &MI) const;
33 void applySPIRVDistance(MachineInstr &MI) const;
34 bool matchSelectToFaceForward(MachineInstr &MI) const;
35 void applySPIRVFaceForward(MachineInstr &MI) const;
36 bool matchMatrixTranspose(MachineInstr &MI) const;
37 void applyMatrixTranspose(MachineInstr &MI) const;
38 bool matchMatrixMultiply(MachineInstr &MI) const;
39 void applyMatrixMultiply(MachineInstr &MI) const;
40
41private:
42 SPIRVType *getDotProductVectorType(Register ResReg, uint32_t K,
43 SPIRVGlobalRegistry *GR) const;
44 SmallVector<Register, 4> extractColumns(Register BReg, uint32_t N,
45 SPIRVType *SpvVecType,
46 SPIRVGlobalRegistry *GR) const;
47 SmallVector<Register, 4> extractRows(Register AReg, uint32_t NumRows,
48 uint32_t NumCols, SPIRVType *SpvRowType,
49 SPIRVGlobalRegistry *GR) const;
50 SmallVector<Register, 16>
51 computeDotProducts(const SmallVector<Register, 4> &RowsA,
52 const SmallVector<Register, 4> &ColsB,
53 SPIRVType *SpvVecType, SPIRVGlobalRegistry *GR) const;
54 Register computeDotProduct(Register RowA, Register ColB,
55 SPIRVType *SpvVecType,
56 SPIRVGlobalRegistry *GR) const;
57};
58
59} // end namespace llvm
60
61#endif // LLVM_LIB_TARGET_SPIRV_SPIRVCOMBINERHELPER_H
62