1//===- MatrixUtils.cpp - Utilities to lower matrix intrinsics ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Utilities for generating tiled loops for matrix operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "llvm/Transforms/Utils/MatrixUtils.h"
14#include "llvm/Analysis/DomTreeUpdater.h"
15#include "llvm/Analysis/LoopInfo.h"
16#include "llvm/IR/BasicBlock.h"
17#include "llvm/IR/Dominators.h"
18#include "llvm/IR/IRBuilder.h"
19#include "llvm/IR/MDBuilder.h"
20#include "llvm/IR/ProfDataUtils.h"
21#include "llvm/IR/Type.h"
22#include "llvm/Support/CommandLine.h"
23
24using namespace llvm;
25
26namespace llvm {
27extern cl::opt<bool> ProfcheckDisableMetadataFixes;
28} // end namespace llvm
29
30BasicBlock *TileInfo::CreateLoop(BasicBlock *Preheader, BasicBlock *Exit,
31 ConstantInt *Bound, ConstantInt *Step,
32 StringRef Name, IRBuilderBase &B,
33 DomTreeUpdater &DTU, Loop *L, LoopInfo &LI) {
34 LLVMContext &Ctx = Preheader->getContext();
35 BasicBlock *Header = BasicBlock::Create(
36 Context&: Preheader->getContext(), Name: Name + ".header", Parent: Preheader->getParent(), InsertBefore: Exit);
37 BasicBlock *Body = BasicBlock::Create(Context&: Header->getContext(), Name: Name + ".body",
38 Parent: Header->getParent(), InsertBefore: Exit);
39 BasicBlock *Latch = BasicBlock::Create(Context&: Header->getContext(), Name: Name + ".latch",
40 Parent: Header->getParent(), InsertBefore: Exit);
41
42 Type *I32Ty = Type::getInt64Ty(C&: Ctx);
43 UncondBrInst::Create(IfTrue: Body, InsertBefore: Header);
44 UncondBrInst::Create(IfTrue: Latch, InsertBefore: Body);
45 PHINode *IV =
46 PHINode::Create(Ty: I32Ty, NumReservedValues: 2, NameStr: Name + ".iv", InsertBefore: Header->getTerminator()->getIterator());
47 IV->addIncoming(V: ConstantInt::get(Ty: I32Ty, V: 0), BB: Preheader);
48
49 B.SetInsertPoint(Latch);
50 Value *Inc = B.CreateAdd(LHS: IV, RHS: Step, Name: Name + ".step");
51 Value *Cond = B.CreateICmpNE(LHS: Inc, RHS: Bound, Name: Name + ".cond");
52 auto *BR = B.CreateCondBr(Cond, True: Header, False: Exit);
53 if (!ProfcheckDisableMetadataFixes) {
54 assert(Step->getZExtValue() != 0 &&
55 "Expected a non-zero step size. This is chosen by the pass and "
56 "should always be non-zero to imply a finite loop.");
57 MDBuilder MDB(Preheader->getContext());
58 setFittedBranchWeights(
59 I&: *BR, Weights: {Bound->getZExtValue() / Step->getZExtValue(), 1}, IsExpected: false);
60 }
61 IV->addIncoming(V: Inc, BB: Latch);
62
63 UncondBrInst *PreheaderBr = cast<UncondBrInst>(Val: Preheader->getTerminator());
64 BasicBlock *Tmp = PreheaderBr->getSuccessor();
65 PreheaderBr->setSuccessor(idx: 0, NewSucc: Header);
66 DTU.applyUpdatesPermissive(Updates: {
67 {DominatorTree::Delete, Preheader, Tmp},
68 {DominatorTree::Insert, Header, Body},
69 {DominatorTree::Insert, Body, Latch},
70 {DominatorTree::Insert, Latch, Header},
71 {DominatorTree::Insert, Latch, Exit},
72 {DominatorTree::Insert, Preheader, Header},
73 });
74
75 L->addBasicBlockToLoop(NewBB: Header, LI);
76 L->addBasicBlockToLoop(NewBB: Body, LI);
77 L->addBasicBlockToLoop(NewBB: Latch, LI);
78 return Body;
79}
80
81// Creates the following loop nest skeleton:
82// for C = 0; C < NumColumns; C += TileSize
83// for R = 0; R < NumRows; R += TileSize
84// for K = 0; K < Inner ; K += TileSize
85BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
86 IRBuilderBase &B, DomTreeUpdater &DTU,
87 LoopInfo &LI) {
88 Loop *ColumnLoopInfo = LI.AllocateLoop();
89 Loop *RowLoopInfo = LI.AllocateLoop();
90 Loop *KLoopInfo = LI.AllocateLoop();
91 RowLoopInfo->addChildLoop(NewChild: KLoopInfo);
92 ColumnLoopInfo->addChildLoop(NewChild: RowLoopInfo);
93 if (Loop *ParentL = LI.getLoopFor(BB: Start))
94 ParentL->addChildLoop(NewChild: ColumnLoopInfo);
95 else
96 LI.addTopLevelLoop(New: ColumnLoopInfo);
97
98 BasicBlock *ColBody =
99 CreateLoop(Preheader: Start, Exit: End, Bound: B.getInt64(C: NumColumns), Step: B.getInt64(C: TileSize),
100 Name: "cols", B, DTU, L: ColumnLoopInfo, LI);
101 ColumnLoop.Latch = ColBody->getSingleSuccessor();
102 BasicBlock *RowBody =
103 CreateLoop(Preheader: ColBody, Exit: ColumnLoop.Latch, Bound: B.getInt64(C: NumRows),
104 Step: B.getInt64(C: TileSize), Name: "rows", B, DTU, L: RowLoopInfo, LI);
105 RowLoop.Latch = RowBody->getSingleSuccessor();
106
107 BasicBlock *InnerBody =
108 CreateLoop(Preheader: RowBody, Exit: RowLoop.Latch, Bound: B.getInt64(C: NumInner),
109 Step: B.getInt64(C: TileSize), Name: "inner", B, DTU, L: KLoopInfo, LI);
110 KLoop.Latch = InnerBody->getSingleSuccessor();
111 ColumnLoop.Header = ColBody->getSinglePredecessor();
112 RowLoop.Header = RowBody->getSinglePredecessor();
113 KLoop.Header = InnerBody->getSinglePredecessor();
114 RowLoop.Index = &*RowLoop.Header->begin();
115 ColumnLoop.Index = &*ColumnLoop.Header->begin();
116 KLoop.Index = &*KLoop.Header->begin();
117
118 return InnerBody;
119}
120