1//===- SPIRVPushConstantAccess.cpp - Translate CBuffer Loads ----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass changes the types of all the globals in the PushConstant
10// address space into a target extension type, and makes all references
11// to this global go though a custom SPIR-V intrinsic.
12//
13// This allows the backend to properly lower the push constant struct type
14// to a fully laid out type, and generate the proper OpAccessChain.
15//
16//===----------------------------------------------------------------------===//
17
18#include "SPIRVPushConstantAccess.h"
19#include "SPIRV.h"
20#include "SPIRVSubtarget.h"
21#include "SPIRVTargetMachine.h"
22#include "SPIRVUtils.h"
23#include "llvm/Frontend/HLSL/CBuffer.h"
24#include "llvm/IR/IRBuilder.h"
25#include "llvm/IR/IntrinsicsSPIRV.h"
26#include "llvm/IR/Module.h"
27#include "llvm/IR/ReplaceConstant.h"
28
29#define DEBUG_TYPE "spirv-pushconstant-access"
30using namespace llvm;
31
32static bool replacePushConstantAccesses(Module &M, SPIRVGlobalRegistry *GR) {
33 bool Changed = false;
34 for (GlobalVariable &GV : make_early_inc_range(Range: M.globals())) {
35 if (GV.getAddressSpace() !=
36 storageClassToAddressSpace(SC: SPIRV::StorageClass::PushConstant))
37 continue;
38
39 convertUsersOfConstantsToInstructions(
40 Consts: llvm::SmallVector<Constant *, 1>(1, &GV));
41
42 Type *PCType = llvm::TargetExtType::get(
43 Context&: M.getContext(), Name: "spirv.PushConstant", Types: {GV.getValueType()});
44 GlobalVariable *NewGV =
45 new GlobalVariable(M, PCType, GV.isConstant(), GV.getLinkage(),
46 /* initializer= */ nullptr, GV.getName(),
47 /* InsertBefore= */ &GV, GV.getThreadLocalMode(),
48 GV.getAddressSpace(), GV.isExternallyInitialized());
49 NewGV->setVisibility(GV.getVisibility());
50
51 for (User *U : make_early_inc_range(Range: GV.users())) {
52 Instruction *I = cast<Instruction>(Val: U);
53 IRBuilder<> Builder(I);
54 Value *GetPointerCall = Builder.CreateIntrinsic(
55 RetTy: NewGV->getType(), ID: Intrinsic::spv_pushconstant_getpointer, Args: {NewGV});
56 GR->buildAssignPtr(B&: Builder, ElemTy: GV.getValueType(), Arg: GetPointerCall);
57
58 I->replaceUsesOfWith(From: &GV, To: GetPointerCall);
59 }
60
61 GV.eraseFromParent();
62 Changed = true;
63 }
64
65 return Changed;
66}
67
68PreservedAnalyses SPIRVPushConstantAccess::run(Module &M,
69 ModuleAnalysisManager &AM) {
70 const SPIRVSubtarget *ST = TM.getSubtargetImpl();
71 SPIRVGlobalRegistry *GR = ST->getSPIRVGlobalRegistry();
72 return replacePushConstantAccesses(M, GR) ? PreservedAnalyses::none()
73 : PreservedAnalyses::all();
74}
75
76namespace {
77class SPIRVPushConstantAccessLegacy : public ModulePass {
78 SPIRVTargetMachine *TM = nullptr;
79
80public:
81 bool runOnModule(Module &M) override {
82 const SPIRVSubtarget *ST = TM->getSubtargetImpl();
83 SPIRVGlobalRegistry *GR = ST->getSPIRVGlobalRegistry();
84 return replacePushConstantAccesses(M, GR);
85 }
86 StringRef getPassName() const override {
87 return "SPIRV push constant Access";
88 }
89 SPIRVPushConstantAccessLegacy(SPIRVTargetMachine *TM)
90 : ModulePass(ID), TM(TM) {}
91
92 static char ID; // Pass identification.
93};
94char SPIRVPushConstantAccessLegacy::ID = 0;
95} // end anonymous namespace
96
97INITIALIZE_PASS(SPIRVPushConstantAccessLegacy, DEBUG_TYPE,
98 "SPIRV push constant Access", false, false)
99
100ModulePass *
101llvm::createSPIRVPushConstantAccessLegacyPass(SPIRVTargetMachine *TM) {
102 return new SPIRVPushConstantAccessLegacy(TM);
103}
104