| 1 | //===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | /// \file Pass to transform amx intrinsics to scalar operations. |
| 10 | /// This pass is always enabled and it skips when it is not -O0 and has no |
| 11 | /// optnone attributes. With -O0 or optnone attribute, the def of shape to amx |
| 12 | /// intrinsics is near the amx intrinsics code. We are not able to find a |
| 13 | /// point which post-dominate all the shape and dominate all amx intrinsics. |
| 14 | /// To decouple the dependency of the shape, we transform amx intrinsics |
| 15 | /// to scalar operation, so that compiling doesn't fail. In long term, we |
| 16 | /// should improve fast register allocation to allocate amx register. |
| 17 | //===----------------------------------------------------------------------===// |
| 18 | // |
| 19 | #include "X86.h" |
| 20 | #include "llvm/Analysis/DomTreeUpdater.h" |
| 21 | #include "llvm/Analysis/LoopInfo.h" |
| 22 | #include "llvm/Analysis/TargetTransformInfo.h" |
| 23 | #include "llvm/CodeGen/Passes.h" |
| 24 | #include "llvm/CodeGen/TargetPassConfig.h" |
| 25 | #include "llvm/CodeGen/ValueTypes.h" |
| 26 | #include "llvm/IR/DataLayout.h" |
| 27 | #include "llvm/IR/Function.h" |
| 28 | #include "llvm/IR/IRBuilder.h" |
| 29 | #include "llvm/IR/Instructions.h" |
| 30 | #include "llvm/IR/IntrinsicInst.h" |
| 31 | #include "llvm/IR/IntrinsicsX86.h" |
| 32 | #include "llvm/IR/PatternMatch.h" |
| 33 | #include "llvm/InitializePasses.h" |
| 34 | #include "llvm/Pass.h" |
| 35 | #include "llvm/Support/CommandLine.h" |
| 36 | #include "llvm/Target/TargetMachine.h" |
| 37 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
| 38 | #include "llvm/Transforms/Utils/LoopUtils.h" |
| 39 | |
| 40 | using namespace llvm; |
| 41 | using namespace PatternMatch; |
| 42 | |
| 43 | #define DEBUG_TYPE "lower-amx-intrinsics" |
| 44 | |
| 45 | #ifndef NDEBUG |
| 46 | static bool isV256I32Ty(Type *Ty) { |
| 47 | if (auto *FVT = dyn_cast<FixedVectorType>(Ty)) |
| 48 | return FVT->getNumElements() == 256 && |
| 49 | FVT->getElementType()->isIntegerTy(32); |
| 50 | return false; |
| 51 | } |
| 52 | #endif |
| 53 | |
| 54 | static cl::opt<bool> |
| 55 | X86ScalarizeAMX("enable-x86-scalar-amx" , cl::init(Val: false), cl::Hidden, |
| 56 | cl::desc("X86: enable AMX scalarizition." )); |
| 57 | |
| 58 | namespace { |
| 59 | class X86LowerAMXIntrinsics { |
| 60 | Function &Func; |
| 61 | |
| 62 | public: |
| 63 | X86LowerAMXIntrinsics(Function &F, DomTreeUpdater &DomTU, LoopInfo *LoopI) |
| 64 | : Func(F), DTU(DomTU), LI(LoopI) {} |
| 65 | bool visit(); |
| 66 | |
| 67 | private: |
| 68 | DomTreeUpdater &DTU; |
| 69 | LoopInfo *LI; |
| 70 | BasicBlock *createLoop(BasicBlock *, BasicBlock *Exit, Value *Bound, |
| 71 | Value *Step, StringRef Name, IRBuilderBase &B, |
| 72 | Loop *L); |
| 73 | template <bool IsTileLoad> |
| 74 | Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End, |
| 75 | IRBuilderBase &B, Value *Row, Value *Col, |
| 76 | Value *Ptr, Value *Stride, Value *Tile); |
| 77 | template <Intrinsic::ID IntrID> |
| 78 | std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || |
| 79 | IntrID == Intrinsic::x86_tdpbsud_internal || |
| 80 | IntrID == Intrinsic::x86_tdpbusd_internal || |
| 81 | IntrID == Intrinsic::x86_tdpbuud_internal || |
| 82 | IntrID == Intrinsic::x86_tdpbf16ps_internal, |
| 83 | Value *> |
| 84 | createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, |
| 85 | Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS, |
| 86 | Value *RHS); |
| 87 | template <bool IsTileLoad> |
| 88 | bool lowerTileLoadStore(Instruction *TileLoadStore); |
| 89 | template <Intrinsic::ID IntrID> |
| 90 | std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || |
| 91 | IntrID == Intrinsic::x86_tdpbsud_internal || |
| 92 | IntrID == Intrinsic::x86_tdpbusd_internal || |
| 93 | IntrID == Intrinsic::x86_tdpbuud_internal || |
| 94 | IntrID == Intrinsic::x86_tdpbf16ps_internal, |
| 95 | bool> |
| 96 | lowerTileDP(Instruction *TileDP); |
| 97 | bool lowerTileZero(Instruction *TileZero); |
| 98 | }; |
| 99 | } // anonymous namespace |
| 100 | |
| 101 | BasicBlock *X86LowerAMXIntrinsics::createLoop(BasicBlock *, |
| 102 | BasicBlock *Exit, Value *Bound, |
| 103 | Value *Step, StringRef Name, |
| 104 | IRBuilderBase &B, Loop *L) { |
| 105 | LLVMContext &Ctx = Preheader->getContext(); |
| 106 | BasicBlock * = |
| 107 | BasicBlock::Create(Context&: Ctx, Name: Name + ".header" , Parent: Preheader->getParent(), InsertBefore: Exit); |
| 108 | BasicBlock *Body = |
| 109 | BasicBlock::Create(Context&: Ctx, Name: Name + ".body" , Parent: Header->getParent(), InsertBefore: Exit); |
| 110 | BasicBlock *Latch = |
| 111 | BasicBlock::Create(Context&: Ctx, Name: Name + ".latch" , Parent: Header->getParent(), InsertBefore: Exit); |
| 112 | |
| 113 | Type *I16Ty = Type::getInt16Ty(C&: Ctx); |
| 114 | BranchInst::Create(IfTrue: Body, InsertBefore: Header); |
| 115 | BranchInst::Create(IfTrue: Latch, InsertBefore: Body); |
| 116 | PHINode *IV = |
| 117 | PHINode::Create(Ty: I16Ty, NumReservedValues: 2, NameStr: Name + ".iv" , InsertBefore: Header->getTerminator()->getIterator()); |
| 118 | IV->addIncoming(V: ConstantInt::get(Ty: I16Ty, V: 0), BB: Preheader); |
| 119 | |
| 120 | B.SetInsertPoint(Latch); |
| 121 | Value *Inc = B.CreateAdd(LHS: IV, RHS: Step, Name: Name + ".step" ); |
| 122 | Value *Cond = B.CreateICmpNE(LHS: Inc, RHS: Bound, Name: Name + ".cond" ); |
| 123 | BranchInst::Create(IfTrue: Header, IfFalse: Exit, Cond, InsertBefore: Latch); |
| 124 | IV->addIncoming(V: Inc, BB: Latch); |
| 125 | |
| 126 | BranchInst * = cast<BranchInst>(Val: Preheader->getTerminator()); |
| 127 | BasicBlock *Tmp = PreheaderBr->getSuccessor(i: 0); |
| 128 | PreheaderBr->setSuccessor(idx: 0, NewSucc: Header); |
| 129 | DTU.applyUpdatesPermissive(Updates: { |
| 130 | {DominatorTree::Delete, Preheader, Tmp}, |
| 131 | {DominatorTree::Insert, Header, Body}, |
| 132 | {DominatorTree::Insert, Body, Latch}, |
| 133 | {DominatorTree::Insert, Latch, Header}, |
| 134 | {DominatorTree::Insert, Latch, Exit}, |
| 135 | {DominatorTree::Insert, Preheader, Header}, |
| 136 | }); |
| 137 | if (LI) { |
| 138 | L->addBasicBlockToLoop(NewBB: Header, LI&: *LI); |
| 139 | L->addBasicBlockToLoop(NewBB: Body, LI&: *LI); |
| 140 | L->addBasicBlockToLoop(NewBB: Latch, LI&: *LI); |
| 141 | } |
| 142 | return Body; |
| 143 | } |
| 144 | |
| 145 | template <bool IsTileLoad> |
| 146 | Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops( |
| 147 | BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row, |
| 148 | Value *Col, Value *Ptr, Value *Stride, Value *Tile) { |
| 149 | std::string IntrinName = IsTileLoad ? "tileload" : "tilestore" ; |
| 150 | Loop *RowLoop = nullptr; |
| 151 | Loop *ColLoop = nullptr; |
| 152 | if (LI) { |
| 153 | RowLoop = LI->AllocateLoop(); |
| 154 | ColLoop = LI->AllocateLoop(); |
| 155 | RowLoop->addChildLoop(NewChild: ColLoop); |
| 156 | if (Loop *ParentL = LI->getLoopFor(BB: Start)) |
| 157 | ParentL->addChildLoop(NewChild: RowLoop); |
| 158 | else |
| 159 | LI->addTopLevelLoop(New: RowLoop); |
| 160 | } |
| 161 | |
| 162 | BasicBlock *RowBody = createLoop(Preheader: Start, Exit: End, Bound: Row, Step: B.getInt16(C: 1), |
| 163 | Name: IntrinName + ".scalarize.rows" , B, L: RowLoop); |
| 164 | BasicBlock *RowLatch = RowBody->getSingleSuccessor(); |
| 165 | |
| 166 | BasicBlock *ColBody = createLoop(Preheader: RowBody, Exit: RowLatch, Bound: Col, Step: B.getInt16(C: 1), |
| 167 | Name: IntrinName + ".scalarize.cols" , B, L: ColLoop); |
| 168 | |
| 169 | BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor(); |
| 170 | BasicBlock * = ColBody->getSinglePredecessor(); |
| 171 | BasicBlock * = RowBody->getSinglePredecessor(); |
| 172 | Value *CurrentRow = &*RowLoopHeader->begin(); |
| 173 | Value *CurrentCol = &*ColLoopHeader->begin(); |
| 174 | Type *EltTy = B.getInt32Ty(); |
| 175 | FixedVectorType *V256I32Ty = FixedVectorType::get(ElementType: EltTy, NumElts: 256); |
| 176 | |
| 177 | // Common part for tileload and tilestore |
| 178 | // *.scalarize.cols.body: |
| 179 | // Calculate %idxmem and %idxvec |
| 180 | B.SetInsertPoint(ColBody->getTerminator()); |
| 181 | Value *CurrentRowZExt = B.CreateZExt(V: CurrentRow, DestTy: Stride->getType()); |
| 182 | Value *CurrentColZExt = B.CreateZExt(V: CurrentCol, DestTy: Stride->getType()); |
| 183 | Value *Offset = |
| 184 | B.CreateAdd(LHS: B.CreateMul(LHS: CurrentRowZExt, RHS: Stride), RHS: CurrentColZExt); |
| 185 | Value *EltPtr = B.CreateGEP(Ty: EltTy, Ptr, IdxList: Offset); |
| 186 | Value *Idx = B.CreateAdd(LHS: B.CreateMul(LHS: CurrentRow, RHS: B.getInt16(C: 16)), RHS: CurrentCol); |
| 187 | if (IsTileLoad) { |
| 188 | // tileload.scalarize.rows.header: |
| 189 | // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec, |
| 190 | // %tileload.scalarize.rows.latch ] |
| 191 | B.SetInsertPoint(RowLoopHeader->getTerminator()); |
| 192 | Value *VecZero = Constant::getNullValue(Ty: V256I32Ty); |
| 193 | PHINode *VecCPhiRowLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.phi.row" ); |
| 194 | VecCPhiRowLoop->addIncoming(V: VecZero, BB: Start); |
| 195 | |
| 196 | // tileload.scalarize.cols.header: |
| 197 | // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body |
| 198 | // ], [ %ResVec, %tileload.scalarize.cols.latch ] |
| 199 | B.SetInsertPoint(ColLoopHeader->getTerminator()); |
| 200 | PHINode *VecPhi = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.phi" ); |
| 201 | VecPhi->addIncoming(V: VecCPhiRowLoop, BB: RowBody); |
| 202 | |
| 203 | // tileload.scalarize.cols.body: |
| 204 | // Calculate %idxmem and %idxvec |
| 205 | // %eltptr = getelementptr i32, i32* %base, i64 %idxmem |
| 206 | // %elt = load i32, i32* %ptr |
| 207 | // %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec |
| 208 | B.SetInsertPoint(ColBody->getTerminator()); |
| 209 | Value *Elt = B.CreateLoad(Ty: EltTy, Ptr: EltPtr); |
| 210 | Value *ResVec = B.CreateInsertElement(Vec: VecPhi, NewElt: Elt, Idx); |
| 211 | VecPhi->addIncoming(V: ResVec, BB: ColLoopLatch); |
| 212 | VecCPhiRowLoop->addIncoming(V: ResVec, BB: RowLatch); |
| 213 | |
| 214 | return ResVec; |
| 215 | } else { |
| 216 | auto *BitCast = cast<BitCastInst>(Val: Tile); |
| 217 | Value *Vec = BitCast->getOperand(i_nocapture: 0); |
| 218 | assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx" ); |
| 219 | // tilestore.scalarize.cols.body: |
| 220 | // %mul = mul i16 %row.iv, i16 16 |
| 221 | // %idx = add i16 %mul, i16 %col.iv |
| 222 | // %vec = extractelement <16 x i32> %vec, i16 %idx |
| 223 | // store i32 %vec, i32* %ptr |
| 224 | B.SetInsertPoint(ColBody->getTerminator()); |
| 225 | Value *Elt = B.CreateExtractElement(Vec, Idx); |
| 226 | |
| 227 | B.CreateStore(Val: Elt, Ptr: EltPtr); |
| 228 | return nullptr; |
| 229 | } |
| 230 | } |
| 231 | |
| 232 | template <Intrinsic::ID IntrID> |
| 233 | std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || |
| 234 | IntrID == Intrinsic::x86_tdpbsud_internal || |
| 235 | IntrID == Intrinsic::x86_tdpbusd_internal || |
| 236 | IntrID == Intrinsic::x86_tdpbuud_internal || |
| 237 | IntrID == Intrinsic::x86_tdpbf16ps_internal, |
| 238 | Value *> |
| 239 | X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End, |
| 240 | IRBuilderBase &B, Value *Row, |
| 241 | Value *Col, Value *K, Value *Acc, |
| 242 | Value *LHS, Value *RHS) { |
| 243 | std::string IntrinName; |
| 244 | switch (IntrID) { |
| 245 | case Intrinsic::x86_tdpbssd_internal: |
| 246 | IntrinName = "tiledpbssd" ; |
| 247 | break; |
| 248 | case Intrinsic::x86_tdpbsud_internal: |
| 249 | IntrinName = "tiledpbsud" ; |
| 250 | break; |
| 251 | case Intrinsic::x86_tdpbusd_internal: |
| 252 | IntrinName = "tiledpbusd" ; |
| 253 | break; |
| 254 | case Intrinsic::x86_tdpbuud_internal: |
| 255 | IntrinName = "tiledpbuud" ; |
| 256 | break; |
| 257 | case Intrinsic::x86_tdpbf16ps_internal: |
| 258 | IntrinName = "tiledpbf16ps" ; |
| 259 | break; |
| 260 | } |
| 261 | Loop *RowLoop = nullptr; |
| 262 | Loop *ColLoop = nullptr; |
| 263 | Loop *InnerLoop = nullptr; |
| 264 | if (LI) { |
| 265 | RowLoop = LI->AllocateLoop(); |
| 266 | ColLoop = LI->AllocateLoop(); |
| 267 | InnerLoop = LI->AllocateLoop(); |
| 268 | ColLoop->addChildLoop(NewChild: InnerLoop); |
| 269 | RowLoop->addChildLoop(NewChild: ColLoop); |
| 270 | if (Loop *ParentL = LI->getLoopFor(BB: Start)) |
| 271 | ParentL->addChildLoop(NewChild: RowLoop); |
| 272 | else |
| 273 | LI->addTopLevelLoop(New: RowLoop); |
| 274 | } |
| 275 | |
| 276 | BasicBlock *RowBody = createLoop(Preheader: Start, Exit: End, Bound: Row, Step: B.getInt16(C: 1), |
| 277 | Name: IntrinName + ".scalarize.rows" , B, L: RowLoop); |
| 278 | BasicBlock *RowLatch = RowBody->getSingleSuccessor(); |
| 279 | |
| 280 | BasicBlock *ColBody = createLoop(Preheader: RowBody, Exit: RowLatch, Bound: Col, Step: B.getInt16(C: 1), |
| 281 | Name: IntrinName + ".scalarize.cols" , B, L: ColLoop); |
| 282 | |
| 283 | BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor(); |
| 284 | |
| 285 | B.SetInsertPoint(ColBody->getTerminator()); |
| 286 | BasicBlock *InnerBody = |
| 287 | createLoop(Preheader: ColBody, Exit: ColLoopLatch, Bound: K, Step: B.getInt16(C: 1), |
| 288 | Name: IntrinName + ".scalarize.inner" , B, L: InnerLoop); |
| 289 | |
| 290 | BasicBlock * = ColBody->getSinglePredecessor(); |
| 291 | BasicBlock * = RowBody->getSinglePredecessor(); |
| 292 | BasicBlock * = InnerBody->getSinglePredecessor(); |
| 293 | BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor(); |
| 294 | Value *CurrentRow = &*RowLoopHeader->begin(); |
| 295 | Value *CurrentCol = &*ColLoopHeader->begin(); |
| 296 | Value *CurrentInner = &*InnerLoopHeader->begin(); |
| 297 | |
| 298 | FixedVectorType *V256I32Ty = FixedVectorType::get(ElementType: B.getInt32Ty(), NumElts: 256); |
| 299 | auto *BitCastAcc = cast<BitCastInst>(Val: Acc); |
| 300 | Value *VecC = BitCastAcc->getOperand(i_nocapture: 0); |
| 301 | assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx" ); |
| 302 | // TODO else create BitCast from x86amx to v256i32. |
| 303 | // Store x86amx to memory, and reload from memory |
| 304 | // to vector. However with -O0, it doesn't happen. |
| 305 | auto *BitCastLHS = cast<BitCastInst>(Val: LHS); |
| 306 | Value *VecA = BitCastLHS->getOperand(i_nocapture: 0); |
| 307 | assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx" ); |
| 308 | auto *BitCastRHS = cast<BitCastInst>(Val: RHS); |
| 309 | Value *VecB = BitCastRHS->getOperand(i_nocapture: 0); |
| 310 | assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx" ); |
| 311 | |
| 312 | // tiledpbssd.scalarize.rows.header: |
| 313 | // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC, |
| 314 | // %tiledpbssd.scalarize.rows.latch ] |
| 315 | |
| 316 | // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [ |
| 317 | // %NewVecD, %tiledpbssd.scalarize.rows.latch ] |
| 318 | B.SetInsertPoint(RowLoopHeader->getTerminator()); |
| 319 | PHINode *VecCPhiRowLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.c.phi.row" ); |
| 320 | VecCPhiRowLoop->addIncoming(V: VecC, BB: Start); |
| 321 | Value *VecZero = Constant::getNullValue(Ty: V256I32Ty); |
| 322 | PHINode *VecDPhiRowLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.d.phi.row" ); |
| 323 | VecDPhiRowLoop->addIncoming(V: VecZero, BB: Start); |
| 324 | |
| 325 | // tiledpbssd.scalarize.cols.header: |
| 326 | // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row, |
| 327 | // %tiledpbssd.scalarize.rows.body ], [ %NewVecC, |
| 328 | // %tiledpbssd.scalarize.cols.latch ] |
| 329 | |
| 330 | // %vec.d.phi.col = phi <256 x i32> [ |
| 331 | // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD, |
| 332 | // %tiledpbssd.scalarize.cols.latch ] |
| 333 | |
| 334 | // calculate idxc. |
| 335 | B.SetInsertPoint(ColLoopHeader->getTerminator()); |
| 336 | PHINode *VecCPhiColLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.c.phi.col" ); |
| 337 | VecCPhiColLoop->addIncoming(V: VecCPhiRowLoop, BB: RowBody); |
| 338 | PHINode *VecDPhiColLoop = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.d.phi.col" ); |
| 339 | VecDPhiColLoop->addIncoming(V: VecDPhiRowLoop, BB: RowBody); |
| 340 | Value *IdxC = |
| 341 | B.CreateAdd(LHS: B.CreateMul(LHS: CurrentRow, RHS: B.getInt16(C: 16)), RHS: CurrentCol); |
| 342 | |
| 343 | // tiledpbssd.scalarize.inner.header: |
| 344 | // %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col, |
| 345 | // %tiledpbssd.scalarize.cols.body ], [ %NewVecC, |
| 346 | // %tiledpbssd.scalarize.inner.latch ] |
| 347 | |
| 348 | B.SetInsertPoint(InnerLoopHeader->getTerminator()); |
| 349 | PHINode *VecCPhi = B.CreatePHI(Ty: V256I32Ty, NumReservedValues: 2, Name: "vec.c.inner.phi" ); |
| 350 | VecCPhi->addIncoming(V: VecCPhiColLoop, BB: ColBody); |
| 351 | |
| 352 | B.SetInsertPoint(InnerBody->getTerminator()); |
| 353 | Value *IdxA = |
| 354 | B.CreateAdd(LHS: B.CreateMul(LHS: CurrentRow, RHS: B.getInt16(C: 16)), RHS: CurrentInner); |
| 355 | Value *IdxB = |
| 356 | B.CreateAdd(LHS: B.CreateMul(LHS: CurrentInner, RHS: B.getInt16(C: 16)), RHS: CurrentCol); |
| 357 | Value *NewVecC = nullptr; |
| 358 | |
| 359 | if (IntrID != Intrinsic::x86_tdpbf16ps_internal) { |
| 360 | // tiledpbssd.scalarize.inner.body: |
| 361 | // calculate idxa, idxb |
| 362 | // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc |
| 363 | // %elta = extractelement <256 x i32> %veca, i16 %idxa |
| 364 | // %eltav4i8 = bitcast i32 %elta to <4 x i8> |
| 365 | // %eltb = extractelement <256 x i32> %vecb, i16 %idxb |
| 366 | // %eltbv4i8 = bitcast i32 %eltb to <4 x i8> |
| 367 | // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32> |
| 368 | // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32> |
| 369 | // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32 |
| 370 | // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131) |
| 371 | // %neweltc = add i32 %elt, %acc |
| 372 | // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc, |
| 373 | // i16 %idxc |
| 374 | FixedVectorType *V4I8Ty = FixedVectorType::get(ElementType: B.getInt8Ty(), NumElts: 4); |
| 375 | FixedVectorType *V4I32Ty = FixedVectorType::get(ElementType: B.getInt32Ty(), NumElts: 4); |
| 376 | Value *EltC = B.CreateExtractElement(Vec: VecCPhi, Idx: IdxC); |
| 377 | Value *EltA = B.CreateExtractElement(Vec: VecA, Idx: IdxA); |
| 378 | Value *SubVecA = B.CreateBitCast(V: EltA, DestTy: V4I8Ty); |
| 379 | Value *EltB = B.CreateExtractElement(Vec: VecB, Idx: IdxB); |
| 380 | Value *SubVecB = B.CreateBitCast(V: EltB, DestTy: V4I8Ty); |
| 381 | Value *SEXTSubVecB = nullptr; |
| 382 | Value *SEXTSubVecA = nullptr; |
| 383 | switch (IntrID) { |
| 384 | case Intrinsic::x86_tdpbssd_internal: |
| 385 | SEXTSubVecB = B.CreateSExt(V: SubVecB, DestTy: V4I32Ty); |
| 386 | SEXTSubVecA = B.CreateSExt(V: SubVecA, DestTy: V4I32Ty); |
| 387 | break; |
| 388 | case Intrinsic::x86_tdpbsud_internal: |
| 389 | SEXTSubVecB = B.CreateZExt(V: SubVecB, DestTy: V4I32Ty); |
| 390 | SEXTSubVecA = B.CreateSExt(V: SubVecA, DestTy: V4I32Ty); |
| 391 | break; |
| 392 | case Intrinsic::x86_tdpbusd_internal: |
| 393 | SEXTSubVecB = B.CreateSExt(V: SubVecB, DestTy: V4I32Ty); |
| 394 | SEXTSubVecA = B.CreateZExt(V: SubVecA, DestTy: V4I32Ty); |
| 395 | break; |
| 396 | case Intrinsic::x86_tdpbuud_internal: |
| 397 | SEXTSubVecB = B.CreateZExt(V: SubVecB, DestTy: V4I32Ty); |
| 398 | SEXTSubVecA = B.CreateZExt(V: SubVecA, DestTy: V4I32Ty); |
| 399 | break; |
| 400 | default: |
| 401 | llvm_unreachable("Invalid intrinsic ID!" ); |
| 402 | } |
| 403 | Value *SubVecR = B.CreateAddReduce(Src: B.CreateMul(LHS: SEXTSubVecA, RHS: SEXTSubVecB)); |
| 404 | Value *ResElt = B.CreateAdd(LHS: EltC, RHS: SubVecR); |
| 405 | NewVecC = B.CreateInsertElement(Vec: VecCPhi, NewElt: ResElt, Idx: IdxC); |
| 406 | } else { |
| 407 | // tiledpbf16ps.scalarize.inner.body: |
| 408 | // calculate idxa, idxb, idxc |
| 409 | // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc |
| 410 | // %eltcf32 = bitcast i32 %eltc to float |
| 411 | // %elta = extractelement <256 x i32> %veca, i16 %idxa |
| 412 | // %eltav2i16 = bitcast i32 %elta to <2 x i16> |
| 413 | // %eltb = extractelement <256 x i32> %vecb, i16 %idxb |
| 414 | // %eltbv2i16 = bitcast i32 %eltb to <2 x i16> |
| 415 | // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4 |
| 416 | // x i32> <i32 2, i32 0, i32 3, i32 1> |
| 417 | // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float> |
| 418 | // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x |
| 419 | // i32> <i32 2, i32 0, i32 3, i32 1> |
| 420 | // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float> |
| 421 | // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32 |
| 422 | // %acc = call float |
| 423 | // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab) |
| 424 | // %neweltc = bitcast float %acc to i32 |
| 425 | // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc, |
| 426 | // i16 %idxc |
| 427 | // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc, |
| 428 | // i16 %idxc |
| 429 | FixedVectorType *V2I16Ty = FixedVectorType::get(ElementType: B.getInt16Ty(), NumElts: 2); |
| 430 | FixedVectorType *V2F32Ty = FixedVectorType::get(ElementType: B.getFloatTy(), NumElts: 2); |
| 431 | Value *EltC = B.CreateExtractElement(Vec: VecCPhi, Idx: IdxC); |
| 432 | Value *EltCF32 = B.CreateBitCast(V: EltC, DestTy: B.getFloatTy()); |
| 433 | Value *EltA = B.CreateExtractElement(Vec: VecA, Idx: IdxA); |
| 434 | Value *SubVecA = B.CreateBitCast(V: EltA, DestTy: V2I16Ty); |
| 435 | Value *EltB = B.CreateExtractElement(Vec: VecB, Idx: IdxB); |
| 436 | Value *SubVecB = B.CreateBitCast(V: EltB, DestTy: V2I16Ty); |
| 437 | Value *ZeroV2I16 = Constant::getNullValue(Ty: V2I16Ty); |
| 438 | int ShuffleMask[4] = {2, 0, 3, 1}; |
| 439 | auto ShuffleArray = ArrayRef(ShuffleMask); |
| 440 | Value *AV2F32 = B.CreateBitCast( |
| 441 | V: B.CreateShuffleVector(V1: SubVecA, V2: ZeroV2I16, Mask: ShuffleArray), DestTy: V2F32Ty); |
| 442 | Value *BV2F32 = B.CreateBitCast( |
| 443 | V: B.CreateShuffleVector(V1: SubVecB, V2: ZeroV2I16, Mask: ShuffleArray), DestTy: V2F32Ty); |
| 444 | Value *SubVecR = B.CreateFAddReduce(Acc: EltCF32, Src: B.CreateFMul(L: AV2F32, R: BV2F32)); |
| 445 | Value *ResElt = B.CreateBitCast(V: SubVecR, DestTy: B.getInt32Ty()); |
| 446 | NewVecC = B.CreateInsertElement(Vec: VecCPhi, NewElt: ResElt, Idx: IdxC); |
| 447 | } |
| 448 | |
| 449 | // tiledpbssd.scalarize.cols.latch: |
| 450 | // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc |
| 451 | // %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC, |
| 452 | // i16 %idxc |
| 453 | B.SetInsertPoint(ColLoopLatch->getTerminator()); |
| 454 | Value *NewEltC = B.CreateExtractElement(Vec: NewVecC, Idx: IdxC); |
| 455 | Value *NewVecD = B.CreateInsertElement(Vec: VecDPhiColLoop, NewElt: NewEltC, Idx: IdxC); |
| 456 | |
| 457 | VecCPhi->addIncoming(V: NewVecC, BB: InnerLoopLatch); |
| 458 | VecCPhiRowLoop->addIncoming(V: NewVecC, BB: RowLatch); |
| 459 | VecCPhiColLoop->addIncoming(V: NewVecC, BB: ColLoopLatch); |
| 460 | VecDPhiRowLoop->addIncoming(V: NewVecD, BB: RowLatch); |
| 461 | VecDPhiColLoop->addIncoming(V: NewVecD, BB: ColLoopLatch); |
| 462 | |
| 463 | return NewVecD; |
| 464 | } |
| 465 | |
| 466 | template <Intrinsic::ID IntrID> |
| 467 | std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || |
| 468 | IntrID == Intrinsic::x86_tdpbsud_internal || |
| 469 | IntrID == Intrinsic::x86_tdpbusd_internal || |
| 470 | IntrID == Intrinsic::x86_tdpbuud_internal || |
| 471 | IntrID == Intrinsic::x86_tdpbf16ps_internal, |
| 472 | bool> |
| 473 | X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) { |
| 474 | Value *M, *N, *K, *C, *A, *B; |
| 475 | match(TileDP, m_Intrinsic<IntrID>(m_Value(V&: M), m_Value(V&: N), m_Value(V&: K), |
| 476 | m_Value(V&: C), m_Value(V&: A), m_Value(V&: B))); |
| 477 | Instruction *InsertI = TileDP; |
| 478 | IRBuilder<> PreBuilder(TileDP); |
| 479 | PreBuilder.SetInsertPoint(TileDP); |
| 480 | // We visit the loop with (m, n/4, k/4): |
| 481 | // %n_dword = lshr i16 %n, 2 |
| 482 | // %k_dword = lshr i16 %k, 2 |
| 483 | Value *NDWord = PreBuilder.CreateLShr(LHS: N, RHS: PreBuilder.getInt16(C: 2)); |
| 484 | Value *KDWord = PreBuilder.CreateLShr(LHS: K, RHS: PreBuilder.getInt16(C: 2)); |
| 485 | BasicBlock *Start = InsertI->getParent(); |
| 486 | BasicBlock *End = |
| 487 | SplitBlock(Old: InsertI->getParent(), SplitPt: InsertI, DTU: &DTU, LI, MSSAU: nullptr, BBName: "continue" ); |
| 488 | IRBuilder<> Builder(TileDP); |
| 489 | Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord, |
| 490 | KDWord, C, A, B); |
| 491 | // we cannot assume there always be bitcast after tiledpbssd. So we need to |
| 492 | // insert one bitcast as required |
| 493 | Builder.SetInsertPoint(TheBB: End, IP: End->getFirstNonPHIIt()); |
| 494 | Value *ResAMX = |
| 495 | Builder.CreateBitCast(V: ResVec, DestTy: Type::getX86_AMXTy(C&: Builder.getContext())); |
| 496 | // Delete TileDP intrinsic and do some clean-up. |
| 497 | for (Use &U : llvm::make_early_inc_range(Range: TileDP->uses())) { |
| 498 | Instruction *I = cast<Instruction>(Val: U.getUser()); |
| 499 | Value *Vec; |
| 500 | if (match(V: I, P: m_BitCast(Op: m_Value(V&: Vec)))) { |
| 501 | I->replaceAllUsesWith(V: ResVec); |
| 502 | I->eraseFromParent(); |
| 503 | } |
| 504 | } |
| 505 | TileDP->replaceAllUsesWith(V: ResAMX); |
| 506 | TileDP->eraseFromParent(); |
| 507 | return true; |
| 508 | } |
| 509 | |
| 510 | template <bool IsTileLoad> |
| 511 | bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) { |
| 512 | Value *M, *N, *Ptr, *Stride, *Tile; |
| 513 | if (IsTileLoad) |
| 514 | match(V: TileLoadStore, |
| 515 | P: m_Intrinsic<Intrinsic::x86_tileloadd64_internal>( |
| 516 | Op0: m_Value(V&: M), Op1: m_Value(V&: N), Op2: m_Value(V&: Ptr), Op3: m_Value(V&: Stride))); |
| 517 | else |
| 518 | match(V: TileLoadStore, P: m_Intrinsic<Intrinsic::x86_tilestored64_internal>( |
| 519 | Op0: m_Value(V&: M), Op1: m_Value(V&: N), Op2: m_Value(V&: Ptr), |
| 520 | Op3: m_Value(V&: Stride), Op4: m_Value(V&: Tile))); |
| 521 | |
| 522 | Instruction *InsertI = TileLoadStore; |
| 523 | IRBuilder<> PreBuilder(TileLoadStore); |
| 524 | PreBuilder.SetInsertPoint(TileLoadStore); |
| 525 | Value *NDWord = PreBuilder.CreateLShr(LHS: N, RHS: PreBuilder.getInt16(C: 2)); |
| 526 | Value *StrideDWord = PreBuilder.CreateLShr(LHS: Stride, RHS: PreBuilder.getInt64(C: 2)); |
| 527 | BasicBlock *Start = InsertI->getParent(); |
| 528 | BasicBlock *End = |
| 529 | SplitBlock(Old: InsertI->getParent(), SplitPt: InsertI, DTU: &DTU, LI, MSSAU: nullptr, BBName: "continue" ); |
| 530 | IRBuilder<> Builder(TileLoadStore); |
| 531 | Value *ResVec = createTileLoadStoreLoops<IsTileLoad>( |
| 532 | Start, End, Builder, M, NDWord, Ptr, StrideDWord, |
| 533 | IsTileLoad ? nullptr : Tile); |
| 534 | if (IsTileLoad) { |
| 535 | // we cannot assume there always be bitcast after tileload. So we need to |
| 536 | // insert one bitcast as required |
| 537 | Builder.SetInsertPoint(TheBB: End, IP: End->getFirstNonPHIIt()); |
| 538 | Value *ResAMX = |
| 539 | Builder.CreateBitCast(V: ResVec, DestTy: Type::getX86_AMXTy(C&: Builder.getContext())); |
| 540 | // Delete tileloadd6 intrinsic and do some clean-up |
| 541 | for (Use &U : llvm::make_early_inc_range(Range: TileLoadStore->uses())) { |
| 542 | Instruction *I = cast<Instruction>(Val: U.getUser()); |
| 543 | Value *Vec; |
| 544 | if (match(V: I, P: m_BitCast(Op: m_Value(V&: Vec)))) { |
| 545 | I->replaceAllUsesWith(V: ResVec); |
| 546 | I->eraseFromParent(); |
| 547 | } |
| 548 | } |
| 549 | TileLoadStore->replaceAllUsesWith(V: ResAMX); |
| 550 | } |
| 551 | TileLoadStore->eraseFromParent(); |
| 552 | return true; |
| 553 | } |
| 554 | |
| 555 | bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) { |
| 556 | IRBuilder<> Builder(TileZero); |
| 557 | FixedVectorType *V256I32Ty = FixedVectorType::get(ElementType: Builder.getInt32Ty(), NumElts: 256); |
| 558 | Value *VecZero = Constant::getNullValue(Ty: V256I32Ty); |
| 559 | for (Use &U : llvm::make_early_inc_range(Range: TileZero->uses())) { |
| 560 | Instruction *I = cast<Instruction>(Val: U.getUser()); |
| 561 | Value *Vec; |
| 562 | if (match(V: I, P: m_BitCast(Op: m_Value(V&: Vec)))) { |
| 563 | I->replaceAllUsesWith(V: VecZero); |
| 564 | I->eraseFromParent(); |
| 565 | } |
| 566 | } |
| 567 | TileZero->eraseFromParent(); |
| 568 | return true; |
| 569 | } |
| 570 | |
| 571 | bool X86LowerAMXIntrinsics::visit() { |
| 572 | bool C = false; |
| 573 | SmallVector<IntrinsicInst *, 8> WorkList; |
| 574 | for (BasicBlock *BB : depth_first(G: &Func)) { |
| 575 | for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) { |
| 576 | if (auto *Inst = dyn_cast<IntrinsicInst>(Val: &*II++)) { |
| 577 | switch (Inst->getIntrinsicID()) { |
| 578 | case Intrinsic::x86_tdpbssd_internal: |
| 579 | case Intrinsic::x86_tdpbsud_internal: |
| 580 | case Intrinsic::x86_tdpbusd_internal: |
| 581 | case Intrinsic::x86_tdpbuud_internal: |
| 582 | case Intrinsic::x86_tileloadd64_internal: |
| 583 | case Intrinsic::x86_tilestored64_internal: |
| 584 | case Intrinsic::x86_tilezero_internal: |
| 585 | case Intrinsic::x86_tdpbf16ps_internal: |
| 586 | WorkList.push_back(Elt: Inst); |
| 587 | break; |
| 588 | default: |
| 589 | break; |
| 590 | } |
| 591 | } |
| 592 | } |
| 593 | } |
| 594 | |
| 595 | for (auto *Inst : WorkList) { |
| 596 | switch (Inst->getIntrinsicID()) { |
| 597 | case Intrinsic::x86_tdpbssd_internal: |
| 598 | C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(TileDP: Inst) || C; |
| 599 | break; |
| 600 | case Intrinsic::x86_tdpbsud_internal: |
| 601 | C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(TileDP: Inst) || C; |
| 602 | break; |
| 603 | case Intrinsic::x86_tdpbusd_internal: |
| 604 | C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(TileDP: Inst) || C; |
| 605 | break; |
| 606 | case Intrinsic::x86_tdpbuud_internal: |
| 607 | C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(TileDP: Inst) || C; |
| 608 | break; |
| 609 | case Intrinsic::x86_tdpbf16ps_internal: |
| 610 | C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(TileDP: Inst) || C; |
| 611 | break; |
| 612 | case Intrinsic::x86_tileloadd64_internal: |
| 613 | C = lowerTileLoadStore<true>(TileLoadStore: Inst) || C; |
| 614 | break; |
| 615 | case Intrinsic::x86_tilestored64_internal: |
| 616 | C = lowerTileLoadStore<false>(TileLoadStore: Inst) || C; |
| 617 | break; |
| 618 | case Intrinsic::x86_tilezero_internal: |
| 619 | C = lowerTileZero(TileZero: Inst) || C; |
| 620 | break; |
| 621 | default: |
| 622 | llvm_unreachable("invalid amx intrinsics!" ); |
| 623 | } |
| 624 | } |
| 625 | |
| 626 | return C; |
| 627 | } |
| 628 | |
| 629 | namespace { |
| 630 | class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass { |
| 631 | public: |
| 632 | static char ID; |
| 633 | |
| 634 | X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) {} |
| 635 | |
| 636 | bool runOnFunction(Function &F) override { |
| 637 | if (!X86ScalarizeAMX) |
| 638 | return false; |
| 639 | TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); |
| 640 | if (!F.hasFnAttribute(Kind: Attribute::OptimizeNone) && |
| 641 | TM->getOptLevel() != CodeGenOptLevel::None) |
| 642 | return false; |
| 643 | |
| 644 | auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); |
| 645 | auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; |
| 646 | auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>(); |
| 647 | auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr; |
| 648 | DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); |
| 649 | |
| 650 | X86LowerAMXIntrinsics LAT(F, DTU, LI); |
| 651 | return LAT.visit(); |
| 652 | } |
| 653 | StringRef getPassName() const override { return "Lower AMX intrinsics" ; } |
| 654 | |
| 655 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
| 656 | AU.addPreserved<DominatorTreeWrapperPass>(); |
| 657 | AU.addPreserved<LoopInfoWrapperPass>(); |
| 658 | AU.addRequired<TargetPassConfig>(); |
| 659 | } |
| 660 | }; |
| 661 | } // namespace |
| 662 | |
| 663 | static const char PassName[] = "Lower AMX intrinsics" ; |
| 664 | char X86LowerAMXIntrinsicsLegacyPass::ID = 0; |
| 665 | INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName, |
| 666 | false, false) |
| 667 | INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) |
| 668 | INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName, |
| 669 | false, false) |
| 670 | |
| 671 | FunctionPass *llvm::createX86LowerAMXIntrinsicsPass() { |
| 672 | return new X86LowerAMXIntrinsicsLegacyPass(); |
| 673 | } |
| 674 | |