1//===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the MatrixBuilder class, which is used as a convenient way
10// to lower matrix operations to LLVM IR.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef LLVM_IR_MATRIXBUILDER_H
15#define LLVM_IR_MATRIXBUILDER_H
16
17#include "llvm/IR/Constant.h"
18#include "llvm/IR/Constants.h"
19#include "llvm/IR/IRBuilder.h"
20#include "llvm/IR/InstrTypes.h"
21#include "llvm/IR/Instruction.h"
22#include "llvm/IR/IntrinsicInst.h"
23#include "llvm/IR/Type.h"
24#include "llvm/IR/Value.h"
25#include "llvm/Support/Alignment.h"
26
27namespace llvm {
28
29class Function;
30class Twine;
31class Module;
32
33class MatrixBuilder {
34 IRBuilderBase &B;
35 Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); }
36
37 std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS,
38 Value *RHS) {
39 assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) &&
40 "One of the operands must be a matrix (embedded in a vector)");
41 if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
42 assert(!isa<ScalableVectorType>(LHS->getType()) &&
43 "LHS Assumed to be fixed width");
44 RHS = B.CreateVectorSplat(
45 EC: cast<VectorType>(Val: LHS->getType())->getElementCount(), V: RHS,
46 Name: "scalar.splat");
47 } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
48 assert(!isa<ScalableVectorType>(RHS->getType()) &&
49 "RHS Assumed to be fixed width");
50 LHS = B.CreateVectorSplat(
51 EC: cast<VectorType>(Val: RHS->getType())->getElementCount(), V: LHS,
52 Name: "scalar.splat");
53 }
54 return {LHS, RHS};
55 }
56
57public:
58 MatrixBuilder(IRBuilderBase &Builder) : B(Builder) {}
59
60 /// Create a column major, strided matrix load.
61 /// \p EltTy - Matrix element type
62 /// \p DataPtr - Start address of the matrix read
63 /// \p Rows - Number of rows in matrix (must be a constant)
64 /// \p Columns - Number of columns in matrix (must be a constant)
65 /// \p Stride - Space between columns
66 CallInst *CreateColumnMajorLoad(Type *EltTy, Value *DataPtr, Align Alignment,
67 Value *Stride, bool IsVolatile, unsigned Rows,
68 unsigned Columns, const Twine &Name = "") {
69 auto *RetType = FixedVectorType::get(ElementType: EltTy, NumElts: Rows * Columns);
70
71 Value *Ops[] = {DataPtr, Stride, B.getInt1(V: IsVolatile), B.getInt32(C: Rows),
72 B.getInt32(C: Columns)};
73 Type *OverloadedTypes[] = {RetType, Stride->getType()};
74
75 Function *TheFn = Intrinsic::getDeclaration(
76 M: getModule(), id: Intrinsic::matrix_column_major_load, Tys: OverloadedTypes);
77
78 CallInst *Call = B.CreateCall(FTy: TheFn->getFunctionType(), Callee: TheFn, Args: Ops, Name);
79 Attribute AlignAttr =
80 Attribute::getWithAlignment(Context&: Call->getContext(), Alignment);
81 Call->addParamAttr(ArgNo: 0, Attr: AlignAttr);
82 return Call;
83 }
84
85 /// Create a column major, strided matrix store.
86 /// \p Matrix - Matrix to store
87 /// \p Ptr - Pointer to write back to
88 /// \p Stride - Space between columns
89 CallInst *CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment,
90 Value *Stride, bool IsVolatile,
91 unsigned Rows, unsigned Columns,
92 const Twine &Name = "") {
93 Value *Ops[] = {Matrix, Ptr,
94 Stride, B.getInt1(V: IsVolatile),
95 B.getInt32(C: Rows), B.getInt32(C: Columns)};
96 Type *OverloadedTypes[] = {Matrix->getType(), Stride->getType()};
97
98 Function *TheFn = Intrinsic::getDeclaration(
99 M: getModule(), id: Intrinsic::matrix_column_major_store, Tys: OverloadedTypes);
100
101 CallInst *Call = B.CreateCall(FTy: TheFn->getFunctionType(), Callee: TheFn, Args: Ops, Name);
102 Attribute AlignAttr =
103 Attribute::getWithAlignment(Context&: Call->getContext(), Alignment);
104 Call->addParamAttr(ArgNo: 1, Attr: AlignAttr);
105 return Call;
106 }
107
108 /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows
109 /// rows and \p Columns columns.
110 CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows,
111 unsigned Columns, const Twine &Name = "") {
112 auto *OpType = cast<VectorType>(Val: Matrix->getType());
113 auto *ReturnType =
114 FixedVectorType::get(ElementType: OpType->getElementType(), NumElts: Rows * Columns);
115
116 Type *OverloadedTypes[] = {ReturnType};
117 Value *Ops[] = {Matrix, B.getInt32(C: Rows), B.getInt32(C: Columns)};
118 Function *TheFn = Intrinsic::getDeclaration(
119 M: getModule(), id: Intrinsic::matrix_transpose, Tys: OverloadedTypes);
120
121 return B.CreateCall(FTy: TheFn->getFunctionType(), Callee: TheFn, Args: Ops, Name);
122 }
123
124 /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p
125 /// RHS.
126 CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows,
127 unsigned LHSColumns, unsigned RHSColumns,
128 const Twine &Name = "") {
129 auto *LHSType = cast<VectorType>(Val: LHS->getType());
130 auto *RHSType = cast<VectorType>(Val: RHS->getType());
131
132 auto *ReturnType =
133 FixedVectorType::get(ElementType: LHSType->getElementType(), NumElts: LHSRows * RHSColumns);
134
135 Value *Ops[] = {LHS, RHS, B.getInt32(C: LHSRows), B.getInt32(C: LHSColumns),
136 B.getInt32(C: RHSColumns)};
137 Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType};
138
139 Function *TheFn = Intrinsic::getDeclaration(
140 M: getModule(), id: Intrinsic::matrix_multiply, Tys: OverloadedTypes);
141 return B.CreateCall(FTy: TheFn->getFunctionType(), Callee: TheFn, Args: Ops, Name);
142 }
143
144 /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p
145 /// ColumnIdx).
146 Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx,
147 Value *ColumnIdx, unsigned NumRows) {
148 return B.CreateInsertElement(
149 Vec: Matrix, NewElt: NewVal,
150 Idx: B.CreateAdd(LHS: B.CreateMul(LHS: ColumnIdx, RHS: ConstantInt::get(
151 Ty: ColumnIdx->getType(), V: NumRows)),
152 RHS: RowIdx));
153 }
154
155 /// Add matrixes \p LHS and \p RHS. Support both integer and floating point
156 /// matrixes.
157 Value *CreateAdd(Value *LHS, Value *RHS) {
158 assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
159 if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
160 assert(!isa<ScalableVectorType>(LHS->getType()) &&
161 "LHS Assumed to be fixed width");
162 RHS = B.CreateVectorSplat(
163 EC: cast<VectorType>(Val: LHS->getType())->getElementCount(), V: RHS,
164 Name: "scalar.splat");
165 } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
166 assert(!isa<ScalableVectorType>(RHS->getType()) &&
167 "RHS Assumed to be fixed width");
168 LHS = B.CreateVectorSplat(
169 EC: cast<VectorType>(Val: RHS->getType())->getElementCount(), V: LHS,
170 Name: "scalar.splat");
171 }
172
173 return cast<VectorType>(Val: LHS->getType())
174 ->getElementType()
175 ->isFloatingPointTy()
176 ? B.CreateFAdd(L: LHS, R: RHS)
177 : B.CreateAdd(LHS, RHS);
178 }
179
180 /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating
181 /// point matrixes.
182 Value *CreateSub(Value *LHS, Value *RHS) {
183 assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
184 if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
185 assert(!isa<ScalableVectorType>(LHS->getType()) &&
186 "LHS Assumed to be fixed width");
187 RHS = B.CreateVectorSplat(
188 EC: cast<VectorType>(Val: LHS->getType())->getElementCount(), V: RHS,
189 Name: "scalar.splat");
190 } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
191 assert(!isa<ScalableVectorType>(RHS->getType()) &&
192 "RHS Assumed to be fixed width");
193 LHS = B.CreateVectorSplat(
194 EC: cast<VectorType>(Val: RHS->getType())->getElementCount(), V: LHS,
195 Name: "scalar.splat");
196 }
197
198 return cast<VectorType>(Val: LHS->getType())
199 ->getElementType()
200 ->isFloatingPointTy()
201 ? B.CreateFSub(L: LHS, R: RHS)
202 : B.CreateSub(LHS, RHS);
203 }
204
205 /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p
206 /// RHS.
207 Value *CreateScalarMultiply(Value *LHS, Value *RHS) {
208 std::tie(args&: LHS, args&: RHS) = splatScalarOperandIfNeeded(LHS, RHS);
209 if (LHS->getType()->getScalarType()->isFloatingPointTy())
210 return B.CreateFMul(L: LHS, R: RHS);
211 return B.CreateMul(LHS, RHS);
212 }
213
214 /// Divide matrix \p LHS by scalar \p RHS. If the operands are integers, \p
215 /// IsUnsigned indicates whether UDiv or SDiv should be used.
216 Value *CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned) {
217 assert(LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy());
218 assert(!isa<ScalableVectorType>(LHS->getType()) &&
219 "LHS Assumed to be fixed width");
220 RHS =
221 B.CreateVectorSplat(EC: cast<VectorType>(Val: LHS->getType())->getElementCount(),
222 V: RHS, Name: "scalar.splat");
223 return cast<VectorType>(Val: LHS->getType())
224 ->getElementType()
225 ->isFloatingPointTy()
226 ? B.CreateFDiv(L: LHS, R: RHS)
227 : (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS));
228 }
229
230 /// Create an assumption that \p Idx is less than \p NumElements.
231 void CreateIndexAssumption(Value *Idx, unsigned NumElements,
232 Twine const &Name = "") {
233 Value *NumElts =
234 B.getIntN(N: Idx->getType()->getScalarSizeInBits(), C: NumElements);
235 auto *Cmp = B.CreateICmpULT(LHS: Idx, RHS: NumElts);
236 if (isa<ConstantInt>(Val: Cmp))
237 assert(cast<ConstantInt>(Cmp)->isOne() && "Index must be valid!");
238 else
239 B.CreateAssumption(Cond: Cmp);
240 }
241
242 /// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from
243 /// a matrix with \p NumRows embedded in a vector.
244 Value *CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows,
245 Twine const &Name = "") {
246 unsigned MaxWidth = std::max(a: RowIdx->getType()->getScalarSizeInBits(),
247 b: ColumnIdx->getType()->getScalarSizeInBits());
248 Type *IntTy = IntegerType::get(C&: RowIdx->getType()->getContext(), NumBits: MaxWidth);
249 RowIdx = B.CreateZExt(V: RowIdx, DestTy: IntTy);
250 ColumnIdx = B.CreateZExt(V: ColumnIdx, DestTy: IntTy);
251 Value *NumRowsV = B.getIntN(N: MaxWidth, C: NumRows);
252 return B.CreateAdd(LHS: B.CreateMul(LHS: ColumnIdx, RHS: NumRowsV), RHS: RowIdx);
253 }
254};
255
256} // end namespace llvm
257
258#endif // LLVM_IR_MATRIXBUILDER_H
259