| 1 | //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | /// \file Pass to transform <256 x i32> load/store |
| 10 | /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only |
| 11 | /// provides simple operation on x86_amx. The basic elementwise operation |
| 12 | /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32> |
| 13 | /// and only AMX intrinsics can operate on the type, we need transform |
| 14 | /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can |
| 15 | /// not be combined with load/store, we transform the bitcast to amx load/store |
| 16 | /// and <256 x i32> store/load. |
| 17 | /// |
| 18 | /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S |
| 19 | /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile, |
| 20 | /// because that is necessary for AMX fast register allocation. (In Fast |
| 21 | /// registera allocation, register will be allocated before spill/reload, so |
| 22 | /// there is no additional register for amx to identify the step in spill.) |
| 23 | /// The volatileTileData() will handle this case. |
| 24 | /// e.g. |
| 25 | /// ---------------------------------------------------------- |
| 26 | /// | def %td = ... | |
| 27 | /// | ... | |
| 28 | /// | "use %td" | |
| 29 | /// ---------------------------------------------------------- |
| 30 | /// will transfer to --> |
| 31 | /// ---------------------------------------------------------- |
| 32 | /// | def %td = ... | |
| 33 | /// | call void @llvm.x86.tilestored64.internal(mem, %td) | |
| 34 | /// | ... | |
| 35 | /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)| |
| 36 | /// | "use %td2" | |
| 37 | /// ---------------------------------------------------------- |
| 38 | // |
| 39 | //===----------------------------------------------------------------------===// |
| 40 | // |
| 41 | #include "X86.h" |
| 42 | #include "llvm/ADT/PostOrderIterator.h" |
| 43 | #include "llvm/ADT/SetVector.h" |
| 44 | #include "llvm/Analysis/TargetLibraryInfo.h" |
| 45 | #include "llvm/Analysis/TargetTransformInfo.h" |
| 46 | #include "llvm/CodeGen/Passes.h" |
| 47 | #include "llvm/CodeGen/TargetPassConfig.h" |
| 48 | #include "llvm/CodeGen/ValueTypes.h" |
| 49 | #include "llvm/IR/Analysis.h" |
| 50 | #include "llvm/IR/DataLayout.h" |
| 51 | #include "llvm/IR/Function.h" |
| 52 | #include "llvm/IR/IRBuilder.h" |
| 53 | #include "llvm/IR/Instructions.h" |
| 54 | #include "llvm/IR/IntrinsicInst.h" |
| 55 | #include "llvm/IR/IntrinsicsX86.h" |
| 56 | #include "llvm/IR/PassManager.h" |
| 57 | #include "llvm/IR/PatternMatch.h" |
| 58 | #include "llvm/InitializePasses.h" |
| 59 | #include "llvm/Pass.h" |
| 60 | #include "llvm/Target/TargetMachine.h" |
| 61 | #include "llvm/Transforms/Utils/AssumeBundleBuilder.h" |
| 62 | #include "llvm/Transforms/Utils/Local.h" |
| 63 | |
| 64 | #include <map> |
| 65 | |
| 66 | using namespace llvm; |
| 67 | using namespace PatternMatch; |
| 68 | |
| 69 | #define DEBUG_TYPE "x86-lower-amx-type" |
| 70 | |
| 71 | static bool isAMXCast(Instruction *II) { |
| 72 | return match(V: II, |
| 73 | P: m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(Op0: m_Value())) || |
| 74 | match(V: II, P: m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(Op0: m_Value())); |
| 75 | } |
| 76 | |
| 77 | static bool isAMXIntrinsic(Value *I) { |
| 78 | auto *II = dyn_cast<IntrinsicInst>(Val: I); |
| 79 | if (!II) |
| 80 | return false; |
| 81 | if (isAMXCast(II)) |
| 82 | return false; |
| 83 | // Check if return type or parameter is x86_amx. If it is x86_amx |
| 84 | // the intrinsic must be x86 amx intrinsics. |
| 85 | if (II->getType()->isX86_AMXTy()) |
| 86 | return true; |
| 87 | for (Value *V : II->args()) { |
| 88 | if (V->getType()->isX86_AMXTy()) |
| 89 | return true; |
| 90 | } |
| 91 | |
| 92 | return false; |
| 93 | } |
| 94 | |
| 95 | static bool containsAMXCode(Function &F) { |
| 96 | for (BasicBlock &BB : F) |
| 97 | for (Instruction &I : BB) |
| 98 | if (I.getType()->isX86_AMXTy()) |
| 99 | return true; |
| 100 | return false; |
| 101 | } |
| 102 | |
| 103 | static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, |
| 104 | Type *Ty) { |
| 105 | Function &F = *BB->getParent(); |
| 106 | const DataLayout &DL = F.getDataLayout(); |
| 107 | |
| 108 | LLVMContext &Ctx = Builder.getContext(); |
| 109 | auto AllocaAlignment = DL.getPrefTypeAlign(Ty: Type::getX86_AMXTy(C&: Ctx)); |
| 110 | unsigned AllocaAS = DL.getAllocaAddrSpace(); |
| 111 | AllocaInst *AllocaRes = |
| 112 | new AllocaInst(Ty, AllocaAS, "" , F.getEntryBlock().begin()); |
| 113 | AllocaRes->setAlignment(AllocaAlignment); |
| 114 | return AllocaRes; |
| 115 | } |
| 116 | |
| 117 | static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) { |
| 118 | for (Instruction &I : F.getEntryBlock()) |
| 119 | if (!isa<AllocaInst>(Val: &I)) |
| 120 | return &I; |
| 121 | llvm_unreachable("No terminator in the entry block!" ); |
| 122 | } |
| 123 | |
| 124 | static Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity) { |
| 125 | IRBuilder<> Builder(II); |
| 126 | Value *RealRow = nullptr; |
| 127 | if (isa<ConstantInt>(Val: V)) |
| 128 | RealRow = |
| 129 | Builder.getInt16(C: (cast<ConstantInt>(Val: V)->getSExtValue()) / Granularity); |
| 130 | else if (isa<Instruction>(Val: V)) { |
| 131 | // When it is not a const value and it is not a function argument, we |
| 132 | // create Row after the definition of V instead of |
| 133 | // before II. For example, II is %118, we try to getshape for %117: |
| 134 | // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x |
| 135 | // i32> %115). |
| 136 | // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16 |
| 137 | // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx |
| 138 | // %117). |
| 139 | // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its |
| 140 | // definition is after its user(new tileload for %117). |
| 141 | // So, the best choice is to create %row right after the definition of |
| 142 | // %106. |
| 143 | Builder.SetInsertPoint(cast<Instruction>(Val: V)); |
| 144 | RealRow = Builder.CreateUDiv(LHS: V, RHS: Builder.getInt16(C: 4)); |
| 145 | cast<Instruction>(Val: RealRow)->moveAfter(MovePos: cast<Instruction>(Val: V)); |
| 146 | } else { |
| 147 | // When it is not a const value and it is a function argument, we create |
| 148 | // Row at the entry bb. |
| 149 | IRBuilder<> NewBuilder( |
| 150 | getFirstNonAllocaInTheEntryBlock(F&: *II->getFunction())); |
| 151 | RealRow = NewBuilder.CreateUDiv(LHS: V, RHS: NewBuilder.getInt16(C: Granularity)); |
| 152 | } |
| 153 | return RealRow; |
| 154 | } |
| 155 | |
| 156 | // TODO: Refine the row and col-in-bytes of tile to row and col of matrix. |
| 157 | std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) { |
| 158 | IRBuilder<> Builder(II); |
| 159 | Value *Row = nullptr, *Col = nullptr; |
| 160 | switch (II->getIntrinsicID()) { |
| 161 | default: |
| 162 | llvm_unreachable("Expect amx intrinsics" ); |
| 163 | case Intrinsic::x86_tileloadd64_internal: |
| 164 | case Intrinsic::x86_tileloaddt164_internal: |
| 165 | case Intrinsic::x86_tilestored64_internal: |
| 166 | case Intrinsic::x86_tileloaddrs64_internal: |
| 167 | case Intrinsic::x86_tileloaddrst164_internal: { |
| 168 | Row = II->getArgOperand(i: 0); |
| 169 | Col = II->getArgOperand(i: 1); |
| 170 | break; |
| 171 | } |
| 172 | // a * b + c |
| 173 | // The shape depends on which operand. |
| 174 | case Intrinsic::x86_tcmmimfp16ps_internal: |
| 175 | case Intrinsic::x86_tcmmrlfp16ps_internal: |
| 176 | case Intrinsic::x86_tdpbssd_internal: |
| 177 | case Intrinsic::x86_tdpbsud_internal: |
| 178 | case Intrinsic::x86_tdpbusd_internal: |
| 179 | case Intrinsic::x86_tdpbuud_internal: |
| 180 | case Intrinsic::x86_tdpbf16ps_internal: |
| 181 | case Intrinsic::x86_tdpfp16ps_internal: |
| 182 | case Intrinsic::x86_tmmultf32ps_internal: |
| 183 | case Intrinsic::x86_tdpbf8ps_internal: |
| 184 | case Intrinsic::x86_tdpbhf8ps_internal: |
| 185 | case Intrinsic::x86_tdphbf8ps_internal: |
| 186 | case Intrinsic::x86_tdphf8ps_internal: { |
| 187 | switch (OpNo) { |
| 188 | case 3: |
| 189 | Row = II->getArgOperand(i: 0); |
| 190 | Col = II->getArgOperand(i: 1); |
| 191 | break; |
| 192 | case 4: |
| 193 | Row = II->getArgOperand(i: 0); |
| 194 | Col = II->getArgOperand(i: 2); |
| 195 | break; |
| 196 | case 5: |
| 197 | Row = getRowFromCol(II, V: II->getArgOperand(i: 2), Granularity: 4); |
| 198 | Col = II->getArgOperand(i: 1); |
| 199 | break; |
| 200 | } |
| 201 | break; |
| 202 | } |
| 203 | case Intrinsic::x86_tcvtrowd2ps_internal: |
| 204 | case Intrinsic::x86_tcvtrowps2bf16h_internal: |
| 205 | case Intrinsic::x86_tcvtrowps2bf16l_internal: |
| 206 | case Intrinsic::x86_tcvtrowps2phh_internal: |
| 207 | case Intrinsic::x86_tcvtrowps2phl_internal: |
| 208 | case Intrinsic::x86_tilemovrow_internal: { |
| 209 | assert(OpNo == 2 && "Illegal Operand Number." ); |
| 210 | Row = II->getArgOperand(i: 0); |
| 211 | Col = II->getArgOperand(i: 1); |
| 212 | break; |
| 213 | } |
| 214 | } |
| 215 | |
| 216 | return std::make_pair(x&: Row, y&: Col); |
| 217 | } |
| 218 | |
| 219 | static std::pair<Value *, Value *> getShape(PHINode *Phi) { |
| 220 | Use &U = *(Phi->use_begin()); |
| 221 | unsigned OpNo = U.getOperandNo(); |
| 222 | User *V = U.getUser(); |
| 223 | // TODO We don't traverse all users. To make the algorithm simple, here we |
| 224 | // just traverse the first user. If we can find shape, then return the shape, |
| 225 | // otherwise just return nullptr and the optimization for undef/zero will be |
| 226 | // abandoned. |
| 227 | while (V) { |
| 228 | if (isAMXCast(II: dyn_cast<Instruction>(Val: V))) { |
| 229 | if (V->use_empty()) |
| 230 | break; |
| 231 | Use &U = *(V->use_begin()); |
| 232 | OpNo = U.getOperandNo(); |
| 233 | V = U.getUser(); |
| 234 | } else if (isAMXIntrinsic(I: V)) { |
| 235 | return getShape(II: cast<IntrinsicInst>(Val: V), OpNo); |
| 236 | } else if (isa<PHINode>(Val: V)) { |
| 237 | if (V->use_empty()) |
| 238 | break; |
| 239 | Use &U = *(V->use_begin()); |
| 240 | V = U.getUser(); |
| 241 | } else { |
| 242 | break; |
| 243 | } |
| 244 | } |
| 245 | |
| 246 | return std::make_pair(x: nullptr, y: nullptr); |
| 247 | } |
| 248 | |
| 249 | namespace { |
| 250 | class X86LowerAMXType { |
| 251 | Function &Func; |
| 252 | |
| 253 | // In AMX intrinsics we let Shape = {Row, Col}, but the |
| 254 | // RealCol = Col / ElementSize. We may use the RealCol |
| 255 | // as a new Row for other new created AMX intrinsics. |
| 256 | std::map<Value *, Value *> Col2Row; |
| 257 | |
| 258 | public: |
| 259 | X86LowerAMXType(Function &F) : Func(F) {} |
| 260 | bool visit(); |
| 261 | void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast); |
| 262 | void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST); |
| 263 | bool transformBitcast(BitCastInst *Bitcast); |
| 264 | }; |
| 265 | |
| 266 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 |
| 267 | // %2 = bitcast <256 x i32> %src to x86_amx |
| 268 | // --> |
| 269 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, |
| 270 | // i8* %addr, i64 %stride64) |
| 271 | void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) { |
| 272 | Value *Row = nullptr, *Col = nullptr; |
| 273 | Use &U = *(Bitcast->use_begin()); |
| 274 | unsigned OpNo = U.getOperandNo(); |
| 275 | auto *II = cast<IntrinsicInst>(Val: U.getUser()); |
| 276 | std::tie(args&: Row, args&: Col) = getShape(II, OpNo); |
| 277 | IRBuilder<> Builder(Bitcast); |
| 278 | // Use the maximun column as stride. |
| 279 | Value *Stride = Builder.getInt64(C: 64); |
| 280 | Value *I8Ptr = LD->getOperand(i_nocapture: 0); |
| 281 | std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; |
| 282 | |
| 283 | Value *NewInst = |
| 284 | Builder.CreateIntrinsic(ID: Intrinsic::x86_tileloadd64_internal, Args); |
| 285 | Bitcast->replaceAllUsesWith(V: NewInst); |
| 286 | } |
| 287 | |
| 288 | // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, |
| 289 | // %stride); |
| 290 | // %13 = bitcast x86_amx %src to <256 x i32> |
| 291 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 |
| 292 | // --> |
| 293 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, |
| 294 | // %stride64, %13) |
| 295 | void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) { |
| 296 | |
| 297 | Value *Tile = Bitcast->getOperand(i_nocapture: 0); |
| 298 | auto *II = cast<IntrinsicInst>(Val: Tile); |
| 299 | // Tile is output from AMX intrinsic. The first operand of the |
| 300 | // intrinsic is row, the second operand of the intrinsic is column. |
| 301 | Value *Row = II->getOperand(i_nocapture: 0); |
| 302 | Value *Col = II->getOperand(i_nocapture: 1); |
| 303 | IRBuilder<> Builder(ST); |
| 304 | // Use the maximum column as stride. It must be the same with load |
| 305 | // stride. |
| 306 | Value *Stride = Builder.getInt64(C: 64); |
| 307 | Value *I8Ptr = ST->getOperand(i_nocapture: 1); |
| 308 | std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile}; |
| 309 | Builder.CreateIntrinsic(ID: Intrinsic::x86_tilestored64_internal, Args); |
| 310 | if (Bitcast->hasOneUse()) |
| 311 | return; |
| 312 | // %13 = bitcast x86_amx %src to <256 x i32> |
| 313 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 |
| 314 | // %add = <256 x i32> %13, <256 x i32> %src2 |
| 315 | // --> |
| 316 | // %13 = bitcast x86_amx %src to <256 x i32> |
| 317 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, |
| 318 | // %stride64, %13) |
| 319 | // %14 = load <256 x i32>, %addr |
| 320 | // %add = <256 x i32> %14, <256 x i32> %src2 |
| 321 | Value *Vec = Builder.CreateLoad(Ty: Bitcast->getType(), Ptr: ST->getOperand(i_nocapture: 1)); |
| 322 | Bitcast->replaceAllUsesWith(V: Vec); |
| 323 | } |
| 324 | |
| 325 | // transform bitcast to <store, load> instructions. |
| 326 | bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) { |
| 327 | IRBuilder<> Builder(Bitcast); |
| 328 | AllocaInst *AllocaAddr; |
| 329 | Value *I8Ptr, *Stride; |
| 330 | auto *Src = Bitcast->getOperand(i_nocapture: 0); |
| 331 | |
| 332 | auto Prepare = [&](Type *MemTy) { |
| 333 | AllocaAddr = createAllocaInstAtEntry(Builder, BB: Bitcast->getParent(), Ty: MemTy); |
| 334 | I8Ptr = AllocaAddr; |
| 335 | Stride = Builder.getInt64(C: 64); |
| 336 | }; |
| 337 | |
| 338 | if (Bitcast->getType()->isX86_AMXTy()) { |
| 339 | // %2 = bitcast <256 x i32> %src to x86_amx |
| 340 | // --> |
| 341 | // %addr = alloca <256 x i32>, align 64 |
| 342 | // store <256 x i32> %src, <256 x i32>* %addr, align 64 |
| 343 | // %addr2 = bitcast <256 x i32>* to i8* |
| 344 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, |
| 345 | // i8* %addr2, |
| 346 | // i64 64) |
| 347 | Use &U = *(Bitcast->use_begin()); |
| 348 | unsigned OpNo = U.getOperandNo(); |
| 349 | auto *II = dyn_cast<IntrinsicInst>(Val: U.getUser()); |
| 350 | if (!II) |
| 351 | return false; // May be bitcast from x86amx to <256 x i32>. |
| 352 | Prepare(Bitcast->getOperand(i_nocapture: 0)->getType()); |
| 353 | Builder.CreateStore(Val: Src, Ptr: AllocaAddr); |
| 354 | // TODO we can pick an constant operand for the shape. |
| 355 | Value *Row = nullptr, *Col = nullptr; |
| 356 | std::tie(args&: Row, args&: Col) = getShape(II, OpNo); |
| 357 | std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; |
| 358 | Value *NewInst = |
| 359 | Builder.CreateIntrinsic(ID: Intrinsic::x86_tileloadd64_internal, Args); |
| 360 | Bitcast->replaceAllUsesWith(V: NewInst); |
| 361 | } else { |
| 362 | // %2 = bitcast x86_amx %src to <256 x i32> |
| 363 | // --> |
| 364 | // %addr = alloca <256 x i32>, align 64 |
| 365 | // %addr2 = bitcast <256 x i32>* to i8* |
| 366 | // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, |
| 367 | // i8* %addr2, i64 %stride) |
| 368 | // %2 = load <256 x i32>, <256 x i32>* %addr, align 64 |
| 369 | auto *II = dyn_cast<IntrinsicInst>(Val: Src); |
| 370 | if (!II) |
| 371 | return false; // May be bitcast from <256 x i32> to x86amx. |
| 372 | Prepare(Bitcast->getType()); |
| 373 | Value *Row = II->getOperand(i_nocapture: 0); |
| 374 | Value *Col = II->getOperand(i_nocapture: 1); |
| 375 | std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src}; |
| 376 | Builder.CreateIntrinsic(ID: Intrinsic::x86_tilestored64_internal, Args); |
| 377 | Value *NewInst = Builder.CreateLoad(Ty: Bitcast->getType(), Ptr: AllocaAddr); |
| 378 | Bitcast->replaceAllUsesWith(V: NewInst); |
| 379 | } |
| 380 | |
| 381 | return true; |
| 382 | } |
| 383 | |
| 384 | bool X86LowerAMXType::visit() { |
| 385 | SmallVector<Instruction *, 8> DeadInsts; |
| 386 | Col2Row.clear(); |
| 387 | |
| 388 | for (BasicBlock *BB : post_order(G: &Func)) { |
| 389 | for (Instruction &Inst : llvm::make_early_inc_range(Range: llvm::reverse(C&: *BB))) { |
| 390 | auto *Bitcast = dyn_cast<BitCastInst>(Val: &Inst); |
| 391 | if (!Bitcast) |
| 392 | continue; |
| 393 | |
| 394 | Value *Src = Bitcast->getOperand(i_nocapture: 0); |
| 395 | if (Bitcast->getType()->isX86_AMXTy()) { |
| 396 | if (Bitcast->user_empty()) { |
| 397 | DeadInsts.push_back(Elt: Bitcast); |
| 398 | continue; |
| 399 | } |
| 400 | LoadInst *LD = dyn_cast<LoadInst>(Val: Src); |
| 401 | if (!LD) { |
| 402 | if (transformBitcast(Bitcast)) |
| 403 | DeadInsts.push_back(Elt: Bitcast); |
| 404 | continue; |
| 405 | } |
| 406 | // If load has multi-user, duplicate a vector load. |
| 407 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 |
| 408 | // %2 = bitcast <256 x i32> %src to x86_amx |
| 409 | // %add = add <256 x i32> %src, <256 x i32> %src2 |
| 410 | // --> |
| 411 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 |
| 412 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, |
| 413 | // i8* %addr, i64 %stride64) |
| 414 | // %add = add <256 x i32> %src, <256 x i32> %src2 |
| 415 | |
| 416 | // If load has one user, the load will be eliminated in DAG ISel. |
| 417 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 |
| 418 | // %2 = bitcast <256 x i32> %src to x86_amx |
| 419 | // --> |
| 420 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, |
| 421 | // i8* %addr, i64 %stride64) |
| 422 | combineLoadBitcast(LD, Bitcast); |
| 423 | DeadInsts.push_back(Elt: Bitcast); |
| 424 | if (LD->hasOneUse()) |
| 425 | DeadInsts.push_back(Elt: LD); |
| 426 | } else if (Src->getType()->isX86_AMXTy()) { |
| 427 | if (Bitcast->user_empty()) { |
| 428 | DeadInsts.push_back(Elt: Bitcast); |
| 429 | continue; |
| 430 | } |
| 431 | StoreInst *ST = nullptr; |
| 432 | for (Use &U : Bitcast->uses()) { |
| 433 | ST = dyn_cast<StoreInst>(Val: U.getUser()); |
| 434 | if (ST) |
| 435 | break; |
| 436 | } |
| 437 | if (!ST) { |
| 438 | if (transformBitcast(Bitcast)) |
| 439 | DeadInsts.push_back(Elt: Bitcast); |
| 440 | continue; |
| 441 | } |
| 442 | // If bitcast (%13) has one use, combine bitcast and store to amx store. |
| 443 | // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, |
| 444 | // %stride); |
| 445 | // %13 = bitcast x86_amx %src to <256 x i32> |
| 446 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 |
| 447 | // --> |
| 448 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, |
| 449 | // %stride64, %13) |
| 450 | // |
| 451 | // If bitcast (%13) has multi-use, transform as below. |
| 452 | // %13 = bitcast x86_amx %src to <256 x i32> |
| 453 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 |
| 454 | // %add = <256 x i32> %13, <256 x i32> %src2 |
| 455 | // --> |
| 456 | // %13 = bitcast x86_amx %src to <256 x i32> |
| 457 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, |
| 458 | // %stride64, %13) |
| 459 | // %14 = load <256 x i32>, %addr |
| 460 | // %add = <256 x i32> %14, <256 x i32> %src2 |
| 461 | // |
| 462 | combineBitcastStore(Bitcast, ST); |
| 463 | // Delete user first. |
| 464 | DeadInsts.push_back(Elt: ST); |
| 465 | DeadInsts.push_back(Elt: Bitcast); |
| 466 | } |
| 467 | } |
| 468 | } |
| 469 | |
| 470 | bool C = !DeadInsts.empty(); |
| 471 | |
| 472 | for (auto *Inst : DeadInsts) |
| 473 | Inst->eraseFromParent(); |
| 474 | |
| 475 | return C; |
| 476 | } |
| 477 | } // anonymous namespace |
| 478 | |
| 479 | static Value *getAllocaPos(BasicBlock *BB) { |
| 480 | Function *F = BB->getParent(); |
| 481 | IRBuilder<> Builder(&F->getEntryBlock().front()); |
| 482 | const DataLayout &DL = F->getDataLayout(); |
| 483 | unsigned AllocaAS = DL.getAllocaAddrSpace(); |
| 484 | Type *V256I32Ty = VectorType::get(ElementType: Builder.getInt32Ty(), NumElements: 256, Scalable: false); |
| 485 | AllocaInst *AllocaRes = |
| 486 | new AllocaInst(V256I32Ty, AllocaAS, "" , F->getEntryBlock().begin()); |
| 487 | BasicBlock::iterator Iter = AllocaRes->getIterator(); |
| 488 | ++Iter; |
| 489 | Builder.SetInsertPoint(&*Iter); |
| 490 | Value *I8Ptr = Builder.CreateBitCast(V: AllocaRes, DestTy: Builder.getPtrTy()); |
| 491 | return I8Ptr; |
| 492 | } |
| 493 | |
| 494 | static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) { |
| 495 | assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!" ); |
| 496 | auto *II = cast<IntrinsicInst>(Val: TileDef); |
| 497 | |
| 498 | assert(II && "Not tile intrinsic!" ); |
| 499 | Value *Row = II->getOperand(i_nocapture: 0); |
| 500 | Value *Col = II->getOperand(i_nocapture: 1); |
| 501 | |
| 502 | BasicBlock *BB = TileDef->getParent(); |
| 503 | BasicBlock::iterator Iter = TileDef->getIterator(); |
| 504 | IRBuilder<> Builder(BB, ++Iter); |
| 505 | Value *Stride = Builder.getInt64(C: 64); |
| 506 | std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef}; |
| 507 | |
| 508 | Instruction *TileStore = |
| 509 | Builder.CreateIntrinsic(ID: Intrinsic::x86_tilestored64_internal, Args); |
| 510 | return TileStore; |
| 511 | } |
| 512 | |
| 513 | static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) { |
| 514 | Value *V = U.get(); |
| 515 | assert(V->getType()->isX86_AMXTy() && "Not define tile!" ); |
| 516 | |
| 517 | // Get tile shape. |
| 518 | IntrinsicInst *II = nullptr; |
| 519 | if (IsPHI) { |
| 520 | Value *PhiOp = cast<PHINode>(Val: V)->getIncomingValue(i: 0); |
| 521 | II = cast<IntrinsicInst>(Val: PhiOp); |
| 522 | } else { |
| 523 | II = cast<IntrinsicInst>(Val: V); |
| 524 | } |
| 525 | Value *Row = II->getOperand(i_nocapture: 0); |
| 526 | Value *Col = II->getOperand(i_nocapture: 1); |
| 527 | |
| 528 | Instruction *UserI = cast<Instruction>(Val: U.getUser()); |
| 529 | IRBuilder<> Builder(UserI); |
| 530 | Value *Stride = Builder.getInt64(C: 64); |
| 531 | std::array<Value *, 4> Args = {Row, Col, Ptr, Stride}; |
| 532 | |
| 533 | Value *TileLoad = |
| 534 | Builder.CreateIntrinsic(ID: Intrinsic::x86_tileloadd64_internal, Args); |
| 535 | UserI->replaceUsesOfWith(From: V, To: TileLoad); |
| 536 | } |
| 537 | |
| 538 | static bool isIncomingOfPHI(Instruction *I) { |
| 539 | for (Use &U : I->uses()) { |
| 540 | User *V = U.getUser(); |
| 541 | if (isa<PHINode>(Val: V)) |
| 542 | return true; |
| 543 | } |
| 544 | return false; |
| 545 | } |
| 546 | |
| 547 | // Let all AMX tile data become volatile data, shorten the life range |
| 548 | // of each tile register before fast register allocation. |
| 549 | namespace { |
| 550 | class X86VolatileTileData { |
| 551 | Function &F; |
| 552 | |
| 553 | public: |
| 554 | X86VolatileTileData(Function &Func) : F(Func) {} |
| 555 | Value *updatePhiIncomings(BasicBlock *BB, |
| 556 | SmallVector<Instruction *, 2> &Incomings); |
| 557 | void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr); |
| 558 | bool volatileTileData(); |
| 559 | void volatileTilePHI(PHINode *PHI); |
| 560 | void volatileTileNonPHI(Instruction *I); |
| 561 | }; |
| 562 | |
| 563 | Value *X86VolatileTileData::updatePhiIncomings( |
| 564 | BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) { |
| 565 | Value *I8Ptr = getAllocaPos(BB); |
| 566 | |
| 567 | for (auto *I : Incomings) { |
| 568 | User *Store = createTileStore(TileDef: I, Ptr: I8Ptr); |
| 569 | |
| 570 | // All its uses (except phi) should load from stored mem. |
| 571 | for (Use &U : I->uses()) { |
| 572 | User *V = U.getUser(); |
| 573 | if (isa<PHINode>(Val: V) || V == Store) |
| 574 | continue; |
| 575 | replaceWithTileLoad(U, Ptr: I8Ptr); |
| 576 | } |
| 577 | } |
| 578 | return I8Ptr; |
| 579 | } |
| 580 | |
| 581 | void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI, |
| 582 | Value *StorePtr) { |
| 583 | for (Use &U : PHI->uses()) |
| 584 | replaceWithTileLoad(U, Ptr: StorePtr, IsPHI: true); |
| 585 | PHI->eraseFromParent(); |
| 586 | } |
| 587 | |
| 588 | // Smilar with volatileTileNonPHI, this function only handle PHI Nodes |
| 589 | // and their related AMX intrinsics. |
| 590 | // 1) PHI Def should change to tileload. |
| 591 | // 2) PHI Incoming Values should tilestored in just after their def. |
| 592 | // 3) The mem of these tileload and tilestores should be same. |
| 593 | // e.g. |
| 594 | // ------------------------------------------------------ |
| 595 | // bb_dom: |
| 596 | // ... |
| 597 | // br i1 %bool.cond, label %if.else, label %if.then |
| 598 | // |
| 599 | // if.then: |
| 600 | // def %t0 = ... |
| 601 | // ... |
| 602 | // use %t0 |
| 603 | // ... |
| 604 | // br label %if.end |
| 605 | // |
| 606 | // if.else: |
| 607 | // def %t1 = ... |
| 608 | // br label %if.end |
| 609 | // |
| 610 | // if.end: |
| 611 | // %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ] |
| 612 | // ... |
| 613 | // use %td |
| 614 | // ------------------------------------------------------ |
| 615 | // --> |
| 616 | // ------------------------------------------------------ |
| 617 | // bb_entry: |
| 618 | // %mem = alloca <256 x i32>, align 1024 * |
| 619 | // ... |
| 620 | // bb_dom: |
| 621 | // ... |
| 622 | // br i1 %bool.cond, label %if.else, label %if.then |
| 623 | // |
| 624 | // if.then: |
| 625 | // def %t0 = ... |
| 626 | // call void @llvm.x86.tilestored64.internal(mem, %t0) * |
| 627 | // ... |
| 628 | // %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)* |
| 629 | // use %t0` * |
| 630 | // ... |
| 631 | // br label %if.end |
| 632 | // |
| 633 | // if.else: |
| 634 | // def %t1 = ... |
| 635 | // call void @llvm.x86.tilestored64.internal(mem, %t1) * |
| 636 | // br label %if.end |
| 637 | // |
| 638 | // if.end: |
| 639 | // ... |
| 640 | // %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) * |
| 641 | // use %td |
| 642 | // ------------------------------------------------------ |
| 643 | void X86VolatileTileData::volatileTilePHI(PHINode *PHI) { |
| 644 | BasicBlock *BB = PHI->getParent(); |
| 645 | SmallVector<Instruction *, 2> Incomings; |
| 646 | |
| 647 | for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) { |
| 648 | Value *Op = PHI->getIncomingValue(i: I); |
| 649 | Instruction *Inst = dyn_cast<Instruction>(Val: Op); |
| 650 | assert(Inst && "We shouldn't fold AMX instrution!" ); |
| 651 | Incomings.push_back(Elt: Inst); |
| 652 | } |
| 653 | |
| 654 | Value *StorePtr = updatePhiIncomings(BB, Incomings); |
| 655 | replacePhiDefWithLoad(PHI, StorePtr); |
| 656 | } |
| 657 | |
| 658 | // Store the defined tile and load it before use. |
| 659 | // All its users are not PHI. |
| 660 | // e.g. |
| 661 | // ------------------------------------------------------ |
| 662 | // def %td = ... |
| 663 | // ... |
| 664 | // "use %td" |
| 665 | // ------------------------------------------------------ |
| 666 | // --> |
| 667 | // ------------------------------------------------------ |
| 668 | // def %td = ... |
| 669 | // call void @llvm.x86.tilestored64.internal(mem, %td) |
| 670 | // ... |
| 671 | // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem) |
| 672 | // "use %td2" |
| 673 | // ------------------------------------------------------ |
| 674 | void X86VolatileTileData::volatileTileNonPHI(Instruction *I) { |
| 675 | BasicBlock *BB = I->getParent(); |
| 676 | Value *I8Ptr = getAllocaPos(BB); |
| 677 | User *Store = createTileStore(TileDef: I, Ptr: I8Ptr); |
| 678 | |
| 679 | // All its uses should load from stored mem. |
| 680 | for (Use &U : I->uses()) { |
| 681 | User *V = U.getUser(); |
| 682 | assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!" ); |
| 683 | if (V != Store) |
| 684 | replaceWithTileLoad(U, Ptr: I8Ptr); |
| 685 | } |
| 686 | } |
| 687 | |
| 688 | // Volatile Tile Model: |
| 689 | // 1) All the uses of tile data comes from tileload in time. |
| 690 | // 2) All the defs of tile data tilestore into mem immediately. |
| 691 | // For example: |
| 692 | // -------------------------------------------------------------------------- |
| 693 | // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key |
| 694 | // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) |
| 695 | // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx |
| 696 | // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3) |
| 697 | // call void @llvm.x86.tilestored64.internal(... td) area |
| 698 | // -------------------------------------------------------------------------- |
| 699 | // 3) No terminator, call or other amx instructions in the key amx area. |
| 700 | bool X86VolatileTileData::volatileTileData() { |
| 701 | bool Changed = false; |
| 702 | for (BasicBlock &BB : F) { |
| 703 | SmallVector<Instruction *, 2> PHIInsts; |
| 704 | SmallVector<Instruction *, 8> AMXDefInsts; |
| 705 | |
| 706 | for (Instruction &I : BB) { |
| 707 | if (!I.getType()->isX86_AMXTy()) |
| 708 | continue; |
| 709 | if (isa<PHINode>(Val: &I)) |
| 710 | PHIInsts.push_back(Elt: &I); |
| 711 | else |
| 712 | AMXDefInsts.push_back(Elt: &I); |
| 713 | } |
| 714 | |
| 715 | // First we "volatile" the non-phi related amx intrinsics. |
| 716 | for (Instruction *I : AMXDefInsts) { |
| 717 | if (isIncomingOfPHI(I)) |
| 718 | continue; |
| 719 | volatileTileNonPHI(I); |
| 720 | Changed = true; |
| 721 | } |
| 722 | |
| 723 | for (Instruction *I : PHIInsts) { |
| 724 | volatileTilePHI(PHI: dyn_cast<PHINode>(Val: I)); |
| 725 | Changed = true; |
| 726 | } |
| 727 | } |
| 728 | return Changed; |
| 729 | } |
| 730 | |
| 731 | } // anonymous namespace |
| 732 | |
| 733 | namespace { |
| 734 | |
| 735 | class X86LowerAMXCast { |
| 736 | Function &Func; |
| 737 | std::unique_ptr<DominatorTree> DT; |
| 738 | |
| 739 | public: |
| 740 | X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {} |
| 741 | bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST); |
| 742 | bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD); |
| 743 | bool combineTilezero(IntrinsicInst *Cast); |
| 744 | bool combineLdSt(SmallVectorImpl<Instruction *> &Casts); |
| 745 | bool combineAMXcast(TargetLibraryInfo *TLI); |
| 746 | bool transformAMXCast(IntrinsicInst *AMXCast); |
| 747 | bool transformAllAMXCast(); |
| 748 | bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN, |
| 749 | SmallSetVector<Instruction *, 16> &DeadInst); |
| 750 | }; |
| 751 | |
| 752 | static bool DCEInstruction(Instruction *I, |
| 753 | SmallSetVector<Instruction *, 16> &WorkList, |
| 754 | const TargetLibraryInfo *TLI) { |
| 755 | if (isInstructionTriviallyDead(I, TLI)) { |
| 756 | salvageDebugInfo(I&: *I); |
| 757 | salvageKnowledge(I); |
| 758 | |
| 759 | // Null out all of the instruction's operands to see if any operand becomes |
| 760 | // dead as we go. |
| 761 | for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { |
| 762 | Value *OpV = I->getOperand(i); |
| 763 | I->setOperand(i, Val: nullptr); |
| 764 | |
| 765 | if (!OpV->use_empty() || I == OpV) |
| 766 | continue; |
| 767 | |
| 768 | // If the operand is an instruction that became dead as we nulled out the |
| 769 | // operand, and if it is 'trivially' dead, delete it in a future loop |
| 770 | // iteration. |
| 771 | if (Instruction *OpI = dyn_cast<Instruction>(Val: OpV)) { |
| 772 | if (isInstructionTriviallyDead(I: OpI, TLI)) { |
| 773 | WorkList.insert(X: OpI); |
| 774 | } |
| 775 | } |
| 776 | } |
| 777 | I->eraseFromParent(); |
| 778 | return true; |
| 779 | } |
| 780 | return false; |
| 781 | } |
| 782 | |
| 783 | /// This function handles following case |
| 784 | /// |
| 785 | /// A -> B amxcast |
| 786 | /// PHI |
| 787 | /// B -> A amxcast |
| 788 | /// |
| 789 | /// All the related PHI nodes can be replaced by new PHI nodes with type A. |
| 790 | /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN. |
| 791 | bool X86LowerAMXCast::optimizeAMXCastFromPhi( |
| 792 | IntrinsicInst *CI, PHINode *PN, |
| 793 | SmallSetVector<Instruction *, 16> &DeadInst) { |
| 794 | IRBuilder<> Builder(CI); |
| 795 | Value *Src = CI->getOperand(i_nocapture: 0); |
| 796 | Type *SrcTy = Src->getType(); // Type B |
| 797 | Type *DestTy = CI->getType(); // Type A |
| 798 | |
| 799 | SmallVector<PHINode *, 4> PhiWorklist; |
| 800 | SmallSetVector<PHINode *, 4> OldPhiNodes; |
| 801 | |
| 802 | // Find all of the A->B casts and PHI nodes. |
| 803 | // We need to inspect all related PHI nodes, but PHIs can be cyclic, so |
| 804 | // OldPhiNodes is used to track all known PHI nodes, before adding a new |
| 805 | // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first. |
| 806 | PhiWorklist.push_back(Elt: PN); |
| 807 | OldPhiNodes.insert(X: PN); |
| 808 | while (!PhiWorklist.empty()) { |
| 809 | auto *OldPN = PhiWorklist.pop_back_val(); |
| 810 | for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) { |
| 811 | Value *IncValue = OldPN->getIncomingValue(i: I); |
| 812 | // TODO: currently, We ignore cases where it is a const. In the future, we |
| 813 | // might support const. |
| 814 | if (isa<Constant>(Val: IncValue)) { |
| 815 | auto *IncConst = dyn_cast<Constant>(Val: IncValue); |
| 816 | if (!isa<UndefValue>(Val: IncValue) && !IncConst->isZeroValue()) |
| 817 | return false; |
| 818 | Value *Row = nullptr, *Col = nullptr; |
| 819 | std::tie(args&: Row, args&: Col) = getShape(Phi: OldPN); |
| 820 | // TODO: If it is not constant the Row and Col must domoniate tilezero |
| 821 | // that we are going to create. |
| 822 | if (!Row || !Col || !isa<Constant>(Val: Row) || !isa<Constant>(Val: Col)) |
| 823 | return false; |
| 824 | // Create tilezero at the end of incoming block. |
| 825 | auto *Block = OldPN->getIncomingBlock(i: I); |
| 826 | BasicBlock::iterator Iter = Block->getTerminator()->getIterator(); |
| 827 | Instruction *NewInst = Builder.CreateIntrinsic( |
| 828 | ID: Intrinsic::x86_tilezero_internal, Types: {}, Args: {Row, Col}); |
| 829 | NewInst->moveBefore(InsertPos: Iter); |
| 830 | NewInst = Builder.CreateIntrinsic(ID: Intrinsic::x86_cast_tile_to_vector, |
| 831 | Types: {IncValue->getType()}, Args: {NewInst}); |
| 832 | NewInst->moveBefore(InsertPos: Iter); |
| 833 | // Replace InValue with new Value. |
| 834 | OldPN->setIncomingValue(i: I, V: NewInst); |
| 835 | IncValue = NewInst; |
| 836 | } |
| 837 | |
| 838 | if (auto *PNode = dyn_cast<PHINode>(Val: IncValue)) { |
| 839 | if (OldPhiNodes.insert(X: PNode)) |
| 840 | PhiWorklist.push_back(Elt: PNode); |
| 841 | continue; |
| 842 | } |
| 843 | Instruction *ACI = dyn_cast<Instruction>(Val: IncValue); |
| 844 | if (ACI && isAMXCast(II: ACI)) { |
| 845 | // Verify it's a A->B cast. |
| 846 | Type *TyA = ACI->getOperand(i: 0)->getType(); |
| 847 | Type *TyB = ACI->getType(); |
| 848 | if (TyA != DestTy || TyB != SrcTy) |
| 849 | return false; |
| 850 | continue; |
| 851 | } |
| 852 | return false; |
| 853 | } |
| 854 | } |
| 855 | |
| 856 | // Check that each user of each old PHI node is something that we can |
| 857 | // rewrite, so that all of the old PHI nodes can be cleaned up afterwards. |
| 858 | for (auto *OldPN : OldPhiNodes) { |
| 859 | for (User *V : OldPN->users()) { |
| 860 | Instruction *ACI = dyn_cast<Instruction>(Val: V); |
| 861 | if (ACI && isAMXCast(II: ACI)) { |
| 862 | // Verify it's a B->A cast. |
| 863 | Type *TyB = ACI->getOperand(i: 0)->getType(); |
| 864 | Type *TyA = ACI->getType(); |
| 865 | if (TyA != DestTy || TyB != SrcTy) |
| 866 | return false; |
| 867 | } else if (auto *PHI = dyn_cast<PHINode>(Val: V)) { |
| 868 | // As long as the user is another old PHI node, then even if we don't |
| 869 | // rewrite it, the PHI web we're considering won't have any users |
| 870 | // outside itself, so it'll be dead. |
| 871 | // example: |
| 872 | // bb.0: |
| 873 | // %0 = amxcast ... |
| 874 | // bb.1: |
| 875 | // %1 = amxcast ... |
| 876 | // bb.2: |
| 877 | // %goodphi = phi %0, %1 |
| 878 | // %3 = amxcast %goodphi |
| 879 | // bb.3: |
| 880 | // %goodphi2 = phi %0, %goodphi |
| 881 | // %4 = amxcast %goodphi2 |
| 882 | // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is |
| 883 | // outside the phi-web, so the combination stop When |
| 884 | // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization |
| 885 | // will be done. |
| 886 | if (OldPhiNodes.count(key: PHI) == 0) |
| 887 | return false; |
| 888 | } else |
| 889 | return false; |
| 890 | } |
| 891 | } |
| 892 | |
| 893 | // For each old PHI node, create a corresponding new PHI node with a type A. |
| 894 | SmallDenseMap<PHINode *, PHINode *> NewPNodes; |
| 895 | for (auto *OldPN : OldPhiNodes) { |
| 896 | Builder.SetInsertPoint(OldPN); |
| 897 | PHINode *NewPN = Builder.CreatePHI(Ty: DestTy, NumReservedValues: OldPN->getNumOperands()); |
| 898 | NewPNodes[OldPN] = NewPN; |
| 899 | } |
| 900 | |
| 901 | // Fill in the operands of new PHI nodes. |
| 902 | for (auto *OldPN : OldPhiNodes) { |
| 903 | PHINode *NewPN = NewPNodes[OldPN]; |
| 904 | for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) { |
| 905 | Value *V = OldPN->getOperand(i_nocapture: j); |
| 906 | Value *NewV = nullptr; |
| 907 | Instruction *ACI = dyn_cast<Instruction>(Val: V); |
| 908 | // There should not be a AMXcast from a const. |
| 909 | if (ACI && isAMXCast(II: ACI)) |
| 910 | NewV = ACI->getOperand(i: 0); |
| 911 | else if (auto *PrevPN = dyn_cast<PHINode>(Val: V)) |
| 912 | NewV = NewPNodes[PrevPN]; |
| 913 | assert(NewV); |
| 914 | NewPN->addIncoming(V: NewV, BB: OldPN->getIncomingBlock(i: j)); |
| 915 | } |
| 916 | } |
| 917 | |
| 918 | // Traverse all accumulated PHI nodes and process its users, |
| 919 | // which are Stores and BitcCasts. Without this processing |
| 920 | // NewPHI nodes could be replicated and could lead to extra |
| 921 | // moves generated after DeSSA. |
| 922 | // If there is a store with type B, change it to type A. |
| 923 | |
| 924 | // Replace users of BitCast B->A with NewPHI. These will help |
| 925 | // later to get rid of a closure formed by OldPHI nodes. |
| 926 | for (auto *OldPN : OldPhiNodes) { |
| 927 | PHINode *NewPN = NewPNodes[OldPN]; |
| 928 | for (User *V : make_early_inc_range(Range: OldPN->users())) { |
| 929 | Instruction *ACI = dyn_cast<Instruction>(Val: V); |
| 930 | if (ACI && isAMXCast(II: ACI)) { |
| 931 | Type *TyB = ACI->getOperand(i: 0)->getType(); |
| 932 | Type *TyA = ACI->getType(); |
| 933 | assert(TyA == DestTy && TyB == SrcTy); |
| 934 | (void)TyA; |
| 935 | (void)TyB; |
| 936 | ACI->replaceAllUsesWith(V: NewPN); |
| 937 | DeadInst.insert(X: ACI); |
| 938 | } else if (auto *PHI = dyn_cast<PHINode>(Val: V)) { |
| 939 | // We don't need to push PHINode into DeadInst since they are operands |
| 940 | // of rootPN DCE can safely delete rootPN's operands if rootPN is dead. |
| 941 | assert(OldPhiNodes.contains(PHI)); |
| 942 | (void)PHI; |
| 943 | } else |
| 944 | llvm_unreachable("all uses should be handled" ); |
| 945 | } |
| 946 | } |
| 947 | return true; |
| 948 | } |
| 949 | |
| 950 | // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42) |
| 951 | // store <256 x i32> %43, <256 x i32>* %p, align 64 |
| 952 | // --> |
| 953 | // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p, |
| 954 | // i64 64, x86_amx %42) |
| 955 | bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) { |
| 956 | Value *Tile = Cast->getOperand(i_nocapture: 0); |
| 957 | |
| 958 | assert(Tile->getType()->isX86_AMXTy() && "Not Tile Operand!" ); |
| 959 | |
| 960 | // TODO: Specially handle the multi-use case. |
| 961 | if (!Tile->hasOneUse()) |
| 962 | return false; |
| 963 | |
| 964 | auto *II = cast<IntrinsicInst>(Val: Tile); |
| 965 | // Tile is output from AMX intrinsic. The first operand of the |
| 966 | // intrinsic is row, the second operand of the intrinsic is column. |
| 967 | Value *Row = II->getOperand(i_nocapture: 0); |
| 968 | Value *Col = II->getOperand(i_nocapture: 1); |
| 969 | |
| 970 | IRBuilder<> Builder(ST); |
| 971 | |
| 972 | // Stride should be equal to col(measured by bytes) |
| 973 | Value *Stride = Builder.CreateSExt(V: Col, DestTy: Builder.getInt64Ty()); |
| 974 | Value *I8Ptr = Builder.CreateBitCast(V: ST->getOperand(i_nocapture: 1), DestTy: Builder.getPtrTy()); |
| 975 | std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile}; |
| 976 | Builder.CreateIntrinsic(ID: Intrinsic::x86_tilestored64_internal, Args); |
| 977 | return true; |
| 978 | } |
| 979 | |
| 980 | // %65 = load <256 x i32>, <256 x i32>* %p, align 64 |
| 981 | // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65) |
| 982 | // --> |
| 983 | // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, |
| 984 | // i8* %p, i64 64) |
| 985 | bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) { |
| 986 | bool EraseLoad = true; |
| 987 | Value *Row = nullptr, *Col = nullptr; |
| 988 | Use &U = *(Cast->use_begin()); |
| 989 | unsigned OpNo = U.getOperandNo(); |
| 990 | auto *II = cast<IntrinsicInst>(Val: U.getUser()); |
| 991 | // TODO: If it is cast intrinsic or phi node, we can propagate the |
| 992 | // shape information through def-use chain. |
| 993 | if (!isAMXIntrinsic(I: II)) |
| 994 | return false; |
| 995 | std::tie(args&: Row, args&: Col) = getShape(II, OpNo); |
| 996 | IRBuilder<> Builder(LD); |
| 997 | Value *I8Ptr; |
| 998 | |
| 999 | // To save compiling time, we create dominator tree when it is really needed. |
| 1000 | if (!DT) |
| 1001 | DT.reset(p: new DominatorTree(Func)); |
| 1002 | if (!DT->dominates(Def: Row, User: LD) || !DT->dominates(Def: Col, User: LD)) { |
| 1003 | // store the value to stack and reload it from stack before cast. |
| 1004 | auto *AllocaAddr = |
| 1005 | createAllocaInstAtEntry(Builder, BB: Cast->getParent(), Ty: LD->getType()); |
| 1006 | Builder.SetInsertPoint(&*std::next(x: LD->getIterator())); |
| 1007 | Builder.CreateStore(Val: LD, Ptr: AllocaAddr); |
| 1008 | |
| 1009 | Builder.SetInsertPoint(Cast); |
| 1010 | I8Ptr = Builder.CreateBitCast(V: AllocaAddr, DestTy: Builder.getPtrTy()); |
| 1011 | EraseLoad = false; |
| 1012 | } else { |
| 1013 | I8Ptr = Builder.CreateBitCast(V: LD->getOperand(i_nocapture: 0), DestTy: Builder.getPtrTy()); |
| 1014 | } |
| 1015 | // Stride should be equal to col(measured by bytes) |
| 1016 | Value *Stride = Builder.CreateSExt(V: Col, DestTy: Builder.getInt64Ty()); |
| 1017 | std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; |
| 1018 | |
| 1019 | Value *NewInst = |
| 1020 | Builder.CreateIntrinsic(ID: Intrinsic::x86_tileloadd64_internal, Args); |
| 1021 | Cast->replaceAllUsesWith(V: NewInst); |
| 1022 | |
| 1023 | return EraseLoad; |
| 1024 | } |
| 1025 | |
| 1026 | // %19 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> zeroinitializer) |
| 1027 | // --> |
| 1028 | // %19 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %row, i16 %col) |
| 1029 | bool X86LowerAMXCast::combineTilezero(IntrinsicInst *Cast) { |
| 1030 | Value *Row = nullptr, *Col = nullptr; |
| 1031 | Use &U = *(Cast->use_begin()); |
| 1032 | unsigned OpNo = U.getOperandNo(); |
| 1033 | auto *II = cast<IntrinsicInst>(Val: U.getUser()); |
| 1034 | if (!isAMXIntrinsic(I: II)) |
| 1035 | return false; |
| 1036 | |
| 1037 | std::tie(args&: Row, args&: Col) = getShape(II, OpNo); |
| 1038 | |
| 1039 | IRBuilder<> Builder(Cast); |
| 1040 | Value *NewInst = |
| 1041 | Builder.CreateIntrinsic(ID: Intrinsic::x86_tilezero_internal, Types: {}, Args: {Row, Col}); |
| 1042 | Cast->replaceAllUsesWith(V: NewInst); |
| 1043 | return true; |
| 1044 | } |
| 1045 | |
| 1046 | bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) { |
| 1047 | bool Change = false; |
| 1048 | for (auto *Cast : Casts) { |
| 1049 | auto *II = cast<IntrinsicInst>(Val: Cast); |
| 1050 | // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42) |
| 1051 | // store <256 x i32> %43, <256 x i32>* %p, align 64 |
| 1052 | // --> |
| 1053 | // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p, |
| 1054 | // i64 64, x86_amx %42) |
| 1055 | if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) { |
| 1056 | SmallVector<Instruction *, 2> DeadStores; |
| 1057 | for (User *U : Cast->users()) { |
| 1058 | StoreInst *Store = dyn_cast<StoreInst>(Val: U); |
| 1059 | if (!Store) |
| 1060 | continue; |
| 1061 | if (combineCastStore(Cast: cast<IntrinsicInst>(Val: Cast), ST: Store)) { |
| 1062 | DeadStores.push_back(Elt: Store); |
| 1063 | Change = true; |
| 1064 | } |
| 1065 | } |
| 1066 | for (auto *Store : DeadStores) |
| 1067 | Store->eraseFromParent(); |
| 1068 | } else { // x86_cast_vector_to_tile |
| 1069 | // %19 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> zeroinitializer) |
| 1070 | // --> |
| 1071 | // %19 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %row, i16 %col) |
| 1072 | if (isa<ConstantAggregateZero>(Val: Cast->getOperand(i: 0))) { |
| 1073 | Change |= combineTilezero(Cast: cast<IntrinsicInst>(Val: Cast)); |
| 1074 | continue; |
| 1075 | } |
| 1076 | |
| 1077 | auto *Load = dyn_cast<LoadInst>(Val: Cast->getOperand(i: 0)); |
| 1078 | if (!Load || !Load->hasOneUse()) |
| 1079 | continue; |
| 1080 | // %65 = load <256 x i32>, <256 x i32>* %p, align 64 |
| 1081 | // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65) |
| 1082 | // --> |
| 1083 | // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, |
| 1084 | // i8* %p, i64 64) |
| 1085 | if (combineLoadCast(Cast: cast<IntrinsicInst>(Val: Cast), LD: Load)) { |
| 1086 | // Set the operand is null so that load instruction can be erased. |
| 1087 | Cast->setOperand(i: 0, Val: nullptr); |
| 1088 | Load->eraseFromParent(); |
| 1089 | Change = true; |
| 1090 | } |
| 1091 | } |
| 1092 | } |
| 1093 | return Change; |
| 1094 | } |
| 1095 | |
| 1096 | bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) { |
| 1097 | bool Change = false; |
| 1098 | // Collect tile cast instruction. |
| 1099 | SmallVector<Instruction *, 8> Vec2TileInsts; |
| 1100 | SmallVector<Instruction *, 8> Tile2VecInsts; |
| 1101 | SmallVector<Instruction *, 8> PhiCastWorkList; |
| 1102 | SmallSetVector<Instruction *, 16> DeadInst; |
| 1103 | for (BasicBlock &BB : Func) { |
| 1104 | for (Instruction &I : BB) { |
| 1105 | Value *Vec; |
| 1106 | if (match(V: &I, |
| 1107 | P: m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(Op0: m_Value(V&: Vec)))) |
| 1108 | Vec2TileInsts.push_back(Elt: &I); |
| 1109 | else if (match(V: &I, P: m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>( |
| 1110 | Op0: m_Value(V&: Vec)))) |
| 1111 | Tile2VecInsts.push_back(Elt: &I); |
| 1112 | } |
| 1113 | } |
| 1114 | |
| 1115 | auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) { |
| 1116 | for (auto *Inst : Insts) { |
| 1117 | for (User *U : Inst->users()) { |
| 1118 | IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: U); |
| 1119 | if (!II || II->getIntrinsicID() != IID) |
| 1120 | continue; |
| 1121 | // T1 = vec2tile V0 |
| 1122 | // V2 = tile2vec T1 |
| 1123 | // V3 = OP V2 |
| 1124 | // --> |
| 1125 | // T1 = vec2tile V0 |
| 1126 | // V2 = tile2vec T1 |
| 1127 | // V3 = OP V0 |
| 1128 | II->replaceAllUsesWith(V: Inst->getOperand(i: 0)); |
| 1129 | Change = true; |
| 1130 | } |
| 1131 | } |
| 1132 | }; |
| 1133 | |
| 1134 | Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector); |
| 1135 | Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile); |
| 1136 | |
| 1137 | SmallVector<Instruction *, 8> LiveCasts; |
| 1138 | auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) { |
| 1139 | for (auto *Inst : Insts) { |
| 1140 | if (Inst->use_empty()) { |
| 1141 | Inst->eraseFromParent(); |
| 1142 | Change = true; |
| 1143 | } else { |
| 1144 | LiveCasts.push_back(Elt: Inst); |
| 1145 | } |
| 1146 | } |
| 1147 | }; |
| 1148 | |
| 1149 | EraseInst(Vec2TileInsts); |
| 1150 | EraseInst(Tile2VecInsts); |
| 1151 | LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine " |
| 1152 | "Vec2Tile and Tile2Vec:\n" ; |
| 1153 | Func.dump()); |
| 1154 | Change |= combineLdSt(Casts&: LiveCasts); |
| 1155 | EraseInst(LiveCasts); |
| 1156 | LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine " |
| 1157 | "AMXCast and load/store:\n" ; |
| 1158 | Func.dump()); |
| 1159 | |
| 1160 | // Handle the A->B->A cast, and there is an intervening PHI node. |
| 1161 | for (BasicBlock &BB : Func) { |
| 1162 | for (Instruction &I : BB) { |
| 1163 | if (isAMXCast(II: &I)) { |
| 1164 | if (isa<PHINode>(Val: I.getOperand(i: 0))) |
| 1165 | PhiCastWorkList.push_back(Elt: &I); |
| 1166 | } |
| 1167 | } |
| 1168 | } |
| 1169 | for (auto *I : PhiCastWorkList) { |
| 1170 | // We skip the dead Amxcast. |
| 1171 | if (DeadInst.contains(key: I)) |
| 1172 | continue; |
| 1173 | PHINode *PN = cast<PHINode>(Val: I->getOperand(i: 0)); |
| 1174 | if (optimizeAMXCastFromPhi(CI: cast<IntrinsicInst>(Val: I), PN, DeadInst)) { |
| 1175 | DeadInst.insert(X: PN); |
| 1176 | Change = true; |
| 1177 | } |
| 1178 | } |
| 1179 | |
| 1180 | // Since we create new phi and merge AMXCast, some old phis and AMXCast might |
| 1181 | // have no uses. We do some DeadCodeElimination for them. |
| 1182 | while (!DeadInst.empty()) { |
| 1183 | Instruction *I = DeadInst.pop_back_val(); |
| 1184 | Change |= DCEInstruction(I, WorkList&: DeadInst, TLI); |
| 1185 | } |
| 1186 | LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after " |
| 1187 | "optimizeAMXCastFromPhi:\n" ; |
| 1188 | Func.dump()); |
| 1189 | return Change; |
| 1190 | } |
| 1191 | |
| 1192 | // There might be remaining AMXcast after combineAMXcast and they should be |
| 1193 | // handled elegantly. |
| 1194 | bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) { |
| 1195 | IRBuilder<> Builder(AMXCast); |
| 1196 | AllocaInst *AllocaAddr; |
| 1197 | Value *I8Ptr, *Stride; |
| 1198 | auto *Src = AMXCast->getOperand(i_nocapture: 0); |
| 1199 | |
| 1200 | auto Prepare = [&](Type *MemTy) { |
| 1201 | AllocaAddr = createAllocaInstAtEntry(Builder, BB: AMXCast->getParent(), Ty: MemTy); |
| 1202 | I8Ptr = Builder.CreateBitCast(V: AllocaAddr, DestTy: Builder.getPtrTy()); |
| 1203 | Stride = Builder.getInt64(C: 64); |
| 1204 | }; |
| 1205 | |
| 1206 | if (AMXCast->getType()->isX86_AMXTy()) { |
| 1207 | // %2 = amxcast <225 x i32> %src to x86_amx |
| 1208 | // call void @llvm.x86.tilestored64.internal(i16 15, i16 60, |
| 1209 | // i8* %addr3, i64 60, x86_amx %2) |
| 1210 | // --> |
| 1211 | // %addr = alloca <225 x i32>, align 64 |
| 1212 | // store <225 x i32> %src, <225 x i32>* %addr, align 64 |
| 1213 | // %addr2 = bitcast <225 x i32>* %addr to i8* |
| 1214 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60, |
| 1215 | // i8* %addr2, |
| 1216 | // i64 60) |
| 1217 | // call void @llvm.x86.tilestored64.internal(i16 15, i16 60, |
| 1218 | // i8* %addr3, i64 60, x86_amx %2) |
| 1219 | if (AMXCast->use_empty()) { |
| 1220 | AMXCast->eraseFromParent(); |
| 1221 | return true; |
| 1222 | } |
| 1223 | Use &U = *(AMXCast->use_begin()); |
| 1224 | unsigned OpNo = U.getOperandNo(); |
| 1225 | auto *II = dyn_cast<IntrinsicInst>(Val: U.getUser()); |
| 1226 | if (!II) |
| 1227 | return false; // May be bitcast from x86amx to <256 x i32>. |
| 1228 | Prepare(AMXCast->getOperand(i_nocapture: 0)->getType()); |
| 1229 | Builder.CreateStore(Val: Src, Ptr: AllocaAddr); |
| 1230 | // TODO we can pick an constant operand for the shape. |
| 1231 | Value *Row = nullptr, *Col = nullptr; |
| 1232 | std::tie(args&: Row, args&: Col) = getShape(II, OpNo); |
| 1233 | std::array<Value *, 4> Args = { |
| 1234 | Row, Col, I8Ptr, Builder.CreateSExt(V: Col, DestTy: Builder.getInt64Ty())}; |
| 1235 | Value *NewInst = |
| 1236 | Builder.CreateIntrinsic(ID: Intrinsic::x86_tileloadd64_internal, Args); |
| 1237 | AMXCast->replaceAllUsesWith(V: NewInst); |
| 1238 | AMXCast->eraseFromParent(); |
| 1239 | } else { |
| 1240 | // %2 = amxcast x86_amx %src to <225 x i32> |
| 1241 | // --> |
| 1242 | // %addr = alloca <225 x i32>, align 64 |
| 1243 | // %addr2 = bitcast <225 x i32>* to i8* |
| 1244 | // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, |
| 1245 | // i8* %addr2, i64 %stride) |
| 1246 | // %2 = load <225 x i32>, <225 x i32>* %addr, align 64 |
| 1247 | auto *II = dyn_cast<IntrinsicInst>(Val: Src); |
| 1248 | if (!II) |
| 1249 | return false; // May be bitcast from <256 x i32> to x86amx. |
| 1250 | Prepare(AMXCast->getType()); |
| 1251 | Value *Row = II->getOperand(i_nocapture: 0); |
| 1252 | Value *Col = II->getOperand(i_nocapture: 1); |
| 1253 | std::array<Value *, 5> Args = { |
| 1254 | Row, Col, I8Ptr, Builder.CreateSExt(V: Col, DestTy: Builder.getInt64Ty()), Src}; |
| 1255 | Builder.CreateIntrinsic(ID: Intrinsic::x86_tilestored64_internal, Args); |
| 1256 | Value *NewInst = Builder.CreateLoad(Ty: AMXCast->getType(), Ptr: AllocaAddr); |
| 1257 | AMXCast->replaceAllUsesWith(V: NewInst); |
| 1258 | AMXCast->eraseFromParent(); |
| 1259 | } |
| 1260 | |
| 1261 | return true; |
| 1262 | } |
| 1263 | |
| 1264 | bool X86LowerAMXCast::transformAllAMXCast() { |
| 1265 | bool Change = false; |
| 1266 | // Collect tile cast instruction. |
| 1267 | SmallVector<Instruction *, 8> WorkLists; |
| 1268 | for (BasicBlock &BB : Func) { |
| 1269 | for (Instruction &I : BB) { |
| 1270 | if (isAMXCast(II: &I)) |
| 1271 | WorkLists.push_back(Elt: &I); |
| 1272 | } |
| 1273 | } |
| 1274 | |
| 1275 | for (auto *Inst : WorkLists) { |
| 1276 | Change |= transformAMXCast(AMXCast: cast<IntrinsicInst>(Val: Inst)); |
| 1277 | } |
| 1278 | |
| 1279 | return Change; |
| 1280 | } |
| 1281 | |
| 1282 | bool lowerAmxType(Function &F, const TargetMachine *TM, |
| 1283 | TargetLibraryInfo *TLI) { |
| 1284 | // Performance optimization: most code doesn't use AMX, so return early if |
| 1285 | // there are no instructions that produce AMX values. This is sufficient, as |
| 1286 | // AMX arguments and constants are not allowed -- so any producer of an AMX |
| 1287 | // value must be an instruction. |
| 1288 | // TODO: find a cheaper way for this, without looking at all instructions. |
| 1289 | if (!containsAMXCode(F)) |
| 1290 | return false; |
| 1291 | |
| 1292 | bool C = false; |
| 1293 | X86LowerAMXCast LAC(F); |
| 1294 | C |= LAC.combineAMXcast(TLI); |
| 1295 | // There might be remaining AMXcast after combineAMXcast and they should be |
| 1296 | // handled elegantly. |
| 1297 | C |= LAC.transformAllAMXCast(); |
| 1298 | |
| 1299 | X86LowerAMXType LAT(F); |
| 1300 | C |= LAT.visit(); |
| 1301 | |
| 1302 | // Prepare for fast register allocation at O0. |
| 1303 | // Todo: May better check the volatile model of AMX code, not just |
| 1304 | // by checking Attribute::OptimizeNone and CodeGenOptLevel::None. |
| 1305 | if (TM->getOptLevel() == CodeGenOptLevel::None) { |
| 1306 | // If Front End not use O0 but the Mid/Back end use O0, (e.g. |
| 1307 | // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make |
| 1308 | // sure the amx data is volatile, that is necessary for AMX fast |
| 1309 | // register allocation. |
| 1310 | if (!F.hasFnAttribute(Kind: Attribute::OptimizeNone)) { |
| 1311 | X86VolatileTileData VTD(F); |
| 1312 | C = VTD.volatileTileData() || C; |
| 1313 | } |
| 1314 | } |
| 1315 | |
| 1316 | return C; |
| 1317 | } |
| 1318 | |
| 1319 | } // anonymous namespace |
| 1320 | |
| 1321 | PreservedAnalyses X86LowerAMXTypePass::run(Function &F, |
| 1322 | FunctionAnalysisManager &FAM) { |
| 1323 | TargetLibraryInfo &TLI = FAM.getResult<TargetLibraryAnalysis>(IR&: F); |
| 1324 | bool Changed = lowerAmxType(F, TM, TLI: &TLI); |
| 1325 | if (!Changed) |
| 1326 | return PreservedAnalyses::all(); |
| 1327 | |
| 1328 | PreservedAnalyses PA = PreservedAnalyses::none(); |
| 1329 | PA.preserveSet<CFGAnalyses>(); |
| 1330 | return PA; |
| 1331 | } |
| 1332 | |
| 1333 | namespace { |
| 1334 | |
| 1335 | class X86LowerAMXTypeLegacyPass : public FunctionPass { |
| 1336 | public: |
| 1337 | static char ID; |
| 1338 | |
| 1339 | X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {} |
| 1340 | |
| 1341 | bool runOnFunction(Function &F) override { |
| 1342 | TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); |
| 1343 | TargetLibraryInfo *TLI = |
| 1344 | &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); |
| 1345 | return lowerAmxType(F, TM, TLI); |
| 1346 | } |
| 1347 | |
| 1348 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
| 1349 | AU.setPreservesCFG(); |
| 1350 | AU.addRequired<TargetPassConfig>(); |
| 1351 | AU.addRequired<TargetLibraryInfoWrapperPass>(); |
| 1352 | } |
| 1353 | }; |
| 1354 | |
| 1355 | } // anonymous namespace |
| 1356 | |
| 1357 | static const char PassName[] = "Lower AMX type for load/store" ; |
| 1358 | char X86LowerAMXTypeLegacyPass::ID = 0; |
| 1359 | INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, |
| 1360 | false) |
| 1361 | INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) |
| 1362 | INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) |
| 1363 | INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, |
| 1364 | false) |
| 1365 | |
| 1366 | FunctionPass *llvm::createX86LowerAMXTypeLegacyPass() { |
| 1367 | return new X86LowerAMXTypeLegacyPass(); |
| 1368 | } |
| 1369 | |