1 | //===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the MatrixBuilder class, which is used as a convenient way |
10 | // to lower matrix operations to LLVM IR. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #ifndef LLVM_IR_MATRIXBUILDER_H |
15 | #define LLVM_IR_MATRIXBUILDER_H |
16 | |
17 | #include "llvm/IR/Constant.h" |
18 | #include "llvm/IR/Constants.h" |
19 | #include "llvm/IR/IRBuilder.h" |
20 | #include "llvm/IR/InstrTypes.h" |
21 | #include "llvm/IR/Instruction.h" |
22 | #include "llvm/IR/IntrinsicInst.h" |
23 | #include "llvm/IR/Type.h" |
24 | #include "llvm/IR/Value.h" |
25 | #include "llvm/Support/Alignment.h" |
26 | |
27 | namespace llvm { |
28 | |
29 | class Function; |
30 | class Twine; |
31 | class Module; |
32 | |
33 | class MatrixBuilder { |
34 | IRBuilderBase &B; |
35 | Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); } |
36 | |
37 | std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS, |
38 | Value *RHS) { |
39 | assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) && |
40 | "One of the operands must be a matrix (embedded in a vector)" ); |
41 | if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { |
42 | assert(!isa<ScalableVectorType>(LHS->getType()) && |
43 | "LHS Assumed to be fixed width" ); |
44 | RHS = B.CreateVectorSplat( |
45 | EC: cast<VectorType>(Val: LHS->getType())->getElementCount(), V: RHS, |
46 | Name: "scalar.splat" ); |
47 | } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { |
48 | assert(!isa<ScalableVectorType>(RHS->getType()) && |
49 | "RHS Assumed to be fixed width" ); |
50 | LHS = B.CreateVectorSplat( |
51 | EC: cast<VectorType>(Val: RHS->getType())->getElementCount(), V: LHS, |
52 | Name: "scalar.splat" ); |
53 | } |
54 | return {LHS, RHS}; |
55 | } |
56 | |
57 | public: |
58 | MatrixBuilder(IRBuilderBase &Builder) : B(Builder) {} |
59 | |
60 | /// Create a column major, strided matrix load. |
61 | /// \p EltTy - Matrix element type |
62 | /// \p DataPtr - Start address of the matrix read |
63 | /// \p Rows - Number of rows in matrix (must be a constant) |
64 | /// \p Columns - Number of columns in matrix (must be a constant) |
65 | /// \p Stride - Space between columns |
66 | CallInst *CreateColumnMajorLoad(Type *EltTy, Value *DataPtr, Align Alignment, |
67 | Value *Stride, bool IsVolatile, unsigned Rows, |
68 | unsigned Columns, const Twine &Name = "" ) { |
69 | auto *RetType = FixedVectorType::get(ElementType: EltTy, NumElts: Rows * Columns); |
70 | |
71 | Value *Ops[] = {DataPtr, Stride, B.getInt1(V: IsVolatile), B.getInt32(C: Rows), |
72 | B.getInt32(C: Columns)}; |
73 | Type *OverloadedTypes[] = {RetType, Stride->getType()}; |
74 | |
75 | Function *TheFn = Intrinsic::getDeclaration( |
76 | M: getModule(), id: Intrinsic::matrix_column_major_load, Tys: OverloadedTypes); |
77 | |
78 | CallInst *Call = B.CreateCall(FTy: TheFn->getFunctionType(), Callee: TheFn, Args: Ops, Name); |
79 | Attribute AlignAttr = |
80 | Attribute::getWithAlignment(Context&: Call->getContext(), Alignment); |
81 | Call->addParamAttr(ArgNo: 0, Attr: AlignAttr); |
82 | return Call; |
83 | } |
84 | |
85 | /// Create a column major, strided matrix store. |
86 | /// \p Matrix - Matrix to store |
87 | /// \p Ptr - Pointer to write back to |
88 | /// \p Stride - Space between columns |
89 | CallInst *CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment, |
90 | Value *Stride, bool IsVolatile, |
91 | unsigned Rows, unsigned Columns, |
92 | const Twine &Name = "" ) { |
93 | Value *Ops[] = {Matrix, Ptr, |
94 | Stride, B.getInt1(V: IsVolatile), |
95 | B.getInt32(C: Rows), B.getInt32(C: Columns)}; |
96 | Type *OverloadedTypes[] = {Matrix->getType(), Stride->getType()}; |
97 | |
98 | Function *TheFn = Intrinsic::getDeclaration( |
99 | M: getModule(), id: Intrinsic::matrix_column_major_store, Tys: OverloadedTypes); |
100 | |
101 | CallInst *Call = B.CreateCall(FTy: TheFn->getFunctionType(), Callee: TheFn, Args: Ops, Name); |
102 | Attribute AlignAttr = |
103 | Attribute::getWithAlignment(Context&: Call->getContext(), Alignment); |
104 | Call->addParamAttr(ArgNo: 1, Attr: AlignAttr); |
105 | return Call; |
106 | } |
107 | |
108 | /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows |
109 | /// rows and \p Columns columns. |
110 | CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows, |
111 | unsigned Columns, const Twine &Name = "" ) { |
112 | auto *OpType = cast<VectorType>(Val: Matrix->getType()); |
113 | auto *ReturnType = |
114 | FixedVectorType::get(ElementType: OpType->getElementType(), NumElts: Rows * Columns); |
115 | |
116 | Type *OverloadedTypes[] = {ReturnType}; |
117 | Value *Ops[] = {Matrix, B.getInt32(C: Rows), B.getInt32(C: Columns)}; |
118 | Function *TheFn = Intrinsic::getDeclaration( |
119 | M: getModule(), id: Intrinsic::matrix_transpose, Tys: OverloadedTypes); |
120 | |
121 | return B.CreateCall(FTy: TheFn->getFunctionType(), Callee: TheFn, Args: Ops, Name); |
122 | } |
123 | |
124 | /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p |
125 | /// RHS. |
126 | CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, |
127 | unsigned LHSColumns, unsigned RHSColumns, |
128 | const Twine &Name = "" ) { |
129 | auto *LHSType = cast<VectorType>(Val: LHS->getType()); |
130 | auto *RHSType = cast<VectorType>(Val: RHS->getType()); |
131 | |
132 | auto *ReturnType = |
133 | FixedVectorType::get(ElementType: LHSType->getElementType(), NumElts: LHSRows * RHSColumns); |
134 | |
135 | Value *Ops[] = {LHS, RHS, B.getInt32(C: LHSRows), B.getInt32(C: LHSColumns), |
136 | B.getInt32(C: RHSColumns)}; |
137 | Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType}; |
138 | |
139 | Function *TheFn = Intrinsic::getDeclaration( |
140 | M: getModule(), id: Intrinsic::matrix_multiply, Tys: OverloadedTypes); |
141 | return B.CreateCall(FTy: TheFn->getFunctionType(), Callee: TheFn, Args: Ops, Name); |
142 | } |
143 | |
144 | /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p |
145 | /// ColumnIdx). |
146 | Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx, |
147 | Value *ColumnIdx, unsigned NumRows) { |
148 | return B.CreateInsertElement( |
149 | Vec: Matrix, NewElt: NewVal, |
150 | Idx: B.CreateAdd(LHS: B.CreateMul(LHS: ColumnIdx, RHS: ConstantInt::get( |
151 | Ty: ColumnIdx->getType(), V: NumRows)), |
152 | RHS: RowIdx)); |
153 | } |
154 | |
155 | /// Add matrixes \p LHS and \p RHS. Support both integer and floating point |
156 | /// matrixes. |
157 | Value *CreateAdd(Value *LHS, Value *RHS) { |
158 | assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); |
159 | if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { |
160 | assert(!isa<ScalableVectorType>(LHS->getType()) && |
161 | "LHS Assumed to be fixed width" ); |
162 | RHS = B.CreateVectorSplat( |
163 | EC: cast<VectorType>(Val: LHS->getType())->getElementCount(), V: RHS, |
164 | Name: "scalar.splat" ); |
165 | } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { |
166 | assert(!isa<ScalableVectorType>(RHS->getType()) && |
167 | "RHS Assumed to be fixed width" ); |
168 | LHS = B.CreateVectorSplat( |
169 | EC: cast<VectorType>(Val: RHS->getType())->getElementCount(), V: LHS, |
170 | Name: "scalar.splat" ); |
171 | } |
172 | |
173 | return cast<VectorType>(Val: LHS->getType()) |
174 | ->getElementType() |
175 | ->isFloatingPointTy() |
176 | ? B.CreateFAdd(L: LHS, R: RHS) |
177 | : B.CreateAdd(LHS, RHS); |
178 | } |
179 | |
180 | /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating |
181 | /// point matrixes. |
182 | Value *CreateSub(Value *LHS, Value *RHS) { |
183 | assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); |
184 | if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { |
185 | assert(!isa<ScalableVectorType>(LHS->getType()) && |
186 | "LHS Assumed to be fixed width" ); |
187 | RHS = B.CreateVectorSplat( |
188 | EC: cast<VectorType>(Val: LHS->getType())->getElementCount(), V: RHS, |
189 | Name: "scalar.splat" ); |
190 | } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { |
191 | assert(!isa<ScalableVectorType>(RHS->getType()) && |
192 | "RHS Assumed to be fixed width" ); |
193 | LHS = B.CreateVectorSplat( |
194 | EC: cast<VectorType>(Val: RHS->getType())->getElementCount(), V: LHS, |
195 | Name: "scalar.splat" ); |
196 | } |
197 | |
198 | return cast<VectorType>(Val: LHS->getType()) |
199 | ->getElementType() |
200 | ->isFloatingPointTy() |
201 | ? B.CreateFSub(L: LHS, R: RHS) |
202 | : B.CreateSub(LHS, RHS); |
203 | } |
204 | |
205 | /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p |
206 | /// RHS. |
207 | Value *CreateScalarMultiply(Value *LHS, Value *RHS) { |
208 | std::tie(args&: LHS, args&: RHS) = splatScalarOperandIfNeeded(LHS, RHS); |
209 | if (LHS->getType()->getScalarType()->isFloatingPointTy()) |
210 | return B.CreateFMul(L: LHS, R: RHS); |
211 | return B.CreateMul(LHS, RHS); |
212 | } |
213 | |
214 | /// Divide matrix \p LHS by scalar \p RHS. If the operands are integers, \p |
215 | /// IsUnsigned indicates whether UDiv or SDiv should be used. |
216 | Value *CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned) { |
217 | assert(LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()); |
218 | assert(!isa<ScalableVectorType>(LHS->getType()) && |
219 | "LHS Assumed to be fixed width" ); |
220 | RHS = |
221 | B.CreateVectorSplat(EC: cast<VectorType>(Val: LHS->getType())->getElementCount(), |
222 | V: RHS, Name: "scalar.splat" ); |
223 | return cast<VectorType>(Val: LHS->getType()) |
224 | ->getElementType() |
225 | ->isFloatingPointTy() |
226 | ? B.CreateFDiv(L: LHS, R: RHS) |
227 | : (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS)); |
228 | } |
229 | |
230 | /// Create an assumption that \p Idx is less than \p NumElements. |
231 | void CreateIndexAssumption(Value *Idx, unsigned NumElements, |
232 | Twine const &Name = "" ) { |
233 | Value *NumElts = |
234 | B.getIntN(N: Idx->getType()->getScalarSizeInBits(), C: NumElements); |
235 | auto *Cmp = B.CreateICmpULT(LHS: Idx, RHS: NumElts); |
236 | if (isa<ConstantInt>(Val: Cmp)) |
237 | assert(cast<ConstantInt>(Cmp)->isOne() && "Index must be valid!" ); |
238 | else |
239 | B.CreateAssumption(Cond: Cmp); |
240 | } |
241 | |
242 | /// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from |
243 | /// a matrix with \p NumRows embedded in a vector. |
244 | Value *CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows, |
245 | Twine const &Name = "" ) { |
246 | unsigned MaxWidth = std::max(a: RowIdx->getType()->getScalarSizeInBits(), |
247 | b: ColumnIdx->getType()->getScalarSizeInBits()); |
248 | Type *IntTy = IntegerType::get(C&: RowIdx->getType()->getContext(), NumBits: MaxWidth); |
249 | RowIdx = B.CreateZExt(V: RowIdx, DestTy: IntTy); |
250 | ColumnIdx = B.CreateZExt(V: ColumnIdx, DestTy: IntTy); |
251 | Value *NumRowsV = B.getIntN(N: MaxWidth, C: NumRows); |
252 | return B.CreateAdd(LHS: B.CreateMul(LHS: ColumnIdx, RHS: NumRowsV), RHS: RowIdx); |
253 | } |
254 | }; |
255 | |
256 | } // end namespace llvm |
257 | |
258 | #endif // LLVM_IR_MATRIXBUILDER_H |
259 | |