1//===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Pass to transform amx intrinsics to scalar operations.
10/// This pass is always enabled and it skips when it is not -O0 and has no
11/// optnone attributes. With -O0 or optnone attribute, the def of shape to amx
12/// intrinsics is near the amx intrinsics code. We are not able to find a
13/// point which post-dominate all the shape and dominate all amx intrinsics.
14/// To decouple the dependency of the shape, we transform amx intrinsics
15/// to scalar operation, so that compiling doesn't fail. In long term, we
16/// should improve fast register allocation to allocate amx register.
17//===----------------------------------------------------------------------===//
18//
19#include "X86.h"
20#include "llvm/ADT/PostOrderIterator.h"
21#include "llvm/Analysis/DomTreeUpdater.h"
22#include "llvm/Analysis/LoopInfo.h"
23#include "llvm/Analysis/OptimizationRemarkEmitter.h"
24#include "llvm/Analysis/TargetTransformInfo.h"
25#include "llvm/CodeGen/Passes.h"
26#include "llvm/CodeGen/TargetPassConfig.h"
27#include "llvm/CodeGen/ValueTypes.h"
28#include "llvm/IR/DataLayout.h"
29#include "llvm/IR/Function.h"
30#include "llvm/IR/IRBuilder.h"
31#include "llvm/IR/Instructions.h"
32#include "llvm/IR/IntrinsicInst.h"
33#include "llvm/IR/IntrinsicsX86.h"
34#include "llvm/IR/PatternMatch.h"
35#include "llvm/InitializePasses.h"
36#include "llvm/Pass.h"
37#include "llvm/Support/CommandLine.h"
38#include "llvm/Target/TargetMachine.h"
39#include "llvm/Transforms/Utils/BasicBlockUtils.h"
40#include "llvm/Transforms/Utils/LoopUtils.h"
41
42using namespace llvm;
43using namespace PatternMatch;
44
45#define DEBUG_TYPE "lower-amx-intrinsics"
46
47#ifndef NDEBUG
48static bool isV256I32Ty(Type *Ty) {
49 if (auto *FVT = dyn_cast<FixedVectorType>(Ty))
50 return FVT->getNumElements() == 256 &&
51 FVT->getElementType()->isIntegerTy(32);
52 return false;
53}
54#endif
55
56static cl::opt<bool>
57 X86ScalarizeAMX("enable-x86-scalar-amx", cl::init(Val: false), cl::Hidden,
58 cl::desc("X86: enable AMX scalarizition."));
59
60namespace {
61class X86LowerAMXIntrinsics {
62 Function &Func;
63
64public:
65 X86LowerAMXIntrinsics(Function &F, DomTreeUpdater &DomTU, LoopInfo *LoopI)
66 : Func(F), DTU(DomTU), LI(LoopI) {}
67 bool visit();
68
69private:
70 DomTreeUpdater &DTU;
71 LoopInfo *LI;
72 BasicBlock *createLoop(BasicBlock *Preheader, BasicBlock *Exit, Value *Bound,
73 Value *Step, StringRef Name, IRBuilderBase &B,
74 Loop *L);
75 template <bool IsTileLoad>
76 Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End,
77 IRBuilderBase &B, Value *Row, Value *Col,
78 Value *Ptr, Value *Stride, Value *Tile);
79 template <Intrinsic::ID IntrID>
80 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
81 IntrID == Intrinsic::x86_tdpbsud_internal ||
82 IntrID == Intrinsic::x86_tdpbusd_internal ||
83 IntrID == Intrinsic::x86_tdpbuud_internal ||
84 IntrID == Intrinsic::x86_tdpbf16ps_internal,
85 Value *>
86 createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B,
87 Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS,
88 Value *RHS);
89 template <bool IsTileLoad>
90 bool lowerTileLoadStore(Instruction *TileLoadStore);
91 template <Intrinsic::ID IntrID>
92 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
93 IntrID == Intrinsic::x86_tdpbsud_internal ||
94 IntrID == Intrinsic::x86_tdpbusd_internal ||
95 IntrID == Intrinsic::x86_tdpbuud_internal ||
96 IntrID == Intrinsic::x86_tdpbf16ps_internal,
97 bool>
98 lowerTileDP(Instruction *TileDP);
99 bool lowerTileZero(Instruction *TileZero);
100};
101} // anonymous namespace
102
103BasicBlock *X86LowerAMXIntrinsics::createLoop(BasicBlock *Preheader,
104 BasicBlock *Exit, Value *Bound,
105 Value *Step, StringRef Name,
106 IRBuilderBase &B, Loop *L) {
107 LLVMContext &Ctx = Preheader->getContext();
108 BasicBlock *Header =
109 BasicBlock::Create(Context&: Ctx, Name: Name + ".header", Parent: Preheader->getParent(), InsertBefore: Exit);
110 BasicBlock *Body =
111 BasicBlock::Create(Context&: Ctx, Name: Name + ".body", Parent: Header->getParent(), InsertBefore: Exit);
112 BasicBlock *Latch =
113 BasicBlock::Create(Context&: Ctx, Name: Name + ".latch", Parent: Header->getParent(), InsertBefore: Exit);
114
115 Type *I16Ty = Type::getInt16Ty(C&: Ctx);
116 BranchInst::Create(IfTrue: Body, InsertBefore: Header);
117 BranchInst::Create(IfTrue: Latch, InsertBefore: Body);
118 PHINode *IV =
119 PHINode::Create(Ty: I16Ty, NumReservedValues: 2, NameStr: Name + ".iv", InsertBefore: Header->getTerminator()->getIterator());
120 IV->addIncoming(V: ConstantInt::get(Ty: I16Ty, V: 0), BB: Preheader);
121
122 B.SetInsertPoint(Latch);
123 Value *Inc = B.CreateAdd(LHS: IV, RHS: Step, Name: Name + ".step");
124 Value *Cond = B.CreateICmpNE(LHS: Inc, RHS: Bound, Name: Name + ".cond");
125 BranchInst::Create(IfTrue: Header, IfFalse: Exit, Cond, InsertBefore: Latch);
126 IV->addIncoming(V: Inc, BB: Latch);
127
128 BranchInst *PreheaderBr = cast<BranchInst>(Val: Preheader->getTerminator());
129 BasicBlock *Tmp = PreheaderBr->getSuccessor(i: 0);
130 PreheaderBr->setSuccessor(idx: 0, NewSucc: Header);
131 DTU.applyUpdatesPermissive(Updates: {
132 {DominatorTree::Delete, Preheader, Tmp},
133 {DominatorTree::Insert, Header, Body},
134 {DominatorTree::Insert, Body, Latch},
135 {DominatorTree::Insert, Latch, Header},
136 {DominatorTree::Insert, Latch, Exit},
137 {DominatorTree::Insert, Preheader, Header},
138 });
139 if (LI) {
140 L->addBasicBlockToLoop(NewBB: Header, LI&: *LI);
141 L->addBasicBlockToLoop(NewBB: Body, LI&: *LI);
142 L->addBasicBlockToLoop(NewBB: Latch, LI&: *LI);
143 }
144 return Body;
145}
146
147template <bool IsTileLoad>
148Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops(
149 BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row,
150 Value *Col, Value *Ptr, Value *Stride, Value *Tile) {
151 std::string IntrinName = IsTileLoad ? "tileload" : "tilestore";
152 Loop *RowLoop = nullptr;
153 Loop *ColLoop = nullptr;
154 if (LI) {
155 RowLoop = LI->AllocateLoop();
156 ColLoop = LI->AllocateLoop();
157 RowLoop->addChildLoop(NewChild: ColLoop);
158 if (Loop *ParentL = LI->getLoopFor(BB: Start))
159 ParentL->addChildLoop(NewChild: RowLoop);
160 else
161 LI->addTopLevelLoop(New: RowLoop);
162 }
163
164 BasicBlock *RowBody = createLoop(Preheader: Start, Exit: End, Bound: Row, Step: B.getInt16(C: 1),
165 Name: IntrinName + ".scalarize.rows", B, L: RowLoop);
166 BasicBlock *RowLatch = RowBody->getSingleSuccessor();
167
168 BasicBlock *ColBody = createLoop(Preheader: RowBody, Exit: RowLatch, Bound: Col, Step: B.getInt16(C: 1),
169 Name: IntrinName + ".scalarize.cols", B, L: ColLoop);
170
171 BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
172 BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
173 BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
174 Value *CurrentRow = &*RowLoopHeader->begin();
175 Value *CurrentCol = &*ColLoopHeader->begin();
176 Type *EltTy = B.getInt32Ty();
177 FixedVectorType *V256I32Ty = FixedVectorType::get(ElementType: EltTy, NumElts: 256);
178
179 // Common part for tileload and tilestore
180 // *.scalarize.cols.body:
181 // Calculate %idxmem and %idxvec
182 B.SetInsertPoint(ColBody->getTerminator());
183 Value *CurrentRowZExt = B.CreateZExt(V: CurrentRow, DestTy: Stride->getType());
184 Value *CurrentColZExt = B.CreateZExt(V: CurrentCol, DestTy: Stride->getType());
185 Value *Offset =
186 B.CreateAdd(LHS: B.CreateMul(LHS: CurrentRowZExt, RHS: Stride), RHS: CurrentColZExt);
187 Value *EltPtr = B.CreateGEP(Ty: EltTy, Ptr, IdxList: Offset);
188 Value *Idx = B.CreateAdd(LHS: B.CreateMul(LHS: CurrentRow, RHS: B.getInt16(C: 16)), RHS: CurrentCol);
189 if (IsTileLoad) {
190 // tileload.scalarize.rows.header:
191 // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec,
192 // %tileload.scalarize.rows.latch ]
193 B.SetInsertPoint(RowLoopHeader->getTerminator());
194 Value *VecZero = Constant::getNullValue(Ty: V256I32Ty);
195 PHINode *VecCPhiRowLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.phi.row");
196 VecCPhiRowLoop->addIncoming(V: VecZero, BB: Start);
197
198 // tileload.scalarize.cols.header:
199 // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body
200 // ], [ %ResVec, %tileload.scalarize.cols.latch ]
201 B.SetInsertPoint(ColLoopHeader->getTerminator());
202 PHINode *VecPhi = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.phi");
203 VecPhi->addIncoming(V: VecCPhiRowLoop, BB: RowBody);
204
205 // tileload.scalarize.cols.body:
206 // Calculate %idxmem and %idxvec
207 // %eltptr = getelementptr i32, i32* %base, i64 %idxmem
208 // %elt = load i32, i32* %ptr
209 // %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec
210 B.SetInsertPoint(ColBody->getTerminator());
211 Value *Elt = B.CreateLoad(Ty: EltTy, Ptr: EltPtr);
212 Value *ResVec = B.CreateInsertElement(Vec: VecPhi, NewElt: Elt, Idx);
213 VecPhi->addIncoming(V: ResVec, BB: ColLoopLatch);
214 VecCPhiRowLoop->addIncoming(V: ResVec, BB: RowLatch);
215
216 return ResVec;
217 } else {
218 auto *BitCast = cast<BitCastInst>(Val: Tile);
219 Value *Vec = BitCast->getOperand(i_nocapture: 0);
220 assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx");
221 // tilestore.scalarize.cols.body:
222 // %mul = mul i16 %row.iv, i16 16
223 // %idx = add i16 %mul, i16 %col.iv
224 // %vec = extractelement <16 x i32> %vec, i16 %idx
225 // store i32 %vec, i32* %ptr
226 B.SetInsertPoint(ColBody->getTerminator());
227 Value *Elt = B.CreateExtractElement(Vec, Idx);
228
229 B.CreateStore(Val: Elt, Ptr: EltPtr);
230 return nullptr;
231 }
232}
233
234template <Intrinsic::ID IntrID>
235std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
236 IntrID == Intrinsic::x86_tdpbsud_internal ||
237 IntrID == Intrinsic::x86_tdpbusd_internal ||
238 IntrID == Intrinsic::x86_tdpbuud_internal ||
239 IntrID == Intrinsic::x86_tdpbf16ps_internal,
240 Value *>
241X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End,
242 IRBuilderBase &B, Value *Row,
243 Value *Col, Value *K, Value *Acc,
244 Value *LHS, Value *RHS) {
245 std::string IntrinName;
246 switch (IntrID) {
247 case Intrinsic::x86_tdpbssd_internal:
248 IntrinName = "tiledpbssd";
249 break;
250 case Intrinsic::x86_tdpbsud_internal:
251 IntrinName = "tiledpbsud";
252 break;
253 case Intrinsic::x86_tdpbusd_internal:
254 IntrinName = "tiledpbusd";
255 break;
256 case Intrinsic::x86_tdpbuud_internal:
257 IntrinName = "tiledpbuud";
258 break;
259 case Intrinsic::x86_tdpbf16ps_internal:
260 IntrinName = "tiledpbf16ps";
261 break;
262 }
263 Loop *RowLoop = nullptr;
264 Loop *ColLoop = nullptr;
265 Loop *InnerLoop = nullptr;
266 if (LI) {
267 RowLoop = LI->AllocateLoop();
268 ColLoop = LI->AllocateLoop();
269 InnerLoop = LI->AllocateLoop();
270 ColLoop->addChildLoop(NewChild: InnerLoop);
271 RowLoop->addChildLoop(NewChild: ColLoop);
272 if (Loop *ParentL = LI->getLoopFor(BB: Start))
273 ParentL->addChildLoop(NewChild: RowLoop);
274 else
275 LI->addTopLevelLoop(New: RowLoop);
276 }
277
278 BasicBlock *RowBody = createLoop(Preheader: Start, Exit: End, Bound: Row, Step: B.getInt16(C: 1),
279 Name: IntrinName + ".scalarize.rows", B, L: RowLoop);
280 BasicBlock *RowLatch = RowBody->getSingleSuccessor();
281
282 BasicBlock *ColBody = createLoop(Preheader: RowBody, Exit: RowLatch, Bound: Col, Step: B.getInt16(C: 1),
283 Name: IntrinName + ".scalarize.cols", B, L: ColLoop);
284
285 BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
286
287 B.SetInsertPoint(ColBody->getTerminator());
288 BasicBlock *InnerBody =
289 createLoop(Preheader: ColBody, Exit: ColLoopLatch, Bound: K, Step: B.getInt16(C: 1),
290 Name: IntrinName + ".scalarize.inner", B, L: InnerLoop);
291
292 BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
293 BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
294 BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor();
295 BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor();
296 Value *CurrentRow = &*RowLoopHeader->begin();
297 Value *CurrentCol = &*ColLoopHeader->begin();
298 Value *CurrentInner = &*InnerLoopHeader->begin();
299
300 FixedVectorType *V256I32Ty = FixedVectorType::get(ElementType: B.getInt32Ty(), NumElts: 256);
301 auto *BitCastAcc = cast<BitCastInst>(Val: Acc);
302 Value *VecC = BitCastAcc->getOperand(i_nocapture: 0);
303 assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx");
304 // TODO else create BitCast from x86amx to v256i32.
305 // Store x86amx to memory, and reload from memory
306 // to vector. However with -O0, it doesn't happen.
307 auto *BitCastLHS = cast<BitCastInst>(Val: LHS);
308 Value *VecA = BitCastLHS->getOperand(i_nocapture: 0);
309 assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx");
310 auto *BitCastRHS = cast<BitCastInst>(Val: RHS);
311 Value *VecB = BitCastRHS->getOperand(i_nocapture: 0);
312 assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx");
313
314 // tiledpbssd.scalarize.rows.header:
315 // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC,
316 // %tiledpbssd.scalarize.rows.latch ]
317
318 // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [
319 // %NewVecD, %tiledpbssd.scalarize.rows.latch ]
320 B.SetInsertPoint(RowLoopHeader->getTerminator());
321 PHINode *VecCPhiRowLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.c.phi.row");
322 VecCPhiRowLoop->addIncoming(V: VecC, BB: Start);
323 Value *VecZero = Constant::getNullValue(Ty: V256I32Ty);
324 PHINode *VecDPhiRowLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.d.phi.row");
325 VecDPhiRowLoop->addIncoming(V: VecZero, BB: Start);
326
327 // tiledpbssd.scalarize.cols.header:
328 // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row,
329 // %tiledpbssd.scalarize.rows.body ], [ %NewVecC,
330 // %tiledpbssd.scalarize.cols.latch ]
331
332 // %vec.d.phi.col = phi <256 x i32> [
333 // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD,
334 // %tiledpbssd.scalarize.cols.latch ]
335
336 // calculate idxc.
337 B.SetInsertPoint(ColLoopHeader->getTerminator());
338 PHINode *VecCPhiColLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.c.phi.col");
339 VecCPhiColLoop->addIncoming(V: VecCPhiRowLoop, BB: RowBody);
340 PHINode *VecDPhiColLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.d.phi.col");
341 VecDPhiColLoop->addIncoming(V: VecDPhiRowLoop, BB: RowBody);
342 Value *IdxC =
343 B.CreateAdd(LHS: B.CreateMul(LHS: CurrentRow, RHS: B.getInt16(C: 16)), RHS: CurrentCol);
344
345 // tiledpbssd.scalarize.inner.header:
346 // %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col,
347 // %tiledpbssd.scalarize.cols.body ], [ %NewVecC,
348 // %tiledpbssd.scalarize.inner.latch ]
349
350 B.SetInsertPoint(InnerLoopHeader->getTerminator());
351 PHINode *VecCPhi = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.c.inner.phi");
352 VecCPhi->addIncoming(V: VecCPhiColLoop, BB: ColBody);
353
354 B.SetInsertPoint(InnerBody->getTerminator());
355 Value *IdxA =
356 B.CreateAdd(LHS: B.CreateMul(LHS: CurrentRow, RHS: B.getInt16(C: 16)), RHS: CurrentInner);
357 Value *IdxB =
358 B.CreateAdd(LHS: B.CreateMul(LHS: CurrentInner, RHS: B.getInt16(C: 16)), RHS: CurrentCol);
359 Value *NewVecC = nullptr;
360
361 if (IntrID != Intrinsic::x86_tdpbf16ps_internal) {
362 // tiledpbssd.scalarize.inner.body:
363 // calculate idxa, idxb
364 // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
365 // %elta = extractelement <256 x i32> %veca, i16 %idxa
366 // %eltav4i8 = bitcast i32 %elta to <4 x i8>
367 // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
368 // %eltbv4i8 = bitcast i32 %eltb to <4 x i8>
369 // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32>
370 // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32>
371 // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32
372 // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131)
373 // %neweltc = add i32 %elt, %acc
374 // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
375 // i16 %idxc
376 FixedVectorType *V4I8Ty = FixedVectorType::get(ElementType: B.getInt8Ty(), NumElts: 4);
377 FixedVectorType *V4I32Ty = FixedVectorType::get(ElementType: B.getInt32Ty(), NumElts: 4);
378 Value *EltC = B.CreateExtractElement(Vec: VecCPhi, Idx: IdxC);
379 Value *EltA = B.CreateExtractElement(Vec: VecA, Idx: IdxA);
380 Value *SubVecA = B.CreateBitCast(V: EltA, DestTy: V4I8Ty);
381 Value *EltB = B.CreateExtractElement(Vec: VecB, Idx: IdxB);
382 Value *SubVecB = B.CreateBitCast(V: EltB, DestTy: V4I8Ty);
383 Value *SEXTSubVecB = nullptr;
384 Value *SEXTSubVecA = nullptr;
385 switch (IntrID) {
386 case Intrinsic::x86_tdpbssd_internal:
387 SEXTSubVecB = B.CreateSExt(V: SubVecB, DestTy: V4I32Ty);
388 SEXTSubVecA = B.CreateSExt(V: SubVecA, DestTy: V4I32Ty);
389 break;
390 case Intrinsic::x86_tdpbsud_internal:
391 SEXTSubVecB = B.CreateZExt(V: SubVecB, DestTy: V4I32Ty);
392 SEXTSubVecA = B.CreateSExt(V: SubVecA, DestTy: V4I32Ty);
393 break;
394 case Intrinsic::x86_tdpbusd_internal:
395 SEXTSubVecB = B.CreateSExt(V: SubVecB, DestTy: V4I32Ty);
396 SEXTSubVecA = B.CreateZExt(V: SubVecA, DestTy: V4I32Ty);
397 break;
398 case Intrinsic::x86_tdpbuud_internal:
399 SEXTSubVecB = B.CreateZExt(V: SubVecB, DestTy: V4I32Ty);
400 SEXTSubVecA = B.CreateZExt(V: SubVecA, DestTy: V4I32Ty);
401 break;
402 default:
403 llvm_unreachable("Invalid intrinsic ID!");
404 }
405 Value *SubVecR = B.CreateAddReduce(Src: B.CreateMul(LHS: SEXTSubVecA, RHS: SEXTSubVecB));
406 Value *ResElt = B.CreateAdd(LHS: EltC, RHS: SubVecR);
407 NewVecC = B.CreateInsertElement(Vec: VecCPhi, NewElt: ResElt, Idx: IdxC);
408 } else {
409 // tiledpbf16ps.scalarize.inner.body:
410 // calculate idxa, idxb, idxc
411 // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
412 // %eltcf32 = bitcast i32 %eltc to float
413 // %elta = extractelement <256 x i32> %veca, i16 %idxa
414 // %eltav2i16 = bitcast i32 %elta to <2 x i16>
415 // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
416 // %eltbv2i16 = bitcast i32 %eltb to <2 x i16>
417 // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4
418 // x i32> <i32 2, i32 0, i32 3, i32 1>
419 // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float>
420 // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x
421 // i32> <i32 2, i32 0, i32 3, i32 1>
422 // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float>
423 // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32
424 // %acc = call float
425 // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab)
426 // %neweltc = bitcast float %acc to i32
427 // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
428 // i16 %idxc
429 // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc,
430 // i16 %idxc
431 FixedVectorType *V2I16Ty = FixedVectorType::get(ElementType: B.getInt16Ty(), NumElts: 2);
432 FixedVectorType *V2F32Ty = FixedVectorType::get(ElementType: B.getFloatTy(), NumElts: 2);
433 Value *EltC = B.CreateExtractElement(Vec: VecCPhi, Idx: IdxC);
434 Value *EltCF32 = B.CreateBitCast(V: EltC, DestTy: B.getFloatTy());
435 Value *EltA = B.CreateExtractElement(Vec: VecA, Idx: IdxA);
436 Value *SubVecA = B.CreateBitCast(V: EltA, DestTy: V2I16Ty);
437 Value *EltB = B.CreateExtractElement(Vec: VecB, Idx: IdxB);
438 Value *SubVecB = B.CreateBitCast(V: EltB, DestTy: V2I16Ty);
439 Value *ZeroV2I16 = Constant::getNullValue(Ty: V2I16Ty);
440 int ShuffleMask[4] = {2, 0, 3, 1};
441 auto ShuffleArray = ArrayRef(ShuffleMask);
442 Value *AV2F32 = B.CreateBitCast(
443 V: B.CreateShuffleVector(V1: SubVecA, V2: ZeroV2I16, Mask: ShuffleArray), DestTy: V2F32Ty);
444 Value *BV2F32 = B.CreateBitCast(
445 V: B.CreateShuffleVector(V1: SubVecB, V2: ZeroV2I16, Mask: ShuffleArray), DestTy: V2F32Ty);
446 Value *SubVecR = B.CreateFAddReduce(Acc: EltCF32, Src: B.CreateFMul(L: AV2F32, R: BV2F32));
447 Value *ResElt = B.CreateBitCast(V: SubVecR, DestTy: B.getInt32Ty());
448 NewVecC = B.CreateInsertElement(Vec: VecCPhi, NewElt: ResElt, Idx: IdxC);
449 }
450
451 // tiledpbssd.scalarize.cols.latch:
452 // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc
453 // %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC,
454 // i16 %idxc
455 B.SetInsertPoint(ColLoopLatch->getTerminator());
456 Value *NewEltC = B.CreateExtractElement(Vec: NewVecC, Idx: IdxC);
457 Value *NewVecD = B.CreateInsertElement(Vec: VecDPhiColLoop, NewElt: NewEltC, Idx: IdxC);
458
459 VecCPhi->addIncoming(V: NewVecC, BB: InnerLoopLatch);
460 VecCPhiRowLoop->addIncoming(V: NewVecC, BB: RowLatch);
461 VecCPhiColLoop->addIncoming(V: NewVecC, BB: ColLoopLatch);
462 VecDPhiRowLoop->addIncoming(V: NewVecD, BB: RowLatch);
463 VecDPhiColLoop->addIncoming(V: NewVecD, BB: ColLoopLatch);
464
465 return NewVecD;
466}
467
468template <Intrinsic::ID IntrID>
469std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
470 IntrID == Intrinsic::x86_tdpbsud_internal ||
471 IntrID == Intrinsic::x86_tdpbusd_internal ||
472 IntrID == Intrinsic::x86_tdpbuud_internal ||
473 IntrID == Intrinsic::x86_tdpbf16ps_internal,
474 bool>
475X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) {
476 Value *M, *N, *K, *C, *A, *B;
477 match(TileDP, m_Intrinsic<IntrID>(m_Value(V&: M), m_Value(V&: N), m_Value(V&: K),
478 m_Value(V&: C), m_Value(V&: A), m_Value(V&: B)));
479 Instruction *InsertI = TileDP;
480 IRBuilder<> PreBuilder(TileDP);
481 PreBuilder.SetInsertPoint(TileDP);
482 // We visit the loop with (m, n/4, k/4):
483 // %n_dword = lshr i16 %n, 2
484 // %k_dword = lshr i16 %k, 2
485 Value *NDWord = PreBuilder.CreateLShr(LHS: N, RHS: PreBuilder.getInt16(C: 2));
486 Value *KDWord = PreBuilder.CreateLShr(LHS: K, RHS: PreBuilder.getInt16(C: 2));
487 BasicBlock *Start = InsertI->getParent();
488 BasicBlock *End =
489 SplitBlock(Old: InsertI->getParent(), SplitPt: InsertI, DTU: &DTU, LI, MSSAU: nullptr, BBName: "continue");
490 IRBuilder<> Builder(TileDP);
491 Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord,
492 KDWord, C, A, B);
493 // we cannot assume there always be bitcast after tiledpbssd. So we need to
494 // insert one bitcast as required
495 Builder.SetInsertPoint(TheBB: End, IP: End->getFirstNonPHIIt());
496 Value *ResAMX =
497 Builder.CreateBitCast(V: ResVec, DestTy: Type::getX86_AMXTy(C&: Builder.getContext()));
498 // Delete TileDP intrinsic and do some clean-up.
499 for (Use &U : llvm::make_early_inc_range(Range: TileDP->uses())) {
500 Instruction *I = cast<Instruction>(Val: U.getUser());
501 Value *Vec;
502 if (match(V: I, P: m_BitCast(Op: m_Value(V&: Vec)))) {
503 I->replaceAllUsesWith(V: ResVec);
504 I->eraseFromParent();
505 }
506 }
507 TileDP->replaceAllUsesWith(V: ResAMX);
508 TileDP->eraseFromParent();
509 return true;
510}
511
512template <bool IsTileLoad>
513bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) {
514 Value *M, *N, *Ptr, *Stride, *Tile;
515 if (IsTileLoad)
516 match(V: TileLoadStore,
517 P: m_Intrinsic<Intrinsic::x86_tileloadd64_internal>(
518 Op0: m_Value(V&: M), Op1: m_Value(V&: N), Op2: m_Value(V&: Ptr), Op3: m_Value(V&: Stride)));
519 else
520 match(V: TileLoadStore, P: m_Intrinsic<Intrinsic::x86_tilestored64_internal>(
521 Op0: m_Value(V&: M), Op1: m_Value(V&: N), Op2: m_Value(V&: Ptr),
522 Op3: m_Value(V&: Stride), Op4: m_Value(V&: Tile)));
523
524 Instruction *InsertI = TileLoadStore;
525 IRBuilder<> PreBuilder(TileLoadStore);
526 PreBuilder.SetInsertPoint(TileLoadStore);
527 Value *NDWord = PreBuilder.CreateLShr(LHS: N, RHS: PreBuilder.getInt16(C: 2));
528 Value *StrideDWord = PreBuilder.CreateLShr(LHS: Stride, RHS: PreBuilder.getInt64(C: 2));
529 BasicBlock *Start = InsertI->getParent();
530 BasicBlock *End =
531 SplitBlock(Old: InsertI->getParent(), SplitPt: InsertI, DTU: &DTU, LI, MSSAU: nullptr, BBName: "continue");
532 IRBuilder<> Builder(TileLoadStore);
533 Value *ResVec = createTileLoadStoreLoops<IsTileLoad>(
534 Start, End, Builder, M, NDWord, Ptr, StrideDWord,
535 IsTileLoad ? nullptr : Tile);
536 if (IsTileLoad) {
537 // we cannot assume there always be bitcast after tileload. So we need to
538 // insert one bitcast as required
539 Builder.SetInsertPoint(TheBB: End, IP: End->getFirstNonPHIIt());
540 Value *ResAMX =
541 Builder.CreateBitCast(V: ResVec, DestTy: Type::getX86_AMXTy(C&: Builder.getContext()));
542 // Delete tileloadd6 intrinsic and do some clean-up
543 for (Use &U : llvm::make_early_inc_range(Range: TileLoadStore->uses())) {
544 Instruction *I = cast<Instruction>(Val: U.getUser());
545 Value *Vec;
546 if (match(V: I, P: m_BitCast(Op: m_Value(V&: Vec)))) {
547 I->replaceAllUsesWith(V: ResVec);
548 I->eraseFromParent();
549 }
550 }
551 TileLoadStore->replaceAllUsesWith(V: ResAMX);
552 }
553 TileLoadStore->eraseFromParent();
554 return true;
555}
556
557bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) {
558 IRBuilder<> Builder(TileZero);
559 FixedVectorType *V256I32Ty = FixedVectorType::get(ElementType: Builder.getInt32Ty(), NumElts: 256);
560 Value *VecZero = Constant::getNullValue(Ty: V256I32Ty);
561 for (Use &U : llvm::make_early_inc_range(Range: TileZero->uses())) {
562 Instruction *I = cast<Instruction>(Val: U.getUser());
563 Value *Vec;
564 if (match(V: I, P: m_BitCast(Op: m_Value(V&: Vec)))) {
565 I->replaceAllUsesWith(V: VecZero);
566 I->eraseFromParent();
567 }
568 }
569 TileZero->eraseFromParent();
570 return true;
571}
572
573bool X86LowerAMXIntrinsics::visit() {
574 bool C = false;
575 SmallVector<IntrinsicInst *, 8> WorkList;
576 for (BasicBlock *BB : depth_first(G: &Func)) {
577 for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
578 if (auto *Inst = dyn_cast<IntrinsicInst>(Val: &*II++)) {
579 switch (Inst->getIntrinsicID()) {
580 case Intrinsic::x86_tdpbssd_internal:
581 case Intrinsic::x86_tdpbsud_internal:
582 case Intrinsic::x86_tdpbusd_internal:
583 case Intrinsic::x86_tdpbuud_internal:
584 case Intrinsic::x86_tileloadd64_internal:
585 case Intrinsic::x86_tilestored64_internal:
586 case Intrinsic::x86_tilezero_internal:
587 case Intrinsic::x86_tdpbf16ps_internal:
588 WorkList.push_back(Elt: Inst);
589 break;
590 default:
591 break;
592 }
593 }
594 }
595 }
596
597 for (auto *Inst : WorkList) {
598 switch (Inst->getIntrinsicID()) {
599 case Intrinsic::x86_tdpbssd_internal:
600 C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(TileDP: Inst) || C;
601 break;
602 case Intrinsic::x86_tdpbsud_internal:
603 C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(TileDP: Inst) || C;
604 break;
605 case Intrinsic::x86_tdpbusd_internal:
606 C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(TileDP: Inst) || C;
607 break;
608 case Intrinsic::x86_tdpbuud_internal:
609 C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(TileDP: Inst) || C;
610 break;
611 case Intrinsic::x86_tdpbf16ps_internal:
612 C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(TileDP: Inst) || C;
613 break;
614 case Intrinsic::x86_tileloadd64_internal:
615 C = lowerTileLoadStore<true>(TileLoadStore: Inst) || C;
616 break;
617 case Intrinsic::x86_tilestored64_internal:
618 C = lowerTileLoadStore<false>(TileLoadStore: Inst) || C;
619 break;
620 case Intrinsic::x86_tilezero_internal:
621 C = lowerTileZero(TileZero: Inst) || C;
622 break;
623 default:
624 llvm_unreachable("invalid amx intrinsics!");
625 }
626 }
627
628 return C;
629}
630
631namespace {
632class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass {
633public:
634 static char ID;
635
636 X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) {
637 initializeX86LowerAMXIntrinsicsLegacyPassPass(
638 *PassRegistry::getPassRegistry());
639 }
640
641 bool runOnFunction(Function &F) override {
642 if (!X86ScalarizeAMX)
643 return false;
644 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
645 if (!F.hasFnAttribute(Kind: Attribute::OptimizeNone) &&
646 TM->getOptLevel() != CodeGenOptLevel::None)
647 return false;
648
649 auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
650 auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
651 auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
652 auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr;
653 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
654
655 X86LowerAMXIntrinsics LAT(F, DTU, LI);
656 return LAT.visit();
657 }
658 StringRef getPassName() const override { return "Lower AMX intrinsics"; }
659
660 void getAnalysisUsage(AnalysisUsage &AU) const override {
661 AU.addPreserved<DominatorTreeWrapperPass>();
662 AU.addPreserved<LoopInfoWrapperPass>();
663 AU.addRequired<TargetPassConfig>();
664 }
665};
666} // namespace
667
668static const char PassName[] = "Lower AMX intrinsics";
669char X86LowerAMXIntrinsicsLegacyPass::ID = 0;
670INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
671 false, false)
672INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
673INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
674 false, false)
675
676FunctionPass *llvm::createX86LowerAMXIntrinsicsPass() {
677 return new X86LowerAMXIntrinsicsLegacyPass();
678}
679