| 1 | //===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===// | 
|---|
| 2 | // | 
|---|
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | 
|---|
| 4 | // See https://llvm.org/LICENSE.txt for license information. | 
|---|
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|---|
| 6 | // | 
|---|
| 7 | //===----------------------------------------------------------------------===// | 
|---|
| 8 | // | 
|---|
| 9 | /// \file Pass to transform amx intrinsics to scalar operations. | 
|---|
| 10 | /// This pass is always enabled and it skips when it is not -O0 and has no | 
|---|
| 11 | /// optnone attributes. With -O0 or optnone attribute, the def of shape to amx | 
|---|
| 12 | /// intrinsics is near the amx intrinsics code. We are not able to find a | 
|---|
| 13 | /// point which post-dominate all the shape and dominate all amx intrinsics. | 
|---|
| 14 | /// To decouple the dependency of the shape, we transform amx intrinsics | 
|---|
| 15 | /// to scalar operation, so that compiling doesn't fail. In long term, we | 
|---|
| 16 | /// should improve fast register allocation to allocate amx register. | 
|---|
| 17 | //===----------------------------------------------------------------------===// | 
|---|
| 18 | // | 
|---|
| 19 | #include "X86.h" | 
|---|
| 20 | #include "llvm/Analysis/DomTreeUpdater.h" | 
|---|
| 21 | #include "llvm/Analysis/LoopInfo.h" | 
|---|
| 22 | #include "llvm/Analysis/TargetTransformInfo.h" | 
|---|
| 23 | #include "llvm/CodeGen/Passes.h" | 
|---|
| 24 | #include "llvm/CodeGen/TargetPassConfig.h" | 
|---|
| 25 | #include "llvm/CodeGen/ValueTypes.h" | 
|---|
| 26 | #include "llvm/IR/DataLayout.h" | 
|---|
| 27 | #include "llvm/IR/Function.h" | 
|---|
| 28 | #include "llvm/IR/IRBuilder.h" | 
|---|
| 29 | #include "llvm/IR/Instructions.h" | 
|---|
| 30 | #include "llvm/IR/IntrinsicInst.h" | 
|---|
| 31 | #include "llvm/IR/IntrinsicsX86.h" | 
|---|
| 32 | #include "llvm/IR/PatternMatch.h" | 
|---|
| 33 | #include "llvm/InitializePasses.h" | 
|---|
| 34 | #include "llvm/Pass.h" | 
|---|
| 35 | #include "llvm/Support/CommandLine.h" | 
|---|
| 36 | #include "llvm/Target/TargetMachine.h" | 
|---|
| 37 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" | 
|---|
| 38 | #include "llvm/Transforms/Utils/LoopUtils.h" | 
|---|
| 39 |  | 
|---|
| 40 | using namespace llvm; | 
|---|
| 41 | using namespace PatternMatch; | 
|---|
| 42 |  | 
|---|
| 43 | #define DEBUG_TYPE "lower-amx-intrinsics" | 
|---|
| 44 |  | 
|---|
| 45 | #ifndef NDEBUG | 
|---|
| 46 | static bool isV256I32Ty(Type *Ty) { | 
|---|
| 47 | if (auto *FVT = dyn_cast<FixedVectorType>(Ty)) | 
|---|
| 48 | return FVT->getNumElements() == 256 && | 
|---|
| 49 | FVT->getElementType()->isIntegerTy(32); | 
|---|
| 50 | return false; | 
|---|
| 51 | } | 
|---|
| 52 | #endif | 
|---|
| 53 |  | 
|---|
| 54 | static cl::opt<bool> | 
|---|
| 55 | X86ScalarizeAMX( "enable-x86-scalar-amx", cl::init(Val: false), cl::Hidden, | 
|---|
| 56 | cl::desc( "X86: enable AMX scalarizition.")); | 
|---|
| 57 |  | 
|---|
| 58 | namespace { | 
|---|
| 59 | class X86LowerAMXIntrinsics { | 
|---|
| 60 | Function &Func; | 
|---|
| 61 |  | 
|---|
| 62 | public: | 
|---|
| 63 | X86LowerAMXIntrinsics(Function &F, DomTreeUpdater &DomTU, LoopInfo *LoopI) | 
|---|
| 64 | : Func(F), DTU(DomTU), LI(LoopI) {} | 
|---|
| 65 | bool visit(); | 
|---|
| 66 |  | 
|---|
| 67 | private: | 
|---|
| 68 | DomTreeUpdater &DTU; | 
|---|
| 69 | LoopInfo *LI; | 
|---|
| 70 | BasicBlock *createLoop(BasicBlock *, BasicBlock *Exit, Value *Bound, | 
|---|
| 71 | Value *Step, StringRef Name, IRBuilderBase &B, | 
|---|
| 72 | Loop *L); | 
|---|
| 73 | template <bool IsTileLoad> | 
|---|
| 74 | Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End, | 
|---|
| 75 | IRBuilderBase &B, Value *Row, Value *Col, | 
|---|
| 76 | Value *Ptr, Value *Stride, Value *Tile); | 
|---|
| 77 | template <Intrinsic::ID IntrID> | 
|---|
| 78 | std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || | 
|---|
| 79 | IntrID == Intrinsic::x86_tdpbsud_internal || | 
|---|
| 80 | IntrID == Intrinsic::x86_tdpbusd_internal || | 
|---|
| 81 | IntrID == Intrinsic::x86_tdpbuud_internal || | 
|---|
| 82 | IntrID == Intrinsic::x86_tdpbf16ps_internal, | 
|---|
| 83 | Value *> | 
|---|
| 84 | createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, | 
|---|
| 85 | Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS, | 
|---|
| 86 | Value *RHS); | 
|---|
| 87 | template <bool IsTileLoad> | 
|---|
| 88 | bool lowerTileLoadStore(Instruction *TileLoadStore); | 
|---|
| 89 | template <Intrinsic::ID IntrID> | 
|---|
| 90 | std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || | 
|---|
| 91 | IntrID == Intrinsic::x86_tdpbsud_internal || | 
|---|
| 92 | IntrID == Intrinsic::x86_tdpbusd_internal || | 
|---|
| 93 | IntrID == Intrinsic::x86_tdpbuud_internal || | 
|---|
| 94 | IntrID == Intrinsic::x86_tdpbf16ps_internal, | 
|---|
| 95 | bool> | 
|---|
| 96 | lowerTileDP(Instruction *TileDP); | 
|---|
| 97 | bool lowerTileZero(Instruction *TileZero); | 
|---|
| 98 | }; | 
|---|
| 99 | } // anonymous namespace | 
|---|
| 100 |  | 
|---|
| 101 | BasicBlock *X86LowerAMXIntrinsics::createLoop(BasicBlock *, | 
|---|
| 102 | BasicBlock *Exit, Value *Bound, | 
|---|
| 103 | Value *Step, StringRef Name, | 
|---|
| 104 | IRBuilderBase &B, Loop *L) { | 
|---|
| 105 | LLVMContext &Ctx = Preheader->getContext(); | 
|---|
| 106 | BasicBlock * = | 
|---|
| 107 | BasicBlock::Create(Context&: Ctx, Name: Name + ".header", Parent: Preheader->getParent(), InsertBefore: Exit); | 
|---|
| 108 | BasicBlock *Body = | 
|---|
| 109 | BasicBlock::Create(Context&: Ctx, Name: Name + ".body", Parent: Header->getParent(), InsertBefore: Exit); | 
|---|
| 110 | BasicBlock *Latch = | 
|---|
| 111 | BasicBlock::Create(Context&: Ctx, Name: Name + ".latch", Parent: Header->getParent(), InsertBefore: Exit); | 
|---|
| 112 |  | 
|---|
| 113 | Type *I16Ty = Type::getInt16Ty(C&: Ctx); | 
|---|
| 114 | BranchInst::Create(IfTrue: Body, InsertBefore: Header); | 
|---|
| 115 | BranchInst::Create(IfTrue: Latch, InsertBefore: Body); | 
|---|
| 116 | PHINode *IV = | 
|---|
| 117 | PHINode::Create(Ty: I16Ty, NumReservedValues: 2, NameStr: Name + ".iv", InsertBefore: Header->getTerminator()->getIterator()); | 
|---|
| 118 | IV->addIncoming(V: ConstantInt::get(Ty: I16Ty, V: 0), BB: Preheader); | 
|---|
| 119 |  | 
|---|
| 120 | B.SetInsertPoint(Latch); | 
|---|
| 121 | Value *Inc = B.CreateAdd(LHS: IV, RHS: Step, Name: Name + ".step"); | 
|---|
| 122 | Value *Cond = B.CreateICmpNE(LHS: Inc, RHS: Bound, Name: Name + ".cond"); | 
|---|
| 123 | BranchInst::Create(IfTrue: Header, IfFalse: Exit, Cond, InsertBefore: Latch); | 
|---|
| 124 | IV->addIncoming(V: Inc, BB: Latch); | 
|---|
| 125 |  | 
|---|
| 126 | BranchInst * = cast<BranchInst>(Val: Preheader->getTerminator()); | 
|---|
| 127 | BasicBlock *Tmp = PreheaderBr->getSuccessor(i: 0); | 
|---|
| 128 | PreheaderBr->setSuccessor(idx: 0, NewSucc: Header); | 
|---|
| 129 | DTU.applyUpdatesPermissive(Updates: { | 
|---|
| 130 | {DominatorTree::Delete, Preheader, Tmp}, | 
|---|
| 131 | {DominatorTree::Insert, Header, Body}, | 
|---|
| 132 | {DominatorTree::Insert, Body, Latch}, | 
|---|
| 133 | {DominatorTree::Insert, Latch, Header}, | 
|---|
| 134 | {DominatorTree::Insert, Latch, Exit}, | 
|---|
| 135 | {DominatorTree::Insert, Preheader, Header}, | 
|---|
| 136 | }); | 
|---|
| 137 | if (LI) { | 
|---|
| 138 | L->addBasicBlockToLoop(NewBB: Header, LI&: *LI); | 
|---|
| 139 | L->addBasicBlockToLoop(NewBB: Body, LI&: *LI); | 
|---|
| 140 | L->addBasicBlockToLoop(NewBB: Latch, LI&: *LI); | 
|---|
| 141 | } | 
|---|
| 142 | return Body; | 
|---|
| 143 | } | 
|---|
| 144 |  | 
|---|
| 145 | template <bool IsTileLoad> | 
|---|
| 146 | Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops( | 
|---|
| 147 | BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row, | 
|---|
| 148 | Value *Col, Value *Ptr, Value *Stride, Value *Tile) { | 
|---|
| 149 | std::string IntrinName = IsTileLoad ? "tileload": "tilestore"; | 
|---|
| 150 | Loop *RowLoop = nullptr; | 
|---|
| 151 | Loop *ColLoop = nullptr; | 
|---|
| 152 | if (LI) { | 
|---|
| 153 | RowLoop = LI->AllocateLoop(); | 
|---|
| 154 | ColLoop = LI->AllocateLoop(); | 
|---|
| 155 | RowLoop->addChildLoop(NewChild: ColLoop); | 
|---|
| 156 | if (Loop *ParentL = LI->getLoopFor(BB: Start)) | 
|---|
| 157 | ParentL->addChildLoop(NewChild: RowLoop); | 
|---|
| 158 | else | 
|---|
| 159 | LI->addTopLevelLoop(New: RowLoop); | 
|---|
| 160 | } | 
|---|
| 161 |  | 
|---|
| 162 | BasicBlock *RowBody = createLoop(Preheader: Start, Exit: End, Bound: Row, Step: B.getInt16(C: 1), | 
|---|
| 163 | Name: IntrinName + ".scalarize.rows", B, L: RowLoop); | 
|---|
| 164 | BasicBlock *RowLatch = RowBody->getSingleSuccessor(); | 
|---|
| 165 |  | 
|---|
| 166 | BasicBlock *ColBody = createLoop(Preheader: RowBody, Exit: RowLatch, Bound: Col, Step: B.getInt16(C: 1), | 
|---|
| 167 | Name: IntrinName + ".scalarize.cols", B, L: ColLoop); | 
|---|
| 168 |  | 
|---|
| 169 | BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor(); | 
|---|
| 170 | BasicBlock * = ColBody->getSinglePredecessor(); | 
|---|
| 171 | BasicBlock * = RowBody->getSinglePredecessor(); | 
|---|
| 172 | Value *CurrentRow = &*RowLoopHeader->begin(); | 
|---|
| 173 | Value *CurrentCol = &*ColLoopHeader->begin(); | 
|---|
| 174 | Type *EltTy = B.getInt32Ty(); | 
|---|
| 175 | FixedVectorType *V256I32Ty = FixedVectorType::get(ElementType: EltTy, NumElts: 256); | 
|---|
| 176 |  | 
|---|
| 177 | // Common part for tileload and tilestore | 
|---|
| 178 | // *.scalarize.cols.body: | 
|---|
| 179 | // Calculate %idxmem and %idxvec | 
|---|
| 180 | B.SetInsertPoint(ColBody->getTerminator()); | 
|---|
| 181 | Value *CurrentRowZExt = B.CreateZExt(V: CurrentRow, DestTy: Stride->getType()); | 
|---|
| 182 | Value *CurrentColZExt = B.CreateZExt(V: CurrentCol, DestTy: Stride->getType()); | 
|---|
| 183 | Value *Offset = | 
|---|
| 184 | B.CreateAdd(LHS: B.CreateMul(LHS: CurrentRowZExt, RHS: Stride), RHS: CurrentColZExt); | 
|---|
| 185 | Value *EltPtr = B.CreateGEP(Ty: EltTy, Ptr, IdxList: Offset); | 
|---|
| 186 | Value *Idx = B.CreateAdd(LHS: B.CreateMul(LHS: CurrentRow, RHS: B.getInt16(C: 16)), RHS: CurrentCol); | 
|---|
| 187 | if (IsTileLoad) { | 
|---|
| 188 | // tileload.scalarize.rows.header: | 
|---|
| 189 | // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec, | 
|---|
| 190 | // %tileload.scalarize.rows.latch ] | 
|---|
| 191 | B.SetInsertPoint(RowLoopHeader->getTerminator()); | 
|---|
| 192 | Value *VecZero = Constant::getNullValue(Ty: V256I32Ty); | 
|---|
| 193 | PHINode *VecCPhiRowLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.phi.row"); | 
|---|
| 194 | VecCPhiRowLoop->addIncoming(V: VecZero, BB: Start); | 
|---|
| 195 |  | 
|---|
| 196 | // tileload.scalarize.cols.header: | 
|---|
| 197 | // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body | 
|---|
| 198 | // ], [ %ResVec, %tileload.scalarize.cols.latch ] | 
|---|
| 199 | B.SetInsertPoint(ColLoopHeader->getTerminator()); | 
|---|
| 200 | PHINode *VecPhi = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.phi"); | 
|---|
| 201 | VecPhi->addIncoming(V: VecCPhiRowLoop, BB: RowBody); | 
|---|
| 202 |  | 
|---|
| 203 | // tileload.scalarize.cols.body: | 
|---|
| 204 | // Calculate %idxmem and %idxvec | 
|---|
| 205 | // %eltptr = getelementptr i32, i32* %base, i64 %idxmem | 
|---|
| 206 | // %elt = load i32, i32* %ptr | 
|---|
| 207 | // %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec | 
|---|
| 208 | B.SetInsertPoint(ColBody->getTerminator()); | 
|---|
| 209 | Value *Elt = B.CreateLoad(Ty: EltTy, Ptr: EltPtr); | 
|---|
| 210 | Value *ResVec = B.CreateInsertElement(Vec: VecPhi, NewElt: Elt, Idx); | 
|---|
| 211 | VecPhi->addIncoming(V: ResVec, BB: ColLoopLatch); | 
|---|
| 212 | VecCPhiRowLoop->addIncoming(V: ResVec, BB: RowLatch); | 
|---|
| 213 |  | 
|---|
| 214 | return ResVec; | 
|---|
| 215 | } else { | 
|---|
| 216 | auto *BitCast = cast<BitCastInst>(Val: Tile); | 
|---|
| 217 | Value *Vec = BitCast->getOperand(i_nocapture: 0); | 
|---|
| 218 | assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx"); | 
|---|
| 219 | // tilestore.scalarize.cols.body: | 
|---|
| 220 | // %mul = mul i16 %row.iv, i16 16 | 
|---|
| 221 | // %idx = add i16 %mul, i16 %col.iv | 
|---|
| 222 | // %vec = extractelement <16 x i32> %vec, i16 %idx | 
|---|
| 223 | // store i32 %vec, i32* %ptr | 
|---|
| 224 | B.SetInsertPoint(ColBody->getTerminator()); | 
|---|
| 225 | Value *Elt = B.CreateExtractElement(Vec, Idx); | 
|---|
| 226 |  | 
|---|
| 227 | B.CreateStore(Val: Elt, Ptr: EltPtr); | 
|---|
| 228 | return nullptr; | 
|---|
| 229 | } | 
|---|
| 230 | } | 
|---|
| 231 |  | 
|---|
| 232 | template <Intrinsic::ID IntrID> | 
|---|
| 233 | std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || | 
|---|
| 234 | IntrID == Intrinsic::x86_tdpbsud_internal || | 
|---|
| 235 | IntrID == Intrinsic::x86_tdpbusd_internal || | 
|---|
| 236 | IntrID == Intrinsic::x86_tdpbuud_internal || | 
|---|
| 237 | IntrID == Intrinsic::x86_tdpbf16ps_internal, | 
|---|
| 238 | Value *> | 
|---|
| 239 | X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End, | 
|---|
| 240 | IRBuilderBase &B, Value *Row, | 
|---|
| 241 | Value *Col, Value *K, Value *Acc, | 
|---|
| 242 | Value *LHS, Value *RHS) { | 
|---|
| 243 | std::string IntrinName; | 
|---|
| 244 | switch (IntrID) { | 
|---|
| 245 | case Intrinsic::x86_tdpbssd_internal: | 
|---|
| 246 | IntrinName = "tiledpbssd"; | 
|---|
| 247 | break; | 
|---|
| 248 | case Intrinsic::x86_tdpbsud_internal: | 
|---|
| 249 | IntrinName = "tiledpbsud"; | 
|---|
| 250 | break; | 
|---|
| 251 | case Intrinsic::x86_tdpbusd_internal: | 
|---|
| 252 | IntrinName = "tiledpbusd"; | 
|---|
| 253 | break; | 
|---|
| 254 | case Intrinsic::x86_tdpbuud_internal: | 
|---|
| 255 | IntrinName = "tiledpbuud"; | 
|---|
| 256 | break; | 
|---|
| 257 | case Intrinsic::x86_tdpbf16ps_internal: | 
|---|
| 258 | IntrinName = "tiledpbf16ps"; | 
|---|
| 259 | break; | 
|---|
| 260 | } | 
|---|
| 261 | Loop *RowLoop = nullptr; | 
|---|
| 262 | Loop *ColLoop = nullptr; | 
|---|
| 263 | Loop *InnerLoop = nullptr; | 
|---|
| 264 | if (LI) { | 
|---|
| 265 | RowLoop = LI->AllocateLoop(); | 
|---|
| 266 | ColLoop = LI->AllocateLoop(); | 
|---|
| 267 | InnerLoop = LI->AllocateLoop(); | 
|---|
| 268 | ColLoop->addChildLoop(NewChild: InnerLoop); | 
|---|
| 269 | RowLoop->addChildLoop(NewChild: ColLoop); | 
|---|
| 270 | if (Loop *ParentL = LI->getLoopFor(BB: Start)) | 
|---|
| 271 | ParentL->addChildLoop(NewChild: RowLoop); | 
|---|
| 272 | else | 
|---|
| 273 | LI->addTopLevelLoop(New: RowLoop); | 
|---|
| 274 | } | 
|---|
| 275 |  | 
|---|
| 276 | BasicBlock *RowBody = createLoop(Preheader: Start, Exit: End, Bound: Row, Step: B.getInt16(C: 1), | 
|---|
| 277 | Name: IntrinName + ".scalarize.rows", B, L: RowLoop); | 
|---|
| 278 | BasicBlock *RowLatch = RowBody->getSingleSuccessor(); | 
|---|
| 279 |  | 
|---|
| 280 | BasicBlock *ColBody = createLoop(Preheader: RowBody, Exit: RowLatch, Bound: Col, Step: B.getInt16(C: 1), | 
|---|
| 281 | Name: IntrinName + ".scalarize.cols", B, L: ColLoop); | 
|---|
| 282 |  | 
|---|
| 283 | BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor(); | 
|---|
| 284 |  | 
|---|
| 285 | B.SetInsertPoint(ColBody->getTerminator()); | 
|---|
| 286 | BasicBlock *InnerBody = | 
|---|
| 287 | createLoop(Preheader: ColBody, Exit: ColLoopLatch, Bound: K, Step: B.getInt16(C: 1), | 
|---|
| 288 | Name: IntrinName + ".scalarize.inner", B, L: InnerLoop); | 
|---|
| 289 |  | 
|---|
| 290 | BasicBlock * = ColBody->getSinglePredecessor(); | 
|---|
| 291 | BasicBlock * = RowBody->getSinglePredecessor(); | 
|---|
| 292 | BasicBlock * = InnerBody->getSinglePredecessor(); | 
|---|
| 293 | BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor(); | 
|---|
| 294 | Value *CurrentRow = &*RowLoopHeader->begin(); | 
|---|
| 295 | Value *CurrentCol = &*ColLoopHeader->begin(); | 
|---|
| 296 | Value *CurrentInner = &*InnerLoopHeader->begin(); | 
|---|
| 297 |  | 
|---|
| 298 | FixedVectorType *V256I32Ty = FixedVectorType::get(ElementType: B.getInt32Ty(), NumElts: 256); | 
|---|
| 299 | auto *BitCastAcc = cast<BitCastInst>(Val: Acc); | 
|---|
| 300 | Value *VecC = BitCastAcc->getOperand(i_nocapture: 0); | 
|---|
| 301 | assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx"); | 
|---|
| 302 | // TODO else create BitCast from x86amx to v256i32. | 
|---|
| 303 | // Store x86amx to memory, and reload from memory | 
|---|
| 304 | // to vector. However with -O0, it doesn't happen. | 
|---|
| 305 | auto *BitCastLHS = cast<BitCastInst>(Val: LHS); | 
|---|
| 306 | Value *VecA = BitCastLHS->getOperand(i_nocapture: 0); | 
|---|
| 307 | assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx"); | 
|---|
| 308 | auto *BitCastRHS = cast<BitCastInst>(Val: RHS); | 
|---|
| 309 | Value *VecB = BitCastRHS->getOperand(i_nocapture: 0); | 
|---|
| 310 | assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx"); | 
|---|
| 311 |  | 
|---|
| 312 | // tiledpbssd.scalarize.rows.header: | 
|---|
| 313 | // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC, | 
|---|
| 314 | // %tiledpbssd.scalarize.rows.latch ] | 
|---|
| 315 |  | 
|---|
| 316 | // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [ | 
|---|
| 317 | // %NewVecD, %tiledpbssd.scalarize.rows.latch ] | 
|---|
| 318 | B.SetInsertPoint(RowLoopHeader->getTerminator()); | 
|---|
| 319 | PHINode *VecCPhiRowLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.c.phi.row"); | 
|---|
| 320 | VecCPhiRowLoop->addIncoming(V: VecC, BB: Start); | 
|---|
| 321 | Value *VecZero = Constant::getNullValue(Ty: V256I32Ty); | 
|---|
| 322 | PHINode *VecDPhiRowLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.d.phi.row"); | 
|---|
| 323 | VecDPhiRowLoop->addIncoming(V: VecZero, BB: Start); | 
|---|
| 324 |  | 
|---|
| 325 | // tiledpbssd.scalarize.cols.header: | 
|---|
| 326 | // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row, | 
|---|
| 327 | // %tiledpbssd.scalarize.rows.body ], [ %NewVecC, | 
|---|
| 328 | // %tiledpbssd.scalarize.cols.latch ] | 
|---|
| 329 |  | 
|---|
| 330 | // %vec.d.phi.col = phi <256 x i32> [ | 
|---|
| 331 | // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD, | 
|---|
| 332 | // %tiledpbssd.scalarize.cols.latch ] | 
|---|
| 333 |  | 
|---|
| 334 | // calculate idxc. | 
|---|
| 335 | B.SetInsertPoint(ColLoopHeader->getTerminator()); | 
|---|
| 336 | PHINode *VecCPhiColLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.c.phi.col"); | 
|---|
| 337 | VecCPhiColLoop->addIncoming(V: VecCPhiRowLoop, BB: RowBody); | 
|---|
| 338 | PHINode *VecDPhiColLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.d.phi.col"); | 
|---|
| 339 | VecDPhiColLoop->addIncoming(V: VecDPhiRowLoop, BB: RowBody); | 
|---|
| 340 | Value *IdxC = | 
|---|
| 341 | B.CreateAdd(LHS: B.CreateMul(LHS: CurrentRow, RHS: B.getInt16(C: 16)), RHS: CurrentCol); | 
|---|
| 342 |  | 
|---|
| 343 | // tiledpbssd.scalarize.inner.header: | 
|---|
| 344 | // %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col, | 
|---|
| 345 | // %tiledpbssd.scalarize.cols.body ], [ %NewVecC, | 
|---|
| 346 | // %tiledpbssd.scalarize.inner.latch ] | 
|---|
| 347 |  | 
|---|
| 348 | B.SetInsertPoint(InnerLoopHeader->getTerminator()); | 
|---|
| 349 | PHINode *VecCPhi = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.c.inner.phi"); | 
|---|
| 350 | VecCPhi->addIncoming(V: VecCPhiColLoop, BB: ColBody); | 
|---|
| 351 |  | 
|---|
| 352 | B.SetInsertPoint(InnerBody->getTerminator()); | 
|---|
| 353 | Value *IdxA = | 
|---|
| 354 | B.CreateAdd(LHS: B.CreateMul(LHS: CurrentRow, RHS: B.getInt16(C: 16)), RHS: CurrentInner); | 
|---|
| 355 | Value *IdxB = | 
|---|
| 356 | B.CreateAdd(LHS: B.CreateMul(LHS: CurrentInner, RHS: B.getInt16(C: 16)), RHS: CurrentCol); | 
|---|
| 357 | Value *NewVecC = nullptr; | 
|---|
| 358 |  | 
|---|
| 359 | if (IntrID != Intrinsic::x86_tdpbf16ps_internal) { | 
|---|
| 360 | // tiledpbssd.scalarize.inner.body: | 
|---|
| 361 | // calculate idxa, idxb | 
|---|
| 362 | // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc | 
|---|
| 363 | // %elta = extractelement <256 x i32> %veca, i16 %idxa | 
|---|
| 364 | // %eltav4i8 = bitcast i32 %elta to <4 x i8> | 
|---|
| 365 | // %eltb = extractelement <256 x i32> %vecb, i16 %idxb | 
|---|
| 366 | // %eltbv4i8 = bitcast i32 %eltb to <4 x i8> | 
|---|
| 367 | // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32> | 
|---|
| 368 | // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32> | 
|---|
| 369 | // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32 | 
|---|
| 370 | // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131) | 
|---|
| 371 | // %neweltc = add i32 %elt, %acc | 
|---|
| 372 | // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc, | 
|---|
| 373 | // i16 %idxc | 
|---|
| 374 | FixedVectorType *V4I8Ty = FixedVectorType::get(ElementType: B.getInt8Ty(), NumElts: 4); | 
|---|
| 375 | FixedVectorType *V4I32Ty = FixedVectorType::get(ElementType: B.getInt32Ty(), NumElts: 4); | 
|---|
| 376 | Value *EltC = B.CreateExtractElement(Vec: VecCPhi, Idx: IdxC); | 
|---|
| 377 | Value *EltA = B.CreateExtractElement(Vec: VecA, Idx: IdxA); | 
|---|
| 378 | Value *SubVecA = B.CreateBitCast(V: EltA, DestTy: V4I8Ty); | 
|---|
| 379 | Value *EltB = B.CreateExtractElement(Vec: VecB, Idx: IdxB); | 
|---|
| 380 | Value *SubVecB = B.CreateBitCast(V: EltB, DestTy: V4I8Ty); | 
|---|
| 381 | Value *SEXTSubVecB = nullptr; | 
|---|
| 382 | Value *SEXTSubVecA = nullptr; | 
|---|
| 383 | switch (IntrID) { | 
|---|
| 384 | case Intrinsic::x86_tdpbssd_internal: | 
|---|
| 385 | SEXTSubVecB = B.CreateSExt(V: SubVecB, DestTy: V4I32Ty); | 
|---|
| 386 | SEXTSubVecA = B.CreateSExt(V: SubVecA, DestTy: V4I32Ty); | 
|---|
| 387 | break; | 
|---|
| 388 | case Intrinsic::x86_tdpbsud_internal: | 
|---|
| 389 | SEXTSubVecB = B.CreateZExt(V: SubVecB, DestTy: V4I32Ty); | 
|---|
| 390 | SEXTSubVecA = B.CreateSExt(V: SubVecA, DestTy: V4I32Ty); | 
|---|
| 391 | break; | 
|---|
| 392 | case Intrinsic::x86_tdpbusd_internal: | 
|---|
| 393 | SEXTSubVecB = B.CreateSExt(V: SubVecB, DestTy: V4I32Ty); | 
|---|
| 394 | SEXTSubVecA = B.CreateZExt(V: SubVecA, DestTy: V4I32Ty); | 
|---|
| 395 | break; | 
|---|
| 396 | case Intrinsic::x86_tdpbuud_internal: | 
|---|
| 397 | SEXTSubVecB = B.CreateZExt(V: SubVecB, DestTy: V4I32Ty); | 
|---|
| 398 | SEXTSubVecA = B.CreateZExt(V: SubVecA, DestTy: V4I32Ty); | 
|---|
| 399 | break; | 
|---|
| 400 | default: | 
|---|
| 401 | llvm_unreachable( "Invalid intrinsic ID!"); | 
|---|
| 402 | } | 
|---|
| 403 | Value *SubVecR = B.CreateAddReduce(Src: B.CreateMul(LHS: SEXTSubVecA, RHS: SEXTSubVecB)); | 
|---|
| 404 | Value *ResElt = B.CreateAdd(LHS: EltC, RHS: SubVecR); | 
|---|
| 405 | NewVecC = B.CreateInsertElement(Vec: VecCPhi, NewElt: ResElt, Idx: IdxC); | 
|---|
| 406 | } else { | 
|---|
| 407 | // tiledpbf16ps.scalarize.inner.body: | 
|---|
| 408 | // calculate idxa, idxb, idxc | 
|---|
| 409 | // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc | 
|---|
| 410 | // %eltcf32 = bitcast i32 %eltc to float | 
|---|
| 411 | // %elta = extractelement <256 x i32> %veca, i16 %idxa | 
|---|
| 412 | // %eltav2i16 = bitcast i32 %elta to <2 x i16> | 
|---|
| 413 | // %eltb = extractelement <256 x i32> %vecb, i16 %idxb | 
|---|
| 414 | // %eltbv2i16 = bitcast i32 %eltb to <2 x i16> | 
|---|
| 415 | // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4 | 
|---|
| 416 | // x i32> <i32 2, i32 0, i32 3, i32 1> | 
|---|
| 417 | // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float> | 
|---|
| 418 | // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x | 
|---|
| 419 | // i32> <i32 2, i32 0, i32 3, i32 1> | 
|---|
| 420 | // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float> | 
|---|
| 421 | // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32 | 
|---|
| 422 | // %acc = call float | 
|---|
| 423 | // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab) | 
|---|
| 424 | // %neweltc = bitcast float %acc to i32 | 
|---|
| 425 | // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc, | 
|---|
| 426 | // i16 %idxc | 
|---|
| 427 | // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc, | 
|---|
| 428 | // i16 %idxc | 
|---|
| 429 | FixedVectorType *V2I16Ty = FixedVectorType::get(ElementType: B.getInt16Ty(), NumElts: 2); | 
|---|
| 430 | FixedVectorType *V2F32Ty = FixedVectorType::get(ElementType: B.getFloatTy(), NumElts: 2); | 
|---|
| 431 | Value *EltC = B.CreateExtractElement(Vec: VecCPhi, Idx: IdxC); | 
|---|
| 432 | Value *EltCF32 = B.CreateBitCast(V: EltC, DestTy: B.getFloatTy()); | 
|---|
| 433 | Value *EltA = B.CreateExtractElement(Vec: VecA, Idx: IdxA); | 
|---|
| 434 | Value *SubVecA = B.CreateBitCast(V: EltA, DestTy: V2I16Ty); | 
|---|
| 435 | Value *EltB = B.CreateExtractElement(Vec: VecB, Idx: IdxB); | 
|---|
| 436 | Value *SubVecB = B.CreateBitCast(V: EltB, DestTy: V2I16Ty); | 
|---|
| 437 | Value *ZeroV2I16 = Constant::getNullValue(Ty: V2I16Ty); | 
|---|
| 438 | int ShuffleMask[4] = {2, 0, 3, 1}; | 
|---|
| 439 | auto ShuffleArray = ArrayRef(ShuffleMask); | 
|---|
| 440 | Value *AV2F32 = B.CreateBitCast( | 
|---|
| 441 | V: B.CreateShuffleVector(V1: SubVecA, V2: ZeroV2I16, Mask: ShuffleArray), DestTy: V2F32Ty); | 
|---|
| 442 | Value *BV2F32 = B.CreateBitCast( | 
|---|
| 443 | V: B.CreateShuffleVector(V1: SubVecB, V2: ZeroV2I16, Mask: ShuffleArray), DestTy: V2F32Ty); | 
|---|
| 444 | Value *SubVecR = B.CreateFAddReduce(Acc: EltCF32, Src: B.CreateFMul(L: AV2F32, R: BV2F32)); | 
|---|
| 445 | Value *ResElt = B.CreateBitCast(V: SubVecR, DestTy: B.getInt32Ty()); | 
|---|
| 446 | NewVecC = B.CreateInsertElement(Vec: VecCPhi, NewElt: ResElt, Idx: IdxC); | 
|---|
| 447 | } | 
|---|
| 448 |  | 
|---|
| 449 | // tiledpbssd.scalarize.cols.latch: | 
|---|
| 450 | // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc | 
|---|
| 451 | // %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC, | 
|---|
| 452 | // i16 %idxc | 
|---|
| 453 | B.SetInsertPoint(ColLoopLatch->getTerminator()); | 
|---|
| 454 | Value *NewEltC = B.CreateExtractElement(Vec: NewVecC, Idx: IdxC); | 
|---|
| 455 | Value *NewVecD = B.CreateInsertElement(Vec: VecDPhiColLoop, NewElt: NewEltC, Idx: IdxC); | 
|---|
| 456 |  | 
|---|
| 457 | VecCPhi->addIncoming(V: NewVecC, BB: InnerLoopLatch); | 
|---|
| 458 | VecCPhiRowLoop->addIncoming(V: NewVecC, BB: RowLatch); | 
|---|
| 459 | VecCPhiColLoop->addIncoming(V: NewVecC, BB: ColLoopLatch); | 
|---|
| 460 | VecDPhiRowLoop->addIncoming(V: NewVecD, BB: RowLatch); | 
|---|
| 461 | VecDPhiColLoop->addIncoming(V: NewVecD, BB: ColLoopLatch); | 
|---|
| 462 |  | 
|---|
| 463 | return NewVecD; | 
|---|
| 464 | } | 
|---|
| 465 |  | 
|---|
| 466 | template <Intrinsic::ID IntrID> | 
|---|
| 467 | std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || | 
|---|
| 468 | IntrID == Intrinsic::x86_tdpbsud_internal || | 
|---|
| 469 | IntrID == Intrinsic::x86_tdpbusd_internal || | 
|---|
| 470 | IntrID == Intrinsic::x86_tdpbuud_internal || | 
|---|
| 471 | IntrID == Intrinsic::x86_tdpbf16ps_internal, | 
|---|
| 472 | bool> | 
|---|
| 473 | X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) { | 
|---|
| 474 | Value *M, *N, *K, *C, *A, *B; | 
|---|
| 475 | match(TileDP, m_Intrinsic<IntrID>(m_Value(V&: M), m_Value(V&: N), m_Value(V&: K), | 
|---|
| 476 | m_Value(V&: C), m_Value(V&: A), m_Value(V&: B))); | 
|---|
| 477 | Instruction *InsertI = TileDP; | 
|---|
| 478 | IRBuilder<> PreBuilder(TileDP); | 
|---|
| 479 | PreBuilder.SetInsertPoint(TileDP); | 
|---|
| 480 | // We visit the loop with (m, n/4, k/4): | 
|---|
| 481 | // %n_dword = lshr i16 %n, 2 | 
|---|
| 482 | // %k_dword = lshr i16 %k, 2 | 
|---|
| 483 | Value *NDWord = PreBuilder.CreateLShr(LHS: N, RHS: PreBuilder.getInt16(C: 2)); | 
|---|
| 484 | Value *KDWord = PreBuilder.CreateLShr(LHS: K, RHS: PreBuilder.getInt16(C: 2)); | 
|---|
| 485 | BasicBlock *Start = InsertI->getParent(); | 
|---|
| 486 | BasicBlock *End = | 
|---|
| 487 | SplitBlock(Old: InsertI->getParent(), SplitPt: InsertI, DTU: &DTU, LI, MSSAU: nullptr, BBName: "continue"); | 
|---|
| 488 | IRBuilder<> Builder(TileDP); | 
|---|
| 489 | Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord, | 
|---|
| 490 | KDWord, C, A, B); | 
|---|
| 491 | // we cannot assume there always be bitcast after tiledpbssd. So we need to | 
|---|
| 492 | // insert one bitcast as required | 
|---|
| 493 | Builder.SetInsertPoint(TheBB: End, IP: End->getFirstNonPHIIt()); | 
|---|
| 494 | Value *ResAMX = | 
|---|
| 495 | Builder.CreateBitCast(V: ResVec, DestTy: Type::getX86_AMXTy(C&: Builder.getContext())); | 
|---|
| 496 | // Delete TileDP intrinsic and do some clean-up. | 
|---|
| 497 | for (Use &U : llvm::make_early_inc_range(Range: TileDP->uses())) { | 
|---|
| 498 | Instruction *I = cast<Instruction>(Val: U.getUser()); | 
|---|
| 499 | Value *Vec; | 
|---|
| 500 | if (match(V: I, P: m_BitCast(Op: m_Value(V&: Vec)))) { | 
|---|
| 501 | I->replaceAllUsesWith(V: ResVec); | 
|---|
| 502 | I->eraseFromParent(); | 
|---|
| 503 | } | 
|---|
| 504 | } | 
|---|
| 505 | TileDP->replaceAllUsesWith(V: ResAMX); | 
|---|
| 506 | TileDP->eraseFromParent(); | 
|---|
| 507 | return true; | 
|---|
| 508 | } | 
|---|
| 509 |  | 
|---|
| 510 | template <bool IsTileLoad> | 
|---|
| 511 | bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) { | 
|---|
| 512 | Value *M, *N, *Ptr, *Stride, *Tile; | 
|---|
| 513 | if (IsTileLoad) | 
|---|
| 514 | match(V: TileLoadStore, | 
|---|
| 515 | P: m_Intrinsic<Intrinsic::x86_tileloadd64_internal>( | 
|---|
| 516 | Op0: m_Value(V&: M), Op1: m_Value(V&: N), Op2: m_Value(V&: Ptr), Op3: m_Value(V&: Stride))); | 
|---|
| 517 | else | 
|---|
| 518 | match(V: TileLoadStore, P: m_Intrinsic<Intrinsic::x86_tilestored64_internal>( | 
|---|
| 519 | Op0: m_Value(V&: M), Op1: m_Value(V&: N), Op2: m_Value(V&: Ptr), | 
|---|
| 520 | Op3: m_Value(V&: Stride), Op4: m_Value(V&: Tile))); | 
|---|
| 521 |  | 
|---|
| 522 | Instruction *InsertI = TileLoadStore; | 
|---|
| 523 | IRBuilder<> PreBuilder(TileLoadStore); | 
|---|
| 524 | PreBuilder.SetInsertPoint(TileLoadStore); | 
|---|
| 525 | Value *NDWord = PreBuilder.CreateLShr(LHS: N, RHS: PreBuilder.getInt16(C: 2)); | 
|---|
| 526 | Value *StrideDWord = PreBuilder.CreateLShr(LHS: Stride, RHS: PreBuilder.getInt64(C: 2)); | 
|---|
| 527 | BasicBlock *Start = InsertI->getParent(); | 
|---|
| 528 | BasicBlock *End = | 
|---|
| 529 | SplitBlock(Old: InsertI->getParent(), SplitPt: InsertI, DTU: &DTU, LI, MSSAU: nullptr, BBName: "continue"); | 
|---|
| 530 | IRBuilder<> Builder(TileLoadStore); | 
|---|
| 531 | Value *ResVec = createTileLoadStoreLoops<IsTileLoad>( | 
|---|
| 532 | Start, End, Builder, M, NDWord, Ptr, StrideDWord, | 
|---|
| 533 | IsTileLoad ? nullptr : Tile); | 
|---|
| 534 | if (IsTileLoad) { | 
|---|
| 535 | // we cannot assume there always be bitcast after tileload. So we need to | 
|---|
| 536 | // insert one bitcast as required | 
|---|
| 537 | Builder.SetInsertPoint(TheBB: End, IP: End->getFirstNonPHIIt()); | 
|---|
| 538 | Value *ResAMX = | 
|---|
| 539 | Builder.CreateBitCast(V: ResVec, DestTy: Type::getX86_AMXTy(C&: Builder.getContext())); | 
|---|
| 540 | // Delete tileloadd6 intrinsic and do some clean-up | 
|---|
| 541 | for (Use &U : llvm::make_early_inc_range(Range: TileLoadStore->uses())) { | 
|---|
| 542 | Instruction *I = cast<Instruction>(Val: U.getUser()); | 
|---|
| 543 | Value *Vec; | 
|---|
| 544 | if (match(V: I, P: m_BitCast(Op: m_Value(V&: Vec)))) { | 
|---|
| 545 | I->replaceAllUsesWith(V: ResVec); | 
|---|
| 546 | I->eraseFromParent(); | 
|---|
| 547 | } | 
|---|
| 548 | } | 
|---|
| 549 | TileLoadStore->replaceAllUsesWith(V: ResAMX); | 
|---|
| 550 | } | 
|---|
| 551 | TileLoadStore->eraseFromParent(); | 
|---|
| 552 | return true; | 
|---|
| 553 | } | 
|---|
| 554 |  | 
|---|
| 555 | bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) { | 
|---|
| 556 | IRBuilder<> Builder(TileZero); | 
|---|
| 557 | FixedVectorType *V256I32Ty = FixedVectorType::get(ElementType: Builder.getInt32Ty(), NumElts: 256); | 
|---|
| 558 | Value *VecZero = Constant::getNullValue(Ty: V256I32Ty); | 
|---|
| 559 | for (Use &U : llvm::make_early_inc_range(Range: TileZero->uses())) { | 
|---|
| 560 | Instruction *I = cast<Instruction>(Val: U.getUser()); | 
|---|
| 561 | Value *Vec; | 
|---|
| 562 | if (match(V: I, P: m_BitCast(Op: m_Value(V&: Vec)))) { | 
|---|
| 563 | I->replaceAllUsesWith(V: VecZero); | 
|---|
| 564 | I->eraseFromParent(); | 
|---|
| 565 | } | 
|---|
| 566 | } | 
|---|
| 567 | TileZero->eraseFromParent(); | 
|---|
| 568 | return true; | 
|---|
| 569 | } | 
|---|
| 570 |  | 
|---|
| 571 | bool X86LowerAMXIntrinsics::visit() { | 
|---|
| 572 | bool C = false; | 
|---|
| 573 | SmallVector<IntrinsicInst *, 8> WorkList; | 
|---|
| 574 | for (BasicBlock *BB : depth_first(G: &Func)) { | 
|---|
| 575 | for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) { | 
|---|
| 576 | if (auto *Inst = dyn_cast<IntrinsicInst>(Val: &*II++)) { | 
|---|
| 577 | switch (Inst->getIntrinsicID()) { | 
|---|
| 578 | case Intrinsic::x86_tdpbssd_internal: | 
|---|
| 579 | case Intrinsic::x86_tdpbsud_internal: | 
|---|
| 580 | case Intrinsic::x86_tdpbusd_internal: | 
|---|
| 581 | case Intrinsic::x86_tdpbuud_internal: | 
|---|
| 582 | case Intrinsic::x86_tileloadd64_internal: | 
|---|
| 583 | case Intrinsic::x86_tilestored64_internal: | 
|---|
| 584 | case Intrinsic::x86_tilezero_internal: | 
|---|
| 585 | case Intrinsic::x86_tdpbf16ps_internal: | 
|---|
| 586 | WorkList.push_back(Elt: Inst); | 
|---|
| 587 | break; | 
|---|
| 588 | default: | 
|---|
| 589 | break; | 
|---|
| 590 | } | 
|---|
| 591 | } | 
|---|
| 592 | } | 
|---|
| 593 | } | 
|---|
| 594 |  | 
|---|
| 595 | for (auto *Inst : WorkList) { | 
|---|
| 596 | switch (Inst->getIntrinsicID()) { | 
|---|
| 597 | case Intrinsic::x86_tdpbssd_internal: | 
|---|
| 598 | C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(TileDP: Inst) || C; | 
|---|
| 599 | break; | 
|---|
| 600 | case Intrinsic::x86_tdpbsud_internal: | 
|---|
| 601 | C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(TileDP: Inst) || C; | 
|---|
| 602 | break; | 
|---|
| 603 | case Intrinsic::x86_tdpbusd_internal: | 
|---|
| 604 | C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(TileDP: Inst) || C; | 
|---|
| 605 | break; | 
|---|
| 606 | case Intrinsic::x86_tdpbuud_internal: | 
|---|
| 607 | C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(TileDP: Inst) || C; | 
|---|
| 608 | break; | 
|---|
| 609 | case Intrinsic::x86_tdpbf16ps_internal: | 
|---|
| 610 | C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(TileDP: Inst) || C; | 
|---|
| 611 | break; | 
|---|
| 612 | case Intrinsic::x86_tileloadd64_internal: | 
|---|
| 613 | C = lowerTileLoadStore<true>(TileLoadStore: Inst) || C; | 
|---|
| 614 | break; | 
|---|
| 615 | case Intrinsic::x86_tilestored64_internal: | 
|---|
| 616 | C = lowerTileLoadStore<false>(TileLoadStore: Inst) || C; | 
|---|
| 617 | break; | 
|---|
| 618 | case Intrinsic::x86_tilezero_internal: | 
|---|
| 619 | C = lowerTileZero(TileZero: Inst) || C; | 
|---|
| 620 | break; | 
|---|
| 621 | default: | 
|---|
| 622 | llvm_unreachable( "invalid amx intrinsics!"); | 
|---|
| 623 | } | 
|---|
| 624 | } | 
|---|
| 625 |  | 
|---|
| 626 | return C; | 
|---|
| 627 | } | 
|---|
| 628 |  | 
|---|
| 629 | namespace { | 
|---|
| 630 | class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass { | 
|---|
| 631 | public: | 
|---|
| 632 | static char ID; | 
|---|
| 633 |  | 
|---|
| 634 | X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) {} | 
|---|
| 635 |  | 
|---|
| 636 | bool runOnFunction(Function &F) override { | 
|---|
| 637 | if (!X86ScalarizeAMX) | 
|---|
| 638 | return false; | 
|---|
| 639 | TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); | 
|---|
| 640 | if (!F.hasFnAttribute(Kind: Attribute::OptimizeNone) && | 
|---|
| 641 | TM->getOptLevel() != CodeGenOptLevel::None) | 
|---|
| 642 | return false; | 
|---|
| 643 |  | 
|---|
| 644 | auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); | 
|---|
| 645 | auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; | 
|---|
| 646 | auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>(); | 
|---|
| 647 | auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr; | 
|---|
| 648 | DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); | 
|---|
| 649 |  | 
|---|
| 650 | X86LowerAMXIntrinsics LAT(F, DTU, LI); | 
|---|
| 651 | return LAT.visit(); | 
|---|
| 652 | } | 
|---|
| 653 | StringRef getPassName() const override { return "Lower AMX intrinsics"; } | 
|---|
| 654 |  | 
|---|
| 655 | void getAnalysisUsage(AnalysisUsage &AU) const override { | 
|---|
| 656 | AU.addPreserved<DominatorTreeWrapperPass>(); | 
|---|
| 657 | AU.addPreserved<LoopInfoWrapperPass>(); | 
|---|
| 658 | AU.addRequired<TargetPassConfig>(); | 
|---|
| 659 | } | 
|---|
| 660 | }; | 
|---|
| 661 | } // namespace | 
|---|
| 662 |  | 
|---|
| 663 | static const char PassName[] = "Lower AMX intrinsics"; | 
|---|
| 664 | char X86LowerAMXIntrinsicsLegacyPass::ID = 0; | 
|---|
| 665 | INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName, | 
|---|
| 666 | false, false) | 
|---|
| 667 | INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) | 
|---|
| 668 | INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName, | 
|---|
| 669 | false, false) | 
|---|
| 670 |  | 
|---|
| 671 | FunctionPass *llvm::createX86LowerAMXIntrinsicsPass() { | 
|---|
| 672 | return new X86LowerAMXIntrinsicsLegacyPass(); | 
|---|
| 673 | } | 
|---|
| 674 |  | 
|---|