1 | //===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | /// \file Pass to transform amx intrinsics to scalar operations. |
10 | /// This pass is always enabled and it skips when it is not -O0 and has no |
11 | /// optnone attributes. With -O0 or optnone attribute, the def of shape to amx |
12 | /// intrinsics is near the amx intrinsics code. We are not able to find a |
13 | /// point which post-dominate all the shape and dominate all amx intrinsics. |
14 | /// To decouple the dependency of the shape, we transform amx intrinsics |
15 | /// to scalar operation, so that compiling doesn't fail. In long term, we |
16 | /// should improve fast register allocation to allocate amx register. |
17 | //===----------------------------------------------------------------------===// |
18 | // |
19 | #include "X86.h" |
20 | #include "llvm/ADT/PostOrderIterator.h" |
21 | #include "llvm/Analysis/DomTreeUpdater.h" |
22 | #include "llvm/Analysis/LoopInfo.h" |
23 | #include "llvm/Analysis/OptimizationRemarkEmitter.h" |
24 | #include "llvm/Analysis/TargetTransformInfo.h" |
25 | #include "llvm/CodeGen/Passes.h" |
26 | #include "llvm/CodeGen/TargetPassConfig.h" |
27 | #include "llvm/CodeGen/ValueTypes.h" |
28 | #include "llvm/IR/DataLayout.h" |
29 | #include "llvm/IR/Function.h" |
30 | #include "llvm/IR/IRBuilder.h" |
31 | #include "llvm/IR/Instructions.h" |
32 | #include "llvm/IR/IntrinsicInst.h" |
33 | #include "llvm/IR/IntrinsicsX86.h" |
34 | #include "llvm/IR/PatternMatch.h" |
35 | #include "llvm/InitializePasses.h" |
36 | #include "llvm/Pass.h" |
37 | #include "llvm/Support/CommandLine.h" |
38 | #include "llvm/Target/TargetMachine.h" |
39 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
40 | #include "llvm/Transforms/Utils/LoopUtils.h" |
41 | |
42 | using namespace llvm; |
43 | using namespace PatternMatch; |
44 | |
45 | #define DEBUG_TYPE "lower-amx-intrinsics" |
46 | |
47 | #ifndef NDEBUG |
48 | static bool isV256I32Ty(Type *Ty) { |
49 | if (auto *FVT = dyn_cast<FixedVectorType>(Ty)) |
50 | return FVT->getNumElements() == 256 && |
51 | FVT->getElementType()->isIntegerTy(32); |
52 | return false; |
53 | } |
54 | #endif |
55 | |
56 | static cl::opt<bool> |
57 | X86ScalarizeAMX("enable-x86-scalar-amx" , cl::init(Val: false), cl::Hidden, |
58 | cl::desc("X86: enable AMX scalarizition." )); |
59 | |
60 | namespace { |
61 | class X86LowerAMXIntrinsics { |
62 | Function &Func; |
63 | |
64 | public: |
65 | X86LowerAMXIntrinsics(Function &F, DomTreeUpdater &DomTU, LoopInfo *LoopI) |
66 | : Func(F), DTU(DomTU), LI(LoopI) {} |
67 | bool visit(); |
68 | |
69 | private: |
70 | DomTreeUpdater &DTU; |
71 | LoopInfo *LI; |
72 | BasicBlock *createLoop(BasicBlock *, BasicBlock *Exit, Value *Bound, |
73 | Value *Step, StringRef Name, IRBuilderBase &B, |
74 | Loop *L); |
75 | template <bool IsTileLoad> |
76 | Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End, |
77 | IRBuilderBase &B, Value *Row, Value *Col, |
78 | Value *Ptr, Value *Stride, Value *Tile); |
79 | template <Intrinsic::ID IntrID> |
80 | std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || |
81 | IntrID == Intrinsic::x86_tdpbsud_internal || |
82 | IntrID == Intrinsic::x86_tdpbusd_internal || |
83 | IntrID == Intrinsic::x86_tdpbuud_internal || |
84 | IntrID == Intrinsic::x86_tdpbf16ps_internal, |
85 | Value *> |
86 | createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, |
87 | Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS, |
88 | Value *RHS); |
89 | template <bool IsTileLoad> |
90 | bool lowerTileLoadStore(Instruction *TileLoadStore); |
91 | template <Intrinsic::ID IntrID> |
92 | std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || |
93 | IntrID == Intrinsic::x86_tdpbsud_internal || |
94 | IntrID == Intrinsic::x86_tdpbusd_internal || |
95 | IntrID == Intrinsic::x86_tdpbuud_internal || |
96 | IntrID == Intrinsic::x86_tdpbf16ps_internal, |
97 | bool> |
98 | lowerTileDP(Instruction *TileDP); |
99 | bool lowerTileZero(Instruction *TileZero); |
100 | }; |
101 | } // anonymous namespace |
102 | |
103 | BasicBlock *X86LowerAMXIntrinsics::createLoop(BasicBlock *, |
104 | BasicBlock *Exit, Value *Bound, |
105 | Value *Step, StringRef Name, |
106 | IRBuilderBase &B, Loop *L) { |
107 | LLVMContext &Ctx = Preheader->getContext(); |
108 | BasicBlock * = |
109 | BasicBlock::Create(Context&: Ctx, Name: Name + ".header" , Parent: Preheader->getParent(), InsertBefore: Exit); |
110 | BasicBlock *Body = |
111 | BasicBlock::Create(Context&: Ctx, Name: Name + ".body" , Parent: Header->getParent(), InsertBefore: Exit); |
112 | BasicBlock *Latch = |
113 | BasicBlock::Create(Context&: Ctx, Name: Name + ".latch" , Parent: Header->getParent(), InsertBefore: Exit); |
114 | |
115 | Type *I16Ty = Type::getInt16Ty(C&: Ctx); |
116 | BranchInst::Create(IfTrue: Body, InsertBefore: Header); |
117 | BranchInst::Create(IfTrue: Latch, InsertBefore: Body); |
118 | PHINode *IV = |
119 | PHINode::Create(Ty: I16Ty, NumReservedValues: 2, NameStr: Name + ".iv" , InsertBefore: Header->getTerminator()->getIterator()); |
120 | IV->addIncoming(V: ConstantInt::get(Ty: I16Ty, V: 0), BB: Preheader); |
121 | |
122 | B.SetInsertPoint(Latch); |
123 | Value *Inc = B.CreateAdd(LHS: IV, RHS: Step, Name: Name + ".step" ); |
124 | Value *Cond = B.CreateICmpNE(LHS: Inc, RHS: Bound, Name: Name + ".cond" ); |
125 | BranchInst::Create(IfTrue: Header, IfFalse: Exit, Cond, InsertBefore: Latch); |
126 | IV->addIncoming(V: Inc, BB: Latch); |
127 | |
128 | BranchInst * = cast<BranchInst>(Val: Preheader->getTerminator()); |
129 | BasicBlock *Tmp = PreheaderBr->getSuccessor(i: 0); |
130 | PreheaderBr->setSuccessor(idx: 0, NewSucc: Header); |
131 | DTU.applyUpdatesPermissive(Updates: { |
132 | {DominatorTree::Delete, Preheader, Tmp}, |
133 | {DominatorTree::Insert, Header, Body}, |
134 | {DominatorTree::Insert, Body, Latch}, |
135 | {DominatorTree::Insert, Latch, Header}, |
136 | {DominatorTree::Insert, Latch, Exit}, |
137 | {DominatorTree::Insert, Preheader, Header}, |
138 | }); |
139 | if (LI) { |
140 | L->addBasicBlockToLoop(NewBB: Header, LI&: *LI); |
141 | L->addBasicBlockToLoop(NewBB: Body, LI&: *LI); |
142 | L->addBasicBlockToLoop(NewBB: Latch, LI&: *LI); |
143 | } |
144 | return Body; |
145 | } |
146 | |
147 | template <bool IsTileLoad> |
148 | Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops( |
149 | BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row, |
150 | Value *Col, Value *Ptr, Value *Stride, Value *Tile) { |
151 | std::string IntrinName = IsTileLoad ? "tileload" : "tilestore" ; |
152 | Loop *RowLoop = nullptr; |
153 | Loop *ColLoop = nullptr; |
154 | if (LI) { |
155 | RowLoop = LI->AllocateLoop(); |
156 | ColLoop = LI->AllocateLoop(); |
157 | RowLoop->addChildLoop(NewChild: ColLoop); |
158 | if (Loop *ParentL = LI->getLoopFor(BB: Start)) |
159 | ParentL->addChildLoop(NewChild: RowLoop); |
160 | else |
161 | LI->addTopLevelLoop(New: RowLoop); |
162 | } |
163 | |
164 | BasicBlock *RowBody = createLoop(Preheader: Start, Exit: End, Bound: Row, Step: B.getInt16(C: 1), |
165 | Name: IntrinName + ".scalarize.rows" , B, L: RowLoop); |
166 | BasicBlock *RowLatch = RowBody->getSingleSuccessor(); |
167 | |
168 | BasicBlock *ColBody = createLoop(Preheader: RowBody, Exit: RowLatch, Bound: Col, Step: B.getInt16(C: 1), |
169 | Name: IntrinName + ".scalarize.cols" , B, L: ColLoop); |
170 | |
171 | BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor(); |
172 | BasicBlock * = ColBody->getSinglePredecessor(); |
173 | BasicBlock * = RowBody->getSinglePredecessor(); |
174 | Value *CurrentRow = &*RowLoopHeader->begin(); |
175 | Value *CurrentCol = &*ColLoopHeader->begin(); |
176 | Type *EltTy = B.getInt32Ty(); |
177 | FixedVectorType *V256I32Ty = FixedVectorType::get(ElementType: EltTy, NumElts: 256); |
178 | |
179 | // Common part for tileload and tilestore |
180 | // *.scalarize.cols.body: |
181 | // Calculate %idxmem and %idxvec |
182 | B.SetInsertPoint(ColBody->getTerminator()); |
183 | Value *CurrentRowZExt = B.CreateZExt(V: CurrentRow, DestTy: Stride->getType()); |
184 | Value *CurrentColZExt = B.CreateZExt(V: CurrentCol, DestTy: Stride->getType()); |
185 | Value *Offset = |
186 | B.CreateAdd(LHS: B.CreateMul(LHS: CurrentRowZExt, RHS: Stride), RHS: CurrentColZExt); |
187 | Value *EltPtr = B.CreateGEP(Ty: EltTy, Ptr, IdxList: Offset); |
188 | Value *Idx = B.CreateAdd(LHS: B.CreateMul(LHS: CurrentRow, RHS: B.getInt16(C: 16)), RHS: CurrentCol); |
189 | if (IsTileLoad) { |
190 | // tileload.scalarize.rows.header: |
191 | // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec, |
192 | // %tileload.scalarize.rows.latch ] |
193 | B.SetInsertPoint(RowLoopHeader->getTerminator()); |
194 | Value *VecZero = Constant::getNullValue(Ty: V256I32Ty); |
195 | PHINode *VecCPhiRowLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.phi.row" ); |
196 | VecCPhiRowLoop->addIncoming(V: VecZero, BB: Start); |
197 | |
198 | // tileload.scalarize.cols.header: |
199 | // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body |
200 | // ], [ %ResVec, %tileload.scalarize.cols.latch ] |
201 | B.SetInsertPoint(ColLoopHeader->getTerminator()); |
202 | PHINode *VecPhi = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.phi" ); |
203 | VecPhi->addIncoming(V: VecCPhiRowLoop, BB: RowBody); |
204 | |
205 | // tileload.scalarize.cols.body: |
206 | // Calculate %idxmem and %idxvec |
207 | // %eltptr = getelementptr i32, i32* %base, i64 %idxmem |
208 | // %elt = load i32, i32* %ptr |
209 | // %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec |
210 | B.SetInsertPoint(ColBody->getTerminator()); |
211 | Value *Elt = B.CreateLoad(Ty: EltTy, Ptr: EltPtr); |
212 | Value *ResVec = B.CreateInsertElement(Vec: VecPhi, NewElt: Elt, Idx); |
213 | VecPhi->addIncoming(V: ResVec, BB: ColLoopLatch); |
214 | VecCPhiRowLoop->addIncoming(V: ResVec, BB: RowLatch); |
215 | |
216 | return ResVec; |
217 | } else { |
218 | auto *BitCast = cast<BitCastInst>(Val: Tile); |
219 | Value *Vec = BitCast->getOperand(i_nocapture: 0); |
220 | assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx" ); |
221 | // tilestore.scalarize.cols.body: |
222 | // %mul = mul i16 %row.iv, i16 16 |
223 | // %idx = add i16 %mul, i16 %col.iv |
224 | // %vec = extractelement <16 x i32> %vec, i16 %idx |
225 | // store i32 %vec, i32* %ptr |
226 | B.SetInsertPoint(ColBody->getTerminator()); |
227 | Value *Elt = B.CreateExtractElement(Vec, Idx); |
228 | |
229 | B.CreateStore(Val: Elt, Ptr: EltPtr); |
230 | return nullptr; |
231 | } |
232 | } |
233 | |
234 | template <Intrinsic::ID IntrID> |
235 | std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || |
236 | IntrID == Intrinsic::x86_tdpbsud_internal || |
237 | IntrID == Intrinsic::x86_tdpbusd_internal || |
238 | IntrID == Intrinsic::x86_tdpbuud_internal || |
239 | IntrID == Intrinsic::x86_tdpbf16ps_internal, |
240 | Value *> |
241 | X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End, |
242 | IRBuilderBase &B, Value *Row, |
243 | Value *Col, Value *K, Value *Acc, |
244 | Value *LHS, Value *RHS) { |
245 | std::string IntrinName; |
246 | switch (IntrID) { |
247 | case Intrinsic::x86_tdpbssd_internal: |
248 | IntrinName = "tiledpbssd" ; |
249 | break; |
250 | case Intrinsic::x86_tdpbsud_internal: |
251 | IntrinName = "tiledpbsud" ; |
252 | break; |
253 | case Intrinsic::x86_tdpbusd_internal: |
254 | IntrinName = "tiledpbusd" ; |
255 | break; |
256 | case Intrinsic::x86_tdpbuud_internal: |
257 | IntrinName = "tiledpbuud" ; |
258 | break; |
259 | case Intrinsic::x86_tdpbf16ps_internal: |
260 | IntrinName = "tiledpbf16ps" ; |
261 | break; |
262 | } |
263 | Loop *RowLoop = nullptr; |
264 | Loop *ColLoop = nullptr; |
265 | Loop *InnerLoop = nullptr; |
266 | if (LI) { |
267 | RowLoop = LI->AllocateLoop(); |
268 | ColLoop = LI->AllocateLoop(); |
269 | InnerLoop = LI->AllocateLoop(); |
270 | ColLoop->addChildLoop(NewChild: InnerLoop); |
271 | RowLoop->addChildLoop(NewChild: ColLoop); |
272 | if (Loop *ParentL = LI->getLoopFor(BB: Start)) |
273 | ParentL->addChildLoop(NewChild: RowLoop); |
274 | else |
275 | LI->addTopLevelLoop(New: RowLoop); |
276 | } |
277 | |
278 | BasicBlock *RowBody = createLoop(Preheader: Start, Exit: End, Bound: Row, Step: B.getInt16(C: 1), |
279 | Name: IntrinName + ".scalarize.rows" , B, L: RowLoop); |
280 | BasicBlock *RowLatch = RowBody->getSingleSuccessor(); |
281 | |
282 | BasicBlock *ColBody = createLoop(Preheader: RowBody, Exit: RowLatch, Bound: Col, Step: B.getInt16(C: 1), |
283 | Name: IntrinName + ".scalarize.cols" , B, L: ColLoop); |
284 | |
285 | BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor(); |
286 | |
287 | B.SetInsertPoint(ColBody->getTerminator()); |
288 | BasicBlock *InnerBody = |
289 | createLoop(Preheader: ColBody, Exit: ColLoopLatch, Bound: K, Step: B.getInt16(C: 1), |
290 | Name: IntrinName + ".scalarize.inner" , B, L: InnerLoop); |
291 | |
292 | BasicBlock * = ColBody->getSinglePredecessor(); |
293 | BasicBlock * = RowBody->getSinglePredecessor(); |
294 | BasicBlock * = InnerBody->getSinglePredecessor(); |
295 | BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor(); |
296 | Value *CurrentRow = &*RowLoopHeader->begin(); |
297 | Value *CurrentCol = &*ColLoopHeader->begin(); |
298 | Value *CurrentInner = &*InnerLoopHeader->begin(); |
299 | |
300 | FixedVectorType *V256I32Ty = FixedVectorType::get(ElementType: B.getInt32Ty(), NumElts: 256); |
301 | auto *BitCastAcc = cast<BitCastInst>(Val: Acc); |
302 | Value *VecC = BitCastAcc->getOperand(i_nocapture: 0); |
303 | assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx" ); |
304 | // TODO else create BitCast from x86amx to v256i32. |
305 | // Store x86amx to memory, and reload from memory |
306 | // to vector. However with -O0, it doesn't happen. |
307 | auto *BitCastLHS = cast<BitCastInst>(Val: LHS); |
308 | Value *VecA = BitCastLHS->getOperand(i_nocapture: 0); |
309 | assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx" ); |
310 | auto *BitCastRHS = cast<BitCastInst>(Val: RHS); |
311 | Value *VecB = BitCastRHS->getOperand(i_nocapture: 0); |
312 | assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx" ); |
313 | |
314 | // tiledpbssd.scalarize.rows.header: |
315 | // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC, |
316 | // %tiledpbssd.scalarize.rows.latch ] |
317 | |
318 | // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [ |
319 | // %NewVecD, %tiledpbssd.scalarize.rows.latch ] |
320 | B.SetInsertPoint(RowLoopHeader->getTerminator()); |
321 | PHINode *VecCPhiRowLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.c.phi.row" ); |
322 | VecCPhiRowLoop->addIncoming(V: VecC, BB: Start); |
323 | Value *VecZero = Constant::getNullValue(Ty: V256I32Ty); |
324 | PHINode *VecDPhiRowLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.d.phi.row" ); |
325 | VecDPhiRowLoop->addIncoming(V: VecZero, BB: Start); |
326 | |
327 | // tiledpbssd.scalarize.cols.header: |
328 | // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row, |
329 | // %tiledpbssd.scalarize.rows.body ], [ %NewVecC, |
330 | // %tiledpbssd.scalarize.cols.latch ] |
331 | |
332 | // %vec.d.phi.col = phi <256 x i32> [ |
333 | // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD, |
334 | // %tiledpbssd.scalarize.cols.latch ] |
335 | |
336 | // calculate idxc. |
337 | B.SetInsertPoint(ColLoopHeader->getTerminator()); |
338 | PHINode *VecCPhiColLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.c.phi.col" ); |
339 | VecCPhiColLoop->addIncoming(V: VecCPhiRowLoop, BB: RowBody); |
340 | PHINode *VecDPhiColLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.d.phi.col" ); |
341 | VecDPhiColLoop->addIncoming(V: VecDPhiRowLoop, BB: RowBody); |
342 | Value *IdxC = |
343 | B.CreateAdd(LHS: B.CreateMul(LHS: CurrentRow, RHS: B.getInt16(C: 16)), RHS: CurrentCol); |
344 | |
345 | // tiledpbssd.scalarize.inner.header: |
346 | // %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col, |
347 | // %tiledpbssd.scalarize.cols.body ], [ %NewVecC, |
348 | // %tiledpbssd.scalarize.inner.latch ] |
349 | |
350 | B.SetInsertPoint(InnerLoopHeader->getTerminator()); |
351 | PHINode *VecCPhi = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.c.inner.phi" ); |
352 | VecCPhi->addIncoming(V: VecCPhiColLoop, BB: ColBody); |
353 | |
354 | B.SetInsertPoint(InnerBody->getTerminator()); |
355 | Value *IdxA = |
356 | B.CreateAdd(LHS: B.CreateMul(LHS: CurrentRow, RHS: B.getInt16(C: 16)), RHS: CurrentInner); |
357 | Value *IdxB = |
358 | B.CreateAdd(LHS: B.CreateMul(LHS: CurrentInner, RHS: B.getInt16(C: 16)), RHS: CurrentCol); |
359 | Value *NewVecC = nullptr; |
360 | |
361 | if (IntrID != Intrinsic::x86_tdpbf16ps_internal) { |
362 | // tiledpbssd.scalarize.inner.body: |
363 | // calculate idxa, idxb |
364 | // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc |
365 | // %elta = extractelement <256 x i32> %veca, i16 %idxa |
366 | // %eltav4i8 = bitcast i32 %elta to <4 x i8> |
367 | // %eltb = extractelement <256 x i32> %vecb, i16 %idxb |
368 | // %eltbv4i8 = bitcast i32 %eltb to <4 x i8> |
369 | // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32> |
370 | // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32> |
371 | // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32 |
372 | // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131) |
373 | // %neweltc = add i32 %elt, %acc |
374 | // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc, |
375 | // i16 %idxc |
376 | FixedVectorType *V4I8Ty = FixedVectorType::get(ElementType: B.getInt8Ty(), NumElts: 4); |
377 | FixedVectorType *V4I32Ty = FixedVectorType::get(ElementType: B.getInt32Ty(), NumElts: 4); |
378 | Value *EltC = B.CreateExtractElement(Vec: VecCPhi, Idx: IdxC); |
379 | Value *EltA = B.CreateExtractElement(Vec: VecA, Idx: IdxA); |
380 | Value *SubVecA = B.CreateBitCast(V: EltA, DestTy: V4I8Ty); |
381 | Value *EltB = B.CreateExtractElement(Vec: VecB, Idx: IdxB); |
382 | Value *SubVecB = B.CreateBitCast(V: EltB, DestTy: V4I8Ty); |
383 | Value *SEXTSubVecB = nullptr; |
384 | Value *SEXTSubVecA = nullptr; |
385 | switch (IntrID) { |
386 | case Intrinsic::x86_tdpbssd_internal: |
387 | SEXTSubVecB = B.CreateSExt(V: SubVecB, DestTy: V4I32Ty); |
388 | SEXTSubVecA = B.CreateSExt(V: SubVecA, DestTy: V4I32Ty); |
389 | break; |
390 | case Intrinsic::x86_tdpbsud_internal: |
391 | SEXTSubVecB = B.CreateZExt(V: SubVecB, DestTy: V4I32Ty); |
392 | SEXTSubVecA = B.CreateSExt(V: SubVecA, DestTy: V4I32Ty); |
393 | break; |
394 | case Intrinsic::x86_tdpbusd_internal: |
395 | SEXTSubVecB = B.CreateSExt(V: SubVecB, DestTy: V4I32Ty); |
396 | SEXTSubVecA = B.CreateZExt(V: SubVecA, DestTy: V4I32Ty); |
397 | break; |
398 | case Intrinsic::x86_tdpbuud_internal: |
399 | SEXTSubVecB = B.CreateZExt(V: SubVecB, DestTy: V4I32Ty); |
400 | SEXTSubVecA = B.CreateZExt(V: SubVecA, DestTy: V4I32Ty); |
401 | break; |
402 | default: |
403 | llvm_unreachable("Invalid intrinsic ID!" ); |
404 | } |
405 | Value *SubVecR = B.CreateAddReduce(Src: B.CreateMul(LHS: SEXTSubVecA, RHS: SEXTSubVecB)); |
406 | Value *ResElt = B.CreateAdd(LHS: EltC, RHS: SubVecR); |
407 | NewVecC = B.CreateInsertElement(Vec: VecCPhi, NewElt: ResElt, Idx: IdxC); |
408 | } else { |
409 | // tiledpbf16ps.scalarize.inner.body: |
410 | // calculate idxa, idxb, idxc |
411 | // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc |
412 | // %eltcf32 = bitcast i32 %eltc to float |
413 | // %elta = extractelement <256 x i32> %veca, i16 %idxa |
414 | // %eltav2i16 = bitcast i32 %elta to <2 x i16> |
415 | // %eltb = extractelement <256 x i32> %vecb, i16 %idxb |
416 | // %eltbv2i16 = bitcast i32 %eltb to <2 x i16> |
417 | // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4 |
418 | // x i32> <i32 2, i32 0, i32 3, i32 1> |
419 | // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float> |
420 | // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x |
421 | // i32> <i32 2, i32 0, i32 3, i32 1> |
422 | // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float> |
423 | // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32 |
424 | // %acc = call float |
425 | // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab) |
426 | // %neweltc = bitcast float %acc to i32 |
427 | // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc, |
428 | // i16 %idxc |
429 | // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc, |
430 | // i16 %idxc |
431 | FixedVectorType *V2I16Ty = FixedVectorType::get(ElementType: B.getInt16Ty(), NumElts: 2); |
432 | FixedVectorType *V2F32Ty = FixedVectorType::get(ElementType: B.getFloatTy(), NumElts: 2); |
433 | Value *EltC = B.CreateExtractElement(Vec: VecCPhi, Idx: IdxC); |
434 | Value *EltCF32 = B.CreateBitCast(V: EltC, DestTy: B.getFloatTy()); |
435 | Value *EltA = B.CreateExtractElement(Vec: VecA, Idx: IdxA); |
436 | Value *SubVecA = B.CreateBitCast(V: EltA, DestTy: V2I16Ty); |
437 | Value *EltB = B.CreateExtractElement(Vec: VecB, Idx: IdxB); |
438 | Value *SubVecB = B.CreateBitCast(V: EltB, DestTy: V2I16Ty); |
439 | Value *ZeroV2I16 = Constant::getNullValue(Ty: V2I16Ty); |
440 | int ShuffleMask[4] = {2, 0, 3, 1}; |
441 | auto ShuffleArray = ArrayRef(ShuffleMask); |
442 | Value *AV2F32 = B.CreateBitCast( |
443 | V: B.CreateShuffleVector(V1: SubVecA, V2: ZeroV2I16, Mask: ShuffleArray), DestTy: V2F32Ty); |
444 | Value *BV2F32 = B.CreateBitCast( |
445 | V: B.CreateShuffleVector(V1: SubVecB, V2: ZeroV2I16, Mask: ShuffleArray), DestTy: V2F32Ty); |
446 | Value *SubVecR = B.CreateFAddReduce(Acc: EltCF32, Src: B.CreateFMul(L: AV2F32, R: BV2F32)); |
447 | Value *ResElt = B.CreateBitCast(V: SubVecR, DestTy: B.getInt32Ty()); |
448 | NewVecC = B.CreateInsertElement(Vec: VecCPhi, NewElt: ResElt, Idx: IdxC); |
449 | } |
450 | |
451 | // tiledpbssd.scalarize.cols.latch: |
452 | // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc |
453 | // %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC, |
454 | // i16 %idxc |
455 | B.SetInsertPoint(ColLoopLatch->getTerminator()); |
456 | Value *NewEltC = B.CreateExtractElement(Vec: NewVecC, Idx: IdxC); |
457 | Value *NewVecD = B.CreateInsertElement(Vec: VecDPhiColLoop, NewElt: NewEltC, Idx: IdxC); |
458 | |
459 | VecCPhi->addIncoming(V: NewVecC, BB: InnerLoopLatch); |
460 | VecCPhiRowLoop->addIncoming(V: NewVecC, BB: RowLatch); |
461 | VecCPhiColLoop->addIncoming(V: NewVecC, BB: ColLoopLatch); |
462 | VecDPhiRowLoop->addIncoming(V: NewVecD, BB: RowLatch); |
463 | VecDPhiColLoop->addIncoming(V: NewVecD, BB: ColLoopLatch); |
464 | |
465 | return NewVecD; |
466 | } |
467 | |
468 | template <Intrinsic::ID IntrID> |
469 | std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || |
470 | IntrID == Intrinsic::x86_tdpbsud_internal || |
471 | IntrID == Intrinsic::x86_tdpbusd_internal || |
472 | IntrID == Intrinsic::x86_tdpbuud_internal || |
473 | IntrID == Intrinsic::x86_tdpbf16ps_internal, |
474 | bool> |
475 | X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) { |
476 | Value *M, *N, *K, *C, *A, *B; |
477 | match(TileDP, m_Intrinsic<IntrID>(m_Value(V&: M), m_Value(V&: N), m_Value(V&: K), |
478 | m_Value(V&: C), m_Value(V&: A), m_Value(V&: B))); |
479 | Instruction *InsertI = TileDP; |
480 | IRBuilder<> PreBuilder(TileDP); |
481 | PreBuilder.SetInsertPoint(TileDP); |
482 | // We visit the loop with (m, n/4, k/4): |
483 | // %n_dword = lshr i16 %n, 2 |
484 | // %k_dword = lshr i16 %k, 2 |
485 | Value *NDWord = PreBuilder.CreateLShr(LHS: N, RHS: PreBuilder.getInt16(C: 2)); |
486 | Value *KDWord = PreBuilder.CreateLShr(LHS: K, RHS: PreBuilder.getInt16(C: 2)); |
487 | BasicBlock *Start = InsertI->getParent(); |
488 | BasicBlock *End = |
489 | SplitBlock(Old: InsertI->getParent(), SplitPt: InsertI, DTU: &DTU, LI, MSSAU: nullptr, BBName: "continue" ); |
490 | IRBuilder<> Builder(TileDP); |
491 | Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord, |
492 | KDWord, C, A, B); |
493 | // we cannot assume there always be bitcast after tiledpbssd. So we need to |
494 | // insert one bitcast as required |
495 | Builder.SetInsertPoint(TheBB: End, IP: End->getFirstNonPHIIt()); |
496 | Value *ResAMX = |
497 | Builder.CreateBitCast(V: ResVec, DestTy: Type::getX86_AMXTy(C&: Builder.getContext())); |
498 | // Delete TileDP intrinsic and do some clean-up. |
499 | for (Use &U : llvm::make_early_inc_range(Range: TileDP->uses())) { |
500 | Instruction *I = cast<Instruction>(Val: U.getUser()); |
501 | Value *Vec; |
502 | if (match(V: I, P: m_BitCast(Op: m_Value(V&: Vec)))) { |
503 | I->replaceAllUsesWith(V: ResVec); |
504 | I->eraseFromParent(); |
505 | } |
506 | } |
507 | TileDP->replaceAllUsesWith(V: ResAMX); |
508 | TileDP->eraseFromParent(); |
509 | return true; |
510 | } |
511 | |
512 | template <bool IsTileLoad> |
513 | bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) { |
514 | Value *M, *N, *Ptr, *Stride, *Tile; |
515 | if (IsTileLoad) |
516 | match(V: TileLoadStore, |
517 | P: m_Intrinsic<Intrinsic::x86_tileloadd64_internal>( |
518 | Op0: m_Value(V&: M), Op1: m_Value(V&: N), Op2: m_Value(V&: Ptr), Op3: m_Value(V&: Stride))); |
519 | else |
520 | match(V: TileLoadStore, P: m_Intrinsic<Intrinsic::x86_tilestored64_internal>( |
521 | Op0: m_Value(V&: M), Op1: m_Value(V&: N), Op2: m_Value(V&: Ptr), |
522 | Op3: m_Value(V&: Stride), Op4: m_Value(V&: Tile))); |
523 | |
524 | Instruction *InsertI = TileLoadStore; |
525 | IRBuilder<> PreBuilder(TileLoadStore); |
526 | PreBuilder.SetInsertPoint(TileLoadStore); |
527 | Value *NDWord = PreBuilder.CreateLShr(LHS: N, RHS: PreBuilder.getInt16(C: 2)); |
528 | Value *StrideDWord = PreBuilder.CreateLShr(LHS: Stride, RHS: PreBuilder.getInt64(C: 2)); |
529 | BasicBlock *Start = InsertI->getParent(); |
530 | BasicBlock *End = |
531 | SplitBlock(Old: InsertI->getParent(), SplitPt: InsertI, DTU: &DTU, LI, MSSAU: nullptr, BBName: "continue" ); |
532 | IRBuilder<> Builder(TileLoadStore); |
533 | Value *ResVec = createTileLoadStoreLoops<IsTileLoad>( |
534 | Start, End, Builder, M, NDWord, Ptr, StrideDWord, |
535 | IsTileLoad ? nullptr : Tile); |
536 | if (IsTileLoad) { |
537 | // we cannot assume there always be bitcast after tileload. So we need to |
538 | // insert one bitcast as required |
539 | Builder.SetInsertPoint(TheBB: End, IP: End->getFirstNonPHIIt()); |
540 | Value *ResAMX = |
541 | Builder.CreateBitCast(V: ResVec, DestTy: Type::getX86_AMXTy(C&: Builder.getContext())); |
542 | // Delete tileloadd6 intrinsic and do some clean-up |
543 | for (Use &U : llvm::make_early_inc_range(Range: TileLoadStore->uses())) { |
544 | Instruction *I = cast<Instruction>(Val: U.getUser()); |
545 | Value *Vec; |
546 | if (match(V: I, P: m_BitCast(Op: m_Value(V&: Vec)))) { |
547 | I->replaceAllUsesWith(V: ResVec); |
548 | I->eraseFromParent(); |
549 | } |
550 | } |
551 | TileLoadStore->replaceAllUsesWith(V: ResAMX); |
552 | } |
553 | TileLoadStore->eraseFromParent(); |
554 | return true; |
555 | } |
556 | |
557 | bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) { |
558 | IRBuilder<> Builder(TileZero); |
559 | FixedVectorType *V256I32Ty = FixedVectorType::get(ElementType: Builder.getInt32Ty(), NumElts: 256); |
560 | Value *VecZero = Constant::getNullValue(Ty: V256I32Ty); |
561 | for (Use &U : llvm::make_early_inc_range(Range: TileZero->uses())) { |
562 | Instruction *I = cast<Instruction>(Val: U.getUser()); |
563 | Value *Vec; |
564 | if (match(V: I, P: m_BitCast(Op: m_Value(V&: Vec)))) { |
565 | I->replaceAllUsesWith(V: VecZero); |
566 | I->eraseFromParent(); |
567 | } |
568 | } |
569 | TileZero->eraseFromParent(); |
570 | return true; |
571 | } |
572 | |
573 | bool X86LowerAMXIntrinsics::visit() { |
574 | bool C = false; |
575 | SmallVector<IntrinsicInst *, 8> WorkList; |
576 | for (BasicBlock *BB : depth_first(G: &Func)) { |
577 | for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) { |
578 | if (auto *Inst = dyn_cast<IntrinsicInst>(Val: &*II++)) { |
579 | switch (Inst->getIntrinsicID()) { |
580 | case Intrinsic::x86_tdpbssd_internal: |
581 | case Intrinsic::x86_tdpbsud_internal: |
582 | case Intrinsic::x86_tdpbusd_internal: |
583 | case Intrinsic::x86_tdpbuud_internal: |
584 | case Intrinsic::x86_tileloadd64_internal: |
585 | case Intrinsic::x86_tilestored64_internal: |
586 | case Intrinsic::x86_tilezero_internal: |
587 | case Intrinsic::x86_tdpbf16ps_internal: |
588 | WorkList.push_back(Elt: Inst); |
589 | break; |
590 | default: |
591 | break; |
592 | } |
593 | } |
594 | } |
595 | } |
596 | |
597 | for (auto *Inst : WorkList) { |
598 | switch (Inst->getIntrinsicID()) { |
599 | case Intrinsic::x86_tdpbssd_internal: |
600 | C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(TileDP: Inst) || C; |
601 | break; |
602 | case Intrinsic::x86_tdpbsud_internal: |
603 | C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(TileDP: Inst) || C; |
604 | break; |
605 | case Intrinsic::x86_tdpbusd_internal: |
606 | C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(TileDP: Inst) || C; |
607 | break; |
608 | case Intrinsic::x86_tdpbuud_internal: |
609 | C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(TileDP: Inst) || C; |
610 | break; |
611 | case Intrinsic::x86_tdpbf16ps_internal: |
612 | C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(TileDP: Inst) || C; |
613 | break; |
614 | case Intrinsic::x86_tileloadd64_internal: |
615 | C = lowerTileLoadStore<true>(TileLoadStore: Inst) || C; |
616 | break; |
617 | case Intrinsic::x86_tilestored64_internal: |
618 | C = lowerTileLoadStore<false>(TileLoadStore: Inst) || C; |
619 | break; |
620 | case Intrinsic::x86_tilezero_internal: |
621 | C = lowerTileZero(TileZero: Inst) || C; |
622 | break; |
623 | default: |
624 | llvm_unreachable("invalid amx intrinsics!" ); |
625 | } |
626 | } |
627 | |
628 | return C; |
629 | } |
630 | |
631 | namespace { |
632 | class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass { |
633 | public: |
634 | static char ID; |
635 | |
636 | X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) { |
637 | initializeX86LowerAMXIntrinsicsLegacyPassPass( |
638 | *PassRegistry::getPassRegistry()); |
639 | } |
640 | |
641 | bool runOnFunction(Function &F) override { |
642 | if (!X86ScalarizeAMX) |
643 | return false; |
644 | TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); |
645 | if (!F.hasFnAttribute(Kind: Attribute::OptimizeNone) && |
646 | TM->getOptLevel() != CodeGenOptLevel::None) |
647 | return false; |
648 | |
649 | auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); |
650 | auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; |
651 | auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>(); |
652 | auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr; |
653 | DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); |
654 | |
655 | X86LowerAMXIntrinsics LAT(F, DTU, LI); |
656 | return LAT.visit(); |
657 | } |
658 | StringRef getPassName() const override { return "Lower AMX intrinsics" ; } |
659 | |
660 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
661 | AU.addPreserved<DominatorTreeWrapperPass>(); |
662 | AU.addPreserved<LoopInfoWrapperPass>(); |
663 | AU.addRequired<TargetPassConfig>(); |
664 | } |
665 | }; |
666 | } // namespace |
667 | |
668 | static const char PassName[] = "Lower AMX intrinsics" ; |
669 | char X86LowerAMXIntrinsicsLegacyPass::ID = 0; |
670 | INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName, |
671 | false, false) |
672 | INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) |
673 | INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName, |
674 | false, false) |
675 | |
676 | FunctionPass *llvm::createX86LowerAMXIntrinsicsPass() { |
677 | return new X86LowerAMXIntrinsicsLegacyPass(); |
678 | } |
679 | |