1 | //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | /// \file Pass to transform <256 x i32> load/store |
10 | /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only |
11 | /// provides simple operation on x86_amx. The basic elementwise operation |
12 | /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32> |
13 | /// and only AMX intrinsics can operate on the type, we need transform |
14 | /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can |
15 | /// not be combined with load/store, we transform the bitcast to amx load/store |
16 | /// and <256 x i32> store/load. |
17 | /// |
18 | /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S |
19 | /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile, |
20 | /// because that is necessary for AMX fast register allocation. (In Fast |
21 | /// registera allocation, register will be allocated before spill/reload, so |
22 | /// there is no additional register for amx to identify the step in spill.) |
23 | /// The volatileTileData() will handle this case. |
24 | /// e.g. |
25 | /// ---------------------------------------------------------- |
26 | /// | def %td = ... | |
27 | /// | ... | |
28 | /// | "use %td" | |
29 | /// ---------------------------------------------------------- |
30 | /// will transfer to --> |
31 | /// ---------------------------------------------------------- |
32 | /// | def %td = ... | |
33 | /// | call void @llvm.x86.tilestored64.internal(mem, %td) | |
34 | /// | ... | |
35 | /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)| |
36 | /// | "use %td2" | |
37 | /// ---------------------------------------------------------- |
38 | // |
39 | //===----------------------------------------------------------------------===// |
40 | // |
41 | #include "X86.h" |
42 | #include "llvm/ADT/PostOrderIterator.h" |
43 | #include "llvm/ADT/SetVector.h" |
44 | #include "llvm/ADT/SmallSet.h" |
45 | #include "llvm/Analysis/OptimizationRemarkEmitter.h" |
46 | #include "llvm/Analysis/TargetLibraryInfo.h" |
47 | #include "llvm/Analysis/TargetTransformInfo.h" |
48 | #include "llvm/CodeGen/Passes.h" |
49 | #include "llvm/CodeGen/TargetPassConfig.h" |
50 | #include "llvm/CodeGen/ValueTypes.h" |
51 | #include "llvm/IR/DataLayout.h" |
52 | #include "llvm/IR/Function.h" |
53 | #include "llvm/IR/IRBuilder.h" |
54 | #include "llvm/IR/Instructions.h" |
55 | #include "llvm/IR/IntrinsicInst.h" |
56 | #include "llvm/IR/IntrinsicsX86.h" |
57 | #include "llvm/IR/PatternMatch.h" |
58 | #include "llvm/InitializePasses.h" |
59 | #include "llvm/Pass.h" |
60 | #include "llvm/Target/TargetMachine.h" |
61 | #include "llvm/Transforms/Utils/AssumeBundleBuilder.h" |
62 | #include "llvm/Transforms/Utils/Local.h" |
63 | |
64 | #include <map> |
65 | |
66 | using namespace llvm; |
67 | using namespace PatternMatch; |
68 | |
69 | #define DEBUG_TYPE "lower-amx-type" |
70 | |
71 | static bool isAMXCast(Instruction *II) { |
72 | return match(V: II, |
73 | P: m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(Op0: m_Value())) || |
74 | match(V: II, P: m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(Op0: m_Value())); |
75 | } |
76 | |
77 | static bool isAMXIntrinsic(Value *I) { |
78 | auto *II = dyn_cast<IntrinsicInst>(Val: I); |
79 | if (!II) |
80 | return false; |
81 | if (isAMXCast(II)) |
82 | return false; |
83 | // Check if return type or parameter is x86_amx. If it is x86_amx |
84 | // the intrinsic must be x86 amx intrinsics. |
85 | if (II->getType()->isX86_AMXTy()) |
86 | return true; |
87 | for (Value *V : II->args()) { |
88 | if (V->getType()->isX86_AMXTy()) |
89 | return true; |
90 | } |
91 | |
92 | return false; |
93 | } |
94 | |
95 | static bool containsAMXCode(Function &F) { |
96 | for (BasicBlock &BB : F) |
97 | for (Instruction &I : BB) |
98 | if (I.getType()->isX86_AMXTy()) |
99 | return true; |
100 | return false; |
101 | } |
102 | |
103 | static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, |
104 | Type *Ty) { |
105 | Function &F = *BB->getParent(); |
106 | const DataLayout &DL = F.getDataLayout(); |
107 | |
108 | LLVMContext &Ctx = Builder.getContext(); |
109 | auto AllocaAlignment = DL.getPrefTypeAlign(Ty: Type::getX86_AMXTy(C&: Ctx)); |
110 | unsigned AllocaAS = DL.getAllocaAddrSpace(); |
111 | AllocaInst *AllocaRes = |
112 | new AllocaInst(Ty, AllocaAS, "" , F.getEntryBlock().begin()); |
113 | AllocaRes->setAlignment(AllocaAlignment); |
114 | return AllocaRes; |
115 | } |
116 | |
117 | static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) { |
118 | for (Instruction &I : F.getEntryBlock()) |
119 | if (!isa<AllocaInst>(Val: &I)) |
120 | return &I; |
121 | llvm_unreachable("No terminator in the entry block!" ); |
122 | } |
123 | |
124 | static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) { |
125 | IRBuilder<> Builder(II); |
126 | Value *Row = nullptr, *Col = nullptr; |
127 | switch (II->getIntrinsicID()) { |
128 | default: |
129 | llvm_unreachable("Expect amx intrinsics" ); |
130 | case Intrinsic::x86_tileloadd64_internal: |
131 | case Intrinsic::x86_tileloaddt164_internal: |
132 | case Intrinsic::x86_tilestored64_internal: { |
133 | Row = II->getArgOperand(i: 0); |
134 | Col = II->getArgOperand(i: 1); |
135 | break; |
136 | } |
137 | // a * b + c |
138 | // The shape depends on which operand. |
139 | case Intrinsic::x86_tcmmimfp16ps_internal: |
140 | case Intrinsic::x86_tcmmrlfp16ps_internal: |
141 | case Intrinsic::x86_tdpbssd_internal: |
142 | case Intrinsic::x86_tdpbsud_internal: |
143 | case Intrinsic::x86_tdpbusd_internal: |
144 | case Intrinsic::x86_tdpbuud_internal: |
145 | case Intrinsic::x86_tdpbf16ps_internal: |
146 | case Intrinsic::x86_tdpfp16ps_internal: { |
147 | switch (OpNo) { |
148 | case 3: |
149 | Row = II->getArgOperand(i: 0); |
150 | Col = II->getArgOperand(i: 1); |
151 | break; |
152 | case 4: |
153 | Row = II->getArgOperand(i: 0); |
154 | Col = II->getArgOperand(i: 2); |
155 | break; |
156 | case 5: |
157 | if (isa<ConstantInt>(Val: II->getArgOperand(i: 2))) |
158 | Row = Builder.getInt16( |
159 | C: (cast<ConstantInt>(Val: II->getOperand(i_nocapture: 2))->getSExtValue()) / 4); |
160 | else if (isa<Instruction>(Val: II->getArgOperand(i: 2))) { |
161 | // When it is not a const value and it is not a function argument, we |
162 | // create Row after the definition of II->getOperand(2) instead of |
163 | // before II. For example, II is %118, we try to getshape for %117: |
164 | // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x |
165 | // i32> %115). |
166 | // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16 |
167 | // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx |
168 | // %117). |
169 | // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its |
170 | // definition is after its user(new tileload for %117). |
171 | // So, the best choice is to create %row right after the definition of |
172 | // %106. |
173 | Builder.SetInsertPoint(cast<Instruction>(Val: II->getOperand(i_nocapture: 2))); |
174 | Row = Builder.CreateUDiv(LHS: II->getOperand(i_nocapture: 2), RHS: Builder.getInt16(C: 4)); |
175 | cast<Instruction>(Val: Row)->moveAfter(MovePos: cast<Instruction>(Val: II->getOperand(i_nocapture: 2))); |
176 | } else { |
177 | // When it is not a const value and it is a function argument, we create |
178 | // Row at the entry bb. |
179 | IRBuilder<> NewBuilder( |
180 | getFirstNonAllocaInTheEntryBlock(F&: *II->getFunction())); |
181 | Row = NewBuilder.CreateUDiv(LHS: II->getOperand(i_nocapture: 2), RHS: NewBuilder.getInt16(C: 4)); |
182 | } |
183 | Col = II->getArgOperand(i: 1); |
184 | break; |
185 | } |
186 | break; |
187 | } |
188 | } |
189 | |
190 | return std::make_pair(x&: Row, y&: Col); |
191 | } |
192 | |
193 | static std::pair<Value *, Value *> getShape(PHINode *Phi) { |
194 | Use &U = *(Phi->use_begin()); |
195 | unsigned OpNo = U.getOperandNo(); |
196 | User *V = U.getUser(); |
197 | // TODO We don't traverse all users. To make the algorithm simple, here we |
198 | // just traverse the first user. If we can find shape, then return the shape, |
199 | // otherwise just return nullptr and the optimization for undef/zero will be |
200 | // abandoned. |
201 | while (V) { |
202 | if (isAMXCast(II: dyn_cast<Instruction>(Val: V))) { |
203 | if (V->use_empty()) |
204 | break; |
205 | Use &U = *(V->use_begin()); |
206 | OpNo = U.getOperandNo(); |
207 | V = U.getUser(); |
208 | } else if (isAMXIntrinsic(I: V)) { |
209 | return getShape(II: cast<IntrinsicInst>(Val: V), OpNo); |
210 | } else if (isa<PHINode>(Val: V)) { |
211 | if (V->use_empty()) |
212 | break; |
213 | Use &U = *(V->use_begin()); |
214 | V = U.getUser(); |
215 | } else { |
216 | break; |
217 | } |
218 | } |
219 | |
220 | return std::make_pair(x: nullptr, y: nullptr); |
221 | } |
222 | |
223 | namespace { |
224 | class X86LowerAMXType { |
225 | Function &Func; |
226 | |
227 | // In AMX intrinsics we let Shape = {Row, Col}, but the |
228 | // RealCol = Col / ElementSize. We may use the RealCol |
229 | // as a new Row for other new created AMX intrinsics. |
230 | std::map<Value *, Value *> Col2Row; |
231 | |
232 | public: |
233 | X86LowerAMXType(Function &F) : Func(F) {} |
234 | bool visit(); |
235 | void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast); |
236 | void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST); |
237 | bool transformBitcast(BitCastInst *Bitcast); |
238 | }; |
239 | |
240 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 |
241 | // %2 = bitcast <256 x i32> %src to x86_amx |
242 | // --> |
243 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, |
244 | // i8* %addr, i64 %stride64) |
245 | void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) { |
246 | Value *Row = nullptr, *Col = nullptr; |
247 | Use &U = *(Bitcast->use_begin()); |
248 | unsigned OpNo = U.getOperandNo(); |
249 | auto *II = cast<IntrinsicInst>(Val: U.getUser()); |
250 | std::tie(args&: Row, args&: Col) = getShape(II, OpNo); |
251 | IRBuilder<> Builder(Bitcast); |
252 | // Use the maximun column as stride. |
253 | Value *Stride = Builder.getInt64(C: 64); |
254 | Value *I8Ptr = LD->getOperand(i_nocapture: 0); |
255 | std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; |
256 | |
257 | Value *NewInst = Builder.CreateIntrinsic(ID: Intrinsic::x86_tileloadd64_internal, |
258 | Types: std::nullopt, Args); |
259 | Bitcast->replaceAllUsesWith(V: NewInst); |
260 | } |
261 | |
262 | // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, |
263 | // %stride); |
264 | // %13 = bitcast x86_amx %src to <256 x i32> |
265 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 |
266 | // --> |
267 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, |
268 | // %stride64, %13) |
269 | void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) { |
270 | |
271 | Value *Tile = Bitcast->getOperand(i_nocapture: 0); |
272 | auto *II = cast<IntrinsicInst>(Val: Tile); |
273 | // Tile is output from AMX intrinsic. The first operand of the |
274 | // intrinsic is row, the second operand of the intrinsic is column. |
275 | Value *Row = II->getOperand(i_nocapture: 0); |
276 | Value *Col = II->getOperand(i_nocapture: 1); |
277 | IRBuilder<> Builder(ST); |
278 | // Use the maximum column as stride. It must be the same with load |
279 | // stride. |
280 | Value *Stride = Builder.getInt64(C: 64); |
281 | Value *I8Ptr = ST->getOperand(i_nocapture: 1); |
282 | std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile}; |
283 | Builder.CreateIntrinsic(ID: Intrinsic::x86_tilestored64_internal, Types: std::nullopt, |
284 | Args); |
285 | if (Bitcast->hasOneUse()) |
286 | return; |
287 | // %13 = bitcast x86_amx %src to <256 x i32> |
288 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 |
289 | // %add = <256 x i32> %13, <256 x i32> %src2 |
290 | // --> |
291 | // %13 = bitcast x86_amx %src to <256 x i32> |
292 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, |
293 | // %stride64, %13) |
294 | // %14 = load <256 x i32>, %addr |
295 | // %add = <256 x i32> %14, <256 x i32> %src2 |
296 | Value *Vec = Builder.CreateLoad(Ty: Bitcast->getType(), Ptr: ST->getOperand(i_nocapture: 1)); |
297 | Bitcast->replaceAllUsesWith(V: Vec); |
298 | } |
299 | |
300 | // transform bitcast to <store, load> instructions. |
301 | bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) { |
302 | IRBuilder<> Builder(Bitcast); |
303 | AllocaInst *AllocaAddr; |
304 | Value *I8Ptr, *Stride; |
305 | auto *Src = Bitcast->getOperand(i_nocapture: 0); |
306 | |
307 | auto Prepare = [&](Type *MemTy) { |
308 | AllocaAddr = createAllocaInstAtEntry(Builder, BB: Bitcast->getParent(), Ty: MemTy); |
309 | I8Ptr = AllocaAddr; |
310 | Stride = Builder.getInt64(C: 64); |
311 | }; |
312 | |
313 | if (Bitcast->getType()->isX86_AMXTy()) { |
314 | // %2 = bitcast <256 x i32> %src to x86_amx |
315 | // --> |
316 | // %addr = alloca <256 x i32>, align 64 |
317 | // store <256 x i32> %src, <256 x i32>* %addr, align 64 |
318 | // %addr2 = bitcast <256 x i32>* to i8* |
319 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, |
320 | // i8* %addr2, |
321 | // i64 64) |
322 | Use &U = *(Bitcast->use_begin()); |
323 | unsigned OpNo = U.getOperandNo(); |
324 | auto *II = dyn_cast<IntrinsicInst>(Val: U.getUser()); |
325 | if (!II) |
326 | return false; // May be bitcast from x86amx to <256 x i32>. |
327 | Prepare(Bitcast->getOperand(i_nocapture: 0)->getType()); |
328 | Builder.CreateStore(Val: Src, Ptr: AllocaAddr); |
329 | // TODO we can pick an constant operand for the shape. |
330 | Value *Row = nullptr, *Col = nullptr; |
331 | std::tie(args&: Row, args&: Col) = getShape(II, OpNo); |
332 | std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; |
333 | Value *NewInst = Builder.CreateIntrinsic( |
334 | ID: Intrinsic::x86_tileloadd64_internal, Types: std::nullopt, Args); |
335 | Bitcast->replaceAllUsesWith(V: NewInst); |
336 | } else { |
337 | // %2 = bitcast x86_amx %src to <256 x i32> |
338 | // --> |
339 | // %addr = alloca <256 x i32>, align 64 |
340 | // %addr2 = bitcast <256 x i32>* to i8* |
341 | // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, |
342 | // i8* %addr2, i64 %stride) |
343 | // %2 = load <256 x i32>, <256 x i32>* %addr, align 64 |
344 | auto *II = dyn_cast<IntrinsicInst>(Val: Src); |
345 | if (!II) |
346 | return false; // May be bitcast from <256 x i32> to x86amx. |
347 | Prepare(Bitcast->getType()); |
348 | Value *Row = II->getOperand(i_nocapture: 0); |
349 | Value *Col = II->getOperand(i_nocapture: 1); |
350 | std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src}; |
351 | Builder.CreateIntrinsic(ID: Intrinsic::x86_tilestored64_internal, Types: std::nullopt, |
352 | Args); |
353 | Value *NewInst = Builder.CreateLoad(Ty: Bitcast->getType(), Ptr: AllocaAddr); |
354 | Bitcast->replaceAllUsesWith(V: NewInst); |
355 | } |
356 | |
357 | return true; |
358 | } |
359 | |
360 | bool X86LowerAMXType::visit() { |
361 | SmallVector<Instruction *, 8> DeadInsts; |
362 | Col2Row.clear(); |
363 | |
364 | for (BasicBlock *BB : post_order(G: &Func)) { |
365 | for (Instruction &Inst : llvm::make_early_inc_range(Range: llvm::reverse(C&: *BB))) { |
366 | auto *Bitcast = dyn_cast<BitCastInst>(Val: &Inst); |
367 | if (!Bitcast) |
368 | continue; |
369 | |
370 | Value *Src = Bitcast->getOperand(i_nocapture: 0); |
371 | if (Bitcast->getType()->isX86_AMXTy()) { |
372 | if (Bitcast->user_empty()) { |
373 | DeadInsts.push_back(Elt: Bitcast); |
374 | continue; |
375 | } |
376 | LoadInst *LD = dyn_cast<LoadInst>(Val: Src); |
377 | if (!LD) { |
378 | if (transformBitcast(Bitcast)) |
379 | DeadInsts.push_back(Elt: Bitcast); |
380 | continue; |
381 | } |
382 | // If load has mutli-user, duplicate a vector load. |
383 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 |
384 | // %2 = bitcast <256 x i32> %src to x86_amx |
385 | // %add = add <256 x i32> %src, <256 x i32> %src2 |
386 | // --> |
387 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 |
388 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, |
389 | // i8* %addr, i64 %stride64) |
390 | // %add = add <256 x i32> %src, <256 x i32> %src2 |
391 | |
392 | // If load has one user, the load will be eliminated in DAG ISel. |
393 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 |
394 | // %2 = bitcast <256 x i32> %src to x86_amx |
395 | // --> |
396 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, |
397 | // i8* %addr, i64 %stride64) |
398 | combineLoadBitcast(LD, Bitcast); |
399 | DeadInsts.push_back(Elt: Bitcast); |
400 | if (LD->hasOneUse()) |
401 | DeadInsts.push_back(Elt: LD); |
402 | } else if (Src->getType()->isX86_AMXTy()) { |
403 | if (Bitcast->user_empty()) { |
404 | DeadInsts.push_back(Elt: Bitcast); |
405 | continue; |
406 | } |
407 | StoreInst *ST = nullptr; |
408 | for (Use &U : Bitcast->uses()) { |
409 | ST = dyn_cast<StoreInst>(Val: U.getUser()); |
410 | if (ST) |
411 | break; |
412 | } |
413 | if (!ST) { |
414 | if (transformBitcast(Bitcast)) |
415 | DeadInsts.push_back(Elt: Bitcast); |
416 | continue; |
417 | } |
418 | // If bitcast (%13) has one use, combine bitcast and store to amx store. |
419 | // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, |
420 | // %stride); |
421 | // %13 = bitcast x86_amx %src to <256 x i32> |
422 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 |
423 | // --> |
424 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, |
425 | // %stride64, %13) |
426 | // |
427 | // If bitcast (%13) has multi-use, transform as below. |
428 | // %13 = bitcast x86_amx %src to <256 x i32> |
429 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 |
430 | // %add = <256 x i32> %13, <256 x i32> %src2 |
431 | // --> |
432 | // %13 = bitcast x86_amx %src to <256 x i32> |
433 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, |
434 | // %stride64, %13) |
435 | // %14 = load <256 x i32>, %addr |
436 | // %add = <256 x i32> %14, <256 x i32> %src2 |
437 | // |
438 | combineBitcastStore(Bitcast, ST); |
439 | // Delete user first. |
440 | DeadInsts.push_back(Elt: ST); |
441 | DeadInsts.push_back(Elt: Bitcast); |
442 | } |
443 | } |
444 | } |
445 | |
446 | bool C = !DeadInsts.empty(); |
447 | |
448 | for (auto *Inst : DeadInsts) |
449 | Inst->eraseFromParent(); |
450 | |
451 | return C; |
452 | } |
453 | } // anonymous namespace |
454 | |
455 | static Value *getAllocaPos(BasicBlock *BB) { |
456 | Function *F = BB->getParent(); |
457 | IRBuilder<> Builder(&F->getEntryBlock().front()); |
458 | const DataLayout &DL = F->getDataLayout(); |
459 | unsigned AllocaAS = DL.getAllocaAddrSpace(); |
460 | Type *V256I32Ty = VectorType::get(ElementType: Builder.getInt32Ty(), NumElements: 256, Scalable: false); |
461 | AllocaInst *AllocaRes = |
462 | new AllocaInst(V256I32Ty, AllocaAS, "" , F->getEntryBlock().begin()); |
463 | BasicBlock::iterator Iter = AllocaRes->getIterator(); |
464 | ++Iter; |
465 | Builder.SetInsertPoint(&*Iter); |
466 | Value *I8Ptr = Builder.CreateBitCast(V: AllocaRes, DestTy: Builder.getPtrTy()); |
467 | return I8Ptr; |
468 | } |
469 | |
470 | static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) { |
471 | assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!" ); |
472 | auto *II = cast<IntrinsicInst>(Val: TileDef); |
473 | assert(II && "Not tile intrinsic!" ); |
474 | Value *Row = II->getOperand(i_nocapture: 0); |
475 | Value *Col = II->getOperand(i_nocapture: 1); |
476 | |
477 | BasicBlock *BB = TileDef->getParent(); |
478 | BasicBlock::iterator Iter = TileDef->getIterator(); |
479 | IRBuilder<> Builder(BB, ++Iter); |
480 | Value *Stride = Builder.getInt64(C: 64); |
481 | std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef}; |
482 | |
483 | Instruction *TileStore = Builder.CreateIntrinsic( |
484 | ID: Intrinsic::x86_tilestored64_internal, Types: std::nullopt, Args); |
485 | return TileStore; |
486 | } |
487 | |
488 | static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) { |
489 | Value *V = U.get(); |
490 | assert(V->getType()->isX86_AMXTy() && "Not define tile!" ); |
491 | |
492 | // Get tile shape. |
493 | IntrinsicInst *II = nullptr; |
494 | if (IsPHI) { |
495 | Value *PhiOp = cast<PHINode>(Val: V)->getIncomingValue(i: 0); |
496 | II = cast<IntrinsicInst>(Val: PhiOp); |
497 | } else { |
498 | II = cast<IntrinsicInst>(Val: V); |
499 | } |
500 | Value *Row = II->getOperand(i_nocapture: 0); |
501 | Value *Col = II->getOperand(i_nocapture: 1); |
502 | |
503 | Instruction *UserI = cast<Instruction>(Val: U.getUser()); |
504 | IRBuilder<> Builder(UserI); |
505 | Value *Stride = Builder.getInt64(C: 64); |
506 | std::array<Value *, 4> Args = {Row, Col, Ptr, Stride}; |
507 | |
508 | Value *TileLoad = Builder.CreateIntrinsic(ID: Intrinsic::x86_tileloadd64_internal, |
509 | Types: std::nullopt, Args); |
510 | UserI->replaceUsesOfWith(From: V, To: TileLoad); |
511 | } |
512 | |
513 | static bool isIncomingOfPHI(Instruction *I) { |
514 | for (Use &U : I->uses()) { |
515 | User *V = U.getUser(); |
516 | if (isa<PHINode>(Val: V)) |
517 | return true; |
518 | } |
519 | return false; |
520 | } |
521 | |
522 | // Let all AMX tile data become volatile data, shorten the life range |
523 | // of each tile register before fast register allocation. |
524 | namespace { |
525 | class X86VolatileTileData { |
526 | Function &F; |
527 | |
528 | public: |
529 | X86VolatileTileData(Function &Func) : F(Func) {} |
530 | Value *updatePhiIncomings(BasicBlock *BB, |
531 | SmallVector<Instruction *, 2> &Incomings); |
532 | void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr); |
533 | bool volatileTileData(); |
534 | void volatileTilePHI(PHINode *PHI); |
535 | void volatileTileNonPHI(Instruction *I); |
536 | }; |
537 | |
538 | Value *X86VolatileTileData::updatePhiIncomings( |
539 | BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) { |
540 | Value *I8Ptr = getAllocaPos(BB); |
541 | |
542 | for (auto *I : Incomings) { |
543 | User *Store = createTileStore(TileDef: I, Ptr: I8Ptr); |
544 | |
545 | // All its uses (except phi) should load from stored mem. |
546 | for (Use &U : I->uses()) { |
547 | User *V = U.getUser(); |
548 | if (isa<PHINode>(Val: V) || V == Store) |
549 | continue; |
550 | replaceWithTileLoad(U, Ptr: I8Ptr); |
551 | } |
552 | } |
553 | return I8Ptr; |
554 | } |
555 | |
556 | void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI, |
557 | Value *StorePtr) { |
558 | for (Use &U : PHI->uses()) |
559 | replaceWithTileLoad(U, Ptr: StorePtr, IsPHI: true); |
560 | PHI->eraseFromParent(); |
561 | } |
562 | |
563 | // Smilar with volatileTileNonPHI, this function only handle PHI Nodes |
564 | // and their related AMX intrinsics. |
565 | // 1) PHI Def should change to tileload. |
566 | // 2) PHI Incoming Values should tilestored in just after their def. |
567 | // 3) The mem of these tileload and tilestores should be same. |
568 | // e.g. |
569 | // ------------------------------------------------------ |
570 | // bb_dom: |
571 | // ... |
572 | // br i1 %bool.cond, label %if.else, label %if.then |
573 | // |
574 | // if.then: |
575 | // def %t0 = ... |
576 | // ... |
577 | // use %t0 |
578 | // ... |
579 | // br label %if.end |
580 | // |
581 | // if.else: |
582 | // def %t1 = ... |
583 | // br label %if.end |
584 | // |
585 | // if.end: |
586 | // %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ] |
587 | // ... |
588 | // use %td |
589 | // ------------------------------------------------------ |
590 | // --> |
591 | // ------------------------------------------------------ |
592 | // bb_entry: |
593 | // %mem = alloca <256 x i32>, align 1024 * |
594 | // ... |
595 | // bb_dom: |
596 | // ... |
597 | // br i1 %bool.cond, label %if.else, label %if.then |
598 | // |
599 | // if.then: |
600 | // def %t0 = ... |
601 | // call void @llvm.x86.tilestored64.internal(mem, %t0) * |
602 | // ... |
603 | // %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)* |
604 | // use %t0` * |
605 | // ... |
606 | // br label %if.end |
607 | // |
608 | // if.else: |
609 | // def %t1 = ... |
610 | // call void @llvm.x86.tilestored64.internal(mem, %t1) * |
611 | // br label %if.end |
612 | // |
613 | // if.end: |
614 | // ... |
615 | // %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) * |
616 | // use %td |
617 | // ------------------------------------------------------ |
618 | void X86VolatileTileData::volatileTilePHI(PHINode *PHI) { |
619 | BasicBlock *BB = PHI->getParent(); |
620 | SmallVector<Instruction *, 2> Incomings; |
621 | |
622 | for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) { |
623 | Value *Op = PHI->getIncomingValue(i: I); |
624 | Instruction *Inst = dyn_cast<Instruction>(Val: Op); |
625 | assert(Inst && "We shouldn't fold AMX instrution!" ); |
626 | Incomings.push_back(Elt: Inst); |
627 | } |
628 | |
629 | Value *StorePtr = updatePhiIncomings(BB, Incomings); |
630 | replacePhiDefWithLoad(PHI, StorePtr); |
631 | } |
632 | |
633 | // Store the defined tile and load it before use. |
634 | // All its users are not PHI. |
635 | // e.g. |
636 | // ------------------------------------------------------ |
637 | // def %td = ... |
638 | // ... |
639 | // "use %td" |
640 | // ------------------------------------------------------ |
641 | // --> |
642 | // ------------------------------------------------------ |
643 | // def %td = ... |
644 | // call void @llvm.x86.tilestored64.internal(mem, %td) |
645 | // ... |
646 | // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem) |
647 | // "use %td2" |
648 | // ------------------------------------------------------ |
649 | void X86VolatileTileData::volatileTileNonPHI(Instruction *I) { |
650 | BasicBlock *BB = I->getParent(); |
651 | Value *I8Ptr = getAllocaPos(BB); |
652 | User *Store = createTileStore(TileDef: I, Ptr: I8Ptr); |
653 | |
654 | // All its uses should load from stored mem. |
655 | for (Use &U : I->uses()) { |
656 | User *V = U.getUser(); |
657 | assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!" ); |
658 | if (V != Store) |
659 | replaceWithTileLoad(U, Ptr: I8Ptr); |
660 | } |
661 | } |
662 | |
663 | // Volatile Tile Model: |
664 | // 1) All the uses of tile data comes from tileload in time. |
665 | // 2) All the defs of tile data tilestore into mem immediately. |
666 | // For example: |
667 | // -------------------------------------------------------------------------- |
668 | // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key |
669 | // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) |
670 | // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx |
671 | // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3) |
672 | // call void @llvm.x86.tilestored64.internal(... td) area |
673 | // -------------------------------------------------------------------------- |
674 | // 3) No terminator, call or other amx instructions in the key amx area. |
675 | bool X86VolatileTileData::volatileTileData() { |
676 | bool Changed = false; |
677 | for (BasicBlock &BB : F) { |
678 | SmallVector<Instruction *, 2> PHIInsts; |
679 | SmallVector<Instruction *, 8> AMXDefInsts; |
680 | |
681 | for (Instruction &I : BB) { |
682 | if (!I.getType()->isX86_AMXTy()) |
683 | continue; |
684 | if (isa<PHINode>(Val: &I)) |
685 | PHIInsts.push_back(Elt: &I); |
686 | else |
687 | AMXDefInsts.push_back(Elt: &I); |
688 | } |
689 | |
690 | // First we "volatile" the non-phi related amx intrinsics. |
691 | for (Instruction *I : AMXDefInsts) { |
692 | if (isIncomingOfPHI(I)) |
693 | continue; |
694 | volatileTileNonPHI(I); |
695 | Changed = true; |
696 | } |
697 | |
698 | for (Instruction *I : PHIInsts) { |
699 | volatileTilePHI(PHI: dyn_cast<PHINode>(Val: I)); |
700 | Changed = true; |
701 | } |
702 | } |
703 | return Changed; |
704 | } |
705 | |
706 | } // anonymous namespace |
707 | |
708 | namespace { |
709 | |
710 | class X86LowerAMXCast { |
711 | Function &Func; |
712 | std::unique_ptr<DominatorTree> DT; |
713 | |
714 | public: |
715 | X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {} |
716 | bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST); |
717 | bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD); |
718 | bool combineLdSt(SmallVectorImpl<Instruction *> &Casts); |
719 | bool combineAMXcast(TargetLibraryInfo *TLI); |
720 | bool transformAMXCast(IntrinsicInst *AMXCast); |
721 | bool transformAllAMXCast(); |
722 | bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN, |
723 | SmallSetVector<Instruction *, 16> &DeadInst); |
724 | }; |
725 | |
726 | static bool DCEInstruction(Instruction *I, |
727 | SmallSetVector<Instruction *, 16> &WorkList, |
728 | const TargetLibraryInfo *TLI) { |
729 | if (isInstructionTriviallyDead(I, TLI)) { |
730 | salvageDebugInfo(I&: *I); |
731 | salvageKnowledge(I); |
732 | |
733 | // Null out all of the instruction's operands to see if any operand becomes |
734 | // dead as we go. |
735 | for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { |
736 | Value *OpV = I->getOperand(i); |
737 | I->setOperand(i, Val: nullptr); |
738 | |
739 | if (!OpV->use_empty() || I == OpV) |
740 | continue; |
741 | |
742 | // If the operand is an instruction that became dead as we nulled out the |
743 | // operand, and if it is 'trivially' dead, delete it in a future loop |
744 | // iteration. |
745 | if (Instruction *OpI = dyn_cast<Instruction>(Val: OpV)) { |
746 | if (isInstructionTriviallyDead(I: OpI, TLI)) { |
747 | WorkList.insert(X: OpI); |
748 | } |
749 | } |
750 | } |
751 | I->eraseFromParent(); |
752 | return true; |
753 | } |
754 | return false; |
755 | } |
756 | |
757 | /// This function handles following case |
758 | /// |
759 | /// A -> B amxcast |
760 | /// PHI |
761 | /// B -> A amxcast |
762 | /// |
763 | /// All the related PHI nodes can be replaced by new PHI nodes with type A. |
764 | /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN. |
765 | bool X86LowerAMXCast::optimizeAMXCastFromPhi( |
766 | IntrinsicInst *CI, PHINode *PN, |
767 | SmallSetVector<Instruction *, 16> &DeadInst) { |
768 | IRBuilder<> Builder(CI); |
769 | Value *Src = CI->getOperand(i_nocapture: 0); |
770 | Type *SrcTy = Src->getType(); // Type B |
771 | Type *DestTy = CI->getType(); // Type A |
772 | |
773 | SmallVector<PHINode *, 4> PhiWorklist; |
774 | SmallSetVector<PHINode *, 4> OldPhiNodes; |
775 | |
776 | // Find all of the A->B casts and PHI nodes. |
777 | // We need to inspect all related PHI nodes, but PHIs can be cyclic, so |
778 | // OldPhiNodes is used to track all known PHI nodes, before adding a new |
779 | // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first. |
780 | PhiWorklist.push_back(Elt: PN); |
781 | OldPhiNodes.insert(X: PN); |
782 | while (!PhiWorklist.empty()) { |
783 | auto *OldPN = PhiWorklist.pop_back_val(); |
784 | for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) { |
785 | Value *IncValue = OldPN->getIncomingValue(i: I); |
786 | // TODO: currently, We ignore cases where it is a const. In the future, we |
787 | // might support const. |
788 | if (isa<Constant>(Val: IncValue)) { |
789 | auto *IncConst = dyn_cast<Constant>(Val: IncValue); |
790 | if (!isa<UndefValue>(Val: IncValue) && !IncConst->isZeroValue()) |
791 | return false; |
792 | Value *Row = nullptr, *Col = nullptr; |
793 | std::tie(args&: Row, args&: Col) = getShape(Phi: OldPN); |
794 | // TODO: If it is not constant the Row and Col must domoniate tilezero |
795 | // that we are going to create. |
796 | if (!Row || !Col || !isa<Constant>(Val: Row) || !isa<Constant>(Val: Col)) |
797 | return false; |
798 | // Create tilezero at the end of incoming block. |
799 | auto *Block = OldPN->getIncomingBlock(i: I); |
800 | BasicBlock::iterator Iter = Block->getTerminator()->getIterator(); |
801 | Instruction *NewInst = Builder.CreateIntrinsic( |
802 | ID: Intrinsic::x86_tilezero_internal, Types: std::nullopt, Args: {Row, Col}); |
803 | NewInst->moveBefore(MovePos: &*Iter); |
804 | NewInst = Builder.CreateIntrinsic(ID: Intrinsic::x86_cast_tile_to_vector, |
805 | Types: {IncValue->getType()}, Args: {NewInst}); |
806 | NewInst->moveBefore(MovePos: &*Iter); |
807 | // Replace InValue with new Value. |
808 | OldPN->setIncomingValue(i: I, V: NewInst); |
809 | IncValue = NewInst; |
810 | } |
811 | |
812 | if (auto *PNode = dyn_cast<PHINode>(Val: IncValue)) { |
813 | if (OldPhiNodes.insert(X: PNode)) |
814 | PhiWorklist.push_back(Elt: PNode); |
815 | continue; |
816 | } |
817 | Instruction *ACI = dyn_cast<Instruction>(Val: IncValue); |
818 | if (ACI && isAMXCast(II: ACI)) { |
819 | // Verify it's a A->B cast. |
820 | Type *TyA = ACI->getOperand(i: 0)->getType(); |
821 | Type *TyB = ACI->getType(); |
822 | if (TyA != DestTy || TyB != SrcTy) |
823 | return false; |
824 | continue; |
825 | } |
826 | return false; |
827 | } |
828 | } |
829 | |
830 | // Check that each user of each old PHI node is something that we can |
831 | // rewrite, so that all of the old PHI nodes can be cleaned up afterwards. |
832 | for (auto *OldPN : OldPhiNodes) { |
833 | for (User *V : OldPN->users()) { |
834 | Instruction *ACI = dyn_cast<Instruction>(Val: V); |
835 | if (ACI && isAMXCast(II: ACI)) { |
836 | // Verify it's a B->A cast. |
837 | Type *TyB = ACI->getOperand(i: 0)->getType(); |
838 | Type *TyA = ACI->getType(); |
839 | if (TyA != DestTy || TyB != SrcTy) |
840 | return false; |
841 | } else if (auto *PHI = dyn_cast<PHINode>(Val: V)) { |
842 | // As long as the user is another old PHI node, then even if we don't |
843 | // rewrite it, the PHI web we're considering won't have any users |
844 | // outside itself, so it'll be dead. |
845 | // example: |
846 | // bb.0: |
847 | // %0 = amxcast ... |
848 | // bb.1: |
849 | // %1 = amxcast ... |
850 | // bb.2: |
851 | // %goodphi = phi %0, %1 |
852 | // %3 = amxcast %goodphi |
853 | // bb.3: |
854 | // %goodphi2 = phi %0, %goodphi |
855 | // %4 = amxcast %goodphi2 |
856 | // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is |
857 | // outside the phi-web, so the combination stop When |
858 | // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization |
859 | // will be done. |
860 | if (OldPhiNodes.count(key: PHI) == 0) |
861 | return false; |
862 | } else |
863 | return false; |
864 | } |
865 | } |
866 | |
867 | // For each old PHI node, create a corresponding new PHI node with a type A. |
868 | SmallDenseMap<PHINode *, PHINode *> NewPNodes; |
869 | for (auto *OldPN : OldPhiNodes) { |
870 | Builder.SetInsertPoint(OldPN); |
871 | PHINode *NewPN = Builder.CreatePHI(Ty: DestTy, NumReservedValues: OldPN->getNumOperands()); |
872 | NewPNodes[OldPN] = NewPN; |
873 | } |
874 | |
875 | // Fill in the operands of new PHI nodes. |
876 | for (auto *OldPN : OldPhiNodes) { |
877 | PHINode *NewPN = NewPNodes[OldPN]; |
878 | for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) { |
879 | Value *V = OldPN->getOperand(i_nocapture: j); |
880 | Value *NewV = nullptr; |
881 | Instruction *ACI = dyn_cast<Instruction>(Val: V); |
882 | // There should not be a AMXcast from a const. |
883 | if (ACI && isAMXCast(II: ACI)) |
884 | NewV = ACI->getOperand(i: 0); |
885 | else if (auto *PrevPN = dyn_cast<PHINode>(Val: V)) |
886 | NewV = NewPNodes[PrevPN]; |
887 | assert(NewV); |
888 | NewPN->addIncoming(V: NewV, BB: OldPN->getIncomingBlock(i: j)); |
889 | } |
890 | } |
891 | |
892 | // Traverse all accumulated PHI nodes and process its users, |
893 | // which are Stores and BitcCasts. Without this processing |
894 | // NewPHI nodes could be replicated and could lead to extra |
895 | // moves generated after DeSSA. |
896 | // If there is a store with type B, change it to type A. |
897 | |
898 | // Replace users of BitCast B->A with NewPHI. These will help |
899 | // later to get rid of a closure formed by OldPHI nodes. |
900 | for (auto *OldPN : OldPhiNodes) { |
901 | PHINode *NewPN = NewPNodes[OldPN]; |
902 | for (User *V : make_early_inc_range(Range: OldPN->users())) { |
903 | Instruction *ACI = dyn_cast<Instruction>(Val: V); |
904 | if (ACI && isAMXCast(II: ACI)) { |
905 | Type *TyB = ACI->getOperand(i: 0)->getType(); |
906 | Type *TyA = ACI->getType(); |
907 | assert(TyA == DestTy && TyB == SrcTy); |
908 | (void)TyA; |
909 | (void)TyB; |
910 | ACI->replaceAllUsesWith(V: NewPN); |
911 | DeadInst.insert(X: ACI); |
912 | } else if (auto *PHI = dyn_cast<PHINode>(Val: V)) { |
913 | // We don't need to push PHINode into DeadInst since they are operands |
914 | // of rootPN DCE can safely delete rootPN's operands if rootPN is dead. |
915 | assert(OldPhiNodes.contains(PHI)); |
916 | (void)PHI; |
917 | } else |
918 | llvm_unreachable("all uses should be handled" ); |
919 | } |
920 | } |
921 | return true; |
922 | } |
923 | |
924 | // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42) |
925 | // store <256 x i32> %43, <256 x i32>* %p, align 64 |
926 | // --> |
927 | // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p, |
928 | // i64 64, x86_amx %42) |
929 | bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) { |
930 | Value *Tile = Cast->getOperand(i_nocapture: 0); |
931 | // TODO: If it is cast intrinsic or phi node, we can propagate the |
932 | // shape information through def-use chain. |
933 | if (!isAMXIntrinsic(I: Tile)) |
934 | return false; |
935 | auto *II = cast<IntrinsicInst>(Val: Tile); |
936 | // Tile is output from AMX intrinsic. The first operand of the |
937 | // intrinsic is row, the second operand of the intrinsic is column. |
938 | Value *Row = II->getOperand(i_nocapture: 0); |
939 | Value *Col = II->getOperand(i_nocapture: 1); |
940 | IRBuilder<> Builder(ST); |
941 | // Stride should be equal to col(measured by bytes) |
942 | Value *Stride = Builder.CreateSExt(V: Col, DestTy: Builder.getInt64Ty()); |
943 | Value *I8Ptr = Builder.CreateBitCast(V: ST->getOperand(i_nocapture: 1), DestTy: Builder.getPtrTy()); |
944 | std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile}; |
945 | Builder.CreateIntrinsic(ID: Intrinsic::x86_tilestored64_internal, Types: std::nullopt, |
946 | Args); |
947 | return true; |
948 | } |
949 | |
950 | // %65 = load <256 x i32>, <256 x i32>* %p, align 64 |
951 | // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65) |
952 | // --> |
953 | // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, |
954 | // i8* %p, i64 64) |
955 | bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) { |
956 | bool EraseLoad = true; |
957 | Value *Row = nullptr, *Col = nullptr; |
958 | Use &U = *(Cast->use_begin()); |
959 | unsigned OpNo = U.getOperandNo(); |
960 | auto *II = cast<IntrinsicInst>(Val: U.getUser()); |
961 | // TODO: If it is cast intrinsic or phi node, we can propagate the |
962 | // shape information through def-use chain. |
963 | if (!isAMXIntrinsic(I: II)) |
964 | return false; |
965 | std::tie(args&: Row, args&: Col) = getShape(II, OpNo); |
966 | IRBuilder<> Builder(LD); |
967 | // Stride should be equal to col(measured by bytes) |
968 | Value *Stride = Builder.CreateSExt(V: Col, DestTy: Builder.getInt64Ty()); |
969 | Value *I8Ptr; |
970 | |
971 | // To save compiling time, we create doninator tree when it is really |
972 | // needed. |
973 | if (!DT) |
974 | DT.reset(p: new DominatorTree(Func)); |
975 | if (!DT->dominates(Def: Row, User: LD) || !DT->dominates(Def: Col, User: LD)) { |
976 | // store the value to stack and reload it from stack before cast. |
977 | auto *AllocaAddr = |
978 | createAllocaInstAtEntry(Builder, BB: Cast->getParent(), Ty: LD->getType()); |
979 | Builder.SetInsertPoint(&*std::next(x: LD->getIterator())); |
980 | Builder.CreateStore(Val: LD, Ptr: AllocaAddr); |
981 | |
982 | Builder.SetInsertPoint(Cast); |
983 | I8Ptr = Builder.CreateBitCast(V: AllocaAddr, DestTy: Builder.getPtrTy()); |
984 | EraseLoad = false; |
985 | } else { |
986 | I8Ptr = Builder.CreateBitCast(V: LD->getOperand(i_nocapture: 0), DestTy: Builder.getPtrTy()); |
987 | } |
988 | std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; |
989 | |
990 | Value *NewInst = Builder.CreateIntrinsic(ID: Intrinsic::x86_tileloadd64_internal, |
991 | Types: std::nullopt, Args); |
992 | Cast->replaceAllUsesWith(V: NewInst); |
993 | |
994 | return EraseLoad; |
995 | } |
996 | |
997 | bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) { |
998 | bool Change = false; |
999 | for (auto *Cast : Casts) { |
1000 | auto *II = cast<IntrinsicInst>(Val: Cast); |
1001 | // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42) |
1002 | // store <256 x i32> %43, <256 x i32>* %p, align 64 |
1003 | // --> |
1004 | // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p, |
1005 | // i64 64, x86_amx %42) |
1006 | if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) { |
1007 | SmallVector<Instruction *, 2> DeadStores; |
1008 | for (User *U : Cast->users()) { |
1009 | StoreInst *Store = dyn_cast<StoreInst>(Val: U); |
1010 | if (!Store) |
1011 | continue; |
1012 | if (combineCastStore(Cast: cast<IntrinsicInst>(Val: Cast), ST: Store)) { |
1013 | DeadStores.push_back(Elt: Store); |
1014 | Change = true; |
1015 | } |
1016 | } |
1017 | for (auto *Store : DeadStores) |
1018 | Store->eraseFromParent(); |
1019 | } else { // x86_cast_vector_to_tile |
1020 | SmallVector<Instruction *, 2> DeadLoads; |
1021 | auto *Load = dyn_cast<LoadInst>(Val: Cast->getOperand(i: 0)); |
1022 | if (!Load || !Load->hasOneUse()) |
1023 | continue; |
1024 | // %65 = load <256 x i32>, <256 x i32>* %p, align 64 |
1025 | // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65) |
1026 | // --> |
1027 | // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, |
1028 | // i8* %p, i64 64) |
1029 | if (combineLoadCast(Cast: cast<IntrinsicInst>(Val: Cast), LD: Load)) { |
1030 | // Set the operand is null so that load instruction can be erased. |
1031 | Cast->setOperand(i: 0, Val: nullptr); |
1032 | Load->eraseFromParent(); |
1033 | } |
1034 | } |
1035 | } |
1036 | return Change; |
1037 | } |
1038 | |
1039 | bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) { |
1040 | bool Change = false; |
1041 | // Collect tile cast instruction. |
1042 | SmallVector<Instruction *, 8> Vec2TileInsts; |
1043 | SmallVector<Instruction *, 8> Tile2VecInsts; |
1044 | SmallVector<Instruction *, 8> PhiCastWorkList; |
1045 | SmallSetVector<Instruction *, 16> DeadInst; |
1046 | for (BasicBlock &BB : Func) { |
1047 | for (Instruction &I : BB) { |
1048 | Value *Vec; |
1049 | if (match(V: &I, |
1050 | P: m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(Op0: m_Value(V&: Vec)))) |
1051 | Vec2TileInsts.push_back(Elt: &I); |
1052 | else if (match(V: &I, P: m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>( |
1053 | Op0: m_Value(V&: Vec)))) |
1054 | Tile2VecInsts.push_back(Elt: &I); |
1055 | } |
1056 | } |
1057 | |
1058 | auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) { |
1059 | for (auto *Inst : Insts) { |
1060 | for (User *U : Inst->users()) { |
1061 | IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: U); |
1062 | if (!II || II->getIntrinsicID() != IID) |
1063 | continue; |
1064 | // T1 = vec2tile V0 |
1065 | // V2 = tile2vec T1 |
1066 | // V3 = OP V2 |
1067 | // --> |
1068 | // T1 = vec2tile V0 |
1069 | // V2 = tile2vec T1 |
1070 | // V3 = OP V0 |
1071 | II->replaceAllUsesWith(V: Inst->getOperand(i: 0)); |
1072 | Change = true; |
1073 | } |
1074 | } |
1075 | }; |
1076 | |
1077 | Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector); |
1078 | Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile); |
1079 | |
1080 | SmallVector<Instruction *, 8> LiveCasts; |
1081 | auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) { |
1082 | for (auto *Inst : Insts) { |
1083 | if (Inst->use_empty()) { |
1084 | Inst->eraseFromParent(); |
1085 | Change = true; |
1086 | } else { |
1087 | LiveCasts.push_back(Elt: Inst); |
1088 | } |
1089 | } |
1090 | }; |
1091 | |
1092 | EraseInst(Vec2TileInsts); |
1093 | EraseInst(Tile2VecInsts); |
1094 | LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine " |
1095 | "Vec2Tile and Tile2Vec:\n" ; |
1096 | Func.dump()); |
1097 | Change |= combineLdSt(Casts&: LiveCasts); |
1098 | EraseInst(LiveCasts); |
1099 | LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine " |
1100 | "AMXCast and load/store:\n" ; |
1101 | Func.dump()); |
1102 | |
1103 | // Handle the A->B->A cast, and there is an intervening PHI node. |
1104 | for (BasicBlock &BB : Func) { |
1105 | for (Instruction &I : BB) { |
1106 | if (isAMXCast(II: &I)) { |
1107 | if (isa<PHINode>(Val: I.getOperand(i: 0))) |
1108 | PhiCastWorkList.push_back(Elt: &I); |
1109 | } |
1110 | } |
1111 | } |
1112 | for (auto *I : PhiCastWorkList) { |
1113 | // We skip the dead Amxcast. |
1114 | if (DeadInst.contains(key: I)) |
1115 | continue; |
1116 | PHINode *PN = cast<PHINode>(Val: I->getOperand(i: 0)); |
1117 | if (optimizeAMXCastFromPhi(CI: cast<IntrinsicInst>(Val: I), PN, DeadInst)) { |
1118 | DeadInst.insert(X: PN); |
1119 | Change = true; |
1120 | } |
1121 | } |
1122 | |
1123 | // Since we create new phi and merge AMXCast, some old phis and AMXCast might |
1124 | // have no uses. We do some DeadCodeElimination for them. |
1125 | while (!DeadInst.empty()) { |
1126 | Instruction *I = DeadInst.pop_back_val(); |
1127 | Change |= DCEInstruction(I, WorkList&: DeadInst, TLI); |
1128 | } |
1129 | LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after " |
1130 | "optimizeAMXCastFromPhi:\n" ; |
1131 | Func.dump()); |
1132 | return Change; |
1133 | } |
1134 | |
1135 | // There might be remaining AMXcast after combineAMXcast and they should be |
1136 | // handled elegantly. |
1137 | bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) { |
1138 | IRBuilder<> Builder(AMXCast); |
1139 | AllocaInst *AllocaAddr; |
1140 | Value *I8Ptr, *Stride; |
1141 | auto *Src = AMXCast->getOperand(i_nocapture: 0); |
1142 | |
1143 | auto Prepare = [&](Type *MemTy) { |
1144 | AllocaAddr = createAllocaInstAtEntry(Builder, BB: AMXCast->getParent(), Ty: MemTy); |
1145 | I8Ptr = Builder.CreateBitCast(V: AllocaAddr, DestTy: Builder.getPtrTy()); |
1146 | Stride = Builder.getInt64(C: 64); |
1147 | }; |
1148 | |
1149 | if (AMXCast->getType()->isX86_AMXTy()) { |
1150 | // %2 = amxcast <225 x i32> %src to x86_amx |
1151 | // call void @llvm.x86.tilestored64.internal(i16 15, i16 60, |
1152 | // i8* %addr3, i64 60, x86_amx %2) |
1153 | // --> |
1154 | // %addr = alloca <225 x i32>, align 64 |
1155 | // store <225 x i32> %src, <225 x i32>* %addr, align 64 |
1156 | // %addr2 = bitcast <225 x i32>* %addr to i8* |
1157 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60, |
1158 | // i8* %addr2, |
1159 | // i64 60) |
1160 | // call void @llvm.x86.tilestored64.internal(i16 15, i16 60, |
1161 | // i8* %addr3, i64 60, x86_amx %2) |
1162 | if (AMXCast->use_empty()) { |
1163 | AMXCast->eraseFromParent(); |
1164 | return true; |
1165 | } |
1166 | Use &U = *(AMXCast->use_begin()); |
1167 | unsigned OpNo = U.getOperandNo(); |
1168 | auto *II = dyn_cast<IntrinsicInst>(Val: U.getUser()); |
1169 | if (!II) |
1170 | return false; // May be bitcast from x86amx to <256 x i32>. |
1171 | Prepare(AMXCast->getOperand(i_nocapture: 0)->getType()); |
1172 | Builder.CreateStore(Val: Src, Ptr: AllocaAddr); |
1173 | // TODO we can pick an constant operand for the shape. |
1174 | Value *Row = nullptr, *Col = nullptr; |
1175 | std::tie(args&: Row, args&: Col) = getShape(II, OpNo); |
1176 | std::array<Value *, 4> Args = { |
1177 | Row, Col, I8Ptr, Builder.CreateSExt(V: Col, DestTy: Builder.getInt64Ty())}; |
1178 | Value *NewInst = Builder.CreateIntrinsic( |
1179 | ID: Intrinsic::x86_tileloadd64_internal, Types: std::nullopt, Args); |
1180 | AMXCast->replaceAllUsesWith(V: NewInst); |
1181 | AMXCast->eraseFromParent(); |
1182 | } else { |
1183 | // %2 = amxcast x86_amx %src to <225 x i32> |
1184 | // --> |
1185 | // %addr = alloca <225 x i32>, align 64 |
1186 | // %addr2 = bitcast <225 x i32>* to i8* |
1187 | // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, |
1188 | // i8* %addr2, i64 %stride) |
1189 | // %2 = load <225 x i32>, <225 x i32>* %addr, align 64 |
1190 | auto *II = dyn_cast<IntrinsicInst>(Val: Src); |
1191 | if (!II) |
1192 | return false; // May be bitcast from <256 x i32> to x86amx. |
1193 | Prepare(AMXCast->getType()); |
1194 | Value *Row = II->getOperand(i_nocapture: 0); |
1195 | Value *Col = II->getOperand(i_nocapture: 1); |
1196 | std::array<Value *, 5> Args = { |
1197 | Row, Col, I8Ptr, Builder.CreateSExt(V: Col, DestTy: Builder.getInt64Ty()), Src}; |
1198 | Builder.CreateIntrinsic(ID: Intrinsic::x86_tilestored64_internal, Types: std::nullopt, |
1199 | Args); |
1200 | Value *NewInst = Builder.CreateLoad(Ty: AMXCast->getType(), Ptr: AllocaAddr); |
1201 | AMXCast->replaceAllUsesWith(V: NewInst); |
1202 | AMXCast->eraseFromParent(); |
1203 | } |
1204 | |
1205 | return true; |
1206 | } |
1207 | |
1208 | bool X86LowerAMXCast::transformAllAMXCast() { |
1209 | bool Change = false; |
1210 | // Collect tile cast instruction. |
1211 | SmallVector<Instruction *, 8> WorkLists; |
1212 | for (BasicBlock &BB : Func) { |
1213 | for (Instruction &I : BB) { |
1214 | if (isAMXCast(II: &I)) |
1215 | WorkLists.push_back(Elt: &I); |
1216 | } |
1217 | } |
1218 | |
1219 | for (auto *Inst : WorkLists) { |
1220 | Change |= transformAMXCast(AMXCast: cast<IntrinsicInst>(Val: Inst)); |
1221 | } |
1222 | |
1223 | return Change; |
1224 | } |
1225 | |
1226 | } // anonymous namespace |
1227 | |
1228 | namespace { |
1229 | |
1230 | class X86LowerAMXTypeLegacyPass : public FunctionPass { |
1231 | public: |
1232 | static char ID; |
1233 | |
1234 | X86LowerAMXTypeLegacyPass() : FunctionPass(ID) { |
1235 | initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry()); |
1236 | } |
1237 | |
1238 | bool runOnFunction(Function &F) override { |
1239 | // Performance optimization: most code doesn't use AMX, so return early if |
1240 | // there are no instructions that produce AMX values. This is sufficient, as |
1241 | // AMX arguments and constants are not allowed -- so any producer of an AMX |
1242 | // value must be an instruction. |
1243 | // TODO: find a cheaper way for this, without looking at all instructions. |
1244 | if (!containsAMXCode(F)) |
1245 | return false; |
1246 | |
1247 | bool C = false; |
1248 | TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); |
1249 | TargetLibraryInfo *TLI = |
1250 | &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); |
1251 | |
1252 | X86LowerAMXCast LAC(F); |
1253 | C |= LAC.combineAMXcast(TLI); |
1254 | // There might be remaining AMXcast after combineAMXcast and they should be |
1255 | // handled elegantly. |
1256 | C |= LAC.transformAllAMXCast(); |
1257 | |
1258 | X86LowerAMXType LAT(F); |
1259 | C |= LAT.visit(); |
1260 | |
1261 | // Prepare for fast register allocation at O0. |
1262 | // Todo: May better check the volatile model of AMX code, not just |
1263 | // by checking Attribute::OptimizeNone and CodeGenOptLevel::None. |
1264 | if (TM->getOptLevel() == CodeGenOptLevel::None) { |
1265 | // If Front End not use O0 but the Mid/Back end use O0, (e.g. |
1266 | // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make |
1267 | // sure the amx data is volatile, that is nessary for AMX fast |
1268 | // register allocation. |
1269 | if (!F.hasFnAttribute(Kind: Attribute::OptimizeNone)) { |
1270 | X86VolatileTileData VTD(F); |
1271 | C = VTD.volatileTileData() || C; |
1272 | } |
1273 | } |
1274 | |
1275 | return C; |
1276 | } |
1277 | |
1278 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
1279 | AU.setPreservesCFG(); |
1280 | AU.addRequired<TargetPassConfig>(); |
1281 | AU.addRequired<TargetLibraryInfoWrapperPass>(); |
1282 | } |
1283 | }; |
1284 | |
1285 | } // anonymous namespace |
1286 | |
1287 | static const char PassName[] = "Lower AMX type for load/store" ; |
1288 | char X86LowerAMXTypeLegacyPass::ID = 0; |
1289 | INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, |
1290 | false) |
1291 | INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) |
1292 | INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) |
1293 | INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, |
1294 | false) |
1295 | |
1296 | FunctionPass *llvm::createX86LowerAMXTypePass() { |
1297 | return new X86LowerAMXTypeLegacyPass(); |
1298 | } |
1299 | |