1//===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Pass to transform <256 x i32> load/store
10/// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11/// provides simple operation on x86_amx. The basic elementwise operation
12/// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13/// and only AMX intrinsics can operate on the type, we need transform
14/// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15/// not be combined with load/store, we transform the bitcast to amx load/store
16/// and <256 x i32> store/load.
17///
18/// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19/// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20/// because that is necessary for AMX fast register allocation. (In Fast
21/// registera allocation, register will be allocated before spill/reload, so
22/// there is no additional register for amx to identify the step in spill.)
23/// The volatileTileData() will handle this case.
24/// e.g.
25/// ----------------------------------------------------------
26/// | def %td = ... |
27/// | ... |
28/// | "use %td" |
29/// ----------------------------------------------------------
30/// will transfer to -->
31/// ----------------------------------------------------------
32/// | def %td = ... |
33/// | call void @llvm.x86.tilestored64.internal(mem, %td) |
34/// | ... |
35/// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36/// | "use %td2" |
37/// ----------------------------------------------------------
38//
39//===----------------------------------------------------------------------===//
40//
41#include "X86.h"
42#include "llvm/ADT/PostOrderIterator.h"
43#include "llvm/ADT/SetVector.h"
44#include "llvm/Analysis/TargetLibraryInfo.h"
45#include "llvm/Analysis/TargetTransformInfo.h"
46#include "llvm/CodeGen/Passes.h"
47#include "llvm/CodeGen/TargetPassConfig.h"
48#include "llvm/CodeGen/ValueTypes.h"
49#include "llvm/IR/Analysis.h"
50#include "llvm/IR/DataLayout.h"
51#include "llvm/IR/Function.h"
52#include "llvm/IR/IRBuilder.h"
53#include "llvm/IR/Instructions.h"
54#include "llvm/IR/IntrinsicInst.h"
55#include "llvm/IR/IntrinsicsX86.h"
56#include "llvm/IR/PassManager.h"
57#include "llvm/IR/PatternMatch.h"
58#include "llvm/InitializePasses.h"
59#include "llvm/Pass.h"
60#include "llvm/Target/TargetMachine.h"
61#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
62#include "llvm/Transforms/Utils/Local.h"
63
64#include <map>
65
66using namespace llvm;
67using namespace PatternMatch;
68
69#define DEBUG_TYPE "x86-lower-amx-type"
70
71static bool isAMXCast(Instruction *II) {
72 return match(V: II,
73 P: m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(Op0: m_Value())) ||
74 match(V: II, P: m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(Op0: m_Value()));
75}
76
77static bool isAMXIntrinsic(Value *I) {
78 auto *II = dyn_cast<IntrinsicInst>(Val: I);
79 if (!II)
80 return false;
81 if (isAMXCast(II))
82 return false;
83 // Check if return type or parameter is x86_amx. If it is x86_amx
84 // the intrinsic must be x86 amx intrinsics.
85 if (II->getType()->isX86_AMXTy())
86 return true;
87 for (Value *V : II->args()) {
88 if (V->getType()->isX86_AMXTy())
89 return true;
90 }
91
92 return false;
93}
94
95static bool containsAMXCode(Function &F) {
96 for (BasicBlock &BB : F)
97 for (Instruction &I : BB)
98 if (I.getType()->isX86_AMXTy())
99 return true;
100 return false;
101}
102
103static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB,
104 Type *Ty) {
105 Function &F = *BB->getParent();
106 const DataLayout &DL = F.getDataLayout();
107
108 LLVMContext &Ctx = Builder.getContext();
109 auto AllocaAlignment = DL.getPrefTypeAlign(Ty: Type::getX86_AMXTy(C&: Ctx));
110 unsigned AllocaAS = DL.getAllocaAddrSpace();
111 AllocaInst *AllocaRes =
112 new AllocaInst(Ty, AllocaAS, "", F.getEntryBlock().begin());
113 AllocaRes->setAlignment(AllocaAlignment);
114 return AllocaRes;
115}
116
117static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) {
118 for (Instruction &I : F.getEntryBlock())
119 if (!isa<AllocaInst>(Val: &I))
120 return &I;
121 llvm_unreachable("No terminator in the entry block!");
122}
123
124static Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity) {
125 IRBuilder<> Builder(II);
126 Value *RealRow = nullptr;
127 if (isa<ConstantInt>(Val: V))
128 RealRow =
129 Builder.getInt16(C: (cast<ConstantInt>(Val: V)->getSExtValue()) / Granularity);
130 else if (isa<Instruction>(Val: V)) {
131 // When it is not a const value and it is not a function argument, we
132 // create Row after the definition of V instead of
133 // before II. For example, II is %118, we try to getshape for %117:
134 // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
135 // i32> %115).
136 // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
137 // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
138 // %117).
139 // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
140 // definition is after its user(new tileload for %117).
141 // So, the best choice is to create %row right after the definition of
142 // %106.
143 Builder.SetInsertPoint(cast<Instruction>(Val: V));
144 RealRow = Builder.CreateUDiv(LHS: V, RHS: Builder.getInt16(C: 4));
145 cast<Instruction>(Val: RealRow)->moveAfter(MovePos: cast<Instruction>(Val: V));
146 } else {
147 // When it is not a const value and it is a function argument, we create
148 // Row at the entry bb.
149 IRBuilder<> NewBuilder(
150 getFirstNonAllocaInTheEntryBlock(F&: *II->getFunction()));
151 RealRow = NewBuilder.CreateUDiv(LHS: V, RHS: NewBuilder.getInt16(C: Granularity));
152 }
153 return RealRow;
154}
155
156// TODO: Refine the row and col-in-bytes of tile to row and col of matrix.
157std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
158 IRBuilder<> Builder(II);
159 Value *Row = nullptr, *Col = nullptr;
160 switch (II->getIntrinsicID()) {
161 default:
162 llvm_unreachable("Expect amx intrinsics");
163 case Intrinsic::x86_tileloadd64_internal:
164 case Intrinsic::x86_tileloaddt164_internal:
165 case Intrinsic::x86_tilestored64_internal:
166 case Intrinsic::x86_tileloaddrs64_internal:
167 case Intrinsic::x86_tileloaddrst164_internal: {
168 Row = II->getArgOperand(i: 0);
169 Col = II->getArgOperand(i: 1);
170 break;
171 }
172 // a * b + c
173 // The shape depends on which operand.
174 case Intrinsic::x86_tcmmimfp16ps_internal:
175 case Intrinsic::x86_tcmmrlfp16ps_internal:
176 case Intrinsic::x86_tdpbssd_internal:
177 case Intrinsic::x86_tdpbsud_internal:
178 case Intrinsic::x86_tdpbusd_internal:
179 case Intrinsic::x86_tdpbuud_internal:
180 case Intrinsic::x86_tdpbf16ps_internal:
181 case Intrinsic::x86_tdpfp16ps_internal:
182 case Intrinsic::x86_tmmultf32ps_internal:
183 case Intrinsic::x86_tdpbf8ps_internal:
184 case Intrinsic::x86_tdpbhf8ps_internal:
185 case Intrinsic::x86_tdphbf8ps_internal:
186 case Intrinsic::x86_tdphf8ps_internal: {
187 switch (OpNo) {
188 case 3:
189 Row = II->getArgOperand(i: 0);
190 Col = II->getArgOperand(i: 1);
191 break;
192 case 4:
193 Row = II->getArgOperand(i: 0);
194 Col = II->getArgOperand(i: 2);
195 break;
196 case 5:
197 Row = getRowFromCol(II, V: II->getArgOperand(i: 2), Granularity: 4);
198 Col = II->getArgOperand(i: 1);
199 break;
200 }
201 break;
202 }
203 case Intrinsic::x86_tcvtrowd2ps_internal:
204 case Intrinsic::x86_tcvtrowps2bf16h_internal:
205 case Intrinsic::x86_tcvtrowps2bf16l_internal:
206 case Intrinsic::x86_tcvtrowps2phh_internal:
207 case Intrinsic::x86_tcvtrowps2phl_internal:
208 case Intrinsic::x86_tilemovrow_internal: {
209 assert(OpNo == 2 && "Illegal Operand Number.");
210 Row = II->getArgOperand(i: 0);
211 Col = II->getArgOperand(i: 1);
212 break;
213 }
214 }
215
216 return std::make_pair(x&: Row, y&: Col);
217}
218
219static std::pair<Value *, Value *> getShape(PHINode *Phi) {
220 Use &U = *(Phi->use_begin());
221 unsigned OpNo = U.getOperandNo();
222 User *V = U.getUser();
223 // TODO We don't traverse all users. To make the algorithm simple, here we
224 // just traverse the first user. If we can find shape, then return the shape,
225 // otherwise just return nullptr and the optimization for undef/zero will be
226 // abandoned.
227 while (V) {
228 if (isAMXCast(II: dyn_cast<Instruction>(Val: V))) {
229 if (V->use_empty())
230 break;
231 Use &U = *(V->use_begin());
232 OpNo = U.getOperandNo();
233 V = U.getUser();
234 } else if (isAMXIntrinsic(I: V)) {
235 return getShape(II: cast<IntrinsicInst>(Val: V), OpNo);
236 } else if (isa<PHINode>(Val: V)) {
237 if (V->use_empty())
238 break;
239 Use &U = *(V->use_begin());
240 V = U.getUser();
241 } else {
242 break;
243 }
244 }
245
246 return std::make_pair(x: nullptr, y: nullptr);
247}
248
249namespace {
250class X86LowerAMXType {
251 Function &Func;
252
253 // In AMX intrinsics we let Shape = {Row, Col}, but the
254 // RealCol = Col / ElementSize. We may use the RealCol
255 // as a new Row for other new created AMX intrinsics.
256 std::map<Value *, Value *> Col2Row;
257
258public:
259 X86LowerAMXType(Function &F) : Func(F) {}
260 bool visit();
261 void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
262 void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
263 bool transformBitcast(BitCastInst *Bitcast);
264};
265
266// %src = load <256 x i32>, <256 x i32>* %addr, align 64
267// %2 = bitcast <256 x i32> %src to x86_amx
268// -->
269// %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
270// i8* %addr, i64 %stride64)
271void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
272 Value *Row = nullptr, *Col = nullptr;
273 Use &U = *(Bitcast->use_begin());
274 unsigned OpNo = U.getOperandNo();
275 auto *II = cast<IntrinsicInst>(Val: U.getUser());
276 std::tie(args&: Row, args&: Col) = getShape(II, OpNo);
277 IRBuilder<> Builder(Bitcast);
278 // Use the maximun column as stride.
279 Value *Stride = Builder.getInt64(C: 64);
280 Value *I8Ptr = LD->getOperand(i_nocapture: 0);
281 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
282
283 Value *NewInst =
284 Builder.CreateIntrinsic(ID: Intrinsic::x86_tileloadd64_internal, Args);
285 Bitcast->replaceAllUsesWith(V: NewInst);
286}
287
288// %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
289// %stride);
290// %13 = bitcast x86_amx %src to <256 x i32>
291// store <256 x i32> %13, <256 x i32>* %addr, align 64
292// -->
293// call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
294// %stride64, %13)
295void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
296
297 Value *Tile = Bitcast->getOperand(i_nocapture: 0);
298 auto *II = cast<IntrinsicInst>(Val: Tile);
299 // Tile is output from AMX intrinsic. The first operand of the
300 // intrinsic is row, the second operand of the intrinsic is column.
301 Value *Row = II->getOperand(i_nocapture: 0);
302 Value *Col = II->getOperand(i_nocapture: 1);
303 IRBuilder<> Builder(ST);
304 // Use the maximum column as stride. It must be the same with load
305 // stride.
306 Value *Stride = Builder.getInt64(C: 64);
307 Value *I8Ptr = ST->getOperand(i_nocapture: 1);
308 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
309 Builder.CreateIntrinsic(ID: Intrinsic::x86_tilestored64_internal, Args);
310 if (Bitcast->hasOneUse())
311 return;
312 // %13 = bitcast x86_amx %src to <256 x i32>
313 // store <256 x i32> %13, <256 x i32>* %addr, align 64
314 // %add = <256 x i32> %13, <256 x i32> %src2
315 // -->
316 // %13 = bitcast x86_amx %src to <256 x i32>
317 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
318 // %stride64, %13)
319 // %14 = load <256 x i32>, %addr
320 // %add = <256 x i32> %14, <256 x i32> %src2
321 Value *Vec = Builder.CreateLoad(Ty: Bitcast->getType(), Ptr: ST->getOperand(i_nocapture: 1));
322 Bitcast->replaceAllUsesWith(V: Vec);
323}
324
325// transform bitcast to <store, load> instructions.
326bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
327 IRBuilder<> Builder(Bitcast);
328 AllocaInst *AllocaAddr;
329 Value *I8Ptr, *Stride;
330 auto *Src = Bitcast->getOperand(i_nocapture: 0);
331
332 auto Prepare = [&](Type *MemTy) {
333 AllocaAddr = createAllocaInstAtEntry(Builder, BB: Bitcast->getParent(), Ty: MemTy);
334 I8Ptr = AllocaAddr;
335 Stride = Builder.getInt64(C: 64);
336 };
337
338 if (Bitcast->getType()->isX86_AMXTy()) {
339 // %2 = bitcast <256 x i32> %src to x86_amx
340 // -->
341 // %addr = alloca <256 x i32>, align 64
342 // store <256 x i32> %src, <256 x i32>* %addr, align 64
343 // %addr2 = bitcast <256 x i32>* to i8*
344 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
345 // i8* %addr2,
346 // i64 64)
347 Use &U = *(Bitcast->use_begin());
348 unsigned OpNo = U.getOperandNo();
349 auto *II = dyn_cast<IntrinsicInst>(Val: U.getUser());
350 if (!II)
351 return false; // May be bitcast from x86amx to <256 x i32>.
352 Prepare(Bitcast->getOperand(i_nocapture: 0)->getType());
353 Builder.CreateStore(Val: Src, Ptr: AllocaAddr);
354 // TODO we can pick an constant operand for the shape.
355 Value *Row = nullptr, *Col = nullptr;
356 std::tie(args&: Row, args&: Col) = getShape(II, OpNo);
357 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
358 Value *NewInst =
359 Builder.CreateIntrinsic(ID: Intrinsic::x86_tileloadd64_internal, Args);
360 Bitcast->replaceAllUsesWith(V: NewInst);
361 } else {
362 // %2 = bitcast x86_amx %src to <256 x i32>
363 // -->
364 // %addr = alloca <256 x i32>, align 64
365 // %addr2 = bitcast <256 x i32>* to i8*
366 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
367 // i8* %addr2, i64 %stride)
368 // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
369 auto *II = dyn_cast<IntrinsicInst>(Val: Src);
370 if (!II)
371 return false; // May be bitcast from <256 x i32> to x86amx.
372 Prepare(Bitcast->getType());
373 Value *Row = II->getOperand(i_nocapture: 0);
374 Value *Col = II->getOperand(i_nocapture: 1);
375 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
376 Builder.CreateIntrinsic(ID: Intrinsic::x86_tilestored64_internal, Args);
377 Value *NewInst = Builder.CreateLoad(Ty: Bitcast->getType(), Ptr: AllocaAddr);
378 Bitcast->replaceAllUsesWith(V: NewInst);
379 }
380
381 return true;
382}
383
384bool X86LowerAMXType::visit() {
385 SmallVector<Instruction *, 8> DeadInsts;
386 Col2Row.clear();
387
388 for (BasicBlock *BB : post_order(G: &Func)) {
389 for (Instruction &Inst : llvm::make_early_inc_range(Range: llvm::reverse(C&: *BB))) {
390 auto *Bitcast = dyn_cast<BitCastInst>(Val: &Inst);
391 if (!Bitcast)
392 continue;
393
394 Value *Src = Bitcast->getOperand(i_nocapture: 0);
395 if (Bitcast->getType()->isX86_AMXTy()) {
396 if (Bitcast->user_empty()) {
397 DeadInsts.push_back(Elt: Bitcast);
398 continue;
399 }
400 LoadInst *LD = dyn_cast<LoadInst>(Val: Src);
401 if (!LD) {
402 if (transformBitcast(Bitcast))
403 DeadInsts.push_back(Elt: Bitcast);
404 continue;
405 }
406 // If load has multi-user, duplicate a vector load.
407 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
408 // %2 = bitcast <256 x i32> %src to x86_amx
409 // %add = add <256 x i32> %src, <256 x i32> %src2
410 // -->
411 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
412 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
413 // i8* %addr, i64 %stride64)
414 // %add = add <256 x i32> %src, <256 x i32> %src2
415
416 // If load has one user, the load will be eliminated in DAG ISel.
417 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
418 // %2 = bitcast <256 x i32> %src to x86_amx
419 // -->
420 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
421 // i8* %addr, i64 %stride64)
422 combineLoadBitcast(LD, Bitcast);
423 DeadInsts.push_back(Elt: Bitcast);
424 if (LD->hasOneUse())
425 DeadInsts.push_back(Elt: LD);
426 } else if (Src->getType()->isX86_AMXTy()) {
427 if (Bitcast->user_empty()) {
428 DeadInsts.push_back(Elt: Bitcast);
429 continue;
430 }
431 StoreInst *ST = nullptr;
432 for (Use &U : Bitcast->uses()) {
433 ST = dyn_cast<StoreInst>(Val: U.getUser());
434 if (ST)
435 break;
436 }
437 if (!ST) {
438 if (transformBitcast(Bitcast))
439 DeadInsts.push_back(Elt: Bitcast);
440 continue;
441 }
442 // If bitcast (%13) has one use, combine bitcast and store to amx store.
443 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
444 // %stride);
445 // %13 = bitcast x86_amx %src to <256 x i32>
446 // store <256 x i32> %13, <256 x i32>* %addr, align 64
447 // -->
448 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
449 // %stride64, %13)
450 //
451 // If bitcast (%13) has multi-use, transform as below.
452 // %13 = bitcast x86_amx %src to <256 x i32>
453 // store <256 x i32> %13, <256 x i32>* %addr, align 64
454 // %add = <256 x i32> %13, <256 x i32> %src2
455 // -->
456 // %13 = bitcast x86_amx %src to <256 x i32>
457 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
458 // %stride64, %13)
459 // %14 = load <256 x i32>, %addr
460 // %add = <256 x i32> %14, <256 x i32> %src2
461 //
462 combineBitcastStore(Bitcast, ST);
463 // Delete user first.
464 DeadInsts.push_back(Elt: ST);
465 DeadInsts.push_back(Elt: Bitcast);
466 }
467 }
468 }
469
470 bool C = !DeadInsts.empty();
471
472 for (auto *Inst : DeadInsts)
473 Inst->eraseFromParent();
474
475 return C;
476}
477} // anonymous namespace
478
479static Value *getAllocaPos(BasicBlock *BB) {
480 Function *F = BB->getParent();
481 IRBuilder<> Builder(&F->getEntryBlock().front());
482 const DataLayout &DL = F->getDataLayout();
483 unsigned AllocaAS = DL.getAllocaAddrSpace();
484 Type *V256I32Ty = VectorType::get(ElementType: Builder.getInt32Ty(), NumElements: 256, Scalable: false);
485 AllocaInst *AllocaRes =
486 new AllocaInst(V256I32Ty, AllocaAS, "", F->getEntryBlock().begin());
487 BasicBlock::iterator Iter = AllocaRes->getIterator();
488 ++Iter;
489 Builder.SetInsertPoint(&*Iter);
490 Value *I8Ptr = Builder.CreateBitCast(V: AllocaRes, DestTy: Builder.getPtrTy());
491 return I8Ptr;
492}
493
494static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) {
495 assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
496 auto *II = cast<IntrinsicInst>(Val: TileDef);
497
498 assert(II && "Not tile intrinsic!");
499 Value *Row = II->getOperand(i_nocapture: 0);
500 Value *Col = II->getOperand(i_nocapture: 1);
501
502 BasicBlock *BB = TileDef->getParent();
503 BasicBlock::iterator Iter = TileDef->getIterator();
504 IRBuilder<> Builder(BB, ++Iter);
505 Value *Stride = Builder.getInt64(C: 64);
506 std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
507
508 Instruction *TileStore = Builder.CreateIntrinsicWithoutFolding(
509 ID: Intrinsic::x86_tilestored64_internal, Args);
510 return TileStore;
511}
512
513static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
514 Value *V = U.get();
515 assert(V->getType()->isX86_AMXTy() && "Not define tile!");
516
517 // Get tile shape.
518 IntrinsicInst *II = nullptr;
519 if (IsPHI) {
520 Value *PhiOp = cast<PHINode>(Val: V)->getIncomingValue(i: 0);
521 II = cast<IntrinsicInst>(Val: PhiOp);
522 } else {
523 II = cast<IntrinsicInst>(Val: V);
524 }
525 Value *Row = II->getOperand(i_nocapture: 0);
526 Value *Col = II->getOperand(i_nocapture: 1);
527
528 Instruction *UserI = cast<Instruction>(Val: U.getUser());
529 IRBuilder<> Builder(UserI);
530 Value *Stride = Builder.getInt64(C: 64);
531 std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
532
533 Value *TileLoad =
534 Builder.CreateIntrinsic(ID: Intrinsic::x86_tileloadd64_internal, Args);
535 UserI->replaceUsesOfWith(From: V, To: TileLoad);
536}
537
538static bool isIncomingOfPHI(Instruction *I) {
539 for (Use &U : I->uses()) {
540 User *V = U.getUser();
541 if (isa<PHINode>(Val: V))
542 return true;
543 }
544 return false;
545}
546
547// Let all AMX tile data become volatile data, shorten the life range
548// of each tile register before fast register allocation.
549namespace {
550class X86VolatileTileData {
551 Function &F;
552
553public:
554 X86VolatileTileData(Function &Func) : F(Func) {}
555 Value *updatePhiIncomings(BasicBlock *BB,
556 SmallVector<Instruction *, 2> &Incomings);
557 void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
558 bool volatileTileData();
559 void volatileTilePHI(PHINode *PHI);
560 void volatileTileNonPHI(Instruction *I);
561};
562
563Value *X86VolatileTileData::updatePhiIncomings(
564 BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) {
565 Value *I8Ptr = getAllocaPos(BB);
566
567 for (auto *I : Incomings) {
568 User *Store = createTileStore(TileDef: I, Ptr: I8Ptr);
569
570 // All its uses (except phi) should load from stored mem.
571 for (Use &U : I->uses()) {
572 User *V = U.getUser();
573 if (isa<PHINode>(Val: V) || V == Store)
574 continue;
575 replaceWithTileLoad(U, Ptr: I8Ptr);
576 }
577 }
578 return I8Ptr;
579}
580
581void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
582 Value *StorePtr) {
583 for (Use &U : PHI->uses())
584 replaceWithTileLoad(U, Ptr: StorePtr, IsPHI: true);
585 PHI->eraseFromParent();
586}
587
588// Smilar with volatileTileNonPHI, this function only handle PHI Nodes
589// and their related AMX intrinsics.
590// 1) PHI Def should change to tileload.
591// 2) PHI Incoming Values should tilestored in just after their def.
592// 3) The mem of these tileload and tilestores should be same.
593// e.g.
594// ------------------------------------------------------
595// bb_dom:
596// ...
597// br i1 %bool.cond, label %if.else, label %if.then
598//
599// if.then:
600// def %t0 = ...
601// ...
602// use %t0
603// ...
604// br label %if.end
605//
606// if.else:
607// def %t1 = ...
608// br label %if.end
609//
610// if.end:
611// %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
612// ...
613// use %td
614// ------------------------------------------------------
615// -->
616// ------------------------------------------------------
617// bb_entry:
618// %mem = alloca <256 x i32>, align 1024 *
619// ...
620// bb_dom:
621// ...
622// br i1 %bool.cond, label %if.else, label %if.then
623//
624// if.then:
625// def %t0 = ...
626// call void @llvm.x86.tilestored64.internal(mem, %t0) *
627// ...
628// %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
629// use %t0` *
630// ...
631// br label %if.end
632//
633// if.else:
634// def %t1 = ...
635// call void @llvm.x86.tilestored64.internal(mem, %t1) *
636// br label %if.end
637//
638// if.end:
639// ...
640// %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
641// use %td
642// ------------------------------------------------------
643void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
644 BasicBlock *BB = PHI->getParent();
645 SmallVector<Instruction *, 2> Incomings;
646
647 for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
648 Value *Op = PHI->getIncomingValue(i: I);
649 Instruction *Inst = dyn_cast<Instruction>(Val: Op);
650 assert(Inst && "We shouldn't fold AMX instrution!");
651 Incomings.push_back(Elt: Inst);
652 }
653
654 Value *StorePtr = updatePhiIncomings(BB, Incomings);
655 replacePhiDefWithLoad(PHI, StorePtr);
656}
657
658// Store the defined tile and load it before use.
659// All its users are not PHI.
660// e.g.
661// ------------------------------------------------------
662// def %td = ...
663// ...
664// "use %td"
665// ------------------------------------------------------
666// -->
667// ------------------------------------------------------
668// def %td = ...
669// call void @llvm.x86.tilestored64.internal(mem, %td)
670// ...
671// %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
672// "use %td2"
673// ------------------------------------------------------
674void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
675 BasicBlock *BB = I->getParent();
676 Value *I8Ptr = getAllocaPos(BB);
677 User *Store = createTileStore(TileDef: I, Ptr: I8Ptr);
678
679 // All its uses should load from stored mem.
680 for (Use &U : I->uses()) {
681 User *V = U.getUser();
682 assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!");
683 if (V != Store)
684 replaceWithTileLoad(U, Ptr: I8Ptr);
685 }
686}
687
688// Volatile Tile Model:
689// 1) All the uses of tile data comes from tileload in time.
690// 2) All the defs of tile data tilestore into mem immediately.
691// For example:
692// --------------------------------------------------------------------------
693// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
694// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
695// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
696// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
697// call void @llvm.x86.tilestored64.internal(... td) area
698// --------------------------------------------------------------------------
699// 3) No terminator, call or other amx instructions in the key amx area.
700bool X86VolatileTileData::volatileTileData() {
701 bool Changed = false;
702 for (BasicBlock &BB : F) {
703 SmallVector<Instruction *, 2> PHIInsts;
704 SmallVector<Instruction *, 8> AMXDefInsts;
705
706 for (Instruction &I : BB) {
707 if (!I.getType()->isX86_AMXTy())
708 continue;
709 if (isa<PHINode>(Val: &I))
710 PHIInsts.push_back(Elt: &I);
711 else
712 AMXDefInsts.push_back(Elt: &I);
713 }
714
715 // First we "volatile" the non-phi related amx intrinsics.
716 for (Instruction *I : AMXDefInsts) {
717 if (isIncomingOfPHI(I))
718 continue;
719 volatileTileNonPHI(I);
720 Changed = true;
721 }
722
723 for (Instruction *I : PHIInsts) {
724 volatileTilePHI(PHI: dyn_cast<PHINode>(Val: I));
725 Changed = true;
726 }
727 }
728 return Changed;
729}
730
731} // anonymous namespace
732
733namespace {
734
735class X86LowerAMXCast {
736 Function &Func;
737 std::unique_ptr<DominatorTree> DT;
738
739public:
740 X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {}
741 bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST);
742 bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
743 bool combineTilezero(IntrinsicInst *Cast);
744 bool combineLdSt(SmallVectorImpl<Instruction *> &Casts);
745 bool combineAMXcast(TargetLibraryInfo *TLI);
746 bool transformAMXCast(IntrinsicInst *AMXCast);
747 bool transformAllAMXCast();
748 bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
749 SmallSetVector<Instruction *, 16> &DeadInst);
750};
751
752static bool DCEInstruction(Instruction *I,
753 SmallSetVector<Instruction *, 16> &WorkList,
754 const TargetLibraryInfo *TLI) {
755 if (isInstructionTriviallyDead(I, TLI)) {
756 salvageDebugInfo(I&: *I);
757 salvageKnowledge(I);
758
759 // Null out all of the instruction's operands to see if any operand becomes
760 // dead as we go.
761 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
762 Value *OpV = I->getOperand(i);
763 I->setOperand(i, Val: nullptr);
764
765 if (!OpV->use_empty() || I == OpV)
766 continue;
767
768 // If the operand is an instruction that became dead as we nulled out the
769 // operand, and if it is 'trivially' dead, delete it in a future loop
770 // iteration.
771 if (Instruction *OpI = dyn_cast<Instruction>(Val: OpV)) {
772 if (isInstructionTriviallyDead(I: OpI, TLI)) {
773 WorkList.insert(X: OpI);
774 }
775 }
776 }
777 I->eraseFromParent();
778 return true;
779 }
780 return false;
781}
782
783/// This function handles following case
784///
785/// A -> B amxcast
786/// PHI
787/// B -> A amxcast
788///
789/// All the related PHI nodes can be replaced by new PHI nodes with type A.
790/// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
791bool X86LowerAMXCast::optimizeAMXCastFromPhi(
792 IntrinsicInst *CI, PHINode *PN,
793 SmallSetVector<Instruction *, 16> &DeadInst) {
794 IRBuilder<> Builder(CI);
795 Value *Src = CI->getOperand(i_nocapture: 0);
796 Type *SrcTy = Src->getType(); // Type B
797 Type *DestTy = CI->getType(); // Type A
798
799 SmallVector<PHINode *, 4> PhiWorklist;
800 SmallSetVector<PHINode *, 4> OldPhiNodes;
801
802 // Find all of the A->B casts and PHI nodes.
803 // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
804 // OldPhiNodes is used to track all known PHI nodes, before adding a new
805 // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
806 PhiWorklist.push_back(Elt: PN);
807 OldPhiNodes.insert(X: PN);
808 while (!PhiWorklist.empty()) {
809 auto *OldPN = PhiWorklist.pop_back_val();
810 for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) {
811 Value *IncValue = OldPN->getIncomingValue(i: I);
812 // TODO: currently, We ignore cases where it is a const. In the future, we
813 // might support const.
814 if (isa<Constant>(Val: IncValue)) {
815 auto *IncConst = dyn_cast<Constant>(Val: IncValue);
816 if (!isa<UndefValue>(Val: IncValue) && !IncConst->isNullValue())
817 return false;
818 Value *Row = nullptr, *Col = nullptr;
819 std::tie(args&: Row, args&: Col) = getShape(Phi: OldPN);
820 // TODO: If it is not constant the Row and Col must domoniate tilezero
821 // that we are going to create.
822 if (!Row || !Col || !isa<Constant>(Val: Row) || !isa<Constant>(Val: Col))
823 return false;
824 // Create tilezero at the end of incoming block.
825 auto *Block = OldPN->getIncomingBlock(i: I);
826 BasicBlock::iterator Iter = Block->getTerminator()->getIterator();
827 Instruction *NewInst = Builder.CreateIntrinsicWithoutFolding(
828 ID: Intrinsic::x86_tilezero_internal, OverloadTypes: {}, Args: {Row, Col});
829 NewInst->moveBefore(InsertPos: Iter);
830 NewInst = Builder.CreateIntrinsicWithoutFolding(
831 ID: Intrinsic::x86_cast_tile_to_vector, OverloadTypes: {IncValue->getType()},
832 Args: {NewInst});
833 NewInst->moveBefore(InsertPos: Iter);
834 // Replace InValue with new Value.
835 OldPN->setIncomingValue(i: I, V: NewInst);
836 IncValue = NewInst;
837 }
838
839 if (auto *PNode = dyn_cast<PHINode>(Val: IncValue)) {
840 if (OldPhiNodes.insert(X: PNode))
841 PhiWorklist.push_back(Elt: PNode);
842 continue;
843 }
844 Instruction *ACI = dyn_cast<Instruction>(Val: IncValue);
845 if (ACI && isAMXCast(II: ACI)) {
846 // Verify it's a A->B cast.
847 Type *TyA = ACI->getOperand(i: 0)->getType();
848 Type *TyB = ACI->getType();
849 if (TyA != DestTy || TyB != SrcTy)
850 return false;
851 continue;
852 }
853 return false;
854 }
855 }
856
857 // Check that each user of each old PHI node is something that we can
858 // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
859 for (auto *OldPN : OldPhiNodes) {
860 for (User *V : OldPN->users()) {
861 Instruction *ACI = dyn_cast<Instruction>(Val: V);
862 if (ACI && isAMXCast(II: ACI)) {
863 // Verify it's a B->A cast.
864 Type *TyB = ACI->getOperand(i: 0)->getType();
865 Type *TyA = ACI->getType();
866 if (TyA != DestTy || TyB != SrcTy)
867 return false;
868 } else if (auto *PHI = dyn_cast<PHINode>(Val: V)) {
869 // As long as the user is another old PHI node, then even if we don't
870 // rewrite it, the PHI web we're considering won't have any users
871 // outside itself, so it'll be dead.
872 // example:
873 // bb.0:
874 // %0 = amxcast ...
875 // bb.1:
876 // %1 = amxcast ...
877 // bb.2:
878 // %goodphi = phi %0, %1
879 // %3 = amxcast %goodphi
880 // bb.3:
881 // %goodphi2 = phi %0, %goodphi
882 // %4 = amxcast %goodphi2
883 // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
884 // outside the phi-web, so the combination stop When
885 // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
886 // will be done.
887 if (OldPhiNodes.count(key: PHI) == 0)
888 return false;
889 } else
890 return false;
891 }
892 }
893
894 // For each old PHI node, create a corresponding new PHI node with a type A.
895 SmallDenseMap<PHINode *, PHINode *> NewPNodes;
896 for (auto *OldPN : OldPhiNodes) {
897 Builder.SetInsertPoint(OldPN);
898 PHINode *NewPN = Builder.CreatePHI(Ty: DestTy, NumReservedValues: OldPN->getNumOperands());
899 NewPNodes[OldPN] = NewPN;
900 }
901
902 // Fill in the operands of new PHI nodes.
903 for (auto *OldPN : OldPhiNodes) {
904 PHINode *NewPN = NewPNodes[OldPN];
905 for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
906 Value *V = OldPN->getOperand(i_nocapture: j);
907 Value *NewV = nullptr;
908 Instruction *ACI = dyn_cast<Instruction>(Val: V);
909 // There should not be a AMXcast from a const.
910 if (ACI && isAMXCast(II: ACI))
911 NewV = ACI->getOperand(i: 0);
912 else if (auto *PrevPN = dyn_cast<PHINode>(Val: V))
913 NewV = NewPNodes[PrevPN];
914 assert(NewV);
915 NewPN->addIncoming(V: NewV, BB: OldPN->getIncomingBlock(i: j));
916 }
917 }
918
919 // Traverse all accumulated PHI nodes and process its users,
920 // which are Stores and BitcCasts. Without this processing
921 // NewPHI nodes could be replicated and could lead to extra
922 // moves generated after DeSSA.
923 // If there is a store with type B, change it to type A.
924
925 // Replace users of BitCast B->A with NewPHI. These will help
926 // later to get rid of a closure formed by OldPHI nodes.
927 for (auto *OldPN : OldPhiNodes) {
928 PHINode *NewPN = NewPNodes[OldPN];
929 for (User *V : make_early_inc_range(Range: OldPN->users())) {
930 Instruction *ACI = dyn_cast<Instruction>(Val: V);
931 if (ACI && isAMXCast(II: ACI)) {
932 Type *TyB = ACI->getOperand(i: 0)->getType();
933 Type *TyA = ACI->getType();
934 assert(TyA == DestTy && TyB == SrcTy);
935 (void)TyA;
936 (void)TyB;
937 ACI->replaceAllUsesWith(V: NewPN);
938 DeadInst.insert(X: ACI);
939 } else if (auto *PHI = dyn_cast<PHINode>(Val: V)) {
940 // We don't need to push PHINode into DeadInst since they are operands
941 // of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
942 assert(OldPhiNodes.contains(PHI));
943 (void)PHI;
944 } else
945 llvm_unreachable("all uses should be handled");
946 }
947 }
948 return true;
949}
950
951// %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42)
952// store <256 x i32> %43, <256 x i32>* %p, align 64
953// -->
954// call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
955// i64 64, x86_amx %42)
956bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {
957 Value *Tile = Cast->getOperand(i_nocapture: 0);
958
959 assert(Tile->getType()->isX86_AMXTy() && "Not Tile Operand!");
960
961 // TODO: Specially handle the multi-use case.
962 if (!Tile->hasOneUse())
963 return false;
964
965 auto *II = cast<IntrinsicInst>(Val: Tile);
966 // Tile is output from AMX intrinsic. The first operand of the
967 // intrinsic is row, the second operand of the intrinsic is column.
968 Value *Row = II->getOperand(i_nocapture: 0);
969 Value *Col = II->getOperand(i_nocapture: 1);
970
971 IRBuilder<> Builder(ST);
972
973 // Stride should be equal to col(measured by bytes)
974 Value *Stride = Builder.CreateSExt(V: Col, DestTy: Builder.getInt64Ty());
975 Value *I8Ptr = Builder.CreateBitCast(V: ST->getOperand(i_nocapture: 1), DestTy: Builder.getPtrTy());
976 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
977 Builder.CreateIntrinsic(ID: Intrinsic::x86_tilestored64_internal, Args);
978 return true;
979}
980
981// %65 = load <256 x i32>, <256 x i32>* %p, align 64
982// %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
983// -->
984// %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
985// i8* %p, i64 64)
986bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
987 bool EraseLoad = true;
988 Value *Row = nullptr, *Col = nullptr;
989 Use &U = *(Cast->use_begin());
990 unsigned OpNo = U.getOperandNo();
991 auto *II = cast<IntrinsicInst>(Val: U.getUser());
992 // TODO: If it is cast intrinsic or phi node, we can propagate the
993 // shape information through def-use chain.
994 if (!isAMXIntrinsic(I: II))
995 return false;
996 std::tie(args&: Row, args&: Col) = getShape(II, OpNo);
997 IRBuilder<> Builder(LD);
998 Value *I8Ptr;
999
1000 // To save compiling time, we create dominator tree when it is really needed.
1001 if (!DT)
1002 DT.reset(p: new DominatorTree(Func));
1003 if (!DT->dominates(Def: Row, User: LD) || !DT->dominates(Def: Col, User: LD)) {
1004 // store the value to stack and reload it from stack before cast.
1005 auto *AllocaAddr =
1006 createAllocaInstAtEntry(Builder, BB: Cast->getParent(), Ty: LD->getType());
1007 Builder.SetInsertPoint(&*std::next(x: LD->getIterator()));
1008 Builder.CreateStore(Val: LD, Ptr: AllocaAddr);
1009
1010 Builder.SetInsertPoint(Cast);
1011 I8Ptr = Builder.CreateBitCast(V: AllocaAddr, DestTy: Builder.getPtrTy());
1012 EraseLoad = false;
1013 } else {
1014 I8Ptr = Builder.CreateBitCast(V: LD->getOperand(i_nocapture: 0), DestTy: Builder.getPtrTy());
1015 }
1016 // Stride should be equal to col(measured by bytes)
1017 Value *Stride = Builder.CreateSExt(V: Col, DestTy: Builder.getInt64Ty());
1018 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
1019
1020 Value *NewInst =
1021 Builder.CreateIntrinsic(ID: Intrinsic::x86_tileloadd64_internal, Args);
1022 Cast->replaceAllUsesWith(V: NewInst);
1023
1024 return EraseLoad;
1025}
1026
1027// %19 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> zeroinitializer)
1028// -->
1029// %19 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %row, i16 %col)
1030bool X86LowerAMXCast::combineTilezero(IntrinsicInst *Cast) {
1031 Value *Row = nullptr, *Col = nullptr;
1032 Use &U = *(Cast->use_begin());
1033 unsigned OpNo = U.getOperandNo();
1034 auto *II = cast<IntrinsicInst>(Val: U.getUser());
1035 if (!isAMXIntrinsic(I: II))
1036 return false;
1037
1038 std::tie(args&: Row, args&: Col) = getShape(II, OpNo);
1039
1040 IRBuilder<> Builder(Cast);
1041 Value *NewInst =
1042 Builder.CreateIntrinsic(ID: Intrinsic::x86_tilezero_internal, OverloadTypes: {}, Args: {Row, Col});
1043 Cast->replaceAllUsesWith(V: NewInst);
1044 return true;
1045}
1046
1047bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
1048 bool Change = false;
1049 for (auto *Cast : Casts) {
1050 auto *II = cast<IntrinsicInst>(Val: Cast);
1051 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42)
1052 // store <256 x i32> %43, <256 x i32>* %p, align 64
1053 // -->
1054 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
1055 // i64 64, x86_amx %42)
1056 if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
1057 SmallVector<Instruction *, 2> DeadStores;
1058 for (User *U : Cast->users()) {
1059 StoreInst *Store = dyn_cast<StoreInst>(Val: U);
1060 if (!Store)
1061 continue;
1062 if (combineCastStore(Cast: cast<IntrinsicInst>(Val: Cast), ST: Store)) {
1063 DeadStores.push_back(Elt: Store);
1064 Change = true;
1065 }
1066 }
1067 for (auto *Store : DeadStores)
1068 Store->eraseFromParent();
1069 } else { // x86_cast_vector_to_tile
1070 // %19 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> zeroinitializer)
1071 // -->
1072 // %19 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %row, i16 %col)
1073 if (isa<ConstantAggregateZero>(Val: Cast->getOperand(i: 0))) {
1074 Change |= combineTilezero(Cast: cast<IntrinsicInst>(Val: Cast));
1075 continue;
1076 }
1077
1078 auto *Load = dyn_cast<LoadInst>(Val: Cast->getOperand(i: 0));
1079 if (!Load || !Load->hasOneUse())
1080 continue;
1081 // %65 = load <256 x i32>, <256 x i32>* %p, align 64
1082 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
1083 // -->
1084 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
1085 // i8* %p, i64 64)
1086 if (combineLoadCast(Cast: cast<IntrinsicInst>(Val: Cast), LD: Load)) {
1087 // Set the operand is null so that load instruction can be erased.
1088 Cast->setOperand(i: 0, Val: nullptr);
1089 Load->eraseFromParent();
1090 Change = true;
1091 }
1092 }
1093 }
1094 return Change;
1095}
1096
1097bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
1098 bool Change = false;
1099 // Collect tile cast instruction.
1100 SmallVector<Instruction *, 8> Vec2TileInsts;
1101 SmallVector<Instruction *, 8> Tile2VecInsts;
1102 SmallVector<Instruction *, 8> PhiCastWorkList;
1103 SmallSetVector<Instruction *, 16> DeadInst;
1104 for (BasicBlock &BB : Func) {
1105 for (Instruction &I : BB) {
1106 Value *Vec;
1107 if (match(V: &I,
1108 P: m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(Op0: m_Value(V&: Vec))))
1109 Vec2TileInsts.push_back(Elt: &I);
1110 else if (match(V: &I, P: m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
1111 Op0: m_Value(V&: Vec))))
1112 Tile2VecInsts.push_back(Elt: &I);
1113 }
1114 }
1115
1116 auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
1117 for (auto *Inst : Insts) {
1118 for (User *U : Inst->users()) {
1119 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: U);
1120 if (!II || II->getIntrinsicID() != IID)
1121 continue;
1122 // T1 = vec2tile V0
1123 // V2 = tile2vec T1
1124 // V3 = OP V2
1125 // -->
1126 // T1 = vec2tile V0
1127 // V2 = tile2vec T1
1128 // V3 = OP V0
1129 II->replaceAllUsesWith(V: Inst->getOperand(i: 0));
1130 Change = true;
1131 }
1132 }
1133 };
1134
1135 Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
1136 Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
1137
1138 SmallVector<Instruction *, 8> LiveCasts;
1139 auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {
1140 for (auto *Inst : Insts) {
1141 if (Inst->use_empty()) {
1142 Inst->eraseFromParent();
1143 Change = true;
1144 } else {
1145 LiveCasts.push_back(Elt: Inst);
1146 }
1147 }
1148 };
1149
1150 EraseInst(Vec2TileInsts);
1151 EraseInst(Tile2VecInsts);
1152 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1153 "Vec2Tile and Tile2Vec:\n";
1154 Func.dump());
1155 Change |= combineLdSt(Casts&: LiveCasts);
1156 EraseInst(LiveCasts);
1157 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1158 "AMXCast and load/store:\n";
1159 Func.dump());
1160
1161 // Handle the A->B->A cast, and there is an intervening PHI node.
1162 for (BasicBlock &BB : Func) {
1163 for (Instruction &I : BB) {
1164 if (isAMXCast(II: &I)) {
1165 if (isa<PHINode>(Val: I.getOperand(i: 0)))
1166 PhiCastWorkList.push_back(Elt: &I);
1167 }
1168 }
1169 }
1170 for (auto *I : PhiCastWorkList) {
1171 // We skip the dead Amxcast.
1172 if (DeadInst.contains(key: I))
1173 continue;
1174 PHINode *PN = cast<PHINode>(Val: I->getOperand(i: 0));
1175 if (optimizeAMXCastFromPhi(CI: cast<IntrinsicInst>(Val: I), PN, DeadInst)) {
1176 DeadInst.insert(X: PN);
1177 Change = true;
1178 }
1179 }
1180
1181 // Since we create new phi and merge AMXCast, some old phis and AMXCast might
1182 // have no uses. We do some DeadCodeElimination for them.
1183 while (!DeadInst.empty()) {
1184 Instruction *I = DeadInst.pop_back_val();
1185 Change |= DCEInstruction(I, WorkList&: DeadInst, TLI);
1186 }
1187 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after "
1188 "optimizeAMXCastFromPhi:\n";
1189 Func.dump());
1190 return Change;
1191}
1192
1193// There might be remaining AMXcast after combineAMXcast and they should be
1194// handled elegantly.
1195bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
1196 IRBuilder<> Builder(AMXCast);
1197 AllocaInst *AllocaAddr;
1198 Value *I8Ptr, *Stride;
1199 auto *Src = AMXCast->getOperand(i_nocapture: 0);
1200
1201 auto Prepare = [&](Type *MemTy) {
1202 AllocaAddr = createAllocaInstAtEntry(Builder, BB: AMXCast->getParent(), Ty: MemTy);
1203 I8Ptr = Builder.CreateBitCast(V: AllocaAddr, DestTy: Builder.getPtrTy());
1204 Stride = Builder.getInt64(C: 64);
1205 };
1206
1207 if (AMXCast->getType()->isX86_AMXTy()) {
1208 // %2 = amxcast <225 x i32> %src to x86_amx
1209 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1210 // i8* %addr3, i64 60, x86_amx %2)
1211 // -->
1212 // %addr = alloca <225 x i32>, align 64
1213 // store <225 x i32> %src, <225 x i32>* %addr, align 64
1214 // %addr2 = bitcast <225 x i32>* %addr to i8*
1215 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
1216 // i8* %addr2,
1217 // i64 60)
1218 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1219 // i8* %addr3, i64 60, x86_amx %2)
1220 if (AMXCast->use_empty()) {
1221 AMXCast->eraseFromParent();
1222 return true;
1223 }
1224 Use &U = *(AMXCast->use_begin());
1225 unsigned OpNo = U.getOperandNo();
1226 auto *II = dyn_cast<IntrinsicInst>(Val: U.getUser());
1227 if (!II)
1228 return false; // May be bitcast from x86amx to <256 x i32>.
1229 Prepare(AMXCast->getOperand(i_nocapture: 0)->getType());
1230 Builder.CreateStore(Val: Src, Ptr: AllocaAddr);
1231 // TODO we can pick an constant operand for the shape.
1232 Value *Row = nullptr, *Col = nullptr;
1233 std::tie(args&: Row, args&: Col) = getShape(II, OpNo);
1234 std::array<Value *, 4> Args = {
1235 Row, Col, I8Ptr, Builder.CreateSExt(V: Col, DestTy: Builder.getInt64Ty())};
1236 Value *NewInst =
1237 Builder.CreateIntrinsic(ID: Intrinsic::x86_tileloadd64_internal, Args);
1238 AMXCast->replaceAllUsesWith(V: NewInst);
1239 AMXCast->eraseFromParent();
1240 } else {
1241 // %2 = amxcast x86_amx %src to <225 x i32>
1242 // -->
1243 // %addr = alloca <225 x i32>, align 64
1244 // %addr2 = bitcast <225 x i32>* to i8*
1245 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
1246 // i8* %addr2, i64 %stride)
1247 // %2 = load <225 x i32>, <225 x i32>* %addr, align 64
1248 auto *II = dyn_cast<IntrinsicInst>(Val: Src);
1249 if (!II)
1250 return false; // May be bitcast from <256 x i32> to x86amx.
1251 Prepare(AMXCast->getType());
1252 Value *Row = II->getOperand(i_nocapture: 0);
1253 Value *Col = II->getOperand(i_nocapture: 1);
1254 std::array<Value *, 5> Args = {
1255 Row, Col, I8Ptr, Builder.CreateSExt(V: Col, DestTy: Builder.getInt64Ty()), Src};
1256 Builder.CreateIntrinsic(ID: Intrinsic::x86_tilestored64_internal, Args);
1257 Value *NewInst = Builder.CreateLoad(Ty: AMXCast->getType(), Ptr: AllocaAddr);
1258 AMXCast->replaceAllUsesWith(V: NewInst);
1259 AMXCast->eraseFromParent();
1260 }
1261
1262 return true;
1263}
1264
1265bool X86LowerAMXCast::transformAllAMXCast() {
1266 bool Change = false;
1267 // Collect tile cast instruction.
1268 SmallVector<Instruction *, 8> WorkLists;
1269 for (BasicBlock &BB : Func) {
1270 for (Instruction &I : BB) {
1271 if (isAMXCast(II: &I))
1272 WorkLists.push_back(Elt: &I);
1273 }
1274 }
1275
1276 for (auto *Inst : WorkLists) {
1277 Change |= transformAMXCast(AMXCast: cast<IntrinsicInst>(Val: Inst));
1278 }
1279
1280 return Change;
1281}
1282
1283bool lowerAmxType(Function &F, const TargetMachine *TM,
1284 TargetLibraryInfo *TLI) {
1285 // Performance optimization: most code doesn't use AMX, so return early if
1286 // there are no instructions that produce AMX values. This is sufficient, as
1287 // AMX arguments and constants are not allowed -- so any producer of an AMX
1288 // value must be an instruction.
1289 // TODO: find a cheaper way for this, without looking at all instructions.
1290 if (!containsAMXCode(F))
1291 return false;
1292
1293 bool C = false;
1294 X86LowerAMXCast LAC(F);
1295 C |= LAC.combineAMXcast(TLI);
1296 // There might be remaining AMXcast after combineAMXcast and they should be
1297 // handled elegantly.
1298 C |= LAC.transformAllAMXCast();
1299
1300 X86LowerAMXType LAT(F);
1301 C |= LAT.visit();
1302
1303 // Prepare for fast register allocation at O0.
1304 // Todo: May better check the volatile model of AMX code, not just
1305 // by checking Attribute::OptimizeNone and CodeGenOptLevel::None.
1306 if (TM->getOptLevel() == CodeGenOptLevel::None) {
1307 // If Front End not use O0 but the Mid/Back end use O0, (e.g.
1308 // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
1309 // sure the amx data is volatile, that is necessary for AMX fast
1310 // register allocation.
1311 if (!F.hasFnAttribute(Kind: Attribute::OptimizeNone)) {
1312 X86VolatileTileData VTD(F);
1313 C = VTD.volatileTileData() || C;
1314 }
1315 }
1316
1317 return C;
1318}
1319
1320} // anonymous namespace
1321
1322PreservedAnalyses X86LowerAMXTypePass::run(Function &F,
1323 FunctionAnalysisManager &FAM) {
1324 TargetLibraryInfo &TLI = FAM.getResult<TargetLibraryAnalysis>(IR&: F);
1325 bool Changed = lowerAmxType(F, TM, TLI: &TLI);
1326 if (!Changed)
1327 return PreservedAnalyses::all();
1328
1329 PreservedAnalyses PA = PreservedAnalyses::none();
1330 PA.preserveSet<CFGAnalyses>();
1331 return PA;
1332}
1333
1334namespace {
1335
1336class X86LowerAMXTypeLegacyPass : public FunctionPass {
1337public:
1338 static char ID;
1339
1340 X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {}
1341
1342 bool runOnFunction(Function &F) override {
1343 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
1344 TargetLibraryInfo *TLI =
1345 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1346 return lowerAmxType(F, TM, TLI);
1347 }
1348
1349 void getAnalysisUsage(AnalysisUsage &AU) const override {
1350 AU.setPreservesCFG();
1351 AU.addRequired<TargetPassConfig>();
1352 AU.addRequired<TargetLibraryInfoWrapperPass>();
1353 }
1354};
1355
1356} // anonymous namespace
1357
1358static const char PassName[] = "Lower AMX type for load/store";
1359char X86LowerAMXTypeLegacyPass::ID = 0;
1360INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1361 false)
1362INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
1363INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
1364INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1365 false)
1366
1367FunctionPass *llvm::createX86LowerAMXTypeLegacyPass() {
1368 return new X86LowerAMXTypeLegacyPass();
1369}
1370