1//===-- SPIRVAuxDataHandler.cpp - NonSemantic.AuxData emitter -*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "SPIRVAuxDataHandler.h"
10#include "MCTargetDesc/SPIRVMCTargetDesc.h"
11#include "SPIRVSubtarget.h"
12#include "SPIRVUtils.h"
13#include "llvm/CodeGen/AsmPrinter.h"
14#include "llvm/IR/Attributes.h"
15#include "llvm/IR/Constants.h"
16#include "llvm/IR/Function.h"
17#include "llvm/IR/GlobalObject.h"
18#include "llvm/IR/GlobalVariable.h"
19#include "llvm/IR/LLVMContext.h"
20#include "llvm/IR/Metadata.h"
21#include "llvm/IR/Module.h"
22#include "llvm/MC/MCInst.h"
23#include "llvm/MC/MCStreamer.h"
24#include "llvm/Support/CommandLine.h"
25#include "llvm/Support/ErrorHandling.h"
26
27using namespace llvm;
28
29static cl::opt<bool> SPVPreserveAuxData(
30 "spirv-preserve-auxdata",
31 cl::desc("Preserve LLVM attributes and metadata as "
32 "NonSemantic.AuxData ExtInst annotations (requires "
33 "SPV_KHR_non_semantic_info)"),
34 cl::Optional, cl::Hidden, cl::init(Val: false));
35
36namespace {
37enum AuxDataLinkageType : uint32_t {
38 AvailableExternally = 0,
39};
40
41constexpr unsigned NonSemanticAuxDataSet =
42 static_cast<unsigned>(SPIRV::InstructionSet::NonSemantic_AuxData);
43
44AttributeSet getGOAttrs(const GlobalObject *GO) {
45 if (const auto *F = dyn_cast<Function>(Val: GO))
46 return F->getAttributes().getFnAttrs();
47 return cast<GlobalVariable>(Val: GO)->getAttributes();
48}
49} // namespace
50
51static bool wasAvailableExternally(const GlobalObject *GO) {
52 if (const auto *F = dyn_cast<Function>(Val: GO))
53 return F->hasFnAttribute(SPIRV_WAS_AVAILABLE_EXTERNALLY_ATTR);
54 return cast<GlobalVariable>(Val: GO)->getAttributes().hasAttribute(
55 SPIRV_WAS_AVAILABLE_EXTERNALLY_ATTR);
56}
57
58SPIRVAuxDataHandler::SPIRVAuxDataHandler(AsmPrinter &AP, const Module &M)
59 : Asm(AP), Mod(M) {
60 for (const GlobalObject &GO : M.global_objects())
61 if (wasAvailableExternally(GO: &GO))
62 LinkagePreservedGOs.push_back(Elt: &GO);
63}
64
65bool SPIRVAuxDataHandler::hasWork() const { return SPVPreserveAuxData; }
66
67void SPIRVAuxDataHandler::prepareModuleOutput(const SPIRVSubtarget &ST,
68 SPIRV::ModuleAnalysisInfo &MAI) {
69 if (!hasWork())
70 return;
71 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_non_semantic_info)) {
72 if (SPVPreserveAuxData)
73 report_fatal_error(reason: "-spirv-preserve-auxdata requires the "
74 "SPV_KHR_non_semantic_info extension to be enabled.");
75 return;
76 }
77 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_non_semantic_info);
78 if (!MAI.ExtInstSetMap.count(Val: NonSemanticAuxDataSet))
79 MAI.ExtInstSetMap[NonSemanticAuxDataSet] = MAI.getNextIDRegister();
80}
81
82MCRegister
83SPIRVAuxDataHandler::getOrEmitString(StringRef S,
84 SPIRV::ModuleAnalysisInfo &MAI) {
85 auto [It, Inserted] = StringRegs.try_emplace(Key: S);
86 if (!Inserted)
87 return It->second;
88 MCRegister Reg = MAI.getNextIDRegister();
89 It->second = Reg;
90 MCInst Inst;
91 Inst.setOpcode(SPIRV::OpString);
92 Inst.addOperand(Op: MCOperand::createReg(Reg));
93 addStringImm(Str: S, Inst);
94 emitMCInst(Inst);
95 return Reg;
96}
97
98void SPIRVAuxDataHandler::collectAttributesFor(const GlobalObject *GO,
99 SPIRV::ModuleAnalysisInfo &MAI) {
100 AuxDataOpcode Opcode = isa<Function>(Val: GO) ? FunctionAttributeOpcode
101 : GlobalVariableAttributeOpcode;
102 for (const Attribute &A : getGOAttrs(GO)) {
103 if (A.isStringAttribute() &&
104 A.getKindAsString() == SPIRV_WAS_AVAILABLE_EXTERNALLY_ATTR)
105 continue;
106 ExtInstRecord Rec;
107 Rec.Opcode = Opcode;
108 Rec.Target = GO;
109 if (A.isStringAttribute()) {
110 Rec.Operands.push_back(Elt: {.Reg: getOrEmitString(S: A.getKindAsString(), MAI)});
111 StringRef Val = A.getValueAsString();
112 if (!Val.empty())
113 Rec.Operands.push_back(Elt: {.Reg: getOrEmitString(S: Val, MAI)});
114 } else {
115 Rec.Operands.push_back(
116 Elt: {.Reg: getOrEmitString(S: StringPool.save(S: A.getAsString()), MAI)});
117 }
118 PendingRecords.push_back(Elt: std::move(Rec));
119 }
120}
121
122void SPIRVAuxDataHandler::collectMetadataFor(const GlobalObject *GO,
123 ArrayRef<StringRef> MDNames,
124 SPIRV::ModuleAnalysisInfo &MAI) {
125 SmallVector<std::pair<unsigned, MDNode *>> AllMD;
126 GO->getAllMetadata(MDs&: AllMD);
127 if (AllMD.empty())
128 return;
129 AuxDataOpcode Opcode =
130 isa<Function>(Val: GO) ? FunctionMetadataOpcode : GlobalVariableMetadataOpcode;
131 // MDString operands become OpStrings; ValueAsMetadata constants (e.g.
132 // !{i32 5}) become OpConstants emitted at section 10. Any other operand
133 // kind would need full value translation, so skip the whole node.
134 auto CollectOperands =
135 [&](MDNode *MD) -> std::optional<SmallVector<Operand, 4>> {
136 SmallVector<Operand, 4> Out;
137 for (const MDOperand &MdOp : MD->operands()) {
138 Metadata *Md = MdOp.get();
139 if (auto *MDStr = dyn_cast_or_null<MDString>(Val: Md)) {
140 Out.push_back(Elt: {.Reg: getOrEmitString(S: MDStr->getString(), MAI)});
141 } else if (auto *VAM = dyn_cast_or_null<ValueAsMetadata>(Val: Md)) {
142 auto *C = dyn_cast<Constant>(Val: VAM->getValue());
143 if (!C || !(isa<ConstantInt>(Val: C) || isa<ConstantFP>(Val: C)))
144 return std::nullopt;
145 Out.push_back(Elt: {.Reg: MCRegister(), .Const: C});
146 } else {
147 return std::nullopt;
148 }
149 }
150 return Out;
151 };
152 for (const auto &MD : AllMD) {
153 if (MD.first == LLVMContext::MD_dbg)
154 continue;
155 StringRef MDName = MDNames[MD.first];
156 if (MDName == "spirv.Decorations" || MDName == "spirv.ParameterDecorations")
157 continue;
158 auto Operands = CollectOperands(MD.second);
159 if (!Operands)
160 continue;
161 ExtInstRecord Rec;
162 Rec.Opcode = Opcode;
163 Rec.Target = GO;
164 Rec.Operands.push_back(Elt: {.Reg: getOrEmitString(S: MDName, MAI)});
165 Rec.Operands.append(in_start: Operands->begin(), in_end: Operands->end());
166 PendingRecords.push_back(Elt: std::move(Rec));
167 }
168}
169
170void SPIRVAuxDataHandler::emitAuxDataStrings(SPIRV::ModuleAnalysisInfo &MAI) {
171 if (!SPVPreserveAuxData)
172 return;
173 if (!MAI.getExtInstSetReg(SetNum: NonSemanticAuxDataSet).isValid())
174 return;
175 SmallVector<StringRef> MDNames;
176 Mod.getContext().getMDKindNames(Result&: MDNames);
177 for (const GlobalObject &GO : Mod.global_objects()) {
178 if (GO.isDeclaration())
179 continue;
180 collectAttributesFor(GO: &GO, MAI);
181 collectMetadataFor(GO: &GO, MDNames, MAI);
182 }
183}
184
185void SPIRVAuxDataHandler::emitAuxData(SPIRV::ModuleAnalysisInfo &MAI) {
186 MCRegister ExtSetReg = MAI.getExtInstSetReg(SetNum: NonSemanticAuxDataSet);
187 if (!ExtSetReg.isValid())
188 return;
189
190 MCRegister VoidTypeReg = findOrEmitOpTypeVoid(MAI);
191
192 for (const ExtInstRecord &Rec : PendingRecords) {
193 MCRegister TargetReg = MAI.getGlobalObjReg(GO: Rec.Target);
194 if (!TargetReg.isValid())
195 continue;
196 SmallVector<MCRegister, 5> Operands;
197 Operands.push_back(Elt: TargetReg);
198 for (const Operand &Op : Rec.Operands)
199 Operands.push_back(Elt: Op.Const ? emitConstant(C: Op.Const, MAI) : Op.Reg);
200 emitAuxDataExtInst(Opcode: Rec.Opcode, VoidTypeReg, ExtSetReg, Operands, MAI);
201 }
202
203 if (LinkagePreservedGOs.empty())
204 return;
205
206 MCRegister UInt32TypeReg = findOrEmitOpTypeUInt32(MAI);
207 MCRegister AEConstReg;
208 for (const GlobalObject *GO : LinkagePreservedGOs) {
209 MCRegister TargetReg = MAI.getGlobalObjReg(GO);
210 if (!TargetReg.isValid())
211 continue;
212 if (!AEConstReg.isValid())
213 AEConstReg =
214 emitOpConstantUInt32(Value: AvailableExternally, UInt32TypeReg, MAI);
215 emitAuxDataExtInst(Opcode: LinkageOpcode, VoidTypeReg, ExtSetReg,
216 Operands: {TargetReg, AEConstReg}, MAI);
217 }
218}
219
220void SPIRVAuxDataHandler::emitAuxDataExtInst(AuxDataOpcode Opcode,
221 MCRegister VoidTypeReg,
222 MCRegister ExtSetReg,
223 ArrayRef<MCRegister> Operands,
224 SPIRV::ModuleAnalysisInfo &MAI) {
225 MCInst Inst;
226 Inst.setOpcode(SPIRV::OpExtInst);
227 Inst.addOperand(Op: MCOperand::createReg(Reg: MAI.getNextIDRegister()));
228 Inst.addOperand(Op: MCOperand::createReg(Reg: VoidTypeReg));
229 Inst.addOperand(Op: MCOperand::createReg(Reg: ExtSetReg));
230 Inst.addOperand(Op: MCOperand::createImm(Val: Opcode));
231 for (MCRegister R : Operands)
232 Inst.addOperand(Op: MCOperand::createReg(Reg: R));
233 emitMCInst(Inst);
234}
235
236void SPIRVAuxDataHandler::emitMCInst(MCInst &Inst) {
237 Asm.OutStreamer->emitInstruction(Inst, STI: Asm.getSubtargetInfo());
238}
239
240MCRegister
241SPIRVAuxDataHandler::findOrEmitOpTypeVoid(SPIRV::ModuleAnalysisInfo &MAI) {
242 for (const MachineInstr *MI : MAI.getMSInstrs(MSType: SPIRV::MB_TypeConstVars))
243 if (MI->getOpcode() == SPIRV::OpTypeVoid)
244 return MAI.getRegisterAlias(MF: MI->getMF(), Reg: MI->getOperand(i: 0).getReg());
245 MCRegister Reg = MAI.getNextIDRegister();
246 MCInst Inst;
247 Inst.setOpcode(SPIRV::OpTypeVoid);
248 Inst.addOperand(Op: MCOperand::createReg(Reg));
249 emitMCInst(Inst);
250 return Reg;
251}
252
253MCRegister
254SPIRVAuxDataHandler::findOrEmitOpTypeInt(unsigned BitWidth,
255 SPIRV::ModuleAnalysisInfo &MAI) {
256 // SPIR-V OpTypeInt: <width>, <signedness>. Signedness 0 = unsigned, 1 =
257 // signed; we always emit unsigned.
258 constexpr int64_t UnsignedSignedness = 0;
259 for (const MachineInstr *MI : MAI.getMSInstrs(MSType: SPIRV::MB_TypeConstVars))
260 if (MI->getOpcode() == SPIRV::OpTypeInt &&
261 MI->getOperand(i: 1).getImm() == static_cast<int64_t>(BitWidth) &&
262 MI->getOperand(i: 2).getImm() == UnsignedSignedness)
263 return MAI.getRegisterAlias(MF: MI->getMF(), Reg: MI->getOperand(i: 0).getReg());
264 MCRegister Reg = MAI.getNextIDRegister();
265 MCInst Inst;
266 Inst.setOpcode(SPIRV::OpTypeInt);
267 Inst.addOperand(Op: MCOperand::createReg(Reg));
268 Inst.addOperand(Op: MCOperand::createImm(Val: BitWidth));
269 Inst.addOperand(Op: MCOperand::createImm(Val: UnsignedSignedness));
270 emitMCInst(Inst);
271 return Reg;
272}
273
274MCRegister
275SPIRVAuxDataHandler::findOrEmitOpTypeUInt32(SPIRV::ModuleAnalysisInfo &MAI) {
276 return findOrEmitOpTypeInt(BitWidth: 32, MAI);
277}
278
279MCRegister
280SPIRVAuxDataHandler::findOrEmitOpTypeFloat(unsigned BitWidth,
281 SPIRV::ModuleAnalysisInfo &MAI) {
282 for (const MachineInstr *MI : MAI.getMSInstrs(MSType: SPIRV::MB_TypeConstVars))
283 if (MI->getOpcode() == SPIRV::OpTypeFloat &&
284 MI->getOperand(i: 1).getImm() == static_cast<int64_t>(BitWidth))
285 return MAI.getRegisterAlias(MF: MI->getMF(), Reg: MI->getOperand(i: 0).getReg());
286 MCRegister Reg = MAI.getNextIDRegister();
287 MCInst Inst;
288 Inst.setOpcode(SPIRV::OpTypeFloat);
289 Inst.addOperand(Op: MCOperand::createReg(Reg));
290 Inst.addOperand(Op: MCOperand::createImm(Val: BitWidth));
291 emitMCInst(Inst);
292 return Reg;
293}
294
295MCRegister SPIRVAuxDataHandler::emitConstant(const Constant *C,
296 SPIRV::ModuleAnalysisInfo &MAI) {
297 auto [It, Inserted] = ConstantRegs.try_emplace(Key: C);
298 if (!Inserted)
299 return It->second;
300
301 APInt Bits;
302 unsigned Opcode;
303 MCRegister TypeReg;
304 if (const auto *CI = dyn_cast<ConstantInt>(Val: C)) {
305 Bits = CI->getValue();
306 Opcode = SPIRV::OpConstantI;
307 TypeReg = findOrEmitOpTypeInt(BitWidth: Bits.getBitWidth(), MAI);
308 } else {
309 const auto *CF = cast<ConstantFP>(Val: C);
310 Bits = CF->getValueAPF().bitcastToAPInt();
311 Opcode = SPIRV::OpConstantF;
312 TypeReg = findOrEmitOpTypeFloat(BitWidth: Bits.getBitWidth(), MAI);
313 }
314
315 MCRegister Reg = MAI.getNextIDRegister();
316 It->second = Reg;
317 MCInst Inst;
318 Inst.setOpcode(Opcode);
319 Inst.addOperand(Op: MCOperand::createReg(Reg));
320 Inst.addOperand(Op: MCOperand::createReg(Reg: TypeReg));
321 // SPIR-V encodes the literal as ceil(width/32) little-endian 32-bit words.
322 unsigned NumWords = std::max(a: 1u, b: (Bits.getBitWidth() + 31) / 32);
323 for (unsigned I = 0; I < NumWords; ++I)
324 Inst.addOperand(Op: MCOperand::createImm(Val: Bits.extractBitsAsZExtValue(
325 numBits: std::min(a: 32u, b: Bits.getBitWidth() - I * 32), bitPosition: I * 32)));
326 // The asm printer needs this hint to render an f16 literal correctly.
327 if (Opcode == SPIRV::OpConstantF && Bits.getBitWidth() == 16)
328 Inst.setFlags(SPIRV::INST_PRINTER_WIDTH16);
329 emitMCInst(Inst);
330 return Reg;
331}
332
333MCRegister SPIRVAuxDataHandler::emitOpConstantUInt32(
334 uint32_t Value, MCRegister UInt32TypeReg, SPIRV::ModuleAnalysisInfo &MAI) {
335 MCRegister Reg = MAI.getNextIDRegister();
336 MCInst Inst;
337 Inst.setOpcode(SPIRV::OpConstantI);
338 Inst.addOperand(Op: MCOperand::createReg(Reg));
339 Inst.addOperand(Op: MCOperand::createReg(Reg: UInt32TypeReg));
340 Inst.addOperand(Op: MCOperand::createImm(Val: static_cast<int64_t>(Value)));
341 emitMCInst(Inst);
342 return Reg;
343}
344