1//===- OpenMPIRBuilder.cpp - Builder for LLVM-IR for OpenMP directives ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8/// \file
9///
10/// This file implements the OpenMPIRBuilder class, which is used as a
11/// convenient way to create LLVM instructions for OpenMP directives.
12///
13//===----------------------------------------------------------------------===//
14
15#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
16#include "llvm/ADT/SmallBitVector.h"
17#include "llvm/ADT/SmallSet.h"
18#include "llvm/ADT/StringExtras.h"
19#include "llvm/ADT/StringRef.h"
20#include "llvm/Analysis/AssumptionCache.h"
21#include "llvm/Analysis/CodeMetrics.h"
22#include "llvm/Analysis/LoopInfo.h"
23#include "llvm/Analysis/OptimizationRemarkEmitter.h"
24#include "llvm/Analysis/ScalarEvolution.h"
25#include "llvm/Analysis/TargetLibraryInfo.h"
26#include "llvm/Bitcode/BitcodeReader.h"
27#include "llvm/Frontend/Offloading/Utility.h"
28#include "llvm/Frontend/OpenMP/OMPGridValues.h"
29#include "llvm/IR/Attributes.h"
30#include "llvm/IR/BasicBlock.h"
31#include "llvm/IR/CFG.h"
32#include "llvm/IR/CallingConv.h"
33#include "llvm/IR/Constant.h"
34#include "llvm/IR/Constants.h"
35#include "llvm/IR/DIBuilder.h"
36#include "llvm/IR/DebugInfoMetadata.h"
37#include "llvm/IR/DerivedTypes.h"
38#include "llvm/IR/Function.h"
39#include "llvm/IR/GlobalVariable.h"
40#include "llvm/IR/IRBuilder.h"
41#include "llvm/IR/InstIterator.h"
42#include "llvm/IR/IntrinsicInst.h"
43#include "llvm/IR/LLVMContext.h"
44#include "llvm/IR/MDBuilder.h"
45#include "llvm/IR/Metadata.h"
46#include "llvm/IR/PassInstrumentation.h"
47#include "llvm/IR/PassManager.h"
48#include "llvm/IR/ReplaceConstant.h"
49#include "llvm/IR/Value.h"
50#include "llvm/MC/TargetRegistry.h"
51#include "llvm/Support/CommandLine.h"
52#include "llvm/Support/ErrorHandling.h"
53#include "llvm/Support/FileSystem.h"
54#include "llvm/Target/TargetMachine.h"
55#include "llvm/Target/TargetOptions.h"
56#include "llvm/Transforms/Utils/BasicBlockUtils.h"
57#include "llvm/Transforms/Utils/Cloning.h"
58#include "llvm/Transforms/Utils/CodeExtractor.h"
59#include "llvm/Transforms/Utils/LoopPeel.h"
60#include "llvm/Transforms/Utils/UnrollLoop.h"
61
62#include <cstdint>
63#include <optional>
64
65#define DEBUG_TYPE "openmp-ir-builder"
66
67using namespace llvm;
68using namespace omp;
69
70static cl::opt<bool>
71 OptimisticAttributes("openmp-ir-builder-optimistic-attributes", cl::Hidden,
72 cl::desc("Use optimistic attributes describing "
73 "'as-if' properties of runtime calls."),
74 cl::init(Val: false));
75
76static cl::opt<double> UnrollThresholdFactor(
77 "openmp-ir-builder-unroll-threshold-factor", cl::Hidden,
78 cl::desc("Factor for the unroll threshold to account for code "
79 "simplifications still taking place"),
80 cl::init(Val: 1.5));
81
82#ifndef NDEBUG
83/// Return whether IP1 and IP2 are ambiguous, i.e. that inserting instructions
84/// at position IP1 may change the meaning of IP2 or vice-versa. This is because
85/// an InsertPoint stores the instruction before something is inserted. For
86/// instance, if both point to the same instruction, two IRBuilders alternating
87/// creating instruction will cause the instructions to be interleaved.
88static bool isConflictIP(IRBuilder<>::InsertPoint IP1,
89 IRBuilder<>::InsertPoint IP2) {
90 if (!IP1.isSet() || !IP2.isSet())
91 return false;
92 return IP1.getBlock() == IP2.getBlock() && IP1.getPoint() == IP2.getPoint();
93}
94
95static bool isValidWorkshareLoopScheduleType(OMPScheduleType SchedType) {
96 // Valid ordered/unordered and base algorithm combinations.
97 switch (SchedType & ~OMPScheduleType::MonotonicityMask) {
98 case OMPScheduleType::UnorderedStaticChunked:
99 case OMPScheduleType::UnorderedStatic:
100 case OMPScheduleType::UnorderedDynamicChunked:
101 case OMPScheduleType::UnorderedGuidedChunked:
102 case OMPScheduleType::UnorderedRuntime:
103 case OMPScheduleType::UnorderedAuto:
104 case OMPScheduleType::UnorderedTrapezoidal:
105 case OMPScheduleType::UnorderedGreedy:
106 case OMPScheduleType::UnorderedBalanced:
107 case OMPScheduleType::UnorderedGuidedIterativeChunked:
108 case OMPScheduleType::UnorderedGuidedAnalyticalChunked:
109 case OMPScheduleType::UnorderedSteal:
110 case OMPScheduleType::UnorderedStaticBalancedChunked:
111 case OMPScheduleType::UnorderedGuidedSimd:
112 case OMPScheduleType::UnorderedRuntimeSimd:
113 case OMPScheduleType::OrderedStaticChunked:
114 case OMPScheduleType::OrderedStatic:
115 case OMPScheduleType::OrderedDynamicChunked:
116 case OMPScheduleType::OrderedGuidedChunked:
117 case OMPScheduleType::OrderedRuntime:
118 case OMPScheduleType::OrderedAuto:
119 case OMPScheduleType::OrderdTrapezoidal:
120 case OMPScheduleType::NomergeUnorderedStaticChunked:
121 case OMPScheduleType::NomergeUnorderedStatic:
122 case OMPScheduleType::NomergeUnorderedDynamicChunked:
123 case OMPScheduleType::NomergeUnorderedGuidedChunked:
124 case OMPScheduleType::NomergeUnorderedRuntime:
125 case OMPScheduleType::NomergeUnorderedAuto:
126 case OMPScheduleType::NomergeUnorderedTrapezoidal:
127 case OMPScheduleType::NomergeUnorderedGreedy:
128 case OMPScheduleType::NomergeUnorderedBalanced:
129 case OMPScheduleType::NomergeUnorderedGuidedIterativeChunked:
130 case OMPScheduleType::NomergeUnorderedGuidedAnalyticalChunked:
131 case OMPScheduleType::NomergeUnorderedSteal:
132 case OMPScheduleType::NomergeOrderedStaticChunked:
133 case OMPScheduleType::NomergeOrderedStatic:
134 case OMPScheduleType::NomergeOrderedDynamicChunked:
135 case OMPScheduleType::NomergeOrderedGuidedChunked:
136 case OMPScheduleType::NomergeOrderedRuntime:
137 case OMPScheduleType::NomergeOrderedAuto:
138 case OMPScheduleType::NomergeOrderedTrapezoidal:
139 break;
140 default:
141 return false;
142 }
143
144 // Must not set both monotonicity modifiers at the same time.
145 OMPScheduleType MonotonicityFlags =
146 SchedType & OMPScheduleType::MonotonicityMask;
147 if (MonotonicityFlags == OMPScheduleType::MonotonicityMask)
148 return false;
149
150 return true;
151}
152#endif
153
154static const omp::GV &getGridValue(const Triple &T, Function *Kernel) {
155 if (T.isAMDGPU()) {
156 StringRef Features =
157 Kernel->getFnAttribute(Kind: "target-features").getValueAsString();
158 if (Features.count(Str: "+wavefrontsize64"))
159 return omp::getAMDGPUGridValues<64>();
160 return omp::getAMDGPUGridValues<32>();
161 }
162 if (T.isNVPTX())
163 return omp::NVPTXGridValues;
164 if (T.isSPIRV())
165 return omp::SPIRVGridValues;
166 llvm_unreachable("No grid value available for this architecture!");
167}
168
169/// Determine which scheduling algorithm to use, determined from schedule clause
170/// arguments.
171static OMPScheduleType
172getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind, bool HasChunks,
173 bool HasSimdModifier) {
174 // Currently, the default schedule it static.
175 switch (ClauseKind) {
176 case OMP_SCHEDULE_Default:
177 case OMP_SCHEDULE_Static:
178 return HasChunks ? OMPScheduleType::BaseStaticChunked
179 : OMPScheduleType::BaseStatic;
180 case OMP_SCHEDULE_Dynamic:
181 return OMPScheduleType::BaseDynamicChunked;
182 case OMP_SCHEDULE_Guided:
183 return HasSimdModifier ? OMPScheduleType::BaseGuidedSimd
184 : OMPScheduleType::BaseGuidedChunked;
185 case OMP_SCHEDULE_Auto:
186 return llvm::omp::OMPScheduleType::BaseAuto;
187 case OMP_SCHEDULE_Runtime:
188 return HasSimdModifier ? OMPScheduleType::BaseRuntimeSimd
189 : OMPScheduleType::BaseRuntime;
190 }
191 llvm_unreachable("unhandled schedule clause argument");
192}
193
194/// Adds ordering modifier flags to schedule type.
195static OMPScheduleType
196getOpenMPOrderingScheduleType(OMPScheduleType BaseScheduleType,
197 bool HasOrderedClause) {
198 assert((BaseScheduleType & OMPScheduleType::ModifierMask) ==
199 OMPScheduleType::None &&
200 "Must not have ordering nor monotonicity flags already set");
201
202 OMPScheduleType OrderingModifier = HasOrderedClause
203 ? OMPScheduleType::ModifierOrdered
204 : OMPScheduleType::ModifierUnordered;
205 OMPScheduleType OrderingScheduleType = BaseScheduleType | OrderingModifier;
206
207 // Unsupported combinations
208 if (OrderingScheduleType ==
209 (OMPScheduleType::BaseGuidedSimd | OMPScheduleType::ModifierOrdered))
210 return OMPScheduleType::OrderedGuidedChunked;
211 else if (OrderingScheduleType == (OMPScheduleType::BaseRuntimeSimd |
212 OMPScheduleType::ModifierOrdered))
213 return OMPScheduleType::OrderedRuntime;
214
215 return OrderingScheduleType;
216}
217
218/// Adds monotonicity modifier flags to schedule type.
219static OMPScheduleType
220getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,
221 bool HasSimdModifier, bool HasMonotonic,
222 bool HasNonmonotonic, bool HasOrderedClause) {
223 assert((ScheduleType & OMPScheduleType::MonotonicityMask) ==
224 OMPScheduleType::None &&
225 "Must not have monotonicity flags already set");
226 assert((!HasMonotonic || !HasNonmonotonic) &&
227 "Monotonic and Nonmonotonic are contradicting each other");
228
229 if (HasMonotonic) {
230 return ScheduleType | OMPScheduleType::ModifierMonotonic;
231 } else if (HasNonmonotonic) {
232 return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
233 } else {
234 // OpenMP 5.1, 2.11.4 Worksharing-Loop Construct, Description.
235 // If the static schedule kind is specified or if the ordered clause is
236 // specified, and if the nonmonotonic modifier is not specified, the
237 // effect is as if the monotonic modifier is specified. Otherwise, unless
238 // the monotonic modifier is specified, the effect is as if the
239 // nonmonotonic modifier is specified.
240 OMPScheduleType BaseScheduleType =
241 ScheduleType & ~OMPScheduleType::ModifierMask;
242 if ((BaseScheduleType == OMPScheduleType::BaseStatic) ||
243 (BaseScheduleType == OMPScheduleType::BaseStaticChunked) ||
244 HasOrderedClause) {
245 // The monotonic is used by default in openmp runtime library, so no need
246 // to set it.
247 return ScheduleType;
248 } else {
249 return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
250 }
251 }
252}
253
254/// Determine the schedule type using schedule and ordering clause arguments.
255static OMPScheduleType
256computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
257 bool HasSimdModifier, bool HasMonotonicModifier,
258 bool HasNonmonotonicModifier, bool HasOrderedClause) {
259 OMPScheduleType BaseSchedule =
260 getOpenMPBaseScheduleType(ClauseKind, HasChunks, HasSimdModifier);
261 OMPScheduleType OrderedSchedule =
262 getOpenMPOrderingScheduleType(BaseScheduleType: BaseSchedule, HasOrderedClause);
263 OMPScheduleType Result = getOpenMPMonotonicityScheduleType(
264 ScheduleType: OrderedSchedule, HasSimdModifier, HasMonotonic: HasMonotonicModifier,
265 HasNonmonotonic: HasNonmonotonicModifier, HasOrderedClause);
266
267 assert(isValidWorkshareLoopScheduleType(Result));
268 return Result;
269}
270
271/// Make \p Source branch to \p Target.
272///
273/// Handles two situations:
274/// * \p Source already has an unconditional branch.
275/// * \p Source is a degenerate block (no terminator because the BB is
276/// the current head of the IR construction).
277static void redirectTo(BasicBlock *Source, BasicBlock *Target, DebugLoc DL) {
278 if (Instruction *Term = Source->getTerminator()) {
279 auto *Br = cast<BranchInst>(Val: Term);
280 assert(!Br->isConditional() &&
281 "BB's terminator must be an unconditional branch (or degenerate)");
282 BasicBlock *Succ = Br->getSuccessor(i: 0);
283 Succ->removePredecessor(Pred: Source, /*KeepOneInputPHIs=*/true);
284 Br->setSuccessor(idx: 0, NewSucc: Target);
285 return;
286 }
287
288 auto *NewBr = BranchInst::Create(IfTrue: Target, InsertBefore: Source);
289 NewBr->setDebugLoc(DL);
290}
291
292void llvm::spliceBB(IRBuilderBase::InsertPoint IP, BasicBlock *New,
293 bool CreateBranch, DebugLoc DL) {
294 assert(New->getFirstInsertionPt() == New->begin() &&
295 "Target BB must not have PHI nodes");
296
297 // Move instructions to new block.
298 BasicBlock *Old = IP.getBlock();
299 New->splice(ToIt: New->begin(), FromBB: Old, FromBeginIt: IP.getPoint(), FromEndIt: Old->end());
300
301 if (CreateBranch) {
302 auto *NewBr = BranchInst::Create(IfTrue: New, InsertBefore: Old);
303 NewBr->setDebugLoc(DL);
304 }
305}
306
307void llvm::spliceBB(IRBuilder<> &Builder, BasicBlock *New, bool CreateBranch) {
308 DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
309 BasicBlock *Old = Builder.GetInsertBlock();
310
311 spliceBB(IP: Builder.saveIP(), New, CreateBranch, DL: DebugLoc);
312 if (CreateBranch)
313 Builder.SetInsertPoint(Old->getTerminator());
314 else
315 Builder.SetInsertPoint(Old);
316
317 // SetInsertPoint also updates the Builder's debug location, but we want to
318 // keep the one the Builder was configured to use.
319 Builder.SetCurrentDebugLocation(DebugLoc);
320}
321
322BasicBlock *llvm::splitBB(IRBuilderBase::InsertPoint IP, bool CreateBranch,
323 DebugLoc DL, llvm::Twine Name) {
324 BasicBlock *Old = IP.getBlock();
325 BasicBlock *New = BasicBlock::Create(
326 Context&: Old->getContext(), Name: Name.isTriviallyEmpty() ? Old->getName() : Name,
327 Parent: Old->getParent(), InsertBefore: Old->getNextNode());
328 spliceBB(IP, New, CreateBranch, DL);
329 New->replaceSuccessorsPhiUsesWith(Old, New);
330 return New;
331}
332
333BasicBlock *llvm::splitBB(IRBuilderBase &Builder, bool CreateBranch,
334 llvm::Twine Name) {
335 DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
336 BasicBlock *New = splitBB(IP: Builder.saveIP(), CreateBranch, DL: DebugLoc, Name);
337 if (CreateBranch)
338 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
339 else
340 Builder.SetInsertPoint(Builder.GetInsertBlock());
341 // SetInsertPoint also updates the Builder's debug location, but we want to
342 // keep the one the Builder was configured to use.
343 Builder.SetCurrentDebugLocation(DebugLoc);
344 return New;
345}
346
347BasicBlock *llvm::splitBB(IRBuilder<> &Builder, bool CreateBranch,
348 llvm::Twine Name) {
349 DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
350 BasicBlock *New = splitBB(IP: Builder.saveIP(), CreateBranch, DL: DebugLoc, Name);
351 if (CreateBranch)
352 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
353 else
354 Builder.SetInsertPoint(Builder.GetInsertBlock());
355 // SetInsertPoint also updates the Builder's debug location, but we want to
356 // keep the one the Builder was configured to use.
357 Builder.SetCurrentDebugLocation(DebugLoc);
358 return New;
359}
360
361BasicBlock *llvm::splitBBWithSuffix(IRBuilderBase &Builder, bool CreateBranch,
362 llvm::Twine Suffix) {
363 BasicBlock *Old = Builder.GetInsertBlock();
364 return splitBB(Builder, CreateBranch, Name: Old->getName() + Suffix);
365}
366
367// This function creates a fake integer value and a fake use for the integer
368// value. It returns the fake value created. This is useful in modeling the
369// extra arguments to the outlined functions.
370Value *createFakeIntVal(IRBuilderBase &Builder,
371 OpenMPIRBuilder::InsertPointTy OuterAllocaIP,
372 llvm::SmallVectorImpl<Instruction *> &ToBeDeleted,
373 OpenMPIRBuilder::InsertPointTy InnerAllocaIP,
374 const Twine &Name = "", bool AsPtr = true) {
375 Builder.restoreIP(IP: OuterAllocaIP);
376 Instruction *FakeVal;
377 AllocaInst *FakeValAddr =
378 Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr, Name: Name + ".addr");
379 ToBeDeleted.push_back(Elt: FakeValAddr);
380
381 if (AsPtr) {
382 FakeVal = FakeValAddr;
383 } else {
384 FakeVal =
385 Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: FakeValAddr, Name: Name + ".val");
386 ToBeDeleted.push_back(Elt: FakeVal);
387 }
388
389 // Generate a fake use of this value
390 Builder.restoreIP(IP: InnerAllocaIP);
391 Instruction *UseFakeVal;
392 if (AsPtr) {
393 UseFakeVal =
394 Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: FakeVal, Name: Name + ".use");
395 } else {
396 UseFakeVal =
397 cast<BinaryOperator>(Val: Builder.CreateAdd(LHS: FakeVal, RHS: Builder.getInt32(C: 10)));
398 }
399 ToBeDeleted.push_back(Elt: UseFakeVal);
400 return FakeVal;
401}
402
403//===----------------------------------------------------------------------===//
404// OpenMPIRBuilderConfig
405//===----------------------------------------------------------------------===//
406
407namespace {
408LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
409/// Values for bit flags for marking which requires clauses have been used.
410enum OpenMPOffloadingRequiresDirFlags {
411 /// flag undefined.
412 OMP_REQ_UNDEFINED = 0x000,
413 /// no requires directive present.
414 OMP_REQ_NONE = 0x001,
415 /// reverse_offload clause.
416 OMP_REQ_REVERSE_OFFLOAD = 0x002,
417 /// unified_address clause.
418 OMP_REQ_UNIFIED_ADDRESS = 0x004,
419 /// unified_shared_memory clause.
420 OMP_REQ_UNIFIED_SHARED_MEMORY = 0x008,
421 /// dynamic_allocators clause.
422 OMP_REQ_DYNAMIC_ALLOCATORS = 0x010,
423 LLVM_MARK_AS_BITMASK_ENUM(/*LargestValue=*/OMP_REQ_DYNAMIC_ALLOCATORS)
424};
425
426} // anonymous namespace
427
428OpenMPIRBuilderConfig::OpenMPIRBuilderConfig()
429 : RequiresFlags(OMP_REQ_UNDEFINED) {}
430
431OpenMPIRBuilderConfig::OpenMPIRBuilderConfig(
432 bool IsTargetDevice, bool IsGPU, bool OpenMPOffloadMandatory,
433 bool HasRequiresReverseOffload, bool HasRequiresUnifiedAddress,
434 bool HasRequiresUnifiedSharedMemory, bool HasRequiresDynamicAllocators)
435 : IsTargetDevice(IsTargetDevice), IsGPU(IsGPU),
436 OpenMPOffloadMandatory(OpenMPOffloadMandatory),
437 RequiresFlags(OMP_REQ_UNDEFINED) {
438 if (HasRequiresReverseOffload)
439 RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
440 if (HasRequiresUnifiedAddress)
441 RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
442 if (HasRequiresUnifiedSharedMemory)
443 RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
444 if (HasRequiresDynamicAllocators)
445 RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
446}
447
448bool OpenMPIRBuilderConfig::hasRequiresReverseOffload() const {
449 return RequiresFlags & OMP_REQ_REVERSE_OFFLOAD;
450}
451
452bool OpenMPIRBuilderConfig::hasRequiresUnifiedAddress() const {
453 return RequiresFlags & OMP_REQ_UNIFIED_ADDRESS;
454}
455
456bool OpenMPIRBuilderConfig::hasRequiresUnifiedSharedMemory() const {
457 return RequiresFlags & OMP_REQ_UNIFIED_SHARED_MEMORY;
458}
459
460bool OpenMPIRBuilderConfig::hasRequiresDynamicAllocators() const {
461 return RequiresFlags & OMP_REQ_DYNAMIC_ALLOCATORS;
462}
463
464int64_t OpenMPIRBuilderConfig::getRequiresFlags() const {
465 return hasRequiresFlags() ? RequiresFlags
466 : static_cast<int64_t>(OMP_REQ_NONE);
467}
468
469void OpenMPIRBuilderConfig::setHasRequiresReverseOffload(bool Value) {
470 if (Value)
471 RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
472 else
473 RequiresFlags &= ~OMP_REQ_REVERSE_OFFLOAD;
474}
475
476void OpenMPIRBuilderConfig::setHasRequiresUnifiedAddress(bool Value) {
477 if (Value)
478 RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
479 else
480 RequiresFlags &= ~OMP_REQ_UNIFIED_ADDRESS;
481}
482
483void OpenMPIRBuilderConfig::setHasRequiresUnifiedSharedMemory(bool Value) {
484 if (Value)
485 RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
486 else
487 RequiresFlags &= ~OMP_REQ_UNIFIED_SHARED_MEMORY;
488}
489
490void OpenMPIRBuilderConfig::setHasRequiresDynamicAllocators(bool Value) {
491 if (Value)
492 RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
493 else
494 RequiresFlags &= ~OMP_REQ_DYNAMIC_ALLOCATORS;
495}
496
497//===----------------------------------------------------------------------===//
498// OpenMPIRBuilder
499//===----------------------------------------------------------------------===//
500
501void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
502 IRBuilderBase &Builder,
503 SmallVector<Value *> &ArgsVector) {
504 Value *Version = Builder.getInt32(OMP_KERNEL_ARG_VERSION);
505 Value *PointerNum = Builder.getInt32(C: KernelArgs.NumTargetItems);
506 auto Int32Ty = Type::getInt32Ty(C&: Builder.getContext());
507 constexpr const size_t MaxDim = 3;
508 Value *ZeroArray = Constant::getNullValue(Ty: ArrayType::get(ElementType: Int32Ty, NumElements: MaxDim));
509 Value *Flags = Builder.getInt64(C: KernelArgs.HasNoWait);
510
511 assert(!KernelArgs.NumTeams.empty() && !KernelArgs.NumThreads.empty());
512
513 Value *NumTeams3D =
514 Builder.CreateInsertValue(Agg: ZeroArray, Val: KernelArgs.NumTeams[0], Idxs: {0});
515 Value *NumThreads3D =
516 Builder.CreateInsertValue(Agg: ZeroArray, Val: KernelArgs.NumThreads[0], Idxs: {0});
517 for (unsigned I :
518 seq<unsigned>(Begin: 1, End: std::min(a: KernelArgs.NumTeams.size(), b: MaxDim)))
519 NumTeams3D =
520 Builder.CreateInsertValue(Agg: NumTeams3D, Val: KernelArgs.NumTeams[I], Idxs: {I});
521 for (unsigned I :
522 seq<unsigned>(Begin: 1, End: std::min(a: KernelArgs.NumThreads.size(), b: MaxDim)))
523 NumThreads3D =
524 Builder.CreateInsertValue(Agg: NumThreads3D, Val: KernelArgs.NumThreads[I], Idxs: {I});
525
526 ArgsVector = {Version,
527 PointerNum,
528 KernelArgs.RTArgs.BasePointersArray,
529 KernelArgs.RTArgs.PointersArray,
530 KernelArgs.RTArgs.SizesArray,
531 KernelArgs.RTArgs.MapTypesArray,
532 KernelArgs.RTArgs.MapNamesArray,
533 KernelArgs.RTArgs.MappersArray,
534 KernelArgs.NumIterations,
535 Flags,
536 NumTeams3D,
537 NumThreads3D,
538 KernelArgs.DynCGGroupMem};
539}
540
541void OpenMPIRBuilder::addAttributes(omp::RuntimeFunction FnID, Function &Fn) {
542 LLVMContext &Ctx = Fn.getContext();
543
544 // Get the function's current attributes.
545 auto Attrs = Fn.getAttributes();
546 auto FnAttrs = Attrs.getFnAttrs();
547 auto RetAttrs = Attrs.getRetAttrs();
548 SmallVector<AttributeSet, 4> ArgAttrs;
549 for (size_t ArgNo = 0; ArgNo < Fn.arg_size(); ++ArgNo)
550 ArgAttrs.emplace_back(Args: Attrs.getParamAttrs(ArgNo));
551
552 // Add AS to FnAS while taking special care with integer extensions.
553 auto addAttrSet = [&](AttributeSet &FnAS, const AttributeSet &AS,
554 bool Param = true) -> void {
555 bool HasSignExt = AS.hasAttribute(Kind: Attribute::SExt);
556 bool HasZeroExt = AS.hasAttribute(Kind: Attribute::ZExt);
557 if (HasSignExt || HasZeroExt) {
558 assert(AS.getNumAttributes() == 1 &&
559 "Currently not handling extension attr combined with others.");
560 if (Param) {
561 if (auto AK = TargetLibraryInfo::getExtAttrForI32Param(T, Signed: HasSignExt))
562 FnAS = FnAS.addAttribute(C&: Ctx, Kind: AK);
563 } else if (auto AK =
564 TargetLibraryInfo::getExtAttrForI32Return(T, Signed: HasSignExt))
565 FnAS = FnAS.addAttribute(C&: Ctx, Kind: AK);
566 } else {
567 FnAS = FnAS.addAttributes(C&: Ctx, AS);
568 }
569 };
570
571#define OMP_ATTRS_SET(VarName, AttrSet) AttributeSet VarName = AttrSet;
572#include "llvm/Frontend/OpenMP/OMPKinds.def"
573
574 // Add attributes to the function declaration.
575 switch (FnID) {
576#define OMP_RTL_ATTRS(Enum, FnAttrSet, RetAttrSet, ArgAttrSets) \
577 case Enum: \
578 FnAttrs = FnAttrs.addAttributes(Ctx, FnAttrSet); \
579 addAttrSet(RetAttrs, RetAttrSet, /*Param*/ false); \
580 for (size_t ArgNo = 0; ArgNo < ArgAttrSets.size(); ++ArgNo) \
581 addAttrSet(ArgAttrs[ArgNo], ArgAttrSets[ArgNo]); \
582 Fn.setAttributes(AttributeList::get(Ctx, FnAttrs, RetAttrs, ArgAttrs)); \
583 break;
584#include "llvm/Frontend/OpenMP/OMPKinds.def"
585 default:
586 // Attributes are optional.
587 break;
588 }
589}
590
591FunctionCallee
592OpenMPIRBuilder::getOrCreateRuntimeFunction(Module &M, RuntimeFunction FnID) {
593 FunctionType *FnTy = nullptr;
594 Function *Fn = nullptr;
595
596 // Try to find the declation in the module first.
597 switch (FnID) {
598#define OMP_RTL(Enum, Str, IsVarArg, ReturnType, ...) \
599 case Enum: \
600 FnTy = FunctionType::get(ReturnType, ArrayRef<Type *>{__VA_ARGS__}, \
601 IsVarArg); \
602 Fn = M.getFunction(Str); \
603 break;
604#include "llvm/Frontend/OpenMP/OMPKinds.def"
605 }
606
607 if (!Fn) {
608 // Create a new declaration if we need one.
609 switch (FnID) {
610#define OMP_RTL(Enum, Str, ...) \
611 case Enum: \
612 Fn = Function::Create(FnTy, GlobalValue::ExternalLinkage, Str, M); \
613 break;
614#include "llvm/Frontend/OpenMP/OMPKinds.def"
615 }
616
617 // Add information if the runtime function takes a callback function
618 if (FnID == OMPRTL___kmpc_fork_call || FnID == OMPRTL___kmpc_fork_teams) {
619 if (!Fn->hasMetadata(KindID: LLVMContext::MD_callback)) {
620 LLVMContext &Ctx = Fn->getContext();
621 MDBuilder MDB(Ctx);
622 // Annotate the callback behavior of the runtime function:
623 // - The callback callee is argument number 2 (microtask).
624 // - The first two arguments of the callback callee are unknown (-1).
625 // - All variadic arguments to the runtime function are passed to the
626 // callback callee.
627 Fn->addMetadata(
628 KindID: LLVMContext::MD_callback,
629 MD&: *MDNode::get(Context&: Ctx, MDs: {MDB.createCallbackEncoding(
630 CalleeArgNo: 2, Arguments: {-1, -1}, /* VarArgsArePassed */ true)}));
631 }
632 }
633
634 LLVM_DEBUG(dbgs() << "Created OpenMP runtime function " << Fn->getName()
635 << " with type " << *Fn->getFunctionType() << "\n");
636 addAttributes(FnID, Fn&: *Fn);
637
638 } else {
639 LLVM_DEBUG(dbgs() << "Found OpenMP runtime function " << Fn->getName()
640 << " with type " << *Fn->getFunctionType() << "\n");
641 }
642
643 assert(Fn && "Failed to create OpenMP runtime function");
644
645 return {FnTy, Fn};
646}
647
648Function *OpenMPIRBuilder::getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID) {
649 FunctionCallee RTLFn = getOrCreateRuntimeFunction(M, FnID);
650 auto *Fn = dyn_cast<llvm::Function>(Val: RTLFn.getCallee());
651 assert(Fn && "Failed to create OpenMP runtime function pointer");
652 return Fn;
653}
654
655void OpenMPIRBuilder::initialize() { initializeTypes(M); }
656
657static void raiseUserConstantDataAllocasToEntryBlock(IRBuilderBase &Builder,
658 Function *Function) {
659 BasicBlock &EntryBlock = Function->getEntryBlock();
660 BasicBlock::iterator MoveLocInst = EntryBlock.getFirstNonPHIIt();
661
662 // Loop over blocks looking for constant allocas, skipping the entry block
663 // as any allocas there are already in the desired location.
664 for (auto Block = std::next(x: Function->begin(), n: 1); Block != Function->end();
665 Block++) {
666 for (auto Inst = Block->getReverseIterator()->begin();
667 Inst != Block->getReverseIterator()->end();) {
668 if (auto *AllocaInst = dyn_cast_if_present<llvm::AllocaInst>(Val&: Inst)) {
669 Inst++;
670 if (!isa<ConstantData>(Val: AllocaInst->getArraySize()))
671 continue;
672 AllocaInst->moveBeforePreserving(MovePos: MoveLocInst);
673 } else {
674 Inst++;
675 }
676 }
677 }
678}
679
680void OpenMPIRBuilder::finalize(Function *Fn) {
681 SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
682 SmallVector<BasicBlock *, 32> Blocks;
683 SmallVector<OutlineInfo, 16> DeferredOutlines;
684 for (OutlineInfo &OI : OutlineInfos) {
685 // Skip functions that have not finalized yet; may happen with nested
686 // function generation.
687 if (Fn && OI.getFunction() != Fn) {
688 DeferredOutlines.push_back(Elt: OI);
689 continue;
690 }
691
692 ParallelRegionBlockSet.clear();
693 Blocks.clear();
694 OI.collectBlocks(BlockSet&: ParallelRegionBlockSet, BlockVector&: Blocks);
695
696 Function *OuterFn = OI.getFunction();
697 CodeExtractorAnalysisCache CEAC(*OuterFn);
698 // If we generate code for the target device, we need to allocate
699 // struct for aggregate params in the device default alloca address space.
700 // OpenMP runtime requires that the params of the extracted functions are
701 // passed as zero address space pointers. This flag ensures that
702 // CodeExtractor generates correct code for extracted functions
703 // which are used by OpenMP runtime.
704 bool ArgsInZeroAddressSpace = Config.isTargetDevice();
705 CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
706 /* AggregateArgs */ true,
707 /* BlockFrequencyInfo */ nullptr,
708 /* BranchProbabilityInfo */ nullptr,
709 /* AssumptionCache */ nullptr,
710 /* AllowVarArgs */ true,
711 /* AllowAlloca */ true,
712 /* AllocaBlock*/ OI.OuterAllocaBB,
713 /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
714
715 LLVM_DEBUG(dbgs() << "Before outlining: " << *OuterFn << "\n");
716 LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
717 << " Exit: " << OI.ExitBB->getName() << "\n");
718 assert(Extractor.isEligible() &&
719 "Expected OpenMP outlining to be possible!");
720
721 for (auto *V : OI.ExcludeArgsFromAggregate)
722 Extractor.excludeArgFromAggregate(Arg: V);
723
724 Function *OutlinedFn = Extractor.extractCodeRegion(CEAC);
725
726 // Forward target-cpu, target-features attributes to the outlined function.
727 auto TargetCpuAttr = OuterFn->getFnAttribute(Kind: "target-cpu");
728 if (TargetCpuAttr.isStringAttribute())
729 OutlinedFn->addFnAttr(Attr: TargetCpuAttr);
730
731 auto TargetFeaturesAttr = OuterFn->getFnAttribute(Kind: "target-features");
732 if (TargetFeaturesAttr.isStringAttribute())
733 OutlinedFn->addFnAttr(Attr: TargetFeaturesAttr);
734
735 LLVM_DEBUG(dbgs() << "After outlining: " << *OuterFn << "\n");
736 LLVM_DEBUG(dbgs() << " Outlined function: " << *OutlinedFn << "\n");
737 assert(OutlinedFn->getReturnType()->isVoidTy() &&
738 "OpenMP outlined functions should not return a value!");
739
740 // For compability with the clang CG we move the outlined function after the
741 // one with the parallel region.
742 OutlinedFn->removeFromParent();
743 M.getFunctionList().insertAfter(where: OuterFn->getIterator(), New: OutlinedFn);
744
745 // Remove the artificial entry introduced by the extractor right away, we
746 // made our own entry block after all.
747 {
748 BasicBlock &ArtificialEntry = OutlinedFn->getEntryBlock();
749 assert(ArtificialEntry.getUniqueSuccessor() == OI.EntryBB);
750 assert(OI.EntryBB->getUniquePredecessor() == &ArtificialEntry);
751 // Move instructions from the to-be-deleted ArtificialEntry to the entry
752 // basic block of the parallel region. CodeExtractor generates
753 // instructions to unwrap the aggregate argument and may sink
754 // allocas/bitcasts for values that are solely used in the outlined region
755 // and do not escape.
756 assert(!ArtificialEntry.empty() &&
757 "Expected instructions to add in the outlined region entry");
758 for (BasicBlock::reverse_iterator It = ArtificialEntry.rbegin(),
759 End = ArtificialEntry.rend();
760 It != End;) {
761 Instruction &I = *It;
762 It++;
763
764 if (I.isTerminator()) {
765 // Absorb any debug value that terminator may have
766 if (OI.EntryBB->getTerminator())
767 OI.EntryBB->getTerminator()->adoptDbgRecords(
768 BB: &ArtificialEntry, It: I.getIterator(), InsertAtHead: false);
769 continue;
770 }
771
772 I.moveBeforePreserving(BB&: *OI.EntryBB, I: OI.EntryBB->getFirstInsertionPt());
773 }
774
775 OI.EntryBB->moveBefore(MovePos: &ArtificialEntry);
776 ArtificialEntry.eraseFromParent();
777 }
778 assert(&OutlinedFn->getEntryBlock() == OI.EntryBB);
779 assert(OutlinedFn && OutlinedFn->hasNUses(1));
780
781 // Run a user callback, e.g. to add attributes.
782 if (OI.PostOutlineCB)
783 OI.PostOutlineCB(*OutlinedFn);
784 }
785
786 // Remove work items that have been completed.
787 OutlineInfos = std::move(DeferredOutlines);
788
789 // The createTarget functions embeds user written code into
790 // the target region which may inject allocas which need to
791 // be moved to the entry block of our target or risk malformed
792 // optimisations by later passes, this is only relevant for
793 // the device pass which appears to be a little more delicate
794 // when it comes to optimisations (however, we do not block on
795 // that here, it's up to the inserter to the list to do so).
796 // This notbaly has to occur after the OutlinedInfo candidates
797 // have been extracted so we have an end product that will not
798 // be implicitly adversely affected by any raises unless
799 // intentionally appended to the list.
800 // NOTE: This only does so for ConstantData, it could be extended
801 // to ConstantExpr's with further effort, however, they should
802 // largely be folded when they get here. Extending it to runtime
803 // defined/read+writeable allocation sizes would be non-trivial
804 // (need to factor in movement of any stores to variables the
805 // allocation size depends on, as well as the usual loads,
806 // otherwise it'll yield the wrong result after movement) and
807 // likely be more suitable as an LLVM optimisation pass.
808 for (Function *F : ConstantAllocaRaiseCandidates)
809 raiseUserConstantDataAllocasToEntryBlock(Builder, Function: F);
810
811 EmitMetadataErrorReportFunctionTy &&ErrorReportFn =
812 [](EmitMetadataErrorKind Kind,
813 const TargetRegionEntryInfo &EntryInfo) -> void {
814 errs() << "Error of kind: " << Kind
815 << " when emitting offload entries and metadata during "
816 "OMPIRBuilder finalization \n";
817 };
818
819 if (!OffloadInfoManager.empty())
820 createOffloadEntriesAndInfoMetadata(ErrorReportFunction&: ErrorReportFn);
821
822 if (Config.EmitLLVMUsedMetaInfo.value_or(u: false)) {
823 std::vector<WeakTrackingVH> LLVMCompilerUsed = {
824 M.getGlobalVariable(Name: "__openmp_nvptx_data_transfer_temporary_storage")};
825 emitUsed(Name: "llvm.compiler.used", List: LLVMCompilerUsed);
826 }
827
828 IsFinalized = true;
829}
830
831bool OpenMPIRBuilder::isFinalized() { return IsFinalized; }
832
833OpenMPIRBuilder::~OpenMPIRBuilder() {
834 assert(OutlineInfos.empty() && "There must be no outstanding outlinings");
835}
836
837GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
838 IntegerType *I32Ty = Type::getInt32Ty(C&: M.getContext());
839 auto *GV =
840 new GlobalVariable(M, I32Ty,
841 /* isConstant = */ true, GlobalValue::WeakODRLinkage,
842 ConstantInt::get(Ty: I32Ty, V: Value), Name);
843 GV->setVisibility(GlobalValue::HiddenVisibility);
844
845 return GV;
846}
847
848void OpenMPIRBuilder::emitUsed(StringRef Name, ArrayRef<WeakTrackingVH> List) {
849 if (List.empty())
850 return;
851
852 // Convert List to what ConstantArray needs.
853 SmallVector<Constant *, 8> UsedArray;
854 UsedArray.resize(N: List.size());
855 for (unsigned I = 0, E = List.size(); I != E; ++I)
856 UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
857 C: cast<Constant>(Val: &*List[I]), Ty: Builder.getPtrTy());
858
859 if (UsedArray.empty())
860 return;
861 ArrayType *ATy = ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: UsedArray.size());
862
863 auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
864 ConstantArray::get(T: ATy, V: UsedArray), Name);
865
866 GV->setSection("llvm.metadata");
867}
868
869GlobalVariable *
870OpenMPIRBuilder::emitKernelExecutionMode(StringRef KernelName,
871 OMPTgtExecModeFlags Mode) {
872 auto *Int8Ty = Builder.getInt8Ty();
873 auto *GVMode = new GlobalVariable(
874 M, Int8Ty, /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
875 ConstantInt::get(Ty: Int8Ty, V: Mode), Twine(KernelName, "_exec_mode"));
876 GVMode->setVisibility(GlobalVariable::ProtectedVisibility);
877 return GVMode;
878}
879
880Constant *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr,
881 uint32_t SrcLocStrSize,
882 IdentFlag LocFlags,
883 unsigned Reserve2Flags) {
884 // Enable "C-mode".
885 LocFlags |= OMP_IDENT_FLAG_KMPC;
886
887 Constant *&Ident =
888 IdentMap[{SrcLocStr, uint64_t(LocFlags) << 31 | Reserve2Flags}];
889 if (!Ident) {
890 Constant *I32Null = ConstantInt::getNullValue(Ty: Int32);
891 Constant *IdentData[] = {I32Null,
892 ConstantInt::get(Ty: Int32, V: uint32_t(LocFlags)),
893 ConstantInt::get(Ty: Int32, V: Reserve2Flags),
894 ConstantInt::get(Ty: Int32, V: SrcLocStrSize), SrcLocStr};
895 Constant *Initializer =
896 ConstantStruct::get(T: OpenMPIRBuilder::Ident, V: IdentData);
897
898 // Look for existing encoding of the location + flags, not needed but
899 // minimizes the difference to the existing solution while we transition.
900 for (GlobalVariable &GV : M.globals())
901 if (GV.getValueType() == OpenMPIRBuilder::Ident && GV.hasInitializer())
902 if (GV.getInitializer() == Initializer)
903 Ident = &GV;
904
905 if (!Ident) {
906 auto *GV = new GlobalVariable(
907 M, OpenMPIRBuilder::Ident,
908 /* isConstant = */ true, GlobalValue::PrivateLinkage, Initializer, "",
909 nullptr, GlobalValue::NotThreadLocal,
910 M.getDataLayout().getDefaultGlobalsAddressSpace());
911 GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
912 GV->setAlignment(Align(8));
913 Ident = GV;
914 }
915 }
916
917 return ConstantExpr::getPointerBitCastOrAddrSpaceCast(C: Ident, Ty: IdentPtr);
918}
919
920Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef LocStr,
921 uint32_t &SrcLocStrSize) {
922 SrcLocStrSize = LocStr.size();
923 Constant *&SrcLocStr = SrcLocStrMap[LocStr];
924 if (!SrcLocStr) {
925 Constant *Initializer =
926 ConstantDataArray::getString(Context&: M.getContext(), Initializer: LocStr);
927
928 // Look for existing encoding of the location, not needed but minimizes the
929 // difference to the existing solution while we transition.
930 for (GlobalVariable &GV : M.globals())
931 if (GV.isConstant() && GV.hasInitializer() &&
932 GV.getInitializer() == Initializer)
933 return SrcLocStr = ConstantExpr::getPointerCast(C: &GV, Ty: Int8Ptr);
934
935 SrcLocStr = Builder.CreateGlobalString(Str: LocStr, /* Name */ "",
936 /* AddressSpace */ 0, M: &M);
937 }
938 return SrcLocStr;
939}
940
941Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef FunctionName,
942 StringRef FileName,
943 unsigned Line, unsigned Column,
944 uint32_t &SrcLocStrSize) {
945 SmallString<128> Buffer;
946 Buffer.push_back(Elt: ';');
947 Buffer.append(RHS: FileName);
948 Buffer.push_back(Elt: ';');
949 Buffer.append(RHS: FunctionName);
950 Buffer.push_back(Elt: ';');
951 Buffer.append(RHS: std::to_string(val: Line));
952 Buffer.push_back(Elt: ';');
953 Buffer.append(RHS: std::to_string(val: Column));
954 Buffer.push_back(Elt: ';');
955 Buffer.push_back(Elt: ';');
956 return getOrCreateSrcLocStr(LocStr: Buffer.str(), SrcLocStrSize);
957}
958
959Constant *
960OpenMPIRBuilder::getOrCreateDefaultSrcLocStr(uint32_t &SrcLocStrSize) {
961 StringRef UnknownLoc = ";unknown;unknown;0;0;;";
962 return getOrCreateSrcLocStr(LocStr: UnknownLoc, SrcLocStrSize);
963}
964
965Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(DebugLoc DL,
966 uint32_t &SrcLocStrSize,
967 Function *F) {
968 DILocation *DIL = DL.get();
969 if (!DIL)
970 return getOrCreateDefaultSrcLocStr(SrcLocStrSize);
971 StringRef FileName = M.getName();
972 if (DIFile *DIF = DIL->getFile())
973 if (std::optional<StringRef> Source = DIF->getSource())
974 FileName = *Source;
975 StringRef Function = DIL->getScope()->getSubprogram()->getName();
976 if (Function.empty() && F)
977 Function = F->getName();
978 return getOrCreateSrcLocStr(FunctionName: Function, FileName, Line: DIL->getLine(),
979 Column: DIL->getColumn(), SrcLocStrSize);
980}
981
982Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(const LocationDescription &Loc,
983 uint32_t &SrcLocStrSize) {
984 return getOrCreateSrcLocStr(DL: Loc.DL, SrcLocStrSize,
985 F: Loc.IP.getBlock()->getParent());
986}
987
988Value *OpenMPIRBuilder::getOrCreateThreadID(Value *Ident) {
989 return Builder.CreateCall(
990 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_global_thread_num), Args: Ident,
991 Name: "omp_global_thread_num");
992}
993
994OpenMPIRBuilder::InsertPointOrErrorTy
995OpenMPIRBuilder::createBarrier(const LocationDescription &Loc, Directive Kind,
996 bool ForceSimpleCall, bool CheckCancelFlag) {
997 if (!updateToLocation(Loc))
998 return Loc.IP;
999
1000 // Build call __kmpc_cancel_barrier(loc, thread_id) or
1001 // __kmpc_barrier(loc, thread_id);
1002
1003 IdentFlag BarrierLocFlags;
1004 switch (Kind) {
1005 case OMPD_for:
1006 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_FOR;
1007 break;
1008 case OMPD_sections:
1009 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SECTIONS;
1010 break;
1011 case OMPD_single:
1012 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SINGLE;
1013 break;
1014 case OMPD_barrier:
1015 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_EXPL;
1016 break;
1017 default:
1018 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL;
1019 break;
1020 }
1021
1022 uint32_t SrcLocStrSize;
1023 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1024 Value *Args[] = {
1025 getOrCreateIdent(SrcLocStr, SrcLocStrSize, LocFlags: BarrierLocFlags),
1026 getOrCreateThreadID(Ident: getOrCreateIdent(SrcLocStr, SrcLocStrSize))};
1027
1028 // If we are in a cancellable parallel region, barriers are cancellation
1029 // points.
1030 // TODO: Check why we would force simple calls or to ignore the cancel flag.
1031 bool UseCancelBarrier =
1032 !ForceSimpleCall && isLastFinalizationInfoCancellable(DK: OMPD_parallel);
1033
1034 Value *Result =
1035 Builder.CreateCall(Callee: getOrCreateRuntimeFunctionPtr(
1036 FnID: UseCancelBarrier ? OMPRTL___kmpc_cancel_barrier
1037 : OMPRTL___kmpc_barrier),
1038 Args);
1039
1040 if (UseCancelBarrier && CheckCancelFlag)
1041 if (Error Err = emitCancelationCheckImpl(CancelFlag: Result, CanceledDirective: OMPD_parallel))
1042 return Err;
1043
1044 return Builder.saveIP();
1045}
1046
1047OpenMPIRBuilder::InsertPointOrErrorTy
1048OpenMPIRBuilder::createCancel(const LocationDescription &Loc,
1049 Value *IfCondition,
1050 omp::Directive CanceledDirective) {
1051 if (!updateToLocation(Loc))
1052 return Loc.IP;
1053
1054 // LLVM utilities like blocks with terminators.
1055 auto *UI = Builder.CreateUnreachable();
1056
1057 Instruction *ThenTI = UI, *ElseTI = nullptr;
1058 if (IfCondition)
1059 SplitBlockAndInsertIfThenElse(Cond: IfCondition, SplitBefore: UI, ThenTerm: &ThenTI, ElseTerm: &ElseTI);
1060 Builder.SetInsertPoint(ThenTI);
1061
1062 Value *CancelKind = nullptr;
1063 switch (CanceledDirective) {
1064#define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value) \
1065 case DirectiveEnum: \
1066 CancelKind = Builder.getInt32(Value); \
1067 break;
1068#include "llvm/Frontend/OpenMP/OMPKinds.def"
1069 default:
1070 llvm_unreachable("Unknown cancel kind!");
1071 }
1072
1073 uint32_t SrcLocStrSize;
1074 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1075 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1076 Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
1077 Value *Result = Builder.CreateCall(
1078 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_cancel), Args);
1079 auto ExitCB = [this, CanceledDirective, Loc](InsertPointTy IP) -> Error {
1080 if (CanceledDirective == OMPD_parallel) {
1081 IRBuilder<>::InsertPointGuard IPG(Builder);
1082 Builder.restoreIP(IP);
1083 return createBarrier(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
1084 Kind: omp::Directive::OMPD_unknown,
1085 /* ForceSimpleCall */ false,
1086 /* CheckCancelFlag */ false)
1087 .takeError();
1088 }
1089 return Error::success();
1090 };
1091
1092 // The actual cancel logic is shared with others, e.g., cancel_barriers.
1093 if (Error Err = emitCancelationCheckImpl(CancelFlag: Result, CanceledDirective, ExitCB))
1094 return Err;
1095
1096 // Update the insertion point and remove the terminator we introduced.
1097 Builder.SetInsertPoint(UI->getParent());
1098 UI->eraseFromParent();
1099
1100 return Builder.saveIP();
1101}
1102
1103OpenMPIRBuilder::InsertPointOrErrorTy
1104OpenMPIRBuilder::createCancellationPoint(const LocationDescription &Loc,
1105 omp::Directive CanceledDirective) {
1106 if (!updateToLocation(Loc))
1107 return Loc.IP;
1108
1109 // LLVM utilities like blocks with terminators.
1110 auto *UI = Builder.CreateUnreachable();
1111 Builder.SetInsertPoint(UI);
1112
1113 Value *CancelKind = nullptr;
1114 switch (CanceledDirective) {
1115#define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value) \
1116 case DirectiveEnum: \
1117 CancelKind = Builder.getInt32(Value); \
1118 break;
1119#include "llvm/Frontend/OpenMP/OMPKinds.def"
1120 default:
1121 llvm_unreachable("Unknown cancel kind!");
1122 }
1123
1124 uint32_t SrcLocStrSize;
1125 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1126 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1127 Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
1128 Value *Result = Builder.CreateCall(
1129 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_cancellationpoint), Args);
1130 auto ExitCB = [this, CanceledDirective, Loc](InsertPointTy IP) -> Error {
1131 if (CanceledDirective == OMPD_parallel) {
1132 IRBuilder<>::InsertPointGuard IPG(Builder);
1133 Builder.restoreIP(IP);
1134 return createBarrier(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
1135 Kind: omp::Directive::OMPD_unknown,
1136 /* ForceSimpleCall */ false,
1137 /* CheckCancelFlag */ false)
1138 .takeError();
1139 }
1140 return Error::success();
1141 };
1142
1143 // The actual cancel logic is shared with others, e.g., cancel_barriers.
1144 if (Error Err = emitCancelationCheckImpl(CancelFlag: Result, CanceledDirective, ExitCB))
1145 return Err;
1146
1147 // Update the insertion point and remove the terminator we introduced.
1148 Builder.SetInsertPoint(UI->getParent());
1149 UI->eraseFromParent();
1150
1151 return Builder.saveIP();
1152}
1153
1154OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetKernel(
1155 const LocationDescription &Loc, InsertPointTy AllocaIP, Value *&Return,
1156 Value *Ident, Value *DeviceID, Value *NumTeams, Value *NumThreads,
1157 Value *HostPtr, ArrayRef<Value *> KernelArgs) {
1158 if (!updateToLocation(Loc))
1159 return Loc.IP;
1160
1161 Builder.restoreIP(IP: AllocaIP);
1162 auto *KernelArgsPtr =
1163 Builder.CreateAlloca(Ty: OpenMPIRBuilder::KernelArgs, ArraySize: nullptr, Name: "kernel_args");
1164 Builder.restoreIP(IP: Loc.IP);
1165
1166 for (unsigned I = 0, Size = KernelArgs.size(); I != Size; ++I) {
1167 llvm::Value *Arg =
1168 Builder.CreateStructGEP(Ty: OpenMPIRBuilder::KernelArgs, Ptr: KernelArgsPtr, Idx: I);
1169 Builder.CreateAlignedStore(
1170 Val: KernelArgs[I], Ptr: Arg,
1171 Align: M.getDataLayout().getPrefTypeAlign(Ty: KernelArgs[I]->getType()));
1172 }
1173
1174 SmallVector<Value *> OffloadingArgs{Ident, DeviceID, NumTeams,
1175 NumThreads, HostPtr, KernelArgsPtr};
1176
1177 Return = Builder.CreateCall(
1178 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___tgt_target_kernel),
1179 Args: OffloadingArgs);
1180
1181 return Builder.saveIP();
1182}
1183
1184OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitKernelLaunch(
1185 const LocationDescription &Loc, Value *OutlinedFnID,
1186 EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
1187 Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP) {
1188
1189 if (!updateToLocation(Loc))
1190 return Loc.IP;
1191
1192 Builder.restoreIP(IP: Loc.IP);
1193 // On top of the arrays that were filled up, the target offloading call
1194 // takes as arguments the device id as well as the host pointer. The host
1195 // pointer is used by the runtime library to identify the current target
1196 // region, so it only has to be unique and not necessarily point to
1197 // anything. It could be the pointer to the outlined function that
1198 // implements the target region, but we aren't using that so that the
1199 // compiler doesn't need to keep that, and could therefore inline the host
1200 // function if proven worthwhile during optimization.
1201
1202 // From this point on, we need to have an ID of the target region defined.
1203 assert(OutlinedFnID && "Invalid outlined function ID!");
1204 (void)OutlinedFnID;
1205
1206 // Return value of the runtime offloading call.
1207 Value *Return = nullptr;
1208
1209 // Arguments for the target kernel.
1210 SmallVector<Value *> ArgsVector;
1211 getKernelArgsVector(KernelArgs&: Args, Builder, ArgsVector);
1212
1213 // The target region is an outlined function launched by the runtime
1214 // via calls to __tgt_target_kernel().
1215 //
1216 // Note that on the host and CPU targets, the runtime implementation of
1217 // these calls simply call the outlined function without forking threads.
1218 // The outlined functions themselves have runtime calls to
1219 // __kmpc_fork_teams() and __kmpc_fork() for this purpose, codegen'd by
1220 // the compiler in emitTeamsCall() and emitParallelCall().
1221 //
1222 // In contrast, on the NVPTX target, the implementation of
1223 // __tgt_target_teams() launches a GPU kernel with the requested number
1224 // of teams and threads so no additional calls to the runtime are required.
1225 // Check the error code and execute the host version if required.
1226 Builder.restoreIP(IP: emitTargetKernel(
1227 Loc: Builder, AllocaIP, Return, Ident: RTLoc, DeviceID, NumTeams: Args.NumTeams.front(),
1228 NumThreads: Args.NumThreads.front(), HostPtr: OutlinedFnID, KernelArgs: ArgsVector));
1229
1230 BasicBlock *OffloadFailedBlock =
1231 BasicBlock::Create(Context&: Builder.getContext(), Name: "omp_offload.failed");
1232 BasicBlock *OffloadContBlock =
1233 BasicBlock::Create(Context&: Builder.getContext(), Name: "omp_offload.cont");
1234 Value *Failed = Builder.CreateIsNotNull(Arg: Return);
1235 Builder.CreateCondBr(Cond: Failed, True: OffloadFailedBlock, False: OffloadContBlock);
1236
1237 auto CurFn = Builder.GetInsertBlock()->getParent();
1238 emitBlock(BB: OffloadFailedBlock, CurFn);
1239 InsertPointOrErrorTy AfterIP = EmitTargetCallFallbackCB(Builder.saveIP());
1240 if (!AfterIP)
1241 return AfterIP.takeError();
1242 Builder.restoreIP(IP: *AfterIP);
1243 emitBranch(Target: OffloadContBlock);
1244 emitBlock(BB: OffloadContBlock, CurFn, /*IsFinished=*/true);
1245 return Builder.saveIP();
1246}
1247
1248Error OpenMPIRBuilder::emitCancelationCheckImpl(
1249 Value *CancelFlag, omp::Directive CanceledDirective,
1250 FinalizeCallbackTy ExitCB) {
1251 assert(isLastFinalizationInfoCancellable(CanceledDirective) &&
1252 "Unexpected cancellation!");
1253
1254 // For a cancel barrier we create two new blocks.
1255 BasicBlock *BB = Builder.GetInsertBlock();
1256 BasicBlock *NonCancellationBlock;
1257 if (Builder.GetInsertPoint() == BB->end()) {
1258 // TODO: This branch will not be needed once we moved to the
1259 // OpenMPIRBuilder codegen completely.
1260 NonCancellationBlock = BasicBlock::Create(
1261 Context&: BB->getContext(), Name: BB->getName() + ".cont", Parent: BB->getParent());
1262 } else {
1263 NonCancellationBlock = SplitBlock(Old: BB, SplitPt: &*Builder.GetInsertPoint());
1264 BB->getTerminator()->eraseFromParent();
1265 Builder.SetInsertPoint(BB);
1266 }
1267 BasicBlock *CancellationBlock = BasicBlock::Create(
1268 Context&: BB->getContext(), Name: BB->getName() + ".cncl", Parent: BB->getParent());
1269
1270 // Jump to them based on the return value.
1271 Value *Cmp = Builder.CreateIsNull(Arg: CancelFlag);
1272 Builder.CreateCondBr(Cond: Cmp, True: NonCancellationBlock, False: CancellationBlock,
1273 /* TODO weight */ BranchWeights: nullptr, Unpredictable: nullptr);
1274
1275 // From the cancellation block we finalize all variables and go to the
1276 // post finalization block that is known to the FiniCB callback.
1277 Builder.SetInsertPoint(CancellationBlock);
1278 if (ExitCB)
1279 if (Error Err = ExitCB(Builder.saveIP()))
1280 return Err;
1281 auto &FI = FinalizationStack.back();
1282 if (Error Err = FI.FiniCB(Builder.saveIP()))
1283 return Err;
1284
1285 // The continuation block is where code generation continues.
1286 Builder.SetInsertPoint(TheBB: NonCancellationBlock, IP: NonCancellationBlock->begin());
1287 return Error::success();
1288}
1289
1290// Callback used to create OpenMP runtime calls to support
1291// omp parallel clause for the device.
1292// We need to use this callback to replace call to the OutlinedFn in OuterFn
1293// by the call to the OpenMP DeviceRTL runtime function (kmpc_parallel_51)
1294static void targetParallelCallback(
1295 OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn, Function *OuterFn,
1296 BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
1297 Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1298 Value *ThreadID, const SmallVector<Instruction *, 4> &ToBeDeleted) {
1299 // Add some known attributes.
1300 IRBuilder<> &Builder = OMPIRBuilder->Builder;
1301 OutlinedFn.addParamAttr(ArgNo: 0, Kind: Attribute::NoAlias);
1302 OutlinedFn.addParamAttr(ArgNo: 1, Kind: Attribute::NoAlias);
1303 OutlinedFn.addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
1304 OutlinedFn.addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
1305 OutlinedFn.addFnAttr(Kind: Attribute::NoUnwind);
1306
1307 assert(OutlinedFn.arg_size() >= 2 &&
1308 "Expected at least tid and bounded tid as arguments");
1309 unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1310
1311 CallInst *CI = cast<CallInst>(Val: OutlinedFn.user_back());
1312 assert(CI && "Expected call instruction to outlined function");
1313 CI->getParent()->setName("omp_parallel");
1314
1315 Builder.SetInsertPoint(CI);
1316 Type *PtrTy = OMPIRBuilder->VoidPtr;
1317 Value *NullPtrValue = Constant::getNullValue(Ty: PtrTy);
1318
1319 // Add alloca for kernel args
1320 OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP();
1321 Builder.SetInsertPoint(TheBB: OuterAllocaBB, IP: OuterAllocaBB->getFirstInsertionPt());
1322 AllocaInst *ArgsAlloca =
1323 Builder.CreateAlloca(Ty: ArrayType::get(ElementType: PtrTy, NumElements: NumCapturedVars));
1324 Value *Args = ArgsAlloca;
1325 // Add address space cast if array for storing arguments is not allocated
1326 // in address space 0
1327 if (ArgsAlloca->getAddressSpace())
1328 Args = Builder.CreatePointerCast(V: ArgsAlloca, DestTy: PtrTy);
1329 Builder.restoreIP(IP: CurrentIP);
1330
1331 // Store captured vars which are used by kmpc_parallel_51
1332 for (unsigned Idx = 0; Idx < NumCapturedVars; Idx++) {
1333 Value *V = *(CI->arg_begin() + 2 + Idx);
1334 Value *StoreAddress = Builder.CreateConstInBoundsGEP2_64(
1335 Ty: ArrayType::get(ElementType: PtrTy, NumElements: NumCapturedVars), Ptr: Args, Idx0: 0, Idx1: Idx);
1336 Builder.CreateStore(Val: V, Ptr: StoreAddress);
1337 }
1338
1339 Value *Cond =
1340 IfCondition ? Builder.CreateSExtOrTrunc(V: IfCondition, DestTy: OMPIRBuilder->Int32)
1341 : Builder.getInt32(C: 1);
1342
1343 // Build kmpc_parallel_51 call
1344 Value *Parallel51CallArgs[] = {
1345 /* identifier*/ Ident,
1346 /* global thread num*/ ThreadID,
1347 /* if expression */ Cond,
1348 /* number of threads */ NumThreads ? NumThreads : Builder.getInt32(C: -1),
1349 /* Proc bind */ Builder.getInt32(C: -1),
1350 /* outlined function */ &OutlinedFn,
1351 /* wrapper function */ NullPtrValue,
1352 /* arguments of the outlined funciton*/ Args,
1353 /* number of arguments */ Builder.getInt64(C: NumCapturedVars)};
1354
1355 FunctionCallee RTLFn =
1356 OMPIRBuilder->getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_parallel_51);
1357
1358 Builder.CreateCall(Callee: RTLFn, Args: Parallel51CallArgs);
1359
1360 LLVM_DEBUG(dbgs() << "With kmpc_parallel_51 placed: "
1361 << *Builder.GetInsertBlock()->getParent() << "\n");
1362
1363 // Initialize the local TID stack location with the argument value.
1364 Builder.SetInsertPoint(PrivTID);
1365 Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1366 Builder.CreateStore(Val: Builder.CreateLoad(Ty: OMPIRBuilder->Int32, Ptr: OutlinedAI),
1367 Ptr: PrivTIDAddr);
1368
1369 // Remove redundant call to the outlined function.
1370 CI->eraseFromParent();
1371
1372 for (Instruction *I : ToBeDeleted) {
1373 I->eraseFromParent();
1374 }
1375}
1376
1377// Callback used to create OpenMP runtime calls to support
1378// omp parallel clause for the host.
1379// We need to use this callback to replace call to the OutlinedFn in OuterFn
1380// by the call to the OpenMP host runtime function ( __kmpc_fork_call[_if])
1381static void
1382hostParallelCallback(OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn,
1383 Function *OuterFn, Value *Ident, Value *IfCondition,
1384 Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1385 const SmallVector<Instruction *, 4> &ToBeDeleted) {
1386 IRBuilder<> &Builder = OMPIRBuilder->Builder;
1387 FunctionCallee RTLFn;
1388 if (IfCondition) {
1389 RTLFn =
1390 OMPIRBuilder->getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_fork_call_if);
1391 } else {
1392 RTLFn =
1393 OMPIRBuilder->getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_fork_call);
1394 }
1395 if (auto *F = dyn_cast<Function>(Val: RTLFn.getCallee())) {
1396 if (!F->hasMetadata(KindID: LLVMContext::MD_callback)) {
1397 LLVMContext &Ctx = F->getContext();
1398 MDBuilder MDB(Ctx);
1399 // Annotate the callback behavior of the __kmpc_fork_call:
1400 // - The callback callee is argument number 2 (microtask).
1401 // - The first two arguments of the callback callee are unknown (-1).
1402 // - All variadic arguments to the __kmpc_fork_call are passed to the
1403 // callback callee.
1404 F->addMetadata(KindID: LLVMContext::MD_callback,
1405 MD&: *MDNode::get(Context&: Ctx, MDs: {MDB.createCallbackEncoding(
1406 CalleeArgNo: 2, Arguments: {-1, -1},
1407 /* VarArgsArePassed */ true)}));
1408 }
1409 }
1410 // Add some known attributes.
1411 OutlinedFn.addParamAttr(ArgNo: 0, Kind: Attribute::NoAlias);
1412 OutlinedFn.addParamAttr(ArgNo: 1, Kind: Attribute::NoAlias);
1413 OutlinedFn.addFnAttr(Kind: Attribute::NoUnwind);
1414
1415 assert(OutlinedFn.arg_size() >= 2 &&
1416 "Expected at least tid and bounded tid as arguments");
1417 unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1418
1419 CallInst *CI = cast<CallInst>(Val: OutlinedFn.user_back());
1420 CI->getParent()->setName("omp_parallel");
1421 Builder.SetInsertPoint(CI);
1422
1423 // Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn);
1424 Value *ForkCallArgs[] = {Ident, Builder.getInt32(C: NumCapturedVars),
1425 &OutlinedFn};
1426
1427 SmallVector<Value *, 16> RealArgs;
1428 RealArgs.append(in_start: std::begin(arr&: ForkCallArgs), in_end: std::end(arr&: ForkCallArgs));
1429 if (IfCondition) {
1430 Value *Cond = Builder.CreateSExtOrTrunc(V: IfCondition, DestTy: OMPIRBuilder->Int32);
1431 RealArgs.push_back(Elt: Cond);
1432 }
1433 RealArgs.append(in_start: CI->arg_begin() + /* tid & bound tid */ 2, in_end: CI->arg_end());
1434
1435 // __kmpc_fork_call_if always expects a void ptr as the last argument
1436 // If there are no arguments, pass a null pointer.
1437 auto PtrTy = OMPIRBuilder->VoidPtr;
1438 if (IfCondition && NumCapturedVars == 0) {
1439 Value *NullPtrValue = Constant::getNullValue(Ty: PtrTy);
1440 RealArgs.push_back(Elt: NullPtrValue);
1441 }
1442
1443 Builder.CreateCall(Callee: RTLFn, Args: RealArgs);
1444
1445 LLVM_DEBUG(dbgs() << "With fork_call placed: "
1446 << *Builder.GetInsertBlock()->getParent() << "\n");
1447
1448 // Initialize the local TID stack location with the argument value.
1449 Builder.SetInsertPoint(PrivTID);
1450 Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1451 Builder.CreateStore(Val: Builder.CreateLoad(Ty: OMPIRBuilder->Int32, Ptr: OutlinedAI),
1452 Ptr: PrivTIDAddr);
1453
1454 // Remove redundant call to the outlined function.
1455 CI->eraseFromParent();
1456
1457 for (Instruction *I : ToBeDeleted) {
1458 I->eraseFromParent();
1459 }
1460}
1461
1462OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel(
1463 const LocationDescription &Loc, InsertPointTy OuterAllocaIP,
1464 BodyGenCallbackTy BodyGenCB, PrivatizeCallbackTy PrivCB,
1465 FinalizeCallbackTy FiniCB, Value *IfCondition, Value *NumThreads,
1466 omp::ProcBindKind ProcBind, bool IsCancellable) {
1467 assert(!isConflictIP(Loc.IP, OuterAllocaIP) && "IPs must not be ambiguous");
1468
1469 if (!updateToLocation(Loc))
1470 return Loc.IP;
1471
1472 uint32_t SrcLocStrSize;
1473 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1474 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1475 Value *ThreadID = getOrCreateThreadID(Ident);
1476 // If we generate code for the target device, we need to allocate
1477 // struct for aggregate params in the device default alloca address space.
1478 // OpenMP runtime requires that the params of the extracted functions are
1479 // passed as zero address space pointers. This flag ensures that extracted
1480 // function arguments are declared in zero address space
1481 bool ArgsInZeroAddressSpace = Config.isTargetDevice();
1482
1483 // Build call __kmpc_push_num_threads(&Ident, global_tid, num_threads)
1484 // only if we compile for host side.
1485 if (NumThreads && !Config.isTargetDevice()) {
1486 Value *Args[] = {
1487 Ident, ThreadID,
1488 Builder.CreateIntCast(V: NumThreads, DestTy: Int32, /*isSigned*/ false)};
1489 Builder.CreateCall(
1490 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_push_num_threads), Args);
1491 }
1492
1493 if (ProcBind != OMP_PROC_BIND_default) {
1494 // Build call __kmpc_push_proc_bind(&Ident, global_tid, proc_bind)
1495 Value *Args[] = {
1496 Ident, ThreadID,
1497 ConstantInt::get(Ty: Int32, V: unsigned(ProcBind), /*isSigned=*/IsSigned: true)};
1498 Builder.CreateCall(
1499 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_push_proc_bind), Args);
1500 }
1501
1502 BasicBlock *InsertBB = Builder.GetInsertBlock();
1503 Function *OuterFn = InsertBB->getParent();
1504
1505 // Save the outer alloca block because the insertion iterator may get
1506 // invalidated and we still need this later.
1507 BasicBlock *OuterAllocaBlock = OuterAllocaIP.getBlock();
1508
1509 // Vector to remember instructions we used only during the modeling but which
1510 // we want to delete at the end.
1511 SmallVector<Instruction *, 4> ToBeDeleted;
1512
1513 // Change the location to the outer alloca insertion point to create and
1514 // initialize the allocas we pass into the parallel region.
1515 InsertPointTy NewOuter(OuterAllocaBlock, OuterAllocaBlock->begin());
1516 Builder.restoreIP(IP: NewOuter);
1517 AllocaInst *TIDAddrAlloca = Builder.CreateAlloca(Ty: Int32, ArraySize: nullptr, Name: "tid.addr");
1518 AllocaInst *ZeroAddrAlloca =
1519 Builder.CreateAlloca(Ty: Int32, ArraySize: nullptr, Name: "zero.addr");
1520 Instruction *TIDAddr = TIDAddrAlloca;
1521 Instruction *ZeroAddr = ZeroAddrAlloca;
1522 if (ArgsInZeroAddressSpace && M.getDataLayout().getAllocaAddrSpace() != 0) {
1523 // Add additional casts to enforce pointers in zero address space
1524 TIDAddr = new AddrSpaceCastInst(
1525 TIDAddrAlloca, PointerType ::get(C&: M.getContext(), AddressSpace: 0), "tid.addr.ascast");
1526 TIDAddr->insertAfter(InsertPos: TIDAddrAlloca->getIterator());
1527 ToBeDeleted.push_back(Elt: TIDAddr);
1528 ZeroAddr = new AddrSpaceCastInst(ZeroAddrAlloca,
1529 PointerType ::get(C&: M.getContext(), AddressSpace: 0),
1530 "zero.addr.ascast");
1531 ZeroAddr->insertAfter(InsertPos: ZeroAddrAlloca->getIterator());
1532 ToBeDeleted.push_back(Elt: ZeroAddr);
1533 }
1534
1535 // We only need TIDAddr and ZeroAddr for modeling purposes to get the
1536 // associated arguments in the outlined function, so we delete them later.
1537 ToBeDeleted.push_back(Elt: TIDAddrAlloca);
1538 ToBeDeleted.push_back(Elt: ZeroAddrAlloca);
1539
1540 // Create an artificial insertion point that will also ensure the blocks we
1541 // are about to split are not degenerated.
1542 auto *UI = new UnreachableInst(Builder.getContext(), InsertBB);
1543
1544 BasicBlock *EntryBB = UI->getParent();
1545 BasicBlock *PRegEntryBB = EntryBB->splitBasicBlock(I: UI, BBName: "omp.par.entry");
1546 BasicBlock *PRegBodyBB = PRegEntryBB->splitBasicBlock(I: UI, BBName: "omp.par.region");
1547 BasicBlock *PRegPreFiniBB =
1548 PRegBodyBB->splitBasicBlock(I: UI, BBName: "omp.par.pre_finalize");
1549 BasicBlock *PRegExitBB = PRegPreFiniBB->splitBasicBlock(I: UI, BBName: "omp.par.exit");
1550
1551 auto FiniCBWrapper = [&](InsertPointTy IP) {
1552 // Hide "open-ended" blocks from the given FiniCB by setting the right jump
1553 // target to the region exit block.
1554 if (IP.getBlock()->end() == IP.getPoint()) {
1555 IRBuilder<>::InsertPointGuard IPG(Builder);
1556 Builder.restoreIP(IP);
1557 Instruction *I = Builder.CreateBr(Dest: PRegExitBB);
1558 IP = InsertPointTy(I->getParent(), I->getIterator());
1559 }
1560 assert(IP.getBlock()->getTerminator()->getNumSuccessors() == 1 &&
1561 IP.getBlock()->getTerminator()->getSuccessor(0) == PRegExitBB &&
1562 "Unexpected insertion point for finalization call!");
1563 return FiniCB(IP);
1564 };
1565
1566 FinalizationStack.push_back(Elt: {.FiniCB: FiniCBWrapper, .DK: OMPD_parallel, .IsCancellable: IsCancellable});
1567
1568 // Generate the privatization allocas in the block that will become the entry
1569 // of the outlined function.
1570 Builder.SetInsertPoint(PRegEntryBB->getTerminator());
1571 InsertPointTy InnerAllocaIP = Builder.saveIP();
1572
1573 AllocaInst *PrivTIDAddr =
1574 Builder.CreateAlloca(Ty: Int32, ArraySize: nullptr, Name: "tid.addr.local");
1575 Instruction *PrivTID = Builder.CreateLoad(Ty: Int32, Ptr: PrivTIDAddr, Name: "tid");
1576
1577 // Add some fake uses for OpenMP provided arguments.
1578 ToBeDeleted.push_back(Elt: Builder.CreateLoad(Ty: Int32, Ptr: TIDAddr, Name: "tid.addr.use"));
1579 Instruction *ZeroAddrUse =
1580 Builder.CreateLoad(Ty: Int32, Ptr: ZeroAddr, Name: "zero.addr.use");
1581 ToBeDeleted.push_back(Elt: ZeroAddrUse);
1582
1583 // EntryBB
1584 // |
1585 // V
1586 // PRegionEntryBB <- Privatization allocas are placed here.
1587 // |
1588 // V
1589 // PRegionBodyBB <- BodeGen is invoked here.
1590 // |
1591 // V
1592 // PRegPreFiniBB <- The block we will start finalization from.
1593 // |
1594 // V
1595 // PRegionExitBB <- A common exit to simplify block collection.
1596 //
1597
1598 LLVM_DEBUG(dbgs() << "Before body codegen: " << *OuterFn << "\n");
1599
1600 // Let the caller create the body.
1601 assert(BodyGenCB && "Expected body generation callback!");
1602 InsertPointTy CodeGenIP(PRegBodyBB, PRegBodyBB->begin());
1603 if (Error Err = BodyGenCB(InnerAllocaIP, CodeGenIP))
1604 return Err;
1605
1606 LLVM_DEBUG(dbgs() << "After body codegen: " << *OuterFn << "\n");
1607
1608 OutlineInfo OI;
1609 if (Config.isTargetDevice()) {
1610 // Generate OpenMP target specific runtime call
1611 OI.PostOutlineCB = [=, ToBeDeletedVec =
1612 std::move(ToBeDeleted)](Function &OutlinedFn) {
1613 targetParallelCallback(OMPIRBuilder: this, OutlinedFn, OuterFn, OuterAllocaBB: OuterAllocaBlock, Ident,
1614 IfCondition, NumThreads, PrivTID, PrivTIDAddr,
1615 ThreadID, ToBeDeleted: ToBeDeletedVec);
1616 };
1617 } else {
1618 // Generate OpenMP host runtime call
1619 OI.PostOutlineCB = [=, ToBeDeletedVec =
1620 std::move(ToBeDeleted)](Function &OutlinedFn) {
1621 hostParallelCallback(OMPIRBuilder: this, OutlinedFn, OuterFn, Ident, IfCondition,
1622 PrivTID, PrivTIDAddr, ToBeDeleted: ToBeDeletedVec);
1623 };
1624 }
1625
1626 OI.OuterAllocaBB = OuterAllocaBlock;
1627 OI.EntryBB = PRegEntryBB;
1628 OI.ExitBB = PRegExitBB;
1629
1630 SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
1631 SmallVector<BasicBlock *, 32> Blocks;
1632 OI.collectBlocks(BlockSet&: ParallelRegionBlockSet, BlockVector&: Blocks);
1633
1634 CodeExtractorAnalysisCache CEAC(*OuterFn);
1635 CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
1636 /* AggregateArgs */ false,
1637 /* BlockFrequencyInfo */ nullptr,
1638 /* BranchProbabilityInfo */ nullptr,
1639 /* AssumptionCache */ nullptr,
1640 /* AllowVarArgs */ true,
1641 /* AllowAlloca */ true,
1642 /* AllocationBlock */ OuterAllocaBlock,
1643 /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
1644
1645 // Find inputs to, outputs from the code region.
1646 BasicBlock *CommonExit = nullptr;
1647 SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
1648 Extractor.findAllocas(CEAC, SinkCands&: SinkingCands, HoistCands&: HoistingCands, ExitBlock&: CommonExit);
1649
1650 Extractor.findInputsOutputs(Inputs, Outputs, Allocas: SinkingCands,
1651 /*CollectGlobalInputs=*/true);
1652
1653 Inputs.remove_if(P: [&](Value *I) {
1654 if (auto *GV = dyn_cast_if_present<GlobalVariable>(Val: I))
1655 return GV->getValueType() == OpenMPIRBuilder::Ident;
1656
1657 return false;
1658 });
1659
1660 LLVM_DEBUG(dbgs() << "Before privatization: " << *OuterFn << "\n");
1661
1662 FunctionCallee TIDRTLFn =
1663 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_global_thread_num);
1664
1665 auto PrivHelper = [&](Value &V) -> Error {
1666 if (&V == TIDAddr || &V == ZeroAddr) {
1667 OI.ExcludeArgsFromAggregate.push_back(Elt: &V);
1668 return Error::success();
1669 }
1670
1671 SetVector<Use *> Uses;
1672 for (Use &U : V.uses())
1673 if (auto *UserI = dyn_cast<Instruction>(Val: U.getUser()))
1674 if (ParallelRegionBlockSet.count(Ptr: UserI->getParent()))
1675 Uses.insert(X: &U);
1676
1677 // __kmpc_fork_call expects extra arguments as pointers. If the input
1678 // already has a pointer type, everything is fine. Otherwise, store the
1679 // value onto stack and load it back inside the to-be-outlined region. This
1680 // will ensure only the pointer will be passed to the function.
1681 // FIXME: if there are more than 15 trailing arguments, they must be
1682 // additionally packed in a struct.
1683 Value *Inner = &V;
1684 if (!V.getType()->isPointerTy()) {
1685 IRBuilder<>::InsertPointGuard Guard(Builder);
1686 LLVM_DEBUG(llvm::dbgs() << "Forwarding input as pointer: " << V << "\n");
1687
1688 Builder.restoreIP(IP: OuterAllocaIP);
1689 Value *Ptr =
1690 Builder.CreateAlloca(Ty: V.getType(), ArraySize: nullptr, Name: V.getName() + ".reloaded");
1691
1692 // Store to stack at end of the block that currently branches to the entry
1693 // block of the to-be-outlined region.
1694 Builder.SetInsertPoint(TheBB: InsertBB,
1695 IP: InsertBB->getTerminator()->getIterator());
1696 Builder.CreateStore(Val: &V, Ptr);
1697
1698 // Load back next to allocations in the to-be-outlined region.
1699 Builder.restoreIP(IP: InnerAllocaIP);
1700 Inner = Builder.CreateLoad(Ty: V.getType(), Ptr);
1701 }
1702
1703 Value *ReplacementValue = nullptr;
1704 CallInst *CI = dyn_cast<CallInst>(Val: &V);
1705 if (CI && CI->getCalledFunction() == TIDRTLFn.getCallee()) {
1706 ReplacementValue = PrivTID;
1707 } else {
1708 InsertPointOrErrorTy AfterIP =
1709 PrivCB(InnerAllocaIP, Builder.saveIP(), V, *Inner, ReplacementValue);
1710 if (!AfterIP)
1711 return AfterIP.takeError();
1712 Builder.restoreIP(IP: *AfterIP);
1713 InnerAllocaIP = {
1714 InnerAllocaIP.getBlock(),
1715 InnerAllocaIP.getBlock()->getTerminator()->getIterator()};
1716
1717 assert(ReplacementValue &&
1718 "Expected copy/create callback to set replacement value!");
1719 if (ReplacementValue == &V)
1720 return Error::success();
1721 }
1722
1723 for (Use *UPtr : Uses)
1724 UPtr->set(ReplacementValue);
1725
1726 return Error::success();
1727 };
1728
1729 // Reset the inner alloca insertion as it will be used for loading the values
1730 // wrapped into pointers before passing them into the to-be-outlined region.
1731 // Configure it to insert immediately after the fake use of zero address so
1732 // that they are available in the generated body and so that the
1733 // OpenMP-related values (thread ID and zero address pointers) remain leading
1734 // in the argument list.
1735 InnerAllocaIP = IRBuilder<>::InsertPoint(
1736 ZeroAddrUse->getParent(), ZeroAddrUse->getNextNode()->getIterator());
1737
1738 // Reset the outer alloca insertion point to the entry of the relevant block
1739 // in case it was invalidated.
1740 OuterAllocaIP = IRBuilder<>::InsertPoint(
1741 OuterAllocaBlock, OuterAllocaBlock->getFirstInsertionPt());
1742
1743 for (Value *Input : Inputs) {
1744 LLVM_DEBUG(dbgs() << "Captured input: " << *Input << "\n");
1745 if (Error Err = PrivHelper(*Input))
1746 return Err;
1747 }
1748 LLVM_DEBUG({
1749 for (Value *Output : Outputs)
1750 LLVM_DEBUG(dbgs() << "Captured output: " << *Output << "\n");
1751 });
1752 assert(Outputs.empty() &&
1753 "OpenMP outlining should not produce live-out values!");
1754
1755 LLVM_DEBUG(dbgs() << "After privatization: " << *OuterFn << "\n");
1756 LLVM_DEBUG({
1757 for (auto *BB : Blocks)
1758 dbgs() << " PBR: " << BB->getName() << "\n";
1759 });
1760
1761 // Adjust the finalization stack, verify the adjustment, and call the
1762 // finalize function a last time to finalize values between the pre-fini
1763 // block and the exit block if we left the parallel "the normal way".
1764 auto FiniInfo = FinalizationStack.pop_back_val();
1765 (void)FiniInfo;
1766 assert(FiniInfo.DK == OMPD_parallel &&
1767 "Unexpected finalization stack state!");
1768
1769 Instruction *PRegPreFiniTI = PRegPreFiniBB->getTerminator();
1770
1771 InsertPointTy PreFiniIP(PRegPreFiniBB, PRegPreFiniTI->getIterator());
1772 if (Error Err = FiniCB(PreFiniIP))
1773 return Err;
1774
1775 // Register the outlined info.
1776 addOutlineInfo(OI: std::move(OI));
1777
1778 InsertPointTy AfterIP(UI->getParent(), UI->getParent()->end());
1779 UI->eraseFromParent();
1780
1781 return AfterIP;
1782}
1783
1784void OpenMPIRBuilder::emitFlush(const LocationDescription &Loc) {
1785 // Build call void __kmpc_flush(ident_t *loc)
1786 uint32_t SrcLocStrSize;
1787 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1788 Value *Args[] = {getOrCreateIdent(SrcLocStr, SrcLocStrSize)};
1789
1790 Builder.CreateCall(Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_flush), Args);
1791}
1792
1793void OpenMPIRBuilder::createFlush(const LocationDescription &Loc) {
1794 if (!updateToLocation(Loc))
1795 return;
1796 emitFlush(Loc);
1797}
1798
1799void OpenMPIRBuilder::emitTaskwaitImpl(const LocationDescription &Loc) {
1800 // Build call kmp_int32 __kmpc_omp_taskwait(ident_t *loc, kmp_int32
1801 // global_tid);
1802 uint32_t SrcLocStrSize;
1803 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1804 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1805 Value *Args[] = {Ident, getOrCreateThreadID(Ident)};
1806
1807 // Ignore return result until untied tasks are supported.
1808 Builder.CreateCall(Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_taskwait),
1809 Args);
1810}
1811
1812void OpenMPIRBuilder::createTaskwait(const LocationDescription &Loc) {
1813 if (!updateToLocation(Loc))
1814 return;
1815 emitTaskwaitImpl(Loc);
1816}
1817
1818void OpenMPIRBuilder::emitTaskyieldImpl(const LocationDescription &Loc) {
1819 // Build call __kmpc_omp_taskyield(loc, thread_id, 0);
1820 uint32_t SrcLocStrSize;
1821 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1822 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1823 Constant *I32Null = ConstantInt::getNullValue(Ty: Int32);
1824 Value *Args[] = {Ident, getOrCreateThreadID(Ident), I32Null};
1825
1826 Builder.CreateCall(Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_taskyield),
1827 Args);
1828}
1829
1830void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
1831 if (!updateToLocation(Loc))
1832 return;
1833 emitTaskyieldImpl(Loc);
1834}
1835
1836// Processes the dependencies in Dependencies and does the following
1837// - Allocates space on the stack of an array of DependInfo objects
1838// - Populates each DependInfo object with relevant information of
1839// the corresponding dependence.
1840// - All code is inserted in the entry block of the current function.
1841static Value *emitTaskDependencies(
1842 OpenMPIRBuilder &OMPBuilder,
1843 const SmallVectorImpl<OpenMPIRBuilder::DependData> &Dependencies) {
1844 // Early return if we have no dependencies to process
1845 if (Dependencies.empty())
1846 return nullptr;
1847
1848 // Given a vector of DependData objects, in this function we create an
1849 // array on the stack that holds kmp_dep_info objects corresponding
1850 // to each dependency. This is then passed to the OpenMP runtime.
1851 // For example, if there are 'n' dependencies then the following psedo
1852 // code is generated. Assume the first dependence is on a variable 'a'
1853 //
1854 // \code{c}
1855 // DepArray = alloc(n x sizeof(kmp_depend_info);
1856 // idx = 0;
1857 // DepArray[idx].base_addr = ptrtoint(&a);
1858 // DepArray[idx].len = 8;
1859 // DepArray[idx].flags = Dep.DepKind; /*(See OMPContants.h for DepKind)*/
1860 // ++idx;
1861 // DepArray[idx].base_addr = ...;
1862 // \endcode
1863
1864 IRBuilderBase &Builder = OMPBuilder.Builder;
1865 Type *DependInfo = OMPBuilder.DependInfo;
1866 Module &M = OMPBuilder.M;
1867
1868 Value *DepArray = nullptr;
1869 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
1870 Builder.SetInsertPoint(
1871 OldIP.getBlock()->getParent()->getEntryBlock().getTerminator());
1872
1873 Type *DepArrayTy = ArrayType::get(ElementType: DependInfo, NumElements: Dependencies.size());
1874 DepArray = Builder.CreateAlloca(Ty: DepArrayTy, ArraySize: nullptr, Name: ".dep.arr.addr");
1875
1876 Builder.restoreIP(IP: OldIP);
1877
1878 for (const auto &[DepIdx, Dep] : enumerate(First: Dependencies)) {
1879 Value *Base =
1880 Builder.CreateConstInBoundsGEP2_64(Ty: DepArrayTy, Ptr: DepArray, Idx0: 0, Idx1: DepIdx);
1881 // Store the pointer to the variable
1882 Value *Addr = Builder.CreateStructGEP(
1883 Ty: DependInfo, Ptr: Base,
1884 Idx: static_cast<unsigned int>(RTLDependInfoFields::BaseAddr));
1885 Value *DepValPtr = Builder.CreatePtrToInt(V: Dep.DepVal, DestTy: Builder.getInt64Ty());
1886 Builder.CreateStore(Val: DepValPtr, Ptr: Addr);
1887 // Store the size of the variable
1888 Value *Size = Builder.CreateStructGEP(
1889 Ty: DependInfo, Ptr: Base, Idx: static_cast<unsigned int>(RTLDependInfoFields::Len));
1890 Builder.CreateStore(
1891 Val: Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: Dep.DepValueType)),
1892 Ptr: Size);
1893 // Store the dependency kind
1894 Value *Flags = Builder.CreateStructGEP(
1895 Ty: DependInfo, Ptr: Base,
1896 Idx: static_cast<unsigned int>(RTLDependInfoFields::Flags));
1897 Builder.CreateStore(
1898 Val: ConstantInt::get(Ty: Builder.getInt8Ty(),
1899 V: static_cast<unsigned int>(Dep.DepKind)),
1900 Ptr: Flags);
1901 }
1902 return DepArray;
1903}
1904
1905OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
1906 const LocationDescription &Loc, InsertPointTy AllocaIP,
1907 BodyGenCallbackTy BodyGenCB, bool Tied, Value *Final, Value *IfCondition,
1908 SmallVector<DependData> Dependencies, bool Mergeable, Value *EventHandle,
1909 Value *Priority) {
1910
1911 if (!updateToLocation(Loc))
1912 return InsertPointTy();
1913
1914 uint32_t SrcLocStrSize;
1915 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1916 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1917 // The current basic block is split into four basic blocks. After outlining,
1918 // they will be mapped as follows:
1919 // ```
1920 // def current_fn() {
1921 // current_basic_block:
1922 // br label %task.exit
1923 // task.exit:
1924 // ; instructions after task
1925 // }
1926 // def outlined_fn() {
1927 // task.alloca:
1928 // br label %task.body
1929 // task.body:
1930 // ret void
1931 // }
1932 // ```
1933 BasicBlock *TaskExitBB = splitBB(Builder, /*CreateBranch=*/true, Name: "task.exit");
1934 BasicBlock *TaskBodyBB = splitBB(Builder, /*CreateBranch=*/true, Name: "task.body");
1935 BasicBlock *TaskAllocaBB =
1936 splitBB(Builder, /*CreateBranch=*/true, Name: "task.alloca");
1937
1938 InsertPointTy TaskAllocaIP =
1939 InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin());
1940 InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin());
1941 if (Error Err = BodyGenCB(TaskAllocaIP, TaskBodyIP))
1942 return Err;
1943
1944 OutlineInfo OI;
1945 OI.EntryBB = TaskAllocaBB;
1946 OI.OuterAllocaBB = AllocaIP.getBlock();
1947 OI.ExitBB = TaskExitBB;
1948
1949 // Add the thread ID argument.
1950 SmallVector<Instruction *, 4> ToBeDeleted;
1951 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
1952 Builder, OuterAllocaIP: AllocaIP, ToBeDeleted, InnerAllocaIP: TaskAllocaIP, Name: "global.tid", AsPtr: false));
1953
1954 OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
1955 Mergeable, Priority, EventHandle, TaskAllocaBB,
1956 ToBeDeleted](Function &OutlinedFn) mutable {
1957 // Replace the Stale CI by appropriate RTL function call.
1958 assert(OutlinedFn.hasOneUse() &&
1959 "there must be a single user for the outlined function");
1960 CallInst *StaleCI = cast<CallInst>(Val: OutlinedFn.user_back());
1961
1962 // HasShareds is true if any variables are captured in the outlined region,
1963 // false otherwise.
1964 bool HasShareds = StaleCI->arg_size() > 1;
1965 Builder.SetInsertPoint(StaleCI);
1966
1967 // Gather the arguments for emitting the runtime call for
1968 // @__kmpc_omp_task_alloc
1969 Function *TaskAllocFn =
1970 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_alloc);
1971
1972 // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
1973 // call.
1974 Value *ThreadID = getOrCreateThreadID(Ident);
1975
1976 // Argument - `flags`
1977 // Task is tied iff (Flags & 1) == 1.
1978 // Task is untied iff (Flags & 1) == 0.
1979 // Task is final iff (Flags & 2) == 2.
1980 // Task is not final iff (Flags & 2) == 0.
1981 // Task is mergeable iff (Flags & 4) == 4.
1982 // Task is not mergeable iff (Flags & 4) == 0.
1983 // Task is priority iff (Flags & 32) == 32.
1984 // Task is not priority iff (Flags & 32) == 0.
1985 // TODO: Handle the other flags.
1986 Value *Flags = Builder.getInt32(C: Tied);
1987 if (Final) {
1988 Value *FinalFlag =
1989 Builder.CreateSelect(C: Final, True: Builder.getInt32(C: 2), False: Builder.getInt32(C: 0));
1990 Flags = Builder.CreateOr(LHS: FinalFlag, RHS: Flags);
1991 }
1992
1993 if (Mergeable)
1994 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 4), RHS: Flags);
1995 if (Priority)
1996 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 32), RHS: Flags);
1997
1998 // Argument - `sizeof_kmp_task_t` (TaskSize)
1999 // Tasksize refers to the size in bytes of kmp_task_t data structure
2000 // including private vars accessed in task.
2001 // TODO: add kmp_task_t_with_privates (privates)
2002 Value *TaskSize = Builder.getInt64(
2003 C: divideCeil(Numerator: M.getDataLayout().getTypeSizeInBits(Ty: Task), Denominator: 8));
2004
2005 // Argument - `sizeof_shareds` (SharedsSize)
2006 // SharedsSize refers to the shareds array size in the kmp_task_t data
2007 // structure.
2008 Value *SharedsSize = Builder.getInt64(C: 0);
2009 if (HasShareds) {
2010 AllocaInst *ArgStructAlloca =
2011 dyn_cast<AllocaInst>(Val: StaleCI->getArgOperand(i: 1));
2012 assert(ArgStructAlloca &&
2013 "Unable to find the alloca instruction corresponding to arguments "
2014 "for extracted function");
2015 StructType *ArgStructType =
2016 dyn_cast<StructType>(Val: ArgStructAlloca->getAllocatedType());
2017 assert(ArgStructType && "Unable to find struct type corresponding to "
2018 "arguments for extracted function");
2019 SharedsSize =
2020 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: ArgStructType));
2021 }
2022 // Emit the @__kmpc_omp_task_alloc runtime call
2023 // The runtime call returns a pointer to an area where the task captured
2024 // variables must be copied before the task is run (TaskData)
2025 CallInst *TaskData = Builder.CreateCall(
2026 Callee: TaskAllocFn, Args: {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
2027 /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
2028 /*task_func=*/&OutlinedFn});
2029
2030 // Emit detach clause initialization.
2031 // evt = (typeof(evt))__kmpc_task_allow_completion_event(loc, tid,
2032 // task_descriptor);
2033 if (EventHandle) {
2034 Function *TaskDetachFn = getOrCreateRuntimeFunctionPtr(
2035 FnID: OMPRTL___kmpc_task_allow_completion_event);
2036 llvm::Value *EventVal =
2037 Builder.CreateCall(Callee: TaskDetachFn, Args: {Ident, ThreadID, TaskData});
2038 llvm::Value *EventHandleAddr =
2039 Builder.CreatePointerBitCastOrAddrSpaceCast(V: EventHandle,
2040 DestTy: Builder.getPtrTy(AddrSpace: 0));
2041 EventVal = Builder.CreatePtrToInt(V: EventVal, DestTy: Builder.getInt64Ty());
2042 Builder.CreateStore(Val: EventVal, Ptr: EventHandleAddr);
2043 }
2044 // Copy the arguments for outlined function
2045 if (HasShareds) {
2046 Value *Shareds = StaleCI->getArgOperand(i: 1);
2047 Align Alignment = TaskData->getPointerAlignment(DL: M.getDataLayout());
2048 Value *TaskShareds = Builder.CreateLoad(Ty: VoidPtr, Ptr: TaskData);
2049 Builder.CreateMemCpy(Dst: TaskShareds, DstAlign: Alignment, Src: Shareds, SrcAlign: Alignment,
2050 Size: SharedsSize);
2051 }
2052
2053 if (Priority) {
2054 //
2055 // The return type of "__kmpc_omp_task_alloc" is "kmp_task_t *",
2056 // we populate the priority information into the "kmp_task_t" here
2057 //
2058 // The struct "kmp_task_t" definition is available in kmp.h
2059 // kmp_task_t = { shareds, routine, part_id, data1, data2 }
2060 // data2 is used for priority
2061 //
2062 Type *Int32Ty = Builder.getInt32Ty();
2063 Constant *Zero = ConstantInt::get(Ty: Int32Ty, V: 0);
2064 // kmp_task_t* => { ptr }
2065 Type *TaskPtr = StructType::get(elt1: VoidPtr);
2066 Value *TaskGEP =
2067 Builder.CreateInBoundsGEP(Ty: TaskPtr, Ptr: TaskData, IdxList: {Zero, Zero});
2068 // kmp_task_t => { ptr, ptr, i32, ptr, ptr }
2069 Type *TaskStructType = StructType::get(
2070 elt1: VoidPtr, elts: VoidPtr, elts: Builder.getInt32Ty(), elts: VoidPtr, elts: VoidPtr);
2071 Value *PriorityData = Builder.CreateInBoundsGEP(
2072 Ty: TaskStructType, Ptr: TaskGEP, IdxList: {Zero, ConstantInt::get(Ty: Int32Ty, V: 4)});
2073 // kmp_cmplrdata_t => { ptr, ptr }
2074 Type *CmplrStructType = StructType::get(elt1: VoidPtr, elts: VoidPtr);
2075 Value *CmplrData = Builder.CreateInBoundsGEP(Ty: CmplrStructType,
2076 Ptr: PriorityData, IdxList: {Zero, Zero});
2077 Builder.CreateStore(Val: Priority, Ptr: CmplrData);
2078 }
2079
2080 Value *DepArray = emitTaskDependencies(OMPBuilder&: *this, Dependencies);
2081
2082 // In the presence of the `if` clause, the following IR is generated:
2083 // ...
2084 // %data = call @__kmpc_omp_task_alloc(...)
2085 // br i1 %if_condition, label %then, label %else
2086 // then:
2087 // call @__kmpc_omp_task(...)
2088 // br label %exit
2089 // else:
2090 // ;; Wait for resolution of dependencies, if any, before
2091 // ;; beginning the task
2092 // call @__kmpc_omp_wait_deps(...)
2093 // call @__kmpc_omp_task_begin_if0(...)
2094 // call @outlined_fn(...)
2095 // call @__kmpc_omp_task_complete_if0(...)
2096 // br label %exit
2097 // exit:
2098 // ...
2099 if (IfCondition) {
2100 // `SplitBlockAndInsertIfThenElse` requires the block to have a
2101 // terminator.
2102 splitBB(Builder, /*CreateBranch=*/true, Name: "if.end");
2103 Instruction *IfTerminator =
2104 Builder.GetInsertPoint()->getParent()->getTerminator();
2105 Instruction *ThenTI = IfTerminator, *ElseTI = nullptr;
2106 Builder.SetInsertPoint(IfTerminator);
2107 SplitBlockAndInsertIfThenElse(Cond: IfCondition, SplitBefore: IfTerminator, ThenTerm: &ThenTI,
2108 ElseTerm: &ElseTI);
2109 Builder.SetInsertPoint(ElseTI);
2110
2111 if (Dependencies.size()) {
2112 Function *TaskWaitFn =
2113 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_wait_deps);
2114 Builder.CreateCall(
2115 Callee: TaskWaitFn,
2116 Args: {Ident, ThreadID, Builder.getInt32(C: Dependencies.size()), DepArray,
2117 ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
2118 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
2119 }
2120 Function *TaskBeginFn =
2121 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_begin_if0);
2122 Function *TaskCompleteFn =
2123 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_complete_if0);
2124 Builder.CreateCall(Callee: TaskBeginFn, Args: {Ident, ThreadID, TaskData});
2125 CallInst *CI = nullptr;
2126 if (HasShareds)
2127 CI = Builder.CreateCall(Callee: &OutlinedFn, Args: {ThreadID, TaskData});
2128 else
2129 CI = Builder.CreateCall(Callee: &OutlinedFn, Args: {ThreadID});
2130 CI->setDebugLoc(StaleCI->getDebugLoc());
2131 Builder.CreateCall(Callee: TaskCompleteFn, Args: {Ident, ThreadID, TaskData});
2132 Builder.SetInsertPoint(ThenTI);
2133 }
2134
2135 if (Dependencies.size()) {
2136 Function *TaskFn =
2137 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_with_deps);
2138 Builder.CreateCall(
2139 Callee: TaskFn,
2140 Args: {Ident, ThreadID, TaskData, Builder.getInt32(C: Dependencies.size()),
2141 DepArray, ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
2142 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
2143
2144 } else {
2145 // Emit the @__kmpc_omp_task runtime call to spawn the task
2146 Function *TaskFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task);
2147 Builder.CreateCall(Callee: TaskFn, Args: {Ident, ThreadID, TaskData});
2148 }
2149
2150 StaleCI->eraseFromParent();
2151
2152 Builder.SetInsertPoint(TheBB: TaskAllocaBB, IP: TaskAllocaBB->begin());
2153 if (HasShareds) {
2154 LoadInst *Shareds = Builder.CreateLoad(Ty: VoidPtr, Ptr: OutlinedFn.getArg(i: 1));
2155 OutlinedFn.getArg(i: 1)->replaceUsesWithIf(
2156 New: Shareds, ShouldReplace: [Shareds](Use &U) { return U.getUser() != Shareds; });
2157 }
2158
2159 for (Instruction *I : llvm::reverse(C&: ToBeDeleted))
2160 I->eraseFromParent();
2161 };
2162
2163 addOutlineInfo(OI: std::move(OI));
2164 Builder.SetInsertPoint(TheBB: TaskExitBB, IP: TaskExitBB->begin());
2165
2166 return Builder.saveIP();
2167}
2168
2169OpenMPIRBuilder::InsertPointOrErrorTy
2170OpenMPIRBuilder::createTaskgroup(const LocationDescription &Loc,
2171 InsertPointTy AllocaIP,
2172 BodyGenCallbackTy BodyGenCB) {
2173 if (!updateToLocation(Loc))
2174 return InsertPointTy();
2175
2176 uint32_t SrcLocStrSize;
2177 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2178 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2179 Value *ThreadID = getOrCreateThreadID(Ident);
2180
2181 // Emit the @__kmpc_taskgroup runtime call to start the taskgroup
2182 Function *TaskgroupFn =
2183 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_taskgroup);
2184 Builder.CreateCall(Callee: TaskgroupFn, Args: {Ident, ThreadID});
2185
2186 BasicBlock *TaskgroupExitBB = splitBB(Builder, CreateBranch: true, Name: "taskgroup.exit");
2187 if (Error Err = BodyGenCB(AllocaIP, Builder.saveIP()))
2188 return Err;
2189
2190 Builder.SetInsertPoint(TaskgroupExitBB);
2191 // Emit the @__kmpc_end_taskgroup runtime call to end the taskgroup
2192 Function *EndTaskgroupFn =
2193 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_taskgroup);
2194 Builder.CreateCall(Callee: EndTaskgroupFn, Args: {Ident, ThreadID});
2195
2196 return Builder.saveIP();
2197}
2198
2199OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createSections(
2200 const LocationDescription &Loc, InsertPointTy AllocaIP,
2201 ArrayRef<StorableBodyGenCallbackTy> SectionCBs, PrivatizeCallbackTy PrivCB,
2202 FinalizeCallbackTy FiniCB, bool IsCancellable, bool IsNowait) {
2203 assert(!isConflictIP(AllocaIP, Loc.IP) && "Dedicated IP allocas required");
2204
2205 if (!updateToLocation(Loc))
2206 return Loc.IP;
2207
2208 // FiniCBWrapper needs to create a branch to the loop finalization block, but
2209 // this has not been created yet at some times when this callback runs.
2210 SmallVector<BranchInst *> CancellationBranches;
2211 auto FiniCBWrapper = [&](InsertPointTy IP) {
2212 if (IP.getBlock()->end() != IP.getPoint())
2213 return FiniCB(IP);
2214 // This must be done otherwise any nested constructs using FinalizeOMPRegion
2215 // will fail because that function requires the Finalization Basic Block to
2216 // have a terminator, which is already removed by EmitOMPRegionBody.
2217 // IP is currently at cancelation block.
2218 BranchInst *DummyBranch = Builder.CreateBr(Dest: IP.getBlock());
2219 IP = InsertPointTy(DummyBranch->getParent(), DummyBranch->getIterator());
2220 CancellationBranches.push_back(Elt: DummyBranch);
2221 return FiniCB(IP);
2222 };
2223
2224 FinalizationStack.push_back(Elt: {.FiniCB: FiniCBWrapper, .DK: OMPD_sections, .IsCancellable: IsCancellable});
2225
2226 // Each section is emitted as a switch case
2227 // Each finalization callback is handled from clang.EmitOMPSectionDirective()
2228 // -> OMP.createSection() which generates the IR for each section
2229 // Iterate through all sections and emit a switch construct:
2230 // switch (IV) {
2231 // case 0:
2232 // <SectionStmt[0]>;
2233 // break;
2234 // ...
2235 // case <NumSection> - 1:
2236 // <SectionStmt[<NumSection> - 1]>;
2237 // break;
2238 // }
2239 // ...
2240 // section_loop.after:
2241 // <FiniCB>;
2242 auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, Value *IndVar) -> Error {
2243 Builder.restoreIP(IP: CodeGenIP);
2244 BasicBlock *Continue =
2245 splitBBWithSuffix(Builder, /*CreateBranch=*/false, Suffix: ".sections.after");
2246 Function *CurFn = Continue->getParent();
2247 SwitchInst *SwitchStmt = Builder.CreateSwitch(V: IndVar, Dest: Continue);
2248
2249 unsigned CaseNumber = 0;
2250 for (auto SectionCB : SectionCBs) {
2251 BasicBlock *CaseBB = BasicBlock::Create(
2252 Context&: M.getContext(), Name: "omp_section_loop.body.case", Parent: CurFn, InsertBefore: Continue);
2253 SwitchStmt->addCase(OnVal: Builder.getInt32(C: CaseNumber), Dest: CaseBB);
2254 Builder.SetInsertPoint(CaseBB);
2255 BranchInst *CaseEndBr = Builder.CreateBr(Dest: Continue);
2256 if (Error Err = SectionCB(InsertPointTy(), {CaseEndBr->getParent(),
2257 CaseEndBr->getIterator()}))
2258 return Err;
2259 CaseNumber++;
2260 }
2261 // remove the existing terminator from body BB since there can be no
2262 // terminators after switch/case
2263 return Error::success();
2264 };
2265 // Loop body ends here
2266 // LowerBound, UpperBound, and STride for createCanonicalLoop
2267 Type *I32Ty = Type::getInt32Ty(C&: M.getContext());
2268 Value *LB = ConstantInt::get(Ty: I32Ty, V: 0);
2269 Value *UB = ConstantInt::get(Ty: I32Ty, V: SectionCBs.size());
2270 Value *ST = ConstantInt::get(Ty: I32Ty, V: 1);
2271 Expected<CanonicalLoopInfo *> LoopInfo = createCanonicalLoop(
2272 Loc, BodyGenCB: LoopBodyGenCB, Start: LB, Stop: UB, Step: ST, IsSigned: true, InclusiveStop: false, ComputeIP: AllocaIP, Name: "section_loop");
2273 if (!LoopInfo)
2274 return LoopInfo.takeError();
2275
2276 InsertPointOrErrorTy WsloopIP =
2277 applyStaticWorkshareLoop(DL: Loc.DL, CLI: *LoopInfo, AllocaIP,
2278 LoopType: WorksharingLoopType::ForStaticLoop, NeedsBarrier: !IsNowait);
2279 if (!WsloopIP)
2280 return WsloopIP.takeError();
2281 InsertPointTy AfterIP = *WsloopIP;
2282
2283 BasicBlock *LoopFini = AfterIP.getBlock()->getSinglePredecessor();
2284 assert(LoopFini && "Bad structure of static workshare loop finalization");
2285
2286 // Apply the finalization callback in LoopAfterBB
2287 auto FiniInfo = FinalizationStack.pop_back_val();
2288 assert(FiniInfo.DK == OMPD_sections &&
2289 "Unexpected finalization stack state!");
2290 if (FinalizeCallbackTy &CB = FiniInfo.FiniCB) {
2291 Builder.restoreIP(IP: AfterIP);
2292 BasicBlock *FiniBB =
2293 splitBBWithSuffix(Builder, /*CreateBranch=*/true, Suffix: "sections.fini");
2294 if (Error Err = CB(Builder.saveIP()))
2295 return Err;
2296 AfterIP = {FiniBB, FiniBB->begin()};
2297 }
2298
2299 // Now we can fix the dummy branch to point to the right place
2300 for (BranchInst *DummyBranch : CancellationBranches) {
2301 assert(DummyBranch->getNumSuccessors() == 1);
2302 DummyBranch->setSuccessor(idx: 0, NewSucc: LoopFini);
2303 }
2304
2305 return AfterIP;
2306}
2307
2308OpenMPIRBuilder::InsertPointOrErrorTy
2309OpenMPIRBuilder::createSection(const LocationDescription &Loc,
2310 BodyGenCallbackTy BodyGenCB,
2311 FinalizeCallbackTy FiniCB) {
2312 if (!updateToLocation(Loc))
2313 return Loc.IP;
2314
2315 auto FiniCBWrapper = [&](InsertPointTy IP) {
2316 if (IP.getBlock()->end() != IP.getPoint())
2317 return FiniCB(IP);
2318 // This must be done otherwise any nested constructs using FinalizeOMPRegion
2319 // will fail because that function requires the Finalization Basic Block to
2320 // have a terminator, which is already removed by EmitOMPRegionBody.
2321 // IP is currently at cancelation block.
2322 // We need to backtrack to the condition block to fetch
2323 // the exit block and create a branch from cancelation
2324 // to exit block.
2325 IRBuilder<>::InsertPointGuard IPG(Builder);
2326 Builder.restoreIP(IP);
2327 auto *CaseBB = Loc.IP.getBlock();
2328 auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
2329 auto *ExitBB = CondBB->getTerminator()->getSuccessor(Idx: 1);
2330 Instruction *I = Builder.CreateBr(Dest: ExitBB);
2331 IP = InsertPointTy(I->getParent(), I->getIterator());
2332 return FiniCB(IP);
2333 };
2334
2335 Directive OMPD = Directive::OMPD_sections;
2336 // Since we are using Finalization Callback here, HasFinalize
2337 // and IsCancellable have to be true
2338 return EmitOMPInlinedRegion(OMPD, EntryCall: nullptr, ExitCall: nullptr, BodyGenCB, FiniCB: FiniCBWrapper,
2339 /*Conditional*/ false, /*hasFinalize*/ HasFinalize: true,
2340 /*IsCancellable*/ true);
2341}
2342
2343static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
2344 BasicBlock::iterator IT(I);
2345 IT++;
2346 return OpenMPIRBuilder::InsertPointTy(I->getParent(), IT);
2347}
2348
2349Value *OpenMPIRBuilder::getGPUThreadID() {
2350 return Builder.CreateCall(
2351 Callee: getOrCreateRuntimeFunction(M,
2352 FnID: OMPRTL___kmpc_get_hardware_thread_id_in_block),
2353 Args: {});
2354}
2355
2356Value *OpenMPIRBuilder::getGPUWarpSize() {
2357 return Builder.CreateCall(
2358 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___kmpc_get_warp_size), Args: {});
2359}
2360
2361Value *OpenMPIRBuilder::getNVPTXWarpID() {
2362 unsigned LaneIDBits = Log2_32(Value: Config.getGridValue().GV_Warp_Size);
2363 return Builder.CreateAShr(LHS: getGPUThreadID(), RHS: LaneIDBits, Name: "nvptx_warp_id");
2364}
2365
2366Value *OpenMPIRBuilder::getNVPTXLaneID() {
2367 unsigned LaneIDBits = Log2_32(Value: Config.getGridValue().GV_Warp_Size);
2368 assert(LaneIDBits < 32 && "Invalid LaneIDBits size in NVPTX device.");
2369 unsigned LaneIDMask = ~0u >> (32u - LaneIDBits);
2370 return Builder.CreateAnd(LHS: getGPUThreadID(), RHS: Builder.getInt32(C: LaneIDMask),
2371 Name: "nvptx_lane_id");
2372}
2373
2374Value *OpenMPIRBuilder::castValueToType(InsertPointTy AllocaIP, Value *From,
2375 Type *ToType) {
2376 Type *FromType = From->getType();
2377 uint64_t FromSize = M.getDataLayout().getTypeStoreSize(Ty: FromType);
2378 uint64_t ToSize = M.getDataLayout().getTypeStoreSize(Ty: ToType);
2379 assert(FromSize > 0 && "From size must be greater than zero");
2380 assert(ToSize > 0 && "To size must be greater than zero");
2381 if (FromType == ToType)
2382 return From;
2383 if (FromSize == ToSize)
2384 return Builder.CreateBitCast(V: From, DestTy: ToType);
2385 if (ToType->isIntegerTy() && FromType->isIntegerTy())
2386 return Builder.CreateIntCast(V: From, DestTy: ToType, /*isSigned*/ true);
2387 InsertPointTy SaveIP = Builder.saveIP();
2388 Builder.restoreIP(IP: AllocaIP);
2389 Value *CastItem = Builder.CreateAlloca(Ty: ToType);
2390 Builder.restoreIP(IP: SaveIP);
2391
2392 Value *ValCastItem = Builder.CreatePointerBitCastOrAddrSpaceCast(
2393 V: CastItem, DestTy: Builder.getPtrTy(AddrSpace: 0));
2394 Builder.CreateStore(Val: From, Ptr: ValCastItem);
2395 return Builder.CreateLoad(Ty: ToType, Ptr: CastItem);
2396}
2397
2398Value *OpenMPIRBuilder::createRuntimeShuffleFunction(InsertPointTy AllocaIP,
2399 Value *Element,
2400 Type *ElementType,
2401 Value *Offset) {
2402 uint64_t Size = M.getDataLayout().getTypeStoreSize(Ty: ElementType);
2403 assert(Size <= 8 && "Unsupported bitwidth in shuffle instruction");
2404
2405 // Cast all types to 32- or 64-bit values before calling shuffle routines.
2406 Type *CastTy = Builder.getIntNTy(N: Size <= 4 ? 32 : 64);
2407 Value *ElemCast = castValueToType(AllocaIP, From: Element, ToType: CastTy);
2408 Value *WarpSize =
2409 Builder.CreateIntCast(V: getGPUWarpSize(), DestTy: Builder.getInt16Ty(), isSigned: true);
2410 Function *ShuffleFunc = getOrCreateRuntimeFunctionPtr(
2411 FnID: Size <= 4 ? RuntimeFunction::OMPRTL___kmpc_shuffle_int32
2412 : RuntimeFunction::OMPRTL___kmpc_shuffle_int64);
2413 Value *WarpSizeCast =
2414 Builder.CreateIntCast(V: WarpSize, DestTy: Builder.getInt16Ty(), /*isSigned=*/true);
2415 Value *ShuffleCall =
2416 Builder.CreateCall(Callee: ShuffleFunc, Args: {ElemCast, Offset, WarpSizeCast});
2417 return castValueToType(AllocaIP, From: ShuffleCall, ToType: CastTy);
2418}
2419
2420void OpenMPIRBuilder::shuffleAndStore(InsertPointTy AllocaIP, Value *SrcAddr,
2421 Value *DstAddr, Type *ElemType,
2422 Value *Offset, Type *ReductionArrayTy) {
2423 uint64_t Size = M.getDataLayout().getTypeStoreSize(Ty: ElemType);
2424 // Create the loop over the big sized data.
2425 // ptr = (void*)Elem;
2426 // ptrEnd = (void*) Elem + 1;
2427 // Step = 8;
2428 // while (ptr + Step < ptrEnd)
2429 // shuffle((int64_t)*ptr);
2430 // Step = 4;
2431 // while (ptr + Step < ptrEnd)
2432 // shuffle((int32_t)*ptr);
2433 // ...
2434 Type *IndexTy = Builder.getIndexTy(
2435 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
2436 Value *ElemPtr = DstAddr;
2437 Value *Ptr = SrcAddr;
2438 for (unsigned IntSize = 8; IntSize >= 1; IntSize /= 2) {
2439 if (Size < IntSize)
2440 continue;
2441 Type *IntType = Builder.getIntNTy(N: IntSize * 8);
2442 Ptr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2443 V: Ptr, DestTy: Builder.getPtrTy(AddrSpace: 0), Name: Ptr->getName() + ".ascast");
2444 Value *SrcAddrGEP =
2445 Builder.CreateGEP(Ty: ElemType, Ptr: SrcAddr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
2446 ElemPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2447 V: ElemPtr, DestTy: Builder.getPtrTy(AddrSpace: 0), Name: ElemPtr->getName() + ".ascast");
2448
2449 Function *CurFunc = Builder.GetInsertBlock()->getParent();
2450 if ((Size / IntSize) > 1) {
2451 Value *PtrEnd = Builder.CreatePointerBitCastOrAddrSpaceCast(
2452 V: SrcAddrGEP, DestTy: Builder.getPtrTy());
2453 BasicBlock *PreCondBB =
2454 BasicBlock::Create(Context&: M.getContext(), Name: ".shuffle.pre_cond");
2455 BasicBlock *ThenBB = BasicBlock::Create(Context&: M.getContext(), Name: ".shuffle.then");
2456 BasicBlock *ExitBB = BasicBlock::Create(Context&: M.getContext(), Name: ".shuffle.exit");
2457 BasicBlock *CurrentBB = Builder.GetInsertBlock();
2458 emitBlock(BB: PreCondBB, CurFn: CurFunc);
2459 PHINode *PhiSrc =
2460 Builder.CreatePHI(Ty: Ptr->getType(), /*NumReservedValues=*/2);
2461 PhiSrc->addIncoming(V: Ptr, BB: CurrentBB);
2462 PHINode *PhiDest =
2463 Builder.CreatePHI(Ty: ElemPtr->getType(), /*NumReservedValues=*/2);
2464 PhiDest->addIncoming(V: ElemPtr, BB: CurrentBB);
2465 Ptr = PhiSrc;
2466 ElemPtr = PhiDest;
2467 Value *PtrDiff = Builder.CreatePtrDiff(
2468 ElemTy: Builder.getInt8Ty(), LHS: PtrEnd,
2469 RHS: Builder.CreatePointerBitCastOrAddrSpaceCast(V: Ptr, DestTy: Builder.getPtrTy()));
2470 Builder.CreateCondBr(
2471 Cond: Builder.CreateICmpSGT(LHS: PtrDiff, RHS: Builder.getInt64(C: IntSize - 1)), True: ThenBB,
2472 False: ExitBB);
2473 emitBlock(BB: ThenBB, CurFn: CurFunc);
2474 Value *Res = createRuntimeShuffleFunction(
2475 AllocaIP,
2476 Element: Builder.CreateAlignedLoad(
2477 Ty: IntType, Ptr, Align: M.getDataLayout().getPrefTypeAlign(Ty: ElemType)),
2478 ElementType: IntType, Offset);
2479 Builder.CreateAlignedStore(Val: Res, Ptr: ElemPtr,
2480 Align: M.getDataLayout().getPrefTypeAlign(Ty: ElemType));
2481 Value *LocalPtr =
2482 Builder.CreateGEP(Ty: IntType, Ptr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
2483 Value *LocalElemPtr =
2484 Builder.CreateGEP(Ty: IntType, Ptr: ElemPtr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
2485 PhiSrc->addIncoming(V: LocalPtr, BB: ThenBB);
2486 PhiDest->addIncoming(V: LocalElemPtr, BB: ThenBB);
2487 emitBranch(Target: PreCondBB);
2488 emitBlock(BB: ExitBB, CurFn: CurFunc);
2489 } else {
2490 Value *Res = createRuntimeShuffleFunction(
2491 AllocaIP, Element: Builder.CreateLoad(Ty: IntType, Ptr), ElementType: IntType, Offset);
2492 if (ElemType->isIntegerTy() && ElemType->getScalarSizeInBits() <
2493 Res->getType()->getScalarSizeInBits())
2494 Res = Builder.CreateTrunc(V: Res, DestTy: ElemType);
2495 Builder.CreateStore(Val: Res, Ptr: ElemPtr);
2496 Ptr = Builder.CreateGEP(Ty: IntType, Ptr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
2497 ElemPtr =
2498 Builder.CreateGEP(Ty: IntType, Ptr: ElemPtr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
2499 }
2500 Size = Size % IntSize;
2501 }
2502}
2503
2504void OpenMPIRBuilder::emitReductionListCopy(
2505 InsertPointTy AllocaIP, CopyAction Action, Type *ReductionArrayTy,
2506 ArrayRef<ReductionInfo> ReductionInfos, Value *SrcBase, Value *DestBase,
2507 CopyOptionsTy CopyOptions) {
2508 Type *IndexTy = Builder.getIndexTy(
2509 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
2510 Value *RemoteLaneOffset = CopyOptions.RemoteLaneOffset;
2511
2512 // Iterates, element-by-element, through the source Reduce list and
2513 // make a copy.
2514 for (auto En : enumerate(First&: ReductionInfos)) {
2515 const ReductionInfo &RI = En.value();
2516 Value *SrcElementAddr = nullptr;
2517 Value *DestElementAddr = nullptr;
2518 Value *DestElementPtrAddr = nullptr;
2519 // Should we shuffle in an element from a remote lane?
2520 bool ShuffleInElement = false;
2521 // Set to true to update the pointer in the dest Reduce list to a
2522 // newly created element.
2523 bool UpdateDestListPtr = false;
2524
2525 // Step 1.1: Get the address for the src element in the Reduce list.
2526 Value *SrcElementPtrAddr = Builder.CreateInBoundsGEP(
2527 Ty: ReductionArrayTy, Ptr: SrcBase,
2528 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
2529 SrcElementAddr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: SrcElementPtrAddr);
2530
2531 // Step 1.2: Create a temporary to store the element in the destination
2532 // Reduce list.
2533 DestElementPtrAddr = Builder.CreateInBoundsGEP(
2534 Ty: ReductionArrayTy, Ptr: DestBase,
2535 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
2536 switch (Action) {
2537 case CopyAction::RemoteLaneToThread: {
2538 InsertPointTy CurIP = Builder.saveIP();
2539 Builder.restoreIP(IP: AllocaIP);
2540 AllocaInst *DestAlloca = Builder.CreateAlloca(Ty: RI.ElementType, ArraySize: nullptr,
2541 Name: ".omp.reduction.element");
2542 DestAlloca->setAlignment(
2543 M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType));
2544 DestElementAddr = DestAlloca;
2545 DestElementAddr =
2546 Builder.CreateAddrSpaceCast(V: DestElementAddr, DestTy: Builder.getPtrTy(),
2547 Name: DestElementAddr->getName() + ".ascast");
2548 Builder.restoreIP(IP: CurIP);
2549 ShuffleInElement = true;
2550 UpdateDestListPtr = true;
2551 break;
2552 }
2553 case CopyAction::ThreadCopy: {
2554 DestElementAddr =
2555 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: DestElementPtrAddr);
2556 break;
2557 }
2558 }
2559
2560 // Now that all active lanes have read the element in the
2561 // Reduce list, shuffle over the value from the remote lane.
2562 if (ShuffleInElement) {
2563 shuffleAndStore(AllocaIP, SrcAddr: SrcElementAddr, DstAddr: DestElementAddr, ElemType: RI.ElementType,
2564 Offset: RemoteLaneOffset, ReductionArrayTy);
2565 } else {
2566 switch (RI.EvaluationKind) {
2567 case EvalKind::Scalar: {
2568 Value *Elem = Builder.CreateLoad(Ty: RI.ElementType, Ptr: SrcElementAddr);
2569 // Store the source element value to the dest element address.
2570 Builder.CreateStore(Val: Elem, Ptr: DestElementAddr);
2571 break;
2572 }
2573 case EvalKind::Complex: {
2574 Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
2575 Ty: RI.ElementType, Ptr: SrcElementAddr, Idx0: 0, Idx1: 0, Name: ".realp");
2576 Value *SrcReal = Builder.CreateLoad(
2577 Ty: RI.ElementType->getStructElementType(N: 0), Ptr: SrcRealPtr, Name: ".real");
2578 Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
2579 Ty: RI.ElementType, Ptr: SrcElementAddr, Idx0: 0, Idx1: 1, Name: ".imagp");
2580 Value *SrcImg = Builder.CreateLoad(
2581 Ty: RI.ElementType->getStructElementType(N: 1), Ptr: SrcImgPtr, Name: ".imag");
2582
2583 Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
2584 Ty: RI.ElementType, Ptr: DestElementAddr, Idx0: 0, Idx1: 0, Name: ".realp");
2585 Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
2586 Ty: RI.ElementType, Ptr: DestElementAddr, Idx0: 0, Idx1: 1, Name: ".imagp");
2587 Builder.CreateStore(Val: SrcReal, Ptr: DestRealPtr);
2588 Builder.CreateStore(Val: SrcImg, Ptr: DestImgPtr);
2589 break;
2590 }
2591 case EvalKind::Aggregate: {
2592 Value *SizeVal = Builder.getInt64(
2593 C: M.getDataLayout().getTypeStoreSize(Ty: RI.ElementType));
2594 Builder.CreateMemCpy(
2595 Dst: DestElementAddr, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
2596 Src: SrcElementAddr, SrcAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
2597 Size: SizeVal, isVolatile: false);
2598 break;
2599 }
2600 };
2601 }
2602
2603 // Step 3.1: Modify reference in dest Reduce list as needed.
2604 // Modifying the reference in Reduce list to point to the newly
2605 // created element. The element is live in the current function
2606 // scope and that of functions it invokes (i.e., reduce_function).
2607 // RemoteReduceData[i] = (void*)&RemoteElem
2608 if (UpdateDestListPtr) {
2609 Value *CastDestAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2610 V: DestElementAddr, DestTy: Builder.getPtrTy(),
2611 Name: DestElementAddr->getName() + ".ascast");
2612 Builder.CreateStore(Val: CastDestAddr, Ptr: DestElementPtrAddr);
2613 }
2614 }
2615}
2616
2617Expected<Function *> OpenMPIRBuilder::emitInterWarpCopyFunction(
2618 const LocationDescription &Loc, ArrayRef<ReductionInfo> ReductionInfos,
2619 AttributeList FuncAttrs) {
2620 InsertPointTy SavedIP = Builder.saveIP();
2621 LLVMContext &Ctx = M.getContext();
2622 FunctionType *FuncTy = FunctionType::get(
2623 Result: Builder.getVoidTy(), Params: {Builder.getPtrTy(), Builder.getInt32Ty()},
2624 /* IsVarArg */ isVarArg: false);
2625 Function *WcFunc =
2626 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
2627 N: "_omp_reduction_inter_warp_copy_func", M: &M);
2628 WcFunc->setAttributes(FuncAttrs);
2629 WcFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
2630 WcFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
2631 BasicBlock *EntryBB = BasicBlock::Create(Context&: M.getContext(), Name: "entry", Parent: WcFunc);
2632 Builder.SetInsertPoint(EntryBB);
2633
2634 // ReduceList: thread local Reduce list.
2635 // At the stage of the computation when this function is called, partially
2636 // aggregated values reside in the first lane of every active warp.
2637 Argument *ReduceListArg = WcFunc->getArg(i: 0);
2638 // NumWarps: number of warps active in the parallel region. This could
2639 // be smaller than 32 (max warps in a CTA) for partial block reduction.
2640 Argument *NumWarpsArg = WcFunc->getArg(i: 1);
2641
2642 // This array is used as a medium to transfer, one reduce element at a time,
2643 // the data from the first lane of every warp to lanes in the first warp
2644 // in order to perform the final step of a reduction in a parallel region
2645 // (reduction across warps). The array is placed in NVPTX __shared__ memory
2646 // for reduced latency, as well as to have a distinct copy for concurrently
2647 // executing target regions. The array is declared with common linkage so
2648 // as to be shared across compilation units.
2649 StringRef TransferMediumName =
2650 "__openmp_nvptx_data_transfer_temporary_storage";
2651 GlobalVariable *TransferMedium = M.getGlobalVariable(Name: TransferMediumName);
2652 unsigned WarpSize = Config.getGridValue().GV_Warp_Size;
2653 ArrayType *ArrayTy = ArrayType::get(ElementType: Builder.getInt32Ty(), NumElements: WarpSize);
2654 if (!TransferMedium) {
2655 TransferMedium = new GlobalVariable(
2656 M, ArrayTy, /*isConstant=*/false, GlobalVariable::WeakAnyLinkage,
2657 UndefValue::get(T: ArrayTy), TransferMediumName,
2658 /*InsertBefore=*/nullptr, GlobalVariable::NotThreadLocal,
2659 /*AddressSpace=*/3);
2660 }
2661
2662 // Get the CUDA thread id of the current OpenMP thread on the GPU.
2663 Value *GPUThreadID = getGPUThreadID();
2664 // nvptx_lane_id = nvptx_id % warpsize
2665 Value *LaneID = getNVPTXLaneID();
2666 // nvptx_warp_id = nvptx_id / warpsize
2667 Value *WarpID = getNVPTXWarpID();
2668
2669 InsertPointTy AllocaIP =
2670 InsertPointTy(Builder.GetInsertBlock(),
2671 Builder.GetInsertBlock()->getFirstInsertionPt());
2672 Type *Arg0Type = ReduceListArg->getType();
2673 Type *Arg1Type = NumWarpsArg->getType();
2674 Builder.restoreIP(IP: AllocaIP);
2675 AllocaInst *ReduceListAlloca = Builder.CreateAlloca(
2676 Ty: Arg0Type, ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
2677 AllocaInst *NumWarpsAlloca =
2678 Builder.CreateAlloca(Ty: Arg1Type, ArraySize: nullptr, Name: NumWarpsArg->getName() + ".addr");
2679 Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2680 V: ReduceListAlloca, DestTy: Arg0Type, Name: ReduceListAlloca->getName() + ".ascast");
2681 Value *NumWarpsAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2682 V: NumWarpsAlloca, DestTy: Builder.getPtrTy(AddrSpace: 0),
2683 Name: NumWarpsAlloca->getName() + ".ascast");
2684 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListAddrCast);
2685 Builder.CreateStore(Val: NumWarpsArg, Ptr: NumWarpsAddrCast);
2686 AllocaIP = getInsertPointAfterInstr(I: NumWarpsAlloca);
2687 InsertPointTy CodeGenIP =
2688 getInsertPointAfterInstr(I: &Builder.GetInsertBlock()->back());
2689 Builder.restoreIP(IP: CodeGenIP);
2690
2691 Value *ReduceList =
2692 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListAddrCast);
2693
2694 for (auto En : enumerate(First&: ReductionInfos)) {
2695 //
2696 // Warp master copies reduce element to transfer medium in __shared__
2697 // memory.
2698 //
2699 const ReductionInfo &RI = En.value();
2700 unsigned RealTySize = M.getDataLayout().getTypeAllocSize(Ty: RI.ElementType);
2701 for (unsigned TySize = 4; TySize > 0 && RealTySize > 0; TySize /= 2) {
2702 Type *CType = Builder.getIntNTy(N: TySize * 8);
2703
2704 unsigned NumIters = RealTySize / TySize;
2705 if (NumIters == 0)
2706 continue;
2707 Value *Cnt = nullptr;
2708 Value *CntAddr = nullptr;
2709 BasicBlock *PrecondBB = nullptr;
2710 BasicBlock *ExitBB = nullptr;
2711 if (NumIters > 1) {
2712 CodeGenIP = Builder.saveIP();
2713 Builder.restoreIP(IP: AllocaIP);
2714 CntAddr =
2715 Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr, Name: ".cnt.addr");
2716
2717 CntAddr = Builder.CreateAddrSpaceCast(V: CntAddr, DestTy: Builder.getPtrTy(),
2718 Name: CntAddr->getName() + ".ascast");
2719 Builder.restoreIP(IP: CodeGenIP);
2720 Builder.CreateStore(Val: Constant::getNullValue(Ty: Builder.getInt32Ty()),
2721 Ptr: CntAddr,
2722 /*Volatile=*/isVolatile: false);
2723 PrecondBB = BasicBlock::Create(Context&: Ctx, Name: "precond");
2724 ExitBB = BasicBlock::Create(Context&: Ctx, Name: "exit");
2725 BasicBlock *BodyBB = BasicBlock::Create(Context&: Ctx, Name: "body");
2726 emitBlock(BB: PrecondBB, CurFn: Builder.GetInsertBlock()->getParent());
2727 Cnt = Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: CntAddr,
2728 /*Volatile=*/isVolatile: false);
2729 Value *Cmp = Builder.CreateICmpULT(
2730 LHS: Cnt, RHS: ConstantInt::get(Ty: Builder.getInt32Ty(), V: NumIters));
2731 Builder.CreateCondBr(Cond: Cmp, True: BodyBB, False: ExitBB);
2732 emitBlock(BB: BodyBB, CurFn: Builder.GetInsertBlock()->getParent());
2733 }
2734
2735 // kmpc_barrier.
2736 InsertPointOrErrorTy BarrierIP1 =
2737 createBarrier(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
2738 Kind: omp::Directive::OMPD_unknown,
2739 /* ForceSimpleCall */ false,
2740 /* CheckCancelFlag */ true);
2741 if (!BarrierIP1)
2742 return BarrierIP1.takeError();
2743 BasicBlock *ThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
2744 BasicBlock *ElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
2745 BasicBlock *MergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
2746
2747 // if (lane_id == 0)
2748 Value *IsWarpMaster = Builder.CreateIsNull(Arg: LaneID, Name: "warp_master");
2749 Builder.CreateCondBr(Cond: IsWarpMaster, True: ThenBB, False: ElseBB);
2750 emitBlock(BB: ThenBB, CurFn: Builder.GetInsertBlock()->getParent());
2751
2752 // Reduce element = LocalReduceList[i]
2753 auto *RedListArrayTy =
2754 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
2755 Type *IndexTy = Builder.getIndexTy(
2756 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
2757 Value *ElemPtrPtr =
2758 Builder.CreateInBoundsGEP(Ty: RedListArrayTy, Ptr: ReduceList,
2759 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0),
2760 ConstantInt::get(Ty: IndexTy, V: En.index())});
2761 // elemptr = ((CopyType*)(elemptrptr)) + I
2762 Value *ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtrPtr);
2763 if (NumIters > 1)
2764 ElemPtr = Builder.CreateGEP(Ty: Builder.getInt32Ty(), Ptr: ElemPtr, IdxList: Cnt);
2765
2766 // Get pointer to location in transfer medium.
2767 // MediumPtr = &medium[warp_id]
2768 Value *MediumPtr = Builder.CreateInBoundsGEP(
2769 Ty: ArrayTy, Ptr: TransferMedium, IdxList: {Builder.getInt64(C: 0), WarpID});
2770 // elem = *elemptr
2771 //*MediumPtr = elem
2772 Value *Elem = Builder.CreateLoad(Ty: CType, Ptr: ElemPtr);
2773 // Store the source element value to the dest element address.
2774 Builder.CreateStore(Val: Elem, Ptr: MediumPtr,
2775 /*IsVolatile*/ isVolatile: true);
2776 Builder.CreateBr(Dest: MergeBB);
2777
2778 // else
2779 emitBlock(BB: ElseBB, CurFn: Builder.GetInsertBlock()->getParent());
2780 Builder.CreateBr(Dest: MergeBB);
2781
2782 // endif
2783 emitBlock(BB: MergeBB, CurFn: Builder.GetInsertBlock()->getParent());
2784 InsertPointOrErrorTy BarrierIP2 =
2785 createBarrier(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
2786 Kind: omp::Directive::OMPD_unknown,
2787 /* ForceSimpleCall */ false,
2788 /* CheckCancelFlag */ true);
2789 if (!BarrierIP2)
2790 return BarrierIP2.takeError();
2791
2792 // Warp 0 copies reduce element from transfer medium
2793 BasicBlock *W0ThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
2794 BasicBlock *W0ElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
2795 BasicBlock *W0MergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
2796
2797 Value *NumWarpsVal =
2798 Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: NumWarpsAddrCast);
2799 // Up to 32 threads in warp 0 are active.
2800 Value *IsActiveThread =
2801 Builder.CreateICmpULT(LHS: GPUThreadID, RHS: NumWarpsVal, Name: "is_active_thread");
2802 Builder.CreateCondBr(Cond: IsActiveThread, True: W0ThenBB, False: W0ElseBB);
2803
2804 emitBlock(BB: W0ThenBB, CurFn: Builder.GetInsertBlock()->getParent());
2805
2806 // SecMediumPtr = &medium[tid]
2807 // SrcMediumVal = *SrcMediumPtr
2808 Value *SrcMediumPtrVal = Builder.CreateInBoundsGEP(
2809 Ty: ArrayTy, Ptr: TransferMedium, IdxList: {Builder.getInt64(C: 0), GPUThreadID});
2810 // TargetElemPtr = (CopyType*)(SrcDataAddr[i]) + I
2811 Value *TargetElemPtrPtr =
2812 Builder.CreateInBoundsGEP(Ty: RedListArrayTy, Ptr: ReduceList,
2813 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0),
2814 ConstantInt::get(Ty: IndexTy, V: En.index())});
2815 Value *TargetElemPtrVal =
2816 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: TargetElemPtrPtr);
2817 Value *TargetElemPtr = TargetElemPtrVal;
2818 if (NumIters > 1)
2819 TargetElemPtr =
2820 Builder.CreateGEP(Ty: Builder.getInt32Ty(), Ptr: TargetElemPtr, IdxList: Cnt);
2821
2822 // *TargetElemPtr = SrcMediumVal;
2823 Value *SrcMediumValue =
2824 Builder.CreateLoad(Ty: CType, Ptr: SrcMediumPtrVal, /*IsVolatile*/ isVolatile: true);
2825 Builder.CreateStore(Val: SrcMediumValue, Ptr: TargetElemPtr);
2826 Builder.CreateBr(Dest: W0MergeBB);
2827
2828 emitBlock(BB: W0ElseBB, CurFn: Builder.GetInsertBlock()->getParent());
2829 Builder.CreateBr(Dest: W0MergeBB);
2830
2831 emitBlock(BB: W0MergeBB, CurFn: Builder.GetInsertBlock()->getParent());
2832
2833 if (NumIters > 1) {
2834 Cnt = Builder.CreateNSWAdd(
2835 LHS: Cnt, RHS: ConstantInt::get(Ty: Builder.getInt32Ty(), /*V=*/1));
2836 Builder.CreateStore(Val: Cnt, Ptr: CntAddr, /*Volatile=*/isVolatile: false);
2837
2838 auto *CurFn = Builder.GetInsertBlock()->getParent();
2839 emitBranch(Target: PrecondBB);
2840 emitBlock(BB: ExitBB, CurFn);
2841 }
2842 RealTySize %= TySize;
2843 }
2844 }
2845
2846 Builder.CreateRetVoid();
2847 Builder.restoreIP(IP: SavedIP);
2848
2849 return WcFunc;
2850}
2851
2852Function *OpenMPIRBuilder::emitShuffleAndReduceFunction(
2853 ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
2854 AttributeList FuncAttrs) {
2855 LLVMContext &Ctx = M.getContext();
2856 FunctionType *FuncTy =
2857 FunctionType::get(Result: Builder.getVoidTy(),
2858 Params: {Builder.getPtrTy(), Builder.getInt16Ty(),
2859 Builder.getInt16Ty(), Builder.getInt16Ty()},
2860 /* IsVarArg */ isVarArg: false);
2861 Function *SarFunc =
2862 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
2863 N: "_omp_reduction_shuffle_and_reduce_func", M: &M);
2864 SarFunc->setAttributes(FuncAttrs);
2865 SarFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
2866 SarFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
2867 SarFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
2868 SarFunc->addParamAttr(ArgNo: 3, Kind: Attribute::NoUndef);
2869 SarFunc->addParamAttr(ArgNo: 1, Kind: Attribute::SExt);
2870 SarFunc->addParamAttr(ArgNo: 2, Kind: Attribute::SExt);
2871 SarFunc->addParamAttr(ArgNo: 3, Kind: Attribute::SExt);
2872 BasicBlock *EntryBB = BasicBlock::Create(Context&: M.getContext(), Name: "entry", Parent: SarFunc);
2873 Builder.SetInsertPoint(EntryBB);
2874
2875 // Thread local Reduce list used to host the values of data to be reduced.
2876 Argument *ReduceListArg = SarFunc->getArg(i: 0);
2877 // Current lane id; could be logical.
2878 Argument *LaneIDArg = SarFunc->getArg(i: 1);
2879 // Offset of the remote source lane relative to the current lane.
2880 Argument *RemoteLaneOffsetArg = SarFunc->getArg(i: 2);
2881 // Algorithm version. This is expected to be known at compile time.
2882 Argument *AlgoVerArg = SarFunc->getArg(i: 3);
2883
2884 Type *ReduceListArgType = ReduceListArg->getType();
2885 Type *LaneIDArgType = LaneIDArg->getType();
2886 Type *LaneIDArgPtrType = Builder.getPtrTy(AddrSpace: 0);
2887 Value *ReduceListAlloca = Builder.CreateAlloca(
2888 Ty: ReduceListArgType, ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
2889 Value *LaneIdAlloca = Builder.CreateAlloca(Ty: LaneIDArgType, ArraySize: nullptr,
2890 Name: LaneIDArg->getName() + ".addr");
2891 Value *RemoteLaneOffsetAlloca = Builder.CreateAlloca(
2892 Ty: LaneIDArgType, ArraySize: nullptr, Name: RemoteLaneOffsetArg->getName() + ".addr");
2893 Value *AlgoVerAlloca = Builder.CreateAlloca(Ty: LaneIDArgType, ArraySize: nullptr,
2894 Name: AlgoVerArg->getName() + ".addr");
2895 ArrayType *RedListArrayTy =
2896 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
2897
2898 // Create a local thread-private variable to host the Reduce list
2899 // from a remote lane.
2900 Instruction *RemoteReductionListAlloca = Builder.CreateAlloca(
2901 Ty: RedListArrayTy, ArraySize: nullptr, Name: ".omp.reduction.remote_reduce_list");
2902
2903 Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2904 V: ReduceListAlloca, DestTy: ReduceListArgType,
2905 Name: ReduceListAlloca->getName() + ".ascast");
2906 Value *LaneIdAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2907 V: LaneIdAlloca, DestTy: LaneIDArgPtrType, Name: LaneIdAlloca->getName() + ".ascast");
2908 Value *RemoteLaneOffsetAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2909 V: RemoteLaneOffsetAlloca, DestTy: LaneIDArgPtrType,
2910 Name: RemoteLaneOffsetAlloca->getName() + ".ascast");
2911 Value *AlgoVerAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2912 V: AlgoVerAlloca, DestTy: LaneIDArgPtrType, Name: AlgoVerAlloca->getName() + ".ascast");
2913 Value *RemoteListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2914 V: RemoteReductionListAlloca, DestTy: Builder.getPtrTy(),
2915 Name: RemoteReductionListAlloca->getName() + ".ascast");
2916
2917 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListAddrCast);
2918 Builder.CreateStore(Val: LaneIDArg, Ptr: LaneIdAddrCast);
2919 Builder.CreateStore(Val: RemoteLaneOffsetArg, Ptr: RemoteLaneOffsetAddrCast);
2920 Builder.CreateStore(Val: AlgoVerArg, Ptr: AlgoVerAddrCast);
2921
2922 Value *ReduceList = Builder.CreateLoad(Ty: ReduceListArgType, Ptr: ReduceListAddrCast);
2923 Value *LaneId = Builder.CreateLoad(Ty: LaneIDArgType, Ptr: LaneIdAddrCast);
2924 Value *RemoteLaneOffset =
2925 Builder.CreateLoad(Ty: LaneIDArgType, Ptr: RemoteLaneOffsetAddrCast);
2926 Value *AlgoVer = Builder.CreateLoad(Ty: LaneIDArgType, Ptr: AlgoVerAddrCast);
2927
2928 InsertPointTy AllocaIP = getInsertPointAfterInstr(I: RemoteReductionListAlloca);
2929
2930 // This loop iterates through the list of reduce elements and copies,
2931 // element by element, from a remote lane in the warp to RemoteReduceList,
2932 // hosted on the thread's stack.
2933 emitReductionListCopy(
2934 AllocaIP, Action: CopyAction::RemoteLaneToThread, ReductionArrayTy: RedListArrayTy, ReductionInfos,
2935 SrcBase: ReduceList, DestBase: RemoteListAddrCast, CopyOptions: {.RemoteLaneOffset: RemoteLaneOffset, .ScratchpadIndex: nullptr, .ScratchpadWidth: nullptr});
2936
2937 // The actions to be performed on the Remote Reduce list is dependent
2938 // on the algorithm version.
2939 //
2940 // if (AlgoVer==0) || (AlgoVer==1 && (LaneId < Offset)) || (AlgoVer==2 &&
2941 // LaneId % 2 == 0 && Offset > 0):
2942 // do the reduction value aggregation
2943 //
2944 // The thread local variable Reduce list is mutated in place to host the
2945 // reduced data, which is the aggregated value produced from local and
2946 // remote lanes.
2947 //
2948 // Note that AlgoVer is expected to be a constant integer known at compile
2949 // time.
2950 // When AlgoVer==0, the first conjunction evaluates to true, making
2951 // the entire predicate true during compile time.
2952 // When AlgoVer==1, the second conjunction has only the second part to be
2953 // evaluated during runtime. Other conjunctions evaluates to false
2954 // during compile time.
2955 // When AlgoVer==2, the third conjunction has only the second part to be
2956 // evaluated during runtime. Other conjunctions evaluates to false
2957 // during compile time.
2958 Value *CondAlgo0 = Builder.CreateIsNull(Arg: AlgoVer);
2959 Value *Algo1 = Builder.CreateICmpEQ(LHS: AlgoVer, RHS: Builder.getInt16(C: 1));
2960 Value *LaneComp = Builder.CreateICmpULT(LHS: LaneId, RHS: RemoteLaneOffset);
2961 Value *CondAlgo1 = Builder.CreateAnd(LHS: Algo1, RHS: LaneComp);
2962 Value *Algo2 = Builder.CreateICmpEQ(LHS: AlgoVer, RHS: Builder.getInt16(C: 2));
2963 Value *LaneIdAnd1 = Builder.CreateAnd(LHS: LaneId, RHS: Builder.getInt16(C: 1));
2964 Value *LaneIdComp = Builder.CreateIsNull(Arg: LaneIdAnd1);
2965 Value *Algo2AndLaneIdComp = Builder.CreateAnd(LHS: Algo2, RHS: LaneIdComp);
2966 Value *RemoteOffsetComp =
2967 Builder.CreateICmpSGT(LHS: RemoteLaneOffset, RHS: Builder.getInt16(C: 0));
2968 Value *CondAlgo2 = Builder.CreateAnd(LHS: Algo2AndLaneIdComp, RHS: RemoteOffsetComp);
2969 Value *CA0OrCA1 = Builder.CreateOr(LHS: CondAlgo0, RHS: CondAlgo1);
2970 Value *CondReduce = Builder.CreateOr(LHS: CA0OrCA1, RHS: CondAlgo2);
2971
2972 BasicBlock *ThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
2973 BasicBlock *ElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
2974 BasicBlock *MergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
2975
2976 Builder.CreateCondBr(Cond: CondReduce, True: ThenBB, False: ElseBB);
2977 emitBlock(BB: ThenBB, CurFn: Builder.GetInsertBlock()->getParent());
2978 Value *LocalReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2979 V: ReduceList, DestTy: Builder.getPtrTy());
2980 Value *RemoteReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2981 V: RemoteListAddrCast, DestTy: Builder.getPtrTy());
2982 Builder.CreateCall(Callee: ReduceFn, Args: {LocalReduceListPtr, RemoteReduceListPtr})
2983 ->addFnAttr(Kind: Attribute::NoUnwind);
2984 Builder.CreateBr(Dest: MergeBB);
2985
2986 emitBlock(BB: ElseBB, CurFn: Builder.GetInsertBlock()->getParent());
2987 Builder.CreateBr(Dest: MergeBB);
2988
2989 emitBlock(BB: MergeBB, CurFn: Builder.GetInsertBlock()->getParent());
2990
2991 // if (AlgoVer==1 && (LaneId >= Offset)) copy Remote Reduce list to local
2992 // Reduce list.
2993 Algo1 = Builder.CreateICmpEQ(LHS: AlgoVer, RHS: Builder.getInt16(C: 1));
2994 Value *LaneIdGtOffset = Builder.CreateICmpUGE(LHS: LaneId, RHS: RemoteLaneOffset);
2995 Value *CondCopy = Builder.CreateAnd(LHS: Algo1, RHS: LaneIdGtOffset);
2996
2997 BasicBlock *CpyThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
2998 BasicBlock *CpyElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
2999 BasicBlock *CpyMergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
3000 Builder.CreateCondBr(Cond: CondCopy, True: CpyThenBB, False: CpyElseBB);
3001
3002 emitBlock(BB: CpyThenBB, CurFn: Builder.GetInsertBlock()->getParent());
3003 emitReductionListCopy(AllocaIP, Action: CopyAction::ThreadCopy, ReductionArrayTy: RedListArrayTy,
3004 ReductionInfos, SrcBase: RemoteListAddrCast, DestBase: ReduceList);
3005 Builder.CreateBr(Dest: CpyMergeBB);
3006
3007 emitBlock(BB: CpyElseBB, CurFn: Builder.GetInsertBlock()->getParent());
3008 Builder.CreateBr(Dest: CpyMergeBB);
3009
3010 emitBlock(BB: CpyMergeBB, CurFn: Builder.GetInsertBlock()->getParent());
3011
3012 Builder.CreateRetVoid();
3013
3014 return SarFunc;
3015}
3016
3017Function *OpenMPIRBuilder::emitListToGlobalCopyFunction(
3018 ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
3019 AttributeList FuncAttrs) {
3020 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
3021 LLVMContext &Ctx = M.getContext();
3022 FunctionType *FuncTy = FunctionType::get(
3023 Result: Builder.getVoidTy(),
3024 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3025 /* IsVarArg */ isVarArg: false);
3026 Function *LtGCFunc =
3027 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3028 N: "_omp_reduction_list_to_global_copy_func", M: &M);
3029 LtGCFunc->setAttributes(FuncAttrs);
3030 LtGCFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3031 LtGCFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3032 LtGCFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
3033
3034 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: LtGCFunc);
3035 Builder.SetInsertPoint(EntryBlock);
3036
3037 // Buffer: global reduction buffer.
3038 Argument *BufferArg = LtGCFunc->getArg(i: 0);
3039 // Idx: index of the buffer.
3040 Argument *IdxArg = LtGCFunc->getArg(i: 1);
3041 // ReduceList: thread local Reduce list.
3042 Argument *ReduceListArg = LtGCFunc->getArg(i: 2);
3043
3044 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
3045 Name: BufferArg->getName() + ".addr");
3046 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
3047 Name: IdxArg->getName() + ".addr");
3048 Value *ReduceListArgAlloca = Builder.CreateAlloca(
3049 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3050 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3051 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
3052 Name: BufferArgAlloca->getName() + ".ascast");
3053 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3054 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
3055 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3056 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
3057 Name: ReduceListArgAlloca->getName() + ".ascast");
3058
3059 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
3060 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
3061 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
3062
3063 Value *LocalReduceList =
3064 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
3065 Value *BufferArgVal =
3066 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
3067 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
3068 Type *IndexTy = Builder.getIndexTy(
3069 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3070 for (auto En : enumerate(First&: ReductionInfos)) {
3071 const ReductionInfo &RI = En.value();
3072 auto *RedListArrayTy =
3073 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3074 // Reduce element = LocalReduceList[i]
3075 Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
3076 Ty: RedListArrayTy, Ptr: LocalReduceList,
3077 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3078 // elemptr = ((CopyType*)(elemptrptr)) + I
3079 Value *ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtrPtr);
3080
3081 // Global = Buffer.VD[Idx];
3082 Value *BufferVD =
3083 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferArgVal, IdxList: Idxs);
3084 Value *GlobVal = Builder.CreateConstInBoundsGEP2_32(
3085 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
3086
3087 switch (RI.EvaluationKind) {
3088 case EvalKind::Scalar: {
3089 Value *TargetElement = Builder.CreateLoad(Ty: RI.ElementType, Ptr: ElemPtr);
3090 Builder.CreateStore(Val: TargetElement, Ptr: GlobVal);
3091 break;
3092 }
3093 case EvalKind::Complex: {
3094 Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
3095 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 0, Name: ".realp");
3096 Value *SrcReal = Builder.CreateLoad(
3097 Ty: RI.ElementType->getStructElementType(N: 0), Ptr: SrcRealPtr, Name: ".real");
3098 Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
3099 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 1, Name: ".imagp");
3100 Value *SrcImg = Builder.CreateLoad(
3101 Ty: RI.ElementType->getStructElementType(N: 1), Ptr: SrcImgPtr, Name: ".imag");
3102
3103 Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
3104 Ty: RI.ElementType, Ptr: GlobVal, Idx0: 0, Idx1: 0, Name: ".realp");
3105 Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
3106 Ty: RI.ElementType, Ptr: GlobVal, Idx0: 0, Idx1: 1, Name: ".imagp");
3107 Builder.CreateStore(Val: SrcReal, Ptr: DestRealPtr);
3108 Builder.CreateStore(Val: SrcImg, Ptr: DestImgPtr);
3109 break;
3110 }
3111 case EvalKind::Aggregate: {
3112 Value *SizeVal =
3113 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: RI.ElementType));
3114 Builder.CreateMemCpy(
3115 Dst: GlobVal, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType), Src: ElemPtr,
3116 SrcAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType), Size: SizeVal, isVolatile: false);
3117 break;
3118 }
3119 }
3120 }
3121
3122 Builder.CreateRetVoid();
3123 Builder.restoreIP(IP: OldIP);
3124 return LtGCFunc;
3125}
3126
3127Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
3128 ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
3129 Type *ReductionsBufferTy, AttributeList FuncAttrs) {
3130 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
3131 LLVMContext &Ctx = M.getContext();
3132 FunctionType *FuncTy = FunctionType::get(
3133 Result: Builder.getVoidTy(),
3134 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3135 /* IsVarArg */ isVarArg: false);
3136 Function *LtGRFunc =
3137 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3138 N: "_omp_reduction_list_to_global_reduce_func", M: &M);
3139 LtGRFunc->setAttributes(FuncAttrs);
3140 LtGRFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3141 LtGRFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3142 LtGRFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
3143
3144 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: LtGRFunc);
3145 Builder.SetInsertPoint(EntryBlock);
3146
3147 // Buffer: global reduction buffer.
3148 Argument *BufferArg = LtGRFunc->getArg(i: 0);
3149 // Idx: index of the buffer.
3150 Argument *IdxArg = LtGRFunc->getArg(i: 1);
3151 // ReduceList: thread local Reduce list.
3152 Argument *ReduceListArg = LtGRFunc->getArg(i: 2);
3153
3154 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
3155 Name: BufferArg->getName() + ".addr");
3156 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
3157 Name: IdxArg->getName() + ".addr");
3158 Value *ReduceListArgAlloca = Builder.CreateAlloca(
3159 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3160 auto *RedListArrayTy =
3161 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3162
3163 // 1. Build a list of reduction variables.
3164 // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
3165 Value *LocalReduceList =
3166 Builder.CreateAlloca(Ty: RedListArrayTy, ArraySize: nullptr, Name: ".omp.reduction.red_list");
3167
3168 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3169 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
3170 Name: BufferArgAlloca->getName() + ".ascast");
3171 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3172 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
3173 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3174 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
3175 Name: ReduceListArgAlloca->getName() + ".ascast");
3176 Value *LocalReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3177 V: LocalReduceList, DestTy: Builder.getPtrTy(),
3178 Name: LocalReduceList->getName() + ".ascast");
3179
3180 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
3181 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
3182 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
3183
3184 Value *BufferVal = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
3185 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
3186 Type *IndexTy = Builder.getIndexTy(
3187 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3188 for (auto En : enumerate(First&: ReductionInfos)) {
3189 Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
3190 Ty: RedListArrayTy, Ptr: LocalReduceListAddrCast,
3191 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3192 Value *BufferVD =
3193 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferVal, IdxList: Idxs);
3194 // Global = Buffer.VD[Idx];
3195 Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3196 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
3197 Builder.CreateStore(Val: GlobValPtr, Ptr: TargetElementPtrPtr);
3198 }
3199
3200 // Call reduce_function(GlobalReduceList, ReduceList)
3201 Value *ReduceList =
3202 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
3203 Builder.CreateCall(Callee: ReduceFn, Args: {LocalReduceListAddrCast, ReduceList})
3204 ->addFnAttr(Kind: Attribute::NoUnwind);
3205 Builder.CreateRetVoid();
3206 Builder.restoreIP(IP: OldIP);
3207 return LtGRFunc;
3208}
3209
3210Function *OpenMPIRBuilder::emitGlobalToListCopyFunction(
3211 ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
3212 AttributeList FuncAttrs) {
3213 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
3214 LLVMContext &Ctx = M.getContext();
3215 FunctionType *FuncTy = FunctionType::get(
3216 Result: Builder.getVoidTy(),
3217 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3218 /* IsVarArg */ isVarArg: false);
3219 Function *LtGCFunc =
3220 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3221 N: "_omp_reduction_global_to_list_copy_func", M: &M);
3222 LtGCFunc->setAttributes(FuncAttrs);
3223 LtGCFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3224 LtGCFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3225 LtGCFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
3226
3227 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: LtGCFunc);
3228 Builder.SetInsertPoint(EntryBlock);
3229
3230 // Buffer: global reduction buffer.
3231 Argument *BufferArg = LtGCFunc->getArg(i: 0);
3232 // Idx: index of the buffer.
3233 Argument *IdxArg = LtGCFunc->getArg(i: 1);
3234 // ReduceList: thread local Reduce list.
3235 Argument *ReduceListArg = LtGCFunc->getArg(i: 2);
3236
3237 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
3238 Name: BufferArg->getName() + ".addr");
3239 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
3240 Name: IdxArg->getName() + ".addr");
3241 Value *ReduceListArgAlloca = Builder.CreateAlloca(
3242 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3243 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3244 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
3245 Name: BufferArgAlloca->getName() + ".ascast");
3246 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3247 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
3248 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3249 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
3250 Name: ReduceListArgAlloca->getName() + ".ascast");
3251 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
3252 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
3253 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
3254
3255 Value *LocalReduceList =
3256 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
3257 Value *BufferVal = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
3258 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
3259 Type *IndexTy = Builder.getIndexTy(
3260 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3261 for (auto En : enumerate(First&: ReductionInfos)) {
3262 const OpenMPIRBuilder::ReductionInfo &RI = En.value();
3263 auto *RedListArrayTy =
3264 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3265 // Reduce element = LocalReduceList[i]
3266 Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
3267 Ty: RedListArrayTy, Ptr: LocalReduceList,
3268 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3269 // elemptr = ((CopyType*)(elemptrptr)) + I
3270 Value *ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtrPtr);
3271 // Global = Buffer.VD[Idx];
3272 Value *BufferVD =
3273 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferVal, IdxList: Idxs);
3274 Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3275 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
3276
3277 switch (RI.EvaluationKind) {
3278 case EvalKind::Scalar: {
3279 Value *TargetElement = Builder.CreateLoad(Ty: RI.ElementType, Ptr: GlobValPtr);
3280 Builder.CreateStore(Val: TargetElement, Ptr: ElemPtr);
3281 break;
3282 }
3283 case EvalKind::Complex: {
3284 Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
3285 Ty: RI.ElementType, Ptr: GlobValPtr, Idx0: 0, Idx1: 0, Name: ".realp");
3286 Value *SrcReal = Builder.CreateLoad(
3287 Ty: RI.ElementType->getStructElementType(N: 0), Ptr: SrcRealPtr, Name: ".real");
3288 Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
3289 Ty: RI.ElementType, Ptr: GlobValPtr, Idx0: 0, Idx1: 1, Name: ".imagp");
3290 Value *SrcImg = Builder.CreateLoad(
3291 Ty: RI.ElementType->getStructElementType(N: 1), Ptr: SrcImgPtr, Name: ".imag");
3292
3293 Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
3294 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 0, Name: ".realp");
3295 Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
3296 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 1, Name: ".imagp");
3297 Builder.CreateStore(Val: SrcReal, Ptr: DestRealPtr);
3298 Builder.CreateStore(Val: SrcImg, Ptr: DestImgPtr);
3299 break;
3300 }
3301 case EvalKind::Aggregate: {
3302 Value *SizeVal =
3303 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: RI.ElementType));
3304 Builder.CreateMemCpy(
3305 Dst: ElemPtr, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
3306 Src: GlobValPtr, SrcAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
3307 Size: SizeVal, isVolatile: false);
3308 break;
3309 }
3310 }
3311 }
3312
3313 Builder.CreateRetVoid();
3314 Builder.restoreIP(IP: OldIP);
3315 return LtGCFunc;
3316}
3317
3318Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
3319 ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
3320 Type *ReductionsBufferTy, AttributeList FuncAttrs) {
3321 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
3322 LLVMContext &Ctx = M.getContext();
3323 auto *FuncTy = FunctionType::get(
3324 Result: Builder.getVoidTy(),
3325 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3326 /* IsVarArg */ isVarArg: false);
3327 Function *LtGRFunc =
3328 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3329 N: "_omp_reduction_global_to_list_reduce_func", M: &M);
3330 LtGRFunc->setAttributes(FuncAttrs);
3331 LtGRFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3332 LtGRFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3333 LtGRFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
3334
3335 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: LtGRFunc);
3336 Builder.SetInsertPoint(EntryBlock);
3337
3338 // Buffer: global reduction buffer.
3339 Argument *BufferArg = LtGRFunc->getArg(i: 0);
3340 // Idx: index of the buffer.
3341 Argument *IdxArg = LtGRFunc->getArg(i: 1);
3342 // ReduceList: thread local Reduce list.
3343 Argument *ReduceListArg = LtGRFunc->getArg(i: 2);
3344
3345 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
3346 Name: BufferArg->getName() + ".addr");
3347 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
3348 Name: IdxArg->getName() + ".addr");
3349 Value *ReduceListArgAlloca = Builder.CreateAlloca(
3350 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3351 ArrayType *RedListArrayTy =
3352 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3353
3354 // 1. Build a list of reduction variables.
3355 // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
3356 Value *LocalReduceList =
3357 Builder.CreateAlloca(Ty: RedListArrayTy, ArraySize: nullptr, Name: ".omp.reduction.red_list");
3358
3359 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3360 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
3361 Name: BufferArgAlloca->getName() + ".ascast");
3362 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3363 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
3364 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3365 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
3366 Name: ReduceListArgAlloca->getName() + ".ascast");
3367 Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
3368 V: LocalReduceList, DestTy: Builder.getPtrTy(),
3369 Name: LocalReduceList->getName() + ".ascast");
3370
3371 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
3372 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
3373 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
3374
3375 Value *BufferVal = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
3376 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
3377 Type *IndexTy = Builder.getIndexTy(
3378 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3379 for (auto En : enumerate(First&: ReductionInfos)) {
3380 Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
3381 Ty: RedListArrayTy, Ptr: ReductionList,
3382 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3383 // Global = Buffer.VD[Idx];
3384 Value *BufferVD =
3385 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferVal, IdxList: Idxs);
3386 Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3387 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
3388 Builder.CreateStore(Val: GlobValPtr, Ptr: TargetElementPtrPtr);
3389 }
3390
3391 // Call reduce_function(ReduceList, GlobalReduceList)
3392 Value *ReduceList =
3393 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
3394 Builder.CreateCall(Callee: ReduceFn, Args: {ReduceList, ReductionList})
3395 ->addFnAttr(Kind: Attribute::NoUnwind);
3396 Builder.CreateRetVoid();
3397 Builder.restoreIP(IP: OldIP);
3398 return LtGRFunc;
3399}
3400
3401std::string OpenMPIRBuilder::getReductionFuncName(StringRef Name) const {
3402 std::string Suffix =
3403 createPlatformSpecificName(Parts: {"omp", "reduction", "reduction_func"});
3404 return (Name + Suffix).str();
3405}
3406
3407Expected<Function *> OpenMPIRBuilder::createReductionFunction(
3408 StringRef ReducerName, ArrayRef<ReductionInfo> ReductionInfos,
3409 ReductionGenCBKind ReductionGenCBKind, AttributeList FuncAttrs) {
3410 auto *FuncTy = FunctionType::get(Result: Builder.getVoidTy(),
3411 Params: {Builder.getPtrTy(), Builder.getPtrTy()},
3412 /* IsVarArg */ isVarArg: false);
3413 std::string Name = getReductionFuncName(Name: ReducerName);
3414 Function *ReductionFunc =
3415 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage, N: Name, M: &M);
3416 ReductionFunc->setAttributes(FuncAttrs);
3417 ReductionFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3418 ReductionFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3419 BasicBlock *EntryBB =
3420 BasicBlock::Create(Context&: M.getContext(), Name: "entry", Parent: ReductionFunc);
3421 Builder.SetInsertPoint(EntryBB);
3422
3423 // Need to alloca memory here and deal with the pointers before getting
3424 // LHS/RHS pointers out
3425 Value *LHSArrayPtr = nullptr;
3426 Value *RHSArrayPtr = nullptr;
3427 Argument *Arg0 = ReductionFunc->getArg(i: 0);
3428 Argument *Arg1 = ReductionFunc->getArg(i: 1);
3429 Type *Arg0Type = Arg0->getType();
3430 Type *Arg1Type = Arg1->getType();
3431
3432 Value *LHSAlloca =
3433 Builder.CreateAlloca(Ty: Arg0Type, ArraySize: nullptr, Name: Arg0->getName() + ".addr");
3434 Value *RHSAlloca =
3435 Builder.CreateAlloca(Ty: Arg1Type, ArraySize: nullptr, Name: Arg1->getName() + ".addr");
3436 Value *LHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3437 V: LHSAlloca, DestTy: Arg0Type, Name: LHSAlloca->getName() + ".ascast");
3438 Value *RHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3439 V: RHSAlloca, DestTy: Arg1Type, Name: RHSAlloca->getName() + ".ascast");
3440 Builder.CreateStore(Val: Arg0, Ptr: LHSAddrCast);
3441 Builder.CreateStore(Val: Arg1, Ptr: RHSAddrCast);
3442 LHSArrayPtr = Builder.CreateLoad(Ty: Arg0Type, Ptr: LHSAddrCast);
3443 RHSArrayPtr = Builder.CreateLoad(Ty: Arg1Type, Ptr: RHSAddrCast);
3444
3445 Type *RedArrayTy = ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3446 Type *IndexTy = Builder.getIndexTy(
3447 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3448 SmallVector<Value *> LHSPtrs, RHSPtrs;
3449 for (auto En : enumerate(First&: ReductionInfos)) {
3450 const ReductionInfo &RI = En.value();
3451 Value *RHSI8PtrPtr = Builder.CreateInBoundsGEP(
3452 Ty: RedArrayTy, Ptr: RHSArrayPtr,
3453 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3454 Value *RHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: RHSI8PtrPtr);
3455 Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3456 V: RHSI8Ptr, DestTy: RI.PrivateVariable->getType(),
3457 Name: RHSI8Ptr->getName() + ".ascast");
3458
3459 Value *LHSI8PtrPtr = Builder.CreateInBoundsGEP(
3460 Ty: RedArrayTy, Ptr: LHSArrayPtr,
3461 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3462 Value *LHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: LHSI8PtrPtr);
3463 Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3464 V: LHSI8Ptr, DestTy: RI.Variable->getType(), Name: LHSI8Ptr->getName() + ".ascast");
3465
3466 if (ReductionGenCBKind == ReductionGenCBKind::Clang) {
3467 LHSPtrs.emplace_back(Args&: LHSPtr);
3468 RHSPtrs.emplace_back(Args&: RHSPtr);
3469 } else {
3470 Value *LHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: LHSPtr);
3471 Value *RHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: RHSPtr);
3472 Value *Reduced;
3473 InsertPointOrErrorTy AfterIP =
3474 RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
3475 if (!AfterIP)
3476 return AfterIP.takeError();
3477 if (!Builder.GetInsertBlock())
3478 return ReductionFunc;
3479 Builder.CreateStore(Val: Reduced, Ptr: LHSPtr);
3480 }
3481 }
3482
3483 if (ReductionGenCBKind == ReductionGenCBKind::Clang)
3484 for (auto En : enumerate(First&: ReductionInfos)) {
3485 unsigned Index = En.index();
3486 const ReductionInfo &RI = En.value();
3487 Value *LHSFixupPtr, *RHSFixupPtr;
3488 Builder.restoreIP(IP: RI.ReductionGenClang(
3489 Builder.saveIP(), Index, &LHSFixupPtr, &RHSFixupPtr, ReductionFunc));
3490
3491 // Fix the CallBack code genereated to use the correct Values for the LHS
3492 // and RHS
3493 LHSFixupPtr->replaceUsesWithIf(
3494 New: LHSPtrs[Index], ShouldReplace: [ReductionFunc](const Use &U) {
3495 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
3496 ReductionFunc;
3497 });
3498 RHSFixupPtr->replaceUsesWithIf(
3499 New: RHSPtrs[Index], ShouldReplace: [ReductionFunc](const Use &U) {
3500 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
3501 ReductionFunc;
3502 });
3503 }
3504
3505 Builder.CreateRetVoid();
3506 return ReductionFunc;
3507}
3508
3509static void
3510checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
3511 bool IsGPU) {
3512 for (const OpenMPIRBuilder::ReductionInfo &RI : ReductionInfos) {
3513 (void)RI;
3514 assert(RI.Variable && "expected non-null variable");
3515 assert(RI.PrivateVariable && "expected non-null private variable");
3516 assert((RI.ReductionGen || RI.ReductionGenClang) &&
3517 "expected non-null reduction generator callback");
3518 if (!IsGPU) {
3519 assert(
3520 RI.Variable->getType() == RI.PrivateVariable->getType() &&
3521 "expected variables and their private equivalents to have the same "
3522 "type");
3523 }
3524 assert(RI.Variable->getType()->isPointerTy() &&
3525 "expected variables to be pointers");
3526 }
3527}
3528
3529OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
3530 const LocationDescription &Loc, InsertPointTy AllocaIP,
3531 InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
3532 bool IsNoWait, bool IsTeamsReduction, ReductionGenCBKind ReductionGenCBKind,
3533 std::optional<omp::GV> GridValue, unsigned ReductionBufNum,
3534 Value *SrcLocInfo) {
3535 if (!updateToLocation(Loc))
3536 return InsertPointTy();
3537 Builder.restoreIP(IP: CodeGenIP);
3538 checkReductionInfos(ReductionInfos, /*IsGPU*/ true);
3539 LLVMContext &Ctx = M.getContext();
3540
3541 // Source location for the ident struct
3542 if (!SrcLocInfo) {
3543 uint32_t SrcLocStrSize;
3544 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3545 SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3546 }
3547
3548 if (ReductionInfos.size() == 0)
3549 return Builder.saveIP();
3550
3551 BasicBlock *ContinuationBlock = nullptr;
3552 if (ReductionGenCBKind != ReductionGenCBKind::Clang) {
3553 // Copied code from createReductions
3554 BasicBlock *InsertBlock = Loc.IP.getBlock();
3555 ContinuationBlock =
3556 InsertBlock->splitBasicBlock(I: Loc.IP.getPoint(), BBName: "reduce.finalize");
3557 InsertBlock->getTerminator()->eraseFromParent();
3558 Builder.SetInsertPoint(TheBB: InsertBlock, IP: InsertBlock->end());
3559 }
3560
3561 Function *CurFunc = Builder.GetInsertBlock()->getParent();
3562 AttributeList FuncAttrs;
3563 AttrBuilder AttrBldr(Ctx);
3564 for (auto Attr : CurFunc->getAttributes().getFnAttrs())
3565 AttrBldr.addAttribute(A: Attr);
3566 AttrBldr.removeAttribute(Val: Attribute::OptimizeNone);
3567 FuncAttrs = FuncAttrs.addFnAttributes(C&: Ctx, B: AttrBldr);
3568
3569 CodeGenIP = Builder.saveIP();
3570 Expected<Function *> ReductionResult =
3571 createReductionFunction(ReducerName: Builder.GetInsertBlock()->getParent()->getName(),
3572 ReductionInfos, ReductionGenCBKind, FuncAttrs);
3573 if (!ReductionResult)
3574 return ReductionResult.takeError();
3575 Function *ReductionFunc = *ReductionResult;
3576 Builder.restoreIP(IP: CodeGenIP);
3577
3578 // Set the grid value in the config needed for lowering later on
3579 if (GridValue.has_value())
3580 Config.setGridValue(GridValue.value());
3581 else
3582 Config.setGridValue(getGridValue(T, Kernel: ReductionFunc));
3583
3584 // Build res = __kmpc_reduce{_nowait}(<gtid>, <n>, sizeof(RedList),
3585 // RedList, shuffle_reduce_func, interwarp_copy_func);
3586 // or
3587 // Build res = __kmpc_reduce_teams_nowait_simple(<loc>, <gtid>, <lck>);
3588 Value *Res;
3589
3590 // 1. Build a list of reduction variables.
3591 // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
3592 auto Size = ReductionInfos.size();
3593 Type *PtrTy = PointerType::getUnqual(C&: Ctx);
3594 Type *RedArrayTy = ArrayType::get(ElementType: PtrTy, NumElements: Size);
3595 CodeGenIP = Builder.saveIP();
3596 Builder.restoreIP(IP: AllocaIP);
3597 Value *ReductionListAlloca =
3598 Builder.CreateAlloca(Ty: RedArrayTy, ArraySize: nullptr, Name: ".omp.reduction.red_list");
3599 Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
3600 V: ReductionListAlloca, DestTy: PtrTy, Name: ReductionListAlloca->getName() + ".ascast");
3601 Builder.restoreIP(IP: CodeGenIP);
3602 Type *IndexTy = Builder.getIndexTy(
3603 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3604 for (auto En : enumerate(First&: ReductionInfos)) {
3605 const ReductionInfo &RI = En.value();
3606 Value *ElemPtr = Builder.CreateInBoundsGEP(
3607 Ty: RedArrayTy, Ptr: ReductionList,
3608 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3609 Value *CastElem =
3610 Builder.CreatePointerBitCastOrAddrSpaceCast(V: RI.PrivateVariable, DestTy: PtrTy);
3611 Builder.CreateStore(Val: CastElem, Ptr: ElemPtr);
3612 }
3613 CodeGenIP = Builder.saveIP();
3614 Function *SarFunc =
3615 emitShuffleAndReduceFunction(ReductionInfos, ReduceFn: ReductionFunc, FuncAttrs);
3616 Expected<Function *> CopyResult =
3617 emitInterWarpCopyFunction(Loc, ReductionInfos, FuncAttrs);
3618 if (!CopyResult)
3619 return CopyResult.takeError();
3620 Function *WcFunc = *CopyResult;
3621 Builder.restoreIP(IP: CodeGenIP);
3622
3623 Value *RL = Builder.CreatePointerBitCastOrAddrSpaceCast(V: ReductionList, DestTy: PtrTy);
3624
3625 unsigned MaxDataSize = 0;
3626 SmallVector<Type *> ReductionTypeArgs;
3627 for (auto En : enumerate(First&: ReductionInfos)) {
3628 auto Size = M.getDataLayout().getTypeStoreSize(Ty: En.value().ElementType);
3629 if (Size > MaxDataSize)
3630 MaxDataSize = Size;
3631 ReductionTypeArgs.emplace_back(Args: En.value().ElementType);
3632 }
3633 Value *ReductionDataSize =
3634 Builder.getInt64(C: MaxDataSize * ReductionInfos.size());
3635 if (!IsTeamsReduction) {
3636 Value *SarFuncCast =
3637 Builder.CreatePointerBitCastOrAddrSpaceCast(V: SarFunc, DestTy: PtrTy);
3638 Value *WcFuncCast =
3639 Builder.CreatePointerBitCastOrAddrSpaceCast(V: WcFunc, DestTy: PtrTy);
3640 Value *Args[] = {SrcLocInfo, ReductionDataSize, RL, SarFuncCast,
3641 WcFuncCast};
3642 Function *Pv2Ptr = getOrCreateRuntimeFunctionPtr(
3643 FnID: RuntimeFunction::OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2);
3644 Res = Builder.CreateCall(Callee: Pv2Ptr, Args);
3645 } else {
3646 CodeGenIP = Builder.saveIP();
3647 StructType *ReductionsBufferTy = StructType::create(
3648 Context&: Ctx, Elements: ReductionTypeArgs, Name: "struct._globalized_locals_ty");
3649 Function *RedFixedBuferFn = getOrCreateRuntimeFunctionPtr(
3650 FnID: RuntimeFunction::OMPRTL___kmpc_reduction_get_fixed_buffer);
3651 Function *LtGCFunc = emitListToGlobalCopyFunction(
3652 ReductionInfos, ReductionsBufferTy, FuncAttrs);
3653 Function *LtGRFunc = emitListToGlobalReduceFunction(
3654 ReductionInfos, ReduceFn: ReductionFunc, ReductionsBufferTy, FuncAttrs);
3655 Function *GtLCFunc = emitGlobalToListCopyFunction(
3656 ReductionInfos, ReductionsBufferTy, FuncAttrs);
3657 Function *GtLRFunc = emitGlobalToListReduceFunction(
3658 ReductionInfos, ReduceFn: ReductionFunc, ReductionsBufferTy, FuncAttrs);
3659 Builder.restoreIP(IP: CodeGenIP);
3660
3661 Value *KernelTeamsReductionPtr = Builder.CreateCall(
3662 Callee: RedFixedBuferFn, Args: {}, Name: "_openmp_teams_reductions_buffer_$_$ptr");
3663
3664 Value *Args3[] = {SrcLocInfo,
3665 KernelTeamsReductionPtr,
3666 Builder.getInt32(C: ReductionBufNum),
3667 ReductionDataSize,
3668 RL,
3669 SarFunc,
3670 WcFunc,
3671 LtGCFunc,
3672 LtGRFunc,
3673 GtLCFunc,
3674 GtLRFunc};
3675
3676 Function *TeamsReduceFn = getOrCreateRuntimeFunctionPtr(
3677 FnID: RuntimeFunction::OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2);
3678 Res = Builder.CreateCall(Callee: TeamsReduceFn, Args: Args3);
3679 }
3680
3681 // 5. Build if (res == 1)
3682 BasicBlock *ExitBB = BasicBlock::Create(Context&: Ctx, Name: ".omp.reduction.done");
3683 BasicBlock *ThenBB = BasicBlock::Create(Context&: Ctx, Name: ".omp.reduction.then");
3684 Value *Cond = Builder.CreateICmpEQ(LHS: Res, RHS: Builder.getInt32(C: 1));
3685 Builder.CreateCondBr(Cond, True: ThenBB, False: ExitBB);
3686
3687 // 6. Build then branch: where we have reduced values in the master
3688 // thread in each team.
3689 // __kmpc_end_reduce{_nowait}(<gtid>);
3690 // break;
3691 emitBlock(BB: ThenBB, CurFn: CurFunc);
3692
3693 // Add emission of __kmpc_end_reduce{_nowait}(<gtid>);
3694 for (auto En : enumerate(First&: ReductionInfos)) {
3695 const ReductionInfo &RI = En.value();
3696 Value *LHS = RI.Variable;
3697 Value *RHS =
3698 Builder.CreatePointerBitCastOrAddrSpaceCast(V: RI.PrivateVariable, DestTy: PtrTy);
3699
3700 if (ReductionGenCBKind == ReductionGenCBKind::Clang) {
3701 Value *LHSPtr, *RHSPtr;
3702 Builder.restoreIP(IP: RI.ReductionGenClang(Builder.saveIP(), En.index(),
3703 &LHSPtr, &RHSPtr, CurFunc));
3704
3705 // Fix the CallBack code genereated to use the correct Values for the LHS
3706 // and RHS
3707 LHSPtr->replaceUsesWithIf(New: LHS, ShouldReplace: [ReductionFunc](const Use &U) {
3708 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
3709 ReductionFunc;
3710 });
3711 RHSPtr->replaceUsesWithIf(New: RHS, ShouldReplace: [ReductionFunc](const Use &U) {
3712 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
3713 ReductionFunc;
3714 });
3715 } else {
3716 Value *LHSValue = Builder.CreateLoad(Ty: RI.ElementType, Ptr: LHS, Name: "final.lhs");
3717 Value *RHSValue = Builder.CreateLoad(Ty: RI.ElementType, Ptr: RHS, Name: "final.rhs");
3718 Value *Reduced;
3719 InsertPointOrErrorTy AfterIP =
3720 RI.ReductionGen(Builder.saveIP(), RHSValue, LHSValue, Reduced);
3721 if (!AfterIP)
3722 return AfterIP.takeError();
3723 Builder.CreateStore(Val: Reduced, Ptr: LHS, isVolatile: false);
3724 }
3725 }
3726 emitBlock(BB: ExitBB, CurFn: CurFunc);
3727 if (ContinuationBlock) {
3728 Builder.CreateBr(Dest: ContinuationBlock);
3729 Builder.SetInsertPoint(ContinuationBlock);
3730 }
3731 Config.setEmitLLVMUsed();
3732
3733 return Builder.saveIP();
3734}
3735
3736static Function *getFreshReductionFunc(Module &M) {
3737 Type *VoidTy = Type::getVoidTy(C&: M.getContext());
3738 Type *Int8PtrTy = PointerType::getUnqual(C&: M.getContext());
3739 auto *FuncTy =
3740 FunctionType::get(Result: VoidTy, Params: {Int8PtrTy, Int8PtrTy}, /* IsVarArg */ isVarArg: false);
3741 return Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3742 N: ".omp.reduction.func", M: &M);
3743}
3744
3745static Error populateReductionFunction(
3746 Function *ReductionFunc,
3747 ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
3748 IRBuilder<> &Builder, ArrayRef<bool> IsByRef, bool IsGPU) {
3749 Module *Module = ReductionFunc->getParent();
3750 BasicBlock *ReductionFuncBlock =
3751 BasicBlock::Create(Context&: Module->getContext(), Name: "", Parent: ReductionFunc);
3752 Builder.SetInsertPoint(ReductionFuncBlock);
3753 Value *LHSArrayPtr = nullptr;
3754 Value *RHSArrayPtr = nullptr;
3755 if (IsGPU) {
3756 // Need to alloca memory here and deal with the pointers before getting
3757 // LHS/RHS pointers out
3758 //
3759 Argument *Arg0 = ReductionFunc->getArg(i: 0);
3760 Argument *Arg1 = ReductionFunc->getArg(i: 1);
3761 Type *Arg0Type = Arg0->getType();
3762 Type *Arg1Type = Arg1->getType();
3763
3764 Value *LHSAlloca =
3765 Builder.CreateAlloca(Ty: Arg0Type, ArraySize: nullptr, Name: Arg0->getName() + ".addr");
3766 Value *RHSAlloca =
3767 Builder.CreateAlloca(Ty: Arg1Type, ArraySize: nullptr, Name: Arg1->getName() + ".addr");
3768 Value *LHSAddrCast =
3769 Builder.CreatePointerBitCastOrAddrSpaceCast(V: LHSAlloca, DestTy: Arg0Type);
3770 Value *RHSAddrCast =
3771 Builder.CreatePointerBitCastOrAddrSpaceCast(V: RHSAlloca, DestTy: Arg1Type);
3772 Builder.CreateStore(Val: Arg0, Ptr: LHSAddrCast);
3773 Builder.CreateStore(Val: Arg1, Ptr: RHSAddrCast);
3774 LHSArrayPtr = Builder.CreateLoad(Ty: Arg0Type, Ptr: LHSAddrCast);
3775 RHSArrayPtr = Builder.CreateLoad(Ty: Arg1Type, Ptr: RHSAddrCast);
3776 } else {
3777 LHSArrayPtr = ReductionFunc->getArg(i: 0);
3778 RHSArrayPtr = ReductionFunc->getArg(i: 1);
3779 }
3780
3781 unsigned NumReductions = ReductionInfos.size();
3782 Type *RedArrayTy = ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: NumReductions);
3783
3784 for (auto En : enumerate(First&: ReductionInfos)) {
3785 const OpenMPIRBuilder::ReductionInfo &RI = En.value();
3786 Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
3787 Ty: RedArrayTy, Ptr: LHSArrayPtr, Idx0: 0, Idx1: En.index());
3788 Value *LHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: LHSI8PtrPtr);
3789 Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3790 V: LHSI8Ptr, DestTy: RI.Variable->getType());
3791 Value *LHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: LHSPtr);
3792 Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
3793 Ty: RedArrayTy, Ptr: RHSArrayPtr, Idx0: 0, Idx1: En.index());
3794 Value *RHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: RHSI8PtrPtr);
3795 Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3796 V: RHSI8Ptr, DestTy: RI.PrivateVariable->getType());
3797 Value *RHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: RHSPtr);
3798 Value *Reduced;
3799 OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
3800 RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
3801 if (!AfterIP)
3802 return AfterIP.takeError();
3803
3804 Builder.restoreIP(IP: *AfterIP);
3805 // TODO: Consider flagging an error.
3806 if (!Builder.GetInsertBlock())
3807 return Error::success();
3808
3809 // store is inside of the reduction region when using by-ref
3810 if (!IsByRef[En.index()])
3811 Builder.CreateStore(Val: Reduced, Ptr: LHSPtr);
3812 }
3813 Builder.CreateRetVoid();
3814 return Error::success();
3815}
3816
3817OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductions(
3818 const LocationDescription &Loc, InsertPointTy AllocaIP,
3819 ArrayRef<ReductionInfo> ReductionInfos, ArrayRef<bool> IsByRef,
3820 bool IsNoWait, bool IsTeamsReduction) {
3821 assert(ReductionInfos.size() == IsByRef.size());
3822 if (Config.isGPU())
3823 return createReductionsGPU(Loc, AllocaIP, CodeGenIP: Builder.saveIP(), ReductionInfos,
3824 IsNoWait, IsTeamsReduction);
3825
3826 checkReductionInfos(ReductionInfos, /*IsGPU*/ false);
3827
3828 if (!updateToLocation(Loc))
3829 return InsertPointTy();
3830
3831 if (ReductionInfos.size() == 0)
3832 return Builder.saveIP();
3833
3834 BasicBlock *InsertBlock = Loc.IP.getBlock();
3835 BasicBlock *ContinuationBlock =
3836 InsertBlock->splitBasicBlock(I: Loc.IP.getPoint(), BBName: "reduce.finalize");
3837 InsertBlock->getTerminator()->eraseFromParent();
3838
3839 // Create and populate array of type-erased pointers to private reduction
3840 // values.
3841 unsigned NumReductions = ReductionInfos.size();
3842 Type *RedArrayTy = ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: NumReductions);
3843 Builder.SetInsertPoint(AllocaIP.getBlock()->getTerminator());
3844 Value *RedArray = Builder.CreateAlloca(Ty: RedArrayTy, ArraySize: nullptr, Name: "red.array");
3845
3846 Builder.SetInsertPoint(TheBB: InsertBlock, IP: InsertBlock->end());
3847
3848 for (auto En : enumerate(First&: ReductionInfos)) {
3849 unsigned Index = En.index();
3850 const ReductionInfo &RI = En.value();
3851 Value *RedArrayElemPtr = Builder.CreateConstInBoundsGEP2_64(
3852 Ty: RedArrayTy, Ptr: RedArray, Idx0: 0, Idx1: Index, Name: "red.array.elem." + Twine(Index));
3853 Builder.CreateStore(Val: RI.PrivateVariable, Ptr: RedArrayElemPtr);
3854 }
3855
3856 // Emit a call to the runtime function that orchestrates the reduction.
3857 // Declare the reduction function in the process.
3858 Type *IndexTy = Builder.getIndexTy(
3859 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3860 Function *Func = Builder.GetInsertBlock()->getParent();
3861 Module *Module = Func->getParent();
3862 uint32_t SrcLocStrSize;
3863 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3864 bool CanGenerateAtomic = all_of(Range&: ReductionInfos, P: [](const ReductionInfo &RI) {
3865 return RI.AtomicReductionGen;
3866 });
3867 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize,
3868 LocFlags: CanGenerateAtomic
3869 ? IdentFlag::OMP_IDENT_FLAG_ATOMIC_REDUCE
3870 : IdentFlag(0));
3871 Value *ThreadId = getOrCreateThreadID(Ident);
3872 Constant *NumVariables = Builder.getInt32(C: NumReductions);
3873 const DataLayout &DL = Module->getDataLayout();
3874 unsigned RedArrayByteSize = DL.getTypeStoreSize(Ty: RedArrayTy);
3875 Constant *RedArraySize = ConstantInt::get(Ty: IndexTy, V: RedArrayByteSize);
3876 Function *ReductionFunc = getFreshReductionFunc(M&: *Module);
3877 Value *Lock = getOMPCriticalRegionLock(CriticalName: ".reduction");
3878 Function *ReduceFunc = getOrCreateRuntimeFunctionPtr(
3879 FnID: IsNoWait ? RuntimeFunction::OMPRTL___kmpc_reduce_nowait
3880 : RuntimeFunction::OMPRTL___kmpc_reduce);
3881 CallInst *ReduceCall =
3882 Builder.CreateCall(Callee: ReduceFunc,
3883 Args: {Ident, ThreadId, NumVariables, RedArraySize, RedArray,
3884 ReductionFunc, Lock},
3885 Name: "reduce");
3886
3887 // Create final reduction entry blocks for the atomic and non-atomic case.
3888 // Emit IR that dispatches control flow to one of the blocks based on the
3889 // reduction supporting the atomic mode.
3890 BasicBlock *NonAtomicRedBlock =
3891 BasicBlock::Create(Context&: Module->getContext(), Name: "reduce.switch.nonatomic", Parent: Func);
3892 BasicBlock *AtomicRedBlock =
3893 BasicBlock::Create(Context&: Module->getContext(), Name: "reduce.switch.atomic", Parent: Func);
3894 SwitchInst *Switch =
3895 Builder.CreateSwitch(V: ReduceCall, Dest: ContinuationBlock, /* NumCases */ 2);
3896 Switch->addCase(OnVal: Builder.getInt32(C: 1), Dest: NonAtomicRedBlock);
3897 Switch->addCase(OnVal: Builder.getInt32(C: 2), Dest: AtomicRedBlock);
3898
3899 // Populate the non-atomic reduction using the elementwise reduction function.
3900 // This loads the elements from the global and private variables and reduces
3901 // them before storing back the result to the global variable.
3902 Builder.SetInsertPoint(NonAtomicRedBlock);
3903 for (auto En : enumerate(First&: ReductionInfos)) {
3904 const ReductionInfo &RI = En.value();
3905 Type *ValueType = RI.ElementType;
3906 // We have one less load for by-ref case because that load is now inside of
3907 // the reduction region
3908 Value *RedValue = RI.Variable;
3909 if (!IsByRef[En.index()]) {
3910 RedValue = Builder.CreateLoad(Ty: ValueType, Ptr: RI.Variable,
3911 Name: "red.value." + Twine(En.index()));
3912 }
3913 Value *PrivateRedValue =
3914 Builder.CreateLoad(Ty: ValueType, Ptr: RI.PrivateVariable,
3915 Name: "red.private.value." + Twine(En.index()));
3916 Value *Reduced;
3917 InsertPointOrErrorTy AfterIP =
3918 RI.ReductionGen(Builder.saveIP(), RedValue, PrivateRedValue, Reduced);
3919 if (!AfterIP)
3920 return AfterIP.takeError();
3921 Builder.restoreIP(IP: *AfterIP);
3922
3923 if (!Builder.GetInsertBlock())
3924 return InsertPointTy();
3925 // for by-ref case, the load is inside of the reduction region
3926 if (!IsByRef[En.index()])
3927 Builder.CreateStore(Val: Reduced, Ptr: RI.Variable);
3928 }
3929 Function *EndReduceFunc = getOrCreateRuntimeFunctionPtr(
3930 FnID: IsNoWait ? RuntimeFunction::OMPRTL___kmpc_end_reduce_nowait
3931 : RuntimeFunction::OMPRTL___kmpc_end_reduce);
3932 Builder.CreateCall(Callee: EndReduceFunc, Args: {Ident, ThreadId, Lock});
3933 Builder.CreateBr(Dest: ContinuationBlock);
3934
3935 // Populate the atomic reduction using the atomic elementwise reduction
3936 // function. There are no loads/stores here because they will be happening
3937 // inside the atomic elementwise reduction.
3938 Builder.SetInsertPoint(AtomicRedBlock);
3939 if (CanGenerateAtomic && llvm::none_of(Range&: IsByRef, P: [](bool P) { return P; })) {
3940 for (const ReductionInfo &RI : ReductionInfos) {
3941 InsertPointOrErrorTy AfterIP = RI.AtomicReductionGen(
3942 Builder.saveIP(), RI.ElementType, RI.Variable, RI.PrivateVariable);
3943 if (!AfterIP)
3944 return AfterIP.takeError();
3945 Builder.restoreIP(IP: *AfterIP);
3946 if (!Builder.GetInsertBlock())
3947 return InsertPointTy();
3948 }
3949 Builder.CreateBr(Dest: ContinuationBlock);
3950 } else {
3951 Builder.CreateUnreachable();
3952 }
3953
3954 // Populate the outlined reduction function using the elementwise reduction
3955 // function. Partial values are extracted from the type-erased array of
3956 // pointers to private variables.
3957 Error Err = populateReductionFunction(ReductionFunc, ReductionInfos, Builder,
3958 IsByRef, /*isGPU=*/IsGPU: false);
3959 if (Err)
3960 return Err;
3961
3962 if (!Builder.GetInsertBlock())
3963 return InsertPointTy();
3964
3965 Builder.SetInsertPoint(ContinuationBlock);
3966 return Builder.saveIP();
3967}
3968
3969OpenMPIRBuilder::InsertPointOrErrorTy
3970OpenMPIRBuilder::createMaster(const LocationDescription &Loc,
3971 BodyGenCallbackTy BodyGenCB,
3972 FinalizeCallbackTy FiniCB) {
3973 if (!updateToLocation(Loc))
3974 return Loc.IP;
3975
3976 Directive OMPD = Directive::OMPD_master;
3977 uint32_t SrcLocStrSize;
3978 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3979 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3980 Value *ThreadId = getOrCreateThreadID(Ident);
3981 Value *Args[] = {Ident, ThreadId};
3982
3983 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_master);
3984 Instruction *EntryCall = Builder.CreateCall(Callee: EntryRTLFn, Args);
3985
3986 Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_master);
3987 Instruction *ExitCall = Builder.CreateCall(Callee: ExitRTLFn, Args);
3988
3989 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
3990 /*Conditional*/ true, /*hasFinalize*/ HasFinalize: true);
3991}
3992
3993OpenMPIRBuilder::InsertPointOrErrorTy
3994OpenMPIRBuilder::createMasked(const LocationDescription &Loc,
3995 BodyGenCallbackTy BodyGenCB,
3996 FinalizeCallbackTy FiniCB, Value *Filter) {
3997 if (!updateToLocation(Loc))
3998 return Loc.IP;
3999
4000 Directive OMPD = Directive::OMPD_masked;
4001 uint32_t SrcLocStrSize;
4002 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4003 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4004 Value *ThreadId = getOrCreateThreadID(Ident);
4005 Value *Args[] = {Ident, ThreadId, Filter};
4006 Value *ArgsEnd[] = {Ident, ThreadId};
4007
4008 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_masked);
4009 Instruction *EntryCall = Builder.CreateCall(Callee: EntryRTLFn, Args);
4010
4011 Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_masked);
4012 Instruction *ExitCall = Builder.CreateCall(Callee: ExitRTLFn, Args: ArgsEnd);
4013
4014 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
4015 /*Conditional*/ true, /*hasFinalize*/ HasFinalize: true);
4016}
4017
4018CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton(
4019 DebugLoc DL, Value *TripCount, Function *F, BasicBlock *PreInsertBefore,
4020 BasicBlock *PostInsertBefore, const Twine &Name) {
4021 Module *M = F->getParent();
4022 LLVMContext &Ctx = M->getContext();
4023 Type *IndVarTy = TripCount->getType();
4024
4025 // Create the basic block structure.
4026 BasicBlock *Preheader =
4027 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".preheader", Parent: F, InsertBefore: PreInsertBefore);
4028 BasicBlock *Header =
4029 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".header", Parent: F, InsertBefore: PreInsertBefore);
4030 BasicBlock *Cond =
4031 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".cond", Parent: F, InsertBefore: PreInsertBefore);
4032 BasicBlock *Body =
4033 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".body", Parent: F, InsertBefore: PreInsertBefore);
4034 BasicBlock *Latch =
4035 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".inc", Parent: F, InsertBefore: PostInsertBefore);
4036 BasicBlock *Exit =
4037 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".exit", Parent: F, InsertBefore: PostInsertBefore);
4038 BasicBlock *After =
4039 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".after", Parent: F, InsertBefore: PostInsertBefore);
4040
4041 // Use specified DebugLoc for new instructions.
4042 Builder.SetCurrentDebugLocation(DL);
4043
4044 Builder.SetInsertPoint(Preheader);
4045 Builder.CreateBr(Dest: Header);
4046
4047 Builder.SetInsertPoint(Header);
4048 PHINode *IndVarPHI = Builder.CreatePHI(Ty: IndVarTy, NumReservedValues: 2, Name: "omp_" + Name + ".iv");
4049 IndVarPHI->addIncoming(V: ConstantInt::get(Ty: IndVarTy, V: 0), BB: Preheader);
4050 Builder.CreateBr(Dest: Cond);
4051
4052 Builder.SetInsertPoint(Cond);
4053 Value *Cmp =
4054 Builder.CreateICmpULT(LHS: IndVarPHI, RHS: TripCount, Name: "omp_" + Name + ".cmp");
4055 Builder.CreateCondBr(Cond: Cmp, True: Body, False: Exit);
4056
4057 Builder.SetInsertPoint(Body);
4058 Builder.CreateBr(Dest: Latch);
4059
4060 Builder.SetInsertPoint(Latch);
4061 Value *Next = Builder.CreateAdd(LHS: IndVarPHI, RHS: ConstantInt::get(Ty: IndVarTy, V: 1),
4062 Name: "omp_" + Name + ".next", /*HasNUW=*/true);
4063 Builder.CreateBr(Dest: Header);
4064 IndVarPHI->addIncoming(V: Next, BB: Latch);
4065
4066 Builder.SetInsertPoint(Exit);
4067 Builder.CreateBr(Dest: After);
4068
4069 // Remember and return the canonical control flow.
4070 LoopInfos.emplace_front();
4071 CanonicalLoopInfo *CL = &LoopInfos.front();
4072
4073 CL->Header = Header;
4074 CL->Cond = Cond;
4075 CL->Latch = Latch;
4076 CL->Exit = Exit;
4077
4078#ifndef NDEBUG
4079 CL->assertOK();
4080#endif
4081 return CL;
4082}
4083
4084Expected<CanonicalLoopInfo *>
4085OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
4086 LoopBodyGenCallbackTy BodyGenCB,
4087 Value *TripCount, const Twine &Name) {
4088 BasicBlock *BB = Loc.IP.getBlock();
4089 BasicBlock *NextBB = BB->getNextNode();
4090
4091 CanonicalLoopInfo *CL = createLoopSkeleton(DL: Loc.DL, TripCount, F: BB->getParent(),
4092 PreInsertBefore: NextBB, PostInsertBefore: NextBB, Name);
4093 BasicBlock *After = CL->getAfter();
4094
4095 // If location is not set, don't connect the loop.
4096 if (updateToLocation(Loc)) {
4097 // Split the loop at the insertion point: Branch to the preheader and move
4098 // every following instruction to after the loop (the After BB). Also, the
4099 // new successor is the loop's after block.
4100 spliceBB(Builder, New: After, /*CreateBranch=*/false);
4101 Builder.CreateBr(Dest: CL->getPreheader());
4102 }
4103
4104 // Emit the body content. We do it after connecting the loop to the CFG to
4105 // avoid that the callback encounters degenerate BBs.
4106 if (Error Err = BodyGenCB(CL->getBodyIP(), CL->getIndVar()))
4107 return Err;
4108
4109#ifndef NDEBUG
4110 CL->assertOK();
4111#endif
4112 return CL;
4113}
4114
4115Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
4116 const LocationDescription &Loc, Value *Start, Value *Stop, Value *Step,
4117 bool IsSigned, bool InclusiveStop, const Twine &Name) {
4118
4119 // Consider the following difficulties (assuming 8-bit signed integers):
4120 // * Adding \p Step to the loop counter which passes \p Stop may overflow:
4121 // DO I = 1, 100, 50
4122 /// * A \p Step of INT_MIN cannot not be normalized to a positive direction:
4123 // DO I = 100, 0, -128
4124
4125 // Start, Stop and Step must be of the same integer type.
4126 auto *IndVarTy = cast<IntegerType>(Val: Start->getType());
4127 assert(IndVarTy == Stop->getType() && "Stop type mismatch");
4128 assert(IndVarTy == Step->getType() && "Step type mismatch");
4129
4130 updateToLocation(Loc);
4131
4132 ConstantInt *Zero = ConstantInt::get(Ty: IndVarTy, V: 0);
4133 ConstantInt *One = ConstantInt::get(Ty: IndVarTy, V: 1);
4134
4135 // Like Step, but always positive.
4136 Value *Incr = Step;
4137
4138 // Distance between Start and Stop; always positive.
4139 Value *Span;
4140
4141 // Condition whether there are no iterations are executed at all, e.g. because
4142 // UB < LB.
4143 Value *ZeroCmp;
4144
4145 if (IsSigned) {
4146 // Ensure that increment is positive. If not, negate and invert LB and UB.
4147 Value *IsNeg = Builder.CreateICmpSLT(LHS: Step, RHS: Zero);
4148 Incr = Builder.CreateSelect(C: IsNeg, True: Builder.CreateNeg(V: Step), False: Step);
4149 Value *LB = Builder.CreateSelect(C: IsNeg, True: Stop, False: Start);
4150 Value *UB = Builder.CreateSelect(C: IsNeg, True: Start, False: Stop);
4151 Span = Builder.CreateSub(LHS: UB, RHS: LB, Name: "", HasNUW: false, HasNSW: true);
4152 ZeroCmp = Builder.CreateICmp(
4153 P: InclusiveStop ? CmpInst::ICMP_SLT : CmpInst::ICMP_SLE, LHS: UB, RHS: LB);
4154 } else {
4155 Span = Builder.CreateSub(LHS: Stop, RHS: Start, Name: "", HasNUW: true);
4156 ZeroCmp = Builder.CreateICmp(
4157 P: InclusiveStop ? CmpInst::ICMP_ULT : CmpInst::ICMP_ULE, LHS: Stop, RHS: Start);
4158 }
4159
4160 Value *CountIfLooping;
4161 if (InclusiveStop) {
4162 CountIfLooping = Builder.CreateAdd(LHS: Builder.CreateUDiv(LHS: Span, RHS: Incr), RHS: One);
4163 } else {
4164 // Avoid incrementing past stop since it could overflow.
4165 Value *CountIfTwo = Builder.CreateAdd(
4166 LHS: Builder.CreateUDiv(LHS: Builder.CreateSub(LHS: Span, RHS: One), RHS: Incr), RHS: One);
4167 Value *OneCmp = Builder.CreateICmp(P: CmpInst::ICMP_ULE, LHS: Span, RHS: Incr);
4168 CountIfLooping = Builder.CreateSelect(C: OneCmp, True: One, False: CountIfTwo);
4169 }
4170
4171 return Builder.CreateSelect(C: ZeroCmp, True: Zero, False: CountIfLooping,
4172 Name: "omp_" + Name + ".tripcount");
4173}
4174
4175Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
4176 const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
4177 Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
4178 InsertPointTy ComputeIP, const Twine &Name) {
4179 LocationDescription ComputeLoc =
4180 ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
4181
4182 Value *TripCount = calculateCanonicalLoopTripCount(
4183 Loc: ComputeLoc, Start, Stop, Step, IsSigned, InclusiveStop, Name);
4184
4185 auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
4186 Builder.restoreIP(IP: CodeGenIP);
4187 Value *Span = Builder.CreateMul(LHS: IV, RHS: Step);
4188 Value *IndVar = Builder.CreateAdd(LHS: Span, RHS: Start);
4189 return BodyGenCB(Builder.saveIP(), IndVar);
4190 };
4191 LocationDescription LoopLoc =
4192 ComputeIP.isSet()
4193 ? Loc
4194 : LocationDescription(Builder.saveIP(),
4195 Builder.getCurrentDebugLocation());
4196 return createCanonicalLoop(Loc: LoopLoc, BodyGenCB: BodyGen, TripCount, Name);
4197}
4198
4199// Returns an LLVM function to call for initializing loop bounds using OpenMP
4200// static scheduling for composite `distribute parallel for` depending on
4201// `type`. Only i32 and i64 are supported by the runtime. Always interpret
4202// integers as unsigned similarly to CanonicalLoopInfo.
4203static FunctionCallee
4204getKmpcDistForStaticInitForType(Type *Ty, Module &M,
4205 OpenMPIRBuilder &OMPBuilder) {
4206 unsigned Bitwidth = Ty->getIntegerBitWidth();
4207 if (Bitwidth == 32)
4208 return OMPBuilder.getOrCreateRuntimeFunction(
4209 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dist_for_static_init_4u);
4210 if (Bitwidth == 64)
4211 return OMPBuilder.getOrCreateRuntimeFunction(
4212 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dist_for_static_init_8u);
4213 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
4214}
4215
4216// Returns an LLVM function to call for initializing loop bounds using OpenMP
4217// static scheduling depending on `type`. Only i32 and i64 are supported by the
4218// runtime. Always interpret integers as unsigned similarly to
4219// CanonicalLoopInfo.
4220static FunctionCallee getKmpcForStaticInitForType(Type *Ty, Module &M,
4221 OpenMPIRBuilder &OMPBuilder) {
4222 unsigned Bitwidth = Ty->getIntegerBitWidth();
4223 if (Bitwidth == 32)
4224 return OMPBuilder.getOrCreateRuntimeFunction(
4225 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_4u);
4226 if (Bitwidth == 64)
4227 return OMPBuilder.getOrCreateRuntimeFunction(
4228 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_8u);
4229 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
4230}
4231
4232OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
4233 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
4234 WorksharingLoopType LoopType, bool NeedsBarrier) {
4235 assert(CLI->isValid() && "Requires a valid canonical loop");
4236 assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
4237 "Require dedicated allocate IP");
4238
4239 // Set up the source location value for OpenMP runtime.
4240 Builder.restoreIP(IP: CLI->getPreheaderIP());
4241 Builder.SetCurrentDebugLocation(DL);
4242
4243 uint32_t SrcLocStrSize;
4244 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4245 Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4246
4247 // Declare useful OpenMP runtime functions.
4248 Value *IV = CLI->getIndVar();
4249 Type *IVTy = IV->getType();
4250 FunctionCallee StaticInit =
4251 LoopType == WorksharingLoopType::DistributeForStaticLoop
4252 ? getKmpcDistForStaticInitForType(Ty: IVTy, M, OMPBuilder&: *this)
4253 : getKmpcForStaticInitForType(Ty: IVTy, M, OMPBuilder&: *this);
4254 FunctionCallee StaticFini =
4255 getOrCreateRuntimeFunction(M, FnID: omp::OMPRTL___kmpc_for_static_fini);
4256
4257 // Allocate space for computed loop bounds as expected by the "init" function.
4258 Builder.SetInsertPoint(AllocaIP.getBlock()->getFirstNonPHIOrDbgOrAlloca());
4259
4260 Type *I32Type = Type::getInt32Ty(C&: M.getContext());
4261 Value *PLastIter = Builder.CreateAlloca(Ty: I32Type, ArraySize: nullptr, Name: "p.lastiter");
4262 Value *PLowerBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.lowerbound");
4263 Value *PUpperBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.upperbound");
4264 Value *PStride = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.stride");
4265 CLI->setLastIter(PLastIter);
4266
4267 // At the end of the preheader, prepare for calling the "init" function by
4268 // storing the current loop bounds into the allocated space. A canonical loop
4269 // always iterates from 0 to trip-count with step 1. Note that "init" expects
4270 // and produces an inclusive upper bound.
4271 Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
4272 Constant *Zero = ConstantInt::get(Ty: IVTy, V: 0);
4273 Constant *One = ConstantInt::get(Ty: IVTy, V: 1);
4274 Builder.CreateStore(Val: Zero, Ptr: PLowerBound);
4275 Value *UpperBound = Builder.CreateSub(LHS: CLI->getTripCount(), RHS: One);
4276 Builder.CreateStore(Val: UpperBound, Ptr: PUpperBound);
4277 Builder.CreateStore(Val: One, Ptr: PStride);
4278
4279 Value *ThreadNum = getOrCreateThreadID(Ident: SrcLoc);
4280
4281 OMPScheduleType SchedType =
4282 (LoopType == WorksharingLoopType::DistributeStaticLoop)
4283 ? OMPScheduleType::OrderedDistribute
4284 : OMPScheduleType::UnorderedStatic;
4285 Constant *SchedulingType =
4286 ConstantInt::get(Ty: I32Type, V: static_cast<int>(SchedType));
4287
4288 // Call the "init" function and update the trip count of the loop with the
4289 // value it produced.
4290 SmallVector<Value *, 10> Args(
4291 {SrcLoc, ThreadNum, SchedulingType, PLastIter, PLowerBound, PUpperBound});
4292 if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
4293 Value *PDistUpperBound =
4294 Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.distupperbound");
4295 Args.push_back(Elt: PDistUpperBound);
4296 }
4297 Args.append(IL: {PStride, One, Zero});
4298 Builder.CreateCall(Callee: StaticInit, Args);
4299 Value *LowerBound = Builder.CreateLoad(Ty: IVTy, Ptr: PLowerBound);
4300 Value *InclusiveUpperBound = Builder.CreateLoad(Ty: IVTy, Ptr: PUpperBound);
4301 Value *TripCountMinusOne = Builder.CreateSub(LHS: InclusiveUpperBound, RHS: LowerBound);
4302 Value *TripCount = Builder.CreateAdd(LHS: TripCountMinusOne, RHS: One);
4303 CLI->setTripCount(TripCount);
4304
4305 // Update all uses of the induction variable except the one in the condition
4306 // block that compares it with the actual upper bound, and the increment in
4307 // the latch block.
4308
4309 CLI->mapIndVar(Updater: [&](Instruction *OldIV) -> Value * {
4310 Builder.SetInsertPoint(TheBB: CLI->getBody(),
4311 IP: CLI->getBody()->getFirstInsertionPt());
4312 Builder.SetCurrentDebugLocation(DL);
4313 return Builder.CreateAdd(LHS: OldIV, RHS: LowerBound);
4314 });
4315
4316 // In the "exit" block, call the "fini" function.
4317 Builder.SetInsertPoint(TheBB: CLI->getExit(),
4318 IP: CLI->getExit()->getTerminator()->getIterator());
4319 Builder.CreateCall(Callee: StaticFini, Args: {SrcLoc, ThreadNum});
4320
4321 // Add the barrier if requested.
4322 if (NeedsBarrier) {
4323 InsertPointOrErrorTy BarrierIP =
4324 createBarrier(Loc: LocationDescription(Builder.saveIP(), DL),
4325 Kind: omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
4326 /* CheckCancelFlag */ false);
4327 if (!BarrierIP)
4328 return BarrierIP.takeError();
4329 }
4330
4331 InsertPointTy AfterIP = CLI->getAfterIP();
4332 CLI->invalidate();
4333
4334 return AfterIP;
4335}
4336
4337OpenMPIRBuilder::InsertPointOrErrorTy
4338OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
4339 CanonicalLoopInfo *CLI,
4340 InsertPointTy AllocaIP,
4341 bool NeedsBarrier,
4342 Value *ChunkSize) {
4343 assert(CLI->isValid() && "Requires a valid canonical loop");
4344 assert(ChunkSize && "Chunk size is required");
4345
4346 LLVMContext &Ctx = CLI->getFunction()->getContext();
4347 Value *IV = CLI->getIndVar();
4348 Value *OrigTripCount = CLI->getTripCount();
4349 Type *IVTy = IV->getType();
4350 assert(IVTy->getIntegerBitWidth() <= 64 &&
4351 "Max supported tripcount bitwidth is 64 bits");
4352 Type *InternalIVTy = IVTy->getIntegerBitWidth() <= 32 ? Type::getInt32Ty(C&: Ctx)
4353 : Type::getInt64Ty(C&: Ctx);
4354 Type *I32Type = Type::getInt32Ty(C&: M.getContext());
4355 Constant *Zero = ConstantInt::get(Ty: InternalIVTy, V: 0);
4356 Constant *One = ConstantInt::get(Ty: InternalIVTy, V: 1);
4357
4358 // Declare useful OpenMP runtime functions.
4359 FunctionCallee StaticInit =
4360 getKmpcForStaticInitForType(Ty: InternalIVTy, M, OMPBuilder&: *this);
4361 FunctionCallee StaticFini =
4362 getOrCreateRuntimeFunction(M, FnID: omp::OMPRTL___kmpc_for_static_fini);
4363
4364 // Allocate space for computed loop bounds as expected by the "init" function.
4365 Builder.restoreIP(IP: AllocaIP);
4366 Builder.SetCurrentDebugLocation(DL);
4367 Value *PLastIter = Builder.CreateAlloca(Ty: I32Type, ArraySize: nullptr, Name: "p.lastiter");
4368 Value *PLowerBound =
4369 Builder.CreateAlloca(Ty: InternalIVTy, ArraySize: nullptr, Name: "p.lowerbound");
4370 Value *PUpperBound =
4371 Builder.CreateAlloca(Ty: InternalIVTy, ArraySize: nullptr, Name: "p.upperbound");
4372 Value *PStride = Builder.CreateAlloca(Ty: InternalIVTy, ArraySize: nullptr, Name: "p.stride");
4373 CLI->setLastIter(PLastIter);
4374
4375 // Set up the source location value for the OpenMP runtime.
4376 Builder.restoreIP(IP: CLI->getPreheaderIP());
4377 Builder.SetCurrentDebugLocation(DL);
4378
4379 // TODO: Detect overflow in ubsan or max-out with current tripcount.
4380 Value *CastedChunkSize =
4381 Builder.CreateZExtOrTrunc(V: ChunkSize, DestTy: InternalIVTy, Name: "chunksize");
4382 Value *CastedTripCount =
4383 Builder.CreateZExt(V: OrigTripCount, DestTy: InternalIVTy, Name: "tripcount");
4384
4385 Constant *SchedulingType = ConstantInt::get(
4386 Ty: I32Type, V: static_cast<int>(OMPScheduleType::UnorderedStaticChunked));
4387 Builder.CreateStore(Val: Zero, Ptr: PLowerBound);
4388 Value *OrigUpperBound = Builder.CreateSub(LHS: CastedTripCount, RHS: One);
4389 Builder.CreateStore(Val: OrigUpperBound, Ptr: PUpperBound);
4390 Builder.CreateStore(Val: One, Ptr: PStride);
4391
4392 // Call the "init" function and update the trip count of the loop with the
4393 // value it produced.
4394 uint32_t SrcLocStrSize;
4395 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4396 Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4397 Value *ThreadNum = getOrCreateThreadID(Ident: SrcLoc);
4398 Builder.CreateCall(Callee: StaticInit,
4399 Args: {/*loc=*/SrcLoc, /*global_tid=*/ThreadNum,
4400 /*schedtype=*/SchedulingType, /*plastiter=*/PLastIter,
4401 /*plower=*/PLowerBound, /*pupper=*/PUpperBound,
4402 /*pstride=*/PStride, /*incr=*/One,
4403 /*chunk=*/CastedChunkSize});
4404
4405 // Load values written by the "init" function.
4406 Value *FirstChunkStart =
4407 Builder.CreateLoad(Ty: InternalIVTy, Ptr: PLowerBound, Name: "omp_firstchunk.lb");
4408 Value *FirstChunkStop =
4409 Builder.CreateLoad(Ty: InternalIVTy, Ptr: PUpperBound, Name: "omp_firstchunk.ub");
4410 Value *FirstChunkEnd = Builder.CreateAdd(LHS: FirstChunkStop, RHS: One);
4411 Value *ChunkRange =
4412 Builder.CreateSub(LHS: FirstChunkEnd, RHS: FirstChunkStart, Name: "omp_chunk.range");
4413 Value *NextChunkStride =
4414 Builder.CreateLoad(Ty: InternalIVTy, Ptr: PStride, Name: "omp_dispatch.stride");
4415
4416 // Create outer "dispatch" loop for enumerating the chunks.
4417 BasicBlock *DispatchEnter = splitBB(Builder, CreateBranch: true);
4418 Value *DispatchCounter;
4419
4420 // It is safe to assume this didn't return an error because the callback
4421 // passed into createCanonicalLoop is the only possible error source, and it
4422 // always returns success.
4423 CanonicalLoopInfo *DispatchCLI = cantFail(ValOrErr: createCanonicalLoop(
4424 Loc: {Builder.saveIP(), DL},
4425 BodyGenCB: [&](InsertPointTy BodyIP, Value *Counter) {
4426 DispatchCounter = Counter;
4427 return Error::success();
4428 },
4429 Start: FirstChunkStart, Stop: CastedTripCount, Step: NextChunkStride,
4430 /*IsSigned=*/false, /*InclusiveStop=*/false, /*ComputeIP=*/{},
4431 Name: "dispatch"));
4432
4433 // Remember the BasicBlocks of the dispatch loop we need, then invalidate to
4434 // not have to preserve the canonical invariant.
4435 BasicBlock *DispatchBody = DispatchCLI->getBody();
4436 BasicBlock *DispatchLatch = DispatchCLI->getLatch();
4437 BasicBlock *DispatchExit = DispatchCLI->getExit();
4438 BasicBlock *DispatchAfter = DispatchCLI->getAfter();
4439 DispatchCLI->invalidate();
4440
4441 // Rewire the original loop to become the chunk loop inside the dispatch loop.
4442 redirectTo(Source: DispatchAfter, Target: CLI->getAfter(), DL);
4443 redirectTo(Source: CLI->getExit(), Target: DispatchLatch, DL);
4444 redirectTo(Source: DispatchBody, Target: DispatchEnter, DL);
4445
4446 // Prepare the prolog of the chunk loop.
4447 Builder.restoreIP(IP: CLI->getPreheaderIP());
4448 Builder.SetCurrentDebugLocation(DL);
4449
4450 // Compute the number of iterations of the chunk loop.
4451 Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
4452 Value *ChunkEnd = Builder.CreateAdd(LHS: DispatchCounter, RHS: ChunkRange);
4453 Value *IsLastChunk =
4454 Builder.CreateICmpUGE(LHS: ChunkEnd, RHS: CastedTripCount, Name: "omp_chunk.is_last");
4455 Value *CountUntilOrigTripCount =
4456 Builder.CreateSub(LHS: CastedTripCount, RHS: DispatchCounter);
4457 Value *ChunkTripCount = Builder.CreateSelect(
4458 C: IsLastChunk, True: CountUntilOrigTripCount, False: ChunkRange, Name: "omp_chunk.tripcount");
4459 Value *BackcastedChunkTC =
4460 Builder.CreateTrunc(V: ChunkTripCount, DestTy: IVTy, Name: "omp_chunk.tripcount.trunc");
4461 CLI->setTripCount(BackcastedChunkTC);
4462
4463 // Update all uses of the induction variable except the one in the condition
4464 // block that compares it with the actual upper bound, and the increment in
4465 // the latch block.
4466 Value *BackcastedDispatchCounter =
4467 Builder.CreateTrunc(V: DispatchCounter, DestTy: IVTy, Name: "omp_dispatch.iv.trunc");
4468 CLI->mapIndVar(Updater: [&](Instruction *) -> Value * {
4469 Builder.restoreIP(IP: CLI->getBodyIP());
4470 return Builder.CreateAdd(LHS: IV, RHS: BackcastedDispatchCounter);
4471 });
4472
4473 // In the "exit" block, call the "fini" function.
4474 Builder.SetInsertPoint(TheBB: DispatchExit, IP: DispatchExit->getFirstInsertionPt());
4475 Builder.CreateCall(Callee: StaticFini, Args: {SrcLoc, ThreadNum});
4476
4477 // Add the barrier if requested.
4478 if (NeedsBarrier) {
4479 InsertPointOrErrorTy AfterIP =
4480 createBarrier(Loc: LocationDescription(Builder.saveIP(), DL), Kind: OMPD_for,
4481 /*ForceSimpleCall=*/false, /*CheckCancelFlag=*/false);
4482 if (!AfterIP)
4483 return AfterIP.takeError();
4484 }
4485
4486#ifndef NDEBUG
4487 // Even though we currently do not support applying additional methods to it,
4488 // the chunk loop should remain a canonical loop.
4489 CLI->assertOK();
4490#endif
4491
4492 return InsertPointTy(DispatchAfter, DispatchAfter->getFirstInsertionPt());
4493}
4494
4495// Returns an LLVM function to call for executing an OpenMP static worksharing
4496// for loop depending on `type`. Only i32 and i64 are supported by the runtime.
4497// Always interpret integers as unsigned similarly to CanonicalLoopInfo.
4498static FunctionCallee
4499getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
4500 WorksharingLoopType LoopType) {
4501 unsigned Bitwidth = Ty->getIntegerBitWidth();
4502 Module &M = OMPBuilder->M;
4503 switch (LoopType) {
4504 case WorksharingLoopType::ForStaticLoop:
4505 if (Bitwidth == 32)
4506 return OMPBuilder->getOrCreateRuntimeFunction(
4507 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_4u);
4508 if (Bitwidth == 64)
4509 return OMPBuilder->getOrCreateRuntimeFunction(
4510 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_8u);
4511 break;
4512 case WorksharingLoopType::DistributeStaticLoop:
4513 if (Bitwidth == 32)
4514 return OMPBuilder->getOrCreateRuntimeFunction(
4515 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_4u);
4516 if (Bitwidth == 64)
4517 return OMPBuilder->getOrCreateRuntimeFunction(
4518 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_8u);
4519 break;
4520 case WorksharingLoopType::DistributeForStaticLoop:
4521 if (Bitwidth == 32)
4522 return OMPBuilder->getOrCreateRuntimeFunction(
4523 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_4u);
4524 if (Bitwidth == 64)
4525 return OMPBuilder->getOrCreateRuntimeFunction(
4526 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_8u);
4527 break;
4528 }
4529 if (Bitwidth != 32 && Bitwidth != 64) {
4530 llvm_unreachable("Unknown OpenMP loop iterator bitwidth");
4531 }
4532 llvm_unreachable("Unknown type of OpenMP worksharing loop");
4533}
4534
4535// Inserts a call to proper OpenMP Device RTL function which handles
4536// loop worksharing.
4537static void createTargetLoopWorkshareCall(OpenMPIRBuilder *OMPBuilder,
4538 WorksharingLoopType LoopType,
4539 BasicBlock *InsertBlock, Value *Ident,
4540 Value *LoopBodyArg, Value *TripCount,
4541 Function &LoopBodyFn) {
4542 Type *TripCountTy = TripCount->getType();
4543 Module &M = OMPBuilder->M;
4544 IRBuilder<> &Builder = OMPBuilder->Builder;
4545 FunctionCallee RTLFn =
4546 getKmpcForStaticLoopForType(Ty: TripCountTy, OMPBuilder, LoopType);
4547 SmallVector<Value *, 8> RealArgs;
4548 RealArgs.push_back(Elt: Ident);
4549 RealArgs.push_back(Elt: &LoopBodyFn);
4550 RealArgs.push_back(Elt: LoopBodyArg);
4551 RealArgs.push_back(Elt: TripCount);
4552 if (LoopType == WorksharingLoopType::DistributeStaticLoop) {
4553 RealArgs.push_back(Elt: ConstantInt::get(Ty: TripCountTy, V: 0));
4554 Builder.restoreIP(IP: {InsertBlock, std::prev(x: InsertBlock->end())});
4555 Builder.CreateCall(Callee: RTLFn, Args: RealArgs);
4556 return;
4557 }
4558 FunctionCallee RTLNumThreads = OMPBuilder->getOrCreateRuntimeFunction(
4559 M, FnID: omp::RuntimeFunction::OMPRTL_omp_get_num_threads);
4560 Builder.restoreIP(IP: {InsertBlock, std::prev(x: InsertBlock->end())});
4561 Value *NumThreads = Builder.CreateCall(Callee: RTLNumThreads, Args: {});
4562
4563 RealArgs.push_back(
4564 Elt: Builder.CreateZExtOrTrunc(V: NumThreads, DestTy: TripCountTy, Name: "num.threads.cast"));
4565 RealArgs.push_back(Elt: ConstantInt::get(Ty: TripCountTy, V: 0));
4566 if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
4567 RealArgs.push_back(Elt: ConstantInt::get(Ty: TripCountTy, V: 0));
4568 }
4569
4570 Builder.CreateCall(Callee: RTLFn, Args: RealArgs);
4571}
4572
4573static void workshareLoopTargetCallback(
4574 OpenMPIRBuilder *OMPIRBuilder, CanonicalLoopInfo *CLI, Value *Ident,
4575 Function &OutlinedFn, const SmallVector<Instruction *, 4> &ToBeDeleted,
4576 WorksharingLoopType LoopType) {
4577 IRBuilder<> &Builder = OMPIRBuilder->Builder;
4578 BasicBlock *Preheader = CLI->getPreheader();
4579 Value *TripCount = CLI->getTripCount();
4580
4581 // After loop body outling, the loop body contains only set up
4582 // of loop body argument structure and the call to the outlined
4583 // loop body function. Firstly, we need to move setup of loop body args
4584 // into loop preheader.
4585 Preheader->splice(ToIt: std::prev(x: Preheader->end()), FromBB: CLI->getBody(),
4586 FromBeginIt: CLI->getBody()->begin(), FromEndIt: std::prev(x: CLI->getBody()->end()));
4587
4588 // The next step is to remove the whole loop. We do not it need anymore.
4589 // That's why make an unconditional branch from loop preheader to loop
4590 // exit block
4591 Builder.restoreIP(IP: {Preheader, Preheader->end()});
4592 Builder.SetCurrentDebugLocation(Preheader->getTerminator()->getDebugLoc());
4593 Preheader->getTerminator()->eraseFromParent();
4594 Builder.CreateBr(Dest: CLI->getExit());
4595
4596 // Delete dead loop blocks
4597 OpenMPIRBuilder::OutlineInfo CleanUpInfo;
4598 SmallPtrSet<BasicBlock *, 32> RegionBlockSet;
4599 SmallVector<BasicBlock *, 32> BlocksToBeRemoved;
4600 CleanUpInfo.EntryBB = CLI->getHeader();
4601 CleanUpInfo.ExitBB = CLI->getExit();
4602 CleanUpInfo.collectBlocks(BlockSet&: RegionBlockSet, BlockVector&: BlocksToBeRemoved);
4603 DeleteDeadBlocks(BBs: BlocksToBeRemoved);
4604
4605 // Find the instruction which corresponds to loop body argument structure
4606 // and remove the call to loop body function instruction.
4607 Value *LoopBodyArg;
4608 User *OutlinedFnUser = OutlinedFn.getUniqueUndroppableUser();
4609 assert(OutlinedFnUser &&
4610 "Expected unique undroppable user of outlined function");
4611 CallInst *OutlinedFnCallInstruction = dyn_cast<CallInst>(Val: OutlinedFnUser);
4612 assert(OutlinedFnCallInstruction && "Expected outlined function call");
4613 assert((OutlinedFnCallInstruction->getParent() == Preheader) &&
4614 "Expected outlined function call to be located in loop preheader");
4615 // Check in case no argument structure has been passed.
4616 if (OutlinedFnCallInstruction->arg_size() > 1)
4617 LoopBodyArg = OutlinedFnCallInstruction->getArgOperand(i: 1);
4618 else
4619 LoopBodyArg = Constant::getNullValue(Ty: Builder.getPtrTy());
4620 OutlinedFnCallInstruction->eraseFromParent();
4621
4622 createTargetLoopWorkshareCall(OMPBuilder: OMPIRBuilder, LoopType, InsertBlock: Preheader, Ident,
4623 LoopBodyArg, TripCount, LoopBodyFn&: OutlinedFn);
4624
4625 for (auto &ToBeDeletedItem : ToBeDeleted)
4626 ToBeDeletedItem->eraseFromParent();
4627 CLI->invalidate();
4628}
4629
4630OpenMPIRBuilder::InsertPointTy
4631OpenMPIRBuilder::applyWorkshareLoopTarget(DebugLoc DL, CanonicalLoopInfo *CLI,
4632 InsertPointTy AllocaIP,
4633 WorksharingLoopType LoopType) {
4634 uint32_t SrcLocStrSize;
4635 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4636 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4637
4638 OutlineInfo OI;
4639 OI.OuterAllocaBB = CLI->getPreheader();
4640 Function *OuterFn = CLI->getPreheader()->getParent();
4641
4642 // Instructions which need to be deleted at the end of code generation
4643 SmallVector<Instruction *, 4> ToBeDeleted;
4644
4645 OI.OuterAllocaBB = AllocaIP.getBlock();
4646
4647 // Mark the body loop as region which needs to be extracted
4648 OI.EntryBB = CLI->getBody();
4649 OI.ExitBB = CLI->getLatch()->splitBasicBlock(I: CLI->getLatch()->begin(),
4650 BBName: "omp.prelatch", Before: true);
4651
4652 // Prepare loop body for extraction
4653 Builder.restoreIP(IP: {CLI->getPreheader(), CLI->getPreheader()->begin()});
4654
4655 // Insert new loop counter variable which will be used only in loop
4656 // body.
4657 AllocaInst *NewLoopCnt = Builder.CreateAlloca(Ty: CLI->getIndVarType(), ArraySize: 0, Name: "");
4658 Instruction *NewLoopCntLoad =
4659 Builder.CreateLoad(Ty: CLI->getIndVarType(), Ptr: NewLoopCnt);
4660 // New loop counter instructions are redundant in the loop preheader when
4661 // code generation for workshare loop is finshed. That's why mark them as
4662 // ready for deletion.
4663 ToBeDeleted.push_back(Elt: NewLoopCntLoad);
4664 ToBeDeleted.push_back(Elt: NewLoopCnt);
4665
4666 // Analyse loop body region. Find all input variables which are used inside
4667 // loop body region.
4668 SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
4669 SmallVector<BasicBlock *, 32> Blocks;
4670 OI.collectBlocks(BlockSet&: ParallelRegionBlockSet, BlockVector&: Blocks);
4671
4672 CodeExtractorAnalysisCache CEAC(*OuterFn);
4673 CodeExtractor Extractor(Blocks,
4674 /* DominatorTree */ nullptr,
4675 /* AggregateArgs */ true,
4676 /* BlockFrequencyInfo */ nullptr,
4677 /* BranchProbabilityInfo */ nullptr,
4678 /* AssumptionCache */ nullptr,
4679 /* AllowVarArgs */ true,
4680 /* AllowAlloca */ true,
4681 /* AllocationBlock */ CLI->getPreheader(),
4682 /* Suffix */ ".omp_wsloop",
4683 /* AggrArgsIn0AddrSpace */ true);
4684
4685 BasicBlock *CommonExit = nullptr;
4686 SetVector<Value *> SinkingCands, HoistingCands;
4687
4688 // Find allocas outside the loop body region which are used inside loop
4689 // body
4690 Extractor.findAllocas(CEAC, SinkCands&: SinkingCands, HoistCands&: HoistingCands, ExitBlock&: CommonExit);
4691
4692 // We need to model loop body region as the function f(cnt, loop_arg).
4693 // That's why we replace loop induction variable by the new counter
4694 // which will be one of loop body function argument
4695 SmallVector<User *> Users(CLI->getIndVar()->user_begin(),
4696 CLI->getIndVar()->user_end());
4697 for (auto Use : Users) {
4698 if (Instruction *Inst = dyn_cast<Instruction>(Val: Use)) {
4699 if (ParallelRegionBlockSet.count(Ptr: Inst->getParent())) {
4700 Inst->replaceUsesOfWith(From: CLI->getIndVar(), To: NewLoopCntLoad);
4701 }
4702 }
4703 }
4704 // Make sure that loop counter variable is not merged into loop body
4705 // function argument structure and it is passed as separate variable
4706 OI.ExcludeArgsFromAggregate.push_back(Elt: NewLoopCntLoad);
4707
4708 // PostOutline CB is invoked when loop body function is outlined and
4709 // loop body is replaced by call to outlined function. We need to add
4710 // call to OpenMP device rtl inside loop preheader. OpenMP device rtl
4711 // function will handle loop control logic.
4712 //
4713 OI.PostOutlineCB = [=, ToBeDeletedVec =
4714 std::move(ToBeDeleted)](Function &OutlinedFn) {
4715 workshareLoopTargetCallback(OMPIRBuilder: this, CLI, Ident, OutlinedFn, ToBeDeleted: ToBeDeletedVec,
4716 LoopType);
4717 };
4718 addOutlineInfo(OI: std::move(OI));
4719 return CLI->getAfterIP();
4720}
4721
4722OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyWorkshareLoop(
4723 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
4724 bool NeedsBarrier, omp::ScheduleKind SchedKind, Value *ChunkSize,
4725 bool HasSimdModifier, bool HasMonotonicModifier,
4726 bool HasNonmonotonicModifier, bool HasOrderedClause,
4727 WorksharingLoopType LoopType) {
4728 if (Config.isTargetDevice())
4729 return applyWorkshareLoopTarget(DL, CLI, AllocaIP, LoopType);
4730 OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType(
4731 ClauseKind: SchedKind, HasChunks: ChunkSize, HasSimdModifier, HasMonotonicModifier,
4732 HasNonmonotonicModifier, HasOrderedClause);
4733
4734 bool IsOrdered = (EffectiveScheduleType & OMPScheduleType::ModifierOrdered) ==
4735 OMPScheduleType::ModifierOrdered;
4736 switch (EffectiveScheduleType & ~OMPScheduleType::ModifierMask) {
4737 case OMPScheduleType::BaseStatic:
4738 assert(!ChunkSize && "No chunk size with static-chunked schedule");
4739 if (IsOrdered)
4740 return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, SchedType: EffectiveScheduleType,
4741 NeedsBarrier, Chunk: ChunkSize);
4742 // FIXME: Monotonicity ignored?
4743 return applyStaticWorkshareLoop(DL, CLI, AllocaIP, LoopType, NeedsBarrier);
4744
4745 case OMPScheduleType::BaseStaticChunked:
4746 if (IsOrdered)
4747 return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, SchedType: EffectiveScheduleType,
4748 NeedsBarrier, Chunk: ChunkSize);
4749 // FIXME: Monotonicity ignored?
4750 return applyStaticChunkedWorkshareLoop(DL, CLI, AllocaIP, NeedsBarrier,
4751 ChunkSize);
4752
4753 case OMPScheduleType::BaseRuntime:
4754 case OMPScheduleType::BaseAuto:
4755 case OMPScheduleType::BaseGreedy:
4756 case OMPScheduleType::BaseBalanced:
4757 case OMPScheduleType::BaseSteal:
4758 case OMPScheduleType::BaseGuidedSimd:
4759 case OMPScheduleType::BaseRuntimeSimd:
4760 assert(!ChunkSize &&
4761 "schedule type does not support user-defined chunk sizes");
4762 [[fallthrough]];
4763 case OMPScheduleType::BaseDynamicChunked:
4764 case OMPScheduleType::BaseGuidedChunked:
4765 case OMPScheduleType::BaseGuidedIterativeChunked:
4766 case OMPScheduleType::BaseGuidedAnalyticalChunked:
4767 case OMPScheduleType::BaseStaticBalancedChunked:
4768 return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, SchedType: EffectiveScheduleType,
4769 NeedsBarrier, Chunk: ChunkSize);
4770
4771 default:
4772 llvm_unreachable("Unknown/unimplemented schedule kind");
4773 }
4774}
4775
4776/// Returns an LLVM function to call for initializing loop bounds using OpenMP
4777/// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
4778/// the runtime. Always interpret integers as unsigned similarly to
4779/// CanonicalLoopInfo.
4780static FunctionCallee
4781getKmpcForDynamicInitForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
4782 unsigned Bitwidth = Ty->getIntegerBitWidth();
4783 if (Bitwidth == 32)
4784 return OMPBuilder.getOrCreateRuntimeFunction(
4785 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_4u);
4786 if (Bitwidth == 64)
4787 return OMPBuilder.getOrCreateRuntimeFunction(
4788 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_8u);
4789 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
4790}
4791
4792/// Returns an LLVM function to call for updating the next loop using OpenMP
4793/// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
4794/// the runtime. Always interpret integers as unsigned similarly to
4795/// CanonicalLoopInfo.
4796static FunctionCallee
4797getKmpcForDynamicNextForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
4798 unsigned Bitwidth = Ty->getIntegerBitWidth();
4799 if (Bitwidth == 32)
4800 return OMPBuilder.getOrCreateRuntimeFunction(
4801 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_4u);
4802 if (Bitwidth == 64)
4803 return OMPBuilder.getOrCreateRuntimeFunction(
4804 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_8u);
4805 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
4806}
4807
4808/// Returns an LLVM function to call for finalizing the dynamic loop using
4809/// depending on `type`. Only i32 and i64 are supported by the runtime. Always
4810/// interpret integers as unsigned similarly to CanonicalLoopInfo.
4811static FunctionCallee
4812getKmpcForDynamicFiniForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
4813 unsigned Bitwidth = Ty->getIntegerBitWidth();
4814 if (Bitwidth == 32)
4815 return OMPBuilder.getOrCreateRuntimeFunction(
4816 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_4u);
4817 if (Bitwidth == 64)
4818 return OMPBuilder.getOrCreateRuntimeFunction(
4819 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_8u);
4820 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
4821}
4822
4823OpenMPIRBuilder::InsertPointOrErrorTy
4824OpenMPIRBuilder::applyDynamicWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
4825 InsertPointTy AllocaIP,
4826 OMPScheduleType SchedType,
4827 bool NeedsBarrier, Value *Chunk) {
4828 assert(CLI->isValid() && "Requires a valid canonical loop");
4829 assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
4830 "Require dedicated allocate IP");
4831 assert(isValidWorkshareLoopScheduleType(SchedType) &&
4832 "Require valid schedule type");
4833
4834 bool Ordered = (SchedType & OMPScheduleType::ModifierOrdered) ==
4835 OMPScheduleType::ModifierOrdered;
4836
4837 // Set up the source location value for OpenMP runtime.
4838 Builder.SetCurrentDebugLocation(DL);
4839
4840 uint32_t SrcLocStrSize;
4841 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4842 Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4843
4844 // Declare useful OpenMP runtime functions.
4845 Value *IV = CLI->getIndVar();
4846 Type *IVTy = IV->getType();
4847 FunctionCallee DynamicInit = getKmpcForDynamicInitForType(Ty: IVTy, M, OMPBuilder&: *this);
4848 FunctionCallee DynamicNext = getKmpcForDynamicNextForType(Ty: IVTy, M, OMPBuilder&: *this);
4849
4850 // Allocate space for computed loop bounds as expected by the "init" function.
4851 Builder.SetInsertPoint(AllocaIP.getBlock()->getFirstNonPHIOrDbgOrAlloca());
4852 Type *I32Type = Type::getInt32Ty(C&: M.getContext());
4853 Value *PLastIter = Builder.CreateAlloca(Ty: I32Type, ArraySize: nullptr, Name: "p.lastiter");
4854 Value *PLowerBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.lowerbound");
4855 Value *PUpperBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.upperbound");
4856 Value *PStride = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.stride");
4857 CLI->setLastIter(PLastIter);
4858
4859 // At the end of the preheader, prepare for calling the "init" function by
4860 // storing the current loop bounds into the allocated space. A canonical loop
4861 // always iterates from 0 to trip-count with step 1. Note that "init" expects
4862 // and produces an inclusive upper bound.
4863 BasicBlock *PreHeader = CLI->getPreheader();
4864 Builder.SetInsertPoint(PreHeader->getTerminator());
4865 Constant *One = ConstantInt::get(Ty: IVTy, V: 1);
4866 Builder.CreateStore(Val: One, Ptr: PLowerBound);
4867 Value *UpperBound = CLI->getTripCount();
4868 Builder.CreateStore(Val: UpperBound, Ptr: PUpperBound);
4869 Builder.CreateStore(Val: One, Ptr: PStride);
4870
4871 BasicBlock *Header = CLI->getHeader();
4872 BasicBlock *Exit = CLI->getExit();
4873 BasicBlock *Cond = CLI->getCond();
4874 BasicBlock *Latch = CLI->getLatch();
4875 InsertPointTy AfterIP = CLI->getAfterIP();
4876
4877 // The CLI will be "broken" in the code below, as the loop is no longer
4878 // a valid canonical loop.
4879
4880 if (!Chunk)
4881 Chunk = One;
4882
4883 Value *ThreadNum = getOrCreateThreadID(Ident: SrcLoc);
4884
4885 Constant *SchedulingType =
4886 ConstantInt::get(Ty: I32Type, V: static_cast<int>(SchedType));
4887
4888 // Call the "init" function.
4889 Builder.CreateCall(Callee: DynamicInit,
4890 Args: {SrcLoc, ThreadNum, SchedulingType, /* LowerBound */ One,
4891 UpperBound, /* step */ One, Chunk});
4892
4893 // An outer loop around the existing one.
4894 BasicBlock *OuterCond = BasicBlock::Create(
4895 Context&: PreHeader->getContext(), Name: Twine(PreHeader->getName()) + ".outer.cond",
4896 Parent: PreHeader->getParent());
4897 // This needs to be 32-bit always, so can't use the IVTy Zero above.
4898 Builder.SetInsertPoint(TheBB: OuterCond, IP: OuterCond->getFirstInsertionPt());
4899 Value *Res =
4900 Builder.CreateCall(Callee: DynamicNext, Args: {SrcLoc, ThreadNum, PLastIter,
4901 PLowerBound, PUpperBound, PStride});
4902 Constant *Zero32 = ConstantInt::get(Ty: I32Type, V: 0);
4903 Value *MoreWork = Builder.CreateCmp(Pred: CmpInst::ICMP_NE, LHS: Res, RHS: Zero32);
4904 Value *LowerBound =
4905 Builder.CreateSub(LHS: Builder.CreateLoad(Ty: IVTy, Ptr: PLowerBound), RHS: One, Name: "lb");
4906 Builder.CreateCondBr(Cond: MoreWork, True: Header, False: Exit);
4907
4908 // Change PHI-node in loop header to use outer cond rather than preheader,
4909 // and set IV to the LowerBound.
4910 Instruction *Phi = &Header->front();
4911 auto *PI = cast<PHINode>(Val: Phi);
4912 PI->setIncomingBlock(i: 0, BB: OuterCond);
4913 PI->setIncomingValue(i: 0, V: LowerBound);
4914
4915 // Then set the pre-header to jump to the OuterCond
4916 Instruction *Term = PreHeader->getTerminator();
4917 auto *Br = cast<BranchInst>(Val: Term);
4918 Br->setSuccessor(idx: 0, NewSucc: OuterCond);
4919
4920 // Modify the inner condition:
4921 // * Use the UpperBound returned from the DynamicNext call.
4922 // * jump to the loop outer loop when done with one of the inner loops.
4923 Builder.SetInsertPoint(TheBB: Cond, IP: Cond->getFirstInsertionPt());
4924 UpperBound = Builder.CreateLoad(Ty: IVTy, Ptr: PUpperBound, Name: "ub");
4925 Instruction *Comp = &*Builder.GetInsertPoint();
4926 auto *CI = cast<CmpInst>(Val: Comp);
4927 CI->setOperand(i_nocapture: 1, Val_nocapture: UpperBound);
4928 // Redirect the inner exit to branch to outer condition.
4929 Instruction *Branch = &Cond->back();
4930 auto *BI = cast<BranchInst>(Val: Branch);
4931 assert(BI->getSuccessor(1) == Exit);
4932 BI->setSuccessor(idx: 1, NewSucc: OuterCond);
4933
4934 // Call the "fini" function if "ordered" is present in wsloop directive.
4935 if (Ordered) {
4936 Builder.SetInsertPoint(&Latch->back());
4937 FunctionCallee DynamicFini = getKmpcForDynamicFiniForType(Ty: IVTy, M, OMPBuilder&: *this);
4938 Builder.CreateCall(Callee: DynamicFini, Args: {SrcLoc, ThreadNum});
4939 }
4940
4941 // Add the barrier if requested.
4942 if (NeedsBarrier) {
4943 Builder.SetInsertPoint(&Exit->back());
4944 InsertPointOrErrorTy BarrierIP =
4945 createBarrier(Loc: LocationDescription(Builder.saveIP(), DL),
4946 Kind: omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
4947 /* CheckCancelFlag */ false);
4948 if (!BarrierIP)
4949 return BarrierIP.takeError();
4950 }
4951
4952 CLI->invalidate();
4953 return AfterIP;
4954}
4955
4956/// Redirect all edges that branch to \p OldTarget to \p NewTarget. That is,
4957/// after this \p OldTarget will be orphaned.
4958static void redirectAllPredecessorsTo(BasicBlock *OldTarget,
4959 BasicBlock *NewTarget, DebugLoc DL) {
4960 for (BasicBlock *Pred : make_early_inc_range(Range: predecessors(BB: OldTarget)))
4961 redirectTo(Source: Pred, Target: NewTarget, DL);
4962}
4963
4964/// Determine which blocks in \p BBs are reachable from outside and remove the
4965/// ones that are not reachable from the function.
4966static void removeUnusedBlocksFromParent(ArrayRef<BasicBlock *> BBs) {
4967 SmallPtrSet<BasicBlock *, 6> BBsToErase(llvm::from_range, BBs);
4968 auto HasRemainingUses = [&BBsToErase](BasicBlock *BB) {
4969 for (Use &U : BB->uses()) {
4970 auto *UseInst = dyn_cast<Instruction>(Val: U.getUser());
4971 if (!UseInst)
4972 continue;
4973 if (BBsToErase.count(Ptr: UseInst->getParent()))
4974 continue;
4975 return true;
4976 }
4977 return false;
4978 };
4979
4980 while (BBsToErase.remove_if(P: HasRemainingUses)) {
4981 // Try again if anything was removed.
4982 }
4983
4984 SmallVector<BasicBlock *, 7> BBVec(BBsToErase.begin(), BBsToErase.end());
4985 DeleteDeadBlocks(BBs: BBVec);
4986}
4987
4988CanonicalLoopInfo *
4989OpenMPIRBuilder::collapseLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
4990 InsertPointTy ComputeIP) {
4991 assert(Loops.size() >= 1 && "At least one loop required");
4992 size_t NumLoops = Loops.size();
4993
4994 // Nothing to do if there is already just one loop.
4995 if (NumLoops == 1)
4996 return Loops.front();
4997
4998 CanonicalLoopInfo *Outermost = Loops.front();
4999 CanonicalLoopInfo *Innermost = Loops.back();
5000 BasicBlock *OrigPreheader = Outermost->getPreheader();
5001 BasicBlock *OrigAfter = Outermost->getAfter();
5002 Function *F = OrigPreheader->getParent();
5003
5004 // Loop control blocks that may become orphaned later.
5005 SmallVector<BasicBlock *, 12> OldControlBBs;
5006 OldControlBBs.reserve(N: 6 * Loops.size());
5007 for (CanonicalLoopInfo *Loop : Loops)
5008 Loop->collectControlBlocks(BBs&: OldControlBBs);
5009
5010 // Setup the IRBuilder for inserting the trip count computation.
5011 Builder.SetCurrentDebugLocation(DL);
5012 if (ComputeIP.isSet())
5013 Builder.restoreIP(IP: ComputeIP);
5014 else
5015 Builder.restoreIP(IP: Outermost->getPreheaderIP());
5016
5017 // Derive the collapsed' loop trip count.
5018 // TODO: Find common/largest indvar type.
5019 Value *CollapsedTripCount = nullptr;
5020 for (CanonicalLoopInfo *L : Loops) {
5021 assert(L->isValid() &&
5022 "All loops to collapse must be valid canonical loops");
5023 Value *OrigTripCount = L->getTripCount();
5024 if (!CollapsedTripCount) {
5025 CollapsedTripCount = OrigTripCount;
5026 continue;
5027 }
5028
5029 // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
5030 CollapsedTripCount = Builder.CreateMul(LHS: CollapsedTripCount, RHS: OrigTripCount,
5031 Name: {}, /*HasNUW=*/true);
5032 }
5033
5034 // Create the collapsed loop control flow.
5035 CanonicalLoopInfo *Result =
5036 createLoopSkeleton(DL, TripCount: CollapsedTripCount, F,
5037 PreInsertBefore: OrigPreheader->getNextNode(), PostInsertBefore: OrigAfter, Name: "collapsed");
5038
5039 // Build the collapsed loop body code.
5040 // Start with deriving the input loop induction variables from the collapsed
5041 // one, using a divmod scheme. To preserve the original loops' order, the
5042 // innermost loop use the least significant bits.
5043 Builder.restoreIP(IP: Result->getBodyIP());
5044
5045 Value *Leftover = Result->getIndVar();
5046 SmallVector<Value *> NewIndVars;
5047 NewIndVars.resize(N: NumLoops);
5048 for (int i = NumLoops - 1; i >= 1; --i) {
5049 Value *OrigTripCount = Loops[i]->getTripCount();
5050
5051 Value *NewIndVar = Builder.CreateURem(LHS: Leftover, RHS: OrigTripCount);
5052 NewIndVars[i] = NewIndVar;
5053
5054 Leftover = Builder.CreateUDiv(LHS: Leftover, RHS: OrigTripCount);
5055 }
5056 // Outermost loop gets all the remaining bits.
5057 NewIndVars[0] = Leftover;
5058
5059 // Construct the loop body control flow.
5060 // We progressively construct the branch structure following in direction of
5061 // the control flow, from the leading in-between code, the loop nest body, the
5062 // trailing in-between code, and rejoining the collapsed loop's latch.
5063 // ContinueBlock and ContinuePred keep track of the source(s) of next edge. If
5064 // the ContinueBlock is set, continue with that block. If ContinuePred, use
5065 // its predecessors as sources.
5066 BasicBlock *ContinueBlock = Result->getBody();
5067 BasicBlock *ContinuePred = nullptr;
5068 auto ContinueWith = [&ContinueBlock, &ContinuePred, DL](BasicBlock *Dest,
5069 BasicBlock *NextSrc) {
5070 if (ContinueBlock)
5071 redirectTo(Source: ContinueBlock, Target: Dest, DL);
5072 else
5073 redirectAllPredecessorsTo(OldTarget: ContinuePred, NewTarget: Dest, DL);
5074
5075 ContinueBlock = nullptr;
5076 ContinuePred = NextSrc;
5077 };
5078
5079 // The code before the nested loop of each level.
5080 // Because we are sinking it into the nest, it will be executed more often
5081 // that the original loop. More sophisticated schemes could keep track of what
5082 // the in-between code is and instantiate it only once per thread.
5083 for (size_t i = 0; i < NumLoops - 1; ++i)
5084 ContinueWith(Loops[i]->getBody(), Loops[i + 1]->getHeader());
5085
5086 // Connect the loop nest body.
5087 ContinueWith(Innermost->getBody(), Innermost->getLatch());
5088
5089 // The code after the nested loop at each level.
5090 for (size_t i = NumLoops - 1; i > 0; --i)
5091 ContinueWith(Loops[i]->getAfter(), Loops[i - 1]->getLatch());
5092
5093 // Connect the finished loop to the collapsed loop latch.
5094 ContinueWith(Result->getLatch(), nullptr);
5095
5096 // Replace the input loops with the new collapsed loop.
5097 redirectTo(Source: Outermost->getPreheader(), Target: Result->getPreheader(), DL);
5098 redirectTo(Source: Result->getAfter(), Target: Outermost->getAfter(), DL);
5099
5100 // Replace the input loop indvars with the derived ones.
5101 for (size_t i = 0; i < NumLoops; ++i)
5102 Loops[i]->getIndVar()->replaceAllUsesWith(V: NewIndVars[i]);
5103
5104 // Remove unused parts of the input loops.
5105 removeUnusedBlocksFromParent(BBs: OldControlBBs);
5106
5107 for (CanonicalLoopInfo *L : Loops)
5108 L->invalidate();
5109
5110#ifndef NDEBUG
5111 Result->assertOK();
5112#endif
5113 return Result;
5114}
5115
5116std::vector<CanonicalLoopInfo *>
5117OpenMPIRBuilder::tileLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
5118 ArrayRef<Value *> TileSizes) {
5119 assert(TileSizes.size() == Loops.size() &&
5120 "Must pass as many tile sizes as there are loops");
5121 int NumLoops = Loops.size();
5122 assert(NumLoops >= 1 && "At least one loop to tile required");
5123
5124 CanonicalLoopInfo *OutermostLoop = Loops.front();
5125 CanonicalLoopInfo *InnermostLoop = Loops.back();
5126 Function *F = OutermostLoop->getBody()->getParent();
5127 BasicBlock *InnerEnter = InnermostLoop->getBody();
5128 BasicBlock *InnerLatch = InnermostLoop->getLatch();
5129
5130 // Loop control blocks that may become orphaned later.
5131 SmallVector<BasicBlock *, 12> OldControlBBs;
5132 OldControlBBs.reserve(N: 6 * Loops.size());
5133 for (CanonicalLoopInfo *Loop : Loops)
5134 Loop->collectControlBlocks(BBs&: OldControlBBs);
5135
5136 // Collect original trip counts and induction variable to be accessible by
5137 // index. Also, the structure of the original loops is not preserved during
5138 // the construction of the tiled loops, so do it before we scavenge the BBs of
5139 // any original CanonicalLoopInfo.
5140 SmallVector<Value *, 4> OrigTripCounts, OrigIndVars;
5141 for (CanonicalLoopInfo *L : Loops) {
5142 assert(L->isValid() && "All input loops must be valid canonical loops");
5143 OrigTripCounts.push_back(Elt: L->getTripCount());
5144 OrigIndVars.push_back(Elt: L->getIndVar());
5145 }
5146
5147 // Collect the code between loop headers. These may contain SSA definitions
5148 // that are used in the loop nest body. To be usable with in the innermost
5149 // body, these BasicBlocks will be sunk into the loop nest body. That is,
5150 // these instructions may be executed more often than before the tiling.
5151 // TODO: It would be sufficient to only sink them into body of the
5152 // corresponding tile loop.
5153 SmallVector<std::pair<BasicBlock *, BasicBlock *>, 4> InbetweenCode;
5154 for (int i = 0; i < NumLoops - 1; ++i) {
5155 CanonicalLoopInfo *Surrounding = Loops[i];
5156 CanonicalLoopInfo *Nested = Loops[i + 1];
5157
5158 BasicBlock *EnterBB = Surrounding->getBody();
5159 BasicBlock *ExitBB = Nested->getHeader();
5160 InbetweenCode.emplace_back(Args&: EnterBB, Args&: ExitBB);
5161 }
5162
5163 // Compute the trip counts of the floor loops.
5164 Builder.SetCurrentDebugLocation(DL);
5165 Builder.restoreIP(IP: OutermostLoop->getPreheaderIP());
5166 SmallVector<Value *, 4> FloorCount, FloorRems;
5167 for (int i = 0; i < NumLoops; ++i) {
5168 Value *TileSize = TileSizes[i];
5169 Value *OrigTripCount = OrigTripCounts[i];
5170 Type *IVType = OrigTripCount->getType();
5171
5172 Value *FloorTripCount = Builder.CreateUDiv(LHS: OrigTripCount, RHS: TileSize);
5173 Value *FloorTripRem = Builder.CreateURem(LHS: OrigTripCount, RHS: TileSize);
5174
5175 // 0 if tripcount divides the tilesize, 1 otherwise.
5176 // 1 means we need an additional iteration for a partial tile.
5177 //
5178 // Unfortunately we cannot just use the roundup-formula
5179 // (tripcount + tilesize - 1)/tilesize
5180 // because the summation might overflow. We do not want introduce undefined
5181 // behavior when the untiled loop nest did not.
5182 Value *FloorTripOverflow =
5183 Builder.CreateICmpNE(LHS: FloorTripRem, RHS: ConstantInt::get(Ty: IVType, V: 0));
5184
5185 FloorTripOverflow = Builder.CreateZExt(V: FloorTripOverflow, DestTy: IVType);
5186 FloorTripCount =
5187 Builder.CreateAdd(LHS: FloorTripCount, RHS: FloorTripOverflow,
5188 Name: "omp_floor" + Twine(i) + ".tripcount", HasNUW: true);
5189
5190 // Remember some values for later use.
5191 FloorCount.push_back(Elt: FloorTripCount);
5192 FloorRems.push_back(Elt: FloorTripRem);
5193 }
5194
5195 // Generate the new loop nest, from the outermost to the innermost.
5196 std::vector<CanonicalLoopInfo *> Result;
5197 Result.reserve(n: NumLoops * 2);
5198
5199 // The basic block of the surrounding loop that enters the nest generated
5200 // loop.
5201 BasicBlock *Enter = OutermostLoop->getPreheader();
5202
5203 // The basic block of the surrounding loop where the inner code should
5204 // continue.
5205 BasicBlock *Continue = OutermostLoop->getAfter();
5206
5207 // Where the next loop basic block should be inserted.
5208 BasicBlock *OutroInsertBefore = InnermostLoop->getExit();
5209
5210 auto EmbeddNewLoop =
5211 [this, DL, F, InnerEnter, &Enter, &Continue, &OutroInsertBefore](
5212 Value *TripCount, const Twine &Name) -> CanonicalLoopInfo * {
5213 CanonicalLoopInfo *EmbeddedLoop = createLoopSkeleton(
5214 DL, TripCount, F, PreInsertBefore: InnerEnter, PostInsertBefore: OutroInsertBefore, Name);
5215 redirectTo(Source: Enter, Target: EmbeddedLoop->getPreheader(), DL);
5216 redirectTo(Source: EmbeddedLoop->getAfter(), Target: Continue, DL);
5217
5218 // Setup the position where the next embedded loop connects to this loop.
5219 Enter = EmbeddedLoop->getBody();
5220 Continue = EmbeddedLoop->getLatch();
5221 OutroInsertBefore = EmbeddedLoop->getLatch();
5222 return EmbeddedLoop;
5223 };
5224
5225 auto EmbeddNewLoops = [&Result, &EmbeddNewLoop](ArrayRef<Value *> TripCounts,
5226 const Twine &NameBase) {
5227 for (auto P : enumerate(First&: TripCounts)) {
5228 CanonicalLoopInfo *EmbeddedLoop =
5229 EmbeddNewLoop(P.value(), NameBase + Twine(P.index()));
5230 Result.push_back(x: EmbeddedLoop);
5231 }
5232 };
5233
5234 EmbeddNewLoops(FloorCount, "floor");
5235
5236 // Within the innermost floor loop, emit the code that computes the tile
5237 // sizes.
5238 Builder.SetInsertPoint(Enter->getTerminator());
5239 SmallVector<Value *, 4> TileCounts;
5240 for (int i = 0; i < NumLoops; ++i) {
5241 CanonicalLoopInfo *FloorLoop = Result[i];
5242 Value *TileSize = TileSizes[i];
5243
5244 Value *FloorIsEpilogue =
5245 Builder.CreateICmpEQ(LHS: FloorLoop->getIndVar(), RHS: FloorCount[i]);
5246 Value *TileTripCount =
5247 Builder.CreateSelect(C: FloorIsEpilogue, True: FloorRems[i], False: TileSize);
5248
5249 TileCounts.push_back(Elt: TileTripCount);
5250 }
5251
5252 // Create the tile loops.
5253 EmbeddNewLoops(TileCounts, "tile");
5254
5255 // Insert the inbetween code into the body.
5256 BasicBlock *BodyEnter = Enter;
5257 BasicBlock *BodyEntered = nullptr;
5258 for (std::pair<BasicBlock *, BasicBlock *> P : InbetweenCode) {
5259 BasicBlock *EnterBB = P.first;
5260 BasicBlock *ExitBB = P.second;
5261
5262 if (BodyEnter)
5263 redirectTo(Source: BodyEnter, Target: EnterBB, DL);
5264 else
5265 redirectAllPredecessorsTo(OldTarget: BodyEntered, NewTarget: EnterBB, DL);
5266
5267 BodyEnter = nullptr;
5268 BodyEntered = ExitBB;
5269 }
5270
5271 // Append the original loop nest body into the generated loop nest body.
5272 if (BodyEnter)
5273 redirectTo(Source: BodyEnter, Target: InnerEnter, DL);
5274 else
5275 redirectAllPredecessorsTo(OldTarget: BodyEntered, NewTarget: InnerEnter, DL);
5276 redirectAllPredecessorsTo(OldTarget: InnerLatch, NewTarget: Continue, DL);
5277
5278 // Replace the original induction variable with an induction variable computed
5279 // from the tile and floor induction variables.
5280 Builder.restoreIP(IP: Result.back()->getBodyIP());
5281 for (int i = 0; i < NumLoops; ++i) {
5282 CanonicalLoopInfo *FloorLoop = Result[i];
5283 CanonicalLoopInfo *TileLoop = Result[NumLoops + i];
5284 Value *OrigIndVar = OrigIndVars[i];
5285 Value *Size = TileSizes[i];
5286
5287 Value *Scale =
5288 Builder.CreateMul(LHS: Size, RHS: FloorLoop->getIndVar(), Name: {}, /*HasNUW=*/true);
5289 Value *Shift =
5290 Builder.CreateAdd(LHS: Scale, RHS: TileLoop->getIndVar(), Name: {}, /*HasNUW=*/true);
5291 OrigIndVar->replaceAllUsesWith(V: Shift);
5292 }
5293
5294 // Remove unused parts of the original loops.
5295 removeUnusedBlocksFromParent(BBs: OldControlBBs);
5296
5297 for (CanonicalLoopInfo *L : Loops)
5298 L->invalidate();
5299
5300#ifndef NDEBUG
5301 for (CanonicalLoopInfo *GenL : Result)
5302 GenL->assertOK();
5303#endif
5304 return Result;
5305}
5306
5307/// Attach metadata \p Properties to the basic block described by \p BB. If the
5308/// basic block already has metadata, the basic block properties are appended.
5309static void addBasicBlockMetadata(BasicBlock *BB,
5310 ArrayRef<Metadata *> Properties) {
5311 // Nothing to do if no property to attach.
5312 if (Properties.empty())
5313 return;
5314
5315 LLVMContext &Ctx = BB->getContext();
5316 SmallVector<Metadata *> NewProperties;
5317 NewProperties.push_back(Elt: nullptr);
5318
5319 // If the basic block already has metadata, prepend it to the new metadata.
5320 MDNode *Existing = BB->getTerminator()->getMetadata(KindID: LLVMContext::MD_loop);
5321 if (Existing)
5322 append_range(C&: NewProperties, R: drop_begin(RangeOrContainer: Existing->operands(), N: 1));
5323
5324 append_range(C&: NewProperties, R&: Properties);
5325 MDNode *BasicBlockID = MDNode::getDistinct(Context&: Ctx, MDs: NewProperties);
5326 BasicBlockID->replaceOperandWith(I: 0, New: BasicBlockID);
5327
5328 BB->getTerminator()->setMetadata(KindID: LLVMContext::MD_loop, Node: BasicBlockID);
5329}
5330
5331/// Attach loop metadata \p Properties to the loop described by \p Loop. If the
5332/// loop already has metadata, the loop properties are appended.
5333static void addLoopMetadata(CanonicalLoopInfo *Loop,
5334 ArrayRef<Metadata *> Properties) {
5335 assert(Loop->isValid() && "Expecting a valid CanonicalLoopInfo");
5336
5337 // Attach metadata to the loop's latch
5338 BasicBlock *Latch = Loop->getLatch();
5339 assert(Latch && "A valid CanonicalLoopInfo must have a unique latch");
5340 addBasicBlockMetadata(BB: Latch, Properties);
5341}
5342
5343/// Attach llvm.access.group metadata to the memref instructions of \p Block
5344static void addSimdMetadata(BasicBlock *Block, MDNode *AccessGroup,
5345 LoopInfo &LI) {
5346 for (Instruction &I : *Block) {
5347 if (I.mayReadOrWriteMemory()) {
5348 // TODO: This instruction may already have access group from
5349 // other pragmas e.g. #pragma clang loop vectorize. Append
5350 // so that the existing metadata is not overwritten.
5351 I.setMetadata(KindID: LLVMContext::MD_access_group, Node: AccessGroup);
5352 }
5353 }
5354}
5355
5356void OpenMPIRBuilder::unrollLoopFull(DebugLoc, CanonicalLoopInfo *Loop) {
5357 LLVMContext &Ctx = Builder.getContext();
5358 addLoopMetadata(
5359 Loop, Properties: {MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")),
5360 MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.full"))});
5361}
5362
5363void OpenMPIRBuilder::unrollLoopHeuristic(DebugLoc, CanonicalLoopInfo *Loop) {
5364 LLVMContext &Ctx = Builder.getContext();
5365 addLoopMetadata(
5366 Loop, Properties: {
5367 MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")),
5368 });
5369}
5370
5371void OpenMPIRBuilder::createIfVersion(CanonicalLoopInfo *CanonicalLoop,
5372 Value *IfCond, ValueToValueMapTy &VMap,
5373 const Twine &NamePrefix) {
5374 Function *F = CanonicalLoop->getFunction();
5375
5376 // Define where if branch should be inserted
5377 Instruction *SplitBefore = CanonicalLoop->getPreheader()->getTerminator();
5378
5379 // TODO: We should not rely on pass manager. Currently we use pass manager
5380 // only for getting llvm::Loop which corresponds to given CanonicalLoopInfo
5381 // object. We should have a method which returns all blocks between
5382 // CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter()
5383 FunctionAnalysisManager FAM;
5384 FAM.registerPass(PassBuilder: []() { return DominatorTreeAnalysis(); });
5385 FAM.registerPass(PassBuilder: []() { return LoopAnalysis(); });
5386 FAM.registerPass(PassBuilder: []() { return PassInstrumentationAnalysis(); });
5387
5388 // Get the loop which needs to be cloned
5389 LoopAnalysis LIA;
5390 LoopInfo &&LI = LIA.run(F&: *F, AM&: FAM);
5391 Loop *L = LI.getLoopFor(BB: CanonicalLoop->getHeader());
5392
5393 // Create additional blocks for the if statement
5394 BasicBlock *Head = SplitBefore->getParent();
5395 Instruction *HeadOldTerm = Head->getTerminator();
5396 llvm::LLVMContext &C = Head->getContext();
5397 llvm::BasicBlock *ThenBlock = llvm::BasicBlock::Create(
5398 Context&: C, Name: NamePrefix + ".if.then", Parent: Head->getParent(), InsertBefore: Head->getNextNode());
5399 llvm::BasicBlock *ElseBlock = llvm::BasicBlock::Create(
5400 Context&: C, Name: NamePrefix + ".if.else", Parent: Head->getParent(), InsertBefore: CanonicalLoop->getExit());
5401
5402 // Create if condition branch.
5403 Builder.SetInsertPoint(HeadOldTerm);
5404 Instruction *BrInstr =
5405 Builder.CreateCondBr(Cond: IfCond, True: ThenBlock, /*ifFalse*/ False: ElseBlock);
5406 InsertPointTy IP{BrInstr->getParent(), ++BrInstr->getIterator()};
5407 // Then block contains branch to omp loop which needs to be vectorized
5408 spliceBB(IP, New: ThenBlock, CreateBranch: false, DL: Builder.getCurrentDebugLocation());
5409 ThenBlock->replaceSuccessorsPhiUsesWith(Old: Head, New: ThenBlock);
5410
5411 Builder.SetInsertPoint(ElseBlock);
5412
5413 // Clone loop for the else branch
5414 SmallVector<BasicBlock *, 8> NewBlocks;
5415
5416 VMap[CanonicalLoop->getPreheader()] = ElseBlock;
5417 for (BasicBlock *Block : L->getBlocks()) {
5418 BasicBlock *NewBB = CloneBasicBlock(BB: Block, VMap, NameSuffix: "", F);
5419 NewBB->moveBefore(MovePos: CanonicalLoop->getExit());
5420 VMap[Block] = NewBB;
5421 NewBlocks.push_back(Elt: NewBB);
5422 }
5423 remapInstructionsInBlocks(Blocks: NewBlocks, VMap);
5424 Builder.CreateBr(Dest: NewBlocks.front());
5425}
5426
5427unsigned
5428OpenMPIRBuilder::getOpenMPDefaultSimdAlign(const Triple &TargetTriple,
5429 const StringMap<bool> &Features) {
5430 if (TargetTriple.isX86()) {
5431 if (Features.lookup(Key: "avx512f"))
5432 return 512;
5433 else if (Features.lookup(Key: "avx"))
5434 return 256;
5435 return 128;
5436 }
5437 if (TargetTriple.isPPC())
5438 return 128;
5439 if (TargetTriple.isWasm())
5440 return 128;
5441 return 0;
5442}
5443
5444void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
5445 MapVector<Value *, Value *> AlignedVars,
5446 Value *IfCond, OrderKind Order,
5447 ConstantInt *Simdlen, ConstantInt *Safelen) {
5448 LLVMContext &Ctx = Builder.getContext();
5449
5450 Function *F = CanonicalLoop->getFunction();
5451
5452 // TODO: We should not rely on pass manager. Currently we use pass manager
5453 // only for getting llvm::Loop which corresponds to given CanonicalLoopInfo
5454 // object. We should have a method which returns all blocks between
5455 // CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter()
5456 FunctionAnalysisManager FAM;
5457 FAM.registerPass(PassBuilder: []() { return DominatorTreeAnalysis(); });
5458 FAM.registerPass(PassBuilder: []() { return LoopAnalysis(); });
5459 FAM.registerPass(PassBuilder: []() { return PassInstrumentationAnalysis(); });
5460
5461 LoopAnalysis LIA;
5462 LoopInfo &&LI = LIA.run(F&: *F, AM&: FAM);
5463
5464 Loop *L = LI.getLoopFor(BB: CanonicalLoop->getHeader());
5465 if (AlignedVars.size()) {
5466 InsertPointTy IP = Builder.saveIP();
5467 for (auto &AlignedItem : AlignedVars) {
5468 Value *AlignedPtr = AlignedItem.first;
5469 Value *Alignment = AlignedItem.second;
5470 Instruction *loadInst = dyn_cast<Instruction>(Val: AlignedPtr);
5471 Builder.SetInsertPoint(loadInst->getNextNode());
5472 Builder.CreateAlignmentAssumption(DL: F->getDataLayout(), PtrValue: AlignedPtr,
5473 Alignment);
5474 }
5475 Builder.restoreIP(IP);
5476 }
5477
5478 if (IfCond) {
5479 ValueToValueMapTy VMap;
5480 createIfVersion(CanonicalLoop, IfCond, VMap, NamePrefix: "simd");
5481 // Add metadata to the cloned loop which disables vectorization
5482 Value *MappedLatch = VMap.lookup(Val: CanonicalLoop->getLatch());
5483 assert(MappedLatch &&
5484 "Cannot find value which corresponds to original loop latch");
5485 assert(isa<BasicBlock>(MappedLatch) &&
5486 "Cannot cast mapped latch block value to BasicBlock");
5487 BasicBlock *NewLatchBlock = dyn_cast<BasicBlock>(Val: MappedLatch);
5488 ConstantAsMetadata *BoolConst =
5489 ConstantAsMetadata::get(C: ConstantInt::getFalse(Ty: Type::getInt1Ty(C&: Ctx)));
5490 addBasicBlockMetadata(
5491 BB: NewLatchBlock,
5492 Properties: {MDNode::get(Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.vectorize.enable"),
5493 BoolConst})});
5494 }
5495
5496 SmallSet<BasicBlock *, 8> Reachable;
5497
5498 // Get the basic blocks from the loop in which memref instructions
5499 // can be found.
5500 // TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
5501 // preferably without running any passes.
5502 for (BasicBlock *Block : L->getBlocks()) {
5503 if (Block == CanonicalLoop->getCond() ||
5504 Block == CanonicalLoop->getHeader())
5505 continue;
5506 Reachable.insert(Ptr: Block);
5507 }
5508
5509 SmallVector<Metadata *> LoopMDList;
5510
5511 // In presence of finite 'safelen', it may be unsafe to mark all
5512 // the memory instructions parallel, because loop-carried
5513 // dependences of 'safelen' iterations are possible.
5514 // If clause order(concurrent) is specified then the memory instructions
5515 // are marked parallel even if 'safelen' is finite.
5516 if ((Safelen == nullptr) || (Order == OrderKind::OMP_ORDER_concurrent)) {
5517 // Add access group metadata to memory-access instructions.
5518 MDNode *AccessGroup = MDNode::getDistinct(Context&: Ctx, MDs: {});
5519 for (BasicBlock *BB : Reachable)
5520 addSimdMetadata(Block: BB, AccessGroup, LI);
5521 // TODO: If the loop has existing parallel access metadata, have
5522 // to combine two lists.
5523 LoopMDList.push_back(Elt: MDNode::get(
5524 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.parallel_accesses"), AccessGroup}));
5525 }
5526
5527 // Use the above access group metadata to create loop level
5528 // metadata, which should be distinct for each loop.
5529 ConstantAsMetadata *BoolConst =
5530 ConstantAsMetadata::get(C: ConstantInt::getTrue(Ty: Type::getInt1Ty(C&: Ctx)));
5531 LoopMDList.push_back(Elt: MDNode::get(
5532 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.vectorize.enable"), BoolConst}));
5533
5534 if (Simdlen || Safelen) {
5535 // If both simdlen and safelen clauses are specified, the value of the
5536 // simdlen parameter must be less than or equal to the value of the safelen
5537 // parameter. Therefore, use safelen only in the absence of simdlen.
5538 ConstantInt *VectorizeWidth = Simdlen == nullptr ? Safelen : Simdlen;
5539 LoopMDList.push_back(
5540 Elt: MDNode::get(Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.vectorize.width"),
5541 ConstantAsMetadata::get(C: VectorizeWidth)}));
5542 }
5543
5544 addLoopMetadata(Loop: CanonicalLoop, Properties: LoopMDList);
5545}
5546
5547/// Create the TargetMachine object to query the backend for optimization
5548/// preferences.
5549///
5550/// Ideally, this would be passed from the front-end to the OpenMPBuilder, but
5551/// e.g. Clang does not pass it to its CodeGen layer and creates it only when
5552/// needed for the LLVM pass pipline. We use some default options to avoid
5553/// having to pass too many settings from the frontend that probably do not
5554/// matter.
5555///
5556/// Currently, TargetMachine is only used sometimes by the unrollLoopPartial
5557/// method. If we are going to use TargetMachine for more purposes, especially
5558/// those that are sensitive to TargetOptions, RelocModel and CodeModel, it
5559/// might become be worth requiring front-ends to pass on their TargetMachine,
5560/// or at least cache it between methods. Note that while fontends such as Clang
5561/// have just a single main TargetMachine per translation unit, "target-cpu" and
5562/// "target-features" that determine the TargetMachine are per-function and can
5563/// be overrided using __attribute__((target("OPTIONS"))).
5564static std::unique_ptr<TargetMachine>
5565createTargetMachine(Function *F, CodeGenOptLevel OptLevel) {
5566 Module *M = F->getParent();
5567
5568 StringRef CPU = F->getFnAttribute(Kind: "target-cpu").getValueAsString();
5569 StringRef Features = F->getFnAttribute(Kind: "target-features").getValueAsString();
5570 const llvm::Triple &Triple = M->getTargetTriple();
5571
5572 std::string Error;
5573 const llvm::Target *TheTarget = TargetRegistry::lookupTarget(TheTriple: Triple, Error);
5574 if (!TheTarget)
5575 return {};
5576
5577 llvm::TargetOptions Options;
5578 return std::unique_ptr<TargetMachine>(TheTarget->createTargetMachine(
5579 TT: Triple, CPU, Features, Options, /*RelocModel=*/RM: std::nullopt,
5580 /*CodeModel=*/CM: std::nullopt, OL: OptLevel));
5581}
5582
5583/// Heuristically determine the best-performant unroll factor for \p CLI. This
5584/// depends on the target processor. We are re-using the same heuristics as the
5585/// LoopUnrollPass.
5586static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) {
5587 Function *F = CLI->getFunction();
5588
5589 // Assume the user requests the most aggressive unrolling, even if the rest of
5590 // the code is optimized using a lower setting.
5591 CodeGenOptLevel OptLevel = CodeGenOptLevel::Aggressive;
5592 std::unique_ptr<TargetMachine> TM = createTargetMachine(F, OptLevel);
5593
5594 FunctionAnalysisManager FAM;
5595 FAM.registerPass(PassBuilder: []() { return TargetLibraryAnalysis(); });
5596 FAM.registerPass(PassBuilder: []() { return AssumptionAnalysis(); });
5597 FAM.registerPass(PassBuilder: []() { return DominatorTreeAnalysis(); });
5598 FAM.registerPass(PassBuilder: []() { return LoopAnalysis(); });
5599 FAM.registerPass(PassBuilder: []() { return ScalarEvolutionAnalysis(); });
5600 FAM.registerPass(PassBuilder: []() { return PassInstrumentationAnalysis(); });
5601 TargetIRAnalysis TIRA;
5602 if (TM)
5603 TIRA = TargetIRAnalysis(
5604 [&](const Function &F) { return TM->getTargetTransformInfo(F); });
5605 FAM.registerPass(PassBuilder: [&]() { return TIRA; });
5606
5607 TargetIRAnalysis::Result &&TTI = TIRA.run(F: *F, FAM);
5608 ScalarEvolutionAnalysis SEA;
5609 ScalarEvolution &&SE = SEA.run(F&: *F, AM&: FAM);
5610 DominatorTreeAnalysis DTA;
5611 DominatorTree &&DT = DTA.run(F&: *F, FAM);
5612 LoopAnalysis LIA;
5613 LoopInfo &&LI = LIA.run(F&: *F, AM&: FAM);
5614 AssumptionAnalysis ACT;
5615 AssumptionCache &&AC = ACT.run(F&: *F, FAM);
5616 OptimizationRemarkEmitter ORE{F};
5617
5618 Loop *L = LI.getLoopFor(BB: CLI->getHeader());
5619 assert(L && "Expecting CanonicalLoopInfo to be recognized as a loop");
5620
5621 TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences(
5622 L, SE, TTI,
5623 /*BlockFrequencyInfo=*/BFI: nullptr,
5624 /*ProfileSummaryInfo=*/PSI: nullptr, ORE, OptLevel: static_cast<int>(OptLevel),
5625 /*UserThreshold=*/std::nullopt,
5626 /*UserCount=*/std::nullopt,
5627 /*UserAllowPartial=*/true,
5628 /*UserAllowRuntime=*/UserRuntime: true,
5629 /*UserUpperBound=*/std::nullopt,
5630 /*UserFullUnrollMaxCount=*/std::nullopt);
5631
5632 UP.Force = true;
5633
5634 // Account for additional optimizations taking place before the LoopUnrollPass
5635 // would unroll the loop.
5636 UP.Threshold *= UnrollThresholdFactor;
5637 UP.PartialThreshold *= UnrollThresholdFactor;
5638
5639 // Use normal unroll factors even if the rest of the code is optimized for
5640 // size.
5641 UP.OptSizeThreshold = UP.Threshold;
5642 UP.PartialOptSizeThreshold = UP.PartialThreshold;
5643
5644 LLVM_DEBUG(dbgs() << "Unroll heuristic thresholds:\n"
5645 << " Threshold=" << UP.Threshold << "\n"
5646 << " PartialThreshold=" << UP.PartialThreshold << "\n"
5647 << " OptSizeThreshold=" << UP.OptSizeThreshold << "\n"
5648 << " PartialOptSizeThreshold="
5649 << UP.PartialOptSizeThreshold << "\n");
5650
5651 // Disable peeling.
5652 TargetTransformInfo::PeelingPreferences PP =
5653 gatherPeelingPreferences(L, SE, TTI,
5654 /*UserAllowPeeling=*/false,
5655 /*UserAllowProfileBasedPeeling=*/false,
5656 /*UnrollingSpecficValues=*/false);
5657
5658 SmallPtrSet<const Value *, 32> EphValues;
5659 CodeMetrics::collectEphemeralValues(L, AC: &AC, EphValues);
5660
5661 // Assume that reads and writes to stack variables can be eliminated by
5662 // Mem2Reg, SROA or LICM. That is, don't count them towards the loop body's
5663 // size.
5664 for (BasicBlock *BB : L->blocks()) {
5665 for (Instruction &I : *BB) {
5666 Value *Ptr;
5667 if (auto *Load = dyn_cast<LoadInst>(Val: &I)) {
5668 Ptr = Load->getPointerOperand();
5669 } else if (auto *Store = dyn_cast<StoreInst>(Val: &I)) {
5670 Ptr = Store->getPointerOperand();
5671 } else
5672 continue;
5673
5674 Ptr = Ptr->stripPointerCasts();
5675
5676 if (auto *Alloca = dyn_cast<AllocaInst>(Val: Ptr)) {
5677 if (Alloca->getParent() == &F->getEntryBlock())
5678 EphValues.insert(Ptr: &I);
5679 }
5680 }
5681 }
5682
5683 UnrollCostEstimator UCE(L, TTI, EphValues, UP.BEInsns);
5684
5685 // Loop is not unrollable if the loop contains certain instructions.
5686 if (!UCE.canUnroll()) {
5687 LLVM_DEBUG(dbgs() << "Loop not considered unrollable\n");
5688 return 1;
5689 }
5690
5691 LLVM_DEBUG(dbgs() << "Estimated loop size is " << UCE.getRolledLoopSize()
5692 << "\n");
5693
5694 // TODO: Determine trip count of \p CLI if constant, computeUnrollCount might
5695 // be able to use it.
5696 int TripCount = 0;
5697 int MaxTripCount = 0;
5698 bool MaxOrZero = false;
5699 unsigned TripMultiple = 0;
5700
5701 bool UseUpperBound = false;
5702 computeUnrollCount(L, TTI, DT, LI: &LI, AC: &AC, SE, EphValues, ORE: &ORE, TripCount,
5703 MaxTripCount, MaxOrZero, TripMultiple, UCE, UP, PP,
5704 UseUpperBound);
5705 unsigned Factor = UP.Count;
5706 LLVM_DEBUG(dbgs() << "Suggesting unroll factor of " << Factor << "\n");
5707
5708 // This function returns 1 to signal to not unroll a loop.
5709 if (Factor == 0)
5710 return 1;
5711 return Factor;
5712}
5713
5714void OpenMPIRBuilder::unrollLoopPartial(DebugLoc DL, CanonicalLoopInfo *Loop,
5715 int32_t Factor,
5716 CanonicalLoopInfo **UnrolledCLI) {
5717 assert(Factor >= 0 && "Unroll factor must not be negative");
5718
5719 Function *F = Loop->getFunction();
5720 LLVMContext &Ctx = F->getContext();
5721
5722 // If the unrolled loop is not used for another loop-associated directive, it
5723 // is sufficient to add metadata for the LoopUnrollPass.
5724 if (!UnrolledCLI) {
5725 SmallVector<Metadata *, 2> LoopMetadata;
5726 LoopMetadata.push_back(
5727 Elt: MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")));
5728
5729 if (Factor >= 1) {
5730 ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
5731 C: ConstantInt::get(Ty: Type::getInt32Ty(C&: Ctx), V: APInt(32, Factor)));
5732 LoopMetadata.push_back(Elt: MDNode::get(
5733 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.count"), FactorConst}));
5734 }
5735
5736 addLoopMetadata(Loop, Properties: LoopMetadata);
5737 return;
5738 }
5739
5740 // Heuristically determine the unroll factor.
5741 if (Factor == 0)
5742 Factor = computeHeuristicUnrollFactor(CLI: Loop);
5743
5744 // No change required with unroll factor 1.
5745 if (Factor == 1) {
5746 *UnrolledCLI = Loop;
5747 return;
5748 }
5749
5750 assert(Factor >= 2 &&
5751 "unrolling only makes sense with a factor of 2 or larger");
5752
5753 Type *IndVarTy = Loop->getIndVarType();
5754
5755 // Apply partial unrolling by tiling the loop by the unroll-factor, then fully
5756 // unroll the inner loop.
5757 Value *FactorVal =
5758 ConstantInt::get(Ty: IndVarTy, V: APInt(IndVarTy->getIntegerBitWidth(), Factor,
5759 /*isSigned=*/false));
5760 std::vector<CanonicalLoopInfo *> LoopNest =
5761 tileLoops(DL, Loops: {Loop}, TileSizes: {FactorVal});
5762 assert(LoopNest.size() == 2 && "Expect 2 loops after tiling");
5763 *UnrolledCLI = LoopNest[0];
5764 CanonicalLoopInfo *InnerLoop = LoopNest[1];
5765
5766 // LoopUnrollPass can only fully unroll loops with constant trip count.
5767 // Unroll by the unroll factor with a fallback epilog for the remainder
5768 // iterations if necessary.
5769 ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
5770 C: ConstantInt::get(Ty: Type::getInt32Ty(C&: Ctx), V: APInt(32, Factor)));
5771 addLoopMetadata(
5772 Loop: InnerLoop,
5773 Properties: {MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")),
5774 MDNode::get(
5775 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.count"), FactorConst})});
5776
5777#ifndef NDEBUG
5778 (*UnrolledCLI)->assertOK();
5779#endif
5780}
5781
5782OpenMPIRBuilder::InsertPointTy
5783OpenMPIRBuilder::createCopyPrivate(const LocationDescription &Loc,
5784 llvm::Value *BufSize, llvm::Value *CpyBuf,
5785 llvm::Value *CpyFn, llvm::Value *DidIt) {
5786 if (!updateToLocation(Loc))
5787 return Loc.IP;
5788
5789 uint32_t SrcLocStrSize;
5790 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5791 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5792 Value *ThreadId = getOrCreateThreadID(Ident);
5793
5794 llvm::Value *DidItLD = Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: DidIt);
5795
5796 Value *Args[] = {Ident, ThreadId, BufSize, CpyBuf, CpyFn, DidItLD};
5797
5798 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_copyprivate);
5799 Builder.CreateCall(Callee: Fn, Args);
5800
5801 return Builder.saveIP();
5802}
5803
5804OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createSingle(
5805 const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
5806 FinalizeCallbackTy FiniCB, bool IsNowait, ArrayRef<llvm::Value *> CPVars,
5807 ArrayRef<llvm::Function *> CPFuncs) {
5808
5809 if (!updateToLocation(Loc))
5810 return Loc.IP;
5811
5812 // If needed allocate and initialize `DidIt` with 0.
5813 // DidIt: flag variable: 1=single thread; 0=not single thread.
5814 llvm::Value *DidIt = nullptr;
5815 if (!CPVars.empty()) {
5816 DidIt = Builder.CreateAlloca(Ty: llvm::Type::getInt32Ty(C&: Builder.getContext()));
5817 Builder.CreateStore(Val: Builder.getInt32(C: 0), Ptr: DidIt);
5818 }
5819
5820 Directive OMPD = Directive::OMPD_single;
5821 uint32_t SrcLocStrSize;
5822 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5823 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5824 Value *ThreadId = getOrCreateThreadID(Ident);
5825 Value *Args[] = {Ident, ThreadId};
5826
5827 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_single);
5828 Instruction *EntryCall = Builder.CreateCall(Callee: EntryRTLFn, Args);
5829
5830 Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_single);
5831 Instruction *ExitCall = Builder.CreateCall(Callee: ExitRTLFn, Args);
5832
5833 auto FiniCBWrapper = [&](InsertPointTy IP) -> Error {
5834 if (Error Err = FiniCB(IP))
5835 return Err;
5836
5837 // The thread that executes the single region must set `DidIt` to 1.
5838 // This is used by __kmpc_copyprivate, to know if the caller is the
5839 // single thread or not.
5840 if (DidIt)
5841 Builder.CreateStore(Val: Builder.getInt32(C: 1), Ptr: DidIt);
5842
5843 return Error::success();
5844 };
5845
5846 // generates the following:
5847 // if (__kmpc_single()) {
5848 // .... single region ...
5849 // __kmpc_end_single
5850 // }
5851 // __kmpc_copyprivate
5852 // __kmpc_barrier
5853
5854 InsertPointOrErrorTy AfterIP =
5855 EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB: FiniCBWrapper,
5856 /*Conditional*/ true,
5857 /*hasFinalize*/ HasFinalize: true);
5858 if (!AfterIP)
5859 return AfterIP.takeError();
5860
5861 if (DidIt) {
5862 for (size_t I = 0, E = CPVars.size(); I < E; ++I)
5863 // NOTE BufSize is currently unused, so just pass 0.
5864 createCopyPrivate(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
5865 /*BufSize=*/ConstantInt::get(Ty: Int64, V: 0), CpyBuf: CPVars[I],
5866 CpyFn: CPFuncs[I], DidIt);
5867 // NOTE __kmpc_copyprivate already inserts a barrier
5868 } else if (!IsNowait) {
5869 InsertPointOrErrorTy AfterIP =
5870 createBarrier(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
5871 Kind: omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
5872 /* CheckCancelFlag */ false);
5873 if (!AfterIP)
5874 return AfterIP.takeError();
5875 }
5876 return Builder.saveIP();
5877}
5878
5879OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createCritical(
5880 const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
5881 FinalizeCallbackTy FiniCB, StringRef CriticalName, Value *HintInst) {
5882
5883 if (!updateToLocation(Loc))
5884 return Loc.IP;
5885
5886 Directive OMPD = Directive::OMPD_critical;
5887 uint32_t SrcLocStrSize;
5888 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5889 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5890 Value *ThreadId = getOrCreateThreadID(Ident);
5891 Value *LockVar = getOMPCriticalRegionLock(CriticalName);
5892 Value *Args[] = {Ident, ThreadId, LockVar};
5893
5894 SmallVector<llvm::Value *, 4> EnterArgs(std::begin(arr&: Args), std::end(arr&: Args));
5895 Function *RTFn = nullptr;
5896 if (HintInst) {
5897 // Add Hint to entry Args and create call
5898 EnterArgs.push_back(Elt: HintInst);
5899 RTFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_critical_with_hint);
5900 } else {
5901 RTFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_critical);
5902 }
5903 Instruction *EntryCall = Builder.CreateCall(Callee: RTFn, Args: EnterArgs);
5904
5905 Function *ExitRTLFn =
5906 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_critical);
5907 Instruction *ExitCall = Builder.CreateCall(Callee: ExitRTLFn, Args);
5908
5909 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
5910 /*Conditional*/ false, /*hasFinalize*/ HasFinalize: true);
5911}
5912
5913OpenMPIRBuilder::InsertPointTy
5914OpenMPIRBuilder::createOrderedDepend(const LocationDescription &Loc,
5915 InsertPointTy AllocaIP, unsigned NumLoops,
5916 ArrayRef<llvm::Value *> StoreValues,
5917 const Twine &Name, bool IsDependSource) {
5918 assert(
5919 llvm::all_of(StoreValues,
5920 [](Value *SV) { return SV->getType()->isIntegerTy(64); }) &&
5921 "OpenMP runtime requires depend vec with i64 type");
5922
5923 if (!updateToLocation(Loc))
5924 return Loc.IP;
5925
5926 // Allocate space for vector and generate alloc instruction.
5927 auto *ArrI64Ty = ArrayType::get(ElementType: Int64, NumElements: NumLoops);
5928 Builder.restoreIP(IP: AllocaIP);
5929 AllocaInst *ArgsBase = Builder.CreateAlloca(Ty: ArrI64Ty, ArraySize: nullptr, Name);
5930 ArgsBase->setAlignment(Align(8));
5931 Builder.restoreIP(IP: Loc.IP);
5932
5933 // Store the index value with offset in depend vector.
5934 for (unsigned I = 0; I < NumLoops; ++I) {
5935 Value *DependAddrGEPIter = Builder.CreateInBoundsGEP(
5936 Ty: ArrI64Ty, Ptr: ArgsBase, IdxList: {Builder.getInt64(C: 0), Builder.getInt64(C: I)});
5937 StoreInst *STInst = Builder.CreateStore(Val: StoreValues[I], Ptr: DependAddrGEPIter);
5938 STInst->setAlignment(Align(8));
5939 }
5940
5941 Value *DependBaseAddrGEP = Builder.CreateInBoundsGEP(
5942 Ty: ArrI64Ty, Ptr: ArgsBase, IdxList: {Builder.getInt64(C: 0), Builder.getInt64(C: 0)});
5943
5944 uint32_t SrcLocStrSize;
5945 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5946 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5947 Value *ThreadId = getOrCreateThreadID(Ident);
5948 Value *Args[] = {Ident, ThreadId, DependBaseAddrGEP};
5949
5950 Function *RTLFn = nullptr;
5951 if (IsDependSource)
5952 RTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_doacross_post);
5953 else
5954 RTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_doacross_wait);
5955 Builder.CreateCall(Callee: RTLFn, Args);
5956
5957 return Builder.saveIP();
5958}
5959
5960OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createOrderedThreadsSimd(
5961 const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
5962 FinalizeCallbackTy FiniCB, bool IsThreads) {
5963 if (!updateToLocation(Loc))
5964 return Loc.IP;
5965
5966 Directive OMPD = Directive::OMPD_ordered;
5967 Instruction *EntryCall = nullptr;
5968 Instruction *ExitCall = nullptr;
5969
5970 if (IsThreads) {
5971 uint32_t SrcLocStrSize;
5972 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5973 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5974 Value *ThreadId = getOrCreateThreadID(Ident);
5975 Value *Args[] = {Ident, ThreadId};
5976
5977 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_ordered);
5978 EntryCall = Builder.CreateCall(Callee: EntryRTLFn, Args);
5979
5980 Function *ExitRTLFn =
5981 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_ordered);
5982 ExitCall = Builder.CreateCall(Callee: ExitRTLFn, Args);
5983 }
5984
5985 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
5986 /*Conditional*/ false, /*hasFinalize*/ HasFinalize: true);
5987}
5988
5989OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::EmitOMPInlinedRegion(
5990 Directive OMPD, Instruction *EntryCall, Instruction *ExitCall,
5991 BodyGenCallbackTy BodyGenCB, FinalizeCallbackTy FiniCB, bool Conditional,
5992 bool HasFinalize, bool IsCancellable) {
5993
5994 if (HasFinalize)
5995 FinalizationStack.push_back(Elt: {.FiniCB: FiniCB, .DK: OMPD, .IsCancellable: IsCancellable});
5996
5997 // Create inlined region's entry and body blocks, in preparation
5998 // for conditional creation
5999 BasicBlock *EntryBB = Builder.GetInsertBlock();
6000 Instruction *SplitPos = EntryBB->getTerminator();
6001 if (!isa_and_nonnull<BranchInst>(Val: SplitPos))
6002 SplitPos = new UnreachableInst(Builder.getContext(), EntryBB);
6003 BasicBlock *ExitBB = EntryBB->splitBasicBlock(I: SplitPos, BBName: "omp_region.end");
6004 BasicBlock *FiniBB =
6005 EntryBB->splitBasicBlock(I: EntryBB->getTerminator(), BBName: "omp_region.finalize");
6006
6007 Builder.SetInsertPoint(EntryBB->getTerminator());
6008 emitCommonDirectiveEntry(OMPD, EntryCall, ExitBB, Conditional);
6009
6010 // generate body
6011 if (Error Err = BodyGenCB(/* AllocaIP */ InsertPointTy(),
6012 /* CodeGenIP */ Builder.saveIP()))
6013 return Err;
6014
6015 // emit exit call and do any needed finalization.
6016 auto FinIP = InsertPointTy(FiniBB, FiniBB->getFirstInsertionPt());
6017 assert(FiniBB->getTerminator()->getNumSuccessors() == 1 &&
6018 FiniBB->getTerminator()->getSuccessor(0) == ExitBB &&
6019 "Unexpected control flow graph state!!");
6020 InsertPointOrErrorTy AfterIP =
6021 emitCommonDirectiveExit(OMPD, FinIP, ExitCall, HasFinalize);
6022 if (!AfterIP)
6023 return AfterIP.takeError();
6024 assert(FiniBB->getUniquePredecessor()->getUniqueSuccessor() == FiniBB &&
6025 "Unexpected Control Flow State!");
6026 MergeBlockIntoPredecessor(BB: FiniBB);
6027
6028 // If we are skipping the region of a non conditional, remove the exit
6029 // block, and clear the builder's insertion point.
6030 assert(SplitPos->getParent() == ExitBB &&
6031 "Unexpected Insertion point location!");
6032 auto merged = MergeBlockIntoPredecessor(BB: ExitBB);
6033 BasicBlock *ExitPredBB = SplitPos->getParent();
6034 auto InsertBB = merged ? ExitPredBB : ExitBB;
6035 if (!isa_and_nonnull<BranchInst>(Val: SplitPos))
6036 SplitPos->eraseFromParent();
6037 Builder.SetInsertPoint(InsertBB);
6038
6039 return Builder.saveIP();
6040}
6041
6042OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveEntry(
6043 Directive OMPD, Value *EntryCall, BasicBlock *ExitBB, bool Conditional) {
6044 // if nothing to do, Return current insertion point.
6045 if (!Conditional || !EntryCall)
6046 return Builder.saveIP();
6047
6048 BasicBlock *EntryBB = Builder.GetInsertBlock();
6049 Value *CallBool = Builder.CreateIsNotNull(Arg: EntryCall);
6050 auto *ThenBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp_region.body");
6051 auto *UI = new UnreachableInst(Builder.getContext(), ThenBB);
6052
6053 // Emit thenBB and set the Builder's insertion point there for
6054 // body generation next. Place the block after the current block.
6055 Function *CurFn = EntryBB->getParent();
6056 CurFn->insert(Position: std::next(x: EntryBB->getIterator()), BB: ThenBB);
6057
6058 // Move Entry branch to end of ThenBB, and replace with conditional
6059 // branch (If-stmt)
6060 Instruction *EntryBBTI = EntryBB->getTerminator();
6061 Builder.CreateCondBr(Cond: CallBool, True: ThenBB, False: ExitBB);
6062 EntryBBTI->removeFromParent();
6063 Builder.SetInsertPoint(UI);
6064 Builder.Insert(I: EntryBBTI);
6065 UI->eraseFromParent();
6066 Builder.SetInsertPoint(ThenBB->getTerminator());
6067
6068 // return an insertion point to ExitBB.
6069 return IRBuilder<>::InsertPoint(ExitBB, ExitBB->getFirstInsertionPt());
6070}
6071
6072OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitCommonDirectiveExit(
6073 omp::Directive OMPD, InsertPointTy FinIP, Instruction *ExitCall,
6074 bool HasFinalize) {
6075
6076 Builder.restoreIP(IP: FinIP);
6077
6078 // If there is finalization to do, emit it before the exit call
6079 if (HasFinalize) {
6080 assert(!FinalizationStack.empty() &&
6081 "Unexpected finalization stack state!");
6082
6083 FinalizationInfo Fi = FinalizationStack.pop_back_val();
6084 assert(Fi.DK == OMPD && "Unexpected Directive for Finalization call!");
6085
6086 if (Error Err = Fi.FiniCB(FinIP))
6087 return Err;
6088
6089 BasicBlock *FiniBB = FinIP.getBlock();
6090 Instruction *FiniBBTI = FiniBB->getTerminator();
6091
6092 // set Builder IP for call creation
6093 Builder.SetInsertPoint(FiniBBTI);
6094 }
6095
6096 if (!ExitCall)
6097 return Builder.saveIP();
6098
6099 // place the Exitcall as last instruction before Finalization block terminator
6100 ExitCall->removeFromParent();
6101 Builder.Insert(I: ExitCall);
6102
6103 return IRBuilder<>::InsertPoint(ExitCall->getParent(),
6104 ExitCall->getIterator());
6105}
6106
6107OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCopyinClauseBlocks(
6108 InsertPointTy IP, Value *MasterAddr, Value *PrivateAddr,
6109 llvm::IntegerType *IntPtrTy, bool BranchtoEnd) {
6110 if (!IP.isSet())
6111 return IP;
6112
6113 IRBuilder<>::InsertPointGuard IPG(Builder);
6114
6115 // creates the following CFG structure
6116 // OMP_Entry : (MasterAddr != PrivateAddr)?
6117 // F T
6118 // | \
6119 // | copin.not.master
6120 // | /
6121 // v /
6122 // copyin.not.master.end
6123 // |
6124 // v
6125 // OMP.Entry.Next
6126
6127 BasicBlock *OMP_Entry = IP.getBlock();
6128 Function *CurFn = OMP_Entry->getParent();
6129 BasicBlock *CopyBegin =
6130 BasicBlock::Create(Context&: M.getContext(), Name: "copyin.not.master", Parent: CurFn);
6131 BasicBlock *CopyEnd = nullptr;
6132
6133 // If entry block is terminated, split to preserve the branch to following
6134 // basic block (i.e. OMP.Entry.Next), otherwise, leave everything as is.
6135 if (isa_and_nonnull<BranchInst>(Val: OMP_Entry->getTerminator())) {
6136 CopyEnd = OMP_Entry->splitBasicBlock(I: OMP_Entry->getTerminator(),
6137 BBName: "copyin.not.master.end");
6138 OMP_Entry->getTerminator()->eraseFromParent();
6139 } else {
6140 CopyEnd =
6141 BasicBlock::Create(Context&: M.getContext(), Name: "copyin.not.master.end", Parent: CurFn);
6142 }
6143
6144 Builder.SetInsertPoint(OMP_Entry);
6145 Value *MasterPtr = Builder.CreatePtrToInt(V: MasterAddr, DestTy: IntPtrTy);
6146 Value *PrivatePtr = Builder.CreatePtrToInt(V: PrivateAddr, DestTy: IntPtrTy);
6147 Value *cmp = Builder.CreateICmpNE(LHS: MasterPtr, RHS: PrivatePtr);
6148 Builder.CreateCondBr(Cond: cmp, True: CopyBegin, False: CopyEnd);
6149
6150 Builder.SetInsertPoint(CopyBegin);
6151 if (BranchtoEnd)
6152 Builder.SetInsertPoint(Builder.CreateBr(Dest: CopyEnd));
6153
6154 return Builder.saveIP();
6155}
6156
6157CallInst *OpenMPIRBuilder::createOMPAlloc(const LocationDescription &Loc,
6158 Value *Size, Value *Allocator,
6159 std::string Name) {
6160 IRBuilder<>::InsertPointGuard IPG(Builder);
6161 updateToLocation(Loc);
6162
6163 uint32_t SrcLocStrSize;
6164 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6165 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6166 Value *ThreadId = getOrCreateThreadID(Ident);
6167 Value *Args[] = {ThreadId, Size, Allocator};
6168
6169 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_alloc);
6170
6171 return Builder.CreateCall(Callee: Fn, Args, Name);
6172}
6173
6174CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc,
6175 Value *Addr, Value *Allocator,
6176 std::string Name) {
6177 IRBuilder<>::InsertPointGuard IPG(Builder);
6178 updateToLocation(Loc);
6179
6180 uint32_t SrcLocStrSize;
6181 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6182 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6183 Value *ThreadId = getOrCreateThreadID(Ident);
6184 Value *Args[] = {ThreadId, Addr, Allocator};
6185 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_free);
6186 return Builder.CreateCall(Callee: Fn, Args, Name);
6187}
6188
6189CallInst *OpenMPIRBuilder::createOMPInteropInit(
6190 const LocationDescription &Loc, Value *InteropVar,
6191 omp::OMPInteropType InteropType, Value *Device, Value *NumDependences,
6192 Value *DependenceAddress, bool HaveNowaitClause) {
6193 IRBuilder<>::InsertPointGuard IPG(Builder);
6194 updateToLocation(Loc);
6195
6196 uint32_t SrcLocStrSize;
6197 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6198 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6199 Value *ThreadId = getOrCreateThreadID(Ident);
6200 if (Device == nullptr)
6201 Device = Constant::getAllOnesValue(Ty: Int32);
6202 Constant *InteropTypeVal = ConstantInt::get(Ty: Int32, V: (int)InteropType);
6203 if (NumDependences == nullptr) {
6204 NumDependences = ConstantInt::get(Ty: Int32, V: 0);
6205 PointerType *PointerTypeVar = PointerType::getUnqual(C&: M.getContext());
6206 DependenceAddress = ConstantPointerNull::get(T: PointerTypeVar);
6207 }
6208 Value *HaveNowaitClauseVal = ConstantInt::get(Ty: Int32, V: HaveNowaitClause);
6209 Value *Args[] = {
6210 Ident, ThreadId, InteropVar, InteropTypeVal,
6211 Device, NumDependences, DependenceAddress, HaveNowaitClauseVal};
6212
6213 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___tgt_interop_init);
6214
6215 return Builder.CreateCall(Callee: Fn, Args);
6216}
6217
6218CallInst *OpenMPIRBuilder::createOMPInteropDestroy(
6219 const LocationDescription &Loc, Value *InteropVar, Value *Device,
6220 Value *NumDependences, Value *DependenceAddress, bool HaveNowaitClause) {
6221 IRBuilder<>::InsertPointGuard IPG(Builder);
6222 updateToLocation(Loc);
6223
6224 uint32_t SrcLocStrSize;
6225 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6226 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6227 Value *ThreadId = getOrCreateThreadID(Ident);
6228 if (Device == nullptr)
6229 Device = Constant::getAllOnesValue(Ty: Int32);
6230 if (NumDependences == nullptr) {
6231 NumDependences = ConstantInt::get(Ty: Int32, V: 0);
6232 PointerType *PointerTypeVar = PointerType::getUnqual(C&: M.getContext());
6233 DependenceAddress = ConstantPointerNull::get(T: PointerTypeVar);
6234 }
6235 Value *HaveNowaitClauseVal = ConstantInt::get(Ty: Int32, V: HaveNowaitClause);
6236 Value *Args[] = {
6237 Ident, ThreadId, InteropVar, Device,
6238 NumDependences, DependenceAddress, HaveNowaitClauseVal};
6239
6240 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___tgt_interop_destroy);
6241
6242 return Builder.CreateCall(Callee: Fn, Args);
6243}
6244
6245CallInst *OpenMPIRBuilder::createOMPInteropUse(const LocationDescription &Loc,
6246 Value *InteropVar, Value *Device,
6247 Value *NumDependences,
6248 Value *DependenceAddress,
6249 bool HaveNowaitClause) {
6250 IRBuilder<>::InsertPointGuard IPG(Builder);
6251 updateToLocation(Loc);
6252 uint32_t SrcLocStrSize;
6253 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6254 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6255 Value *ThreadId = getOrCreateThreadID(Ident);
6256 if (Device == nullptr)
6257 Device = Constant::getAllOnesValue(Ty: Int32);
6258 if (NumDependences == nullptr) {
6259 NumDependences = ConstantInt::get(Ty: Int32, V: 0);
6260 PointerType *PointerTypeVar = PointerType::getUnqual(C&: M.getContext());
6261 DependenceAddress = ConstantPointerNull::get(T: PointerTypeVar);
6262 }
6263 Value *HaveNowaitClauseVal = ConstantInt::get(Ty: Int32, V: HaveNowaitClause);
6264 Value *Args[] = {
6265 Ident, ThreadId, InteropVar, Device,
6266 NumDependences, DependenceAddress, HaveNowaitClauseVal};
6267
6268 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___tgt_interop_use);
6269
6270 return Builder.CreateCall(Callee: Fn, Args);
6271}
6272
6273CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
6274 const LocationDescription &Loc, llvm::Value *Pointer,
6275 llvm::ConstantInt *Size, const llvm::Twine &Name) {
6276 IRBuilder<>::InsertPointGuard IPG(Builder);
6277 updateToLocation(Loc);
6278
6279 uint32_t SrcLocStrSize;
6280 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6281 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6282 Value *ThreadId = getOrCreateThreadID(Ident);
6283 Constant *ThreadPrivateCache =
6284 getOrCreateInternalVariable(Ty: Int8PtrPtr, Name: Name.str());
6285 llvm::Value *Args[] = {Ident, ThreadId, Pointer, Size, ThreadPrivateCache};
6286
6287 Function *Fn =
6288 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_threadprivate_cached);
6289
6290 return Builder.CreateCall(Callee: Fn, Args);
6291}
6292
6293OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
6294 const LocationDescription &Loc,
6295 const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
6296 assert(!Attrs.MaxThreads.empty() && !Attrs.MaxTeams.empty() &&
6297 "expected num_threads and num_teams to be specified");
6298
6299 if (!updateToLocation(Loc))
6300 return Loc.IP;
6301
6302 uint32_t SrcLocStrSize;
6303 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6304 Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6305 Constant *IsSPMDVal = ConstantInt::getSigned(Ty: Int8, V: Attrs.ExecFlags);
6306 Constant *UseGenericStateMachineVal = ConstantInt::getSigned(
6307 Ty: Int8, V: Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD);
6308 Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Ty: Int8, V: true);
6309 Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Ty: Int16, V: 0);
6310
6311 Function *DebugKernelWrapper = Builder.GetInsertBlock()->getParent();
6312 Function *Kernel = DebugKernelWrapper;
6313
6314 // We need to strip the debug prefix to get the correct kernel name.
6315 StringRef KernelName = Kernel->getName();
6316 const std::string DebugPrefix = "_debug__";
6317 if (KernelName.ends_with(Suffix: DebugPrefix)) {
6318 KernelName = KernelName.drop_back(N: DebugPrefix.length());
6319 Kernel = M.getFunction(Name: KernelName);
6320 assert(Kernel && "Expected the real kernel to exist");
6321 }
6322
6323 // Manifest the launch configuration in the metadata matching the kernel
6324 // environment.
6325 if (Attrs.MinTeams > 1 || Attrs.MaxTeams.front() > 0)
6326 writeTeamsForKernel(T, Kernel&: *Kernel, LB: Attrs.MinTeams, UB: Attrs.MaxTeams.front());
6327
6328 // If MaxThreads not set, select the maximum between the default workgroup
6329 // size and the MinThreads value.
6330 int32_t MaxThreadsVal = Attrs.MaxThreads.front();
6331 if (MaxThreadsVal < 0)
6332 MaxThreadsVal = std::max(
6333 a: int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), b: Attrs.MinThreads);
6334
6335 if (MaxThreadsVal > 0)
6336 writeThreadBoundsForKernel(T, Kernel&: *Kernel, LB: Attrs.MinThreads, UB: MaxThreadsVal);
6337
6338 Constant *MinThreads = ConstantInt::getSigned(Ty: Int32, V: Attrs.MinThreads);
6339 Constant *MaxThreads = ConstantInt::getSigned(Ty: Int32, V: MaxThreadsVal);
6340 Constant *MinTeams = ConstantInt::getSigned(Ty: Int32, V: Attrs.MinTeams);
6341 Constant *MaxTeams = ConstantInt::getSigned(Ty: Int32, V: Attrs.MaxTeams.front());
6342 Constant *ReductionDataSize =
6343 ConstantInt::getSigned(Ty: Int32, V: Attrs.ReductionDataSize);
6344 Constant *ReductionBufferLength =
6345 ConstantInt::getSigned(Ty: Int32, V: Attrs.ReductionBufferLength);
6346
6347 Function *Fn = getOrCreateRuntimeFunctionPtr(
6348 FnID: omp::RuntimeFunction::OMPRTL___kmpc_target_init);
6349 const DataLayout &DL = Fn->getDataLayout();
6350
6351 Twine DynamicEnvironmentName = KernelName + "_dynamic_environment";
6352 Constant *DynamicEnvironmentInitializer =
6353 ConstantStruct::get(T: DynamicEnvironment, V: {DebugIndentionLevelVal});
6354 GlobalVariable *DynamicEnvironmentGV = new GlobalVariable(
6355 M, DynamicEnvironment, /*IsConstant=*/false, GlobalValue::WeakODRLinkage,
6356 DynamicEnvironmentInitializer, DynamicEnvironmentName,
6357 /*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
6358 DL.getDefaultGlobalsAddressSpace());
6359 DynamicEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
6360
6361 Constant *DynamicEnvironment =
6362 DynamicEnvironmentGV->getType() == DynamicEnvironmentPtr
6363 ? DynamicEnvironmentGV
6364 : ConstantExpr::getAddrSpaceCast(C: DynamicEnvironmentGV,
6365 Ty: DynamicEnvironmentPtr);
6366
6367 Constant *ConfigurationEnvironmentInitializer = ConstantStruct::get(
6368 T: ConfigurationEnvironment, V: {
6369 UseGenericStateMachineVal,
6370 MayUseNestedParallelismVal,
6371 IsSPMDVal,
6372 MinThreads,
6373 MaxThreads,
6374 MinTeams,
6375 MaxTeams,
6376 ReductionDataSize,
6377 ReductionBufferLength,
6378 });
6379 Constant *KernelEnvironmentInitializer = ConstantStruct::get(
6380 T: KernelEnvironment, V: {
6381 ConfigurationEnvironmentInitializer,
6382 Ident,
6383 DynamicEnvironment,
6384 });
6385 std::string KernelEnvironmentName =
6386 (KernelName + "_kernel_environment").str();
6387 GlobalVariable *KernelEnvironmentGV = new GlobalVariable(
6388 M, KernelEnvironment, /*IsConstant=*/true, GlobalValue::WeakODRLinkage,
6389 KernelEnvironmentInitializer, KernelEnvironmentName,
6390 /*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
6391 DL.getDefaultGlobalsAddressSpace());
6392 KernelEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
6393
6394 Constant *KernelEnvironment =
6395 KernelEnvironmentGV->getType() == KernelEnvironmentPtr
6396 ? KernelEnvironmentGV
6397 : ConstantExpr::getAddrSpaceCast(C: KernelEnvironmentGV,
6398 Ty: KernelEnvironmentPtr);
6399 Value *KernelLaunchEnvironment = DebugKernelWrapper->getArg(i: 0);
6400 Type *KernelLaunchEnvParamTy = Fn->getFunctionType()->getParamType(i: 1);
6401 KernelLaunchEnvironment =
6402 KernelLaunchEnvironment->getType() == KernelLaunchEnvParamTy
6403 ? KernelLaunchEnvironment
6404 : Builder.CreateAddrSpaceCast(V: KernelLaunchEnvironment,
6405 DestTy: KernelLaunchEnvParamTy);
6406 CallInst *ThreadKind =
6407 Builder.CreateCall(Callee: Fn, Args: {KernelEnvironment, KernelLaunchEnvironment});
6408
6409 Value *ExecUserCode = Builder.CreateICmpEQ(
6410 LHS: ThreadKind, RHS: Constant::getAllOnesValue(Ty: ThreadKind->getType()),
6411 Name: "exec_user_code");
6412
6413 // ThreadKind = __kmpc_target_init(...)
6414 // if (ThreadKind == -1)
6415 // user_code
6416 // else
6417 // return;
6418
6419 auto *UI = Builder.CreateUnreachable();
6420 BasicBlock *CheckBB = UI->getParent();
6421 BasicBlock *UserCodeEntryBB = CheckBB->splitBasicBlock(I: UI, BBName: "user_code.entry");
6422
6423 BasicBlock *WorkerExitBB = BasicBlock::Create(
6424 Context&: CheckBB->getContext(), Name: "worker.exit", Parent: CheckBB->getParent());
6425 Builder.SetInsertPoint(WorkerExitBB);
6426 Builder.CreateRetVoid();
6427
6428 auto *CheckBBTI = CheckBB->getTerminator();
6429 Builder.SetInsertPoint(CheckBBTI);
6430 Builder.CreateCondBr(Cond: ExecUserCode, True: UI->getParent(), False: WorkerExitBB);
6431
6432 CheckBBTI->eraseFromParent();
6433 UI->eraseFromParent();
6434
6435 // Continue in the "user_code" block, see diagram above and in
6436 // openmp/libomptarget/deviceRTLs/common/include/target.h .
6437 return InsertPointTy(UserCodeEntryBB, UserCodeEntryBB->getFirstInsertionPt());
6438}
6439
6440void OpenMPIRBuilder::createTargetDeinit(const LocationDescription &Loc,
6441 int32_t TeamsReductionDataSize,
6442 int32_t TeamsReductionBufferLength) {
6443 if (!updateToLocation(Loc))
6444 return;
6445
6446 Function *Fn = getOrCreateRuntimeFunctionPtr(
6447 FnID: omp::RuntimeFunction::OMPRTL___kmpc_target_deinit);
6448
6449 Builder.CreateCall(Callee: Fn, Args: {});
6450
6451 if (!TeamsReductionBufferLength || !TeamsReductionDataSize)
6452 return;
6453
6454 Function *Kernel = Builder.GetInsertBlock()->getParent();
6455 // We need to strip the debug prefix to get the correct kernel name.
6456 StringRef KernelName = Kernel->getName();
6457 const std::string DebugPrefix = "_debug__";
6458 if (KernelName.ends_with(Suffix: DebugPrefix))
6459 KernelName = KernelName.drop_back(N: DebugPrefix.length());
6460 auto *KernelEnvironmentGV =
6461 M.getNamedGlobal(Name: (KernelName + "_kernel_environment").str());
6462 assert(KernelEnvironmentGV && "Expected kernel environment global\n");
6463 auto *KernelEnvironmentInitializer = KernelEnvironmentGV->getInitializer();
6464 auto *NewInitializer = ConstantFoldInsertValueInstruction(
6465 Agg: KernelEnvironmentInitializer,
6466 Val: ConstantInt::get(Ty: Int32, V: TeamsReductionDataSize), Idxs: {0, 7});
6467 NewInitializer = ConstantFoldInsertValueInstruction(
6468 Agg: NewInitializer, Val: ConstantInt::get(Ty: Int32, V: TeamsReductionBufferLength),
6469 Idxs: {0, 8});
6470 KernelEnvironmentGV->setInitializer(NewInitializer);
6471}
6472
6473static void updateNVPTXAttr(Function &Kernel, StringRef Name, int32_t Value,
6474 bool Min) {
6475 if (Kernel.hasFnAttribute(Kind: Name)) {
6476 int32_t OldLimit = Kernel.getFnAttributeAsParsedInteger(Kind: Name);
6477 Value = Min ? std::min(a: OldLimit, b: Value) : std::max(a: OldLimit, b: Value);
6478 }
6479 Kernel.addFnAttr(Kind: Name, Val: llvm::utostr(X: Value));
6480}
6481
6482std::pair<int32_t, int32_t>
6483OpenMPIRBuilder::readThreadBoundsForKernel(const Triple &T, Function &Kernel) {
6484 int32_t ThreadLimit =
6485 Kernel.getFnAttributeAsParsedInteger(Kind: "omp_target_thread_limit");
6486
6487 if (T.isAMDGPU()) {
6488 const auto &Attr = Kernel.getFnAttribute(Kind: "amdgpu-flat-work-group-size");
6489 if (!Attr.isValid() || !Attr.isStringAttribute())
6490 return {0, ThreadLimit};
6491 auto [LBStr, UBStr] = Attr.getValueAsString().split(Separator: ',');
6492 int32_t LB, UB;
6493 if (!llvm::to_integer(S: UBStr, Num&: UB, Base: 10))
6494 return {0, ThreadLimit};
6495 UB = ThreadLimit ? std::min(a: ThreadLimit, b: UB) : UB;
6496 if (!llvm::to_integer(S: LBStr, Num&: LB, Base: 10))
6497 return {0, UB};
6498 return {LB, UB};
6499 }
6500
6501 if (Kernel.hasFnAttribute(Kind: "nvvm.maxntid")) {
6502 int32_t UB = Kernel.getFnAttributeAsParsedInteger(Kind: "nvvm.maxntid");
6503 return {0, ThreadLimit ? std::min(a: ThreadLimit, b: UB) : UB};
6504 }
6505 return {0, ThreadLimit};
6506}
6507
6508void OpenMPIRBuilder::writeThreadBoundsForKernel(const Triple &T,
6509 Function &Kernel, int32_t LB,
6510 int32_t UB) {
6511 Kernel.addFnAttr(Kind: "omp_target_thread_limit", Val: std::to_string(val: UB));
6512
6513 if (T.isAMDGPU()) {
6514 Kernel.addFnAttr(Kind: "amdgpu-flat-work-group-size",
6515 Val: llvm::utostr(X: LB) + "," + llvm::utostr(X: UB));
6516 return;
6517 }
6518
6519 updateNVPTXAttr(Kernel, Name: "nvvm.maxntid", Value: UB, Min: true);
6520}
6521
6522std::pair<int32_t, int32_t>
6523OpenMPIRBuilder::readTeamBoundsForKernel(const Triple &, Function &Kernel) {
6524 // TODO: Read from backend annotations if available.
6525 return {0, Kernel.getFnAttributeAsParsedInteger(Kind: "omp_target_num_teams")};
6526}
6527
6528void OpenMPIRBuilder::writeTeamsForKernel(const Triple &T, Function &Kernel,
6529 int32_t LB, int32_t UB) {
6530 if (T.isNVPTX())
6531 if (UB > 0)
6532 Kernel.addFnAttr(Kind: "nvvm.maxclusterrank", Val: llvm::utostr(X: UB));
6533 if (T.isAMDGPU())
6534 Kernel.addFnAttr(Kind: "amdgpu-max-num-workgroups", Val: llvm::utostr(X: LB) + ",1,1");
6535
6536 Kernel.addFnAttr(Kind: "omp_target_num_teams", Val: std::to_string(val: LB));
6537}
6538
6539void OpenMPIRBuilder::setOutlinedTargetRegionFunctionAttributes(
6540 Function *OutlinedFn) {
6541 if (Config.isTargetDevice()) {
6542 OutlinedFn->setLinkage(GlobalValue::WeakODRLinkage);
6543 // TODO: Determine if DSO local can be set to true.
6544 OutlinedFn->setDSOLocal(false);
6545 OutlinedFn->setVisibility(GlobalValue::ProtectedVisibility);
6546 if (T.isAMDGCN())
6547 OutlinedFn->setCallingConv(CallingConv::AMDGPU_KERNEL);
6548 else if (T.isNVPTX())
6549 OutlinedFn->setCallingConv(CallingConv::PTX_Kernel);
6550 else if (T.isSPIRV())
6551 OutlinedFn->setCallingConv(CallingConv::SPIR_KERNEL);
6552 }
6553}
6554
6555Constant *OpenMPIRBuilder::createOutlinedFunctionID(Function *OutlinedFn,
6556 StringRef EntryFnIDName) {
6557 if (Config.isTargetDevice()) {
6558 assert(OutlinedFn && "The outlined function must exist if embedded");
6559 return OutlinedFn;
6560 }
6561
6562 return new GlobalVariable(
6563 M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
6564 Constant::getNullValue(Ty: Builder.getInt8Ty()), EntryFnIDName);
6565}
6566
6567Constant *OpenMPIRBuilder::createTargetRegionEntryAddr(Function *OutlinedFn,
6568 StringRef EntryFnName) {
6569 if (OutlinedFn)
6570 return OutlinedFn;
6571
6572 assert(!M.getGlobalVariable(EntryFnName, true) &&
6573 "Named kernel already exists?");
6574 return new GlobalVariable(
6575 M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::InternalLinkage,
6576 Constant::getNullValue(Ty: Builder.getInt8Ty()), EntryFnName);
6577}
6578
6579Error OpenMPIRBuilder::emitTargetRegionFunction(
6580 TargetRegionEntryInfo &EntryInfo,
6581 FunctionGenCallback &GenerateFunctionCallback, bool IsOffloadEntry,
6582 Function *&OutlinedFn, Constant *&OutlinedFnID) {
6583
6584 SmallString<64> EntryFnName;
6585 OffloadInfoManager.getTargetRegionEntryFnName(Name&: EntryFnName, EntryInfo);
6586
6587 if (Config.isTargetDevice() || !Config.openMPOffloadMandatory()) {
6588 Expected<Function *> CBResult = GenerateFunctionCallback(EntryFnName);
6589 if (!CBResult)
6590 return CBResult.takeError();
6591 OutlinedFn = *CBResult;
6592 } else {
6593 OutlinedFn = nullptr;
6594 }
6595
6596 // If this target outline function is not an offload entry, we don't need to
6597 // register it. This may be in the case of a false if clause, or if there are
6598 // no OpenMP targets.
6599 if (!IsOffloadEntry)
6600 return Error::success();
6601
6602 std::string EntryFnIDName =
6603 Config.isTargetDevice()
6604 ? std::string(EntryFnName)
6605 : createPlatformSpecificName(Parts: {EntryFnName, "region_id"});
6606
6607 OutlinedFnID = registerTargetRegionFunction(EntryInfo, OutlinedFunction: OutlinedFn,
6608 EntryFnName, EntryFnIDName);
6609 return Error::success();
6610}
6611
6612Constant *OpenMPIRBuilder::registerTargetRegionFunction(
6613 TargetRegionEntryInfo &EntryInfo, Function *OutlinedFn,
6614 StringRef EntryFnName, StringRef EntryFnIDName) {
6615 if (OutlinedFn)
6616 setOutlinedTargetRegionFunctionAttributes(OutlinedFn);
6617 auto OutlinedFnID = createOutlinedFunctionID(OutlinedFn, EntryFnIDName);
6618 auto EntryAddr = createTargetRegionEntryAddr(OutlinedFn, EntryFnName);
6619 OffloadInfoManager.registerTargetRegionEntryInfo(
6620 EntryInfo, Addr: EntryAddr, ID: OutlinedFnID,
6621 Flags: OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion);
6622 return OutlinedFnID;
6623}
6624
6625OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData(
6626 const LocationDescription &Loc, InsertPointTy AllocaIP,
6627 InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
6628 TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
6629 CustomMapperCallbackTy CustomMapperCB, omp::RuntimeFunction *MapperFunc,
6630 function_ref<InsertPointOrErrorTy(InsertPointTy CodeGenIP,
6631 BodyGenTy BodyGenType)>
6632 BodyGenCB,
6633 function_ref<void(unsigned int, Value *)> DeviceAddrCB, Value *SrcLocInfo) {
6634 if (!updateToLocation(Loc))
6635 return InsertPointTy();
6636
6637 Builder.restoreIP(IP: CodeGenIP);
6638 // Disable TargetData CodeGen on Device pass.
6639 if (Config.IsTargetDevice.value_or(u: false)) {
6640 if (BodyGenCB) {
6641 InsertPointOrErrorTy AfterIP =
6642 BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv);
6643 if (!AfterIP)
6644 return AfterIP.takeError();
6645 Builder.restoreIP(IP: *AfterIP);
6646 }
6647 return Builder.saveIP();
6648 }
6649
6650 bool IsStandAlone = !BodyGenCB;
6651 MapInfosTy *MapInfo;
6652 // Generate the code for the opening of the data environment. Capture all the
6653 // arguments of the runtime call by reference because they are used in the
6654 // closing of the region.
6655 auto BeginThenGen = [&](InsertPointTy AllocaIP,
6656 InsertPointTy CodeGenIP) -> Error {
6657 MapInfo = &GenMapInfoCB(Builder.saveIP());
6658 if (Error Err = emitOffloadingArrays(
6659 AllocaIP, CodeGenIP: Builder.saveIP(), CombinedInfo&: *MapInfo, Info, CustomMapperCB,
6660 /*IsNonContiguous=*/true, DeviceAddrCB))
6661 return Err;
6662
6663 TargetDataRTArgs RTArgs;
6664 emitOffloadingArraysArgument(Builder, RTArgs, Info);
6665
6666 // Emit the number of elements in the offloading arrays.
6667 Value *PointerNum = Builder.getInt32(C: Info.NumberOfPtrs);
6668
6669 // Source location for the ident struct
6670 if (!SrcLocInfo) {
6671 uint32_t SrcLocStrSize;
6672 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6673 SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6674 }
6675
6676 SmallVector<llvm::Value *, 13> OffloadingArgs = {
6677 SrcLocInfo, DeviceID,
6678 PointerNum, RTArgs.BasePointersArray,
6679 RTArgs.PointersArray, RTArgs.SizesArray,
6680 RTArgs.MapTypesArray, RTArgs.MapNamesArray,
6681 RTArgs.MappersArray};
6682
6683 if (IsStandAlone) {
6684 assert(MapperFunc && "MapperFunc missing for standalone target data");
6685
6686 auto TaskBodyCB = [&](Value *, Value *,
6687 IRBuilderBase::InsertPoint) -> Error {
6688 if (Info.HasNoWait) {
6689 OffloadingArgs.append(IL: {llvm::Constant::getNullValue(Ty: Int32),
6690 llvm::Constant::getNullValue(Ty: VoidPtr),
6691 llvm::Constant::getNullValue(Ty: Int32),
6692 llvm::Constant::getNullValue(Ty: VoidPtr)});
6693 }
6694
6695 Builder.CreateCall(Callee: getOrCreateRuntimeFunctionPtr(FnID: *MapperFunc),
6696 Args: OffloadingArgs);
6697
6698 if (Info.HasNoWait) {
6699 BasicBlock *OffloadContBlock =
6700 BasicBlock::Create(Context&: Builder.getContext(), Name: "omp_offload.cont");
6701 Function *CurFn = Builder.GetInsertBlock()->getParent();
6702 emitBlock(BB: OffloadContBlock, CurFn, /*IsFinished=*/true);
6703 Builder.restoreIP(IP: Builder.saveIP());
6704 }
6705 return Error::success();
6706 };
6707
6708 bool RequiresOuterTargetTask = Info.HasNoWait;
6709 if (!RequiresOuterTargetTask)
6710 cantFail(Err: TaskBodyCB(/*DeviceID=*/nullptr, /*RTLoc=*/nullptr,
6711 /*TargetTaskAllocaIP=*/{}));
6712 else
6713 cantFail(ValOrErr: emitTargetTask(TaskBodyCB, DeviceID, RTLoc: SrcLocInfo, AllocaIP,
6714 /*Dependencies=*/{}, RTArgs, HasNoWait: Info.HasNoWait));
6715 } else {
6716 Function *BeginMapperFunc = getOrCreateRuntimeFunctionPtr(
6717 FnID: omp::OMPRTL___tgt_target_data_begin_mapper);
6718
6719 Builder.CreateCall(Callee: BeginMapperFunc, Args: OffloadingArgs);
6720
6721 for (auto DeviceMap : Info.DevicePtrInfoMap) {
6722 if (isa<AllocaInst>(Val: DeviceMap.second.second)) {
6723 auto *LI =
6724 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: DeviceMap.second.first);
6725 Builder.CreateStore(Val: LI, Ptr: DeviceMap.second.second);
6726 }
6727 }
6728
6729 // If device pointer privatization is required, emit the body of the
6730 // region here. It will have to be duplicated: with and without
6731 // privatization.
6732 InsertPointOrErrorTy AfterIP =
6733 BodyGenCB(Builder.saveIP(), BodyGenTy::Priv);
6734 if (!AfterIP)
6735 return AfterIP.takeError();
6736 Builder.restoreIP(IP: *AfterIP);
6737 }
6738 return Error::success();
6739 };
6740
6741 // If we need device pointer privatization, we need to emit the body of the
6742 // region with no privatization in the 'else' branch of the conditional.
6743 // Otherwise, we don't have to do anything.
6744 auto BeginElseGen = [&](InsertPointTy AllocaIP,
6745 InsertPointTy CodeGenIP) -> Error {
6746 InsertPointOrErrorTy AfterIP =
6747 BodyGenCB(Builder.saveIP(), BodyGenTy::DupNoPriv);
6748 if (!AfterIP)
6749 return AfterIP.takeError();
6750 Builder.restoreIP(IP: *AfterIP);
6751 return Error::success();
6752 };
6753
6754 // Generate code for the closing of the data region.
6755 auto EndThenGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
6756 TargetDataRTArgs RTArgs;
6757 Info.EmitDebug = !MapInfo->Names.empty();
6758 emitOffloadingArraysArgument(Builder, RTArgs, Info, /*ForEndCall=*/true);
6759
6760 // Emit the number of elements in the offloading arrays.
6761 Value *PointerNum = Builder.getInt32(C: Info.NumberOfPtrs);
6762
6763 // Source location for the ident struct
6764 if (!SrcLocInfo) {
6765 uint32_t SrcLocStrSize;
6766 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6767 SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6768 }
6769
6770 Value *OffloadingArgs[] = {SrcLocInfo, DeviceID,
6771 PointerNum, RTArgs.BasePointersArray,
6772 RTArgs.PointersArray, RTArgs.SizesArray,
6773 RTArgs.MapTypesArray, RTArgs.MapNamesArray,
6774 RTArgs.MappersArray};
6775 Function *EndMapperFunc =
6776 getOrCreateRuntimeFunctionPtr(FnID: omp::OMPRTL___tgt_target_data_end_mapper);
6777
6778 Builder.CreateCall(Callee: EndMapperFunc, Args: OffloadingArgs);
6779 return Error::success();
6780 };
6781
6782 // We don't have to do anything to close the region if the if clause evaluates
6783 // to false.
6784 auto EndElseGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
6785 return Error::success();
6786 };
6787
6788 Error Err = [&]() -> Error {
6789 if (BodyGenCB) {
6790 Error Err = [&]() {
6791 if (IfCond)
6792 return emitIfClause(Cond: IfCond, ThenGen: BeginThenGen, ElseGen: BeginElseGen, AllocaIP);
6793 return BeginThenGen(AllocaIP, Builder.saveIP());
6794 }();
6795
6796 if (Err)
6797 return Err;
6798
6799 // If we don't require privatization of device pointers, we emit the body
6800 // in between the runtime calls. This avoids duplicating the body code.
6801 InsertPointOrErrorTy AfterIP =
6802 BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv);
6803 if (!AfterIP)
6804 return AfterIP.takeError();
6805 Builder.restoreIP(IP: *AfterIP);
6806
6807 if (IfCond)
6808 return emitIfClause(Cond: IfCond, ThenGen: EndThenGen, ElseGen: EndElseGen, AllocaIP);
6809 return EndThenGen(AllocaIP, Builder.saveIP());
6810 }
6811 if (IfCond)
6812 return emitIfClause(Cond: IfCond, ThenGen: BeginThenGen, ElseGen: EndElseGen, AllocaIP);
6813 return BeginThenGen(AllocaIP, Builder.saveIP());
6814 }();
6815
6816 if (Err)
6817 return Err;
6818
6819 return Builder.saveIP();
6820}
6821
6822FunctionCallee
6823OpenMPIRBuilder::createForStaticInitFunction(unsigned IVSize, bool IVSigned,
6824 bool IsGPUDistribute) {
6825 assert((IVSize == 32 || IVSize == 64) &&
6826 "IV size is not compatible with the omp runtime");
6827 RuntimeFunction Name;
6828 if (IsGPUDistribute)
6829 Name = IVSize == 32
6830 ? (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_4
6831 : omp::OMPRTL___kmpc_distribute_static_init_4u)
6832 : (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_8
6833 : omp::OMPRTL___kmpc_distribute_static_init_8u);
6834 else
6835 Name = IVSize == 32 ? (IVSigned ? omp::OMPRTL___kmpc_for_static_init_4
6836 : omp::OMPRTL___kmpc_for_static_init_4u)
6837 : (IVSigned ? omp::OMPRTL___kmpc_for_static_init_8
6838 : omp::OMPRTL___kmpc_for_static_init_8u);
6839
6840 return getOrCreateRuntimeFunction(M, FnID: Name);
6841}
6842
6843FunctionCallee OpenMPIRBuilder::createDispatchInitFunction(unsigned IVSize,
6844 bool IVSigned) {
6845 assert((IVSize == 32 || IVSize == 64) &&
6846 "IV size is not compatible with the omp runtime");
6847 RuntimeFunction Name = IVSize == 32
6848 ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_4
6849 : omp::OMPRTL___kmpc_dispatch_init_4u)
6850 : (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_8
6851 : omp::OMPRTL___kmpc_dispatch_init_8u);
6852
6853 return getOrCreateRuntimeFunction(M, FnID: Name);
6854}
6855
6856FunctionCallee OpenMPIRBuilder::createDispatchNextFunction(unsigned IVSize,
6857 bool IVSigned) {
6858 assert((IVSize == 32 || IVSize == 64) &&
6859 "IV size is not compatible with the omp runtime");
6860 RuntimeFunction Name = IVSize == 32
6861 ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_4
6862 : omp::OMPRTL___kmpc_dispatch_next_4u)
6863 : (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_8
6864 : omp::OMPRTL___kmpc_dispatch_next_8u);
6865
6866 return getOrCreateRuntimeFunction(M, FnID: Name);
6867}
6868
6869FunctionCallee OpenMPIRBuilder::createDispatchFiniFunction(unsigned IVSize,
6870 bool IVSigned) {
6871 assert((IVSize == 32 || IVSize == 64) &&
6872 "IV size is not compatible with the omp runtime");
6873 RuntimeFunction Name = IVSize == 32
6874 ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_4
6875 : omp::OMPRTL___kmpc_dispatch_fini_4u)
6876 : (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_8
6877 : omp::OMPRTL___kmpc_dispatch_fini_8u);
6878
6879 return getOrCreateRuntimeFunction(M, FnID: Name);
6880}
6881
6882FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
6883 return getOrCreateRuntimeFunction(M, FnID: omp::OMPRTL___kmpc_dispatch_deinit);
6884}
6885
6886static void FixupDebugInfoForOutlinedFunction(
6887 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, Function *Func,
6888 DenseMap<Value *, std::tuple<Value *, unsigned>> &ValueReplacementMap) {
6889
6890 DISubprogram *NewSP = Func->getSubprogram();
6891 if (!NewSP)
6892 return;
6893
6894 SmallDenseMap<DILocalVariable *, DILocalVariable *> RemappedVariables;
6895
6896 auto GetUpdatedDIVariable = [&](DILocalVariable *OldVar, unsigned arg) {
6897 DILocalVariable *&NewVar = RemappedVariables[OldVar];
6898 // Only use cached variable if the arg number matches. This is important
6899 // so that DIVariable created for privatized variables are not discarded.
6900 if (NewVar && (arg == NewVar->getArg()))
6901 return NewVar;
6902
6903 NewVar = llvm::DILocalVariable::get(
6904 Context&: Builder.getContext(), Scope: OldVar->getScope(), Name: OldVar->getName(),
6905 File: OldVar->getFile(), Line: OldVar->getLine(), Type: OldVar->getType(), Arg: arg,
6906 Flags: OldVar->getFlags(), AlignInBits: OldVar->getAlignInBits(), Annotations: OldVar->getAnnotations());
6907 return NewVar;
6908 };
6909
6910 auto UpdateDebugRecord = [&](auto *DR) {
6911 DILocalVariable *OldVar = DR->getVariable();
6912 unsigned ArgNo = 0;
6913 for (auto Loc : DR->location_ops()) {
6914 auto Iter = ValueReplacementMap.find(Loc);
6915 if (Iter != ValueReplacementMap.end()) {
6916 DR->replaceVariableLocationOp(Loc, std::get<0>(Iter->second));
6917 ArgNo = std::get<1>(Iter->second) + 1;
6918 }
6919 }
6920 if (ArgNo != 0)
6921 DR->setVariable(GetUpdatedDIVariable(OldVar, ArgNo));
6922 };
6923
6924 // The location and scope of variable intrinsics and records still point to
6925 // the parent function of the target region. Update them.
6926 for (Instruction &I : instructions(F: Func)) {
6927 if (auto *DDI = dyn_cast<llvm::DbgVariableIntrinsic>(Val: &I))
6928 UpdateDebugRecord(DDI);
6929
6930 for (DbgVariableRecord &DVR : filterDbgVars(R: I.getDbgRecordRange()))
6931 UpdateDebugRecord(&DVR);
6932 }
6933 // An extra argument is passed to the device. Create the debug data for it.
6934 if (OMPBuilder.Config.isTargetDevice()) {
6935 DICompileUnit *CU = NewSP->getUnit();
6936 Module *M = Func->getParent();
6937 DIBuilder DB(*M, true, CU);
6938 DIType *VoidPtrTy =
6939 DB.createQualifiedType(Tag: dwarf::DW_TAG_pointer_type, FromTy: nullptr);
6940 DILocalVariable *Var = DB.createParameterVariable(
6941 Scope: NewSP, Name: "dyn_ptr", /*ArgNo*/ 1, File: NewSP->getFile(), /*LineNo=*/0,
6942 Ty: VoidPtrTy, /*AlwaysPreserve=*/false, Flags: DINode::DIFlags::FlagArtificial);
6943 auto Loc = DILocation::get(Context&: Func->getContext(), Line: 0, Column: 0, Scope: NewSP, InlinedAt: 0);
6944 DB.insertDeclare(Storage: &(*Func->arg_begin()), VarInfo: Var, Expr: DB.createExpression(), DL: Loc,
6945 InsertAtEnd: &(*Func->begin()));
6946 }
6947}
6948
6949static Expected<Function *> createOutlinedFunction(
6950 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
6951 const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
6952 StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
6953 OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
6954 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
6955 SmallVector<Type *> ParameterTypes;
6956 if (OMPBuilder.Config.isTargetDevice()) {
6957 // Add the "implicit" runtime argument we use to provide launch specific
6958 // information for target devices.
6959 auto *Int8PtrTy = PointerType::getUnqual(C&: Builder.getContext());
6960 ParameterTypes.push_back(Elt: Int8PtrTy);
6961
6962 // All parameters to target devices are passed as pointers
6963 // or i64. This assumes 64-bit address spaces/pointers.
6964 for (auto &Arg : Inputs)
6965 ParameterTypes.push_back(Elt: Arg->getType()->isPointerTy()
6966 ? Arg->getType()
6967 : Type::getInt64Ty(C&: Builder.getContext()));
6968 } else {
6969 for (auto &Arg : Inputs)
6970 ParameterTypes.push_back(Elt: Arg->getType());
6971 }
6972
6973 auto BB = Builder.GetInsertBlock();
6974 auto M = BB->getModule();
6975 auto FuncType = FunctionType::get(Result: Builder.getVoidTy(), Params: ParameterTypes,
6976 /*isVarArg*/ false);
6977 auto Func =
6978 Function::Create(Ty: FuncType, Linkage: GlobalValue::InternalLinkage, N: FuncName, M);
6979
6980 // Forward target-cpu and target-features function attributes from the
6981 // original function to the new outlined function.
6982 Function *ParentFn = Builder.GetInsertBlock()->getParent();
6983
6984 auto TargetCpuAttr = ParentFn->getFnAttribute(Kind: "target-cpu");
6985 if (TargetCpuAttr.isStringAttribute())
6986 Func->addFnAttr(Attr: TargetCpuAttr);
6987
6988 auto TargetFeaturesAttr = ParentFn->getFnAttribute(Kind: "target-features");
6989 if (TargetFeaturesAttr.isStringAttribute())
6990 Func->addFnAttr(Attr: TargetFeaturesAttr);
6991
6992 if (OMPBuilder.Config.isTargetDevice()) {
6993 Value *ExecMode =
6994 OMPBuilder.emitKernelExecutionMode(KernelName: FuncName, Mode: DefaultAttrs.ExecFlags);
6995 OMPBuilder.emitUsed(Name: "llvm.compiler.used", List: {ExecMode});
6996 }
6997
6998 // Save insert point.
6999 IRBuilder<>::InsertPointGuard IPG(Builder);
7000 // We will generate the entries in the outlined function but the debug
7001 // location may still be pointing to the parent function. Reset it now.
7002 Builder.SetCurrentDebugLocation(llvm::DebugLoc());
7003
7004 // Generate the region into the function.
7005 BasicBlock *EntryBB = BasicBlock::Create(Context&: Builder.getContext(), Name: "entry", Parent: Func);
7006 Builder.SetInsertPoint(EntryBB);
7007
7008 // Insert target init call in the device compilation pass.
7009 if (OMPBuilder.Config.isTargetDevice())
7010 Builder.restoreIP(IP: OMPBuilder.createTargetInit(Loc: Builder, Attrs: DefaultAttrs));
7011
7012 BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
7013
7014 // As we embed the user code in the middle of our target region after we
7015 // generate entry code, we must move what allocas we can into the entry
7016 // block to avoid possible breaking optimisations for device
7017 if (OMPBuilder.Config.isTargetDevice())
7018 OMPBuilder.ConstantAllocaRaiseCandidates.emplace_back(Args&: Func);
7019
7020 // Insert target deinit call in the device compilation pass.
7021 BasicBlock *OutlinedBodyBB =
7022 splitBB(Builder, /*CreateBranch=*/true, Name: "outlined.body");
7023 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = CBFunc(
7024 Builder.saveIP(),
7025 OpenMPIRBuilder::InsertPointTy(OutlinedBodyBB, OutlinedBodyBB->begin()));
7026 if (!AfterIP)
7027 return AfterIP.takeError();
7028 Builder.restoreIP(IP: *AfterIP);
7029 if (OMPBuilder.Config.isTargetDevice())
7030 OMPBuilder.createTargetDeinit(Loc: Builder);
7031
7032 // Insert return instruction.
7033 Builder.CreateRetVoid();
7034
7035 // New Alloca IP at entry point of created device function.
7036 Builder.SetInsertPoint(EntryBB->getFirstNonPHIIt());
7037 auto AllocaIP = Builder.saveIP();
7038
7039 Builder.SetInsertPoint(UserCodeEntryBB->getFirstNonPHIOrDbg());
7040
7041 // Skip the artificial dyn_ptr on the device.
7042 const auto &ArgRange =
7043 OMPBuilder.Config.isTargetDevice()
7044 ? make_range(x: Func->arg_begin() + 1, y: Func->arg_end())
7045 : Func->args();
7046
7047 DenseMap<Value *, std::tuple<Value *, unsigned>> ValueReplacementMap;
7048
7049 auto ReplaceValue = [](Value *Input, Value *InputCopy, Function *Func) {
7050 // Things like GEP's can come in the form of Constants. Constants and
7051 // ConstantExpr's do not have access to the knowledge of what they're
7052 // contained in, so we must dig a little to find an instruction so we
7053 // can tell if they're used inside of the function we're outlining. We
7054 // also replace the original constant expression with a new instruction
7055 // equivalent; an instruction as it allows easy modification in the
7056 // following loop, as we can now know the constant (instruction) is
7057 // owned by our target function and replaceUsesOfWith can now be invoked
7058 // on it (cannot do this with constants it seems). A brand new one also
7059 // allows us to be cautious as it is perhaps possible the old expression
7060 // was used inside of the function but exists and is used externally
7061 // (unlikely by the nature of a Constant, but still).
7062 // NOTE: We cannot remove dead constants that have been rewritten to
7063 // instructions at this stage, we run the risk of breaking later lowering
7064 // by doing so as we could still be in the process of lowering the module
7065 // from MLIR to LLVM-IR and the MLIR lowering may still require the original
7066 // constants we have created rewritten versions of.
7067 if (auto *Const = dyn_cast<Constant>(Val: Input))
7068 convertUsersOfConstantsToInstructions(Consts: Const, RestrictToFunc: Func, RemoveDeadConstants: false);
7069
7070 // Collect users before iterating over them to avoid invalidating the
7071 // iteration in case a user uses Input more than once (e.g. a call
7072 // instruction).
7073 SetVector<User *> Users(Input->users().begin(), Input->users().end());
7074 // Collect all the instructions
7075 for (User *User : make_early_inc_range(Range&: Users))
7076 if (auto *Instr = dyn_cast<Instruction>(Val: User))
7077 if (Instr->getFunction() == Func)
7078 Instr->replaceUsesOfWith(From: Input, To: InputCopy);
7079 };
7080
7081 SmallVector<std::pair<Value *, Value *>> DeferredReplacement;
7082
7083 // Rewrite uses of input valus to parameters.
7084 for (auto InArg : zip(t&: Inputs, u: ArgRange)) {
7085 Value *Input = std::get<0>(t&: InArg);
7086 Argument &Arg = std::get<1>(t&: InArg);
7087 Value *InputCopy = nullptr;
7088
7089 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
7090 ArgAccessorFuncCB(Arg, Input, InputCopy, AllocaIP, Builder.saveIP());
7091 if (!AfterIP)
7092 return AfterIP.takeError();
7093 Builder.restoreIP(IP: *AfterIP);
7094 ValueReplacementMap[Input] = std::make_tuple(args&: InputCopy, args: Arg.getArgNo());
7095
7096 // In certain cases a Global may be set up for replacement, however, this
7097 // Global may be used in multiple arguments to the kernel, just segmented
7098 // apart, for example, if we have a global array, that is sectioned into
7099 // multiple mappings (technically not legal in OpenMP, but there is a case
7100 // in Fortran for Common Blocks where this is neccesary), we will end up
7101 // with GEP's into this array inside the kernel, that refer to the Global
7102 // but are technically seperate arguments to the kernel for all intents and
7103 // purposes. If we have mapped a segment that requires a GEP into the 0-th
7104 // index, it will fold into an referal to the Global, if we then encounter
7105 // this folded GEP during replacement all of the references to the
7106 // Global in the kernel will be replaced with the argument we have generated
7107 // that corresponds to it, including any other GEP's that refer to the
7108 // Global that may be other arguments. This will invalidate all of the other
7109 // preceding mapped arguments that refer to the same global that may be
7110 // seperate segments. To prevent this, we defer global processing until all
7111 // other processing has been performed.
7112 if (isa<GlobalValue>(Val: Input)) {
7113 DeferredReplacement.push_back(Elt: std::make_pair(x&: Input, y&: InputCopy));
7114 continue;
7115 }
7116
7117 if (isa<ConstantData>(Val: Input))
7118 continue;
7119
7120 ReplaceValue(Input, InputCopy, Func);
7121 }
7122
7123 // Replace all of our deferred Input values, currently just Globals.
7124 for (auto Deferred : DeferredReplacement)
7125 ReplaceValue(std::get<0>(in&: Deferred), std::get<1>(in&: Deferred), Func);
7126
7127 FixupDebugInfoForOutlinedFunction(OMPBuilder, Builder, Func,
7128 ValueReplacementMap);
7129 return Func;
7130}
7131/// Given a task descriptor, TaskWithPrivates, return the pointer to the block
7132/// of pointers containing shared data between the parent task and the created
7133/// task.
7134static LoadInst *loadSharedDataFromTaskDescriptor(OpenMPIRBuilder &OMPIRBuilder,
7135 IRBuilderBase &Builder,
7136 Value *TaskWithPrivates,
7137 Type *TaskWithPrivatesTy) {
7138
7139 Type *TaskTy = OMPIRBuilder.Task;
7140 LLVMContext &Ctx = Builder.getContext();
7141 Value *TaskT =
7142 Builder.CreateStructGEP(Ty: TaskWithPrivatesTy, Ptr: TaskWithPrivates, Idx: 0);
7143 Value *Shareds = TaskT;
7144 // TaskWithPrivatesTy can be one of the following
7145 // 1. %struct.task_with_privates = type { %struct.kmp_task_ompbuilder_t,
7146 // %struct.privates }
7147 // 2. %struct.kmp_task_ompbuilder_t ;; This is simply TaskTy
7148 //
7149 // In the former case, that is when TaskWithPrivatesTy != TaskTy,
7150 // its first member has to be the task descriptor. TaskTy is the type of the
7151 // task descriptor. TaskT is the pointer to the task descriptor. Loading the
7152 // first member of TaskT, gives us the pointer to shared data.
7153 if (TaskWithPrivatesTy != TaskTy)
7154 Shareds = Builder.CreateStructGEP(Ty: TaskTy, Ptr: TaskT, Idx: 0);
7155 return Builder.CreateLoad(Ty: PointerType::getUnqual(C&: Ctx), Ptr: Shareds);
7156}
7157/// Create an entry point for a target task with the following.
7158/// It'll have the following signature
7159/// void @.omp_target_task_proxy_func(i32 %thread.id, ptr %task)
7160/// This function is called from emitTargetTask once the
7161/// code to launch the target kernel has been outlined already.
7162/// NumOffloadingArrays is the number of offloading arrays that we need to copy
7163/// into the task structure so that the deferred target task can access this
7164/// data even after the stack frame of the generating task has been rolled
7165/// back. Offloading arrays contain base pointers, pointers, sizes etc
7166/// of the data that the target kernel will access. These in effect are the
7167/// non-empty arrays of pointers held by OpenMPIRBuilder::TargetDataRTArgs.
7168static Function *emitTargetTaskProxyFunction(
7169 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, CallInst *StaleCI,
7170 StructType *PrivatesTy, StructType *TaskWithPrivatesTy,
7171 const size_t NumOffloadingArrays, const int SharedArgsOperandNo) {
7172
7173 // If NumOffloadingArrays is non-zero, PrivatesTy better not be nullptr.
7174 // This is because PrivatesTy is the type of the structure in which
7175 // we pass the offloading arrays to the deferred target task.
7176 assert((!NumOffloadingArrays || PrivatesTy) &&
7177 "PrivatesTy cannot be nullptr when there are offloadingArrays"
7178 "to privatize");
7179
7180 Module &M = OMPBuilder.M;
7181 // KernelLaunchFunction is the target launch function, i.e.
7182 // the function that sets up kernel arguments and calls
7183 // __tgt_target_kernel to launch the kernel on the device.
7184 //
7185 Function *KernelLaunchFunction = StaleCI->getCalledFunction();
7186
7187 // StaleCI is the CallInst which is the call to the outlined
7188 // target kernel launch function. If there are local live-in values
7189 // that the outlined function uses then these are aggregated into a structure
7190 // which is passed as the second argument. If there are no local live-in
7191 // values or if all values used by the outlined kernel are global variables,
7192 // then there's only one argument, the threadID. So, StaleCI can be
7193 //
7194 // %structArg = alloca { ptr, ptr }, align 8
7195 // %gep_ = getelementptr { ptr, ptr }, ptr %structArg, i32 0, i32 0
7196 // store ptr %20, ptr %gep_, align 8
7197 // %gep_8 = getelementptr { ptr, ptr }, ptr %structArg, i32 0, i32 1
7198 // store ptr %21, ptr %gep_8, align 8
7199 // call void @_QQmain..omp_par.1(i32 %global.tid.val6, ptr %structArg)
7200 //
7201 // OR
7202 //
7203 // call void @_QQmain..omp_par.1(i32 %global.tid.val6)
7204 OpenMPIRBuilder::InsertPointTy IP(StaleCI->getParent(),
7205 StaleCI->getIterator());
7206
7207 LLVMContext &Ctx = StaleCI->getParent()->getContext();
7208
7209 Type *ThreadIDTy = Type::getInt32Ty(C&: Ctx);
7210 Type *TaskPtrTy = OMPBuilder.TaskPtr;
7211 [[maybe_unused]] Type *TaskTy = OMPBuilder.Task;
7212
7213 auto ProxyFnTy =
7214 FunctionType::get(Result: Builder.getVoidTy(), Params: {ThreadIDTy, TaskPtrTy},
7215 /* isVarArg */ false);
7216 auto ProxyFn = Function::Create(Ty: ProxyFnTy, Linkage: GlobalValue::InternalLinkage,
7217 N: ".omp_target_task_proxy_func",
7218 M: Builder.GetInsertBlock()->getModule());
7219 Value *ThreadId = ProxyFn->getArg(i: 0);
7220 Value *TaskWithPrivates = ProxyFn->getArg(i: 1);
7221 ThreadId->setName("thread.id");
7222 TaskWithPrivates->setName("task");
7223
7224 bool HasShareds = SharedArgsOperandNo > 0;
7225 bool HasOffloadingArrays = NumOffloadingArrays > 0;
7226 BasicBlock *EntryBB =
7227 BasicBlock::Create(Context&: Builder.getContext(), Name: "entry", Parent: ProxyFn);
7228 Builder.SetInsertPoint(EntryBB);
7229
7230 SmallVector<Value *> KernelLaunchArgs;
7231 KernelLaunchArgs.reserve(N: StaleCI->arg_size());
7232 KernelLaunchArgs.push_back(Elt: ThreadId);
7233
7234 if (HasOffloadingArrays) {
7235 assert(TaskTy != TaskWithPrivatesTy &&
7236 "If there are offloading arrays to pass to the target"
7237 "TaskTy cannot be the same as TaskWithPrivatesTy");
7238 (void)TaskTy;
7239 Value *Privates =
7240 Builder.CreateStructGEP(Ty: TaskWithPrivatesTy, Ptr: TaskWithPrivates, Idx: 1);
7241 for (unsigned int i = 0; i < NumOffloadingArrays; ++i)
7242 KernelLaunchArgs.push_back(
7243 Elt: Builder.CreateStructGEP(Ty: PrivatesTy, Ptr: Privates, Idx: i));
7244 }
7245
7246 if (HasShareds) {
7247 auto *ArgStructAlloca =
7248 dyn_cast<AllocaInst>(Val: StaleCI->getArgOperand(i: SharedArgsOperandNo));
7249 assert(ArgStructAlloca &&
7250 "Unable to find the alloca instruction corresponding to arguments "
7251 "for extracted function");
7252 auto *ArgStructType = cast<StructType>(Val: ArgStructAlloca->getAllocatedType());
7253
7254 AllocaInst *NewArgStructAlloca =
7255 Builder.CreateAlloca(Ty: ArgStructType, ArraySize: nullptr, Name: "structArg");
7256
7257 Value *SharedsSize =
7258 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: ArgStructType));
7259
7260 LoadInst *LoadShared = loadSharedDataFromTaskDescriptor(
7261 OMPIRBuilder&: OMPBuilder, Builder, TaskWithPrivates, TaskWithPrivatesTy);
7262
7263 Builder.CreateMemCpy(
7264 Dst: NewArgStructAlloca, DstAlign: NewArgStructAlloca->getAlign(), Src: LoadShared,
7265 SrcAlign: LoadShared->getPointerAlignment(DL: M.getDataLayout()), Size: SharedsSize);
7266 KernelLaunchArgs.push_back(Elt: NewArgStructAlloca);
7267 }
7268 Builder.CreateCall(Callee: KernelLaunchFunction, Args: KernelLaunchArgs);
7269 Builder.CreateRetVoid();
7270 return ProxyFn;
7271}
7272static Type *getOffloadingArrayType(Value *V) {
7273
7274 if (auto *GEP = dyn_cast<GetElementPtrInst>(Val: V))
7275 return GEP->getSourceElementType();
7276 if (auto *Alloca = dyn_cast<AllocaInst>(Val: V))
7277 return Alloca->getAllocatedType();
7278
7279 llvm_unreachable("Unhandled Instruction type");
7280 return nullptr;
7281}
7282// This function returns a struct that has at most two members.
7283// The first member is always %struct.kmp_task_ompbuilder_t, that is the task
7284// descriptor. The second member, if needed, is a struct containing arrays
7285// that need to be passed to the offloaded target kernel. For example,
7286// if .offload_baseptrs, .offload_ptrs and .offload_sizes have to be passed to
7287// the target kernel and their types are [3 x ptr], [3 x ptr] and [3 x i64]
7288// respectively, then the types created by this function are
7289//
7290// %struct.privates = type { [3 x ptr], [3 x ptr], [3 x i64] }
7291// %struct.task_with_privates = type { %struct.kmp_task_ompbuilder_t,
7292// %struct.privates }
7293// %struct.task_with_privates is returned by this function.
7294// If there aren't any offloading arrays to pass to the target kernel,
7295// %struct.kmp_task_ompbuilder_t is returned.
7296static StructType *
7297createTaskWithPrivatesTy(OpenMPIRBuilder &OMPIRBuilder,
7298 ArrayRef<Value *> OffloadingArraysToPrivatize) {
7299
7300 if (OffloadingArraysToPrivatize.empty())
7301 return OMPIRBuilder.Task;
7302
7303 SmallVector<Type *, 4> StructFieldTypes;
7304 for (Value *V : OffloadingArraysToPrivatize) {
7305 assert(V->getType()->isPointerTy() &&
7306 "Expected pointer to array to privatize. Got a non-pointer value "
7307 "instead");
7308 Type *ArrayTy = getOffloadingArrayType(V);
7309 assert(ArrayTy && "ArrayType cannot be nullptr");
7310 StructFieldTypes.push_back(Elt: ArrayTy);
7311 }
7312 StructType *PrivatesStructTy =
7313 StructType::create(Elements: StructFieldTypes, Name: "struct.privates");
7314 return StructType::create(Elements: {OMPIRBuilder.Task, PrivatesStructTy},
7315 Name: "struct.task_with_privates");
7316}
7317static Error emitTargetOutlinedFunction(
7318 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
7319 TargetRegionEntryInfo &EntryInfo,
7320 const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7321 Function *&OutlinedFn, Constant *&OutlinedFnID,
7322 SmallVectorImpl<Value *> &Inputs,
7323 OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
7324 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
7325
7326 OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
7327 [&](StringRef EntryFnName) {
7328 return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
7329 FuncName: EntryFnName, Inputs, CBFunc,
7330 ArgAccessorFuncCB);
7331 };
7332
7333 return OMPBuilder.emitTargetRegionFunction(
7334 EntryInfo, GenerateFunctionCallback&: GenerateOutlinedFunction, IsOffloadEntry, OutlinedFn,
7335 OutlinedFnID);
7336}
7337
7338OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask(
7339 TargetTaskBodyCallbackTy TaskBodyCB, Value *DeviceID, Value *RTLoc,
7340 OpenMPIRBuilder::InsertPointTy AllocaIP,
7341 const SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
7342 const TargetDataRTArgs &RTArgs, bool HasNoWait) {
7343
7344 // The following explains the code-gen scenario for the `target` directive. A
7345 // similar scneario is followed for other device-related directives (e.g.
7346 // `target enter data`) but in similar fashion since we only need to emit task
7347 // that encapsulates the proper runtime call.
7348 //
7349 // When we arrive at this function, the target region itself has been
7350 // outlined into the function OutlinedFn.
7351 // So at ths point, for
7352 // --------------------------------------------------------------
7353 // void user_code_that_offloads(...) {
7354 // omp target depend(..) map(from:a) map(to:b) private(i)
7355 // do i = 1, 10
7356 // a(i) = b(i) + n
7357 // }
7358 //
7359 // --------------------------------------------------------------
7360 //
7361 // we have
7362 //
7363 // --------------------------------------------------------------
7364 //
7365 // void user_code_that_offloads(...) {
7366 // %.offload_baseptrs = alloca [2 x ptr], align 8
7367 // %.offload_ptrs = alloca [2 x ptr], align 8
7368 // %.offload_mappers = alloca [2 x ptr], align 8
7369 // ;; target region has been outlined and now we need to
7370 // ;; offload to it via a target task.
7371 // }
7372 // void outlined_device_function(ptr a, ptr b, ptr n) {
7373 // n = *n_ptr;
7374 // do i = 1, 10
7375 // a(i) = b(i) + n
7376 // }
7377 //
7378 // We have to now do the following
7379 // (i) Make an offloading call to outlined_device_function using the OpenMP
7380 // RTL. See 'kernel_launch_function' in the pseudo code below. This is
7381 // emitted by emitKernelLaunch
7382 // (ii) Create a task entry point function that calls kernel_launch_function
7383 // and is the entry point for the target task. See
7384 // '@.omp_target_task_proxy_func in the pseudocode below.
7385 // (iii) Create a task with the task entry point created in (ii)
7386 //
7387 // That is we create the following
7388 // struct task_with_privates {
7389 // struct kmp_task_ompbuilder_t task_struct;
7390 // struct privates {
7391 // [2 x ptr] ; baseptrs
7392 // [2 x ptr] ; ptrs
7393 // [2 x i64] ; sizes
7394 // }
7395 // }
7396 // void user_code_that_offloads(...) {
7397 // %.offload_baseptrs = alloca [2 x ptr], align 8
7398 // %.offload_ptrs = alloca [2 x ptr], align 8
7399 // %.offload_sizes = alloca [2 x i64], align 8
7400 //
7401 // %structArg = alloca { ptr, ptr, ptr }, align 8
7402 // %strucArg[0] = a
7403 // %strucArg[1] = b
7404 // %strucArg[2] = &n
7405 //
7406 // target_task_with_privates = @__kmpc_omp_target_task_alloc(...,
7407 // sizeof(kmp_task_ompbuilder_t),
7408 // sizeof(structArg),
7409 // @.omp_target_task_proxy_func,
7410 // ...)
7411 // memcpy(target_task_with_privates->task_struct->shareds, %structArg,
7412 // sizeof(structArg))
7413 // memcpy(target_task_with_privates->privates->baseptrs,
7414 // offload_baseptrs, sizeof(offload_baseptrs)
7415 // memcpy(target_task_with_privates->privates->ptrs,
7416 // offload_ptrs, sizeof(offload_ptrs)
7417 // memcpy(target_task_with_privates->privates->sizes,
7418 // offload_sizes, sizeof(offload_sizes)
7419 // dependencies_array = ...
7420 // ;; if nowait not present
7421 // call @__kmpc_omp_wait_deps(..., dependencies_array)
7422 // call @__kmpc_omp_task_begin_if0(...)
7423 // call @ @.omp_target_task_proxy_func(i32 thread_id, ptr
7424 // %target_task_with_privates)
7425 // call @__kmpc_omp_task_complete_if0(...)
7426 // }
7427 //
7428 // define internal void @.omp_target_task_proxy_func(i32 %thread.id,
7429 // ptr %task) {
7430 // %structArg = alloca {ptr, ptr, ptr}
7431 // %task_ptr = getelementptr(%task, 0, 0)
7432 // %shared_data = load (getelementptr %task_ptr, 0, 0)
7433 // mempcy(%structArg, %shared_data, sizeof(%structArg))
7434 //
7435 // %offloading_arrays = getelementptr(%task, 0, 1)
7436 // %offload_baseptrs = getelementptr(%offloading_arrays, 0, 0)
7437 // %offload_ptrs = getelementptr(%offloading_arrays, 0, 1)
7438 // %offload_sizes = getelementptr(%offloading_arrays, 0, 2)
7439 // kernel_launch_function(%thread.id, %offload_baseptrs, %offload_ptrs,
7440 // %offload_sizes, %structArg)
7441 // }
7442 //
7443 // We need the proxy function because the signature of the task entry point
7444 // expected by kmpc_omp_task is always the same and will be different from
7445 // that of the kernel_launch function.
7446 //
7447 // kernel_launch_function is generated by emitKernelLaunch and has the
7448 // always_inline attribute. For this example, it'll look like so:
7449 // void kernel_launch_function(%thread_id, %offload_baseptrs, %offload_ptrs,
7450 // %offload_sizes, %structArg) alwaysinline {
7451 // %kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
7452 // ; load aggregated data from %structArg
7453 // ; setup kernel_args using offload_baseptrs, offload_ptrs and
7454 // ; offload_sizes
7455 // call i32 @__tgt_target_kernel(...,
7456 // outlined_device_function,
7457 // ptr %kernel_args)
7458 // }
7459 // void outlined_device_function(ptr a, ptr b, ptr n) {
7460 // n = *n_ptr;
7461 // do i = 1, 10
7462 // a(i) = b(i) + n
7463 // }
7464 //
7465 BasicBlock *TargetTaskBodyBB =
7466 splitBB(Builder, /*CreateBranch=*/true, Name: "target.task.body");
7467 BasicBlock *TargetTaskAllocaBB =
7468 splitBB(Builder, /*CreateBranch=*/true, Name: "target.task.alloca");
7469
7470 InsertPointTy TargetTaskAllocaIP(TargetTaskAllocaBB,
7471 TargetTaskAllocaBB->begin());
7472 InsertPointTy TargetTaskBodyIP(TargetTaskBodyBB, TargetTaskBodyBB->begin());
7473
7474 OutlineInfo OI;
7475 OI.EntryBB = TargetTaskAllocaBB;
7476 OI.OuterAllocaBB = AllocaIP.getBlock();
7477
7478 // Add the thread ID argument.
7479 SmallVector<Instruction *, 4> ToBeDeleted;
7480 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
7481 Builder, OuterAllocaIP: AllocaIP, ToBeDeleted, InnerAllocaIP: TargetTaskAllocaIP, Name: "global.tid", AsPtr: false));
7482
7483 // Generate the task body which will subsequently be outlined.
7484 Builder.restoreIP(IP: TargetTaskBodyIP);
7485 if (Error Err = TaskBodyCB(DeviceID, RTLoc, TargetTaskAllocaIP))
7486 return Err;
7487
7488 // The outliner (CodeExtractor) extract a sequence or vector of blocks that
7489 // it is given. These blocks are enumerated by
7490 // OpenMPIRBuilder::OutlineInfo::collectBlocks which expects the OI.ExitBlock
7491 // to be outside the region. In other words, OI.ExitBlock is expected to be
7492 // the start of the region after the outlining. We used to set OI.ExitBlock
7493 // to the InsertBlock after TaskBodyCB is done. This is fine in most cases
7494 // except when the task body is a single basic block. In that case,
7495 // OI.ExitBlock is set to the single task body block and will get left out of
7496 // the outlining process. So, simply create a new empty block to which we
7497 // uncoditionally branch from where TaskBodyCB left off
7498 OI.ExitBB = BasicBlock::Create(Context&: Builder.getContext(), Name: "target.task.cont");
7499 emitBlock(BB: OI.ExitBB, CurFn: Builder.GetInsertBlock()->getParent(),
7500 /*IsFinished=*/true);
7501
7502 SmallVector<Value *, 2> OffloadingArraysToPrivatize;
7503 bool NeedsTargetTask = HasNoWait && DeviceID;
7504 if (NeedsTargetTask) {
7505 for (auto *V :
7506 {RTArgs.BasePointersArray, RTArgs.PointersArray, RTArgs.MappersArray,
7507 RTArgs.MapNamesArray, RTArgs.MapTypesArray, RTArgs.MapTypesArrayEnd,
7508 RTArgs.SizesArray}) {
7509 if (V && !isa<ConstantPointerNull, GlobalVariable>(Val: V)) {
7510 OffloadingArraysToPrivatize.push_back(Elt: V);
7511 OI.ExcludeArgsFromAggregate.push_back(Elt: V);
7512 }
7513 }
7514 }
7515 OI.PostOutlineCB = [this, ToBeDeleted, Dependencies, NeedsTargetTask,
7516 DeviceID, OffloadingArraysToPrivatize](
7517 Function &OutlinedFn) mutable {
7518 assert(OutlinedFn.hasOneUse() &&
7519 "there must be a single user for the outlined function");
7520
7521 CallInst *StaleCI = cast<CallInst>(Val: OutlinedFn.user_back());
7522
7523 // The first argument of StaleCI is always the thread id.
7524 // The next few arguments are the pointers to offloading arrays
7525 // if any. (see OffloadingArraysToPrivatize)
7526 // Finally, all other local values that are live-in into the outlined region
7527 // end up in a structure whose pointer is passed as the last argument. This
7528 // piece of data is passed in the "shared" field of the task structure. So,
7529 // we know we have to pass shareds to the task if the number of arguments is
7530 // greater than OffloadingArraysToPrivatize.size() + 1 The 1 is for the
7531 // thread id. Further, for safety, we assert that the number of arguments of
7532 // StaleCI is exactly OffloadingArraysToPrivatize.size() + 2
7533 const unsigned int NumStaleCIArgs = StaleCI->arg_size();
7534 bool HasShareds = NumStaleCIArgs > OffloadingArraysToPrivatize.size() + 1;
7535 assert((!HasShareds ||
7536 NumStaleCIArgs == (OffloadingArraysToPrivatize.size() + 2)) &&
7537 "Wrong number of arguments for StaleCI when shareds are present");
7538 int SharedArgOperandNo =
7539 HasShareds ? OffloadingArraysToPrivatize.size() + 1 : 0;
7540
7541 StructType *TaskWithPrivatesTy =
7542 createTaskWithPrivatesTy(OMPIRBuilder&: *this, OffloadingArraysToPrivatize);
7543 StructType *PrivatesTy = nullptr;
7544
7545 if (!OffloadingArraysToPrivatize.empty())
7546 PrivatesTy =
7547 static_cast<StructType *>(TaskWithPrivatesTy->getElementType(N: 1));
7548
7549 Function *ProxyFn = emitTargetTaskProxyFunction(
7550 OMPBuilder&: *this, Builder, StaleCI, PrivatesTy, TaskWithPrivatesTy,
7551 NumOffloadingArrays: OffloadingArraysToPrivatize.size(), SharedArgsOperandNo: SharedArgOperandNo);
7552
7553 LLVM_DEBUG(dbgs() << "Proxy task entry function created: " << *ProxyFn
7554 << "\n");
7555
7556 Builder.SetInsertPoint(StaleCI);
7557
7558 // Gather the arguments for emitting the runtime call.
7559 uint32_t SrcLocStrSize;
7560 Constant *SrcLocStr =
7561 getOrCreateSrcLocStr(Loc: LocationDescription(Builder), SrcLocStrSize);
7562 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7563
7564 // @__kmpc_omp_task_alloc or @__kmpc_omp_target_task_alloc
7565 //
7566 // If `HasNoWait == true`, we call @__kmpc_omp_target_task_alloc to provide
7567 // the DeviceID to the deferred task and also since
7568 // @__kmpc_omp_target_task_alloc creates an untied/async task.
7569 Function *TaskAllocFn =
7570 !NeedsTargetTask
7571 ? getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_alloc)
7572 : getOrCreateRuntimeFunctionPtr(
7573 FnID: OMPRTL___kmpc_omp_target_task_alloc);
7574
7575 // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
7576 // call.
7577 Value *ThreadID = getOrCreateThreadID(Ident);
7578
7579 // Argument - `sizeof_kmp_task_t` (TaskSize)
7580 // Tasksize refers to the size in bytes of kmp_task_t data structure
7581 // plus any other data to be passed to the target task, if any, which
7582 // is packed into a struct. kmp_task_t and the struct so created are
7583 // packed into a wrapper struct whose type is TaskWithPrivatesTy.
7584 Value *TaskSize = Builder.getInt64(
7585 C: M.getDataLayout().getTypeStoreSize(Ty: TaskWithPrivatesTy));
7586
7587 // Argument - `sizeof_shareds` (SharedsSize)
7588 // SharedsSize refers to the shareds array size in the kmp_task_t data
7589 // structure.
7590 Value *SharedsSize = Builder.getInt64(C: 0);
7591 if (HasShareds) {
7592 auto *ArgStructAlloca =
7593 dyn_cast<AllocaInst>(Val: StaleCI->getArgOperand(i: SharedArgOperandNo));
7594 assert(ArgStructAlloca &&
7595 "Unable to find the alloca instruction corresponding to arguments "
7596 "for extracted function");
7597 auto *ArgStructType =
7598 dyn_cast<StructType>(Val: ArgStructAlloca->getAllocatedType());
7599 assert(ArgStructType && "Unable to find struct type corresponding to "
7600 "arguments for extracted function");
7601 SharedsSize =
7602 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: ArgStructType));
7603 }
7604
7605 // Argument - `flags`
7606 // Task is tied iff (Flags & 1) == 1.
7607 // Task is untied iff (Flags & 1) == 0.
7608 // Task is final iff (Flags & 2) == 2.
7609 // Task is not final iff (Flags & 2) == 0.
7610 // A target task is not final and is untied.
7611 Value *Flags = Builder.getInt32(C: 0);
7612
7613 // Emit the @__kmpc_omp_task_alloc runtime call
7614 // The runtime call returns a pointer to an area where the task captured
7615 // variables must be copied before the task is run (TaskData)
7616 CallInst *TaskData = nullptr;
7617
7618 SmallVector<llvm::Value *> TaskAllocArgs = {
7619 /*loc_ref=*/Ident, /*gtid=*/ThreadID,
7620 /*flags=*/Flags,
7621 /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
7622 /*task_func=*/ProxyFn};
7623
7624 if (NeedsTargetTask) {
7625 assert(DeviceID && "Expected non-empty device ID.");
7626 TaskAllocArgs.push_back(Elt: DeviceID);
7627 }
7628
7629 TaskData = Builder.CreateCall(Callee: TaskAllocFn, Args: TaskAllocArgs);
7630
7631 Align Alignment = TaskData->getPointerAlignment(DL: M.getDataLayout());
7632 if (HasShareds) {
7633 Value *Shareds = StaleCI->getArgOperand(i: SharedArgOperandNo);
7634 Value *TaskShareds = loadSharedDataFromTaskDescriptor(
7635 OMPIRBuilder&: *this, Builder, TaskWithPrivates: TaskData, TaskWithPrivatesTy);
7636 Builder.CreateMemCpy(Dst: TaskShareds, DstAlign: Alignment, Src: Shareds, SrcAlign: Alignment,
7637 Size: SharedsSize);
7638 }
7639 if (!OffloadingArraysToPrivatize.empty()) {
7640 Value *Privates =
7641 Builder.CreateStructGEP(Ty: TaskWithPrivatesTy, Ptr: TaskData, Idx: 1);
7642 for (unsigned int i = 0; i < OffloadingArraysToPrivatize.size(); ++i) {
7643 Value *PtrToPrivatize = OffloadingArraysToPrivatize[i];
7644 [[maybe_unused]] Type *ArrayType =
7645 getOffloadingArrayType(V: PtrToPrivatize);
7646 assert(ArrayType && "ArrayType cannot be nullptr");
7647
7648 Type *ElementType = PrivatesTy->getElementType(N: i);
7649 assert(ElementType == ArrayType &&
7650 "ElementType should match ArrayType");
7651 (void)ArrayType;
7652
7653 Value *Dst = Builder.CreateStructGEP(Ty: PrivatesTy, Ptr: Privates, Idx: i);
7654 Builder.CreateMemCpy(
7655 Dst, DstAlign: Alignment, Src: PtrToPrivatize, SrcAlign: Alignment,
7656 Size: Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: ElementType)));
7657 }
7658 }
7659
7660 Value *DepArray = emitTaskDependencies(OMPBuilder&: *this, Dependencies);
7661
7662 // ---------------------------------------------------------------
7663 // V5.2 13.8 target construct
7664 // If the nowait clause is present, execution of the target task
7665 // may be deferred. If the nowait clause is not present, the target task is
7666 // an included task.
7667 // ---------------------------------------------------------------
7668 // The above means that the lack of a nowait on the target construct
7669 // translates to '#pragma omp task if(0)'
7670 if (!NeedsTargetTask) {
7671 if (DepArray) {
7672 Function *TaskWaitFn =
7673 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_wait_deps);
7674 Builder.CreateCall(
7675 Callee: TaskWaitFn,
7676 Args: {/*loc_ref=*/Ident, /*gtid=*/ThreadID,
7677 /*ndeps=*/Builder.getInt32(C: Dependencies.size()),
7678 /*dep_list=*/DepArray,
7679 /*ndeps_noalias=*/ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
7680 /*noalias_dep_list=*/
7681 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
7682 }
7683 // Included task.
7684 Function *TaskBeginFn =
7685 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_begin_if0);
7686 Function *TaskCompleteFn =
7687 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_complete_if0);
7688 Builder.CreateCall(Callee: TaskBeginFn, Args: {Ident, ThreadID, TaskData});
7689 CallInst *CI = Builder.CreateCall(Callee: ProxyFn, Args: {ThreadID, TaskData});
7690 CI->setDebugLoc(StaleCI->getDebugLoc());
7691 Builder.CreateCall(Callee: TaskCompleteFn, Args: {Ident, ThreadID, TaskData});
7692 } else if (DepArray) {
7693 // HasNoWait - meaning the task may be deferred. Call
7694 // __kmpc_omp_task_with_deps if there are dependencies,
7695 // else call __kmpc_omp_task
7696 Function *TaskFn =
7697 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_with_deps);
7698 Builder.CreateCall(
7699 Callee: TaskFn,
7700 Args: {Ident, ThreadID, TaskData, Builder.getInt32(C: Dependencies.size()),
7701 DepArray, ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
7702 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
7703 } else {
7704 // Emit the @__kmpc_omp_task runtime call to spawn the task
7705 Function *TaskFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task);
7706 Builder.CreateCall(Callee: TaskFn, Args: {Ident, ThreadID, TaskData});
7707 }
7708
7709 StaleCI->eraseFromParent();
7710 for (Instruction *I : llvm::reverse(C&: ToBeDeleted))
7711 I->eraseFromParent();
7712 };
7713 addOutlineInfo(OI: std::move(OI));
7714
7715 LLVM_DEBUG(dbgs() << "Insert block after emitKernelLaunch = \n"
7716 << *(Builder.GetInsertBlock()) << "\n");
7717 LLVM_DEBUG(dbgs() << "Module after emitKernelLaunch = \n"
7718 << *(Builder.GetInsertBlock()->getParent()->getParent())
7719 << "\n");
7720 return Builder.saveIP();
7721}
7722
7723Error OpenMPIRBuilder::emitOffloadingArraysAndArgs(
7724 InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
7725 TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo,
7726 CustomMapperCallbackTy CustomMapperCB, bool IsNonContiguous,
7727 bool ForEndCall, function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
7728 if (Error Err =
7729 emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info,
7730 CustomMapperCB, IsNonContiguous, DeviceAddrCB))
7731 return Err;
7732 emitOffloadingArraysArgument(Builder, RTArgs, Info, ForEndCall);
7733 return Error::success();
7734}
7735
7736static void emitTargetCall(
7737 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7738 OpenMPIRBuilder::InsertPointTy AllocaIP,
7739 OpenMPIRBuilder::TargetDataInfo &Info,
7740 const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7741 const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
7742 Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID,
7743 SmallVectorImpl<Value *> &Args,
7744 OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
7745 OpenMPIRBuilder::CustomMapperCallbackTy CustomMapperCB,
7746 const SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
7747 bool HasNoWait) {
7748 // Generate a function call to the host fallback implementation of the target
7749 // region. This is called by the host when no offload entry was generated for
7750 // the target region and when the offloading call fails at runtime.
7751 auto &&EmitTargetCallFallbackCB = [&](OpenMPIRBuilder::InsertPointTy IP)
7752 -> OpenMPIRBuilder::InsertPointOrErrorTy {
7753 Builder.restoreIP(IP);
7754 Builder.CreateCall(Callee: OutlinedFn, Args);
7755 return Builder.saveIP();
7756 };
7757
7758 bool HasDependencies = Dependencies.size() > 0;
7759 bool RequiresOuterTargetTask = HasNoWait || HasDependencies;
7760
7761 OpenMPIRBuilder::TargetKernelArgs KArgs;
7762
7763 auto TaskBodyCB =
7764 [&](Value *DeviceID, Value *RTLoc,
7765 IRBuilderBase::InsertPoint TargetTaskAllocaIP) -> Error {
7766 // Assume no error was returned because EmitTargetCallFallbackCB doesn't
7767 // produce any.
7768 llvm::OpenMPIRBuilder::InsertPointTy AfterIP = cantFail(ValOrErr: [&]() {
7769 // emitKernelLaunch makes the necessary runtime call to offload the
7770 // kernel. We then outline all that code into a separate function
7771 // ('kernel_launch_function' in the pseudo code above). This function is
7772 // then called by the target task proxy function (see
7773 // '@.omp_target_task_proxy_func' in the pseudo code above)
7774 // "@.omp_target_task_proxy_func' is generated by
7775 // emitTargetTaskProxyFunction.
7776 if (OutlinedFnID && DeviceID)
7777 return OMPBuilder.emitKernelLaunch(Loc: Builder, OutlinedFnID,
7778 EmitTargetCallFallbackCB, Args&: KArgs,
7779 DeviceID, RTLoc, AllocaIP: TargetTaskAllocaIP);
7780
7781 // We only need to do the outlining if `DeviceID` is set to avoid calling
7782 // `emitKernelLaunch` if we want to code-gen for the host; e.g. if we are
7783 // generating the `else` branch of an `if` clause.
7784 //
7785 // When OutlinedFnID is set to nullptr, then it's not an offloading call.
7786 // In this case, we execute the host implementation directly.
7787 return EmitTargetCallFallbackCB(OMPBuilder.Builder.saveIP());
7788 }());
7789
7790 OMPBuilder.Builder.restoreIP(IP: AfterIP);
7791 return Error::success();
7792 };
7793
7794 auto &&EmitTargetCallElse =
7795 [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
7796 OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
7797 // Assume no error was returned because EmitTargetCallFallbackCB doesn't
7798 // produce any.
7799 OpenMPIRBuilder::InsertPointTy AfterIP = cantFail(ValOrErr: [&]() {
7800 if (RequiresOuterTargetTask) {
7801 // Arguments that are intended to be directly forwarded to an
7802 // emitKernelLaunch call are pased as nullptr, since
7803 // OutlinedFnID=nullptr results in that call not being done.
7804 OpenMPIRBuilder::TargetDataRTArgs EmptyRTArgs;
7805 return OMPBuilder.emitTargetTask(TaskBodyCB, /*DeviceID=*/nullptr,
7806 /*RTLoc=*/nullptr, AllocaIP,
7807 Dependencies, RTArgs: EmptyRTArgs, HasNoWait);
7808 }
7809 return EmitTargetCallFallbackCB(Builder.saveIP());
7810 }());
7811
7812 Builder.restoreIP(IP: AfterIP);
7813 return Error::success();
7814 };
7815
7816 auto &&EmitTargetCallThen =
7817 [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
7818 OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
7819 Info.HasNoWait = HasNoWait;
7820 OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
7821 OpenMPIRBuilder::TargetDataRTArgs RTArgs;
7822 if (Error Err = OMPBuilder.emitOffloadingArraysAndArgs(
7823 AllocaIP, CodeGenIP: Builder.saveIP(), Info, RTArgs, CombinedInfo&: MapInfo, CustomMapperCB,
7824 /*IsNonContiguous=*/true,
7825 /*ForEndCall=*/false))
7826 return Err;
7827
7828 SmallVector<Value *, 3> NumTeamsC;
7829 for (auto [DefaultVal, RuntimeVal] :
7830 zip_equal(t: DefaultAttrs.MaxTeams, u: RuntimeAttrs.MaxTeams))
7831 NumTeamsC.push_back(Elt: RuntimeVal ? RuntimeVal
7832 : Builder.getInt32(C: DefaultVal));
7833
7834 // Calculate number of threads: 0 if no clauses specified, otherwise it is
7835 // the minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
7836 auto InitMaxThreadsClause = [&Builder](Value *Clause) {
7837 if (Clause)
7838 Clause = Builder.CreateIntCast(V: Clause, DestTy: Builder.getInt32Ty(),
7839 /*isSigned=*/false);
7840 return Clause;
7841 };
7842 auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
7843 if (Clause)
7844 Result =
7845 Result ? Builder.CreateSelect(C: Builder.CreateICmpULT(LHS: Result, RHS: Clause),
7846 True: Result, False: Clause)
7847 : Clause;
7848 };
7849
7850 // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
7851 // the NUM_THREADS clause is overriden by THREAD_LIMIT.
7852 SmallVector<Value *, 3> NumThreadsC;
7853 Value *MaxThreadsClause =
7854 RuntimeAttrs.TeamsThreadLimit.size() == 1
7855 ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
7856 : nullptr;
7857
7858 for (auto [TeamsVal, TargetVal] : zip_equal(
7859 t: RuntimeAttrs.TeamsThreadLimit, u: RuntimeAttrs.TargetThreadLimit)) {
7860 Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
7861 Value *NumThreads = InitMaxThreadsClause(TargetVal);
7862
7863 CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
7864 CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
7865
7866 NumThreadsC.push_back(Elt: NumThreads ? NumThreads : Builder.getInt32(C: 0));
7867 }
7868
7869 unsigned NumTargetItems = Info.NumberOfPtrs;
7870 // TODO: Use correct device ID
7871 Value *DeviceID = Builder.getInt64(C: OMP_DEVICEID_UNDEF);
7872 uint32_t SrcLocStrSize;
7873 Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
7874 Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
7875 LocFlags: llvm::omp::IdentFlag(0), Reserve2Flags: 0);
7876
7877 Value *TripCount = RuntimeAttrs.LoopTripCount
7878 ? Builder.CreateIntCast(V: RuntimeAttrs.LoopTripCount,
7879 DestTy: Builder.getInt64Ty(),
7880 /*isSigned=*/false)
7881 : Builder.getInt64(C: 0);
7882
7883 // TODO: Use correct DynCGGroupMem
7884 Value *DynCGGroupMem = Builder.getInt32(C: 0);
7885
7886 KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
7887 NumTeamsC, NumThreadsC,
7888 DynCGGroupMem, HasNoWait);
7889
7890 // Assume no error was returned because TaskBodyCB and
7891 // EmitTargetCallFallbackCB don't produce any.
7892 OpenMPIRBuilder::InsertPointTy AfterIP = cantFail(ValOrErr: [&]() {
7893 // The presence of certain clauses on the target directive require the
7894 // explicit generation of the target task.
7895 if (RequiresOuterTargetTask)
7896 return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
7897 Dependencies, RTArgs: KArgs.RTArgs,
7898 HasNoWait: Info.HasNoWait);
7899
7900 return OMPBuilder.emitKernelLaunch(Loc: Builder, OutlinedFnID,
7901 EmitTargetCallFallbackCB, Args&: KArgs,
7902 DeviceID, RTLoc, AllocaIP);
7903 }());
7904
7905 Builder.restoreIP(IP: AfterIP);
7906 return Error::success();
7907 };
7908
7909 // If we don't have an ID for the target region, it means an offload entry
7910 // wasn't created. In this case we just run the host fallback directly and
7911 // ignore any potential 'if' clauses.
7912 if (!OutlinedFnID) {
7913 cantFail(Err: EmitTargetCallElse(AllocaIP, Builder.saveIP()));
7914 return;
7915 }
7916
7917 // If there's no 'if' clause, only generate the kernel launch code path.
7918 if (!IfCond) {
7919 cantFail(Err: EmitTargetCallThen(AllocaIP, Builder.saveIP()));
7920 return;
7921 }
7922
7923 cantFail(Err: OMPBuilder.emitIfClause(Cond: IfCond, ThenGen: EmitTargetCallThen,
7924 ElseGen: EmitTargetCallElse, AllocaIP));
7925}
7926
7927OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
7928 const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
7929 InsertPointTy CodeGenIP, TargetDataInfo &Info,
7930 TargetRegionEntryInfo &EntryInfo,
7931 const TargetKernelDefaultAttrs &DefaultAttrs,
7932 const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
7933 SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
7934 OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
7935 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
7936 CustomMapperCallbackTy CustomMapperCB,
7937 const SmallVector<DependData> &Dependencies, bool HasNowait) {
7938
7939 if (!updateToLocation(Loc))
7940 return InsertPointTy();
7941
7942 Builder.restoreIP(IP: CodeGenIP);
7943
7944 Function *OutlinedFn;
7945 Constant *OutlinedFnID = nullptr;
7946 // The target region is outlined into its own function. The LLVM IR for
7947 // the target region itself is generated using the callbacks CBFunc
7948 // and ArgAccessorFuncCB
7949 if (Error Err = emitTargetOutlinedFunction(
7950 OMPBuilder&: *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
7951 OutlinedFnID, Inputs, CBFunc, ArgAccessorFuncCB))
7952 return Err;
7953
7954 // If we are not on the target device, then we need to generate code
7955 // to make a remote call (offload) to the previously outlined function
7956 // that represents the target region. Do that now.
7957 if (!Config.isTargetDevice())
7958 emitTargetCall(OMPBuilder&: *this, Builder, AllocaIP, Info, DefaultAttrs, RuntimeAttrs,
7959 IfCond, OutlinedFn, OutlinedFnID, Args&: Inputs, GenMapInfoCB,
7960 CustomMapperCB, Dependencies, HasNoWait: HasNowait);
7961 return Builder.saveIP();
7962}
7963
7964std::string OpenMPIRBuilder::getNameWithSeparators(ArrayRef<StringRef> Parts,
7965 StringRef FirstSeparator,
7966 StringRef Separator) {
7967 SmallString<128> Buffer;
7968 llvm::raw_svector_ostream OS(Buffer);
7969 StringRef Sep = FirstSeparator;
7970 for (StringRef Part : Parts) {
7971 OS << Sep << Part;
7972 Sep = Separator;
7973 }
7974 return OS.str().str();
7975}
7976
7977std::string
7978OpenMPIRBuilder::createPlatformSpecificName(ArrayRef<StringRef> Parts) const {
7979 return OpenMPIRBuilder::getNameWithSeparators(Parts, FirstSeparator: Config.firstSeparator(),
7980 Separator: Config.separator());
7981}
7982
7983GlobalVariable *
7984OpenMPIRBuilder::getOrCreateInternalVariable(Type *Ty, const StringRef &Name,
7985 unsigned AddressSpace) {
7986 auto &Elem = *InternalVars.try_emplace(Key: Name, Args: nullptr).first;
7987 if (Elem.second) {
7988 assert(Elem.second->getValueType() == Ty &&
7989 "OMP internal variable has different type than requested");
7990 } else {
7991 // TODO: investigate the appropriate linkage type used for the global
7992 // variable for possibly changing that to internal or private, or maybe
7993 // create different versions of the function for different OMP internal
7994 // variables.
7995 auto Linkage = this->M.getTargetTriple().getArch() == Triple::wasm32
7996 ? GlobalValue::InternalLinkage
7997 : GlobalValue::CommonLinkage;
7998 auto *GV = new GlobalVariable(M, Ty, /*IsConstant=*/false, Linkage,
7999 Constant::getNullValue(Ty), Elem.first(),
8000 /*InsertBefore=*/nullptr,
8001 GlobalValue::NotThreadLocal, AddressSpace);
8002 const DataLayout &DL = M.getDataLayout();
8003 const llvm::Align TypeAlign = DL.getABITypeAlign(Ty);
8004 const llvm::Align PtrAlign = DL.getPointerABIAlignment(AS: AddressSpace);
8005 GV->setAlignment(std::max(a: TypeAlign, b: PtrAlign));
8006 Elem.second = GV;
8007 }
8008
8009 return Elem.second;
8010}
8011
8012Value *OpenMPIRBuilder::getOMPCriticalRegionLock(StringRef CriticalName) {
8013 std::string Prefix = Twine("gomp_critical_user_", CriticalName).str();
8014 std::string Name = getNameWithSeparators(Parts: {Prefix, "var"}, FirstSeparator: ".", Separator: ".");
8015 return getOrCreateInternalVariable(Ty: KmpCriticalNameTy, Name);
8016}
8017
8018Value *OpenMPIRBuilder::getSizeInBytes(Value *BasePtr) {
8019 LLVMContext &Ctx = Builder.getContext();
8020 Value *Null =
8021 Constant::getNullValue(Ty: PointerType::getUnqual(C&: BasePtr->getContext()));
8022 Value *SizeGep =
8023 Builder.CreateGEP(Ty: BasePtr->getType(), Ptr: Null, IdxList: Builder.getInt32(C: 1));
8024 Value *SizePtrToInt = Builder.CreatePtrToInt(V: SizeGep, DestTy: Type::getInt64Ty(C&: Ctx));
8025 return SizePtrToInt;
8026}
8027
8028GlobalVariable *
8029OpenMPIRBuilder::createOffloadMaptypes(SmallVectorImpl<uint64_t> &Mappings,
8030 std::string VarName) {
8031 llvm::Constant *MaptypesArrayInit =
8032 llvm::ConstantDataArray::get(Context&: M.getContext(), Elts&: Mappings);
8033 auto *MaptypesArrayGlobal = new llvm::GlobalVariable(
8034 M, MaptypesArrayInit->getType(),
8035 /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MaptypesArrayInit,
8036 VarName);
8037 MaptypesArrayGlobal->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
8038 return MaptypesArrayGlobal;
8039}
8040
8041void OpenMPIRBuilder::createMapperAllocas(const LocationDescription &Loc,
8042 InsertPointTy AllocaIP,
8043 unsigned NumOperands,
8044 struct MapperAllocas &MapperAllocas) {
8045 if (!updateToLocation(Loc))
8046 return;
8047
8048 auto *ArrI8PtrTy = ArrayType::get(ElementType: Int8Ptr, NumElements: NumOperands);
8049 auto *ArrI64Ty = ArrayType::get(ElementType: Int64, NumElements: NumOperands);
8050 Builder.restoreIP(IP: AllocaIP);
8051 AllocaInst *ArgsBase = Builder.CreateAlloca(
8052 Ty: ArrI8PtrTy, /* ArraySize = */ nullptr, Name: ".offload_baseptrs");
8053 AllocaInst *Args = Builder.CreateAlloca(Ty: ArrI8PtrTy, /* ArraySize = */ nullptr,
8054 Name: ".offload_ptrs");
8055 AllocaInst *ArgSizes = Builder.CreateAlloca(
8056 Ty: ArrI64Ty, /* ArraySize = */ nullptr, Name: ".offload_sizes");
8057 Builder.restoreIP(IP: Loc.IP);
8058 MapperAllocas.ArgsBase = ArgsBase;
8059 MapperAllocas.Args = Args;
8060 MapperAllocas.ArgSizes = ArgSizes;
8061}
8062
8063void OpenMPIRBuilder::emitMapperCall(const LocationDescription &Loc,
8064 Function *MapperFunc, Value *SrcLocInfo,
8065 Value *MaptypesArg, Value *MapnamesArg,
8066 struct MapperAllocas &MapperAllocas,
8067 int64_t DeviceID, unsigned NumOperands) {
8068 if (!updateToLocation(Loc))
8069 return;
8070
8071 auto *ArrI8PtrTy = ArrayType::get(ElementType: Int8Ptr, NumElements: NumOperands);
8072 auto *ArrI64Ty = ArrayType::get(ElementType: Int64, NumElements: NumOperands);
8073 Value *ArgsBaseGEP =
8074 Builder.CreateInBoundsGEP(Ty: ArrI8PtrTy, Ptr: MapperAllocas.ArgsBase,
8075 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 0)});
8076 Value *ArgsGEP =
8077 Builder.CreateInBoundsGEP(Ty: ArrI8PtrTy, Ptr: MapperAllocas.Args,
8078 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 0)});
8079 Value *ArgSizesGEP =
8080 Builder.CreateInBoundsGEP(Ty: ArrI64Ty, Ptr: MapperAllocas.ArgSizes,
8081 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 0)});
8082 Value *NullPtr =
8083 Constant::getNullValue(Ty: PointerType::getUnqual(C&: Int8Ptr->getContext()));
8084 Builder.CreateCall(Callee: MapperFunc,
8085 Args: {SrcLocInfo, Builder.getInt64(C: DeviceID),
8086 Builder.getInt32(C: NumOperands), ArgsBaseGEP, ArgsGEP,
8087 ArgSizesGEP, MaptypesArg, MapnamesArg, NullPtr});
8088}
8089
8090void OpenMPIRBuilder::emitOffloadingArraysArgument(IRBuilderBase &Builder,
8091 TargetDataRTArgs &RTArgs,
8092 TargetDataInfo &Info,
8093 bool ForEndCall) {
8094 assert((!ForEndCall || Info.separateBeginEndCalls()) &&
8095 "expected region end call to runtime only when end call is separate");
8096 auto UnqualPtrTy = PointerType::getUnqual(C&: M.getContext());
8097 auto VoidPtrTy = UnqualPtrTy;
8098 auto VoidPtrPtrTy = UnqualPtrTy;
8099 auto Int64Ty = Type::getInt64Ty(C&: M.getContext());
8100 auto Int64PtrTy = UnqualPtrTy;
8101
8102 if (!Info.NumberOfPtrs) {
8103 RTArgs.BasePointersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
8104 RTArgs.PointersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
8105 RTArgs.SizesArray = ConstantPointerNull::get(T: Int64PtrTy);
8106 RTArgs.MapTypesArray = ConstantPointerNull::get(T: Int64PtrTy);
8107 RTArgs.MapNamesArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
8108 RTArgs.MappersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
8109 return;
8110 }
8111
8112 RTArgs.BasePointersArray = Builder.CreateConstInBoundsGEP2_32(
8113 Ty: ArrayType::get(ElementType: VoidPtrTy, NumElements: Info.NumberOfPtrs),
8114 Ptr: Info.RTArgs.BasePointersArray,
8115 /*Idx0=*/0, /*Idx1=*/0);
8116 RTArgs.PointersArray = Builder.CreateConstInBoundsGEP2_32(
8117 Ty: ArrayType::get(ElementType: VoidPtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.PointersArray,
8118 /*Idx0=*/0,
8119 /*Idx1=*/0);
8120 RTArgs.SizesArray = Builder.CreateConstInBoundsGEP2_32(
8121 Ty: ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.SizesArray,
8122 /*Idx0=*/0, /*Idx1=*/0);
8123 RTArgs.MapTypesArray = Builder.CreateConstInBoundsGEP2_32(
8124 Ty: ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs),
8125 Ptr: ForEndCall && Info.RTArgs.MapTypesArrayEnd ? Info.RTArgs.MapTypesArrayEnd
8126 : Info.RTArgs.MapTypesArray,
8127 /*Idx0=*/0,
8128 /*Idx1=*/0);
8129
8130 // Only emit the mapper information arrays if debug information is
8131 // requested.
8132 if (!Info.EmitDebug)
8133 RTArgs.MapNamesArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
8134 else
8135 RTArgs.MapNamesArray = Builder.CreateConstInBoundsGEP2_32(
8136 Ty: ArrayType::get(ElementType: VoidPtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.MapNamesArray,
8137 /*Idx0=*/0,
8138 /*Idx1=*/0);
8139 // If there is no user-defined mapper, set the mapper array to nullptr to
8140 // avoid an unnecessary data privatization
8141 if (!Info.HasMapper)
8142 RTArgs.MappersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
8143 else
8144 RTArgs.MappersArray =
8145 Builder.CreatePointerCast(V: Info.RTArgs.MappersArray, DestTy: VoidPtrPtrTy);
8146}
8147
8148void OpenMPIRBuilder::emitNonContiguousDescriptor(InsertPointTy AllocaIP,
8149 InsertPointTy CodeGenIP,
8150 MapInfosTy &CombinedInfo,
8151 TargetDataInfo &Info) {
8152 MapInfosTy::StructNonContiguousInfo &NonContigInfo =
8153 CombinedInfo.NonContigInfo;
8154
8155 // Build an array of struct descriptor_dim and then assign it to
8156 // offload_args.
8157 //
8158 // struct descriptor_dim {
8159 // uint64_t offset;
8160 // uint64_t count;
8161 // uint64_t stride
8162 // };
8163 Type *Int64Ty = Builder.getInt64Ty();
8164 StructType *DimTy = StructType::create(
8165 Context&: M.getContext(), Elements: ArrayRef<Type *>({Int64Ty, Int64Ty, Int64Ty}),
8166 Name: "struct.descriptor_dim");
8167
8168 enum { OffsetFD = 0, CountFD, StrideFD };
8169 // We need two index variable here since the size of "Dims" is the same as
8170 // the size of Components, however, the size of offset, count, and stride is
8171 // equal to the size of base declaration that is non-contiguous.
8172 for (unsigned I = 0, L = 0, E = NonContigInfo.Dims.size(); I < E; ++I) {
8173 // Skip emitting ir if dimension size is 1 since it cannot be
8174 // non-contiguous.
8175 if (NonContigInfo.Dims[I] == 1)
8176 continue;
8177 Builder.restoreIP(IP: AllocaIP);
8178 ArrayType *ArrayTy = ArrayType::get(ElementType: DimTy, NumElements: NonContigInfo.Dims[I]);
8179 AllocaInst *DimsAddr =
8180 Builder.CreateAlloca(Ty: ArrayTy, /* ArraySize = */ nullptr, Name: "dims");
8181 Builder.restoreIP(IP: CodeGenIP);
8182 for (unsigned II = 0, EE = NonContigInfo.Dims[I]; II < EE; ++II) {
8183 unsigned RevIdx = EE - II - 1;
8184 Value *DimsLVal = Builder.CreateInBoundsGEP(
8185 Ty: DimsAddr->getAllocatedType(), Ptr: DimsAddr,
8186 IdxList: {Builder.getInt64(C: 0), Builder.getInt64(C: II)});
8187 // Offset
8188 Value *OffsetLVal = Builder.CreateStructGEP(Ty: DimTy, Ptr: DimsLVal, Idx: OffsetFD);
8189 Builder.CreateAlignedStore(
8190 Val: NonContigInfo.Offsets[L][RevIdx], Ptr: OffsetLVal,
8191 Align: M.getDataLayout().getPrefTypeAlign(Ty: OffsetLVal->getType()));
8192 // Count
8193 Value *CountLVal = Builder.CreateStructGEP(Ty: DimTy, Ptr: DimsLVal, Idx: CountFD);
8194 Builder.CreateAlignedStore(
8195 Val: NonContigInfo.Counts[L][RevIdx], Ptr: CountLVal,
8196 Align: M.getDataLayout().getPrefTypeAlign(Ty: CountLVal->getType()));
8197 // Stride
8198 Value *StrideLVal = Builder.CreateStructGEP(Ty: DimTy, Ptr: DimsLVal, Idx: StrideFD);
8199 Builder.CreateAlignedStore(
8200 Val: NonContigInfo.Strides[L][RevIdx], Ptr: StrideLVal,
8201 Align: M.getDataLayout().getPrefTypeAlign(Ty: CountLVal->getType()));
8202 }
8203 // args[I] = &dims
8204 Builder.restoreIP(IP: CodeGenIP);
8205 Value *DAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
8206 V: DimsAddr, DestTy: Builder.getPtrTy());
8207 Value *P = Builder.CreateConstInBoundsGEP2_32(
8208 Ty: ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: Info.NumberOfPtrs),
8209 Ptr: Info.RTArgs.PointersArray, Idx0: 0, Idx1: I);
8210 Builder.CreateAlignedStore(
8211 Val: DAddr, Ptr: P, Align: M.getDataLayout().getPrefTypeAlign(Ty: Builder.getPtrTy()));
8212 ++L;
8213 }
8214}
8215
8216void OpenMPIRBuilder::emitUDMapperArrayInitOrDel(
8217 Function *MapperFn, Value *MapperHandle, Value *Base, Value *Begin,
8218 Value *Size, Value *MapType, Value *MapName, TypeSize ElementSize,
8219 BasicBlock *ExitBB, bool IsInit) {
8220 StringRef Prefix = IsInit ? ".init" : ".del";
8221
8222 // Evaluate if this is an array section.
8223 BasicBlock *BodyBB = BasicBlock::Create(
8224 Context&: M.getContext(), Name: createPlatformSpecificName(Parts: {"omp.array", Prefix}));
8225 Value *IsArray =
8226 Builder.CreateICmpSGT(LHS: Size, RHS: Builder.getInt64(C: 1), Name: "omp.arrayinit.isarray");
8227 Value *DeleteBit = Builder.CreateAnd(
8228 LHS: MapType,
8229 RHS: Builder.getInt64(
8230 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8231 OpenMPOffloadMappingFlags::OMP_MAP_DELETE)));
8232 Value *DeleteCond;
8233 Value *Cond;
8234 if (IsInit) {
8235 // base != begin?
8236 Value *BaseIsBegin = Builder.CreateICmpNE(LHS: Base, RHS: Begin);
8237 // IsPtrAndObj?
8238 Value *PtrAndObjBit = Builder.CreateAnd(
8239 LHS: MapType,
8240 RHS: Builder.getInt64(
8241 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8242 OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ)));
8243 PtrAndObjBit = Builder.CreateIsNotNull(Arg: PtrAndObjBit);
8244 BaseIsBegin = Builder.CreateAnd(LHS: BaseIsBegin, RHS: PtrAndObjBit);
8245 Cond = Builder.CreateOr(LHS: IsArray, RHS: BaseIsBegin);
8246 DeleteCond = Builder.CreateIsNull(
8247 Arg: DeleteBit,
8248 Name: createPlatformSpecificName(Parts: {"omp.array", Prefix, ".delete"}));
8249 } else {
8250 Cond = IsArray;
8251 DeleteCond = Builder.CreateIsNotNull(
8252 Arg: DeleteBit,
8253 Name: createPlatformSpecificName(Parts: {"omp.array", Prefix, ".delete"}));
8254 }
8255 Cond = Builder.CreateAnd(LHS: Cond, RHS: DeleteCond);
8256 Builder.CreateCondBr(Cond, True: BodyBB, False: ExitBB);
8257
8258 emitBlock(BB: BodyBB, CurFn: MapperFn);
8259 // Get the array size by multiplying element size and element number (i.e., \p
8260 // Size).
8261 Value *ArraySize = Builder.CreateNUWMul(LHS: Size, RHS: Builder.getInt64(C: ElementSize));
8262 // Remove OMP_MAP_TO and OMP_MAP_FROM from the map type, so that it achieves
8263 // memory allocation/deletion purpose only.
8264 Value *MapTypeArg = Builder.CreateAnd(
8265 LHS: MapType,
8266 RHS: Builder.getInt64(
8267 C: ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8268 OpenMPOffloadMappingFlags::OMP_MAP_TO |
8269 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
8270 MapTypeArg = Builder.CreateOr(
8271 LHS: MapTypeArg,
8272 RHS: Builder.getInt64(
8273 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8274 OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT)));
8275
8276 // Call the runtime API __tgt_push_mapper_component to fill up the runtime
8277 // data structure.
8278 Value *OffloadingArgs[] = {MapperHandle, Base, Begin,
8279 ArraySize, MapTypeArg, MapName};
8280 Builder.CreateCall(
8281 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___tgt_push_mapper_component),
8282 Args: OffloadingArgs);
8283}
8284
8285Expected<Function *> OpenMPIRBuilder::emitUserDefinedMapper(
8286 function_ref<MapInfosOrErrorTy(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
8287 llvm::Value *BeginArg)>
8288 GenMapInfoCB,
8289 Type *ElemTy, StringRef FuncName, CustomMapperCallbackTy CustomMapperCB) {
8290 SmallVector<Type *> Params;
8291 Params.emplace_back(Args: Builder.getPtrTy());
8292 Params.emplace_back(Args: Builder.getPtrTy());
8293 Params.emplace_back(Args: Builder.getPtrTy());
8294 Params.emplace_back(Args: Builder.getInt64Ty());
8295 Params.emplace_back(Args: Builder.getInt64Ty());
8296 Params.emplace_back(Args: Builder.getPtrTy());
8297
8298 auto *FnTy =
8299 FunctionType::get(Result: Builder.getVoidTy(), Params, /* IsVarArg */ isVarArg: false);
8300
8301 SmallString<64> TyStr;
8302 raw_svector_ostream Out(TyStr);
8303 Function *MapperFn =
8304 Function::Create(Ty: FnTy, Linkage: GlobalValue::InternalLinkage, N: FuncName, M);
8305 MapperFn->addFnAttr(Kind: Attribute::NoInline);
8306 MapperFn->addFnAttr(Kind: Attribute::NoUnwind);
8307 MapperFn->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
8308 MapperFn->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
8309 MapperFn->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
8310 MapperFn->addParamAttr(ArgNo: 3, Kind: Attribute::NoUndef);
8311 MapperFn->addParamAttr(ArgNo: 4, Kind: Attribute::NoUndef);
8312 MapperFn->addParamAttr(ArgNo: 5, Kind: Attribute::NoUndef);
8313
8314 // Start the mapper function code generation.
8315 BasicBlock *EntryBB = BasicBlock::Create(Context&: M.getContext(), Name: "entry", Parent: MapperFn);
8316 auto SavedIP = Builder.saveIP();
8317 Builder.SetInsertPoint(EntryBB);
8318
8319 Value *MapperHandle = MapperFn->getArg(i: 0);
8320 Value *BaseIn = MapperFn->getArg(i: 1);
8321 Value *BeginIn = MapperFn->getArg(i: 2);
8322 Value *Size = MapperFn->getArg(i: 3);
8323 Value *MapType = MapperFn->getArg(i: 4);
8324 Value *MapName = MapperFn->getArg(i: 5);
8325
8326 // Compute the starting and end addresses of array elements.
8327 // Prepare common arguments for array initiation and deletion.
8328 // Convert the size in bytes into the number of array elements.
8329 TypeSize ElementSize = M.getDataLayout().getTypeStoreSize(Ty: ElemTy);
8330 Size = Builder.CreateExactUDiv(LHS: Size, RHS: Builder.getInt64(C: ElementSize));
8331 Value *PtrBegin = BeginIn;
8332 Value *PtrEnd = Builder.CreateGEP(Ty: ElemTy, Ptr: PtrBegin, IdxList: Size);
8333
8334 // Emit array initiation if this is an array section and \p MapType indicates
8335 // that memory allocation is required.
8336 BasicBlock *HeadBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.arraymap.head");
8337 emitUDMapperArrayInitOrDel(MapperFn, MapperHandle, Base: BaseIn, Begin: BeginIn, Size,
8338 MapType, MapName, ElementSize, ExitBB: HeadBB,
8339 /*IsInit=*/true);
8340
8341 // Emit a for loop to iterate through SizeArg of elements and map all of them.
8342
8343 // Emit the loop header block.
8344 emitBlock(BB: HeadBB, CurFn: MapperFn);
8345 BasicBlock *BodyBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.arraymap.body");
8346 BasicBlock *DoneBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.done");
8347 // Evaluate whether the initial condition is satisfied.
8348 Value *IsEmpty =
8349 Builder.CreateICmpEQ(LHS: PtrBegin, RHS: PtrEnd, Name: "omp.arraymap.isempty");
8350 Builder.CreateCondBr(Cond: IsEmpty, True: DoneBB, False: BodyBB);
8351
8352 // Emit the loop body block.
8353 emitBlock(BB: BodyBB, CurFn: MapperFn);
8354 BasicBlock *LastBB = BodyBB;
8355 PHINode *PtrPHI =
8356 Builder.CreatePHI(Ty: PtrBegin->getType(), NumReservedValues: 2, Name: "omp.arraymap.ptrcurrent");
8357 PtrPHI->addIncoming(V: PtrBegin, BB: HeadBB);
8358
8359 // Get map clause information. Fill up the arrays with all mapped variables.
8360 MapInfosOrErrorTy Info = GenMapInfoCB(Builder.saveIP(), PtrPHI, BeginIn);
8361 if (!Info)
8362 return Info.takeError();
8363
8364 // Call the runtime API __tgt_mapper_num_components to get the number of
8365 // pre-existing components.
8366 Value *OffloadingArgs[] = {MapperHandle};
8367 Value *PreviousSize = Builder.CreateCall(
8368 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___tgt_mapper_num_components),
8369 Args: OffloadingArgs);
8370 Value *ShiftedPreviousSize =
8371 Builder.CreateShl(LHS: PreviousSize, RHS: Builder.getInt64(C: getFlagMemberOffset()));
8372
8373 // Fill up the runtime mapper handle for all components.
8374 for (unsigned I = 0; I < Info->BasePointers.size(); ++I) {
8375 Value *CurBaseArg = Info->BasePointers[I];
8376 Value *CurBeginArg = Info->Pointers[I];
8377 Value *CurSizeArg = Info->Sizes[I];
8378 Value *CurNameArg = Info->Names.size()
8379 ? Info->Names[I]
8380 : Constant::getNullValue(Ty: Builder.getPtrTy());
8381
8382 // Extract the MEMBER_OF field from the map type.
8383 Value *OriMapType = Builder.getInt64(
8384 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8385 Info->Types[I]));
8386 Value *MemberMapType =
8387 Builder.CreateNUWAdd(LHS: OriMapType, RHS: ShiftedPreviousSize);
8388
8389 // Combine the map type inherited from user-defined mapper with that
8390 // specified in the program. According to the OMP_MAP_TO and OMP_MAP_FROM
8391 // bits of the \a MapType, which is the input argument of the mapper
8392 // function, the following code will set the OMP_MAP_TO and OMP_MAP_FROM
8393 // bits of MemberMapType.
8394 // [OpenMP 5.0], 1.2.6. map-type decay.
8395 // | alloc | to | from | tofrom | release | delete
8396 // ----------------------------------------------------------
8397 // alloc | alloc | alloc | alloc | alloc | release | delete
8398 // to | alloc | to | alloc | to | release | delete
8399 // from | alloc | alloc | from | from | release | delete
8400 // tofrom | alloc | to | from | tofrom | release | delete
8401 Value *LeftToFrom = Builder.CreateAnd(
8402 LHS: MapType,
8403 RHS: Builder.getInt64(
8404 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8405 OpenMPOffloadMappingFlags::OMP_MAP_TO |
8406 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
8407 BasicBlock *AllocBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.alloc");
8408 BasicBlock *AllocElseBB =
8409 BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.alloc.else");
8410 BasicBlock *ToBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.to");
8411 BasicBlock *ToElseBB =
8412 BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.to.else");
8413 BasicBlock *FromBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.from");
8414 BasicBlock *EndBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.end");
8415 Value *IsAlloc = Builder.CreateIsNull(Arg: LeftToFrom);
8416 Builder.CreateCondBr(Cond: IsAlloc, True: AllocBB, False: AllocElseBB);
8417 // In case of alloc, clear OMP_MAP_TO and OMP_MAP_FROM.
8418 emitBlock(BB: AllocBB, CurFn: MapperFn);
8419 Value *AllocMapType = Builder.CreateAnd(
8420 LHS: MemberMapType,
8421 RHS: Builder.getInt64(
8422 C: ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8423 OpenMPOffloadMappingFlags::OMP_MAP_TO |
8424 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
8425 Builder.CreateBr(Dest: EndBB);
8426 emitBlock(BB: AllocElseBB, CurFn: MapperFn);
8427 Value *IsTo = Builder.CreateICmpEQ(
8428 LHS: LeftToFrom,
8429 RHS: Builder.getInt64(
8430 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8431 OpenMPOffloadMappingFlags::OMP_MAP_TO)));
8432 Builder.CreateCondBr(Cond: IsTo, True: ToBB, False: ToElseBB);
8433 // In case of to, clear OMP_MAP_FROM.
8434 emitBlock(BB: ToBB, CurFn: MapperFn);
8435 Value *ToMapType = Builder.CreateAnd(
8436 LHS: MemberMapType,
8437 RHS: Builder.getInt64(
8438 C: ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8439 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
8440 Builder.CreateBr(Dest: EndBB);
8441 emitBlock(BB: ToElseBB, CurFn: MapperFn);
8442 Value *IsFrom = Builder.CreateICmpEQ(
8443 LHS: LeftToFrom,
8444 RHS: Builder.getInt64(
8445 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8446 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
8447 Builder.CreateCondBr(Cond: IsFrom, True: FromBB, False: EndBB);
8448 // In case of from, clear OMP_MAP_TO.
8449 emitBlock(BB: FromBB, CurFn: MapperFn);
8450 Value *FromMapType = Builder.CreateAnd(
8451 LHS: MemberMapType,
8452 RHS: Builder.getInt64(
8453 C: ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8454 OpenMPOffloadMappingFlags::OMP_MAP_TO)));
8455 // In case of tofrom, do nothing.
8456 emitBlock(BB: EndBB, CurFn: MapperFn);
8457 LastBB = EndBB;
8458 PHINode *CurMapType =
8459 Builder.CreatePHI(Ty: Builder.getInt64Ty(), NumReservedValues: 4, Name: "omp.maptype");
8460 CurMapType->addIncoming(V: AllocMapType, BB: AllocBB);
8461 CurMapType->addIncoming(V: ToMapType, BB: ToBB);
8462 CurMapType->addIncoming(V: FromMapType, BB: FromBB);
8463 CurMapType->addIncoming(V: MemberMapType, BB: ToElseBB);
8464
8465 Value *OffloadingArgs[] = {MapperHandle, CurBaseArg, CurBeginArg,
8466 CurSizeArg, CurMapType, CurNameArg};
8467
8468 auto ChildMapperFn = CustomMapperCB(I);
8469 if (!ChildMapperFn)
8470 return ChildMapperFn.takeError();
8471 if (*ChildMapperFn) {
8472 // Call the corresponding mapper function.
8473 Builder.CreateCall(Callee: *ChildMapperFn, Args: OffloadingArgs)->setDoesNotThrow();
8474 } else {
8475 // Call the runtime API __tgt_push_mapper_component to fill up the runtime
8476 // data structure.
8477 Builder.CreateCall(
8478 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___tgt_push_mapper_component),
8479 Args: OffloadingArgs);
8480 }
8481 }
8482
8483 // Update the pointer to point to the next element that needs to be mapped,
8484 // and check whether we have mapped all elements.
8485 Value *PtrNext = Builder.CreateConstGEP1_32(Ty: ElemTy, Ptr: PtrPHI, /*Idx0=*/1,
8486 Name: "omp.arraymap.next");
8487 PtrPHI->addIncoming(V: PtrNext, BB: LastBB);
8488 Value *IsDone = Builder.CreateICmpEQ(LHS: PtrNext, RHS: PtrEnd, Name: "omp.arraymap.isdone");
8489 BasicBlock *ExitBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.arraymap.exit");
8490 Builder.CreateCondBr(Cond: IsDone, True: ExitBB, False: BodyBB);
8491
8492 emitBlock(BB: ExitBB, CurFn: MapperFn);
8493 // Emit array deletion if this is an array section and \p MapType indicates
8494 // that deletion is required.
8495 emitUDMapperArrayInitOrDel(MapperFn, MapperHandle, Base: BaseIn, Begin: BeginIn, Size,
8496 MapType, MapName, ElementSize, ExitBB: DoneBB,
8497 /*IsInit=*/false);
8498
8499 // Emit the function exit block.
8500 emitBlock(BB: DoneBB, CurFn: MapperFn, /*IsFinished=*/true);
8501
8502 Builder.CreateRetVoid();
8503 Builder.restoreIP(IP: SavedIP);
8504 return MapperFn;
8505}
8506
8507Error OpenMPIRBuilder::emitOffloadingArrays(
8508 InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
8509 TargetDataInfo &Info, CustomMapperCallbackTy CustomMapperCB,
8510 bool IsNonContiguous,
8511 function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
8512
8513 // Reset the array information.
8514 Info.clearArrayInfo();
8515 Info.NumberOfPtrs = CombinedInfo.BasePointers.size();
8516
8517 if (Info.NumberOfPtrs == 0)
8518 return Error::success();
8519
8520 Builder.restoreIP(IP: AllocaIP);
8521 // Detect if we have any capture size requiring runtime evaluation of the
8522 // size so that a constant array could be eventually used.
8523 ArrayType *PointerArrayType =
8524 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: Info.NumberOfPtrs);
8525
8526 Info.RTArgs.BasePointersArray = Builder.CreateAlloca(
8527 Ty: PointerArrayType, /* ArraySize = */ nullptr, Name: ".offload_baseptrs");
8528
8529 Info.RTArgs.PointersArray = Builder.CreateAlloca(
8530 Ty: PointerArrayType, /* ArraySize = */ nullptr, Name: ".offload_ptrs");
8531 AllocaInst *MappersArray = Builder.CreateAlloca(
8532 Ty: PointerArrayType, /* ArraySize = */ nullptr, Name: ".offload_mappers");
8533 Info.RTArgs.MappersArray = MappersArray;
8534
8535 // If we don't have any VLA types or other types that require runtime
8536 // evaluation, we can use a constant array for the map sizes, otherwise we
8537 // need to fill up the arrays as we do for the pointers.
8538 Type *Int64Ty = Builder.getInt64Ty();
8539 SmallVector<Constant *> ConstSizes(CombinedInfo.Sizes.size(),
8540 ConstantInt::get(Ty: Int64Ty, V: 0));
8541 SmallBitVector RuntimeSizes(CombinedInfo.Sizes.size());
8542 for (unsigned I = 0, E = CombinedInfo.Sizes.size(); I < E; ++I) {
8543 if (auto *CI = dyn_cast<Constant>(Val: CombinedInfo.Sizes[I])) {
8544 if (!isa<ConstantExpr>(Val: CI) && !isa<GlobalValue>(Val: CI)) {
8545 if (IsNonContiguous &&
8546 static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8547 CombinedInfo.Types[I] &
8548 OpenMPOffloadMappingFlags::OMP_MAP_NON_CONTIG))
8549 ConstSizes[I] =
8550 ConstantInt::get(Ty: Int64Ty, V: CombinedInfo.NonContigInfo.Dims[I]);
8551 else
8552 ConstSizes[I] = CI;
8553 continue;
8554 }
8555 }
8556 RuntimeSizes.set(I);
8557 }
8558
8559 if (RuntimeSizes.all()) {
8560 ArrayType *SizeArrayType = ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs);
8561 Info.RTArgs.SizesArray = Builder.CreateAlloca(
8562 Ty: SizeArrayType, /* ArraySize = */ nullptr, Name: ".offload_sizes");
8563 Builder.restoreIP(IP: CodeGenIP);
8564 } else {
8565 auto *SizesArrayInit = ConstantArray::get(
8566 T: ArrayType::get(ElementType: Int64Ty, NumElements: ConstSizes.size()), V: ConstSizes);
8567 std::string Name = createPlatformSpecificName(Parts: {"offload_sizes"});
8568 auto *SizesArrayGbl =
8569 new GlobalVariable(M, SizesArrayInit->getType(), /*isConstant=*/true,
8570 GlobalValue::PrivateLinkage, SizesArrayInit, Name);
8571 SizesArrayGbl->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
8572
8573 if (!RuntimeSizes.any()) {
8574 Info.RTArgs.SizesArray = SizesArrayGbl;
8575 } else {
8576 unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(AS: 0);
8577 Align OffloadSizeAlign = M.getDataLayout().getABIIntegerTypeAlignment(BitWidth: 64);
8578 ArrayType *SizeArrayType = ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs);
8579 AllocaInst *Buffer = Builder.CreateAlloca(
8580 Ty: SizeArrayType, /* ArraySize = */ nullptr, Name: ".offload_sizes");
8581 Buffer->setAlignment(OffloadSizeAlign);
8582 Builder.restoreIP(IP: CodeGenIP);
8583 Builder.CreateMemCpy(
8584 Dst: Buffer, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: Buffer->getType()),
8585 Src: SizesArrayGbl, SrcAlign: OffloadSizeAlign,
8586 Size: Builder.getIntN(
8587 N: IndexSize,
8588 C: Buffer->getAllocationSize(DL: M.getDataLayout())->getFixedValue()));
8589
8590 Info.RTArgs.SizesArray = Buffer;
8591 }
8592 Builder.restoreIP(IP: CodeGenIP);
8593 }
8594
8595 // The map types are always constant so we don't need to generate code to
8596 // fill arrays. Instead, we create an array constant.
8597 SmallVector<uint64_t, 4> Mapping;
8598 for (auto mapFlag : CombinedInfo.Types)
8599 Mapping.push_back(
8600 Elt: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8601 mapFlag));
8602 std::string MaptypesName = createPlatformSpecificName(Parts: {"offload_maptypes"});
8603 auto *MapTypesArrayGbl = createOffloadMaptypes(Mappings&: Mapping, VarName: MaptypesName);
8604 Info.RTArgs.MapTypesArray = MapTypesArrayGbl;
8605
8606 // The information types are only built if provided.
8607 if (!CombinedInfo.Names.empty()) {
8608 auto *MapNamesArrayGbl = createOffloadMapnames(
8609 Names&: CombinedInfo.Names, VarName: createPlatformSpecificName(Parts: {"offload_mapnames"}));
8610 Info.RTArgs.MapNamesArray = MapNamesArrayGbl;
8611 Info.EmitDebug = true;
8612 } else {
8613 Info.RTArgs.MapNamesArray =
8614 Constant::getNullValue(Ty: PointerType::getUnqual(C&: Builder.getContext()));
8615 Info.EmitDebug = false;
8616 }
8617
8618 // If there's a present map type modifier, it must not be applied to the end
8619 // of a region, so generate a separate map type array in that case.
8620 if (Info.separateBeginEndCalls()) {
8621 bool EndMapTypesDiffer = false;
8622 for (uint64_t &Type : Mapping) {
8623 if (Type & static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8624 OpenMPOffloadMappingFlags::OMP_MAP_PRESENT)) {
8625 Type &= ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8626 OpenMPOffloadMappingFlags::OMP_MAP_PRESENT);
8627 EndMapTypesDiffer = true;
8628 }
8629 }
8630 if (EndMapTypesDiffer) {
8631 MapTypesArrayGbl = createOffloadMaptypes(Mappings&: Mapping, VarName: MaptypesName);
8632 Info.RTArgs.MapTypesArrayEnd = MapTypesArrayGbl;
8633 }
8634 }
8635
8636 PointerType *PtrTy = Builder.getPtrTy();
8637 for (unsigned I = 0; I < Info.NumberOfPtrs; ++I) {
8638 Value *BPVal = CombinedInfo.BasePointers[I];
8639 Value *BP = Builder.CreateConstInBoundsGEP2_32(
8640 Ty: ArrayType::get(ElementType: PtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.BasePointersArray,
8641 Idx0: 0, Idx1: I);
8642 Builder.CreateAlignedStore(Val: BPVal, Ptr: BP,
8643 Align: M.getDataLayout().getPrefTypeAlign(Ty: PtrTy));
8644
8645 if (Info.requiresDevicePointerInfo()) {
8646 if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Pointer) {
8647 CodeGenIP = Builder.saveIP();
8648 Builder.restoreIP(IP: AllocaIP);
8649 Info.DevicePtrInfoMap[BPVal] = {BP, Builder.CreateAlloca(Ty: PtrTy)};
8650 Builder.restoreIP(IP: CodeGenIP);
8651 if (DeviceAddrCB)
8652 DeviceAddrCB(I, Info.DevicePtrInfoMap[BPVal].second);
8653 } else if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Address) {
8654 Info.DevicePtrInfoMap[BPVal] = {BP, BP};
8655 if (DeviceAddrCB)
8656 DeviceAddrCB(I, BP);
8657 }
8658 }
8659
8660 Value *PVal = CombinedInfo.Pointers[I];
8661 Value *P = Builder.CreateConstInBoundsGEP2_32(
8662 Ty: ArrayType::get(ElementType: PtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.PointersArray, Idx0: 0,
8663 Idx1: I);
8664 // TODO: Check alignment correct.
8665 Builder.CreateAlignedStore(Val: PVal, Ptr: P,
8666 Align: M.getDataLayout().getPrefTypeAlign(Ty: PtrTy));
8667
8668 if (RuntimeSizes.test(Idx: I)) {
8669 Value *S = Builder.CreateConstInBoundsGEP2_32(
8670 Ty: ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.SizesArray,
8671 /*Idx0=*/0,
8672 /*Idx1=*/I);
8673 Builder.CreateAlignedStore(Val: Builder.CreateIntCast(V: CombinedInfo.Sizes[I],
8674 DestTy: Int64Ty,
8675 /*isSigned=*/true),
8676 Ptr: S, Align: M.getDataLayout().getPrefTypeAlign(Ty: PtrTy));
8677 }
8678 // Fill up the mapper array.
8679 unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(AS: 0);
8680 Value *MFunc = ConstantPointerNull::get(T: PtrTy);
8681
8682 auto CustomMFunc = CustomMapperCB(I);
8683 if (!CustomMFunc)
8684 return CustomMFunc.takeError();
8685 if (*CustomMFunc)
8686 MFunc = Builder.CreatePointerCast(V: *CustomMFunc, DestTy: PtrTy);
8687
8688 Value *MAddr = Builder.CreateInBoundsGEP(
8689 Ty: MappersArray->getAllocatedType(), Ptr: MappersArray,
8690 IdxList: {Builder.getIntN(N: IndexSize, C: 0), Builder.getIntN(N: IndexSize, C: I)});
8691 Builder.CreateAlignedStore(
8692 Val: MFunc, Ptr: MAddr, Align: M.getDataLayout().getPrefTypeAlign(Ty: MAddr->getType()));
8693 }
8694
8695 if (!IsNonContiguous || CombinedInfo.NonContigInfo.Offsets.empty() ||
8696 Info.NumberOfPtrs == 0)
8697 return Error::success();
8698 emitNonContiguousDescriptor(AllocaIP, CodeGenIP, CombinedInfo, Info);
8699 return Error::success();
8700}
8701
8702void OpenMPIRBuilder::emitBranch(BasicBlock *Target) {
8703 BasicBlock *CurBB = Builder.GetInsertBlock();
8704
8705 if (!CurBB || CurBB->getTerminator()) {
8706 // If there is no insert point or the previous block is already
8707 // terminated, don't touch it.
8708 } else {
8709 // Otherwise, create a fall-through branch.
8710 Builder.CreateBr(Dest: Target);
8711 }
8712
8713 Builder.ClearInsertionPoint();
8714}
8715
8716void OpenMPIRBuilder::emitBlock(BasicBlock *BB, Function *CurFn,
8717 bool IsFinished) {
8718 BasicBlock *CurBB = Builder.GetInsertBlock();
8719
8720 // Fall out of the current block (if necessary).
8721 emitBranch(Target: BB);
8722
8723 if (IsFinished && BB->use_empty()) {
8724 BB->eraseFromParent();
8725 return;
8726 }
8727
8728 // Place the block after the current block, if possible, or else at
8729 // the end of the function.
8730 if (CurBB && CurBB->getParent())
8731 CurFn->insert(Position: std::next(x: CurBB->getIterator()), BB);
8732 else
8733 CurFn->insert(Position: CurFn->end(), BB);
8734 Builder.SetInsertPoint(BB);
8735}
8736
8737Error OpenMPIRBuilder::emitIfClause(Value *Cond, BodyGenCallbackTy ThenGen,
8738 BodyGenCallbackTy ElseGen,
8739 InsertPointTy AllocaIP) {
8740 // If the condition constant folds and can be elided, try to avoid emitting
8741 // the condition and the dead arm of the if/else.
8742 if (auto *CI = dyn_cast<ConstantInt>(Val: Cond)) {
8743 auto CondConstant = CI->getSExtValue();
8744 if (CondConstant)
8745 return ThenGen(AllocaIP, Builder.saveIP());
8746
8747 return ElseGen(AllocaIP, Builder.saveIP());
8748 }
8749
8750 Function *CurFn = Builder.GetInsertBlock()->getParent();
8751
8752 // Otherwise, the condition did not fold, or we couldn't elide it. Just
8753 // emit the conditional branch.
8754 BasicBlock *ThenBlock = BasicBlock::Create(Context&: M.getContext(), Name: "omp_if.then");
8755 BasicBlock *ElseBlock = BasicBlock::Create(Context&: M.getContext(), Name: "omp_if.else");
8756 BasicBlock *ContBlock = BasicBlock::Create(Context&: M.getContext(), Name: "omp_if.end");
8757 Builder.CreateCondBr(Cond, True: ThenBlock, False: ElseBlock);
8758 // Emit the 'then' code.
8759 emitBlock(BB: ThenBlock, CurFn);
8760 if (Error Err = ThenGen(AllocaIP, Builder.saveIP()))
8761 return Err;
8762 emitBranch(Target: ContBlock);
8763 // Emit the 'else' code if present.
8764 // There is no need to emit line number for unconditional branch.
8765 emitBlock(BB: ElseBlock, CurFn);
8766 if (Error Err = ElseGen(AllocaIP, Builder.saveIP()))
8767 return Err;
8768 // There is no need to emit line number for unconditional branch.
8769 emitBranch(Target: ContBlock);
8770 // Emit the continuation block for code after the if.
8771 emitBlock(BB: ContBlock, CurFn, /*IsFinished=*/true);
8772 return Error::success();
8773}
8774
8775bool OpenMPIRBuilder::checkAndEmitFlushAfterAtomic(
8776 const LocationDescription &Loc, llvm::AtomicOrdering AO, AtomicKind AK) {
8777 assert(!(AO == AtomicOrdering::NotAtomic ||
8778 AO == llvm::AtomicOrdering::Unordered) &&
8779 "Unexpected Atomic Ordering.");
8780
8781 bool Flush = false;
8782 llvm::AtomicOrdering FlushAO = AtomicOrdering::Monotonic;
8783
8784 switch (AK) {
8785 case Read:
8786 if (AO == AtomicOrdering::Acquire || AO == AtomicOrdering::AcquireRelease ||
8787 AO == AtomicOrdering::SequentiallyConsistent) {
8788 FlushAO = AtomicOrdering::Acquire;
8789 Flush = true;
8790 }
8791 break;
8792 case Write:
8793 case Compare:
8794 case Update:
8795 if (AO == AtomicOrdering::Release || AO == AtomicOrdering::AcquireRelease ||
8796 AO == AtomicOrdering::SequentiallyConsistent) {
8797 FlushAO = AtomicOrdering::Release;
8798 Flush = true;
8799 }
8800 break;
8801 case Capture:
8802 switch (AO) {
8803 case AtomicOrdering::Acquire:
8804 FlushAO = AtomicOrdering::Acquire;
8805 Flush = true;
8806 break;
8807 case AtomicOrdering::Release:
8808 FlushAO = AtomicOrdering::Release;
8809 Flush = true;
8810 break;
8811 case AtomicOrdering::AcquireRelease:
8812 case AtomicOrdering::SequentiallyConsistent:
8813 FlushAO = AtomicOrdering::AcquireRelease;
8814 Flush = true;
8815 break;
8816 default:
8817 // do nothing - leave silently.
8818 break;
8819 }
8820 }
8821
8822 if (Flush) {
8823 // Currently Flush RT call still doesn't take memory_ordering, so for when
8824 // that happens, this tries to do the resolution of which atomic ordering
8825 // to use with but issue the flush call
8826 // TODO: pass `FlushAO` after memory ordering support is added
8827 (void)FlushAO;
8828 emitFlush(Loc);
8829 }
8830
8831 // for AO == AtomicOrdering::Monotonic and all other case combinations
8832 // do nothing
8833 return Flush;
8834}
8835
8836OpenMPIRBuilder::InsertPointTy
8837OpenMPIRBuilder::createAtomicRead(const LocationDescription &Loc,
8838 AtomicOpValue &X, AtomicOpValue &V,
8839 AtomicOrdering AO, InsertPointTy AllocaIP) {
8840 if (!updateToLocation(Loc))
8841 return Loc.IP;
8842
8843 assert(X.Var->getType()->isPointerTy() &&
8844 "OMP Atomic expects a pointer to target memory");
8845 Type *XElemTy = X.ElemTy;
8846 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
8847 XElemTy->isPointerTy() || XElemTy->isStructTy()) &&
8848 "OMP atomic read expected a scalar type");
8849
8850 Value *XRead = nullptr;
8851
8852 if (XElemTy->isIntegerTy()) {
8853 LoadInst *XLD =
8854 Builder.CreateLoad(Ty: XElemTy, Ptr: X.Var, isVolatile: X.IsVolatile, Name: "omp.atomic.read");
8855 XLD->setAtomic(Ordering: AO);
8856 XRead = cast<Value>(Val: XLD);
8857 } else if (XElemTy->isStructTy()) {
8858 // FIXME: Add checks to ensure __atomic_load is emitted iff the
8859 // target does not support `atomicrmw` of the size of the struct
8860 LoadInst *OldVal = Builder.CreateLoad(Ty: XElemTy, Ptr: X.Var, Name: "omp.atomic.read");
8861 OldVal->setAtomic(Ordering: AO);
8862 const DataLayout &LoadDL = OldVal->getModule()->getDataLayout();
8863 unsigned LoadSize =
8864 LoadDL.getTypeStoreSize(Ty: OldVal->getPointerOperand()->getType());
8865 OpenMPIRBuilder::AtomicInfo atomicInfo(
8866 &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
8867 OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X.Var);
8868 auto AtomicLoadRes = atomicInfo.EmitAtomicLoadLibcall(AO);
8869 XRead = AtomicLoadRes.first;
8870 OldVal->eraseFromParent();
8871 } else {
8872 // We need to perform atomic op as integer
8873 IntegerType *IntCastTy =
8874 IntegerType::get(C&: M.getContext(), NumBits: XElemTy->getScalarSizeInBits());
8875 LoadInst *XLoad =
8876 Builder.CreateLoad(Ty: IntCastTy, Ptr: X.Var, isVolatile: X.IsVolatile, Name: "omp.atomic.load");
8877 XLoad->setAtomic(Ordering: AO);
8878 if (XElemTy->isFloatingPointTy()) {
8879 XRead = Builder.CreateBitCast(V: XLoad, DestTy: XElemTy, Name: "atomic.flt.cast");
8880 } else {
8881 XRead = Builder.CreateIntToPtr(V: XLoad, DestTy: XElemTy, Name: "atomic.ptr.cast");
8882 }
8883 }
8884 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Read);
8885 Builder.CreateStore(Val: XRead, Ptr: V.Var, isVolatile: V.IsVolatile);
8886 return Builder.saveIP();
8887}
8888
8889OpenMPIRBuilder::InsertPointTy
8890OpenMPIRBuilder::createAtomicWrite(const LocationDescription &Loc,
8891 AtomicOpValue &X, Value *Expr,
8892 AtomicOrdering AO, InsertPointTy AllocaIP) {
8893 if (!updateToLocation(Loc))
8894 return Loc.IP;
8895
8896 assert(X.Var->getType()->isPointerTy() &&
8897 "OMP Atomic expects a pointer to target memory");
8898 Type *XElemTy = X.ElemTy;
8899 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
8900 XElemTy->isPointerTy() || XElemTy->isStructTy()) &&
8901 "OMP atomic write expected a scalar type");
8902
8903 if (XElemTy->isIntegerTy()) {
8904 StoreInst *XSt = Builder.CreateStore(Val: Expr, Ptr: X.Var, isVolatile: X.IsVolatile);
8905 XSt->setAtomic(Ordering: AO);
8906 } else if (XElemTy->isStructTy()) {
8907 LoadInst *OldVal = Builder.CreateLoad(Ty: XElemTy, Ptr: X.Var, Name: "omp.atomic.read");
8908 const DataLayout &LoadDL = OldVal->getModule()->getDataLayout();
8909 unsigned LoadSize =
8910 LoadDL.getTypeStoreSize(Ty: OldVal->getPointerOperand()->getType());
8911 OpenMPIRBuilder::AtomicInfo atomicInfo(
8912 &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
8913 OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X.Var);
8914 atomicInfo.EmitAtomicStoreLibcall(AO, Source: Expr);
8915 OldVal->eraseFromParent();
8916 } else {
8917 // We need to bitcast and perform atomic op as integers
8918 IntegerType *IntCastTy =
8919 IntegerType::get(C&: M.getContext(), NumBits: XElemTy->getScalarSizeInBits());
8920 Value *ExprCast =
8921 Builder.CreateBitCast(V: Expr, DestTy: IntCastTy, Name: "atomic.src.int.cast");
8922 StoreInst *XSt = Builder.CreateStore(Val: ExprCast, Ptr: X.Var, isVolatile: X.IsVolatile);
8923 XSt->setAtomic(Ordering: AO);
8924 }
8925
8926 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Write);
8927 return Builder.saveIP();
8928}
8929
8930OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createAtomicUpdate(
8931 const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
8932 Value *Expr, AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
8933 AtomicUpdateCallbackTy &UpdateOp, bool IsXBinopExpr) {
8934 assert(!isConflictIP(Loc.IP, AllocaIP) && "IPs must not be ambiguous");
8935 if (!updateToLocation(Loc))
8936 return Loc.IP;
8937
8938 LLVM_DEBUG({
8939 Type *XTy = X.Var->getType();
8940 assert(XTy->isPointerTy() &&
8941 "OMP Atomic expects a pointer to target memory");
8942 Type *XElemTy = X.ElemTy;
8943 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
8944 XElemTy->isPointerTy()) &&
8945 "OMP atomic update expected a scalar type");
8946 assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
8947 (RMWOp != AtomicRMWInst::UMax) && (RMWOp != AtomicRMWInst::UMin) &&
8948 "OpenMP atomic does not support LT or GT operations");
8949 });
8950
8951 Expected<std::pair<Value *, Value *>> AtomicResult =
8952 emitAtomicUpdate(AllocaIP, X: X.Var, XElemTy: X.ElemTy, Expr, AO, RMWOp, UpdateOp,
8953 VolatileX: X.IsVolatile, IsXBinopExpr);
8954 if (!AtomicResult)
8955 return AtomicResult.takeError();
8956 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Update);
8957 return Builder.saveIP();
8958}
8959
8960// FIXME: Duplicating AtomicExpand
8961Value *OpenMPIRBuilder::emitRMWOpAsInstruction(Value *Src1, Value *Src2,
8962 AtomicRMWInst::BinOp RMWOp) {
8963 switch (RMWOp) {
8964 case AtomicRMWInst::Add:
8965 return Builder.CreateAdd(LHS: Src1, RHS: Src2);
8966 case AtomicRMWInst::Sub:
8967 return Builder.CreateSub(LHS: Src1, RHS: Src2);
8968 case AtomicRMWInst::And:
8969 return Builder.CreateAnd(LHS: Src1, RHS: Src2);
8970 case AtomicRMWInst::Nand:
8971 return Builder.CreateNeg(V: Builder.CreateAnd(LHS: Src1, RHS: Src2));
8972 case AtomicRMWInst::Or:
8973 return Builder.CreateOr(LHS: Src1, RHS: Src2);
8974 case AtomicRMWInst::Xor:
8975 return Builder.CreateXor(LHS: Src1, RHS: Src2);
8976 case AtomicRMWInst::Xchg:
8977 case AtomicRMWInst::FAdd:
8978 case AtomicRMWInst::FSub:
8979 case AtomicRMWInst::BAD_BINOP:
8980 case AtomicRMWInst::Max:
8981 case AtomicRMWInst::Min:
8982 case AtomicRMWInst::UMax:
8983 case AtomicRMWInst::UMin:
8984 case AtomicRMWInst::FMax:
8985 case AtomicRMWInst::FMin:
8986 case AtomicRMWInst::FMaximum:
8987 case AtomicRMWInst::FMinimum:
8988 case AtomicRMWInst::UIncWrap:
8989 case AtomicRMWInst::UDecWrap:
8990 case AtomicRMWInst::USubCond:
8991 case AtomicRMWInst::USubSat:
8992 llvm_unreachable("Unsupported atomic update operation");
8993 }
8994 llvm_unreachable("Unsupported atomic update operation");
8995}
8996
8997Expected<std::pair<Value *, Value *>> OpenMPIRBuilder::emitAtomicUpdate(
8998 InsertPointTy AllocaIP, Value *X, Type *XElemTy, Value *Expr,
8999 AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
9000 AtomicUpdateCallbackTy &UpdateOp, bool VolatileX, bool IsXBinopExpr) {
9001 // TODO: handle the case where XElemTy is not byte-sized or not a power of 2
9002 // or a complex datatype.
9003 bool emitRMWOp = false;
9004 switch (RMWOp) {
9005 case AtomicRMWInst::Add:
9006 case AtomicRMWInst::And:
9007 case AtomicRMWInst::Nand:
9008 case AtomicRMWInst::Or:
9009 case AtomicRMWInst::Xor:
9010 case AtomicRMWInst::Xchg:
9011 emitRMWOp = XElemTy;
9012 break;
9013 case AtomicRMWInst::Sub:
9014 emitRMWOp = (IsXBinopExpr && XElemTy);
9015 break;
9016 default:
9017 emitRMWOp = false;
9018 }
9019 emitRMWOp &= XElemTy->isIntegerTy();
9020
9021 std::pair<Value *, Value *> Res;
9022 if (emitRMWOp) {
9023 Res.first = Builder.CreateAtomicRMW(Op: RMWOp, Ptr: X, Val: Expr, Align: llvm::MaybeAlign(), Ordering: AO);
9024 // not needed except in case of postfix captures. Generate anyway for
9025 // consistency with the else part. Will be removed with any DCE pass.
9026 // AtomicRMWInst::Xchg does not have a coressponding instruction.
9027 if (RMWOp == AtomicRMWInst::Xchg)
9028 Res.second = Res.first;
9029 else
9030 Res.second = emitRMWOpAsInstruction(Src1: Res.first, Src2: Expr, RMWOp);
9031 } else if (RMWOp == llvm::AtomicRMWInst::BinOp::BAD_BINOP &&
9032 XElemTy->isStructTy()) {
9033 LoadInst *OldVal =
9034 Builder.CreateLoad(Ty: XElemTy, Ptr: X, Name: X->getName() + ".atomic.load");
9035 OldVal->setAtomic(Ordering: AO);
9036 const DataLayout &LoadDL = OldVal->getModule()->getDataLayout();
9037 unsigned LoadSize =
9038 LoadDL.getTypeStoreSize(Ty: OldVal->getPointerOperand()->getType());
9039
9040 OpenMPIRBuilder::AtomicInfo atomicInfo(
9041 &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
9042 OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X);
9043 auto AtomicLoadRes = atomicInfo.EmitAtomicLoadLibcall(AO);
9044 BasicBlock *CurBB = Builder.GetInsertBlock();
9045 Instruction *CurBBTI = CurBB->getTerminator();
9046 CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
9047 BasicBlock *ExitBB =
9048 CurBB->splitBasicBlock(I: CurBBTI, BBName: X->getName() + ".atomic.exit");
9049 BasicBlock *ContBB = CurBB->splitBasicBlock(I: CurBB->getTerminator(),
9050 BBName: X->getName() + ".atomic.cont");
9051 ContBB->getTerminator()->eraseFromParent();
9052 Builder.restoreIP(IP: AllocaIP);
9053 AllocaInst *NewAtomicAddr = Builder.CreateAlloca(Ty: XElemTy);
9054 NewAtomicAddr->setName(X->getName() + "x.new.val");
9055 Builder.SetInsertPoint(ContBB);
9056 llvm::PHINode *PHI = Builder.CreatePHI(Ty: OldVal->getType(), NumReservedValues: 2);
9057 PHI->addIncoming(V: AtomicLoadRes.first, BB: CurBB);
9058 Value *OldExprVal = PHI;
9059 Expected<Value *> CBResult = UpdateOp(OldExprVal, Builder);
9060 if (!CBResult)
9061 return CBResult.takeError();
9062 Value *Upd = *CBResult;
9063 Builder.CreateStore(Val: Upd, Ptr: NewAtomicAddr);
9064 AtomicOrdering Failure =
9065 llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(SuccessOrdering: AO);
9066 auto Result = atomicInfo.EmitAtomicCompareExchangeLibcall(
9067 ExpectedVal: AtomicLoadRes.second, DesiredVal: NewAtomicAddr, Success: AO, Failure);
9068 LoadInst *PHILoad = Builder.CreateLoad(Ty: XElemTy, Ptr: Result.first);
9069 PHI->addIncoming(V: PHILoad, BB: Builder.GetInsertBlock());
9070 Builder.CreateCondBr(Cond: Result.second, True: ExitBB, False: ContBB);
9071 OldVal->eraseFromParent();
9072 Res.first = OldExprVal;
9073 Res.second = Upd;
9074
9075 if (UnreachableInst *ExitTI =
9076 dyn_cast<UnreachableInst>(Val: ExitBB->getTerminator())) {
9077 CurBBTI->eraseFromParent();
9078 Builder.SetInsertPoint(ExitBB);
9079 } else {
9080 Builder.SetInsertPoint(ExitTI);
9081 }
9082 } else {
9083 IntegerType *IntCastTy =
9084 IntegerType::get(C&: M.getContext(), NumBits: XElemTy->getScalarSizeInBits());
9085 LoadInst *OldVal =
9086 Builder.CreateLoad(Ty: IntCastTy, Ptr: X, Name: X->getName() + ".atomic.load");
9087 OldVal->setAtomic(Ordering: AO);
9088 // CurBB
9089 // | /---\
9090 // ContBB |
9091 // | \---/
9092 // ExitBB
9093 BasicBlock *CurBB = Builder.GetInsertBlock();
9094 Instruction *CurBBTI = CurBB->getTerminator();
9095 CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
9096 BasicBlock *ExitBB =
9097 CurBB->splitBasicBlock(I: CurBBTI, BBName: X->getName() + ".atomic.exit");
9098 BasicBlock *ContBB = CurBB->splitBasicBlock(I: CurBB->getTerminator(),
9099 BBName: X->getName() + ".atomic.cont");
9100 ContBB->getTerminator()->eraseFromParent();
9101 Builder.restoreIP(IP: AllocaIP);
9102 AllocaInst *NewAtomicAddr = Builder.CreateAlloca(Ty: XElemTy);
9103 NewAtomicAddr->setName(X->getName() + "x.new.val");
9104 Builder.SetInsertPoint(ContBB);
9105 llvm::PHINode *PHI = Builder.CreatePHI(Ty: OldVal->getType(), NumReservedValues: 2);
9106 PHI->addIncoming(V: OldVal, BB: CurBB);
9107 bool IsIntTy = XElemTy->isIntegerTy();
9108 Value *OldExprVal = PHI;
9109 if (!IsIntTy) {
9110 if (XElemTy->isFloatingPointTy()) {
9111 OldExprVal = Builder.CreateBitCast(V: PHI, DestTy: XElemTy,
9112 Name: X->getName() + ".atomic.fltCast");
9113 } else {
9114 OldExprVal = Builder.CreateIntToPtr(V: PHI, DestTy: XElemTy,
9115 Name: X->getName() + ".atomic.ptrCast");
9116 }
9117 }
9118
9119 Expected<Value *> CBResult = UpdateOp(OldExprVal, Builder);
9120 if (!CBResult)
9121 return CBResult.takeError();
9122 Value *Upd = *CBResult;
9123 Builder.CreateStore(Val: Upd, Ptr: NewAtomicAddr);
9124 LoadInst *DesiredVal = Builder.CreateLoad(Ty: IntCastTy, Ptr: NewAtomicAddr);
9125 AtomicOrdering Failure =
9126 llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(SuccessOrdering: AO);
9127 AtomicCmpXchgInst *Result = Builder.CreateAtomicCmpXchg(
9128 Ptr: X, Cmp: PHI, New: DesiredVal, Align: llvm::MaybeAlign(), SuccessOrdering: AO, FailureOrdering: Failure);
9129 Result->setVolatile(VolatileX);
9130 Value *PreviousVal = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/0);
9131 Value *SuccessFailureVal = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/1);
9132 PHI->addIncoming(V: PreviousVal, BB: Builder.GetInsertBlock());
9133 Builder.CreateCondBr(Cond: SuccessFailureVal, True: ExitBB, False: ContBB);
9134
9135 Res.first = OldExprVal;
9136 Res.second = Upd;
9137
9138 // set Insertion point in exit block
9139 if (UnreachableInst *ExitTI =
9140 dyn_cast<UnreachableInst>(Val: ExitBB->getTerminator())) {
9141 CurBBTI->eraseFromParent();
9142 Builder.SetInsertPoint(ExitBB);
9143 } else {
9144 Builder.SetInsertPoint(ExitTI);
9145 }
9146 }
9147
9148 return Res;
9149}
9150
9151OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createAtomicCapture(
9152 const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
9153 AtomicOpValue &V, Value *Expr, AtomicOrdering AO,
9154 AtomicRMWInst::BinOp RMWOp, AtomicUpdateCallbackTy &UpdateOp,
9155 bool UpdateExpr, bool IsPostfixUpdate, bool IsXBinopExpr) {
9156 if (!updateToLocation(Loc))
9157 return Loc.IP;
9158
9159 LLVM_DEBUG({
9160 Type *XTy = X.Var->getType();
9161 assert(XTy->isPointerTy() &&
9162 "OMP Atomic expects a pointer to target memory");
9163 Type *XElemTy = X.ElemTy;
9164 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
9165 XElemTy->isPointerTy()) &&
9166 "OMP atomic capture expected a scalar type");
9167 assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
9168 "OpenMP atomic does not support LT or GT operations");
9169 });
9170
9171 // If UpdateExpr is 'x' updated with some `expr` not based on 'x',
9172 // 'x' is simply atomically rewritten with 'expr'.
9173 AtomicRMWInst::BinOp AtomicOp = (UpdateExpr ? RMWOp : AtomicRMWInst::Xchg);
9174 Expected<std::pair<Value *, Value *>> AtomicResult =
9175 emitAtomicUpdate(AllocaIP, X: X.Var, XElemTy: X.ElemTy, Expr, AO, RMWOp: AtomicOp, UpdateOp,
9176 VolatileX: X.IsVolatile, IsXBinopExpr);
9177 if (!AtomicResult)
9178 return AtomicResult.takeError();
9179 Value *CapturedVal =
9180 (IsPostfixUpdate ? AtomicResult->first : AtomicResult->second);
9181 Builder.CreateStore(Val: CapturedVal, Ptr: V.Var, isVolatile: V.IsVolatile);
9182
9183 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Capture);
9184 return Builder.saveIP();
9185}
9186
9187OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
9188 const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
9189 AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
9190 omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
9191 bool IsFailOnly) {
9192
9193 AtomicOrdering Failure = AtomicCmpXchgInst::getStrongestFailureOrdering(SuccessOrdering: AO);
9194 return createAtomicCompare(Loc, X, V, R, E, D, AO, Op, IsXBinopExpr,
9195 IsPostfixUpdate, IsFailOnly, Failure);
9196}
9197
9198OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
9199 const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
9200 AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
9201 omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
9202 bool IsFailOnly, AtomicOrdering Failure) {
9203
9204 if (!updateToLocation(Loc))
9205 return Loc.IP;
9206
9207 assert(X.Var->getType()->isPointerTy() &&
9208 "OMP atomic expects a pointer to target memory");
9209 // compare capture
9210 if (V.Var) {
9211 assert(V.Var->getType()->isPointerTy() && "v.var must be of pointer type");
9212 assert(V.ElemTy == X.ElemTy && "x and v must be of same type");
9213 }
9214
9215 bool IsInteger = E->getType()->isIntegerTy();
9216
9217 if (Op == OMPAtomicCompareOp::EQ) {
9218 AtomicCmpXchgInst *Result = nullptr;
9219 if (!IsInteger) {
9220 IntegerType *IntCastTy =
9221 IntegerType::get(C&: M.getContext(), NumBits: X.ElemTy->getScalarSizeInBits());
9222 Value *EBCast = Builder.CreateBitCast(V: E, DestTy: IntCastTy);
9223 Value *DBCast = Builder.CreateBitCast(V: D, DestTy: IntCastTy);
9224 Result = Builder.CreateAtomicCmpXchg(Ptr: X.Var, Cmp: EBCast, New: DBCast, Align: MaybeAlign(),
9225 SuccessOrdering: AO, FailureOrdering: Failure);
9226 } else {
9227 Result =
9228 Builder.CreateAtomicCmpXchg(Ptr: X.Var, Cmp: E, New: D, Align: MaybeAlign(), SuccessOrdering: AO, FailureOrdering: Failure);
9229 }
9230
9231 if (V.Var) {
9232 Value *OldValue = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/0);
9233 if (!IsInteger)
9234 OldValue = Builder.CreateBitCast(V: OldValue, DestTy: X.ElemTy);
9235 assert(OldValue->getType() == V.ElemTy &&
9236 "OldValue and V must be of same type");
9237 if (IsPostfixUpdate) {
9238 Builder.CreateStore(Val: OldValue, Ptr: V.Var, isVolatile: V.IsVolatile);
9239 } else {
9240 Value *SuccessOrFail = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/1);
9241 if (IsFailOnly) {
9242 // CurBB----
9243 // | |
9244 // v |
9245 // ContBB |
9246 // | |
9247 // v |
9248 // ExitBB <-
9249 //
9250 // where ContBB only contains the store of old value to 'v'.
9251 BasicBlock *CurBB = Builder.GetInsertBlock();
9252 Instruction *CurBBTI = CurBB->getTerminator();
9253 CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
9254 BasicBlock *ExitBB = CurBB->splitBasicBlock(
9255 I: CurBBTI, BBName: X.Var->getName() + ".atomic.exit");
9256 BasicBlock *ContBB = CurBB->splitBasicBlock(
9257 I: CurBB->getTerminator(), BBName: X.Var->getName() + ".atomic.cont");
9258 ContBB->getTerminator()->eraseFromParent();
9259 CurBB->getTerminator()->eraseFromParent();
9260
9261 Builder.CreateCondBr(Cond: SuccessOrFail, True: ExitBB, False: ContBB);
9262
9263 Builder.SetInsertPoint(ContBB);
9264 Builder.CreateStore(Val: OldValue, Ptr: V.Var);
9265 Builder.CreateBr(Dest: ExitBB);
9266
9267 if (UnreachableInst *ExitTI =
9268 dyn_cast<UnreachableInst>(Val: ExitBB->getTerminator())) {
9269 CurBBTI->eraseFromParent();
9270 Builder.SetInsertPoint(ExitBB);
9271 } else {
9272 Builder.SetInsertPoint(ExitTI);
9273 }
9274 } else {
9275 Value *CapturedValue =
9276 Builder.CreateSelect(C: SuccessOrFail, True: E, False: OldValue);
9277 Builder.CreateStore(Val: CapturedValue, Ptr: V.Var, isVolatile: V.IsVolatile);
9278 }
9279 }
9280 }
9281 // The comparison result has to be stored.
9282 if (R.Var) {
9283 assert(R.Var->getType()->isPointerTy() &&
9284 "r.var must be of pointer type");
9285 assert(R.ElemTy->isIntegerTy() && "r must be of integral type");
9286
9287 Value *SuccessFailureVal = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/1);
9288 Value *ResultCast = R.IsSigned
9289 ? Builder.CreateSExt(V: SuccessFailureVal, DestTy: R.ElemTy)
9290 : Builder.CreateZExt(V: SuccessFailureVal, DestTy: R.ElemTy);
9291 Builder.CreateStore(Val: ResultCast, Ptr: R.Var, isVolatile: R.IsVolatile);
9292 }
9293 } else {
9294 assert((Op == OMPAtomicCompareOp::MAX || Op == OMPAtomicCompareOp::MIN) &&
9295 "Op should be either max or min at this point");
9296 assert(!IsFailOnly && "IsFailOnly is only valid when the comparison is ==");
9297
9298 // Reverse the ordop as the OpenMP forms are different from LLVM forms.
9299 // Let's take max as example.
9300 // OpenMP form:
9301 // x = x > expr ? expr : x;
9302 // LLVM form:
9303 // *ptr = *ptr > val ? *ptr : val;
9304 // We need to transform to LLVM form.
9305 // x = x <= expr ? x : expr;
9306 AtomicRMWInst::BinOp NewOp;
9307 if (IsXBinopExpr) {
9308 if (IsInteger) {
9309 if (X.IsSigned)
9310 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Min
9311 : AtomicRMWInst::Max;
9312 else
9313 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMin
9314 : AtomicRMWInst::UMax;
9315 } else {
9316 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMin
9317 : AtomicRMWInst::FMax;
9318 }
9319 } else {
9320 if (IsInteger) {
9321 if (X.IsSigned)
9322 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Max
9323 : AtomicRMWInst::Min;
9324 else
9325 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMax
9326 : AtomicRMWInst::UMin;
9327 } else {
9328 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMax
9329 : AtomicRMWInst::FMin;
9330 }
9331 }
9332
9333 AtomicRMWInst *OldValue =
9334 Builder.CreateAtomicRMW(Op: NewOp, Ptr: X.Var, Val: E, Align: MaybeAlign(), Ordering: AO);
9335 if (V.Var) {
9336 Value *CapturedValue = nullptr;
9337 if (IsPostfixUpdate) {
9338 CapturedValue = OldValue;
9339 } else {
9340 CmpInst::Predicate Pred;
9341 switch (NewOp) {
9342 case AtomicRMWInst::Max:
9343 Pred = CmpInst::ICMP_SGT;
9344 break;
9345 case AtomicRMWInst::UMax:
9346 Pred = CmpInst::ICMP_UGT;
9347 break;
9348 case AtomicRMWInst::FMax:
9349 Pred = CmpInst::FCMP_OGT;
9350 break;
9351 case AtomicRMWInst::Min:
9352 Pred = CmpInst::ICMP_SLT;
9353 break;
9354 case AtomicRMWInst::UMin:
9355 Pred = CmpInst::ICMP_ULT;
9356 break;
9357 case AtomicRMWInst::FMin:
9358 Pred = CmpInst::FCMP_OLT;
9359 break;
9360 default:
9361 llvm_unreachable("unexpected comparison op");
9362 }
9363 Value *NonAtomicCmp = Builder.CreateCmp(Pred, LHS: OldValue, RHS: E);
9364 CapturedValue = Builder.CreateSelect(C: NonAtomicCmp, True: E, False: OldValue);
9365 }
9366 Builder.CreateStore(Val: CapturedValue, Ptr: V.Var, isVolatile: V.IsVolatile);
9367 }
9368 }
9369
9370 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Compare);
9371
9372 return Builder.saveIP();
9373}
9374
9375OpenMPIRBuilder::InsertPointOrErrorTy
9376OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
9377 BodyGenCallbackTy BodyGenCB, Value *NumTeamsLower,
9378 Value *NumTeamsUpper, Value *ThreadLimit,
9379 Value *IfExpr) {
9380 if (!updateToLocation(Loc))
9381 return InsertPointTy();
9382
9383 uint32_t SrcLocStrSize;
9384 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
9385 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
9386 Function *CurrentFunction = Builder.GetInsertBlock()->getParent();
9387
9388 // Outer allocation basicblock is the entry block of the current function.
9389 BasicBlock &OuterAllocaBB = CurrentFunction->getEntryBlock();
9390 if (&OuterAllocaBB == Builder.GetInsertBlock()) {
9391 BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, Name: "teams.entry");
9392 Builder.SetInsertPoint(TheBB: BodyBB, IP: BodyBB->begin());
9393 }
9394
9395 // The current basic block is split into four basic blocks. After outlining,
9396 // they will be mapped as follows:
9397 // ```
9398 // def current_fn() {
9399 // current_basic_block:
9400 // br label %teams.exit
9401 // teams.exit:
9402 // ; instructions after teams
9403 // }
9404 //
9405 // def outlined_fn() {
9406 // teams.alloca:
9407 // br label %teams.body
9408 // teams.body:
9409 // ; instructions within teams body
9410 // }
9411 // ```
9412 BasicBlock *ExitBB = splitBB(Builder, /*CreateBranch=*/true, Name: "teams.exit");
9413 BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, Name: "teams.body");
9414 BasicBlock *AllocaBB =
9415 splitBB(Builder, /*CreateBranch=*/true, Name: "teams.alloca");
9416
9417 bool SubClausesPresent =
9418 (NumTeamsLower || NumTeamsUpper || ThreadLimit || IfExpr);
9419 // Push num_teams
9420 if (!Config.isTargetDevice() && SubClausesPresent) {
9421 assert((NumTeamsLower == nullptr || NumTeamsUpper != nullptr) &&
9422 "if lowerbound is non-null, then upperbound must also be non-null "
9423 "for bounds on num_teams");
9424
9425 if (NumTeamsUpper == nullptr)
9426 NumTeamsUpper = Builder.getInt32(C: 0);
9427
9428 if (NumTeamsLower == nullptr)
9429 NumTeamsLower = NumTeamsUpper;
9430
9431 if (IfExpr) {
9432 assert(IfExpr->getType()->isIntegerTy() &&
9433 "argument to if clause must be an integer value");
9434
9435 // upper = ifexpr ? upper : 1
9436 if (IfExpr->getType() != Int1)
9437 IfExpr = Builder.CreateICmpNE(LHS: IfExpr,
9438 RHS: ConstantInt::get(Ty: IfExpr->getType(), V: 0));
9439 NumTeamsUpper = Builder.CreateSelect(
9440 C: IfExpr, True: NumTeamsUpper, False: Builder.getInt32(C: 1), Name: "numTeamsUpper");
9441
9442 // lower = ifexpr ? lower : 1
9443 NumTeamsLower = Builder.CreateSelect(
9444 C: IfExpr, True: NumTeamsLower, False: Builder.getInt32(C: 1), Name: "numTeamsLower");
9445 }
9446
9447 if (ThreadLimit == nullptr)
9448 ThreadLimit = Builder.getInt32(C: 0);
9449
9450 Value *ThreadNum = getOrCreateThreadID(Ident);
9451 Builder.CreateCall(
9452 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_push_num_teams_51),
9453 Args: {Ident, ThreadNum, NumTeamsLower, NumTeamsUpper, ThreadLimit});
9454 }
9455 // Generate the body of teams.
9456 InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
9457 InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
9458 if (Error Err = BodyGenCB(AllocaIP, CodeGenIP))
9459 return Err;
9460
9461 OutlineInfo OI;
9462 OI.EntryBB = AllocaBB;
9463 OI.ExitBB = ExitBB;
9464 OI.OuterAllocaBB = &OuterAllocaBB;
9465
9466 // Insert fake values for global tid and bound tid.
9467 SmallVector<Instruction *, 8> ToBeDeleted;
9468 InsertPointTy OuterAllocaIP(&OuterAllocaBB, OuterAllocaBB.begin());
9469 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
9470 Builder, OuterAllocaIP, ToBeDeleted, InnerAllocaIP: AllocaIP, Name: "gid", AsPtr: true));
9471 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
9472 Builder, OuterAllocaIP, ToBeDeleted, InnerAllocaIP: AllocaIP, Name: "tid", AsPtr: true));
9473
9474 auto HostPostOutlineCB = [this, Ident,
9475 ToBeDeleted](Function &OutlinedFn) mutable {
9476 // The stale call instruction will be replaced with a new call instruction
9477 // for runtime call with the outlined function.
9478
9479 assert(OutlinedFn.hasOneUse() &&
9480 "there must be a single user for the outlined function");
9481 CallInst *StaleCI = cast<CallInst>(Val: OutlinedFn.user_back());
9482 ToBeDeleted.push_back(Elt: StaleCI);
9483
9484 assert((OutlinedFn.arg_size() == 2 || OutlinedFn.arg_size() == 3) &&
9485 "Outlined function must have two or three arguments only");
9486
9487 bool HasShared = OutlinedFn.arg_size() == 3;
9488
9489 OutlinedFn.getArg(i: 0)->setName("global.tid.ptr");
9490 OutlinedFn.getArg(i: 1)->setName("bound.tid.ptr");
9491 if (HasShared)
9492 OutlinedFn.getArg(i: 2)->setName("data");
9493
9494 // Call to the runtime function for teams in the current function.
9495 assert(StaleCI && "Error while outlining - no CallInst user found for the "
9496 "outlined function.");
9497 Builder.SetInsertPoint(StaleCI);
9498 SmallVector<Value *> Args = {
9499 Ident, Builder.getInt32(C: StaleCI->arg_size() - 2), &OutlinedFn};
9500 if (HasShared)
9501 Args.push_back(Elt: StaleCI->getArgOperand(i: 2));
9502 Builder.CreateCall(Callee: getOrCreateRuntimeFunctionPtr(
9503 FnID: omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
9504 Args);
9505
9506 for (Instruction *I : llvm::reverse(C&: ToBeDeleted))
9507 I->eraseFromParent();
9508 };
9509
9510 if (!Config.isTargetDevice())
9511 OI.PostOutlineCB = HostPostOutlineCB;
9512
9513 addOutlineInfo(OI: std::move(OI));
9514
9515 Builder.SetInsertPoint(TheBB: ExitBB, IP: ExitBB->begin());
9516
9517 return Builder.saveIP();
9518}
9519
9520OpenMPIRBuilder::InsertPointOrErrorTy
9521OpenMPIRBuilder::createDistribute(const LocationDescription &Loc,
9522 InsertPointTy OuterAllocaIP,
9523 BodyGenCallbackTy BodyGenCB) {
9524 if (!updateToLocation(Loc))
9525 return InsertPointTy();
9526
9527 BasicBlock *OuterAllocaBB = OuterAllocaIP.getBlock();
9528
9529 if (OuterAllocaBB == Builder.GetInsertBlock()) {
9530 BasicBlock *BodyBB =
9531 splitBB(Builder, /*CreateBranch=*/true, Name: "distribute.entry");
9532 Builder.SetInsertPoint(TheBB: BodyBB, IP: BodyBB->begin());
9533 }
9534 BasicBlock *ExitBB =
9535 splitBB(Builder, /*CreateBranch=*/true, Name: "distribute.exit");
9536 BasicBlock *BodyBB =
9537 splitBB(Builder, /*CreateBranch=*/true, Name: "distribute.body");
9538 BasicBlock *AllocaBB =
9539 splitBB(Builder, /*CreateBranch=*/true, Name: "distribute.alloca");
9540
9541 // Generate the body of distribute clause
9542 InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
9543 InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
9544 if (Error Err = BodyGenCB(AllocaIP, CodeGenIP))
9545 return Err;
9546
9547 OutlineInfo OI;
9548 OI.OuterAllocaBB = OuterAllocaIP.getBlock();
9549 OI.EntryBB = AllocaBB;
9550 OI.ExitBB = ExitBB;
9551
9552 addOutlineInfo(OI: std::move(OI));
9553 Builder.SetInsertPoint(TheBB: ExitBB, IP: ExitBB->begin());
9554
9555 return Builder.saveIP();
9556}
9557
9558GlobalVariable *
9559OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl<llvm::Constant *> &Names,
9560 std::string VarName) {
9561 llvm::Constant *MapNamesArrayInit = llvm::ConstantArray::get(
9562 T: llvm::ArrayType::get(ElementType: llvm::PointerType::getUnqual(C&: M.getContext()),
9563 NumElements: Names.size()),
9564 V: Names);
9565 auto *MapNamesArrayGlobal = new llvm::GlobalVariable(
9566 M, MapNamesArrayInit->getType(),
9567 /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MapNamesArrayInit,
9568 VarName);
9569 return MapNamesArrayGlobal;
9570}
9571
9572// Create all simple and struct types exposed by the runtime and remember
9573// the llvm::PointerTypes of them for easy access later.
9574void OpenMPIRBuilder::initializeTypes(Module &M) {
9575 LLVMContext &Ctx = M.getContext();
9576 StructType *T;
9577#define OMP_TYPE(VarName, InitValue) VarName = InitValue;
9578#define OMP_ARRAY_TYPE(VarName, ElemTy, ArraySize) \
9579 VarName##Ty = ArrayType::get(ElemTy, ArraySize); \
9580 VarName##PtrTy = PointerType::getUnqual(Ctx);
9581#define OMP_FUNCTION_TYPE(VarName, IsVarArg, ReturnType, ...) \
9582 VarName = FunctionType::get(ReturnType, {__VA_ARGS__}, IsVarArg); \
9583 VarName##Ptr = PointerType::getUnqual(Ctx);
9584#define OMP_STRUCT_TYPE(VarName, StructName, Packed, ...) \
9585 T = StructType::getTypeByName(Ctx, StructName); \
9586 if (!T) \
9587 T = StructType::create(Ctx, {__VA_ARGS__}, StructName, Packed); \
9588 VarName = T; \
9589 VarName##Ptr = PointerType::getUnqual(Ctx);
9590#include "llvm/Frontend/OpenMP/OMPKinds.def"
9591}
9592
9593void OpenMPIRBuilder::OutlineInfo::collectBlocks(
9594 SmallPtrSetImpl<BasicBlock *> &BlockSet,
9595 SmallVectorImpl<BasicBlock *> &BlockVector) {
9596 SmallVector<BasicBlock *, 32> Worklist;
9597 BlockSet.insert(Ptr: EntryBB);
9598 BlockSet.insert(Ptr: ExitBB);
9599
9600 Worklist.push_back(Elt: EntryBB);
9601 while (!Worklist.empty()) {
9602 BasicBlock *BB = Worklist.pop_back_val();
9603 BlockVector.push_back(Elt: BB);
9604 for (BasicBlock *SuccBB : successors(BB))
9605 if (BlockSet.insert(Ptr: SuccBB).second)
9606 Worklist.push_back(Elt: SuccBB);
9607 }
9608}
9609
9610void OpenMPIRBuilder::createOffloadEntry(Constant *ID, Constant *Addr,
9611 uint64_t Size, int32_t Flags,
9612 GlobalValue::LinkageTypes,
9613 StringRef Name) {
9614 if (!Config.isGPU()) {
9615 llvm::offloading::emitOffloadingEntry(
9616 M, Kind: object::OffloadKind::OFK_OpenMP, Addr: ID,
9617 Name: Name.empty() ? Addr->getName() : Name, Size, Flags, /*Data=*/0);
9618 return;
9619 }
9620 // TODO: Add support for global variables on the device after declare target
9621 // support.
9622 Function *Fn = dyn_cast<Function>(Val: Addr);
9623 if (!Fn)
9624 return;
9625
9626 // Add a function attribute for the kernel.
9627 Fn->addFnAttr(Kind: "kernel");
9628 if (T.isAMDGCN())
9629 Fn->addFnAttr(Kind: "uniform-work-group-size", Val: "true");
9630 Fn->addFnAttr(Kind: Attribute::MustProgress);
9631}
9632
9633// We only generate metadata for function that contain target regions.
9634void OpenMPIRBuilder::createOffloadEntriesAndInfoMetadata(
9635 EmitMetadataErrorReportFunctionTy &ErrorFn) {
9636
9637 // If there are no entries, we don't need to do anything.
9638 if (OffloadInfoManager.empty())
9639 return;
9640
9641 LLVMContext &C = M.getContext();
9642 SmallVector<std::pair<const OffloadEntriesInfoManager::OffloadEntryInfo *,
9643 TargetRegionEntryInfo>,
9644 16>
9645 OrderedEntries(OffloadInfoManager.size());
9646
9647 // Auxiliary methods to create metadata values and strings.
9648 auto &&GetMDInt = [this](unsigned V) {
9649 return ConstantAsMetadata::get(C: ConstantInt::get(Ty: Builder.getInt32Ty(), V));
9650 };
9651
9652 auto &&GetMDString = [&C](StringRef V) { return MDString::get(Context&: C, Str: V); };
9653
9654 // Create the offloading info metadata node.
9655 NamedMDNode *MD = M.getOrInsertNamedMetadata(Name: "omp_offload.info");
9656 auto &&TargetRegionMetadataEmitter =
9657 [&C, MD, &OrderedEntries, &GetMDInt, &GetMDString](
9658 const TargetRegionEntryInfo &EntryInfo,
9659 const OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion &E) {
9660 // Generate metadata for target regions. Each entry of this metadata
9661 // contains:
9662 // - Entry 0 -> Kind of this type of metadata (0).
9663 // - Entry 1 -> Device ID of the file where the entry was identified.
9664 // - Entry 2 -> File ID of the file where the entry was identified.
9665 // - Entry 3 -> Mangled name of the function where the entry was
9666 // identified.
9667 // - Entry 4 -> Line in the file where the entry was identified.
9668 // - Entry 5 -> Count of regions at this DeviceID/FilesID/Line.
9669 // - Entry 6 -> Order the entry was created.
9670 // The first element of the metadata node is the kind.
9671 Metadata *Ops[] = {
9672 GetMDInt(E.getKind()), GetMDInt(EntryInfo.DeviceID),
9673 GetMDInt(EntryInfo.FileID), GetMDString(EntryInfo.ParentName),
9674 GetMDInt(EntryInfo.Line), GetMDInt(EntryInfo.Count),
9675 GetMDInt(E.getOrder())};
9676
9677 // Save this entry in the right position of the ordered entries array.
9678 OrderedEntries[E.getOrder()] = std::make_pair(x: &E, y: EntryInfo);
9679
9680 // Add metadata to the named metadata node.
9681 MD->addOperand(M: MDNode::get(Context&: C, MDs: Ops));
9682 };
9683
9684 OffloadInfoManager.actOnTargetRegionEntriesInfo(Action: TargetRegionMetadataEmitter);
9685
9686 // Create function that emits metadata for each device global variable entry;
9687 auto &&DeviceGlobalVarMetadataEmitter =
9688 [&C, &OrderedEntries, &GetMDInt, &GetMDString, MD](
9689 StringRef MangledName,
9690 const OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar &E) {
9691 // Generate metadata for global variables. Each entry of this metadata
9692 // contains:
9693 // - Entry 0 -> Kind of this type of metadata (1).
9694 // - Entry 1 -> Mangled name of the variable.
9695 // - Entry 2 -> Declare target kind.
9696 // - Entry 3 -> Order the entry was created.
9697 // The first element of the metadata node is the kind.
9698 Metadata *Ops[] = {GetMDInt(E.getKind()), GetMDString(MangledName),
9699 GetMDInt(E.getFlags()), GetMDInt(E.getOrder())};
9700
9701 // Save this entry in the right position of the ordered entries array.
9702 TargetRegionEntryInfo varInfo(MangledName, 0, 0, 0);
9703 OrderedEntries[E.getOrder()] = std::make_pair(x: &E, y&: varInfo);
9704
9705 // Add metadata to the named metadata node.
9706 MD->addOperand(M: MDNode::get(Context&: C, MDs: Ops));
9707 };
9708
9709 OffloadInfoManager.actOnDeviceGlobalVarEntriesInfo(
9710 Action: DeviceGlobalVarMetadataEmitter);
9711
9712 for (const auto &E : OrderedEntries) {
9713 assert(E.first && "All ordered entries must exist!");
9714 if (const auto *CE =
9715 dyn_cast<OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion>(
9716 Val: E.first)) {
9717 if (!CE->getID() || !CE->getAddress()) {
9718 // Do not blame the entry if the parent funtion is not emitted.
9719 TargetRegionEntryInfo EntryInfo = E.second;
9720 StringRef FnName = EntryInfo.ParentName;
9721 if (!M.getNamedValue(Name: FnName))
9722 continue;
9723 ErrorFn(EMIT_MD_TARGET_REGION_ERROR, EntryInfo);
9724 continue;
9725 }
9726 createOffloadEntry(ID: CE->getID(), Addr: CE->getAddress(),
9727 /*Size=*/0, Flags: CE->getFlags(),
9728 GlobalValue::WeakAnyLinkage);
9729 } else if (const auto *CE = dyn_cast<
9730 OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar>(
9731 Val: E.first)) {
9732 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags =
9733 static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
9734 CE->getFlags());
9735 switch (Flags) {
9736 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter:
9737 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo:
9738 if (Config.isTargetDevice() && Config.hasRequiresUnifiedSharedMemory())
9739 continue;
9740 if (!CE->getAddress()) {
9741 ErrorFn(EMIT_MD_DECLARE_TARGET_ERROR, E.second);
9742 continue;
9743 }
9744 // The vaiable has no definition - no need to add the entry.
9745 if (CE->getVarSize() == 0)
9746 continue;
9747 break;
9748 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink:
9749 assert(((Config.isTargetDevice() && !CE->getAddress()) ||
9750 (!Config.isTargetDevice() && CE->getAddress())) &&
9751 "Declaret target link address is set.");
9752 if (Config.isTargetDevice())
9753 continue;
9754 if (!CE->getAddress()) {
9755 ErrorFn(EMIT_MD_GLOBAL_VAR_LINK_ERROR, TargetRegionEntryInfo());
9756 continue;
9757 }
9758 break;
9759 default:
9760 break;
9761 }
9762
9763 // Hidden or internal symbols on the device are not externally visible.
9764 // We should not attempt to register them by creating an offloading
9765 // entry. Indirect variables are handled separately on the device.
9766 if (auto *GV = dyn_cast<GlobalValue>(Val: CE->getAddress()))
9767 if ((GV->hasLocalLinkage() || GV->hasHiddenVisibility()) &&
9768 Flags != OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect)
9769 continue;
9770
9771 // Indirect globals need to use a special name that doesn't match the name
9772 // of the associated host global.
9773 if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect)
9774 createOffloadEntry(ID: CE->getAddress(), Addr: CE->getAddress(), Size: CE->getVarSize(),
9775 Flags, CE->getLinkage(), Name: CE->getVarName());
9776 else
9777 createOffloadEntry(ID: CE->getAddress(), Addr: CE->getAddress(), Size: CE->getVarSize(),
9778 Flags, CE->getLinkage());
9779
9780 } else {
9781 llvm_unreachable("Unsupported entry kind.");
9782 }
9783 }
9784
9785 // Emit requires directive globals to a special entry so the runtime can
9786 // register them when the device image is loaded.
9787 // TODO: This reduces the offloading entries to a 32-bit integer. Offloading
9788 // entries should be redesigned to better suit this use-case.
9789 if (Config.hasRequiresFlags() && !Config.isTargetDevice())
9790 offloading::emitOffloadingEntry(
9791 M, Kind: object::OffloadKind::OFK_OpenMP,
9792 Addr: Constant::getNullValue(Ty: PointerType::getUnqual(C&: M.getContext())),
9793 Name: ".requires", /*Size=*/0,
9794 Flags: OffloadEntriesInfoManager::OMPTargetGlobalRegisterRequires,
9795 Data: Config.getRequiresFlags());
9796}
9797
9798void TargetRegionEntryInfo::getTargetRegionEntryFnName(
9799 SmallVectorImpl<char> &Name, StringRef ParentName, unsigned DeviceID,
9800 unsigned FileID, unsigned Line, unsigned Count) {
9801 raw_svector_ostream OS(Name);
9802 OS << KernelNamePrefix << llvm::format(Fmt: "%x", Vals: DeviceID)
9803 << llvm::format(Fmt: "_%x_", Vals: FileID) << ParentName << "_l" << Line;
9804 if (Count)
9805 OS << "_" << Count;
9806}
9807
9808void OffloadEntriesInfoManager::getTargetRegionEntryFnName(
9809 SmallVectorImpl<char> &Name, const TargetRegionEntryInfo &EntryInfo) {
9810 unsigned NewCount = getTargetRegionEntryInfoCount(EntryInfo);
9811 TargetRegionEntryInfo::getTargetRegionEntryFnName(
9812 Name, ParentName: EntryInfo.ParentName, DeviceID: EntryInfo.DeviceID, FileID: EntryInfo.FileID,
9813 Line: EntryInfo.Line, Count: NewCount);
9814}
9815
9816TargetRegionEntryInfo
9817OpenMPIRBuilder::getTargetEntryUniqueInfo(FileIdentifierInfoCallbackTy CallBack,
9818 StringRef ParentName) {
9819 sys::fs::UniqueID ID(0xdeadf17e, 0);
9820 auto FileIDInfo = CallBack();
9821 uint64_t FileID = 0;
9822 std::error_code EC = sys::fs::getUniqueID(Path: std::get<0>(t&: FileIDInfo), Result&: ID);
9823 // If the inode ID could not be determined, create a hash value
9824 // the current file name and use that as an ID.
9825 if (EC)
9826 FileID = hash_value(arg: std::get<0>(t&: FileIDInfo));
9827 else
9828 FileID = ID.getFile();
9829
9830 return TargetRegionEntryInfo(ParentName, ID.getDevice(), FileID,
9831 std::get<1>(t&: FileIDInfo));
9832}
9833
9834unsigned OpenMPIRBuilder::getFlagMemberOffset() {
9835 unsigned Offset = 0;
9836 for (uint64_t Remain =
9837 static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
9838 omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF);
9839 !(Remain & 1); Remain = Remain >> 1)
9840 Offset++;
9841 return Offset;
9842}
9843
9844omp::OpenMPOffloadMappingFlags
9845OpenMPIRBuilder::getMemberOfFlag(unsigned Position) {
9846 // Rotate by getFlagMemberOffset() bits.
9847 return static_cast<omp::OpenMPOffloadMappingFlags>(((uint64_t)Position + 1)
9848 << getFlagMemberOffset());
9849}
9850
9851void OpenMPIRBuilder::setCorrectMemberOfFlag(
9852 omp::OpenMPOffloadMappingFlags &Flags,
9853 omp::OpenMPOffloadMappingFlags MemberOfFlag) {
9854 // If the entry is PTR_AND_OBJ but has not been marked with the special
9855 // placeholder value 0xFFFF in the MEMBER_OF field, then it should not be
9856 // marked as MEMBER_OF.
9857 if (static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
9858 Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ) &&
9859 static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
9860 (Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF) !=
9861 omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF))
9862 return;
9863
9864 // Reset the placeholder value to prepare the flag for the assignment of the
9865 // proper MEMBER_OF value.
9866 Flags &= ~omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
9867 Flags |= MemberOfFlag;
9868}
9869
9870Constant *OpenMPIRBuilder::getAddrOfDeclareTargetVar(
9871 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
9872 OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
9873 bool IsDeclaration, bool IsExternallyVisible,
9874 TargetRegionEntryInfo EntryInfo, StringRef MangledName,
9875 std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
9876 std::vector<Triple> TargetTriple, Type *LlvmPtrTy,
9877 std::function<Constant *()> GlobalInitializer,
9878 std::function<GlobalValue::LinkageTypes()> VariableLinkage) {
9879 // TODO: convert this to utilise the IRBuilder Config rather than
9880 // a passed down argument.
9881 if (OpenMPSIMD)
9882 return nullptr;
9883
9884 if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink ||
9885 ((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
9886 CaptureClause ==
9887 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
9888 Config.hasRequiresUnifiedSharedMemory())) {
9889 SmallString<64> PtrName;
9890 {
9891 raw_svector_ostream OS(PtrName);
9892 OS << MangledName;
9893 if (!IsExternallyVisible)
9894 OS << format(Fmt: "_%x", Vals: EntryInfo.FileID);
9895 OS << "_decl_tgt_ref_ptr";
9896 }
9897
9898 Value *Ptr = M.getNamedValue(Name: PtrName);
9899
9900 if (!Ptr) {
9901 GlobalValue *GlobalValue = M.getNamedValue(Name: MangledName);
9902 Ptr = getOrCreateInternalVariable(Ty: LlvmPtrTy, Name: PtrName);
9903
9904 auto *GV = cast<GlobalVariable>(Val: Ptr);
9905 GV->setLinkage(GlobalValue::WeakAnyLinkage);
9906
9907 if (!Config.isTargetDevice()) {
9908 if (GlobalInitializer)
9909 GV->setInitializer(GlobalInitializer());
9910 else
9911 GV->setInitializer(GlobalValue);
9912 }
9913
9914 registerTargetGlobalVariable(
9915 CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
9916 EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
9917 GlobalInitializer, VariableLinkage, LlvmPtrTy, Addr: cast<Constant>(Val: Ptr));
9918 }
9919
9920 return cast<Constant>(Val: Ptr);
9921 }
9922
9923 return nullptr;
9924}
9925
9926void OpenMPIRBuilder::registerTargetGlobalVariable(
9927 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
9928 OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
9929 bool IsDeclaration, bool IsExternallyVisible,
9930 TargetRegionEntryInfo EntryInfo, StringRef MangledName,
9931 std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
9932 std::vector<Triple> TargetTriple,
9933 std::function<Constant *()> GlobalInitializer,
9934 std::function<GlobalValue::LinkageTypes()> VariableLinkage, Type *LlvmPtrTy,
9935 Constant *Addr) {
9936 if (DeviceClause != OffloadEntriesInfoManager::OMPTargetDeviceClauseAny ||
9937 (TargetTriple.empty() && !Config.isTargetDevice()))
9938 return;
9939
9940 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags;
9941 StringRef VarName;
9942 int64_t VarSize;
9943 GlobalValue::LinkageTypes Linkage;
9944
9945 if ((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
9946 CaptureClause ==
9947 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
9948 !Config.hasRequiresUnifiedSharedMemory()) {
9949 Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
9950 VarName = MangledName;
9951 GlobalValue *LlvmVal = M.getNamedValue(Name: VarName);
9952
9953 if (!IsDeclaration)
9954 VarSize = divideCeil(
9955 Numerator: M.getDataLayout().getTypeSizeInBits(Ty: LlvmVal->getValueType()), Denominator: 8);
9956 else
9957 VarSize = 0;
9958 Linkage = (VariableLinkage) ? VariableLinkage() : LlvmVal->getLinkage();
9959
9960 // This is a workaround carried over from Clang which prevents undesired
9961 // optimisation of internal variables.
9962 if (Config.isTargetDevice() &&
9963 (!IsExternallyVisible || Linkage == GlobalValue::LinkOnceODRLinkage)) {
9964 // Do not create a "ref-variable" if the original is not also available
9965 // on the host.
9966 if (!OffloadInfoManager.hasDeviceGlobalVarEntryInfo(VarName))
9967 return;
9968
9969 std::string RefName = createPlatformSpecificName(Parts: {VarName, "ref"});
9970
9971 if (!M.getNamedValue(Name: RefName)) {
9972 Constant *AddrRef =
9973 getOrCreateInternalVariable(Ty: Addr->getType(), Name: RefName);
9974 auto *GvAddrRef = cast<GlobalVariable>(Val: AddrRef);
9975 GvAddrRef->setConstant(true);
9976 GvAddrRef->setLinkage(GlobalValue::InternalLinkage);
9977 GvAddrRef->setInitializer(Addr);
9978 GeneratedRefs.push_back(x: GvAddrRef);
9979 }
9980 }
9981 } else {
9982 if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink)
9983 Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
9984 else
9985 Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
9986
9987 if (Config.isTargetDevice()) {
9988 VarName = (Addr) ? Addr->getName() : "";
9989 Addr = nullptr;
9990 } else {
9991 Addr = getAddrOfDeclareTargetVar(
9992 CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
9993 EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
9994 LlvmPtrTy, GlobalInitializer, VariableLinkage);
9995 VarName = (Addr) ? Addr->getName() : "";
9996 }
9997 VarSize = M.getDataLayout().getPointerSize();
9998 Linkage = GlobalValue::WeakAnyLinkage;
9999 }
10000
10001 OffloadInfoManager.registerDeviceGlobalVarEntryInfo(VarName, Addr, VarSize,
10002 Flags, Linkage);
10003}
10004
10005/// Loads all the offload entries information from the host IR
10006/// metadata.
10007void OpenMPIRBuilder::loadOffloadInfoMetadata(Module &M) {
10008 // If we are in target mode, load the metadata from the host IR. This code has
10009 // to match the metadata creation in createOffloadEntriesAndInfoMetadata().
10010
10011 NamedMDNode *MD = M.getNamedMetadata(Name: ompOffloadInfoName);
10012 if (!MD)
10013 return;
10014
10015 for (MDNode *MN : MD->operands()) {
10016 auto &&GetMDInt = [MN](unsigned Idx) {
10017 auto *V = cast<ConstantAsMetadata>(Val: MN->getOperand(I: Idx));
10018 return cast<ConstantInt>(Val: V->getValue())->getZExtValue();
10019 };
10020
10021 auto &&GetMDString = [MN](unsigned Idx) {
10022 auto *V = cast<MDString>(Val: MN->getOperand(I: Idx));
10023 return V->getString();
10024 };
10025
10026 switch (GetMDInt(0)) {
10027 default:
10028 llvm_unreachable("Unexpected metadata!");
10029 break;
10030 case OffloadEntriesInfoManager::OffloadEntryInfo::
10031 OffloadingEntryInfoTargetRegion: {
10032 TargetRegionEntryInfo EntryInfo(/*ParentName=*/GetMDString(3),
10033 /*DeviceID=*/GetMDInt(1),
10034 /*FileID=*/GetMDInt(2),
10035 /*Line=*/GetMDInt(4),
10036 /*Count=*/GetMDInt(5));
10037 OffloadInfoManager.initializeTargetRegionEntryInfo(EntryInfo,
10038 /*Order=*/GetMDInt(6));
10039 break;
10040 }
10041 case OffloadEntriesInfoManager::OffloadEntryInfo::
10042 OffloadingEntryInfoDeviceGlobalVar:
10043 OffloadInfoManager.initializeDeviceGlobalVarEntryInfo(
10044 /*MangledName=*/Name: GetMDString(1),
10045 Flags: static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
10046 /*Flags=*/GetMDInt(2)),
10047 /*Order=*/GetMDInt(3));
10048 break;
10049 }
10050 }
10051}
10052
10053void OpenMPIRBuilder::loadOffloadInfoMetadata(StringRef HostFilePath) {
10054 if (HostFilePath.empty())
10055 return;
10056
10057 auto Buf = MemoryBuffer::getFile(Filename: HostFilePath);
10058 if (std::error_code Err = Buf.getError()) {
10059 report_fatal_error(reason: ("error opening host file from host file path inside of "
10060 "OpenMPIRBuilder: " +
10061 Err.message())
10062 .c_str());
10063 }
10064
10065 LLVMContext Ctx;
10066 auto M = expectedToErrorOrAndEmitErrors(
10067 Ctx, Val: parseBitcodeFile(Buffer: Buf.get()->getMemBufferRef(), Context&: Ctx));
10068 if (std::error_code Err = M.getError()) {
10069 report_fatal_error(
10070 reason: ("error parsing host file inside of OpenMPIRBuilder: " + Err.message())
10071 .c_str());
10072 }
10073
10074 loadOffloadInfoMetadata(M&: *M.get());
10075}
10076
10077//===----------------------------------------------------------------------===//
10078// OffloadEntriesInfoManager
10079//===----------------------------------------------------------------------===//
10080
10081bool OffloadEntriesInfoManager::empty() const {
10082 return OffloadEntriesTargetRegion.empty() &&
10083 OffloadEntriesDeviceGlobalVar.empty();
10084}
10085
10086unsigned OffloadEntriesInfoManager::getTargetRegionEntryInfoCount(
10087 const TargetRegionEntryInfo &EntryInfo) const {
10088 auto It = OffloadEntriesTargetRegionCount.find(
10089 x: getTargetRegionEntryCountKey(EntryInfo));
10090 if (It == OffloadEntriesTargetRegionCount.end())
10091 return 0;
10092 return It->second;
10093}
10094
10095void OffloadEntriesInfoManager::incrementTargetRegionEntryInfoCount(
10096 const TargetRegionEntryInfo &EntryInfo) {
10097 OffloadEntriesTargetRegionCount[getTargetRegionEntryCountKey(EntryInfo)] =
10098 EntryInfo.Count + 1;
10099}
10100
10101/// Initialize target region entry.
10102void OffloadEntriesInfoManager::initializeTargetRegionEntryInfo(
10103 const TargetRegionEntryInfo &EntryInfo, unsigned Order) {
10104 OffloadEntriesTargetRegion[EntryInfo] =
10105 OffloadEntryInfoTargetRegion(Order, /*Addr=*/nullptr, /*ID=*/nullptr,
10106 OMPTargetRegionEntryTargetRegion);
10107 ++OffloadingEntriesNum;
10108}
10109
10110void OffloadEntriesInfoManager::registerTargetRegionEntryInfo(
10111 TargetRegionEntryInfo EntryInfo, Constant *Addr, Constant *ID,
10112 OMPTargetRegionEntryKind Flags) {
10113 assert(EntryInfo.Count == 0 && "expected default EntryInfo");
10114
10115 // Update the EntryInfo with the next available count for this location.
10116 EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
10117
10118 // If we are emitting code for a target, the entry is already initialized,
10119 // only has to be registered.
10120 if (OMPBuilder->Config.isTargetDevice()) {
10121 // This could happen if the device compilation is invoked standalone.
10122 if (!hasTargetRegionEntryInfo(EntryInfo)) {
10123 return;
10124 }
10125 auto &Entry = OffloadEntriesTargetRegion[EntryInfo];
10126 Entry.setAddress(Addr);
10127 Entry.setID(ID);
10128 Entry.setFlags(Flags);
10129 } else {
10130 if (Flags == OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion &&
10131 hasTargetRegionEntryInfo(EntryInfo, /*IgnoreAddressId*/ true))
10132 return;
10133 assert(!hasTargetRegionEntryInfo(EntryInfo) &&
10134 "Target region entry already registered!");
10135 OffloadEntryInfoTargetRegion Entry(OffloadingEntriesNum, Addr, ID, Flags);
10136 OffloadEntriesTargetRegion[EntryInfo] = Entry;
10137 ++OffloadingEntriesNum;
10138 }
10139 incrementTargetRegionEntryInfoCount(EntryInfo);
10140}
10141
10142bool OffloadEntriesInfoManager::hasTargetRegionEntryInfo(
10143 TargetRegionEntryInfo EntryInfo, bool IgnoreAddressId) const {
10144
10145 // Update the EntryInfo with the next available count for this location.
10146 EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
10147
10148 auto It = OffloadEntriesTargetRegion.find(x: EntryInfo);
10149 if (It == OffloadEntriesTargetRegion.end()) {
10150 return false;
10151 }
10152 // Fail if this entry is already registered.
10153 if (!IgnoreAddressId && (It->second.getAddress() || It->second.getID()))
10154 return false;
10155 return true;
10156}
10157
10158void OffloadEntriesInfoManager::actOnTargetRegionEntriesInfo(
10159 const OffloadTargetRegionEntryInfoActTy &Action) {
10160 // Scan all target region entries and perform the provided action.
10161 for (const auto &It : OffloadEntriesTargetRegion) {
10162 Action(It.first, It.second);
10163 }
10164}
10165
10166void OffloadEntriesInfoManager::initializeDeviceGlobalVarEntryInfo(
10167 StringRef Name, OMPTargetGlobalVarEntryKind Flags, unsigned Order) {
10168 OffloadEntriesDeviceGlobalVar.try_emplace(Key: Name, Args&: Order, Args&: Flags);
10169 ++OffloadingEntriesNum;
10170}
10171
10172void OffloadEntriesInfoManager::registerDeviceGlobalVarEntryInfo(
10173 StringRef VarName, Constant *Addr, int64_t VarSize,
10174 OMPTargetGlobalVarEntryKind Flags, GlobalValue::LinkageTypes Linkage) {
10175 if (OMPBuilder->Config.isTargetDevice()) {
10176 // This could happen if the device compilation is invoked standalone.
10177 if (!hasDeviceGlobalVarEntryInfo(VarName))
10178 return;
10179 auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
10180 if (Entry.getAddress() && hasDeviceGlobalVarEntryInfo(VarName)) {
10181 if (Entry.getVarSize() == 0) {
10182 Entry.setVarSize(VarSize);
10183 Entry.setLinkage(Linkage);
10184 }
10185 return;
10186 }
10187 Entry.setVarSize(VarSize);
10188 Entry.setLinkage(Linkage);
10189 Entry.setAddress(Addr);
10190 } else {
10191 if (hasDeviceGlobalVarEntryInfo(VarName)) {
10192 auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
10193 assert(Entry.isValid() && Entry.getFlags() == Flags &&
10194 "Entry not initialized!");
10195 if (Entry.getVarSize() == 0) {
10196 Entry.setVarSize(VarSize);
10197 Entry.setLinkage(Linkage);
10198 }
10199 return;
10200 }
10201 if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect)
10202 OffloadEntriesDeviceGlobalVar.try_emplace(Key: VarName, Args&: OffloadingEntriesNum,
10203 Args&: Addr, Args&: VarSize, Args&: Flags, Args&: Linkage,
10204 Args: VarName.str());
10205 else
10206 OffloadEntriesDeviceGlobalVar.try_emplace(
10207 Key: VarName, Args&: OffloadingEntriesNum, Args&: Addr, Args&: VarSize, Args&: Flags, Args&: Linkage, Args: "");
10208 ++OffloadingEntriesNum;
10209 }
10210}
10211
10212void OffloadEntriesInfoManager::actOnDeviceGlobalVarEntriesInfo(
10213 const OffloadDeviceGlobalVarEntryInfoActTy &Action) {
10214 // Scan all target region entries and perform the provided action.
10215 for (const auto &E : OffloadEntriesDeviceGlobalVar)
10216 Action(E.getKey(), E.getValue());
10217}
10218
10219//===----------------------------------------------------------------------===//
10220// CanonicalLoopInfo
10221//===----------------------------------------------------------------------===//
10222
10223void CanonicalLoopInfo::collectControlBlocks(
10224 SmallVectorImpl<BasicBlock *> &BBs) {
10225 // We only count those BBs as control block for which we do not need to
10226 // reverse the CFG, i.e. not the loop body which can contain arbitrary control
10227 // flow. For consistency, this also means we do not add the Body block, which
10228 // is just the entry to the body code.
10229 BBs.reserve(N: BBs.size() + 6);
10230 BBs.append(IL: {getPreheader(), Header, Cond, Latch, Exit, getAfter()});
10231}
10232
10233BasicBlock *CanonicalLoopInfo::getPreheader() const {
10234 assert(isValid() && "Requires a valid canonical loop");
10235 for (BasicBlock *Pred : predecessors(BB: Header)) {
10236 if (Pred != Latch)
10237 return Pred;
10238 }
10239 llvm_unreachable("Missing preheader");
10240}
10241
10242void CanonicalLoopInfo::setTripCount(Value *TripCount) {
10243 assert(isValid() && "Requires a valid canonical loop");
10244
10245 Instruction *CmpI = &getCond()->front();
10246 assert(isa<CmpInst>(CmpI) && "First inst must compare IV with TripCount");
10247 CmpI->setOperand(i: 1, Val: TripCount);
10248
10249#ifndef NDEBUG
10250 assertOK();
10251#endif
10252}
10253
10254void CanonicalLoopInfo::mapIndVar(
10255 llvm::function_ref<Value *(Instruction *)> Updater) {
10256 assert(isValid() && "Requires a valid canonical loop");
10257
10258 Instruction *OldIV = getIndVar();
10259
10260 // Record all uses excluding those introduced by the updater. Uses by the
10261 // CanonicalLoopInfo itself to keep track of the number of iterations are
10262 // excluded.
10263 SmallVector<Use *> ReplacableUses;
10264 for (Use &U : OldIV->uses()) {
10265 auto *User = dyn_cast<Instruction>(Val: U.getUser());
10266 if (!User)
10267 continue;
10268 if (User->getParent() == getCond())
10269 continue;
10270 if (User->getParent() == getLatch())
10271 continue;
10272 ReplacableUses.push_back(Elt: &U);
10273 }
10274
10275 // Run the updater that may introduce new uses
10276 Value *NewIV = Updater(OldIV);
10277
10278 // Replace the old uses with the value returned by the updater.
10279 for (Use *U : ReplacableUses)
10280 U->set(NewIV);
10281
10282#ifndef NDEBUG
10283 assertOK();
10284#endif
10285}
10286
10287void CanonicalLoopInfo::assertOK() const {
10288#ifndef NDEBUG
10289 // No constraints if this object currently does not describe a loop.
10290 if (!isValid())
10291 return;
10292
10293 BasicBlock *Preheader = getPreheader();
10294 BasicBlock *Body = getBody();
10295 BasicBlock *After = getAfter();
10296
10297 // Verify standard control-flow we use for OpenMP loops.
10298 assert(Preheader);
10299 assert(isa<BranchInst>(Preheader->getTerminator()) &&
10300 "Preheader must terminate with unconditional branch");
10301 assert(Preheader->getSingleSuccessor() == Header &&
10302 "Preheader must jump to header");
10303
10304 assert(Header);
10305 assert(isa<BranchInst>(Header->getTerminator()) &&
10306 "Header must terminate with unconditional branch");
10307 assert(Header->getSingleSuccessor() == Cond &&
10308 "Header must jump to exiting block");
10309
10310 assert(Cond);
10311 assert(Cond->getSinglePredecessor() == Header &&
10312 "Exiting block only reachable from header");
10313
10314 assert(isa<BranchInst>(Cond->getTerminator()) &&
10315 "Exiting block must terminate with conditional branch");
10316 assert(size(successors(Cond)) == 2 &&
10317 "Exiting block must have two successors");
10318 assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(0) == Body &&
10319 "Exiting block's first successor jump to the body");
10320 assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(1) == Exit &&
10321 "Exiting block's second successor must exit the loop");
10322
10323 assert(Body);
10324 assert(Body->getSinglePredecessor() == Cond &&
10325 "Body only reachable from exiting block");
10326 assert(!isa<PHINode>(Body->front()));
10327
10328 assert(Latch);
10329 assert(isa<BranchInst>(Latch->getTerminator()) &&
10330 "Latch must terminate with unconditional branch");
10331 assert(Latch->getSingleSuccessor() == Header && "Latch must jump to header");
10332 // TODO: To support simple redirecting of the end of the body code that has
10333 // multiple; introduce another auxiliary basic block like preheader and after.
10334 assert(Latch->getSinglePredecessor() != nullptr);
10335 assert(!isa<PHINode>(Latch->front()));
10336
10337 assert(Exit);
10338 assert(isa<BranchInst>(Exit->getTerminator()) &&
10339 "Exit block must terminate with unconditional branch");
10340 assert(Exit->getSingleSuccessor() == After &&
10341 "Exit block must jump to after block");
10342
10343 assert(After);
10344 assert(After->getSinglePredecessor() == Exit &&
10345 "After block only reachable from exit block");
10346 assert(After->empty() || !isa<PHINode>(After->front()));
10347
10348 Instruction *IndVar = getIndVar();
10349 assert(IndVar && "Canonical induction variable not found?");
10350 assert(isa<IntegerType>(IndVar->getType()) &&
10351 "Induction variable must be an integer");
10352 assert(cast<PHINode>(IndVar)->getParent() == Header &&
10353 "Induction variable must be a PHI in the loop header");
10354 assert(cast<PHINode>(IndVar)->getIncomingBlock(0) == Preheader);
10355 assert(
10356 cast<ConstantInt>(cast<PHINode>(IndVar)->getIncomingValue(0))->isZero());
10357 assert(cast<PHINode>(IndVar)->getIncomingBlock(1) == Latch);
10358
10359 auto *NextIndVar = cast<PHINode>(IndVar)->getIncomingValue(1);
10360 assert(cast<Instruction>(NextIndVar)->getParent() == Latch);
10361 assert(cast<BinaryOperator>(NextIndVar)->getOpcode() == BinaryOperator::Add);
10362 assert(cast<BinaryOperator>(NextIndVar)->getOperand(0) == IndVar);
10363 assert(cast<ConstantInt>(cast<BinaryOperator>(NextIndVar)->getOperand(1))
10364 ->isOne());
10365
10366 Value *TripCount = getTripCount();
10367 assert(TripCount && "Loop trip count not found?");
10368 assert(IndVar->getType() == TripCount->getType() &&
10369 "Trip count and induction variable must have the same type");
10370
10371 auto *CmpI = cast<CmpInst>(&Cond->front());
10372 assert(CmpI->getPredicate() == CmpInst::ICMP_ULT &&
10373 "Exit condition must be a signed less-than comparison");
10374 assert(CmpI->getOperand(0) == IndVar &&
10375 "Exit condition must compare the induction variable");
10376 assert(CmpI->getOperand(1) == TripCount &&
10377 "Exit condition must compare with the trip count");
10378#endif
10379}
10380
10381void CanonicalLoopInfo::invalidate() {
10382 Header = nullptr;
10383 Cond = nullptr;
10384 Latch = nullptr;
10385 Exit = nullptr;
10386}
10387