1 | //=== AMDGPUPrintfRuntimeBinding.cpp - OpenCL printf implementation -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // \file |
9 | // |
10 | // The pass bind printfs to a kernel arg pointer that will be bound to a buffer |
11 | // later by the runtime. |
12 | // |
13 | // This pass traverses the functions in the module and converts |
14 | // each call to printf to a sequence of operations that |
15 | // store the following into the printf buffer: |
16 | // - format string (passed as a module's metadata unique ID) |
17 | // - bitwise copies of printf arguments |
18 | // The backend passes will need to store metadata in the kernel |
19 | //===----------------------------------------------------------------------===// |
20 | |
21 | #include "AMDGPU.h" |
22 | #include "llvm/ADT/StringExtras.h" |
23 | #include "llvm/Analysis/ValueTracking.h" |
24 | #include "llvm/IR/DiagnosticInfo.h" |
25 | #include "llvm/IR/Dominators.h" |
26 | #include "llvm/IR/IRBuilder.h" |
27 | #include "llvm/IR/Instructions.h" |
28 | #include "llvm/IR/Module.h" |
29 | #include "llvm/InitializePasses.h" |
30 | #include "llvm/Support/DataExtractor.h" |
31 | #include "llvm/TargetParser/Triple.h" |
32 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
33 | |
34 | using namespace llvm; |
35 | |
36 | #define DEBUG_TYPE "printfToRuntime" |
37 | enum { DWORD_ALIGN = 4 }; |
38 | |
39 | namespace { |
40 | class AMDGPUPrintfRuntimeBinding final : public ModulePass { |
41 | |
42 | public: |
43 | static char ID; |
44 | |
45 | explicit AMDGPUPrintfRuntimeBinding() : ModulePass(ID) {} |
46 | |
47 | private: |
48 | bool runOnModule(Module &M) override; |
49 | }; |
50 | |
51 | class AMDGPUPrintfRuntimeBindingImpl { |
52 | public: |
53 | AMDGPUPrintfRuntimeBindingImpl() = default; |
54 | bool run(Module &M); |
55 | |
56 | private: |
57 | void getConversionSpecifiers(SmallVectorImpl<char> &OpConvSpecifiers, |
58 | StringRef fmt, size_t num_ops) const; |
59 | |
60 | bool lowerPrintfForGpu(Module &M); |
61 | |
62 | const DataLayout *TD; |
63 | SmallVector<CallInst *, 32> Printfs; |
64 | }; |
65 | } // namespace |
66 | |
67 | char AMDGPUPrintfRuntimeBinding::ID = 0; |
68 | |
69 | INITIALIZE_PASS_BEGIN(AMDGPUPrintfRuntimeBinding, |
70 | "amdgpu-printf-runtime-binding" , "AMDGPU Printf lowering" , |
71 | false, false) |
72 | INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) |
73 | INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) |
74 | INITIALIZE_PASS_END(AMDGPUPrintfRuntimeBinding, "amdgpu-printf-runtime-binding" , |
75 | "AMDGPU Printf lowering" , false, false) |
76 | |
77 | char &llvm::AMDGPUPrintfRuntimeBindingID = AMDGPUPrintfRuntimeBinding::ID; |
78 | |
79 | ModulePass *llvm::createAMDGPUPrintfRuntimeBinding() { |
80 | return new AMDGPUPrintfRuntimeBinding(); |
81 | } |
82 | |
83 | void AMDGPUPrintfRuntimeBindingImpl::getConversionSpecifiers( |
84 | SmallVectorImpl<char> &OpConvSpecifiers, StringRef Fmt, |
85 | size_t NumOps) const { |
86 | // not all format characters are collected. |
87 | // At this time the format characters of interest |
88 | // are %p and %s, which use to know if we |
89 | // are either storing a literal string or a |
90 | // pointer to the printf buffer. |
91 | static const char ConvSpecifiers[] = "cdieEfgGaosuxXp" ; |
92 | size_t CurFmtSpecifierIdx = 0; |
93 | size_t PrevFmtSpecifierIdx = 0; |
94 | |
95 | while ((CurFmtSpecifierIdx = Fmt.find_first_of( |
96 | Chars: ConvSpecifiers, From: CurFmtSpecifierIdx)) != StringRef::npos) { |
97 | bool ArgDump = false; |
98 | StringRef CurFmt = Fmt.substr(Start: PrevFmtSpecifierIdx, |
99 | N: CurFmtSpecifierIdx - PrevFmtSpecifierIdx); |
100 | size_t pTag = CurFmt.find_last_of(C: '%'); |
101 | if (pTag != StringRef::npos) { |
102 | ArgDump = true; |
103 | while (pTag && CurFmt[--pTag] == '%') { |
104 | ArgDump = !ArgDump; |
105 | } |
106 | } |
107 | |
108 | if (ArgDump) |
109 | OpConvSpecifiers.push_back(Elt: Fmt[CurFmtSpecifierIdx]); |
110 | |
111 | PrevFmtSpecifierIdx = ++CurFmtSpecifierIdx; |
112 | } |
113 | } |
114 | |
115 | static bool shouldPrintAsStr(char Specifier, Type *OpType) { |
116 | return Specifier == 's' && isa<PointerType>(Val: OpType); |
117 | } |
118 | |
119 | constexpr StringLiteral NonLiteralStr("???" ); |
120 | static_assert(NonLiteralStr.size() == 3); |
121 | |
122 | static StringRef getAsConstantStr(Value *V) { |
123 | StringRef S; |
124 | if (!getConstantStringInfo(V, Str&: S)) |
125 | S = NonLiteralStr; |
126 | |
127 | return S; |
128 | } |
129 | |
130 | static void diagnoseInvalidFormatString(const CallBase *CI) { |
131 | CI->getContext().diagnose(DI: DiagnosticInfoUnsupported( |
132 | *CI->getParent()->getParent(), |
133 | "printf format string must be a trivially resolved constant string " |
134 | "global variable" , |
135 | CI->getDebugLoc())); |
136 | } |
137 | |
138 | bool AMDGPUPrintfRuntimeBindingImpl::lowerPrintfForGpu(Module &M) { |
139 | LLVMContext &Ctx = M.getContext(); |
140 | IRBuilder<> Builder(Ctx); |
141 | Type *I32Ty = Type::getInt32Ty(C&: Ctx); |
142 | |
143 | // Instead of creating global variables, the printf format strings are |
144 | // extracted and passed as metadata. This avoids polluting llvm's symbol |
145 | // tables in this module. Metadata is going to be extracted by the backend |
146 | // passes and inserted into the OpenCL binary as appropriate. |
147 | NamedMDNode *metaD = M.getOrInsertNamedMetadata(Name: "llvm.printf.fmts" ); |
148 | unsigned UniqID = metaD->getNumOperands(); |
149 | |
150 | for (auto *CI : Printfs) { |
151 | unsigned NumOps = CI->arg_size(); |
152 | |
153 | SmallString<16> OpConvSpecifiers; |
154 | Value *Op = CI->getArgOperand(i: 0); |
155 | |
156 | StringRef FormatStr; |
157 | if (!getConstantStringInfo(V: Op, Str&: FormatStr)) { |
158 | Value *Stripped = Op->stripPointerCasts(); |
159 | if (!isa<UndefValue>(Val: Stripped) && !isa<ConstantPointerNull>(Val: Stripped)) |
160 | diagnoseInvalidFormatString(CI); |
161 | continue; |
162 | } |
163 | |
164 | // We need this call to ascertain that we are printing a string or a |
165 | // pointer. It takes out the specifiers and fills up the first arg. |
166 | getConversionSpecifiers(OpConvSpecifiers, Fmt: FormatStr, NumOps: NumOps - 1); |
167 | |
168 | // Add metadata for the string |
169 | std::string AStreamHolder; |
170 | raw_string_ostream Sizes(AStreamHolder); |
171 | int Sum = DWORD_ALIGN; |
172 | Sizes << CI->arg_size() - 1; |
173 | Sizes << ':'; |
174 | for (unsigned ArgCount = 1; |
175 | ArgCount < CI->arg_size() && ArgCount <= OpConvSpecifiers.size(); |
176 | ArgCount++) { |
177 | Value *Arg = CI->getArgOperand(i: ArgCount); |
178 | Type *ArgType = Arg->getType(); |
179 | unsigned ArgSize = TD->getTypeAllocSize(Ty: ArgType); |
180 | // |
181 | // ArgSize by design should be a multiple of DWORD_ALIGN, |
182 | // expand the arguments that do not follow this rule. |
183 | // |
184 | if (ArgSize % DWORD_ALIGN != 0) { |
185 | Type *ResType = Type::getInt32Ty(C&: Ctx); |
186 | if (auto *VecType = dyn_cast<VectorType>(Val: ArgType)) |
187 | ResType = VectorType::get(ElementType: ResType, EC: VecType->getElementCount()); |
188 | Builder.SetInsertPoint(CI); |
189 | Builder.SetCurrentDebugLocation(CI->getDebugLoc()); |
190 | |
191 | if (ArgType->isFloatingPointTy()) { |
192 | Arg = Builder.CreateBitCast( |
193 | V: Arg, |
194 | DestTy: IntegerType::getIntNTy(C&: Ctx, N: ArgType->getPrimitiveSizeInBits())); |
195 | } |
196 | |
197 | if (OpConvSpecifiers[ArgCount - 1] == 'x' || |
198 | OpConvSpecifiers[ArgCount - 1] == 'X' || |
199 | OpConvSpecifiers[ArgCount - 1] == 'u' || |
200 | OpConvSpecifiers[ArgCount - 1] == 'o') |
201 | Arg = Builder.CreateZExt(V: Arg, DestTy: ResType); |
202 | else |
203 | Arg = Builder.CreateSExt(V: Arg, DestTy: ResType); |
204 | ArgType = Arg->getType(); |
205 | ArgSize = TD->getTypeAllocSize(Ty: ArgType); |
206 | CI->setOperand(i_nocapture: ArgCount, Val_nocapture: Arg); |
207 | } |
208 | if (OpConvSpecifiers[ArgCount - 1] == 'f') { |
209 | ConstantFP *FpCons = dyn_cast<ConstantFP>(Val: Arg); |
210 | if (FpCons) |
211 | ArgSize = 4; |
212 | else { |
213 | FPExtInst *FpExt = dyn_cast<FPExtInst>(Val: Arg); |
214 | if (FpExt && FpExt->getType()->isDoubleTy() && |
215 | FpExt->getOperand(i_nocapture: 0)->getType()->isFloatTy()) |
216 | ArgSize = 4; |
217 | } |
218 | } |
219 | if (shouldPrintAsStr(Specifier: OpConvSpecifiers[ArgCount - 1], OpType: ArgType)) |
220 | ArgSize = alignTo(Value: getAsConstantStr(V: Arg).size() + 1, Align: 4); |
221 | |
222 | LLVM_DEBUG(dbgs() << "Printf ArgSize (in buffer) = " << ArgSize |
223 | << " for type: " << *ArgType << '\n'); |
224 | Sizes << ArgSize << ':'; |
225 | Sum += ArgSize; |
226 | } |
227 | LLVM_DEBUG(dbgs() << "Printf format string in source = " << FormatStr |
228 | << '\n'); |
229 | for (char C : FormatStr) { |
230 | // Rest of the C escape sequences (e.g. \') are handled correctly |
231 | // by the MDParser |
232 | switch (C) { |
233 | case '\a': |
234 | Sizes << "\\a" ; |
235 | break; |
236 | case '\b': |
237 | Sizes << "\\b" ; |
238 | break; |
239 | case '\f': |
240 | Sizes << "\\f" ; |
241 | break; |
242 | case '\n': |
243 | Sizes << "\\n" ; |
244 | break; |
245 | case '\r': |
246 | Sizes << "\\r" ; |
247 | break; |
248 | case '\v': |
249 | Sizes << "\\v" ; |
250 | break; |
251 | case ':': |
252 | // ':' cannot be scanned by Flex, as it is defined as a delimiter |
253 | // Replace it with it's octal representation \72 |
254 | Sizes << "\\72" ; |
255 | break; |
256 | default: |
257 | Sizes << C; |
258 | break; |
259 | } |
260 | } |
261 | |
262 | // Insert the printf_alloc call |
263 | Builder.SetInsertPoint(CI); |
264 | Builder.SetCurrentDebugLocation(CI->getDebugLoc()); |
265 | |
266 | AttributeList Attr = AttributeList::get(C&: Ctx, Index: AttributeList::FunctionIndex, |
267 | Kinds: Attribute::NoUnwind); |
268 | |
269 | Type *SizetTy = Type::getInt32Ty(C&: Ctx); |
270 | |
271 | Type *Tys_alloc[1] = {SizetTy}; |
272 | Type *I8Ty = Type::getInt8Ty(C&: Ctx); |
273 | Type *I8Ptr = PointerType::get(C&: Ctx, AddressSpace: 1); |
274 | FunctionType *FTy_alloc = FunctionType::get(Result: I8Ptr, Params: Tys_alloc, isVarArg: false); |
275 | FunctionCallee PrintfAllocFn = |
276 | M.getOrInsertFunction(Name: StringRef("__printf_alloc" ), T: FTy_alloc, AttributeList: Attr); |
277 | |
278 | LLVM_DEBUG(dbgs() << "Printf metadata = " << Sizes.str() << '\n'); |
279 | std::string fmtstr = itostr(X: ++UniqID) + ":" + Sizes.str(); |
280 | MDString *fmtStrArray = MDString::get(Context&: Ctx, Str: fmtstr); |
281 | |
282 | MDNode *myMD = MDNode::get(Context&: Ctx, MDs: fmtStrArray); |
283 | metaD->addOperand(M: myMD); |
284 | Value *sumC = ConstantInt::get(Ty: SizetTy, V: Sum, IsSigned: false); |
285 | SmallVector<Value *, 1> alloc_args; |
286 | alloc_args.push_back(Elt: sumC); |
287 | CallInst *pcall = CallInst::Create(Func: PrintfAllocFn, Args: alloc_args, |
288 | NameStr: "printf_alloc_fn" , InsertBefore: CI->getIterator()); |
289 | |
290 | // |
291 | // Insert code to split basicblock with a |
292 | // piece of hammock code. |
293 | // basicblock splits after buffer overflow check |
294 | // |
295 | ConstantPointerNull *zeroIntPtr = |
296 | ConstantPointerNull::get(T: PointerType::get(C&: Ctx, AddressSpace: 1)); |
297 | auto *cmp = cast<ICmpInst>(Val: Builder.CreateICmpNE(LHS: pcall, RHS: zeroIntPtr, Name: "" )); |
298 | if (!CI->use_empty()) { |
299 | Value *result = |
300 | Builder.CreateSExt(V: Builder.CreateNot(V: cmp), DestTy: I32Ty, Name: "printf_res" ); |
301 | CI->replaceAllUsesWith(V: result); |
302 | } |
303 | SplitBlock(Old: CI->getParent(), SplitPt: cmp); |
304 | Instruction *Brnch = |
305 | SplitBlockAndInsertIfThen(Cond: cmp, SplitBefore: cmp->getNextNode(), Unreachable: false); |
306 | BasicBlock::iterator BrnchPoint = Brnch->getIterator(); |
307 | |
308 | Builder.SetInsertPoint(Brnch); |
309 | |
310 | // store unique printf id in the buffer |
311 | // |
312 | GetElementPtrInst *BufferIdx = GetElementPtrInst::Create( |
313 | PointeeType: I8Ty, Ptr: pcall, IdxList: ConstantInt::get(Context&: Ctx, V: APInt(32, 0)), NameStr: "PrintBuffID" , |
314 | InsertBefore: BrnchPoint); |
315 | |
316 | Type *idPointer = PointerType::get(C&: Ctx, AddressSpace: AMDGPUAS::GLOBAL_ADDRESS); |
317 | Value *id_gep_cast = |
318 | new BitCastInst(BufferIdx, idPointer, "PrintBuffIdCast" , BrnchPoint); |
319 | |
320 | new StoreInst(ConstantInt::get(Ty: I32Ty, V: UniqID), id_gep_cast, BrnchPoint); |
321 | |
322 | // 1st 4 bytes hold the printf_id |
323 | // the following GEP is the buffer pointer |
324 | BufferIdx = GetElementPtrInst::Create(PointeeType: I8Ty, Ptr: pcall, |
325 | IdxList: ConstantInt::get(Context&: Ctx, V: APInt(32, 4)), |
326 | NameStr: "PrintBuffGep" , InsertBefore: BrnchPoint); |
327 | |
328 | Type *Int32Ty = Type::getInt32Ty(C&: Ctx); |
329 | for (unsigned ArgCount = 1; |
330 | ArgCount < CI->arg_size() && ArgCount <= OpConvSpecifiers.size(); |
331 | ArgCount++) { |
332 | Value *Arg = CI->getArgOperand(i: ArgCount); |
333 | Type *ArgType = Arg->getType(); |
334 | SmallVector<Value *, 32> WhatToStore; |
335 | if (ArgType->isFPOrFPVectorTy() && !isa<VectorType>(Val: ArgType)) { |
336 | if (OpConvSpecifiers[ArgCount - 1] == 'f') { |
337 | if (auto *FpCons = dyn_cast<ConstantFP>(Val: Arg)) { |
338 | APFloat Val(FpCons->getValueAPF()); |
339 | bool Lost = false; |
340 | Val.convert(ToSemantics: APFloat::IEEEsingle(), RM: APFloat::rmNearestTiesToEven, |
341 | losesInfo: &Lost); |
342 | Arg = ConstantFP::get(Context&: Ctx, V: Val); |
343 | } else if (auto *FpExt = dyn_cast<FPExtInst>(Val: Arg)) { |
344 | if (FpExt->getType()->isDoubleTy() && |
345 | FpExt->getOperand(i_nocapture: 0)->getType()->isFloatTy()) { |
346 | Arg = FpExt->getOperand(i_nocapture: 0); |
347 | } |
348 | } |
349 | } |
350 | WhatToStore.push_back(Elt: Arg); |
351 | } else if (isa<PointerType>(Val: ArgType)) { |
352 | if (shouldPrintAsStr(Specifier: OpConvSpecifiers[ArgCount - 1], OpType: ArgType)) { |
353 | StringRef S = getAsConstantStr(V: Arg); |
354 | if (!S.empty()) { |
355 | const uint64_t ReadSize = 4; |
356 | |
357 | DataExtractor (S, /*IsLittleEndian=*/true, 8); |
358 | DataExtractor::Cursor Offset(0); |
359 | while (Offset && Offset.tell() < S.size()) { |
360 | uint64_t ReadNow = std::min(a: ReadSize, b: S.size() - Offset.tell()); |
361 | uint64_t ReadBytes = 0; |
362 | switch (ReadNow) { |
363 | default: llvm_unreachable("min(4, X) > 4?" ); |
364 | case 1: |
365 | ReadBytes = Extractor.getU8(C&: Offset); |
366 | break; |
367 | case 2: |
368 | ReadBytes = Extractor.getU16(C&: Offset); |
369 | break; |
370 | case 3: |
371 | ReadBytes = Extractor.getU24(C&: Offset); |
372 | break; |
373 | case 4: |
374 | ReadBytes = Extractor.getU32(C&: Offset); |
375 | break; |
376 | } |
377 | |
378 | cantFail(Err: Offset.takeError(), |
379 | Msg: "failed to read bytes from constant array" ); |
380 | |
381 | APInt IntVal(8 * ReadSize, ReadBytes); |
382 | |
383 | // TODO: Should not bothering aligning up. |
384 | if (ReadNow < ReadSize) |
385 | IntVal = IntVal.zext(width: 8 * ReadSize); |
386 | |
387 | Type *IntTy = Type::getIntNTy(C&: Ctx, N: IntVal.getBitWidth()); |
388 | WhatToStore.push_back(Elt: ConstantInt::get(Ty: IntTy, V: IntVal)); |
389 | } |
390 | } else { |
391 | // Empty string, give a hint to RT it is no NULL |
392 | Value *ANumV = ConstantInt::get(Ty: Int32Ty, V: 0xFFFFFF00, IsSigned: false); |
393 | WhatToStore.push_back(Elt: ANumV); |
394 | } |
395 | } else { |
396 | WhatToStore.push_back(Elt: Arg); |
397 | } |
398 | } else { |
399 | WhatToStore.push_back(Elt: Arg); |
400 | } |
401 | for (unsigned I = 0, E = WhatToStore.size(); I != E; ++I) { |
402 | Value *TheBtCast = WhatToStore[I]; |
403 | unsigned ArgSize = TD->getTypeAllocSize(Ty: TheBtCast->getType()); |
404 | StoreInst *StBuff = new StoreInst(TheBtCast, BufferIdx, BrnchPoint); |
405 | LLVM_DEBUG(dbgs() << "inserting store to printf buffer:\n" |
406 | << *StBuff << '\n'); |
407 | (void)StBuff; |
408 | if (I + 1 == E && ArgCount + 1 == CI->arg_size()) |
409 | break; |
410 | BufferIdx = GetElementPtrInst::Create( |
411 | PointeeType: I8Ty, Ptr: BufferIdx, IdxList: {ConstantInt::get(Ty: I32Ty, V: ArgSize)}, |
412 | NameStr: "PrintBuffNextPtr" , InsertBefore: BrnchPoint); |
413 | LLVM_DEBUG(dbgs() << "inserting gep to the printf buffer:\n" |
414 | << *BufferIdx << '\n'); |
415 | } |
416 | } |
417 | } |
418 | |
419 | // erase the printf calls |
420 | for (auto *CI : Printfs) |
421 | CI->eraseFromParent(); |
422 | |
423 | Printfs.clear(); |
424 | return true; |
425 | } |
426 | |
427 | bool AMDGPUPrintfRuntimeBindingImpl::run(Module &M) { |
428 | Triple TT(M.getTargetTriple()); |
429 | if (TT.getArch() == Triple::r600) |
430 | return false; |
431 | |
432 | auto *PrintfFunction = M.getFunction(Name: "printf" ); |
433 | if (!PrintfFunction || !PrintfFunction->isDeclaration() || |
434 | M.getModuleFlag(Key: "openmp" )) |
435 | return false; |
436 | |
437 | for (auto &U : PrintfFunction->uses()) { |
438 | if (auto *CI = dyn_cast<CallInst>(Val: U.getUser())) { |
439 | if (CI->isCallee(U: &U) && !CI->isNoBuiltin()) |
440 | Printfs.push_back(Elt: CI); |
441 | } |
442 | } |
443 | |
444 | if (Printfs.empty()) |
445 | return false; |
446 | |
447 | TD = &M.getDataLayout(); |
448 | |
449 | return lowerPrintfForGpu(M); |
450 | } |
451 | |
452 | bool AMDGPUPrintfRuntimeBinding::runOnModule(Module &M) { |
453 | return AMDGPUPrintfRuntimeBindingImpl().run(M); |
454 | } |
455 | |
456 | PreservedAnalyses |
457 | AMDGPUPrintfRuntimeBindingPass::run(Module &M, ModuleAnalysisManager &AM) { |
458 | bool Changed = AMDGPUPrintfRuntimeBindingImpl().run(M); |
459 | return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); |
460 | } |
461 | |