1//===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Pass to transform <256 x i32> load/store
10/// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11/// provides simple operation on x86_amx. The basic elementwise operation
12/// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13/// and only AMX intrinsics can operate on the type, we need transform
14/// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15/// not be combined with load/store, we transform the bitcast to amx load/store
16/// and <256 x i32> store/load.
17///
18/// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19/// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20/// because that is necessary for AMX fast register allocation. (In Fast
21/// registera allocation, register will be allocated before spill/reload, so
22/// there is no additional register for amx to identify the step in spill.)
23/// The volatileTileData() will handle this case.
24/// e.g.
25/// ----------------------------------------------------------
26/// | def %td = ... |
27/// | ... |
28/// | "use %td" |
29/// ----------------------------------------------------------
30/// will transfer to -->
31/// ----------------------------------------------------------
32/// | def %td = ... |
33/// | call void @llvm.x86.tilestored64.internal(mem, %td) |
34/// | ... |
35/// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36/// | "use %td2" |
37/// ----------------------------------------------------------
38//
39//===----------------------------------------------------------------------===//
40//
41#include "X86.h"
42#include "llvm/ADT/PostOrderIterator.h"
43#include "llvm/ADT/SetVector.h"
44#include "llvm/ADT/SmallSet.h"
45#include "llvm/Analysis/OptimizationRemarkEmitter.h"
46#include "llvm/Analysis/TargetLibraryInfo.h"
47#include "llvm/Analysis/TargetTransformInfo.h"
48#include "llvm/CodeGen/Passes.h"
49#include "llvm/CodeGen/TargetPassConfig.h"
50#include "llvm/CodeGen/ValueTypes.h"
51#include "llvm/IR/DataLayout.h"
52#include "llvm/IR/Function.h"
53#include "llvm/IR/IRBuilder.h"
54#include "llvm/IR/Instructions.h"
55#include "llvm/IR/IntrinsicInst.h"
56#include "llvm/IR/IntrinsicsX86.h"
57#include "llvm/IR/PatternMatch.h"
58#include "llvm/InitializePasses.h"
59#include "llvm/Pass.h"
60#include "llvm/Target/TargetMachine.h"
61#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
62#include "llvm/Transforms/Utils/Local.h"
63
64#include <map>
65
66using namespace llvm;
67using namespace PatternMatch;
68
69#define DEBUG_TYPE "lower-amx-type"
70
71static bool isAMXCast(Instruction *II) {
72 return match(V: II,
73 P: m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(Op0: m_Value())) ||
74 match(V: II, P: m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(Op0: m_Value()));
75}
76
77static bool isAMXIntrinsic(Value *I) {
78 auto *II = dyn_cast<IntrinsicInst>(Val: I);
79 if (!II)
80 return false;
81 if (isAMXCast(II))
82 return false;
83 // Check if return type or parameter is x86_amx. If it is x86_amx
84 // the intrinsic must be x86 amx intrinsics.
85 if (II->getType()->isX86_AMXTy())
86 return true;
87 for (Value *V : II->args()) {
88 if (V->getType()->isX86_AMXTy())
89 return true;
90 }
91
92 return false;
93}
94
95static bool containsAMXCode(Function &F) {
96 for (BasicBlock &BB : F)
97 for (Instruction &I : BB)
98 if (I.getType()->isX86_AMXTy())
99 return true;
100 return false;
101}
102
103static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB,
104 Type *Ty) {
105 Function &F = *BB->getParent();
106 const DataLayout &DL = F.getDataLayout();
107
108 LLVMContext &Ctx = Builder.getContext();
109 auto AllocaAlignment = DL.getPrefTypeAlign(Ty: Type::getX86_AMXTy(C&: Ctx));
110 unsigned AllocaAS = DL.getAllocaAddrSpace();
111 AllocaInst *AllocaRes =
112 new AllocaInst(Ty, AllocaAS, "", F.getEntryBlock().begin());
113 AllocaRes->setAlignment(AllocaAlignment);
114 return AllocaRes;
115}
116
117static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) {
118 for (Instruction &I : F.getEntryBlock())
119 if (!isa<AllocaInst>(Val: &I))
120 return &I;
121 llvm_unreachable("No terminator in the entry block!");
122}
123
124static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
125 IRBuilder<> Builder(II);
126 Value *Row = nullptr, *Col = nullptr;
127 switch (II->getIntrinsicID()) {
128 default:
129 llvm_unreachable("Expect amx intrinsics");
130 case Intrinsic::x86_tileloadd64_internal:
131 case Intrinsic::x86_tileloaddt164_internal:
132 case Intrinsic::x86_tilestored64_internal: {
133 Row = II->getArgOperand(i: 0);
134 Col = II->getArgOperand(i: 1);
135 break;
136 }
137 // a * b + c
138 // The shape depends on which operand.
139 case Intrinsic::x86_tcmmimfp16ps_internal:
140 case Intrinsic::x86_tcmmrlfp16ps_internal:
141 case Intrinsic::x86_tdpbssd_internal:
142 case Intrinsic::x86_tdpbsud_internal:
143 case Intrinsic::x86_tdpbusd_internal:
144 case Intrinsic::x86_tdpbuud_internal:
145 case Intrinsic::x86_tdpbf16ps_internal:
146 case Intrinsic::x86_tdpfp16ps_internal: {
147 switch (OpNo) {
148 case 3:
149 Row = II->getArgOperand(i: 0);
150 Col = II->getArgOperand(i: 1);
151 break;
152 case 4:
153 Row = II->getArgOperand(i: 0);
154 Col = II->getArgOperand(i: 2);
155 break;
156 case 5:
157 if (isa<ConstantInt>(Val: II->getArgOperand(i: 2)))
158 Row = Builder.getInt16(
159 C: (cast<ConstantInt>(Val: II->getOperand(i_nocapture: 2))->getSExtValue()) / 4);
160 else if (isa<Instruction>(Val: II->getArgOperand(i: 2))) {
161 // When it is not a const value and it is not a function argument, we
162 // create Row after the definition of II->getOperand(2) instead of
163 // before II. For example, II is %118, we try to getshape for %117:
164 // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
165 // i32> %115).
166 // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
167 // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
168 // %117).
169 // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
170 // definition is after its user(new tileload for %117).
171 // So, the best choice is to create %row right after the definition of
172 // %106.
173 Builder.SetInsertPoint(cast<Instruction>(Val: II->getOperand(i_nocapture: 2)));
174 Row = Builder.CreateUDiv(LHS: II->getOperand(i_nocapture: 2), RHS: Builder.getInt16(C: 4));
175 cast<Instruction>(Val: Row)->moveAfter(MovePos: cast<Instruction>(Val: II->getOperand(i_nocapture: 2)));
176 } else {
177 // When it is not a const value and it is a function argument, we create
178 // Row at the entry bb.
179 IRBuilder<> NewBuilder(
180 getFirstNonAllocaInTheEntryBlock(F&: *II->getFunction()));
181 Row = NewBuilder.CreateUDiv(LHS: II->getOperand(i_nocapture: 2), RHS: NewBuilder.getInt16(C: 4));
182 }
183 Col = II->getArgOperand(i: 1);
184 break;
185 }
186 break;
187 }
188 }
189
190 return std::make_pair(x&: Row, y&: Col);
191}
192
193static std::pair<Value *, Value *> getShape(PHINode *Phi) {
194 Use &U = *(Phi->use_begin());
195 unsigned OpNo = U.getOperandNo();
196 User *V = U.getUser();
197 // TODO We don't traverse all users. To make the algorithm simple, here we
198 // just traverse the first user. If we can find shape, then return the shape,
199 // otherwise just return nullptr and the optimization for undef/zero will be
200 // abandoned.
201 while (V) {
202 if (isAMXCast(II: dyn_cast<Instruction>(Val: V))) {
203 if (V->use_empty())
204 break;
205 Use &U = *(V->use_begin());
206 OpNo = U.getOperandNo();
207 V = U.getUser();
208 } else if (isAMXIntrinsic(I: V)) {
209 return getShape(II: cast<IntrinsicInst>(Val: V), OpNo);
210 } else if (isa<PHINode>(Val: V)) {
211 if (V->use_empty())
212 break;
213 Use &U = *(V->use_begin());
214 V = U.getUser();
215 } else {
216 break;
217 }
218 }
219
220 return std::make_pair(x: nullptr, y: nullptr);
221}
222
223namespace {
224class X86LowerAMXType {
225 Function &Func;
226
227 // In AMX intrinsics we let Shape = {Row, Col}, but the
228 // RealCol = Col / ElementSize. We may use the RealCol
229 // as a new Row for other new created AMX intrinsics.
230 std::map<Value *, Value *> Col2Row;
231
232public:
233 X86LowerAMXType(Function &F) : Func(F) {}
234 bool visit();
235 void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
236 void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
237 bool transformBitcast(BitCastInst *Bitcast);
238};
239
240// %src = load <256 x i32>, <256 x i32>* %addr, align 64
241// %2 = bitcast <256 x i32> %src to x86_amx
242// -->
243// %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
244// i8* %addr, i64 %stride64)
245void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
246 Value *Row = nullptr, *Col = nullptr;
247 Use &U = *(Bitcast->use_begin());
248 unsigned OpNo = U.getOperandNo();
249 auto *II = cast<IntrinsicInst>(Val: U.getUser());
250 std::tie(args&: Row, args&: Col) = getShape(II, OpNo);
251 IRBuilder<> Builder(Bitcast);
252 // Use the maximun column as stride.
253 Value *Stride = Builder.getInt64(C: 64);
254 Value *I8Ptr = LD->getOperand(i_nocapture: 0);
255 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
256
257 Value *NewInst = Builder.CreateIntrinsic(ID: Intrinsic::x86_tileloadd64_internal,
258 Types: std::nullopt, Args);
259 Bitcast->replaceAllUsesWith(V: NewInst);
260}
261
262// %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
263// %stride);
264// %13 = bitcast x86_amx %src to <256 x i32>
265// store <256 x i32> %13, <256 x i32>* %addr, align 64
266// -->
267// call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
268// %stride64, %13)
269void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
270
271 Value *Tile = Bitcast->getOperand(i_nocapture: 0);
272 auto *II = cast<IntrinsicInst>(Val: Tile);
273 // Tile is output from AMX intrinsic. The first operand of the
274 // intrinsic is row, the second operand of the intrinsic is column.
275 Value *Row = II->getOperand(i_nocapture: 0);
276 Value *Col = II->getOperand(i_nocapture: 1);
277 IRBuilder<> Builder(ST);
278 // Use the maximum column as stride. It must be the same with load
279 // stride.
280 Value *Stride = Builder.getInt64(C: 64);
281 Value *I8Ptr = ST->getOperand(i_nocapture: 1);
282 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
283 Builder.CreateIntrinsic(ID: Intrinsic::x86_tilestored64_internal, Types: std::nullopt,
284 Args);
285 if (Bitcast->hasOneUse())
286 return;
287 // %13 = bitcast x86_amx %src to <256 x i32>
288 // store <256 x i32> %13, <256 x i32>* %addr, align 64
289 // %add = <256 x i32> %13, <256 x i32> %src2
290 // -->
291 // %13 = bitcast x86_amx %src to <256 x i32>
292 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
293 // %stride64, %13)
294 // %14 = load <256 x i32>, %addr
295 // %add = <256 x i32> %14, <256 x i32> %src2
296 Value *Vec = Builder.CreateLoad(Ty: Bitcast->getType(), Ptr: ST->getOperand(i_nocapture: 1));
297 Bitcast->replaceAllUsesWith(V: Vec);
298}
299
300// transform bitcast to <store, load> instructions.
301bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
302 IRBuilder<> Builder(Bitcast);
303 AllocaInst *AllocaAddr;
304 Value *I8Ptr, *Stride;
305 auto *Src = Bitcast->getOperand(i_nocapture: 0);
306
307 auto Prepare = [&](Type *MemTy) {
308 AllocaAddr = createAllocaInstAtEntry(Builder, BB: Bitcast->getParent(), Ty: MemTy);
309 I8Ptr = AllocaAddr;
310 Stride = Builder.getInt64(C: 64);
311 };
312
313 if (Bitcast->getType()->isX86_AMXTy()) {
314 // %2 = bitcast <256 x i32> %src to x86_amx
315 // -->
316 // %addr = alloca <256 x i32>, align 64
317 // store <256 x i32> %src, <256 x i32>* %addr, align 64
318 // %addr2 = bitcast <256 x i32>* to i8*
319 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
320 // i8* %addr2,
321 // i64 64)
322 Use &U = *(Bitcast->use_begin());
323 unsigned OpNo = U.getOperandNo();
324 auto *II = dyn_cast<IntrinsicInst>(Val: U.getUser());
325 if (!II)
326 return false; // May be bitcast from x86amx to <256 x i32>.
327 Prepare(Bitcast->getOperand(i_nocapture: 0)->getType());
328 Builder.CreateStore(Val: Src, Ptr: AllocaAddr);
329 // TODO we can pick an constant operand for the shape.
330 Value *Row = nullptr, *Col = nullptr;
331 std::tie(args&: Row, args&: Col) = getShape(II, OpNo);
332 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
333 Value *NewInst = Builder.CreateIntrinsic(
334 ID: Intrinsic::x86_tileloadd64_internal, Types: std::nullopt, Args);
335 Bitcast->replaceAllUsesWith(V: NewInst);
336 } else {
337 // %2 = bitcast x86_amx %src to <256 x i32>
338 // -->
339 // %addr = alloca <256 x i32>, align 64
340 // %addr2 = bitcast <256 x i32>* to i8*
341 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
342 // i8* %addr2, i64 %stride)
343 // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
344 auto *II = dyn_cast<IntrinsicInst>(Val: Src);
345 if (!II)
346 return false; // May be bitcast from <256 x i32> to x86amx.
347 Prepare(Bitcast->getType());
348 Value *Row = II->getOperand(i_nocapture: 0);
349 Value *Col = II->getOperand(i_nocapture: 1);
350 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
351 Builder.CreateIntrinsic(ID: Intrinsic::x86_tilestored64_internal, Types: std::nullopt,
352 Args);
353 Value *NewInst = Builder.CreateLoad(Ty: Bitcast->getType(), Ptr: AllocaAddr);
354 Bitcast->replaceAllUsesWith(V: NewInst);
355 }
356
357 return true;
358}
359
360bool X86LowerAMXType::visit() {
361 SmallVector<Instruction *, 8> DeadInsts;
362 Col2Row.clear();
363
364 for (BasicBlock *BB : post_order(G: &Func)) {
365 for (Instruction &Inst : llvm::make_early_inc_range(Range: llvm::reverse(C&: *BB))) {
366 auto *Bitcast = dyn_cast<BitCastInst>(Val: &Inst);
367 if (!Bitcast)
368 continue;
369
370 Value *Src = Bitcast->getOperand(i_nocapture: 0);
371 if (Bitcast->getType()->isX86_AMXTy()) {
372 if (Bitcast->user_empty()) {
373 DeadInsts.push_back(Elt: Bitcast);
374 continue;
375 }
376 LoadInst *LD = dyn_cast<LoadInst>(Val: Src);
377 if (!LD) {
378 if (transformBitcast(Bitcast))
379 DeadInsts.push_back(Elt: Bitcast);
380 continue;
381 }
382 // If load has mutli-user, duplicate a vector load.
383 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
384 // %2 = bitcast <256 x i32> %src to x86_amx
385 // %add = add <256 x i32> %src, <256 x i32> %src2
386 // -->
387 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
388 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
389 // i8* %addr, i64 %stride64)
390 // %add = add <256 x i32> %src, <256 x i32> %src2
391
392 // If load has one user, the load will be eliminated in DAG ISel.
393 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
394 // %2 = bitcast <256 x i32> %src to x86_amx
395 // -->
396 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
397 // i8* %addr, i64 %stride64)
398 combineLoadBitcast(LD, Bitcast);
399 DeadInsts.push_back(Elt: Bitcast);
400 if (LD->hasOneUse())
401 DeadInsts.push_back(Elt: LD);
402 } else if (Src->getType()->isX86_AMXTy()) {
403 if (Bitcast->user_empty()) {
404 DeadInsts.push_back(Elt: Bitcast);
405 continue;
406 }
407 StoreInst *ST = nullptr;
408 for (Use &U : Bitcast->uses()) {
409 ST = dyn_cast<StoreInst>(Val: U.getUser());
410 if (ST)
411 break;
412 }
413 if (!ST) {
414 if (transformBitcast(Bitcast))
415 DeadInsts.push_back(Elt: Bitcast);
416 continue;
417 }
418 // If bitcast (%13) has one use, combine bitcast and store to amx store.
419 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
420 // %stride);
421 // %13 = bitcast x86_amx %src to <256 x i32>
422 // store <256 x i32> %13, <256 x i32>* %addr, align 64
423 // -->
424 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
425 // %stride64, %13)
426 //
427 // If bitcast (%13) has multi-use, transform as below.
428 // %13 = bitcast x86_amx %src to <256 x i32>
429 // store <256 x i32> %13, <256 x i32>* %addr, align 64
430 // %add = <256 x i32> %13, <256 x i32> %src2
431 // -->
432 // %13 = bitcast x86_amx %src to <256 x i32>
433 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
434 // %stride64, %13)
435 // %14 = load <256 x i32>, %addr
436 // %add = <256 x i32> %14, <256 x i32> %src2
437 //
438 combineBitcastStore(Bitcast, ST);
439 // Delete user first.
440 DeadInsts.push_back(Elt: ST);
441 DeadInsts.push_back(Elt: Bitcast);
442 }
443 }
444 }
445
446 bool C = !DeadInsts.empty();
447
448 for (auto *Inst : DeadInsts)
449 Inst->eraseFromParent();
450
451 return C;
452}
453} // anonymous namespace
454
455static Value *getAllocaPos(BasicBlock *BB) {
456 Function *F = BB->getParent();
457 IRBuilder<> Builder(&F->getEntryBlock().front());
458 const DataLayout &DL = F->getDataLayout();
459 unsigned AllocaAS = DL.getAllocaAddrSpace();
460 Type *V256I32Ty = VectorType::get(ElementType: Builder.getInt32Ty(), NumElements: 256, Scalable: false);
461 AllocaInst *AllocaRes =
462 new AllocaInst(V256I32Ty, AllocaAS, "", F->getEntryBlock().begin());
463 BasicBlock::iterator Iter = AllocaRes->getIterator();
464 ++Iter;
465 Builder.SetInsertPoint(&*Iter);
466 Value *I8Ptr = Builder.CreateBitCast(V: AllocaRes, DestTy: Builder.getPtrTy());
467 return I8Ptr;
468}
469
470static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) {
471 assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
472 auto *II = cast<IntrinsicInst>(Val: TileDef);
473 assert(II && "Not tile intrinsic!");
474 Value *Row = II->getOperand(i_nocapture: 0);
475 Value *Col = II->getOperand(i_nocapture: 1);
476
477 BasicBlock *BB = TileDef->getParent();
478 BasicBlock::iterator Iter = TileDef->getIterator();
479 IRBuilder<> Builder(BB, ++Iter);
480 Value *Stride = Builder.getInt64(C: 64);
481 std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
482
483 Instruction *TileStore = Builder.CreateIntrinsic(
484 ID: Intrinsic::x86_tilestored64_internal, Types: std::nullopt, Args);
485 return TileStore;
486}
487
488static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
489 Value *V = U.get();
490 assert(V->getType()->isX86_AMXTy() && "Not define tile!");
491
492 // Get tile shape.
493 IntrinsicInst *II = nullptr;
494 if (IsPHI) {
495 Value *PhiOp = cast<PHINode>(Val: V)->getIncomingValue(i: 0);
496 II = cast<IntrinsicInst>(Val: PhiOp);
497 } else {
498 II = cast<IntrinsicInst>(Val: V);
499 }
500 Value *Row = II->getOperand(i_nocapture: 0);
501 Value *Col = II->getOperand(i_nocapture: 1);
502
503 Instruction *UserI = cast<Instruction>(Val: U.getUser());
504 IRBuilder<> Builder(UserI);
505 Value *Stride = Builder.getInt64(C: 64);
506 std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
507
508 Value *TileLoad = Builder.CreateIntrinsic(ID: Intrinsic::x86_tileloadd64_internal,
509 Types: std::nullopt, Args);
510 UserI->replaceUsesOfWith(From: V, To: TileLoad);
511}
512
513static bool isIncomingOfPHI(Instruction *I) {
514 for (Use &U : I->uses()) {
515 User *V = U.getUser();
516 if (isa<PHINode>(Val: V))
517 return true;
518 }
519 return false;
520}
521
522// Let all AMX tile data become volatile data, shorten the life range
523// of each tile register before fast register allocation.
524namespace {
525class X86VolatileTileData {
526 Function &F;
527
528public:
529 X86VolatileTileData(Function &Func) : F(Func) {}
530 Value *updatePhiIncomings(BasicBlock *BB,
531 SmallVector<Instruction *, 2> &Incomings);
532 void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
533 bool volatileTileData();
534 void volatileTilePHI(PHINode *PHI);
535 void volatileTileNonPHI(Instruction *I);
536};
537
538Value *X86VolatileTileData::updatePhiIncomings(
539 BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) {
540 Value *I8Ptr = getAllocaPos(BB);
541
542 for (auto *I : Incomings) {
543 User *Store = createTileStore(TileDef: I, Ptr: I8Ptr);
544
545 // All its uses (except phi) should load from stored mem.
546 for (Use &U : I->uses()) {
547 User *V = U.getUser();
548 if (isa<PHINode>(Val: V) || V == Store)
549 continue;
550 replaceWithTileLoad(U, Ptr: I8Ptr);
551 }
552 }
553 return I8Ptr;
554}
555
556void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
557 Value *StorePtr) {
558 for (Use &U : PHI->uses())
559 replaceWithTileLoad(U, Ptr: StorePtr, IsPHI: true);
560 PHI->eraseFromParent();
561}
562
563// Smilar with volatileTileNonPHI, this function only handle PHI Nodes
564// and their related AMX intrinsics.
565// 1) PHI Def should change to tileload.
566// 2) PHI Incoming Values should tilestored in just after their def.
567// 3) The mem of these tileload and tilestores should be same.
568// e.g.
569// ------------------------------------------------------
570// bb_dom:
571// ...
572// br i1 %bool.cond, label %if.else, label %if.then
573//
574// if.then:
575// def %t0 = ...
576// ...
577// use %t0
578// ...
579// br label %if.end
580//
581// if.else:
582// def %t1 = ...
583// br label %if.end
584//
585// if.end:
586// %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
587// ...
588// use %td
589// ------------------------------------------------------
590// -->
591// ------------------------------------------------------
592// bb_entry:
593// %mem = alloca <256 x i32>, align 1024 *
594// ...
595// bb_dom:
596// ...
597// br i1 %bool.cond, label %if.else, label %if.then
598//
599// if.then:
600// def %t0 = ...
601// call void @llvm.x86.tilestored64.internal(mem, %t0) *
602// ...
603// %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
604// use %t0` *
605// ...
606// br label %if.end
607//
608// if.else:
609// def %t1 = ...
610// call void @llvm.x86.tilestored64.internal(mem, %t1) *
611// br label %if.end
612//
613// if.end:
614// ...
615// %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
616// use %td
617// ------------------------------------------------------
618void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
619 BasicBlock *BB = PHI->getParent();
620 SmallVector<Instruction *, 2> Incomings;
621
622 for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
623 Value *Op = PHI->getIncomingValue(i: I);
624 Instruction *Inst = dyn_cast<Instruction>(Val: Op);
625 assert(Inst && "We shouldn't fold AMX instrution!");
626 Incomings.push_back(Elt: Inst);
627 }
628
629 Value *StorePtr = updatePhiIncomings(BB, Incomings);
630 replacePhiDefWithLoad(PHI, StorePtr);
631}
632
633// Store the defined tile and load it before use.
634// All its users are not PHI.
635// e.g.
636// ------------------------------------------------------
637// def %td = ...
638// ...
639// "use %td"
640// ------------------------------------------------------
641// -->
642// ------------------------------------------------------
643// def %td = ...
644// call void @llvm.x86.tilestored64.internal(mem, %td)
645// ...
646// %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
647// "use %td2"
648// ------------------------------------------------------
649void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
650 BasicBlock *BB = I->getParent();
651 Value *I8Ptr = getAllocaPos(BB);
652 User *Store = createTileStore(TileDef: I, Ptr: I8Ptr);
653
654 // All its uses should load from stored mem.
655 for (Use &U : I->uses()) {
656 User *V = U.getUser();
657 assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!");
658 if (V != Store)
659 replaceWithTileLoad(U, Ptr: I8Ptr);
660 }
661}
662
663// Volatile Tile Model:
664// 1) All the uses of tile data comes from tileload in time.
665// 2) All the defs of tile data tilestore into mem immediately.
666// For example:
667// --------------------------------------------------------------------------
668// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
669// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
670// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
671// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
672// call void @llvm.x86.tilestored64.internal(... td) area
673// --------------------------------------------------------------------------
674// 3) No terminator, call or other amx instructions in the key amx area.
675bool X86VolatileTileData::volatileTileData() {
676 bool Changed = false;
677 for (BasicBlock &BB : F) {
678 SmallVector<Instruction *, 2> PHIInsts;
679 SmallVector<Instruction *, 8> AMXDefInsts;
680
681 for (Instruction &I : BB) {
682 if (!I.getType()->isX86_AMXTy())
683 continue;
684 if (isa<PHINode>(Val: &I))
685 PHIInsts.push_back(Elt: &I);
686 else
687 AMXDefInsts.push_back(Elt: &I);
688 }
689
690 // First we "volatile" the non-phi related amx intrinsics.
691 for (Instruction *I : AMXDefInsts) {
692 if (isIncomingOfPHI(I))
693 continue;
694 volatileTileNonPHI(I);
695 Changed = true;
696 }
697
698 for (Instruction *I : PHIInsts) {
699 volatileTilePHI(PHI: dyn_cast<PHINode>(Val: I));
700 Changed = true;
701 }
702 }
703 return Changed;
704}
705
706} // anonymous namespace
707
708namespace {
709
710class X86LowerAMXCast {
711 Function &Func;
712 std::unique_ptr<DominatorTree> DT;
713
714public:
715 X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {}
716 bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST);
717 bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
718 bool combineLdSt(SmallVectorImpl<Instruction *> &Casts);
719 bool combineAMXcast(TargetLibraryInfo *TLI);
720 bool transformAMXCast(IntrinsicInst *AMXCast);
721 bool transformAllAMXCast();
722 bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
723 SmallSetVector<Instruction *, 16> &DeadInst);
724};
725
726static bool DCEInstruction(Instruction *I,
727 SmallSetVector<Instruction *, 16> &WorkList,
728 const TargetLibraryInfo *TLI) {
729 if (isInstructionTriviallyDead(I, TLI)) {
730 salvageDebugInfo(I&: *I);
731 salvageKnowledge(I);
732
733 // Null out all of the instruction's operands to see if any operand becomes
734 // dead as we go.
735 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
736 Value *OpV = I->getOperand(i);
737 I->setOperand(i, Val: nullptr);
738
739 if (!OpV->use_empty() || I == OpV)
740 continue;
741
742 // If the operand is an instruction that became dead as we nulled out the
743 // operand, and if it is 'trivially' dead, delete it in a future loop
744 // iteration.
745 if (Instruction *OpI = dyn_cast<Instruction>(Val: OpV)) {
746 if (isInstructionTriviallyDead(I: OpI, TLI)) {
747 WorkList.insert(X: OpI);
748 }
749 }
750 }
751 I->eraseFromParent();
752 return true;
753 }
754 return false;
755}
756
757/// This function handles following case
758///
759/// A -> B amxcast
760/// PHI
761/// B -> A amxcast
762///
763/// All the related PHI nodes can be replaced by new PHI nodes with type A.
764/// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
765bool X86LowerAMXCast::optimizeAMXCastFromPhi(
766 IntrinsicInst *CI, PHINode *PN,
767 SmallSetVector<Instruction *, 16> &DeadInst) {
768 IRBuilder<> Builder(CI);
769 Value *Src = CI->getOperand(i_nocapture: 0);
770 Type *SrcTy = Src->getType(); // Type B
771 Type *DestTy = CI->getType(); // Type A
772
773 SmallVector<PHINode *, 4> PhiWorklist;
774 SmallSetVector<PHINode *, 4> OldPhiNodes;
775
776 // Find all of the A->B casts and PHI nodes.
777 // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
778 // OldPhiNodes is used to track all known PHI nodes, before adding a new
779 // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
780 PhiWorklist.push_back(Elt: PN);
781 OldPhiNodes.insert(X: PN);
782 while (!PhiWorklist.empty()) {
783 auto *OldPN = PhiWorklist.pop_back_val();
784 for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) {
785 Value *IncValue = OldPN->getIncomingValue(i: I);
786 // TODO: currently, We ignore cases where it is a const. In the future, we
787 // might support const.
788 if (isa<Constant>(Val: IncValue)) {
789 auto *IncConst = dyn_cast<Constant>(Val: IncValue);
790 if (!isa<UndefValue>(Val: IncValue) && !IncConst->isZeroValue())
791 return false;
792 Value *Row = nullptr, *Col = nullptr;
793 std::tie(args&: Row, args&: Col) = getShape(Phi: OldPN);
794 // TODO: If it is not constant the Row and Col must domoniate tilezero
795 // that we are going to create.
796 if (!Row || !Col || !isa<Constant>(Val: Row) || !isa<Constant>(Val: Col))
797 return false;
798 // Create tilezero at the end of incoming block.
799 auto *Block = OldPN->getIncomingBlock(i: I);
800 BasicBlock::iterator Iter = Block->getTerminator()->getIterator();
801 Instruction *NewInst = Builder.CreateIntrinsic(
802 ID: Intrinsic::x86_tilezero_internal, Types: std::nullopt, Args: {Row, Col});
803 NewInst->moveBefore(MovePos: &*Iter);
804 NewInst = Builder.CreateIntrinsic(ID: Intrinsic::x86_cast_tile_to_vector,
805 Types: {IncValue->getType()}, Args: {NewInst});
806 NewInst->moveBefore(MovePos: &*Iter);
807 // Replace InValue with new Value.
808 OldPN->setIncomingValue(i: I, V: NewInst);
809 IncValue = NewInst;
810 }
811
812 if (auto *PNode = dyn_cast<PHINode>(Val: IncValue)) {
813 if (OldPhiNodes.insert(X: PNode))
814 PhiWorklist.push_back(Elt: PNode);
815 continue;
816 }
817 Instruction *ACI = dyn_cast<Instruction>(Val: IncValue);
818 if (ACI && isAMXCast(II: ACI)) {
819 // Verify it's a A->B cast.
820 Type *TyA = ACI->getOperand(i: 0)->getType();
821 Type *TyB = ACI->getType();
822 if (TyA != DestTy || TyB != SrcTy)
823 return false;
824 continue;
825 }
826 return false;
827 }
828 }
829
830 // Check that each user of each old PHI node is something that we can
831 // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
832 for (auto *OldPN : OldPhiNodes) {
833 for (User *V : OldPN->users()) {
834 Instruction *ACI = dyn_cast<Instruction>(Val: V);
835 if (ACI && isAMXCast(II: ACI)) {
836 // Verify it's a B->A cast.
837 Type *TyB = ACI->getOperand(i: 0)->getType();
838 Type *TyA = ACI->getType();
839 if (TyA != DestTy || TyB != SrcTy)
840 return false;
841 } else if (auto *PHI = dyn_cast<PHINode>(Val: V)) {
842 // As long as the user is another old PHI node, then even if we don't
843 // rewrite it, the PHI web we're considering won't have any users
844 // outside itself, so it'll be dead.
845 // example:
846 // bb.0:
847 // %0 = amxcast ...
848 // bb.1:
849 // %1 = amxcast ...
850 // bb.2:
851 // %goodphi = phi %0, %1
852 // %3 = amxcast %goodphi
853 // bb.3:
854 // %goodphi2 = phi %0, %goodphi
855 // %4 = amxcast %goodphi2
856 // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
857 // outside the phi-web, so the combination stop When
858 // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
859 // will be done.
860 if (OldPhiNodes.count(key: PHI) == 0)
861 return false;
862 } else
863 return false;
864 }
865 }
866
867 // For each old PHI node, create a corresponding new PHI node with a type A.
868 SmallDenseMap<PHINode *, PHINode *> NewPNodes;
869 for (auto *OldPN : OldPhiNodes) {
870 Builder.SetInsertPoint(OldPN);
871 PHINode *NewPN = Builder.CreatePHI(Ty: DestTy, NumReservedValues: OldPN->getNumOperands());
872 NewPNodes[OldPN] = NewPN;
873 }
874
875 // Fill in the operands of new PHI nodes.
876 for (auto *OldPN : OldPhiNodes) {
877 PHINode *NewPN = NewPNodes[OldPN];
878 for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
879 Value *V = OldPN->getOperand(i_nocapture: j);
880 Value *NewV = nullptr;
881 Instruction *ACI = dyn_cast<Instruction>(Val: V);
882 // There should not be a AMXcast from a const.
883 if (ACI && isAMXCast(II: ACI))
884 NewV = ACI->getOperand(i: 0);
885 else if (auto *PrevPN = dyn_cast<PHINode>(Val: V))
886 NewV = NewPNodes[PrevPN];
887 assert(NewV);
888 NewPN->addIncoming(V: NewV, BB: OldPN->getIncomingBlock(i: j));
889 }
890 }
891
892 // Traverse all accumulated PHI nodes and process its users,
893 // which are Stores and BitcCasts. Without this processing
894 // NewPHI nodes could be replicated and could lead to extra
895 // moves generated after DeSSA.
896 // If there is a store with type B, change it to type A.
897
898 // Replace users of BitCast B->A with NewPHI. These will help
899 // later to get rid of a closure formed by OldPHI nodes.
900 for (auto *OldPN : OldPhiNodes) {
901 PHINode *NewPN = NewPNodes[OldPN];
902 for (User *V : make_early_inc_range(Range: OldPN->users())) {
903 Instruction *ACI = dyn_cast<Instruction>(Val: V);
904 if (ACI && isAMXCast(II: ACI)) {
905 Type *TyB = ACI->getOperand(i: 0)->getType();
906 Type *TyA = ACI->getType();
907 assert(TyA == DestTy && TyB == SrcTy);
908 (void)TyA;
909 (void)TyB;
910 ACI->replaceAllUsesWith(V: NewPN);
911 DeadInst.insert(X: ACI);
912 } else if (auto *PHI = dyn_cast<PHINode>(Val: V)) {
913 // We don't need to push PHINode into DeadInst since they are operands
914 // of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
915 assert(OldPhiNodes.contains(PHI));
916 (void)PHI;
917 } else
918 llvm_unreachable("all uses should be handled");
919 }
920 }
921 return true;
922}
923
924// %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42)
925// store <256 x i32> %43, <256 x i32>* %p, align 64
926// -->
927// call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
928// i64 64, x86_amx %42)
929bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {
930 Value *Tile = Cast->getOperand(i_nocapture: 0);
931 // TODO: If it is cast intrinsic or phi node, we can propagate the
932 // shape information through def-use chain.
933 if (!isAMXIntrinsic(I: Tile))
934 return false;
935 auto *II = cast<IntrinsicInst>(Val: Tile);
936 // Tile is output from AMX intrinsic. The first operand of the
937 // intrinsic is row, the second operand of the intrinsic is column.
938 Value *Row = II->getOperand(i_nocapture: 0);
939 Value *Col = II->getOperand(i_nocapture: 1);
940 IRBuilder<> Builder(ST);
941 // Stride should be equal to col(measured by bytes)
942 Value *Stride = Builder.CreateSExt(V: Col, DestTy: Builder.getInt64Ty());
943 Value *I8Ptr = Builder.CreateBitCast(V: ST->getOperand(i_nocapture: 1), DestTy: Builder.getPtrTy());
944 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
945 Builder.CreateIntrinsic(ID: Intrinsic::x86_tilestored64_internal, Types: std::nullopt,
946 Args);
947 return true;
948}
949
950// %65 = load <256 x i32>, <256 x i32>* %p, align 64
951// %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
952// -->
953// %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
954// i8* %p, i64 64)
955bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
956 bool EraseLoad = true;
957 Value *Row = nullptr, *Col = nullptr;
958 Use &U = *(Cast->use_begin());
959 unsigned OpNo = U.getOperandNo();
960 auto *II = cast<IntrinsicInst>(Val: U.getUser());
961 // TODO: If it is cast intrinsic or phi node, we can propagate the
962 // shape information through def-use chain.
963 if (!isAMXIntrinsic(I: II))
964 return false;
965 std::tie(args&: Row, args&: Col) = getShape(II, OpNo);
966 IRBuilder<> Builder(LD);
967 // Stride should be equal to col(measured by bytes)
968 Value *Stride = Builder.CreateSExt(V: Col, DestTy: Builder.getInt64Ty());
969 Value *I8Ptr;
970
971 // To save compiling time, we create doninator tree when it is really
972 // needed.
973 if (!DT)
974 DT.reset(p: new DominatorTree(Func));
975 if (!DT->dominates(Def: Row, User: LD) || !DT->dominates(Def: Col, User: LD)) {
976 // store the value to stack and reload it from stack before cast.
977 auto *AllocaAddr =
978 createAllocaInstAtEntry(Builder, BB: Cast->getParent(), Ty: LD->getType());
979 Builder.SetInsertPoint(&*std::next(x: LD->getIterator()));
980 Builder.CreateStore(Val: LD, Ptr: AllocaAddr);
981
982 Builder.SetInsertPoint(Cast);
983 I8Ptr = Builder.CreateBitCast(V: AllocaAddr, DestTy: Builder.getPtrTy());
984 EraseLoad = false;
985 } else {
986 I8Ptr = Builder.CreateBitCast(V: LD->getOperand(i_nocapture: 0), DestTy: Builder.getPtrTy());
987 }
988 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
989
990 Value *NewInst = Builder.CreateIntrinsic(ID: Intrinsic::x86_tileloadd64_internal,
991 Types: std::nullopt, Args);
992 Cast->replaceAllUsesWith(V: NewInst);
993
994 return EraseLoad;
995}
996
997bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
998 bool Change = false;
999 for (auto *Cast : Casts) {
1000 auto *II = cast<IntrinsicInst>(Val: Cast);
1001 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42)
1002 // store <256 x i32> %43, <256 x i32>* %p, align 64
1003 // -->
1004 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
1005 // i64 64, x86_amx %42)
1006 if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
1007 SmallVector<Instruction *, 2> DeadStores;
1008 for (User *U : Cast->users()) {
1009 StoreInst *Store = dyn_cast<StoreInst>(Val: U);
1010 if (!Store)
1011 continue;
1012 if (combineCastStore(Cast: cast<IntrinsicInst>(Val: Cast), ST: Store)) {
1013 DeadStores.push_back(Elt: Store);
1014 Change = true;
1015 }
1016 }
1017 for (auto *Store : DeadStores)
1018 Store->eraseFromParent();
1019 } else { // x86_cast_vector_to_tile
1020 SmallVector<Instruction *, 2> DeadLoads;
1021 auto *Load = dyn_cast<LoadInst>(Val: Cast->getOperand(i: 0));
1022 if (!Load || !Load->hasOneUse())
1023 continue;
1024 // %65 = load <256 x i32>, <256 x i32>* %p, align 64
1025 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
1026 // -->
1027 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
1028 // i8* %p, i64 64)
1029 if (combineLoadCast(Cast: cast<IntrinsicInst>(Val: Cast), LD: Load)) {
1030 // Set the operand is null so that load instruction can be erased.
1031 Cast->setOperand(i: 0, Val: nullptr);
1032 Load->eraseFromParent();
1033 }
1034 }
1035 }
1036 return Change;
1037}
1038
1039bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
1040 bool Change = false;
1041 // Collect tile cast instruction.
1042 SmallVector<Instruction *, 8> Vec2TileInsts;
1043 SmallVector<Instruction *, 8> Tile2VecInsts;
1044 SmallVector<Instruction *, 8> PhiCastWorkList;
1045 SmallSetVector<Instruction *, 16> DeadInst;
1046 for (BasicBlock &BB : Func) {
1047 for (Instruction &I : BB) {
1048 Value *Vec;
1049 if (match(V: &I,
1050 P: m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(Op0: m_Value(V&: Vec))))
1051 Vec2TileInsts.push_back(Elt: &I);
1052 else if (match(V: &I, P: m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
1053 Op0: m_Value(V&: Vec))))
1054 Tile2VecInsts.push_back(Elt: &I);
1055 }
1056 }
1057
1058 auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
1059 for (auto *Inst : Insts) {
1060 for (User *U : Inst->users()) {
1061 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: U);
1062 if (!II || II->getIntrinsicID() != IID)
1063 continue;
1064 // T1 = vec2tile V0
1065 // V2 = tile2vec T1
1066 // V3 = OP V2
1067 // -->
1068 // T1 = vec2tile V0
1069 // V2 = tile2vec T1
1070 // V3 = OP V0
1071 II->replaceAllUsesWith(V: Inst->getOperand(i: 0));
1072 Change = true;
1073 }
1074 }
1075 };
1076
1077 Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
1078 Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
1079
1080 SmallVector<Instruction *, 8> LiveCasts;
1081 auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {
1082 for (auto *Inst : Insts) {
1083 if (Inst->use_empty()) {
1084 Inst->eraseFromParent();
1085 Change = true;
1086 } else {
1087 LiveCasts.push_back(Elt: Inst);
1088 }
1089 }
1090 };
1091
1092 EraseInst(Vec2TileInsts);
1093 EraseInst(Tile2VecInsts);
1094 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1095 "Vec2Tile and Tile2Vec:\n";
1096 Func.dump());
1097 Change |= combineLdSt(Casts&: LiveCasts);
1098 EraseInst(LiveCasts);
1099 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1100 "AMXCast and load/store:\n";
1101 Func.dump());
1102
1103 // Handle the A->B->A cast, and there is an intervening PHI node.
1104 for (BasicBlock &BB : Func) {
1105 for (Instruction &I : BB) {
1106 if (isAMXCast(II: &I)) {
1107 if (isa<PHINode>(Val: I.getOperand(i: 0)))
1108 PhiCastWorkList.push_back(Elt: &I);
1109 }
1110 }
1111 }
1112 for (auto *I : PhiCastWorkList) {
1113 // We skip the dead Amxcast.
1114 if (DeadInst.contains(key: I))
1115 continue;
1116 PHINode *PN = cast<PHINode>(Val: I->getOperand(i: 0));
1117 if (optimizeAMXCastFromPhi(CI: cast<IntrinsicInst>(Val: I), PN, DeadInst)) {
1118 DeadInst.insert(X: PN);
1119 Change = true;
1120 }
1121 }
1122
1123 // Since we create new phi and merge AMXCast, some old phis and AMXCast might
1124 // have no uses. We do some DeadCodeElimination for them.
1125 while (!DeadInst.empty()) {
1126 Instruction *I = DeadInst.pop_back_val();
1127 Change |= DCEInstruction(I, WorkList&: DeadInst, TLI);
1128 }
1129 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after "
1130 "optimizeAMXCastFromPhi:\n";
1131 Func.dump());
1132 return Change;
1133}
1134
1135// There might be remaining AMXcast after combineAMXcast and they should be
1136// handled elegantly.
1137bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
1138 IRBuilder<> Builder(AMXCast);
1139 AllocaInst *AllocaAddr;
1140 Value *I8Ptr, *Stride;
1141 auto *Src = AMXCast->getOperand(i_nocapture: 0);
1142
1143 auto Prepare = [&](Type *MemTy) {
1144 AllocaAddr = createAllocaInstAtEntry(Builder, BB: AMXCast->getParent(), Ty: MemTy);
1145 I8Ptr = Builder.CreateBitCast(V: AllocaAddr, DestTy: Builder.getPtrTy());
1146 Stride = Builder.getInt64(C: 64);
1147 };
1148
1149 if (AMXCast->getType()->isX86_AMXTy()) {
1150 // %2 = amxcast <225 x i32> %src to x86_amx
1151 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1152 // i8* %addr3, i64 60, x86_amx %2)
1153 // -->
1154 // %addr = alloca <225 x i32>, align 64
1155 // store <225 x i32> %src, <225 x i32>* %addr, align 64
1156 // %addr2 = bitcast <225 x i32>* %addr to i8*
1157 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
1158 // i8* %addr2,
1159 // i64 60)
1160 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1161 // i8* %addr3, i64 60, x86_amx %2)
1162 if (AMXCast->use_empty()) {
1163 AMXCast->eraseFromParent();
1164 return true;
1165 }
1166 Use &U = *(AMXCast->use_begin());
1167 unsigned OpNo = U.getOperandNo();
1168 auto *II = dyn_cast<IntrinsicInst>(Val: U.getUser());
1169 if (!II)
1170 return false; // May be bitcast from x86amx to <256 x i32>.
1171 Prepare(AMXCast->getOperand(i_nocapture: 0)->getType());
1172 Builder.CreateStore(Val: Src, Ptr: AllocaAddr);
1173 // TODO we can pick an constant operand for the shape.
1174 Value *Row = nullptr, *Col = nullptr;
1175 std::tie(args&: Row, args&: Col) = getShape(II, OpNo);
1176 std::array<Value *, 4> Args = {
1177 Row, Col, I8Ptr, Builder.CreateSExt(V: Col, DestTy: Builder.getInt64Ty())};
1178 Value *NewInst = Builder.CreateIntrinsic(
1179 ID: Intrinsic::x86_tileloadd64_internal, Types: std::nullopt, Args);
1180 AMXCast->replaceAllUsesWith(V: NewInst);
1181 AMXCast->eraseFromParent();
1182 } else {
1183 // %2 = amxcast x86_amx %src to <225 x i32>
1184 // -->
1185 // %addr = alloca <225 x i32>, align 64
1186 // %addr2 = bitcast <225 x i32>* to i8*
1187 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
1188 // i8* %addr2, i64 %stride)
1189 // %2 = load <225 x i32>, <225 x i32>* %addr, align 64
1190 auto *II = dyn_cast<IntrinsicInst>(Val: Src);
1191 if (!II)
1192 return false; // May be bitcast from <256 x i32> to x86amx.
1193 Prepare(AMXCast->getType());
1194 Value *Row = II->getOperand(i_nocapture: 0);
1195 Value *Col = II->getOperand(i_nocapture: 1);
1196 std::array<Value *, 5> Args = {
1197 Row, Col, I8Ptr, Builder.CreateSExt(V: Col, DestTy: Builder.getInt64Ty()), Src};
1198 Builder.CreateIntrinsic(ID: Intrinsic::x86_tilestored64_internal, Types: std::nullopt,
1199 Args);
1200 Value *NewInst = Builder.CreateLoad(Ty: AMXCast->getType(), Ptr: AllocaAddr);
1201 AMXCast->replaceAllUsesWith(V: NewInst);
1202 AMXCast->eraseFromParent();
1203 }
1204
1205 return true;
1206}
1207
1208bool X86LowerAMXCast::transformAllAMXCast() {
1209 bool Change = false;
1210 // Collect tile cast instruction.
1211 SmallVector<Instruction *, 8> WorkLists;
1212 for (BasicBlock &BB : Func) {
1213 for (Instruction &I : BB) {
1214 if (isAMXCast(II: &I))
1215 WorkLists.push_back(Elt: &I);
1216 }
1217 }
1218
1219 for (auto *Inst : WorkLists) {
1220 Change |= transformAMXCast(AMXCast: cast<IntrinsicInst>(Val: Inst));
1221 }
1222
1223 return Change;
1224}
1225
1226} // anonymous namespace
1227
1228namespace {
1229
1230class X86LowerAMXTypeLegacyPass : public FunctionPass {
1231public:
1232 static char ID;
1233
1234 X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
1235 initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry());
1236 }
1237
1238 bool runOnFunction(Function &F) override {
1239 // Performance optimization: most code doesn't use AMX, so return early if
1240 // there are no instructions that produce AMX values. This is sufficient, as
1241 // AMX arguments and constants are not allowed -- so any producer of an AMX
1242 // value must be an instruction.
1243 // TODO: find a cheaper way for this, without looking at all instructions.
1244 if (!containsAMXCode(F))
1245 return false;
1246
1247 bool C = false;
1248 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
1249 TargetLibraryInfo *TLI =
1250 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1251
1252 X86LowerAMXCast LAC(F);
1253 C |= LAC.combineAMXcast(TLI);
1254 // There might be remaining AMXcast after combineAMXcast and they should be
1255 // handled elegantly.
1256 C |= LAC.transformAllAMXCast();
1257
1258 X86LowerAMXType LAT(F);
1259 C |= LAT.visit();
1260
1261 // Prepare for fast register allocation at O0.
1262 // Todo: May better check the volatile model of AMX code, not just
1263 // by checking Attribute::OptimizeNone and CodeGenOptLevel::None.
1264 if (TM->getOptLevel() == CodeGenOptLevel::None) {
1265 // If Front End not use O0 but the Mid/Back end use O0, (e.g.
1266 // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
1267 // sure the amx data is volatile, that is nessary for AMX fast
1268 // register allocation.
1269 if (!F.hasFnAttribute(Kind: Attribute::OptimizeNone)) {
1270 X86VolatileTileData VTD(F);
1271 C = VTD.volatileTileData() || C;
1272 }
1273 }
1274
1275 return C;
1276 }
1277
1278 void getAnalysisUsage(AnalysisUsage &AU) const override {
1279 AU.setPreservesCFG();
1280 AU.addRequired<TargetPassConfig>();
1281 AU.addRequired<TargetLibraryInfoWrapperPass>();
1282 }
1283};
1284
1285} // anonymous namespace
1286
1287static const char PassName[] = "Lower AMX type for load/store";
1288char X86LowerAMXTypeLegacyPass::ID = 0;
1289INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1290 false)
1291INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
1292INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
1293INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1294 false)
1295
1296FunctionPass *llvm::createX86LowerAMXTypePass() {
1297 return new X86LowerAMXTypeLegacyPass();
1298}
1299