1 | //===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | /// \file Pass to transform amx intrinsics to scalar operations. |
10 | /// This pass is always enabled and it skips when it is not -O0 and has no |
11 | /// optnone attributes. With -O0 or optnone attribute, the def of shape to amx |
12 | /// intrinsics is near the amx intrinsics code. We are not able to find a |
13 | /// point which post-dominate all the shape and dominate all amx intrinsics. |
14 | /// To decouple the dependency of the shape, we transform amx intrinsics |
15 | /// to scalar operation, so that compiling doesn't fail. In long term, we |
16 | /// should improve fast register allocation to allocate amx register. |
17 | //===----------------------------------------------------------------------===// |
18 | // |
19 | #include "X86.h" |
20 | #include "llvm/Analysis/DomTreeUpdater.h" |
21 | #include "llvm/Analysis/LoopInfo.h" |
22 | #include "llvm/Analysis/TargetTransformInfo.h" |
23 | #include "llvm/CodeGen/Passes.h" |
24 | #include "llvm/CodeGen/TargetPassConfig.h" |
25 | #include "llvm/CodeGen/ValueTypes.h" |
26 | #include "llvm/IR/DataLayout.h" |
27 | #include "llvm/IR/Function.h" |
28 | #include "llvm/IR/IRBuilder.h" |
29 | #include "llvm/IR/Instructions.h" |
30 | #include "llvm/IR/IntrinsicInst.h" |
31 | #include "llvm/IR/IntrinsicsX86.h" |
32 | #include "llvm/IR/PatternMatch.h" |
33 | #include "llvm/InitializePasses.h" |
34 | #include "llvm/Pass.h" |
35 | #include "llvm/Support/CommandLine.h" |
36 | #include "llvm/Target/TargetMachine.h" |
37 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
38 | #include "llvm/Transforms/Utils/LoopUtils.h" |
39 | |
40 | using namespace llvm; |
41 | using namespace PatternMatch; |
42 | |
43 | #define DEBUG_TYPE "lower-amx-intrinsics" |
44 | |
45 | #ifndef NDEBUG |
46 | static bool isV256I32Ty(Type *Ty) { |
47 | if (auto *FVT = dyn_cast<FixedVectorType>(Ty)) |
48 | return FVT->getNumElements() == 256 && |
49 | FVT->getElementType()->isIntegerTy(32); |
50 | return false; |
51 | } |
52 | #endif |
53 | |
54 | static cl::opt<bool> |
55 | X86ScalarizeAMX("enable-x86-scalar-amx" , cl::init(Val: false), cl::Hidden, |
56 | cl::desc("X86: enable AMX scalarizition." )); |
57 | |
58 | namespace { |
59 | class X86LowerAMXIntrinsics { |
60 | Function &Func; |
61 | |
62 | public: |
63 | X86LowerAMXIntrinsics(Function &F, DomTreeUpdater &DomTU, LoopInfo *LoopI) |
64 | : Func(F), DTU(DomTU), LI(LoopI) {} |
65 | bool visit(); |
66 | |
67 | private: |
68 | DomTreeUpdater &DTU; |
69 | LoopInfo *LI; |
70 | BasicBlock *createLoop(BasicBlock *, BasicBlock *Exit, Value *Bound, |
71 | Value *Step, StringRef Name, IRBuilderBase &B, |
72 | Loop *L); |
73 | template <bool IsTileLoad> |
74 | Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End, |
75 | IRBuilderBase &B, Value *Row, Value *Col, |
76 | Value *Ptr, Value *Stride, Value *Tile); |
77 | template <Intrinsic::ID IntrID> |
78 | std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || |
79 | IntrID == Intrinsic::x86_tdpbsud_internal || |
80 | IntrID == Intrinsic::x86_tdpbusd_internal || |
81 | IntrID == Intrinsic::x86_tdpbuud_internal || |
82 | IntrID == Intrinsic::x86_tdpbf16ps_internal, |
83 | Value *> |
84 | createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, |
85 | Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS, |
86 | Value *RHS); |
87 | template <bool IsTileLoad> |
88 | bool lowerTileLoadStore(Instruction *TileLoadStore); |
89 | template <Intrinsic::ID IntrID> |
90 | std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || |
91 | IntrID == Intrinsic::x86_tdpbsud_internal || |
92 | IntrID == Intrinsic::x86_tdpbusd_internal || |
93 | IntrID == Intrinsic::x86_tdpbuud_internal || |
94 | IntrID == Intrinsic::x86_tdpbf16ps_internal, |
95 | bool> |
96 | lowerTileDP(Instruction *TileDP); |
97 | bool lowerTileZero(Instruction *TileZero); |
98 | }; |
99 | } // anonymous namespace |
100 | |
101 | BasicBlock *X86LowerAMXIntrinsics::createLoop(BasicBlock *, |
102 | BasicBlock *Exit, Value *Bound, |
103 | Value *Step, StringRef Name, |
104 | IRBuilderBase &B, Loop *L) { |
105 | LLVMContext &Ctx = Preheader->getContext(); |
106 | BasicBlock * = |
107 | BasicBlock::Create(Context&: Ctx, Name: Name + ".header" , Parent: Preheader->getParent(), InsertBefore: Exit); |
108 | BasicBlock *Body = |
109 | BasicBlock::Create(Context&: Ctx, Name: Name + ".body" , Parent: Header->getParent(), InsertBefore: Exit); |
110 | BasicBlock *Latch = |
111 | BasicBlock::Create(Context&: Ctx, Name: Name + ".latch" , Parent: Header->getParent(), InsertBefore: Exit); |
112 | |
113 | Type *I16Ty = Type::getInt16Ty(C&: Ctx); |
114 | BranchInst::Create(IfTrue: Body, InsertBefore: Header); |
115 | BranchInst::Create(IfTrue: Latch, InsertBefore: Body); |
116 | PHINode *IV = |
117 | PHINode::Create(Ty: I16Ty, NumReservedValues: 2, NameStr: Name + ".iv" , InsertBefore: Header->getTerminator()->getIterator()); |
118 | IV->addIncoming(V: ConstantInt::get(Ty: I16Ty, V: 0), BB: Preheader); |
119 | |
120 | B.SetInsertPoint(Latch); |
121 | Value *Inc = B.CreateAdd(LHS: IV, RHS: Step, Name: Name + ".step" ); |
122 | Value *Cond = B.CreateICmpNE(LHS: Inc, RHS: Bound, Name: Name + ".cond" ); |
123 | BranchInst::Create(IfTrue: Header, IfFalse: Exit, Cond, InsertBefore: Latch); |
124 | IV->addIncoming(V: Inc, BB: Latch); |
125 | |
126 | BranchInst * = cast<BranchInst>(Val: Preheader->getTerminator()); |
127 | BasicBlock *Tmp = PreheaderBr->getSuccessor(i: 0); |
128 | PreheaderBr->setSuccessor(idx: 0, NewSucc: Header); |
129 | DTU.applyUpdatesPermissive(Updates: { |
130 | {DominatorTree::Delete, Preheader, Tmp}, |
131 | {DominatorTree::Insert, Header, Body}, |
132 | {DominatorTree::Insert, Body, Latch}, |
133 | {DominatorTree::Insert, Latch, Header}, |
134 | {DominatorTree::Insert, Latch, Exit}, |
135 | {DominatorTree::Insert, Preheader, Header}, |
136 | }); |
137 | if (LI) { |
138 | L->addBasicBlockToLoop(NewBB: Header, LI&: *LI); |
139 | L->addBasicBlockToLoop(NewBB: Body, LI&: *LI); |
140 | L->addBasicBlockToLoop(NewBB: Latch, LI&: *LI); |
141 | } |
142 | return Body; |
143 | } |
144 | |
145 | template <bool IsTileLoad> |
146 | Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops( |
147 | BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row, |
148 | Value *Col, Value *Ptr, Value *Stride, Value *Tile) { |
149 | std::string IntrinName = IsTileLoad ? "tileload" : "tilestore" ; |
150 | Loop *RowLoop = nullptr; |
151 | Loop *ColLoop = nullptr; |
152 | if (LI) { |
153 | RowLoop = LI->AllocateLoop(); |
154 | ColLoop = LI->AllocateLoop(); |
155 | RowLoop->addChildLoop(NewChild: ColLoop); |
156 | if (Loop *ParentL = LI->getLoopFor(BB: Start)) |
157 | ParentL->addChildLoop(NewChild: RowLoop); |
158 | else |
159 | LI->addTopLevelLoop(New: RowLoop); |
160 | } |
161 | |
162 | BasicBlock *RowBody = createLoop(Preheader: Start, Exit: End, Bound: Row, Step: B.getInt16(C: 1), |
163 | Name: IntrinName + ".scalarize.rows" , B, L: RowLoop); |
164 | BasicBlock *RowLatch = RowBody->getSingleSuccessor(); |
165 | |
166 | BasicBlock *ColBody = createLoop(Preheader: RowBody, Exit: RowLatch, Bound: Col, Step: B.getInt16(C: 1), |
167 | Name: IntrinName + ".scalarize.cols" , B, L: ColLoop); |
168 | |
169 | BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor(); |
170 | BasicBlock * = ColBody->getSinglePredecessor(); |
171 | BasicBlock * = RowBody->getSinglePredecessor(); |
172 | Value *CurrentRow = &*RowLoopHeader->begin(); |
173 | Value *CurrentCol = &*ColLoopHeader->begin(); |
174 | Type *EltTy = B.getInt32Ty(); |
175 | FixedVectorType *V256I32Ty = FixedVectorType::get(ElementType: EltTy, NumElts: 256); |
176 | |
177 | // Common part for tileload and tilestore |
178 | // *.scalarize.cols.body: |
179 | // Calculate %idxmem and %idxvec |
180 | B.SetInsertPoint(ColBody->getTerminator()); |
181 | Value *CurrentRowZExt = B.CreateZExt(V: CurrentRow, DestTy: Stride->getType()); |
182 | Value *CurrentColZExt = B.CreateZExt(V: CurrentCol, DestTy: Stride->getType()); |
183 | Value *Offset = |
184 | B.CreateAdd(LHS: B.CreateMul(LHS: CurrentRowZExt, RHS: Stride), RHS: CurrentColZExt); |
185 | Value *EltPtr = B.CreateGEP(Ty: EltTy, Ptr, IdxList: Offset); |
186 | Value *Idx = B.CreateAdd(LHS: B.CreateMul(LHS: CurrentRow, RHS: B.getInt16(C: 16)), RHS: CurrentCol); |
187 | if (IsTileLoad) { |
188 | // tileload.scalarize.rows.header: |
189 | // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec, |
190 | // %tileload.scalarize.rows.latch ] |
191 | B.SetInsertPoint(RowLoopHeader->getTerminator()); |
192 | Value *VecZero = Constant::getNullValue(Ty: V256I32Ty); |
193 | PHINode *VecCPhiRowLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.phi.row" ); |
194 | VecCPhiRowLoop->addIncoming(V: VecZero, BB: Start); |
195 | |
196 | // tileload.scalarize.cols.header: |
197 | // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body |
198 | // ], [ %ResVec, %tileload.scalarize.cols.latch ] |
199 | B.SetInsertPoint(ColLoopHeader->getTerminator()); |
200 | PHINode *VecPhi = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.phi" ); |
201 | VecPhi->addIncoming(V: VecCPhiRowLoop, BB: RowBody); |
202 | |
203 | // tileload.scalarize.cols.body: |
204 | // Calculate %idxmem and %idxvec |
205 | // %eltptr = getelementptr i32, i32* %base, i64 %idxmem |
206 | // %elt = load i32, i32* %ptr |
207 | // %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec |
208 | B.SetInsertPoint(ColBody->getTerminator()); |
209 | Value *Elt = B.CreateLoad(Ty: EltTy, Ptr: EltPtr); |
210 | Value *ResVec = B.CreateInsertElement(Vec: VecPhi, NewElt: Elt, Idx); |
211 | VecPhi->addIncoming(V: ResVec, BB: ColLoopLatch); |
212 | VecCPhiRowLoop->addIncoming(V: ResVec, BB: RowLatch); |
213 | |
214 | return ResVec; |
215 | } else { |
216 | auto *BitCast = cast<BitCastInst>(Val: Tile); |
217 | Value *Vec = BitCast->getOperand(i_nocapture: 0); |
218 | assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx" ); |
219 | // tilestore.scalarize.cols.body: |
220 | // %mul = mul i16 %row.iv, i16 16 |
221 | // %idx = add i16 %mul, i16 %col.iv |
222 | // %vec = extractelement <16 x i32> %vec, i16 %idx |
223 | // store i32 %vec, i32* %ptr |
224 | B.SetInsertPoint(ColBody->getTerminator()); |
225 | Value *Elt = B.CreateExtractElement(Vec, Idx); |
226 | |
227 | B.CreateStore(Val: Elt, Ptr: EltPtr); |
228 | return nullptr; |
229 | } |
230 | } |
231 | |
232 | template <Intrinsic::ID IntrID> |
233 | std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || |
234 | IntrID == Intrinsic::x86_tdpbsud_internal || |
235 | IntrID == Intrinsic::x86_tdpbusd_internal || |
236 | IntrID == Intrinsic::x86_tdpbuud_internal || |
237 | IntrID == Intrinsic::x86_tdpbf16ps_internal, |
238 | Value *> |
239 | X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End, |
240 | IRBuilderBase &B, Value *Row, |
241 | Value *Col, Value *K, Value *Acc, |
242 | Value *LHS, Value *RHS) { |
243 | std::string IntrinName; |
244 | switch (IntrID) { |
245 | case Intrinsic::x86_tdpbssd_internal: |
246 | IntrinName = "tiledpbssd" ; |
247 | break; |
248 | case Intrinsic::x86_tdpbsud_internal: |
249 | IntrinName = "tiledpbsud" ; |
250 | break; |
251 | case Intrinsic::x86_tdpbusd_internal: |
252 | IntrinName = "tiledpbusd" ; |
253 | break; |
254 | case Intrinsic::x86_tdpbuud_internal: |
255 | IntrinName = "tiledpbuud" ; |
256 | break; |
257 | case Intrinsic::x86_tdpbf16ps_internal: |
258 | IntrinName = "tiledpbf16ps" ; |
259 | break; |
260 | } |
261 | Loop *RowLoop = nullptr; |
262 | Loop *ColLoop = nullptr; |
263 | Loop *InnerLoop = nullptr; |
264 | if (LI) { |
265 | RowLoop = LI->AllocateLoop(); |
266 | ColLoop = LI->AllocateLoop(); |
267 | InnerLoop = LI->AllocateLoop(); |
268 | ColLoop->addChildLoop(NewChild: InnerLoop); |
269 | RowLoop->addChildLoop(NewChild: ColLoop); |
270 | if (Loop *ParentL = LI->getLoopFor(BB: Start)) |
271 | ParentL->addChildLoop(NewChild: RowLoop); |
272 | else |
273 | LI->addTopLevelLoop(New: RowLoop); |
274 | } |
275 | |
276 | BasicBlock *RowBody = createLoop(Preheader: Start, Exit: End, Bound: Row, Step: B.getInt16(C: 1), |
277 | Name: IntrinName + ".scalarize.rows" , B, L: RowLoop); |
278 | BasicBlock *RowLatch = RowBody->getSingleSuccessor(); |
279 | |
280 | BasicBlock *ColBody = createLoop(Preheader: RowBody, Exit: RowLatch, Bound: Col, Step: B.getInt16(C: 1), |
281 | Name: IntrinName + ".scalarize.cols" , B, L: ColLoop); |
282 | |
283 | BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor(); |
284 | |
285 | B.SetInsertPoint(ColBody->getTerminator()); |
286 | BasicBlock *InnerBody = |
287 | createLoop(Preheader: ColBody, Exit: ColLoopLatch, Bound: K, Step: B.getInt16(C: 1), |
288 | Name: IntrinName + ".scalarize.inner" , B, L: InnerLoop); |
289 | |
290 | BasicBlock * = ColBody->getSinglePredecessor(); |
291 | BasicBlock * = RowBody->getSinglePredecessor(); |
292 | BasicBlock * = InnerBody->getSinglePredecessor(); |
293 | BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor(); |
294 | Value *CurrentRow = &*RowLoopHeader->begin(); |
295 | Value *CurrentCol = &*ColLoopHeader->begin(); |
296 | Value *CurrentInner = &*InnerLoopHeader->begin(); |
297 | |
298 | FixedVectorType *V256I32Ty = FixedVectorType::get(ElementType: B.getInt32Ty(), NumElts: 256); |
299 | auto *BitCastAcc = cast<BitCastInst>(Val: Acc); |
300 | Value *VecC = BitCastAcc->getOperand(i_nocapture: 0); |
301 | assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx" ); |
302 | // TODO else create BitCast from x86amx to v256i32. |
303 | // Store x86amx to memory, and reload from memory |
304 | // to vector. However with -O0, it doesn't happen. |
305 | auto *BitCastLHS = cast<BitCastInst>(Val: LHS); |
306 | Value *VecA = BitCastLHS->getOperand(i_nocapture: 0); |
307 | assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx" ); |
308 | auto *BitCastRHS = cast<BitCastInst>(Val: RHS); |
309 | Value *VecB = BitCastRHS->getOperand(i_nocapture: 0); |
310 | assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx" ); |
311 | |
312 | // tiledpbssd.scalarize.rows.header: |
313 | // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC, |
314 | // %tiledpbssd.scalarize.rows.latch ] |
315 | |
316 | // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [ |
317 | // %NewVecD, %tiledpbssd.scalarize.rows.latch ] |
318 | B.SetInsertPoint(RowLoopHeader->getTerminator()); |
319 | PHINode *VecCPhiRowLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.c.phi.row" ); |
320 | VecCPhiRowLoop->addIncoming(V: VecC, BB: Start); |
321 | Value *VecZero = Constant::getNullValue(Ty: V256I32Ty); |
322 | PHINode *VecDPhiRowLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.d.phi.row" ); |
323 | VecDPhiRowLoop->addIncoming(V: VecZero, BB: Start); |
324 | |
325 | // tiledpbssd.scalarize.cols.header: |
326 | // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row, |
327 | // %tiledpbssd.scalarize.rows.body ], [ %NewVecC, |
328 | // %tiledpbssd.scalarize.cols.latch ] |
329 | |
330 | // %vec.d.phi.col = phi <256 x i32> [ |
331 | // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD, |
332 | // %tiledpbssd.scalarize.cols.latch ] |
333 | |
334 | // calculate idxc. |
335 | B.SetInsertPoint(ColLoopHeader->getTerminator()); |
336 | PHINode *VecCPhiColLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.c.phi.col" ); |
337 | VecCPhiColLoop->addIncoming(V: VecCPhiRowLoop, BB: RowBody); |
338 | PHINode *VecDPhiColLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.d.phi.col" ); |
339 | VecDPhiColLoop->addIncoming(V: VecDPhiRowLoop, BB: RowBody); |
340 | Value *IdxC = |
341 | B.CreateAdd(LHS: B.CreateMul(LHS: CurrentRow, RHS: B.getInt16(C: 16)), RHS: CurrentCol); |
342 | |
343 | // tiledpbssd.scalarize.inner.header: |
344 | // %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col, |
345 | // %tiledpbssd.scalarize.cols.body ], [ %NewVecC, |
346 | // %tiledpbssd.scalarize.inner.latch ] |
347 | |
348 | B.SetInsertPoint(InnerLoopHeader->getTerminator()); |
349 | PHINode *VecCPhi = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.c.inner.phi" ); |
350 | VecCPhi->addIncoming(V: VecCPhiColLoop, BB: ColBody); |
351 | |
352 | B.SetInsertPoint(InnerBody->getTerminator()); |
353 | Value *IdxA = |
354 | B.CreateAdd(LHS: B.CreateMul(LHS: CurrentRow, RHS: B.getInt16(C: 16)), RHS: CurrentInner); |
355 | Value *IdxB = |
356 | B.CreateAdd(LHS: B.CreateMul(LHS: CurrentInner, RHS: B.getInt16(C: 16)), RHS: CurrentCol); |
357 | Value *NewVecC = nullptr; |
358 | |
359 | if (IntrID != Intrinsic::x86_tdpbf16ps_internal) { |
360 | // tiledpbssd.scalarize.inner.body: |
361 | // calculate idxa, idxb |
362 | // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc |
363 | // %elta = extractelement <256 x i32> %veca, i16 %idxa |
364 | // %eltav4i8 = bitcast i32 %elta to <4 x i8> |
365 | // %eltb = extractelement <256 x i32> %vecb, i16 %idxb |
366 | // %eltbv4i8 = bitcast i32 %eltb to <4 x i8> |
367 | // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32> |
368 | // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32> |
369 | // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32 |
370 | // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131) |
371 | // %neweltc = add i32 %elt, %acc |
372 | // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc, |
373 | // i16 %idxc |
374 | FixedVectorType *V4I8Ty = FixedVectorType::get(ElementType: B.getInt8Ty(), NumElts: 4); |
375 | FixedVectorType *V4I32Ty = FixedVectorType::get(ElementType: B.getInt32Ty(), NumElts: 4); |
376 | Value *EltC = B.CreateExtractElement(Vec: VecCPhi, Idx: IdxC); |
377 | Value *EltA = B.CreateExtractElement(Vec: VecA, Idx: IdxA); |
378 | Value *SubVecA = B.CreateBitCast(V: EltA, DestTy: V4I8Ty); |
379 | Value *EltB = B.CreateExtractElement(Vec: VecB, Idx: IdxB); |
380 | Value *SubVecB = B.CreateBitCast(V: EltB, DestTy: V4I8Ty); |
381 | Value *SEXTSubVecB = nullptr; |
382 | Value *SEXTSubVecA = nullptr; |
383 | switch (IntrID) { |
384 | case Intrinsic::x86_tdpbssd_internal: |
385 | SEXTSubVecB = B.CreateSExt(V: SubVecB, DestTy: V4I32Ty); |
386 | SEXTSubVecA = B.CreateSExt(V: SubVecA, DestTy: V4I32Ty); |
387 | break; |
388 | case Intrinsic::x86_tdpbsud_internal: |
389 | SEXTSubVecB = B.CreateZExt(V: SubVecB, DestTy: V4I32Ty); |
390 | SEXTSubVecA = B.CreateSExt(V: SubVecA, DestTy: V4I32Ty); |
391 | break; |
392 | case Intrinsic::x86_tdpbusd_internal: |
393 | SEXTSubVecB = B.CreateSExt(V: SubVecB, DestTy: V4I32Ty); |
394 | SEXTSubVecA = B.CreateZExt(V: SubVecA, DestTy: V4I32Ty); |
395 | break; |
396 | case Intrinsic::x86_tdpbuud_internal: |
397 | SEXTSubVecB = B.CreateZExt(V: SubVecB, DestTy: V4I32Ty); |
398 | SEXTSubVecA = B.CreateZExt(V: SubVecA, DestTy: V4I32Ty); |
399 | break; |
400 | default: |
401 | llvm_unreachable("Invalid intrinsic ID!" ); |
402 | } |
403 | Value *SubVecR = B.CreateAddReduce(Src: B.CreateMul(LHS: SEXTSubVecA, RHS: SEXTSubVecB)); |
404 | Value *ResElt = B.CreateAdd(LHS: EltC, RHS: SubVecR); |
405 | NewVecC = B.CreateInsertElement(Vec: VecCPhi, NewElt: ResElt, Idx: IdxC); |
406 | } else { |
407 | // tiledpbf16ps.scalarize.inner.body: |
408 | // calculate idxa, idxb, idxc |
409 | // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc |
410 | // %eltcf32 = bitcast i32 %eltc to float |
411 | // %elta = extractelement <256 x i32> %veca, i16 %idxa |
412 | // %eltav2i16 = bitcast i32 %elta to <2 x i16> |
413 | // %eltb = extractelement <256 x i32> %vecb, i16 %idxb |
414 | // %eltbv2i16 = bitcast i32 %eltb to <2 x i16> |
415 | // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4 |
416 | // x i32> <i32 2, i32 0, i32 3, i32 1> |
417 | // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float> |
418 | // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x |
419 | // i32> <i32 2, i32 0, i32 3, i32 1> |
420 | // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float> |
421 | // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32 |
422 | // %acc = call float |
423 | // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab) |
424 | // %neweltc = bitcast float %acc to i32 |
425 | // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc, |
426 | // i16 %idxc |
427 | // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc, |
428 | // i16 %idxc |
429 | FixedVectorType *V2I16Ty = FixedVectorType::get(ElementType: B.getInt16Ty(), NumElts: 2); |
430 | FixedVectorType *V2F32Ty = FixedVectorType::get(ElementType: B.getFloatTy(), NumElts: 2); |
431 | Value *EltC = B.CreateExtractElement(Vec: VecCPhi, Idx: IdxC); |
432 | Value *EltCF32 = B.CreateBitCast(V: EltC, DestTy: B.getFloatTy()); |
433 | Value *EltA = B.CreateExtractElement(Vec: VecA, Idx: IdxA); |
434 | Value *SubVecA = B.CreateBitCast(V: EltA, DestTy: V2I16Ty); |
435 | Value *EltB = B.CreateExtractElement(Vec: VecB, Idx: IdxB); |
436 | Value *SubVecB = B.CreateBitCast(V: EltB, DestTy: V2I16Ty); |
437 | Value *ZeroV2I16 = Constant::getNullValue(Ty: V2I16Ty); |
438 | int ShuffleMask[4] = {2, 0, 3, 1}; |
439 | auto ShuffleArray = ArrayRef(ShuffleMask); |
440 | Value *AV2F32 = B.CreateBitCast( |
441 | V: B.CreateShuffleVector(V1: SubVecA, V2: ZeroV2I16, Mask: ShuffleArray), DestTy: V2F32Ty); |
442 | Value *BV2F32 = B.CreateBitCast( |
443 | V: B.CreateShuffleVector(V1: SubVecB, V2: ZeroV2I16, Mask: ShuffleArray), DestTy: V2F32Ty); |
444 | Value *SubVecR = B.CreateFAddReduce(Acc: EltCF32, Src: B.CreateFMul(L: AV2F32, R: BV2F32)); |
445 | Value *ResElt = B.CreateBitCast(V: SubVecR, DestTy: B.getInt32Ty()); |
446 | NewVecC = B.CreateInsertElement(Vec: VecCPhi, NewElt: ResElt, Idx: IdxC); |
447 | } |
448 | |
449 | // tiledpbssd.scalarize.cols.latch: |
450 | // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc |
451 | // %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC, |
452 | // i16 %idxc |
453 | B.SetInsertPoint(ColLoopLatch->getTerminator()); |
454 | Value *NewEltC = B.CreateExtractElement(Vec: NewVecC, Idx: IdxC); |
455 | Value *NewVecD = B.CreateInsertElement(Vec: VecDPhiColLoop, NewElt: NewEltC, Idx: IdxC); |
456 | |
457 | VecCPhi->addIncoming(V: NewVecC, BB: InnerLoopLatch); |
458 | VecCPhiRowLoop->addIncoming(V: NewVecC, BB: RowLatch); |
459 | VecCPhiColLoop->addIncoming(V: NewVecC, BB: ColLoopLatch); |
460 | VecDPhiRowLoop->addIncoming(V: NewVecD, BB: RowLatch); |
461 | VecDPhiColLoop->addIncoming(V: NewVecD, BB: ColLoopLatch); |
462 | |
463 | return NewVecD; |
464 | } |
465 | |
466 | template <Intrinsic::ID IntrID> |
467 | std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || |
468 | IntrID == Intrinsic::x86_tdpbsud_internal || |
469 | IntrID == Intrinsic::x86_tdpbusd_internal || |
470 | IntrID == Intrinsic::x86_tdpbuud_internal || |
471 | IntrID == Intrinsic::x86_tdpbf16ps_internal, |
472 | bool> |
473 | X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) { |
474 | Value *M, *N, *K, *C, *A, *B; |
475 | match(TileDP, m_Intrinsic<IntrID>(m_Value(V&: M), m_Value(V&: N), m_Value(V&: K), |
476 | m_Value(V&: C), m_Value(V&: A), m_Value(V&: B))); |
477 | Instruction *InsertI = TileDP; |
478 | IRBuilder<> PreBuilder(TileDP); |
479 | PreBuilder.SetInsertPoint(TileDP); |
480 | // We visit the loop with (m, n/4, k/4): |
481 | // %n_dword = lshr i16 %n, 2 |
482 | // %k_dword = lshr i16 %k, 2 |
483 | Value *NDWord = PreBuilder.CreateLShr(LHS: N, RHS: PreBuilder.getInt16(C: 2)); |
484 | Value *KDWord = PreBuilder.CreateLShr(LHS: K, RHS: PreBuilder.getInt16(C: 2)); |
485 | BasicBlock *Start = InsertI->getParent(); |
486 | BasicBlock *End = |
487 | SplitBlock(Old: InsertI->getParent(), SplitPt: InsertI, DTU: &DTU, LI, MSSAU: nullptr, BBName: "continue" ); |
488 | IRBuilder<> Builder(TileDP); |
489 | Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord, |
490 | KDWord, C, A, B); |
491 | // we cannot assume there always be bitcast after tiledpbssd. So we need to |
492 | // insert one bitcast as required |
493 | Builder.SetInsertPoint(TheBB: End, IP: End->getFirstNonPHIIt()); |
494 | Value *ResAMX = |
495 | Builder.CreateBitCast(V: ResVec, DestTy: Type::getX86_AMXTy(C&: Builder.getContext())); |
496 | // Delete TileDP intrinsic and do some clean-up. |
497 | for (Use &U : llvm::make_early_inc_range(Range: TileDP->uses())) { |
498 | Instruction *I = cast<Instruction>(Val: U.getUser()); |
499 | Value *Vec; |
500 | if (match(V: I, P: m_BitCast(Op: m_Value(V&: Vec)))) { |
501 | I->replaceAllUsesWith(V: ResVec); |
502 | I->eraseFromParent(); |
503 | } |
504 | } |
505 | TileDP->replaceAllUsesWith(V: ResAMX); |
506 | TileDP->eraseFromParent(); |
507 | return true; |
508 | } |
509 | |
510 | template <bool IsTileLoad> |
511 | bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) { |
512 | Value *M, *N, *Ptr, *Stride, *Tile; |
513 | if (IsTileLoad) |
514 | match(V: TileLoadStore, |
515 | P: m_Intrinsic<Intrinsic::x86_tileloadd64_internal>( |
516 | Op0: m_Value(V&: M), Op1: m_Value(V&: N), Op2: m_Value(V&: Ptr), Op3: m_Value(V&: Stride))); |
517 | else |
518 | match(V: TileLoadStore, P: m_Intrinsic<Intrinsic::x86_tilestored64_internal>( |
519 | Op0: m_Value(V&: M), Op1: m_Value(V&: N), Op2: m_Value(V&: Ptr), |
520 | Op3: m_Value(V&: Stride), Op4: m_Value(V&: Tile))); |
521 | |
522 | Instruction *InsertI = TileLoadStore; |
523 | IRBuilder<> PreBuilder(TileLoadStore); |
524 | PreBuilder.SetInsertPoint(TileLoadStore); |
525 | Value *NDWord = PreBuilder.CreateLShr(LHS: N, RHS: PreBuilder.getInt16(C: 2)); |
526 | Value *StrideDWord = PreBuilder.CreateLShr(LHS: Stride, RHS: PreBuilder.getInt64(C: 2)); |
527 | BasicBlock *Start = InsertI->getParent(); |
528 | BasicBlock *End = |
529 | SplitBlock(Old: InsertI->getParent(), SplitPt: InsertI, DTU: &DTU, LI, MSSAU: nullptr, BBName: "continue" ); |
530 | IRBuilder<> Builder(TileLoadStore); |
531 | Value *ResVec = createTileLoadStoreLoops<IsTileLoad>( |
532 | Start, End, Builder, M, NDWord, Ptr, StrideDWord, |
533 | IsTileLoad ? nullptr : Tile); |
534 | if (IsTileLoad) { |
535 | // we cannot assume there always be bitcast after tileload. So we need to |
536 | // insert one bitcast as required |
537 | Builder.SetInsertPoint(TheBB: End, IP: End->getFirstNonPHIIt()); |
538 | Value *ResAMX = |
539 | Builder.CreateBitCast(V: ResVec, DestTy: Type::getX86_AMXTy(C&: Builder.getContext())); |
540 | // Delete tileloadd6 intrinsic and do some clean-up |
541 | for (Use &U : llvm::make_early_inc_range(Range: TileLoadStore->uses())) { |
542 | Instruction *I = cast<Instruction>(Val: U.getUser()); |
543 | Value *Vec; |
544 | if (match(V: I, P: m_BitCast(Op: m_Value(V&: Vec)))) { |
545 | I->replaceAllUsesWith(V: ResVec); |
546 | I->eraseFromParent(); |
547 | } |
548 | } |
549 | TileLoadStore->replaceAllUsesWith(V: ResAMX); |
550 | } |
551 | TileLoadStore->eraseFromParent(); |
552 | return true; |
553 | } |
554 | |
555 | bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) { |
556 | IRBuilder<> Builder(TileZero); |
557 | FixedVectorType *V256I32Ty = FixedVectorType::get(ElementType: Builder.getInt32Ty(), NumElts: 256); |
558 | Value *VecZero = Constant::getNullValue(Ty: V256I32Ty); |
559 | for (Use &U : llvm::make_early_inc_range(Range: TileZero->uses())) { |
560 | Instruction *I = cast<Instruction>(Val: U.getUser()); |
561 | Value *Vec; |
562 | if (match(V: I, P: m_BitCast(Op: m_Value(V&: Vec)))) { |
563 | I->replaceAllUsesWith(V: VecZero); |
564 | I->eraseFromParent(); |
565 | } |
566 | } |
567 | TileZero->eraseFromParent(); |
568 | return true; |
569 | } |
570 | |
571 | bool X86LowerAMXIntrinsics::visit() { |
572 | bool C = false; |
573 | SmallVector<IntrinsicInst *, 8> WorkList; |
574 | for (BasicBlock *BB : depth_first(G: &Func)) { |
575 | for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) { |
576 | if (auto *Inst = dyn_cast<IntrinsicInst>(Val: &*II++)) { |
577 | switch (Inst->getIntrinsicID()) { |
578 | case Intrinsic::x86_tdpbssd_internal: |
579 | case Intrinsic::x86_tdpbsud_internal: |
580 | case Intrinsic::x86_tdpbusd_internal: |
581 | case Intrinsic::x86_tdpbuud_internal: |
582 | case Intrinsic::x86_tileloadd64_internal: |
583 | case Intrinsic::x86_tilestored64_internal: |
584 | case Intrinsic::x86_tilezero_internal: |
585 | case Intrinsic::x86_tdpbf16ps_internal: |
586 | WorkList.push_back(Elt: Inst); |
587 | break; |
588 | default: |
589 | break; |
590 | } |
591 | } |
592 | } |
593 | } |
594 | |
595 | for (auto *Inst : WorkList) { |
596 | switch (Inst->getIntrinsicID()) { |
597 | case Intrinsic::x86_tdpbssd_internal: |
598 | C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(TileDP: Inst) || C; |
599 | break; |
600 | case Intrinsic::x86_tdpbsud_internal: |
601 | C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(TileDP: Inst) || C; |
602 | break; |
603 | case Intrinsic::x86_tdpbusd_internal: |
604 | C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(TileDP: Inst) || C; |
605 | break; |
606 | case Intrinsic::x86_tdpbuud_internal: |
607 | C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(TileDP: Inst) || C; |
608 | break; |
609 | case Intrinsic::x86_tdpbf16ps_internal: |
610 | C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(TileDP: Inst) || C; |
611 | break; |
612 | case Intrinsic::x86_tileloadd64_internal: |
613 | C = lowerTileLoadStore<true>(TileLoadStore: Inst) || C; |
614 | break; |
615 | case Intrinsic::x86_tilestored64_internal: |
616 | C = lowerTileLoadStore<false>(TileLoadStore: Inst) || C; |
617 | break; |
618 | case Intrinsic::x86_tilezero_internal: |
619 | C = lowerTileZero(TileZero: Inst) || C; |
620 | break; |
621 | default: |
622 | llvm_unreachable("invalid amx intrinsics!" ); |
623 | } |
624 | } |
625 | |
626 | return C; |
627 | } |
628 | |
629 | namespace { |
630 | class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass { |
631 | public: |
632 | static char ID; |
633 | |
634 | X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) {} |
635 | |
636 | bool runOnFunction(Function &F) override { |
637 | if (!X86ScalarizeAMX) |
638 | return false; |
639 | TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); |
640 | if (!F.hasFnAttribute(Kind: Attribute::OptimizeNone) && |
641 | TM->getOptLevel() != CodeGenOptLevel::None) |
642 | return false; |
643 | |
644 | auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); |
645 | auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; |
646 | auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>(); |
647 | auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr; |
648 | DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); |
649 | |
650 | X86LowerAMXIntrinsics LAT(F, DTU, LI); |
651 | return LAT.visit(); |
652 | } |
653 | StringRef getPassName() const override { return "Lower AMX intrinsics" ; } |
654 | |
655 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
656 | AU.addPreserved<DominatorTreeWrapperPass>(); |
657 | AU.addPreserved<LoopInfoWrapperPass>(); |
658 | AU.addRequired<TargetPassConfig>(); |
659 | } |
660 | }; |
661 | } // namespace |
662 | |
663 | static const char PassName[] = "Lower AMX intrinsics" ; |
664 | char X86LowerAMXIntrinsicsLegacyPass::ID = 0; |
665 | INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName, |
666 | false, false) |
667 | INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) |
668 | INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName, |
669 | false, false) |
670 | |
671 | FunctionPass *llvm::createX86LowerAMXIntrinsicsPass() { |
672 | return new X86LowerAMXIntrinsicsLegacyPass(); |
673 | } |
674 | |