1 | //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | /// \file Pass to transform <256 x i32> load/store |
10 | /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only |
11 | /// provides simple operation on x86_amx. The basic elementwise operation |
12 | /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32> |
13 | /// and only AMX intrinsics can operate on the type, we need transform |
14 | /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can |
15 | /// not be combined with load/store, we transform the bitcast to amx load/store |
16 | /// and <256 x i32> store/load. |
17 | /// |
18 | /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S |
19 | /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile, |
20 | /// because that is necessary for AMX fast register allocation. (In Fast |
21 | /// registera allocation, register will be allocated before spill/reload, so |
22 | /// there is no additional register for amx to identify the step in spill.) |
23 | /// The volatileTileData() will handle this case. |
24 | /// e.g. |
25 | /// ---------------------------------------------------------- |
26 | /// | def %td = ... | |
27 | /// | ... | |
28 | /// | "use %td" | |
29 | /// ---------------------------------------------------------- |
30 | /// will transfer to --> |
31 | /// ---------------------------------------------------------- |
32 | /// | def %td = ... | |
33 | /// | call void @llvm.x86.tilestored64.internal(mem, %td) | |
34 | /// | ... | |
35 | /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)| |
36 | /// | "use %td2" | |
37 | /// ---------------------------------------------------------- |
38 | // |
39 | //===----------------------------------------------------------------------===// |
40 | // |
41 | #include "X86.h" |
42 | #include "llvm/ADT/PostOrderIterator.h" |
43 | #include "llvm/ADT/SetVector.h" |
44 | #include "llvm/Analysis/TargetLibraryInfo.h" |
45 | #include "llvm/Analysis/TargetTransformInfo.h" |
46 | #include "llvm/CodeGen/Passes.h" |
47 | #include "llvm/CodeGen/TargetPassConfig.h" |
48 | #include "llvm/CodeGen/ValueTypes.h" |
49 | #include "llvm/IR/DataLayout.h" |
50 | #include "llvm/IR/Function.h" |
51 | #include "llvm/IR/IRBuilder.h" |
52 | #include "llvm/IR/Instructions.h" |
53 | #include "llvm/IR/IntrinsicInst.h" |
54 | #include "llvm/IR/IntrinsicsX86.h" |
55 | #include "llvm/IR/PatternMatch.h" |
56 | #include "llvm/InitializePasses.h" |
57 | #include "llvm/Pass.h" |
58 | #include "llvm/Target/TargetMachine.h" |
59 | #include "llvm/Transforms/Utils/AssumeBundleBuilder.h" |
60 | #include "llvm/Transforms/Utils/Local.h" |
61 | |
62 | #include <map> |
63 | |
64 | using namespace llvm; |
65 | using namespace PatternMatch; |
66 | |
67 | #define DEBUG_TYPE "lower-amx-type" |
68 | |
69 | static bool isAMXCast(Instruction *II) { |
70 | return match(V: II, |
71 | P: m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(Op0: m_Value())) || |
72 | match(V: II, P: m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(Op0: m_Value())); |
73 | } |
74 | |
75 | // Some instructions may return more than one tiles. |
76 | // e.g: call { x86_amx, x86_amx } @llvm.x86.t2rpntlvwz0.internal |
77 | static unsigned getNumDefTiles(IntrinsicInst *II) { |
78 | Type *Ty = II->getType(); |
79 | if (Ty->isX86_AMXTy()) |
80 | return 1; |
81 | |
82 | unsigned Num = 0; |
83 | for (unsigned i = 0; i < Ty->getNumContainedTypes(); i++) { |
84 | Type *STy = Ty->getContainedType(i); |
85 | if (STy->isX86_AMXTy()) |
86 | Num++; |
87 | } |
88 | return Num; |
89 | } |
90 | |
91 | static bool isAMXIntrinsic(Value *I) { |
92 | auto *II = dyn_cast<IntrinsicInst>(Val: I); |
93 | if (!II) |
94 | return false; |
95 | if (isAMXCast(II)) |
96 | return false; |
97 | // Check if return type or parameter is x86_amx. If it is x86_amx |
98 | // the intrinsic must be x86 amx intrinsics. |
99 | if (getNumDefTiles(II) > 0) |
100 | return true; |
101 | for (Value *V : II->args()) { |
102 | if (V->getType()->isX86_AMXTy()) |
103 | return true; |
104 | } |
105 | |
106 | return false; |
107 | } |
108 | |
109 | static bool containsAMXCode(Function &F) { |
110 | for (BasicBlock &BB : F) |
111 | for (Instruction &I : BB) |
112 | if (I.getType()->isX86_AMXTy()) |
113 | return true; |
114 | return false; |
115 | } |
116 | |
117 | static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, |
118 | Type *Ty) { |
119 | Function &F = *BB->getParent(); |
120 | const DataLayout &DL = F.getDataLayout(); |
121 | |
122 | LLVMContext &Ctx = Builder.getContext(); |
123 | auto AllocaAlignment = DL.getPrefTypeAlign(Ty: Type::getX86_AMXTy(C&: Ctx)); |
124 | unsigned AllocaAS = DL.getAllocaAddrSpace(); |
125 | AllocaInst *AllocaRes = |
126 | new AllocaInst(Ty, AllocaAS, "" , F.getEntryBlock().begin()); |
127 | AllocaRes->setAlignment(AllocaAlignment); |
128 | return AllocaRes; |
129 | } |
130 | |
131 | static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) { |
132 | for (Instruction &I : F.getEntryBlock()) |
133 | if (!isa<AllocaInst>(Val: &I)) |
134 | return &I; |
135 | llvm_unreachable("No terminator in the entry block!" ); |
136 | } |
137 | |
138 | class ShapeCalculator { |
139 | private: |
140 | TargetMachine *TM = nullptr; |
141 | |
142 | // In AMX intrinsics we let Shape = {Row, Col}, but the |
143 | // RealCol = Col / ElementSize. We may use the RealCol |
144 | // as a new Row for other new created AMX intrinsics. |
145 | std::map<Value *, Value *> Col2Row, Row2Col; |
146 | |
147 | public: |
148 | ShapeCalculator(TargetMachine *TargetM) : TM(TargetM) {} |
149 | std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo); |
150 | std::pair<Value *, Value *> getShape(PHINode *Phi); |
151 | Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity); |
152 | Value *getColFromRow(Instruction *II, Value *V, unsigned Granularity); |
153 | }; |
154 | |
155 | Value *ShapeCalculator::getRowFromCol(Instruction *II, Value *V, |
156 | unsigned Granularity) { |
157 | if (auto It = Col2Row.find(x: V); It != Col2Row.end()) |
158 | return It->second; |
159 | IRBuilder<> Builder(II); |
160 | Value *RealRow = nullptr; |
161 | if (isa<ConstantInt>(Val: V)) |
162 | RealRow = |
163 | Builder.getInt16(C: (cast<ConstantInt>(Val: V)->getSExtValue()) / Granularity); |
164 | else if (isa<Instruction>(Val: V)) { |
165 | // When it is not a const value and it is not a function argument, we |
166 | // create Row after the definition of V instead of |
167 | // before II. For example, II is %118, we try to getshape for %117: |
168 | // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x |
169 | // i32> %115). |
170 | // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16 |
171 | // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx |
172 | // %117). |
173 | // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its |
174 | // definition is after its user(new tileload for %117). |
175 | // So, the best choice is to create %row right after the definition of |
176 | // %106. |
177 | Builder.SetInsertPoint(cast<Instruction>(Val: V)); |
178 | RealRow = Builder.CreateUDiv(LHS: V, RHS: Builder.getInt16(C: 4)); |
179 | cast<Instruction>(Val: RealRow)->moveAfter(MovePos: cast<Instruction>(Val: V)); |
180 | } else { |
181 | // When it is not a const value and it is a function argument, we create |
182 | // Row at the entry bb. |
183 | IRBuilder<> NewBuilder( |
184 | getFirstNonAllocaInTheEntryBlock(F&: *II->getFunction())); |
185 | RealRow = NewBuilder.CreateUDiv(LHS: V, RHS: NewBuilder.getInt16(C: Granularity)); |
186 | } |
187 | Col2Row[V] = RealRow; |
188 | return RealRow; |
189 | } |
190 | |
191 | Value *ShapeCalculator::getColFromRow(Instruction *II, Value *V, |
192 | unsigned Granularity) { |
193 | if (auto It = Row2Col.find(x: V); It != Row2Col.end()) |
194 | return It->second; |
195 | IRBuilder<> Builder(II); |
196 | Value *RealCol = nullptr; |
197 | if (isa<ConstantInt>(Val: V)) |
198 | RealCol = |
199 | Builder.getInt16(C: (cast<ConstantInt>(Val: V)->getSExtValue()) * Granularity); |
200 | else if (isa<Instruction>(Val: V)) { |
201 | Builder.SetInsertPoint(cast<Instruction>(Val: V)); |
202 | RealCol = Builder.CreateNUWMul(LHS: V, RHS: Builder.getInt16(C: Granularity)); |
203 | cast<Instruction>(Val: RealCol)->moveAfter(MovePos: cast<Instruction>(Val: V)); |
204 | } else { |
205 | // When it is not a const value and it is a function argument, we create |
206 | // Row at the entry bb. |
207 | IRBuilder<> NewBuilder( |
208 | getFirstNonAllocaInTheEntryBlock(F&: *II->getFunction())); |
209 | RealCol = NewBuilder.CreateNUWMul(LHS: V, RHS: NewBuilder.getInt16(C: Granularity)); |
210 | } |
211 | Row2Col[V] = RealCol; |
212 | return RealCol; |
213 | } |
214 | |
215 | // TODO: Refine the row and col-in-bytes of tile to row and col of matrix. |
216 | std::pair<Value *, Value *> ShapeCalculator::getShape(IntrinsicInst *II, |
217 | unsigned OpNo) { |
218 | (void)TM; |
219 | IRBuilder<> Builder(II); |
220 | Value *Row = nullptr, *Col = nullptr; |
221 | switch (II->getIntrinsicID()) { |
222 | default: |
223 | llvm_unreachable("Expect amx intrinsics" ); |
224 | case Intrinsic::x86_t2rpntlvwz0_internal: |
225 | case Intrinsic::x86_t2rpntlvwz0t1_internal: |
226 | case Intrinsic::x86_t2rpntlvwz1_internal: |
227 | case Intrinsic::x86_t2rpntlvwz1t1_internal: |
228 | case Intrinsic::x86_tileloadd64_internal: |
229 | case Intrinsic::x86_tileloaddt164_internal: |
230 | case Intrinsic::x86_tilestored64_internal: |
231 | case Intrinsic::x86_t2rpntlvwz0rs_internal: |
232 | case Intrinsic::x86_t2rpntlvwz0rst1_internal: |
233 | case Intrinsic::x86_t2rpntlvwz1rs_internal: |
234 | case Intrinsic::x86_t2rpntlvwz1rst1_internal: |
235 | case Intrinsic::x86_tileloaddrs64_internal: |
236 | case Intrinsic::x86_tileloaddrst164_internal: { |
237 | Row = II->getArgOperand(i: 0); |
238 | Col = II->getArgOperand(i: 1); |
239 | break; |
240 | } |
241 | // a * b + c |
242 | // The shape depends on which operand. |
243 | case Intrinsic::x86_tcmmimfp16ps_internal: |
244 | case Intrinsic::x86_tcmmrlfp16ps_internal: |
245 | case Intrinsic::x86_tdpbssd_internal: |
246 | case Intrinsic::x86_tdpbsud_internal: |
247 | case Intrinsic::x86_tdpbusd_internal: |
248 | case Intrinsic::x86_tdpbuud_internal: |
249 | case Intrinsic::x86_tdpbf16ps_internal: |
250 | case Intrinsic::x86_tdpfp16ps_internal: |
251 | case Intrinsic::x86_tmmultf32ps_internal: |
252 | case Intrinsic::x86_tdpbf8ps_internal: |
253 | case Intrinsic::x86_tdpbhf8ps_internal: |
254 | case Intrinsic::x86_tdphbf8ps_internal: |
255 | case Intrinsic::x86_tdphf8ps_internal: { |
256 | switch (OpNo) { |
257 | case 3: |
258 | Row = II->getArgOperand(i: 0); |
259 | Col = II->getArgOperand(i: 1); |
260 | break; |
261 | case 4: |
262 | Row = II->getArgOperand(i: 0); |
263 | Col = II->getArgOperand(i: 2); |
264 | break; |
265 | case 5: |
266 | Row = getRowFromCol(II, V: II->getArgOperand(i: 2), Granularity: 4); |
267 | Col = II->getArgOperand(i: 1); |
268 | break; |
269 | } |
270 | break; |
271 | } |
272 | case Intrinsic::x86_ttransposed_internal: |
273 | case Intrinsic::x86_tconjtfp16_internal: { |
274 | assert((OpNo == 2) && "Illegal Operand Number." ); |
275 | Row = getRowFromCol(II, V: II->getArgOperand(i: 1), Granularity: 4); |
276 | Col = getColFromRow(II, V: II->getArgOperand(i: 0), Granularity: 4); |
277 | break; |
278 | } |
279 | case Intrinsic::x86_tcvtrowd2ps_internal: |
280 | case Intrinsic::x86_tcvtrowps2bf16h_internal: |
281 | case Intrinsic::x86_tcvtrowps2bf16l_internal: |
282 | case Intrinsic::x86_tcvtrowps2phh_internal: |
283 | case Intrinsic::x86_tcvtrowps2phl_internal: |
284 | case Intrinsic::x86_tilemovrow_internal: { |
285 | assert(OpNo == 2 && "Illegal Operand Number." ); |
286 | Row = II->getArgOperand(i: 0); |
287 | Col = II->getArgOperand(i: 1); |
288 | break; |
289 | } |
290 | case Intrinsic::x86_ttdpbf16ps_internal: |
291 | case Intrinsic::x86_ttdpfp16ps_internal: |
292 | case Intrinsic::x86_ttcmmimfp16ps_internal: |
293 | case Intrinsic::x86_ttcmmrlfp16ps_internal: |
294 | case Intrinsic::x86_tconjtcmmimfp16ps_internal: |
295 | case Intrinsic::x86_ttmmultf32ps_internal: { |
296 | switch (OpNo) { |
297 | case 3: |
298 | Row = II->getArgOperand(i: 0); |
299 | Col = II->getArgOperand(i: 1); |
300 | break; |
301 | case 4: |
302 | Row = getRowFromCol(II, V: II->getArgOperand(i: 2), Granularity: 4); |
303 | Col = getColFromRow(II, V: II->getArgOperand(i: 0), Granularity: 4); |
304 | break; |
305 | case 5: |
306 | Row = getRowFromCol(II, V: II->getArgOperand(i: 2), Granularity: 4); |
307 | Col = II->getArgOperand(i: 1); |
308 | break; |
309 | } |
310 | break; |
311 | } |
312 | } |
313 | |
314 | return std::make_pair(x&: Row, y&: Col); |
315 | } |
316 | |
317 | std::pair<Value *, Value *> ShapeCalculator::getShape(PHINode *Phi) { |
318 | Use &U = *(Phi->use_begin()); |
319 | unsigned OpNo = U.getOperandNo(); |
320 | User *V = U.getUser(); |
321 | // TODO We don't traverse all users. To make the algorithm simple, here we |
322 | // just traverse the first user. If we can find shape, then return the shape, |
323 | // otherwise just return nullptr and the optimization for undef/zero will be |
324 | // abandoned. |
325 | while (V) { |
326 | if (isAMXCast(II: dyn_cast<Instruction>(Val: V))) { |
327 | if (V->use_empty()) |
328 | break; |
329 | Use &U = *(V->use_begin()); |
330 | OpNo = U.getOperandNo(); |
331 | V = U.getUser(); |
332 | } else if (isAMXIntrinsic(I: V)) { |
333 | return getShape(II: cast<IntrinsicInst>(Val: V), OpNo); |
334 | } else if (isa<PHINode>(Val: V)) { |
335 | if (V->use_empty()) |
336 | break; |
337 | Use &U = *(V->use_begin()); |
338 | V = U.getUser(); |
339 | } else { |
340 | break; |
341 | } |
342 | } |
343 | |
344 | return std::make_pair(x: nullptr, y: nullptr); |
345 | } |
346 | |
347 | namespace { |
348 | class X86LowerAMXType { |
349 | Function &Func; |
350 | ShapeCalculator *SC; |
351 | |
352 | // In AMX intrinsics we let Shape = {Row, Col}, but the |
353 | // RealCol = Col / ElementSize. We may use the RealCol |
354 | // as a new Row for other new created AMX intrinsics. |
355 | std::map<Value *, Value *> Col2Row, Row2Col; |
356 | |
357 | public: |
358 | X86LowerAMXType(Function &F, ShapeCalculator *ShapeC) : Func(F), SC(ShapeC) {} |
359 | bool visit(); |
360 | void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast); |
361 | void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST); |
362 | bool transformBitcast(BitCastInst *Bitcast); |
363 | }; |
364 | |
365 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 |
366 | // %2 = bitcast <256 x i32> %src to x86_amx |
367 | // --> |
368 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, |
369 | // i8* %addr, i64 %stride64) |
370 | void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) { |
371 | Value *Row = nullptr, *Col = nullptr; |
372 | Use &U = *(Bitcast->use_begin()); |
373 | unsigned OpNo = U.getOperandNo(); |
374 | auto *II = cast<IntrinsicInst>(Val: U.getUser()); |
375 | std::tie(args&: Row, args&: Col) = SC->getShape(II, OpNo); |
376 | IRBuilder<> Builder(Bitcast); |
377 | // Use the maximun column as stride. |
378 | Value *Stride = Builder.getInt64(C: 64); |
379 | Value *I8Ptr = LD->getOperand(i_nocapture: 0); |
380 | std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; |
381 | |
382 | Value *NewInst = |
383 | Builder.CreateIntrinsic(ID: Intrinsic::x86_tileloadd64_internal, Args); |
384 | Bitcast->replaceAllUsesWith(V: NewInst); |
385 | } |
386 | |
387 | // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, |
388 | // %stride); |
389 | // %13 = bitcast x86_amx %src to <256 x i32> |
390 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 |
391 | // --> |
392 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, |
393 | // %stride64, %13) |
394 | void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) { |
395 | |
396 | Value *Tile = Bitcast->getOperand(i_nocapture: 0); |
397 | auto *II = cast<IntrinsicInst>(Val: Tile); |
398 | // Tile is output from AMX intrinsic. The first operand of the |
399 | // intrinsic is row, the second operand of the intrinsic is column. |
400 | Value *Row = II->getOperand(i_nocapture: 0); |
401 | Value *Col = II->getOperand(i_nocapture: 1); |
402 | IRBuilder<> Builder(ST); |
403 | // Use the maximum column as stride. It must be the same with load |
404 | // stride. |
405 | Value *Stride = Builder.getInt64(C: 64); |
406 | Value *I8Ptr = ST->getOperand(i_nocapture: 1); |
407 | std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile}; |
408 | Builder.CreateIntrinsic(ID: Intrinsic::x86_tilestored64_internal, Args); |
409 | if (Bitcast->hasOneUse()) |
410 | return; |
411 | // %13 = bitcast x86_amx %src to <256 x i32> |
412 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 |
413 | // %add = <256 x i32> %13, <256 x i32> %src2 |
414 | // --> |
415 | // %13 = bitcast x86_amx %src to <256 x i32> |
416 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, |
417 | // %stride64, %13) |
418 | // %14 = load <256 x i32>, %addr |
419 | // %add = <256 x i32> %14, <256 x i32> %src2 |
420 | Value *Vec = Builder.CreateLoad(Ty: Bitcast->getType(), Ptr: ST->getOperand(i_nocapture: 1)); |
421 | Bitcast->replaceAllUsesWith(V: Vec); |
422 | } |
423 | |
424 | // transform bitcast to <store, load> instructions. |
425 | bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) { |
426 | IRBuilder<> Builder(Bitcast); |
427 | AllocaInst *AllocaAddr; |
428 | Value *I8Ptr, *Stride; |
429 | auto *Src = Bitcast->getOperand(i_nocapture: 0); |
430 | |
431 | auto Prepare = [&](Type *MemTy) { |
432 | AllocaAddr = createAllocaInstAtEntry(Builder, BB: Bitcast->getParent(), Ty: MemTy); |
433 | I8Ptr = AllocaAddr; |
434 | Stride = Builder.getInt64(C: 64); |
435 | }; |
436 | |
437 | if (Bitcast->getType()->isX86_AMXTy()) { |
438 | // %2 = bitcast <256 x i32> %src to x86_amx |
439 | // --> |
440 | // %addr = alloca <256 x i32>, align 64 |
441 | // store <256 x i32> %src, <256 x i32>* %addr, align 64 |
442 | // %addr2 = bitcast <256 x i32>* to i8* |
443 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, |
444 | // i8* %addr2, |
445 | // i64 64) |
446 | Use &U = *(Bitcast->use_begin()); |
447 | unsigned OpNo = U.getOperandNo(); |
448 | auto *II = dyn_cast<IntrinsicInst>(Val: U.getUser()); |
449 | if (!II) |
450 | return false; // May be bitcast from x86amx to <256 x i32>. |
451 | Prepare(Bitcast->getOperand(i_nocapture: 0)->getType()); |
452 | Builder.CreateStore(Val: Src, Ptr: AllocaAddr); |
453 | // TODO we can pick an constant operand for the shape. |
454 | Value *Row = nullptr, *Col = nullptr; |
455 | std::tie(args&: Row, args&: Col) = SC->getShape(II, OpNo); |
456 | std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; |
457 | Value *NewInst = |
458 | Builder.CreateIntrinsic(ID: Intrinsic::x86_tileloadd64_internal, Args); |
459 | Bitcast->replaceAllUsesWith(V: NewInst); |
460 | } else { |
461 | // %2 = bitcast x86_amx %src to <256 x i32> |
462 | // --> |
463 | // %addr = alloca <256 x i32>, align 64 |
464 | // %addr2 = bitcast <256 x i32>* to i8* |
465 | // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, |
466 | // i8* %addr2, i64 %stride) |
467 | // %2 = load <256 x i32>, <256 x i32>* %addr, align 64 |
468 | auto *II = dyn_cast<IntrinsicInst>(Val: Src); |
469 | if (!II) |
470 | return false; // May be bitcast from <256 x i32> to x86amx. |
471 | Prepare(Bitcast->getType()); |
472 | Value *Row = II->getOperand(i_nocapture: 0); |
473 | Value *Col = II->getOperand(i_nocapture: 1); |
474 | std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src}; |
475 | Builder.CreateIntrinsic(ID: Intrinsic::x86_tilestored64_internal, Args); |
476 | Value *NewInst = Builder.CreateLoad(Ty: Bitcast->getType(), Ptr: AllocaAddr); |
477 | Bitcast->replaceAllUsesWith(V: NewInst); |
478 | } |
479 | |
480 | return true; |
481 | } |
482 | |
483 | bool X86LowerAMXType::visit() { |
484 | SmallVector<Instruction *, 8> DeadInsts; |
485 | Col2Row.clear(); |
486 | |
487 | for (BasicBlock *BB : post_order(G: &Func)) { |
488 | for (Instruction &Inst : llvm::make_early_inc_range(Range: llvm::reverse(C&: *BB))) { |
489 | auto *Bitcast = dyn_cast<BitCastInst>(Val: &Inst); |
490 | if (!Bitcast) |
491 | continue; |
492 | |
493 | Value *Src = Bitcast->getOperand(i_nocapture: 0); |
494 | if (Bitcast->getType()->isX86_AMXTy()) { |
495 | if (Bitcast->user_empty()) { |
496 | DeadInsts.push_back(Elt: Bitcast); |
497 | continue; |
498 | } |
499 | LoadInst *LD = dyn_cast<LoadInst>(Val: Src); |
500 | if (!LD) { |
501 | if (transformBitcast(Bitcast)) |
502 | DeadInsts.push_back(Elt: Bitcast); |
503 | continue; |
504 | } |
505 | // If load has multi-user, duplicate a vector load. |
506 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 |
507 | // %2 = bitcast <256 x i32> %src to x86_amx |
508 | // %add = add <256 x i32> %src, <256 x i32> %src2 |
509 | // --> |
510 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 |
511 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, |
512 | // i8* %addr, i64 %stride64) |
513 | // %add = add <256 x i32> %src, <256 x i32> %src2 |
514 | |
515 | // If load has one user, the load will be eliminated in DAG ISel. |
516 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 |
517 | // %2 = bitcast <256 x i32> %src to x86_amx |
518 | // --> |
519 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, |
520 | // i8* %addr, i64 %stride64) |
521 | combineLoadBitcast(LD, Bitcast); |
522 | DeadInsts.push_back(Elt: Bitcast); |
523 | if (LD->hasOneUse()) |
524 | DeadInsts.push_back(Elt: LD); |
525 | } else if (Src->getType()->isX86_AMXTy()) { |
526 | if (Bitcast->user_empty()) { |
527 | DeadInsts.push_back(Elt: Bitcast); |
528 | continue; |
529 | } |
530 | StoreInst *ST = nullptr; |
531 | for (Use &U : Bitcast->uses()) { |
532 | ST = dyn_cast<StoreInst>(Val: U.getUser()); |
533 | if (ST) |
534 | break; |
535 | } |
536 | if (!ST) { |
537 | if (transformBitcast(Bitcast)) |
538 | DeadInsts.push_back(Elt: Bitcast); |
539 | continue; |
540 | } |
541 | // If bitcast (%13) has one use, combine bitcast and store to amx store. |
542 | // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, |
543 | // %stride); |
544 | // %13 = bitcast x86_amx %src to <256 x i32> |
545 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 |
546 | // --> |
547 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, |
548 | // %stride64, %13) |
549 | // |
550 | // If bitcast (%13) has multi-use, transform as below. |
551 | // %13 = bitcast x86_amx %src to <256 x i32> |
552 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 |
553 | // %add = <256 x i32> %13, <256 x i32> %src2 |
554 | // --> |
555 | // %13 = bitcast x86_amx %src to <256 x i32> |
556 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, |
557 | // %stride64, %13) |
558 | // %14 = load <256 x i32>, %addr |
559 | // %add = <256 x i32> %14, <256 x i32> %src2 |
560 | // |
561 | combineBitcastStore(Bitcast, ST); |
562 | // Delete user first. |
563 | DeadInsts.push_back(Elt: ST); |
564 | DeadInsts.push_back(Elt: Bitcast); |
565 | } |
566 | } |
567 | } |
568 | |
569 | bool C = !DeadInsts.empty(); |
570 | |
571 | for (auto *Inst : DeadInsts) |
572 | Inst->eraseFromParent(); |
573 | |
574 | return C; |
575 | } |
576 | } // anonymous namespace |
577 | |
578 | static Value *getAllocaPos(BasicBlock *BB) { |
579 | Function *F = BB->getParent(); |
580 | IRBuilder<> Builder(&F->getEntryBlock().front()); |
581 | const DataLayout &DL = F->getDataLayout(); |
582 | unsigned AllocaAS = DL.getAllocaAddrSpace(); |
583 | Type *V256I32Ty = VectorType::get(ElementType: Builder.getInt32Ty(), NumElements: 256, Scalable: false); |
584 | AllocaInst *AllocaRes = |
585 | new AllocaInst(V256I32Ty, AllocaAS, "" , F->getEntryBlock().begin()); |
586 | BasicBlock::iterator Iter = AllocaRes->getIterator(); |
587 | ++Iter; |
588 | Builder.SetInsertPoint(&*Iter); |
589 | Value *I8Ptr = Builder.CreateBitCast(V: AllocaRes, DestTy: Builder.getPtrTy()); |
590 | return I8Ptr; |
591 | } |
592 | |
593 | static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) { |
594 | assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!" ); |
595 | auto *II = dyn_cast<IntrinsicInst>(Val: TileDef); |
596 | unsigned Idx = 0; |
597 | // Extract tile from multiple tiles' def. |
598 | if (auto *Extr = dyn_cast<ExtractValueInst>(Val: TileDef)) { |
599 | assert(Extr->hasIndices() && "Tile extract miss index!" ); |
600 | Idx = Extr->getIndices()[0]; |
601 | II = cast<IntrinsicInst>(Val: Extr->getOperand(i_nocapture: 0)); |
602 | } |
603 | |
604 | assert(II && "Not tile intrinsic!" ); |
605 | Value *Row = II->getOperand(i_nocapture: Idx); |
606 | Value *Col = II->getOperand(i_nocapture: Idx + 1); |
607 | |
608 | BasicBlock *BB = TileDef->getParent(); |
609 | BasicBlock::iterator Iter = TileDef->getIterator(); |
610 | IRBuilder<> Builder(BB, ++Iter); |
611 | Value *Stride = Builder.getInt64(C: 64); |
612 | std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef}; |
613 | |
614 | Instruction *TileStore = |
615 | Builder.CreateIntrinsic(ID: Intrinsic::x86_tilestored64_internal, Args); |
616 | return TileStore; |
617 | } |
618 | |
619 | static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) { |
620 | Value *V = U.get(); |
621 | assert(V->getType()->isX86_AMXTy() && "Not define tile!" ); |
622 | |
623 | // Get tile shape. |
624 | IntrinsicInst *II = nullptr; |
625 | unsigned Idx = 0; |
626 | if (IsPHI) { |
627 | Value *PhiOp = cast<PHINode>(Val: V)->getIncomingValue(i: 0); |
628 | II = cast<IntrinsicInst>(Val: PhiOp); |
629 | } else if (auto *Extr = dyn_cast<ExtractValueInst>(Val: V)) { |
630 | // Extract tile from multiple tiles' def. |
631 | assert(Extr->hasIndices() && "Tile extract miss index!" ); |
632 | Idx = Extr->getIndices()[0]; |
633 | II = cast<IntrinsicInst>(Val: Extr->getOperand(i_nocapture: 0)); |
634 | } else { |
635 | II = cast<IntrinsicInst>(Val: V); |
636 | } |
637 | Value *Row = II->getOperand(i_nocapture: Idx); |
638 | Value *Col = II->getOperand(i_nocapture: Idx + 1); |
639 | |
640 | Instruction *UserI = cast<Instruction>(Val: U.getUser()); |
641 | IRBuilder<> Builder(UserI); |
642 | Value *Stride = Builder.getInt64(C: 64); |
643 | std::array<Value *, 4> Args = {Row, Col, Ptr, Stride}; |
644 | |
645 | Value *TileLoad = |
646 | Builder.CreateIntrinsic(ID: Intrinsic::x86_tileloadd64_internal, Args); |
647 | UserI->replaceUsesOfWith(From: V, To: TileLoad); |
648 | } |
649 | |
650 | static bool isIncomingOfPHI(Instruction *I) { |
651 | for (Use &U : I->uses()) { |
652 | User *V = U.getUser(); |
653 | if (isa<PHINode>(Val: V)) |
654 | return true; |
655 | } |
656 | return false; |
657 | } |
658 | |
659 | // Let all AMX tile data become volatile data, shorten the life range |
660 | // of each tile register before fast register allocation. |
661 | namespace { |
662 | class X86VolatileTileData { |
663 | Function &F; |
664 | |
665 | public: |
666 | X86VolatileTileData(Function &Func) : F(Func) {} |
667 | Value *updatePhiIncomings(BasicBlock *BB, |
668 | SmallVector<Instruction *, 2> &Incomings); |
669 | void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr); |
670 | bool volatileTileData(); |
671 | void volatileTilePHI(PHINode *PHI); |
672 | void volatileTileNonPHI(Instruction *I); |
673 | }; |
674 | |
675 | Value *X86VolatileTileData::updatePhiIncomings( |
676 | BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) { |
677 | Value *I8Ptr = getAllocaPos(BB); |
678 | |
679 | for (auto *I : Incomings) { |
680 | User *Store = createTileStore(TileDef: I, Ptr: I8Ptr); |
681 | |
682 | // All its uses (except phi) should load from stored mem. |
683 | for (Use &U : I->uses()) { |
684 | User *V = U.getUser(); |
685 | if (isa<PHINode>(Val: V) || V == Store) |
686 | continue; |
687 | replaceWithTileLoad(U, Ptr: I8Ptr); |
688 | } |
689 | } |
690 | return I8Ptr; |
691 | } |
692 | |
693 | void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI, |
694 | Value *StorePtr) { |
695 | for (Use &U : PHI->uses()) |
696 | replaceWithTileLoad(U, Ptr: StorePtr, IsPHI: true); |
697 | PHI->eraseFromParent(); |
698 | } |
699 | |
700 | // Smilar with volatileTileNonPHI, this function only handle PHI Nodes |
701 | // and their related AMX intrinsics. |
702 | // 1) PHI Def should change to tileload. |
703 | // 2) PHI Incoming Values should tilestored in just after their def. |
704 | // 3) The mem of these tileload and tilestores should be same. |
705 | // e.g. |
706 | // ------------------------------------------------------ |
707 | // bb_dom: |
708 | // ... |
709 | // br i1 %bool.cond, label %if.else, label %if.then |
710 | // |
711 | // if.then: |
712 | // def %t0 = ... |
713 | // ... |
714 | // use %t0 |
715 | // ... |
716 | // br label %if.end |
717 | // |
718 | // if.else: |
719 | // def %t1 = ... |
720 | // br label %if.end |
721 | // |
722 | // if.end: |
723 | // %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ] |
724 | // ... |
725 | // use %td |
726 | // ------------------------------------------------------ |
727 | // --> |
728 | // ------------------------------------------------------ |
729 | // bb_entry: |
730 | // %mem = alloca <256 x i32>, align 1024 * |
731 | // ... |
732 | // bb_dom: |
733 | // ... |
734 | // br i1 %bool.cond, label %if.else, label %if.then |
735 | // |
736 | // if.then: |
737 | // def %t0 = ... |
738 | // call void @llvm.x86.tilestored64.internal(mem, %t0) * |
739 | // ... |
740 | // %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)* |
741 | // use %t0` * |
742 | // ... |
743 | // br label %if.end |
744 | // |
745 | // if.else: |
746 | // def %t1 = ... |
747 | // call void @llvm.x86.tilestored64.internal(mem, %t1) * |
748 | // br label %if.end |
749 | // |
750 | // if.end: |
751 | // ... |
752 | // %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) * |
753 | // use %td |
754 | // ------------------------------------------------------ |
755 | void X86VolatileTileData::volatileTilePHI(PHINode *PHI) { |
756 | BasicBlock *BB = PHI->getParent(); |
757 | SmallVector<Instruction *, 2> Incomings; |
758 | |
759 | for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) { |
760 | Value *Op = PHI->getIncomingValue(i: I); |
761 | Instruction *Inst = dyn_cast<Instruction>(Val: Op); |
762 | assert(Inst && "We shouldn't fold AMX instrution!" ); |
763 | Incomings.push_back(Elt: Inst); |
764 | } |
765 | |
766 | Value *StorePtr = updatePhiIncomings(BB, Incomings); |
767 | replacePhiDefWithLoad(PHI, StorePtr); |
768 | } |
769 | |
770 | // Store the defined tile and load it before use. |
771 | // All its users are not PHI. |
772 | // e.g. |
773 | // ------------------------------------------------------ |
774 | // def %td = ... |
775 | // ... |
776 | // "use %td" |
777 | // ------------------------------------------------------ |
778 | // --> |
779 | // ------------------------------------------------------ |
780 | // def %td = ... |
781 | // call void @llvm.x86.tilestored64.internal(mem, %td) |
782 | // ... |
783 | // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem) |
784 | // "use %td2" |
785 | // ------------------------------------------------------ |
786 | void X86VolatileTileData::volatileTileNonPHI(Instruction *I) { |
787 | BasicBlock *BB = I->getParent(); |
788 | Value *I8Ptr = getAllocaPos(BB); |
789 | User *Store = createTileStore(TileDef: I, Ptr: I8Ptr); |
790 | |
791 | // All its uses should load from stored mem. |
792 | for (Use &U : I->uses()) { |
793 | User *V = U.getUser(); |
794 | assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!" ); |
795 | if (V != Store) |
796 | replaceWithTileLoad(U, Ptr: I8Ptr); |
797 | } |
798 | } |
799 | |
800 | // Volatile Tile Model: |
801 | // 1) All the uses of tile data comes from tileload in time. |
802 | // 2) All the defs of tile data tilestore into mem immediately. |
803 | // For example: |
804 | // -------------------------------------------------------------------------- |
805 | // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key |
806 | // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) |
807 | // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx |
808 | // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3) |
809 | // call void @llvm.x86.tilestored64.internal(... td) area |
810 | // -------------------------------------------------------------------------- |
811 | // 3) No terminator, call or other amx instructions in the key amx area. |
812 | bool X86VolatileTileData::volatileTileData() { |
813 | bool Changed = false; |
814 | for (BasicBlock &BB : F) { |
815 | SmallVector<Instruction *, 2> PHIInsts; |
816 | SmallVector<Instruction *, 8> AMXDefInsts; |
817 | |
818 | for (Instruction &I : BB) { |
819 | if (!I.getType()->isX86_AMXTy()) |
820 | continue; |
821 | if (isa<PHINode>(Val: &I)) |
822 | PHIInsts.push_back(Elt: &I); |
823 | else |
824 | AMXDefInsts.push_back(Elt: &I); |
825 | } |
826 | |
827 | // First we "volatile" the non-phi related amx intrinsics. |
828 | for (Instruction *I : AMXDefInsts) { |
829 | if (isIncomingOfPHI(I)) |
830 | continue; |
831 | volatileTileNonPHI(I); |
832 | Changed = true; |
833 | } |
834 | |
835 | for (Instruction *I : PHIInsts) { |
836 | volatileTilePHI(PHI: dyn_cast<PHINode>(Val: I)); |
837 | Changed = true; |
838 | } |
839 | } |
840 | return Changed; |
841 | } |
842 | |
843 | } // anonymous namespace |
844 | |
845 | namespace { |
846 | |
847 | class X86LowerAMXCast { |
848 | Function &Func; |
849 | ShapeCalculator *SC; |
850 | std::unique_ptr<DominatorTree> DT; |
851 | |
852 | public: |
853 | X86LowerAMXCast(Function &F, ShapeCalculator *ShapeC) |
854 | : Func(F), SC(ShapeC), DT(nullptr) {} |
855 | bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST); |
856 | bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD); |
857 | bool combineLdSt(SmallVectorImpl<Instruction *> &Casts); |
858 | bool combineAMXcast(TargetLibraryInfo *TLI); |
859 | bool transformAMXCast(IntrinsicInst *AMXCast); |
860 | bool transformAllAMXCast(); |
861 | bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN, |
862 | SmallSetVector<Instruction *, 16> &DeadInst); |
863 | }; |
864 | |
865 | static bool DCEInstruction(Instruction *I, |
866 | SmallSetVector<Instruction *, 16> &WorkList, |
867 | const TargetLibraryInfo *TLI) { |
868 | if (isInstructionTriviallyDead(I, TLI)) { |
869 | salvageDebugInfo(I&: *I); |
870 | salvageKnowledge(I); |
871 | |
872 | // Null out all of the instruction's operands to see if any operand becomes |
873 | // dead as we go. |
874 | for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { |
875 | Value *OpV = I->getOperand(i); |
876 | I->setOperand(i, Val: nullptr); |
877 | |
878 | if (!OpV->use_empty() || I == OpV) |
879 | continue; |
880 | |
881 | // If the operand is an instruction that became dead as we nulled out the |
882 | // operand, and if it is 'trivially' dead, delete it in a future loop |
883 | // iteration. |
884 | if (Instruction *OpI = dyn_cast<Instruction>(Val: OpV)) { |
885 | if (isInstructionTriviallyDead(I: OpI, TLI)) { |
886 | WorkList.insert(X: OpI); |
887 | } |
888 | } |
889 | } |
890 | I->eraseFromParent(); |
891 | return true; |
892 | } |
893 | return false; |
894 | } |
895 | |
896 | /// This function handles following case |
897 | /// |
898 | /// A -> B amxcast |
899 | /// PHI |
900 | /// B -> A amxcast |
901 | /// |
902 | /// All the related PHI nodes can be replaced by new PHI nodes with type A. |
903 | /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN. |
904 | bool X86LowerAMXCast::optimizeAMXCastFromPhi( |
905 | IntrinsicInst *CI, PHINode *PN, |
906 | SmallSetVector<Instruction *, 16> &DeadInst) { |
907 | IRBuilder<> Builder(CI); |
908 | Value *Src = CI->getOperand(i_nocapture: 0); |
909 | Type *SrcTy = Src->getType(); // Type B |
910 | Type *DestTy = CI->getType(); // Type A |
911 | |
912 | SmallVector<PHINode *, 4> PhiWorklist; |
913 | SmallSetVector<PHINode *, 4> OldPhiNodes; |
914 | |
915 | // Find all of the A->B casts and PHI nodes. |
916 | // We need to inspect all related PHI nodes, but PHIs can be cyclic, so |
917 | // OldPhiNodes is used to track all known PHI nodes, before adding a new |
918 | // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first. |
919 | PhiWorklist.push_back(Elt: PN); |
920 | OldPhiNodes.insert(X: PN); |
921 | while (!PhiWorklist.empty()) { |
922 | auto *OldPN = PhiWorklist.pop_back_val(); |
923 | for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) { |
924 | Value *IncValue = OldPN->getIncomingValue(i: I); |
925 | // TODO: currently, We ignore cases where it is a const. In the future, we |
926 | // might support const. |
927 | if (isa<Constant>(Val: IncValue)) { |
928 | auto *IncConst = dyn_cast<Constant>(Val: IncValue); |
929 | if (!isa<UndefValue>(Val: IncValue) && !IncConst->isZeroValue()) |
930 | return false; |
931 | Value *Row = nullptr, *Col = nullptr; |
932 | std::tie(args&: Row, args&: Col) = SC->getShape(Phi: OldPN); |
933 | // TODO: If it is not constant the Row and Col must domoniate tilezero |
934 | // that we are going to create. |
935 | if (!Row || !Col || !isa<Constant>(Val: Row) || !isa<Constant>(Val: Col)) |
936 | return false; |
937 | // Create tilezero at the end of incoming block. |
938 | auto *Block = OldPN->getIncomingBlock(i: I); |
939 | BasicBlock::iterator Iter = Block->getTerminator()->getIterator(); |
940 | Instruction *NewInst = Builder.CreateIntrinsic( |
941 | ID: Intrinsic::x86_tilezero_internal, Types: {}, Args: {Row, Col}); |
942 | NewInst->moveBefore(InsertPos: Iter); |
943 | NewInst = Builder.CreateIntrinsic(ID: Intrinsic::x86_cast_tile_to_vector, |
944 | Types: {IncValue->getType()}, Args: {NewInst}); |
945 | NewInst->moveBefore(InsertPos: Iter); |
946 | // Replace InValue with new Value. |
947 | OldPN->setIncomingValue(i: I, V: NewInst); |
948 | IncValue = NewInst; |
949 | } |
950 | |
951 | if (auto *PNode = dyn_cast<PHINode>(Val: IncValue)) { |
952 | if (OldPhiNodes.insert(X: PNode)) |
953 | PhiWorklist.push_back(Elt: PNode); |
954 | continue; |
955 | } |
956 | Instruction *ACI = dyn_cast<Instruction>(Val: IncValue); |
957 | if (ACI && isAMXCast(II: ACI)) { |
958 | // Verify it's a A->B cast. |
959 | Type *TyA = ACI->getOperand(i: 0)->getType(); |
960 | Type *TyB = ACI->getType(); |
961 | if (TyA != DestTy || TyB != SrcTy) |
962 | return false; |
963 | continue; |
964 | } |
965 | return false; |
966 | } |
967 | } |
968 | |
969 | // Check that each user of each old PHI node is something that we can |
970 | // rewrite, so that all of the old PHI nodes can be cleaned up afterwards. |
971 | for (auto *OldPN : OldPhiNodes) { |
972 | for (User *V : OldPN->users()) { |
973 | Instruction *ACI = dyn_cast<Instruction>(Val: V); |
974 | if (ACI && isAMXCast(II: ACI)) { |
975 | // Verify it's a B->A cast. |
976 | Type *TyB = ACI->getOperand(i: 0)->getType(); |
977 | Type *TyA = ACI->getType(); |
978 | if (TyA != DestTy || TyB != SrcTy) |
979 | return false; |
980 | } else if (auto *PHI = dyn_cast<PHINode>(Val: V)) { |
981 | // As long as the user is another old PHI node, then even if we don't |
982 | // rewrite it, the PHI web we're considering won't have any users |
983 | // outside itself, so it'll be dead. |
984 | // example: |
985 | // bb.0: |
986 | // %0 = amxcast ... |
987 | // bb.1: |
988 | // %1 = amxcast ... |
989 | // bb.2: |
990 | // %goodphi = phi %0, %1 |
991 | // %3 = amxcast %goodphi |
992 | // bb.3: |
993 | // %goodphi2 = phi %0, %goodphi |
994 | // %4 = amxcast %goodphi2 |
995 | // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is |
996 | // outside the phi-web, so the combination stop When |
997 | // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization |
998 | // will be done. |
999 | if (OldPhiNodes.count(key: PHI) == 0) |
1000 | return false; |
1001 | } else |
1002 | return false; |
1003 | } |
1004 | } |
1005 | |
1006 | // For each old PHI node, create a corresponding new PHI node with a type A. |
1007 | SmallDenseMap<PHINode *, PHINode *> NewPNodes; |
1008 | for (auto *OldPN : OldPhiNodes) { |
1009 | Builder.SetInsertPoint(OldPN); |
1010 | PHINode *NewPN = Builder.CreatePHI(Ty: DestTy, NumReservedValues: OldPN->getNumOperands()); |
1011 | NewPNodes[OldPN] = NewPN; |
1012 | } |
1013 | |
1014 | // Fill in the operands of new PHI nodes. |
1015 | for (auto *OldPN : OldPhiNodes) { |
1016 | PHINode *NewPN = NewPNodes[OldPN]; |
1017 | for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) { |
1018 | Value *V = OldPN->getOperand(i_nocapture: j); |
1019 | Value *NewV = nullptr; |
1020 | Instruction *ACI = dyn_cast<Instruction>(Val: V); |
1021 | // There should not be a AMXcast from a const. |
1022 | if (ACI && isAMXCast(II: ACI)) |
1023 | NewV = ACI->getOperand(i: 0); |
1024 | else if (auto *PrevPN = dyn_cast<PHINode>(Val: V)) |
1025 | NewV = NewPNodes[PrevPN]; |
1026 | assert(NewV); |
1027 | NewPN->addIncoming(V: NewV, BB: OldPN->getIncomingBlock(i: j)); |
1028 | } |
1029 | } |
1030 | |
1031 | // Traverse all accumulated PHI nodes and process its users, |
1032 | // which are Stores and BitcCasts. Without this processing |
1033 | // NewPHI nodes could be replicated and could lead to extra |
1034 | // moves generated after DeSSA. |
1035 | // If there is a store with type B, change it to type A. |
1036 | |
1037 | // Replace users of BitCast B->A with NewPHI. These will help |
1038 | // later to get rid of a closure formed by OldPHI nodes. |
1039 | for (auto *OldPN : OldPhiNodes) { |
1040 | PHINode *NewPN = NewPNodes[OldPN]; |
1041 | for (User *V : make_early_inc_range(Range: OldPN->users())) { |
1042 | Instruction *ACI = dyn_cast<Instruction>(Val: V); |
1043 | if (ACI && isAMXCast(II: ACI)) { |
1044 | Type *TyB = ACI->getOperand(i: 0)->getType(); |
1045 | Type *TyA = ACI->getType(); |
1046 | assert(TyA == DestTy && TyB == SrcTy); |
1047 | (void)TyA; |
1048 | (void)TyB; |
1049 | ACI->replaceAllUsesWith(V: NewPN); |
1050 | DeadInst.insert(X: ACI); |
1051 | } else if (auto *PHI = dyn_cast<PHINode>(Val: V)) { |
1052 | // We don't need to push PHINode into DeadInst since they are operands |
1053 | // of rootPN DCE can safely delete rootPN's operands if rootPN is dead. |
1054 | assert(OldPhiNodes.contains(PHI)); |
1055 | (void)PHI; |
1056 | } else |
1057 | llvm_unreachable("all uses should be handled" ); |
1058 | } |
1059 | } |
1060 | return true; |
1061 | } |
1062 | |
1063 | static Value *getShapeFromAMXIntrinsic(Value *Inst, unsigned ShapeIdx, |
1064 | bool IsRow) { |
1065 | if (!isAMXIntrinsic(I: Inst)) |
1066 | return nullptr; |
1067 | |
1068 | auto *II = cast<IntrinsicInst>(Val: Inst); |
1069 | if (IsRow) |
1070 | return II->getOperand(i_nocapture: 0); |
1071 | |
1072 | assert(ShapeIdx < 2 && "Currently 2 shapes in 1 instruction at most!" ); |
1073 | return II->getOperand(i_nocapture: ShapeIdx + 1); |
1074 | } |
1075 | |
1076 | // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42) |
1077 | // store <256 x i32> %43, <256 x i32>* %p, align 64 |
1078 | // --> |
1079 | // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p, |
1080 | // i64 64, x86_amx %42) |
1081 | bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) { |
1082 | Value *Tile = Cast->getOperand(i_nocapture: 0); |
1083 | |
1084 | assert(Tile->getType()->isX86_AMXTy() && "Not Tile Operand!" ); |
1085 | |
1086 | // TODO: Specially handle the multi-use case. |
1087 | if (!Tile->hasOneUse()) |
1088 | return false; |
1089 | |
1090 | // We don't fetch shape from tilestore, we only get shape from tiledef, |
1091 | // so we can set the max tile shape to tilestore for special cases. |
1092 | IRBuilder<> Builder(ST); |
1093 | Value *Row = nullptr; |
1094 | Value *Col = nullptr; |
1095 | |
1096 | if (isAMXIntrinsic(I: Tile)) { |
1097 | auto *II = cast<IntrinsicInst>(Val: Tile); |
1098 | // Tile is output from AMX intrinsic. The first operand of the |
1099 | // intrinsic is row, the second operand of the intrinsic is column. |
1100 | Row = II->getOperand(i_nocapture: 0); |
1101 | Col = II->getOperand(i_nocapture: 1); |
1102 | } else { |
1103 | // Now we supported multi-tiles value in structure, so we may get tile |
1104 | // from extracting multi-tiles structure. |
1105 | // For example: |
1106 | // %6 = call { x86_amx, x86_amx } @llvm.x86.t2rpntlvwz0.internal(i16 %1, |
1107 | // i16 %2, i16 %3, i8* %4, i64 %5) |
1108 | // %7 = extractvalue { x86_amx, x86_amx } %6, 0 |
1109 | // %8 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %7) |
1110 | // store <256 x i32> %8, <256 x i32>* %0, align 1024 |
1111 | // |
1112 | // TODO: Currently we only handle extractvalue case, enhance me for other |
1113 | // cases if possible. |
1114 | auto *II = cast<ExtractValueInst>(Val: Tile); |
1115 | assert(II && "We meet unhandle source in fetching tile value!" ); |
1116 | unsigned ShapeIdx = II->getIndices()[0]; |
1117 | Value *Tiles = II->getOperand(i_nocapture: 0); |
1118 | Row = getShapeFromAMXIntrinsic(Inst: Tiles, ShapeIdx, IsRow: true); |
1119 | Col = getShapeFromAMXIntrinsic(Inst: Tiles, ShapeIdx, IsRow: false); |
1120 | } |
1121 | assert(Row && Col && "Shape got failed!" ); |
1122 | |
1123 | // Stride should be equal to col(measured by bytes) |
1124 | Value *Stride = Builder.CreateSExt(V: Col, DestTy: Builder.getInt64Ty()); |
1125 | Value *I8Ptr = Builder.CreateBitCast(V: ST->getOperand(i_nocapture: 1), DestTy: Builder.getPtrTy()); |
1126 | std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile}; |
1127 | Builder.CreateIntrinsic(ID: Intrinsic::x86_tilestored64_internal, Args); |
1128 | return true; |
1129 | } |
1130 | |
1131 | // %65 = load <256 x i32>, <256 x i32>* %p, align 64 |
1132 | // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65) |
1133 | // --> |
1134 | // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, |
1135 | // i8* %p, i64 64) |
1136 | bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) { |
1137 | bool EraseLoad = true; |
1138 | Value *Row = nullptr, *Col = nullptr; |
1139 | Use &U = *(Cast->use_begin()); |
1140 | unsigned OpNo = U.getOperandNo(); |
1141 | auto *II = cast<IntrinsicInst>(Val: U.getUser()); |
1142 | // TODO: If it is cast intrinsic or phi node, we can propagate the |
1143 | // shape information through def-use chain. |
1144 | if (!isAMXIntrinsic(I: II)) |
1145 | return false; |
1146 | std::tie(args&: Row, args&: Col) = SC->getShape(II, OpNo); |
1147 | IRBuilder<> Builder(LD); |
1148 | // Stride should be equal to col(measured by bytes) |
1149 | Value *Stride = Builder.CreateSExt(V: Col, DestTy: Builder.getInt64Ty()); |
1150 | Value *I8Ptr; |
1151 | |
1152 | // To save compiling time, we create doninator tree when it is really |
1153 | // needed. |
1154 | if (!DT) |
1155 | DT.reset(p: new DominatorTree(Func)); |
1156 | if (!DT->dominates(Def: Row, User: LD) || !DT->dominates(Def: Col, User: LD)) { |
1157 | // store the value to stack and reload it from stack before cast. |
1158 | auto *AllocaAddr = |
1159 | createAllocaInstAtEntry(Builder, BB: Cast->getParent(), Ty: LD->getType()); |
1160 | Builder.SetInsertPoint(&*std::next(x: LD->getIterator())); |
1161 | Builder.CreateStore(Val: LD, Ptr: AllocaAddr); |
1162 | |
1163 | Builder.SetInsertPoint(Cast); |
1164 | I8Ptr = Builder.CreateBitCast(V: AllocaAddr, DestTy: Builder.getPtrTy()); |
1165 | EraseLoad = false; |
1166 | } else { |
1167 | I8Ptr = Builder.CreateBitCast(V: LD->getOperand(i_nocapture: 0), DestTy: Builder.getPtrTy()); |
1168 | } |
1169 | std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; |
1170 | |
1171 | Value *NewInst = |
1172 | Builder.CreateIntrinsic(ID: Intrinsic::x86_tileloadd64_internal, Args); |
1173 | Cast->replaceAllUsesWith(V: NewInst); |
1174 | |
1175 | return EraseLoad; |
1176 | } |
1177 | |
1178 | bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) { |
1179 | bool Change = false; |
1180 | for (auto *Cast : Casts) { |
1181 | auto *II = cast<IntrinsicInst>(Val: Cast); |
1182 | // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42) |
1183 | // store <256 x i32> %43, <256 x i32>* %p, align 64 |
1184 | // --> |
1185 | // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p, |
1186 | // i64 64, x86_amx %42) |
1187 | if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) { |
1188 | SmallVector<Instruction *, 2> DeadStores; |
1189 | for (User *U : Cast->users()) { |
1190 | StoreInst *Store = dyn_cast<StoreInst>(Val: U); |
1191 | if (!Store) |
1192 | continue; |
1193 | if (combineCastStore(Cast: cast<IntrinsicInst>(Val: Cast), ST: Store)) { |
1194 | DeadStores.push_back(Elt: Store); |
1195 | Change = true; |
1196 | } |
1197 | } |
1198 | for (auto *Store : DeadStores) |
1199 | Store->eraseFromParent(); |
1200 | } else { // x86_cast_vector_to_tile |
1201 | auto *Load = dyn_cast<LoadInst>(Val: Cast->getOperand(i: 0)); |
1202 | if (!Load || !Load->hasOneUse()) |
1203 | continue; |
1204 | // %65 = load <256 x i32>, <256 x i32>* %p, align 64 |
1205 | // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65) |
1206 | // --> |
1207 | // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, |
1208 | // i8* %p, i64 64) |
1209 | if (combineLoadCast(Cast: cast<IntrinsicInst>(Val: Cast), LD: Load)) { |
1210 | // Set the operand is null so that load instruction can be erased. |
1211 | Cast->setOperand(i: 0, Val: nullptr); |
1212 | Load->eraseFromParent(); |
1213 | } |
1214 | } |
1215 | } |
1216 | return Change; |
1217 | } |
1218 | |
1219 | bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) { |
1220 | bool Change = false; |
1221 | // Collect tile cast instruction. |
1222 | SmallVector<Instruction *, 8> Vec2TileInsts; |
1223 | SmallVector<Instruction *, 8> Tile2VecInsts; |
1224 | SmallVector<Instruction *, 8> PhiCastWorkList; |
1225 | SmallSetVector<Instruction *, 16> DeadInst; |
1226 | for (BasicBlock &BB : Func) { |
1227 | for (Instruction &I : BB) { |
1228 | Value *Vec; |
1229 | if (match(V: &I, |
1230 | P: m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(Op0: m_Value(V&: Vec)))) |
1231 | Vec2TileInsts.push_back(Elt: &I); |
1232 | else if (match(V: &I, P: m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>( |
1233 | Op0: m_Value(V&: Vec)))) |
1234 | Tile2VecInsts.push_back(Elt: &I); |
1235 | } |
1236 | } |
1237 | |
1238 | auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) { |
1239 | for (auto *Inst : Insts) { |
1240 | for (User *U : Inst->users()) { |
1241 | IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: U); |
1242 | if (!II || II->getIntrinsicID() != IID) |
1243 | continue; |
1244 | // T1 = vec2tile V0 |
1245 | // V2 = tile2vec T1 |
1246 | // V3 = OP V2 |
1247 | // --> |
1248 | // T1 = vec2tile V0 |
1249 | // V2 = tile2vec T1 |
1250 | // V3 = OP V0 |
1251 | II->replaceAllUsesWith(V: Inst->getOperand(i: 0)); |
1252 | Change = true; |
1253 | } |
1254 | } |
1255 | }; |
1256 | |
1257 | Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector); |
1258 | Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile); |
1259 | |
1260 | SmallVector<Instruction *, 8> LiveCasts; |
1261 | auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) { |
1262 | for (auto *Inst : Insts) { |
1263 | if (Inst->use_empty()) { |
1264 | Inst->eraseFromParent(); |
1265 | Change = true; |
1266 | } else { |
1267 | LiveCasts.push_back(Elt: Inst); |
1268 | } |
1269 | } |
1270 | }; |
1271 | |
1272 | EraseInst(Vec2TileInsts); |
1273 | EraseInst(Tile2VecInsts); |
1274 | LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine " |
1275 | "Vec2Tile and Tile2Vec:\n" ; |
1276 | Func.dump()); |
1277 | Change |= combineLdSt(Casts&: LiveCasts); |
1278 | EraseInst(LiveCasts); |
1279 | LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine " |
1280 | "AMXCast and load/store:\n" ; |
1281 | Func.dump()); |
1282 | |
1283 | // Handle the A->B->A cast, and there is an intervening PHI node. |
1284 | for (BasicBlock &BB : Func) { |
1285 | for (Instruction &I : BB) { |
1286 | if (isAMXCast(II: &I)) { |
1287 | if (isa<PHINode>(Val: I.getOperand(i: 0))) |
1288 | PhiCastWorkList.push_back(Elt: &I); |
1289 | } |
1290 | } |
1291 | } |
1292 | for (auto *I : PhiCastWorkList) { |
1293 | // We skip the dead Amxcast. |
1294 | if (DeadInst.contains(key: I)) |
1295 | continue; |
1296 | PHINode *PN = cast<PHINode>(Val: I->getOperand(i: 0)); |
1297 | if (optimizeAMXCastFromPhi(CI: cast<IntrinsicInst>(Val: I), PN, DeadInst)) { |
1298 | DeadInst.insert(X: PN); |
1299 | Change = true; |
1300 | } |
1301 | } |
1302 | |
1303 | // Since we create new phi and merge AMXCast, some old phis and AMXCast might |
1304 | // have no uses. We do some DeadCodeElimination for them. |
1305 | while (!DeadInst.empty()) { |
1306 | Instruction *I = DeadInst.pop_back_val(); |
1307 | Change |= DCEInstruction(I, WorkList&: DeadInst, TLI); |
1308 | } |
1309 | LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after " |
1310 | "optimizeAMXCastFromPhi:\n" ; |
1311 | Func.dump()); |
1312 | return Change; |
1313 | } |
1314 | |
1315 | // There might be remaining AMXcast after combineAMXcast and they should be |
1316 | // handled elegantly. |
1317 | bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) { |
1318 | IRBuilder<> Builder(AMXCast); |
1319 | AllocaInst *AllocaAddr; |
1320 | Value *I8Ptr, *Stride; |
1321 | auto *Src = AMXCast->getOperand(i_nocapture: 0); |
1322 | |
1323 | auto Prepare = [&](Type *MemTy) { |
1324 | AllocaAddr = createAllocaInstAtEntry(Builder, BB: AMXCast->getParent(), Ty: MemTy); |
1325 | I8Ptr = Builder.CreateBitCast(V: AllocaAddr, DestTy: Builder.getPtrTy()); |
1326 | Stride = Builder.getInt64(C: 64); |
1327 | }; |
1328 | |
1329 | if (AMXCast->getType()->isX86_AMXTy()) { |
1330 | // %2 = amxcast <225 x i32> %src to x86_amx |
1331 | // call void @llvm.x86.tilestored64.internal(i16 15, i16 60, |
1332 | // i8* %addr3, i64 60, x86_amx %2) |
1333 | // --> |
1334 | // %addr = alloca <225 x i32>, align 64 |
1335 | // store <225 x i32> %src, <225 x i32>* %addr, align 64 |
1336 | // %addr2 = bitcast <225 x i32>* %addr to i8* |
1337 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60, |
1338 | // i8* %addr2, |
1339 | // i64 60) |
1340 | // call void @llvm.x86.tilestored64.internal(i16 15, i16 60, |
1341 | // i8* %addr3, i64 60, x86_amx %2) |
1342 | if (AMXCast->use_empty()) { |
1343 | AMXCast->eraseFromParent(); |
1344 | return true; |
1345 | } |
1346 | Use &U = *(AMXCast->use_begin()); |
1347 | unsigned OpNo = U.getOperandNo(); |
1348 | auto *II = dyn_cast<IntrinsicInst>(Val: U.getUser()); |
1349 | if (!II) |
1350 | return false; // May be bitcast from x86amx to <256 x i32>. |
1351 | Prepare(AMXCast->getOperand(i_nocapture: 0)->getType()); |
1352 | Builder.CreateStore(Val: Src, Ptr: AllocaAddr); |
1353 | // TODO we can pick an constant operand for the shape. |
1354 | Value *Row = nullptr, *Col = nullptr; |
1355 | std::tie(args&: Row, args&: Col) = SC->getShape(II, OpNo); |
1356 | std::array<Value *, 4> Args = { |
1357 | Row, Col, I8Ptr, Builder.CreateSExt(V: Col, DestTy: Builder.getInt64Ty())}; |
1358 | Value *NewInst = |
1359 | Builder.CreateIntrinsic(ID: Intrinsic::x86_tileloadd64_internal, Args); |
1360 | AMXCast->replaceAllUsesWith(V: NewInst); |
1361 | AMXCast->eraseFromParent(); |
1362 | } else { |
1363 | // %2 = amxcast x86_amx %src to <225 x i32> |
1364 | // --> |
1365 | // %addr = alloca <225 x i32>, align 64 |
1366 | // %addr2 = bitcast <225 x i32>* to i8* |
1367 | // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, |
1368 | // i8* %addr2, i64 %stride) |
1369 | // %2 = load <225 x i32>, <225 x i32>* %addr, align 64 |
1370 | auto *II = dyn_cast<IntrinsicInst>(Val: Src); |
1371 | if (!II) |
1372 | return false; // May be bitcast from <256 x i32> to x86amx. |
1373 | Prepare(AMXCast->getType()); |
1374 | Value *Row = II->getOperand(i_nocapture: 0); |
1375 | Value *Col = II->getOperand(i_nocapture: 1); |
1376 | std::array<Value *, 5> Args = { |
1377 | Row, Col, I8Ptr, Builder.CreateSExt(V: Col, DestTy: Builder.getInt64Ty()), Src}; |
1378 | Builder.CreateIntrinsic(ID: Intrinsic::x86_tilestored64_internal, Args); |
1379 | Value *NewInst = Builder.CreateLoad(Ty: AMXCast->getType(), Ptr: AllocaAddr); |
1380 | AMXCast->replaceAllUsesWith(V: NewInst); |
1381 | AMXCast->eraseFromParent(); |
1382 | } |
1383 | |
1384 | return true; |
1385 | } |
1386 | |
1387 | bool X86LowerAMXCast::transformAllAMXCast() { |
1388 | bool Change = false; |
1389 | // Collect tile cast instruction. |
1390 | SmallVector<Instruction *, 8> WorkLists; |
1391 | for (BasicBlock &BB : Func) { |
1392 | for (Instruction &I : BB) { |
1393 | if (isAMXCast(II: &I)) |
1394 | WorkLists.push_back(Elt: &I); |
1395 | } |
1396 | } |
1397 | |
1398 | for (auto *Inst : WorkLists) { |
1399 | Change |= transformAMXCast(AMXCast: cast<IntrinsicInst>(Val: Inst)); |
1400 | } |
1401 | |
1402 | return Change; |
1403 | } |
1404 | |
1405 | } // anonymous namespace |
1406 | |
1407 | namespace { |
1408 | |
1409 | class X86LowerAMXTypeLegacyPass : public FunctionPass { |
1410 | public: |
1411 | static char ID; |
1412 | |
1413 | X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {} |
1414 | |
1415 | bool runOnFunction(Function &F) override { |
1416 | // Performance optimization: most code doesn't use AMX, so return early if |
1417 | // there are no instructions that produce AMX values. This is sufficient, as |
1418 | // AMX arguments and constants are not allowed -- so any producer of an AMX |
1419 | // value must be an instruction. |
1420 | // TODO: find a cheaper way for this, without looking at all instructions. |
1421 | if (!containsAMXCode(F)) |
1422 | return false; |
1423 | |
1424 | bool C = false; |
1425 | TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); |
1426 | TargetLibraryInfo *TLI = |
1427 | &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); |
1428 | |
1429 | ShapeCalculator SC(TM); |
1430 | X86LowerAMXCast LAC(F, &SC); |
1431 | C |= LAC.combineAMXcast(TLI); |
1432 | // There might be remaining AMXcast after combineAMXcast and they should be |
1433 | // handled elegantly. |
1434 | C |= LAC.transformAllAMXCast(); |
1435 | |
1436 | X86LowerAMXType LAT(F, &SC); |
1437 | C |= LAT.visit(); |
1438 | |
1439 | // Prepare for fast register allocation at O0. |
1440 | // Todo: May better check the volatile model of AMX code, not just |
1441 | // by checking Attribute::OptimizeNone and CodeGenOptLevel::None. |
1442 | if (TM->getOptLevel() == CodeGenOptLevel::None) { |
1443 | // If Front End not use O0 but the Mid/Back end use O0, (e.g. |
1444 | // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make |
1445 | // sure the amx data is volatile, that is nessary for AMX fast |
1446 | // register allocation. |
1447 | if (!F.hasFnAttribute(Kind: Attribute::OptimizeNone)) { |
1448 | X86VolatileTileData VTD(F); |
1449 | C = VTD.volatileTileData() || C; |
1450 | } |
1451 | } |
1452 | |
1453 | return C; |
1454 | } |
1455 | |
1456 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
1457 | AU.setPreservesCFG(); |
1458 | AU.addRequired<TargetPassConfig>(); |
1459 | AU.addRequired<TargetLibraryInfoWrapperPass>(); |
1460 | } |
1461 | }; |
1462 | |
1463 | } // anonymous namespace |
1464 | |
1465 | static const char PassName[] = "Lower AMX type for load/store" ; |
1466 | char X86LowerAMXTypeLegacyPass::ID = 0; |
1467 | INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, |
1468 | false) |
1469 | INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) |
1470 | INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) |
1471 | INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, |
1472 | false) |
1473 | |
1474 | FunctionPass *llvm::createX86LowerAMXTypePass() { |
1475 | return new X86LowerAMXTypeLegacyPass(); |
1476 | } |
1477 | |