1//===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Lower matrix intrinsics to vector operations.
10//
11// TODO:
12// * Improve fusion:
13// * Support more cases, e.g. multiply-add, multiply-sub, operands/results
14// transposed.
15// * Improve cost-modeling, e.g. choose different number of rows/columns
16// columns for tiles, consider cost of copies on alias.
17//
18//===----------------------------------------------------------------------===//
19
20#include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
21#include "llvm/ADT/PostOrderIterator.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/ScopeExit.h"
24#include "llvm/ADT/SmallSet.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/Statistic.h"
27#include "llvm/Analysis/AliasAnalysis.h"
28#include "llvm/Analysis/DomTreeUpdater.h"
29#include "llvm/Analysis/LoopInfo.h"
30#include "llvm/Analysis/OptimizationRemarkEmitter.h"
31#include "llvm/Analysis/TargetTransformInfo.h"
32#include "llvm/Analysis/ValueTracking.h"
33#include "llvm/Analysis/VectorUtils.h"
34#include "llvm/IR/CFG.h"
35#include "llvm/IR/DataLayout.h"
36#include "llvm/IR/DebugInfoMetadata.h"
37#include "llvm/IR/DerivedTypes.h"
38#include "llvm/IR/Function.h"
39#include "llvm/IR/IRBuilder.h"
40#include "llvm/IR/InstrTypes.h"
41#include "llvm/IR/Instructions.h"
42#include "llvm/IR/IntrinsicInst.h"
43#include "llvm/IR/MatrixBuilder.h"
44#include "llvm/IR/PatternMatch.h"
45#include "llvm/Support/Alignment.h"
46#include "llvm/Support/CommandLine.h"
47#include "llvm/Support/Compiler.h"
48#include "llvm/Support/Debug.h"
49#include "llvm/Transforms/Utils/BasicBlockUtils.h"
50#include "llvm/Transforms/Utils/LoopUtils.h"
51#include "llvm/Transforms/Utils/MatrixUtils.h"
52
53#include <cmath>
54
55using namespace llvm;
56using namespace PatternMatch;
57
58#define DEBUG_TYPE "lower-matrix-intrinsics"
59
60STATISTIC(FlattenedMatrices, "Number of matrix flattenings");
61STATISTIC(ReshapedMatrices, "Number of matrix reshapes");
62STATISTIC(SplitMatrices, "Number of matrix splits");
63
64static cl::opt<bool>
65 FuseMatrix("fuse-matrix", cl::init(Val: true), cl::Hidden,
66 cl::desc("Enable/disable fusing matrix instructions."));
67// TODO: Allow and use non-square tiles.
68static cl::opt<unsigned> TileSize(
69 "fuse-matrix-tile-size", cl::init(Val: 4), cl::Hidden,
70 cl::desc(
71 "Tile size for matrix instruction fusion using square-shaped tiles."));
72static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(Val: false),
73 cl::Hidden,
74 cl::desc("Generate loop nest for tiling."));
75static cl::opt<bool> ForceFusion(
76 "force-fuse-matrix", cl::init(Val: false), cl::Hidden,
77 cl::desc("Force matrix instruction fusion even if not profitable."));
78static cl::opt<bool> AllowContractEnabled(
79 "matrix-allow-contract", cl::init(Val: false), cl::Hidden,
80 cl::desc("Allow the use of FMAs if available and profitable. This may "
81 "result in different results, due to less rounding error."));
82
83static cl::opt<bool>
84 VerifyShapeInfo("verify-matrix-shapes", cl::Hidden,
85 cl::desc("Enable/disable matrix shape verification."),
86 cl::init(Val: false));
87
88enum class MatrixLayoutTy { ColumnMajor, RowMajor };
89
90static cl::opt<MatrixLayoutTy> MatrixLayout(
91 "matrix-default-layout", cl::init(Val: MatrixLayoutTy::ColumnMajor),
92 cl::desc("Sets the default matrix layout"),
93 cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major",
94 "Use column-major layout"),
95 clEnumValN(MatrixLayoutTy::RowMajor, "row-major",
96 "Use row-major layout")));
97
98static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
99 cl::init(Val: false));
100
101/// Helper function to either return Scope, if it is a subprogram or the
102/// attached subprogram for a local scope.
103static DISubprogram *getSubprogram(DIScope *Scope) {
104 if (auto *Subprogram = dyn_cast<DISubprogram>(Val: Scope))
105 return Subprogram;
106 return cast<DILocalScope>(Val: Scope)->getSubprogram();
107}
108
109/// Return true if V is a splat of a value (which is used when multiplying a
110/// matrix with a scalar).
111static bool isSplat(Value *V) {
112 if (auto *SV = dyn_cast<ShuffleVectorInst>(Val: V))
113 return SV->isZeroEltSplat();
114 return false;
115}
116
117/// Match any mul operation (fp or integer).
118template <typename LTy, typename RTy>
119auto m_AnyMul(const LTy &L, const RTy &R) {
120 return m_CombineOr(m_Mul(L, R), m_FMul(L, R));
121}
122
123/// Match any add operation (fp or integer).
124template <typename LTy, typename RTy>
125auto m_AnyAdd(const LTy &L, const RTy &R) {
126 return m_CombineOr(m_Add(L, R), m_FAdd(L, R));
127}
128
129namespace {
130
131// Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
132// the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
133// assuming \p Stride elements between start two consecutive vectors.
134// \p Stride must be >= \p NumElements.
135// For column-major matrixes, the function computes the address of a column
136// vectors and \p NumElements must be set to the number of elements in a column
137// (= number of rows of the matrix). For row-major matrixes, the function
138// computes the address of a row vector and \p NumElements must be set to the
139// number of elements in a column (= number of columns of the matrix).
140//
141// Consider a 4x4 matrix in column-mjaor layout like below
142//
143// 0 1 2 3
144// 0 v_0_0 v_0_1 v_0_2 v_0_3
145// 1 v_1_0 v_1_1 v_1_2 v_1_3
146// 2 v_2_0 v_2_1 v_2_2 v_2_3
147// 3 v_3_0 v_3_1 v_3_2 v_3_3
148
149// To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
150// we need a pointer to the first element of the submatrix as base pointer.
151// Then we can use computeVectorAddr to compute the addresses for the columns
152// of the sub-matrix.
153//
154// Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
155// -> just returns Base
156// Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
157// -> returns Base + (1 * 4)
158// Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
159// -> returns Base + (2 * 4)
160//
161// The graphic below illustrates the number of elements in a column (marked
162// with |) and the number of skipped elements (marked with }).
163//
164// v_0_0 v_0_1 {v_0_2 {v_0_3
165// Base Col 1 Col 2
166// | | |
167// v_1_0 |v_1_1 |v_1_2 |v_1_3
168// v_2_0 |v_2_1 |v_2_2 |v_2_3
169// v_3_0 {v_3_1 {v_3_2 v_3_3
170//
171Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
172 unsigned NumElements, Type *EltType,
173 IRBuilder<> &Builder) {
174
175 assert((!isa<ConstantInt>(Stride) ||
176 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
177 "Stride must be >= the number of elements in the result vector.");
178
179 // Compute the start of the vector with index VecIdx as VecIdx * Stride.
180 Value *VecStart = Builder.CreateMul(LHS: VecIdx, RHS: Stride, Name: "vec.start");
181
182 // Get pointer to the start of the selected vector. Skip GEP creation,
183 // if we select vector 0.
184 if (isa<ConstantInt>(Val: VecStart) && cast<ConstantInt>(Val: VecStart)->isZero())
185 VecStart = BasePtr;
186 else
187 VecStart = Builder.CreateGEP(Ty: EltType, Ptr: BasePtr, IdxList: VecStart, Name: "vec.gep");
188
189 return VecStart;
190}
191
192namespace {
193struct ShapeInfo {
194 unsigned NumRows;
195 unsigned NumColumns;
196
197 bool IsColumnMajor;
198
199 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
200 : NumRows(NumRows), NumColumns(NumColumns),
201 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
202
203 ShapeInfo(Value *NumRows, Value *NumColumns)
204 : ShapeInfo(cast<ConstantInt>(Val: NumRows)->getZExtValue(),
205 cast<ConstantInt>(Val: NumColumns)->getZExtValue()) {}
206
207 bool operator==(const ShapeInfo &other) {
208 return NumRows == other.NumRows && NumColumns == other.NumColumns;
209 }
210 bool operator!=(const ShapeInfo &other) { return !(*this == other); }
211
212 /// Returns true if shape-information is defined, meaning both dimensions
213 /// are != 0.
214 operator bool() const {
215 assert(NumRows == 0 || NumColumns != 0);
216 return NumRows != 0;
217 }
218
219 unsigned getStride() const {
220 if (IsColumnMajor)
221 return NumRows;
222 return NumColumns;
223 }
224
225 unsigned getNumVectors() const {
226 if (IsColumnMajor)
227 return NumColumns;
228 return NumRows;
229 }
230
231 /// Returns the transposed shape.
232 ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
233
234 friend raw_ostream &operator<<(raw_ostream &OS, ShapeInfo SI);
235
236 LLVM_DUMP_METHOD void dump() const { dbgs() << *this << '\n'; }
237};
238
239raw_ostream &operator<<(raw_ostream &OS, ShapeInfo SI) {
240 return OS << SI.NumRows << 'x' << SI.NumColumns;
241}
242
243} // namespace
244
245static bool isUniformShape(Value *V) {
246 Instruction *I = dyn_cast<Instruction>(Val: V);
247 if (!I)
248 return true;
249
250 if (I->isBinaryOp())
251 return true;
252
253 if (auto *Cast = dyn_cast<CastInst>(Val: V)) {
254 switch (Cast->getOpcode()) {
255 case llvm::Instruction::Trunc:
256 case llvm::Instruction::ZExt:
257 case llvm::Instruction::SExt:
258 case llvm::Instruction::FPToUI:
259 case llvm::Instruction::FPToSI:
260 case llvm::Instruction::UIToFP:
261 case llvm::Instruction::SIToFP:
262 case llvm::Instruction::FPTrunc:
263 case llvm::Instruction::FPExt:
264 return true;
265 case llvm::Instruction::AddrSpaceCast:
266 case CastInst::PtrToInt:
267 case CastInst::IntToPtr:
268 return false;
269 case CastInst::BitCast: {
270 if (auto *SrcVTy = dyn_cast<FixedVectorType>(Val: Cast->getSrcTy()))
271 if (auto *DestVTy = dyn_cast<FixedVectorType>(Val: Cast->getDestTy()))
272 return SrcVTy->getNumElements() == DestVTy->getNumElements();
273 return false;
274 }
275 case llvm::Instruction::CastOpsEnd:
276 llvm_unreachable("not an actual cast op");
277 }
278 llvm_unreachable("unhandled cast opcode");
279 }
280
281 if (auto *II = dyn_cast<IntrinsicInst>(Val: V))
282 switch (II->getIntrinsicID()) {
283 case Intrinsic::abs:
284 case Intrinsic::fabs:
285 return true;
286 default:
287 return false;
288 }
289
290 switch (I->getOpcode()) {
291 case Instruction::PHI:
292 case Instruction::FNeg:
293 return true;
294 default:
295 return false;
296 }
297}
298
299/// Return the ShapeInfo for the result of \p I, it it can be determined.
300static std::optional<ShapeInfo>
301computeShapeInfoForInst(Instruction *I,
302 const DenseMap<Value *, ShapeInfo> &ShapeMap) {
303 Value *M;
304 Value *N;
305 Value *K;
306 if (match(V: I, P: m_Intrinsic<Intrinsic::matrix_multiply>(
307 Op0: m_Value(), Op1: m_Value(), Op2: m_Value(V&: M), Op3: m_Value(V&: N), Op4: m_Value(V&: K))))
308 return ShapeInfo(M, K);
309 if (match(V: I, P: m_Intrinsic<Intrinsic::matrix_transpose>(Op0: m_Value(), Op1: m_Value(V&: M),
310 Op2: m_Value(V&: N)))) {
311 // Flip dimensions.
312 return ShapeInfo(N, M);
313 }
314 if (match(V: I, P: m_Intrinsic<Intrinsic::matrix_column_major_store>(
315 Op0: m_Value(), Op1: m_Value(), Op2: m_Value(), Op3: m_Value(), Op4: m_Value(V&: M),
316 Op5: m_Value(V&: N))))
317 return ShapeInfo(N, M);
318 if (match(V: I, P: m_Intrinsic<Intrinsic::matrix_column_major_load>(
319 Op0: m_Value(), Op1: m_Value(), Op2: m_Value(), Op3: m_Value(V&: M), Op4: m_Value(V&: N))))
320 return ShapeInfo(M, N);
321 Value *MatrixA;
322 if (match(V: I, P: m_Store(ValueOp: m_Value(V&: MatrixA), PointerOp: m_Value()))) {
323 auto OpShape = ShapeMap.find(Val: MatrixA);
324 if (OpShape != ShapeMap.end())
325 return OpShape->second;
326 }
327
328 if (isUniformShape(V: I) || isa<SelectInst>(Val: I)) {
329 auto Ops = I->operands();
330 auto ShapedOps = isa<SelectInst>(Val: I) ? drop_begin(RangeOrContainer&: Ops) : Ops;
331 // Find the first operand that has a known shape and use that.
332 for (auto &Op : ShapedOps) {
333 auto OpShape = ShapeMap.find(Val: Op.get());
334 if (OpShape != ShapeMap.end())
335 return OpShape->second;
336 }
337 }
338 return std::nullopt;
339}
340
341/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
342///
343/// Currently, the lowering for each matrix intrinsic is done as follows:
344/// 1. Propagate the shape information from intrinsics to connected
345/// instructions.
346/// 2. Lower instructions with shape information (assuming column-major layout).
347/// The lowering works similarly using row-major layout.
348/// 2.1. Get column vectors for each argument. If we already lowered the
349/// definition of an argument, use the produced column vectors directly.
350/// If not, split the operand vector containing an embedded matrix into
351/// a set of column vectors,
352/// 2.2. Lower the instruction in terms of column major operations, which
353/// yields a set of column vectors containing result matrix. Note that we
354/// lower all instructions that have shape information. Besides the
355/// intrinsics, this includes stores for example.
356/// 2.3. Update uses of the lowered instruction. If we have shape information
357/// for a user, there is nothing to do, as we will look up the result
358/// column matrix when lowering the user. For other uses, we embed the
359/// result matrix in a flat vector and update the use.
360/// 2.4. Cache the result column matrix for the instruction we lowered
361/// 3. After we lowered all instructions in a function, remove the now
362/// obsolete instructions.
363///
364class LowerMatrixIntrinsics {
365 Function &Func;
366 const DataLayout &DL;
367 const TargetTransformInfo &TTI;
368 FunctionAnalysisManager *AM;
369 AliasAnalysis *AA = nullptr;
370 DominatorTree *DT = nullptr;
371 LoopInfo *LI = nullptr;
372 OptimizationRemarkEmitter *ORE = nullptr;
373
374 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
375 struct OpInfoTy {
376 /// Number of stores emitted to generate this matrix.
377 unsigned NumStores = 0;
378 /// Number of loads emitted to generate this matrix.
379 unsigned NumLoads = 0;
380 /// Number of compute operations emitted to generate this matrix.
381 unsigned NumComputeOps = 0;
382 /// Most of the time transposes can be fused with matrix multiplies or can
383 /// be folded away via algebraic simplifications. This is the number of
384 /// transposes that we failed to make "free" via such optimizations.
385 unsigned NumExposedTransposes = 0;
386
387 OpInfoTy &operator+=(const OpInfoTy &RHS) {
388 NumStores += RHS.NumStores;
389 NumLoads += RHS.NumLoads;
390 NumComputeOps += RHS.NumComputeOps;
391 NumExposedTransposes += RHS.NumExposedTransposes;
392 return *this;
393 }
394 };
395
396 /// Wrapper class representing a matrix as a set of vectors, either in row or
397 /// column major layout. All vectors must have the same vector type.
398 class MatrixTy {
399 SmallVector<Value *, 16> Vectors;
400
401 OpInfoTy OpInfo;
402
403 bool IsColumnMajor = true;
404
405 public:
406 MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
407 MatrixTy(ArrayRef<Value *> Vectors)
408 : Vectors(Vectors),
409 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
410 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)
411 : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {
412
413 unsigned D = isColumnMajor() ? NumColumns : NumRows;
414 for (unsigned J = 0; J < D; ++J)
415 addVector(V: PoisonValue::get(T: FixedVectorType::get(
416 ElementType: EltTy, NumElts: isColumnMajor() ? NumRows : NumColumns)));
417 }
418
419 Value *getVector(unsigned i) const { return Vectors[i]; }
420 Value *getColumn(unsigned i) const {
421 assert(isColumnMajor() && "only supported for column-major matrixes");
422 return Vectors[i];
423 }
424 Value *getRow(unsigned i) const {
425 assert(!isColumnMajor() && "only supported for row-major matrixes");
426 return Vectors[i];
427 }
428
429 void setVector(unsigned i, Value *V) { Vectors[i] = V; }
430
431 Type *getElementType() const { return getVectorTy()->getElementType(); }
432
433 unsigned getNumVectors() const {
434 if (isColumnMajor())
435 return getNumColumns();
436 return getNumRows();
437 }
438
439 unsigned getNumColumns() const {
440 if (isColumnMajor())
441 return Vectors.size();
442 else {
443 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
444 return getVectorTy()->getNumElements();
445 }
446 }
447 unsigned getNumRows() const {
448 if (isColumnMajor()) {
449 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
450 return getVectorTy()->getNumElements();
451 } else
452 return Vectors.size();
453 }
454
455 void addVector(Value *V) { Vectors.push_back(Elt: V); }
456 FixedVectorType *getColumnTy() {
457 assert(isColumnMajor() && "only supported for column-major matrixes");
458 return getVectorTy();
459 }
460
461 FixedVectorType *getVectorTy() const {
462 return cast<FixedVectorType>(Val: Vectors[0]->getType());
463 }
464
465 iterator_range<SmallVector<Value *, 8>::iterator> columns() {
466 assert(isColumnMajor() &&
467 "columns() only supported for column-major matrixes");
468 return make_range(x: Vectors.begin(), y: Vectors.end());
469 }
470
471 iterator_range<SmallVector<Value *, 8>::iterator> vectors() {
472 return make_range(x: Vectors.begin(), y: Vectors.end());
473 }
474
475 /// Embed the vectors of the matrix into a flat vector by concatenating
476 /// them.
477 Value *embedInVector(IRBuilder<> &Builder) const {
478 return Vectors.size() == 1 ? Vectors[0]
479 : concatenateVectors(Builder, Vecs: Vectors);
480 }
481
482 MatrixTy &addNumLoads(unsigned N) {
483 OpInfo.NumLoads += N;
484 return *this;
485 }
486
487 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
488
489 MatrixTy &addNumStores(unsigned N) {
490 OpInfo.NumStores += N;
491 return *this;
492 }
493
494 MatrixTy &addNumExposedTransposes(unsigned N) {
495 OpInfo.NumExposedTransposes += N;
496 return *this;
497 }
498
499 MatrixTy &addNumComputeOps(unsigned N) {
500 OpInfo.NumComputeOps += N;
501 return *this;
502 }
503
504 unsigned getNumStores() const { return OpInfo.NumStores; }
505 unsigned getNumLoads() const { return OpInfo.NumLoads; }
506 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
507
508 const OpInfoTy &getOpInfo() const { return OpInfo; }
509
510 bool isColumnMajor() const { return IsColumnMajor; }
511
512 unsigned getStride() const {
513 if (isColumnMajor())
514 return getNumRows();
515 return getNumColumns();
516 }
517
518 ShapeInfo shape() const { return {getNumRows(), getNumColumns()}; }
519
520 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
521 /// matrix is column-major, the result vector is extracted from a column
522 /// vector, otherwise from a row vector.
523 Value *extractVector(unsigned I, unsigned J, unsigned NumElts,
524 IRBuilder<> &Builder) const {
525 Value *Vec = isColumnMajor() ? getColumn(i: J) : getRow(i: I);
526 assert(cast<FixedVectorType>(Vec->getType())->getNumElements() >=
527 NumElts &&
528 "Extracted vector will contain poison values");
529 return Builder.CreateShuffleVector(
530 V: Vec, Mask: createSequentialMask(Start: isColumnMajor() ? I : J, NumInts: NumElts, NumUndefs: 0),
531 Name: "block");
532 }
533 };
534
535 /// Maps instructions to their shape information. The shape information
536 /// describes the shape to be used while lowering. This matches the shape of
537 /// the result value of the instruction, with the only exceptions being store
538 /// instructions and the matrix_column_major_store intrinsics. For those, the
539 /// shape information indicates that those instructions should be lowered
540 /// using shape information as well. Note that extra care is needed when
541 /// erasing or RAUW'ing a value that is present in ShapeMap. If the
542 /// replacement is also a matrix operation, use
543 /// updateShapeAndReplaceAllUsesWith to make sure the replacement is added to
544 /// ShapeMap. We don't use ValueMap, as there are also cases where we do not
545 /// want to add shape information for a replacement instruction. When directly
546 /// erasing a value with an entry in ShapeMap, use
547 /// eraseFromParentAndRemoveFromShapeMap to make sure ShapeMap is also updated
548 /// accordingly.
549 DenseMap<Value *, ShapeInfo> ShapeMap;
550
551 /// List of instructions to remove. While lowering, we are not replacing all
552 /// users of a lowered instruction, if shape information is available and
553 /// those need to be removed after we finished lowering.
554 SmallVector<Instruction *, 16> ToRemove;
555
556 /// Map from instructions to their produced column matrix.
557 MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
558
559private:
560 static FastMathFlags getFastMathFlags(Instruction *Inst) {
561 FastMathFlags FMF;
562
563 if (isa<FPMathOperator>(Val: *Inst))
564 FMF = Inst->getFastMathFlags();
565
566 FMF.setAllowContract(AllowContractEnabled || FMF.allowContract());
567
568 return FMF;
569 }
570
571public:
572 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
573 FunctionAnalysisManager *AM)
574 : Func(F), DL(F.getDataLayout()), TTI(TTI), AM(AM) {}
575
576 unsigned getNumOps(Type *VT) {
577 assert(isa<FixedVectorType>(VT) && "Expected vector type");
578 return getNumOps(ST: VT->getScalarType(),
579 N: cast<FixedVectorType>(Val: VT)->getNumElements());
580 }
581
582 /// Is this the minimal version executed in the backend pipelines.
583 bool isMinimal() const {
584 return !DT;
585 }
586
587 /// Return the estimated number of vector ops required for an operation on
588 /// \p VT * N.
589 unsigned getNumOps(Type *ST, unsigned N) {
590 return std::ceil(x: (ST->getPrimitiveSizeInBits() * N).getFixedValue() /
591 double(TTI.getRegisterBitWidth(
592 K: TargetTransformInfo::RGK_FixedWidthVector)
593 .getFixedValue()));
594 }
595
596 /// Return the set of vectors that a matrix value is lowered to.
597 ///
598 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
599 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
600 /// into vectors.
601 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
602 IRBuilder<> &Builder) {
603 FixedVectorType *VType = cast<FixedVectorType>(Val: MatrixVal->getType());
604 assert(VType->getNumElements() == SI.NumRows * SI.NumColumns &&
605 "The vector size must match the number of matrix elements");
606
607 // Check if we lowered MatrixVal using shape information. In that case,
608 // return the existing matrix, if it matches the requested shape
609 // information. If there is a mis-match, embed the result in a flat
610 // vector and split it later.
611 auto Found = Inst2ColumnMatrix.find(Key: MatrixVal);
612 if (Found != Inst2ColumnMatrix.end()) {
613 MatrixTy &M = Found->second;
614 // Return the found matrix, if its shape matches the requested shape
615 // information
616 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
617 return M;
618
619 MatrixVal = M.embedInVector(Builder);
620 }
621
622 // Otherwise split MatrixVal.
623 SmallVector<Value *, 16> SplitVecs;
624 for (unsigned MaskStart = 0; MaskStart < VType->getNumElements();
625 MaskStart += SI.getStride()) {
626 Value *V = Builder.CreateShuffleVector(
627 V: MatrixVal, Mask: createSequentialMask(Start: MaskStart, NumInts: SI.getStride(), NumUndefs: 0),
628 Name: "split");
629 SplitVecs.push_back(Elt: V);
630 }
631
632 if (Instruction *Inst = dyn_cast<Instruction>(Val: MatrixVal)) {
633 if (Found != Inst2ColumnMatrix.end()) {
634 // FIXME: re: "at least": SplitVecs.size() doesn't count the shuffles
635 // that embedInVector created.
636 LLVM_DEBUG(dbgs() << "matrix reshape from " << Found->second.shape()
637 << " to " << SI << " using at least "
638 << SplitVecs.size() << " shuffles on behalf of:\n"
639 << *Inst << '\n');
640 ReshapedMatrices++;
641 } else if (!ShapeMap.contains(Val: MatrixVal)) {
642 LLVM_DEBUG(
643 dbgs()
644 << "splitting a " << SI << " matrix with " << SplitVecs.size()
645 << " shuffles beacuse we do not have a shape-aware lowering for "
646 "its def:\n"
647 << *Inst << '\n');
648 (void)Inst;
649 SplitMatrices++;
650 } else {
651 // The ShapeMap has it, so it's a case where we're being lowered
652 // before the def, and we expect that InstCombine will clean things up
653 // afterward.
654 }
655 }
656
657 return {SplitVecs};
658 }
659
660 /// If \p V already has a known shape return false. Otherwise set the shape
661 /// for instructions that support it.
662 bool setShapeInfo(Value *V, ShapeInfo Shape) {
663 assert(Shape && "Shape not set");
664 if (isa<UndefValue>(Val: V) || !supportsShapeInfo(V))
665 return false;
666
667 auto SIter = ShapeMap.find(Val: V);
668 if (SIter != ShapeMap.end()) {
669 if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows ||
670 SIter->second.NumColumns != Shape.NumColumns)) {
671 errs() << "Conflicting shapes (" << SIter->second.NumRows << "x"
672 << SIter->second.NumColumns << " vs " << Shape.NumRows << "x"
673 << Shape.NumColumns << ") for " << *V << "\n";
674 report_fatal_error(
675 reason: "Matrix shape verification failed, compilation aborted!");
676 }
677
678 LLVM_DEBUG(dbgs() << " not overriding existing shape: "
679 << SIter->second.NumRows << " "
680 << SIter->second.NumColumns << " for " << *V << "\n");
681 return false;
682 }
683
684 ShapeMap.insert(KV: {V, Shape});
685 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns
686 << " for " << *V << "\n");
687 return true;
688 }
689
690 /// Returns true if shape information can be used for \p V. The supported
691 /// instructions must match the instructions that can be lowered by this pass.
692 bool supportsShapeInfo(Value *V) {
693 Instruction *Inst = dyn_cast<Instruction>(Val: V);
694 if (!Inst)
695 return false;
696
697 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: Inst);
698 if (II)
699 switch (II->getIntrinsicID()) {
700 case Intrinsic::matrix_multiply:
701 case Intrinsic::matrix_transpose:
702 case Intrinsic::matrix_column_major_load:
703 case Intrinsic::matrix_column_major_store:
704 return true;
705 default:
706 return isUniformShape(V: II);
707 }
708 return isUniformShape(V) || isa<StoreInst>(Val: V) || isa<LoadInst>(Val: V) ||
709 isa<SelectInst>(Val: V);
710 }
711
712 /// Propagate the shape information of instructions to their users.
713 /// The work list contains instructions for which we can compute the shape,
714 /// either based on the information provided by matrix intrinsics or known
715 /// shapes of operands.
716 SmallVector<Instruction *, 32>
717 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
718 SmallVector<Instruction *, 32> NewWorkList;
719 // Pop an element for which we guaranteed to have at least one of the
720 // operand shapes. Add the shape for this and then add users to the work
721 // list.
722 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
723 while (!WorkList.empty()) {
724 Instruction *Inst = WorkList.pop_back_val();
725
726 // New entry, set the value and insert operands
727 bool Propagate = false;
728 if (auto SI = computeShapeInfoForInst(I: Inst, ShapeMap))
729 Propagate = setShapeInfo(V: Inst, Shape: *SI);
730
731 if (Propagate) {
732 NewWorkList.push_back(Elt: Inst);
733 for (auto *User : Inst->users())
734 if (ShapeMap.count(Val: User) == 0)
735 WorkList.push_back(Elt: cast<Instruction>(Val: User));
736 }
737 }
738
739 return NewWorkList;
740 }
741
742 /// Propagate the shape to operands of instructions with shape information.
743 /// \p Worklist contains the instruction for which we already know the shape.
744 SmallVector<Instruction *, 32>
745 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
746 SmallVector<Instruction *, 32> NewWorkList;
747
748 auto pushInstruction = [](Value *V,
749 SmallVectorImpl<Instruction *> &WorkList) {
750 Instruction *I = dyn_cast<Instruction>(Val: V);
751 if (I)
752 WorkList.push_back(Elt: I);
753 };
754 // Pop an element with known shape. Traverse the operands, if their shape
755 // derives from the result shape and is unknown, add it and add them to the
756 // worklist.
757 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
758 while (!WorkList.empty()) {
759 Value *V = WorkList.pop_back_val();
760
761 size_t BeforeProcessingV = WorkList.size();
762 if (!isa<Instruction>(Val: V))
763 continue;
764
765 Value *MatrixA;
766 Value *MatrixB;
767 Value *M;
768 Value *N;
769 Value *K;
770 if (match(V, P: m_Intrinsic<Intrinsic::matrix_multiply>(
771 Op0: m_Value(V&: MatrixA), Op1: m_Value(V&: MatrixB), Op2: m_Value(V&: M),
772 Op3: m_Value(V&: N), Op4: m_Value(V&: K)))) {
773 if (setShapeInfo(V: MatrixA, Shape: {M, N}))
774 pushInstruction(MatrixA, WorkList);
775
776 if (setShapeInfo(V: MatrixB, Shape: {N, K}))
777 pushInstruction(MatrixB, WorkList);
778
779 } else if (match(V, P: m_Intrinsic<Intrinsic::matrix_transpose>(
780 Op0: m_Value(V&: MatrixA), Op1: m_Value(V&: M), Op2: m_Value(V&: N)))) {
781 // Flip dimensions.
782 if (setShapeInfo(V: MatrixA, Shape: {M, N}))
783 pushInstruction(MatrixA, WorkList);
784 } else if (match(V, P: m_Intrinsic<Intrinsic::matrix_column_major_store>(
785 Op0: m_Value(V&: MatrixA), Op1: m_Value(), Op2: m_Value(), Op3: m_Value(),
786 Op4: m_Value(V&: M), Op5: m_Value(V&: N)))) {
787 if (setShapeInfo(V: MatrixA, Shape: {M, N})) {
788 pushInstruction(MatrixA, WorkList);
789 }
790 } else if (isa<LoadInst>(Val: V) ||
791 match(V, P: m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
792 // Nothing to do, no matrix input.
793 } else if (isa<StoreInst>(Val: V)) {
794 // Nothing to do. We forward-propagated to this so we would just
795 // backward propagate to an instruction with an already known shape.
796 } else if (isUniformShape(V) || isa<SelectInst>(Val: V)) {
797 auto Ops = cast<Instruction>(Val: V)->operands();
798 auto ShapedOps = isa<SelectInst>(Val: V) ? drop_begin(RangeOrContainer&: Ops) : Ops;
799 // Propagate to all operands.
800 ShapeInfo Shape = ShapeMap[V];
801 for (Use &U : ShapedOps) {
802 if (setShapeInfo(V: U.get(), Shape))
803 pushInstruction(U.get(), WorkList);
804 }
805 }
806 // After we discovered new shape info for new instructions in the
807 // worklist, we use their users as seeds for the next round of forward
808 // propagation.
809 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
810 for (User *U : WorkList[I]->users())
811 if (isa<Instruction>(Val: U) && V != U)
812 NewWorkList.push_back(Elt: cast<Instruction>(Val: U));
813 }
814 return NewWorkList;
815 }
816
817 /// (Op0 op Op1)^T -> Op0^T op Op1^T
818 /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use
819 /// them on both sides of \p Operation.
820 Instruction *distributeTransposes(
821 Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1,
822 MatrixBuilder &Builder,
823 function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)>
824 Operation) {
825 Value *T0 = Builder.CreateMatrixTranspose(
826 Matrix: Op0, Rows: Shape0.NumRows, Columns: Shape0.NumColumns, Name: Op0->getName() + "_t");
827 // We are being run after shape prop, add shape for newly created
828 // instructions so that we lower them later.
829 setShapeInfo(V: T0, Shape: Shape0.t());
830 Value *T1 = Builder.CreateMatrixTranspose(
831 Matrix: Op1, Rows: Shape1.NumRows, Columns: Shape1.NumColumns, Name: Op1->getName() + "_t");
832 setShapeInfo(V: T1, Shape: Shape1.t());
833 return Operation(T0, Shape0.t(), T1, Shape1.t());
834 }
835
836 /// Erase \p Inst from both ShapeMap (if an entry exists) and erase \p Inst
837 /// itself.
838 void eraseFromParentAndRemoveFromShapeMap(Instruction *Inst) {
839 ShapeMap.erase(Val: Inst);
840 Inst->eraseFromParent();
841 }
842
843 /// Erase \p V from \p BB and move \II forward to avoid invalidating
844 /// iterators.
845 void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II,
846 BasicBlock &BB) {
847 auto *Inst = cast<Instruction>(Val: V);
848 // Still used, don't erase.
849 if (!Inst->use_empty())
850 return;
851 if (II != BB.rend() && Inst == &*II)
852 ++II;
853 eraseFromParentAndRemoveFromShapeMap(Inst);
854 }
855
856 /// Add a new entry to ShapeMap for \p New with \p Old's shape info, erase the
857 /// entry for \p Old and replace all uses of \p Old with \p New.
858 void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {
859 // We need to remove Old from the ShapeMap otherwise RAUW will replace it
860 // with New. We should only add New it it supportsShapeInfo so we insert
861 // it conditionally instead.
862 auto S = ShapeMap.find(Val: &Old);
863 if (S != ShapeMap.end()) {
864 ShapeMap.erase(I: S);
865 if (supportsShapeInfo(V: New))
866 ShapeMap.insert(KV: {New, S->second});
867 }
868 Old.replaceAllUsesWith(V: New);
869 }
870
871 /// Sink a top-level transpose inside matmuls and adds.
872 /// This creates and erases instructions as needed, and returns the newly
873 /// created instruction while updating the iterator to avoid invalidation. If
874 /// this returns nullptr, no new instruction was created.
875 Instruction *sinkTranspose(Instruction &I, BasicBlock::reverse_iterator &II,
876 bool &Changed) {
877 BasicBlock &BB = *I.getParent();
878 IRBuilder<> IB(&I);
879 MatrixBuilder Builder(IB);
880
881 Value *TA, *TAMA, *TAMB;
882 ConstantInt *R, *K, *C;
883 if (!match(V: &I, P: m_Intrinsic<Intrinsic::matrix_transpose>(
884 Op0: m_Value(V&: TA), Op1: m_ConstantInt(CI&: R), Op2: m_ConstantInt(CI&: C))))
885 return nullptr;
886
887 // Transpose of a transpose is a nop when the shapes match.
888 Value *TATA;
889 if (match(V: TA, P: m_Intrinsic<Intrinsic::matrix_transpose>(
890 Op0: m_Value(V&: TATA), Op1: m_Specific(V: C), Op2: m_Specific(V: R)))) {
891 updateShapeAndReplaceAllUsesWith(Old&: I, New: TATA);
892 eraseFromParentAndMove(V: &I, II, BB);
893 eraseFromParentAndMove(V: TA, II, BB);
894 Changed = true;
895 return nullptr;
896 }
897
898 // k^T -> k
899 if (isSplat(V: TA)) {
900 updateShapeAndReplaceAllUsesWith(Old&: I, New: TA);
901 eraseFromParentAndMove(V: &I, II, BB);
902 Changed = true;
903 return nullptr;
904 }
905
906 // (A * B)^t -> B^t * A^t
907 // RxK KxC CxK KxR
908 if (match(V: TA, P: m_Intrinsic<Intrinsic::matrix_multiply>(
909 Op0: m_Value(V&: TAMA), Op1: m_Value(V&: TAMB), Op2: m_ConstantInt(CI&: R),
910 Op3: m_ConstantInt(CI&: K), Op4: m_ConstantInt(CI&: C)))) {
911 auto NewInst = distributeTransposes(
912 Op0: TAMB, Shape0: {K, C}, Op1: TAMA, Shape1: {R, K}, Builder,
913 Operation: [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
914 return Builder.CreateMatrixMultiply(LHS: T0, RHS: T1, LHSRows: Shape0.NumRows,
915 LHSColumns: Shape0.NumColumns,
916 RHSColumns: Shape1.NumColumns, Name: "mmul");
917 });
918 updateShapeAndReplaceAllUsesWith(Old&: I, New: NewInst);
919 eraseFromParentAndMove(V: &I, II, BB);
920 eraseFromParentAndMove(V: TA, II, BB);
921 Changed = true;
922 return NewInst;
923 }
924
925 // Same as above, but with a mul, which occurs when multiplied
926 // with a scalar.
927 // (A * k)^t -> A^t * k
928 // R x C RxC
929 if (match(V: TA, P: m_AnyMul(L: m_Value(V&: TAMA), R: m_Value(V&: TAMB))) &&
930 (isSplat(V: TAMA) || isSplat(V: TAMB))) {
931 IRBuilder<> LocalBuilder(&I);
932 // We know that the transposed operand is of shape RxC.
933 // An when multiplied with a scalar, the shape is preserved.
934 auto NewInst = distributeTransposes(
935 Op0: TAMA, Shape0: {R, C}, Op1: TAMB, Shape1: {R, C}, Builder,
936 Operation: [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
937 bool IsFP = I.getType()->isFPOrFPVectorTy();
938 auto *Mul = IsFP ? LocalBuilder.CreateFMul(L: T0, R: T1, Name: "mmul")
939 : LocalBuilder.CreateMul(LHS: T0, RHS: T1, Name: "mmul");
940 auto *Result = cast<Instruction>(Val: Mul);
941 setShapeInfo(V: Result, Shape: Shape0);
942 return Result;
943 });
944 updateShapeAndReplaceAllUsesWith(Old&: I, New: NewInst);
945 eraseFromParentAndMove(V: &I, II, BB);
946 eraseFromParentAndMove(V: TA, II, BB);
947 Changed = true;
948 return NewInst;
949 }
950
951 // (A + B)^t -> A^t + B^t
952 // RxC RxC CxR CxR
953 if (match(V: TA, P: m_AnyAdd(L: m_Value(V&: TAMA), R: m_Value(V&: TAMB)))) {
954 IRBuilder<> LocalBuilder(&I);
955 auto NewInst = distributeTransposes(
956 Op0: TAMA, Shape0: {R, C}, Op1: TAMB, Shape1: {R, C}, Builder,
957 Operation: [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
958 bool IsFP = I.getType()->isFPOrFPVectorTy();
959 auto *Add = IsFP ? LocalBuilder.CreateFAdd(L: T0, R: T1, Name: "madd")
960 : LocalBuilder.CreateAdd(LHS: T0, RHS: T1, Name: "madd");
961
962 auto *Result = cast<Instruction>(Val: Add);
963 setShapeInfo(V: Result, Shape: Shape0);
964 return Result;
965 });
966 updateShapeAndReplaceAllUsesWith(Old&: I, New: NewInst);
967 eraseFromParentAndMove(V: &I, II, BB);
968 eraseFromParentAndMove(V: TA, II, BB);
969 Changed = true;
970 return NewInst;
971 }
972
973 return nullptr;
974 }
975
976 bool liftTranspose(Instruction &I) {
977 // Erase dead Instructions after lifting transposes from binops.
978 auto CleanupBinOp = [this](Instruction &T, Value *A, Value *B) {
979 if (T.use_empty())
980 eraseFromParentAndRemoveFromShapeMap(Inst: &T);
981 if (A->use_empty())
982 eraseFromParentAndRemoveFromShapeMap(Inst: cast<Instruction>(Val: A));
983 if (A != B && B->use_empty())
984 eraseFromParentAndRemoveFromShapeMap(Inst: cast<Instruction>(Val: B));
985 };
986
987 Value *A, *B, *AT, *BT;
988 ConstantInt *R, *K, *C;
989 // A^t * B ^t -> (B * A)^t
990 if (match(V: &I, P: m_Intrinsic<Intrinsic::matrix_multiply>(
991 Op0: m_Value(V&: A), Op1: m_Value(V&: B), Op2: m_ConstantInt(CI&: R),
992 Op3: m_ConstantInt(CI&: K), Op4: m_ConstantInt(CI&: C))) &&
993 match(V: A, P: m_Intrinsic<Intrinsic::matrix_transpose>(Op0: m_Value(V&: AT))) &&
994 match(V: B, P: m_Intrinsic<Intrinsic::matrix_transpose>(Op0: m_Value(V&: (BT))))) {
995 IRBuilder<> IB(&I);
996 MatrixBuilder Builder(IB);
997 Value *M = Builder.CreateMatrixMultiply(
998 LHS: BT, RHS: AT, LHSRows: C->getZExtValue(), LHSColumns: K->getZExtValue(), RHSColumns: R->getZExtValue());
999 setShapeInfo(V: M, Shape: {C, R});
1000 Instruction *NewInst = Builder.CreateMatrixTranspose(Matrix: M, Rows: C->getZExtValue(),
1001 Columns: R->getZExtValue());
1002 updateShapeAndReplaceAllUsesWith(Old&: I, New: NewInst);
1003 CleanupBinOp(I, A, B);
1004 return true;
1005 }
1006 // A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If
1007 // the shape of the second transpose is different, there's a shape conflict
1008 // which gets resolved by picking the shape of the first operand.
1009 else if (match(V: &I, P: m_FAdd(L: m_Value(V&: A), R: m_Value(V&: B))) &&
1010 match(V: A, P: m_Intrinsic<Intrinsic::matrix_transpose>(
1011 Op0: m_Value(V&: AT), Op1: m_ConstantInt(CI&: R), Op2: m_ConstantInt(CI&: C))) &&
1012 match(V: B, P: m_Intrinsic<Intrinsic::matrix_transpose>(
1013 Op0: m_Value(V&: BT), Op1: m_ConstantInt(), Op2: m_ConstantInt()))) {
1014 IRBuilder<> Builder(&I);
1015 auto *Add = Builder.CreateFAdd(L: AT, R: BT, Name: "mfadd");
1016 MatrixBuilder MBuilder(Builder);
1017 Instruction *NewInst = MBuilder.CreateMatrixTranspose(
1018 Matrix: Add, Rows: R->getZExtValue(), Columns: C->getZExtValue(), Name: "mfadd_t");
1019 updateShapeAndReplaceAllUsesWith(Old&: I, New: NewInst);
1020 assert(computeShapeInfoForInst(NewInst, ShapeMap) ==
1021 computeShapeInfoForInst(&I, ShapeMap) &&
1022 "Shape of new instruction doesn't match original shape.");
1023 CleanupBinOp(I, A, B);
1024 if (auto *AddI = dyn_cast<Instruction>(Val: Add)) {
1025 setShapeInfo(V: AddI, Shape: {R, C});
1026 assert(
1027 computeShapeInfoForInst(AddI, ShapeMap).value_or(ShapeMap[AddI]) ==
1028 ShapeMap[AddI] &&
1029 "Shape of updated addition doesn't match cached shape.");
1030 }
1031 return true;
1032 }
1033 return false;
1034 }
1035
1036 /// Try moving transposes in order to fold them away or into multiplies.
1037 bool optimizeTransposes() {
1038 bool Changed = false;
1039 // First sink all transposes inside matmuls and adds, hoping that we end up
1040 // with NN, NT or TN variants.
1041 for (BasicBlock &BB : reverse(C&: Func)) {
1042 for (auto II = BB.rbegin(); II != BB.rend();) {
1043 Instruction &I = *II;
1044 // We may remove II. By default continue on the next/prev instruction.
1045 ++II;
1046 if (Instruction *NewInst = sinkTranspose(I, II, Changed))
1047 II = std::next(x: BasicBlock::reverse_iterator(NewInst));
1048 }
1049 }
1050
1051 // If we have a TT matmul or a TT add, lift the transpose. We may be able
1052 // to fold into consuming multiply or add.
1053 for (BasicBlock &BB : Func) {
1054 for (Instruction &I : llvm::make_early_inc_range(Range&: BB)) {
1055 Changed |= liftTranspose(I);
1056 }
1057 }
1058 return Changed;
1059 }
1060
1061 bool Visit() {
1062 SmallVector<Instruction *, 32> WorkList;
1063
1064 // Initially only the shape of matrix intrinsics is known.
1065 // Initialize the work list with ops carrying shape information.
1066 for (BasicBlock &BB : Func)
1067 for (Instruction &Inst : BB) {
1068 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: &Inst);
1069 if (!II)
1070 continue;
1071
1072 switch (II->getIntrinsicID()) {
1073 case Intrinsic::matrix_multiply:
1074 case Intrinsic::matrix_transpose:
1075 case Intrinsic::matrix_column_major_load:
1076 case Intrinsic::matrix_column_major_store:
1077 WorkList.push_back(Elt: &Inst);
1078 break;
1079 default:
1080 break;
1081 }
1082 }
1083
1084 // Avoid unnecessary work if there are no matrix intrinsics in the function.
1085 if (WorkList.empty())
1086 return false;
1087
1088 if (AM) {
1089 ORE = &AM->getResult<OptimizationRemarkEmitterAnalysis>(IR&: Func);
1090 AA = &AM->getResult<AAManager>(IR&: Func);
1091 DT = &AM->getResult<DominatorTreeAnalysis>(IR&: Func);
1092 LI = &AM->getResult<LoopAnalysis>(IR&: Func);
1093 }
1094
1095 // Propagate shapes until nothing changes any longer.
1096 while (!WorkList.empty()) {
1097 WorkList = propagateShapeForward(WorkList);
1098 WorkList = propagateShapeBackward(WorkList);
1099 }
1100
1101 bool Changed = false;
1102 if (!isMinimal()) {
1103 Changed |= optimizeTransposes();
1104 if (PrintAfterTransposeOpt) {
1105 dbgs() << "Dump after matrix transpose optimization:\n";
1106 Func.print(OS&: dbgs());
1107 }
1108 }
1109
1110 SmallVector<CallInst *, 16> MaybeFusableInsts;
1111 SmallVector<Instruction *, 16> MatrixInsts;
1112 SmallVector<IntrinsicInst *, 16> LifetimeEnds;
1113
1114 // First, collect all instructions with shape information and candidates for
1115 // fusion (currently only matrix multiplies).
1116 ReversePostOrderTraversal<Function *> RPOT(&Func);
1117 for (auto *BB : RPOT)
1118 for (Instruction &I : *BB) {
1119 if (match(V: &I, P: m_Intrinsic<Intrinsic::lifetime_end>()))
1120 LifetimeEnds.push_back(Elt: cast<IntrinsicInst>(Val: &I));
1121 if (!ShapeMap.contains(Val: &I))
1122 continue;
1123 if (match(V: &I, P: m_Intrinsic<Intrinsic::matrix_multiply>()))
1124 MaybeFusableInsts.push_back(Elt: cast<CallInst>(Val: &I));
1125 MatrixInsts.push_back(Elt: &I);
1126 }
1127
1128 // Second, try to lower any dot products
1129 SmallPtrSet<Instruction *, 16> FusedInsts;
1130 for (CallInst *CI : MaybeFusableInsts)
1131 lowerDotProduct(MatMul: CI, FusedInsts, FMF: getFastMathFlags(Inst: CI));
1132
1133 // Third, try to fuse candidates.
1134 for (CallInst *CI : MaybeFusableInsts)
1135 if (!FusedInsts.contains(Ptr: CI))
1136 LowerMatrixMultiplyFused(MatMul: CI, FusedInsts, LifetimeEnds);
1137
1138 Changed |= !FusedInsts.empty();
1139
1140 // Fourth, pre-process all the PHINode's. The incoming values will be
1141 // assigned later in VisitPHI.
1142 for (Instruction *Inst : MatrixInsts) {
1143 if (FusedInsts.count(Ptr: Inst))
1144 continue;
1145
1146 auto *PHI = dyn_cast<PHINode>(Val: Inst);
1147 if (!PHI)
1148 continue;
1149
1150 const ShapeInfo &SI = ShapeMap.at(Val: Inst);
1151 auto *EltTy = cast<FixedVectorType>(Val: PHI->getType())->getElementType();
1152 MatrixTy PhiM(SI.NumRows, SI.NumColumns, EltTy);
1153
1154 IRBuilder<> Builder(Inst);
1155 for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI)
1156 PhiM.setVector(i: VI, V: Builder.CreatePHI(Ty: PhiM.getVectorTy(),
1157 NumReservedValues: PHI->getNumIncomingValues(),
1158 Name: PHI->getName()));
1159 assert(!Inst2ColumnMatrix.contains(PHI) && "map already contains phi?");
1160 Inst2ColumnMatrix[PHI] = PhiM;
1161 }
1162
1163 // Fifth, lower remaining instructions with shape information.
1164 for (Instruction *Inst : MatrixInsts) {
1165 if (FusedInsts.count(Ptr: Inst))
1166 continue;
1167
1168 const ShapeInfo &SI = ShapeMap.at(Val: Inst);
1169
1170 Value *Op1;
1171 Value *Op2;
1172 MatrixTy Result;
1173 IRBuilder<> Builder(Inst);
1174 if (auto *BinOp = dyn_cast<BinaryOperator>(Val: Inst))
1175 Result = VisitBinaryOperator(Inst: BinOp, SI, Builder);
1176 else if (auto *Cast = dyn_cast<CastInst>(Val: Inst))
1177 Result = VisitCastInstruction(Inst: Cast, Shape: SI, Builder);
1178 else if (auto *UnOp = dyn_cast<UnaryOperator>(Val: Inst))
1179 Result = VisitUnaryOperator(Inst: UnOp, SI, Builder);
1180 else if (auto *Intr = dyn_cast<IntrinsicInst>(Val: Inst))
1181 Result = VisitIntrinsicInst(Inst: Intr, SI, Builder);
1182 else if (auto *Select = dyn_cast<SelectInst>(Val: Inst))
1183 Result = VisitSelectInst(Inst: Select, Shape: SI, Builder);
1184 else if (match(V: Inst, P: m_Load(Op: m_Value(V&: Op1))))
1185 Result = VisitLoad(Inst: cast<LoadInst>(Val: Inst), SI, Ptr: Op1, Builder);
1186 else if (match(V: Inst, P: m_Store(ValueOp: m_Value(V&: Op1), PointerOp: m_Value(V&: Op2))))
1187 Result = VisitStore(Inst: cast<StoreInst>(Val: Inst), SI, StoredVal: Op1, Ptr: Op2, Builder);
1188 else if (auto *PHI = dyn_cast<PHINode>(Val: Inst))
1189 Result = VisitPHI(Inst: PHI, SI, Builder);
1190 else
1191 continue;
1192
1193 finalizeLowering(Inst, Matrix: Result, Builder);
1194 Changed = true;
1195 }
1196
1197 if (ORE) {
1198 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
1199 RemarkGen.emitRemarks();
1200 }
1201
1202 // Delete the instructions backwards, as it has a reduced likelihood of
1203 // having to update as many def-use and use-def chains.
1204 //
1205 // Because we add to ToRemove during fusion we can't guarantee that defs
1206 // are before uses. Change uses to poison temporarily as these should get
1207 // removed as well.
1208 //
1209 // For verification, we keep track of where we changed uses to poison in
1210 // PoisonedInsts and then check that we in fact remove them.
1211 SmallSet<Instruction *, 16> PoisonedInsts;
1212 for (auto *Inst : reverse(C&: ToRemove)) {
1213 for (Use &U : llvm::make_early_inc_range(Range: Inst->uses())) {
1214 if (auto *Poisoned = dyn_cast<Instruction>(Val: U.getUser()))
1215 PoisonedInsts.insert(Ptr: Poisoned);
1216 U.set(PoisonValue::get(T: Inst->getType()));
1217 }
1218 Inst->eraseFromParent();
1219 PoisonedInsts.erase(Ptr: Inst);
1220 }
1221 if (!PoisonedInsts.empty()) {
1222 // If we didn't remove all poisoned instructions, it's a hard error.
1223 dbgs() << "Poisoned but present instructions:\n";
1224 for (auto *I : PoisonedInsts)
1225 dbgs() << *I << "\n";
1226 llvm_unreachable("Poisoned but instruction not removed");
1227 }
1228
1229 return Changed;
1230 }
1231
1232 /// Replace intrinsic calls.
1233 MatrixTy VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &SI,
1234 IRBuilder<> &Builder) {
1235 assert(Inst->getCalledFunction() &&
1236 Inst->getCalledFunction()->isIntrinsic());
1237
1238 switch (Inst->getCalledFunction()->getIntrinsicID()) {
1239 case Intrinsic::matrix_multiply:
1240 return LowerMultiply(MatMul: Inst, Builder);
1241 case Intrinsic::matrix_transpose:
1242 return LowerTranspose(Inst, Builder);
1243 case Intrinsic::matrix_column_major_load:
1244 return LowerColumnMajorLoad(Inst, Builder);
1245 case Intrinsic::matrix_column_major_store:
1246 return LowerColumnMajorStore(Inst, Builder);
1247 case Intrinsic::abs:
1248 case Intrinsic::fabs: {
1249 MatrixTy Result;
1250 MatrixTy M = getMatrix(MatrixVal: Inst->getOperand(i_nocapture: 0), SI, Builder);
1251 Builder.setFastMathFlags(getFastMathFlags(Inst));
1252
1253 for (auto *Vector : M.vectors()) {
1254 switch (Inst->getIntrinsicID()) {
1255 case Intrinsic::abs:
1256 Result.addVector(V: Builder.CreateBinaryIntrinsic(ID: Intrinsic::abs, LHS: Vector,
1257 RHS: Inst->getOperand(i_nocapture: 1)));
1258 continue;
1259 case Intrinsic::fabs:
1260 Result.addVector(
1261 V: Builder.CreateUnaryIntrinsic(ID: Inst->getIntrinsicID(), V: Vector));
1262 continue;
1263 default:
1264 llvm_unreachable("unexpected intrinsic");
1265 }
1266 }
1267
1268 return Result.addNumComputeOps(N: getNumOps(VT: Result.getVectorTy()) *
1269 Result.getNumVectors());
1270 }
1271 default:
1272 break;
1273 }
1274 llvm_unreachable(
1275 "only intrinsics supporting shape info should be seen here");
1276 }
1277
1278 /// Compute the alignment for a column/row \p Idx with \p Stride between them.
1279 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
1280 /// ConstantInt, reduce the initial alignment based on the byte offset. For
1281 /// non-ConstantInt strides, return the common alignment of the initial
1282 /// alignment and the element size in bytes.
1283 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
1284 MaybeAlign A) const {
1285 Align InitialAlign = DL.getValueOrABITypeAlignment(Alignment: A, Ty: ElementTy);
1286 if (Idx == 0)
1287 return InitialAlign;
1288
1289 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(Ty: ElementTy);
1290 if (auto *ConstStride = dyn_cast<ConstantInt>(Val: Stride)) {
1291 uint64_t StrideInBytes =
1292 ConstStride->getZExtValue() * ElementSizeInBits / 8;
1293 return commonAlignment(A: InitialAlign, Offset: Idx * StrideInBytes);
1294 }
1295 return commonAlignment(A: InitialAlign, Offset: ElementSizeInBits / 8);
1296 }
1297
1298 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
1299 /// vectors.
1300 MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
1301 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
1302 auto *VType = cast<FixedVectorType>(Val: Ty);
1303 Type *EltTy = VType->getElementType();
1304 Type *VecTy = FixedVectorType::get(ElementType: EltTy, NumElts: Shape.getStride());
1305 Value *EltPtr = Ptr;
1306 MatrixTy Result;
1307 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
1308 Value *GEP = computeVectorAddr(
1309 BasePtr: EltPtr, VecIdx: Builder.getIntN(N: Stride->getType()->getScalarSizeInBits(), C: I),
1310 Stride, NumElements: Shape.getStride(), EltType: EltTy, Builder);
1311 Value *Vector = Builder.CreateAlignedLoad(
1312 Ty: VecTy, Ptr: GEP, Align: getAlignForIndex(Idx: I, Stride, ElementTy: EltTy, A: MAlign),
1313 isVolatile: IsVolatile, Name: "col.load");
1314
1315 Result.addVector(V: Vector);
1316 }
1317 return Result.addNumLoads(N: getNumOps(VT: Result.getVectorTy()) *
1318 Result.getNumVectors());
1319 }
1320
1321 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
1322 /// starting at \p MatrixPtr[I][J].
1323 MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
1324 ShapeInfo MatrixShape, Value *I, Value *J,
1325 ShapeInfo ResultShape, Type *EltTy,
1326 IRBuilder<> &Builder) {
1327 Value *Offset = Builder.CreateAdd(
1328 LHS: Builder.CreateMul(LHS: J, RHS: Builder.getInt64(C: MatrixShape.getStride())), RHS: I);
1329
1330 Value *TileStart = Builder.CreateGEP(Ty: EltTy, Ptr: MatrixPtr, IdxList: Offset);
1331 auto *TileTy = FixedVectorType::get(ElementType: EltTy, NumElts: ResultShape.NumRows *
1332 ResultShape.NumColumns);
1333
1334 return loadMatrix(Ty: TileTy, Ptr: TileStart, MAlign: Align,
1335 Stride: Builder.getInt64(C: MatrixShape.getStride()), IsVolatile,
1336 Shape: ResultShape, Builder);
1337 }
1338
1339 /// Lower a load instruction with shape information.
1340 MatrixTy LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align,
1341 Value *Stride, bool IsVolatile, ShapeInfo Shape,
1342 IRBuilder<> &Builder) {
1343 return loadMatrix(Ty: Inst->getType(), Ptr, MAlign: Align, Stride, IsVolatile, Shape,
1344 Builder);
1345 }
1346
1347 /// Lowers llvm.matrix.column.major.load.
1348 ///
1349 /// The intrinsic loads a matrix from memory using a stride between columns.
1350 MatrixTy LowerColumnMajorLoad(CallInst *Inst, IRBuilder<> &Builder) {
1351 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1352 "Intrinsic only supports column-major layout!");
1353 Value *Ptr = Inst->getArgOperand(i: 0);
1354 Value *Stride = Inst->getArgOperand(i: 1);
1355 return LowerLoad(Inst, Ptr, Align: Inst->getParamAlign(ArgNo: 0), Stride,
1356 IsVolatile: cast<ConstantInt>(Val: Inst->getArgOperand(i: 2))->isOne(),
1357 Shape: {Inst->getArgOperand(i: 3), Inst->getArgOperand(i: 4)}, Builder);
1358 }
1359
1360 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
1361 /// MatrixPtr[I][J].
1362 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
1363 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
1364 Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
1365 Value *Offset = Builder.CreateAdd(
1366 LHS: Builder.CreateMul(LHS: J, RHS: Builder.getInt64(C: MatrixShape.getStride())), RHS: I);
1367
1368 Value *TileStart = Builder.CreateGEP(Ty: EltTy, Ptr: MatrixPtr, IdxList: Offset);
1369 auto *TileTy = FixedVectorType::get(ElementType: EltTy, NumElts: StoreVal.getNumRows() *
1370 StoreVal.getNumColumns());
1371
1372 storeMatrix(Ty: TileTy, StoreVal, Ptr: TileStart, MAlign,
1373 Stride: Builder.getInt64(C: MatrixShape.getStride()), IsVolatile, Builder);
1374 }
1375
1376 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
1377 /// vectors.
1378 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
1379 MaybeAlign MAlign, Value *Stride, bool IsVolatile,
1380 IRBuilder<> &Builder) {
1381 auto *VType = cast<FixedVectorType>(Val: Ty);
1382 Value *EltPtr = Ptr;
1383 for (auto Vec : enumerate(First: StoreVal.vectors())) {
1384 Value *GEP = computeVectorAddr(
1385 BasePtr: EltPtr,
1386 VecIdx: Builder.getIntN(N: Stride->getType()->getScalarSizeInBits(),
1387 C: Vec.index()),
1388 Stride, NumElements: StoreVal.getStride(), EltType: VType->getElementType(), Builder);
1389 Builder.CreateAlignedStore(Val: Vec.value(), Ptr: GEP,
1390 Align: getAlignForIndex(Idx: Vec.index(), Stride,
1391 ElementTy: VType->getElementType(),
1392 A: MAlign),
1393 isVolatile: IsVolatile);
1394 }
1395 return MatrixTy().addNumStores(N: getNumOps(VT: StoreVal.getVectorTy()) *
1396 StoreVal.getNumVectors());
1397 }
1398
1399 /// Lower a store instruction with shape information.
1400 MatrixTy LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr,
1401 MaybeAlign A, Value *Stride, bool IsVolatile,
1402 ShapeInfo Shape, IRBuilder<> &Builder) {
1403 auto StoreVal = getMatrix(MatrixVal: Matrix, SI: Shape, Builder);
1404 return storeMatrix(Ty: Matrix->getType(), StoreVal, Ptr, MAlign: A, Stride, IsVolatile,
1405 Builder);
1406 }
1407
1408 /// Lowers llvm.matrix.column.major.store.
1409 ///
1410 /// The intrinsic store a matrix back memory using a stride between columns.
1411 MatrixTy LowerColumnMajorStore(CallInst *Inst, IRBuilder<> &Builder) {
1412 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1413 "Intrinsic only supports column-major layout!");
1414 Value *Matrix = Inst->getArgOperand(i: 0);
1415 Value *Ptr = Inst->getArgOperand(i: 1);
1416 Value *Stride = Inst->getArgOperand(i: 2);
1417 return LowerStore(Inst, Matrix, Ptr, A: Inst->getParamAlign(ArgNo: 1), Stride,
1418 IsVolatile: cast<ConstantInt>(Val: Inst->getArgOperand(i: 3))->isOne(),
1419 Shape: {Inst->getArgOperand(i: 4), Inst->getArgOperand(i: 5)},
1420 Builder);
1421 }
1422
1423 // Set elements I..I+NumElts-1 to Block
1424 Value *insertVector(Value *Col, unsigned I, Value *Block,
1425 IRBuilder<> &Builder) {
1426
1427 // First, bring Block to the same size as Col
1428 unsigned BlockNumElts =
1429 cast<FixedVectorType>(Val: Block->getType())->getNumElements();
1430 unsigned NumElts = cast<FixedVectorType>(Val: Col->getType())->getNumElements();
1431 assert(NumElts >= BlockNumElts && "Too few elements for current block");
1432
1433 Block = Builder.CreateShuffleVector(
1434 V: Block, Mask: createSequentialMask(Start: 0, NumInts: BlockNumElts, NumUndefs: NumElts - BlockNumElts));
1435
1436 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
1437 // 8, 4, 5, 6
1438 SmallVector<int, 16> Mask;
1439 unsigned i;
1440 for (i = 0; i < I; i++)
1441 Mask.push_back(Elt: i);
1442
1443 unsigned VecNumElts =
1444 cast<FixedVectorType>(Val: Col->getType())->getNumElements();
1445 for (; i < I + BlockNumElts; i++)
1446 Mask.push_back(Elt: i - I + VecNumElts);
1447
1448 for (; i < VecNumElts; i++)
1449 Mask.push_back(Elt: i);
1450
1451 return Builder.CreateShuffleVector(V1: Col, V2: Block, Mask);
1452 }
1453
1454 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
1455 IRBuilder<> &Builder, bool AllowContraction,
1456 unsigned &NumComputeOps) {
1457 NumComputeOps += getNumOps(VT: A->getType());
1458 if (!Sum)
1459 return UseFPOp ? Builder.CreateFMul(L: A, R: B) : Builder.CreateMul(LHS: A, RHS: B);
1460
1461 if (UseFPOp) {
1462 if (AllowContraction) {
1463 // Use fmuladd for floating point operations and let the backend decide
1464 // if that's profitable.
1465 return Builder.CreateIntrinsic(ID: Intrinsic::fmuladd, Types: A->getType(),
1466 Args: {A, B, Sum});
1467 }
1468 NumComputeOps += getNumOps(VT: A->getType());
1469 Value *Mul = Builder.CreateFMul(L: A, R: B);
1470 return Builder.CreateFAdd(L: Sum, R: Mul);
1471 }
1472
1473 NumComputeOps += getNumOps(VT: A->getType());
1474 Value *Mul = Builder.CreateMul(LHS: A, RHS: B);
1475 return Builder.CreateAdd(LHS: Sum, RHS: Mul);
1476 }
1477
1478 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
1479 /// users with shape information, there's nothing to do: they will use the
1480 /// cached value when they are lowered. For other users, \p Matrix is
1481 /// flattened and the uses are updated to use it. Also marks \p Inst for
1482 /// deletion.
1483 void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
1484 IRBuilder<> &Builder) {
1485 auto inserted = Inst2ColumnMatrix.insert(KV: std::make_pair(x&: Inst, y&: Matrix));
1486 (void)inserted;
1487 assert((inserted.second || isa<PHINode>(Inst)) &&
1488 "multiple matrix lowering mapping");
1489
1490 ToRemove.push_back(Elt: Inst);
1491 Value *Flattened = nullptr;
1492 for (Use &U : llvm::make_early_inc_range(Range: Inst->uses())) {
1493 if (ShapeMap.contains(Val: U.getUser()))
1494 continue;
1495
1496 if (!Flattened) {
1497 Flattened = Matrix.embedInVector(Builder);
1498 LLVM_DEBUG(
1499 if (Instruction *User = dyn_cast<Instruction>(U.getUser())) dbgs()
1500 << "flattening a " << Matrix.shape() << " matrix:\n"
1501 << *Inst
1502 << "\nbecause we do not have a shape-aware lowering for its "
1503 "user:\n"
1504 << *User << '\n';);
1505 FlattenedMatrices++;
1506 }
1507 U.set(Flattened);
1508 }
1509 }
1510
1511 /// Special case for MatMul lowering. Prevents scalar loads of row-major
1512 /// vectors Lowers to vector reduction add instead of sequential add if
1513 /// reassocation is enabled.
1514 void lowerDotProduct(CallInst *MatMul,
1515 SmallPtrSet<Instruction *, 16> &FusedInsts,
1516 FastMathFlags FMF) {
1517 if (FusedInsts.contains(Ptr: MatMul) ||
1518 MatrixLayout != MatrixLayoutTy::ColumnMajor)
1519 return;
1520 ShapeInfo LShape(MatMul->getArgOperand(i: 2), MatMul->getArgOperand(i: 3));
1521 ShapeInfo RShape(MatMul->getArgOperand(i: 3), MatMul->getArgOperand(i: 4));
1522
1523 if (LShape.NumRows != 1 || RShape.NumColumns != 1) // not a dot product
1524 return;
1525
1526 Value *LHS = MatMul->getArgOperand(i: 0);
1527 Value *RHS = MatMul->getArgOperand(i: 1);
1528
1529 Type *ElementType = cast<FixedVectorType>(Val: LHS->getType())->getElementType();
1530 bool IsIntVec = ElementType->isIntegerTy();
1531
1532 // Floating point reductions require reassocation.
1533 if (!IsIntVec && !FMF.allowReassoc())
1534 return;
1535
1536 auto CanBeFlattened = [](Value *Op) {
1537 if (match(V: Op, P: m_BinOp()))
1538 return true;
1539 return match(
1540 V: Op, P: m_OneUse(SubPattern: m_CombineOr(
1541 L: m_Load(Op: m_Value()),
1542 R: m_CombineOr(L: m_Intrinsic<Intrinsic::matrix_transpose>(),
1543 R: m_Intrinsic<Intrinsic::matrix_column_major_load>(
1544 Op0: m_Value(), Op1: m_SpecificInt(V: 1))))));
1545 };
1546 // Returns the cost benefit of using \p Op with the dot product lowering. If
1547 // the returned cost is < 0, the argument is cheaper to use in the
1548 // dot-product lowering.
1549 auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) {
1550 if (!ShapeMap.contains(Val: Op))
1551 return InstructionCost::getInvalid();
1552
1553 if (!isa<Instruction>(Val: Op))
1554 return InstructionCost(0);
1555
1556 FixedVectorType *VecTy = cast<FixedVectorType>(Val: Op->getType());
1557 Type *EltTy = VecTy->getElementType();
1558
1559 if (!CanBeFlattened(Op)) {
1560 InstructionCost EmbedCost(0);
1561 // Roughly estimate the cost for embedding the columns into a vector.
1562 for (unsigned I = 1; I < N; ++I)
1563 EmbedCost += TTI.getShuffleCost(
1564 Kind: TTI::SK_Splice, DstTy: FixedVectorType::get(ElementType: EltTy, NumElts: 1),
1565 SrcTy: FixedVectorType::get(ElementType: EltTy, NumElts: 1), Mask: {}, CostKind: TTI::TCK_RecipThroughput);
1566 return EmbedCost;
1567 }
1568
1569 if (match(V: Op, P: m_BinOp()) && ShapeMap.contains(Val: Op)) {
1570 InstructionCost OriginalCost =
1571 TTI.getArithmeticInstrCost(Opcode: cast<Instruction>(Val: Op)->getOpcode(),
1572 Ty: EltTy) *
1573 N;
1574 InstructionCost NewCost = TTI.getArithmeticInstrCost(
1575 Opcode: cast<Instruction>(Val: Op)->getOpcode(), Ty: VecTy);
1576 return NewCost - OriginalCost;
1577 }
1578
1579 if (match(V: Op, P: m_Intrinsic<Intrinsic::matrix_transpose>())) {
1580 // The transpose can be skipped for the dot product lowering, roughly
1581 // estimate the savings as the cost of embedding the columns in a
1582 // vector.
1583 InstructionCost EmbedCost(0);
1584 for (unsigned I = 1; I < N; ++I)
1585 EmbedCost -= TTI.getShuffleCost(
1586 Kind: TTI::SK_Splice, DstTy: FixedVectorType::get(ElementType: EltTy, NumElts: 1),
1587 SrcTy: FixedVectorType::get(ElementType: EltTy, NumElts: 1), Mask: {}, CostKind: TTI::TCK_RecipThroughput);
1588 return EmbedCost;
1589 }
1590
1591 // Costs for loads.
1592 if (N == 1)
1593 return InstructionCost(0);
1594
1595 return TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: VecTy, Alignment: Align(1), AddressSpace: 0) -
1596 N * TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: EltTy, Alignment: Align(1), AddressSpace: 0);
1597 };
1598
1599 // Iterate over LHS and operations feeding LHS and check if it is profitable
1600 // to flatten the visited ops. For each op, we compute the difference
1601 // between the flattened and matrix versions.
1602 SmallPtrSet<Value *, 4> Seen;
1603 SmallVector<Value *> WorkList;
1604 SmallVector<Value *> ToFlatten;
1605 WorkList.push_back(Elt: LHS);
1606 InstructionCost LHSCost(0);
1607 while (!WorkList.empty()) {
1608 Value *Op = WorkList.pop_back_val();
1609 if (!Seen.insert(Ptr: Op).second)
1610 continue;
1611
1612 InstructionCost OpCost = GetCostForArg(Op, LShape.NumColumns);
1613 if (OpCost + LHSCost >= LHSCost)
1614 continue;
1615
1616 LHSCost += OpCost;
1617 ToFlatten.push_back(Elt: Op);
1618 if (auto *I = dyn_cast<Instruction>(Val: Op))
1619 WorkList.append(in_start: I->op_begin(), in_end: I->op_end());
1620 }
1621
1622 // We compare the costs of a vector.reduce.add to sequential add.
1623 int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;
1624 int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
1625 InstructionCost ReductionCost =
1626 TTI.getArithmeticReductionCost(
1627 Opcode: AddOpCode, Ty: cast<FixedVectorType>(Val: LHS->getType()),
1628 FMF: IsIntVec ? std::nullopt : std::optional(FMF)) +
1629 TTI.getArithmeticInstrCost(Opcode: MulOpCode, Ty: LHS->getType());
1630 InstructionCost SequentialAddCost =
1631 TTI.getArithmeticInstrCost(Opcode: AddOpCode, Ty: ElementType) *
1632 (LShape.NumColumns - 1) +
1633 TTI.getArithmeticInstrCost(Opcode: MulOpCode, Ty: ElementType) *
1634 (LShape.NumColumns);
1635 if ((LHSCost + ReductionCost - SequentialAddCost) > InstructionCost(0))
1636 return;
1637
1638 FusedInsts.insert(Ptr: MatMul);
1639 IRBuilder<> Builder(MatMul);
1640 auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
1641 this](Value *Op) {
1642 // Matmul must be the only user of loads because we don't use LowerLoad
1643 // for row vectors (LowerLoad results in scalar loads and shufflevectors
1644 // instead of single vector load).
1645 if (!CanBeFlattened(Op))
1646 return;
1647
1648 if (match(V: Op, P: m_BinOp())) {
1649 auto It = ShapeMap.find(Val: Op);
1650 if (It != ShapeMap.end()) {
1651 It->second = It->second.t();
1652 return;
1653 }
1654 }
1655
1656 FusedInsts.insert(Ptr: cast<Instruction>(Val: Op));
1657 // If vector uses the builtin load, lower to a LoadInst
1658 Value *Arg;
1659 if (match(V: Op, P: m_Intrinsic<Intrinsic::matrix_column_major_load>(
1660 Op0: m_Value(V&: Arg)))) {
1661 auto *NewLoad = Builder.CreateLoad(Ty: Op->getType(), Ptr: Arg);
1662 Op->replaceAllUsesWith(V: NewLoad);
1663 eraseFromParentAndRemoveFromShapeMap(Inst: cast<Instruction>(Val: Op));
1664 return;
1665 } else if (match(V: Op, P: m_Intrinsic<Intrinsic::matrix_transpose>(
1666 Op0: m_Value(V&: Arg)))) {
1667 ToRemove.push_back(Elt: cast<Instruction>(Val: Op));
1668 Op->replaceAllUsesWith(V: Arg);
1669 return;
1670 }
1671 };
1672
1673 for (auto *V : ToFlatten)
1674 FlattenArg(V);
1675
1676 LHS = MatMul->getArgOperand(i: 0);
1677
1678 // Insert mul/fmul and llvm.vector.reduce.fadd
1679 Value *Mul =
1680 IsIntVec ? Builder.CreateMul(LHS, RHS) : Builder.CreateFMul(L: LHS, R: RHS);
1681
1682 Value *Result;
1683 if (IsIntVec)
1684 Result = Builder.CreateAddReduce(Src: Mul);
1685 else {
1686 Result = Builder.CreateFAddReduce(
1687 Acc: ConstantFP::get(
1688 Ty: cast<FixedVectorType>(Val: LHS->getType())->getElementType(), V: 0.0),
1689 Src: Mul);
1690 cast<Instruction>(Val: Result)->setFastMathFlags(FMF);
1691 }
1692
1693 // pack scalar back into a matrix and then replace matmul inst
1694 Result = Builder.CreateInsertElement(Vec: PoisonValue::get(T: MatMul->getType()),
1695 NewElt: Result, Idx: uint64_t(0));
1696 MatMul->replaceAllUsesWith(V: Result);
1697 FusedInsts.insert(Ptr: MatMul);
1698 ToRemove.push_back(Elt: MatMul);
1699 }
1700
1701 /// Compute \p Result += \p A * \p B for input matrices with left-associating
1702 /// addition.
1703 ///
1704 /// We can fold a transpose into the operand that is used to extract scalars.
1705 /// This is the first operands with row-major and the second with
1706 /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate
1707 /// operand is transposed.
1708 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
1709 const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled,
1710 bool IsScalarMatrixTransposed, FastMathFlags FMF) {
1711 const unsigned VF = std::max<unsigned>(
1712 a: TTI.getRegisterBitWidth(K: TargetTransformInfo::RGK_FixedWidthVector)
1713 .getFixedValue() /
1714 Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),
1715 b: 1U);
1716 unsigned R = Result.getNumRows();
1717 unsigned C = Result.getNumColumns();
1718 unsigned M = A.getNumColumns();
1719
1720 bool IsFP = Result.getElementType()->isFloatingPointTy();
1721 assert(A.isColumnMajor() == B.isColumnMajor() &&
1722 Result.isColumnMajor() == A.isColumnMajor() &&
1723 "operands must agree on matrix layout");
1724 unsigned NumComputeOps = 0;
1725
1726 Builder.setFastMathFlags(FMF);
1727
1728 if (A.isColumnMajor()) {
1729 // Multiply columns from the first operand with scalars from the second
1730 // operand. Then move along the K axes and accumulate the columns. With
1731 // this the adds can be vectorized without reassociation.
1732 for (unsigned J = 0; J < C; ++J) {
1733 unsigned BlockSize = VF;
1734 // If Result is zero, we don't need to accumulate in the K==0 iteration.
1735 bool isSumZero = isa<ConstantAggregateZero>(Val: Result.getColumn(i: J));
1736
1737 for (unsigned I = 0; I < R; I += BlockSize) {
1738 // Gradually lower the vectorization factor to cover the remainder.
1739 while (I + BlockSize > R)
1740 BlockSize /= 2;
1741
1742 Value *Sum = IsTiled ? Result.extractVector(I, J, NumElts: BlockSize, Builder)
1743 : nullptr;
1744 for (unsigned K = 0; K < M; ++K) {
1745 Value *L = A.extractVector(I, J: K, NumElts: BlockSize, Builder);
1746 Value *RH = Builder.CreateExtractElement(
1747 Vec: B.getColumn(i: IsScalarMatrixTransposed ? K : J),
1748 Idx: IsScalarMatrixTransposed ? J : K);
1749 Value *Splat = Builder.CreateVectorSplat(NumElts: BlockSize, V: RH, Name: "splat");
1750 Sum =
1751 createMulAdd(Sum: isSumZero && K == 0 ? nullptr : Sum, A: L, B: Splat,
1752 UseFPOp: IsFP, Builder, AllowContraction: FMF.allowContract(), NumComputeOps);
1753 }
1754 Result.setVector(i: J,
1755 V: insertVector(Col: Result.getVector(i: J), I, Block: Sum, Builder));
1756 }
1757 }
1758 } else {
1759 // Multiply rows from the second operand with scalars from the first
1760 // operand. Then move along the K axes and accumulate the rows. With this
1761 // the adds can be vectorized without reassociation.
1762 for (unsigned I = 0; I < R; ++I) {
1763 unsigned BlockSize = VF;
1764 bool isSumZero = isa<ConstantAggregateZero>(Val: Result.getRow(i: I));
1765 for (unsigned J = 0; J < C; J += BlockSize) {
1766 // Gradually lower the vectorization factor to cover the remainder.
1767 while (J + BlockSize > C)
1768 BlockSize /= 2;
1769
1770 Value *Sum = nullptr;
1771 for (unsigned K = 0; K < M; ++K) {
1772 Value *R = B.extractVector(I: K, J, NumElts: BlockSize, Builder);
1773 Value *LH = Builder.CreateExtractElement(
1774 Vec: A.getVector(i: IsScalarMatrixTransposed ? K : I),
1775 Idx: IsScalarMatrixTransposed ? I : K);
1776 Value *Splat = Builder.CreateVectorSplat(NumElts: BlockSize, V: LH, Name: "splat");
1777 Sum =
1778 createMulAdd(Sum: isSumZero && K == 0 ? nullptr : Sum, A: Splat, B: R,
1779 UseFPOp: IsFP, Builder, AllowContraction: FMF.allowContract(), NumComputeOps);
1780 }
1781 Result.setVector(i: I,
1782 V: insertVector(Col: Result.getVector(i: I), I: J, Block: Sum, Builder));
1783 }
1784 }
1785 }
1786 Result.addNumComputeOps(N: NumComputeOps);
1787 }
1788
1789 /// Ensure that the memory in \p Load does not alias \p Store by potentially
1790 /// copying it to a new location. This new or otherwise the original location
1791 /// is returned.
1792 Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store,
1793 CallInst *MatMul) {
1794 MemoryLocation StoreLoc = MemoryLocation::get(SI: Store);
1795 MemoryLocation LoadLoc = MemoryLocation::get(LI: Load);
1796
1797 // If we can statically determine noalias we're good.
1798 if (AA->isNoAlias(LocA: LoadLoc, LocB: StoreLoc))
1799 return Load->getPointerOperand();
1800
1801 // Create code to check if the memory locations of the Load and Store
1802 // overlap and if they do, copy Load's operand to a new buffer.
1803
1804 // First, create new blocks for 2n part of the check and the copy.
1805 BasicBlock *Check0 = MatMul->getParent();
1806 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
1807 // DT. Manually collect dominator tree updates, to avoid unnecessary work,
1808 // as we adjust Check0 and Check1's branches.
1809 SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
1810 for (BasicBlock *Succ : successors(BB: Check0))
1811 DTUpdates.push_back(Elt: {DT->Delete, Check0, Succ});
1812
1813 BasicBlock *Check1 =
1814 SplitBlock(Old: MatMul->getParent(), SplitPt: MatMul, DTU: (DomTreeUpdater *)nullptr, LI,
1815 MSSAU: nullptr, BBName: "alias_cont");
1816 BasicBlock *Copy =
1817 SplitBlock(Old: MatMul->getParent(), SplitPt: MatMul, DTU: (DomTreeUpdater *)nullptr, LI,
1818 MSSAU: nullptr, BBName: "copy");
1819 BasicBlock *Fusion =
1820 SplitBlock(Old: MatMul->getParent(), SplitPt: MatMul, DTU: (DomTreeUpdater *)nullptr, LI,
1821 MSSAU: nullptr, BBName: "no_alias");
1822
1823 // Check if the loaded memory location begins before the end of the store
1824 // location. If the condition holds, they might overlap, otherwise they are
1825 // guaranteed to not overlap.
1826 IRBuilder<> Builder(MatMul);
1827 Check0->getTerminator()->eraseFromParent();
1828 Builder.SetInsertPoint(Check0);
1829 Type *IntPtrTy = Builder.getIntPtrTy(DL: Load->getDataLayout());
1830 Value *StoreBegin = Builder.CreatePtrToInt(
1831 V: const_cast<Value *>(StoreLoc.Ptr), DestTy: IntPtrTy, Name: "store.begin");
1832 Value *StoreEnd = Builder.CreateAdd(
1833 LHS: StoreBegin, RHS: ConstantInt::get(Ty: IntPtrTy, V: StoreLoc.Size.getValue()),
1834 Name: "store.end", HasNUW: true, HasNSW: true);
1835 Value *LoadBegin = Builder.CreatePtrToInt(V: const_cast<Value *>(LoadLoc.Ptr),
1836 DestTy: IntPtrTy, Name: "load.begin");
1837 Builder.CreateCondBr(Cond: Builder.CreateICmpULT(LHS: LoadBegin, RHS: StoreEnd), True: Check1,
1838 False: Fusion);
1839
1840 // Check if the store begins before the end of the load location. If the
1841 // condition holds, they alias, otherwise they are guaranteed to not
1842 // overlap.
1843 Check1->getTerminator()->eraseFromParent();
1844 Builder.SetInsertPoint(TheBB: Check1, IP: Check1->begin());
1845 Value *LoadEnd = Builder.CreateAdd(
1846 LHS: LoadBegin, RHS: ConstantInt::get(Ty: IntPtrTy, V: LoadLoc.Size.getValue()),
1847 Name: "load.end", HasNUW: true, HasNSW: true);
1848 Builder.CreateCondBr(Cond: Builder.CreateICmpULT(LHS: StoreBegin, RHS: LoadEnd), True: Copy,
1849 False: Fusion);
1850
1851 // Copy load operand to new alloca.
1852 Builder.SetInsertPoint(TheBB: Copy, IP: Copy->begin());
1853 auto *VT = cast<FixedVectorType>(Val: Load->getType());
1854 // Use an array type for the alloca, to avoid potentially huge alignment
1855 // requirements for large vector types.
1856 auto *ArrayTy = ArrayType::get(ElementType: VT->getElementType(), NumElements: VT->getNumElements());
1857 AllocaInst *Alloca =
1858 Builder.CreateAlloca(Ty: ArrayTy, AddrSpace: Load->getPointerAddressSpace());
1859
1860 Builder.CreateMemCpy(Dst: Alloca, DstAlign: Alloca->getAlign(), Src: Load->getPointerOperand(),
1861 SrcAlign: Load->getAlign(), Size: LoadLoc.Size.getValue());
1862 Builder.SetInsertPoint(TheBB: Fusion, IP: Fusion->begin());
1863 PHINode *PHI = Builder.CreatePHI(Ty: Load->getPointerOperandType(), NumReservedValues: 3);
1864 PHI->addIncoming(V: Load->getPointerOperand(), BB: Check0);
1865 PHI->addIncoming(V: Load->getPointerOperand(), BB: Check1);
1866 PHI->addIncoming(V: Alloca, BB: Copy);
1867
1868 // Adjust DT.
1869 DTUpdates.push_back(Elt: {DT->Insert, Check0, Check1});
1870 DTUpdates.push_back(Elt: {DT->Insert, Check0, Fusion});
1871 DTUpdates.push_back(Elt: {DT->Insert, Check1, Copy});
1872 DTUpdates.push_back(Elt: {DT->Insert, Check1, Fusion});
1873 DT->applyUpdates(Updates: DTUpdates);
1874 return PHI;
1875 }
1876
1877 bool isFusionProfitable(CallInst *MatMul) {
1878 if (ForceFusion)
1879 return true;
1880
1881 ShapeInfo LShape(MatMul->getArgOperand(i: 2), MatMul->getArgOperand(i: 3));
1882 ShapeInfo RShape(MatMul->getArgOperand(i: 3), MatMul->getArgOperand(i: 4));
1883
1884 const unsigned R = LShape.NumRows;
1885 const unsigned C = RShape.NumColumns;
1886 const unsigned M = LShape.NumColumns;
1887 auto *EltType = cast<FixedVectorType>(Val: MatMul->getType())->getElementType();
1888
1889 const unsigned VF = std::max<unsigned>(
1890 a: TTI.getRegisterBitWidth(K: TargetTransformInfo::RGK_FixedWidthVector)
1891 .getFixedValue() /
1892 EltType->getPrimitiveSizeInBits().getFixedValue(),
1893 b: 1U);
1894
1895 // Cost model for tiling
1896 //
1897 // For tiling to be beneficial, we need reuse either along the R or
1898 // the C axis. We vectorize along the R axis so that means at least
1899 // 3 elements.
1900 // TODO: Also consider cost of copying if operands alias.
1901 if (R <= VF && C == 1)
1902 return false;
1903 // Then we need enough elements to exceed the number of vector
1904 // registers we have. Note that this is an oversimplification since
1905 // fusing also takes some extra loads which may exceed the number of
1906 // reloads necessary.
1907 unsigned Op0Regs = (R + VF - 1) / VF * M;
1908 unsigned Op1Regs = (M + VF - 1) / VF * C;
1909 return Op0Regs + Op1Regs >
1910 TTI.getNumberOfRegisters(ClassID: TTI.getRegisterClassForType(Vector: true));
1911 }
1912
1913 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {
1914 MatrixTy Res;
1915 auto *ColumType = FixedVectorType::get(ElementType: EltType, NumElts: R);
1916 for (unsigned I = 0; I < C; ++I)
1917 Res.addVector(V: ConstantAggregateZero::get(Ty: ColumType));
1918 return Res;
1919 }
1920
1921 void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
1922 Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
1923 auto *EltType = cast<FixedVectorType>(Val: MatMul->getType())->getElementType();
1924
1925 // Create the main tiling loop nest.
1926 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);
1927 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1928 Instruction *InsertI = cast<Instruction>(Val: MatMul);
1929 BasicBlock *Start = InsertI->getParent();
1930 BasicBlock *End =
1931 SplitBlock(Old: InsertI->getParent(), SplitPt: InsertI, DT, LI, MSSAU: nullptr, BBName: "continue");
1932 IRBuilder<> Builder(MatMul);
1933 BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, B&: Builder, DTU, LI&: *LI);
1934
1935 Type *TileVecTy =
1936 FixedVectorType::get(ElementType: MatMul->getType()->getScalarType(), NumElts: TileSize);
1937 MatrixTy TileResult;
1938 // Insert in the inner loop header.
1939 Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());
1940 // Create PHI nodes for the result columns to accumulate across iterations.
1941 SmallVector<PHINode *, 4> ColumnPhis;
1942 for (unsigned I = 0; I < TileSize; I++) {
1943 auto *Phi = Builder.CreatePHI(Ty: TileVecTy, NumReservedValues: 2, Name: "result.vec." + Twine(I));
1944 Phi->addIncoming(V: ConstantAggregateZero::get(Ty: TileVecTy),
1945 BB: TI.RowLoop.Header->getSingleSuccessor());
1946 TileResult.addVector(V: Phi);
1947 ColumnPhis.push_back(Elt: Phi);
1948 }
1949
1950 // Insert in the inner loop body, which computes
1951 // Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
1952 Builder.SetInsertPoint(InnerBody->getTerminator());
1953 // Load tiles of the operands.
1954 MatrixTy A =
1955 loadMatrix(MatrixPtr: LPtr, Align: {}, IsVolatile: false, MatrixShape: LShape, I: TI.RowLoop.Index, J: TI.KLoop.Index,
1956 ResultShape: {TileSize, TileSize}, EltTy: EltType, Builder);
1957 MatrixTy B =
1958 loadMatrix(MatrixPtr: RPtr, Align: {}, IsVolatile: false, MatrixShape: RShape, I: TI.KLoop.Index, J: TI.ColumnLoop.Index,
1959 ResultShape: {TileSize, TileSize}, EltTy: EltType, Builder);
1960 emitMatrixMultiply(Result&: TileResult, A, B, Builder, IsTiled: true, IsScalarMatrixTransposed: false,
1961 FMF: getFastMathFlags(Inst: MatMul));
1962 // Store result after the inner loop is done.
1963 Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());
1964 storeMatrix(StoreVal: TileResult, MatrixPtr: Store->getPointerOperand(), MAlign: Store->getAlign(),
1965 IsVolatile: Store->isVolatile(), MatrixShape: {LShape.NumRows, RShape.NumColumns},
1966 I: TI.RowLoop.Index, J: TI.ColumnLoop.Index, EltTy: EltType, Builder);
1967
1968 for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
1969 ColumnPhis[I]->addIncoming(V: TileResult.getVector(i: I), BB: TI.KLoop.Latch);
1970
1971 // Force unrolling of a few iterations of the inner loop, to make sure there
1972 // is enough work per iteration.
1973 // FIXME: The unroller should make this decision directly instead, but
1974 // currently the cost-model is not up to the task.
1975 unsigned InnerLoopUnrollCount = std::min(a: 10u, b: LShape.NumColumns / TileSize);
1976 addStringMetadataToLoop(TheLoop: LI->getLoopFor(BB: TI.KLoop.Header),
1977 MDString: "llvm.loop.unroll.count", V: InnerLoopUnrollCount);
1978 }
1979
1980 void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1,
1981 StoreInst *Store,
1982 SmallPtrSetImpl<Instruction *> &FusedInsts) {
1983 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1984 "Tiling only supported for column-major matrixes at the moment!");
1985 if (!isFusionProfitable(MatMul))
1986 return;
1987
1988 ShapeInfo LShape(MatMul->getArgOperand(i: 2), MatMul->getArgOperand(i: 3));
1989 ShapeInfo RShape(MatMul->getArgOperand(i: 3), MatMul->getArgOperand(i: 4));
1990
1991 const unsigned R = LShape.NumRows;
1992 const unsigned C = RShape.NumColumns;
1993 const unsigned M = LShape.NumColumns;
1994 auto *EltType = cast<FixedVectorType>(Val: MatMul->getType())->getElementType();
1995
1996 Value *APtr = getNonAliasingPointer(Load: LoadOp0, Store, MatMul);
1997 Value *BPtr = getNonAliasingPointer(Load: LoadOp1, Store, MatMul);
1998 Value *CPtr = Store->getPointerOperand();
1999
2000 if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0))
2001 createTiledLoops(MatMul, LPtr: APtr, LShape, RPtr: BPtr, RShape, Store);
2002 else {
2003 IRBuilder<> Builder(Store);
2004 for (unsigned J = 0; J < C; J += TileSize)
2005 for (unsigned I = 0; I < R; I += TileSize) {
2006 const unsigned TileR = std::min(a: R - I, b: unsigned(TileSize));
2007 const unsigned TileC = std::min(a: C - J, b: unsigned(TileSize));
2008 MatrixTy Res = getZeroMatrix(EltType, R: TileR, C: TileC);
2009
2010 for (unsigned K = 0; K < M; K += TileSize) {
2011 const unsigned TileM = std::min(a: M - K, b: unsigned(TileSize));
2012 MatrixTy A =
2013 loadMatrix(MatrixPtr: APtr, Align: LoadOp0->getAlign(), IsVolatile: LoadOp0->isVolatile(),
2014 MatrixShape: LShape, I: Builder.getInt64(C: I), J: Builder.getInt64(C: K),
2015 ResultShape: {TileR, TileM}, EltTy: EltType, Builder);
2016 MatrixTy B =
2017 loadMatrix(MatrixPtr: BPtr, Align: LoadOp1->getAlign(), IsVolatile: LoadOp1->isVolatile(),
2018 MatrixShape: RShape, I: Builder.getInt64(C: K), J: Builder.getInt64(C: J),
2019 ResultShape: {TileM, TileC}, EltTy: EltType, Builder);
2020 emitMatrixMultiply(Result&: Res, A, B, Builder, IsTiled: true, IsScalarMatrixTransposed: false,
2021 FMF: getFastMathFlags(Inst: MatMul));
2022 }
2023 storeMatrix(StoreVal: Res, MatrixPtr: CPtr, MAlign: Store->getAlign(), IsVolatile: Store->isVolatile(), MatrixShape: {R, M},
2024 I: Builder.getInt64(C: I), J: Builder.getInt64(C: J), EltTy: EltType,
2025 Builder);
2026 }
2027 }
2028
2029 // Mark eliminated instructions as fused and remove them.
2030 FusedInsts.insert(Ptr: Store);
2031 FusedInsts.insert(Ptr: MatMul);
2032 eraseFromParentAndRemoveFromShapeMap(Inst: Store);
2033 eraseFromParentAndRemoveFromShapeMap(Inst: MatMul);
2034 if (LoadOp0->use_empty()) {
2035 FusedInsts.insert(Ptr: LoadOp0);
2036 eraseFromParentAndRemoveFromShapeMap(Inst: LoadOp0);
2037 }
2038 if (LoadOp1 != LoadOp0 && LoadOp1->use_empty()) {
2039 FusedInsts.insert(Ptr: LoadOp1);
2040 eraseFromParentAndRemoveFromShapeMap(Inst: LoadOp1);
2041 }
2042 }
2043
2044 /// Try to lower matrix multiply chains by fusing operations.
2045 ///
2046 /// Call finalizeLowering on lowered instructions. Instructions that are
2047 /// completely eliminated by fusion are added to \p FusedInsts.
2048 void
2049 LowerMatrixMultiplyFused(CallInst *MatMul,
2050 SmallPtrSetImpl<Instruction *> &FusedInsts,
2051 SmallVector<IntrinsicInst *, 16> &LifetimeEnds) {
2052 if (!FuseMatrix || !DT)
2053 return;
2054
2055 assert(AA && LI && "Analyses should be available");
2056
2057 Value *A = MatMul->getArgOperand(i: 0);
2058 Value *B = MatMul->getArgOperand(i: 1);
2059
2060 // We can fold the transpose into the operand that is used to fetch scalars.
2061 Value *T;
2062 if (MatrixLayout == MatrixLayoutTy::ColumnMajor
2063 ? match(V: B, P: m_Intrinsic<Intrinsic::matrix_transpose>(Op0: m_Value(V&: T)))
2064 : match(V: A, P: m_Intrinsic<Intrinsic::matrix_transpose>(Op0: m_Value(V&: T)))) {
2065 IRBuilder<> Builder(MatMul);
2066 auto *EltType =
2067 cast<FixedVectorType>(Val: MatMul->getType())->getElementType();
2068 ShapeInfo LShape(MatMul->getArgOperand(i: 2), MatMul->getArgOperand(i: 3));
2069 ShapeInfo RShape(MatMul->getArgOperand(i: 3), MatMul->getArgOperand(i: 4));
2070 const unsigned R = LShape.NumRows;
2071 const unsigned M = LShape.NumColumns;
2072 const unsigned C = RShape.NumColumns;
2073
2074 MatrixTy MA;
2075 MatrixTy MB;
2076
2077 Value *Transpose;
2078 if (MatrixLayout == MatrixLayoutTy::ColumnMajor) {
2079 MA = getMatrix(MatrixVal: A, SI: ShapeInfo(R, M), Builder);
2080 MB = getMatrix(MatrixVal: T, SI: ShapeInfo(C, M), Builder);
2081 Transpose = B;
2082 } else {
2083 MA = getMatrix(MatrixVal: T, SI: ShapeInfo(R, M), Builder);
2084 MB = getMatrix(MatrixVal: B, SI: ShapeInfo(C, M), Builder);
2085 Transpose = A;
2086 }
2087
2088 // Initialize the output
2089 MatrixTy Result(R, C, EltType);
2090
2091 emitMatrixMultiply(Result, A: MA, B: MB, Builder, IsTiled: false, IsScalarMatrixTransposed: true,
2092 FMF: getFastMathFlags(Inst: MatMul));
2093
2094 FusedInsts.insert(Ptr: MatMul);
2095 if (Transpose->hasOneUse()) {
2096 FusedInsts.insert(Ptr: cast<Instruction>(Val: Transpose));
2097 ToRemove.push_back(Elt: cast<Instruction>(Val: Transpose));
2098 // TODO: add a fake entry for the folded instruction so that this is
2099 // included in the expression in the remark.
2100 Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType);
2101 }
2102 finalizeLowering(Inst: MatMul, Matrix: Result, Builder);
2103 return;
2104 }
2105
2106 if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor)
2107 return;
2108
2109 // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering
2110 // since the single store user will be lowered as part of this.
2111 auto *LoadOp0 = dyn_cast<LoadInst>(Val: A);
2112 auto *LoadOp1 = dyn_cast<LoadInst>(Val: B);
2113 auto *Store = dyn_cast<StoreInst>(Val: *MatMul->user_begin());
2114 if (LoadOp0 && LoadOp1 && Store) {
2115 // The store address must dominate the MatMul instruction, otherwise
2116 // we create invalid IR.
2117 SetVector<Value *> WorkList;
2118 WorkList.insert(X: Store->getOperand(i_nocapture: 1));
2119 SmallVector<Instruction *> ToHoist;
2120 for (unsigned I = 0; I != WorkList.size(); ++I) {
2121 Value *Current = WorkList[I];
2122 auto *CurrI = dyn_cast<Instruction>(Val: Current);
2123 if (!CurrI)
2124 continue;
2125 if (isa<PHINode>(Val: CurrI))
2126 return;
2127 if (DT->dominates(Def: CurrI, User: MatMul))
2128 continue;
2129 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
2130 return;
2131 ToHoist.push_back(Elt: CurrI);
2132 WorkList.insert_range(R: CurrI->operands());
2133 }
2134
2135 sort(C&: ToHoist, Comp: [this](Instruction *A, Instruction *B) {
2136 return DT->dominates(Def: A, User: B);
2137 });
2138 for (Instruction *I : ToHoist)
2139 I->moveBefore(InsertPos: MatMul->getIterator());
2140
2141 // Deal with lifetime.end calls that might be between Load0/Load1 and the
2142 // store. To avoid introducing loads to dead objects (i.e. after the
2143 // lifetime has been termined by @llvm.lifetime.end), either sink them
2144 // after the store if in the same block, or remove the lifetime.end marker
2145 // otherwise. This might pessimize further optimizations, by extending the
2146 // lifetime of the object until the function returns, but should be
2147 // conservatively correct.
2148 MemoryLocation Load0Loc = MemoryLocation::get(LI: LoadOp0);
2149 MemoryLocation Load1Loc = MemoryLocation::get(LI: LoadOp1);
2150 BasicBlock *StoreParent = Store->getParent();
2151 bool FusableOpsInSameBlock = LoadOp0->getParent() == StoreParent &&
2152 LoadOp1->getParent() == StoreParent;
2153 for (unsigned Idx = 0; Idx != LifetimeEnds.size();) {
2154 IntrinsicInst *End = LifetimeEnds[Idx];
2155 auto Inc = make_scope_exit(F: [&Idx]() { Idx++; });
2156 // If the lifetime.end is guaranteed to be before the loads or after the
2157 // store, it won't interfere with fusion.
2158 if (DT->dominates(Def: End, User: LoadOp0) && DT->dominates(Def: End, User: LoadOp1))
2159 continue;
2160 if (DT->dominates(Def: Store, User: End))
2161 continue;
2162 // If all fusable ops are in the same block and the lifetime.end is in a
2163 // different block, it won't interfere with fusion.
2164 if (FusableOpsInSameBlock && End->getParent() != StoreParent)
2165 continue;
2166
2167 // If the loads don't alias the lifetime.end, it won't interfere with
2168 // fusion.
2169 MemoryLocation EndLoc = MemoryLocation::getForArgument(Call: End, ArgIdx: 1, TLI: nullptr);
2170 if (!EndLoc.Ptr)
2171 continue;
2172 if (AA->isNoAlias(LocA: Load0Loc, LocB: EndLoc) && AA->isNoAlias(LocA: Load1Loc, LocB: EndLoc))
2173 continue;
2174
2175 // If both lifetime.end and the store are in the same block, extend the
2176 // lifetime until after the store, so the new lifetime covers the loads
2177 // we introduce later.
2178 if (End->getParent() == StoreParent) {
2179 End->moveAfter(MovePos: Store);
2180 continue;
2181 }
2182
2183 // Otherwise remove the conflicting lifetime.end marker.
2184 ToRemove.push_back(Elt: End);
2185 std::swap(a&: LifetimeEnds[Idx], b&: LifetimeEnds.back());
2186 LifetimeEnds.pop_back();
2187 Inc.release();
2188 }
2189
2190 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
2191 return;
2192 }
2193 }
2194
2195 /// Lowers llvm.matrix.multiply.
2196 MatrixTy LowerMultiply(CallInst *MatMul, IRBuilder<> &Builder) {
2197 auto *EltType = cast<FixedVectorType>(Val: MatMul->getType())->getElementType();
2198 ShapeInfo LShape(MatMul->getArgOperand(i: 2), MatMul->getArgOperand(i: 3));
2199 ShapeInfo RShape(MatMul->getArgOperand(i: 3), MatMul->getArgOperand(i: 4));
2200
2201 const MatrixTy &Lhs = getMatrix(MatrixVal: MatMul->getArgOperand(i: 0), SI: LShape, Builder);
2202 const MatrixTy &Rhs = getMatrix(MatrixVal: MatMul->getArgOperand(i: 1), SI: RShape, Builder);
2203 assert(Lhs.getElementType() == Rhs.getElementType() &&
2204 "Matrix multiply argument element types do not match.");
2205
2206 const unsigned R = LShape.NumRows;
2207 const unsigned C = RShape.NumColumns;
2208 assert(LShape.NumColumns == RShape.NumRows);
2209
2210 // Initialize the output
2211 MatrixTy Result(R, C, EltType);
2212 assert(Lhs.getElementType() == Result.getElementType() &&
2213 "Matrix multiply result element type does not match arguments.");
2214
2215 emitMatrixMultiply(Result, A: Lhs, B: Rhs, Builder, IsTiled: false, IsScalarMatrixTransposed: false,
2216 FMF: getFastMathFlags(Inst: MatMul));
2217 return Result;
2218 }
2219
2220 /// Lowers llvm.matrix.transpose.
2221 MatrixTy LowerTranspose(CallInst *Inst, IRBuilder<> &Builder) {
2222 MatrixTy Result;
2223 Value *InputVal = Inst->getArgOperand(i: 0);
2224 FixedVectorType *VectorTy = cast<FixedVectorType>(Val: InputVal->getType());
2225 ShapeInfo ArgShape(Inst->getArgOperand(i: 1), Inst->getArgOperand(i: 2));
2226 MatrixTy InputMatrix = getMatrix(MatrixVal: InputVal, SI: ArgShape, Builder);
2227
2228 const unsigned NewNumVecs =
2229 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
2230 const unsigned NewNumElts =
2231 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
2232
2233 for (unsigned I = 0; I < NewNumVecs; ++I) {
2234 // Build a single result vector. First initialize it.
2235 Value *ResultVector = PoisonValue::get(
2236 T: FixedVectorType::get(ElementType: VectorTy->getElementType(), NumElts: NewNumElts));
2237 // Go through the old elements and insert it into the resulting vector.
2238 for (auto J : enumerate(First: InputMatrix.vectors())) {
2239 Value *Elt = Builder.CreateExtractElement(Vec: J.value(), Idx: I);
2240 // Row and column indices are transposed.
2241 ResultVector =
2242 Builder.CreateInsertElement(Vec: ResultVector, NewElt: Elt, Idx: J.index());
2243 }
2244 Result.addVector(V: ResultVector);
2245 }
2246
2247 // TODO: Improve estimate of operations needed for transposes. Currently we
2248 // just count the insertelement/extractelement instructions, but do not
2249 // account for later simplifications/combines.
2250 return Result.addNumComputeOps(N: 2 * ArgShape.NumRows * ArgShape.NumColumns)
2251 .addNumExposedTransposes(N: 1);
2252 }
2253
2254 /// Lower load instructions.
2255 MatrixTy VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr,
2256 IRBuilder<> &Builder) {
2257 return LowerLoad(Inst, Ptr, Align: Inst->getAlign(),
2258 Stride: Builder.getInt64(C: SI.getStride()), IsVolatile: Inst->isVolatile(), Shape: SI,
2259 Builder);
2260 }
2261
2262 MatrixTy VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
2263 Value *Ptr, IRBuilder<> &Builder) {
2264 return LowerStore(Inst, Matrix: StoredVal, Ptr, A: Inst->getAlign(),
2265 Stride: Builder.getInt64(C: SI.getStride()), IsVolatile: Inst->isVolatile(), Shape: SI,
2266 Builder);
2267 }
2268
2269 MatrixTy VisitPHI(PHINode *Inst, const ShapeInfo &SI, IRBuilder<> &Builder) {
2270 auto BlockIP = Inst->getParent()->getFirstInsertionPt();
2271 Builder.SetInsertPoint(BlockIP);
2272 MatrixTy PhiM = getMatrix(MatrixVal: Inst, SI, Builder);
2273
2274 for (auto [IncomingV, IncomingB] :
2275 llvm::zip_equal(t: Inst->incoming_values(), u: Inst->blocks())) {
2276 // getMatrix() may insert some instructions to help with reshaping. The
2277 // safest place for those is at the top of the block after the rest of the
2278 // PHI's. Even better, if we can put it in the incoming block.
2279 Builder.SetInsertPoint(BlockIP);
2280 if (auto *IncomingInst = dyn_cast<Instruction>(Val&: IncomingV))
2281 if (auto MaybeIP = IncomingInst->getInsertionPointAfterDef())
2282 Builder.SetInsertPoint(*MaybeIP);
2283
2284 MatrixTy OpM = getMatrix(MatrixVal: IncomingV, SI, Builder);
2285
2286 for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) {
2287 PHINode *NewPHI = cast<PHINode>(Val: PhiM.getVector(i: VI));
2288 NewPHI->addIncoming(V: OpM.getVector(i: VI), BB: IncomingB);
2289 }
2290 }
2291
2292 // finalizeLowering() may also insert instructions in some cases. The safe
2293 // place for those is at the end of the initial block of PHIs.
2294 Builder.SetInsertPoint(BlockIP);
2295 return PhiM;
2296 }
2297
2298 /// Lower binary operators.
2299 MatrixTy VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI,
2300 IRBuilder<> &Builder) {
2301 Value *Lhs = Inst->getOperand(i_nocapture: 0);
2302 Value *Rhs = Inst->getOperand(i_nocapture: 1);
2303
2304 MatrixTy Result;
2305 MatrixTy A = getMatrix(MatrixVal: Lhs, SI, Builder);
2306 MatrixTy B = getMatrix(MatrixVal: Rhs, SI, Builder);
2307 assert(A.isColumnMajor() == B.isColumnMajor() &&
2308 Result.isColumnMajor() == A.isColumnMajor() &&
2309 "operands must agree on matrix layout");
2310
2311 Builder.setFastMathFlags(getFastMathFlags(Inst));
2312
2313 for (auto [AV, BV] : llvm::zip_equal(t: A.vectors(), u: B.vectors()))
2314 Result.addVector(V: Builder.CreateBinOp(Opc: Inst->getOpcode(), LHS: AV, RHS: BV));
2315
2316 return Result.addNumComputeOps(N: getNumOps(VT: Result.getVectorTy()) *
2317 Result.getNumVectors());
2318 }
2319
2320 /// Lower unary operators.
2321 MatrixTy VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI,
2322 IRBuilder<> &Builder) {
2323 Value *Op = Inst->getOperand(i_nocapture: 0);
2324
2325 MatrixTy Result;
2326 MatrixTy M = getMatrix(MatrixVal: Op, SI, Builder);
2327
2328 Builder.setFastMathFlags(getFastMathFlags(Inst));
2329
2330 // Helper to perform unary op on vectors.
2331 auto BuildVectorOp = [&Builder, Inst](Value *Op) {
2332 switch (Inst->getOpcode()) {
2333 case Instruction::FNeg:
2334 return Builder.CreateFNeg(V: Op);
2335 default:
2336 llvm_unreachable("Unsupported unary operator for matrix");
2337 }
2338 };
2339
2340 for (auto *Vector : M.vectors())
2341 Result.addVector(V: BuildVectorOp(Vector));
2342
2343 return Result.addNumComputeOps(N: getNumOps(VT: Result.getVectorTy()) *
2344 Result.getNumVectors());
2345 }
2346
2347 /// Lower cast instructions.
2348 MatrixTy VisitCastInstruction(CastInst *Inst, const ShapeInfo &Shape,
2349 IRBuilder<> &Builder) {
2350 Value *Op = Inst->getOperand(i_nocapture: 0);
2351
2352 MatrixTy Result;
2353 MatrixTy M = getMatrix(MatrixVal: Op, SI: Shape, Builder);
2354
2355 Builder.setFastMathFlags(getFastMathFlags(Inst));
2356
2357 auto *OrigVTy = cast<VectorType>(Val: Inst->getType());
2358 auto *NewVTy = VectorType::get(ElementType: OrigVTy->getElementType(),
2359 EC: ElementCount::getFixed(MinVal: M.getStride()));
2360
2361 for (auto *Vector : M.vectors())
2362 Result.addVector(V: Builder.CreateCast(Op: Inst->getOpcode(), V: Vector, DestTy: NewVTy));
2363
2364 return Result.addNumComputeOps(N: getNumOps(VT: Result.getVectorTy()) *
2365 Result.getNumVectors());
2366 }
2367
2368 /// Lower selects.
2369 MatrixTy VisitSelectInst(SelectInst *Inst, const ShapeInfo &Shape,
2370 IRBuilder<> &Builder) {
2371 Value *Cond = Inst->getOperand(i_nocapture: 0);
2372 Value *OpA = Inst->getOperand(i_nocapture: 1);
2373 Value *OpB = Inst->getOperand(i_nocapture: 2);
2374
2375 MatrixTy Result;
2376 MatrixTy A = getMatrix(MatrixVal: OpA, SI: Shape, Builder);
2377 MatrixTy B = getMatrix(MatrixVal: OpB, SI: Shape, Builder);
2378
2379 SmallVector<Value*> CondV;
2380 if (isa<FixedVectorType>(Val: Cond->getType())) {
2381 MatrixTy C = getMatrix(MatrixVal: Cond, SI: Shape, Builder);
2382 llvm::copy(Range: C.vectors(), Out: std::back_inserter(x&: CondV));
2383 } else {
2384 CondV.resize(N: A.getNumVectors());
2385 llvm::fill(Range&: CondV, Value&: Cond);
2386 }
2387
2388 for (auto [CV, AV, BV] : llvm::zip_equal(t&: CondV, u: A.vectors(), args: B.vectors()))
2389 Result.addVector(V: Builder.CreateSelect(C: CV, True: AV, False: BV));
2390
2391 return Result.addNumComputeOps(N: getNumOps(VT: Result.getVectorTy()) *
2392 Result.getNumVectors());
2393 }
2394
2395 /// Helper to linearize a matrix expression tree into a string. Currently
2396 /// matrix expressions are linarized by starting at an expression leaf and
2397 /// linearizing bottom up.
2398 struct ExprLinearizer {
2399 unsigned LengthToBreak = 100;
2400 std::string Str;
2401 raw_string_ostream Stream;
2402 unsigned LineLength = 0;
2403 const DataLayout &DL;
2404
2405 /// Mapping from instructions to matrixes. It is used to identify
2406 /// matrix instructions.
2407 const MapVector<Value *, MatrixTy> &Inst2Matrix;
2408
2409 /// Mapping from values to the leaves of all expressions that the value is
2410 /// part of.
2411 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared;
2412
2413 /// Set of matrix expressions in the scope of a given DISubprogram.
2414 const SmallSetVector<Value *, 32> &ExprsInSubprogram;
2415
2416 /// Leaf node of the expression to linearize.
2417 Value *Leaf;
2418
2419 /// Used to keep track of sub-expressions that get reused while linearizing
2420 /// the expression. Re-used sub-expressions are marked as (reused).
2421 SmallPtrSet<Value *, 8> ReusedExprs;
2422
2423 ExprLinearizer(const DataLayout &DL,
2424 const MapVector<Value *, MatrixTy> &Inst2Matrix,
2425 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2426 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2427 Value *Leaf)
2428 : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
2429 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
2430
2431 void indent(unsigned N) {
2432 LineLength += N;
2433 for (unsigned i = 0; i < N; i++)
2434 Stream << " ";
2435 }
2436
2437 void lineBreak() {
2438 Stream << "\n";
2439 LineLength = 0;
2440 }
2441
2442 void maybeIndent(unsigned Indent) {
2443 if (LineLength >= LengthToBreak)
2444 lineBreak();
2445
2446 if (LineLength == 0)
2447 indent(N: Indent);
2448 }
2449
2450 void write(StringRef S) {
2451 LineLength += S.size();
2452 Stream << S;
2453 }
2454
2455 Value *getUnderlyingObjectThroughLoads(Value *V) {
2456 if (Value *Ptr = getPointerOperand(V))
2457 return getUnderlyingObjectThroughLoads(V: Ptr);
2458 else if (V->getType()->isPointerTy())
2459 return getUnderlyingObject(V);
2460 return V;
2461 }
2462
2463 /// Returns true if \p V is a matrix value in the given subprogram.
2464 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(key: V); }
2465
2466 /// If \p V is a matrix value, print its shape as NumRows x NumColumns to
2467 /// \p SS.
2468 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
2469 auto M = Inst2Matrix.find(Key: V);
2470 if (M == Inst2Matrix.end())
2471 SS << "unknown";
2472 else {
2473 SS << M->second.getNumRows();
2474 SS << "x";
2475 SS << M->second.getNumColumns();
2476 }
2477 }
2478
2479 /// Write the called function name. Handles calls to llvm.matrix.*
2480 /// specially: we write the name, followed by the dimensions of the input
2481 /// matrixes, followed by the scalar type name.
2482 void writeFnName(CallInst *CI) {
2483 if (!CI->getCalledFunction())
2484 write(S: "<no called fn>");
2485 else {
2486 StringRef Name = CI->getCalledFunction()->getName();
2487 if (!Name.starts_with(Prefix: "llvm.matrix")) {
2488 write(S: Name);
2489 return;
2490 }
2491 auto *II = cast<IntrinsicInst>(Val: CI);
2492 write(S: Intrinsic::getBaseName(id: II->getIntrinsicID())
2493 .drop_front(N: StringRef("llvm.matrix.").size()));
2494 write(S: ".");
2495 std::string Tmp;
2496 raw_string_ostream SS(Tmp);
2497
2498 switch (II->getIntrinsicID()) {
2499 case Intrinsic::matrix_multiply:
2500 prettyPrintMatrixType(V: II->getOperand(i_nocapture: 0), SS);
2501 SS << ".";
2502 prettyPrintMatrixType(V: II->getOperand(i_nocapture: 1), SS);
2503 SS << "." << *II->getType()->getScalarType();
2504 break;
2505 case Intrinsic::matrix_transpose:
2506 prettyPrintMatrixType(V: II->getOperand(i_nocapture: 0), SS);
2507 SS << "." << *II->getType()->getScalarType();
2508 break;
2509 case Intrinsic::matrix_column_major_load:
2510 prettyPrintMatrixType(V: II, SS);
2511 SS << "." << *II->getType()->getScalarType();
2512 break;
2513 case Intrinsic::matrix_column_major_store:
2514 prettyPrintMatrixType(V: II->getOperand(i_nocapture: 0), SS);
2515 SS << "." << *II->getOperand(i_nocapture: 0)->getType()->getScalarType();
2516 break;
2517 default:
2518 llvm_unreachable("Unhandled case");
2519 }
2520 write(S: Tmp);
2521 }
2522 }
2523
2524 unsigned getNumShapeArgs(CallInst *CI) const {
2525 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: CI)) {
2526 switch (II->getIntrinsicID()) {
2527 case Intrinsic::matrix_multiply:
2528 return 3;
2529 case Intrinsic::matrix_transpose:
2530 return 2;
2531 case Intrinsic::matrix_column_major_load:
2532 case Intrinsic::matrix_column_major_store:
2533 return 3;
2534 default:
2535 return 0;
2536 }
2537 }
2538 return 0;
2539 }
2540
2541 /// Special printing for values: for pointers, we print if they refer to an
2542 /// (function) external address or a stack address, for other values we
2543 /// either print the constant or "scalar"/"matrix" for other values.
2544 void write(Value *V) {
2545 V = getUnderlyingObjectThroughLoads(V);
2546 if (V->getType()->isPointerTy()) {
2547 if (isa<AllocaInst>(Val: V)) {
2548 Stream << "stack addr";
2549 LineLength += StringRef("stack addr").size();
2550 } else {
2551 Stream << "addr";
2552 LineLength += StringRef("addr").size();
2553 }
2554 if (!V->getName().empty()) {
2555 Stream << " %" << V->getName() << "";
2556 LineLength += V->getName().size() + 2;
2557 }
2558 return;
2559 }
2560
2561 std::string Tmp;
2562 raw_string_ostream TmpStream(Tmp);
2563
2564 if (auto *CI = dyn_cast<ConstantInt>(Val: V))
2565 TmpStream << CI->getValue();
2566 else if (isa<Constant>(Val: V))
2567 TmpStream << "constant";
2568 else {
2569 if (isMatrix(V))
2570 TmpStream << "matrix";
2571 else
2572 TmpStream << "scalar";
2573 }
2574 Tmp = std::string(StringRef(Tmp).trim());
2575 LineLength += Tmp.size();
2576 Stream << Tmp;
2577 }
2578
2579 /// Linearize expression \p Expr starting at an indentation of \p Indent.
2580 /// Expressions that are re-used multiple times are prefixed with (reused)
2581 /// at the re-used root instruction.
2582 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
2583 bool ParentShared) {
2584 auto *I = cast<Instruction>(Val: Expr);
2585 maybeIndent(Indent);
2586 SmallVector<Value *, 8> Ops;
2587
2588 // Is Expr shared with other expression leaves?
2589 bool ExprShared = false;
2590
2591 // Deal with shared subtrees. Mark them as shared, if required.
2592 if (!ParentShared) {
2593 auto SI = Shared.find(Val: Expr);
2594 assert(SI != Shared.end() && SI->second.count(Leaf));
2595
2596 for (Value *S : SI->second) {
2597 if (S == Leaf)
2598 continue;
2599 DebugLoc DL = cast<Instruction>(Val: S)->getDebugLoc();
2600 write(S: "shared with remark at line " + std::to_string(val: DL.getLine()) +
2601 " column " + std::to_string(val: DL.getCol()) + " (");
2602 }
2603 ExprShared = SI->second.size() > 1;
2604 }
2605
2606 bool Reused = !ReusedExprs.insert(Ptr: Expr).second;
2607 if (Reused && !ParentReused)
2608 write(S: "(reused) ");
2609
2610 if (auto *CI = dyn_cast<CallInst>(Val: I)) {
2611 writeFnName(CI);
2612
2613 Ops.append(in_start: CI->arg_begin(), in_end: CI->arg_end() - getNumShapeArgs(CI));
2614 } else if (isa<BitCastInst>(Val: Expr)) {
2615 // Special case bitcasts, which are used to materialize matrixes from
2616 // non-matrix ops.
2617 write(S: "matrix");
2618 return;
2619 } else {
2620 Ops.append(in_start: I->value_op_begin(), in_end: I->value_op_end());
2621 write(S: I->getOpcodeName());
2622 }
2623
2624 write(S: "(");
2625
2626 unsigned NumOpsToBreak = 1;
2627 if (match(V: Expr, P: m_Intrinsic<Intrinsic::matrix_column_major_load>()))
2628 NumOpsToBreak = 2;
2629
2630 for (Value *Op : Ops) {
2631 if (Ops.size() > NumOpsToBreak)
2632 lineBreak();
2633
2634 maybeIndent(Indent: Indent + 1);
2635 if (isMatrix(V: Op))
2636 linearizeExpr(Expr: Op, Indent: Indent + 1, ParentReused: Reused, ParentShared: ExprShared);
2637 else
2638 write(V: Op);
2639 if (Op != Ops.back())
2640 write(S: ", ");
2641 }
2642
2643 write(S: ")");
2644 }
2645
2646 const std::string &getResult() {
2647 return Str;
2648 }
2649 };
2650
2651 /// Generate remarks for matrix operations in a function. To generate remarks
2652 /// for matrix expressions, the following approach is used:
2653 /// 1. Use the inlined-at debug information to group matrix operations to the
2654 /// DISubprograms they are contained in.
2655 /// 2. Collect leaves of matrix expressions (done in
2656 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
2657 // mapping. Leaves are lowered matrix instructions without other matrix
2658 // users (like stores) in the current subprogram.
2659 /// 3. For each leaf, create a remark containing a linearizied version of the
2660 /// matrix expression. The expression is linearized by a recursive
2661 /// bottom-up traversal of the matrix operands, starting at a leaf. Note
2662 /// that multiple leaves can share sub-expressions. Shared subexpressions
2663 /// are explicitly marked as shared().
2664 struct RemarkGenerator {
2665 const MapVector<Value *, MatrixTy> &Inst2Matrix;
2666 OptimizationRemarkEmitter &ORE;
2667 Function &Func;
2668 const DataLayout &DL;
2669
2670 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,
2671 OptimizationRemarkEmitter &ORE, Function &Func)
2672 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
2673 DL(Func.getDataLayout()) {}
2674
2675 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
2676 /// instructions in Inst2Matrix returning void or without any users in
2677 /// \p ExprsInSubprogram. Currently that should only include stores.
2678 SmallVector<Value *, 4>
2679 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {
2680 SmallVector<Value *, 4> Leaves;
2681 for (auto *Expr : ExprsInSubprogram)
2682 if (Expr->getType()->isVoidTy() ||
2683 !any_of(Range: Expr->users(), P: [&ExprsInSubprogram](User *U) {
2684 return ExprsInSubprogram.count(key: U);
2685 }))
2686 Leaves.push_back(Elt: Expr);
2687 return Leaves;
2688 }
2689
2690 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
2691 /// to all visited expressions in \p Shared. Limit the matrix operations to
2692 /// the ones in \p ExprsInSubprogram.
2693 void collectSharedInfo(Value *Leaf, Value *V,
2694 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2695 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) {
2696
2697 if (!ExprsInSubprogram.count(key: V))
2698 return;
2699
2700 Shared[V].insert(Ptr: Leaf);
2701
2702 for (Value *Op : cast<Instruction>(Val: V)->operand_values())
2703 collectSharedInfo(Leaf, V: Op, ExprsInSubprogram, Shared);
2704 }
2705
2706 /// Calculate the number of exclusive and shared op counts for expression
2707 /// starting at \p V. Expressions used multiple times are counted once.
2708 /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
2709 std::pair<OpInfoTy, OpInfoTy>
2710 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
2711 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2712 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {
2713 if (!ExprsInSubprogram.count(key: Root))
2714 return {};
2715
2716 // Already counted this expression. Stop.
2717 if (!ReusedExprs.insert(Ptr: Root).second)
2718 return {};
2719
2720 OpInfoTy SharedCount;
2721 OpInfoTy Count;
2722
2723 auto I = Shared.find(Val: Root);
2724 auto CM = Inst2Matrix.find(Key: Root);
2725 if (I->second.size() == 1)
2726 Count = CM->second.getOpInfo();
2727 else
2728 SharedCount = CM->second.getOpInfo();
2729
2730 for (Value *Op : cast<Instruction>(Val: Root)->operand_values()) {
2731 auto C = sumOpInfos(Root: Op, ReusedExprs, ExprsInSubprogram, Shared);
2732 Count += C.first;
2733 SharedCount += C.second;
2734 }
2735 return {Count, SharedCount};
2736 }
2737
2738 void emitRemarks() {
2739 if (!ORE.allowExtraAnalysis(DEBUG_TYPE))
2740 return;
2741
2742 // Map matrix operations to their containting subprograms, by traversing
2743 // the inlinedAt chain. If the function does not have a DISubprogram, we
2744 // only map them to the containing function.
2745 MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs;
2746 for (const auto &KV : Inst2Matrix) {
2747 if (Func.getSubprogram()) {
2748 auto *I = cast<Instruction>(Val: KV.first);
2749 DILocation *Context = I->getDebugLoc();
2750 while (Context) {
2751 Subprog2Exprs[getSubprogram(Scope: Context->getScope())].push_back(
2752 Elt: KV.first);
2753 Context = DebugLoc(Context).getInlinedAt();
2754 }
2755 } else {
2756 Subprog2Exprs[nullptr].push_back(Elt: KV.first);
2757 }
2758 }
2759 for (auto &KV : Subprog2Exprs) {
2760 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
2761 KV.second.end());
2762 auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2763
2764 DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared;
2765 for (Value *Leaf : Leaves)
2766 collectSharedInfo(Leaf, V: Leaf, ExprsInSubprogram, Shared);
2767
2768 // Generate remarks for each leaf.
2769 for (auto *L : Leaves) {
2770
2771 DebugLoc Loc = cast<Instruction>(Val: L)->getDebugLoc();
2772 DILocation *Context = cast<Instruction>(Val: L)->getDebugLoc();
2773 while (Context) {
2774 if (getSubprogram(Scope: Context->getScope()) == KV.first) {
2775 Loc = Context;
2776 break;
2777 }
2778 Context = DebugLoc(Context).getInlinedAt();
2779 }
2780
2781 SmallPtrSet<Value *, 8> ReusedExprs;
2782 OpInfoTy Counts, SharedCounts;
2783 std::tie(args&: Counts, args&: SharedCounts) =
2784 sumOpInfos(Root: L, ReusedExprs, ExprsInSubprogram, Shared);
2785
2786 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
2787 cast<Instruction>(Val: L)->getParent());
2788
2789 Rem << "Lowered with ";
2790 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
2791 << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
2792 << ore::NV("NumComputeOps", Counts.NumComputeOps)
2793 << " compute ops, "
2794 << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes)
2795 << " exposed transposes";
2796
2797 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2798 SharedCounts.NumComputeOps > 0) {
2799 Rem << ",\nadditionally "
2800 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
2801 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
2802 << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
2803 << " compute ops"
2804 << " are shared with other expressions";
2805 }
2806
2807 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
2808 ORE.emit(OptDiag&: Rem);
2809 }
2810 }
2811 }
2812
2813 std::string
2814 linearize(Value *L,
2815 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2816 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2817 const DataLayout &DL) {
2818 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2819 Lin.linearizeExpr(Expr: L, Indent: 0, ParentReused: false, ParentShared: false);
2820 return Lin.getResult();
2821 }
2822 };
2823};
2824} // namespace
2825
2826PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F,
2827 FunctionAnalysisManager &AM) {
2828 auto &TTI = AM.getResult<TargetIRAnalysis>(IR&: F);
2829
2830 LowerMatrixIntrinsics LMT(F, TTI, Minimal ? nullptr : &AM);
2831 if (LMT.Visit()) {
2832 PreservedAnalyses PA;
2833 if (!Minimal) {
2834 PA.preserve<LoopAnalysis>();
2835 PA.preserve<DominatorTreeAnalysis>();
2836 }
2837 return PA;
2838 }
2839 return PreservedAnalyses::all();
2840}
2841
2842void LowerMatrixIntrinsicsPass::printPipeline(
2843 raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
2844 static_cast<PassInfoMixin<LowerMatrixIntrinsicsPass> *>(this)->printPipeline(
2845 OS, MapClassName2PassName);
2846 OS << '<';
2847 if (Minimal)
2848 OS << "minimal";
2849 OS << '>';
2850}
2851