1//===- SPIRVPushConstantAccess.cpp - Translate CBuffer Loads ----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass changes the types of all the globals in the PushConstant
10// address space into a target extension type, and makes all references
11// to this global go though a custom SPIR-V intrinsic.
12//
13// This allows the backend to properly lower the push constant struct type
14// to a fully laid out type, and generate the proper OpAccessChain.
15//
16//===----------------------------------------------------------------------===//
17
18#include "SPIRVPushConstantAccess.h"
19#include "SPIRV.h"
20#include "SPIRVSubtarget.h"
21#include "SPIRVTargetMachine.h"
22#include "SPIRVUtils.h"
23#include "llvm/Frontend/HLSL/CBuffer.h"
24#include "llvm/IR/IRBuilder.h"
25#include "llvm/IR/IntrinsicsSPIRV.h"
26#include "llvm/IR/Module.h"
27
28#define DEBUG_TYPE "spirv-pushconstant-access"
29using namespace llvm;
30
31static bool replacePushConstantAccesses(Module &M, SPIRVGlobalRegistry *GR) {
32 bool Changed = false;
33 for (GlobalVariable &GV : make_early_inc_range(Range: M.globals())) {
34 if (GV.getAddressSpace() !=
35 storageClassToAddressSpace(SC: SPIRV::StorageClass::PushConstant))
36 continue;
37
38 GV.removeDeadConstantUsers();
39
40 Type *PCType = llvm::TargetExtType::get(
41 Context&: M.getContext(), Name: "spirv.PushConstant", Types: {GV.getValueType()});
42 GlobalVariable *NewGV =
43 new GlobalVariable(M, PCType, GV.isConstant(), GV.getLinkage(),
44 /* initializer= */ nullptr, GV.getName(),
45 /* InsertBefore= */ &GV, GV.getThreadLocalMode(),
46 GV.getAddressSpace(), GV.isExternallyInitialized());
47 NewGV->setVisibility(GV.getVisibility());
48
49 for (User *U : make_early_inc_range(Range: GV.users())) {
50 Instruction *I = cast<Instruction>(Val: U);
51 IRBuilder<> Builder(I);
52 Value *GetPointerCall = Builder.CreateIntrinsic(
53 RetTy: NewGV->getType(), ID: Intrinsic::spv_pushconstant_getpointer, Args: {NewGV});
54 GR->buildAssignPtr(B&: Builder, ElemTy: GV.getValueType(), Arg: GetPointerCall);
55
56 I->replaceUsesOfWith(From: &GV, To: GetPointerCall);
57 }
58
59 GV.eraseFromParent();
60 Changed = true;
61 }
62
63 return Changed;
64}
65
66PreservedAnalyses SPIRVPushConstantAccess::run(Module &M,
67 ModuleAnalysisManager &AM) {
68 const SPIRVSubtarget *ST = TM.getSubtargetImpl();
69 SPIRVGlobalRegistry *GR = ST->getSPIRVGlobalRegistry();
70 return replacePushConstantAccesses(M, GR) ? PreservedAnalyses::none()
71 : PreservedAnalyses::all();
72}
73
74namespace {
75class SPIRVPushConstantAccessLegacy : public ModulePass {
76 SPIRVTargetMachine *TM = nullptr;
77
78public:
79 bool runOnModule(Module &M) override {
80 const SPIRVSubtarget *ST = TM->getSubtargetImpl();
81 SPIRVGlobalRegistry *GR = ST->getSPIRVGlobalRegistry();
82 return replacePushConstantAccesses(M, GR);
83 }
84 StringRef getPassName() const override {
85 return "SPIRV push constant Access";
86 }
87 SPIRVPushConstantAccessLegacy(SPIRVTargetMachine *TM)
88 : ModulePass(ID), TM(TM) {}
89
90 static char ID; // Pass identification.
91};
92char SPIRVPushConstantAccessLegacy::ID = 0;
93} // end anonymous namespace
94
95INITIALIZE_PASS(SPIRVPushConstantAccessLegacy, DEBUG_TYPE,
96 "SPIRV push constant Access", false, false)
97
98ModulePass *
99llvm::createSPIRVPushConstantAccessLegacyPass(SPIRVTargetMachine *TM) {
100 return new SPIRVPushConstantAccessLegacy(TM);
101}
102