1 | //===- AMDGPULibCalls.cpp -------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | /// \file |
10 | /// This file does AMD library function optimizations. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "AMDGPU.h" |
15 | #include "AMDGPULibFunc.h" |
16 | #include "llvm/Analysis/AssumptionCache.h" |
17 | #include "llvm/Analysis/TargetLibraryInfo.h" |
18 | #include "llvm/Analysis/ValueTracking.h" |
19 | #include "llvm/IR/AttributeMask.h" |
20 | #include "llvm/IR/Dominators.h" |
21 | #include "llvm/IR/IRBuilder.h" |
22 | #include "llvm/IR/MDBuilder.h" |
23 | #include "llvm/IR/PatternMatch.h" |
24 | #include <cmath> |
25 | |
26 | #define DEBUG_TYPE "amdgpu-simplifylib" |
27 | |
28 | using namespace llvm; |
29 | using namespace llvm::PatternMatch; |
30 | |
31 | static cl::opt<bool> EnablePreLink("amdgpu-prelink" , |
32 | cl::desc("Enable pre-link mode optimizations" ), |
33 | cl::init(Val: false), |
34 | cl::Hidden); |
35 | |
36 | static cl::list<std::string> UseNative("amdgpu-use-native" , |
37 | cl::desc("Comma separated list of functions to replace with native, or all" ), |
38 | cl::CommaSeparated, cl::ValueOptional, |
39 | cl::Hidden); |
40 | |
41 | #define MATH_PI numbers::pi |
42 | #define MATH_E numbers::e |
43 | #define MATH_SQRT2 numbers::sqrt2 |
44 | #define MATH_SQRT1_2 numbers::inv_sqrt2 |
45 | |
46 | namespace llvm { |
47 | |
48 | class AMDGPULibCalls { |
49 | private: |
50 | const TargetLibraryInfo *TLInfo = nullptr; |
51 | AssumptionCache *AC = nullptr; |
52 | DominatorTree *DT = nullptr; |
53 | |
54 | using FuncInfo = llvm::AMDGPULibFunc; |
55 | |
56 | bool UnsafeFPMath = false; |
57 | |
58 | // -fuse-native. |
59 | bool AllNative = false; |
60 | |
61 | bool useNativeFunc(const StringRef F) const; |
62 | |
63 | // Return a pointer (pointer expr) to the function if function definition with |
64 | // "FuncName" exists. It may create a new function prototype in pre-link mode. |
65 | FunctionCallee getFunction(Module *M, const FuncInfo &fInfo); |
66 | |
67 | bool parseFunctionName(const StringRef &FMangledName, FuncInfo &FInfo); |
68 | |
69 | bool TDOFold(CallInst *CI, const FuncInfo &FInfo); |
70 | |
71 | /* Specialized optimizations */ |
72 | |
73 | // pow/powr/pown |
74 | bool fold_pow(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo); |
75 | |
76 | // rootn |
77 | bool fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo); |
78 | |
79 | // -fuse-native for sincos |
80 | bool sincosUseNative(CallInst *aCI, const FuncInfo &FInfo); |
81 | |
82 | // evaluate calls if calls' arguments are constants. |
83 | bool evaluateScalarMathFunc(const FuncInfo &FInfo, double &Res0, double &Res1, |
84 | Constant *copr0, Constant *copr1); |
85 | bool evaluateCall(CallInst *aCI, const FuncInfo &FInfo); |
86 | |
87 | /// Insert a value to sincos function \p Fsincos. Returns (value of sin, value |
88 | /// of cos, sincos call). |
89 | std::tuple<Value *, Value *, Value *> insertSinCos(Value *Arg, |
90 | FastMathFlags FMF, |
91 | IRBuilder<> &B, |
92 | FunctionCallee Fsincos); |
93 | |
94 | // sin/cos |
95 | bool fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo); |
96 | |
97 | // __read_pipe/__write_pipe |
98 | bool fold_read_write_pipe(CallInst *CI, IRBuilder<> &B, |
99 | const FuncInfo &FInfo); |
100 | |
101 | // Get a scalar native builtin single argument FP function |
102 | FunctionCallee getNativeFunction(Module *M, const FuncInfo &FInfo); |
103 | |
104 | /// Substitute a call to a known libcall with an intrinsic call. If \p |
105 | /// AllowMinSize is true, allow the replacement in a minsize function. |
106 | bool shouldReplaceLibcallWithIntrinsic(const CallInst *CI, |
107 | bool AllowMinSizeF32 = false, |
108 | bool AllowF64 = false, |
109 | bool AllowStrictFP = false); |
110 | void replaceLibCallWithSimpleIntrinsic(IRBuilder<> &B, CallInst *CI, |
111 | Intrinsic::ID IntrID); |
112 | |
113 | bool tryReplaceLibcallWithSimpleIntrinsic(IRBuilder<> &B, CallInst *CI, |
114 | Intrinsic::ID IntrID, |
115 | bool AllowMinSizeF32 = false, |
116 | bool AllowF64 = false, |
117 | bool AllowStrictFP = false); |
118 | |
119 | protected: |
120 | bool isUnsafeMath(const FPMathOperator *FPOp) const; |
121 | bool isUnsafeFiniteOnlyMath(const FPMathOperator *FPOp) const; |
122 | |
123 | bool canIncreasePrecisionOfConstantFold(const FPMathOperator *FPOp) const; |
124 | |
125 | static void replaceCall(Instruction *I, Value *With) { |
126 | I->replaceAllUsesWith(V: With); |
127 | I->eraseFromParent(); |
128 | } |
129 | |
130 | static void replaceCall(FPMathOperator *I, Value *With) { |
131 | replaceCall(I: cast<Instruction>(Val: I), With); |
132 | } |
133 | |
134 | public: |
135 | AMDGPULibCalls() = default; |
136 | |
137 | bool fold(CallInst *CI); |
138 | |
139 | void initFunction(Function &F, FunctionAnalysisManager &FAM); |
140 | void initNativeFuncs(); |
141 | |
142 | // Replace a normal math function call with that native version |
143 | bool useNative(CallInst *CI); |
144 | }; |
145 | |
146 | } // end namespace llvm |
147 | |
148 | template <typename IRB> |
149 | static CallInst *CreateCallEx(IRB &B, FunctionCallee Callee, Value *Arg, |
150 | const Twine &Name = "" ) { |
151 | CallInst *R = B.CreateCall(Callee, Arg, Name); |
152 | if (Function *F = dyn_cast<Function>(Val: Callee.getCallee())) |
153 | R->setCallingConv(F->getCallingConv()); |
154 | return R; |
155 | } |
156 | |
157 | template <typename IRB> |
158 | static CallInst *CreateCallEx2(IRB &B, FunctionCallee Callee, Value *Arg1, |
159 | Value *Arg2, const Twine &Name = "" ) { |
160 | CallInst *R = B.CreateCall(Callee, {Arg1, Arg2}, Name); |
161 | if (Function *F = dyn_cast<Function>(Val: Callee.getCallee())) |
162 | R->setCallingConv(F->getCallingConv()); |
163 | return R; |
164 | } |
165 | |
166 | static FunctionType *getPownType(FunctionType *FT) { |
167 | Type *PowNExpTy = Type::getInt32Ty(C&: FT->getContext()); |
168 | if (VectorType *VecTy = dyn_cast<VectorType>(Val: FT->getReturnType())) |
169 | PowNExpTy = VectorType::get(ElementType: PowNExpTy, EC: VecTy->getElementCount()); |
170 | |
171 | return FunctionType::get(Result: FT->getReturnType(), |
172 | Params: {FT->getParamType(i: 0), PowNExpTy}, isVarArg: false); |
173 | } |
174 | |
175 | // Data structures for table-driven optimizations. |
176 | // FuncTbl works for both f32 and f64 functions with 1 input argument |
177 | |
178 | struct TableEntry { |
179 | double result; |
180 | double input; |
181 | }; |
182 | |
183 | /* a list of {result, input} */ |
184 | static const TableEntry tbl_acos[] = { |
185 | {MATH_PI / 2.0, .input: 0.0}, |
186 | {MATH_PI / 2.0, .input: -0.0}, |
187 | {.result: 0.0, .input: 1.0}, |
188 | {MATH_PI, .input: -1.0} |
189 | }; |
190 | static const TableEntry tbl_acosh[] = { |
191 | {.result: 0.0, .input: 1.0} |
192 | }; |
193 | static const TableEntry tbl_acospi[] = { |
194 | {.result: 0.5, .input: 0.0}, |
195 | {.result: 0.5, .input: -0.0}, |
196 | {.result: 0.0, .input: 1.0}, |
197 | {.result: 1.0, .input: -1.0} |
198 | }; |
199 | static const TableEntry tbl_asin[] = { |
200 | {.result: 0.0, .input: 0.0}, |
201 | {.result: -0.0, .input: -0.0}, |
202 | {MATH_PI / 2.0, .input: 1.0}, |
203 | {.result: -MATH_PI / 2.0, .input: -1.0} |
204 | }; |
205 | static const TableEntry tbl_asinh[] = { |
206 | {.result: 0.0, .input: 0.0}, |
207 | {.result: -0.0, .input: -0.0} |
208 | }; |
209 | static const TableEntry tbl_asinpi[] = { |
210 | {.result: 0.0, .input: 0.0}, |
211 | {.result: -0.0, .input: -0.0}, |
212 | {.result: 0.5, .input: 1.0}, |
213 | {.result: -0.5, .input: -1.0} |
214 | }; |
215 | static const TableEntry tbl_atan[] = { |
216 | {.result: 0.0, .input: 0.0}, |
217 | {.result: -0.0, .input: -0.0}, |
218 | {MATH_PI / 4.0, .input: 1.0}, |
219 | {.result: -MATH_PI / 4.0, .input: -1.0} |
220 | }; |
221 | static const TableEntry tbl_atanh[] = { |
222 | {.result: 0.0, .input: 0.0}, |
223 | {.result: -0.0, .input: -0.0} |
224 | }; |
225 | static const TableEntry tbl_atanpi[] = { |
226 | {.result: 0.0, .input: 0.0}, |
227 | {.result: -0.0, .input: -0.0}, |
228 | {.result: 0.25, .input: 1.0}, |
229 | {.result: -0.25, .input: -1.0} |
230 | }; |
231 | static const TableEntry tbl_cbrt[] = { |
232 | {.result: 0.0, .input: 0.0}, |
233 | {.result: -0.0, .input: -0.0}, |
234 | {.result: 1.0, .input: 1.0}, |
235 | {.result: -1.0, .input: -1.0}, |
236 | }; |
237 | static const TableEntry tbl_cos[] = { |
238 | {.result: 1.0, .input: 0.0}, |
239 | {.result: 1.0, .input: -0.0} |
240 | }; |
241 | static const TableEntry tbl_cosh[] = { |
242 | {.result: 1.0, .input: 0.0}, |
243 | {.result: 1.0, .input: -0.0} |
244 | }; |
245 | static const TableEntry tbl_cospi[] = { |
246 | {.result: 1.0, .input: 0.0}, |
247 | {.result: 1.0, .input: -0.0} |
248 | }; |
249 | static const TableEntry tbl_erfc[] = { |
250 | {.result: 1.0, .input: 0.0}, |
251 | {.result: 1.0, .input: -0.0} |
252 | }; |
253 | static const TableEntry tbl_erf[] = { |
254 | {.result: 0.0, .input: 0.0}, |
255 | {.result: -0.0, .input: -0.0} |
256 | }; |
257 | static const TableEntry tbl_exp[] = { |
258 | {.result: 1.0, .input: 0.0}, |
259 | {.result: 1.0, .input: -0.0}, |
260 | {MATH_E, .input: 1.0} |
261 | }; |
262 | static const TableEntry tbl_exp2[] = { |
263 | {.result: 1.0, .input: 0.0}, |
264 | {.result: 1.0, .input: -0.0}, |
265 | {.result: 2.0, .input: 1.0} |
266 | }; |
267 | static const TableEntry tbl_exp10[] = { |
268 | {.result: 1.0, .input: 0.0}, |
269 | {.result: 1.0, .input: -0.0}, |
270 | {.result: 10.0, .input: 1.0} |
271 | }; |
272 | static const TableEntry tbl_expm1[] = { |
273 | {.result: 0.0, .input: 0.0}, |
274 | {.result: -0.0, .input: -0.0} |
275 | }; |
276 | static const TableEntry tbl_log[] = { |
277 | {.result: 0.0, .input: 1.0}, |
278 | {.result: 1.0, MATH_E} |
279 | }; |
280 | static const TableEntry tbl_log2[] = { |
281 | {.result: 0.0, .input: 1.0}, |
282 | {.result: 1.0, .input: 2.0} |
283 | }; |
284 | static const TableEntry tbl_log10[] = { |
285 | {.result: 0.0, .input: 1.0}, |
286 | {.result: 1.0, .input: 10.0} |
287 | }; |
288 | static const TableEntry tbl_rsqrt[] = { |
289 | {.result: 1.0, .input: 1.0}, |
290 | {MATH_SQRT1_2, .input: 2.0} |
291 | }; |
292 | static const TableEntry tbl_sin[] = { |
293 | {.result: 0.0, .input: 0.0}, |
294 | {.result: -0.0, .input: -0.0} |
295 | }; |
296 | static const TableEntry tbl_sinh[] = { |
297 | {.result: 0.0, .input: 0.0}, |
298 | {.result: -0.0, .input: -0.0} |
299 | }; |
300 | static const TableEntry tbl_sinpi[] = { |
301 | {.result: 0.0, .input: 0.0}, |
302 | {.result: -0.0, .input: -0.0} |
303 | }; |
304 | static const TableEntry tbl_sqrt[] = { |
305 | {.result: 0.0, .input: 0.0}, |
306 | {.result: 1.0, .input: 1.0}, |
307 | {MATH_SQRT2, .input: 2.0} |
308 | }; |
309 | static const TableEntry tbl_tan[] = { |
310 | {.result: 0.0, .input: 0.0}, |
311 | {.result: -0.0, .input: -0.0} |
312 | }; |
313 | static const TableEntry tbl_tanh[] = { |
314 | {.result: 0.0, .input: 0.0}, |
315 | {.result: -0.0, .input: -0.0} |
316 | }; |
317 | static const TableEntry tbl_tanpi[] = { |
318 | {.result: 0.0, .input: 0.0}, |
319 | {.result: -0.0, .input: -0.0} |
320 | }; |
321 | static const TableEntry tbl_tgamma[] = { |
322 | {.result: 1.0, .input: 1.0}, |
323 | {.result: 1.0, .input: 2.0}, |
324 | {.result: 2.0, .input: 3.0}, |
325 | {.result: 6.0, .input: 4.0} |
326 | }; |
327 | |
328 | static bool HasNative(AMDGPULibFunc::EFuncId id) { |
329 | switch(id) { |
330 | case AMDGPULibFunc::EI_DIVIDE: |
331 | case AMDGPULibFunc::EI_COS: |
332 | case AMDGPULibFunc::EI_EXP: |
333 | case AMDGPULibFunc::EI_EXP2: |
334 | case AMDGPULibFunc::EI_EXP10: |
335 | case AMDGPULibFunc::EI_LOG: |
336 | case AMDGPULibFunc::EI_LOG2: |
337 | case AMDGPULibFunc::EI_LOG10: |
338 | case AMDGPULibFunc::EI_POWR: |
339 | case AMDGPULibFunc::EI_RECIP: |
340 | case AMDGPULibFunc::EI_RSQRT: |
341 | case AMDGPULibFunc::EI_SIN: |
342 | case AMDGPULibFunc::EI_SINCOS: |
343 | case AMDGPULibFunc::EI_SQRT: |
344 | case AMDGPULibFunc::EI_TAN: |
345 | return true; |
346 | default:; |
347 | } |
348 | return false; |
349 | } |
350 | |
351 | using TableRef = ArrayRef<TableEntry>; |
352 | |
353 | static TableRef getOptTable(AMDGPULibFunc::EFuncId id) { |
354 | switch(id) { |
355 | case AMDGPULibFunc::EI_ACOS: return TableRef(tbl_acos); |
356 | case AMDGPULibFunc::EI_ACOSH: return TableRef(tbl_acosh); |
357 | case AMDGPULibFunc::EI_ACOSPI: return TableRef(tbl_acospi); |
358 | case AMDGPULibFunc::EI_ASIN: return TableRef(tbl_asin); |
359 | case AMDGPULibFunc::EI_ASINH: return TableRef(tbl_asinh); |
360 | case AMDGPULibFunc::EI_ASINPI: return TableRef(tbl_asinpi); |
361 | case AMDGPULibFunc::EI_ATAN: return TableRef(tbl_atan); |
362 | case AMDGPULibFunc::EI_ATANH: return TableRef(tbl_atanh); |
363 | case AMDGPULibFunc::EI_ATANPI: return TableRef(tbl_atanpi); |
364 | case AMDGPULibFunc::EI_CBRT: return TableRef(tbl_cbrt); |
365 | case AMDGPULibFunc::EI_NCOS: |
366 | case AMDGPULibFunc::EI_COS: return TableRef(tbl_cos); |
367 | case AMDGPULibFunc::EI_COSH: return TableRef(tbl_cosh); |
368 | case AMDGPULibFunc::EI_COSPI: return TableRef(tbl_cospi); |
369 | case AMDGPULibFunc::EI_ERFC: return TableRef(tbl_erfc); |
370 | case AMDGPULibFunc::EI_ERF: return TableRef(tbl_erf); |
371 | case AMDGPULibFunc::EI_EXP: return TableRef(tbl_exp); |
372 | case AMDGPULibFunc::EI_NEXP2: |
373 | case AMDGPULibFunc::EI_EXP2: return TableRef(tbl_exp2); |
374 | case AMDGPULibFunc::EI_EXP10: return TableRef(tbl_exp10); |
375 | case AMDGPULibFunc::EI_EXPM1: return TableRef(tbl_expm1); |
376 | case AMDGPULibFunc::EI_LOG: return TableRef(tbl_log); |
377 | case AMDGPULibFunc::EI_NLOG2: |
378 | case AMDGPULibFunc::EI_LOG2: return TableRef(tbl_log2); |
379 | case AMDGPULibFunc::EI_LOG10: return TableRef(tbl_log10); |
380 | case AMDGPULibFunc::EI_NRSQRT: |
381 | case AMDGPULibFunc::EI_RSQRT: return TableRef(tbl_rsqrt); |
382 | case AMDGPULibFunc::EI_NSIN: |
383 | case AMDGPULibFunc::EI_SIN: return TableRef(tbl_sin); |
384 | case AMDGPULibFunc::EI_SINH: return TableRef(tbl_sinh); |
385 | case AMDGPULibFunc::EI_SINPI: return TableRef(tbl_sinpi); |
386 | case AMDGPULibFunc::EI_NSQRT: |
387 | case AMDGPULibFunc::EI_SQRT: return TableRef(tbl_sqrt); |
388 | case AMDGPULibFunc::EI_TAN: return TableRef(tbl_tan); |
389 | case AMDGPULibFunc::EI_TANH: return TableRef(tbl_tanh); |
390 | case AMDGPULibFunc::EI_TANPI: return TableRef(tbl_tanpi); |
391 | case AMDGPULibFunc::EI_TGAMMA: return TableRef(tbl_tgamma); |
392 | default:; |
393 | } |
394 | return TableRef(); |
395 | } |
396 | |
397 | static inline int getVecSize(const AMDGPULibFunc& FInfo) { |
398 | return FInfo.getLeads()[0].VectorSize; |
399 | } |
400 | |
401 | static inline AMDGPULibFunc::EType getArgType(const AMDGPULibFunc& FInfo) { |
402 | return (AMDGPULibFunc::EType)FInfo.getLeads()[0].ArgType; |
403 | } |
404 | |
405 | FunctionCallee AMDGPULibCalls::getFunction(Module *M, const FuncInfo &fInfo) { |
406 | // If we are doing PreLinkOpt, the function is external. So it is safe to |
407 | // use getOrInsertFunction() at this stage. |
408 | |
409 | return EnablePreLink ? AMDGPULibFunc::getOrInsertFunction(M, fInfo) |
410 | : AMDGPULibFunc::getFunction(M, fInfo); |
411 | } |
412 | |
413 | bool AMDGPULibCalls::parseFunctionName(const StringRef &FMangledName, |
414 | FuncInfo &FInfo) { |
415 | return AMDGPULibFunc::parse(MangledName: FMangledName, Ptr&: FInfo); |
416 | } |
417 | |
418 | bool AMDGPULibCalls::isUnsafeMath(const FPMathOperator *FPOp) const { |
419 | return UnsafeFPMath || FPOp->isFast(); |
420 | } |
421 | |
422 | bool AMDGPULibCalls::isUnsafeFiniteOnlyMath(const FPMathOperator *FPOp) const { |
423 | return UnsafeFPMath || |
424 | (FPOp->hasApproxFunc() && FPOp->hasNoNaNs() && FPOp->hasNoInfs()); |
425 | } |
426 | |
427 | bool AMDGPULibCalls::canIncreasePrecisionOfConstantFold( |
428 | const FPMathOperator *FPOp) const { |
429 | // TODO: Refine to approxFunc or contract |
430 | return isUnsafeMath(FPOp); |
431 | } |
432 | |
433 | void AMDGPULibCalls::initFunction(Function &F, FunctionAnalysisManager &FAM) { |
434 | UnsafeFPMath = F.getFnAttribute(Kind: "unsafe-fp-math" ).getValueAsBool(); |
435 | AC = &FAM.getResult<AssumptionAnalysis>(IR&: F); |
436 | TLInfo = &FAM.getResult<TargetLibraryAnalysis>(IR&: F); |
437 | DT = FAM.getCachedResult<DominatorTreeAnalysis>(IR&: F); |
438 | } |
439 | |
440 | bool AMDGPULibCalls::useNativeFunc(const StringRef F) const { |
441 | return AllNative || llvm::is_contained(Range&: UseNative, Element: F); |
442 | } |
443 | |
444 | void AMDGPULibCalls::initNativeFuncs() { |
445 | AllNative = useNativeFunc(F: "all" ) || |
446 | (UseNative.getNumOccurrences() && UseNative.size() == 1 && |
447 | UseNative.begin()->empty()); |
448 | } |
449 | |
450 | bool AMDGPULibCalls::sincosUseNative(CallInst *aCI, const FuncInfo &FInfo) { |
451 | bool native_sin = useNativeFunc(F: "sin" ); |
452 | bool native_cos = useNativeFunc(F: "cos" ); |
453 | |
454 | if (native_sin && native_cos) { |
455 | Module *M = aCI->getModule(); |
456 | Value *opr0 = aCI->getArgOperand(i: 0); |
457 | |
458 | AMDGPULibFunc nf; |
459 | nf.getLeads()[0].ArgType = FInfo.getLeads()[0].ArgType; |
460 | nf.getLeads()[0].VectorSize = FInfo.getLeads()[0].VectorSize; |
461 | |
462 | nf.setPrefix(AMDGPULibFunc::NATIVE); |
463 | nf.setId(AMDGPULibFunc::EI_SIN); |
464 | FunctionCallee sinExpr = getFunction(M, fInfo: nf); |
465 | |
466 | nf.setPrefix(AMDGPULibFunc::NATIVE); |
467 | nf.setId(AMDGPULibFunc::EI_COS); |
468 | FunctionCallee cosExpr = getFunction(M, fInfo: nf); |
469 | if (sinExpr && cosExpr) { |
470 | Value *sinval = |
471 | CallInst::Create(Func: sinExpr, Args: opr0, NameStr: "splitsin" , InsertBefore: aCI->getIterator()); |
472 | Value *cosval = |
473 | CallInst::Create(Func: cosExpr, Args: opr0, NameStr: "splitcos" , InsertBefore: aCI->getIterator()); |
474 | new StoreInst(cosval, aCI->getArgOperand(i: 1), aCI->getIterator()); |
475 | |
476 | DEBUG_WITH_TYPE("usenative" , dbgs() << "<useNative> replace " << *aCI |
477 | << " with native version of sin/cos" ); |
478 | |
479 | replaceCall(I: aCI, With: sinval); |
480 | return true; |
481 | } |
482 | } |
483 | return false; |
484 | } |
485 | |
486 | bool AMDGPULibCalls::useNative(CallInst *aCI) { |
487 | Function *Callee = aCI->getCalledFunction(); |
488 | if (!Callee || aCI->isNoBuiltin()) |
489 | return false; |
490 | |
491 | FuncInfo FInfo; |
492 | if (!parseFunctionName(FMangledName: Callee->getName(), FInfo) || !FInfo.isMangled() || |
493 | FInfo.getPrefix() != AMDGPULibFunc::NOPFX || |
494 | getArgType(FInfo) == AMDGPULibFunc::F64 || !HasNative(id: FInfo.getId()) || |
495 | !(AllNative || useNativeFunc(F: FInfo.getName()))) { |
496 | return false; |
497 | } |
498 | |
499 | if (FInfo.getId() == AMDGPULibFunc::EI_SINCOS) |
500 | return sincosUseNative(aCI, FInfo); |
501 | |
502 | FInfo.setPrefix(AMDGPULibFunc::NATIVE); |
503 | FunctionCallee F = getFunction(M: aCI->getModule(), fInfo: FInfo); |
504 | if (!F) |
505 | return false; |
506 | |
507 | aCI->setCalledFunction(F); |
508 | DEBUG_WITH_TYPE("usenative" , dbgs() << "<useNative> replace " << *aCI |
509 | << " with native version" ); |
510 | return true; |
511 | } |
512 | |
513 | // Clang emits call of __read_pipe_2 or __read_pipe_4 for OpenCL read_pipe |
514 | // builtin, with appended type size and alignment arguments, where 2 or 4 |
515 | // indicates the original number of arguments. The library has optimized version |
516 | // of __read_pipe_2/__read_pipe_4 when the type size and alignment has the same |
517 | // power of 2 value. This function transforms __read_pipe_2 to __read_pipe_2_N |
518 | // for such cases where N is the size in bytes of the type (N = 1, 2, 4, 8, ..., |
519 | // 128). The same for __read_pipe_4, write_pipe_2, and write_pipe_4. |
520 | bool AMDGPULibCalls::fold_read_write_pipe(CallInst *CI, IRBuilder<> &B, |
521 | const FuncInfo &FInfo) { |
522 | auto *Callee = CI->getCalledFunction(); |
523 | if (!Callee->isDeclaration()) |
524 | return false; |
525 | |
526 | assert(Callee->hasName() && "Invalid read_pipe/write_pipe function" ); |
527 | auto *M = Callee->getParent(); |
528 | std::string Name = std::string(Callee->getName()); |
529 | auto NumArg = CI->arg_size(); |
530 | if (NumArg != 4 && NumArg != 6) |
531 | return false; |
532 | ConstantInt *PacketSize = |
533 | dyn_cast<ConstantInt>(Val: CI->getArgOperand(i: NumArg - 2)); |
534 | ConstantInt *PacketAlign = |
535 | dyn_cast<ConstantInt>(Val: CI->getArgOperand(i: NumArg - 1)); |
536 | if (!PacketSize || !PacketAlign) |
537 | return false; |
538 | |
539 | unsigned Size = PacketSize->getZExtValue(); |
540 | Align Alignment = PacketAlign->getAlignValue(); |
541 | if (Alignment != Size) |
542 | return false; |
543 | |
544 | unsigned PtrArgLoc = CI->arg_size() - 3; |
545 | Value *PtrArg = CI->getArgOperand(i: PtrArgLoc); |
546 | Type *PtrTy = PtrArg->getType(); |
547 | |
548 | SmallVector<llvm::Type *, 6> ArgTys; |
549 | for (unsigned I = 0; I != PtrArgLoc; ++I) |
550 | ArgTys.push_back(Elt: CI->getArgOperand(i: I)->getType()); |
551 | ArgTys.push_back(Elt: PtrTy); |
552 | |
553 | Name = Name + "_" + std::to_string(val: Size); |
554 | auto *FTy = FunctionType::get(Result: Callee->getReturnType(), |
555 | Params: ArrayRef<Type *>(ArgTys), isVarArg: false); |
556 | AMDGPULibFunc NewLibFunc(Name, FTy); |
557 | FunctionCallee F = AMDGPULibFunc::getOrInsertFunction(M, fInfo: NewLibFunc); |
558 | if (!F) |
559 | return false; |
560 | |
561 | SmallVector<Value *, 6> Args; |
562 | for (unsigned I = 0; I != PtrArgLoc; ++I) |
563 | Args.push_back(Elt: CI->getArgOperand(i: I)); |
564 | Args.push_back(Elt: PtrArg); |
565 | |
566 | auto *NCI = B.CreateCall(Callee: F, Args); |
567 | NCI->setAttributes(CI->getAttributes()); |
568 | CI->replaceAllUsesWith(V: NCI); |
569 | CI->dropAllReferences(); |
570 | CI->eraseFromParent(); |
571 | |
572 | return true; |
573 | } |
574 | |
575 | static bool isKnownIntegral(const Value *V, const DataLayout &DL, |
576 | FastMathFlags FMF) { |
577 | if (isa<PoisonValue>(Val: V)) |
578 | return true; |
579 | if (isa<UndefValue>(Val: V)) |
580 | return false; |
581 | |
582 | if (const ConstantFP *CF = dyn_cast<ConstantFP>(Val: V)) |
583 | return CF->getValueAPF().isInteger(); |
584 | |
585 | auto *VFVTy = dyn_cast<FixedVectorType>(Val: V->getType()); |
586 | const Constant *CV = dyn_cast<Constant>(Val: V); |
587 | if (VFVTy && CV) { |
588 | unsigned NumElts = VFVTy->getNumElements(); |
589 | for (unsigned i = 0; i != NumElts; ++i) { |
590 | Constant *Elt = CV->getAggregateElement(Elt: i); |
591 | if (!Elt) |
592 | return false; |
593 | if (isa<PoisonValue>(Val: Elt)) |
594 | continue; |
595 | |
596 | const ConstantFP *CFP = dyn_cast<ConstantFP>(Val: Elt); |
597 | if (!CFP || !CFP->getValue().isInteger()) |
598 | return false; |
599 | } |
600 | |
601 | return true; |
602 | } |
603 | |
604 | const Instruction *I = dyn_cast<Instruction>(Val: V); |
605 | if (!I) |
606 | return false; |
607 | |
608 | switch (I->getOpcode()) { |
609 | case Instruction::SIToFP: |
610 | case Instruction::UIToFP: |
611 | // TODO: Could check nofpclass(inf) on incoming argument |
612 | if (FMF.noInfs()) |
613 | return true; |
614 | |
615 | // Need to check int size cannot produce infinity, which computeKnownFPClass |
616 | // knows how to do already. |
617 | return isKnownNeverInfinity(V: I, SQ: SimplifyQuery(DL)); |
618 | case Instruction::Call: { |
619 | const CallInst *CI = cast<CallInst>(Val: I); |
620 | switch (CI->getIntrinsicID()) { |
621 | case Intrinsic::trunc: |
622 | case Intrinsic::floor: |
623 | case Intrinsic::ceil: |
624 | case Intrinsic::rint: |
625 | case Intrinsic::nearbyint: |
626 | case Intrinsic::round: |
627 | case Intrinsic::roundeven: |
628 | return (FMF.noInfs() && FMF.noNaNs()) || |
629 | isKnownNeverInfOrNaN(V: I, SQ: SimplifyQuery(DL)); |
630 | default: |
631 | break; |
632 | } |
633 | |
634 | break; |
635 | } |
636 | default: |
637 | break; |
638 | } |
639 | |
640 | return false; |
641 | } |
642 | |
643 | // This function returns false if no change; return true otherwise. |
644 | bool AMDGPULibCalls::fold(CallInst *CI) { |
645 | Function *Callee = CI->getCalledFunction(); |
646 | // Ignore indirect calls. |
647 | if (!Callee || Callee->isIntrinsic() || CI->isNoBuiltin()) |
648 | return false; |
649 | |
650 | FuncInfo FInfo; |
651 | if (!parseFunctionName(FMangledName: Callee->getName(), FInfo)) |
652 | return false; |
653 | |
654 | // Further check the number of arguments to see if they match. |
655 | // TODO: Check calling convention matches too |
656 | if (!FInfo.isCompatibleSignature(M: *Callee->getParent(), FuncTy: CI->getFunctionType())) |
657 | return false; |
658 | |
659 | LLVM_DEBUG(dbgs() << "AMDIC: try folding " << *CI << '\n'); |
660 | |
661 | if (TDOFold(CI, FInfo)) |
662 | return true; |
663 | |
664 | IRBuilder<> B(CI); |
665 | if (CI->isStrictFP()) |
666 | B.setIsFPConstrained(true); |
667 | |
668 | if (FPMathOperator *FPOp = dyn_cast<FPMathOperator>(Val: CI)) { |
669 | // Under unsafe-math, evaluate calls if possible. |
670 | // According to Brian Sumner, we can do this for all f32 function calls |
671 | // using host's double function calls. |
672 | if (canIncreasePrecisionOfConstantFold(FPOp) && evaluateCall(aCI: CI, FInfo)) |
673 | return true; |
674 | |
675 | // Copy fast flags from the original call. |
676 | FastMathFlags FMF = FPOp->getFastMathFlags(); |
677 | B.setFastMathFlags(FMF); |
678 | |
679 | // Specialized optimizations for each function call. |
680 | // |
681 | // TODO: Handle native functions |
682 | switch (FInfo.getId()) { |
683 | case AMDGPULibFunc::EI_EXP: |
684 | if (FMF.none()) |
685 | return false; |
686 | return tryReplaceLibcallWithSimpleIntrinsic(B, CI, IntrID: Intrinsic::exp, |
687 | AllowMinSizeF32: FMF.approxFunc()); |
688 | case AMDGPULibFunc::EI_EXP2: |
689 | if (FMF.none()) |
690 | return false; |
691 | return tryReplaceLibcallWithSimpleIntrinsic(B, CI, IntrID: Intrinsic::exp2, |
692 | AllowMinSizeF32: FMF.approxFunc()); |
693 | case AMDGPULibFunc::EI_LOG: |
694 | if (FMF.none()) |
695 | return false; |
696 | return tryReplaceLibcallWithSimpleIntrinsic(B, CI, IntrID: Intrinsic::log, |
697 | AllowMinSizeF32: FMF.approxFunc()); |
698 | case AMDGPULibFunc::EI_LOG2: |
699 | if (FMF.none()) |
700 | return false; |
701 | return tryReplaceLibcallWithSimpleIntrinsic(B, CI, IntrID: Intrinsic::log2, |
702 | AllowMinSizeF32: FMF.approxFunc()); |
703 | case AMDGPULibFunc::EI_LOG10: |
704 | if (FMF.none()) |
705 | return false; |
706 | return tryReplaceLibcallWithSimpleIntrinsic(B, CI, IntrID: Intrinsic::log10, |
707 | AllowMinSizeF32: FMF.approxFunc()); |
708 | case AMDGPULibFunc::EI_FMIN: |
709 | return tryReplaceLibcallWithSimpleIntrinsic(B, CI, IntrID: Intrinsic::minnum, |
710 | AllowMinSizeF32: true, AllowF64: true); |
711 | case AMDGPULibFunc::EI_FMAX: |
712 | return tryReplaceLibcallWithSimpleIntrinsic(B, CI, IntrID: Intrinsic::maxnum, |
713 | AllowMinSizeF32: true, AllowF64: true); |
714 | case AMDGPULibFunc::EI_FMA: |
715 | return tryReplaceLibcallWithSimpleIntrinsic(B, CI, IntrID: Intrinsic::fma, AllowMinSizeF32: true, |
716 | AllowF64: true); |
717 | case AMDGPULibFunc::EI_MAD: |
718 | return tryReplaceLibcallWithSimpleIntrinsic(B, CI, IntrID: Intrinsic::fmuladd, |
719 | AllowMinSizeF32: true, AllowF64: true); |
720 | case AMDGPULibFunc::EI_FABS: |
721 | return tryReplaceLibcallWithSimpleIntrinsic(B, CI, IntrID: Intrinsic::fabs, AllowMinSizeF32: true, |
722 | AllowF64: true, AllowStrictFP: true); |
723 | case AMDGPULibFunc::EI_COPYSIGN: |
724 | return tryReplaceLibcallWithSimpleIntrinsic(B, CI, IntrID: Intrinsic::copysign, |
725 | AllowMinSizeF32: true, AllowF64: true, AllowStrictFP: true); |
726 | case AMDGPULibFunc::EI_FLOOR: |
727 | return tryReplaceLibcallWithSimpleIntrinsic(B, CI, IntrID: Intrinsic::floor, AllowMinSizeF32: true, |
728 | AllowF64: true); |
729 | case AMDGPULibFunc::EI_CEIL: |
730 | return tryReplaceLibcallWithSimpleIntrinsic(B, CI, IntrID: Intrinsic::ceil, AllowMinSizeF32: true, |
731 | AllowF64: true); |
732 | case AMDGPULibFunc::EI_TRUNC: |
733 | return tryReplaceLibcallWithSimpleIntrinsic(B, CI, IntrID: Intrinsic::trunc, AllowMinSizeF32: true, |
734 | AllowF64: true); |
735 | case AMDGPULibFunc::EI_RINT: |
736 | return tryReplaceLibcallWithSimpleIntrinsic(B, CI, IntrID: Intrinsic::rint, AllowMinSizeF32: true, |
737 | AllowF64: true); |
738 | case AMDGPULibFunc::EI_ROUND: |
739 | return tryReplaceLibcallWithSimpleIntrinsic(B, CI, IntrID: Intrinsic::round, AllowMinSizeF32: true, |
740 | AllowF64: true); |
741 | case AMDGPULibFunc::EI_LDEXP: { |
742 | if (!shouldReplaceLibcallWithIntrinsic(CI, AllowMinSizeF32: true, AllowF64: true)) |
743 | return false; |
744 | |
745 | Value *Arg1 = CI->getArgOperand(i: 1); |
746 | if (VectorType *VecTy = dyn_cast<VectorType>(Val: CI->getType()); |
747 | VecTy && !isa<VectorType>(Val: Arg1->getType())) { |
748 | Value *SplatArg1 = B.CreateVectorSplat(EC: VecTy->getElementCount(), V: Arg1); |
749 | CI->setArgOperand(i: 1, v: SplatArg1); |
750 | } |
751 | |
752 | CI->setCalledFunction(Intrinsic::getOrInsertDeclaration( |
753 | M: CI->getModule(), id: Intrinsic::ldexp, |
754 | Tys: {CI->getType(), CI->getArgOperand(i: 1)->getType()})); |
755 | return true; |
756 | } |
757 | case AMDGPULibFunc::EI_POW: { |
758 | Module *M = Callee->getParent(); |
759 | AMDGPULibFunc PowrInfo(AMDGPULibFunc::EI_POWR, FInfo); |
760 | FunctionCallee PowrFunc = getFunction(M, fInfo: PowrInfo); |
761 | CallInst *Call = cast<CallInst>(Val: FPOp); |
762 | |
763 | // pow(x, y) -> powr(x, y) for x >= -0.0 |
764 | // TODO: Account for flags on current call |
765 | if (PowrFunc && |
766 | cannotBeOrderedLessThanZero( |
767 | V: FPOp->getOperand(i: 0), |
768 | SQ: SimplifyQuery(M->getDataLayout(), TLInfo, DT, AC, Call))) { |
769 | Call->setCalledFunction(PowrFunc); |
770 | return fold_pow(FPOp, B, FInfo: PowrInfo) || true; |
771 | } |
772 | |
773 | // pow(x, y) -> pown(x, y) for known integral y |
774 | if (isKnownIntegral(V: FPOp->getOperand(i: 1), DL: M->getDataLayout(), |
775 | FMF: FPOp->getFastMathFlags())) { |
776 | FunctionType *PownType = getPownType(FT: CI->getFunctionType()); |
777 | AMDGPULibFunc PownInfo(AMDGPULibFunc::EI_POWN, PownType, true); |
778 | FunctionCallee PownFunc = getFunction(M, fInfo: PownInfo); |
779 | if (PownFunc) { |
780 | // TODO: If the incoming integral value is an sitofp/uitofp, it won't |
781 | // fold out without a known range. We can probably take the source |
782 | // value directly. |
783 | Value *CastedArg = |
784 | B.CreateFPToSI(V: FPOp->getOperand(i: 1), DestTy: PownType->getParamType(i: 1)); |
785 | // Have to drop any nofpclass attributes on the original call site. |
786 | Call->removeParamAttrs( |
787 | ArgNo: 1, AttrsToRemove: AttributeFuncs::typeIncompatible(Ty: CastedArg->getType(), |
788 | AS: Call->getParamAttributes(ArgNo: 1))); |
789 | Call->setCalledFunction(PownFunc); |
790 | Call->setArgOperand(i: 1, v: CastedArg); |
791 | return fold_pow(FPOp, B, FInfo: PownInfo) || true; |
792 | } |
793 | } |
794 | |
795 | return fold_pow(FPOp, B, FInfo); |
796 | } |
797 | case AMDGPULibFunc::EI_POWR: |
798 | case AMDGPULibFunc::EI_POWN: |
799 | return fold_pow(FPOp, B, FInfo); |
800 | case AMDGPULibFunc::EI_ROOTN: |
801 | return fold_rootn(FPOp, B, FInfo); |
802 | case AMDGPULibFunc::EI_SQRT: |
803 | // TODO: Allow with strictfp + constrained intrinsic |
804 | return tryReplaceLibcallWithSimpleIntrinsic( |
805 | B, CI, IntrID: Intrinsic::sqrt, AllowMinSizeF32: true, AllowF64: true, /*AllowStrictFP=*/false); |
806 | case AMDGPULibFunc::EI_COS: |
807 | case AMDGPULibFunc::EI_SIN: |
808 | return fold_sincos(FPOp, B, FInfo); |
809 | default: |
810 | break; |
811 | } |
812 | } else { |
813 | // Specialized optimizations for each function call |
814 | switch (FInfo.getId()) { |
815 | case AMDGPULibFunc::EI_READ_PIPE_2: |
816 | case AMDGPULibFunc::EI_READ_PIPE_4: |
817 | case AMDGPULibFunc::EI_WRITE_PIPE_2: |
818 | case AMDGPULibFunc::EI_WRITE_PIPE_4: |
819 | return fold_read_write_pipe(CI, B, FInfo); |
820 | default: |
821 | break; |
822 | } |
823 | } |
824 | |
825 | return false; |
826 | } |
827 | |
828 | bool AMDGPULibCalls::TDOFold(CallInst *CI, const FuncInfo &FInfo) { |
829 | // Table-Driven optimization |
830 | const TableRef tr = getOptTable(id: FInfo.getId()); |
831 | if (tr.empty()) |
832 | return false; |
833 | |
834 | int const sz = (int)tr.size(); |
835 | Value *opr0 = CI->getArgOperand(i: 0); |
836 | |
837 | if (getVecSize(FInfo) > 1) { |
838 | if (ConstantDataVector *CV = dyn_cast<ConstantDataVector>(Val: opr0)) { |
839 | SmallVector<double, 0> DVal; |
840 | for (int eltNo = 0; eltNo < getVecSize(FInfo); ++eltNo) { |
841 | ConstantFP *eltval = dyn_cast<ConstantFP>( |
842 | Val: CV->getElementAsConstant(i: (unsigned)eltNo)); |
843 | assert(eltval && "Non-FP arguments in math function!" ); |
844 | bool found = false; |
845 | for (int i=0; i < sz; ++i) { |
846 | if (eltval->isExactlyValue(V: tr[i].input)) { |
847 | DVal.push_back(Elt: tr[i].result); |
848 | found = true; |
849 | break; |
850 | } |
851 | } |
852 | if (!found) { |
853 | // This vector constants not handled yet. |
854 | return false; |
855 | } |
856 | } |
857 | LLVMContext &context = CI->getParent()->getParent()->getContext(); |
858 | Constant *nval; |
859 | if (getArgType(FInfo) == AMDGPULibFunc::F32) { |
860 | SmallVector<float, 0> FVal; |
861 | for (double D : DVal) |
862 | FVal.push_back(Elt: (float)D); |
863 | ArrayRef<float> tmp(FVal); |
864 | nval = ConstantDataVector::get(Context&: context, Elts: tmp); |
865 | } else { // F64 |
866 | ArrayRef<double> tmp(DVal); |
867 | nval = ConstantDataVector::get(Context&: context, Elts: tmp); |
868 | } |
869 | LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *nval << "\n" ); |
870 | replaceCall(I: CI, With: nval); |
871 | return true; |
872 | } |
873 | } else { |
874 | // Scalar version |
875 | if (ConstantFP *CF = dyn_cast<ConstantFP>(Val: opr0)) { |
876 | for (int i = 0; i < sz; ++i) { |
877 | if (CF->isExactlyValue(V: tr[i].input)) { |
878 | Value *nval = ConstantFP::get(Ty: CF->getType(), V: tr[i].result); |
879 | LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *nval << "\n" ); |
880 | replaceCall(I: CI, With: nval); |
881 | return true; |
882 | } |
883 | } |
884 | } |
885 | } |
886 | |
887 | return false; |
888 | } |
889 | |
890 | namespace llvm { |
891 | static double log2(double V) { |
892 | #if _XOPEN_SOURCE >= 600 || defined(_ISOC99_SOURCE) || _POSIX_C_SOURCE >= 200112L |
893 | return ::log2(x: V); |
894 | #else |
895 | return log(V) / numbers::ln2; |
896 | #endif |
897 | } |
898 | } // namespace llvm |
899 | |
900 | bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B, |
901 | const FuncInfo &FInfo) { |
902 | assert((FInfo.getId() == AMDGPULibFunc::EI_POW || |
903 | FInfo.getId() == AMDGPULibFunc::EI_POWR || |
904 | FInfo.getId() == AMDGPULibFunc::EI_POWN) && |
905 | "fold_pow: encounter a wrong function call" ); |
906 | |
907 | Module *M = B.GetInsertBlock()->getModule(); |
908 | Type *eltType = FPOp->getType()->getScalarType(); |
909 | Value *opr0 = FPOp->getOperand(i: 0); |
910 | Value *opr1 = FPOp->getOperand(i: 1); |
911 | |
912 | const APFloat *CF = nullptr; |
913 | const APInt *CINT = nullptr; |
914 | if (!match(V: opr1, P: m_APFloatAllowPoison(Res&: CF))) |
915 | match(V: opr1, P: m_APIntAllowPoison(Res&: CINT)); |
916 | |
917 | // 0x1111111 means that we don't do anything for this call. |
918 | int ci_opr1 = (CINT ? (int)CINT->getSExtValue() : 0x1111111); |
919 | |
920 | if ((CF && CF->isZero()) || (CINT && ci_opr1 == 0)) { |
921 | // pow/powr/pown(x, 0) == 1 |
922 | LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> 1\n" ); |
923 | Constant *cnval = ConstantFP::get(Ty: eltType, V: 1.0); |
924 | if (getVecSize(FInfo) > 1) { |
925 | cnval = ConstantDataVector::getSplat(NumElts: getVecSize(FInfo), Elt: cnval); |
926 | } |
927 | replaceCall(I: FPOp, With: cnval); |
928 | return true; |
929 | } |
930 | if ((CF && CF->isExactlyValue(V: 1.0)) || (CINT && ci_opr1 == 1)) { |
931 | // pow/powr/pown(x, 1.0) = x |
932 | LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << *opr0 << "\n" ); |
933 | replaceCall(I: FPOp, With: opr0); |
934 | return true; |
935 | } |
936 | if ((CF && CF->isExactlyValue(V: 2.0)) || (CINT && ci_opr1 == 2)) { |
937 | // pow/powr/pown(x, 2.0) = x*x |
938 | LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << *opr0 << " * " |
939 | << *opr0 << "\n" ); |
940 | Value *nval = B.CreateFMul(L: opr0, R: opr0, Name: "__pow2" ); |
941 | replaceCall(I: FPOp, With: nval); |
942 | return true; |
943 | } |
944 | if ((CF && CF->isExactlyValue(V: -1.0)) || (CINT && ci_opr1 == -1)) { |
945 | // pow/powr/pown(x, -1.0) = 1.0/x |
946 | LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> 1 / " << *opr0 << "\n" ); |
947 | Constant *cnval = ConstantFP::get(Ty: eltType, V: 1.0); |
948 | if (getVecSize(FInfo) > 1) { |
949 | cnval = ConstantDataVector::getSplat(NumElts: getVecSize(FInfo), Elt: cnval); |
950 | } |
951 | Value *nval = B.CreateFDiv(L: cnval, R: opr0, Name: "__powrecip" ); |
952 | replaceCall(I: FPOp, With: nval); |
953 | return true; |
954 | } |
955 | |
956 | if (CF && (CF->isExactlyValue(V: 0.5) || CF->isExactlyValue(V: -0.5))) { |
957 | // pow[r](x, [-]0.5) = sqrt(x) |
958 | bool issqrt = CF->isExactlyValue(V: 0.5); |
959 | if (FunctionCallee FPExpr = |
960 | getFunction(M, fInfo: AMDGPULibFunc(issqrt ? AMDGPULibFunc::EI_SQRT |
961 | : AMDGPULibFunc::EI_RSQRT, |
962 | FInfo))) { |
963 | LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << FInfo.getName() |
964 | << '(' << *opr0 << ")\n" ); |
965 | Value *nval = CreateCallEx(B,Callee: FPExpr, Arg: opr0, Name: issqrt ? "__pow2sqrt" |
966 | : "__pow2rsqrt" ); |
967 | replaceCall(I: FPOp, With: nval); |
968 | return true; |
969 | } |
970 | } |
971 | |
972 | if (!isUnsafeFiniteOnlyMath(FPOp)) |
973 | return false; |
974 | |
975 | // Unsafe Math optimization |
976 | |
977 | // Remember that ci_opr1 is set if opr1 is integral |
978 | if (CF) { |
979 | double dval = (getArgType(FInfo) == AMDGPULibFunc::F32) |
980 | ? (double)CF->convertToFloat() |
981 | : CF->convertToDouble(); |
982 | int ival = (int)dval; |
983 | if ((double)ival == dval) { |
984 | ci_opr1 = ival; |
985 | } else |
986 | ci_opr1 = 0x11111111; |
987 | } |
988 | |
989 | // pow/powr/pown(x, c) = [1/](x*x*..x); where |
990 | // trunc(c) == c && the number of x == c && |c| <= 12 |
991 | unsigned abs_opr1 = (ci_opr1 < 0) ? -ci_opr1 : ci_opr1; |
992 | if (abs_opr1 <= 12) { |
993 | Constant *cnval; |
994 | Value *nval; |
995 | if (abs_opr1 == 0) { |
996 | cnval = ConstantFP::get(Ty: eltType, V: 1.0); |
997 | if (getVecSize(FInfo) > 1) { |
998 | cnval = ConstantDataVector::getSplat(NumElts: getVecSize(FInfo), Elt: cnval); |
999 | } |
1000 | nval = cnval; |
1001 | } else { |
1002 | Value *valx2 = nullptr; |
1003 | nval = nullptr; |
1004 | while (abs_opr1 > 0) { |
1005 | valx2 = valx2 ? B.CreateFMul(L: valx2, R: valx2, Name: "__powx2" ) : opr0; |
1006 | if (abs_opr1 & 1) { |
1007 | nval = nval ? B.CreateFMul(L: nval, R: valx2, Name: "__powprod" ) : valx2; |
1008 | } |
1009 | abs_opr1 >>= 1; |
1010 | } |
1011 | } |
1012 | |
1013 | if (ci_opr1 < 0) { |
1014 | cnval = ConstantFP::get(Ty: eltType, V: 1.0); |
1015 | if (getVecSize(FInfo) > 1) { |
1016 | cnval = ConstantDataVector::getSplat(NumElts: getVecSize(FInfo), Elt: cnval); |
1017 | } |
1018 | nval = B.CreateFDiv(L: cnval, R: nval, Name: "__1powprod" ); |
1019 | } |
1020 | LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " |
1021 | << ((ci_opr1 < 0) ? "1/prod(" : "prod(" ) << *opr0 |
1022 | << ")\n" ); |
1023 | replaceCall(I: FPOp, With: nval); |
1024 | return true; |
1025 | } |
1026 | |
1027 | // If we should use the generic intrinsic instead of emitting a libcall |
1028 | const bool ShouldUseIntrinsic = eltType->isFloatTy() || eltType->isHalfTy(); |
1029 | |
1030 | // powr ---> exp2(y * log2(x)) |
1031 | // pown/pow ---> powr(fabs(x), y) | (x & ((int)y << 31)) |
1032 | FunctionCallee ExpExpr; |
1033 | if (ShouldUseIntrinsic) |
1034 | ExpExpr = Intrinsic::getOrInsertDeclaration(M, id: Intrinsic::exp2, |
1035 | Tys: {FPOp->getType()}); |
1036 | else { |
1037 | ExpExpr = getFunction(M, fInfo: AMDGPULibFunc(AMDGPULibFunc::EI_EXP2, FInfo)); |
1038 | if (!ExpExpr) |
1039 | return false; |
1040 | } |
1041 | |
1042 | bool needlog = false; |
1043 | bool needabs = false; |
1044 | bool needcopysign = false; |
1045 | Constant *cnval = nullptr; |
1046 | if (getVecSize(FInfo) == 1) { |
1047 | CF = nullptr; |
1048 | match(V: opr0, P: m_APFloatAllowPoison(Res&: CF)); |
1049 | |
1050 | if (CF) { |
1051 | double V = (getArgType(FInfo) == AMDGPULibFunc::F32) |
1052 | ? (double)CF->convertToFloat() |
1053 | : CF->convertToDouble(); |
1054 | |
1055 | V = log2(V: std::abs(x: V)); |
1056 | cnval = ConstantFP::get(Ty: eltType, V); |
1057 | needcopysign = (FInfo.getId() != AMDGPULibFunc::EI_POWR) && |
1058 | CF->isNegative(); |
1059 | } else { |
1060 | needlog = true; |
1061 | needcopysign = needabs = FInfo.getId() != AMDGPULibFunc::EI_POWR; |
1062 | } |
1063 | } else { |
1064 | ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(Val: opr0); |
1065 | |
1066 | if (!CDV) { |
1067 | needlog = true; |
1068 | needcopysign = needabs = FInfo.getId() != AMDGPULibFunc::EI_POWR; |
1069 | } else { |
1070 | assert ((int)CDV->getNumElements() == getVecSize(FInfo) && |
1071 | "Wrong vector size detected" ); |
1072 | |
1073 | SmallVector<double, 0> DVal; |
1074 | for (int i=0; i < getVecSize(FInfo); ++i) { |
1075 | double V = CDV->getElementAsAPFloat(i).convertToDouble(); |
1076 | if (V < 0.0) needcopysign = true; |
1077 | V = log2(V: std::abs(x: V)); |
1078 | DVal.push_back(Elt: V); |
1079 | } |
1080 | if (getArgType(FInfo) == AMDGPULibFunc::F32) { |
1081 | SmallVector<float, 0> FVal; |
1082 | for (double D : DVal) |
1083 | FVal.push_back(Elt: (float)D); |
1084 | ArrayRef<float> tmp(FVal); |
1085 | cnval = ConstantDataVector::get(Context&: M->getContext(), Elts: tmp); |
1086 | } else { |
1087 | ArrayRef<double> tmp(DVal); |
1088 | cnval = ConstantDataVector::get(Context&: M->getContext(), Elts: tmp); |
1089 | } |
1090 | } |
1091 | } |
1092 | |
1093 | if (needcopysign && (FInfo.getId() == AMDGPULibFunc::EI_POW)) { |
1094 | // We cannot handle corner cases for a general pow() function, give up |
1095 | // unless y is a constant integral value. Then proceed as if it were pown. |
1096 | if (!isKnownIntegral(V: opr1, DL: M->getDataLayout(), FMF: FPOp->getFastMathFlags())) |
1097 | return false; |
1098 | } |
1099 | |
1100 | Value *nval; |
1101 | if (needabs) { |
1102 | nval = B.CreateUnaryIntrinsic(ID: Intrinsic::fabs, V: opr0, FMFSource: nullptr, Name: "__fabs" ); |
1103 | } else { |
1104 | nval = cnval ? cnval : opr0; |
1105 | } |
1106 | if (needlog) { |
1107 | FunctionCallee LogExpr; |
1108 | if (ShouldUseIntrinsic) { |
1109 | LogExpr = Intrinsic::getOrInsertDeclaration(M, id: Intrinsic::log2, |
1110 | Tys: {FPOp->getType()}); |
1111 | } else { |
1112 | LogExpr = getFunction(M, fInfo: AMDGPULibFunc(AMDGPULibFunc::EI_LOG2, FInfo)); |
1113 | if (!LogExpr) |
1114 | return false; |
1115 | } |
1116 | |
1117 | nval = CreateCallEx(B,Callee: LogExpr, Arg: nval, Name: "__log2" ); |
1118 | } |
1119 | |
1120 | if (FInfo.getId() == AMDGPULibFunc::EI_POWN) { |
1121 | // convert int(32) to fp(f32 or f64) |
1122 | opr1 = B.CreateSIToFP(V: opr1, DestTy: nval->getType(), Name: "pownI2F" ); |
1123 | } |
1124 | nval = B.CreateFMul(L: opr1, R: nval, Name: "__ylogx" ); |
1125 | nval = CreateCallEx(B,Callee: ExpExpr, Arg: nval, Name: "__exp2" ); |
1126 | |
1127 | if (needcopysign) { |
1128 | Type* nTyS = B.getIntNTy(N: eltType->getPrimitiveSizeInBits()); |
1129 | Type *nTy = FPOp->getType()->getWithNewType(EltTy: nTyS); |
1130 | unsigned size = nTy->getScalarSizeInBits(); |
1131 | Value *opr_n = FPOp->getOperand(i: 1); |
1132 | if (opr_n->getType()->getScalarType()->isIntegerTy()) |
1133 | opr_n = B.CreateZExtOrTrunc(V: opr_n, DestTy: nTy, Name: "__ytou" ); |
1134 | else |
1135 | opr_n = B.CreateFPToSI(V: opr1, DestTy: nTy, Name: "__ytou" ); |
1136 | |
1137 | Value *sign = B.CreateShl(LHS: opr_n, RHS: size-1, Name: "__yeven" ); |
1138 | sign = B.CreateAnd(LHS: B.CreateBitCast(V: opr0, DestTy: nTy), RHS: sign, Name: "__pow_sign" ); |
1139 | nval = B.CreateOr(LHS: B.CreateBitCast(V: nval, DestTy: nTy), RHS: sign); |
1140 | nval = B.CreateBitCast(V: nval, DestTy: opr0->getType()); |
1141 | } |
1142 | |
1143 | LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " |
1144 | << "exp2(" << *opr1 << " * log2(" << *opr0 << "))\n" ); |
1145 | replaceCall(I: FPOp, With: nval); |
1146 | |
1147 | return true; |
1148 | } |
1149 | |
1150 | bool AMDGPULibCalls::fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B, |
1151 | const FuncInfo &FInfo) { |
1152 | Value *opr0 = FPOp->getOperand(i: 0); |
1153 | Value *opr1 = FPOp->getOperand(i: 1); |
1154 | |
1155 | const APInt *CINT = nullptr; |
1156 | if (!match(V: opr1, P: m_APIntAllowPoison(Res&: CINT))) |
1157 | return false; |
1158 | |
1159 | Function *Parent = B.GetInsertBlock()->getParent(); |
1160 | |
1161 | int ci_opr1 = (int)CINT->getSExtValue(); |
1162 | if (ci_opr1 == 1 && !Parent->hasFnAttribute(Kind: Attribute::StrictFP)) { |
1163 | // rootn(x, 1) = x |
1164 | // |
1165 | // TODO: Insert constrained canonicalize for strictfp case. |
1166 | LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << *opr0 << '\n'); |
1167 | replaceCall(I: FPOp, With: opr0); |
1168 | return true; |
1169 | } |
1170 | |
1171 | Module *M = B.GetInsertBlock()->getModule(); |
1172 | |
1173 | CallInst *CI = cast<CallInst>(Val: FPOp); |
1174 | if (ci_opr1 == 2 && |
1175 | shouldReplaceLibcallWithIntrinsic(CI, |
1176 | /*AllowMinSizeF32=*/true, |
1177 | /*AllowF64=*/true)) { |
1178 | // rootn(x, 2) = sqrt(x) |
1179 | LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> sqrt(" << *opr0 << ")\n" ); |
1180 | |
1181 | CallInst *NewCall = B.CreateUnaryIntrinsic(ID: Intrinsic::sqrt, V: opr0, FMFSource: CI); |
1182 | NewCall->takeName(V: CI); |
1183 | |
1184 | // OpenCL rootn has a looser ulp of 2 requirement than sqrt, so add some |
1185 | // metadata. |
1186 | MDBuilder MDHelper(M->getContext()); |
1187 | MDNode *FPMD = MDHelper.createFPMath(Accuracy: std::max(a: FPOp->getFPAccuracy(), b: 2.0f)); |
1188 | NewCall->setMetadata(KindID: LLVMContext::MD_fpmath, Node: FPMD); |
1189 | |
1190 | replaceCall(I: CI, With: NewCall); |
1191 | return true; |
1192 | } |
1193 | |
1194 | if (ci_opr1 == 3) { // rootn(x, 3) = cbrt(x) |
1195 | if (FunctionCallee FPExpr = |
1196 | getFunction(M, fInfo: AMDGPULibFunc(AMDGPULibFunc::EI_CBRT, FInfo))) { |
1197 | LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> cbrt(" << *opr0 |
1198 | << ")\n" ); |
1199 | Value *nval = CreateCallEx(B,Callee: FPExpr, Arg: opr0, Name: "__rootn2cbrt" ); |
1200 | replaceCall(I: FPOp, With: nval); |
1201 | return true; |
1202 | } |
1203 | } else if (ci_opr1 == -1) { // rootn(x, -1) = 1.0/x |
1204 | LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> 1.0 / " << *opr0 << "\n" ); |
1205 | Value *nval = B.CreateFDiv(L: ConstantFP::get(Ty: opr0->getType(), V: 1.0), |
1206 | R: opr0, |
1207 | Name: "__rootn2div" ); |
1208 | replaceCall(I: FPOp, With: nval); |
1209 | return true; |
1210 | } |
1211 | |
1212 | if (ci_opr1 == -2 && |
1213 | shouldReplaceLibcallWithIntrinsic(CI, |
1214 | /*AllowMinSizeF32=*/true, |
1215 | /*AllowF64=*/true)) { |
1216 | // rootn(x, -2) = rsqrt(x) |
1217 | |
1218 | // The original rootn had looser ulp requirements than the resultant sqrt |
1219 | // and fdiv. |
1220 | MDBuilder MDHelper(M->getContext()); |
1221 | MDNode *FPMD = MDHelper.createFPMath(Accuracy: std::max(a: FPOp->getFPAccuracy(), b: 2.0f)); |
1222 | |
1223 | // TODO: Could handle strictfp but need to fix strict sqrt emission |
1224 | FastMathFlags FMF = FPOp->getFastMathFlags(); |
1225 | FMF.setAllowContract(true); |
1226 | |
1227 | CallInst *Sqrt = B.CreateUnaryIntrinsic(ID: Intrinsic::sqrt, V: opr0, FMFSource: CI); |
1228 | Instruction *RSqrt = cast<Instruction>( |
1229 | Val: B.CreateFDiv(L: ConstantFP::get(Ty: opr0->getType(), V: 1.0), R: Sqrt)); |
1230 | Sqrt->setFastMathFlags(FMF); |
1231 | RSqrt->setFastMathFlags(FMF); |
1232 | RSqrt->setMetadata(KindID: LLVMContext::MD_fpmath, Node: FPMD); |
1233 | |
1234 | LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> rsqrt(" << *opr0 |
1235 | << ")\n" ); |
1236 | replaceCall(I: CI, With: RSqrt); |
1237 | return true; |
1238 | } |
1239 | |
1240 | return false; |
1241 | } |
1242 | |
1243 | // Get a scalar native builtin single argument FP function |
1244 | FunctionCallee AMDGPULibCalls::getNativeFunction(Module *M, |
1245 | const FuncInfo &FInfo) { |
1246 | if (getArgType(FInfo) == AMDGPULibFunc::F64 || !HasNative(id: FInfo.getId())) |
1247 | return nullptr; |
1248 | FuncInfo nf = FInfo; |
1249 | nf.setPrefix(AMDGPULibFunc::NATIVE); |
1250 | return getFunction(M, fInfo: nf); |
1251 | } |
1252 | |
1253 | // Some library calls are just wrappers around llvm intrinsics, but compiled |
1254 | // conservatively. Preserve the flags from the original call site by |
1255 | // substituting them with direct calls with all the flags. |
1256 | bool AMDGPULibCalls::shouldReplaceLibcallWithIntrinsic(const CallInst *CI, |
1257 | bool AllowMinSizeF32, |
1258 | bool AllowF64, |
1259 | bool AllowStrictFP) { |
1260 | Type *FltTy = CI->getType()->getScalarType(); |
1261 | const bool IsF32 = FltTy->isFloatTy(); |
1262 | |
1263 | // f64 intrinsics aren't implemented for most operations. |
1264 | if (!IsF32 && !FltTy->isHalfTy() && (!AllowF64 || !FltTy->isDoubleTy())) |
1265 | return false; |
1266 | |
1267 | // We're implicitly inlining by replacing the libcall with the intrinsic, so |
1268 | // don't do it for noinline call sites. |
1269 | if (CI->isNoInline()) |
1270 | return false; |
1271 | |
1272 | const Function *ParentF = CI->getFunction(); |
1273 | // TODO: Handle strictfp |
1274 | if (!AllowStrictFP && ParentF->hasFnAttribute(Kind: Attribute::StrictFP)) |
1275 | return false; |
1276 | |
1277 | if (IsF32 && !AllowMinSizeF32 && ParentF->hasMinSize()) |
1278 | return false; |
1279 | return true; |
1280 | } |
1281 | |
1282 | void AMDGPULibCalls::replaceLibCallWithSimpleIntrinsic(IRBuilder<> &B, |
1283 | CallInst *CI, |
1284 | Intrinsic::ID IntrID) { |
1285 | if (CI->arg_size() == 2) { |
1286 | Value *Arg0 = CI->getArgOperand(i: 0); |
1287 | Value *Arg1 = CI->getArgOperand(i: 1); |
1288 | VectorType *Arg0VecTy = dyn_cast<VectorType>(Val: Arg0->getType()); |
1289 | VectorType *Arg1VecTy = dyn_cast<VectorType>(Val: Arg1->getType()); |
1290 | if (Arg0VecTy && !Arg1VecTy) { |
1291 | Value *SplatRHS = B.CreateVectorSplat(EC: Arg0VecTy->getElementCount(), V: Arg1); |
1292 | CI->setArgOperand(i: 1, v: SplatRHS); |
1293 | } else if (!Arg0VecTy && Arg1VecTy) { |
1294 | Value *SplatLHS = B.CreateVectorSplat(EC: Arg1VecTy->getElementCount(), V: Arg0); |
1295 | CI->setArgOperand(i: 0, v: SplatLHS); |
1296 | } |
1297 | } |
1298 | |
1299 | CI->setCalledFunction(Intrinsic::getOrInsertDeclaration( |
1300 | M: CI->getModule(), id: IntrID, Tys: {CI->getType()})); |
1301 | } |
1302 | |
1303 | bool AMDGPULibCalls::tryReplaceLibcallWithSimpleIntrinsic( |
1304 | IRBuilder<> &B, CallInst *CI, Intrinsic::ID IntrID, bool AllowMinSizeF32, |
1305 | bool AllowF64, bool AllowStrictFP) { |
1306 | if (!shouldReplaceLibcallWithIntrinsic(CI, AllowMinSizeF32, AllowF64, |
1307 | AllowStrictFP)) |
1308 | return false; |
1309 | replaceLibCallWithSimpleIntrinsic(B, CI, IntrID); |
1310 | return true; |
1311 | } |
1312 | |
1313 | std::tuple<Value *, Value *, Value *> |
1314 | AMDGPULibCalls::insertSinCos(Value *Arg, FastMathFlags FMF, IRBuilder<> &B, |
1315 | FunctionCallee Fsincos) { |
1316 | DebugLoc DL = B.getCurrentDebugLocation(); |
1317 | Function *F = B.GetInsertBlock()->getParent(); |
1318 | B.SetInsertPointPastAllocas(F); |
1319 | |
1320 | AllocaInst *Alloc = B.CreateAlloca(Ty: Arg->getType(), ArraySize: nullptr, Name: "__sincos_" ); |
1321 | |
1322 | if (Instruction *ArgInst = dyn_cast<Instruction>(Val: Arg)) { |
1323 | // If the argument is an instruction, it must dominate all uses so put our |
1324 | // sincos call there. Otherwise, right after the allocas works well enough |
1325 | // if it's an argument or constant. |
1326 | |
1327 | B.SetInsertPoint(TheBB: ArgInst->getParent(), IP: ++ArgInst->getIterator()); |
1328 | |
1329 | // SetInsertPoint unwelcomely always tries to set the debug loc. |
1330 | B.SetCurrentDebugLocation(DL); |
1331 | } |
1332 | |
1333 | Type *CosPtrTy = Fsincos.getFunctionType()->getParamType(i: 1); |
1334 | |
1335 | // The allocaInst allocates the memory in private address space. This need |
1336 | // to be addrspacecasted to point to the address space of cos pointer type. |
1337 | // In OpenCL 2.0 this is generic, while in 1.2 that is private. |
1338 | Value *CastAlloc = B.CreateAddrSpaceCast(V: Alloc, DestTy: CosPtrTy); |
1339 | |
1340 | CallInst *SinCos = CreateCallEx2(B, Callee: Fsincos, Arg1: Arg, Arg2: CastAlloc); |
1341 | |
1342 | // TODO: Is it worth trying to preserve the location for the cos calls for the |
1343 | // load? |
1344 | |
1345 | LoadInst *LoadCos = B.CreateLoad(Ty: Alloc->getAllocatedType(), Ptr: Alloc); |
1346 | return {SinCos, LoadCos, SinCos}; |
1347 | } |
1348 | |
1349 | // fold sin, cos -> sincos. |
1350 | bool AMDGPULibCalls::fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B, |
1351 | const FuncInfo &fInfo) { |
1352 | assert(fInfo.getId() == AMDGPULibFunc::EI_SIN || |
1353 | fInfo.getId() == AMDGPULibFunc::EI_COS); |
1354 | |
1355 | if ((getArgType(FInfo: fInfo) != AMDGPULibFunc::F32 && |
1356 | getArgType(FInfo: fInfo) != AMDGPULibFunc::F64) || |
1357 | fInfo.getPrefix() != AMDGPULibFunc::NOPFX) |
1358 | return false; |
1359 | |
1360 | bool const isSin = fInfo.getId() == AMDGPULibFunc::EI_SIN; |
1361 | |
1362 | Value *CArgVal = FPOp->getOperand(i: 0); |
1363 | |
1364 | // TODO: Constant fold the call |
1365 | if (isa<ConstantData>(Val: CArgVal)) |
1366 | return false; |
1367 | |
1368 | CallInst *CI = cast<CallInst>(Val: FPOp); |
1369 | |
1370 | Function *F = B.GetInsertBlock()->getParent(); |
1371 | Module *M = F->getParent(); |
1372 | |
1373 | // Merge the sin and cos. For OpenCL 2.0, there may only be a generic pointer |
1374 | // implementation. Prefer the private form if available. |
1375 | AMDGPULibFunc SinCosLibFuncPrivate(AMDGPULibFunc::EI_SINCOS, fInfo); |
1376 | SinCosLibFuncPrivate.getLeads()[0].PtrKind = |
1377 | AMDGPULibFunc::getEPtrKindFromAddrSpace(AS: AMDGPUAS::PRIVATE_ADDRESS); |
1378 | |
1379 | AMDGPULibFunc SinCosLibFuncGeneric(AMDGPULibFunc::EI_SINCOS, fInfo); |
1380 | SinCosLibFuncGeneric.getLeads()[0].PtrKind = |
1381 | AMDGPULibFunc::getEPtrKindFromAddrSpace(AS: AMDGPUAS::FLAT_ADDRESS); |
1382 | |
1383 | FunctionCallee FSinCosPrivate = getFunction(M, fInfo: SinCosLibFuncPrivate); |
1384 | FunctionCallee FSinCosGeneric = getFunction(M, fInfo: SinCosLibFuncGeneric); |
1385 | FunctionCallee FSinCos = FSinCosPrivate ? FSinCosPrivate : FSinCosGeneric; |
1386 | if (!FSinCos) |
1387 | return false; |
1388 | |
1389 | SmallVector<CallInst *> SinCalls; |
1390 | SmallVector<CallInst *> CosCalls; |
1391 | SmallVector<CallInst *> SinCosCalls; |
1392 | FuncInfo PartnerInfo(isSin ? AMDGPULibFunc::EI_COS : AMDGPULibFunc::EI_SIN, |
1393 | fInfo); |
1394 | const std::string PairName = PartnerInfo.mangle(); |
1395 | |
1396 | StringRef SinName = isSin ? CI->getCalledFunction()->getName() : PairName; |
1397 | StringRef CosName = isSin ? PairName : CI->getCalledFunction()->getName(); |
1398 | const std::string SinCosPrivateName = SinCosLibFuncPrivate.mangle(); |
1399 | const std::string SinCosGenericName = SinCosLibFuncGeneric.mangle(); |
1400 | |
1401 | // Intersect the two sets of flags. |
1402 | FastMathFlags FMF = FPOp->getFastMathFlags(); |
1403 | MDNode *FPMath = CI->getMetadata(KindID: LLVMContext::MD_fpmath); |
1404 | |
1405 | SmallVector<DILocation *> MergeDbgLocs = {CI->getDebugLoc()}; |
1406 | |
1407 | for (User* U : CArgVal->users()) { |
1408 | CallInst *XI = dyn_cast<CallInst>(Val: U); |
1409 | if (!XI || XI->getFunction() != F || XI->isNoBuiltin()) |
1410 | continue; |
1411 | |
1412 | Function *UCallee = XI->getCalledFunction(); |
1413 | if (!UCallee) |
1414 | continue; |
1415 | |
1416 | bool Handled = true; |
1417 | |
1418 | if (UCallee->getName() == SinName) |
1419 | SinCalls.push_back(Elt: XI); |
1420 | else if (UCallee->getName() == CosName) |
1421 | CosCalls.push_back(Elt: XI); |
1422 | else if (UCallee->getName() == SinCosPrivateName || |
1423 | UCallee->getName() == SinCosGenericName) |
1424 | SinCosCalls.push_back(Elt: XI); |
1425 | else |
1426 | Handled = false; |
1427 | |
1428 | if (Handled) { |
1429 | MergeDbgLocs.push_back(Elt: XI->getDebugLoc()); |
1430 | auto *OtherOp = cast<FPMathOperator>(Val: XI); |
1431 | FMF &= OtherOp->getFastMathFlags(); |
1432 | FPMath = MDNode::getMostGenericFPMath( |
1433 | A: FPMath, B: XI->getMetadata(KindID: LLVMContext::MD_fpmath)); |
1434 | } |
1435 | } |
1436 | |
1437 | if (SinCalls.empty() || CosCalls.empty()) |
1438 | return false; |
1439 | |
1440 | B.setFastMathFlags(FMF); |
1441 | B.setDefaultFPMathTag(FPMath); |
1442 | DILocation *DbgLoc = DILocation::getMergedLocations(Locs: MergeDbgLocs); |
1443 | B.SetCurrentDebugLocation(DbgLoc); |
1444 | |
1445 | auto [Sin, Cos, SinCos] = insertSinCos(Arg: CArgVal, FMF, B, Fsincos: FSinCos); |
1446 | |
1447 | auto replaceTrigInsts = [](ArrayRef<CallInst *> Calls, Value *Res) { |
1448 | for (CallInst *C : Calls) |
1449 | C->replaceAllUsesWith(V: Res); |
1450 | |
1451 | // Leave the other dead instructions to avoid clobbering iterators. |
1452 | }; |
1453 | |
1454 | replaceTrigInsts(SinCalls, Sin); |
1455 | replaceTrigInsts(CosCalls, Cos); |
1456 | replaceTrigInsts(SinCosCalls, SinCos); |
1457 | |
1458 | // It's safe to delete the original now. |
1459 | CI->eraseFromParent(); |
1460 | return true; |
1461 | } |
1462 | |
1463 | bool AMDGPULibCalls::evaluateScalarMathFunc(const FuncInfo &FInfo, double &Res0, |
1464 | double &Res1, Constant *copr0, |
1465 | Constant *copr1) { |
1466 | // By default, opr0/opr1/opr3 holds values of float/double type. |
1467 | // If they are not float/double, each function has to its |
1468 | // operand separately. |
1469 | double opr0 = 0.0, opr1 = 0.0; |
1470 | ConstantFP *fpopr0 = dyn_cast_or_null<ConstantFP>(Val: copr0); |
1471 | ConstantFP *fpopr1 = dyn_cast_or_null<ConstantFP>(Val: copr1); |
1472 | if (fpopr0) { |
1473 | opr0 = (getArgType(FInfo) == AMDGPULibFunc::F64) |
1474 | ? fpopr0->getValueAPF().convertToDouble() |
1475 | : (double)fpopr0->getValueAPF().convertToFloat(); |
1476 | } |
1477 | |
1478 | if (fpopr1) { |
1479 | opr1 = (getArgType(FInfo) == AMDGPULibFunc::F64) |
1480 | ? fpopr1->getValueAPF().convertToDouble() |
1481 | : (double)fpopr1->getValueAPF().convertToFloat(); |
1482 | } |
1483 | |
1484 | switch (FInfo.getId()) { |
1485 | default : return false; |
1486 | |
1487 | case AMDGPULibFunc::EI_ACOS: |
1488 | Res0 = acos(x: opr0); |
1489 | return true; |
1490 | |
1491 | case AMDGPULibFunc::EI_ACOSH: |
1492 | // acosh(x) == log(x + sqrt(x*x - 1)) |
1493 | Res0 = log(x: opr0 + sqrt(x: opr0*opr0 - 1.0)); |
1494 | return true; |
1495 | |
1496 | case AMDGPULibFunc::EI_ACOSPI: |
1497 | Res0 = acos(x: opr0) / MATH_PI; |
1498 | return true; |
1499 | |
1500 | case AMDGPULibFunc::EI_ASIN: |
1501 | Res0 = asin(x: opr0); |
1502 | return true; |
1503 | |
1504 | case AMDGPULibFunc::EI_ASINH: |
1505 | // asinh(x) == log(x + sqrt(x*x + 1)) |
1506 | Res0 = log(x: opr0 + sqrt(x: opr0*opr0 + 1.0)); |
1507 | return true; |
1508 | |
1509 | case AMDGPULibFunc::EI_ASINPI: |
1510 | Res0 = asin(x: opr0) / MATH_PI; |
1511 | return true; |
1512 | |
1513 | case AMDGPULibFunc::EI_ATAN: |
1514 | Res0 = atan(x: opr0); |
1515 | return true; |
1516 | |
1517 | case AMDGPULibFunc::EI_ATANH: |
1518 | // atanh(x) == (log(x+1) - log(x-1))/2; |
1519 | Res0 = (log(x: opr0 + 1.0) - log(x: opr0 - 1.0))/2.0; |
1520 | return true; |
1521 | |
1522 | case AMDGPULibFunc::EI_ATANPI: |
1523 | Res0 = atan(x: opr0) / MATH_PI; |
1524 | return true; |
1525 | |
1526 | case AMDGPULibFunc::EI_CBRT: |
1527 | Res0 = (opr0 < 0.0) ? -pow(x: -opr0, y: 1.0/3.0) : pow(x: opr0, y: 1.0/3.0); |
1528 | return true; |
1529 | |
1530 | case AMDGPULibFunc::EI_COS: |
1531 | Res0 = cos(x: opr0); |
1532 | return true; |
1533 | |
1534 | case AMDGPULibFunc::EI_COSH: |
1535 | Res0 = cosh(x: opr0); |
1536 | return true; |
1537 | |
1538 | case AMDGPULibFunc::EI_COSPI: |
1539 | Res0 = cos(MATH_PI * opr0); |
1540 | return true; |
1541 | |
1542 | case AMDGPULibFunc::EI_EXP: |
1543 | Res0 = exp(x: opr0); |
1544 | return true; |
1545 | |
1546 | case AMDGPULibFunc::EI_EXP2: |
1547 | Res0 = pow(x: 2.0, y: opr0); |
1548 | return true; |
1549 | |
1550 | case AMDGPULibFunc::EI_EXP10: |
1551 | Res0 = pow(x: 10.0, y: opr0); |
1552 | return true; |
1553 | |
1554 | case AMDGPULibFunc::EI_LOG: |
1555 | Res0 = log(x: opr0); |
1556 | return true; |
1557 | |
1558 | case AMDGPULibFunc::EI_LOG2: |
1559 | Res0 = log(x: opr0) / log(x: 2.0); |
1560 | return true; |
1561 | |
1562 | case AMDGPULibFunc::EI_LOG10: |
1563 | Res0 = log(x: opr0) / log(x: 10.0); |
1564 | return true; |
1565 | |
1566 | case AMDGPULibFunc::EI_RSQRT: |
1567 | Res0 = 1.0 / sqrt(x: opr0); |
1568 | return true; |
1569 | |
1570 | case AMDGPULibFunc::EI_SIN: |
1571 | Res0 = sin(x: opr0); |
1572 | return true; |
1573 | |
1574 | case AMDGPULibFunc::EI_SINH: |
1575 | Res0 = sinh(x: opr0); |
1576 | return true; |
1577 | |
1578 | case AMDGPULibFunc::EI_SINPI: |
1579 | Res0 = sin(MATH_PI * opr0); |
1580 | return true; |
1581 | |
1582 | case AMDGPULibFunc::EI_TAN: |
1583 | Res0 = tan(x: opr0); |
1584 | return true; |
1585 | |
1586 | case AMDGPULibFunc::EI_TANH: |
1587 | Res0 = tanh(x: opr0); |
1588 | return true; |
1589 | |
1590 | case AMDGPULibFunc::EI_TANPI: |
1591 | Res0 = tan(MATH_PI * opr0); |
1592 | return true; |
1593 | |
1594 | // two-arg functions |
1595 | case AMDGPULibFunc::EI_POW: |
1596 | case AMDGPULibFunc::EI_POWR: |
1597 | Res0 = pow(x: opr0, y: opr1); |
1598 | return true; |
1599 | |
1600 | case AMDGPULibFunc::EI_POWN: { |
1601 | if (ConstantInt *iopr1 = dyn_cast_or_null<ConstantInt>(Val: copr1)) { |
1602 | double val = (double)iopr1->getSExtValue(); |
1603 | Res0 = pow(x: opr0, y: val); |
1604 | return true; |
1605 | } |
1606 | return false; |
1607 | } |
1608 | |
1609 | case AMDGPULibFunc::EI_ROOTN: { |
1610 | if (ConstantInt *iopr1 = dyn_cast_or_null<ConstantInt>(Val: copr1)) { |
1611 | double val = (double)iopr1->getSExtValue(); |
1612 | Res0 = pow(x: opr0, y: 1.0 / val); |
1613 | return true; |
1614 | } |
1615 | return false; |
1616 | } |
1617 | |
1618 | // with ptr arg |
1619 | case AMDGPULibFunc::EI_SINCOS: |
1620 | Res0 = sin(x: opr0); |
1621 | Res1 = cos(x: opr0); |
1622 | return true; |
1623 | } |
1624 | |
1625 | return false; |
1626 | } |
1627 | |
1628 | bool AMDGPULibCalls::evaluateCall(CallInst *aCI, const FuncInfo &FInfo) { |
1629 | int numArgs = (int)aCI->arg_size(); |
1630 | if (numArgs > 3) |
1631 | return false; |
1632 | |
1633 | Constant *copr0 = nullptr; |
1634 | Constant *copr1 = nullptr; |
1635 | if (numArgs > 0) { |
1636 | if ((copr0 = dyn_cast<Constant>(Val: aCI->getArgOperand(i: 0))) == nullptr) |
1637 | return false; |
1638 | } |
1639 | |
1640 | if (numArgs > 1) { |
1641 | if ((copr1 = dyn_cast<Constant>(Val: aCI->getArgOperand(i: 1))) == nullptr) { |
1642 | if (FInfo.getId() != AMDGPULibFunc::EI_SINCOS) |
1643 | return false; |
1644 | } |
1645 | } |
1646 | |
1647 | // At this point, all arguments to aCI are constants. |
1648 | |
1649 | // max vector size is 16, and sincos will generate two results. |
1650 | double DVal0[16], DVal1[16]; |
1651 | int FuncVecSize = getVecSize(FInfo); |
1652 | bool hasTwoResults = (FInfo.getId() == AMDGPULibFunc::EI_SINCOS); |
1653 | if (FuncVecSize == 1) { |
1654 | if (!evaluateScalarMathFunc(FInfo, Res0&: DVal0[0], Res1&: DVal1[0], copr0, copr1)) { |
1655 | return false; |
1656 | } |
1657 | } else { |
1658 | ConstantDataVector *CDV0 = dyn_cast_or_null<ConstantDataVector>(Val: copr0); |
1659 | ConstantDataVector *CDV1 = dyn_cast_or_null<ConstantDataVector>(Val: copr1); |
1660 | for (int i = 0; i < FuncVecSize; ++i) { |
1661 | Constant *celt0 = CDV0 ? CDV0->getElementAsConstant(i) : nullptr; |
1662 | Constant *celt1 = CDV1 ? CDV1->getElementAsConstant(i) : nullptr; |
1663 | if (!evaluateScalarMathFunc(FInfo, Res0&: DVal0[i], Res1&: DVal1[i], copr0: celt0, copr1: celt1)) { |
1664 | return false; |
1665 | } |
1666 | } |
1667 | } |
1668 | |
1669 | LLVMContext &context = aCI->getContext(); |
1670 | Constant *nval0, *nval1; |
1671 | if (FuncVecSize == 1) { |
1672 | nval0 = ConstantFP::get(Ty: aCI->getType(), V: DVal0[0]); |
1673 | if (hasTwoResults) |
1674 | nval1 = ConstantFP::get(Ty: aCI->getType(), V: DVal1[0]); |
1675 | } else { |
1676 | if (getArgType(FInfo) == AMDGPULibFunc::F32) { |
1677 | SmallVector <float, 0> FVal0, FVal1; |
1678 | for (int i = 0; i < FuncVecSize; ++i) |
1679 | FVal0.push_back(Elt: (float)DVal0[i]); |
1680 | ArrayRef<float> tmp0(FVal0); |
1681 | nval0 = ConstantDataVector::get(Context&: context, Elts: tmp0); |
1682 | if (hasTwoResults) { |
1683 | for (int i = 0; i < FuncVecSize; ++i) |
1684 | FVal1.push_back(Elt: (float)DVal1[i]); |
1685 | ArrayRef<float> tmp1(FVal1); |
1686 | nval1 = ConstantDataVector::get(Context&: context, Elts: tmp1); |
1687 | } |
1688 | } else { |
1689 | ArrayRef<double> tmp0(DVal0); |
1690 | nval0 = ConstantDataVector::get(Context&: context, Elts: tmp0); |
1691 | if (hasTwoResults) { |
1692 | ArrayRef<double> tmp1(DVal1); |
1693 | nval1 = ConstantDataVector::get(Context&: context, Elts: tmp1); |
1694 | } |
1695 | } |
1696 | } |
1697 | |
1698 | if (hasTwoResults) { |
1699 | // sincos |
1700 | assert(FInfo.getId() == AMDGPULibFunc::EI_SINCOS && |
1701 | "math function with ptr arg not supported yet" ); |
1702 | new StoreInst(nval1, aCI->getArgOperand(i: 1), aCI->getIterator()); |
1703 | } |
1704 | |
1705 | replaceCall(I: aCI, With: nval0); |
1706 | return true; |
1707 | } |
1708 | |
1709 | PreservedAnalyses AMDGPUSimplifyLibCallsPass::run(Function &F, |
1710 | FunctionAnalysisManager &AM) { |
1711 | AMDGPULibCalls Simplifier; |
1712 | Simplifier.initNativeFuncs(); |
1713 | Simplifier.initFunction(F, FAM&: AM); |
1714 | |
1715 | bool Changed = false; |
1716 | |
1717 | LLVM_DEBUG(dbgs() << "AMDIC: process function " ; |
1718 | F.printAsOperand(dbgs(), false, F.getParent()); dbgs() << '\n';); |
1719 | |
1720 | for (auto &BB : F) { |
1721 | for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E;) { |
1722 | // Ignore non-calls. |
1723 | CallInst *CI = dyn_cast<CallInst>(Val&: I); |
1724 | ++I; |
1725 | |
1726 | if (CI) { |
1727 | if (Simplifier.fold(CI)) |
1728 | Changed = true; |
1729 | } |
1730 | } |
1731 | } |
1732 | return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); |
1733 | } |
1734 | |
1735 | PreservedAnalyses AMDGPUUseNativeCallsPass::run(Function &F, |
1736 | FunctionAnalysisManager &AM) { |
1737 | if (UseNative.empty()) |
1738 | return PreservedAnalyses::all(); |
1739 | |
1740 | AMDGPULibCalls Simplifier; |
1741 | Simplifier.initNativeFuncs(); |
1742 | Simplifier.initFunction(F, FAM&: AM); |
1743 | |
1744 | bool Changed = false; |
1745 | for (auto &BB : F) { |
1746 | for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E;) { |
1747 | // Ignore non-calls. |
1748 | CallInst *CI = dyn_cast<CallInst>(Val&: I); |
1749 | ++I; |
1750 | if (CI && Simplifier.useNative(aCI: CI)) |
1751 | Changed = true; |
1752 | } |
1753 | } |
1754 | return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); |
1755 | } |
1756 | |