| 1 | //===-- SPIRVAuxDataHandler.cpp - NonSemantic.AuxData emitter -*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "SPIRVAuxDataHandler.h" |
| 10 | #include "MCTargetDesc/SPIRVMCTargetDesc.h" |
| 11 | #include "SPIRVSubtarget.h" |
| 12 | #include "SPIRVUtils.h" |
| 13 | #include "llvm/CodeGen/AsmPrinter.h" |
| 14 | #include "llvm/IR/Attributes.h" |
| 15 | #include "llvm/IR/Constants.h" |
| 16 | #include "llvm/IR/Function.h" |
| 17 | #include "llvm/IR/GlobalObject.h" |
| 18 | #include "llvm/IR/GlobalVariable.h" |
| 19 | #include "llvm/IR/LLVMContext.h" |
| 20 | #include "llvm/IR/Metadata.h" |
| 21 | #include "llvm/IR/Module.h" |
| 22 | #include "llvm/MC/MCInst.h" |
| 23 | #include "llvm/MC/MCStreamer.h" |
| 24 | #include "llvm/Support/CommandLine.h" |
| 25 | #include "llvm/Support/ErrorHandling.h" |
| 26 | |
| 27 | using namespace llvm; |
| 28 | |
| 29 | static cl::opt<bool> SPVPreserveAuxData( |
| 30 | "spirv-preserve-auxdata" , |
| 31 | cl::desc("Preserve LLVM attributes and metadata as " |
| 32 | "NonSemantic.AuxData ExtInst annotations (requires " |
| 33 | "SPV_KHR_non_semantic_info)" ), |
| 34 | cl::Optional, cl::Hidden, cl::init(Val: false)); |
| 35 | |
| 36 | namespace { |
| 37 | enum AuxDataLinkageType : uint32_t { |
| 38 | AvailableExternally = 0, |
| 39 | }; |
| 40 | |
| 41 | constexpr unsigned NonSemanticAuxDataSet = |
| 42 | static_cast<unsigned>(SPIRV::InstructionSet::NonSemantic_AuxData); |
| 43 | |
| 44 | AttributeSet getGOAttrs(const GlobalObject *GO) { |
| 45 | if (const auto *F = dyn_cast<Function>(Val: GO)) |
| 46 | return F->getAttributes().getFnAttrs(); |
| 47 | return cast<GlobalVariable>(Val: GO)->getAttributes(); |
| 48 | } |
| 49 | } // namespace |
| 50 | |
| 51 | static bool wasAvailableExternally(const GlobalObject *GO) { |
| 52 | if (const auto *F = dyn_cast<Function>(Val: GO)) |
| 53 | return F->hasFnAttribute(SPIRV_WAS_AVAILABLE_EXTERNALLY_ATTR); |
| 54 | return cast<GlobalVariable>(Val: GO)->getAttributes().hasAttribute( |
| 55 | SPIRV_WAS_AVAILABLE_EXTERNALLY_ATTR); |
| 56 | } |
| 57 | |
| 58 | SPIRVAuxDataHandler::SPIRVAuxDataHandler(AsmPrinter &AP, const Module &M) |
| 59 | : Asm(AP), Mod(M) { |
| 60 | for (const GlobalObject &GO : M.global_objects()) |
| 61 | if (wasAvailableExternally(GO: &GO)) |
| 62 | LinkagePreservedGOs.push_back(Elt: &GO); |
| 63 | } |
| 64 | |
| 65 | bool SPIRVAuxDataHandler::hasWork() const { return SPVPreserveAuxData; } |
| 66 | |
| 67 | void SPIRVAuxDataHandler::prepareModuleOutput(const SPIRVSubtarget &ST, |
| 68 | SPIRV::ModuleAnalysisInfo &MAI) { |
| 69 | if (!hasWork()) |
| 70 | return; |
| 71 | if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_non_semantic_info)) { |
| 72 | if (SPVPreserveAuxData) |
| 73 | report_fatal_error(reason: "-spirv-preserve-auxdata requires the " |
| 74 | "SPV_KHR_non_semantic_info extension to be enabled." ); |
| 75 | return; |
| 76 | } |
| 77 | MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_non_semantic_info); |
| 78 | if (!MAI.ExtInstSetMap.count(Val: NonSemanticAuxDataSet)) |
| 79 | MAI.ExtInstSetMap[NonSemanticAuxDataSet] = MAI.getNextIDRegister(); |
| 80 | } |
| 81 | |
| 82 | MCRegister |
| 83 | SPIRVAuxDataHandler::getOrEmitString(StringRef S, |
| 84 | SPIRV::ModuleAnalysisInfo &MAI) { |
| 85 | auto [It, Inserted] = StringRegs.try_emplace(Key: S); |
| 86 | if (!Inserted) |
| 87 | return It->second; |
| 88 | MCRegister Reg = MAI.getNextIDRegister(); |
| 89 | It->second = Reg; |
| 90 | MCInst Inst; |
| 91 | Inst.setOpcode(SPIRV::OpString); |
| 92 | Inst.addOperand(Op: MCOperand::createReg(Reg)); |
| 93 | addStringImm(Str: S, Inst); |
| 94 | emitMCInst(Inst); |
| 95 | return Reg; |
| 96 | } |
| 97 | |
| 98 | void SPIRVAuxDataHandler::collectAttributesFor(const GlobalObject *GO, |
| 99 | SPIRV::ModuleAnalysisInfo &MAI) { |
| 100 | AuxDataOpcode Opcode = isa<Function>(Val: GO) ? FunctionAttributeOpcode |
| 101 | : GlobalVariableAttributeOpcode; |
| 102 | for (const Attribute &A : getGOAttrs(GO)) { |
| 103 | if (A.isStringAttribute() && |
| 104 | A.getKindAsString() == SPIRV_WAS_AVAILABLE_EXTERNALLY_ATTR) |
| 105 | continue; |
| 106 | ExtInstRecord Rec; |
| 107 | Rec.Opcode = Opcode; |
| 108 | Rec.Target = GO; |
| 109 | if (A.isStringAttribute()) { |
| 110 | Rec.Operands.push_back(Elt: {.Reg: getOrEmitString(S: A.getKindAsString(), MAI)}); |
| 111 | StringRef Val = A.getValueAsString(); |
| 112 | if (!Val.empty()) |
| 113 | Rec.Operands.push_back(Elt: {.Reg: getOrEmitString(S: Val, MAI)}); |
| 114 | } else { |
| 115 | Rec.Operands.push_back( |
| 116 | Elt: {.Reg: getOrEmitString(S: StringPool.save(S: A.getAsString()), MAI)}); |
| 117 | } |
| 118 | PendingRecords.push_back(Elt: std::move(Rec)); |
| 119 | } |
| 120 | } |
| 121 | |
| 122 | void SPIRVAuxDataHandler::collectMetadataFor(const GlobalObject *GO, |
| 123 | ArrayRef<StringRef> MDNames, |
| 124 | SPIRV::ModuleAnalysisInfo &MAI) { |
| 125 | SmallVector<std::pair<unsigned, MDNode *>> AllMD; |
| 126 | GO->getAllMetadata(MDs&: AllMD); |
| 127 | if (AllMD.empty()) |
| 128 | return; |
| 129 | AuxDataOpcode Opcode = |
| 130 | isa<Function>(Val: GO) ? FunctionMetadataOpcode : GlobalVariableMetadataOpcode; |
| 131 | // MDString operands become OpStrings; ValueAsMetadata constants (e.g. |
| 132 | // !{i32 5}) become OpConstants emitted at section 10. Any other operand |
| 133 | // kind would need full value translation, so skip the whole node. |
| 134 | auto CollectOperands = |
| 135 | [&](MDNode *MD) -> std::optional<SmallVector<Operand, 4>> { |
| 136 | SmallVector<Operand, 4> Out; |
| 137 | for (const MDOperand &MdOp : MD->operands()) { |
| 138 | Metadata *Md = MdOp.get(); |
| 139 | if (auto *MDStr = dyn_cast_or_null<MDString>(Val: Md)) { |
| 140 | Out.push_back(Elt: {.Reg: getOrEmitString(S: MDStr->getString(), MAI)}); |
| 141 | } else if (auto *VAM = dyn_cast_or_null<ValueAsMetadata>(Val: Md)) { |
| 142 | auto *C = dyn_cast<Constant>(Val: VAM->getValue()); |
| 143 | if (!C || !(isa<ConstantInt>(Val: C) || isa<ConstantFP>(Val: C))) |
| 144 | return std::nullopt; |
| 145 | Out.push_back(Elt: {.Reg: MCRegister(), .Const: C}); |
| 146 | } else { |
| 147 | return std::nullopt; |
| 148 | } |
| 149 | } |
| 150 | return Out; |
| 151 | }; |
| 152 | for (const auto &MD : AllMD) { |
| 153 | if (MD.first == LLVMContext::MD_dbg) |
| 154 | continue; |
| 155 | StringRef MDName = MDNames[MD.first]; |
| 156 | if (MDName == "spirv.Decorations" || MDName == "spirv.ParameterDecorations" ) |
| 157 | continue; |
| 158 | auto Operands = CollectOperands(MD.second); |
| 159 | if (!Operands) |
| 160 | continue; |
| 161 | ExtInstRecord Rec; |
| 162 | Rec.Opcode = Opcode; |
| 163 | Rec.Target = GO; |
| 164 | Rec.Operands.push_back(Elt: {.Reg: getOrEmitString(S: MDName, MAI)}); |
| 165 | Rec.Operands.append(in_start: Operands->begin(), in_end: Operands->end()); |
| 166 | PendingRecords.push_back(Elt: std::move(Rec)); |
| 167 | } |
| 168 | } |
| 169 | |
| 170 | void SPIRVAuxDataHandler::emitAuxDataStrings(SPIRV::ModuleAnalysisInfo &MAI) { |
| 171 | if (!SPVPreserveAuxData) |
| 172 | return; |
| 173 | if (!MAI.getExtInstSetReg(SetNum: NonSemanticAuxDataSet).isValid()) |
| 174 | return; |
| 175 | SmallVector<StringRef> MDNames; |
| 176 | Mod.getContext().getMDKindNames(Result&: MDNames); |
| 177 | for (const GlobalObject &GO : Mod.global_objects()) { |
| 178 | if (GO.isDeclaration()) |
| 179 | continue; |
| 180 | collectAttributesFor(GO: &GO, MAI); |
| 181 | collectMetadataFor(GO: &GO, MDNames, MAI); |
| 182 | } |
| 183 | } |
| 184 | |
| 185 | void SPIRVAuxDataHandler::emitAuxData(SPIRV::ModuleAnalysisInfo &MAI) { |
| 186 | MCRegister ExtSetReg = MAI.getExtInstSetReg(SetNum: NonSemanticAuxDataSet); |
| 187 | if (!ExtSetReg.isValid()) |
| 188 | return; |
| 189 | |
| 190 | MCRegister VoidTypeReg = findOrEmitOpTypeVoid(MAI); |
| 191 | |
| 192 | for (const ExtInstRecord &Rec : PendingRecords) { |
| 193 | MCRegister TargetReg = MAI.getGlobalObjReg(GO: Rec.Target); |
| 194 | if (!TargetReg.isValid()) |
| 195 | continue; |
| 196 | SmallVector<MCRegister, 5> Operands; |
| 197 | Operands.push_back(Elt: TargetReg); |
| 198 | for (const Operand &Op : Rec.Operands) |
| 199 | Operands.push_back(Elt: Op.Const ? emitConstant(C: Op.Const, MAI) : Op.Reg); |
| 200 | emitAuxDataExtInst(Opcode: Rec.Opcode, VoidTypeReg, ExtSetReg, Operands, MAI); |
| 201 | } |
| 202 | |
| 203 | if (LinkagePreservedGOs.empty()) |
| 204 | return; |
| 205 | |
| 206 | MCRegister UInt32TypeReg = findOrEmitOpTypeUInt32(MAI); |
| 207 | MCRegister AEConstReg; |
| 208 | for (const GlobalObject *GO : LinkagePreservedGOs) { |
| 209 | MCRegister TargetReg = MAI.getGlobalObjReg(GO); |
| 210 | if (!TargetReg.isValid()) |
| 211 | continue; |
| 212 | if (!AEConstReg.isValid()) |
| 213 | AEConstReg = |
| 214 | emitOpConstantUInt32(Value: AvailableExternally, UInt32TypeReg, MAI); |
| 215 | emitAuxDataExtInst(Opcode: LinkageOpcode, VoidTypeReg, ExtSetReg, |
| 216 | Operands: {TargetReg, AEConstReg}, MAI); |
| 217 | } |
| 218 | } |
| 219 | |
| 220 | void SPIRVAuxDataHandler::emitAuxDataExtInst(AuxDataOpcode Opcode, |
| 221 | MCRegister VoidTypeReg, |
| 222 | MCRegister ExtSetReg, |
| 223 | ArrayRef<MCRegister> Operands, |
| 224 | SPIRV::ModuleAnalysisInfo &MAI) { |
| 225 | MCInst Inst; |
| 226 | Inst.setOpcode(SPIRV::OpExtInst); |
| 227 | Inst.addOperand(Op: MCOperand::createReg(Reg: MAI.getNextIDRegister())); |
| 228 | Inst.addOperand(Op: MCOperand::createReg(Reg: VoidTypeReg)); |
| 229 | Inst.addOperand(Op: MCOperand::createReg(Reg: ExtSetReg)); |
| 230 | Inst.addOperand(Op: MCOperand::createImm(Val: Opcode)); |
| 231 | for (MCRegister R : Operands) |
| 232 | Inst.addOperand(Op: MCOperand::createReg(Reg: R)); |
| 233 | emitMCInst(Inst); |
| 234 | } |
| 235 | |
| 236 | void SPIRVAuxDataHandler::emitMCInst(MCInst &Inst) { |
| 237 | Asm.OutStreamer->emitInstruction(Inst, STI: Asm.getSubtargetInfo()); |
| 238 | } |
| 239 | |
| 240 | MCRegister |
| 241 | SPIRVAuxDataHandler::findOrEmitOpTypeVoid(SPIRV::ModuleAnalysisInfo &MAI) { |
| 242 | for (const MachineInstr *MI : MAI.getMSInstrs(MSType: SPIRV::MB_TypeConstVars)) |
| 243 | if (MI->getOpcode() == SPIRV::OpTypeVoid) |
| 244 | return MAI.getRegisterAlias(MF: MI->getMF(), Reg: MI->getOperand(i: 0).getReg()); |
| 245 | MCRegister Reg = MAI.getNextIDRegister(); |
| 246 | MCInst Inst; |
| 247 | Inst.setOpcode(SPIRV::OpTypeVoid); |
| 248 | Inst.addOperand(Op: MCOperand::createReg(Reg)); |
| 249 | emitMCInst(Inst); |
| 250 | return Reg; |
| 251 | } |
| 252 | |
| 253 | MCRegister |
| 254 | SPIRVAuxDataHandler::findOrEmitOpTypeInt(unsigned BitWidth, |
| 255 | SPIRV::ModuleAnalysisInfo &MAI) { |
| 256 | // SPIR-V OpTypeInt: <width>, <signedness>. Signedness 0 = unsigned, 1 = |
| 257 | // signed; we always emit unsigned. |
| 258 | constexpr int64_t UnsignedSignedness = 0; |
| 259 | for (const MachineInstr *MI : MAI.getMSInstrs(MSType: SPIRV::MB_TypeConstVars)) |
| 260 | if (MI->getOpcode() == SPIRV::OpTypeInt && |
| 261 | MI->getOperand(i: 1).getImm() == static_cast<int64_t>(BitWidth) && |
| 262 | MI->getOperand(i: 2).getImm() == UnsignedSignedness) |
| 263 | return MAI.getRegisterAlias(MF: MI->getMF(), Reg: MI->getOperand(i: 0).getReg()); |
| 264 | MCRegister Reg = MAI.getNextIDRegister(); |
| 265 | MCInst Inst; |
| 266 | Inst.setOpcode(SPIRV::OpTypeInt); |
| 267 | Inst.addOperand(Op: MCOperand::createReg(Reg)); |
| 268 | Inst.addOperand(Op: MCOperand::createImm(Val: BitWidth)); |
| 269 | Inst.addOperand(Op: MCOperand::createImm(Val: UnsignedSignedness)); |
| 270 | emitMCInst(Inst); |
| 271 | return Reg; |
| 272 | } |
| 273 | |
| 274 | MCRegister |
| 275 | SPIRVAuxDataHandler::findOrEmitOpTypeUInt32(SPIRV::ModuleAnalysisInfo &MAI) { |
| 276 | return findOrEmitOpTypeInt(BitWidth: 32, MAI); |
| 277 | } |
| 278 | |
| 279 | MCRegister |
| 280 | SPIRVAuxDataHandler::findOrEmitOpTypeFloat(unsigned BitWidth, |
| 281 | SPIRV::ModuleAnalysisInfo &MAI) { |
| 282 | for (const MachineInstr *MI : MAI.getMSInstrs(MSType: SPIRV::MB_TypeConstVars)) |
| 283 | if (MI->getOpcode() == SPIRV::OpTypeFloat && |
| 284 | MI->getOperand(i: 1).getImm() == static_cast<int64_t>(BitWidth)) |
| 285 | return MAI.getRegisterAlias(MF: MI->getMF(), Reg: MI->getOperand(i: 0).getReg()); |
| 286 | MCRegister Reg = MAI.getNextIDRegister(); |
| 287 | MCInst Inst; |
| 288 | Inst.setOpcode(SPIRV::OpTypeFloat); |
| 289 | Inst.addOperand(Op: MCOperand::createReg(Reg)); |
| 290 | Inst.addOperand(Op: MCOperand::createImm(Val: BitWidth)); |
| 291 | emitMCInst(Inst); |
| 292 | return Reg; |
| 293 | } |
| 294 | |
| 295 | MCRegister SPIRVAuxDataHandler::emitConstant(const Constant *C, |
| 296 | SPIRV::ModuleAnalysisInfo &MAI) { |
| 297 | auto [It, Inserted] = ConstantRegs.try_emplace(Key: C); |
| 298 | if (!Inserted) |
| 299 | return It->second; |
| 300 | |
| 301 | APInt Bits; |
| 302 | unsigned Opcode; |
| 303 | MCRegister TypeReg; |
| 304 | if (const auto *CI = dyn_cast<ConstantInt>(Val: C)) { |
| 305 | Bits = CI->getValue(); |
| 306 | Opcode = SPIRV::OpConstantI; |
| 307 | TypeReg = findOrEmitOpTypeInt(BitWidth: Bits.getBitWidth(), MAI); |
| 308 | } else { |
| 309 | const auto *CF = cast<ConstantFP>(Val: C); |
| 310 | Bits = CF->getValueAPF().bitcastToAPInt(); |
| 311 | Opcode = SPIRV::OpConstantF; |
| 312 | TypeReg = findOrEmitOpTypeFloat(BitWidth: Bits.getBitWidth(), MAI); |
| 313 | } |
| 314 | |
| 315 | MCRegister Reg = MAI.getNextIDRegister(); |
| 316 | It->second = Reg; |
| 317 | MCInst Inst; |
| 318 | Inst.setOpcode(Opcode); |
| 319 | Inst.addOperand(Op: MCOperand::createReg(Reg)); |
| 320 | Inst.addOperand(Op: MCOperand::createReg(Reg: TypeReg)); |
| 321 | // SPIR-V encodes the literal as ceil(width/32) little-endian 32-bit words. |
| 322 | unsigned NumWords = std::max(a: 1u, b: (Bits.getBitWidth() + 31) / 32); |
| 323 | for (unsigned I = 0; I < NumWords; ++I) |
| 324 | Inst.addOperand(Op: MCOperand::createImm(Val: Bits.extractBitsAsZExtValue( |
| 325 | numBits: std::min(a: 32u, b: Bits.getBitWidth() - I * 32), bitPosition: I * 32))); |
| 326 | // The asm printer needs this hint to render an f16 literal correctly. |
| 327 | if (Opcode == SPIRV::OpConstantF && Bits.getBitWidth() == 16) |
| 328 | Inst.setFlags(SPIRV::INST_PRINTER_WIDTH16); |
| 329 | emitMCInst(Inst); |
| 330 | return Reg; |
| 331 | } |
| 332 | |
| 333 | MCRegister SPIRVAuxDataHandler::emitOpConstantUInt32( |
| 334 | uint32_t Value, MCRegister UInt32TypeReg, SPIRV::ModuleAnalysisInfo &MAI) { |
| 335 | MCRegister Reg = MAI.getNextIDRegister(); |
| 336 | MCInst Inst; |
| 337 | Inst.setOpcode(SPIRV::OpConstantI); |
| 338 | Inst.addOperand(Op: MCOperand::createReg(Reg)); |
| 339 | Inst.addOperand(Op: MCOperand::createReg(Reg: UInt32TypeReg)); |
| 340 | Inst.addOperand(Op: MCOperand::createImm(Val: static_cast<int64_t>(Value))); |
| 341 | emitMCInst(Inst); |
| 342 | return Reg; |
| 343 | } |
| 344 | |