1//===- MatrixUtils.cpp - Utilities to lower matrix intrinsics ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Utilities for generating tiled loops for matrix operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "llvm/Transforms/Utils/MatrixUtils.h"
14#include "llvm/Analysis/DomTreeUpdater.h"
15#include "llvm/Analysis/LoopInfo.h"
16#include "llvm/IR/BasicBlock.h"
17#include "llvm/IR/Dominators.h"
18#include "llvm/IR/IRBuilder.h"
19#include "llvm/IR/Type.h"
20
21using namespace llvm;
22
23BasicBlock *TileInfo::CreateLoop(BasicBlock *Preheader, BasicBlock *Exit,
24 Value *Bound, Value *Step, StringRef Name,
25 IRBuilderBase &B, DomTreeUpdater &DTU, Loop *L,
26 LoopInfo &LI) {
27 LLVMContext &Ctx = Preheader->getContext();
28 BasicBlock *Header = BasicBlock::Create(
29 Context&: Preheader->getContext(), Name: Name + ".header", Parent: Preheader->getParent(), InsertBefore: Exit);
30 BasicBlock *Body = BasicBlock::Create(Context&: Header->getContext(), Name: Name + ".body",
31 Parent: Header->getParent(), InsertBefore: Exit);
32 BasicBlock *Latch = BasicBlock::Create(Context&: Header->getContext(), Name: Name + ".latch",
33 Parent: Header->getParent(), InsertBefore: Exit);
34
35 Type *I32Ty = Type::getInt64Ty(C&: Ctx);
36 BranchInst::Create(IfTrue: Body, InsertBefore: Header);
37 BranchInst::Create(IfTrue: Latch, InsertBefore: Body);
38 PHINode *IV =
39 PHINode::Create(Ty: I32Ty, NumReservedValues: 2, NameStr: Name + ".iv", InsertBefore: Header->getTerminator()->getIterator());
40 IV->addIncoming(V: ConstantInt::get(Ty: I32Ty, V: 0), BB: Preheader);
41
42 B.SetInsertPoint(Latch);
43 Value *Inc = B.CreateAdd(LHS: IV, RHS: Step, Name: Name + ".step");
44 Value *Cond = B.CreateICmpNE(LHS: Inc, RHS: Bound, Name: Name + ".cond");
45 BranchInst::Create(IfTrue: Header, IfFalse: Exit, Cond, InsertBefore: Latch);
46 IV->addIncoming(V: Inc, BB: Latch);
47
48 BranchInst *PreheaderBr = cast<BranchInst>(Val: Preheader->getTerminator());
49 BasicBlock *Tmp = PreheaderBr->getSuccessor(i: 0);
50 PreheaderBr->setSuccessor(idx: 0, NewSucc: Header);
51 DTU.applyUpdatesPermissive(Updates: {
52 {DominatorTree::Delete, Preheader, Tmp},
53 {DominatorTree::Insert, Header, Body},
54 {DominatorTree::Insert, Body, Latch},
55 {DominatorTree::Insert, Latch, Header},
56 {DominatorTree::Insert, Latch, Exit},
57 {DominatorTree::Insert, Preheader, Header},
58 });
59
60 L->addBasicBlockToLoop(NewBB: Header, LI);
61 L->addBasicBlockToLoop(NewBB: Body, LI);
62 L->addBasicBlockToLoop(NewBB: Latch, LI);
63 return Body;
64}
65
66// Creates the following loop nest skeleton:
67// for C = 0; C < NumColumns; C += TileSize
68// for R = 0; R < NumRows; R += TileSize
69// for K = 0; K < Inner ; K += TileSize
70BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
71 IRBuilderBase &B, DomTreeUpdater &DTU,
72 LoopInfo &LI) {
73 Loop *ColumnLoopInfo = LI.AllocateLoop();
74 Loop *RowLoopInfo = LI.AllocateLoop();
75 Loop *KLoopInfo = LI.AllocateLoop();
76 RowLoopInfo->addChildLoop(NewChild: KLoopInfo);
77 ColumnLoopInfo->addChildLoop(NewChild: RowLoopInfo);
78 if (Loop *ParentL = LI.getLoopFor(BB: Start))
79 ParentL->addChildLoop(NewChild: ColumnLoopInfo);
80 else
81 LI.addTopLevelLoop(New: ColumnLoopInfo);
82
83 BasicBlock *ColBody =
84 CreateLoop(Preheader: Start, Exit: End, Bound: B.getInt64(C: NumColumns), Step: B.getInt64(C: TileSize),
85 Name: "cols", B, DTU, L: ColumnLoopInfo, LI);
86 ColumnLoop.Latch = ColBody->getSingleSuccessor();
87 BasicBlock *RowBody =
88 CreateLoop(Preheader: ColBody, Exit: ColumnLoop.Latch, Bound: B.getInt64(C: NumRows),
89 Step: B.getInt64(C: TileSize), Name: "rows", B, DTU, L: RowLoopInfo, LI);
90 RowLoop.Latch = RowBody->getSingleSuccessor();
91
92 BasicBlock *InnerBody =
93 CreateLoop(Preheader: RowBody, Exit: RowLoop.Latch, Bound: B.getInt64(C: NumInner),
94 Step: B.getInt64(C: TileSize), Name: "inner", B, DTU, L: KLoopInfo, LI);
95 KLoop.Latch = InnerBody->getSingleSuccessor();
96 ColumnLoop.Header = ColBody->getSinglePredecessor();
97 RowLoop.Header = RowBody->getSinglePredecessor();
98 KLoop.Header = InnerBody->getSinglePredecessor();
99 RowLoop.Index = &*RowLoop.Header->begin();
100 ColumnLoop.Index = &*ColumnLoop.Header->begin();
101 KLoop.Index = &*KLoop.Header->begin();
102
103 return InnerBody;
104}
105