1//===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Pass to transform amx intrinsics to scalar operations.
10/// This pass is always enabled and it skips when it is not -O0 and has no
11/// optnone attributes. With -O0 or optnone attribute, the def of shape to amx
12/// intrinsics is near the amx intrinsics code. We are not able to find a
13/// point which post-dominate all the shape and dominate all amx intrinsics.
14/// To decouple the dependency of the shape, we transform amx intrinsics
15/// to scalar operation, so that compiling doesn't fail. In long term, we
16/// should improve fast register allocation to allocate amx register.
17//===----------------------------------------------------------------------===//
18//
19#include "X86.h"
20#include "llvm/Analysis/DomTreeUpdater.h"
21#include "llvm/Analysis/LoopInfo.h"
22#include "llvm/Analysis/TargetTransformInfo.h"
23#include "llvm/CodeGen/Passes.h"
24#include "llvm/CodeGen/TargetPassConfig.h"
25#include "llvm/CodeGen/ValueTypes.h"
26#include "llvm/IR/Analysis.h"
27#include "llvm/IR/DataLayout.h"
28#include "llvm/IR/Dominators.h"
29#include "llvm/IR/Function.h"
30#include "llvm/IR/IRBuilder.h"
31#include "llvm/IR/Instructions.h"
32#include "llvm/IR/IntrinsicInst.h"
33#include "llvm/IR/IntrinsicsX86.h"
34#include "llvm/IR/MDBuilder.h"
35#include "llvm/IR/PassManager.h"
36#include "llvm/IR/PatternMatch.h"
37#include "llvm/IR/ProfDataUtils.h"
38#include "llvm/InitializePasses.h"
39#include "llvm/Pass.h"
40#include "llvm/Support/CommandLine.h"
41#include "llvm/Target/TargetMachine.h"
42#include "llvm/Transforms/Utils/BasicBlockUtils.h"
43#include "llvm/Transforms/Utils/LoopUtils.h"
44
45using namespace llvm;
46using namespace PatternMatch;
47
48namespace llvm {
49extern cl::opt<bool> ProfcheckDisableMetadataFixes;
50} // end namespace llvm
51
52#define DEBUG_TYPE "x86-lower-amx-intrinsics"
53
54#ifndef NDEBUG
55static bool isV256I32Ty(Type *Ty) {
56 if (auto *FVT = dyn_cast<FixedVectorType>(Ty))
57 return FVT->getNumElements() == 256 &&
58 FVT->getElementType()->isIntegerTy(32);
59 return false;
60}
61#endif
62
63static cl::opt<bool>
64 X86ScalarizeAMX("enable-x86-scalar-amx", cl::init(Val: false), cl::Hidden,
65 cl::desc("X86: enable AMX scalarizition."));
66
67namespace {
68class X86LowerAMXIntrinsics {
69 Function &Func;
70
71public:
72 X86LowerAMXIntrinsics(Function &F, DomTreeUpdater &DomTU, LoopInfo *LoopI)
73 : Func(F), DTU(DomTU), LI(LoopI) {}
74 bool visit();
75
76private:
77 DomTreeUpdater &DTU;
78 LoopInfo *LI;
79 BasicBlock *createLoop(BasicBlock *Preheader, BasicBlock *Exit, Value *Bound,
80 ConstantInt *Step, StringRef Name, IRBuilderBase &B,
81 Loop *L);
82 template <bool IsTileLoad>
83 Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End,
84 IRBuilderBase &B, Value *Row, Value *Col,
85 Value *Ptr, Value *Stride, Value *Tile);
86 template <Intrinsic::ID IntrID>
87 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
88 IntrID == Intrinsic::x86_tdpbsud_internal ||
89 IntrID == Intrinsic::x86_tdpbusd_internal ||
90 IntrID == Intrinsic::x86_tdpbuud_internal ||
91 IntrID == Intrinsic::x86_tdpbf16ps_internal,
92 Value *>
93 createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B,
94 Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS,
95 Value *RHS);
96 template <bool IsTileLoad>
97 bool lowerTileLoadStore(Instruction *TileLoadStore);
98 template <Intrinsic::ID IntrID>
99 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
100 IntrID == Intrinsic::x86_tdpbsud_internal ||
101 IntrID == Intrinsic::x86_tdpbusd_internal ||
102 IntrID == Intrinsic::x86_tdpbuud_internal ||
103 IntrID == Intrinsic::x86_tdpbf16ps_internal,
104 bool>
105 lowerTileDP(Instruction *TileDP);
106 bool lowerTileZero(Instruction *TileZero);
107};
108} // anonymous namespace
109
110BasicBlock *X86LowerAMXIntrinsics::createLoop(BasicBlock *Preheader,
111 BasicBlock *Exit, Value *Bound,
112 ConstantInt *Step, StringRef Name,
113 IRBuilderBase &B, Loop *L) {
114 LLVMContext &Ctx = Preheader->getContext();
115 BasicBlock *Header =
116 BasicBlock::Create(Context&: Ctx, Name: Name + ".header", Parent: Preheader->getParent(), InsertBefore: Exit);
117 BasicBlock *Body =
118 BasicBlock::Create(Context&: Ctx, Name: Name + ".body", Parent: Header->getParent(), InsertBefore: Exit);
119 BasicBlock *Latch =
120 BasicBlock::Create(Context&: Ctx, Name: Name + ".latch", Parent: Header->getParent(), InsertBefore: Exit);
121
122 Type *I16Ty = Type::getInt16Ty(C&: Ctx);
123 BranchInst::Create(IfTrue: Body, InsertBefore: Header);
124 BranchInst::Create(IfTrue: Latch, InsertBefore: Body);
125 PHINode *IV =
126 PHINode::Create(Ty: I16Ty, NumReservedValues: 2, NameStr: Name + ".iv", InsertBefore: Header->getTerminator()->getIterator());
127 IV->addIncoming(V: ConstantInt::get(Ty: I16Ty, V: 0), BB: Preheader);
128
129 B.SetInsertPoint(Latch);
130 Value *Inc = B.CreateAdd(LHS: IV, RHS: Step, Name: Name + ".step");
131 Value *Cond = B.CreateICmpNE(LHS: Inc, RHS: Bound, Name: Name + ".cond");
132 auto *BR = BranchInst::Create(IfTrue: Header, IfFalse: Exit, Cond, InsertBefore: Latch);
133 if (!ProfcheckDisableMetadataFixes) {
134 if (auto *BoundInt = dyn_cast<ConstantInt>(Val: Bound)) {
135 assert(Step->getZExtValue() != 0 &&
136 "Expected a non-zero step size. This is chosen by the pass and "
137 "should always be non-zero to imply a finite loop.");
138 MDBuilder MDB(Preheader->getContext());
139 setFittedBranchWeights(
140 I&: *BR, Weights: {BoundInt->getZExtValue() / Step->getZExtValue(), 1}, IsExpected: false);
141 } else {
142 setExplicitlyUnknownBranchWeightsIfProfiled(I&: *BR, DEBUG_TYPE);
143 }
144 }
145 IV->addIncoming(V: Inc, BB: Latch);
146
147 BranchInst *PreheaderBr = cast<BranchInst>(Val: Preheader->getTerminator());
148 BasicBlock *Tmp = PreheaderBr->getSuccessor(i: 0);
149 PreheaderBr->setSuccessor(idx: 0, NewSucc: Header);
150 DTU.applyUpdatesPermissive(Updates: {
151 {DominatorTree::Delete, Preheader, Tmp},
152 {DominatorTree::Insert, Header, Body},
153 {DominatorTree::Insert, Body, Latch},
154 {DominatorTree::Insert, Latch, Header},
155 {DominatorTree::Insert, Latch, Exit},
156 {DominatorTree::Insert, Preheader, Header},
157 });
158 if (LI) {
159 L->addBasicBlockToLoop(NewBB: Header, LI&: *LI);
160 L->addBasicBlockToLoop(NewBB: Body, LI&: *LI);
161 L->addBasicBlockToLoop(NewBB: Latch, LI&: *LI);
162 }
163 return Body;
164}
165
166template <bool IsTileLoad>
167Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops(
168 BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row,
169 Value *Col, Value *Ptr, Value *Stride, Value *Tile) {
170 std::string IntrinName = IsTileLoad ? "tileload" : "tilestore";
171 Loop *RowLoop = nullptr;
172 Loop *ColLoop = nullptr;
173 if (LI) {
174 RowLoop = LI->AllocateLoop();
175 ColLoop = LI->AllocateLoop();
176 RowLoop->addChildLoop(NewChild: ColLoop);
177 if (Loop *ParentL = LI->getLoopFor(BB: Start))
178 ParentL->addChildLoop(NewChild: RowLoop);
179 else
180 LI->addTopLevelLoop(New: RowLoop);
181 }
182
183 BasicBlock *RowBody = createLoop(Preheader: Start, Exit: End, Bound: Row, Step: B.getInt16(C: 1),
184 Name: IntrinName + ".scalarize.rows", B, L: RowLoop);
185 BasicBlock *RowLatch = RowBody->getSingleSuccessor();
186
187 BasicBlock *ColBody = createLoop(Preheader: RowBody, Exit: RowLatch, Bound: Col, Step: B.getInt16(C: 1),
188 Name: IntrinName + ".scalarize.cols", B, L: ColLoop);
189
190 BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
191 BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
192 BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
193 Value *CurrentRow = &*RowLoopHeader->begin();
194 Value *CurrentCol = &*ColLoopHeader->begin();
195 Type *EltTy = B.getInt32Ty();
196 FixedVectorType *V256I32Ty = FixedVectorType::get(ElementType: EltTy, NumElts: 256);
197
198 // Common part for tileload and tilestore
199 // *.scalarize.cols.body:
200 // Calculate %idxmem and %idxvec
201 B.SetInsertPoint(ColBody->getTerminator());
202 Value *CurrentRowZExt = B.CreateZExt(V: CurrentRow, DestTy: Stride->getType());
203 Value *CurrentColZExt = B.CreateZExt(V: CurrentCol, DestTy: Stride->getType());
204 Value *Offset =
205 B.CreateAdd(LHS: B.CreateMul(LHS: CurrentRowZExt, RHS: Stride), RHS: CurrentColZExt);
206 Value *EltPtr = B.CreateGEP(Ty: EltTy, Ptr, IdxList: Offset);
207 Value *Idx = B.CreateAdd(LHS: B.CreateMul(LHS: CurrentRow, RHS: B.getInt16(C: 16)), RHS: CurrentCol);
208 if (IsTileLoad) {
209 // tileload.scalarize.rows.header:
210 // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec,
211 // %tileload.scalarize.rows.latch ]
212 B.SetInsertPoint(RowLoopHeader->getTerminator());
213 Value *VecZero = Constant::getNullValue(Ty: V256I32Ty);
214 PHINode *VecCPhiRowLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.phi.row");
215 VecCPhiRowLoop->addIncoming(V: VecZero, BB: Start);
216
217 // tileload.scalarize.cols.header:
218 // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body
219 // ], [ %ResVec, %tileload.scalarize.cols.latch ]
220 B.SetInsertPoint(ColLoopHeader->getTerminator());
221 PHINode *VecPhi = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.phi");
222 VecPhi->addIncoming(V: VecCPhiRowLoop, BB: RowBody);
223
224 // tileload.scalarize.cols.body:
225 // Calculate %idxmem and %idxvec
226 // %eltptr = getelementptr i32, i32* %base, i64 %idxmem
227 // %elt = load i32, i32* %ptr
228 // %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec
229 B.SetInsertPoint(ColBody->getTerminator());
230 Value *Elt = B.CreateLoad(Ty: EltTy, Ptr: EltPtr);
231 Value *ResVec = B.CreateInsertElement(Vec: VecPhi, NewElt: Elt, Idx);
232 VecPhi->addIncoming(V: ResVec, BB: ColLoopLatch);
233 VecCPhiRowLoop->addIncoming(V: ResVec, BB: RowLatch);
234
235 return ResVec;
236 } else {
237 auto *BitCast = cast<BitCastInst>(Val: Tile);
238 Value *Vec = BitCast->getOperand(i_nocapture: 0);
239 assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx");
240 // tilestore.scalarize.cols.body:
241 // %mul = mul i16 %row.iv, i16 16
242 // %idx = add i16 %mul, i16 %col.iv
243 // %vec = extractelement <16 x i32> %vec, i16 %idx
244 // store i32 %vec, i32* %ptr
245 B.SetInsertPoint(ColBody->getTerminator());
246 Value *Elt = B.CreateExtractElement(Vec, Idx);
247
248 B.CreateStore(Val: Elt, Ptr: EltPtr);
249 return nullptr;
250 }
251}
252
253template <Intrinsic::ID IntrID>
254std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
255 IntrID == Intrinsic::x86_tdpbsud_internal ||
256 IntrID == Intrinsic::x86_tdpbusd_internal ||
257 IntrID == Intrinsic::x86_tdpbuud_internal ||
258 IntrID == Intrinsic::x86_tdpbf16ps_internal,
259 Value *>
260X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End,
261 IRBuilderBase &B, Value *Row,
262 Value *Col, Value *K, Value *Acc,
263 Value *LHS, Value *RHS) {
264 std::string IntrinName;
265 switch (IntrID) {
266 case Intrinsic::x86_tdpbssd_internal:
267 IntrinName = "tiledpbssd";
268 break;
269 case Intrinsic::x86_tdpbsud_internal:
270 IntrinName = "tiledpbsud";
271 break;
272 case Intrinsic::x86_tdpbusd_internal:
273 IntrinName = "tiledpbusd";
274 break;
275 case Intrinsic::x86_tdpbuud_internal:
276 IntrinName = "tiledpbuud";
277 break;
278 case Intrinsic::x86_tdpbf16ps_internal:
279 IntrinName = "tiledpbf16ps";
280 break;
281 }
282 Loop *RowLoop = nullptr;
283 Loop *ColLoop = nullptr;
284 Loop *InnerLoop = nullptr;
285 if (LI) {
286 RowLoop = LI->AllocateLoop();
287 ColLoop = LI->AllocateLoop();
288 InnerLoop = LI->AllocateLoop();
289 ColLoop->addChildLoop(NewChild: InnerLoop);
290 RowLoop->addChildLoop(NewChild: ColLoop);
291 if (Loop *ParentL = LI->getLoopFor(BB: Start))
292 ParentL->addChildLoop(NewChild: RowLoop);
293 else
294 LI->addTopLevelLoop(New: RowLoop);
295 }
296
297 BasicBlock *RowBody = createLoop(Preheader: Start, Exit: End, Bound: Row, Step: B.getInt16(C: 1),
298 Name: IntrinName + ".scalarize.rows", B, L: RowLoop);
299 BasicBlock *RowLatch = RowBody->getSingleSuccessor();
300
301 BasicBlock *ColBody = createLoop(Preheader: RowBody, Exit: RowLatch, Bound: Col, Step: B.getInt16(C: 1),
302 Name: IntrinName + ".scalarize.cols", B, L: ColLoop);
303
304 BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
305
306 B.SetInsertPoint(ColBody->getTerminator());
307 BasicBlock *InnerBody =
308 createLoop(Preheader: ColBody, Exit: ColLoopLatch, Bound: K, Step: B.getInt16(C: 1),
309 Name: IntrinName + ".scalarize.inner", B, L: InnerLoop);
310
311 BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
312 BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
313 BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor();
314 BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor();
315 Value *CurrentRow = &*RowLoopHeader->begin();
316 Value *CurrentCol = &*ColLoopHeader->begin();
317 Value *CurrentInner = &*InnerLoopHeader->begin();
318
319 FixedVectorType *V256I32Ty = FixedVectorType::get(ElementType: B.getInt32Ty(), NumElts: 256);
320 auto *BitCastAcc = cast<BitCastInst>(Val: Acc);
321 Value *VecC = BitCastAcc->getOperand(i_nocapture: 0);
322 assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx");
323 // TODO else create BitCast from x86amx to v256i32.
324 // Store x86amx to memory, and reload from memory
325 // to vector. However with -O0, it doesn't happen.
326 auto *BitCastLHS = cast<BitCastInst>(Val: LHS);
327 Value *VecA = BitCastLHS->getOperand(i_nocapture: 0);
328 assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx");
329 auto *BitCastRHS = cast<BitCastInst>(Val: RHS);
330 Value *VecB = BitCastRHS->getOperand(i_nocapture: 0);
331 assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx");
332
333 // tiledpbssd.scalarize.rows.header:
334 // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC,
335 // %tiledpbssd.scalarize.rows.latch ]
336
337 // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [
338 // %NewVecD, %tiledpbssd.scalarize.rows.latch ]
339 B.SetInsertPoint(RowLoopHeader->getTerminator());
340 PHINode *VecCPhiRowLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.c.phi.row");
341 VecCPhiRowLoop->addIncoming(V: VecC, BB: Start);
342 Value *VecZero = Constant::getNullValue(Ty: V256I32Ty);
343 PHINode *VecDPhiRowLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.d.phi.row");
344 VecDPhiRowLoop->addIncoming(V: VecZero, BB: Start);
345
346 // tiledpbssd.scalarize.cols.header:
347 // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row,
348 // %tiledpbssd.scalarize.rows.body ], [ %NewVecC,
349 // %tiledpbssd.scalarize.cols.latch ]
350
351 // %vec.d.phi.col = phi <256 x i32> [
352 // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD,
353 // %tiledpbssd.scalarize.cols.latch ]
354
355 // calculate idxc.
356 B.SetInsertPoint(ColLoopHeader->getTerminator());
357 PHINode *VecCPhiColLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.c.phi.col");
358 VecCPhiColLoop->addIncoming(V: VecCPhiRowLoop, BB: RowBody);
359 PHINode *VecDPhiColLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.d.phi.col");
360 VecDPhiColLoop->addIncoming(V: VecDPhiRowLoop, BB: RowBody);
361 Value *IdxC =
362 B.CreateAdd(LHS: B.CreateMul(LHS: CurrentRow, RHS: B.getInt16(C: 16)), RHS: CurrentCol);
363
364 // tiledpbssd.scalarize.inner.header:
365 // %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col,
366 // %tiledpbssd.scalarize.cols.body ], [ %NewVecC,
367 // %tiledpbssd.scalarize.inner.latch ]
368
369 B.SetInsertPoint(InnerLoopHeader->getTerminator());
370 PHINode *VecCPhi = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.c.inner.phi");
371 VecCPhi->addIncoming(V: VecCPhiColLoop, BB: ColBody);
372
373 B.SetInsertPoint(InnerBody->getTerminator());
374 Value *IdxA =
375 B.CreateAdd(LHS: B.CreateMul(LHS: CurrentRow, RHS: B.getInt16(C: 16)), RHS: CurrentInner);
376 Value *IdxB =
377 B.CreateAdd(LHS: B.CreateMul(LHS: CurrentInner, RHS: B.getInt16(C: 16)), RHS: CurrentCol);
378 Value *NewVecC = nullptr;
379
380 if (IntrID != Intrinsic::x86_tdpbf16ps_internal) {
381 // tiledpbssd.scalarize.inner.body:
382 // calculate idxa, idxb
383 // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
384 // %elta = extractelement <256 x i32> %veca, i16 %idxa
385 // %eltav4i8 = bitcast i32 %elta to <4 x i8>
386 // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
387 // %eltbv4i8 = bitcast i32 %eltb to <4 x i8>
388 // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32>
389 // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32>
390 // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32
391 // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131)
392 // %neweltc = add i32 %elt, %acc
393 // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
394 // i16 %idxc
395 FixedVectorType *V4I8Ty = FixedVectorType::get(ElementType: B.getInt8Ty(), NumElts: 4);
396 FixedVectorType *V4I32Ty = FixedVectorType::get(ElementType: B.getInt32Ty(), NumElts: 4);
397 Value *EltC = B.CreateExtractElement(Vec: VecCPhi, Idx: IdxC);
398 Value *EltA = B.CreateExtractElement(Vec: VecA, Idx: IdxA);
399 Value *SubVecA = B.CreateBitCast(V: EltA, DestTy: V4I8Ty);
400 Value *EltB = B.CreateExtractElement(Vec: VecB, Idx: IdxB);
401 Value *SubVecB = B.CreateBitCast(V: EltB, DestTy: V4I8Ty);
402 Value *SEXTSubVecB = nullptr;
403 Value *SEXTSubVecA = nullptr;
404 switch (IntrID) {
405 case Intrinsic::x86_tdpbssd_internal:
406 SEXTSubVecB = B.CreateSExt(V: SubVecB, DestTy: V4I32Ty);
407 SEXTSubVecA = B.CreateSExt(V: SubVecA, DestTy: V4I32Ty);
408 break;
409 case Intrinsic::x86_tdpbsud_internal:
410 SEXTSubVecB = B.CreateZExt(V: SubVecB, DestTy: V4I32Ty);
411 SEXTSubVecA = B.CreateSExt(V: SubVecA, DestTy: V4I32Ty);
412 break;
413 case Intrinsic::x86_tdpbusd_internal:
414 SEXTSubVecB = B.CreateSExt(V: SubVecB, DestTy: V4I32Ty);
415 SEXTSubVecA = B.CreateZExt(V: SubVecA, DestTy: V4I32Ty);
416 break;
417 case Intrinsic::x86_tdpbuud_internal:
418 SEXTSubVecB = B.CreateZExt(V: SubVecB, DestTy: V4I32Ty);
419 SEXTSubVecA = B.CreateZExt(V: SubVecA, DestTy: V4I32Ty);
420 break;
421 default:
422 llvm_unreachable("Invalid intrinsic ID!");
423 }
424 Value *SubVecR = B.CreateAddReduce(Src: B.CreateMul(LHS: SEXTSubVecA, RHS: SEXTSubVecB));
425 Value *ResElt = B.CreateAdd(LHS: EltC, RHS: SubVecR);
426 NewVecC = B.CreateInsertElement(Vec: VecCPhi, NewElt: ResElt, Idx: IdxC);
427 } else {
428 // tiledpbf16ps.scalarize.inner.body:
429 // calculate idxa, idxb, idxc
430 // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
431 // %eltcf32 = bitcast i32 %eltc to float
432 // %elta = extractelement <256 x i32> %veca, i16 %idxa
433 // %eltav2i16 = bitcast i32 %elta to <2 x i16>
434 // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
435 // %eltbv2i16 = bitcast i32 %eltb to <2 x i16>
436 // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4
437 // x i32> <i32 2, i32 0, i32 3, i32 1>
438 // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float>
439 // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x
440 // i32> <i32 2, i32 0, i32 3, i32 1>
441 // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float>
442 // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32
443 // %acc = call float
444 // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab)
445 // %neweltc = bitcast float %acc to i32
446 // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
447 // i16 %idxc
448 // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc,
449 // i16 %idxc
450 FixedVectorType *V2I16Ty = FixedVectorType::get(ElementType: B.getInt16Ty(), NumElts: 2);
451 FixedVectorType *V2F32Ty = FixedVectorType::get(ElementType: B.getFloatTy(), NumElts: 2);
452 Value *EltC = B.CreateExtractElement(Vec: VecCPhi, Idx: IdxC);
453 Value *EltCF32 = B.CreateBitCast(V: EltC, DestTy: B.getFloatTy());
454 Value *EltA = B.CreateExtractElement(Vec: VecA, Idx: IdxA);
455 Value *SubVecA = B.CreateBitCast(V: EltA, DestTy: V2I16Ty);
456 Value *EltB = B.CreateExtractElement(Vec: VecB, Idx: IdxB);
457 Value *SubVecB = B.CreateBitCast(V: EltB, DestTy: V2I16Ty);
458 Value *ZeroV2I16 = Constant::getNullValue(Ty: V2I16Ty);
459 int ShuffleMask[4] = {2, 0, 3, 1};
460 auto ShuffleArray = ArrayRef(ShuffleMask);
461 Value *AV2F32 = B.CreateBitCast(
462 V: B.CreateShuffleVector(V1: SubVecA, V2: ZeroV2I16, Mask: ShuffleArray), DestTy: V2F32Ty);
463 Value *BV2F32 = B.CreateBitCast(
464 V: B.CreateShuffleVector(V1: SubVecB, V2: ZeroV2I16, Mask: ShuffleArray), DestTy: V2F32Ty);
465 Value *SubVecR = B.CreateFAddReduce(Acc: EltCF32, Src: B.CreateFMul(L: AV2F32, R: BV2F32));
466 Value *ResElt = B.CreateBitCast(V: SubVecR, DestTy: B.getInt32Ty());
467 NewVecC = B.CreateInsertElement(Vec: VecCPhi, NewElt: ResElt, Idx: IdxC);
468 }
469
470 // tiledpbssd.scalarize.cols.latch:
471 // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc
472 // %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC,
473 // i16 %idxc
474 B.SetInsertPoint(ColLoopLatch->getTerminator());
475 Value *NewEltC = B.CreateExtractElement(Vec: NewVecC, Idx: IdxC);
476 Value *NewVecD = B.CreateInsertElement(Vec: VecDPhiColLoop, NewElt: NewEltC, Idx: IdxC);
477
478 VecCPhi->addIncoming(V: NewVecC, BB: InnerLoopLatch);
479 VecCPhiRowLoop->addIncoming(V: NewVecC, BB: RowLatch);
480 VecCPhiColLoop->addIncoming(V: NewVecC, BB: ColLoopLatch);
481 VecDPhiRowLoop->addIncoming(V: NewVecD, BB: RowLatch);
482 VecDPhiColLoop->addIncoming(V: NewVecD, BB: ColLoopLatch);
483
484 return NewVecD;
485}
486
487template <Intrinsic::ID IntrID>
488std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
489 IntrID == Intrinsic::x86_tdpbsud_internal ||
490 IntrID == Intrinsic::x86_tdpbusd_internal ||
491 IntrID == Intrinsic::x86_tdpbuud_internal ||
492 IntrID == Intrinsic::x86_tdpbf16ps_internal,
493 bool>
494X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) {
495 Value *M, *N, *K, *C, *A, *B;
496 match(TileDP, m_Intrinsic<IntrID>(m_Value(V&: M), m_Value(V&: N), m_Value(V&: K),
497 m_Value(V&: C), m_Value(V&: A), m_Value(V&: B)));
498 Instruction *InsertI = TileDP;
499 IRBuilder<> PreBuilder(TileDP);
500 PreBuilder.SetInsertPoint(TileDP);
501 // We visit the loop with (m, n/4, k/4):
502 // %n_dword = lshr i16 %n, 2
503 // %k_dword = lshr i16 %k, 2
504 Value *NDWord = PreBuilder.CreateLShr(LHS: N, RHS: PreBuilder.getInt16(C: 2));
505 Value *KDWord = PreBuilder.CreateLShr(LHS: K, RHS: PreBuilder.getInt16(C: 2));
506 BasicBlock *Start = InsertI->getParent();
507 BasicBlock *End =
508 SplitBlock(Old: InsertI->getParent(), SplitPt: InsertI, DTU: &DTU, LI, MSSAU: nullptr, BBName: "continue");
509 IRBuilder<> Builder(TileDP);
510 Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord,
511 KDWord, C, A, B);
512 // we cannot assume there always be bitcast after tiledpbssd. So we need to
513 // insert one bitcast as required
514 Builder.SetInsertPoint(TheBB: End, IP: End->getFirstNonPHIIt());
515 Value *ResAMX =
516 Builder.CreateBitCast(V: ResVec, DestTy: Type::getX86_AMXTy(C&: Builder.getContext()));
517 // Delete TileDP intrinsic and do some clean-up.
518 for (Use &U : llvm::make_early_inc_range(Range: TileDP->uses())) {
519 Instruction *I = cast<Instruction>(Val: U.getUser());
520 Value *Vec;
521 if (match(V: I, P: m_BitCast(Op: m_Value(V&: Vec)))) {
522 I->replaceAllUsesWith(V: ResVec);
523 I->eraseFromParent();
524 }
525 }
526 TileDP->replaceAllUsesWith(V: ResAMX);
527 TileDP->eraseFromParent();
528 return true;
529}
530
531template <bool IsTileLoad>
532bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) {
533 Value *M, *N, *Ptr, *Stride, *Tile;
534 if (IsTileLoad)
535 match(V: TileLoadStore,
536 P: m_Intrinsic<Intrinsic::x86_tileloadd64_internal>(
537 Op0: m_Value(V&: M), Op1: m_Value(V&: N), Op2: m_Value(V&: Ptr), Op3: m_Value(V&: Stride)));
538 else
539 match(V: TileLoadStore, P: m_Intrinsic<Intrinsic::x86_tilestored64_internal>(
540 Op0: m_Value(V&: M), Op1: m_Value(V&: N), Op2: m_Value(V&: Ptr),
541 Op3: m_Value(V&: Stride), Op4: m_Value(V&: Tile)));
542
543 Instruction *InsertI = TileLoadStore;
544 IRBuilder<> PreBuilder(TileLoadStore);
545 PreBuilder.SetInsertPoint(TileLoadStore);
546 Value *NDWord = PreBuilder.CreateLShr(LHS: N, RHS: PreBuilder.getInt16(C: 2));
547 Value *StrideDWord = PreBuilder.CreateLShr(LHS: Stride, RHS: PreBuilder.getInt64(C: 2));
548 BasicBlock *Start = InsertI->getParent();
549 BasicBlock *End =
550 SplitBlock(Old: InsertI->getParent(), SplitPt: InsertI, DTU: &DTU, LI, MSSAU: nullptr, BBName: "continue");
551 IRBuilder<> Builder(TileLoadStore);
552 Value *ResVec = createTileLoadStoreLoops<IsTileLoad>(
553 Start, End, Builder, M, NDWord, Ptr, StrideDWord,
554 IsTileLoad ? nullptr : Tile);
555 if (IsTileLoad) {
556 // we cannot assume there always be bitcast after tileload. So we need to
557 // insert one bitcast as required
558 Builder.SetInsertPoint(TheBB: End, IP: End->getFirstNonPHIIt());
559 Value *ResAMX =
560 Builder.CreateBitCast(V: ResVec, DestTy: Type::getX86_AMXTy(C&: Builder.getContext()));
561 // Delete tileloadd6 intrinsic and do some clean-up
562 for (Use &U : llvm::make_early_inc_range(Range: TileLoadStore->uses())) {
563 Instruction *I = cast<Instruction>(Val: U.getUser());
564 Value *Vec;
565 if (match(V: I, P: m_BitCast(Op: m_Value(V&: Vec)))) {
566 I->replaceAllUsesWith(V: ResVec);
567 I->eraseFromParent();
568 }
569 }
570 TileLoadStore->replaceAllUsesWith(V: ResAMX);
571 }
572 TileLoadStore->eraseFromParent();
573 return true;
574}
575
576bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) {
577 IRBuilder<> Builder(TileZero);
578 FixedVectorType *V256I32Ty = FixedVectorType::get(ElementType: Builder.getInt32Ty(), NumElts: 256);
579 Value *VecZero = Constant::getNullValue(Ty: V256I32Ty);
580 for (Use &U : llvm::make_early_inc_range(Range: TileZero->uses())) {
581 Instruction *I = cast<Instruction>(Val: U.getUser());
582 Value *Vec;
583 if (match(V: I, P: m_BitCast(Op: m_Value(V&: Vec)))) {
584 I->replaceAllUsesWith(V: VecZero);
585 I->eraseFromParent();
586 }
587 }
588 TileZero->eraseFromParent();
589 return true;
590}
591
592bool X86LowerAMXIntrinsics::visit() {
593 bool C = false;
594 SmallVector<IntrinsicInst *, 8> WorkList;
595 for (BasicBlock *BB : depth_first(G: &Func)) {
596 for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
597 if (auto *Inst = dyn_cast<IntrinsicInst>(Val: &*II++)) {
598 switch (Inst->getIntrinsicID()) {
599 case Intrinsic::x86_tdpbssd_internal:
600 case Intrinsic::x86_tdpbsud_internal:
601 case Intrinsic::x86_tdpbusd_internal:
602 case Intrinsic::x86_tdpbuud_internal:
603 case Intrinsic::x86_tileloadd64_internal:
604 case Intrinsic::x86_tilestored64_internal:
605 case Intrinsic::x86_tilezero_internal:
606 case Intrinsic::x86_tdpbf16ps_internal:
607 WorkList.push_back(Elt: Inst);
608 break;
609 default:
610 break;
611 }
612 }
613 }
614 }
615
616 for (auto *Inst : WorkList) {
617 switch (Inst->getIntrinsicID()) {
618 case Intrinsic::x86_tdpbssd_internal:
619 C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(TileDP: Inst) || C;
620 break;
621 case Intrinsic::x86_tdpbsud_internal:
622 C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(TileDP: Inst) || C;
623 break;
624 case Intrinsic::x86_tdpbusd_internal:
625 C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(TileDP: Inst) || C;
626 break;
627 case Intrinsic::x86_tdpbuud_internal:
628 C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(TileDP: Inst) || C;
629 break;
630 case Intrinsic::x86_tdpbf16ps_internal:
631 C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(TileDP: Inst) || C;
632 break;
633 case Intrinsic::x86_tileloadd64_internal:
634 C = lowerTileLoadStore<true>(TileLoadStore: Inst) || C;
635 break;
636 case Intrinsic::x86_tilestored64_internal:
637 C = lowerTileLoadStore<false>(TileLoadStore: Inst) || C;
638 break;
639 case Intrinsic::x86_tilezero_internal:
640 C = lowerTileZero(TileZero: Inst) || C;
641 break;
642 default:
643 llvm_unreachable("invalid amx intrinsics!");
644 }
645 }
646
647 return C;
648}
649
650namespace {
651bool shouldRunLowerAMXIntrinsics(const Function &F, const TargetMachine *TM) {
652 return X86ScalarizeAMX && (F.hasFnAttribute(Kind: Attribute::OptimizeNone) ||
653 TM->getOptLevel() == CodeGenOptLevel::None);
654}
655
656bool runLowerAMXIntrinsics(Function &F, DominatorTree *DT, LoopInfo *LI) {
657 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
658
659 X86LowerAMXIntrinsics LAT(F, DTU, LI);
660 return LAT.visit();
661}
662} // namespace
663
664PreservedAnalyses X86LowerAMXIntrinsicsPass::run(Function &F,
665 FunctionAnalysisManager &FAM) {
666 if (!shouldRunLowerAMXIntrinsics(F, TM))
667 return PreservedAnalyses::all();
668
669 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(IR&: F);
670 LoopInfo &LI = FAM.getResult<LoopAnalysis>(IR&: F);
671 bool Changed = runLowerAMXIntrinsics(F, DT: &DT, LI: &LI);
672 if (!Changed)
673 return PreservedAnalyses::all();
674
675 PreservedAnalyses PA = PreservedAnalyses::none();
676 PA.preserve<DominatorTreeAnalysis>();
677 PA.preserve<LoopAnalysis>();
678 return PA;
679}
680
681namespace {
682class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass {
683public:
684 static char ID;
685
686 X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) {}
687
688 bool runOnFunction(Function &F) override {
689 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
690 if (!shouldRunLowerAMXIntrinsics(F, TM))
691 return false;
692
693 auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
694 auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
695 auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
696 auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr;
697 return runLowerAMXIntrinsics(F, DT, LI);
698 }
699 StringRef getPassName() const override { return "Lower AMX intrinsics"; }
700
701 void getAnalysisUsage(AnalysisUsage &AU) const override {
702 AU.addPreserved<DominatorTreeWrapperPass>();
703 AU.addPreserved<LoopInfoWrapperPass>();
704 AU.addRequired<TargetPassConfig>();
705 }
706};
707} // namespace
708
709static const char PassName[] = "Lower AMX intrinsics";
710char X86LowerAMXIntrinsicsLegacyPass::ID = 0;
711INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
712 false, false)
713INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
714INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
715 false, false)
716
717FunctionPass *llvm::createX86LowerAMXIntrinsicsLegacyPass() {
718 return new X86LowerAMXIntrinsicsLegacyPass();
719}
720