1//===- OpenMPIRBuilder.cpp - Builder for LLVM-IR for OpenMP directives ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8/// \file
9///
10/// This file implements the OpenMPIRBuilder class, which is used as a
11/// convenient way to create LLVM instructions for OpenMP directives.
12///
13//===----------------------------------------------------------------------===//
14
15#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
16#include "llvm/ADT/SmallSet.h"
17#include "llvm/ADT/StringExtras.h"
18#include "llvm/ADT/StringRef.h"
19#include "llvm/Analysis/AssumptionCache.h"
20#include "llvm/Analysis/CodeMetrics.h"
21#include "llvm/Analysis/LoopInfo.h"
22#include "llvm/Analysis/OptimizationRemarkEmitter.h"
23#include "llvm/Analysis/ScalarEvolution.h"
24#include "llvm/Analysis/TargetLibraryInfo.h"
25#include "llvm/Bitcode/BitcodeReader.h"
26#include "llvm/Frontend/Offloading/Utility.h"
27#include "llvm/Frontend/OpenMP/OMPGridValues.h"
28#include "llvm/IR/Attributes.h"
29#include "llvm/IR/BasicBlock.h"
30#include "llvm/IR/CFG.h"
31#include "llvm/IR/CallingConv.h"
32#include "llvm/IR/Constant.h"
33#include "llvm/IR/Constants.h"
34#include "llvm/IR/DebugInfoMetadata.h"
35#include "llvm/IR/DerivedTypes.h"
36#include "llvm/IR/Function.h"
37#include "llvm/IR/GlobalVariable.h"
38#include "llvm/IR/IRBuilder.h"
39#include "llvm/IR/LLVMContext.h"
40#include "llvm/IR/MDBuilder.h"
41#include "llvm/IR/Metadata.h"
42#include "llvm/IR/PassManager.h"
43#include "llvm/IR/PassInstrumentation.h"
44#include "llvm/IR/ReplaceConstant.h"
45#include "llvm/IR/Value.h"
46#include "llvm/MC/TargetRegistry.h"
47#include "llvm/Support/CommandLine.h"
48#include "llvm/Support/ErrorHandling.h"
49#include "llvm/Support/FileSystem.h"
50#include "llvm/Target/TargetMachine.h"
51#include "llvm/Target/TargetOptions.h"
52#include "llvm/Transforms/Utils/BasicBlockUtils.h"
53#include "llvm/Transforms/Utils/Cloning.h"
54#include "llvm/Transforms/Utils/CodeExtractor.h"
55#include "llvm/Transforms/Utils/LoopPeel.h"
56#include "llvm/Transforms/Utils/UnrollLoop.h"
57
58#include <cstdint>
59#include <optional>
60#include <stack>
61
62#define DEBUG_TYPE "openmp-ir-builder"
63
64using namespace llvm;
65using namespace omp;
66
67static cl::opt<bool>
68 OptimisticAttributes("openmp-ir-builder-optimistic-attributes", cl::Hidden,
69 cl::desc("Use optimistic attributes describing "
70 "'as-if' properties of runtime calls."),
71 cl::init(Val: false));
72
73static cl::opt<double> UnrollThresholdFactor(
74 "openmp-ir-builder-unroll-threshold-factor", cl::Hidden,
75 cl::desc("Factor for the unroll threshold to account for code "
76 "simplifications still taking place"),
77 cl::init(Val: 1.5));
78
79#ifndef NDEBUG
80/// Return whether IP1 and IP2 are ambiguous, i.e. that inserting instructions
81/// at position IP1 may change the meaning of IP2 or vice-versa. This is because
82/// an InsertPoint stores the instruction before something is inserted. For
83/// instance, if both point to the same instruction, two IRBuilders alternating
84/// creating instruction will cause the instructions to be interleaved.
85static bool isConflictIP(IRBuilder<>::InsertPoint IP1,
86 IRBuilder<>::InsertPoint IP2) {
87 if (!IP1.isSet() || !IP2.isSet())
88 return false;
89 return IP1.getBlock() == IP2.getBlock() && IP1.getPoint() == IP2.getPoint();
90}
91
92static bool isValidWorkshareLoopScheduleType(OMPScheduleType SchedType) {
93 // Valid ordered/unordered and base algorithm combinations.
94 switch (SchedType & ~OMPScheduleType::MonotonicityMask) {
95 case OMPScheduleType::UnorderedStaticChunked:
96 case OMPScheduleType::UnorderedStatic:
97 case OMPScheduleType::UnorderedDynamicChunked:
98 case OMPScheduleType::UnorderedGuidedChunked:
99 case OMPScheduleType::UnorderedRuntime:
100 case OMPScheduleType::UnorderedAuto:
101 case OMPScheduleType::UnorderedTrapezoidal:
102 case OMPScheduleType::UnorderedGreedy:
103 case OMPScheduleType::UnorderedBalanced:
104 case OMPScheduleType::UnorderedGuidedIterativeChunked:
105 case OMPScheduleType::UnorderedGuidedAnalyticalChunked:
106 case OMPScheduleType::UnorderedSteal:
107 case OMPScheduleType::UnorderedStaticBalancedChunked:
108 case OMPScheduleType::UnorderedGuidedSimd:
109 case OMPScheduleType::UnorderedRuntimeSimd:
110 case OMPScheduleType::OrderedStaticChunked:
111 case OMPScheduleType::OrderedStatic:
112 case OMPScheduleType::OrderedDynamicChunked:
113 case OMPScheduleType::OrderedGuidedChunked:
114 case OMPScheduleType::OrderedRuntime:
115 case OMPScheduleType::OrderedAuto:
116 case OMPScheduleType::OrderdTrapezoidal:
117 case OMPScheduleType::NomergeUnorderedStaticChunked:
118 case OMPScheduleType::NomergeUnorderedStatic:
119 case OMPScheduleType::NomergeUnorderedDynamicChunked:
120 case OMPScheduleType::NomergeUnorderedGuidedChunked:
121 case OMPScheduleType::NomergeUnorderedRuntime:
122 case OMPScheduleType::NomergeUnorderedAuto:
123 case OMPScheduleType::NomergeUnorderedTrapezoidal:
124 case OMPScheduleType::NomergeUnorderedGreedy:
125 case OMPScheduleType::NomergeUnorderedBalanced:
126 case OMPScheduleType::NomergeUnorderedGuidedIterativeChunked:
127 case OMPScheduleType::NomergeUnorderedGuidedAnalyticalChunked:
128 case OMPScheduleType::NomergeUnorderedSteal:
129 case OMPScheduleType::NomergeOrderedStaticChunked:
130 case OMPScheduleType::NomergeOrderedStatic:
131 case OMPScheduleType::NomergeOrderedDynamicChunked:
132 case OMPScheduleType::NomergeOrderedGuidedChunked:
133 case OMPScheduleType::NomergeOrderedRuntime:
134 case OMPScheduleType::NomergeOrderedAuto:
135 case OMPScheduleType::NomergeOrderedTrapezoidal:
136 break;
137 default:
138 return false;
139 }
140
141 // Must not set both monotonicity modifiers at the same time.
142 OMPScheduleType MonotonicityFlags =
143 SchedType & OMPScheduleType::MonotonicityMask;
144 if (MonotonicityFlags == OMPScheduleType::MonotonicityMask)
145 return false;
146
147 return true;
148}
149#endif
150
151static const omp::GV &getGridValue(const Triple &T, Function *Kernel) {
152 if (T.isAMDGPU()) {
153 StringRef Features =
154 Kernel->getFnAttribute(Kind: "target-features").getValueAsString();
155 if (Features.count(Str: "+wavefrontsize64"))
156 return omp::getAMDGPUGridValues<64>();
157 return omp::getAMDGPUGridValues<32>();
158 }
159 if (T.isNVPTX())
160 return omp::NVPTXGridValues;
161 llvm_unreachable("No grid value available for this architecture!");
162}
163
164/// Determine which scheduling algorithm to use, determined from schedule clause
165/// arguments.
166static OMPScheduleType
167getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind, bool HasChunks,
168 bool HasSimdModifier) {
169 // Currently, the default schedule it static.
170 switch (ClauseKind) {
171 case OMP_SCHEDULE_Default:
172 case OMP_SCHEDULE_Static:
173 return HasChunks ? OMPScheduleType::BaseStaticChunked
174 : OMPScheduleType::BaseStatic;
175 case OMP_SCHEDULE_Dynamic:
176 return OMPScheduleType::BaseDynamicChunked;
177 case OMP_SCHEDULE_Guided:
178 return HasSimdModifier ? OMPScheduleType::BaseGuidedSimd
179 : OMPScheduleType::BaseGuidedChunked;
180 case OMP_SCHEDULE_Auto:
181 return llvm::omp::OMPScheduleType::BaseAuto;
182 case OMP_SCHEDULE_Runtime:
183 return HasSimdModifier ? OMPScheduleType::BaseRuntimeSimd
184 : OMPScheduleType::BaseRuntime;
185 }
186 llvm_unreachable("unhandled schedule clause argument");
187}
188
189/// Adds ordering modifier flags to schedule type.
190static OMPScheduleType
191getOpenMPOrderingScheduleType(OMPScheduleType BaseScheduleType,
192 bool HasOrderedClause) {
193 assert((BaseScheduleType & OMPScheduleType::ModifierMask) ==
194 OMPScheduleType::None &&
195 "Must not have ordering nor monotonicity flags already set");
196
197 OMPScheduleType OrderingModifier = HasOrderedClause
198 ? OMPScheduleType::ModifierOrdered
199 : OMPScheduleType::ModifierUnordered;
200 OMPScheduleType OrderingScheduleType = BaseScheduleType | OrderingModifier;
201
202 // Unsupported combinations
203 if (OrderingScheduleType ==
204 (OMPScheduleType::BaseGuidedSimd | OMPScheduleType::ModifierOrdered))
205 return OMPScheduleType::OrderedGuidedChunked;
206 else if (OrderingScheduleType == (OMPScheduleType::BaseRuntimeSimd |
207 OMPScheduleType::ModifierOrdered))
208 return OMPScheduleType::OrderedRuntime;
209
210 return OrderingScheduleType;
211}
212
213/// Adds monotonicity modifier flags to schedule type.
214static OMPScheduleType
215getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,
216 bool HasSimdModifier, bool HasMonotonic,
217 bool HasNonmonotonic, bool HasOrderedClause) {
218 assert((ScheduleType & OMPScheduleType::MonotonicityMask) ==
219 OMPScheduleType::None &&
220 "Must not have monotonicity flags already set");
221 assert((!HasMonotonic || !HasNonmonotonic) &&
222 "Monotonic and Nonmonotonic are contradicting each other");
223
224 if (HasMonotonic) {
225 return ScheduleType | OMPScheduleType::ModifierMonotonic;
226 } else if (HasNonmonotonic) {
227 return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
228 } else {
229 // OpenMP 5.1, 2.11.4 Worksharing-Loop Construct, Description.
230 // If the static schedule kind is specified or if the ordered clause is
231 // specified, and if the nonmonotonic modifier is not specified, the
232 // effect is as if the monotonic modifier is specified. Otherwise, unless
233 // the monotonic modifier is specified, the effect is as if the
234 // nonmonotonic modifier is specified.
235 OMPScheduleType BaseScheduleType =
236 ScheduleType & ~OMPScheduleType::ModifierMask;
237 if ((BaseScheduleType == OMPScheduleType::BaseStatic) ||
238 (BaseScheduleType == OMPScheduleType::BaseStaticChunked) ||
239 HasOrderedClause) {
240 // The monotonic is used by default in openmp runtime library, so no need
241 // to set it.
242 return ScheduleType;
243 } else {
244 return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
245 }
246 }
247}
248
249/// Determine the schedule type using schedule and ordering clause arguments.
250static OMPScheduleType
251computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
252 bool HasSimdModifier, bool HasMonotonicModifier,
253 bool HasNonmonotonicModifier, bool HasOrderedClause) {
254 OMPScheduleType BaseSchedule =
255 getOpenMPBaseScheduleType(ClauseKind, HasChunks, HasSimdModifier);
256 OMPScheduleType OrderedSchedule =
257 getOpenMPOrderingScheduleType(BaseScheduleType: BaseSchedule, HasOrderedClause);
258 OMPScheduleType Result = getOpenMPMonotonicityScheduleType(
259 ScheduleType: OrderedSchedule, HasSimdModifier, HasMonotonic: HasMonotonicModifier,
260 HasNonmonotonic: HasNonmonotonicModifier, HasOrderedClause);
261
262 assert(isValidWorkshareLoopScheduleType(Result));
263 return Result;
264}
265
266/// Make \p Source branch to \p Target.
267///
268/// Handles two situations:
269/// * \p Source already has an unconditional branch.
270/// * \p Source is a degenerate block (no terminator because the BB is
271/// the current head of the IR construction).
272static void redirectTo(BasicBlock *Source, BasicBlock *Target, DebugLoc DL) {
273 if (Instruction *Term = Source->getTerminator()) {
274 auto *Br = cast<BranchInst>(Val: Term);
275 assert(!Br->isConditional() &&
276 "BB's terminator must be an unconditional branch (or degenerate)");
277 BasicBlock *Succ = Br->getSuccessor(i: 0);
278 Succ->removePredecessor(Pred: Source, /*KeepOneInputPHIs=*/true);
279 Br->setSuccessor(idx: 0, NewSucc: Target);
280 return;
281 }
282
283 auto *NewBr = BranchInst::Create(IfTrue: Target, InsertBefore: Source);
284 NewBr->setDebugLoc(DL);
285}
286
287void llvm::spliceBB(IRBuilderBase::InsertPoint IP, BasicBlock *New,
288 bool CreateBranch) {
289 assert(New->getFirstInsertionPt() == New->begin() &&
290 "Target BB must not have PHI nodes");
291
292 // Move instructions to new block.
293 BasicBlock *Old = IP.getBlock();
294 New->splice(ToIt: New->begin(), FromBB: Old, FromBeginIt: IP.getPoint(), FromEndIt: Old->end());
295
296 if (CreateBranch)
297 BranchInst::Create(IfTrue: New, InsertBefore: Old);
298}
299
300void llvm::spliceBB(IRBuilder<> &Builder, BasicBlock *New, bool CreateBranch) {
301 DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
302 BasicBlock *Old = Builder.GetInsertBlock();
303
304 spliceBB(IP: Builder.saveIP(), New, CreateBranch);
305 if (CreateBranch)
306 Builder.SetInsertPoint(Old->getTerminator());
307 else
308 Builder.SetInsertPoint(Old);
309
310 // SetInsertPoint also updates the Builder's debug location, but we want to
311 // keep the one the Builder was configured to use.
312 Builder.SetCurrentDebugLocation(DebugLoc);
313}
314
315BasicBlock *llvm::splitBB(IRBuilderBase::InsertPoint IP, bool CreateBranch,
316 llvm::Twine Name) {
317 BasicBlock *Old = IP.getBlock();
318 BasicBlock *New = BasicBlock::Create(
319 Context&: Old->getContext(), Name: Name.isTriviallyEmpty() ? Old->getName() : Name,
320 Parent: Old->getParent(), InsertBefore: Old->getNextNode());
321 spliceBB(IP, New, CreateBranch);
322 New->replaceSuccessorsPhiUsesWith(Old, New);
323 return New;
324}
325
326BasicBlock *llvm::splitBB(IRBuilderBase &Builder, bool CreateBranch,
327 llvm::Twine Name) {
328 DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
329 BasicBlock *New = splitBB(IP: Builder.saveIP(), CreateBranch, Name);
330 if (CreateBranch)
331 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
332 else
333 Builder.SetInsertPoint(Builder.GetInsertBlock());
334 // SetInsertPoint also updates the Builder's debug location, but we want to
335 // keep the one the Builder was configured to use.
336 Builder.SetCurrentDebugLocation(DebugLoc);
337 return New;
338}
339
340BasicBlock *llvm::splitBB(IRBuilder<> &Builder, bool CreateBranch,
341 llvm::Twine Name) {
342 DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
343 BasicBlock *New = splitBB(IP: Builder.saveIP(), CreateBranch, Name);
344 if (CreateBranch)
345 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
346 else
347 Builder.SetInsertPoint(Builder.GetInsertBlock());
348 // SetInsertPoint also updates the Builder's debug location, but we want to
349 // keep the one the Builder was configured to use.
350 Builder.SetCurrentDebugLocation(DebugLoc);
351 return New;
352}
353
354BasicBlock *llvm::splitBBWithSuffix(IRBuilderBase &Builder, bool CreateBranch,
355 llvm::Twine Suffix) {
356 BasicBlock *Old = Builder.GetInsertBlock();
357 return splitBB(Builder, CreateBranch, Name: Old->getName() + Suffix);
358}
359
360// This function creates a fake integer value and a fake use for the integer
361// value. It returns the fake value created. This is useful in modeling the
362// extra arguments to the outlined functions.
363Value *createFakeIntVal(IRBuilderBase &Builder,
364 OpenMPIRBuilder::InsertPointTy OuterAllocaIP,
365 llvm::SmallVectorImpl<Instruction *> &ToBeDeleted,
366 OpenMPIRBuilder::InsertPointTy InnerAllocaIP,
367 const Twine &Name = "", bool AsPtr = true) {
368 Builder.restoreIP(IP: OuterAllocaIP);
369 Instruction *FakeVal;
370 AllocaInst *FakeValAddr =
371 Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr, Name: Name + ".addr");
372 ToBeDeleted.push_back(Elt: FakeValAddr);
373
374 if (AsPtr) {
375 FakeVal = FakeValAddr;
376 } else {
377 FakeVal =
378 Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: FakeValAddr, Name: Name + ".val");
379 ToBeDeleted.push_back(Elt: FakeVal);
380 }
381
382 // Generate a fake use of this value
383 Builder.restoreIP(IP: InnerAllocaIP);
384 Instruction *UseFakeVal;
385 if (AsPtr) {
386 UseFakeVal =
387 Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: FakeVal, Name: Name + ".use");
388 } else {
389 UseFakeVal =
390 cast<BinaryOperator>(Val: Builder.CreateAdd(LHS: FakeVal, RHS: Builder.getInt32(C: 10)));
391 }
392 ToBeDeleted.push_back(Elt: UseFakeVal);
393 return FakeVal;
394}
395
396//===----------------------------------------------------------------------===//
397// OpenMPIRBuilderConfig
398//===----------------------------------------------------------------------===//
399
400namespace {
401LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
402/// Values for bit flags for marking which requires clauses have been used.
403enum OpenMPOffloadingRequiresDirFlags {
404 /// flag undefined.
405 OMP_REQ_UNDEFINED = 0x000,
406 /// no requires directive present.
407 OMP_REQ_NONE = 0x001,
408 /// reverse_offload clause.
409 OMP_REQ_REVERSE_OFFLOAD = 0x002,
410 /// unified_address clause.
411 OMP_REQ_UNIFIED_ADDRESS = 0x004,
412 /// unified_shared_memory clause.
413 OMP_REQ_UNIFIED_SHARED_MEMORY = 0x008,
414 /// dynamic_allocators clause.
415 OMP_REQ_DYNAMIC_ALLOCATORS = 0x010,
416 LLVM_MARK_AS_BITMASK_ENUM(/*LargestValue=*/OMP_REQ_DYNAMIC_ALLOCATORS)
417};
418
419} // anonymous namespace
420
421OpenMPIRBuilderConfig::OpenMPIRBuilderConfig()
422 : RequiresFlags(OMP_REQ_UNDEFINED) {}
423
424OpenMPIRBuilderConfig::OpenMPIRBuilderConfig(
425 bool IsTargetDevice, bool IsGPU, bool OpenMPOffloadMandatory,
426 bool HasRequiresReverseOffload, bool HasRequiresUnifiedAddress,
427 bool HasRequiresUnifiedSharedMemory, bool HasRequiresDynamicAllocators)
428 : IsTargetDevice(IsTargetDevice), IsGPU(IsGPU),
429 OpenMPOffloadMandatory(OpenMPOffloadMandatory),
430 RequiresFlags(OMP_REQ_UNDEFINED) {
431 if (HasRequiresReverseOffload)
432 RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
433 if (HasRequiresUnifiedAddress)
434 RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
435 if (HasRequiresUnifiedSharedMemory)
436 RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
437 if (HasRequiresDynamicAllocators)
438 RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
439}
440
441bool OpenMPIRBuilderConfig::hasRequiresReverseOffload() const {
442 return RequiresFlags & OMP_REQ_REVERSE_OFFLOAD;
443}
444
445bool OpenMPIRBuilderConfig::hasRequiresUnifiedAddress() const {
446 return RequiresFlags & OMP_REQ_UNIFIED_ADDRESS;
447}
448
449bool OpenMPIRBuilderConfig::hasRequiresUnifiedSharedMemory() const {
450 return RequiresFlags & OMP_REQ_UNIFIED_SHARED_MEMORY;
451}
452
453bool OpenMPIRBuilderConfig::hasRequiresDynamicAllocators() const {
454 return RequiresFlags & OMP_REQ_DYNAMIC_ALLOCATORS;
455}
456
457int64_t OpenMPIRBuilderConfig::getRequiresFlags() const {
458 return hasRequiresFlags() ? RequiresFlags
459 : static_cast<int64_t>(OMP_REQ_NONE);
460}
461
462void OpenMPIRBuilderConfig::setHasRequiresReverseOffload(bool Value) {
463 if (Value)
464 RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
465 else
466 RequiresFlags &= ~OMP_REQ_REVERSE_OFFLOAD;
467}
468
469void OpenMPIRBuilderConfig::setHasRequiresUnifiedAddress(bool Value) {
470 if (Value)
471 RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
472 else
473 RequiresFlags &= ~OMP_REQ_UNIFIED_ADDRESS;
474}
475
476void OpenMPIRBuilderConfig::setHasRequiresUnifiedSharedMemory(bool Value) {
477 if (Value)
478 RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
479 else
480 RequiresFlags &= ~OMP_REQ_UNIFIED_SHARED_MEMORY;
481}
482
483void OpenMPIRBuilderConfig::setHasRequiresDynamicAllocators(bool Value) {
484 if (Value)
485 RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
486 else
487 RequiresFlags &= ~OMP_REQ_DYNAMIC_ALLOCATORS;
488}
489
490//===----------------------------------------------------------------------===//
491// OpenMPIRBuilder
492//===----------------------------------------------------------------------===//
493
494void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
495 IRBuilderBase &Builder,
496 SmallVector<Value *> &ArgsVector) {
497 Value *Version = Builder.getInt32(OMP_KERNEL_ARG_VERSION);
498 Value *PointerNum = Builder.getInt32(C: KernelArgs.NumTargetItems);
499 auto Int32Ty = Type::getInt32Ty(C&: Builder.getContext());
500 Value *ZeroArray = Constant::getNullValue(Ty: ArrayType::get(ElementType: Int32Ty, NumElements: 3));
501 Value *Flags = Builder.getInt64(C: KernelArgs.HasNoWait);
502
503 Value *NumTeams3D =
504 Builder.CreateInsertValue(Agg: ZeroArray, Val: KernelArgs.NumTeams, Idxs: {0});
505 Value *NumThreads3D =
506 Builder.CreateInsertValue(Agg: ZeroArray, Val: KernelArgs.NumThreads, Idxs: {0});
507
508 ArgsVector = {Version,
509 PointerNum,
510 KernelArgs.RTArgs.BasePointersArray,
511 KernelArgs.RTArgs.PointersArray,
512 KernelArgs.RTArgs.SizesArray,
513 KernelArgs.RTArgs.MapTypesArray,
514 KernelArgs.RTArgs.MapNamesArray,
515 KernelArgs.RTArgs.MappersArray,
516 KernelArgs.NumIterations,
517 Flags,
518 NumTeams3D,
519 NumThreads3D,
520 KernelArgs.DynCGGroupMem};
521}
522
523void OpenMPIRBuilder::addAttributes(omp::RuntimeFunction FnID, Function &Fn) {
524 LLVMContext &Ctx = Fn.getContext();
525
526 // Get the function's current attributes.
527 auto Attrs = Fn.getAttributes();
528 auto FnAttrs = Attrs.getFnAttrs();
529 auto RetAttrs = Attrs.getRetAttrs();
530 SmallVector<AttributeSet, 4> ArgAttrs;
531 for (size_t ArgNo = 0; ArgNo < Fn.arg_size(); ++ArgNo)
532 ArgAttrs.emplace_back(Args: Attrs.getParamAttrs(ArgNo));
533
534 // Add AS to FnAS while taking special care with integer extensions.
535 auto addAttrSet = [&](AttributeSet &FnAS, const AttributeSet &AS,
536 bool Param = true) -> void {
537 bool HasSignExt = AS.hasAttribute(Kind: Attribute::SExt);
538 bool HasZeroExt = AS.hasAttribute(Kind: Attribute::ZExt);
539 if (HasSignExt || HasZeroExt) {
540 assert(AS.getNumAttributes() == 1 &&
541 "Currently not handling extension attr combined with others.");
542 if (Param) {
543 if (auto AK = TargetLibraryInfo::getExtAttrForI32Param(T, Signed: HasSignExt))
544 FnAS = FnAS.addAttribute(C&: Ctx, Kind: AK);
545 } else if (auto AK =
546 TargetLibraryInfo::getExtAttrForI32Return(T, Signed: HasSignExt))
547 FnAS = FnAS.addAttribute(C&: Ctx, Kind: AK);
548 } else {
549 FnAS = FnAS.addAttributes(C&: Ctx, AS);
550 }
551 };
552
553#define OMP_ATTRS_SET(VarName, AttrSet) AttributeSet VarName = AttrSet;
554#include "llvm/Frontend/OpenMP/OMPKinds.def"
555
556 // Add attributes to the function declaration.
557 switch (FnID) {
558#define OMP_RTL_ATTRS(Enum, FnAttrSet, RetAttrSet, ArgAttrSets) \
559 case Enum: \
560 FnAttrs = FnAttrs.addAttributes(Ctx, FnAttrSet); \
561 addAttrSet(RetAttrs, RetAttrSet, /*Param*/ false); \
562 for (size_t ArgNo = 0; ArgNo < ArgAttrSets.size(); ++ArgNo) \
563 addAttrSet(ArgAttrs[ArgNo], ArgAttrSets[ArgNo]); \
564 Fn.setAttributes(AttributeList::get(Ctx, FnAttrs, RetAttrs, ArgAttrs)); \
565 break;
566#include "llvm/Frontend/OpenMP/OMPKinds.def"
567 default:
568 // Attributes are optional.
569 break;
570 }
571}
572
573FunctionCallee
574OpenMPIRBuilder::getOrCreateRuntimeFunction(Module &M, RuntimeFunction FnID) {
575 FunctionType *FnTy = nullptr;
576 Function *Fn = nullptr;
577
578 // Try to find the declation in the module first.
579 switch (FnID) {
580#define OMP_RTL(Enum, Str, IsVarArg, ReturnType, ...) \
581 case Enum: \
582 FnTy = FunctionType::get(ReturnType, ArrayRef<Type *>{__VA_ARGS__}, \
583 IsVarArg); \
584 Fn = M.getFunction(Str); \
585 break;
586#include "llvm/Frontend/OpenMP/OMPKinds.def"
587 }
588
589 if (!Fn) {
590 // Create a new declaration if we need one.
591 switch (FnID) {
592#define OMP_RTL(Enum, Str, ...) \
593 case Enum: \
594 Fn = Function::Create(FnTy, GlobalValue::ExternalLinkage, Str, M); \
595 break;
596#include "llvm/Frontend/OpenMP/OMPKinds.def"
597 }
598
599 // Add information if the runtime function takes a callback function
600 if (FnID == OMPRTL___kmpc_fork_call || FnID == OMPRTL___kmpc_fork_teams) {
601 if (!Fn->hasMetadata(KindID: LLVMContext::MD_callback)) {
602 LLVMContext &Ctx = Fn->getContext();
603 MDBuilder MDB(Ctx);
604 // Annotate the callback behavior of the runtime function:
605 // - The callback callee is argument number 2 (microtask).
606 // - The first two arguments of the callback callee are unknown (-1).
607 // - All variadic arguments to the runtime function are passed to the
608 // callback callee.
609 Fn->addMetadata(
610 KindID: LLVMContext::MD_callback,
611 MD&: *MDNode::get(Context&: Ctx, MDs: {MDB.createCallbackEncoding(
612 CalleeArgNo: 2, Arguments: {-1, -1}, /* VarArgsArePassed */ true)}));
613 }
614 }
615
616 LLVM_DEBUG(dbgs() << "Created OpenMP runtime function " << Fn->getName()
617 << " with type " << *Fn->getFunctionType() << "\n");
618 addAttributes(FnID, Fn&: *Fn);
619
620 } else {
621 LLVM_DEBUG(dbgs() << "Found OpenMP runtime function " << Fn->getName()
622 << " with type " << *Fn->getFunctionType() << "\n");
623 }
624
625 assert(Fn && "Failed to create OpenMP runtime function");
626
627 return {FnTy, Fn};
628}
629
630Function *OpenMPIRBuilder::getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID) {
631 FunctionCallee RTLFn = getOrCreateRuntimeFunction(M, FnID);
632 auto *Fn = dyn_cast<llvm::Function>(Val: RTLFn.getCallee());
633 assert(Fn && "Failed to create OpenMP runtime function pointer");
634 return Fn;
635}
636
637void OpenMPIRBuilder::initialize() { initializeTypes(M); }
638
639static void raiseUserConstantDataAllocasToEntryBlock(IRBuilderBase &Builder,
640 Function *Function) {
641 BasicBlock &EntryBlock = Function->getEntryBlock();
642 Instruction *MoveLocInst = EntryBlock.getFirstNonPHI();
643
644 // Loop over blocks looking for constant allocas, skipping the entry block
645 // as any allocas there are already in the desired location.
646 for (auto Block = std::next(x: Function->begin(), n: 1); Block != Function->end();
647 Block++) {
648 for (auto Inst = Block->getReverseIterator()->begin();
649 Inst != Block->getReverseIterator()->end();) {
650 if (auto *AllocaInst = dyn_cast_if_present<llvm::AllocaInst>(Val&: Inst)) {
651 Inst++;
652 if (!isa<ConstantData>(Val: AllocaInst->getArraySize()))
653 continue;
654 AllocaInst->moveBeforePreserving(MovePos: MoveLocInst);
655 } else {
656 Inst++;
657 }
658 }
659 }
660}
661
662void OpenMPIRBuilder::finalize(Function *Fn) {
663 SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
664 SmallVector<BasicBlock *, 32> Blocks;
665 SmallVector<OutlineInfo, 16> DeferredOutlines;
666 for (OutlineInfo &OI : OutlineInfos) {
667 // Skip functions that have not finalized yet; may happen with nested
668 // function generation.
669 if (Fn && OI.getFunction() != Fn) {
670 DeferredOutlines.push_back(Elt: OI);
671 continue;
672 }
673
674 ParallelRegionBlockSet.clear();
675 Blocks.clear();
676 OI.collectBlocks(BlockSet&: ParallelRegionBlockSet, BlockVector&: Blocks);
677
678 Function *OuterFn = OI.getFunction();
679 CodeExtractorAnalysisCache CEAC(*OuterFn);
680 // If we generate code for the target device, we need to allocate
681 // struct for aggregate params in the device default alloca address space.
682 // OpenMP runtime requires that the params of the extracted functions are
683 // passed as zero address space pointers. This flag ensures that
684 // CodeExtractor generates correct code for extracted functions
685 // which are used by OpenMP runtime.
686 bool ArgsInZeroAddressSpace = Config.isTargetDevice();
687 CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
688 /* AggregateArgs */ true,
689 /* BlockFrequencyInfo */ nullptr,
690 /* BranchProbabilityInfo */ nullptr,
691 /* AssumptionCache */ nullptr,
692 /* AllowVarArgs */ true,
693 /* AllowAlloca */ true,
694 /* AllocaBlock*/ OI.OuterAllocaBB,
695 /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
696
697 LLVM_DEBUG(dbgs() << "Before outlining: " << *OuterFn << "\n");
698 LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
699 << " Exit: " << OI.ExitBB->getName() << "\n");
700 assert(Extractor.isEligible() &&
701 "Expected OpenMP outlining to be possible!");
702
703 for (auto *V : OI.ExcludeArgsFromAggregate)
704 Extractor.excludeArgFromAggregate(Arg: V);
705
706 Function *OutlinedFn = Extractor.extractCodeRegion(CEAC);
707
708 // Forward target-cpu, target-features attributes to the outlined function.
709 auto TargetCpuAttr = OuterFn->getFnAttribute(Kind: "target-cpu");
710 if (TargetCpuAttr.isStringAttribute())
711 OutlinedFn->addFnAttr(Attr: TargetCpuAttr);
712
713 auto TargetFeaturesAttr = OuterFn->getFnAttribute(Kind: "target-features");
714 if (TargetFeaturesAttr.isStringAttribute())
715 OutlinedFn->addFnAttr(Attr: TargetFeaturesAttr);
716
717 LLVM_DEBUG(dbgs() << "After outlining: " << *OuterFn << "\n");
718 LLVM_DEBUG(dbgs() << " Outlined function: " << *OutlinedFn << "\n");
719 assert(OutlinedFn->getReturnType()->isVoidTy() &&
720 "OpenMP outlined functions should not return a value!");
721
722 // For compability with the clang CG we move the outlined function after the
723 // one with the parallel region.
724 OutlinedFn->removeFromParent();
725 M.getFunctionList().insertAfter(where: OuterFn->getIterator(), New: OutlinedFn);
726
727 // Remove the artificial entry introduced by the extractor right away, we
728 // made our own entry block after all.
729 {
730 BasicBlock &ArtificialEntry = OutlinedFn->getEntryBlock();
731 assert(ArtificialEntry.getUniqueSuccessor() == OI.EntryBB);
732 assert(OI.EntryBB->getUniquePredecessor() == &ArtificialEntry);
733 // Move instructions from the to-be-deleted ArtificialEntry to the entry
734 // basic block of the parallel region. CodeExtractor generates
735 // instructions to unwrap the aggregate argument and may sink
736 // allocas/bitcasts for values that are solely used in the outlined region
737 // and do not escape.
738 assert(!ArtificialEntry.empty() &&
739 "Expected instructions to add in the outlined region entry");
740 for (BasicBlock::reverse_iterator It = ArtificialEntry.rbegin(),
741 End = ArtificialEntry.rend();
742 It != End;) {
743 Instruction &I = *It;
744 It++;
745
746 if (I.isTerminator())
747 continue;
748
749 I.moveBeforePreserving(BB&: *OI.EntryBB, I: OI.EntryBB->getFirstInsertionPt());
750 }
751
752 OI.EntryBB->moveBefore(MovePos: &ArtificialEntry);
753 ArtificialEntry.eraseFromParent();
754 }
755 assert(&OutlinedFn->getEntryBlock() == OI.EntryBB);
756 assert(OutlinedFn && OutlinedFn->getNumUses() == 1);
757
758 // Run a user callback, e.g. to add attributes.
759 if (OI.PostOutlineCB)
760 OI.PostOutlineCB(*OutlinedFn);
761 }
762
763 // Remove work items that have been completed.
764 OutlineInfos = std::move(DeferredOutlines);
765
766 // The createTarget functions embeds user written code into
767 // the target region which may inject allocas which need to
768 // be moved to the entry block of our target or risk malformed
769 // optimisations by later passes, this is only relevant for
770 // the device pass which appears to be a little more delicate
771 // when it comes to optimisations (however, we do not block on
772 // that here, it's up to the inserter to the list to do so).
773 // This notbaly has to occur after the OutlinedInfo candidates
774 // have been extracted so we have an end product that will not
775 // be implicitly adversely affected by any raises unless
776 // intentionally appended to the list.
777 // NOTE: This only does so for ConstantData, it could be extended
778 // to ConstantExpr's with further effort, however, they should
779 // largely be folded when they get here. Extending it to runtime
780 // defined/read+writeable allocation sizes would be non-trivial
781 // (need to factor in movement of any stores to variables the
782 // allocation size depends on, as well as the usual loads,
783 // otherwise it'll yield the wrong result after movement) and
784 // likely be more suitable as an LLVM optimisation pass.
785 for (Function *F : ConstantAllocaRaiseCandidates)
786 raiseUserConstantDataAllocasToEntryBlock(Builder, Function: F);
787
788 EmitMetadataErrorReportFunctionTy &&ErrorReportFn =
789 [](EmitMetadataErrorKind Kind,
790 const TargetRegionEntryInfo &EntryInfo) -> void {
791 errs() << "Error of kind: " << Kind
792 << " when emitting offload entries and metadata during "
793 "OMPIRBuilder finalization \n";
794 };
795
796 if (!OffloadInfoManager.empty())
797 createOffloadEntriesAndInfoMetadata(ErrorReportFunction&: ErrorReportFn);
798
799 if (Config.EmitLLVMUsedMetaInfo.value_or(u: false)) {
800 std::vector<WeakTrackingVH> LLVMCompilerUsed = {
801 M.getGlobalVariable(Name: "__openmp_nvptx_data_transfer_temporary_storage")};
802 emitUsed(Name: "llvm.compiler.used", List&: LLVMCompilerUsed);
803 }
804}
805
806OpenMPIRBuilder::~OpenMPIRBuilder() {
807 assert(OutlineInfos.empty() && "There must be no outstanding outlinings");
808}
809
810GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
811 IntegerType *I32Ty = Type::getInt32Ty(C&: M.getContext());
812 auto *GV =
813 new GlobalVariable(M, I32Ty,
814 /* isConstant = */ true, GlobalValue::WeakODRLinkage,
815 ConstantInt::get(Ty: I32Ty, V: Value), Name);
816 GV->setVisibility(GlobalValue::HiddenVisibility);
817
818 return GV;
819}
820
821Constant *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr,
822 uint32_t SrcLocStrSize,
823 IdentFlag LocFlags,
824 unsigned Reserve2Flags) {
825 // Enable "C-mode".
826 LocFlags |= OMP_IDENT_FLAG_KMPC;
827
828 Constant *&Ident =
829 IdentMap[{SrcLocStr, uint64_t(LocFlags) << 31 | Reserve2Flags}];
830 if (!Ident) {
831 Constant *I32Null = ConstantInt::getNullValue(Ty: Int32);
832 Constant *IdentData[] = {I32Null,
833 ConstantInt::get(Ty: Int32, V: uint32_t(LocFlags)),
834 ConstantInt::get(Ty: Int32, V: Reserve2Flags),
835 ConstantInt::get(Ty: Int32, V: SrcLocStrSize), SrcLocStr};
836 Constant *Initializer =
837 ConstantStruct::get(T: OpenMPIRBuilder::Ident, V: IdentData);
838
839 // Look for existing encoding of the location + flags, not needed but
840 // minimizes the difference to the existing solution while we transition.
841 for (GlobalVariable &GV : M.globals())
842 if (GV.getValueType() == OpenMPIRBuilder::Ident && GV.hasInitializer())
843 if (GV.getInitializer() == Initializer)
844 Ident = &GV;
845
846 if (!Ident) {
847 auto *GV = new GlobalVariable(
848 M, OpenMPIRBuilder::Ident,
849 /* isConstant = */ true, GlobalValue::PrivateLinkage, Initializer, "",
850 nullptr, GlobalValue::NotThreadLocal,
851 M.getDataLayout().getDefaultGlobalsAddressSpace());
852 GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
853 GV->setAlignment(Align(8));
854 Ident = GV;
855 }
856 }
857
858 return ConstantExpr::getPointerBitCastOrAddrSpaceCast(C: Ident, Ty: IdentPtr);
859}
860
861Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef LocStr,
862 uint32_t &SrcLocStrSize) {
863 SrcLocStrSize = LocStr.size();
864 Constant *&SrcLocStr = SrcLocStrMap[LocStr];
865 if (!SrcLocStr) {
866 Constant *Initializer =
867 ConstantDataArray::getString(Context&: M.getContext(), Initializer: LocStr);
868
869 // Look for existing encoding of the location, not needed but minimizes the
870 // difference to the existing solution while we transition.
871 for (GlobalVariable &GV : M.globals())
872 if (GV.isConstant() && GV.hasInitializer() &&
873 GV.getInitializer() == Initializer)
874 return SrcLocStr = ConstantExpr::getPointerCast(C: &GV, Ty: Int8Ptr);
875
876 SrcLocStr = Builder.CreateGlobalStringPtr(Str: LocStr, /* Name */ "",
877 /* AddressSpace */ 0, M: &M);
878 }
879 return SrcLocStr;
880}
881
882Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef FunctionName,
883 StringRef FileName,
884 unsigned Line, unsigned Column,
885 uint32_t &SrcLocStrSize) {
886 SmallString<128> Buffer;
887 Buffer.push_back(Elt: ';');
888 Buffer.append(RHS: FileName);
889 Buffer.push_back(Elt: ';');
890 Buffer.append(RHS: FunctionName);
891 Buffer.push_back(Elt: ';');
892 Buffer.append(RHS: std::to_string(val: Line));
893 Buffer.push_back(Elt: ';');
894 Buffer.append(RHS: std::to_string(val: Column));
895 Buffer.push_back(Elt: ';');
896 Buffer.push_back(Elt: ';');
897 return getOrCreateSrcLocStr(LocStr: Buffer.str(), SrcLocStrSize);
898}
899
900Constant *
901OpenMPIRBuilder::getOrCreateDefaultSrcLocStr(uint32_t &SrcLocStrSize) {
902 StringRef UnknownLoc = ";unknown;unknown;0;0;;";
903 return getOrCreateSrcLocStr(LocStr: UnknownLoc, SrcLocStrSize);
904}
905
906Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(DebugLoc DL,
907 uint32_t &SrcLocStrSize,
908 Function *F) {
909 DILocation *DIL = DL.get();
910 if (!DIL)
911 return getOrCreateDefaultSrcLocStr(SrcLocStrSize);
912 StringRef FileName = M.getName();
913 if (DIFile *DIF = DIL->getFile())
914 if (std::optional<StringRef> Source = DIF->getSource())
915 FileName = *Source;
916 StringRef Function = DIL->getScope()->getSubprogram()->getName();
917 if (Function.empty() && F)
918 Function = F->getName();
919 return getOrCreateSrcLocStr(FunctionName: Function, FileName, Line: DIL->getLine(),
920 Column: DIL->getColumn(), SrcLocStrSize);
921}
922
923Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(const LocationDescription &Loc,
924 uint32_t &SrcLocStrSize) {
925 return getOrCreateSrcLocStr(DL: Loc.DL, SrcLocStrSize,
926 F: Loc.IP.getBlock()->getParent());
927}
928
929Value *OpenMPIRBuilder::getOrCreateThreadID(Value *Ident) {
930 return Builder.CreateCall(
931 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_global_thread_num), Args: Ident,
932 Name: "omp_global_thread_num");
933}
934
935OpenMPIRBuilder::InsertPointTy
936OpenMPIRBuilder::createBarrier(const LocationDescription &Loc, Directive Kind,
937 bool ForceSimpleCall, bool CheckCancelFlag) {
938 if (!updateToLocation(Loc))
939 return Loc.IP;
940
941 // Build call __kmpc_cancel_barrier(loc, thread_id) or
942 // __kmpc_barrier(loc, thread_id);
943
944 IdentFlag BarrierLocFlags;
945 switch (Kind) {
946 case OMPD_for:
947 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_FOR;
948 break;
949 case OMPD_sections:
950 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SECTIONS;
951 break;
952 case OMPD_single:
953 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SINGLE;
954 break;
955 case OMPD_barrier:
956 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_EXPL;
957 break;
958 default:
959 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL;
960 break;
961 }
962
963 uint32_t SrcLocStrSize;
964 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
965 Value *Args[] = {
966 getOrCreateIdent(SrcLocStr, SrcLocStrSize, LocFlags: BarrierLocFlags),
967 getOrCreateThreadID(Ident: getOrCreateIdent(SrcLocStr, SrcLocStrSize))};
968
969 // If we are in a cancellable parallel region, barriers are cancellation
970 // points.
971 // TODO: Check why we would force simple calls or to ignore the cancel flag.
972 bool UseCancelBarrier =
973 !ForceSimpleCall && isLastFinalizationInfoCancellable(DK: OMPD_parallel);
974
975 Value *Result =
976 Builder.CreateCall(Callee: getOrCreateRuntimeFunctionPtr(
977 FnID: UseCancelBarrier ? OMPRTL___kmpc_cancel_barrier
978 : OMPRTL___kmpc_barrier),
979 Args);
980
981 if (UseCancelBarrier && CheckCancelFlag)
982 emitCancelationCheckImpl(CancelFlag: Result, CanceledDirective: OMPD_parallel);
983
984 return Builder.saveIP();
985}
986
987OpenMPIRBuilder::InsertPointTy
988OpenMPIRBuilder::createCancel(const LocationDescription &Loc,
989 Value *IfCondition,
990 omp::Directive CanceledDirective) {
991 if (!updateToLocation(Loc))
992 return Loc.IP;
993
994 // LLVM utilities like blocks with terminators.
995 auto *UI = Builder.CreateUnreachable();
996
997 Instruction *ThenTI = UI, *ElseTI = nullptr;
998 if (IfCondition)
999 SplitBlockAndInsertIfThenElse(Cond: IfCondition, SplitBefore: UI, ThenTerm: &ThenTI, ElseTerm: &ElseTI);
1000 Builder.SetInsertPoint(ThenTI);
1001
1002 Value *CancelKind = nullptr;
1003 switch (CanceledDirective) {
1004#define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value) \
1005 case DirectiveEnum: \
1006 CancelKind = Builder.getInt32(Value); \
1007 break;
1008#include "llvm/Frontend/OpenMP/OMPKinds.def"
1009 default:
1010 llvm_unreachable("Unknown cancel kind!");
1011 }
1012
1013 uint32_t SrcLocStrSize;
1014 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1015 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1016 Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
1017 Value *Result = Builder.CreateCall(
1018 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_cancel), Args);
1019 auto ExitCB = [this, CanceledDirective, Loc](InsertPointTy IP) {
1020 if (CanceledDirective == OMPD_parallel) {
1021 IRBuilder<>::InsertPointGuard IPG(Builder);
1022 Builder.restoreIP(IP);
1023 createBarrier(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
1024 Kind: omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
1025 /* CheckCancelFlag */ false);
1026 }
1027 };
1028
1029 // The actual cancel logic is shared with others, e.g., cancel_barriers.
1030 emitCancelationCheckImpl(CancelFlag: Result, CanceledDirective, ExitCB);
1031
1032 // Update the insertion point and remove the terminator we introduced.
1033 Builder.SetInsertPoint(UI->getParent());
1034 UI->eraseFromParent();
1035
1036 return Builder.saveIP();
1037}
1038
1039OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetKernel(
1040 const LocationDescription &Loc, InsertPointTy AllocaIP, Value *&Return,
1041 Value *Ident, Value *DeviceID, Value *NumTeams, Value *NumThreads,
1042 Value *HostPtr, ArrayRef<Value *> KernelArgs) {
1043 if (!updateToLocation(Loc))
1044 return Loc.IP;
1045
1046 Builder.restoreIP(IP: AllocaIP);
1047 auto *KernelArgsPtr =
1048 Builder.CreateAlloca(Ty: OpenMPIRBuilder::KernelArgs, ArraySize: nullptr, Name: "kernel_args");
1049 Builder.restoreIP(IP: Loc.IP);
1050
1051 for (unsigned I = 0, Size = KernelArgs.size(); I != Size; ++I) {
1052 llvm::Value *Arg =
1053 Builder.CreateStructGEP(Ty: OpenMPIRBuilder::KernelArgs, Ptr: KernelArgsPtr, Idx: I);
1054 Builder.CreateAlignedStore(
1055 Val: KernelArgs[I], Ptr: Arg,
1056 Align: M.getDataLayout().getPrefTypeAlign(Ty: KernelArgs[I]->getType()));
1057 }
1058
1059 SmallVector<Value *> OffloadingArgs{Ident, DeviceID, NumTeams,
1060 NumThreads, HostPtr, KernelArgsPtr};
1061
1062 Return = Builder.CreateCall(
1063 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___tgt_target_kernel),
1064 Args: OffloadingArgs);
1065
1066 return Builder.saveIP();
1067}
1068
1069OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitKernelLaunch(
1070 const LocationDescription &Loc, Function *OutlinedFn, Value *OutlinedFnID,
1071 EmitFallbackCallbackTy emitTargetCallFallbackCB, TargetKernelArgs &Args,
1072 Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP) {
1073
1074 if (!updateToLocation(Loc))
1075 return Loc.IP;
1076
1077 Builder.restoreIP(IP: Loc.IP);
1078 // On top of the arrays that were filled up, the target offloading call
1079 // takes as arguments the device id as well as the host pointer. The host
1080 // pointer is used by the runtime library to identify the current target
1081 // region, so it only has to be unique and not necessarily point to
1082 // anything. It could be the pointer to the outlined function that
1083 // implements the target region, but we aren't using that so that the
1084 // compiler doesn't need to keep that, and could therefore inline the host
1085 // function if proven worthwhile during optimization.
1086
1087 // From this point on, we need to have an ID of the target region defined.
1088 assert(OutlinedFnID && "Invalid outlined function ID!");
1089 (void)OutlinedFnID;
1090
1091 // Return value of the runtime offloading call.
1092 Value *Return = nullptr;
1093
1094 // Arguments for the target kernel.
1095 SmallVector<Value *> ArgsVector;
1096 getKernelArgsVector(KernelArgs&: Args, Builder, ArgsVector);
1097
1098 // The target region is an outlined function launched by the runtime
1099 // via calls to __tgt_target_kernel().
1100 //
1101 // Note that on the host and CPU targets, the runtime implementation of
1102 // these calls simply call the outlined function without forking threads.
1103 // The outlined functions themselves have runtime calls to
1104 // __kmpc_fork_teams() and __kmpc_fork() for this purpose, codegen'd by
1105 // the compiler in emitTeamsCall() and emitParallelCall().
1106 //
1107 // In contrast, on the NVPTX target, the implementation of
1108 // __tgt_target_teams() launches a GPU kernel with the requested number
1109 // of teams and threads so no additional calls to the runtime are required.
1110 // Check the error code and execute the host version if required.
1111 Builder.restoreIP(IP: emitTargetKernel(Loc: Builder, AllocaIP, Return, Ident: RTLoc, DeviceID,
1112 NumTeams: Args.NumTeams, NumThreads: Args.NumThreads,
1113 HostPtr: OutlinedFnID, KernelArgs: ArgsVector));
1114
1115 BasicBlock *OffloadFailedBlock =
1116 BasicBlock::Create(Context&: Builder.getContext(), Name: "omp_offload.failed");
1117 BasicBlock *OffloadContBlock =
1118 BasicBlock::Create(Context&: Builder.getContext(), Name: "omp_offload.cont");
1119 Value *Failed = Builder.CreateIsNotNull(Arg: Return);
1120 Builder.CreateCondBr(Cond: Failed, True: OffloadFailedBlock, False: OffloadContBlock);
1121
1122 auto CurFn = Builder.GetInsertBlock()->getParent();
1123 emitBlock(BB: OffloadFailedBlock, CurFn);
1124 Builder.restoreIP(IP: emitTargetCallFallbackCB(Builder.saveIP()));
1125 emitBranch(Target: OffloadContBlock);
1126 emitBlock(BB: OffloadContBlock, CurFn, /*IsFinished=*/true);
1127 return Builder.saveIP();
1128}
1129
1130void OpenMPIRBuilder::emitCancelationCheckImpl(Value *CancelFlag,
1131 omp::Directive CanceledDirective,
1132 FinalizeCallbackTy ExitCB) {
1133 assert(isLastFinalizationInfoCancellable(CanceledDirective) &&
1134 "Unexpected cancellation!");
1135
1136 // For a cancel barrier we create two new blocks.
1137 BasicBlock *BB = Builder.GetInsertBlock();
1138 BasicBlock *NonCancellationBlock;
1139 if (Builder.GetInsertPoint() == BB->end()) {
1140 // TODO: This branch will not be needed once we moved to the
1141 // OpenMPIRBuilder codegen completely.
1142 NonCancellationBlock = BasicBlock::Create(
1143 Context&: BB->getContext(), Name: BB->getName() + ".cont", Parent: BB->getParent());
1144 } else {
1145 NonCancellationBlock = SplitBlock(Old: BB, SplitPt: &*Builder.GetInsertPoint());
1146 BB->getTerminator()->eraseFromParent();
1147 Builder.SetInsertPoint(BB);
1148 }
1149 BasicBlock *CancellationBlock = BasicBlock::Create(
1150 Context&: BB->getContext(), Name: BB->getName() + ".cncl", Parent: BB->getParent());
1151
1152 // Jump to them based on the return value.
1153 Value *Cmp = Builder.CreateIsNull(Arg: CancelFlag);
1154 Builder.CreateCondBr(Cond: Cmp, True: NonCancellationBlock, False: CancellationBlock,
1155 /* TODO weight */ BranchWeights: nullptr, Unpredictable: nullptr);
1156
1157 // From the cancellation block we finalize all variables and go to the
1158 // post finalization block that is known to the FiniCB callback.
1159 Builder.SetInsertPoint(CancellationBlock);
1160 if (ExitCB)
1161 ExitCB(Builder.saveIP());
1162 auto &FI = FinalizationStack.back();
1163 FI.FiniCB(Builder.saveIP());
1164
1165 // The continuation block is where code generation continues.
1166 Builder.SetInsertPoint(TheBB: NonCancellationBlock, IP: NonCancellationBlock->begin());
1167}
1168
1169// Callback used to create OpenMP runtime calls to support
1170// omp parallel clause for the device.
1171// We need to use this callback to replace call to the OutlinedFn in OuterFn
1172// by the call to the OpenMP DeviceRTL runtime function (kmpc_parallel_51)
1173static void targetParallelCallback(
1174 OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn, Function *OuterFn,
1175 BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
1176 Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1177 Value *ThreadID, const SmallVector<Instruction *, 4> &ToBeDeleted) {
1178 // Add some known attributes.
1179 IRBuilder<> &Builder = OMPIRBuilder->Builder;
1180 OutlinedFn.addParamAttr(ArgNo: 0, Kind: Attribute::NoAlias);
1181 OutlinedFn.addParamAttr(ArgNo: 1, Kind: Attribute::NoAlias);
1182 OutlinedFn.addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
1183 OutlinedFn.addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
1184 OutlinedFn.addFnAttr(Kind: Attribute::NoUnwind);
1185
1186 assert(OutlinedFn.arg_size() >= 2 &&
1187 "Expected at least tid and bounded tid as arguments");
1188 unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1189
1190 CallInst *CI = cast<CallInst>(Val: OutlinedFn.user_back());
1191 assert(CI && "Expected call instruction to outlined function");
1192 CI->getParent()->setName("omp_parallel");
1193
1194 Builder.SetInsertPoint(CI);
1195 Type *PtrTy = OMPIRBuilder->VoidPtr;
1196 Value *NullPtrValue = Constant::getNullValue(Ty: PtrTy);
1197
1198 // Add alloca for kernel args
1199 OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP();
1200 Builder.SetInsertPoint(TheBB: OuterAllocaBB, IP: OuterAllocaBB->getFirstInsertionPt());
1201 AllocaInst *ArgsAlloca =
1202 Builder.CreateAlloca(Ty: ArrayType::get(ElementType: PtrTy, NumElements: NumCapturedVars));
1203 Value *Args = ArgsAlloca;
1204 // Add address space cast if array for storing arguments is not allocated
1205 // in address space 0
1206 if (ArgsAlloca->getAddressSpace())
1207 Args = Builder.CreatePointerCast(V: ArgsAlloca, DestTy: PtrTy);
1208 Builder.restoreIP(IP: CurrentIP);
1209
1210 // Store captured vars which are used by kmpc_parallel_51
1211 for (unsigned Idx = 0; Idx < NumCapturedVars; Idx++) {
1212 Value *V = *(CI->arg_begin() + 2 + Idx);
1213 Value *StoreAddress = Builder.CreateConstInBoundsGEP2_64(
1214 Ty: ArrayType::get(ElementType: PtrTy, NumElements: NumCapturedVars), Ptr: Args, Idx0: 0, Idx1: Idx);
1215 Builder.CreateStore(Val: V, Ptr: StoreAddress);
1216 }
1217
1218 Value *Cond =
1219 IfCondition ? Builder.CreateSExtOrTrunc(V: IfCondition, DestTy: OMPIRBuilder->Int32)
1220 : Builder.getInt32(C: 1);
1221
1222 // Build kmpc_parallel_51 call
1223 Value *Parallel51CallArgs[] = {
1224 /* identifier*/ Ident,
1225 /* global thread num*/ ThreadID,
1226 /* if expression */ Cond,
1227 /* number of threads */ NumThreads ? NumThreads : Builder.getInt32(C: -1),
1228 /* Proc bind */ Builder.getInt32(C: -1),
1229 /* outlined function */
1230 Builder.CreateBitCast(V: &OutlinedFn, DestTy: OMPIRBuilder->ParallelTaskPtr),
1231 /* wrapper function */ NullPtrValue,
1232 /* arguments of the outlined funciton*/ Args,
1233 /* number of arguments */ Builder.getInt64(C: NumCapturedVars)};
1234
1235 FunctionCallee RTLFn =
1236 OMPIRBuilder->getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_parallel_51);
1237
1238 Builder.CreateCall(Callee: RTLFn, Args: Parallel51CallArgs);
1239
1240 LLVM_DEBUG(dbgs() << "With kmpc_parallel_51 placed: "
1241 << *Builder.GetInsertBlock()->getParent() << "\n");
1242
1243 // Initialize the local TID stack location with the argument value.
1244 Builder.SetInsertPoint(PrivTID);
1245 Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1246 Builder.CreateStore(Val: Builder.CreateLoad(Ty: OMPIRBuilder->Int32, Ptr: OutlinedAI),
1247 Ptr: PrivTIDAddr);
1248
1249 // Remove redundant call to the outlined function.
1250 CI->eraseFromParent();
1251
1252 for (Instruction *I : ToBeDeleted) {
1253 I->eraseFromParent();
1254 }
1255}
1256
1257// Callback used to create OpenMP runtime calls to support
1258// omp parallel clause for the host.
1259// We need to use this callback to replace call to the OutlinedFn in OuterFn
1260// by the call to the OpenMP host runtime function ( __kmpc_fork_call[_if])
1261static void
1262hostParallelCallback(OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn,
1263 Function *OuterFn, Value *Ident, Value *IfCondition,
1264 Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1265 const SmallVector<Instruction *, 4> &ToBeDeleted) {
1266 IRBuilder<> &Builder = OMPIRBuilder->Builder;
1267 FunctionCallee RTLFn;
1268 if (IfCondition) {
1269 RTLFn =
1270 OMPIRBuilder->getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_fork_call_if);
1271 } else {
1272 RTLFn =
1273 OMPIRBuilder->getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_fork_call);
1274 }
1275 if (auto *F = dyn_cast<Function>(Val: RTLFn.getCallee())) {
1276 if (!F->hasMetadata(KindID: LLVMContext::MD_callback)) {
1277 LLVMContext &Ctx = F->getContext();
1278 MDBuilder MDB(Ctx);
1279 // Annotate the callback behavior of the __kmpc_fork_call:
1280 // - The callback callee is argument number 2 (microtask).
1281 // - The first two arguments of the callback callee are unknown (-1).
1282 // - All variadic arguments to the __kmpc_fork_call are passed to the
1283 // callback callee.
1284 F->addMetadata(KindID: LLVMContext::MD_callback,
1285 MD&: *MDNode::get(Context&: Ctx, MDs: {MDB.createCallbackEncoding(
1286 CalleeArgNo: 2, Arguments: {-1, -1},
1287 /* VarArgsArePassed */ true)}));
1288 }
1289 }
1290 // Add some known attributes.
1291 OutlinedFn.addParamAttr(ArgNo: 0, Kind: Attribute::NoAlias);
1292 OutlinedFn.addParamAttr(ArgNo: 1, Kind: Attribute::NoAlias);
1293 OutlinedFn.addFnAttr(Kind: Attribute::NoUnwind);
1294
1295 assert(OutlinedFn.arg_size() >= 2 &&
1296 "Expected at least tid and bounded tid as arguments");
1297 unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1298
1299 CallInst *CI = cast<CallInst>(Val: OutlinedFn.user_back());
1300 CI->getParent()->setName("omp_parallel");
1301 Builder.SetInsertPoint(CI);
1302
1303 // Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn);
1304 Value *ForkCallArgs[] = {
1305 Ident, Builder.getInt32(C: NumCapturedVars),
1306 Builder.CreateBitCast(V: &OutlinedFn, DestTy: OMPIRBuilder->ParallelTaskPtr)};
1307
1308 SmallVector<Value *, 16> RealArgs;
1309 RealArgs.append(in_start: std::begin(arr&: ForkCallArgs), in_end: std::end(arr&: ForkCallArgs));
1310 if (IfCondition) {
1311 Value *Cond = Builder.CreateSExtOrTrunc(V: IfCondition, DestTy: OMPIRBuilder->Int32);
1312 RealArgs.push_back(Elt: Cond);
1313 }
1314 RealArgs.append(in_start: CI->arg_begin() + /* tid & bound tid */ 2, in_end: CI->arg_end());
1315
1316 // __kmpc_fork_call_if always expects a void ptr as the last argument
1317 // If there are no arguments, pass a null pointer.
1318 auto PtrTy = OMPIRBuilder->VoidPtr;
1319 if (IfCondition && NumCapturedVars == 0) {
1320 Value *NullPtrValue = Constant::getNullValue(Ty: PtrTy);
1321 RealArgs.push_back(Elt: NullPtrValue);
1322 }
1323 if (IfCondition && RealArgs.back()->getType() != PtrTy)
1324 RealArgs.back() = Builder.CreateBitCast(V: RealArgs.back(), DestTy: PtrTy);
1325
1326 Builder.CreateCall(Callee: RTLFn, Args: RealArgs);
1327
1328 LLVM_DEBUG(dbgs() << "With fork_call placed: "
1329 << *Builder.GetInsertBlock()->getParent() << "\n");
1330
1331 // Initialize the local TID stack location with the argument value.
1332 Builder.SetInsertPoint(PrivTID);
1333 Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1334 Builder.CreateStore(Val: Builder.CreateLoad(Ty: OMPIRBuilder->Int32, Ptr: OutlinedAI),
1335 Ptr: PrivTIDAddr);
1336
1337 // Remove redundant call to the outlined function.
1338 CI->eraseFromParent();
1339
1340 for (Instruction *I : ToBeDeleted) {
1341 I->eraseFromParent();
1342 }
1343}
1344
1345IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
1346 const LocationDescription &Loc, InsertPointTy OuterAllocaIP,
1347 BodyGenCallbackTy BodyGenCB, PrivatizeCallbackTy PrivCB,
1348 FinalizeCallbackTy FiniCB, Value *IfCondition, Value *NumThreads,
1349 omp::ProcBindKind ProcBind, bool IsCancellable) {
1350 assert(!isConflictIP(Loc.IP, OuterAllocaIP) && "IPs must not be ambiguous");
1351
1352 if (!updateToLocation(Loc))
1353 return Loc.IP;
1354
1355 uint32_t SrcLocStrSize;
1356 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1357 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1358 Value *ThreadID = getOrCreateThreadID(Ident);
1359 // If we generate code for the target device, we need to allocate
1360 // struct for aggregate params in the device default alloca address space.
1361 // OpenMP runtime requires that the params of the extracted functions are
1362 // passed as zero address space pointers. This flag ensures that extracted
1363 // function arguments are declared in zero address space
1364 bool ArgsInZeroAddressSpace = Config.isTargetDevice();
1365
1366 // Build call __kmpc_push_num_threads(&Ident, global_tid, num_threads)
1367 // only if we compile for host side.
1368 if (NumThreads && !Config.isTargetDevice()) {
1369 Value *Args[] = {
1370 Ident, ThreadID,
1371 Builder.CreateIntCast(V: NumThreads, DestTy: Int32, /*isSigned*/ false)};
1372 Builder.CreateCall(
1373 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_push_num_threads), Args);
1374 }
1375
1376 if (ProcBind != OMP_PROC_BIND_default) {
1377 // Build call __kmpc_push_proc_bind(&Ident, global_tid, proc_bind)
1378 Value *Args[] = {
1379 Ident, ThreadID,
1380 ConstantInt::get(Ty: Int32, V: unsigned(ProcBind), /*isSigned=*/IsSigned: true)};
1381 Builder.CreateCall(
1382 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_push_proc_bind), Args);
1383 }
1384
1385 BasicBlock *InsertBB = Builder.GetInsertBlock();
1386 Function *OuterFn = InsertBB->getParent();
1387
1388 // Save the outer alloca block because the insertion iterator may get
1389 // invalidated and we still need this later.
1390 BasicBlock *OuterAllocaBlock = OuterAllocaIP.getBlock();
1391
1392 // Vector to remember instructions we used only during the modeling but which
1393 // we want to delete at the end.
1394 SmallVector<Instruction *, 4> ToBeDeleted;
1395
1396 // Change the location to the outer alloca insertion point to create and
1397 // initialize the allocas we pass into the parallel region.
1398 InsertPointTy NewOuter(OuterAllocaBlock, OuterAllocaBlock->begin());
1399 Builder.restoreIP(IP: NewOuter);
1400 AllocaInst *TIDAddrAlloca = Builder.CreateAlloca(Ty: Int32, ArraySize: nullptr, Name: "tid.addr");
1401 AllocaInst *ZeroAddrAlloca =
1402 Builder.CreateAlloca(Ty: Int32, ArraySize: nullptr, Name: "zero.addr");
1403 Instruction *TIDAddr = TIDAddrAlloca;
1404 Instruction *ZeroAddr = ZeroAddrAlloca;
1405 if (ArgsInZeroAddressSpace && M.getDataLayout().getAllocaAddrSpace() != 0) {
1406 // Add additional casts to enforce pointers in zero address space
1407 TIDAddr = new AddrSpaceCastInst(
1408 TIDAddrAlloca, PointerType ::get(C&: M.getContext(), AddressSpace: 0), "tid.addr.ascast");
1409 TIDAddr->insertAfter(InsertPos: TIDAddrAlloca);
1410 ToBeDeleted.push_back(Elt: TIDAddr);
1411 ZeroAddr = new AddrSpaceCastInst(ZeroAddrAlloca,
1412 PointerType ::get(C&: M.getContext(), AddressSpace: 0),
1413 "zero.addr.ascast");
1414 ZeroAddr->insertAfter(InsertPos: ZeroAddrAlloca);
1415 ToBeDeleted.push_back(Elt: ZeroAddr);
1416 }
1417
1418 // We only need TIDAddr and ZeroAddr for modeling purposes to get the
1419 // associated arguments in the outlined function, so we delete them later.
1420 ToBeDeleted.push_back(Elt: TIDAddrAlloca);
1421 ToBeDeleted.push_back(Elt: ZeroAddrAlloca);
1422
1423 // Create an artificial insertion point that will also ensure the blocks we
1424 // are about to split are not degenerated.
1425 auto *UI = new UnreachableInst(Builder.getContext(), InsertBB);
1426
1427 BasicBlock *EntryBB = UI->getParent();
1428 BasicBlock *PRegEntryBB = EntryBB->splitBasicBlock(I: UI, BBName: "omp.par.entry");
1429 BasicBlock *PRegBodyBB = PRegEntryBB->splitBasicBlock(I: UI, BBName: "omp.par.region");
1430 BasicBlock *PRegPreFiniBB =
1431 PRegBodyBB->splitBasicBlock(I: UI, BBName: "omp.par.pre_finalize");
1432 BasicBlock *PRegExitBB = PRegPreFiniBB->splitBasicBlock(I: UI, BBName: "omp.par.exit");
1433
1434 auto FiniCBWrapper = [&](InsertPointTy IP) {
1435 // Hide "open-ended" blocks from the given FiniCB by setting the right jump
1436 // target to the region exit block.
1437 if (IP.getBlock()->end() == IP.getPoint()) {
1438 IRBuilder<>::InsertPointGuard IPG(Builder);
1439 Builder.restoreIP(IP);
1440 Instruction *I = Builder.CreateBr(Dest: PRegExitBB);
1441 IP = InsertPointTy(I->getParent(), I->getIterator());
1442 }
1443 assert(IP.getBlock()->getTerminator()->getNumSuccessors() == 1 &&
1444 IP.getBlock()->getTerminator()->getSuccessor(0) == PRegExitBB &&
1445 "Unexpected insertion point for finalization call!");
1446 return FiniCB(IP);
1447 };
1448
1449 FinalizationStack.push_back(Elt: {.FiniCB: FiniCBWrapper, .DK: OMPD_parallel, .IsCancellable: IsCancellable});
1450
1451 // Generate the privatization allocas in the block that will become the entry
1452 // of the outlined function.
1453 Builder.SetInsertPoint(PRegEntryBB->getTerminator());
1454 InsertPointTy InnerAllocaIP = Builder.saveIP();
1455
1456 AllocaInst *PrivTIDAddr =
1457 Builder.CreateAlloca(Ty: Int32, ArraySize: nullptr, Name: "tid.addr.local");
1458 Instruction *PrivTID = Builder.CreateLoad(Ty: Int32, Ptr: PrivTIDAddr, Name: "tid");
1459
1460 // Add some fake uses for OpenMP provided arguments.
1461 ToBeDeleted.push_back(Elt: Builder.CreateLoad(Ty: Int32, Ptr: TIDAddr, Name: "tid.addr.use"));
1462 Instruction *ZeroAddrUse =
1463 Builder.CreateLoad(Ty: Int32, Ptr: ZeroAddr, Name: "zero.addr.use");
1464 ToBeDeleted.push_back(Elt: ZeroAddrUse);
1465
1466 // EntryBB
1467 // |
1468 // V
1469 // PRegionEntryBB <- Privatization allocas are placed here.
1470 // |
1471 // V
1472 // PRegionBodyBB <- BodeGen is invoked here.
1473 // |
1474 // V
1475 // PRegPreFiniBB <- The block we will start finalization from.
1476 // |
1477 // V
1478 // PRegionExitBB <- A common exit to simplify block collection.
1479 //
1480
1481 LLVM_DEBUG(dbgs() << "Before body codegen: " << *OuterFn << "\n");
1482
1483 // Let the caller create the body.
1484 assert(BodyGenCB && "Expected body generation callback!");
1485 InsertPointTy CodeGenIP(PRegBodyBB, PRegBodyBB->begin());
1486 BodyGenCB(InnerAllocaIP, CodeGenIP);
1487
1488 LLVM_DEBUG(dbgs() << "After body codegen: " << *OuterFn << "\n");
1489
1490 OutlineInfo OI;
1491 if (Config.isTargetDevice()) {
1492 // Generate OpenMP target specific runtime call
1493 OI.PostOutlineCB = [=, ToBeDeletedVec =
1494 std::move(ToBeDeleted)](Function &OutlinedFn) {
1495 targetParallelCallback(OMPIRBuilder: this, OutlinedFn, OuterFn, OuterAllocaBB: OuterAllocaBlock, Ident,
1496 IfCondition, NumThreads, PrivTID, PrivTIDAddr,
1497 ThreadID, ToBeDeleted: ToBeDeletedVec);
1498 };
1499 } else {
1500 // Generate OpenMP host runtime call
1501 OI.PostOutlineCB = [=, ToBeDeletedVec =
1502 std::move(ToBeDeleted)](Function &OutlinedFn) {
1503 hostParallelCallback(OMPIRBuilder: this, OutlinedFn, OuterFn, Ident, IfCondition,
1504 PrivTID, PrivTIDAddr, ToBeDeleted: ToBeDeletedVec);
1505 };
1506 }
1507
1508 OI.OuterAllocaBB = OuterAllocaBlock;
1509 OI.EntryBB = PRegEntryBB;
1510 OI.ExitBB = PRegExitBB;
1511
1512 SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
1513 SmallVector<BasicBlock *, 32> Blocks;
1514 OI.collectBlocks(BlockSet&: ParallelRegionBlockSet, BlockVector&: Blocks);
1515
1516 // Ensure a single exit node for the outlined region by creating one.
1517 // We might have multiple incoming edges to the exit now due to finalizations,
1518 // e.g., cancel calls that cause the control flow to leave the region.
1519 BasicBlock *PRegOutlinedExitBB = PRegExitBB;
1520 PRegExitBB = SplitBlock(Old: PRegExitBB, SplitPt: &*PRegExitBB->getFirstInsertionPt());
1521 PRegOutlinedExitBB->setName("omp.par.outlined.exit");
1522 Blocks.push_back(Elt: PRegOutlinedExitBB);
1523
1524 CodeExtractorAnalysisCache CEAC(*OuterFn);
1525 CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
1526 /* AggregateArgs */ false,
1527 /* BlockFrequencyInfo */ nullptr,
1528 /* BranchProbabilityInfo */ nullptr,
1529 /* AssumptionCache */ nullptr,
1530 /* AllowVarArgs */ true,
1531 /* AllowAlloca */ true,
1532 /* AllocationBlock */ OuterAllocaBlock,
1533 /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
1534
1535 // Find inputs to, outputs from the code region.
1536 BasicBlock *CommonExit = nullptr;
1537 SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
1538 Extractor.findAllocas(CEAC, SinkCands&: SinkingCands, HoistCands&: HoistingCands, ExitBlock&: CommonExit);
1539 Extractor.findInputsOutputs(Inputs, Outputs, Allocas: SinkingCands);
1540
1541 LLVM_DEBUG(dbgs() << "Before privatization: " << *OuterFn << "\n");
1542
1543 FunctionCallee TIDRTLFn =
1544 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_global_thread_num);
1545
1546 auto PrivHelper = [&](Value &V) {
1547 if (&V == TIDAddr || &V == ZeroAddr) {
1548 OI.ExcludeArgsFromAggregate.push_back(Elt: &V);
1549 return;
1550 }
1551
1552 SetVector<Use *> Uses;
1553 for (Use &U : V.uses())
1554 if (auto *UserI = dyn_cast<Instruction>(Val: U.getUser()))
1555 if (ParallelRegionBlockSet.count(Ptr: UserI->getParent()))
1556 Uses.insert(X: &U);
1557
1558 // __kmpc_fork_call expects extra arguments as pointers. If the input
1559 // already has a pointer type, everything is fine. Otherwise, store the
1560 // value onto stack and load it back inside the to-be-outlined region. This
1561 // will ensure only the pointer will be passed to the function.
1562 // FIXME: if there are more than 15 trailing arguments, they must be
1563 // additionally packed in a struct.
1564 Value *Inner = &V;
1565 if (!V.getType()->isPointerTy()) {
1566 IRBuilder<>::InsertPointGuard Guard(Builder);
1567 LLVM_DEBUG(llvm::dbgs() << "Forwarding input as pointer: " << V << "\n");
1568
1569 Builder.restoreIP(IP: OuterAllocaIP);
1570 Value *Ptr =
1571 Builder.CreateAlloca(Ty: V.getType(), ArraySize: nullptr, Name: V.getName() + ".reloaded");
1572
1573 // Store to stack at end of the block that currently branches to the entry
1574 // block of the to-be-outlined region.
1575 Builder.SetInsertPoint(TheBB: InsertBB,
1576 IP: InsertBB->getTerminator()->getIterator());
1577 Builder.CreateStore(Val: &V, Ptr);
1578
1579 // Load back next to allocations in the to-be-outlined region.
1580 Builder.restoreIP(IP: InnerAllocaIP);
1581 Inner = Builder.CreateLoad(Ty: V.getType(), Ptr);
1582 }
1583
1584 Value *ReplacementValue = nullptr;
1585 CallInst *CI = dyn_cast<CallInst>(Val: &V);
1586 if (CI && CI->getCalledFunction() == TIDRTLFn.getCallee()) {
1587 ReplacementValue = PrivTID;
1588 } else {
1589 Builder.restoreIP(
1590 IP: PrivCB(InnerAllocaIP, Builder.saveIP(), V, *Inner, ReplacementValue));
1591 InnerAllocaIP = {
1592 InnerAllocaIP.getBlock(),
1593 InnerAllocaIP.getBlock()->getTerminator()->getIterator()};
1594
1595 assert(ReplacementValue &&
1596 "Expected copy/create callback to set replacement value!");
1597 if (ReplacementValue == &V)
1598 return;
1599 }
1600
1601 for (Use *UPtr : Uses)
1602 UPtr->set(ReplacementValue);
1603 };
1604
1605 // Reset the inner alloca insertion as it will be used for loading the values
1606 // wrapped into pointers before passing them into the to-be-outlined region.
1607 // Configure it to insert immediately after the fake use of zero address so
1608 // that they are available in the generated body and so that the
1609 // OpenMP-related values (thread ID and zero address pointers) remain leading
1610 // in the argument list.
1611 InnerAllocaIP = IRBuilder<>::InsertPoint(
1612 ZeroAddrUse->getParent(), ZeroAddrUse->getNextNode()->getIterator());
1613
1614 // Reset the outer alloca insertion point to the entry of the relevant block
1615 // in case it was invalidated.
1616 OuterAllocaIP = IRBuilder<>::InsertPoint(
1617 OuterAllocaBlock, OuterAllocaBlock->getFirstInsertionPt());
1618
1619 for (Value *Input : Inputs) {
1620 LLVM_DEBUG(dbgs() << "Captured input: " << *Input << "\n");
1621 PrivHelper(*Input);
1622 }
1623 LLVM_DEBUG({
1624 for (Value *Output : Outputs)
1625 LLVM_DEBUG(dbgs() << "Captured output: " << *Output << "\n");
1626 });
1627 assert(Outputs.empty() &&
1628 "OpenMP outlining should not produce live-out values!");
1629
1630 LLVM_DEBUG(dbgs() << "After privatization: " << *OuterFn << "\n");
1631 LLVM_DEBUG({
1632 for (auto *BB : Blocks)
1633 dbgs() << " PBR: " << BB->getName() << "\n";
1634 });
1635
1636 // Adjust the finalization stack, verify the adjustment, and call the
1637 // finalize function a last time to finalize values between the pre-fini
1638 // block and the exit block if we left the parallel "the normal way".
1639 auto FiniInfo = FinalizationStack.pop_back_val();
1640 (void)FiniInfo;
1641 assert(FiniInfo.DK == OMPD_parallel &&
1642 "Unexpected finalization stack state!");
1643
1644 Instruction *PRegPreFiniTI = PRegPreFiniBB->getTerminator();
1645
1646 InsertPointTy PreFiniIP(PRegPreFiniBB, PRegPreFiniTI->getIterator());
1647 FiniCB(PreFiniIP);
1648
1649 // Register the outlined info.
1650 addOutlineInfo(OI: std::move(OI));
1651
1652 InsertPointTy AfterIP(UI->getParent(), UI->getParent()->end());
1653 UI->eraseFromParent();
1654
1655 return AfterIP;
1656}
1657
1658void OpenMPIRBuilder::emitFlush(const LocationDescription &Loc) {
1659 // Build call void __kmpc_flush(ident_t *loc)
1660 uint32_t SrcLocStrSize;
1661 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1662 Value *Args[] = {getOrCreateIdent(SrcLocStr, SrcLocStrSize)};
1663
1664 Builder.CreateCall(Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_flush), Args);
1665}
1666
1667void OpenMPIRBuilder::createFlush(const LocationDescription &Loc) {
1668 if (!updateToLocation(Loc))
1669 return;
1670 emitFlush(Loc);
1671}
1672
1673void OpenMPIRBuilder::emitTaskwaitImpl(const LocationDescription &Loc) {
1674 // Build call kmp_int32 __kmpc_omp_taskwait(ident_t *loc, kmp_int32
1675 // global_tid);
1676 uint32_t SrcLocStrSize;
1677 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1678 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1679 Value *Args[] = {Ident, getOrCreateThreadID(Ident)};
1680
1681 // Ignore return result until untied tasks are supported.
1682 Builder.CreateCall(Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_taskwait),
1683 Args);
1684}
1685
1686void OpenMPIRBuilder::createTaskwait(const LocationDescription &Loc) {
1687 if (!updateToLocation(Loc))
1688 return;
1689 emitTaskwaitImpl(Loc);
1690}
1691
1692void OpenMPIRBuilder::emitTaskyieldImpl(const LocationDescription &Loc) {
1693 // Build call __kmpc_omp_taskyield(loc, thread_id, 0);
1694 uint32_t SrcLocStrSize;
1695 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1696 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1697 Constant *I32Null = ConstantInt::getNullValue(Ty: Int32);
1698 Value *Args[] = {Ident, getOrCreateThreadID(Ident), I32Null};
1699
1700 Builder.CreateCall(Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_taskyield),
1701 Args);
1702}
1703
1704void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
1705 if (!updateToLocation(Loc))
1706 return;
1707 emitTaskyieldImpl(Loc);
1708}
1709
1710// Processes the dependencies in Dependencies and does the following
1711// - Allocates space on the stack of an array of DependInfo objects
1712// - Populates each DependInfo object with relevant information of
1713// the corresponding dependence.
1714// - All code is inserted in the entry block of the current function.
1715static Value *emitTaskDependencies(
1716 OpenMPIRBuilder &OMPBuilder,
1717 SmallVectorImpl<OpenMPIRBuilder::DependData> &Dependencies) {
1718 // Early return if we have no dependencies to process
1719 if (Dependencies.empty())
1720 return nullptr;
1721
1722 // Given a vector of DependData objects, in this function we create an
1723 // array on the stack that holds kmp_dep_info objects corresponding
1724 // to each dependency. This is then passed to the OpenMP runtime.
1725 // For example, if there are 'n' dependencies then the following psedo
1726 // code is generated. Assume the first dependence is on a variable 'a'
1727 //
1728 // \code{c}
1729 // DepArray = alloc(n x sizeof(kmp_depend_info);
1730 // idx = 0;
1731 // DepArray[idx].base_addr = ptrtoint(&a);
1732 // DepArray[idx].len = 8;
1733 // DepArray[idx].flags = Dep.DepKind; /*(See OMPContants.h for DepKind)*/
1734 // ++idx;
1735 // DepArray[idx].base_addr = ...;
1736 // \endcode
1737
1738 IRBuilderBase &Builder = OMPBuilder.Builder;
1739 Type *DependInfo = OMPBuilder.DependInfo;
1740 Module &M = OMPBuilder.M;
1741
1742 Value *DepArray = nullptr;
1743 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
1744 Builder.SetInsertPoint(
1745 OldIP.getBlock()->getParent()->getEntryBlock().getTerminator());
1746
1747 Type *DepArrayTy = ArrayType::get(ElementType: DependInfo, NumElements: Dependencies.size());
1748 DepArray = Builder.CreateAlloca(Ty: DepArrayTy, ArraySize: nullptr, Name: ".dep.arr.addr");
1749
1750 for (const auto &[DepIdx, Dep] : enumerate(First&: Dependencies)) {
1751 Value *Base =
1752 Builder.CreateConstInBoundsGEP2_64(Ty: DepArrayTy, Ptr: DepArray, Idx0: 0, Idx1: DepIdx);
1753 // Store the pointer to the variable
1754 Value *Addr = Builder.CreateStructGEP(
1755 Ty: DependInfo, Ptr: Base,
1756 Idx: static_cast<unsigned int>(RTLDependInfoFields::BaseAddr));
1757 Value *DepValPtr = Builder.CreatePtrToInt(V: Dep.DepVal, DestTy: Builder.getInt64Ty());
1758 Builder.CreateStore(Val: DepValPtr, Ptr: Addr);
1759 // Store the size of the variable
1760 Value *Size = Builder.CreateStructGEP(
1761 Ty: DependInfo, Ptr: Base, Idx: static_cast<unsigned int>(RTLDependInfoFields::Len));
1762 Builder.CreateStore(
1763 Val: Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: Dep.DepValueType)),
1764 Ptr: Size);
1765 // Store the dependency kind
1766 Value *Flags = Builder.CreateStructGEP(
1767 Ty: DependInfo, Ptr: Base,
1768 Idx: static_cast<unsigned int>(RTLDependInfoFields::Flags));
1769 Builder.CreateStore(
1770 Val: ConstantInt::get(Ty: Builder.getInt8Ty(),
1771 V: static_cast<unsigned int>(Dep.DepKind)),
1772 Ptr: Flags);
1773 }
1774 Builder.restoreIP(IP: OldIP);
1775 return DepArray;
1776}
1777
1778OpenMPIRBuilder::InsertPointTy
1779OpenMPIRBuilder::createTask(const LocationDescription &Loc,
1780 InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
1781 bool Tied, Value *Final, Value *IfCondition,
1782 SmallVector<DependData> Dependencies) {
1783
1784 if (!updateToLocation(Loc))
1785 return InsertPointTy();
1786
1787 uint32_t SrcLocStrSize;
1788 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1789 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1790 // The current basic block is split into four basic blocks. After outlining,
1791 // they will be mapped as follows:
1792 // ```
1793 // def current_fn() {
1794 // current_basic_block:
1795 // br label %task.exit
1796 // task.exit:
1797 // ; instructions after task
1798 // }
1799 // def outlined_fn() {
1800 // task.alloca:
1801 // br label %task.body
1802 // task.body:
1803 // ret void
1804 // }
1805 // ```
1806 BasicBlock *TaskExitBB = splitBB(Builder, /*CreateBranch=*/true, Name: "task.exit");
1807 BasicBlock *TaskBodyBB = splitBB(Builder, /*CreateBranch=*/true, Name: "task.body");
1808 BasicBlock *TaskAllocaBB =
1809 splitBB(Builder, /*CreateBranch=*/true, Name: "task.alloca");
1810
1811 InsertPointTy TaskAllocaIP =
1812 InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin());
1813 InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin());
1814 BodyGenCB(TaskAllocaIP, TaskBodyIP);
1815
1816 OutlineInfo OI;
1817 OI.EntryBB = TaskAllocaBB;
1818 OI.OuterAllocaBB = AllocaIP.getBlock();
1819 OI.ExitBB = TaskExitBB;
1820
1821 // Add the thread ID argument.
1822 SmallVector<Instruction *, 4> ToBeDeleted;
1823 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
1824 Builder, OuterAllocaIP: AllocaIP, ToBeDeleted, InnerAllocaIP: TaskAllocaIP, Name: "global.tid", AsPtr: false));
1825
1826 OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
1827 TaskAllocaBB, ToBeDeleted](Function &OutlinedFn) mutable {
1828 // Replace the Stale CI by appropriate RTL function call.
1829 assert(OutlinedFn.getNumUses() == 1 &&
1830 "there must be a single user for the outlined function");
1831 CallInst *StaleCI = cast<CallInst>(Val: OutlinedFn.user_back());
1832
1833 // HasShareds is true if any variables are captured in the outlined region,
1834 // false otherwise.
1835 bool HasShareds = StaleCI->arg_size() > 1;
1836 Builder.SetInsertPoint(StaleCI);
1837
1838 // Gather the arguments for emitting the runtime call for
1839 // @__kmpc_omp_task_alloc
1840 Function *TaskAllocFn =
1841 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_alloc);
1842
1843 // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
1844 // call.
1845 Value *ThreadID = getOrCreateThreadID(Ident);
1846
1847 // Argument - `flags`
1848 // Task is tied iff (Flags & 1) == 1.
1849 // Task is untied iff (Flags & 1) == 0.
1850 // Task is final iff (Flags & 2) == 2.
1851 // Task is not final iff (Flags & 2) == 0.
1852 // TODO: Handle the other flags.
1853 Value *Flags = Builder.getInt32(C: Tied);
1854 if (Final) {
1855 Value *FinalFlag =
1856 Builder.CreateSelect(C: Final, True: Builder.getInt32(C: 2), False: Builder.getInt32(C: 0));
1857 Flags = Builder.CreateOr(LHS: FinalFlag, RHS: Flags);
1858 }
1859
1860 // Argument - `sizeof_kmp_task_t` (TaskSize)
1861 // Tasksize refers to the size in bytes of kmp_task_t data structure
1862 // including private vars accessed in task.
1863 // TODO: add kmp_task_t_with_privates (privates)
1864 Value *TaskSize = Builder.getInt64(
1865 C: divideCeil(Numerator: M.getDataLayout().getTypeSizeInBits(Ty: Task), Denominator: 8));
1866
1867 // Argument - `sizeof_shareds` (SharedsSize)
1868 // SharedsSize refers to the shareds array size in the kmp_task_t data
1869 // structure.
1870 Value *SharedsSize = Builder.getInt64(C: 0);
1871 if (HasShareds) {
1872 AllocaInst *ArgStructAlloca =
1873 dyn_cast<AllocaInst>(Val: StaleCI->getArgOperand(i: 1));
1874 assert(ArgStructAlloca &&
1875 "Unable to find the alloca instruction corresponding to arguments "
1876 "for extracted function");
1877 StructType *ArgStructType =
1878 dyn_cast<StructType>(Val: ArgStructAlloca->getAllocatedType());
1879 assert(ArgStructType && "Unable to find struct type corresponding to "
1880 "arguments for extracted function");
1881 SharedsSize =
1882 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: ArgStructType));
1883 }
1884 // Emit the @__kmpc_omp_task_alloc runtime call
1885 // The runtime call returns a pointer to an area where the task captured
1886 // variables must be copied before the task is run (TaskData)
1887 CallInst *TaskData = Builder.CreateCall(
1888 Callee: TaskAllocFn, Args: {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
1889 /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
1890 /*task_func=*/&OutlinedFn});
1891
1892 // Copy the arguments for outlined function
1893 if (HasShareds) {
1894 Value *Shareds = StaleCI->getArgOperand(i: 1);
1895 Align Alignment = TaskData->getPointerAlignment(DL: M.getDataLayout());
1896 Value *TaskShareds = Builder.CreateLoad(Ty: VoidPtr, Ptr: TaskData);
1897 Builder.CreateMemCpy(Dst: TaskShareds, DstAlign: Alignment, Src: Shareds, SrcAlign: Alignment,
1898 Size: SharedsSize);
1899 }
1900
1901 Value *DepArray = nullptr;
1902 if (Dependencies.size()) {
1903 InsertPointTy OldIP = Builder.saveIP();
1904 Builder.SetInsertPoint(
1905 &OldIP.getBlock()->getParent()->getEntryBlock().back());
1906
1907 Type *DepArrayTy = ArrayType::get(ElementType: DependInfo, NumElements: Dependencies.size());
1908 DepArray = Builder.CreateAlloca(Ty: DepArrayTy, ArraySize: nullptr, Name: ".dep.arr.addr");
1909
1910 unsigned P = 0;
1911 for (const DependData &Dep : Dependencies) {
1912 Value *Base =
1913 Builder.CreateConstInBoundsGEP2_64(Ty: DepArrayTy, Ptr: DepArray, Idx0: 0, Idx1: P);
1914 // Store the pointer to the variable
1915 Value *Addr = Builder.CreateStructGEP(
1916 Ty: DependInfo, Ptr: Base,
1917 Idx: static_cast<unsigned int>(RTLDependInfoFields::BaseAddr));
1918 Value *DepValPtr =
1919 Builder.CreatePtrToInt(V: Dep.DepVal, DestTy: Builder.getInt64Ty());
1920 Builder.CreateStore(Val: DepValPtr, Ptr: Addr);
1921 // Store the size of the variable
1922 Value *Size = Builder.CreateStructGEP(
1923 Ty: DependInfo, Ptr: Base,
1924 Idx: static_cast<unsigned int>(RTLDependInfoFields::Len));
1925 Builder.CreateStore(Val: Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(
1926 Ty: Dep.DepValueType)),
1927 Ptr: Size);
1928 // Store the dependency kind
1929 Value *Flags = Builder.CreateStructGEP(
1930 Ty: DependInfo, Ptr: Base,
1931 Idx: static_cast<unsigned int>(RTLDependInfoFields::Flags));
1932 Builder.CreateStore(
1933 Val: ConstantInt::get(Ty: Builder.getInt8Ty(),
1934 V: static_cast<unsigned int>(Dep.DepKind)),
1935 Ptr: Flags);
1936 ++P;
1937 }
1938
1939 Builder.restoreIP(IP: OldIP);
1940 }
1941
1942 // In the presence of the `if` clause, the following IR is generated:
1943 // ...
1944 // %data = call @__kmpc_omp_task_alloc(...)
1945 // br i1 %if_condition, label %then, label %else
1946 // then:
1947 // call @__kmpc_omp_task(...)
1948 // br label %exit
1949 // else:
1950 // ;; Wait for resolution of dependencies, if any, before
1951 // ;; beginning the task
1952 // call @__kmpc_omp_wait_deps(...)
1953 // call @__kmpc_omp_task_begin_if0(...)
1954 // call @outlined_fn(...)
1955 // call @__kmpc_omp_task_complete_if0(...)
1956 // br label %exit
1957 // exit:
1958 // ...
1959 if (IfCondition) {
1960 // `SplitBlockAndInsertIfThenElse` requires the block to have a
1961 // terminator.
1962 splitBB(Builder, /*CreateBranch=*/true, Name: "if.end");
1963 Instruction *IfTerminator =
1964 Builder.GetInsertPoint()->getParent()->getTerminator();
1965 Instruction *ThenTI = IfTerminator, *ElseTI = nullptr;
1966 Builder.SetInsertPoint(IfTerminator);
1967 SplitBlockAndInsertIfThenElse(Cond: IfCondition, SplitBefore: IfTerminator, ThenTerm: &ThenTI,
1968 ElseTerm: &ElseTI);
1969 Builder.SetInsertPoint(ElseTI);
1970
1971 if (Dependencies.size()) {
1972 Function *TaskWaitFn =
1973 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_wait_deps);
1974 Builder.CreateCall(
1975 Callee: TaskWaitFn,
1976 Args: {Ident, ThreadID, Builder.getInt32(C: Dependencies.size()), DepArray,
1977 ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
1978 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
1979 }
1980 Function *TaskBeginFn =
1981 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_begin_if0);
1982 Function *TaskCompleteFn =
1983 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_complete_if0);
1984 Builder.CreateCall(Callee: TaskBeginFn, Args: {Ident, ThreadID, TaskData});
1985 CallInst *CI = nullptr;
1986 if (HasShareds)
1987 CI = Builder.CreateCall(Callee: &OutlinedFn, Args: {ThreadID, TaskData});
1988 else
1989 CI = Builder.CreateCall(Callee: &OutlinedFn, Args: {ThreadID});
1990 CI->setDebugLoc(StaleCI->getDebugLoc());
1991 Builder.CreateCall(Callee: TaskCompleteFn, Args: {Ident, ThreadID, TaskData});
1992 Builder.SetInsertPoint(ThenTI);
1993 }
1994
1995 if (Dependencies.size()) {
1996 Function *TaskFn =
1997 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_with_deps);
1998 Builder.CreateCall(
1999 Callee: TaskFn,
2000 Args: {Ident, ThreadID, TaskData, Builder.getInt32(C: Dependencies.size()),
2001 DepArray, ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
2002 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
2003
2004 } else {
2005 // Emit the @__kmpc_omp_task runtime call to spawn the task
2006 Function *TaskFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task);
2007 Builder.CreateCall(Callee: TaskFn, Args: {Ident, ThreadID, TaskData});
2008 }
2009
2010 StaleCI->eraseFromParent();
2011
2012 Builder.SetInsertPoint(TheBB: TaskAllocaBB, IP: TaskAllocaBB->begin());
2013 if (HasShareds) {
2014 LoadInst *Shareds = Builder.CreateLoad(Ty: VoidPtr, Ptr: OutlinedFn.getArg(i: 1));
2015 OutlinedFn.getArg(i: 1)->replaceUsesWithIf(
2016 New: Shareds, ShouldReplace: [Shareds](Use &U) { return U.getUser() != Shareds; });
2017 }
2018
2019 llvm::for_each(Range: llvm::reverse(C&: ToBeDeleted),
2020 F: [](Instruction *I) { I->eraseFromParent(); });
2021 };
2022
2023 addOutlineInfo(OI: std::move(OI));
2024 Builder.SetInsertPoint(TheBB: TaskExitBB, IP: TaskExitBB->begin());
2025
2026 return Builder.saveIP();
2027}
2028
2029OpenMPIRBuilder::InsertPointTy
2030OpenMPIRBuilder::createTaskgroup(const LocationDescription &Loc,
2031 InsertPointTy AllocaIP,
2032 BodyGenCallbackTy BodyGenCB) {
2033 if (!updateToLocation(Loc))
2034 return InsertPointTy();
2035
2036 uint32_t SrcLocStrSize;
2037 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2038 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2039 Value *ThreadID = getOrCreateThreadID(Ident);
2040
2041 // Emit the @__kmpc_taskgroup runtime call to start the taskgroup
2042 Function *TaskgroupFn =
2043 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_taskgroup);
2044 Builder.CreateCall(Callee: TaskgroupFn, Args: {Ident, ThreadID});
2045
2046 BasicBlock *TaskgroupExitBB = splitBB(Builder, CreateBranch: true, Name: "taskgroup.exit");
2047 BodyGenCB(AllocaIP, Builder.saveIP());
2048
2049 Builder.SetInsertPoint(TaskgroupExitBB);
2050 // Emit the @__kmpc_end_taskgroup runtime call to end the taskgroup
2051 Function *EndTaskgroupFn =
2052 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_taskgroup);
2053 Builder.CreateCall(Callee: EndTaskgroupFn, Args: {Ident, ThreadID});
2054
2055 return Builder.saveIP();
2056}
2057
2058OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSections(
2059 const LocationDescription &Loc, InsertPointTy AllocaIP,
2060 ArrayRef<StorableBodyGenCallbackTy> SectionCBs, PrivatizeCallbackTy PrivCB,
2061 FinalizeCallbackTy FiniCB, bool IsCancellable, bool IsNowait) {
2062 assert(!isConflictIP(AllocaIP, Loc.IP) && "Dedicated IP allocas required");
2063
2064 if (!updateToLocation(Loc))
2065 return Loc.IP;
2066
2067 auto FiniCBWrapper = [&](InsertPointTy IP) {
2068 if (IP.getBlock()->end() != IP.getPoint())
2069 return FiniCB(IP);
2070 // This must be done otherwise any nested constructs using FinalizeOMPRegion
2071 // will fail because that function requires the Finalization Basic Block to
2072 // have a terminator, which is already removed by EmitOMPRegionBody.
2073 // IP is currently at cancelation block.
2074 // We need to backtrack to the condition block to fetch
2075 // the exit block and create a branch from cancelation
2076 // to exit block.
2077 IRBuilder<>::InsertPointGuard IPG(Builder);
2078 Builder.restoreIP(IP);
2079 auto *CaseBB = IP.getBlock()->getSinglePredecessor();
2080 auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
2081 auto *ExitBB = CondBB->getTerminator()->getSuccessor(Idx: 1);
2082 Instruction *I = Builder.CreateBr(Dest: ExitBB);
2083 IP = InsertPointTy(I->getParent(), I->getIterator());
2084 return FiniCB(IP);
2085 };
2086
2087 FinalizationStack.push_back(Elt: {.FiniCB: FiniCBWrapper, .DK: OMPD_sections, .IsCancellable: IsCancellable});
2088
2089 // Each section is emitted as a switch case
2090 // Each finalization callback is handled from clang.EmitOMPSectionDirective()
2091 // -> OMP.createSection() which generates the IR for each section
2092 // Iterate through all sections and emit a switch construct:
2093 // switch (IV) {
2094 // case 0:
2095 // <SectionStmt[0]>;
2096 // break;
2097 // ...
2098 // case <NumSection> - 1:
2099 // <SectionStmt[<NumSection> - 1]>;
2100 // break;
2101 // }
2102 // ...
2103 // section_loop.after:
2104 // <FiniCB>;
2105 auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, Value *IndVar) {
2106 Builder.restoreIP(IP: CodeGenIP);
2107 BasicBlock *Continue =
2108 splitBBWithSuffix(Builder, /*CreateBranch=*/false, Suffix: ".sections.after");
2109 Function *CurFn = Continue->getParent();
2110 SwitchInst *SwitchStmt = Builder.CreateSwitch(V: IndVar, Dest: Continue);
2111
2112 unsigned CaseNumber = 0;
2113 for (auto SectionCB : SectionCBs) {
2114 BasicBlock *CaseBB = BasicBlock::Create(
2115 Context&: M.getContext(), Name: "omp_section_loop.body.case", Parent: CurFn, InsertBefore: Continue);
2116 SwitchStmt->addCase(OnVal: Builder.getInt32(C: CaseNumber), Dest: CaseBB);
2117 Builder.SetInsertPoint(CaseBB);
2118 BranchInst *CaseEndBr = Builder.CreateBr(Dest: Continue);
2119 SectionCB(InsertPointTy(),
2120 {CaseEndBr->getParent(), CaseEndBr->getIterator()});
2121 CaseNumber++;
2122 }
2123 // remove the existing terminator from body BB since there can be no
2124 // terminators after switch/case
2125 };
2126 // Loop body ends here
2127 // LowerBound, UpperBound, and STride for createCanonicalLoop
2128 Type *I32Ty = Type::getInt32Ty(C&: M.getContext());
2129 Value *LB = ConstantInt::get(Ty: I32Ty, V: 0);
2130 Value *UB = ConstantInt::get(Ty: I32Ty, V: SectionCBs.size());
2131 Value *ST = ConstantInt::get(Ty: I32Ty, V: 1);
2132 llvm::CanonicalLoopInfo *LoopInfo = createCanonicalLoop(
2133 Loc, BodyGenCB: LoopBodyGenCB, Start: LB, Stop: UB, Step: ST, IsSigned: true, InclusiveStop: false, ComputeIP: AllocaIP, Name: "section_loop");
2134 InsertPointTy AfterIP =
2135 applyStaticWorkshareLoop(DL: Loc.DL, CLI: LoopInfo, AllocaIP, NeedsBarrier: !IsNowait);
2136
2137 // Apply the finalization callback in LoopAfterBB
2138 auto FiniInfo = FinalizationStack.pop_back_val();
2139 assert(FiniInfo.DK == OMPD_sections &&
2140 "Unexpected finalization stack state!");
2141 if (FinalizeCallbackTy &CB = FiniInfo.FiniCB) {
2142 Builder.restoreIP(IP: AfterIP);
2143 BasicBlock *FiniBB =
2144 splitBBWithSuffix(Builder, /*CreateBranch=*/true, Suffix: "sections.fini");
2145 CB(Builder.saveIP());
2146 AfterIP = {FiniBB, FiniBB->begin()};
2147 }
2148
2149 return AfterIP;
2150}
2151
2152OpenMPIRBuilder::InsertPointTy
2153OpenMPIRBuilder::createSection(const LocationDescription &Loc,
2154 BodyGenCallbackTy BodyGenCB,
2155 FinalizeCallbackTy FiniCB) {
2156 if (!updateToLocation(Loc))
2157 return Loc.IP;
2158
2159 auto FiniCBWrapper = [&](InsertPointTy IP) {
2160 if (IP.getBlock()->end() != IP.getPoint())
2161 return FiniCB(IP);
2162 // This must be done otherwise any nested constructs using FinalizeOMPRegion
2163 // will fail because that function requires the Finalization Basic Block to
2164 // have a terminator, which is already removed by EmitOMPRegionBody.
2165 // IP is currently at cancelation block.
2166 // We need to backtrack to the condition block to fetch
2167 // the exit block and create a branch from cancelation
2168 // to exit block.
2169 IRBuilder<>::InsertPointGuard IPG(Builder);
2170 Builder.restoreIP(IP);
2171 auto *CaseBB = Loc.IP.getBlock();
2172 auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
2173 auto *ExitBB = CondBB->getTerminator()->getSuccessor(Idx: 1);
2174 Instruction *I = Builder.CreateBr(Dest: ExitBB);
2175 IP = InsertPointTy(I->getParent(), I->getIterator());
2176 return FiniCB(IP);
2177 };
2178
2179 Directive OMPD = Directive::OMPD_sections;
2180 // Since we are using Finalization Callback here, HasFinalize
2181 // and IsCancellable have to be true
2182 return EmitOMPInlinedRegion(OMPD, EntryCall: nullptr, ExitCall: nullptr, BodyGenCB, FiniCB: FiniCBWrapper,
2183 /*Conditional*/ false, /*hasFinalize*/ HasFinalize: true,
2184 /*IsCancellable*/ true);
2185}
2186
2187static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
2188 BasicBlock::iterator IT(I);
2189 IT++;
2190 return OpenMPIRBuilder::InsertPointTy(I->getParent(), IT);
2191}
2192
2193void OpenMPIRBuilder::emitUsed(StringRef Name,
2194 std::vector<WeakTrackingVH> &List) {
2195 if (List.empty())
2196 return;
2197
2198 // Convert List to what ConstantArray needs.
2199 SmallVector<Constant *, 8> UsedArray;
2200 UsedArray.resize(N: List.size());
2201 for (unsigned I = 0, E = List.size(); I != E; ++I)
2202 UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
2203 C: cast<Constant>(Val: &*List[I]), Ty: Builder.getPtrTy());
2204
2205 if (UsedArray.empty())
2206 return;
2207 ArrayType *ATy = ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: UsedArray.size());
2208
2209 auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
2210 ConstantArray::get(T: ATy, V: UsedArray), Name);
2211
2212 GV->setSection("llvm.metadata");
2213}
2214
2215Value *OpenMPIRBuilder::getGPUThreadID() {
2216 return Builder.CreateCall(
2217 Callee: getOrCreateRuntimeFunction(M,
2218 FnID: OMPRTL___kmpc_get_hardware_thread_id_in_block),
2219 Args: {});
2220}
2221
2222Value *OpenMPIRBuilder::getGPUWarpSize() {
2223 return Builder.CreateCall(
2224 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___kmpc_get_warp_size), Args: {});
2225}
2226
2227Value *OpenMPIRBuilder::getNVPTXWarpID() {
2228 unsigned LaneIDBits = Log2_32(Value: Config.getGridValue().GV_Warp_Size);
2229 return Builder.CreateAShr(LHS: getGPUThreadID(), RHS: LaneIDBits, Name: "nvptx_warp_id");
2230}
2231
2232Value *OpenMPIRBuilder::getNVPTXLaneID() {
2233 unsigned LaneIDBits = Log2_32(Value: Config.getGridValue().GV_Warp_Size);
2234 assert(LaneIDBits < 32 && "Invalid LaneIDBits size in NVPTX device.");
2235 unsigned LaneIDMask = ~0u >> (32u - LaneIDBits);
2236 return Builder.CreateAnd(LHS: getGPUThreadID(), RHS: Builder.getInt32(C: LaneIDMask),
2237 Name: "nvptx_lane_id");
2238}
2239
2240Value *OpenMPIRBuilder::castValueToType(InsertPointTy AllocaIP, Value *From,
2241 Type *ToType) {
2242 Type *FromType = From->getType();
2243 uint64_t FromSize = M.getDataLayout().getTypeStoreSize(Ty: FromType);
2244 uint64_t ToSize = M.getDataLayout().getTypeStoreSize(Ty: ToType);
2245 assert(FromSize > 0 && "From size must be greater than zero");
2246 assert(ToSize > 0 && "To size must be greater than zero");
2247 if (FromType == ToType)
2248 return From;
2249 if (FromSize == ToSize)
2250 return Builder.CreateBitCast(V: From, DestTy: ToType);
2251 if (ToType->isIntegerTy() && FromType->isIntegerTy())
2252 return Builder.CreateIntCast(V: From, DestTy: ToType, /*isSigned*/ true);
2253 InsertPointTy SaveIP = Builder.saveIP();
2254 Builder.restoreIP(IP: AllocaIP);
2255 Value *CastItem = Builder.CreateAlloca(Ty: ToType);
2256 Builder.restoreIP(IP: SaveIP);
2257
2258 Value *ValCastItem = Builder.CreatePointerBitCastOrAddrSpaceCast(
2259 V: CastItem, DestTy: FromType->getPointerTo());
2260 Builder.CreateStore(Val: From, Ptr: ValCastItem);
2261 return Builder.CreateLoad(Ty: ToType, Ptr: CastItem);
2262}
2263
2264Value *OpenMPIRBuilder::createRuntimeShuffleFunction(InsertPointTy AllocaIP,
2265 Value *Element,
2266 Type *ElementType,
2267 Value *Offset) {
2268 uint64_t Size = M.getDataLayout().getTypeStoreSize(Ty: ElementType);
2269 assert(Size <= 8 && "Unsupported bitwidth in shuffle instruction");
2270
2271 // Cast all types to 32- or 64-bit values before calling shuffle routines.
2272 Type *CastTy = Builder.getIntNTy(N: Size <= 4 ? 32 : 64);
2273 Value *ElemCast = castValueToType(AllocaIP, From: Element, ToType: CastTy);
2274 Value *WarpSize =
2275 Builder.CreateIntCast(V: getGPUWarpSize(), DestTy: Builder.getInt16Ty(), isSigned: true);
2276 Function *ShuffleFunc = getOrCreateRuntimeFunctionPtr(
2277 FnID: Size <= 4 ? RuntimeFunction::OMPRTL___kmpc_shuffle_int32
2278 : RuntimeFunction::OMPRTL___kmpc_shuffle_int64);
2279 Value *WarpSizeCast =
2280 Builder.CreateIntCast(V: WarpSize, DestTy: Builder.getInt16Ty(), /*isSigned=*/true);
2281 Value *ShuffleCall =
2282 Builder.CreateCall(Callee: ShuffleFunc, Args: {ElemCast, Offset, WarpSizeCast});
2283 return castValueToType(AllocaIP, From: ShuffleCall, ToType: CastTy);
2284}
2285
2286void OpenMPIRBuilder::shuffleAndStore(InsertPointTy AllocaIP, Value *SrcAddr,
2287 Value *DstAddr, Type *ElemType,
2288 Value *Offset, Type *ReductionArrayTy) {
2289 uint64_t Size = M.getDataLayout().getTypeStoreSize(Ty: ElemType);
2290 // Create the loop over the big sized data.
2291 // ptr = (void*)Elem;
2292 // ptrEnd = (void*) Elem + 1;
2293 // Step = 8;
2294 // while (ptr + Step < ptrEnd)
2295 // shuffle((int64_t)*ptr);
2296 // Step = 4;
2297 // while (ptr + Step < ptrEnd)
2298 // shuffle((int32_t)*ptr);
2299 // ...
2300 Type *IndexTy = Builder.getIndexTy(
2301 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
2302 Value *ElemPtr = DstAddr;
2303 Value *Ptr = SrcAddr;
2304 for (unsigned IntSize = 8; IntSize >= 1; IntSize /= 2) {
2305 if (Size < IntSize)
2306 continue;
2307 Type *IntType = Builder.getIntNTy(N: IntSize * 8);
2308 Ptr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2309 V: Ptr, DestTy: IntType->getPointerTo(), Name: Ptr->getName() + ".ascast");
2310 Value *SrcAddrGEP =
2311 Builder.CreateGEP(Ty: ElemType, Ptr: SrcAddr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
2312 ElemPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2313 V: ElemPtr, DestTy: IntType->getPointerTo(), Name: ElemPtr->getName() + ".ascast");
2314
2315 Function *CurFunc = Builder.GetInsertBlock()->getParent();
2316 if ((Size / IntSize) > 1) {
2317 Value *PtrEnd = Builder.CreatePointerBitCastOrAddrSpaceCast(
2318 V: SrcAddrGEP, DestTy: Builder.getPtrTy());
2319 BasicBlock *PreCondBB =
2320 BasicBlock::Create(Context&: M.getContext(), Name: ".shuffle.pre_cond");
2321 BasicBlock *ThenBB = BasicBlock::Create(Context&: M.getContext(), Name: ".shuffle.then");
2322 BasicBlock *ExitBB = BasicBlock::Create(Context&: M.getContext(), Name: ".shuffle.exit");
2323 BasicBlock *CurrentBB = Builder.GetInsertBlock();
2324 emitBlock(BB: PreCondBB, CurFn: CurFunc);
2325 PHINode *PhiSrc =
2326 Builder.CreatePHI(Ty: Ptr->getType(), /*NumReservedValues=*/2);
2327 PhiSrc->addIncoming(V: Ptr, BB: CurrentBB);
2328 PHINode *PhiDest =
2329 Builder.CreatePHI(Ty: ElemPtr->getType(), /*NumReservedValues=*/2);
2330 PhiDest->addIncoming(V: ElemPtr, BB: CurrentBB);
2331 Ptr = PhiSrc;
2332 ElemPtr = PhiDest;
2333 Value *PtrDiff = Builder.CreatePtrDiff(
2334 ElemTy: Builder.getInt8Ty(), LHS: PtrEnd,
2335 RHS: Builder.CreatePointerBitCastOrAddrSpaceCast(V: Ptr, DestTy: Builder.getPtrTy()));
2336 Builder.CreateCondBr(
2337 Cond: Builder.CreateICmpSGT(LHS: PtrDiff, RHS: Builder.getInt64(C: IntSize - 1)), True: ThenBB,
2338 False: ExitBB);
2339 emitBlock(BB: ThenBB, CurFn: CurFunc);
2340 Value *Res = createRuntimeShuffleFunction(
2341 AllocaIP,
2342 Element: Builder.CreateAlignedLoad(
2343 Ty: IntType, Ptr, Align: M.getDataLayout().getPrefTypeAlign(Ty: ElemType)),
2344 ElementType: IntType, Offset);
2345 Builder.CreateAlignedStore(Val: Res, Ptr: ElemPtr,
2346 Align: M.getDataLayout().getPrefTypeAlign(Ty: ElemType));
2347 Value *LocalPtr =
2348 Builder.CreateGEP(Ty: IntType, Ptr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
2349 Value *LocalElemPtr =
2350 Builder.CreateGEP(Ty: IntType, Ptr: ElemPtr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
2351 PhiSrc->addIncoming(V: LocalPtr, BB: ThenBB);
2352 PhiDest->addIncoming(V: LocalElemPtr, BB: ThenBB);
2353 emitBranch(Target: PreCondBB);
2354 emitBlock(BB: ExitBB, CurFn: CurFunc);
2355 } else {
2356 Value *Res = createRuntimeShuffleFunction(
2357 AllocaIP, Element: Builder.CreateLoad(Ty: IntType, Ptr), ElementType: IntType, Offset);
2358 if (ElemType->isIntegerTy() && ElemType->getScalarSizeInBits() <
2359 Res->getType()->getScalarSizeInBits())
2360 Res = Builder.CreateTrunc(V: Res, DestTy: ElemType);
2361 Builder.CreateStore(Val: Res, Ptr: ElemPtr);
2362 Ptr = Builder.CreateGEP(Ty: IntType, Ptr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
2363 ElemPtr =
2364 Builder.CreateGEP(Ty: IntType, Ptr: ElemPtr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
2365 }
2366 Size = Size % IntSize;
2367 }
2368}
2369
2370void OpenMPIRBuilder::emitReductionListCopy(
2371 InsertPointTy AllocaIP, CopyAction Action, Type *ReductionArrayTy,
2372 ArrayRef<ReductionInfo> ReductionInfos, Value *SrcBase, Value *DestBase,
2373 CopyOptionsTy CopyOptions) {
2374 Type *IndexTy = Builder.getIndexTy(
2375 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
2376 Value *RemoteLaneOffset = CopyOptions.RemoteLaneOffset;
2377
2378 // Iterates, element-by-element, through the source Reduce list and
2379 // make a copy.
2380 for (auto En : enumerate(First&: ReductionInfos)) {
2381 const ReductionInfo &RI = En.value();
2382 Value *SrcElementAddr = nullptr;
2383 Value *DestElementAddr = nullptr;
2384 Value *DestElementPtrAddr = nullptr;
2385 // Should we shuffle in an element from a remote lane?
2386 bool ShuffleInElement = false;
2387 // Set to true to update the pointer in the dest Reduce list to a
2388 // newly created element.
2389 bool UpdateDestListPtr = false;
2390
2391 // Step 1.1: Get the address for the src element in the Reduce list.
2392 Value *SrcElementPtrAddr = Builder.CreateInBoundsGEP(
2393 Ty: ReductionArrayTy, Ptr: SrcBase,
2394 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
2395 SrcElementAddr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: SrcElementPtrAddr);
2396
2397 // Step 1.2: Create a temporary to store the element in the destination
2398 // Reduce list.
2399 DestElementPtrAddr = Builder.CreateInBoundsGEP(
2400 Ty: ReductionArrayTy, Ptr: DestBase,
2401 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
2402 switch (Action) {
2403 case CopyAction::RemoteLaneToThread: {
2404 InsertPointTy CurIP = Builder.saveIP();
2405 Builder.restoreIP(IP: AllocaIP);
2406 AllocaInst *DestAlloca = Builder.CreateAlloca(Ty: RI.ElementType, ArraySize: nullptr,
2407 Name: ".omp.reduction.element");
2408 DestAlloca->setAlignment(
2409 M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType));
2410 DestElementAddr = DestAlloca;
2411 DestElementAddr =
2412 Builder.CreateAddrSpaceCast(V: DestElementAddr, DestTy: Builder.getPtrTy(),
2413 Name: DestElementAddr->getName() + ".ascast");
2414 Builder.restoreIP(IP: CurIP);
2415 ShuffleInElement = true;
2416 UpdateDestListPtr = true;
2417 break;
2418 }
2419 case CopyAction::ThreadCopy: {
2420 DestElementAddr =
2421 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: DestElementPtrAddr);
2422 break;
2423 }
2424 }
2425
2426 // Now that all active lanes have read the element in the
2427 // Reduce list, shuffle over the value from the remote lane.
2428 if (ShuffleInElement) {
2429 shuffleAndStore(AllocaIP, SrcAddr: SrcElementAddr, DstAddr: DestElementAddr, ElemType: RI.ElementType,
2430 Offset: RemoteLaneOffset, ReductionArrayTy);
2431 } else {
2432 switch (RI.EvaluationKind) {
2433 case EvalKind::Scalar: {
2434 Value *Elem = Builder.CreateLoad(Ty: RI.ElementType, Ptr: SrcElementAddr);
2435 // Store the source element value to the dest element address.
2436 Builder.CreateStore(Val: Elem, Ptr: DestElementAddr);
2437 break;
2438 }
2439 case EvalKind::Complex: {
2440 Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
2441 Ty: RI.ElementType, Ptr: SrcElementAddr, Idx0: 0, Idx1: 0, Name: ".realp");
2442 Value *SrcReal = Builder.CreateLoad(
2443 Ty: RI.ElementType->getStructElementType(N: 0), Ptr: SrcRealPtr, Name: ".real");
2444 Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
2445 Ty: RI.ElementType, Ptr: SrcElementAddr, Idx0: 0, Idx1: 1, Name: ".imagp");
2446 Value *SrcImg = Builder.CreateLoad(
2447 Ty: RI.ElementType->getStructElementType(N: 1), Ptr: SrcImgPtr, Name: ".imag");
2448
2449 Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
2450 Ty: RI.ElementType, Ptr: DestElementAddr, Idx0: 0, Idx1: 0, Name: ".realp");
2451 Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
2452 Ty: RI.ElementType, Ptr: DestElementAddr, Idx0: 0, Idx1: 1, Name: ".imagp");
2453 Builder.CreateStore(Val: SrcReal, Ptr: DestRealPtr);
2454 Builder.CreateStore(Val: SrcImg, Ptr: DestImgPtr);
2455 break;
2456 }
2457 case EvalKind::Aggregate: {
2458 Value *SizeVal = Builder.getInt64(
2459 C: M.getDataLayout().getTypeStoreSize(Ty: RI.ElementType));
2460 Builder.CreateMemCpy(
2461 Dst: DestElementAddr, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
2462 Src: SrcElementAddr, SrcAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
2463 Size: SizeVal, isVolatile: false);
2464 break;
2465 }
2466 };
2467 }
2468
2469 // Step 3.1: Modify reference in dest Reduce list as needed.
2470 // Modifying the reference in Reduce list to point to the newly
2471 // created element. The element is live in the current function
2472 // scope and that of functions it invokes (i.e., reduce_function).
2473 // RemoteReduceData[i] = (void*)&RemoteElem
2474 if (UpdateDestListPtr) {
2475 Value *CastDestAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2476 V: DestElementAddr, DestTy: Builder.getPtrTy(),
2477 Name: DestElementAddr->getName() + ".ascast");
2478 Builder.CreateStore(Val: CastDestAddr, Ptr: DestElementPtrAddr);
2479 }
2480 }
2481}
2482
2483Function *OpenMPIRBuilder::emitInterWarpCopyFunction(
2484 const LocationDescription &Loc, ArrayRef<ReductionInfo> ReductionInfos,
2485 AttributeList FuncAttrs) {
2486 InsertPointTy SavedIP = Builder.saveIP();
2487 LLVMContext &Ctx = M.getContext();
2488 FunctionType *FuncTy = FunctionType::get(
2489 Result: Builder.getVoidTy(), Params: {Builder.getPtrTy(), Builder.getInt32Ty()},
2490 /* IsVarArg */ isVarArg: false);
2491 Function *WcFunc =
2492 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
2493 N: "_omp_reduction_inter_warp_copy_func", M: &M);
2494 WcFunc->setAttributes(FuncAttrs);
2495 WcFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
2496 WcFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
2497 BasicBlock *EntryBB = BasicBlock::Create(Context&: M.getContext(), Name: "entry", Parent: WcFunc);
2498 Builder.SetInsertPoint(EntryBB);
2499
2500 // ReduceList: thread local Reduce list.
2501 // At the stage of the computation when this function is called, partially
2502 // aggregated values reside in the first lane of every active warp.
2503 Argument *ReduceListArg = WcFunc->getArg(i: 0);
2504 // NumWarps: number of warps active in the parallel region. This could
2505 // be smaller than 32 (max warps in a CTA) for partial block reduction.
2506 Argument *NumWarpsArg = WcFunc->getArg(i: 1);
2507
2508 // This array is used as a medium to transfer, one reduce element at a time,
2509 // the data from the first lane of every warp to lanes in the first warp
2510 // in order to perform the final step of a reduction in a parallel region
2511 // (reduction across warps). The array is placed in NVPTX __shared__ memory
2512 // for reduced latency, as well as to have a distinct copy for concurrently
2513 // executing target regions. The array is declared with common linkage so
2514 // as to be shared across compilation units.
2515 StringRef TransferMediumName =
2516 "__openmp_nvptx_data_transfer_temporary_storage";
2517 GlobalVariable *TransferMedium = M.getGlobalVariable(Name: TransferMediumName);
2518 unsigned WarpSize = Config.getGridValue().GV_Warp_Size;
2519 ArrayType *ArrayTy = ArrayType::get(ElementType: Builder.getInt32Ty(), NumElements: WarpSize);
2520 if (!TransferMedium) {
2521 TransferMedium = new GlobalVariable(
2522 M, ArrayTy, /*isConstant=*/false, GlobalVariable::WeakAnyLinkage,
2523 UndefValue::get(T: ArrayTy), TransferMediumName,
2524 /*InsertBefore=*/nullptr, GlobalVariable::NotThreadLocal,
2525 /*AddressSpace=*/3);
2526 }
2527
2528 // Get the CUDA thread id of the current OpenMP thread on the GPU.
2529 Value *GPUThreadID = getGPUThreadID();
2530 // nvptx_lane_id = nvptx_id % warpsize
2531 Value *LaneID = getNVPTXLaneID();
2532 // nvptx_warp_id = nvptx_id / warpsize
2533 Value *WarpID = getNVPTXWarpID();
2534
2535 InsertPointTy AllocaIP =
2536 InsertPointTy(Builder.GetInsertBlock(),
2537 Builder.GetInsertBlock()->getFirstInsertionPt());
2538 Type *Arg0Type = ReduceListArg->getType();
2539 Type *Arg1Type = NumWarpsArg->getType();
2540 Builder.restoreIP(IP: AllocaIP);
2541 AllocaInst *ReduceListAlloca = Builder.CreateAlloca(
2542 Ty: Arg0Type, ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
2543 AllocaInst *NumWarpsAlloca =
2544 Builder.CreateAlloca(Ty: Arg1Type, ArraySize: nullptr, Name: NumWarpsArg->getName() + ".addr");
2545 Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2546 V: ReduceListAlloca, DestTy: Arg0Type, Name: ReduceListAlloca->getName() + ".ascast");
2547 Value *NumWarpsAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2548 V: NumWarpsAlloca, DestTy: Arg1Type->getPointerTo(),
2549 Name: NumWarpsAlloca->getName() + ".ascast");
2550 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListAddrCast);
2551 Builder.CreateStore(Val: NumWarpsArg, Ptr: NumWarpsAddrCast);
2552 AllocaIP = getInsertPointAfterInstr(I: NumWarpsAlloca);
2553 InsertPointTy CodeGenIP =
2554 getInsertPointAfterInstr(I: &Builder.GetInsertBlock()->back());
2555 Builder.restoreIP(IP: CodeGenIP);
2556
2557 Value *ReduceList =
2558 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListAddrCast);
2559
2560 for (auto En : enumerate(First&: ReductionInfos)) {
2561 //
2562 // Warp master copies reduce element to transfer medium in __shared__
2563 // memory.
2564 //
2565 const ReductionInfo &RI = En.value();
2566 unsigned RealTySize = M.getDataLayout().getTypeAllocSize(Ty: RI.ElementType);
2567 for (unsigned TySize = 4; TySize > 0 && RealTySize > 0; TySize /= 2) {
2568 Type *CType = Builder.getIntNTy(N: TySize * 8);
2569
2570 unsigned NumIters = RealTySize / TySize;
2571 if (NumIters == 0)
2572 continue;
2573 Value *Cnt = nullptr;
2574 Value *CntAddr = nullptr;
2575 BasicBlock *PrecondBB = nullptr;
2576 BasicBlock *ExitBB = nullptr;
2577 if (NumIters > 1) {
2578 CodeGenIP = Builder.saveIP();
2579 Builder.restoreIP(IP: AllocaIP);
2580 CntAddr =
2581 Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr, Name: ".cnt.addr");
2582
2583 CntAddr = Builder.CreateAddrSpaceCast(V: CntAddr, DestTy: Builder.getPtrTy(),
2584 Name: CntAddr->getName() + ".ascast");
2585 Builder.restoreIP(IP: CodeGenIP);
2586 Builder.CreateStore(Val: Constant::getNullValue(Ty: Builder.getInt32Ty()),
2587 Ptr: CntAddr,
2588 /*Volatile=*/isVolatile: false);
2589 PrecondBB = BasicBlock::Create(Context&: Ctx, Name: "precond");
2590 ExitBB = BasicBlock::Create(Context&: Ctx, Name: "exit");
2591 BasicBlock *BodyBB = BasicBlock::Create(Context&: Ctx, Name: "body");
2592 emitBlock(BB: PrecondBB, CurFn: Builder.GetInsertBlock()->getParent());
2593 Cnt = Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: CntAddr,
2594 /*Volatile=*/isVolatile: false);
2595 Value *Cmp = Builder.CreateICmpULT(
2596 LHS: Cnt, RHS: ConstantInt::get(Ty: Builder.getInt32Ty(), V: NumIters));
2597 Builder.CreateCondBr(Cond: Cmp, True: BodyBB, False: ExitBB);
2598 emitBlock(BB: BodyBB, CurFn: Builder.GetInsertBlock()->getParent());
2599 }
2600
2601 // kmpc_barrier.
2602 createBarrier(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
2603 Kind: omp::Directive::OMPD_unknown,
2604 /* ForceSimpleCall */ false,
2605 /* CheckCancelFlag */ true);
2606 BasicBlock *ThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
2607 BasicBlock *ElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
2608 BasicBlock *MergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
2609
2610 // if (lane_id == 0)
2611 Value *IsWarpMaster = Builder.CreateIsNull(Arg: LaneID, Name: "warp_master");
2612 Builder.CreateCondBr(Cond: IsWarpMaster, True: ThenBB, False: ElseBB);
2613 emitBlock(BB: ThenBB, CurFn: Builder.GetInsertBlock()->getParent());
2614
2615 // Reduce element = LocalReduceList[i]
2616 auto *RedListArrayTy =
2617 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
2618 Type *IndexTy = Builder.getIndexTy(
2619 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
2620 Value *ElemPtrPtr =
2621 Builder.CreateInBoundsGEP(Ty: RedListArrayTy, Ptr: ReduceList,
2622 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0),
2623 ConstantInt::get(Ty: IndexTy, V: En.index())});
2624 // elemptr = ((CopyType*)(elemptrptr)) + I
2625 Value *ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtrPtr);
2626 if (NumIters > 1)
2627 ElemPtr = Builder.CreateGEP(Ty: Builder.getInt32Ty(), Ptr: ElemPtr, IdxList: Cnt);
2628
2629 // Get pointer to location in transfer medium.
2630 // MediumPtr = &medium[warp_id]
2631 Value *MediumPtr = Builder.CreateInBoundsGEP(
2632 Ty: ArrayTy, Ptr: TransferMedium, IdxList: {Builder.getInt64(C: 0), WarpID});
2633 // elem = *elemptr
2634 //*MediumPtr = elem
2635 Value *Elem = Builder.CreateLoad(Ty: CType, Ptr: ElemPtr);
2636 // Store the source element value to the dest element address.
2637 Builder.CreateStore(Val: Elem, Ptr: MediumPtr,
2638 /*IsVolatile*/ isVolatile: true);
2639 Builder.CreateBr(Dest: MergeBB);
2640
2641 // else
2642 emitBlock(BB: ElseBB, CurFn: Builder.GetInsertBlock()->getParent());
2643 Builder.CreateBr(Dest: MergeBB);
2644
2645 // endif
2646 emitBlock(BB: MergeBB, CurFn: Builder.GetInsertBlock()->getParent());
2647 createBarrier(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
2648 Kind: omp::Directive::OMPD_unknown,
2649 /* ForceSimpleCall */ false,
2650 /* CheckCancelFlag */ true);
2651
2652 // Warp 0 copies reduce element from transfer medium
2653 BasicBlock *W0ThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
2654 BasicBlock *W0ElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
2655 BasicBlock *W0MergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
2656
2657 Value *NumWarpsVal =
2658 Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: NumWarpsAddrCast);
2659 // Up to 32 threads in warp 0 are active.
2660 Value *IsActiveThread =
2661 Builder.CreateICmpULT(LHS: GPUThreadID, RHS: NumWarpsVal, Name: "is_active_thread");
2662 Builder.CreateCondBr(Cond: IsActiveThread, True: W0ThenBB, False: W0ElseBB);
2663
2664 emitBlock(BB: W0ThenBB, CurFn: Builder.GetInsertBlock()->getParent());
2665
2666 // SecMediumPtr = &medium[tid]
2667 // SrcMediumVal = *SrcMediumPtr
2668 Value *SrcMediumPtrVal = Builder.CreateInBoundsGEP(
2669 Ty: ArrayTy, Ptr: TransferMedium, IdxList: {Builder.getInt64(C: 0), GPUThreadID});
2670 // TargetElemPtr = (CopyType*)(SrcDataAddr[i]) + I
2671 Value *TargetElemPtrPtr =
2672 Builder.CreateInBoundsGEP(Ty: RedListArrayTy, Ptr: ReduceList,
2673 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0),
2674 ConstantInt::get(Ty: IndexTy, V: En.index())});
2675 Value *TargetElemPtrVal =
2676 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: TargetElemPtrPtr);
2677 Value *TargetElemPtr = TargetElemPtrVal;
2678 if (NumIters > 1)
2679 TargetElemPtr =
2680 Builder.CreateGEP(Ty: Builder.getInt32Ty(), Ptr: TargetElemPtr, IdxList: Cnt);
2681
2682 // *TargetElemPtr = SrcMediumVal;
2683 Value *SrcMediumValue =
2684 Builder.CreateLoad(Ty: CType, Ptr: SrcMediumPtrVal, /*IsVolatile*/ isVolatile: true);
2685 Builder.CreateStore(Val: SrcMediumValue, Ptr: TargetElemPtr);
2686 Builder.CreateBr(Dest: W0MergeBB);
2687
2688 emitBlock(BB: W0ElseBB, CurFn: Builder.GetInsertBlock()->getParent());
2689 Builder.CreateBr(Dest: W0MergeBB);
2690
2691 emitBlock(BB: W0MergeBB, CurFn: Builder.GetInsertBlock()->getParent());
2692
2693 if (NumIters > 1) {
2694 Cnt = Builder.CreateNSWAdd(
2695 LHS: Cnt, RHS: ConstantInt::get(Ty: Builder.getInt32Ty(), /*V=*/1));
2696 Builder.CreateStore(Val: Cnt, Ptr: CntAddr, /*Volatile=*/isVolatile: false);
2697
2698 auto *CurFn = Builder.GetInsertBlock()->getParent();
2699 emitBranch(Target: PrecondBB);
2700 emitBlock(BB: ExitBB, CurFn);
2701 }
2702 RealTySize %= TySize;
2703 }
2704 }
2705
2706 Builder.CreateRetVoid();
2707 Builder.restoreIP(IP: SavedIP);
2708
2709 return WcFunc;
2710}
2711
2712Function *OpenMPIRBuilder::emitShuffleAndReduceFunction(
2713 ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
2714 AttributeList FuncAttrs) {
2715 LLVMContext &Ctx = M.getContext();
2716 FunctionType *FuncTy =
2717 FunctionType::get(Result: Builder.getVoidTy(),
2718 Params: {Builder.getPtrTy(), Builder.getInt16Ty(),
2719 Builder.getInt16Ty(), Builder.getInt16Ty()},
2720 /* IsVarArg */ isVarArg: false);
2721 Function *SarFunc =
2722 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
2723 N: "_omp_reduction_shuffle_and_reduce_func", M: &M);
2724 SarFunc->setAttributes(FuncAttrs);
2725 SarFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
2726 SarFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
2727 SarFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
2728 SarFunc->addParamAttr(ArgNo: 3, Kind: Attribute::NoUndef);
2729 SarFunc->addParamAttr(ArgNo: 1, Kind: Attribute::SExt);
2730 SarFunc->addParamAttr(ArgNo: 2, Kind: Attribute::SExt);
2731 SarFunc->addParamAttr(ArgNo: 3, Kind: Attribute::SExt);
2732 BasicBlock *EntryBB = BasicBlock::Create(Context&: M.getContext(), Name: "entry", Parent: SarFunc);
2733 Builder.SetInsertPoint(EntryBB);
2734
2735 // Thread local Reduce list used to host the values of data to be reduced.
2736 Argument *ReduceListArg = SarFunc->getArg(i: 0);
2737 // Current lane id; could be logical.
2738 Argument *LaneIDArg = SarFunc->getArg(i: 1);
2739 // Offset of the remote source lane relative to the current lane.
2740 Argument *RemoteLaneOffsetArg = SarFunc->getArg(i: 2);
2741 // Algorithm version. This is expected to be known at compile time.
2742 Argument *AlgoVerArg = SarFunc->getArg(i: 3);
2743
2744 Type *ReduceListArgType = ReduceListArg->getType();
2745 Type *LaneIDArgType = LaneIDArg->getType();
2746 Type *LaneIDArgPtrType = LaneIDArg->getType()->getPointerTo();
2747 Value *ReduceListAlloca = Builder.CreateAlloca(
2748 Ty: ReduceListArgType, ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
2749 Value *LaneIdAlloca = Builder.CreateAlloca(Ty: LaneIDArgType, ArraySize: nullptr,
2750 Name: LaneIDArg->getName() + ".addr");
2751 Value *RemoteLaneOffsetAlloca = Builder.CreateAlloca(
2752 Ty: LaneIDArgType, ArraySize: nullptr, Name: RemoteLaneOffsetArg->getName() + ".addr");
2753 Value *AlgoVerAlloca = Builder.CreateAlloca(Ty: LaneIDArgType, ArraySize: nullptr,
2754 Name: AlgoVerArg->getName() + ".addr");
2755 ArrayType *RedListArrayTy =
2756 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
2757
2758 // Create a local thread-private variable to host the Reduce list
2759 // from a remote lane.
2760 Instruction *RemoteReductionListAlloca = Builder.CreateAlloca(
2761 Ty: RedListArrayTy, ArraySize: nullptr, Name: ".omp.reduction.remote_reduce_list");
2762
2763 Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2764 V: ReduceListAlloca, DestTy: ReduceListArgType,
2765 Name: ReduceListAlloca->getName() + ".ascast");
2766 Value *LaneIdAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2767 V: LaneIdAlloca, DestTy: LaneIDArgPtrType, Name: LaneIdAlloca->getName() + ".ascast");
2768 Value *RemoteLaneOffsetAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2769 V: RemoteLaneOffsetAlloca, DestTy: LaneIDArgPtrType,
2770 Name: RemoteLaneOffsetAlloca->getName() + ".ascast");
2771 Value *AlgoVerAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2772 V: AlgoVerAlloca, DestTy: LaneIDArgPtrType, Name: AlgoVerAlloca->getName() + ".ascast");
2773 Value *RemoteListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2774 V: RemoteReductionListAlloca, DestTy: Builder.getPtrTy(),
2775 Name: RemoteReductionListAlloca->getName() + ".ascast");
2776
2777 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListAddrCast);
2778 Builder.CreateStore(Val: LaneIDArg, Ptr: LaneIdAddrCast);
2779 Builder.CreateStore(Val: RemoteLaneOffsetArg, Ptr: RemoteLaneOffsetAddrCast);
2780 Builder.CreateStore(Val: AlgoVerArg, Ptr: AlgoVerAddrCast);
2781
2782 Value *ReduceList = Builder.CreateLoad(Ty: ReduceListArgType, Ptr: ReduceListAddrCast);
2783 Value *LaneId = Builder.CreateLoad(Ty: LaneIDArgType, Ptr: LaneIdAddrCast);
2784 Value *RemoteLaneOffset =
2785 Builder.CreateLoad(Ty: LaneIDArgType, Ptr: RemoteLaneOffsetAddrCast);
2786 Value *AlgoVer = Builder.CreateLoad(Ty: LaneIDArgType, Ptr: AlgoVerAddrCast);
2787
2788 InsertPointTy AllocaIP = getInsertPointAfterInstr(I: RemoteReductionListAlloca);
2789
2790 // This loop iterates through the list of reduce elements and copies,
2791 // element by element, from a remote lane in the warp to RemoteReduceList,
2792 // hosted on the thread's stack.
2793 emitReductionListCopy(
2794 AllocaIP, Action: CopyAction::RemoteLaneToThread, ReductionArrayTy: RedListArrayTy, ReductionInfos,
2795 SrcBase: ReduceList, DestBase: RemoteListAddrCast, CopyOptions: {.RemoteLaneOffset: RemoteLaneOffset, .ScratchpadIndex: nullptr, .ScratchpadWidth: nullptr});
2796
2797 // The actions to be performed on the Remote Reduce list is dependent
2798 // on the algorithm version.
2799 //
2800 // if (AlgoVer==0) || (AlgoVer==1 && (LaneId < Offset)) || (AlgoVer==2 &&
2801 // LaneId % 2 == 0 && Offset > 0):
2802 // do the reduction value aggregation
2803 //
2804 // The thread local variable Reduce list is mutated in place to host the
2805 // reduced data, which is the aggregated value produced from local and
2806 // remote lanes.
2807 //
2808 // Note that AlgoVer is expected to be a constant integer known at compile
2809 // time.
2810 // When AlgoVer==0, the first conjunction evaluates to true, making
2811 // the entire predicate true during compile time.
2812 // When AlgoVer==1, the second conjunction has only the second part to be
2813 // evaluated during runtime. Other conjunctions evaluates to false
2814 // during compile time.
2815 // When AlgoVer==2, the third conjunction has only the second part to be
2816 // evaluated during runtime. Other conjunctions evaluates to false
2817 // during compile time.
2818 Value *CondAlgo0 = Builder.CreateIsNull(Arg: AlgoVer);
2819 Value *Algo1 = Builder.CreateICmpEQ(LHS: AlgoVer, RHS: Builder.getInt16(C: 1));
2820 Value *LaneComp = Builder.CreateICmpULT(LHS: LaneId, RHS: RemoteLaneOffset);
2821 Value *CondAlgo1 = Builder.CreateAnd(LHS: Algo1, RHS: LaneComp);
2822 Value *Algo2 = Builder.CreateICmpEQ(LHS: AlgoVer, RHS: Builder.getInt16(C: 2));
2823 Value *LaneIdAnd1 = Builder.CreateAnd(LHS: LaneId, RHS: Builder.getInt16(C: 1));
2824 Value *LaneIdComp = Builder.CreateIsNull(Arg: LaneIdAnd1);
2825 Value *Algo2AndLaneIdComp = Builder.CreateAnd(LHS: Algo2, RHS: LaneIdComp);
2826 Value *RemoteOffsetComp =
2827 Builder.CreateICmpSGT(LHS: RemoteLaneOffset, RHS: Builder.getInt16(C: 0));
2828 Value *CondAlgo2 = Builder.CreateAnd(LHS: Algo2AndLaneIdComp, RHS: RemoteOffsetComp);
2829 Value *CA0OrCA1 = Builder.CreateOr(LHS: CondAlgo0, RHS: CondAlgo1);
2830 Value *CondReduce = Builder.CreateOr(LHS: CA0OrCA1, RHS: CondAlgo2);
2831
2832 BasicBlock *ThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
2833 BasicBlock *ElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
2834 BasicBlock *MergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
2835
2836 Builder.CreateCondBr(Cond: CondReduce, True: ThenBB, False: ElseBB);
2837 emitBlock(BB: ThenBB, CurFn: Builder.GetInsertBlock()->getParent());
2838 Value *LocalReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2839 V: ReduceList, DestTy: Builder.getPtrTy());
2840 Value *RemoteReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2841 V: RemoteListAddrCast, DestTy: Builder.getPtrTy());
2842 Builder.CreateCall(Callee: ReduceFn, Args: {LocalReduceListPtr, RemoteReduceListPtr})
2843 ->addFnAttr(Kind: Attribute::NoUnwind);
2844 Builder.CreateBr(Dest: MergeBB);
2845
2846 emitBlock(BB: ElseBB, CurFn: Builder.GetInsertBlock()->getParent());
2847 Builder.CreateBr(Dest: MergeBB);
2848
2849 emitBlock(BB: MergeBB, CurFn: Builder.GetInsertBlock()->getParent());
2850
2851 // if (AlgoVer==1 && (LaneId >= Offset)) copy Remote Reduce list to local
2852 // Reduce list.
2853 Algo1 = Builder.CreateICmpEQ(LHS: AlgoVer, RHS: Builder.getInt16(C: 1));
2854 Value *LaneIdGtOffset = Builder.CreateICmpUGE(LHS: LaneId, RHS: RemoteLaneOffset);
2855 Value *CondCopy = Builder.CreateAnd(LHS: Algo1, RHS: LaneIdGtOffset);
2856
2857 BasicBlock *CpyThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
2858 BasicBlock *CpyElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
2859 BasicBlock *CpyMergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
2860 Builder.CreateCondBr(Cond: CondCopy, True: CpyThenBB, False: CpyElseBB);
2861
2862 emitBlock(BB: CpyThenBB, CurFn: Builder.GetInsertBlock()->getParent());
2863 emitReductionListCopy(AllocaIP, Action: CopyAction::ThreadCopy, ReductionArrayTy: RedListArrayTy,
2864 ReductionInfos, SrcBase: RemoteListAddrCast, DestBase: ReduceList);
2865 Builder.CreateBr(Dest: CpyMergeBB);
2866
2867 emitBlock(BB: CpyElseBB, CurFn: Builder.GetInsertBlock()->getParent());
2868 Builder.CreateBr(Dest: CpyMergeBB);
2869
2870 emitBlock(BB: CpyMergeBB, CurFn: Builder.GetInsertBlock()->getParent());
2871
2872 Builder.CreateRetVoid();
2873
2874 return SarFunc;
2875}
2876
2877Function *OpenMPIRBuilder::emitListToGlobalCopyFunction(
2878 ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
2879 AttributeList FuncAttrs) {
2880 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
2881 LLVMContext &Ctx = M.getContext();
2882 FunctionType *FuncTy = FunctionType::get(
2883 Result: Builder.getVoidTy(),
2884 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
2885 /* IsVarArg */ isVarArg: false);
2886 Function *LtGCFunc =
2887 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
2888 N: "_omp_reduction_list_to_global_copy_func", M: &M);
2889 LtGCFunc->setAttributes(FuncAttrs);
2890 LtGCFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
2891 LtGCFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
2892 LtGCFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
2893
2894 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: LtGCFunc);
2895 Builder.SetInsertPoint(EntryBlock);
2896
2897 // Buffer: global reduction buffer.
2898 Argument *BufferArg = LtGCFunc->getArg(i: 0);
2899 // Idx: index of the buffer.
2900 Argument *IdxArg = LtGCFunc->getArg(i: 1);
2901 // ReduceList: thread local Reduce list.
2902 Argument *ReduceListArg = LtGCFunc->getArg(i: 2);
2903
2904 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
2905 Name: BufferArg->getName() + ".addr");
2906 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
2907 Name: IdxArg->getName() + ".addr");
2908 Value *ReduceListArgAlloca = Builder.CreateAlloca(
2909 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
2910 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2911 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
2912 Name: BufferArgAlloca->getName() + ".ascast");
2913 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2914 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
2915 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2916 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
2917 Name: ReduceListArgAlloca->getName() + ".ascast");
2918
2919 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
2920 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
2921 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
2922
2923 Value *LocalReduceList =
2924 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
2925 Value *BufferArgVal =
2926 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
2927 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
2928 Type *IndexTy = Builder.getIndexTy(
2929 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
2930 for (auto En : enumerate(First&: ReductionInfos)) {
2931 const ReductionInfo &RI = En.value();
2932 auto *RedListArrayTy =
2933 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
2934 // Reduce element = LocalReduceList[i]
2935 Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
2936 Ty: RedListArrayTy, Ptr: LocalReduceList,
2937 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
2938 // elemptr = ((CopyType*)(elemptrptr)) + I
2939 Value *ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtrPtr);
2940
2941 // Global = Buffer.VD[Idx];
2942 Value *BufferVD =
2943 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferArgVal, IdxList: Idxs);
2944 Value *GlobVal = Builder.CreateConstInBoundsGEP2_32(
2945 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
2946
2947 switch (RI.EvaluationKind) {
2948 case EvalKind::Scalar: {
2949 Value *TargetElement = Builder.CreateLoad(Ty: RI.ElementType, Ptr: ElemPtr);
2950 Builder.CreateStore(Val: TargetElement, Ptr: GlobVal);
2951 break;
2952 }
2953 case EvalKind::Complex: {
2954 Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
2955 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 0, Name: ".realp");
2956 Value *SrcReal = Builder.CreateLoad(
2957 Ty: RI.ElementType->getStructElementType(N: 0), Ptr: SrcRealPtr, Name: ".real");
2958 Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
2959 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 1, Name: ".imagp");
2960 Value *SrcImg = Builder.CreateLoad(
2961 Ty: RI.ElementType->getStructElementType(N: 1), Ptr: SrcImgPtr, Name: ".imag");
2962
2963 Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
2964 Ty: RI.ElementType, Ptr: GlobVal, Idx0: 0, Idx1: 0, Name: ".realp");
2965 Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
2966 Ty: RI.ElementType, Ptr: GlobVal, Idx0: 0, Idx1: 1, Name: ".imagp");
2967 Builder.CreateStore(Val: SrcReal, Ptr: DestRealPtr);
2968 Builder.CreateStore(Val: SrcImg, Ptr: DestImgPtr);
2969 break;
2970 }
2971 case EvalKind::Aggregate: {
2972 Value *SizeVal =
2973 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: RI.ElementType));
2974 Builder.CreateMemCpy(
2975 Dst: GlobVal, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType), Src: ElemPtr,
2976 SrcAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType), Size: SizeVal, isVolatile: false);
2977 break;
2978 }
2979 }
2980 }
2981
2982 Builder.CreateRetVoid();
2983 Builder.restoreIP(IP: OldIP);
2984 return LtGCFunc;
2985}
2986
2987Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
2988 ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
2989 Type *ReductionsBufferTy, AttributeList FuncAttrs) {
2990 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
2991 LLVMContext &Ctx = M.getContext();
2992 FunctionType *FuncTy = FunctionType::get(
2993 Result: Builder.getVoidTy(),
2994 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
2995 /* IsVarArg */ isVarArg: false);
2996 Function *LtGRFunc =
2997 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
2998 N: "_omp_reduction_list_to_global_reduce_func", M: &M);
2999 LtGRFunc->setAttributes(FuncAttrs);
3000 LtGRFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3001 LtGRFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3002 LtGRFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
3003
3004 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: LtGRFunc);
3005 Builder.SetInsertPoint(EntryBlock);
3006
3007 // Buffer: global reduction buffer.
3008 Argument *BufferArg = LtGRFunc->getArg(i: 0);
3009 // Idx: index of the buffer.
3010 Argument *IdxArg = LtGRFunc->getArg(i: 1);
3011 // ReduceList: thread local Reduce list.
3012 Argument *ReduceListArg = LtGRFunc->getArg(i: 2);
3013
3014 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
3015 Name: BufferArg->getName() + ".addr");
3016 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
3017 Name: IdxArg->getName() + ".addr");
3018 Value *ReduceListArgAlloca = Builder.CreateAlloca(
3019 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3020 auto *RedListArrayTy =
3021 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3022
3023 // 1. Build a list of reduction variables.
3024 // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
3025 Value *LocalReduceList =
3026 Builder.CreateAlloca(Ty: RedListArrayTy, ArraySize: nullptr, Name: ".omp.reduction.red_list");
3027
3028 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3029 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
3030 Name: BufferArgAlloca->getName() + ".ascast");
3031 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3032 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
3033 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3034 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
3035 Name: ReduceListArgAlloca->getName() + ".ascast");
3036 Value *LocalReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3037 V: LocalReduceList, DestTy: Builder.getPtrTy(),
3038 Name: LocalReduceList->getName() + ".ascast");
3039
3040 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
3041 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
3042 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
3043
3044 Value *BufferVal = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
3045 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
3046 Type *IndexTy = Builder.getIndexTy(
3047 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3048 for (auto En : enumerate(First&: ReductionInfos)) {
3049 Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
3050 Ty: RedListArrayTy, Ptr: LocalReduceListAddrCast,
3051 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3052 Value *BufferVD =
3053 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferVal, IdxList: Idxs);
3054 // Global = Buffer.VD[Idx];
3055 Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3056 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
3057 Builder.CreateStore(Val: GlobValPtr, Ptr: TargetElementPtrPtr);
3058 }
3059
3060 // Call reduce_function(GlobalReduceList, ReduceList)
3061 Value *ReduceList =
3062 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
3063 Builder.CreateCall(Callee: ReduceFn, Args: {LocalReduceListAddrCast, ReduceList})
3064 ->addFnAttr(Kind: Attribute::NoUnwind);
3065 Builder.CreateRetVoid();
3066 Builder.restoreIP(IP: OldIP);
3067 return LtGRFunc;
3068}
3069
3070Function *OpenMPIRBuilder::emitGlobalToListCopyFunction(
3071 ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
3072 AttributeList FuncAttrs) {
3073 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
3074 LLVMContext &Ctx = M.getContext();
3075 FunctionType *FuncTy = FunctionType::get(
3076 Result: Builder.getVoidTy(),
3077 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3078 /* IsVarArg */ isVarArg: false);
3079 Function *LtGCFunc =
3080 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3081 N: "_omp_reduction_global_to_list_copy_func", M: &M);
3082 LtGCFunc->setAttributes(FuncAttrs);
3083 LtGCFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3084 LtGCFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3085 LtGCFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
3086
3087 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: LtGCFunc);
3088 Builder.SetInsertPoint(EntryBlock);
3089
3090 // Buffer: global reduction buffer.
3091 Argument *BufferArg = LtGCFunc->getArg(i: 0);
3092 // Idx: index of the buffer.
3093 Argument *IdxArg = LtGCFunc->getArg(i: 1);
3094 // ReduceList: thread local Reduce list.
3095 Argument *ReduceListArg = LtGCFunc->getArg(i: 2);
3096
3097 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
3098 Name: BufferArg->getName() + ".addr");
3099 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
3100 Name: IdxArg->getName() + ".addr");
3101 Value *ReduceListArgAlloca = Builder.CreateAlloca(
3102 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3103 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3104 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
3105 Name: BufferArgAlloca->getName() + ".ascast");
3106 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3107 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
3108 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3109 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
3110 Name: ReduceListArgAlloca->getName() + ".ascast");
3111 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
3112 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
3113 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
3114
3115 Value *LocalReduceList =
3116 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
3117 Value *BufferVal = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
3118 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
3119 Type *IndexTy = Builder.getIndexTy(
3120 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3121 for (auto En : enumerate(First&: ReductionInfos)) {
3122 const OpenMPIRBuilder::ReductionInfo &RI = En.value();
3123 auto *RedListArrayTy =
3124 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3125 // Reduce element = LocalReduceList[i]
3126 Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
3127 Ty: RedListArrayTy, Ptr: LocalReduceList,
3128 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3129 // elemptr = ((CopyType*)(elemptrptr)) + I
3130 Value *ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtrPtr);
3131 // Global = Buffer.VD[Idx];
3132 Value *BufferVD =
3133 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferVal, IdxList: Idxs);
3134 Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3135 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
3136
3137 switch (RI.EvaluationKind) {
3138 case EvalKind::Scalar: {
3139 Value *TargetElement = Builder.CreateLoad(Ty: RI.ElementType, Ptr: GlobValPtr);
3140 Builder.CreateStore(Val: TargetElement, Ptr: ElemPtr);
3141 break;
3142 }
3143 case EvalKind::Complex: {
3144 Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
3145 Ty: RI.ElementType, Ptr: GlobValPtr, Idx0: 0, Idx1: 0, Name: ".realp");
3146 Value *SrcReal = Builder.CreateLoad(
3147 Ty: RI.ElementType->getStructElementType(N: 0), Ptr: SrcRealPtr, Name: ".real");
3148 Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
3149 Ty: RI.ElementType, Ptr: GlobValPtr, Idx0: 0, Idx1: 1, Name: ".imagp");
3150 Value *SrcImg = Builder.CreateLoad(
3151 Ty: RI.ElementType->getStructElementType(N: 1), Ptr: SrcImgPtr, Name: ".imag");
3152
3153 Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
3154 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 0, Name: ".realp");
3155 Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
3156 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 1, Name: ".imagp");
3157 Builder.CreateStore(Val: SrcReal, Ptr: DestRealPtr);
3158 Builder.CreateStore(Val: SrcImg, Ptr: DestImgPtr);
3159 break;
3160 }
3161 case EvalKind::Aggregate: {
3162 Value *SizeVal =
3163 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: RI.ElementType));
3164 Builder.CreateMemCpy(
3165 Dst: ElemPtr, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
3166 Src: GlobValPtr, SrcAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
3167 Size: SizeVal, isVolatile: false);
3168 break;
3169 }
3170 }
3171 }
3172
3173 Builder.CreateRetVoid();
3174 Builder.restoreIP(IP: OldIP);
3175 return LtGCFunc;
3176}
3177
3178Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
3179 ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
3180 Type *ReductionsBufferTy, AttributeList FuncAttrs) {
3181 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
3182 LLVMContext &Ctx = M.getContext();
3183 auto *FuncTy = FunctionType::get(
3184 Result: Builder.getVoidTy(),
3185 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3186 /* IsVarArg */ isVarArg: false);
3187 Function *LtGRFunc =
3188 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3189 N: "_omp_reduction_global_to_list_reduce_func", M: &M);
3190 LtGRFunc->setAttributes(FuncAttrs);
3191 LtGRFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3192 LtGRFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3193 LtGRFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
3194
3195 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: LtGRFunc);
3196 Builder.SetInsertPoint(EntryBlock);
3197
3198 // Buffer: global reduction buffer.
3199 Argument *BufferArg = LtGRFunc->getArg(i: 0);
3200 // Idx: index of the buffer.
3201 Argument *IdxArg = LtGRFunc->getArg(i: 1);
3202 // ReduceList: thread local Reduce list.
3203 Argument *ReduceListArg = LtGRFunc->getArg(i: 2);
3204
3205 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
3206 Name: BufferArg->getName() + ".addr");
3207 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
3208 Name: IdxArg->getName() + ".addr");
3209 Value *ReduceListArgAlloca = Builder.CreateAlloca(
3210 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3211 ArrayType *RedListArrayTy =
3212 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3213
3214 // 1. Build a list of reduction variables.
3215 // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
3216 Value *LocalReduceList =
3217 Builder.CreateAlloca(Ty: RedListArrayTy, ArraySize: nullptr, Name: ".omp.reduction.red_list");
3218
3219 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3220 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
3221 Name: BufferArgAlloca->getName() + ".ascast");
3222 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3223 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
3224 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3225 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
3226 Name: ReduceListArgAlloca->getName() + ".ascast");
3227 Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
3228 V: LocalReduceList, DestTy: Builder.getPtrTy(),
3229 Name: LocalReduceList->getName() + ".ascast");
3230
3231 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
3232 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
3233 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
3234
3235 Value *BufferVal = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
3236 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
3237 Type *IndexTy = Builder.getIndexTy(
3238 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3239 for (auto En : enumerate(First&: ReductionInfos)) {
3240 Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
3241 Ty: RedListArrayTy, Ptr: ReductionList,
3242 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3243 // Global = Buffer.VD[Idx];
3244 Value *BufferVD =
3245 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferVal, IdxList: Idxs);
3246 Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3247 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
3248 Builder.CreateStore(Val: GlobValPtr, Ptr: TargetElementPtrPtr);
3249 }
3250
3251 // Call reduce_function(ReduceList, GlobalReduceList)
3252 Value *ReduceList =
3253 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
3254 Builder.CreateCall(Callee: ReduceFn, Args: {ReduceList, ReductionList})
3255 ->addFnAttr(Kind: Attribute::NoUnwind);
3256 Builder.CreateRetVoid();
3257 Builder.restoreIP(IP: OldIP);
3258 return LtGRFunc;
3259}
3260
3261std::string OpenMPIRBuilder::getReductionFuncName(StringRef Name) const {
3262 std::string Suffix =
3263 createPlatformSpecificName(Parts: {"omp", "reduction", "reduction_func"});
3264 return (Name + Suffix).str();
3265}
3266
3267Function *OpenMPIRBuilder::createReductionFunction(
3268 StringRef ReducerName, ArrayRef<ReductionInfo> ReductionInfos,
3269 ReductionGenCBKind ReductionGenCBKind, AttributeList FuncAttrs) {
3270 auto *FuncTy = FunctionType::get(Result: Builder.getVoidTy(),
3271 Params: {Builder.getPtrTy(), Builder.getPtrTy()},
3272 /* IsVarArg */ isVarArg: false);
3273 std::string Name = getReductionFuncName(Name: ReducerName);
3274 Function *ReductionFunc =
3275 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage, N: Name, M: &M);
3276 ReductionFunc->setAttributes(FuncAttrs);
3277 ReductionFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3278 ReductionFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3279 BasicBlock *EntryBB =
3280 BasicBlock::Create(Context&: M.getContext(), Name: "entry", Parent: ReductionFunc);
3281 Builder.SetInsertPoint(EntryBB);
3282
3283 // Need to alloca memory here and deal with the pointers before getting
3284 // LHS/RHS pointers out
3285 Value *LHSArrayPtr = nullptr;
3286 Value *RHSArrayPtr = nullptr;
3287 Argument *Arg0 = ReductionFunc->getArg(i: 0);
3288 Argument *Arg1 = ReductionFunc->getArg(i: 1);
3289 Type *Arg0Type = Arg0->getType();
3290 Type *Arg1Type = Arg1->getType();
3291
3292 Value *LHSAlloca =
3293 Builder.CreateAlloca(Ty: Arg0Type, ArraySize: nullptr, Name: Arg0->getName() + ".addr");
3294 Value *RHSAlloca =
3295 Builder.CreateAlloca(Ty: Arg1Type, ArraySize: nullptr, Name: Arg1->getName() + ".addr");
3296 Value *LHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3297 V: LHSAlloca, DestTy: Arg0Type, Name: LHSAlloca->getName() + ".ascast");
3298 Value *RHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3299 V: RHSAlloca, DestTy: Arg1Type, Name: RHSAlloca->getName() + ".ascast");
3300 Builder.CreateStore(Val: Arg0, Ptr: LHSAddrCast);
3301 Builder.CreateStore(Val: Arg1, Ptr: RHSAddrCast);
3302 LHSArrayPtr = Builder.CreateLoad(Ty: Arg0Type, Ptr: LHSAddrCast);
3303 RHSArrayPtr = Builder.CreateLoad(Ty: Arg1Type, Ptr: RHSAddrCast);
3304
3305 Type *RedArrayTy = ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3306 Type *IndexTy = Builder.getIndexTy(
3307 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3308 SmallVector<Value *> LHSPtrs, RHSPtrs;
3309 for (auto En : enumerate(First&: ReductionInfos)) {
3310 const ReductionInfo &RI = En.value();
3311 Value *RHSI8PtrPtr = Builder.CreateInBoundsGEP(
3312 Ty: RedArrayTy, Ptr: RHSArrayPtr,
3313 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3314 Value *RHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: RHSI8PtrPtr);
3315 Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3316 V: RHSI8Ptr, DestTy: RI.PrivateVariable->getType(),
3317 Name: RHSI8Ptr->getName() + ".ascast");
3318
3319 Value *LHSI8PtrPtr = Builder.CreateInBoundsGEP(
3320 Ty: RedArrayTy, Ptr: LHSArrayPtr,
3321 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3322 Value *LHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: LHSI8PtrPtr);
3323 Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3324 V: LHSI8Ptr, DestTy: RI.Variable->getType(), Name: LHSI8Ptr->getName() + ".ascast");
3325
3326 if (ReductionGenCBKind == ReductionGenCBKind::Clang) {
3327 LHSPtrs.emplace_back(Args&: LHSPtr);
3328 RHSPtrs.emplace_back(Args&: RHSPtr);
3329 } else {
3330 Value *LHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: LHSPtr);
3331 Value *RHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: RHSPtr);
3332 Value *Reduced;
3333 RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
3334 if (!Builder.GetInsertBlock())
3335 return ReductionFunc;
3336 Builder.CreateStore(Val: Reduced, Ptr: LHSPtr);
3337 }
3338 }
3339
3340 if (ReductionGenCBKind == ReductionGenCBKind::Clang)
3341 for (auto En : enumerate(First&: ReductionInfos)) {
3342 unsigned Index = En.index();
3343 const ReductionInfo &RI = En.value();
3344 Value *LHSFixupPtr, *RHSFixupPtr;
3345 Builder.restoreIP(IP: RI.ReductionGenClang(
3346 Builder.saveIP(), Index, &LHSFixupPtr, &RHSFixupPtr, ReductionFunc));
3347
3348 // Fix the CallBack code genereated to use the correct Values for the LHS
3349 // and RHS
3350 LHSFixupPtr->replaceUsesWithIf(
3351 New: LHSPtrs[Index], ShouldReplace: [ReductionFunc](const Use &U) {
3352 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
3353 ReductionFunc;
3354 });
3355 RHSFixupPtr->replaceUsesWithIf(
3356 New: RHSPtrs[Index], ShouldReplace: [ReductionFunc](const Use &U) {
3357 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
3358 ReductionFunc;
3359 });
3360 }
3361
3362 Builder.CreateRetVoid();
3363 return ReductionFunc;
3364}
3365
3366static void
3367checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
3368 bool IsGPU) {
3369 for (const OpenMPIRBuilder::ReductionInfo &RI : ReductionInfos) {
3370 (void)RI;
3371 assert(RI.Variable && "expected non-null variable");
3372 assert(RI.PrivateVariable && "expected non-null private variable");
3373 assert((RI.ReductionGen || RI.ReductionGenClang) &&
3374 "expected non-null reduction generator callback");
3375 if (!IsGPU) {
3376 assert(
3377 RI.Variable->getType() == RI.PrivateVariable->getType() &&
3378 "expected variables and their private equivalents to have the same "
3379 "type");
3380 }
3381 assert(RI.Variable->getType()->isPointerTy() &&
3382 "expected variables to be pointers");
3383 }
3384}
3385
3386OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductionsGPU(
3387 const LocationDescription &Loc, InsertPointTy AllocaIP,
3388 InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
3389 bool IsNoWait, bool IsTeamsReduction, bool HasDistribute,
3390 ReductionGenCBKind ReductionGenCBKind, std::optional<omp::GV> GridValue,
3391 unsigned ReductionBufNum, Value *SrcLocInfo) {
3392 if (!updateToLocation(Loc))
3393 return InsertPointTy();
3394 Builder.restoreIP(IP: CodeGenIP);
3395 checkReductionInfos(ReductionInfos, /*IsGPU*/ true);
3396 LLVMContext &Ctx = M.getContext();
3397
3398 // Source location for the ident struct
3399 if (!SrcLocInfo) {
3400 uint32_t SrcLocStrSize;
3401 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3402 SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3403 }
3404
3405 if (ReductionInfos.size() == 0)
3406 return Builder.saveIP();
3407
3408 Function *CurFunc = Builder.GetInsertBlock()->getParent();
3409 AttributeList FuncAttrs;
3410 AttrBuilder AttrBldr(Ctx);
3411 for (auto Attr : CurFunc->getAttributes().getFnAttrs())
3412 AttrBldr.addAttribute(A: Attr);
3413 AttrBldr.removeAttribute(Val: Attribute::OptimizeNone);
3414 FuncAttrs = FuncAttrs.addFnAttributes(C&: Ctx, B: AttrBldr);
3415
3416 Function *ReductionFunc = nullptr;
3417 CodeGenIP = Builder.saveIP();
3418 ReductionFunc =
3419 createReductionFunction(ReducerName: Builder.GetInsertBlock()->getParent()->getName(),
3420 ReductionInfos, ReductionGenCBKind, FuncAttrs);
3421 Builder.restoreIP(IP: CodeGenIP);
3422
3423 // Set the grid value in the config needed for lowering later on
3424 if (GridValue.has_value())
3425 Config.setGridValue(GridValue.value());
3426 else
3427 Config.setGridValue(getGridValue(T, Kernel: ReductionFunc));
3428
3429 uint32_t SrcLocStrSize;
3430 Constant *SrcLocStr = getOrCreateDefaultSrcLocStr(SrcLocStrSize);
3431 Value *RTLoc =
3432 getOrCreateIdent(SrcLocStr, SrcLocStrSize, LocFlags: omp::IdentFlag(0), Reserve2Flags: 0);
3433
3434 // Build res = __kmpc_reduce{_nowait}(<gtid>, <n>, sizeof(RedList),
3435 // RedList, shuffle_reduce_func, interwarp_copy_func);
3436 // or
3437 // Build res = __kmpc_reduce_teams_nowait_simple(<loc>, <gtid>, <lck>);
3438 Value *Res;
3439
3440 // 1. Build a list of reduction variables.
3441 // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
3442 auto Size = ReductionInfos.size();
3443 Type *PtrTy = PointerType::getUnqual(C&: Ctx);
3444 Type *RedArrayTy = ArrayType::get(ElementType: PtrTy, NumElements: Size);
3445 CodeGenIP = Builder.saveIP();
3446 Builder.restoreIP(IP: AllocaIP);
3447 Value *ReductionListAlloca =
3448 Builder.CreateAlloca(Ty: RedArrayTy, ArraySize: nullptr, Name: ".omp.reduction.red_list");
3449 Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
3450 V: ReductionListAlloca, DestTy: PtrTy, Name: ReductionListAlloca->getName() + ".ascast");
3451 Builder.restoreIP(IP: CodeGenIP);
3452 Type *IndexTy = Builder.getIndexTy(
3453 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3454 for (auto En : enumerate(First&: ReductionInfos)) {
3455 const ReductionInfo &RI = En.value();
3456 Value *ElemPtr = Builder.CreateInBoundsGEP(
3457 Ty: RedArrayTy, Ptr: ReductionList,
3458 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3459 Value *CastElem =
3460 Builder.CreatePointerBitCastOrAddrSpaceCast(V: RI.PrivateVariable, DestTy: PtrTy);
3461 Builder.CreateStore(Val: CastElem, Ptr: ElemPtr);
3462 }
3463 CodeGenIP = Builder.saveIP();
3464 Function *SarFunc =
3465 emitShuffleAndReduceFunction(ReductionInfos, ReduceFn: ReductionFunc, FuncAttrs);
3466 Function *WcFunc = emitInterWarpCopyFunction(Loc, ReductionInfos, FuncAttrs);
3467 Builder.restoreIP(IP: CodeGenIP);
3468
3469 Value *RL = Builder.CreatePointerBitCastOrAddrSpaceCast(V: ReductionList, DestTy: PtrTy);
3470
3471 unsigned MaxDataSize = 0;
3472 SmallVector<Type *> ReductionTypeArgs;
3473 for (auto En : enumerate(First&: ReductionInfos)) {
3474 auto Size = M.getDataLayout().getTypeStoreSize(Ty: En.value().ElementType);
3475 if (Size > MaxDataSize)
3476 MaxDataSize = Size;
3477 ReductionTypeArgs.emplace_back(Args: En.value().ElementType);
3478 }
3479 Value *ReductionDataSize =
3480 Builder.getInt64(C: MaxDataSize * ReductionInfos.size());
3481 if (!IsTeamsReduction) {
3482 Value *SarFuncCast =
3483 Builder.CreatePointerBitCastOrAddrSpaceCast(V: SarFunc, DestTy: PtrTy);
3484 Value *WcFuncCast =
3485 Builder.CreatePointerBitCastOrAddrSpaceCast(V: WcFunc, DestTy: PtrTy);
3486 Value *Args[] = {RTLoc, ReductionDataSize, RL, SarFuncCast, WcFuncCast};
3487 Function *Pv2Ptr = getOrCreateRuntimeFunctionPtr(
3488 FnID: RuntimeFunction::OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2);
3489 Res = Builder.CreateCall(Callee: Pv2Ptr, Args);
3490 } else {
3491 CodeGenIP = Builder.saveIP();
3492 StructType *ReductionsBufferTy = StructType::create(
3493 Context&: Ctx, Elements: ReductionTypeArgs, Name: "struct._globalized_locals_ty");
3494 Function *RedFixedBuferFn = getOrCreateRuntimeFunctionPtr(
3495 FnID: RuntimeFunction::OMPRTL___kmpc_reduction_get_fixed_buffer);
3496 Function *LtGCFunc = emitListToGlobalCopyFunction(
3497 ReductionInfos, ReductionsBufferTy, FuncAttrs);
3498 Function *LtGRFunc = emitListToGlobalReduceFunction(
3499 ReductionInfos, ReduceFn: ReductionFunc, ReductionsBufferTy, FuncAttrs);
3500 Function *GtLCFunc = emitGlobalToListCopyFunction(
3501 ReductionInfos, ReductionsBufferTy, FuncAttrs);
3502 Function *GtLRFunc = emitGlobalToListReduceFunction(
3503 ReductionInfos, ReduceFn: ReductionFunc, ReductionsBufferTy, FuncAttrs);
3504 Builder.restoreIP(IP: CodeGenIP);
3505
3506 Value *KernelTeamsReductionPtr = Builder.CreateCall(
3507 Callee: RedFixedBuferFn, Args: {}, Name: "_openmp_teams_reductions_buffer_$_$ptr");
3508
3509 Value *Args3[] = {RTLoc,
3510 KernelTeamsReductionPtr,
3511 Builder.getInt32(C: ReductionBufNum),
3512 ReductionDataSize,
3513 RL,
3514 SarFunc,
3515 WcFunc,
3516 LtGCFunc,
3517 LtGRFunc,
3518 GtLCFunc,
3519 GtLRFunc};
3520
3521 Function *TeamsReduceFn = getOrCreateRuntimeFunctionPtr(
3522 FnID: RuntimeFunction::OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2);
3523 Res = Builder.CreateCall(Callee: TeamsReduceFn, Args: Args3);
3524 }
3525
3526 // 5. Build if (res == 1)
3527 BasicBlock *ExitBB = BasicBlock::Create(Context&: Ctx, Name: ".omp.reduction.done");
3528 BasicBlock *ThenBB = BasicBlock::Create(Context&: Ctx, Name: ".omp.reduction.then");
3529 Value *Cond = Builder.CreateICmpEQ(LHS: Res, RHS: Builder.getInt32(C: 1));
3530 Builder.CreateCondBr(Cond, True: ThenBB, False: ExitBB);
3531
3532 // 6. Build then branch: where we have reduced values in the master
3533 // thread in each team.
3534 // __kmpc_end_reduce{_nowait}(<gtid>);
3535 // break;
3536 emitBlock(BB: ThenBB, CurFn: CurFunc);
3537
3538 // Add emission of __kmpc_end_reduce{_nowait}(<gtid>);
3539 for (auto En : enumerate(First&: ReductionInfos)) {
3540 const ReductionInfo &RI = En.value();
3541 Value *LHS = RI.Variable;
3542 Value *RHS =
3543 Builder.CreatePointerBitCastOrAddrSpaceCast(V: RI.PrivateVariable, DestTy: PtrTy);
3544
3545 if (ReductionGenCBKind == ReductionGenCBKind::Clang) {
3546 Value *LHSPtr, *RHSPtr;
3547 Builder.restoreIP(IP: RI.ReductionGenClang(Builder.saveIP(), En.index(),
3548 &LHSPtr, &RHSPtr, CurFunc));
3549
3550 // Fix the CallBack code genereated to use the correct Values for the LHS
3551 // and RHS
3552 LHSPtr->replaceUsesWithIf(New: LHS, ShouldReplace: [ReductionFunc](const Use &U) {
3553 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
3554 ReductionFunc;
3555 });
3556 RHSPtr->replaceUsesWithIf(New: RHS, ShouldReplace: [ReductionFunc](const Use &U) {
3557 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
3558 ReductionFunc;
3559 });
3560 } else {
3561 assert(false && "Unhandled ReductionGenCBKind");
3562 }
3563 }
3564 emitBlock(BB: ExitBB, CurFn: CurFunc);
3565
3566 Config.setEmitLLVMUsed();
3567
3568 return Builder.saveIP();
3569}
3570
3571static Function *getFreshReductionFunc(Module &M) {
3572 Type *VoidTy = Type::getVoidTy(C&: M.getContext());
3573 Type *Int8PtrTy = PointerType::getUnqual(C&: M.getContext());
3574 auto *FuncTy =
3575 FunctionType::get(Result: VoidTy, Params: {Int8PtrTy, Int8PtrTy}, /* IsVarArg */ isVarArg: false);
3576 return Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3577 N: ".omp.reduction.func", M: &M);
3578}
3579
3580OpenMPIRBuilder::InsertPointTy
3581OpenMPIRBuilder::createReductions(const LocationDescription &Loc,
3582 InsertPointTy AllocaIP,
3583 ArrayRef<ReductionInfo> ReductionInfos,
3584 ArrayRef<bool> IsByRef, bool IsNoWait) {
3585 assert(ReductionInfos.size() == IsByRef.size());
3586 for (const ReductionInfo &RI : ReductionInfos) {
3587 (void)RI;
3588 assert(RI.Variable && "expected non-null variable");
3589 assert(RI.PrivateVariable && "expected non-null private variable");
3590 assert(RI.ReductionGen && "expected non-null reduction generator callback");
3591 assert(RI.Variable->getType() == RI.PrivateVariable->getType() &&
3592 "expected variables and their private equivalents to have the same "
3593 "type");
3594 assert(RI.Variable->getType()->isPointerTy() &&
3595 "expected variables to be pointers");
3596 }
3597
3598 if (!updateToLocation(Loc))
3599 return InsertPointTy();
3600
3601 BasicBlock *InsertBlock = Loc.IP.getBlock();
3602 BasicBlock *ContinuationBlock =
3603 InsertBlock->splitBasicBlock(I: Loc.IP.getPoint(), BBName: "reduce.finalize");
3604 InsertBlock->getTerminator()->eraseFromParent();
3605
3606 // Create and populate array of type-erased pointers to private reduction
3607 // values.
3608 unsigned NumReductions = ReductionInfos.size();
3609 Type *RedArrayTy = ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: NumReductions);
3610 Builder.SetInsertPoint(AllocaIP.getBlock()->getTerminator());
3611 Value *RedArray = Builder.CreateAlloca(Ty: RedArrayTy, ArraySize: nullptr, Name: "red.array");
3612
3613 Builder.SetInsertPoint(TheBB: InsertBlock, IP: InsertBlock->end());
3614
3615 for (auto En : enumerate(First&: ReductionInfos)) {
3616 unsigned Index = En.index();
3617 const ReductionInfo &RI = En.value();
3618 Value *RedArrayElemPtr = Builder.CreateConstInBoundsGEP2_64(
3619 Ty: RedArrayTy, Ptr: RedArray, Idx0: 0, Idx1: Index, Name: "red.array.elem." + Twine(Index));
3620 Builder.CreateStore(Val: RI.PrivateVariable, Ptr: RedArrayElemPtr);
3621 }
3622
3623 // Emit a call to the runtime function that orchestrates the reduction.
3624 // Declare the reduction function in the process.
3625 Function *Func = Builder.GetInsertBlock()->getParent();
3626 Module *Module = Func->getParent();
3627 uint32_t SrcLocStrSize;
3628 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3629 bool CanGenerateAtomic = all_of(Range&: ReductionInfos, P: [](const ReductionInfo &RI) {
3630 return RI.AtomicReductionGen;
3631 });
3632 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize,
3633 LocFlags: CanGenerateAtomic
3634 ? IdentFlag::OMP_IDENT_FLAG_ATOMIC_REDUCE
3635 : IdentFlag(0));
3636 Value *ThreadId = getOrCreateThreadID(Ident);
3637 Constant *NumVariables = Builder.getInt32(C: NumReductions);
3638 const DataLayout &DL = Module->getDataLayout();
3639 unsigned RedArrayByteSize = DL.getTypeStoreSize(Ty: RedArrayTy);
3640 Constant *RedArraySize = Builder.getInt64(C: RedArrayByteSize);
3641 Function *ReductionFunc = getFreshReductionFunc(M&: *Module);
3642 Value *Lock = getOMPCriticalRegionLock(CriticalName: ".reduction");
3643 Function *ReduceFunc = getOrCreateRuntimeFunctionPtr(
3644 FnID: IsNoWait ? RuntimeFunction::OMPRTL___kmpc_reduce_nowait
3645 : RuntimeFunction::OMPRTL___kmpc_reduce);
3646 CallInst *ReduceCall =
3647 Builder.CreateCall(Callee: ReduceFunc,
3648 Args: {Ident, ThreadId, NumVariables, RedArraySize, RedArray,
3649 ReductionFunc, Lock},
3650 Name: "reduce");
3651
3652 // Create final reduction entry blocks for the atomic and non-atomic case.
3653 // Emit IR that dispatches control flow to one of the blocks based on the
3654 // reduction supporting the atomic mode.
3655 BasicBlock *NonAtomicRedBlock =
3656 BasicBlock::Create(Context&: Module->getContext(), Name: "reduce.switch.nonatomic", Parent: Func);
3657 BasicBlock *AtomicRedBlock =
3658 BasicBlock::Create(Context&: Module->getContext(), Name: "reduce.switch.atomic", Parent: Func);
3659 SwitchInst *Switch =
3660 Builder.CreateSwitch(V: ReduceCall, Dest: ContinuationBlock, /* NumCases */ 2);
3661 Switch->addCase(OnVal: Builder.getInt32(C: 1), Dest: NonAtomicRedBlock);
3662 Switch->addCase(OnVal: Builder.getInt32(C: 2), Dest: AtomicRedBlock);
3663
3664 // Populate the non-atomic reduction using the elementwise reduction function.
3665 // This loads the elements from the global and private variables and reduces
3666 // them before storing back the result to the global variable.
3667 Builder.SetInsertPoint(NonAtomicRedBlock);
3668 for (auto En : enumerate(First&: ReductionInfos)) {
3669 const ReductionInfo &RI = En.value();
3670 Type *ValueType = RI.ElementType;
3671 // We have one less load for by-ref case because that load is now inside of
3672 // the reduction region
3673 Value *RedValue = nullptr;
3674 if (!IsByRef[En.index()]) {
3675 RedValue = Builder.CreateLoad(Ty: ValueType, Ptr: RI.Variable,
3676 Name: "red.value." + Twine(En.index()));
3677 }
3678 Value *PrivateRedValue =
3679 Builder.CreateLoad(Ty: ValueType, Ptr: RI.PrivateVariable,
3680 Name: "red.private.value." + Twine(En.index()));
3681 Value *Reduced;
3682 if (IsByRef[En.index()]) {
3683 Builder.restoreIP(IP: RI.ReductionGen(Builder.saveIP(), RI.Variable,
3684 PrivateRedValue, Reduced));
3685 } else {
3686 Builder.restoreIP(IP: RI.ReductionGen(Builder.saveIP(), RedValue,
3687 PrivateRedValue, Reduced));
3688 }
3689 if (!Builder.GetInsertBlock())
3690 return InsertPointTy();
3691 // for by-ref case, the load is inside of the reduction region
3692 if (!IsByRef[En.index()])
3693 Builder.CreateStore(Val: Reduced, Ptr: RI.Variable);
3694 }
3695 Function *EndReduceFunc = getOrCreateRuntimeFunctionPtr(
3696 FnID: IsNoWait ? RuntimeFunction::OMPRTL___kmpc_end_reduce_nowait
3697 : RuntimeFunction::OMPRTL___kmpc_end_reduce);
3698 Builder.CreateCall(Callee: EndReduceFunc, Args: {Ident, ThreadId, Lock});
3699 Builder.CreateBr(Dest: ContinuationBlock);
3700
3701 // Populate the atomic reduction using the atomic elementwise reduction
3702 // function. There are no loads/stores here because they will be happening
3703 // inside the atomic elementwise reduction.
3704 Builder.SetInsertPoint(AtomicRedBlock);
3705 if (CanGenerateAtomic && llvm::none_of(Range&: IsByRef, P: [](bool P) { return P; })) {
3706 for (const ReductionInfo &RI : ReductionInfos) {
3707 Builder.restoreIP(IP: RI.AtomicReductionGen(Builder.saveIP(), RI.ElementType,
3708 RI.Variable, RI.PrivateVariable));
3709 if (!Builder.GetInsertBlock())
3710 return InsertPointTy();
3711 }
3712 Builder.CreateBr(Dest: ContinuationBlock);
3713 } else {
3714 Builder.CreateUnreachable();
3715 }
3716
3717 // Populate the outlined reduction function using the elementwise reduction
3718 // function. Partial values are extracted from the type-erased array of
3719 // pointers to private variables.
3720 BasicBlock *ReductionFuncBlock =
3721 BasicBlock::Create(Context&: Module->getContext(), Name: "", Parent: ReductionFunc);
3722 Builder.SetInsertPoint(ReductionFuncBlock);
3723 Value *LHSArrayPtr = ReductionFunc->getArg(i: 0);
3724 Value *RHSArrayPtr = ReductionFunc->getArg(i: 1);
3725
3726 for (auto En : enumerate(First&: ReductionInfos)) {
3727 const ReductionInfo &RI = En.value();
3728 Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
3729 Ty: RedArrayTy, Ptr: LHSArrayPtr, Idx0: 0, Idx1: En.index());
3730 Value *LHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: LHSI8PtrPtr);
3731 Value *LHSPtr = Builder.CreateBitCast(V: LHSI8Ptr, DestTy: RI.Variable->getType());
3732 Value *LHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: LHSPtr);
3733 Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
3734 Ty: RedArrayTy, Ptr: RHSArrayPtr, Idx0: 0, Idx1: En.index());
3735 Value *RHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: RHSI8PtrPtr);
3736 Value *RHSPtr =
3737 Builder.CreateBitCast(V: RHSI8Ptr, DestTy: RI.PrivateVariable->getType());
3738 Value *RHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: RHSPtr);
3739 Value *Reduced;
3740 Builder.restoreIP(IP: RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced));
3741 if (!Builder.GetInsertBlock())
3742 return InsertPointTy();
3743 // store is inside of the reduction region when using by-ref
3744 if (!IsByRef[En.index()])
3745 Builder.CreateStore(Val: Reduced, Ptr: LHSPtr);
3746 }
3747 Builder.CreateRetVoid();
3748
3749 Builder.SetInsertPoint(ContinuationBlock);
3750 return Builder.saveIP();
3751}
3752
3753OpenMPIRBuilder::InsertPointTy
3754OpenMPIRBuilder::createMaster(const LocationDescription &Loc,
3755 BodyGenCallbackTy BodyGenCB,
3756 FinalizeCallbackTy FiniCB) {
3757
3758 if (!updateToLocation(Loc))
3759 return Loc.IP;
3760
3761 Directive OMPD = Directive::OMPD_master;
3762 uint32_t SrcLocStrSize;
3763 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3764 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3765 Value *ThreadId = getOrCreateThreadID(Ident);
3766 Value *Args[] = {Ident, ThreadId};
3767
3768 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_master);
3769 Instruction *EntryCall = Builder.CreateCall(Callee: EntryRTLFn, Args);
3770
3771 Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_master);
3772 Instruction *ExitCall = Builder.CreateCall(Callee: ExitRTLFn, Args);
3773
3774 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
3775 /*Conditional*/ true, /*hasFinalize*/ HasFinalize: true);
3776}
3777
3778OpenMPIRBuilder::InsertPointTy
3779OpenMPIRBuilder::createMasked(const LocationDescription &Loc,
3780 BodyGenCallbackTy BodyGenCB,
3781 FinalizeCallbackTy FiniCB, Value *Filter) {
3782 if (!updateToLocation(Loc))
3783 return Loc.IP;
3784
3785 Directive OMPD = Directive::OMPD_masked;
3786 uint32_t SrcLocStrSize;
3787 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3788 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3789 Value *ThreadId = getOrCreateThreadID(Ident);
3790 Value *Args[] = {Ident, ThreadId, Filter};
3791 Value *ArgsEnd[] = {Ident, ThreadId};
3792
3793 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_masked);
3794 Instruction *EntryCall = Builder.CreateCall(Callee: EntryRTLFn, Args);
3795
3796 Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_masked);
3797 Instruction *ExitCall = Builder.CreateCall(Callee: ExitRTLFn, Args: ArgsEnd);
3798
3799 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
3800 /*Conditional*/ true, /*hasFinalize*/ HasFinalize: true);
3801}
3802
3803CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton(
3804 DebugLoc DL, Value *TripCount, Function *F, BasicBlock *PreInsertBefore,
3805 BasicBlock *PostInsertBefore, const Twine &Name) {
3806 Module *M = F->getParent();
3807 LLVMContext &Ctx = M->getContext();
3808 Type *IndVarTy = TripCount->getType();
3809
3810 // Create the basic block structure.
3811 BasicBlock *Preheader =
3812 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".preheader", Parent: F, InsertBefore: PreInsertBefore);
3813 BasicBlock *Header =
3814 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".header", Parent: F, InsertBefore: PreInsertBefore);
3815 BasicBlock *Cond =
3816 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".cond", Parent: F, InsertBefore: PreInsertBefore);
3817 BasicBlock *Body =
3818 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".body", Parent: F, InsertBefore: PreInsertBefore);
3819 BasicBlock *Latch =
3820 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".inc", Parent: F, InsertBefore: PostInsertBefore);
3821 BasicBlock *Exit =
3822 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".exit", Parent: F, InsertBefore: PostInsertBefore);
3823 BasicBlock *After =
3824 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".after", Parent: F, InsertBefore: PostInsertBefore);
3825
3826 // Use specified DebugLoc for new instructions.
3827 Builder.SetCurrentDebugLocation(DL);
3828
3829 Builder.SetInsertPoint(Preheader);
3830 Builder.CreateBr(Dest: Header);
3831
3832 Builder.SetInsertPoint(Header);
3833 PHINode *IndVarPHI = Builder.CreatePHI(Ty: IndVarTy, NumReservedValues: 2, Name: "omp_" + Name + ".iv");
3834 IndVarPHI->addIncoming(V: ConstantInt::get(Ty: IndVarTy, V: 0), BB: Preheader);
3835 Builder.CreateBr(Dest: Cond);
3836
3837 Builder.SetInsertPoint(Cond);
3838 Value *Cmp =
3839 Builder.CreateICmpULT(LHS: IndVarPHI, RHS: TripCount, Name: "omp_" + Name + ".cmp");
3840 Builder.CreateCondBr(Cond: Cmp, True: Body, False: Exit);
3841
3842 Builder.SetInsertPoint(Body);
3843 Builder.CreateBr(Dest: Latch);
3844
3845 Builder.SetInsertPoint(Latch);
3846 Value *Next = Builder.CreateAdd(LHS: IndVarPHI, RHS: ConstantInt::get(Ty: IndVarTy, V: 1),
3847 Name: "omp_" + Name + ".next", /*HasNUW=*/true);
3848 Builder.CreateBr(Dest: Header);
3849 IndVarPHI->addIncoming(V: Next, BB: Latch);
3850
3851 Builder.SetInsertPoint(Exit);
3852 Builder.CreateBr(Dest: After);
3853
3854 // Remember and return the canonical control flow.
3855 LoopInfos.emplace_front();
3856 CanonicalLoopInfo *CL = &LoopInfos.front();
3857
3858 CL->Header = Header;
3859 CL->Cond = Cond;
3860 CL->Latch = Latch;
3861 CL->Exit = Exit;
3862
3863#ifndef NDEBUG
3864 CL->assertOK();
3865#endif
3866 return CL;
3867}
3868
3869CanonicalLoopInfo *
3870OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
3871 LoopBodyGenCallbackTy BodyGenCB,
3872 Value *TripCount, const Twine &Name) {
3873 BasicBlock *BB = Loc.IP.getBlock();
3874 BasicBlock *NextBB = BB->getNextNode();
3875
3876 CanonicalLoopInfo *CL = createLoopSkeleton(DL: Loc.DL, TripCount, F: BB->getParent(),
3877 PreInsertBefore: NextBB, PostInsertBefore: NextBB, Name);
3878 BasicBlock *After = CL->getAfter();
3879
3880 // If location is not set, don't connect the loop.
3881 if (updateToLocation(Loc)) {
3882 // Split the loop at the insertion point: Branch to the preheader and move
3883 // every following instruction to after the loop (the After BB). Also, the
3884 // new successor is the loop's after block.
3885 spliceBB(Builder, New: After, /*CreateBranch=*/false);
3886 Builder.CreateBr(Dest: CL->getPreheader());
3887 }
3888
3889 // Emit the body content. We do it after connecting the loop to the CFG to
3890 // avoid that the callback encounters degenerate BBs.
3891 BodyGenCB(CL->getBodyIP(), CL->getIndVar());
3892
3893#ifndef NDEBUG
3894 CL->assertOK();
3895#endif
3896 return CL;
3897}
3898
3899CanonicalLoopInfo *OpenMPIRBuilder::createCanonicalLoop(
3900 const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
3901 Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
3902 InsertPointTy ComputeIP, const Twine &Name) {
3903
3904 // Consider the following difficulties (assuming 8-bit signed integers):
3905 // * Adding \p Step to the loop counter which passes \p Stop may overflow:
3906 // DO I = 1, 100, 50
3907 /// * A \p Step of INT_MIN cannot not be normalized to a positive direction:
3908 // DO I = 100, 0, -128
3909
3910 // Start, Stop and Step must be of the same integer type.
3911 auto *IndVarTy = cast<IntegerType>(Val: Start->getType());
3912 assert(IndVarTy == Stop->getType() && "Stop type mismatch");
3913 assert(IndVarTy == Step->getType() && "Step type mismatch");
3914
3915 LocationDescription ComputeLoc =
3916 ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
3917 updateToLocation(Loc: ComputeLoc);
3918
3919 ConstantInt *Zero = ConstantInt::get(Ty: IndVarTy, V: 0);
3920 ConstantInt *One = ConstantInt::get(Ty: IndVarTy, V: 1);
3921
3922 // Like Step, but always positive.
3923 Value *Incr = Step;
3924
3925 // Distance between Start and Stop; always positive.
3926 Value *Span;
3927
3928 // Condition whether there are no iterations are executed at all, e.g. because
3929 // UB < LB.
3930 Value *ZeroCmp;
3931
3932 if (IsSigned) {
3933 // Ensure that increment is positive. If not, negate and invert LB and UB.
3934 Value *IsNeg = Builder.CreateICmpSLT(LHS: Step, RHS: Zero);
3935 Incr = Builder.CreateSelect(C: IsNeg, True: Builder.CreateNeg(V: Step), False: Step);
3936 Value *LB = Builder.CreateSelect(C: IsNeg, True: Stop, False: Start);
3937 Value *UB = Builder.CreateSelect(C: IsNeg, True: Start, False: Stop);
3938 Span = Builder.CreateSub(LHS: UB, RHS: LB, Name: "", HasNUW: false, HasNSW: true);
3939 ZeroCmp = Builder.CreateICmp(
3940 P: InclusiveStop ? CmpInst::ICMP_SLT : CmpInst::ICMP_SLE, LHS: UB, RHS: LB);
3941 } else {
3942 Span = Builder.CreateSub(LHS: Stop, RHS: Start, Name: "", HasNUW: true);
3943 ZeroCmp = Builder.CreateICmp(
3944 P: InclusiveStop ? CmpInst::ICMP_ULT : CmpInst::ICMP_ULE, LHS: Stop, RHS: Start);
3945 }
3946
3947 Value *CountIfLooping;
3948 if (InclusiveStop) {
3949 CountIfLooping = Builder.CreateAdd(LHS: Builder.CreateUDiv(LHS: Span, RHS: Incr), RHS: One);
3950 } else {
3951 // Avoid incrementing past stop since it could overflow.
3952 Value *CountIfTwo = Builder.CreateAdd(
3953 LHS: Builder.CreateUDiv(LHS: Builder.CreateSub(LHS: Span, RHS: One), RHS: Incr), RHS: One);
3954 Value *OneCmp = Builder.CreateICmp(P: CmpInst::ICMP_ULE, LHS: Span, RHS: Incr);
3955 CountIfLooping = Builder.CreateSelect(C: OneCmp, True: One, False: CountIfTwo);
3956 }
3957 Value *TripCount = Builder.CreateSelect(C: ZeroCmp, True: Zero, False: CountIfLooping,
3958 Name: "omp_" + Name + ".tripcount");
3959
3960 auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
3961 Builder.restoreIP(IP: CodeGenIP);
3962 Value *Span = Builder.CreateMul(LHS: IV, RHS: Step);
3963 Value *IndVar = Builder.CreateAdd(LHS: Span, RHS: Start);
3964 BodyGenCB(Builder.saveIP(), IndVar);
3965 };
3966 LocationDescription LoopLoc = ComputeIP.isSet() ? Loc.IP : Builder.saveIP();
3967 return createCanonicalLoop(Loc: LoopLoc, BodyGenCB: BodyGen, TripCount, Name);
3968}
3969
3970// Returns an LLVM function to call for initializing loop bounds using OpenMP
3971// static scheduling depending on `type`. Only i32 and i64 are supported by the
3972// runtime. Always interpret integers as unsigned similarly to
3973// CanonicalLoopInfo.
3974static FunctionCallee getKmpcForStaticInitForType(Type *Ty, Module &M,
3975 OpenMPIRBuilder &OMPBuilder) {
3976 unsigned Bitwidth = Ty->getIntegerBitWidth();
3977 if (Bitwidth == 32)
3978 return OMPBuilder.getOrCreateRuntimeFunction(
3979 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_4u);
3980 if (Bitwidth == 64)
3981 return OMPBuilder.getOrCreateRuntimeFunction(
3982 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_8u);
3983 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
3984}
3985
3986OpenMPIRBuilder::InsertPointTy
3987OpenMPIRBuilder::applyStaticWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
3988 InsertPointTy AllocaIP,
3989 bool NeedsBarrier) {
3990 assert(CLI->isValid() && "Requires a valid canonical loop");
3991 assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
3992 "Require dedicated allocate IP");
3993
3994 // Set up the source location value for OpenMP runtime.
3995 Builder.restoreIP(IP: CLI->getPreheaderIP());
3996 Builder.SetCurrentDebugLocation(DL);
3997
3998 uint32_t SrcLocStrSize;
3999 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4000 Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4001
4002 // Declare useful OpenMP runtime functions.
4003 Value *IV = CLI->getIndVar();
4004 Type *IVTy = IV->getType();
4005 FunctionCallee StaticInit = getKmpcForStaticInitForType(Ty: IVTy, M, OMPBuilder&: *this);
4006 FunctionCallee StaticFini =
4007 getOrCreateRuntimeFunction(M, FnID: omp::OMPRTL___kmpc_for_static_fini);
4008
4009 // Allocate space for computed loop bounds as expected by the "init" function.
4010 Builder.SetInsertPoint(AllocaIP.getBlock()->getFirstNonPHIOrDbgOrAlloca());
4011
4012 Type *I32Type = Type::getInt32Ty(C&: M.getContext());
4013 Value *PLastIter = Builder.CreateAlloca(Ty: I32Type, ArraySize: nullptr, Name: "p.lastiter");
4014 Value *PLowerBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.lowerbound");
4015 Value *PUpperBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.upperbound");
4016 Value *PStride = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.stride");
4017
4018 // At the end of the preheader, prepare for calling the "init" function by
4019 // storing the current loop bounds into the allocated space. A canonical loop
4020 // always iterates from 0 to trip-count with step 1. Note that "init" expects
4021 // and produces an inclusive upper bound.
4022 Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
4023 Constant *Zero = ConstantInt::get(Ty: IVTy, V: 0);
4024 Constant *One = ConstantInt::get(Ty: IVTy, V: 1);
4025 Builder.CreateStore(Val: Zero, Ptr: PLowerBound);
4026 Value *UpperBound = Builder.CreateSub(LHS: CLI->getTripCount(), RHS: One);
4027 Builder.CreateStore(Val: UpperBound, Ptr: PUpperBound);
4028 Builder.CreateStore(Val: One, Ptr: PStride);
4029
4030 Value *ThreadNum = getOrCreateThreadID(Ident: SrcLoc);
4031
4032 Constant *SchedulingType = ConstantInt::get(
4033 Ty: I32Type, V: static_cast<int>(OMPScheduleType::UnorderedStatic));
4034
4035 // Call the "init" function and update the trip count of the loop with the
4036 // value it produced.
4037 Builder.CreateCall(Callee: StaticInit,
4038 Args: {SrcLoc, ThreadNum, SchedulingType, PLastIter, PLowerBound,
4039 PUpperBound, PStride, One, Zero});
4040 Value *LowerBound = Builder.CreateLoad(Ty: IVTy, Ptr: PLowerBound);
4041 Value *InclusiveUpperBound = Builder.CreateLoad(Ty: IVTy, Ptr: PUpperBound);
4042 Value *TripCountMinusOne = Builder.CreateSub(LHS: InclusiveUpperBound, RHS: LowerBound);
4043 Value *TripCount = Builder.CreateAdd(LHS: TripCountMinusOne, RHS: One);
4044 CLI->setTripCount(TripCount);
4045
4046 // Update all uses of the induction variable except the one in the condition
4047 // block that compares it with the actual upper bound, and the increment in
4048 // the latch block.
4049
4050 CLI->mapIndVar(Updater: [&](Instruction *OldIV) -> Value * {
4051 Builder.SetInsertPoint(TheBB: CLI->getBody(),
4052 IP: CLI->getBody()->getFirstInsertionPt());
4053 Builder.SetCurrentDebugLocation(DL);
4054 return Builder.CreateAdd(LHS: OldIV, RHS: LowerBound);
4055 });
4056
4057 // In the "exit" block, call the "fini" function.
4058 Builder.SetInsertPoint(TheBB: CLI->getExit(),
4059 IP: CLI->getExit()->getTerminator()->getIterator());
4060 Builder.CreateCall(Callee: StaticFini, Args: {SrcLoc, ThreadNum});
4061
4062 // Add the barrier if requested.
4063 if (NeedsBarrier)
4064 createBarrier(Loc: LocationDescription(Builder.saveIP(), DL),
4065 Kind: omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
4066 /* CheckCancelFlag */ false);
4067
4068 InsertPointTy AfterIP = CLI->getAfterIP();
4069 CLI->invalidate();
4070
4071 return AfterIP;
4072}
4073
4074OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
4075 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
4076 bool NeedsBarrier, Value *ChunkSize) {
4077 assert(CLI->isValid() && "Requires a valid canonical loop");
4078 assert(ChunkSize && "Chunk size is required");
4079
4080 LLVMContext &Ctx = CLI->getFunction()->getContext();
4081 Value *IV = CLI->getIndVar();
4082 Value *OrigTripCount = CLI->getTripCount();
4083 Type *IVTy = IV->getType();
4084 assert(IVTy->getIntegerBitWidth() <= 64 &&
4085 "Max supported tripcount bitwidth is 64 bits");
4086 Type *InternalIVTy = IVTy->getIntegerBitWidth() <= 32 ? Type::getInt32Ty(C&: Ctx)
4087 : Type::getInt64Ty(C&: Ctx);
4088 Type *I32Type = Type::getInt32Ty(C&: M.getContext());
4089 Constant *Zero = ConstantInt::get(Ty: InternalIVTy, V: 0);
4090 Constant *One = ConstantInt::get(Ty: InternalIVTy, V: 1);
4091
4092 // Declare useful OpenMP runtime functions.
4093 FunctionCallee StaticInit =
4094 getKmpcForStaticInitForType(Ty: InternalIVTy, M, OMPBuilder&: *this);
4095 FunctionCallee StaticFini =
4096 getOrCreateRuntimeFunction(M, FnID: omp::OMPRTL___kmpc_for_static_fini);
4097
4098 // Allocate space for computed loop bounds as expected by the "init" function.
4099 Builder.restoreIP(IP: AllocaIP);
4100 Builder.SetCurrentDebugLocation(DL);
4101 Value *PLastIter = Builder.CreateAlloca(Ty: I32Type, ArraySize: nullptr, Name: "p.lastiter");
4102 Value *PLowerBound =
4103 Builder.CreateAlloca(Ty: InternalIVTy, ArraySize: nullptr, Name: "p.lowerbound");
4104 Value *PUpperBound =
4105 Builder.CreateAlloca(Ty: InternalIVTy, ArraySize: nullptr, Name: "p.upperbound");
4106 Value *PStride = Builder.CreateAlloca(Ty: InternalIVTy, ArraySize: nullptr, Name: "p.stride");
4107
4108 // Set up the source location value for the OpenMP runtime.
4109 Builder.restoreIP(IP: CLI->getPreheaderIP());
4110 Builder.SetCurrentDebugLocation(DL);
4111
4112 // TODO: Detect overflow in ubsan or max-out with current tripcount.
4113 Value *CastedChunkSize =
4114 Builder.CreateZExtOrTrunc(V: ChunkSize, DestTy: InternalIVTy, Name: "chunksize");
4115 Value *CastedTripCount =
4116 Builder.CreateZExt(V: OrigTripCount, DestTy: InternalIVTy, Name: "tripcount");
4117
4118 Constant *SchedulingType = ConstantInt::get(
4119 Ty: I32Type, V: static_cast<int>(OMPScheduleType::UnorderedStaticChunked));
4120 Builder.CreateStore(Val: Zero, Ptr: PLowerBound);
4121 Value *OrigUpperBound = Builder.CreateSub(LHS: CastedTripCount, RHS: One);
4122 Builder.CreateStore(Val: OrigUpperBound, Ptr: PUpperBound);
4123 Builder.CreateStore(Val: One, Ptr: PStride);
4124
4125 // Call the "init" function and update the trip count of the loop with the
4126 // value it produced.
4127 uint32_t SrcLocStrSize;
4128 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4129 Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4130 Value *ThreadNum = getOrCreateThreadID(Ident: SrcLoc);
4131 Builder.CreateCall(Callee: StaticInit,
4132 Args: {/*loc=*/SrcLoc, /*global_tid=*/ThreadNum,
4133 /*schedtype=*/SchedulingType, /*plastiter=*/PLastIter,
4134 /*plower=*/PLowerBound, /*pupper=*/PUpperBound,
4135 /*pstride=*/PStride, /*incr=*/One,
4136 /*chunk=*/CastedChunkSize});
4137
4138 // Load values written by the "init" function.
4139 Value *FirstChunkStart =
4140 Builder.CreateLoad(Ty: InternalIVTy, Ptr: PLowerBound, Name: "omp_firstchunk.lb");
4141 Value *FirstChunkStop =
4142 Builder.CreateLoad(Ty: InternalIVTy, Ptr: PUpperBound, Name: "omp_firstchunk.ub");
4143 Value *FirstChunkEnd = Builder.CreateAdd(LHS: FirstChunkStop, RHS: One);
4144 Value *ChunkRange =
4145 Builder.CreateSub(LHS: FirstChunkEnd, RHS: FirstChunkStart, Name: "omp_chunk.range");
4146 Value *NextChunkStride =
4147 Builder.CreateLoad(Ty: InternalIVTy, Ptr: PStride, Name: "omp_dispatch.stride");
4148
4149 // Create outer "dispatch" loop for enumerating the chunks.
4150 BasicBlock *DispatchEnter = splitBB(Builder, CreateBranch: true);
4151 Value *DispatchCounter;
4152 CanonicalLoopInfo *DispatchCLI = createCanonicalLoop(
4153 Loc: {Builder.saveIP(), DL},
4154 BodyGenCB: [&](InsertPointTy BodyIP, Value *Counter) { DispatchCounter = Counter; },
4155 Start: FirstChunkStart, Stop: CastedTripCount, Step: NextChunkStride,
4156 /*IsSigned=*/false, /*InclusiveStop=*/false, /*ComputeIP=*/{},
4157 Name: "dispatch");
4158
4159 // Remember the BasicBlocks of the dispatch loop we need, then invalidate to
4160 // not have to preserve the canonical invariant.
4161 BasicBlock *DispatchBody = DispatchCLI->getBody();
4162 BasicBlock *DispatchLatch = DispatchCLI->getLatch();
4163 BasicBlock *DispatchExit = DispatchCLI->getExit();
4164 BasicBlock *DispatchAfter = DispatchCLI->getAfter();
4165 DispatchCLI->invalidate();
4166
4167 // Rewire the original loop to become the chunk loop inside the dispatch loop.
4168 redirectTo(Source: DispatchAfter, Target: CLI->getAfter(), DL);
4169 redirectTo(Source: CLI->getExit(), Target: DispatchLatch, DL);
4170 redirectTo(Source: DispatchBody, Target: DispatchEnter, DL);
4171
4172 // Prepare the prolog of the chunk loop.
4173 Builder.restoreIP(IP: CLI->getPreheaderIP());
4174 Builder.SetCurrentDebugLocation(DL);
4175
4176 // Compute the number of iterations of the chunk loop.
4177 Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
4178 Value *ChunkEnd = Builder.CreateAdd(LHS: DispatchCounter, RHS: ChunkRange);
4179 Value *IsLastChunk =
4180 Builder.CreateICmpUGE(LHS: ChunkEnd, RHS: CastedTripCount, Name: "omp_chunk.is_last");
4181 Value *CountUntilOrigTripCount =
4182 Builder.CreateSub(LHS: CastedTripCount, RHS: DispatchCounter);
4183 Value *ChunkTripCount = Builder.CreateSelect(
4184 C: IsLastChunk, True: CountUntilOrigTripCount, False: ChunkRange, Name: "omp_chunk.tripcount");
4185 Value *BackcastedChunkTC =
4186 Builder.CreateTrunc(V: ChunkTripCount, DestTy: IVTy, Name: "omp_chunk.tripcount.trunc");
4187 CLI->setTripCount(BackcastedChunkTC);
4188
4189 // Update all uses of the induction variable except the one in the condition
4190 // block that compares it with the actual upper bound, and the increment in
4191 // the latch block.
4192 Value *BackcastedDispatchCounter =
4193 Builder.CreateTrunc(V: DispatchCounter, DestTy: IVTy, Name: "omp_dispatch.iv.trunc");
4194 CLI->mapIndVar(Updater: [&](Instruction *) -> Value * {
4195 Builder.restoreIP(IP: CLI->getBodyIP());
4196 return Builder.CreateAdd(LHS: IV, RHS: BackcastedDispatchCounter);
4197 });
4198
4199 // In the "exit" block, call the "fini" function.
4200 Builder.SetInsertPoint(TheBB: DispatchExit, IP: DispatchExit->getFirstInsertionPt());
4201 Builder.CreateCall(Callee: StaticFini, Args: {SrcLoc, ThreadNum});
4202
4203 // Add the barrier if requested.
4204 if (NeedsBarrier)
4205 createBarrier(Loc: LocationDescription(Builder.saveIP(), DL), Kind: OMPD_for,
4206 /*ForceSimpleCall=*/false, /*CheckCancelFlag=*/false);
4207
4208#ifndef NDEBUG
4209 // Even though we currently do not support applying additional methods to it,
4210 // the chunk loop should remain a canonical loop.
4211 CLI->assertOK();
4212#endif
4213
4214 return {DispatchAfter, DispatchAfter->getFirstInsertionPt()};
4215}
4216
4217// Returns an LLVM function to call for executing an OpenMP static worksharing
4218// for loop depending on `type`. Only i32 and i64 are supported by the runtime.
4219// Always interpret integers as unsigned similarly to CanonicalLoopInfo.
4220static FunctionCallee
4221getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
4222 WorksharingLoopType LoopType) {
4223 unsigned Bitwidth = Ty->getIntegerBitWidth();
4224 Module &M = OMPBuilder->M;
4225 switch (LoopType) {
4226 case WorksharingLoopType::ForStaticLoop:
4227 if (Bitwidth == 32)
4228 return OMPBuilder->getOrCreateRuntimeFunction(
4229 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_4u);
4230 if (Bitwidth == 64)
4231 return OMPBuilder->getOrCreateRuntimeFunction(
4232 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_8u);
4233 break;
4234 case WorksharingLoopType::DistributeStaticLoop:
4235 if (Bitwidth == 32)
4236 return OMPBuilder->getOrCreateRuntimeFunction(
4237 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_4u);
4238 if (Bitwidth == 64)
4239 return OMPBuilder->getOrCreateRuntimeFunction(
4240 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_8u);
4241 break;
4242 case WorksharingLoopType::DistributeForStaticLoop:
4243 if (Bitwidth == 32)
4244 return OMPBuilder->getOrCreateRuntimeFunction(
4245 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_4u);
4246 if (Bitwidth == 64)
4247 return OMPBuilder->getOrCreateRuntimeFunction(
4248 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_8u);
4249 break;
4250 }
4251 if (Bitwidth != 32 && Bitwidth != 64) {
4252 llvm_unreachable("Unknown OpenMP loop iterator bitwidth");
4253 }
4254 llvm_unreachable("Unknown type of OpenMP worksharing loop");
4255}
4256
4257// Inserts a call to proper OpenMP Device RTL function which handles
4258// loop worksharing.
4259static void createTargetLoopWorkshareCall(
4260 OpenMPIRBuilder *OMPBuilder, WorksharingLoopType LoopType,
4261 BasicBlock *InsertBlock, Value *Ident, Value *LoopBodyArg,
4262 Type *ParallelTaskPtr, Value *TripCount, Function &LoopBodyFn) {
4263 Type *TripCountTy = TripCount->getType();
4264 Module &M = OMPBuilder->M;
4265 IRBuilder<> &Builder = OMPBuilder->Builder;
4266 FunctionCallee RTLFn =
4267 getKmpcForStaticLoopForType(Ty: TripCountTy, OMPBuilder, LoopType);
4268 SmallVector<Value *, 8> RealArgs;
4269 RealArgs.push_back(Elt: Ident);
4270 RealArgs.push_back(Elt: Builder.CreateBitCast(V: &LoopBodyFn, DestTy: ParallelTaskPtr));
4271 RealArgs.push_back(Elt: LoopBodyArg);
4272 RealArgs.push_back(Elt: TripCount);
4273 if (LoopType == WorksharingLoopType::DistributeStaticLoop) {
4274 RealArgs.push_back(Elt: ConstantInt::get(Ty: TripCountTy, V: 0));
4275 Builder.CreateCall(Callee: RTLFn, Args: RealArgs);
4276 return;
4277 }
4278 FunctionCallee RTLNumThreads = OMPBuilder->getOrCreateRuntimeFunction(
4279 M, FnID: omp::RuntimeFunction::OMPRTL_omp_get_num_threads);
4280 Builder.restoreIP(IP: {InsertBlock, std::prev(x: InsertBlock->end())});
4281 Value *NumThreads = Builder.CreateCall(Callee: RTLNumThreads, Args: {});
4282
4283 RealArgs.push_back(
4284 Elt: Builder.CreateZExtOrTrunc(V: NumThreads, DestTy: TripCountTy, Name: "num.threads.cast"));
4285 RealArgs.push_back(Elt: ConstantInt::get(Ty: TripCountTy, V: 0));
4286 if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
4287 RealArgs.push_back(Elt: ConstantInt::get(Ty: TripCountTy, V: 0));
4288 }
4289
4290 Builder.CreateCall(Callee: RTLFn, Args: RealArgs);
4291}
4292
4293static void
4294workshareLoopTargetCallback(OpenMPIRBuilder *OMPIRBuilder,
4295 CanonicalLoopInfo *CLI, Value *Ident,
4296 Function &OutlinedFn, Type *ParallelTaskPtr,
4297 const SmallVector<Instruction *, 4> &ToBeDeleted,
4298 WorksharingLoopType LoopType) {
4299 IRBuilder<> &Builder = OMPIRBuilder->Builder;
4300 BasicBlock *Preheader = CLI->getPreheader();
4301 Value *TripCount = CLI->getTripCount();
4302
4303 // After loop body outling, the loop body contains only set up
4304 // of loop body argument structure and the call to the outlined
4305 // loop body function. Firstly, we need to move setup of loop body args
4306 // into loop preheader.
4307 Preheader->splice(ToIt: std::prev(x: Preheader->end()), FromBB: CLI->getBody(),
4308 FromBeginIt: CLI->getBody()->begin(), FromEndIt: std::prev(x: CLI->getBody()->end()));
4309
4310 // The next step is to remove the whole loop. We do not it need anymore.
4311 // That's why make an unconditional branch from loop preheader to loop
4312 // exit block
4313 Builder.restoreIP(IP: {Preheader, Preheader->end()});
4314 Preheader->getTerminator()->eraseFromParent();
4315 Builder.CreateBr(Dest: CLI->getExit());
4316
4317 // Delete dead loop blocks
4318 OpenMPIRBuilder::OutlineInfo CleanUpInfo;
4319 SmallPtrSet<BasicBlock *, 32> RegionBlockSet;
4320 SmallVector<BasicBlock *, 32> BlocksToBeRemoved;
4321 CleanUpInfo.EntryBB = CLI->getHeader();
4322 CleanUpInfo.ExitBB = CLI->getExit();
4323 CleanUpInfo.collectBlocks(BlockSet&: RegionBlockSet, BlockVector&: BlocksToBeRemoved);
4324 DeleteDeadBlocks(BBs: BlocksToBeRemoved);
4325
4326 // Find the instruction which corresponds to loop body argument structure
4327 // and remove the call to loop body function instruction.
4328 Value *LoopBodyArg;
4329 User *OutlinedFnUser = OutlinedFn.getUniqueUndroppableUser();
4330 assert(OutlinedFnUser &&
4331 "Expected unique undroppable user of outlined function");
4332 CallInst *OutlinedFnCallInstruction = dyn_cast<CallInst>(Val: OutlinedFnUser);
4333 assert(OutlinedFnCallInstruction && "Expected outlined function call");
4334 assert((OutlinedFnCallInstruction->getParent() == Preheader) &&
4335 "Expected outlined function call to be located in loop preheader");
4336 // Check in case no argument structure has been passed.
4337 if (OutlinedFnCallInstruction->arg_size() > 1)
4338 LoopBodyArg = OutlinedFnCallInstruction->getArgOperand(i: 1);
4339 else
4340 LoopBodyArg = Constant::getNullValue(Ty: Builder.getPtrTy());
4341 OutlinedFnCallInstruction->eraseFromParent();
4342
4343 createTargetLoopWorkshareCall(OMPBuilder: OMPIRBuilder, LoopType, InsertBlock: Preheader, Ident,
4344 LoopBodyArg, ParallelTaskPtr, TripCount,
4345 LoopBodyFn&: OutlinedFn);
4346
4347 for (auto &ToBeDeletedItem : ToBeDeleted)
4348 ToBeDeletedItem->eraseFromParent();
4349 CLI->invalidate();
4350}
4351
4352OpenMPIRBuilder::InsertPointTy
4353OpenMPIRBuilder::applyWorkshareLoopTarget(DebugLoc DL, CanonicalLoopInfo *CLI,
4354 InsertPointTy AllocaIP,
4355 WorksharingLoopType LoopType) {
4356 uint32_t SrcLocStrSize;
4357 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4358 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4359
4360 OutlineInfo OI;
4361 OI.OuterAllocaBB = CLI->getPreheader();
4362 Function *OuterFn = CLI->getPreheader()->getParent();
4363
4364 // Instructions which need to be deleted at the end of code generation
4365 SmallVector<Instruction *, 4> ToBeDeleted;
4366
4367 OI.OuterAllocaBB = AllocaIP.getBlock();
4368
4369 // Mark the body loop as region which needs to be extracted
4370 OI.EntryBB = CLI->getBody();
4371 OI.ExitBB = CLI->getLatch()->splitBasicBlock(I: CLI->getLatch()->begin(),
4372 BBName: "omp.prelatch", Before: true);
4373
4374 // Prepare loop body for extraction
4375 Builder.restoreIP(IP: {CLI->getPreheader(), CLI->getPreheader()->begin()});
4376
4377 // Insert new loop counter variable which will be used only in loop
4378 // body.
4379 AllocaInst *NewLoopCnt = Builder.CreateAlloca(Ty: CLI->getIndVarType(), ArraySize: 0, Name: "");
4380 Instruction *NewLoopCntLoad =
4381 Builder.CreateLoad(Ty: CLI->getIndVarType(), Ptr: NewLoopCnt);
4382 // New loop counter instructions are redundant in the loop preheader when
4383 // code generation for workshare loop is finshed. That's why mark them as
4384 // ready for deletion.
4385 ToBeDeleted.push_back(Elt: NewLoopCntLoad);
4386 ToBeDeleted.push_back(Elt: NewLoopCnt);
4387
4388 // Analyse loop body region. Find all input variables which are used inside
4389 // loop body region.
4390 SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
4391 SmallVector<BasicBlock *, 32> Blocks;
4392 OI.collectBlocks(BlockSet&: ParallelRegionBlockSet, BlockVector&: Blocks);
4393 SmallVector<BasicBlock *, 32> BlocksT(ParallelRegionBlockSet.begin(),
4394 ParallelRegionBlockSet.end());
4395
4396 CodeExtractorAnalysisCache CEAC(*OuterFn);
4397 CodeExtractor Extractor(Blocks,
4398 /* DominatorTree */ nullptr,
4399 /* AggregateArgs */ true,
4400 /* BlockFrequencyInfo */ nullptr,
4401 /* BranchProbabilityInfo */ nullptr,
4402 /* AssumptionCache */ nullptr,
4403 /* AllowVarArgs */ true,
4404 /* AllowAlloca */ true,
4405 /* AllocationBlock */ CLI->getPreheader(),
4406 /* Suffix */ ".omp_wsloop",
4407 /* AggrArgsIn0AddrSpace */ true);
4408
4409 BasicBlock *CommonExit = nullptr;
4410 SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
4411
4412 // Find allocas outside the loop body region which are used inside loop
4413 // body
4414 Extractor.findAllocas(CEAC, SinkCands&: SinkingCands, HoistCands&: HoistingCands, ExitBlock&: CommonExit);
4415
4416 // We need to model loop body region as the function f(cnt, loop_arg).
4417 // That's why we replace loop induction variable by the new counter
4418 // which will be one of loop body function argument
4419 SmallVector<User *> Users(CLI->getIndVar()->user_begin(),
4420 CLI->getIndVar()->user_end());
4421 for (auto Use : Users) {
4422 if (Instruction *Inst = dyn_cast<Instruction>(Val: Use)) {
4423 if (ParallelRegionBlockSet.count(Ptr: Inst->getParent())) {
4424 Inst->replaceUsesOfWith(From: CLI->getIndVar(), To: NewLoopCntLoad);
4425 }
4426 }
4427 }
4428 // Make sure that loop counter variable is not merged into loop body
4429 // function argument structure and it is passed as separate variable
4430 OI.ExcludeArgsFromAggregate.push_back(Elt: NewLoopCntLoad);
4431
4432 // PostOutline CB is invoked when loop body function is outlined and
4433 // loop body is replaced by call to outlined function. We need to add
4434 // call to OpenMP device rtl inside loop preheader. OpenMP device rtl
4435 // function will handle loop control logic.
4436 //
4437 OI.PostOutlineCB = [=, ToBeDeletedVec =
4438 std::move(ToBeDeleted)](Function &OutlinedFn) {
4439 workshareLoopTargetCallback(OMPIRBuilder: this, CLI, Ident, OutlinedFn, ParallelTaskPtr,
4440 ToBeDeleted: ToBeDeletedVec, LoopType);
4441 };
4442 addOutlineInfo(OI: std::move(OI));
4443 return CLI->getAfterIP();
4444}
4445
4446OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoop(
4447 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
4448 bool NeedsBarrier, omp::ScheduleKind SchedKind, Value *ChunkSize,
4449 bool HasSimdModifier, bool HasMonotonicModifier,
4450 bool HasNonmonotonicModifier, bool HasOrderedClause,
4451 WorksharingLoopType LoopType) {
4452 if (Config.isTargetDevice())
4453 return applyWorkshareLoopTarget(DL, CLI, AllocaIP, LoopType);
4454 OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType(
4455 ClauseKind: SchedKind, HasChunks: ChunkSize, HasSimdModifier, HasMonotonicModifier,
4456 HasNonmonotonicModifier, HasOrderedClause);
4457
4458 bool IsOrdered = (EffectiveScheduleType & OMPScheduleType::ModifierOrdered) ==
4459 OMPScheduleType::ModifierOrdered;
4460 switch (EffectiveScheduleType & ~OMPScheduleType::ModifierMask) {
4461 case OMPScheduleType::BaseStatic:
4462 assert(!ChunkSize && "No chunk size with static-chunked schedule");
4463 if (IsOrdered)
4464 return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, SchedType: EffectiveScheduleType,
4465 NeedsBarrier, Chunk: ChunkSize);
4466 // FIXME: Monotonicity ignored?
4467 return applyStaticWorkshareLoop(DL, CLI, AllocaIP, NeedsBarrier);
4468
4469 case OMPScheduleType::BaseStaticChunked:
4470 if (IsOrdered)
4471 return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, SchedType: EffectiveScheduleType,
4472 NeedsBarrier, Chunk: ChunkSize);
4473 // FIXME: Monotonicity ignored?
4474 return applyStaticChunkedWorkshareLoop(DL, CLI, AllocaIP, NeedsBarrier,
4475 ChunkSize);
4476
4477 case OMPScheduleType::BaseRuntime:
4478 case OMPScheduleType::BaseAuto:
4479 case OMPScheduleType::BaseGreedy:
4480 case OMPScheduleType::BaseBalanced:
4481 case OMPScheduleType::BaseSteal:
4482 case OMPScheduleType::BaseGuidedSimd:
4483 case OMPScheduleType::BaseRuntimeSimd:
4484 assert(!ChunkSize &&
4485 "schedule type does not support user-defined chunk sizes");
4486 [[fallthrough]];
4487 case OMPScheduleType::BaseDynamicChunked:
4488 case OMPScheduleType::BaseGuidedChunked:
4489 case OMPScheduleType::BaseGuidedIterativeChunked:
4490 case OMPScheduleType::BaseGuidedAnalyticalChunked:
4491 case OMPScheduleType::BaseStaticBalancedChunked:
4492 return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, SchedType: EffectiveScheduleType,
4493 NeedsBarrier, Chunk: ChunkSize);
4494
4495 default:
4496 llvm_unreachable("Unknown/unimplemented schedule kind");
4497 }
4498}
4499
4500/// Returns an LLVM function to call for initializing loop bounds using OpenMP
4501/// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
4502/// the runtime. Always interpret integers as unsigned similarly to
4503/// CanonicalLoopInfo.
4504static FunctionCallee
4505getKmpcForDynamicInitForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
4506 unsigned Bitwidth = Ty->getIntegerBitWidth();
4507 if (Bitwidth == 32)
4508 return OMPBuilder.getOrCreateRuntimeFunction(
4509 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_4u);
4510 if (Bitwidth == 64)
4511 return OMPBuilder.getOrCreateRuntimeFunction(
4512 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_8u);
4513 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
4514}
4515
4516/// Returns an LLVM function to call for updating the next loop using OpenMP
4517/// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
4518/// the runtime. Always interpret integers as unsigned similarly to
4519/// CanonicalLoopInfo.
4520static FunctionCallee
4521getKmpcForDynamicNextForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
4522 unsigned Bitwidth = Ty->getIntegerBitWidth();
4523 if (Bitwidth == 32)
4524 return OMPBuilder.getOrCreateRuntimeFunction(
4525 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_4u);
4526 if (Bitwidth == 64)
4527 return OMPBuilder.getOrCreateRuntimeFunction(
4528 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_8u);
4529 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
4530}
4531
4532/// Returns an LLVM function to call for finalizing the dynamic loop using
4533/// depending on `type`. Only i32 and i64 are supported by the runtime. Always
4534/// interpret integers as unsigned similarly to CanonicalLoopInfo.
4535static FunctionCallee
4536getKmpcForDynamicFiniForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
4537 unsigned Bitwidth = Ty->getIntegerBitWidth();
4538 if (Bitwidth == 32)
4539 return OMPBuilder.getOrCreateRuntimeFunction(
4540 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_4u);
4541 if (Bitwidth == 64)
4542 return OMPBuilder.getOrCreateRuntimeFunction(
4543 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_8u);
4544 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
4545}
4546
4547OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyDynamicWorkshareLoop(
4548 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
4549 OMPScheduleType SchedType, bool NeedsBarrier, Value *Chunk) {
4550 assert(CLI->isValid() && "Requires a valid canonical loop");
4551 assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
4552 "Require dedicated allocate IP");
4553 assert(isValidWorkshareLoopScheduleType(SchedType) &&
4554 "Require valid schedule type");
4555
4556 bool Ordered = (SchedType & OMPScheduleType::ModifierOrdered) ==
4557 OMPScheduleType::ModifierOrdered;
4558
4559 // Set up the source location value for OpenMP runtime.
4560 Builder.SetCurrentDebugLocation(DL);
4561
4562 uint32_t SrcLocStrSize;
4563 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4564 Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4565
4566 // Declare useful OpenMP runtime functions.
4567 Value *IV = CLI->getIndVar();
4568 Type *IVTy = IV->getType();
4569 FunctionCallee DynamicInit = getKmpcForDynamicInitForType(Ty: IVTy, M, OMPBuilder&: *this);
4570 FunctionCallee DynamicNext = getKmpcForDynamicNextForType(Ty: IVTy, M, OMPBuilder&: *this);
4571
4572 // Allocate space for computed loop bounds as expected by the "init" function.
4573 Builder.SetInsertPoint(AllocaIP.getBlock()->getFirstNonPHIOrDbgOrAlloca());
4574 Type *I32Type = Type::getInt32Ty(C&: M.getContext());
4575 Value *PLastIter = Builder.CreateAlloca(Ty: I32Type, ArraySize: nullptr, Name: "p.lastiter");
4576 Value *PLowerBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.lowerbound");
4577 Value *PUpperBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.upperbound");
4578 Value *PStride = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.stride");
4579
4580 // At the end of the preheader, prepare for calling the "init" function by
4581 // storing the current loop bounds into the allocated space. A canonical loop
4582 // always iterates from 0 to trip-count with step 1. Note that "init" expects
4583 // and produces an inclusive upper bound.
4584 BasicBlock *PreHeader = CLI->getPreheader();
4585 Builder.SetInsertPoint(PreHeader->getTerminator());
4586 Constant *One = ConstantInt::get(Ty: IVTy, V: 1);
4587 Builder.CreateStore(Val: One, Ptr: PLowerBound);
4588 Value *UpperBound = CLI->getTripCount();
4589 Builder.CreateStore(Val: UpperBound, Ptr: PUpperBound);
4590 Builder.CreateStore(Val: One, Ptr: PStride);
4591
4592 BasicBlock *Header = CLI->getHeader();
4593 BasicBlock *Exit = CLI->getExit();
4594 BasicBlock *Cond = CLI->getCond();
4595 BasicBlock *Latch = CLI->getLatch();
4596 InsertPointTy AfterIP = CLI->getAfterIP();
4597
4598 // The CLI will be "broken" in the code below, as the loop is no longer
4599 // a valid canonical loop.
4600
4601 if (!Chunk)
4602 Chunk = One;
4603
4604 Value *ThreadNum = getOrCreateThreadID(Ident: SrcLoc);
4605
4606 Constant *SchedulingType =
4607 ConstantInt::get(Ty: I32Type, V: static_cast<int>(SchedType));
4608
4609 // Call the "init" function.
4610 Builder.CreateCall(Callee: DynamicInit,
4611 Args: {SrcLoc, ThreadNum, SchedulingType, /* LowerBound */ One,
4612 UpperBound, /* step */ One, Chunk});
4613
4614 // An outer loop around the existing one.
4615 BasicBlock *OuterCond = BasicBlock::Create(
4616 Context&: PreHeader->getContext(), Name: Twine(PreHeader->getName()) + ".outer.cond",
4617 Parent: PreHeader->getParent());
4618 // This needs to be 32-bit always, so can't use the IVTy Zero above.
4619 Builder.SetInsertPoint(TheBB: OuterCond, IP: OuterCond->getFirstInsertionPt());
4620 Value *Res =
4621 Builder.CreateCall(Callee: DynamicNext, Args: {SrcLoc, ThreadNum, PLastIter,
4622 PLowerBound, PUpperBound, PStride});
4623 Constant *Zero32 = ConstantInt::get(Ty: I32Type, V: 0);
4624 Value *MoreWork = Builder.CreateCmp(Pred: CmpInst::ICMP_NE, LHS: Res, RHS: Zero32);
4625 Value *LowerBound =
4626 Builder.CreateSub(LHS: Builder.CreateLoad(Ty: IVTy, Ptr: PLowerBound), RHS: One, Name: "lb");
4627 Builder.CreateCondBr(Cond: MoreWork, True: Header, False: Exit);
4628
4629 // Change PHI-node in loop header to use outer cond rather than preheader,
4630 // and set IV to the LowerBound.
4631 Instruction *Phi = &Header->front();
4632 auto *PI = cast<PHINode>(Val: Phi);
4633 PI->setIncomingBlock(i: 0, BB: OuterCond);
4634 PI->setIncomingValue(i: 0, V: LowerBound);
4635
4636 // Then set the pre-header to jump to the OuterCond
4637 Instruction *Term = PreHeader->getTerminator();
4638 auto *Br = cast<BranchInst>(Val: Term);
4639 Br->setSuccessor(idx: 0, NewSucc: OuterCond);
4640
4641 // Modify the inner condition:
4642 // * Use the UpperBound returned from the DynamicNext call.
4643 // * jump to the loop outer loop when done with one of the inner loops.
4644 Builder.SetInsertPoint(TheBB: Cond, IP: Cond->getFirstInsertionPt());
4645 UpperBound = Builder.CreateLoad(Ty: IVTy, Ptr: PUpperBound, Name: "ub");
4646 Instruction *Comp = &*Builder.GetInsertPoint();
4647 auto *CI = cast<CmpInst>(Val: Comp);
4648 CI->setOperand(i_nocapture: 1, Val_nocapture: UpperBound);
4649 // Redirect the inner exit to branch to outer condition.
4650 Instruction *Branch = &Cond->back();
4651 auto *BI = cast<BranchInst>(Val: Branch);
4652 assert(BI->getSuccessor(1) == Exit);
4653 BI->setSuccessor(idx: 1, NewSucc: OuterCond);
4654
4655 // Call the "fini" function if "ordered" is present in wsloop directive.
4656 if (Ordered) {
4657 Builder.SetInsertPoint(&Latch->back());
4658 FunctionCallee DynamicFini = getKmpcForDynamicFiniForType(Ty: IVTy, M, OMPBuilder&: *this);
4659 Builder.CreateCall(Callee: DynamicFini, Args: {SrcLoc, ThreadNum});
4660 }
4661
4662 // Add the barrier if requested.
4663 if (NeedsBarrier) {
4664 Builder.SetInsertPoint(&Exit->back());
4665 createBarrier(Loc: LocationDescription(Builder.saveIP(), DL),
4666 Kind: omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
4667 /* CheckCancelFlag */ false);
4668 }
4669
4670 CLI->invalidate();
4671 return AfterIP;
4672}
4673
4674/// Redirect all edges that branch to \p OldTarget to \p NewTarget. That is,
4675/// after this \p OldTarget will be orphaned.
4676static void redirectAllPredecessorsTo(BasicBlock *OldTarget,
4677 BasicBlock *NewTarget, DebugLoc DL) {
4678 for (BasicBlock *Pred : make_early_inc_range(Range: predecessors(BB: OldTarget)))
4679 redirectTo(Source: Pred, Target: NewTarget, DL);
4680}
4681
4682/// Determine which blocks in \p BBs are reachable from outside and remove the
4683/// ones that are not reachable from the function.
4684static void removeUnusedBlocksFromParent(ArrayRef<BasicBlock *> BBs) {
4685 SmallPtrSet<BasicBlock *, 6> BBsToErase{BBs.begin(), BBs.end()};
4686 auto HasRemainingUses = [&BBsToErase](BasicBlock *BB) {
4687 for (Use &U : BB->uses()) {
4688 auto *UseInst = dyn_cast<Instruction>(Val: U.getUser());
4689 if (!UseInst)
4690 continue;
4691 if (BBsToErase.count(Ptr: UseInst->getParent()))
4692 continue;
4693 return true;
4694 }
4695 return false;
4696 };
4697
4698 while (BBsToErase.remove_if(P: HasRemainingUses)) {
4699 // Try again if anything was removed.
4700 }
4701
4702 SmallVector<BasicBlock *, 7> BBVec(BBsToErase.begin(), BBsToErase.end());
4703 DeleteDeadBlocks(BBs: BBVec);
4704}
4705
4706CanonicalLoopInfo *
4707OpenMPIRBuilder::collapseLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
4708 InsertPointTy ComputeIP) {
4709 assert(Loops.size() >= 1 && "At least one loop required");
4710 size_t NumLoops = Loops.size();
4711
4712 // Nothing to do if there is already just one loop.
4713 if (NumLoops == 1)
4714 return Loops.front();
4715
4716 CanonicalLoopInfo *Outermost = Loops.front();
4717 CanonicalLoopInfo *Innermost = Loops.back();
4718 BasicBlock *OrigPreheader = Outermost->getPreheader();
4719 BasicBlock *OrigAfter = Outermost->getAfter();
4720 Function *F = OrigPreheader->getParent();
4721
4722 // Loop control blocks that may become orphaned later.
4723 SmallVector<BasicBlock *, 12> OldControlBBs;
4724 OldControlBBs.reserve(N: 6 * Loops.size());
4725 for (CanonicalLoopInfo *Loop : Loops)
4726 Loop->collectControlBlocks(BBs&: OldControlBBs);
4727
4728 // Setup the IRBuilder for inserting the trip count computation.
4729 Builder.SetCurrentDebugLocation(DL);
4730 if (ComputeIP.isSet())
4731 Builder.restoreIP(IP: ComputeIP);
4732 else
4733 Builder.restoreIP(IP: Outermost->getPreheaderIP());
4734
4735 // Derive the collapsed' loop trip count.
4736 // TODO: Find common/largest indvar type.
4737 Value *CollapsedTripCount = nullptr;
4738 for (CanonicalLoopInfo *L : Loops) {
4739 assert(L->isValid() &&
4740 "All loops to collapse must be valid canonical loops");
4741 Value *OrigTripCount = L->getTripCount();
4742 if (!CollapsedTripCount) {
4743 CollapsedTripCount = OrigTripCount;
4744 continue;
4745 }
4746
4747 // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
4748 CollapsedTripCount = Builder.CreateMul(LHS: CollapsedTripCount, RHS: OrigTripCount,
4749 Name: {}, /*HasNUW=*/true);
4750 }
4751
4752 // Create the collapsed loop control flow.
4753 CanonicalLoopInfo *Result =
4754 createLoopSkeleton(DL, TripCount: CollapsedTripCount, F,
4755 PreInsertBefore: OrigPreheader->getNextNode(), PostInsertBefore: OrigAfter, Name: "collapsed");
4756
4757 // Build the collapsed loop body code.
4758 // Start with deriving the input loop induction variables from the collapsed
4759 // one, using a divmod scheme. To preserve the original loops' order, the
4760 // innermost loop use the least significant bits.
4761 Builder.restoreIP(IP: Result->getBodyIP());
4762
4763 Value *Leftover = Result->getIndVar();
4764 SmallVector<Value *> NewIndVars;
4765 NewIndVars.resize(N: NumLoops);
4766 for (int i = NumLoops - 1; i >= 1; --i) {
4767 Value *OrigTripCount = Loops[i]->getTripCount();
4768
4769 Value *NewIndVar = Builder.CreateURem(LHS: Leftover, RHS: OrigTripCount);
4770 NewIndVars[i] = NewIndVar;
4771
4772 Leftover = Builder.CreateUDiv(LHS: Leftover, RHS: OrigTripCount);
4773 }
4774 // Outermost loop gets all the remaining bits.
4775 NewIndVars[0] = Leftover;
4776
4777 // Construct the loop body control flow.
4778 // We progressively construct the branch structure following in direction of
4779 // the control flow, from the leading in-between code, the loop nest body, the
4780 // trailing in-between code, and rejoining the collapsed loop's latch.
4781 // ContinueBlock and ContinuePred keep track of the source(s) of next edge. If
4782 // the ContinueBlock is set, continue with that block. If ContinuePred, use
4783 // its predecessors as sources.
4784 BasicBlock *ContinueBlock = Result->getBody();
4785 BasicBlock *ContinuePred = nullptr;
4786 auto ContinueWith = [&ContinueBlock, &ContinuePred, DL](BasicBlock *Dest,
4787 BasicBlock *NextSrc) {
4788 if (ContinueBlock)
4789 redirectTo(Source: ContinueBlock, Target: Dest, DL);
4790 else
4791 redirectAllPredecessorsTo(OldTarget: ContinuePred, NewTarget: Dest, DL);
4792
4793 ContinueBlock = nullptr;
4794 ContinuePred = NextSrc;
4795 };
4796
4797 // The code before the nested loop of each level.
4798 // Because we are sinking it into the nest, it will be executed more often
4799 // that the original loop. More sophisticated schemes could keep track of what
4800 // the in-between code is and instantiate it only once per thread.
4801 for (size_t i = 0; i < NumLoops - 1; ++i)
4802 ContinueWith(Loops[i]->getBody(), Loops[i + 1]->getHeader());
4803
4804 // Connect the loop nest body.
4805 ContinueWith(Innermost->getBody(), Innermost->getLatch());
4806
4807 // The code after the nested loop at each level.
4808 for (size_t i = NumLoops - 1; i > 0; --i)
4809 ContinueWith(Loops[i]->getAfter(), Loops[i - 1]->getLatch());
4810
4811 // Connect the finished loop to the collapsed loop latch.
4812 ContinueWith(Result->getLatch(), nullptr);
4813
4814 // Replace the input loops with the new collapsed loop.
4815 redirectTo(Source: Outermost->getPreheader(), Target: Result->getPreheader(), DL);
4816 redirectTo(Source: Result->getAfter(), Target: Outermost->getAfter(), DL);
4817
4818 // Replace the input loop indvars with the derived ones.
4819 for (size_t i = 0; i < NumLoops; ++i)
4820 Loops[i]->getIndVar()->replaceAllUsesWith(V: NewIndVars[i]);
4821
4822 // Remove unused parts of the input loops.
4823 removeUnusedBlocksFromParent(BBs: OldControlBBs);
4824
4825 for (CanonicalLoopInfo *L : Loops)
4826 L->invalidate();
4827
4828#ifndef NDEBUG
4829 Result->assertOK();
4830#endif
4831 return Result;
4832}
4833
4834std::vector<CanonicalLoopInfo *>
4835OpenMPIRBuilder::tileLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
4836 ArrayRef<Value *> TileSizes) {
4837 assert(TileSizes.size() == Loops.size() &&
4838 "Must pass as many tile sizes as there are loops");
4839 int NumLoops = Loops.size();
4840 assert(NumLoops >= 1 && "At least one loop to tile required");
4841
4842 CanonicalLoopInfo *OutermostLoop = Loops.front();
4843 CanonicalLoopInfo *InnermostLoop = Loops.back();
4844 Function *F = OutermostLoop->getBody()->getParent();
4845 BasicBlock *InnerEnter = InnermostLoop->getBody();
4846 BasicBlock *InnerLatch = InnermostLoop->getLatch();
4847
4848 // Loop control blocks that may become orphaned later.
4849 SmallVector<BasicBlock *, 12> OldControlBBs;
4850 OldControlBBs.reserve(N: 6 * Loops.size());
4851 for (CanonicalLoopInfo *Loop : Loops)
4852 Loop->collectControlBlocks(BBs&: OldControlBBs);
4853
4854 // Collect original trip counts and induction variable to be accessible by
4855 // index. Also, the structure of the original loops is not preserved during
4856 // the construction of the tiled loops, so do it before we scavenge the BBs of
4857 // any original CanonicalLoopInfo.
4858 SmallVector<Value *, 4> OrigTripCounts, OrigIndVars;
4859 for (CanonicalLoopInfo *L : Loops) {
4860 assert(L->isValid() && "All input loops must be valid canonical loops");
4861 OrigTripCounts.push_back(Elt: L->getTripCount());
4862 OrigIndVars.push_back(Elt: L->getIndVar());
4863 }
4864
4865 // Collect the code between loop headers. These may contain SSA definitions
4866 // that are used in the loop nest body. To be usable with in the innermost
4867 // body, these BasicBlocks will be sunk into the loop nest body. That is,
4868 // these instructions may be executed more often than before the tiling.
4869 // TODO: It would be sufficient to only sink them into body of the
4870 // corresponding tile loop.
4871 SmallVector<std::pair<BasicBlock *, BasicBlock *>, 4> InbetweenCode;
4872 for (int i = 0; i < NumLoops - 1; ++i) {
4873 CanonicalLoopInfo *Surrounding = Loops[i];
4874 CanonicalLoopInfo *Nested = Loops[i + 1];
4875
4876 BasicBlock *EnterBB = Surrounding->getBody();
4877 BasicBlock *ExitBB = Nested->getHeader();
4878 InbetweenCode.emplace_back(Args&: EnterBB, Args&: ExitBB);
4879 }
4880
4881 // Compute the trip counts of the floor loops.
4882 Builder.SetCurrentDebugLocation(DL);
4883 Builder.restoreIP(IP: OutermostLoop->getPreheaderIP());
4884 SmallVector<Value *, 4> FloorCount, FloorRems;
4885 for (int i = 0; i < NumLoops; ++i) {
4886 Value *TileSize = TileSizes[i];
4887 Value *OrigTripCount = OrigTripCounts[i];
4888 Type *IVType = OrigTripCount->getType();
4889
4890 Value *FloorTripCount = Builder.CreateUDiv(LHS: OrigTripCount, RHS: TileSize);
4891 Value *FloorTripRem = Builder.CreateURem(LHS: OrigTripCount, RHS: TileSize);
4892
4893 // 0 if tripcount divides the tilesize, 1 otherwise.
4894 // 1 means we need an additional iteration for a partial tile.
4895 //
4896 // Unfortunately we cannot just use the roundup-formula
4897 // (tripcount + tilesize - 1)/tilesize
4898 // because the summation might overflow. We do not want introduce undefined
4899 // behavior when the untiled loop nest did not.
4900 Value *FloorTripOverflow =
4901 Builder.CreateICmpNE(LHS: FloorTripRem, RHS: ConstantInt::get(Ty: IVType, V: 0));
4902
4903 FloorTripOverflow = Builder.CreateZExt(V: FloorTripOverflow, DestTy: IVType);
4904 FloorTripCount =
4905 Builder.CreateAdd(LHS: FloorTripCount, RHS: FloorTripOverflow,
4906 Name: "omp_floor" + Twine(i) + ".tripcount", HasNUW: true);
4907
4908 // Remember some values for later use.
4909 FloorCount.push_back(Elt: FloorTripCount);
4910 FloorRems.push_back(Elt: FloorTripRem);
4911 }
4912
4913 // Generate the new loop nest, from the outermost to the innermost.
4914 std::vector<CanonicalLoopInfo *> Result;
4915 Result.reserve(n: NumLoops * 2);
4916
4917 // The basic block of the surrounding loop that enters the nest generated
4918 // loop.
4919 BasicBlock *Enter = OutermostLoop->getPreheader();
4920
4921 // The basic block of the surrounding loop where the inner code should
4922 // continue.
4923 BasicBlock *Continue = OutermostLoop->getAfter();
4924
4925 // Where the next loop basic block should be inserted.
4926 BasicBlock *OutroInsertBefore = InnermostLoop->getExit();
4927
4928 auto EmbeddNewLoop =
4929 [this, DL, F, InnerEnter, &Enter, &Continue, &OutroInsertBefore](
4930 Value *TripCount, const Twine &Name) -> CanonicalLoopInfo * {
4931 CanonicalLoopInfo *EmbeddedLoop = createLoopSkeleton(
4932 DL, TripCount, F, PreInsertBefore: InnerEnter, PostInsertBefore: OutroInsertBefore, Name);
4933 redirectTo(Source: Enter, Target: EmbeddedLoop->getPreheader(), DL);
4934 redirectTo(Source: EmbeddedLoop->getAfter(), Target: Continue, DL);
4935
4936 // Setup the position where the next embedded loop connects to this loop.
4937 Enter = EmbeddedLoop->getBody();
4938 Continue = EmbeddedLoop->getLatch();
4939 OutroInsertBefore = EmbeddedLoop->getLatch();
4940 return EmbeddedLoop;
4941 };
4942
4943 auto EmbeddNewLoops = [&Result, &EmbeddNewLoop](ArrayRef<Value *> TripCounts,
4944 const Twine &NameBase) {
4945 for (auto P : enumerate(First&: TripCounts)) {
4946 CanonicalLoopInfo *EmbeddedLoop =
4947 EmbeddNewLoop(P.value(), NameBase + Twine(P.index()));
4948 Result.push_back(x: EmbeddedLoop);
4949 }
4950 };
4951
4952 EmbeddNewLoops(FloorCount, "floor");
4953
4954 // Within the innermost floor loop, emit the code that computes the tile
4955 // sizes.
4956 Builder.SetInsertPoint(Enter->getTerminator());
4957 SmallVector<Value *, 4> TileCounts;
4958 for (int i = 0; i < NumLoops; ++i) {
4959 CanonicalLoopInfo *FloorLoop = Result[i];
4960 Value *TileSize = TileSizes[i];
4961
4962 Value *FloorIsEpilogue =
4963 Builder.CreateICmpEQ(LHS: FloorLoop->getIndVar(), RHS: FloorCount[i]);
4964 Value *TileTripCount =
4965 Builder.CreateSelect(C: FloorIsEpilogue, True: FloorRems[i], False: TileSize);
4966
4967 TileCounts.push_back(Elt: TileTripCount);
4968 }
4969
4970 // Create the tile loops.
4971 EmbeddNewLoops(TileCounts, "tile");
4972
4973 // Insert the inbetween code into the body.
4974 BasicBlock *BodyEnter = Enter;
4975 BasicBlock *BodyEntered = nullptr;
4976 for (std::pair<BasicBlock *, BasicBlock *> P : InbetweenCode) {
4977 BasicBlock *EnterBB = P.first;
4978 BasicBlock *ExitBB = P.second;
4979
4980 if (BodyEnter)
4981 redirectTo(Source: BodyEnter, Target: EnterBB, DL);
4982 else
4983 redirectAllPredecessorsTo(OldTarget: BodyEntered, NewTarget: EnterBB, DL);
4984
4985 BodyEnter = nullptr;
4986 BodyEntered = ExitBB;
4987 }
4988
4989 // Append the original loop nest body into the generated loop nest body.
4990 if (BodyEnter)
4991 redirectTo(Source: BodyEnter, Target: InnerEnter, DL);
4992 else
4993 redirectAllPredecessorsTo(OldTarget: BodyEntered, NewTarget: InnerEnter, DL);
4994 redirectAllPredecessorsTo(OldTarget: InnerLatch, NewTarget: Continue, DL);
4995
4996 // Replace the original induction variable with an induction variable computed
4997 // from the tile and floor induction variables.
4998 Builder.restoreIP(IP: Result.back()->getBodyIP());
4999 for (int i = 0; i < NumLoops; ++i) {
5000 CanonicalLoopInfo *FloorLoop = Result[i];
5001 CanonicalLoopInfo *TileLoop = Result[NumLoops + i];
5002 Value *OrigIndVar = OrigIndVars[i];
5003 Value *Size = TileSizes[i];
5004
5005 Value *Scale =
5006 Builder.CreateMul(LHS: Size, RHS: FloorLoop->getIndVar(), Name: {}, /*HasNUW=*/true);
5007 Value *Shift =
5008 Builder.CreateAdd(LHS: Scale, RHS: TileLoop->getIndVar(), Name: {}, /*HasNUW=*/true);
5009 OrigIndVar->replaceAllUsesWith(V: Shift);
5010 }
5011
5012 // Remove unused parts of the original loops.
5013 removeUnusedBlocksFromParent(BBs: OldControlBBs);
5014
5015 for (CanonicalLoopInfo *L : Loops)
5016 L->invalidate();
5017
5018#ifndef NDEBUG
5019 for (CanonicalLoopInfo *GenL : Result)
5020 GenL->assertOK();
5021#endif
5022 return Result;
5023}
5024
5025/// Attach metadata \p Properties to the basic block described by \p BB. If the
5026/// basic block already has metadata, the basic block properties are appended.
5027static void addBasicBlockMetadata(BasicBlock *BB,
5028 ArrayRef<Metadata *> Properties) {
5029 // Nothing to do if no property to attach.
5030 if (Properties.empty())
5031 return;
5032
5033 LLVMContext &Ctx = BB->getContext();
5034 SmallVector<Metadata *> NewProperties;
5035 NewProperties.push_back(Elt: nullptr);
5036
5037 // If the basic block already has metadata, prepend it to the new metadata.
5038 MDNode *Existing = BB->getTerminator()->getMetadata(KindID: LLVMContext::MD_loop);
5039 if (Existing)
5040 append_range(C&: NewProperties, R: drop_begin(RangeOrContainer: Existing->operands(), N: 1));
5041
5042 append_range(C&: NewProperties, R&: Properties);
5043 MDNode *BasicBlockID = MDNode::getDistinct(Context&: Ctx, MDs: NewProperties);
5044 BasicBlockID->replaceOperandWith(I: 0, New: BasicBlockID);
5045
5046 BB->getTerminator()->setMetadata(KindID: LLVMContext::MD_loop, Node: BasicBlockID);
5047}
5048
5049/// Attach loop metadata \p Properties to the loop described by \p Loop. If the
5050/// loop already has metadata, the loop properties are appended.
5051static void addLoopMetadata(CanonicalLoopInfo *Loop,
5052 ArrayRef<Metadata *> Properties) {
5053 assert(Loop->isValid() && "Expecting a valid CanonicalLoopInfo");
5054
5055 // Attach metadata to the loop's latch
5056 BasicBlock *Latch = Loop->getLatch();
5057 assert(Latch && "A valid CanonicalLoopInfo must have a unique latch");
5058 addBasicBlockMetadata(BB: Latch, Properties);
5059}
5060
5061/// Attach llvm.access.group metadata to the memref instructions of \p Block
5062static void addSimdMetadata(BasicBlock *Block, MDNode *AccessGroup,
5063 LoopInfo &LI) {
5064 for (Instruction &I : *Block) {
5065 if (I.mayReadOrWriteMemory()) {
5066 // TODO: This instruction may already have access group from
5067 // other pragmas e.g. #pragma clang loop vectorize. Append
5068 // so that the existing metadata is not overwritten.
5069 I.setMetadata(KindID: LLVMContext::MD_access_group, Node: AccessGroup);
5070 }
5071 }
5072}
5073
5074void OpenMPIRBuilder::unrollLoopFull(DebugLoc, CanonicalLoopInfo *Loop) {
5075 LLVMContext &Ctx = Builder.getContext();
5076 addLoopMetadata(
5077 Loop, Properties: {MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")),
5078 MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.full"))});
5079}
5080
5081void OpenMPIRBuilder::unrollLoopHeuristic(DebugLoc, CanonicalLoopInfo *Loop) {
5082 LLVMContext &Ctx = Builder.getContext();
5083 addLoopMetadata(
5084 Loop, Properties: {
5085 MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")),
5086 });
5087}
5088
5089void OpenMPIRBuilder::createIfVersion(CanonicalLoopInfo *CanonicalLoop,
5090 Value *IfCond, ValueToValueMapTy &VMap,
5091 const Twine &NamePrefix) {
5092 Function *F = CanonicalLoop->getFunction();
5093
5094 // Define where if branch should be inserted
5095 Instruction *SplitBefore;
5096 if (Instruction::classof(V: IfCond)) {
5097 SplitBefore = dyn_cast<Instruction>(Val: IfCond);
5098 } else {
5099 SplitBefore = CanonicalLoop->getPreheader()->getTerminator();
5100 }
5101
5102 // TODO: We should not rely on pass manager. Currently we use pass manager
5103 // only for getting llvm::Loop which corresponds to given CanonicalLoopInfo
5104 // object. We should have a method which returns all blocks between
5105 // CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter()
5106 FunctionAnalysisManager FAM;
5107 FAM.registerPass(PassBuilder: []() { return DominatorTreeAnalysis(); });
5108 FAM.registerPass(PassBuilder: []() { return LoopAnalysis(); });
5109 FAM.registerPass(PassBuilder: []() { return PassInstrumentationAnalysis(); });
5110
5111 // Get the loop which needs to be cloned
5112 LoopAnalysis LIA;
5113 LoopInfo &&LI = LIA.run(F&: *F, AM&: FAM);
5114 Loop *L = LI.getLoopFor(BB: CanonicalLoop->getHeader());
5115
5116 // Create additional blocks for the if statement
5117 BasicBlock *Head = SplitBefore->getParent();
5118 Instruction *HeadOldTerm = Head->getTerminator();
5119 llvm::LLVMContext &C = Head->getContext();
5120 llvm::BasicBlock *ThenBlock = llvm::BasicBlock::Create(
5121 Context&: C, Name: NamePrefix + ".if.then", Parent: Head->getParent(), InsertBefore: Head->getNextNode());
5122 llvm::BasicBlock *ElseBlock = llvm::BasicBlock::Create(
5123 Context&: C, Name: NamePrefix + ".if.else", Parent: Head->getParent(), InsertBefore: CanonicalLoop->getExit());
5124
5125 // Create if condition branch.
5126 Builder.SetInsertPoint(HeadOldTerm);
5127 Instruction *BrInstr =
5128 Builder.CreateCondBr(Cond: IfCond, True: ThenBlock, /*ifFalse*/ False: ElseBlock);
5129 InsertPointTy IP{BrInstr->getParent(), ++BrInstr->getIterator()};
5130 // Then block contains branch to omp loop which needs to be vectorized
5131 spliceBB(IP, New: ThenBlock, CreateBranch: false);
5132 ThenBlock->replaceSuccessorsPhiUsesWith(Old: Head, New: ThenBlock);
5133
5134 Builder.SetInsertPoint(ElseBlock);
5135
5136 // Clone loop for the else branch
5137 SmallVector<BasicBlock *, 8> NewBlocks;
5138
5139 VMap[CanonicalLoop->getPreheader()] = ElseBlock;
5140 for (BasicBlock *Block : L->getBlocks()) {
5141 BasicBlock *NewBB = CloneBasicBlock(BB: Block, VMap, NameSuffix: "", F);
5142 NewBB->moveBefore(MovePos: CanonicalLoop->getExit());
5143 VMap[Block] = NewBB;
5144 NewBlocks.push_back(Elt: NewBB);
5145 }
5146 remapInstructionsInBlocks(Blocks: NewBlocks, VMap);
5147 Builder.CreateBr(Dest: NewBlocks.front());
5148}
5149
5150unsigned
5151OpenMPIRBuilder::getOpenMPDefaultSimdAlign(const Triple &TargetTriple,
5152 const StringMap<bool> &Features) {
5153 if (TargetTriple.isX86()) {
5154 if (Features.lookup(Key: "avx512f"))
5155 return 512;
5156 else if (Features.lookup(Key: "avx"))
5157 return 256;
5158 return 128;
5159 }
5160 if (TargetTriple.isPPC())
5161 return 128;
5162 if (TargetTriple.isWasm())
5163 return 128;
5164 return 0;
5165}
5166
5167void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
5168 MapVector<Value *, Value *> AlignedVars,
5169 Value *IfCond, OrderKind Order,
5170 ConstantInt *Simdlen, ConstantInt *Safelen) {
5171 LLVMContext &Ctx = Builder.getContext();
5172
5173 Function *F = CanonicalLoop->getFunction();
5174
5175 // TODO: We should not rely on pass manager. Currently we use pass manager
5176 // only for getting llvm::Loop which corresponds to given CanonicalLoopInfo
5177 // object. We should have a method which returns all blocks between
5178 // CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter()
5179 FunctionAnalysisManager FAM;
5180 FAM.registerPass(PassBuilder: []() { return DominatorTreeAnalysis(); });
5181 FAM.registerPass(PassBuilder: []() { return LoopAnalysis(); });
5182 FAM.registerPass(PassBuilder: []() { return PassInstrumentationAnalysis(); });
5183
5184 LoopAnalysis LIA;
5185 LoopInfo &&LI = LIA.run(F&: *F, AM&: FAM);
5186
5187 Loop *L = LI.getLoopFor(BB: CanonicalLoop->getHeader());
5188 if (AlignedVars.size()) {
5189 InsertPointTy IP = Builder.saveIP();
5190 Builder.SetInsertPoint(CanonicalLoop->getPreheader()->getTerminator());
5191 for (auto &AlignedItem : AlignedVars) {
5192 Value *AlignedPtr = AlignedItem.first;
5193 Value *Alignment = AlignedItem.second;
5194 Builder.CreateAlignmentAssumption(DL: F->getDataLayout(),
5195 PtrValue: AlignedPtr, Alignment);
5196 }
5197 Builder.restoreIP(IP);
5198 }
5199
5200 if (IfCond) {
5201 ValueToValueMapTy VMap;
5202 createIfVersion(CanonicalLoop, IfCond, VMap, NamePrefix: "simd");
5203 // Add metadata to the cloned loop which disables vectorization
5204 Value *MappedLatch = VMap.lookup(Val: CanonicalLoop->getLatch());
5205 assert(MappedLatch &&
5206 "Cannot find value which corresponds to original loop latch");
5207 assert(isa<BasicBlock>(MappedLatch) &&
5208 "Cannot cast mapped latch block value to BasicBlock");
5209 BasicBlock *NewLatchBlock = dyn_cast<BasicBlock>(Val: MappedLatch);
5210 ConstantAsMetadata *BoolConst =
5211 ConstantAsMetadata::get(C: ConstantInt::getFalse(Ty: Type::getInt1Ty(C&: Ctx)));
5212 addBasicBlockMetadata(
5213 BB: NewLatchBlock,
5214 Properties: {MDNode::get(Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.vectorize.enable"),
5215 BoolConst})});
5216 }
5217
5218 SmallSet<BasicBlock *, 8> Reachable;
5219
5220 // Get the basic blocks from the loop in which memref instructions
5221 // can be found.
5222 // TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
5223 // preferably without running any passes.
5224 for (BasicBlock *Block : L->getBlocks()) {
5225 if (Block == CanonicalLoop->getCond() ||
5226 Block == CanonicalLoop->getHeader())
5227 continue;
5228 Reachable.insert(Ptr: Block);
5229 }
5230
5231 SmallVector<Metadata *> LoopMDList;
5232
5233 // In presence of finite 'safelen', it may be unsafe to mark all
5234 // the memory instructions parallel, because loop-carried
5235 // dependences of 'safelen' iterations are possible.
5236 // If clause order(concurrent) is specified then the memory instructions
5237 // are marked parallel even if 'safelen' is finite.
5238 if ((Safelen == nullptr) || (Order == OrderKind::OMP_ORDER_concurrent)) {
5239 // Add access group metadata to memory-access instructions.
5240 MDNode *AccessGroup = MDNode::getDistinct(Context&: Ctx, MDs: {});
5241 for (BasicBlock *BB : Reachable)
5242 addSimdMetadata(Block: BB, AccessGroup, LI);
5243 // TODO: If the loop has existing parallel access metadata, have
5244 // to combine two lists.
5245 LoopMDList.push_back(Elt: MDNode::get(
5246 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.parallel_accesses"), AccessGroup}));
5247 }
5248
5249 // Use the above access group metadata to create loop level
5250 // metadata, which should be distinct for each loop.
5251 ConstantAsMetadata *BoolConst =
5252 ConstantAsMetadata::get(C: ConstantInt::getTrue(Ty: Type::getInt1Ty(C&: Ctx)));
5253 LoopMDList.push_back(Elt: MDNode::get(
5254 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.vectorize.enable"), BoolConst}));
5255
5256 if (Simdlen || Safelen) {
5257 // If both simdlen and safelen clauses are specified, the value of the
5258 // simdlen parameter must be less than or equal to the value of the safelen
5259 // parameter. Therefore, use safelen only in the absence of simdlen.
5260 ConstantInt *VectorizeWidth = Simdlen == nullptr ? Safelen : Simdlen;
5261 LoopMDList.push_back(
5262 Elt: MDNode::get(Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.vectorize.width"),
5263 ConstantAsMetadata::get(C: VectorizeWidth)}));
5264 }
5265
5266 addLoopMetadata(Loop: CanonicalLoop, Properties: LoopMDList);
5267}
5268
5269/// Create the TargetMachine object to query the backend for optimization
5270/// preferences.
5271///
5272/// Ideally, this would be passed from the front-end to the OpenMPBuilder, but
5273/// e.g. Clang does not pass it to its CodeGen layer and creates it only when
5274/// needed for the LLVM pass pipline. We use some default options to avoid
5275/// having to pass too many settings from the frontend that probably do not
5276/// matter.
5277///
5278/// Currently, TargetMachine is only used sometimes by the unrollLoopPartial
5279/// method. If we are going to use TargetMachine for more purposes, especially
5280/// those that are sensitive to TargetOptions, RelocModel and CodeModel, it
5281/// might become be worth requiring front-ends to pass on their TargetMachine,
5282/// or at least cache it between methods. Note that while fontends such as Clang
5283/// have just a single main TargetMachine per translation unit, "target-cpu" and
5284/// "target-features" that determine the TargetMachine are per-function and can
5285/// be overrided using __attribute__((target("OPTIONS"))).
5286static std::unique_ptr<TargetMachine>
5287createTargetMachine(Function *F, CodeGenOptLevel OptLevel) {
5288 Module *M = F->getParent();
5289
5290 StringRef CPU = F->getFnAttribute(Kind: "target-cpu").getValueAsString();
5291 StringRef Features = F->getFnAttribute(Kind: "target-features").getValueAsString();
5292 const std::string &Triple = M->getTargetTriple();
5293
5294 std::string Error;
5295 const llvm::Target *TheTarget = TargetRegistry::lookupTarget(Triple, Error);
5296 if (!TheTarget)
5297 return {};
5298
5299 llvm::TargetOptions Options;
5300 return std::unique_ptr<TargetMachine>(TheTarget->createTargetMachine(
5301 TT: Triple, CPU, Features, Options, /*RelocModel=*/RM: std::nullopt,
5302 /*CodeModel=*/CM: std::nullopt, OL: OptLevel));
5303}
5304
5305/// Heuristically determine the best-performant unroll factor for \p CLI. This
5306/// depends on the target processor. We are re-using the same heuristics as the
5307/// LoopUnrollPass.
5308static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) {
5309 Function *F = CLI->getFunction();
5310
5311 // Assume the user requests the most aggressive unrolling, even if the rest of
5312 // the code is optimized using a lower setting.
5313 CodeGenOptLevel OptLevel = CodeGenOptLevel::Aggressive;
5314 std::unique_ptr<TargetMachine> TM = createTargetMachine(F, OptLevel);
5315
5316 FunctionAnalysisManager FAM;
5317 FAM.registerPass(PassBuilder: []() { return TargetLibraryAnalysis(); });
5318 FAM.registerPass(PassBuilder: []() { return AssumptionAnalysis(); });
5319 FAM.registerPass(PassBuilder: []() { return DominatorTreeAnalysis(); });
5320 FAM.registerPass(PassBuilder: []() { return LoopAnalysis(); });
5321 FAM.registerPass(PassBuilder: []() { return ScalarEvolutionAnalysis(); });
5322 FAM.registerPass(PassBuilder: []() { return PassInstrumentationAnalysis(); });
5323 TargetIRAnalysis TIRA;
5324 if (TM)
5325 TIRA = TargetIRAnalysis(
5326 [&](const Function &F) { return TM->getTargetTransformInfo(F); });
5327 FAM.registerPass(PassBuilder: [&]() { return TIRA; });
5328
5329 TargetIRAnalysis::Result &&TTI = TIRA.run(F: *F, FAM);
5330 ScalarEvolutionAnalysis SEA;
5331 ScalarEvolution &&SE = SEA.run(F&: *F, AM&: FAM);
5332 DominatorTreeAnalysis DTA;
5333 DominatorTree &&DT = DTA.run(F&: *F, FAM);
5334 LoopAnalysis LIA;
5335 LoopInfo &&LI = LIA.run(F&: *F, AM&: FAM);
5336 AssumptionAnalysis ACT;
5337 AssumptionCache &&AC = ACT.run(F&: *F, FAM);
5338 OptimizationRemarkEmitter ORE{F};
5339
5340 Loop *L = LI.getLoopFor(BB: CLI->getHeader());
5341 assert(L && "Expecting CanonicalLoopInfo to be recognized as a loop");
5342
5343 TargetTransformInfo::UnrollingPreferences UP =
5344 gatherUnrollingPreferences(L, SE, TTI,
5345 /*BlockFrequencyInfo=*/BFI: nullptr,
5346 /*ProfileSummaryInfo=*/PSI: nullptr, ORE, OptLevel: static_cast<int>(OptLevel),
5347 /*UserThreshold=*/std::nullopt,
5348 /*UserCount=*/std::nullopt,
5349 /*UserAllowPartial=*/true,
5350 /*UserAllowRuntime=*/UserRuntime: true,
5351 /*UserUpperBound=*/std::nullopt,
5352 /*UserFullUnrollMaxCount=*/std::nullopt);
5353
5354 UP.Force = true;
5355
5356 // Account for additional optimizations taking place before the LoopUnrollPass
5357 // would unroll the loop.
5358 UP.Threshold *= UnrollThresholdFactor;
5359 UP.PartialThreshold *= UnrollThresholdFactor;
5360
5361 // Use normal unroll factors even if the rest of the code is optimized for
5362 // size.
5363 UP.OptSizeThreshold = UP.Threshold;
5364 UP.PartialOptSizeThreshold = UP.PartialThreshold;
5365
5366 LLVM_DEBUG(dbgs() << "Unroll heuristic thresholds:\n"
5367 << " Threshold=" << UP.Threshold << "\n"
5368 << " PartialThreshold=" << UP.PartialThreshold << "\n"
5369 << " OptSizeThreshold=" << UP.OptSizeThreshold << "\n"
5370 << " PartialOptSizeThreshold="
5371 << UP.PartialOptSizeThreshold << "\n");
5372
5373 // Disable peeling.
5374 TargetTransformInfo::PeelingPreferences PP =
5375 gatherPeelingPreferences(L, SE, TTI,
5376 /*UserAllowPeeling=*/false,
5377 /*UserAllowProfileBasedPeeling=*/false,
5378 /*UnrollingSpecficValues=*/false);
5379
5380 SmallPtrSet<const Value *, 32> EphValues;
5381 CodeMetrics::collectEphemeralValues(L, AC: &AC, EphValues);
5382
5383 // Assume that reads and writes to stack variables can be eliminated by
5384 // Mem2Reg, SROA or LICM. That is, don't count them towards the loop body's
5385 // size.
5386 for (BasicBlock *BB : L->blocks()) {
5387 for (Instruction &I : *BB) {
5388 Value *Ptr;
5389 if (auto *Load = dyn_cast<LoadInst>(Val: &I)) {
5390 Ptr = Load->getPointerOperand();
5391 } else if (auto *Store = dyn_cast<StoreInst>(Val: &I)) {
5392 Ptr = Store->getPointerOperand();
5393 } else
5394 continue;
5395
5396 Ptr = Ptr->stripPointerCasts();
5397
5398 if (auto *Alloca = dyn_cast<AllocaInst>(Val: Ptr)) {
5399 if (Alloca->getParent() == &F->getEntryBlock())
5400 EphValues.insert(Ptr: &I);
5401 }
5402 }
5403 }
5404
5405 UnrollCostEstimator UCE(L, TTI, EphValues, UP.BEInsns);
5406
5407 // Loop is not unrollable if the loop contains certain instructions.
5408 if (!UCE.canUnroll()) {
5409 LLVM_DEBUG(dbgs() << "Loop not considered unrollable\n");
5410 return 1;
5411 }
5412
5413 LLVM_DEBUG(dbgs() << "Estimated loop size is " << UCE.getRolledLoopSize()
5414 << "\n");
5415
5416 // TODO: Determine trip count of \p CLI if constant, computeUnrollCount might
5417 // be able to use it.
5418 int TripCount = 0;
5419 int MaxTripCount = 0;
5420 bool MaxOrZero = false;
5421 unsigned TripMultiple = 0;
5422
5423 bool UseUpperBound = false;
5424 computeUnrollCount(L, TTI, DT, LI: &LI, AC: &AC, SE, EphValues, ORE: &ORE, TripCount,
5425 MaxTripCount, MaxOrZero, TripMultiple, UCE, UP, PP,
5426 UseUpperBound);
5427 unsigned Factor = UP.Count;
5428 LLVM_DEBUG(dbgs() << "Suggesting unroll factor of " << Factor << "\n");
5429
5430 // This function returns 1 to signal to not unroll a loop.
5431 if (Factor == 0)
5432 return 1;
5433 return Factor;
5434}
5435
5436void OpenMPIRBuilder::unrollLoopPartial(DebugLoc DL, CanonicalLoopInfo *Loop,
5437 int32_t Factor,
5438 CanonicalLoopInfo **UnrolledCLI) {
5439 assert(Factor >= 0 && "Unroll factor must not be negative");
5440
5441 Function *F = Loop->getFunction();
5442 LLVMContext &Ctx = F->getContext();
5443
5444 // If the unrolled loop is not used for another loop-associated directive, it
5445 // is sufficient to add metadata for the LoopUnrollPass.
5446 if (!UnrolledCLI) {
5447 SmallVector<Metadata *, 2> LoopMetadata;
5448 LoopMetadata.push_back(
5449 Elt: MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")));
5450
5451 if (Factor >= 1) {
5452 ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
5453 C: ConstantInt::get(Ty: Type::getInt32Ty(C&: Ctx), V: APInt(32, Factor)));
5454 LoopMetadata.push_back(Elt: MDNode::get(
5455 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.count"), FactorConst}));
5456 }
5457
5458 addLoopMetadata(Loop, Properties: LoopMetadata);
5459 return;
5460 }
5461
5462 // Heuristically determine the unroll factor.
5463 if (Factor == 0)
5464 Factor = computeHeuristicUnrollFactor(CLI: Loop);
5465
5466 // No change required with unroll factor 1.
5467 if (Factor == 1) {
5468 *UnrolledCLI = Loop;
5469 return;
5470 }
5471
5472 assert(Factor >= 2 &&
5473 "unrolling only makes sense with a factor of 2 or larger");
5474
5475 Type *IndVarTy = Loop->getIndVarType();
5476
5477 // Apply partial unrolling by tiling the loop by the unroll-factor, then fully
5478 // unroll the inner loop.
5479 Value *FactorVal =
5480 ConstantInt::get(Ty: IndVarTy, V: APInt(IndVarTy->getIntegerBitWidth(), Factor,
5481 /*isSigned=*/false));
5482 std::vector<CanonicalLoopInfo *> LoopNest =
5483 tileLoops(DL, Loops: {Loop}, TileSizes: {FactorVal});
5484 assert(LoopNest.size() == 2 && "Expect 2 loops after tiling");
5485 *UnrolledCLI = LoopNest[0];
5486 CanonicalLoopInfo *InnerLoop = LoopNest[1];
5487
5488 // LoopUnrollPass can only fully unroll loops with constant trip count.
5489 // Unroll by the unroll factor with a fallback epilog for the remainder
5490 // iterations if necessary.
5491 ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
5492 C: ConstantInt::get(Ty: Type::getInt32Ty(C&: Ctx), V: APInt(32, Factor)));
5493 addLoopMetadata(
5494 Loop: InnerLoop,
5495 Properties: {MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")),
5496 MDNode::get(
5497 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.count"), FactorConst})});
5498
5499#ifndef NDEBUG
5500 (*UnrolledCLI)->assertOK();
5501#endif
5502}
5503
5504OpenMPIRBuilder::InsertPointTy
5505OpenMPIRBuilder::createCopyPrivate(const LocationDescription &Loc,
5506 llvm::Value *BufSize, llvm::Value *CpyBuf,
5507 llvm::Value *CpyFn, llvm::Value *DidIt) {
5508 if (!updateToLocation(Loc))
5509 return Loc.IP;
5510
5511 uint32_t SrcLocStrSize;
5512 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5513 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5514 Value *ThreadId = getOrCreateThreadID(Ident);
5515
5516 llvm::Value *DidItLD = Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: DidIt);
5517
5518 Value *Args[] = {Ident, ThreadId, BufSize, CpyBuf, CpyFn, DidItLD};
5519
5520 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_copyprivate);
5521 Builder.CreateCall(Callee: Fn, Args);
5522
5523 return Builder.saveIP();
5524}
5525
5526OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSingle(
5527 const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
5528 FinalizeCallbackTy FiniCB, bool IsNowait, ArrayRef<llvm::Value *> CPVars,
5529 ArrayRef<llvm::Function *> CPFuncs) {
5530
5531 if (!updateToLocation(Loc))
5532 return Loc.IP;
5533
5534 // If needed allocate and initialize `DidIt` with 0.
5535 // DidIt: flag variable: 1=single thread; 0=not single thread.
5536 llvm::Value *DidIt = nullptr;
5537 if (!CPVars.empty()) {
5538 DidIt = Builder.CreateAlloca(Ty: llvm::Type::getInt32Ty(C&: Builder.getContext()));
5539 Builder.CreateStore(Val: Builder.getInt32(C: 0), Ptr: DidIt);
5540 }
5541
5542 Directive OMPD = Directive::OMPD_single;
5543 uint32_t SrcLocStrSize;
5544 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5545 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5546 Value *ThreadId = getOrCreateThreadID(Ident);
5547 Value *Args[] = {Ident, ThreadId};
5548
5549 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_single);
5550 Instruction *EntryCall = Builder.CreateCall(Callee: EntryRTLFn, Args);
5551
5552 Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_single);
5553 Instruction *ExitCall = Builder.CreateCall(Callee: ExitRTLFn, Args);
5554
5555 auto FiniCBWrapper = [&](InsertPointTy IP) {
5556 FiniCB(IP);
5557
5558 // The thread that executes the single region must set `DidIt` to 1.
5559 // This is used by __kmpc_copyprivate, to know if the caller is the
5560 // single thread or not.
5561 if (DidIt)
5562 Builder.CreateStore(Val: Builder.getInt32(C: 1), Ptr: DidIt);
5563 };
5564
5565 // generates the following:
5566 // if (__kmpc_single()) {
5567 // .... single region ...
5568 // __kmpc_end_single
5569 // }
5570 // __kmpc_copyprivate
5571 // __kmpc_barrier
5572
5573 EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB: FiniCBWrapper,
5574 /*Conditional*/ true,
5575 /*hasFinalize*/ HasFinalize: true);
5576
5577 if (DidIt) {
5578 for (size_t I = 0, E = CPVars.size(); I < E; ++I)
5579 // NOTE BufSize is currently unused, so just pass 0.
5580 createCopyPrivate(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
5581 /*BufSize=*/ConstantInt::get(Ty: Int64, V: 0), CpyBuf: CPVars[I],
5582 CpyFn: CPFuncs[I], DidIt);
5583 // NOTE __kmpc_copyprivate already inserts a barrier
5584 } else if (!IsNowait)
5585 createBarrier(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
5586 Kind: omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
5587 /* CheckCancelFlag */ false);
5588 return Builder.saveIP();
5589}
5590
5591OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCritical(
5592 const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
5593 FinalizeCallbackTy FiniCB, StringRef CriticalName, Value *HintInst) {
5594
5595 if (!updateToLocation(Loc))
5596 return Loc.IP;
5597
5598 Directive OMPD = Directive::OMPD_critical;
5599 uint32_t SrcLocStrSize;
5600 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5601 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5602 Value *ThreadId = getOrCreateThreadID(Ident);
5603 Value *LockVar = getOMPCriticalRegionLock(CriticalName);
5604 Value *Args[] = {Ident, ThreadId, LockVar};
5605
5606 SmallVector<llvm::Value *, 4> EnterArgs(std::begin(arr&: Args), std::end(arr&: Args));
5607 Function *RTFn = nullptr;
5608 if (HintInst) {
5609 // Add Hint to entry Args and create call
5610 EnterArgs.push_back(Elt: HintInst);
5611 RTFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_critical_with_hint);
5612 } else {
5613 RTFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_critical);
5614 }
5615 Instruction *EntryCall = Builder.CreateCall(Callee: RTFn, Args: EnterArgs);
5616
5617 Function *ExitRTLFn =
5618 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_critical);
5619 Instruction *ExitCall = Builder.CreateCall(Callee: ExitRTLFn, Args);
5620
5621 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
5622 /*Conditional*/ false, /*hasFinalize*/ HasFinalize: true);
5623}
5624
5625OpenMPIRBuilder::InsertPointTy
5626OpenMPIRBuilder::createOrderedDepend(const LocationDescription &Loc,
5627 InsertPointTy AllocaIP, unsigned NumLoops,
5628 ArrayRef<llvm::Value *> StoreValues,
5629 const Twine &Name, bool IsDependSource) {
5630 assert(
5631 llvm::all_of(StoreValues,
5632 [](Value *SV) { return SV->getType()->isIntegerTy(64); }) &&
5633 "OpenMP runtime requires depend vec with i64 type");
5634
5635 if (!updateToLocation(Loc))
5636 return Loc.IP;
5637
5638 // Allocate space for vector and generate alloc instruction.
5639 auto *ArrI64Ty = ArrayType::get(ElementType: Int64, NumElements: NumLoops);
5640 Builder.restoreIP(IP: AllocaIP);
5641 AllocaInst *ArgsBase = Builder.CreateAlloca(Ty: ArrI64Ty, ArraySize: nullptr, Name);
5642 ArgsBase->setAlignment(Align(8));
5643 Builder.restoreIP(IP: Loc.IP);
5644
5645 // Store the index value with offset in depend vector.
5646 for (unsigned I = 0; I < NumLoops; ++I) {
5647 Value *DependAddrGEPIter = Builder.CreateInBoundsGEP(
5648 Ty: ArrI64Ty, Ptr: ArgsBase, IdxList: {Builder.getInt64(C: 0), Builder.getInt64(C: I)});
5649 StoreInst *STInst = Builder.CreateStore(Val: StoreValues[I], Ptr: DependAddrGEPIter);
5650 STInst->setAlignment(Align(8));
5651 }
5652
5653 Value *DependBaseAddrGEP = Builder.CreateInBoundsGEP(
5654 Ty: ArrI64Ty, Ptr: ArgsBase, IdxList: {Builder.getInt64(C: 0), Builder.getInt64(C: 0)});
5655
5656 uint32_t SrcLocStrSize;
5657 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5658 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5659 Value *ThreadId = getOrCreateThreadID(Ident);
5660 Value *Args[] = {Ident, ThreadId, DependBaseAddrGEP};
5661
5662 Function *RTLFn = nullptr;
5663 if (IsDependSource)
5664 RTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_doacross_post);
5665 else
5666 RTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_doacross_wait);
5667 Builder.CreateCall(Callee: RTLFn, Args);
5668
5669 return Builder.saveIP();
5670}
5671
5672OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createOrderedThreadsSimd(
5673 const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
5674 FinalizeCallbackTy FiniCB, bool IsThreads) {
5675 if (!updateToLocation(Loc))
5676 return Loc.IP;
5677
5678 Directive OMPD = Directive::OMPD_ordered;
5679 Instruction *EntryCall = nullptr;
5680 Instruction *ExitCall = nullptr;
5681
5682 if (IsThreads) {
5683 uint32_t SrcLocStrSize;
5684 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5685 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5686 Value *ThreadId = getOrCreateThreadID(Ident);
5687 Value *Args[] = {Ident, ThreadId};
5688
5689 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_ordered);
5690 EntryCall = Builder.CreateCall(Callee: EntryRTLFn, Args);
5691
5692 Function *ExitRTLFn =
5693 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_ordered);
5694 ExitCall = Builder.CreateCall(Callee: ExitRTLFn, Args);
5695 }
5696
5697 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
5698 /*Conditional*/ false, /*hasFinalize*/ HasFinalize: true);
5699}
5700
5701OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::EmitOMPInlinedRegion(
5702 Directive OMPD, Instruction *EntryCall, Instruction *ExitCall,
5703 BodyGenCallbackTy BodyGenCB, FinalizeCallbackTy FiniCB, bool Conditional,
5704 bool HasFinalize, bool IsCancellable) {
5705
5706 if (HasFinalize)
5707 FinalizationStack.push_back(Elt: {.FiniCB: FiniCB, .DK: OMPD, .IsCancellable: IsCancellable});
5708
5709 // Create inlined region's entry and body blocks, in preparation
5710 // for conditional creation
5711 BasicBlock *EntryBB = Builder.GetInsertBlock();
5712 Instruction *SplitPos = EntryBB->getTerminator();
5713 if (!isa_and_nonnull<BranchInst>(Val: SplitPos))
5714 SplitPos = new UnreachableInst(Builder.getContext(), EntryBB);
5715 BasicBlock *ExitBB = EntryBB->splitBasicBlock(I: SplitPos, BBName: "omp_region.end");
5716 BasicBlock *FiniBB =
5717 EntryBB->splitBasicBlock(I: EntryBB->getTerminator(), BBName: "omp_region.finalize");
5718
5719 Builder.SetInsertPoint(EntryBB->getTerminator());
5720 emitCommonDirectiveEntry(OMPD, EntryCall, ExitBB, Conditional);
5721
5722 // generate body
5723 BodyGenCB(/* AllocaIP */ InsertPointTy(),
5724 /* CodeGenIP */ Builder.saveIP());
5725
5726 // emit exit call and do any needed finalization.
5727 auto FinIP = InsertPointTy(FiniBB, FiniBB->getFirstInsertionPt());
5728 assert(FiniBB->getTerminator()->getNumSuccessors() == 1 &&
5729 FiniBB->getTerminator()->getSuccessor(0) == ExitBB &&
5730 "Unexpected control flow graph state!!");
5731 emitCommonDirectiveExit(OMPD, FinIP, ExitCall, HasFinalize);
5732 assert(FiniBB->getUniquePredecessor()->getUniqueSuccessor() == FiniBB &&
5733 "Unexpected Control Flow State!");
5734 MergeBlockIntoPredecessor(BB: FiniBB);
5735
5736 // If we are skipping the region of a non conditional, remove the exit
5737 // block, and clear the builder's insertion point.
5738 assert(SplitPos->getParent() == ExitBB &&
5739 "Unexpected Insertion point location!");
5740 auto merged = MergeBlockIntoPredecessor(BB: ExitBB);
5741 BasicBlock *ExitPredBB = SplitPos->getParent();
5742 auto InsertBB = merged ? ExitPredBB : ExitBB;
5743 if (!isa_and_nonnull<BranchInst>(Val: SplitPos))
5744 SplitPos->eraseFromParent();
5745 Builder.SetInsertPoint(InsertBB);
5746
5747 return Builder.saveIP();
5748}
5749
5750OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveEntry(
5751 Directive OMPD, Value *EntryCall, BasicBlock *ExitBB, bool Conditional) {
5752 // if nothing to do, Return current insertion point.
5753 if (!Conditional || !EntryCall)
5754 return Builder.saveIP();
5755
5756 BasicBlock *EntryBB = Builder.GetInsertBlock();
5757 Value *CallBool = Builder.CreateIsNotNull(Arg: EntryCall);
5758 auto *ThenBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp_region.body");
5759 auto *UI = new UnreachableInst(Builder.getContext(), ThenBB);
5760
5761 // Emit thenBB and set the Builder's insertion point there for
5762 // body generation next. Place the block after the current block.
5763 Function *CurFn = EntryBB->getParent();
5764 CurFn->insert(Position: std::next(x: EntryBB->getIterator()), BB: ThenBB);
5765
5766 // Move Entry branch to end of ThenBB, and replace with conditional
5767 // branch (If-stmt)
5768 Instruction *EntryBBTI = EntryBB->getTerminator();
5769 Builder.CreateCondBr(Cond: CallBool, True: ThenBB, False: ExitBB);
5770 EntryBBTI->removeFromParent();
5771 Builder.SetInsertPoint(UI);
5772 Builder.Insert(I: EntryBBTI);
5773 UI->eraseFromParent();
5774 Builder.SetInsertPoint(ThenBB->getTerminator());
5775
5776 // return an insertion point to ExitBB.
5777 return IRBuilder<>::InsertPoint(ExitBB, ExitBB->getFirstInsertionPt());
5778}
5779
5780OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveExit(
5781 omp::Directive OMPD, InsertPointTy FinIP, Instruction *ExitCall,
5782 bool HasFinalize) {
5783
5784 Builder.restoreIP(IP: FinIP);
5785
5786 // If there is finalization to do, emit it before the exit call
5787 if (HasFinalize) {
5788 assert(!FinalizationStack.empty() &&
5789 "Unexpected finalization stack state!");
5790
5791 FinalizationInfo Fi = FinalizationStack.pop_back_val();
5792 assert(Fi.DK == OMPD && "Unexpected Directive for Finalization call!");
5793
5794 Fi.FiniCB(FinIP);
5795
5796 BasicBlock *FiniBB = FinIP.getBlock();
5797 Instruction *FiniBBTI = FiniBB->getTerminator();
5798
5799 // set Builder IP for call creation
5800 Builder.SetInsertPoint(FiniBBTI);
5801 }
5802
5803 if (!ExitCall)
5804 return Builder.saveIP();
5805
5806 // place the Exitcall as last instruction before Finalization block terminator
5807 ExitCall->removeFromParent();
5808 Builder.Insert(I: ExitCall);
5809
5810 return IRBuilder<>::InsertPoint(ExitCall->getParent(),
5811 ExitCall->getIterator());
5812}
5813
5814OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCopyinClauseBlocks(
5815 InsertPointTy IP, Value *MasterAddr, Value *PrivateAddr,
5816 llvm::IntegerType *IntPtrTy, bool BranchtoEnd) {
5817 if (!IP.isSet())
5818 return IP;
5819
5820 IRBuilder<>::InsertPointGuard IPG(Builder);
5821
5822 // creates the following CFG structure
5823 // OMP_Entry : (MasterAddr != PrivateAddr)?
5824 // F T
5825 // | \
5826 // | copin.not.master
5827 // | /
5828 // v /
5829 // copyin.not.master.end
5830 // |
5831 // v
5832 // OMP.Entry.Next
5833
5834 BasicBlock *OMP_Entry = IP.getBlock();
5835 Function *CurFn = OMP_Entry->getParent();
5836 BasicBlock *CopyBegin =
5837 BasicBlock::Create(Context&: M.getContext(), Name: "copyin.not.master", Parent: CurFn);
5838 BasicBlock *CopyEnd = nullptr;
5839
5840 // If entry block is terminated, split to preserve the branch to following
5841 // basic block (i.e. OMP.Entry.Next), otherwise, leave everything as is.
5842 if (isa_and_nonnull<BranchInst>(Val: OMP_Entry->getTerminator())) {
5843 CopyEnd = OMP_Entry->splitBasicBlock(I: OMP_Entry->getTerminator(),
5844 BBName: "copyin.not.master.end");
5845 OMP_Entry->getTerminator()->eraseFromParent();
5846 } else {
5847 CopyEnd =
5848 BasicBlock::Create(Context&: M.getContext(), Name: "copyin.not.master.end", Parent: CurFn);
5849 }
5850
5851 Builder.SetInsertPoint(OMP_Entry);
5852 Value *MasterPtr = Builder.CreatePtrToInt(V: MasterAddr, DestTy: IntPtrTy);
5853 Value *PrivatePtr = Builder.CreatePtrToInt(V: PrivateAddr, DestTy: IntPtrTy);
5854 Value *cmp = Builder.CreateICmpNE(LHS: MasterPtr, RHS: PrivatePtr);
5855 Builder.CreateCondBr(Cond: cmp, True: CopyBegin, False: CopyEnd);
5856
5857 Builder.SetInsertPoint(CopyBegin);
5858 if (BranchtoEnd)
5859 Builder.SetInsertPoint(Builder.CreateBr(Dest: CopyEnd));
5860
5861 return Builder.saveIP();
5862}
5863
5864CallInst *OpenMPIRBuilder::createOMPAlloc(const LocationDescription &Loc,
5865 Value *Size, Value *Allocator,
5866 std::string Name) {
5867 IRBuilder<>::InsertPointGuard IPG(Builder);
5868 updateToLocation(Loc);
5869
5870 uint32_t SrcLocStrSize;
5871 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5872 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5873 Value *ThreadId = getOrCreateThreadID(Ident);
5874 Value *Args[] = {ThreadId, Size, Allocator};
5875
5876 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_alloc);
5877
5878 return Builder.CreateCall(Callee: Fn, Args, Name);
5879}
5880
5881CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc,
5882 Value *Addr, Value *Allocator,
5883 std::string Name) {
5884 IRBuilder<>::InsertPointGuard IPG(Builder);
5885 updateToLocation(Loc);
5886
5887 uint32_t SrcLocStrSize;
5888 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5889 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5890 Value *ThreadId = getOrCreateThreadID(Ident);
5891 Value *Args[] = {ThreadId, Addr, Allocator};
5892 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_free);
5893 return Builder.CreateCall(Callee: Fn, Args, Name);
5894}
5895
5896CallInst *OpenMPIRBuilder::createOMPInteropInit(
5897 const LocationDescription &Loc, Value *InteropVar,
5898 omp::OMPInteropType InteropType, Value *Device, Value *NumDependences,
5899 Value *DependenceAddress, bool HaveNowaitClause) {
5900 IRBuilder<>::InsertPointGuard IPG(Builder);
5901 updateToLocation(Loc);
5902
5903 uint32_t SrcLocStrSize;
5904 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5905 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5906 Value *ThreadId = getOrCreateThreadID(Ident);
5907 if (Device == nullptr)
5908 Device = ConstantInt::get(Ty: Int32, V: -1);
5909 Constant *InteropTypeVal = ConstantInt::get(Ty: Int32, V: (int)InteropType);
5910 if (NumDependences == nullptr) {
5911 NumDependences = ConstantInt::get(Ty: Int32, V: 0);
5912 PointerType *PointerTypeVar = PointerType::getUnqual(C&: M.getContext());
5913 DependenceAddress = ConstantPointerNull::get(T: PointerTypeVar);
5914 }
5915 Value *HaveNowaitClauseVal = ConstantInt::get(Ty: Int32, V: HaveNowaitClause);
5916 Value *Args[] = {
5917 Ident, ThreadId, InteropVar, InteropTypeVal,
5918 Device, NumDependences, DependenceAddress, HaveNowaitClauseVal};
5919
5920 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___tgt_interop_init);
5921
5922 return Builder.CreateCall(Callee: Fn, Args);
5923}
5924
5925CallInst *OpenMPIRBuilder::createOMPInteropDestroy(
5926 const LocationDescription &Loc, Value *InteropVar, Value *Device,
5927 Value *NumDependences, Value *DependenceAddress, bool HaveNowaitClause) {
5928 IRBuilder<>::InsertPointGuard IPG(Builder);
5929 updateToLocation(Loc);
5930
5931 uint32_t SrcLocStrSize;
5932 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5933 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5934 Value *ThreadId = getOrCreateThreadID(Ident);
5935 if (Device == nullptr)
5936 Device = ConstantInt::get(Ty: Int32, V: -1);
5937 if (NumDependences == nullptr) {
5938 NumDependences = ConstantInt::get(Ty: Int32, V: 0);
5939 PointerType *PointerTypeVar = PointerType::getUnqual(C&: M.getContext());
5940 DependenceAddress = ConstantPointerNull::get(T: PointerTypeVar);
5941 }
5942 Value *HaveNowaitClauseVal = ConstantInt::get(Ty: Int32, V: HaveNowaitClause);
5943 Value *Args[] = {
5944 Ident, ThreadId, InteropVar, Device,
5945 NumDependences, DependenceAddress, HaveNowaitClauseVal};
5946
5947 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___tgt_interop_destroy);
5948
5949 return Builder.CreateCall(Callee: Fn, Args);
5950}
5951
5952CallInst *OpenMPIRBuilder::createOMPInteropUse(const LocationDescription &Loc,
5953 Value *InteropVar, Value *Device,
5954 Value *NumDependences,
5955 Value *DependenceAddress,
5956 bool HaveNowaitClause) {
5957 IRBuilder<>::InsertPointGuard IPG(Builder);
5958 updateToLocation(Loc);
5959 uint32_t SrcLocStrSize;
5960 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5961 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5962 Value *ThreadId = getOrCreateThreadID(Ident);
5963 if (Device == nullptr)
5964 Device = ConstantInt::get(Ty: Int32, V: -1);
5965 if (NumDependences == nullptr) {
5966 NumDependences = ConstantInt::get(Ty: Int32, V: 0);
5967 PointerType *PointerTypeVar = PointerType::getUnqual(C&: M.getContext());
5968 DependenceAddress = ConstantPointerNull::get(T: PointerTypeVar);
5969 }
5970 Value *HaveNowaitClauseVal = ConstantInt::get(Ty: Int32, V: HaveNowaitClause);
5971 Value *Args[] = {
5972 Ident, ThreadId, InteropVar, Device,
5973 NumDependences, DependenceAddress, HaveNowaitClauseVal};
5974
5975 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___tgt_interop_use);
5976
5977 return Builder.CreateCall(Callee: Fn, Args);
5978}
5979
5980CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
5981 const LocationDescription &Loc, llvm::Value *Pointer,
5982 llvm::ConstantInt *Size, const llvm::Twine &Name) {
5983 IRBuilder<>::InsertPointGuard IPG(Builder);
5984 updateToLocation(Loc);
5985
5986 uint32_t SrcLocStrSize;
5987 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5988 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5989 Value *ThreadId = getOrCreateThreadID(Ident);
5990 Constant *ThreadPrivateCache =
5991 getOrCreateInternalVariable(Ty: Int8PtrPtr, Name: Name.str());
5992 llvm::Value *Args[] = {Ident, ThreadId, Pointer, Size, ThreadPrivateCache};
5993
5994 Function *Fn =
5995 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_threadprivate_cached);
5996
5997 return Builder.CreateCall(Callee: Fn, Args);
5998}
5999
6000OpenMPIRBuilder::InsertPointTy
6001OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
6002 int32_t MinThreadsVal, int32_t MaxThreadsVal,
6003 int32_t MinTeamsVal, int32_t MaxTeamsVal) {
6004 if (!updateToLocation(Loc))
6005 return Loc.IP;
6006
6007 uint32_t SrcLocStrSize;
6008 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6009 Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6010 Constant *IsSPMDVal = ConstantInt::getSigned(
6011 Ty: Int8, V: IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
6012 Constant *UseGenericStateMachineVal = ConstantInt::getSigned(Ty: Int8, V: !IsSPMD);
6013 Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Ty: Int8, V: true);
6014 Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Ty: Int16, V: 0);
6015
6016 Function *Kernel = Builder.GetInsertBlock()->getParent();
6017
6018 // Manifest the launch configuration in the metadata matching the kernel
6019 // environment.
6020 if (MinTeamsVal > 1 || MaxTeamsVal > 0)
6021 writeTeamsForKernel(T, Kernel&: *Kernel, LB: MinTeamsVal, UB: MaxTeamsVal);
6022
6023 // For max values, < 0 means unset, == 0 means set but unknown.
6024 if (MaxThreadsVal < 0)
6025 MaxThreadsVal = std::max(
6026 a: int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), b: MinThreadsVal);
6027
6028 if (MaxThreadsVal > 0)
6029 writeThreadBoundsForKernel(T, Kernel&: *Kernel, LB: MinThreadsVal, UB: MaxThreadsVal);
6030
6031 Constant *MinThreads = ConstantInt::getSigned(Ty: Int32, V: MinThreadsVal);
6032 Constant *MaxThreads = ConstantInt::getSigned(Ty: Int32, V: MaxThreadsVal);
6033 Constant *MinTeams = ConstantInt::getSigned(Ty: Int32, V: MinTeamsVal);
6034 Constant *MaxTeams = ConstantInt::getSigned(Ty: Int32, V: MaxTeamsVal);
6035 Constant *ReductionDataSize = ConstantInt::getSigned(Ty: Int32, V: 0);
6036 Constant *ReductionBufferLength = ConstantInt::getSigned(Ty: Int32, V: 0);
6037
6038 // We need to strip the debug prefix to get the correct kernel name.
6039 StringRef KernelName = Kernel->getName();
6040 const std::string DebugPrefix = "_debug__";
6041 if (KernelName.ends_with(Suffix: DebugPrefix))
6042 KernelName = KernelName.drop_back(N: DebugPrefix.length());
6043
6044 Function *Fn = getOrCreateRuntimeFunctionPtr(
6045 FnID: omp::RuntimeFunction::OMPRTL___kmpc_target_init);
6046 const DataLayout &DL = Fn->getDataLayout();
6047
6048 Twine DynamicEnvironmentName = KernelName + "_dynamic_environment";
6049 Constant *DynamicEnvironmentInitializer =
6050 ConstantStruct::get(T: DynamicEnvironment, V: {DebugIndentionLevelVal});
6051 GlobalVariable *DynamicEnvironmentGV = new GlobalVariable(
6052 M, DynamicEnvironment, /*IsConstant=*/false, GlobalValue::WeakODRLinkage,
6053 DynamicEnvironmentInitializer, DynamicEnvironmentName,
6054 /*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
6055 DL.getDefaultGlobalsAddressSpace());
6056 DynamicEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
6057
6058 Constant *DynamicEnvironment =
6059 DynamicEnvironmentGV->getType() == DynamicEnvironmentPtr
6060 ? DynamicEnvironmentGV
6061 : ConstantExpr::getAddrSpaceCast(C: DynamicEnvironmentGV,
6062 Ty: DynamicEnvironmentPtr);
6063
6064 Constant *ConfigurationEnvironmentInitializer = ConstantStruct::get(
6065 T: ConfigurationEnvironment, V: {
6066 UseGenericStateMachineVal,
6067 MayUseNestedParallelismVal,
6068 IsSPMDVal,
6069 MinThreads,
6070 MaxThreads,
6071 MinTeams,
6072 MaxTeams,
6073 ReductionDataSize,
6074 ReductionBufferLength,
6075 });
6076 Constant *KernelEnvironmentInitializer = ConstantStruct::get(
6077 T: KernelEnvironment, V: {
6078 ConfigurationEnvironmentInitializer,
6079 Ident,
6080 DynamicEnvironment,
6081 });
6082 std::string KernelEnvironmentName =
6083 (KernelName + "_kernel_environment").str();
6084 GlobalVariable *KernelEnvironmentGV = new GlobalVariable(
6085 M, KernelEnvironment, /*IsConstant=*/true, GlobalValue::WeakODRLinkage,
6086 KernelEnvironmentInitializer, KernelEnvironmentName,
6087 /*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
6088 DL.getDefaultGlobalsAddressSpace());
6089 KernelEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
6090
6091 Constant *KernelEnvironment =
6092 KernelEnvironmentGV->getType() == KernelEnvironmentPtr
6093 ? KernelEnvironmentGV
6094 : ConstantExpr::getAddrSpaceCast(C: KernelEnvironmentGV,
6095 Ty: KernelEnvironmentPtr);
6096 Value *KernelLaunchEnvironment = Kernel->getArg(i: 0);
6097 CallInst *ThreadKind =
6098 Builder.CreateCall(Callee: Fn, Args: {KernelEnvironment, KernelLaunchEnvironment});
6099
6100 Value *ExecUserCode = Builder.CreateICmpEQ(
6101 LHS: ThreadKind, RHS: ConstantInt::get(Ty: ThreadKind->getType(), V: -1),
6102 Name: "exec_user_code");
6103
6104 // ThreadKind = __kmpc_target_init(...)
6105 // if (ThreadKind == -1)
6106 // user_code
6107 // else
6108 // return;
6109
6110 auto *UI = Builder.CreateUnreachable();
6111 BasicBlock *CheckBB = UI->getParent();
6112 BasicBlock *UserCodeEntryBB = CheckBB->splitBasicBlock(I: UI, BBName: "user_code.entry");
6113
6114 BasicBlock *WorkerExitBB = BasicBlock::Create(
6115 Context&: CheckBB->getContext(), Name: "worker.exit", Parent: CheckBB->getParent());
6116 Builder.SetInsertPoint(WorkerExitBB);
6117 Builder.CreateRetVoid();
6118
6119 auto *CheckBBTI = CheckBB->getTerminator();
6120 Builder.SetInsertPoint(CheckBBTI);
6121 Builder.CreateCondBr(Cond: ExecUserCode, True: UI->getParent(), False: WorkerExitBB);
6122
6123 CheckBBTI->eraseFromParent();
6124 UI->eraseFromParent();
6125
6126 // Continue in the "user_code" block, see diagram above and in
6127 // openmp/libomptarget/deviceRTLs/common/include/target.h .
6128 return InsertPointTy(UserCodeEntryBB, UserCodeEntryBB->getFirstInsertionPt());
6129}
6130
6131void OpenMPIRBuilder::createTargetDeinit(const LocationDescription &Loc,
6132 int32_t TeamsReductionDataSize,
6133 int32_t TeamsReductionBufferLength) {
6134 if (!updateToLocation(Loc))
6135 return;
6136
6137 Function *Fn = getOrCreateRuntimeFunctionPtr(
6138 FnID: omp::RuntimeFunction::OMPRTL___kmpc_target_deinit);
6139
6140 Builder.CreateCall(Callee: Fn, Args: {});
6141
6142 if (!TeamsReductionBufferLength || !TeamsReductionDataSize)
6143 return;
6144
6145 Function *Kernel = Builder.GetInsertBlock()->getParent();
6146 // We need to strip the debug prefix to get the correct kernel name.
6147 StringRef KernelName = Kernel->getName();
6148 const std::string DebugPrefix = "_debug__";
6149 if (KernelName.ends_with(Suffix: DebugPrefix))
6150 KernelName = KernelName.drop_back(N: DebugPrefix.length());
6151 auto *KernelEnvironmentGV =
6152 M.getNamedGlobal(Name: (KernelName + "_kernel_environment").str());
6153 assert(KernelEnvironmentGV && "Expected kernel environment global\n");
6154 auto *KernelEnvironmentInitializer = KernelEnvironmentGV->getInitializer();
6155 auto *NewInitializer = ConstantFoldInsertValueInstruction(
6156 Agg: KernelEnvironmentInitializer,
6157 Val: ConstantInt::get(Ty: Int32, V: TeamsReductionDataSize), Idxs: {0, 7});
6158 NewInitializer = ConstantFoldInsertValueInstruction(
6159 Agg: NewInitializer, Val: ConstantInt::get(Ty: Int32, V: TeamsReductionBufferLength),
6160 Idxs: {0, 8});
6161 KernelEnvironmentGV->setInitializer(NewInitializer);
6162}
6163
6164static MDNode *getNVPTXMDNode(Function &Kernel, StringRef Name) {
6165 Module &M = *Kernel.getParent();
6166 NamedMDNode *MD = M.getOrInsertNamedMetadata(Name: "nvvm.annotations");
6167 for (auto *Op : MD->operands()) {
6168 if (Op->getNumOperands() != 3)
6169 continue;
6170 auto *KernelOp = dyn_cast<ConstantAsMetadata>(Val: Op->getOperand(I: 0));
6171 if (!KernelOp || KernelOp->getValue() != &Kernel)
6172 continue;
6173 auto *Prop = dyn_cast<MDString>(Val: Op->getOperand(I: 1));
6174 if (!Prop || Prop->getString() != Name)
6175 continue;
6176 return Op;
6177 }
6178 return nullptr;
6179}
6180
6181static void updateNVPTXMetadata(Function &Kernel, StringRef Name, int32_t Value,
6182 bool Min) {
6183 // Update the "maxntidx" metadata for NVIDIA, or add it.
6184 MDNode *ExistingOp = getNVPTXMDNode(Kernel, Name);
6185 if (ExistingOp) {
6186 auto *OldVal = cast<ConstantAsMetadata>(Val: ExistingOp->getOperand(I: 2));
6187 int32_t OldLimit = cast<ConstantInt>(Val: OldVal->getValue())->getZExtValue();
6188 ExistingOp->replaceOperandWith(
6189 I: 2, New: ConstantAsMetadata::get(C: ConstantInt::get(
6190 Ty: OldVal->getValue()->getType(),
6191 V: Min ? std::min(a: OldLimit, b: Value) : std::max(a: OldLimit, b: Value))));
6192 } else {
6193 LLVMContext &Ctx = Kernel.getContext();
6194 Metadata *MDVals[] = {ConstantAsMetadata::get(C: &Kernel),
6195 MDString::get(Context&: Ctx, Str: Name),
6196 ConstantAsMetadata::get(
6197 C: ConstantInt::get(Ty: Type::getInt32Ty(C&: Ctx), V: Value))};
6198 // Append metadata to nvvm.annotations
6199 Module &M = *Kernel.getParent();
6200 NamedMDNode *MD = M.getOrInsertNamedMetadata(Name: "nvvm.annotations");
6201 MD->addOperand(M: MDNode::get(Context&: Ctx, MDs: MDVals));
6202 }
6203}
6204
6205std::pair<int32_t, int32_t>
6206OpenMPIRBuilder::readThreadBoundsForKernel(const Triple &T, Function &Kernel) {
6207 int32_t ThreadLimit =
6208 Kernel.getFnAttributeAsParsedInteger(Kind: "omp_target_thread_limit");
6209
6210 if (T.isAMDGPU()) {
6211 const auto &Attr = Kernel.getFnAttribute(Kind: "amdgpu-flat-work-group-size");
6212 if (!Attr.isValid() || !Attr.isStringAttribute())
6213 return {0, ThreadLimit};
6214 auto [LBStr, UBStr] = Attr.getValueAsString().split(Separator: ',');
6215 int32_t LB, UB;
6216 if (!llvm::to_integer(S: UBStr, Num&: UB, Base: 10))
6217 return {0, ThreadLimit};
6218 UB = ThreadLimit ? std::min(a: ThreadLimit, b: UB) : UB;
6219 if (!llvm::to_integer(S: LBStr, Num&: LB, Base: 10))
6220 return {0, UB};
6221 return {LB, UB};
6222 }
6223
6224 if (MDNode *ExistingOp = getNVPTXMDNode(Kernel, Name: "maxntidx")) {
6225 auto *OldVal = cast<ConstantAsMetadata>(Val: ExistingOp->getOperand(I: 2));
6226 int32_t UB = cast<ConstantInt>(Val: OldVal->getValue())->getZExtValue();
6227 return {0, ThreadLimit ? std::min(a: ThreadLimit, b: UB) : UB};
6228 }
6229 return {0, ThreadLimit};
6230}
6231
6232void OpenMPIRBuilder::writeThreadBoundsForKernel(const Triple &T,
6233 Function &Kernel, int32_t LB,
6234 int32_t UB) {
6235 Kernel.addFnAttr(Kind: "omp_target_thread_limit", Val: std::to_string(val: UB));
6236
6237 if (T.isAMDGPU()) {
6238 Kernel.addFnAttr(Kind: "amdgpu-flat-work-group-size",
6239 Val: llvm::utostr(X: LB) + "," + llvm::utostr(X: UB));
6240 return;
6241 }
6242
6243 updateNVPTXMetadata(Kernel, Name: "maxntidx", Value: UB, Min: true);
6244}
6245
6246std::pair<int32_t, int32_t>
6247OpenMPIRBuilder::readTeamBoundsForKernel(const Triple &, Function &Kernel) {
6248 // TODO: Read from backend annotations if available.
6249 return {0, Kernel.getFnAttributeAsParsedInteger(Kind: "omp_target_num_teams")};
6250}
6251
6252void OpenMPIRBuilder::writeTeamsForKernel(const Triple &T, Function &Kernel,
6253 int32_t LB, int32_t UB) {
6254 if (T.isNVPTX())
6255 if (UB > 0)
6256 updateNVPTXMetadata(Kernel, Name: "maxclusterrank", Value: UB, Min: true);
6257 if (T.isAMDGPU())
6258 Kernel.addFnAttr(Kind: "amdgpu-max-num-workgroups", Val: llvm::utostr(X: LB) + ",1,1");
6259
6260 Kernel.addFnAttr(Kind: "omp_target_num_teams", Val: std::to_string(val: LB));
6261}
6262
6263void OpenMPIRBuilder::setOutlinedTargetRegionFunctionAttributes(
6264 Function *OutlinedFn) {
6265 if (Config.isTargetDevice()) {
6266 OutlinedFn->setLinkage(GlobalValue::WeakODRLinkage);
6267 // TODO: Determine if DSO local can be set to true.
6268 OutlinedFn->setDSOLocal(false);
6269 OutlinedFn->setVisibility(GlobalValue::ProtectedVisibility);
6270 if (T.isAMDGCN())
6271 OutlinedFn->setCallingConv(CallingConv::AMDGPU_KERNEL);
6272 }
6273}
6274
6275Constant *OpenMPIRBuilder::createOutlinedFunctionID(Function *OutlinedFn,
6276 StringRef EntryFnIDName) {
6277 if (Config.isTargetDevice()) {
6278 assert(OutlinedFn && "The outlined function must exist if embedded");
6279 return OutlinedFn;
6280 }
6281
6282 return new GlobalVariable(
6283 M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
6284 Constant::getNullValue(Ty: Builder.getInt8Ty()), EntryFnIDName);
6285}
6286
6287Constant *OpenMPIRBuilder::createTargetRegionEntryAddr(Function *OutlinedFn,
6288 StringRef EntryFnName) {
6289 if (OutlinedFn)
6290 return OutlinedFn;
6291
6292 assert(!M.getGlobalVariable(EntryFnName, true) &&
6293 "Named kernel already exists?");
6294 return new GlobalVariable(
6295 M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::InternalLinkage,
6296 Constant::getNullValue(Ty: Builder.getInt8Ty()), EntryFnName);
6297}
6298
6299void OpenMPIRBuilder::emitTargetRegionFunction(
6300 TargetRegionEntryInfo &EntryInfo,
6301 FunctionGenCallback &GenerateFunctionCallback, bool IsOffloadEntry,
6302 Function *&OutlinedFn, Constant *&OutlinedFnID) {
6303
6304 SmallString<64> EntryFnName;
6305 OffloadInfoManager.getTargetRegionEntryFnName(Name&: EntryFnName, EntryInfo);
6306
6307 OutlinedFn = Config.isTargetDevice() || !Config.openMPOffloadMandatory()
6308 ? GenerateFunctionCallback(EntryFnName)
6309 : nullptr;
6310
6311 // If this target outline function is not an offload entry, we don't need to
6312 // register it. This may be in the case of a false if clause, or if there are
6313 // no OpenMP targets.
6314 if (!IsOffloadEntry)
6315 return;
6316
6317 std::string EntryFnIDName =
6318 Config.isTargetDevice()
6319 ? std::string(EntryFnName)
6320 : createPlatformSpecificName(Parts: {EntryFnName, "region_id"});
6321
6322 OutlinedFnID = registerTargetRegionFunction(EntryInfo, OutlinedFunction: OutlinedFn,
6323 EntryFnName, EntryFnIDName);
6324}
6325
6326Constant *OpenMPIRBuilder::registerTargetRegionFunction(
6327 TargetRegionEntryInfo &EntryInfo, Function *OutlinedFn,
6328 StringRef EntryFnName, StringRef EntryFnIDName) {
6329 if (OutlinedFn)
6330 setOutlinedTargetRegionFunctionAttributes(OutlinedFn);
6331 auto OutlinedFnID = createOutlinedFunctionID(OutlinedFn, EntryFnIDName);
6332 auto EntryAddr = createTargetRegionEntryAddr(OutlinedFn, EntryFnName);
6333 OffloadInfoManager.registerTargetRegionEntryInfo(
6334 EntryInfo, Addr: EntryAddr, ID: OutlinedFnID,
6335 Flags: OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion);
6336 return OutlinedFnID;
6337}
6338
6339OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
6340 const LocationDescription &Loc, InsertPointTy AllocaIP,
6341 InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
6342 TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
6343 omp::RuntimeFunction *MapperFunc,
6344 function_ref<InsertPointTy(InsertPointTy CodeGenIP, BodyGenTy BodyGenType)>
6345 BodyGenCB,
6346 function_ref<void(unsigned int, Value *)> DeviceAddrCB,
6347 function_ref<Value *(unsigned int)> CustomMapperCB, Value *SrcLocInfo) {
6348 if (!updateToLocation(Loc))
6349 return InsertPointTy();
6350
6351 // Disable TargetData CodeGen on Device pass.
6352 if (Config.IsTargetDevice.value_or(u: false)) {
6353 if (BodyGenCB)
6354 Builder.restoreIP(IP: BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv));
6355 return Builder.saveIP();
6356 }
6357
6358 Builder.restoreIP(IP: CodeGenIP);
6359 bool IsStandAlone = !BodyGenCB;
6360 MapInfosTy *MapInfo;
6361 // Generate the code for the opening of the data environment. Capture all the
6362 // arguments of the runtime call by reference because they are used in the
6363 // closing of the region.
6364 auto BeginThenGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
6365 MapInfo = &GenMapInfoCB(Builder.saveIP());
6366 emitOffloadingArrays(AllocaIP, CodeGenIP: Builder.saveIP(), CombinedInfo&: *MapInfo, Info,
6367 /*IsNonContiguous=*/true, DeviceAddrCB,
6368 CustomMapperCB);
6369
6370 TargetDataRTArgs RTArgs;
6371 emitOffloadingArraysArgument(Builder, RTArgs, Info,
6372 EmitDebug: !MapInfo->Names.empty());
6373
6374 // Emit the number of elements in the offloading arrays.
6375 Value *PointerNum = Builder.getInt32(C: Info.NumberOfPtrs);
6376
6377 // Source location for the ident struct
6378 if (!SrcLocInfo) {
6379 uint32_t SrcLocStrSize;
6380 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6381 SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6382 }
6383
6384 Value *OffloadingArgs[] = {SrcLocInfo, DeviceID,
6385 PointerNum, RTArgs.BasePointersArray,
6386 RTArgs.PointersArray, RTArgs.SizesArray,
6387 RTArgs.MapTypesArray, RTArgs.MapNamesArray,
6388 RTArgs.MappersArray};
6389
6390 if (IsStandAlone) {
6391 assert(MapperFunc && "MapperFunc missing for standalone target data");
6392 Builder.CreateCall(Callee: getOrCreateRuntimeFunctionPtr(FnID: *MapperFunc),
6393 Args: OffloadingArgs);
6394 } else {
6395 Function *BeginMapperFunc = getOrCreateRuntimeFunctionPtr(
6396 FnID: omp::OMPRTL___tgt_target_data_begin_mapper);
6397
6398 Builder.CreateCall(Callee: BeginMapperFunc, Args: OffloadingArgs);
6399
6400 for (auto DeviceMap : Info.DevicePtrInfoMap) {
6401 if (isa<AllocaInst>(Val: DeviceMap.second.second)) {
6402 auto *LI =
6403 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: DeviceMap.second.first);
6404 Builder.CreateStore(Val: LI, Ptr: DeviceMap.second.second);
6405 }
6406 }
6407
6408 // If device pointer privatization is required, emit the body of the
6409 // region here. It will have to be duplicated: with and without
6410 // privatization.
6411 Builder.restoreIP(IP: BodyGenCB(Builder.saveIP(), BodyGenTy::Priv));
6412 }
6413 };
6414
6415 // If we need device pointer privatization, we need to emit the body of the
6416 // region with no privatization in the 'else' branch of the conditional.
6417 // Otherwise, we don't have to do anything.
6418 auto BeginElseGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
6419 Builder.restoreIP(IP: BodyGenCB(Builder.saveIP(), BodyGenTy::DupNoPriv));
6420 };
6421
6422 // Generate code for the closing of the data region.
6423 auto EndThenGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
6424 TargetDataRTArgs RTArgs;
6425 emitOffloadingArraysArgument(Builder, RTArgs, Info, EmitDebug: !MapInfo->Names.empty(),
6426 /*ForEndCall=*/true);
6427
6428 // Emit the number of elements in the offloading arrays.
6429 Value *PointerNum = Builder.getInt32(C: Info.NumberOfPtrs);
6430
6431 // Source location for the ident struct
6432 if (!SrcLocInfo) {
6433 uint32_t SrcLocStrSize;
6434 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6435 SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6436 }
6437
6438 Value *OffloadingArgs[] = {SrcLocInfo, DeviceID,
6439 PointerNum, RTArgs.BasePointersArray,
6440 RTArgs.PointersArray, RTArgs.SizesArray,
6441 RTArgs.MapTypesArray, RTArgs.MapNamesArray,
6442 RTArgs.MappersArray};
6443 Function *EndMapperFunc =
6444 getOrCreateRuntimeFunctionPtr(FnID: omp::OMPRTL___tgt_target_data_end_mapper);
6445
6446 Builder.CreateCall(Callee: EndMapperFunc, Args: OffloadingArgs);
6447 };
6448
6449 // We don't have to do anything to close the region if the if clause evaluates
6450 // to false.
6451 auto EndElseGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
6452
6453 if (BodyGenCB) {
6454 if (IfCond) {
6455 emitIfClause(Cond: IfCond, ThenGen: BeginThenGen, ElseGen: BeginElseGen, AllocaIP);
6456 } else {
6457 BeginThenGen(AllocaIP, Builder.saveIP());
6458 }
6459
6460 // If we don't require privatization of device pointers, we emit the body in
6461 // between the runtime calls. This avoids duplicating the body code.
6462 Builder.restoreIP(IP: BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv));
6463
6464 if (IfCond) {
6465 emitIfClause(Cond: IfCond, ThenGen: EndThenGen, ElseGen: EndElseGen, AllocaIP);
6466 } else {
6467 EndThenGen(AllocaIP, Builder.saveIP());
6468 }
6469 } else {
6470 if (IfCond) {
6471 emitIfClause(Cond: IfCond, ThenGen: BeginThenGen, ElseGen: EndElseGen, AllocaIP);
6472 } else {
6473 BeginThenGen(AllocaIP, Builder.saveIP());
6474 }
6475 }
6476
6477 return Builder.saveIP();
6478}
6479
6480FunctionCallee
6481OpenMPIRBuilder::createForStaticInitFunction(unsigned IVSize, bool IVSigned,
6482 bool IsGPUDistribute) {
6483 assert((IVSize == 32 || IVSize == 64) &&
6484 "IV size is not compatible with the omp runtime");
6485 RuntimeFunction Name;
6486 if (IsGPUDistribute)
6487 Name = IVSize == 32
6488 ? (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_4
6489 : omp::OMPRTL___kmpc_distribute_static_init_4u)
6490 : (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_8
6491 : omp::OMPRTL___kmpc_distribute_static_init_8u);
6492 else
6493 Name = IVSize == 32 ? (IVSigned ? omp::OMPRTL___kmpc_for_static_init_4
6494 : omp::OMPRTL___kmpc_for_static_init_4u)
6495 : (IVSigned ? omp::OMPRTL___kmpc_for_static_init_8
6496 : omp::OMPRTL___kmpc_for_static_init_8u);
6497
6498 return getOrCreateRuntimeFunction(M, FnID: Name);
6499}
6500
6501FunctionCallee OpenMPIRBuilder::createDispatchInitFunction(unsigned IVSize,
6502 bool IVSigned) {
6503 assert((IVSize == 32 || IVSize == 64) &&
6504 "IV size is not compatible with the omp runtime");
6505 RuntimeFunction Name = IVSize == 32
6506 ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_4
6507 : omp::OMPRTL___kmpc_dispatch_init_4u)
6508 : (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_8
6509 : omp::OMPRTL___kmpc_dispatch_init_8u);
6510
6511 return getOrCreateRuntimeFunction(M, FnID: Name);
6512}
6513
6514FunctionCallee OpenMPIRBuilder::createDispatchNextFunction(unsigned IVSize,
6515 bool IVSigned) {
6516 assert((IVSize == 32 || IVSize == 64) &&
6517 "IV size is not compatible with the omp runtime");
6518 RuntimeFunction Name = IVSize == 32
6519 ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_4
6520 : omp::OMPRTL___kmpc_dispatch_next_4u)
6521 : (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_8
6522 : omp::OMPRTL___kmpc_dispatch_next_8u);
6523
6524 return getOrCreateRuntimeFunction(M, FnID: Name);
6525}
6526
6527FunctionCallee OpenMPIRBuilder::createDispatchFiniFunction(unsigned IVSize,
6528 bool IVSigned) {
6529 assert((IVSize == 32 || IVSize == 64) &&
6530 "IV size is not compatible with the omp runtime");
6531 RuntimeFunction Name = IVSize == 32
6532 ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_4
6533 : omp::OMPRTL___kmpc_dispatch_fini_4u)
6534 : (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_8
6535 : omp::OMPRTL___kmpc_dispatch_fini_8u);
6536
6537 return getOrCreateRuntimeFunction(M, FnID: Name);
6538}
6539
6540FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
6541 return getOrCreateRuntimeFunction(M, FnID: omp::OMPRTL___kmpc_dispatch_deinit);
6542}
6543
6544static Function *createOutlinedFunction(
6545 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, StringRef FuncName,
6546 SmallVectorImpl<Value *> &Inputs,
6547 OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
6548 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
6549 SmallVector<Type *> ParameterTypes;
6550 if (OMPBuilder.Config.isTargetDevice()) {
6551 // Add the "implicit" runtime argument we use to provide launch specific
6552 // information for target devices.
6553 auto *Int8PtrTy = PointerType::getUnqual(C&: Builder.getContext());
6554 ParameterTypes.push_back(Elt: Int8PtrTy);
6555
6556 // All parameters to target devices are passed as pointers
6557 // or i64. This assumes 64-bit address spaces/pointers.
6558 for (auto &Arg : Inputs)
6559 ParameterTypes.push_back(Elt: Arg->getType()->isPointerTy()
6560 ? Arg->getType()
6561 : Type::getInt64Ty(C&: Builder.getContext()));
6562 } else {
6563 for (auto &Arg : Inputs)
6564 ParameterTypes.push_back(Elt: Arg->getType());
6565 }
6566
6567 auto FuncType = FunctionType::get(Result: Builder.getVoidTy(), Params: ParameterTypes,
6568 /*isVarArg*/ false);
6569 auto Func = Function::Create(Ty: FuncType, Linkage: GlobalValue::InternalLinkage, N: FuncName,
6570 M: Builder.GetInsertBlock()->getModule());
6571
6572 // Save insert point.
6573 auto OldInsertPoint = Builder.saveIP();
6574
6575 // Generate the region into the function.
6576 BasicBlock *EntryBB = BasicBlock::Create(Context&: Builder.getContext(), Name: "entry", Parent: Func);
6577 Builder.SetInsertPoint(EntryBB);
6578
6579 // Insert target init call in the device compilation pass.
6580 if (OMPBuilder.Config.isTargetDevice())
6581 Builder.restoreIP(IP: OMPBuilder.createTargetInit(Loc: Builder, /*IsSPMD*/ false));
6582
6583 BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
6584
6585 // As we embed the user code in the middle of our target region after we
6586 // generate entry code, we must move what allocas we can into the entry
6587 // block to avoid possible breaking optimisations for device
6588 if (OMPBuilder.Config.isTargetDevice())
6589 OMPBuilder.ConstantAllocaRaiseCandidates.emplace_back(Args&: Func);
6590
6591 // Insert target deinit call in the device compilation pass.
6592 Builder.restoreIP(IP: CBFunc(Builder.saveIP(), Builder.saveIP()));
6593 if (OMPBuilder.Config.isTargetDevice())
6594 OMPBuilder.createTargetDeinit(Loc: Builder);
6595
6596 // Insert return instruction.
6597 Builder.CreateRetVoid();
6598
6599 // New Alloca IP at entry point of created device function.
6600 Builder.SetInsertPoint(EntryBB->getFirstNonPHI());
6601 auto AllocaIP = Builder.saveIP();
6602
6603 Builder.SetInsertPoint(UserCodeEntryBB->getFirstNonPHIOrDbg());
6604
6605 // Skip the artificial dyn_ptr on the device.
6606 const auto &ArgRange =
6607 OMPBuilder.Config.isTargetDevice()
6608 ? make_range(x: Func->arg_begin() + 1, y: Func->arg_end())
6609 : Func->args();
6610
6611 auto ReplaceValue = [](Value *Input, Value *InputCopy, Function *Func) {
6612 // Things like GEP's can come in the form of Constants. Constants and
6613 // ConstantExpr's do not have access to the knowledge of what they're
6614 // contained in, so we must dig a little to find an instruction so we
6615 // can tell if they're used inside of the function we're outlining. We
6616 // also replace the original constant expression with a new instruction
6617 // equivalent; an instruction as it allows easy modification in the
6618 // following loop, as we can now know the constant (instruction) is
6619 // owned by our target function and replaceUsesOfWith can now be invoked
6620 // on it (cannot do this with constants it seems). A brand new one also
6621 // allows us to be cautious as it is perhaps possible the old expression
6622 // was used inside of the function but exists and is used externally
6623 // (unlikely by the nature of a Constant, but still).
6624 // NOTE: We cannot remove dead constants that have been rewritten to
6625 // instructions at this stage, we run the risk of breaking later lowering
6626 // by doing so as we could still be in the process of lowering the module
6627 // from MLIR to LLVM-IR and the MLIR lowering may still require the original
6628 // constants we have created rewritten versions of.
6629 if (auto *Const = dyn_cast<Constant>(Val: Input))
6630 convertUsersOfConstantsToInstructions(Consts: Const, RestrictToFunc: Func, RemoveDeadConstants: false);
6631
6632 // Collect all the instructions
6633 for (User *User : make_early_inc_range(Range: Input->users()))
6634 if (auto *Instr = dyn_cast<Instruction>(Val: User))
6635 if (Instr->getFunction() == Func)
6636 Instr->replaceUsesOfWith(From: Input, To: InputCopy);
6637 };
6638
6639 SmallVector<std::pair<Value *, Value *>> DeferredReplacement;
6640
6641 // Rewrite uses of input valus to parameters.
6642 for (auto InArg : zip(t&: Inputs, u: ArgRange)) {
6643 Value *Input = std::get<0>(t&: InArg);
6644 Argument &Arg = std::get<1>(t&: InArg);
6645 Value *InputCopy = nullptr;
6646
6647 Builder.restoreIP(
6648 IP: ArgAccessorFuncCB(Arg, Input, InputCopy, AllocaIP, Builder.saveIP()));
6649
6650 // In certain cases a Global may be set up for replacement, however, this
6651 // Global may be used in multiple arguments to the kernel, just segmented
6652 // apart, for example, if we have a global array, that is sectioned into
6653 // multiple mappings (technically not legal in OpenMP, but there is a case
6654 // in Fortran for Common Blocks where this is neccesary), we will end up
6655 // with GEP's into this array inside the kernel, that refer to the Global
6656 // but are technically seperate arguments to the kernel for all intents and
6657 // purposes. If we have mapped a segment that requires a GEP into the 0-th
6658 // index, it will fold into an referal to the Global, if we then encounter
6659 // this folded GEP during replacement all of the references to the
6660 // Global in the kernel will be replaced with the argument we have generated
6661 // that corresponds to it, including any other GEP's that refer to the
6662 // Global that may be other arguments. This will invalidate all of the other
6663 // preceding mapped arguments that refer to the same global that may be
6664 // seperate segments. To prevent this, we defer global processing until all
6665 // other processing has been performed.
6666 if (llvm::isa<llvm::GlobalValue>(Val: std::get<0>(t&: InArg)) ||
6667 llvm::isa<llvm::GlobalObject>(Val: std::get<0>(t&: InArg)) ||
6668 llvm::isa<llvm::GlobalVariable>(Val: std::get<0>(t&: InArg))) {
6669 DeferredReplacement.push_back(Elt: std::make_pair(x&: Input, y&: InputCopy));
6670 continue;
6671 }
6672
6673 ReplaceValue(Input, InputCopy, Func);
6674 }
6675
6676 // Replace all of our deferred Input values, currently just Globals.
6677 for (auto Deferred : DeferredReplacement)
6678 ReplaceValue(std::get<0>(in&: Deferred), std::get<1>(in&: Deferred), Func);
6679
6680 // Restore insert point.
6681 Builder.restoreIP(IP: OldInsertPoint);
6682
6683 return Func;
6684}
6685
6686/// Create an entry point for a target task with the following.
6687/// It'll have the following signature
6688/// void @.omp_target_task_proxy_func(i32 %thread.id, ptr %task)
6689/// This function is called from emitTargetTask once the
6690/// code to launch the target kernel has been outlined already.
6691static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
6692 IRBuilderBase &Builder,
6693 CallInst *StaleCI) {
6694 Module &M = OMPBuilder.M;
6695 // KernelLaunchFunction is the target launch function, i.e.
6696 // the function that sets up kernel arguments and calls
6697 // __tgt_target_kernel to launch the kernel on the device.
6698 //
6699 Function *KernelLaunchFunction = StaleCI->getCalledFunction();
6700
6701 // StaleCI is the CallInst which is the call to the outlined
6702 // target kernel launch function. If there are values that the
6703 // outlined function uses then these are aggregated into a structure
6704 // which is passed as the second argument. If not, then there's
6705 // only one argument, the threadID. So, StaleCI can be
6706 //
6707 // %structArg = alloca { ptr, ptr }, align 8
6708 // %gep_ = getelementptr { ptr, ptr }, ptr %structArg, i32 0, i32 0
6709 // store ptr %20, ptr %gep_, align 8
6710 // %gep_8 = getelementptr { ptr, ptr }, ptr %structArg, i32 0, i32 1
6711 // store ptr %21, ptr %gep_8, align 8
6712 // call void @_QQmain..omp_par.1(i32 %global.tid.val6, ptr %structArg)
6713 //
6714 // OR
6715 //
6716 // call void @_QQmain..omp_par.1(i32 %global.tid.val6)
6717 OpenMPIRBuilder::InsertPointTy IP(StaleCI->getParent(),
6718 StaleCI->getIterator());
6719 LLVMContext &Ctx = StaleCI->getParent()->getContext();
6720 Type *ThreadIDTy = Type::getInt32Ty(C&: Ctx);
6721 Type *TaskPtrTy = OMPBuilder.TaskPtr;
6722 Type *TaskTy = OMPBuilder.Task;
6723 auto ProxyFnTy =
6724 FunctionType::get(Result: Builder.getVoidTy(), Params: {ThreadIDTy, TaskPtrTy},
6725 /* isVarArg */ false);
6726 auto ProxyFn = Function::Create(Ty: ProxyFnTy, Linkage: GlobalValue::InternalLinkage,
6727 N: ".omp_target_task_proxy_func",
6728 M: Builder.GetInsertBlock()->getModule());
6729 ProxyFn->getArg(i: 0)->setName("thread.id");
6730 ProxyFn->getArg(i: 1)->setName("task");
6731
6732 BasicBlock *EntryBB =
6733 BasicBlock::Create(Context&: Builder.getContext(), Name: "entry", Parent: ProxyFn);
6734 Builder.SetInsertPoint(EntryBB);
6735
6736 bool HasShareds = StaleCI->arg_size() > 1;
6737 // TODO: This is a temporary assert to prove to ourselves that
6738 // the outlined target launch function is always going to have
6739 // atmost two arguments if there is any data shared between
6740 // host and device.
6741 assert((!HasShareds || (StaleCI->arg_size() == 2)) &&
6742 "StaleCI with shareds should have exactly two arguments.");
6743 if (HasShareds) {
6744 auto *ArgStructAlloca = dyn_cast<AllocaInst>(Val: StaleCI->getArgOperand(i: 1));
6745 assert(ArgStructAlloca &&
6746 "Unable to find the alloca instruction corresponding to arguments "
6747 "for extracted function");
6748 auto *ArgStructType =
6749 dyn_cast<StructType>(Val: ArgStructAlloca->getAllocatedType());
6750
6751 AllocaInst *NewArgStructAlloca =
6752 Builder.CreateAlloca(Ty: ArgStructType, ArraySize: nullptr, Name: "structArg");
6753 Value *TaskT = ProxyFn->getArg(i: 1);
6754 Value *ThreadId = ProxyFn->getArg(i: 0);
6755 Value *SharedsSize =
6756 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: ArgStructType));
6757
6758 Value *Shareds = Builder.CreateStructGEP(Ty: TaskTy, Ptr: TaskT, Idx: 0);
6759 LoadInst *LoadShared =
6760 Builder.CreateLoad(Ty: PointerType::getUnqual(C&: Ctx), Ptr: Shareds);
6761
6762 Builder.CreateMemCpy(
6763 Dst: NewArgStructAlloca, DstAlign: NewArgStructAlloca->getAlign(), Src: LoadShared,
6764 SrcAlign: LoadShared->getPointerAlignment(DL: M.getDataLayout()), Size: SharedsSize);
6765
6766 Builder.CreateCall(Callee: KernelLaunchFunction, Args: {ThreadId, NewArgStructAlloca});
6767 }
6768 Builder.CreateRetVoid();
6769 return ProxyFn;
6770}
6771static void emitTargetOutlinedFunction(
6772 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
6773 TargetRegionEntryInfo &EntryInfo, Function *&OutlinedFn,
6774 Constant *&OutlinedFnID, SmallVectorImpl<Value *> &Inputs,
6775 OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
6776 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
6777
6778 OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
6779 [&OMPBuilder, &Builder, &Inputs, &CBFunc,
6780 &ArgAccessorFuncCB](StringRef EntryFnName) {
6781 return createOutlinedFunction(OMPBuilder, Builder, FuncName: EntryFnName, Inputs,
6782 CBFunc, ArgAccessorFuncCB);
6783 };
6784
6785 OMPBuilder.emitTargetRegionFunction(EntryInfo, GenerateFunctionCallback&: GenerateOutlinedFunction, IsOffloadEntry: true,
6786 OutlinedFn, OutlinedFnID);
6787}
6788OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
6789 Function *OutlinedFn, Value *OutlinedFnID,
6790 EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
6791 Value *DeviceID, Value *RTLoc, OpenMPIRBuilder::InsertPointTy AllocaIP,
6792 SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
6793 bool HasNoWait) {
6794
6795 // When we arrive at this function, the target region itself has been
6796 // outlined into the function OutlinedFn.
6797 // So at ths point, for
6798 // --------------------------------------------------
6799 // void user_code_that_offloads(...) {
6800 // omp target depend(..) map(from:a) map(to:b, c)
6801 // a = b + c
6802 // }
6803 //
6804 // --------------------------------------------------
6805 //
6806 // we have
6807 //
6808 // --------------------------------------------------
6809 //
6810 // void user_code_that_offloads(...) {
6811 // %.offload_baseptrs = alloca [3 x ptr], align 8
6812 // %.offload_ptrs = alloca [3 x ptr], align 8
6813 // %.offload_mappers = alloca [3 x ptr], align 8
6814 // ;; target region has been outlined and now we need to
6815 // ;; offload to it via a target task.
6816 // }
6817 // void outlined_device_function(ptr a, ptr b, ptr c) {
6818 // *a = *b + *c
6819 // }
6820 //
6821 // We have to now do the following
6822 // (i) Make an offloading call to outlined_device_function using the OpenMP
6823 // RTL. See 'kernel_launch_function' in the pseudo code below. This is
6824 // emitted by emitKernelLaunch
6825 // (ii) Create a task entry point function that calls kernel_launch_function
6826 // and is the entry point for the target task. See
6827 // '@.omp_target_task_proxy_func in the pseudocode below.
6828 // (iii) Create a task with the task entry point created in (ii)
6829 //
6830 // That is we create the following
6831 //
6832 // void user_code_that_offloads(...) {
6833 // %.offload_baseptrs = alloca [3 x ptr], align 8
6834 // %.offload_ptrs = alloca [3 x ptr], align 8
6835 // %.offload_mappers = alloca [3 x ptr], align 8
6836 //
6837 // %structArg = alloca { ptr, ptr, ptr }, align 8
6838 // %strucArg[0] = %.offload_baseptrs
6839 // %strucArg[1] = %.offload_ptrs
6840 // %strucArg[2] = %.offload_mappers
6841 // proxy_target_task = @__kmpc_omp_task_alloc(...,
6842 // @.omp_target_task_proxy_func)
6843 // memcpy(proxy_target_task->shareds, %structArg, sizeof(structArg))
6844 // dependencies_array = ...
6845 // ;; if nowait not present
6846 // call @__kmpc_omp_wait_deps(..., dependencies_array)
6847 // call @__kmpc_omp_task_begin_if0(...)
6848 // call @ @.omp_target_task_proxy_func(i32 thread_id, ptr
6849 // %proxy_target_task) call @__kmpc_omp_task_complete_if0(...)
6850 // }
6851 //
6852 // define internal void @.omp_target_task_proxy_func(i32 %thread.id,
6853 // ptr %task) {
6854 // %structArg = alloca {ptr, ptr, ptr}
6855 // %shared_data = load (getelementptr %task, 0, 0)
6856 // mempcy(%structArg, %shared_data, sizeof(structArg))
6857 // kernel_launch_function(%thread.id, %structArg)
6858 // }
6859 //
6860 // We need the proxy function because the signature of the task entry point
6861 // expected by kmpc_omp_task is always the same and will be different from
6862 // that of the kernel_launch function.
6863 //
6864 // kernel_launch_function is generated by emitKernelLaunch and has the
6865 // always_inline attribute.
6866 // void kernel_launch_function(thread_id,
6867 // structArg) alwaysinline {
6868 // %kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
6869 // offload_baseptrs = load(getelementptr structArg, 0, 0)
6870 // offload_ptrs = load(getelementptr structArg, 0, 1)
6871 // offload_mappers = load(getelementptr structArg, 0, 2)
6872 // ; setup kernel_args using offload_baseptrs, offload_ptrs and
6873 // ; offload_mappers
6874 // call i32 @__tgt_target_kernel(...,
6875 // outlined_device_function,
6876 // ptr %kernel_args)
6877 // }
6878 // void outlined_device_function(ptr a, ptr b, ptr c) {
6879 // *a = *b + *c
6880 // }
6881 //
6882 BasicBlock *TargetTaskBodyBB =
6883 splitBB(Builder, /*CreateBranch=*/true, Name: "target.task.body");
6884 BasicBlock *TargetTaskAllocaBB =
6885 splitBB(Builder, /*CreateBranch=*/true, Name: "target.task.alloca");
6886
6887 InsertPointTy TargetTaskAllocaIP(TargetTaskAllocaBB,
6888 TargetTaskAllocaBB->begin());
6889 InsertPointTy TargetTaskBodyIP(TargetTaskBodyBB, TargetTaskBodyBB->begin());
6890
6891 OutlineInfo OI;
6892 OI.EntryBB = TargetTaskAllocaBB;
6893 OI.OuterAllocaBB = AllocaIP.getBlock();
6894
6895 // Add the thread ID argument.
6896 SmallVector<Instruction *, 4> ToBeDeleted;
6897 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
6898 Builder, OuterAllocaIP: AllocaIP, ToBeDeleted, InnerAllocaIP: TargetTaskAllocaIP, Name: "global.tid", AsPtr: false));
6899
6900 Builder.restoreIP(IP: TargetTaskBodyIP);
6901
6902 // emitKernelLaunch makes the necessary runtime call to offload the kernel.
6903 // We then outline all that code into a separate function
6904 // ('kernel_launch_function' in the pseudo code above). This function is then
6905 // called by the target task proxy function (see
6906 // '@.omp_target_task_proxy_func' in the pseudo code above)
6907 // "@.omp_target_task_proxy_func' is generated by emitTargetTaskProxyFunction
6908 Builder.restoreIP(IP: emitKernelLaunch(Loc: Builder, OutlinedFn, OutlinedFnID,
6909 emitTargetCallFallbackCB: EmitTargetCallFallbackCB, Args, DeviceID,
6910 RTLoc, AllocaIP: TargetTaskAllocaIP));
6911
6912 OI.ExitBB = Builder.saveIP().getBlock();
6913 OI.PostOutlineCB = [this, ToBeDeleted, Dependencies,
6914 HasNoWait](Function &OutlinedFn) mutable {
6915 assert(OutlinedFn.getNumUses() == 1 &&
6916 "there must be a single user for the outlined function");
6917
6918 CallInst *StaleCI = cast<CallInst>(Val: OutlinedFn.user_back());
6919 bool HasShareds = StaleCI->arg_size() > 1;
6920
6921 Function *ProxyFn = emitTargetTaskProxyFunction(OMPBuilder&: *this, Builder, StaleCI);
6922
6923 LLVM_DEBUG(dbgs() << "Proxy task entry function created: " << *ProxyFn
6924 << "\n");
6925
6926 Builder.SetInsertPoint(StaleCI);
6927
6928 // Gather the arguments for emitting the runtime call.
6929 uint32_t SrcLocStrSize;
6930 Constant *SrcLocStr =
6931 getOrCreateSrcLocStr(Loc: LocationDescription(Builder), SrcLocStrSize);
6932 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6933
6934 // @__kmpc_omp_task_alloc
6935 Function *TaskAllocFn =
6936 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_alloc);
6937
6938 // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
6939 // call.
6940 Value *ThreadID = getOrCreateThreadID(Ident);
6941
6942 // Argument - `sizeof_kmp_task_t` (TaskSize)
6943 // Tasksize refers to the size in bytes of kmp_task_t data structure
6944 // including private vars accessed in task.
6945 // TODO: add kmp_task_t_with_privates (privates)
6946 Value *TaskSize =
6947 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: Task));
6948
6949 // Argument - `sizeof_shareds` (SharedsSize)
6950 // SharedsSize refers to the shareds array size in the kmp_task_t data
6951 // structure.
6952 Value *SharedsSize = Builder.getInt64(C: 0);
6953 if (HasShareds) {
6954 auto *ArgStructAlloca = dyn_cast<AllocaInst>(Val: StaleCI->getArgOperand(i: 1));
6955 assert(ArgStructAlloca &&
6956 "Unable to find the alloca instruction corresponding to arguments "
6957 "for extracted function");
6958 auto *ArgStructType =
6959 dyn_cast<StructType>(Val: ArgStructAlloca->getAllocatedType());
6960 assert(ArgStructType && "Unable to find struct type corresponding to "
6961 "arguments for extracted function");
6962 SharedsSize =
6963 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: ArgStructType));
6964 }
6965
6966 // Argument - `flags`
6967 // Task is tied iff (Flags & 1) == 1.
6968 // Task is untied iff (Flags & 1) == 0.
6969 // Task is final iff (Flags & 2) == 2.
6970 // Task is not final iff (Flags & 2) == 0.
6971 // A target task is not final and is untied.
6972 Value *Flags = Builder.getInt32(C: 0);
6973
6974 // Emit the @__kmpc_omp_task_alloc runtime call
6975 // The runtime call returns a pointer to an area where the task captured
6976 // variables must be copied before the task is run (TaskData)
6977 CallInst *TaskData = Builder.CreateCall(
6978 Callee: TaskAllocFn, Args: {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
6979 /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
6980 /*task_func=*/ProxyFn});
6981
6982 if (HasShareds) {
6983 Value *Shareds = StaleCI->getArgOperand(i: 1);
6984 Align Alignment = TaskData->getPointerAlignment(DL: M.getDataLayout());
6985 Value *TaskShareds = Builder.CreateLoad(Ty: VoidPtr, Ptr: TaskData);
6986 Builder.CreateMemCpy(Dst: TaskShareds, DstAlign: Alignment, Src: Shareds, SrcAlign: Alignment,
6987 Size: SharedsSize);
6988 }
6989
6990 Value *DepArray = emitTaskDependencies(OMPBuilder&: *this, Dependencies);
6991
6992 // ---------------------------------------------------------------
6993 // V5.2 13.8 target construct
6994 // If the nowait clause is present, execution of the target task
6995 // may be deferred. If the nowait clause is not present, the target task is
6996 // an included task.
6997 // ---------------------------------------------------------------
6998 // The above means that the lack of a nowait on the target construct
6999 // translates to '#pragma omp task if(0)'
7000 if (!HasNoWait) {
7001 if (DepArray) {
7002 Function *TaskWaitFn =
7003 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_wait_deps);
7004 Builder.CreateCall(
7005 Callee: TaskWaitFn,
7006 Args: {/*loc_ref=*/Ident, /*gtid=*/ThreadID,
7007 /*ndeps=*/Builder.getInt32(C: Dependencies.size()),
7008 /*dep_list=*/DepArray,
7009 /*ndeps_noalias=*/ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
7010 /*noalias_dep_list=*/
7011 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
7012 }
7013 // Included task.
7014 Function *TaskBeginFn =
7015 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_begin_if0);
7016 Function *TaskCompleteFn =
7017 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_complete_if0);
7018 Builder.CreateCall(Callee: TaskBeginFn, Args: {Ident, ThreadID, TaskData});
7019 CallInst *CI = nullptr;
7020 if (HasShareds)
7021 CI = Builder.CreateCall(Callee: ProxyFn, Args: {ThreadID, TaskData});
7022 else
7023 CI = Builder.CreateCall(Callee: ProxyFn, Args: {ThreadID});
7024 CI->setDebugLoc(StaleCI->getDebugLoc());
7025 Builder.CreateCall(Callee: TaskCompleteFn, Args: {Ident, ThreadID, TaskData});
7026 } else if (DepArray) {
7027 // HasNoWait - meaning the task may be deferred. Call
7028 // __kmpc_omp_task_with_deps if there are dependencies,
7029 // else call __kmpc_omp_task
7030 Function *TaskFn =
7031 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_with_deps);
7032 Builder.CreateCall(
7033 Callee: TaskFn,
7034 Args: {Ident, ThreadID, TaskData, Builder.getInt32(C: Dependencies.size()),
7035 DepArray, ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
7036 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
7037 } else {
7038 // Emit the @__kmpc_omp_task runtime call to spawn the task
7039 Function *TaskFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task);
7040 Builder.CreateCall(Callee: TaskFn, Args: {Ident, ThreadID, TaskData});
7041 }
7042
7043 StaleCI->eraseFromParent();
7044 llvm::for_each(Range: llvm::reverse(C&: ToBeDeleted),
7045 F: [](Instruction *I) { I->eraseFromParent(); });
7046 };
7047 addOutlineInfo(OI: std::move(OI));
7048
7049 LLVM_DEBUG(dbgs() << "Insert block after emitKernelLaunch = \n"
7050 << *(Builder.GetInsertBlock()) << "\n");
7051 LLVM_DEBUG(dbgs() << "Module after emitKernelLaunch = \n"
7052 << *(Builder.GetInsertBlock()->getParent()->getParent())
7053 << "\n");
7054 return Builder.saveIP();
7055}
7056static void emitTargetCall(
7057 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7058 OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
7059 Constant *OutlinedFnID, int32_t NumTeams, int32_t NumThreads,
7060 SmallVectorImpl<Value *> &Args,
7061 OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
7062 SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {}) {
7063
7064 OpenMPIRBuilder::TargetDataInfo Info(
7065 /*RequiresDevicePointerInfo=*/false,
7066 /*SeparateBeginEndCalls=*/true);
7067
7068 OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
7069 OMPBuilder.emitOffloadingArrays(AllocaIP, CodeGenIP: Builder.saveIP(), CombinedInfo&: MapInfo, Info,
7070 /*IsNonContiguous=*/true);
7071
7072 OpenMPIRBuilder::TargetDataRTArgs RTArgs;
7073 OMPBuilder.emitOffloadingArraysArgument(Builder, RTArgs, Info,
7074 EmitDebug: !MapInfo.Names.empty());
7075
7076 // emitKernelLaunch
7077 auto &&EmitTargetCallFallbackCB =
7078 [&](OpenMPIRBuilder::InsertPointTy IP) -> OpenMPIRBuilder::InsertPointTy {
7079 Builder.restoreIP(IP);
7080 Builder.CreateCall(Callee: OutlinedFn, Args);
7081 return Builder.saveIP();
7082 };
7083
7084 unsigned NumTargetItems = MapInfo.BasePointers.size();
7085 // TODO: Use correct device ID
7086 Value *DeviceID = Builder.getInt64(C: OMP_DEVICEID_UNDEF);
7087 Value *NumTeamsVal = Builder.getInt32(C: NumTeams);
7088 Value *NumThreadsVal = Builder.getInt32(C: NumThreads);
7089 uint32_t SrcLocStrSize;
7090 Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
7091 Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
7092 LocFlags: llvm::omp::IdentFlag(0), Reserve2Flags: 0);
7093 // TODO: Use correct NumIterations
7094 Value *NumIterations = Builder.getInt64(C: 0);
7095 // TODO: Use correct DynCGGroupMem
7096 Value *DynCGGroupMem = Builder.getInt32(C: 0);
7097
7098 bool HasNoWait = false;
7099 bool HasDependencies = Dependencies.size() > 0;
7100 bool RequiresOuterTargetTask = HasNoWait || HasDependencies;
7101
7102 OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, NumIterations,
7103 NumTeamsVal, NumThreadsVal,
7104 DynCGGroupMem, HasNoWait);
7105
7106 // The presence of certain clauses on the target directive require the
7107 // explicit generation of the target task.
7108 if (RequiresOuterTargetTask) {
7109 Builder.restoreIP(IP: OMPBuilder.emitTargetTask(
7110 OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, Args&: KArgs, DeviceID,
7111 RTLoc, AllocaIP, Dependencies, HasNoWait));
7112 } else {
7113 Builder.restoreIP(IP: OMPBuilder.emitKernelLaunch(
7114 Loc: Builder, OutlinedFn, OutlinedFnID, emitTargetCallFallbackCB: EmitTargetCallFallbackCB, Args&: KArgs,
7115 DeviceID, RTLoc, AllocaIP));
7116 }
7117}
7118OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
7119 const LocationDescription &Loc, InsertPointTy AllocaIP,
7120 InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, int32_t NumTeams,
7121 int32_t NumThreads, SmallVectorImpl<Value *> &Args,
7122 GenMapInfoCallbackTy GenMapInfoCB,
7123 OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
7124 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
7125 SmallVector<DependData> Dependencies) {
7126
7127 if (!updateToLocation(Loc))
7128 return InsertPointTy();
7129
7130 Builder.restoreIP(IP: CodeGenIP);
7131
7132 Function *OutlinedFn;
7133 Constant *OutlinedFnID;
7134 // The target region is outlined into its own function. The LLVM IR for
7135 // the target region itself is generated using the callbacks CBFunc
7136 // and ArgAccessorFuncCB
7137 emitTargetOutlinedFunction(OMPBuilder&: *this, Builder, EntryInfo, OutlinedFn,
7138 OutlinedFnID, Inputs&: Args, CBFunc, ArgAccessorFuncCB);
7139
7140 // If we are not on the target device, then we need to generate code
7141 // to make a remote call (offload) to the previously outlined function
7142 // that represents the target region. Do that now.
7143 if (!Config.isTargetDevice())
7144 emitTargetCall(OMPBuilder&: *this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
7145 NumThreads, Args, GenMapInfoCB, Dependencies);
7146 return Builder.saveIP();
7147}
7148
7149std::string OpenMPIRBuilder::getNameWithSeparators(ArrayRef<StringRef> Parts,
7150 StringRef FirstSeparator,
7151 StringRef Separator) {
7152 SmallString<128> Buffer;
7153 llvm::raw_svector_ostream OS(Buffer);
7154 StringRef Sep = FirstSeparator;
7155 for (StringRef Part : Parts) {
7156 OS << Sep << Part;
7157 Sep = Separator;
7158 }
7159 return OS.str().str();
7160}
7161
7162std::string
7163OpenMPIRBuilder::createPlatformSpecificName(ArrayRef<StringRef> Parts) const {
7164 return OpenMPIRBuilder::getNameWithSeparators(Parts, FirstSeparator: Config.firstSeparator(),
7165 Separator: Config.separator());
7166}
7167
7168GlobalVariable *
7169OpenMPIRBuilder::getOrCreateInternalVariable(Type *Ty, const StringRef &Name,
7170 unsigned AddressSpace) {
7171 auto &Elem = *InternalVars.try_emplace(Key: Name, Args: nullptr).first;
7172 if (Elem.second) {
7173 assert(Elem.second->getValueType() == Ty &&
7174 "OMP internal variable has different type than requested");
7175 } else {
7176 // TODO: investigate the appropriate linkage type used for the global
7177 // variable for possibly changing that to internal or private, or maybe
7178 // create different versions of the function for different OMP internal
7179 // variables.
7180 auto Linkage = this->M.getTargetTriple().rfind(s: "wasm32") == 0
7181 ? GlobalValue::ExternalLinkage
7182 : GlobalValue::CommonLinkage;
7183 auto *GV = new GlobalVariable(M, Ty, /*IsConstant=*/false, Linkage,
7184 Constant::getNullValue(Ty), Elem.first(),
7185 /*InsertBefore=*/nullptr,
7186 GlobalValue::NotThreadLocal, AddressSpace);
7187 const DataLayout &DL = M.getDataLayout();
7188 const llvm::Align TypeAlign = DL.getABITypeAlign(Ty);
7189 const llvm::Align PtrAlign = DL.getPointerABIAlignment(AS: AddressSpace);
7190 GV->setAlignment(std::max(a: TypeAlign, b: PtrAlign));
7191 Elem.second = GV;
7192 }
7193
7194 return Elem.second;
7195}
7196
7197Value *OpenMPIRBuilder::getOMPCriticalRegionLock(StringRef CriticalName) {
7198 std::string Prefix = Twine("gomp_critical_user_", CriticalName).str();
7199 std::string Name = getNameWithSeparators(Parts: {Prefix, "var"}, FirstSeparator: ".", Separator: ".");
7200 return getOrCreateInternalVariable(Ty: KmpCriticalNameTy, Name);
7201}
7202
7203Value *OpenMPIRBuilder::getSizeInBytes(Value *BasePtr) {
7204 LLVMContext &Ctx = Builder.getContext();
7205 Value *Null =
7206 Constant::getNullValue(Ty: PointerType::getUnqual(C&: BasePtr->getContext()));
7207 Value *SizeGep =
7208 Builder.CreateGEP(Ty: BasePtr->getType(), Ptr: Null, IdxList: Builder.getInt32(C: 1));
7209 Value *SizePtrToInt = Builder.CreatePtrToInt(V: SizeGep, DestTy: Type::getInt64Ty(C&: Ctx));
7210 return SizePtrToInt;
7211}
7212
7213GlobalVariable *
7214OpenMPIRBuilder::createOffloadMaptypes(SmallVectorImpl<uint64_t> &Mappings,
7215 std::string VarName) {
7216 llvm::Constant *MaptypesArrayInit =
7217 llvm::ConstantDataArray::get(Context&: M.getContext(), Elts&: Mappings);
7218 auto *MaptypesArrayGlobal = new llvm::GlobalVariable(
7219 M, MaptypesArrayInit->getType(),
7220 /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MaptypesArrayInit,
7221 VarName);
7222 MaptypesArrayGlobal->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
7223 return MaptypesArrayGlobal;
7224}
7225
7226void OpenMPIRBuilder::createMapperAllocas(const LocationDescription &Loc,
7227 InsertPointTy AllocaIP,
7228 unsigned NumOperands,
7229 struct MapperAllocas &MapperAllocas) {
7230 if (!updateToLocation(Loc))
7231 return;
7232
7233 auto *ArrI8PtrTy = ArrayType::get(ElementType: Int8Ptr, NumElements: NumOperands);
7234 auto *ArrI64Ty = ArrayType::get(ElementType: Int64, NumElements: NumOperands);
7235 Builder.restoreIP(IP: AllocaIP);
7236 AllocaInst *ArgsBase = Builder.CreateAlloca(
7237 Ty: ArrI8PtrTy, /* ArraySize = */ nullptr, Name: ".offload_baseptrs");
7238 AllocaInst *Args = Builder.CreateAlloca(Ty: ArrI8PtrTy, /* ArraySize = */ nullptr,
7239 Name: ".offload_ptrs");
7240 AllocaInst *ArgSizes = Builder.CreateAlloca(
7241 Ty: ArrI64Ty, /* ArraySize = */ nullptr, Name: ".offload_sizes");
7242 Builder.restoreIP(IP: Loc.IP);
7243 MapperAllocas.ArgsBase = ArgsBase;
7244 MapperAllocas.Args = Args;
7245 MapperAllocas.ArgSizes = ArgSizes;
7246}
7247
7248void OpenMPIRBuilder::emitMapperCall(const LocationDescription &Loc,
7249 Function *MapperFunc, Value *SrcLocInfo,
7250 Value *MaptypesArg, Value *MapnamesArg,
7251 struct MapperAllocas &MapperAllocas,
7252 int64_t DeviceID, unsigned NumOperands) {
7253 if (!updateToLocation(Loc))
7254 return;
7255
7256 auto *ArrI8PtrTy = ArrayType::get(ElementType: Int8Ptr, NumElements: NumOperands);
7257 auto *ArrI64Ty = ArrayType::get(ElementType: Int64, NumElements: NumOperands);
7258 Value *ArgsBaseGEP =
7259 Builder.CreateInBoundsGEP(Ty: ArrI8PtrTy, Ptr: MapperAllocas.ArgsBase,
7260 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 0)});
7261 Value *ArgsGEP =
7262 Builder.CreateInBoundsGEP(Ty: ArrI8PtrTy, Ptr: MapperAllocas.Args,
7263 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 0)});
7264 Value *ArgSizesGEP =
7265 Builder.CreateInBoundsGEP(Ty: ArrI64Ty, Ptr: MapperAllocas.ArgSizes,
7266 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 0)});
7267 Value *NullPtr =
7268 Constant::getNullValue(Ty: PointerType::getUnqual(C&: Int8Ptr->getContext()));
7269 Builder.CreateCall(Callee: MapperFunc,
7270 Args: {SrcLocInfo, Builder.getInt64(C: DeviceID),
7271 Builder.getInt32(C: NumOperands), ArgsBaseGEP, ArgsGEP,
7272 ArgSizesGEP, MaptypesArg, MapnamesArg, NullPtr});
7273}
7274
7275void OpenMPIRBuilder::emitOffloadingArraysArgument(IRBuilderBase &Builder,
7276 TargetDataRTArgs &RTArgs,
7277 TargetDataInfo &Info,
7278 bool EmitDebug,
7279 bool ForEndCall) {
7280 assert((!ForEndCall || Info.separateBeginEndCalls()) &&
7281 "expected region end call to runtime only when end call is separate");
7282 auto UnqualPtrTy = PointerType::getUnqual(C&: M.getContext());
7283 auto VoidPtrTy = UnqualPtrTy;
7284 auto VoidPtrPtrTy = UnqualPtrTy;
7285 auto Int64Ty = Type::getInt64Ty(C&: M.getContext());
7286 auto Int64PtrTy = UnqualPtrTy;
7287
7288 if (!Info.NumberOfPtrs) {
7289 RTArgs.BasePointersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
7290 RTArgs.PointersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
7291 RTArgs.SizesArray = ConstantPointerNull::get(T: Int64PtrTy);
7292 RTArgs.MapTypesArray = ConstantPointerNull::get(T: Int64PtrTy);
7293 RTArgs.MapNamesArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
7294 RTArgs.MappersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
7295 return;
7296 }
7297
7298 RTArgs.BasePointersArray = Builder.CreateConstInBoundsGEP2_32(
7299 Ty: ArrayType::get(ElementType: VoidPtrTy, NumElements: Info.NumberOfPtrs),
7300 Ptr: Info.RTArgs.BasePointersArray,
7301 /*Idx0=*/0, /*Idx1=*/0);
7302 RTArgs.PointersArray = Builder.CreateConstInBoundsGEP2_32(
7303 Ty: ArrayType::get(ElementType: VoidPtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.PointersArray,
7304 /*Idx0=*/0,
7305 /*Idx1=*/0);
7306 RTArgs.SizesArray = Builder.CreateConstInBoundsGEP2_32(
7307 Ty: ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.SizesArray,
7308 /*Idx0=*/0, /*Idx1=*/0);
7309 RTArgs.MapTypesArray = Builder.CreateConstInBoundsGEP2_32(
7310 Ty: ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs),
7311 Ptr: ForEndCall && Info.RTArgs.MapTypesArrayEnd ? Info.RTArgs.MapTypesArrayEnd
7312 : Info.RTArgs.MapTypesArray,
7313 /*Idx0=*/0,
7314 /*Idx1=*/0);
7315
7316 // Only emit the mapper information arrays if debug information is
7317 // requested.
7318 if (!EmitDebug)
7319 RTArgs.MapNamesArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
7320 else
7321 RTArgs.MapNamesArray = Builder.CreateConstInBoundsGEP2_32(
7322 Ty: ArrayType::get(ElementType: VoidPtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.MapNamesArray,
7323 /*Idx0=*/0,
7324 /*Idx1=*/0);
7325 // If there is no user-defined mapper, set the mapper array to nullptr to
7326 // avoid an unnecessary data privatization
7327 if (!Info.HasMapper)
7328 RTArgs.MappersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
7329 else
7330 RTArgs.MappersArray =
7331 Builder.CreatePointerCast(V: Info.RTArgs.MappersArray, DestTy: VoidPtrPtrTy);
7332}
7333
7334void OpenMPIRBuilder::emitNonContiguousDescriptor(InsertPointTy AllocaIP,
7335 InsertPointTy CodeGenIP,
7336 MapInfosTy &CombinedInfo,
7337 TargetDataInfo &Info) {
7338 MapInfosTy::StructNonContiguousInfo &NonContigInfo =
7339 CombinedInfo.NonContigInfo;
7340
7341 // Build an array of struct descriptor_dim and then assign it to
7342 // offload_args.
7343 //
7344 // struct descriptor_dim {
7345 // uint64_t offset;
7346 // uint64_t count;
7347 // uint64_t stride
7348 // };
7349 Type *Int64Ty = Builder.getInt64Ty();
7350 StructType *DimTy = StructType::create(
7351 Context&: M.getContext(), Elements: ArrayRef<Type *>({Int64Ty, Int64Ty, Int64Ty}),
7352 Name: "struct.descriptor_dim");
7353
7354 enum { OffsetFD = 0, CountFD, StrideFD };
7355 // We need two index variable here since the size of "Dims" is the same as
7356 // the size of Components, however, the size of offset, count, and stride is
7357 // equal to the size of base declaration that is non-contiguous.
7358 for (unsigned I = 0, L = 0, E = NonContigInfo.Dims.size(); I < E; ++I) {
7359 // Skip emitting ir if dimension size is 1 since it cannot be
7360 // non-contiguous.
7361 if (NonContigInfo.Dims[I] == 1)
7362 continue;
7363 Builder.restoreIP(IP: AllocaIP);
7364 ArrayType *ArrayTy = ArrayType::get(ElementType: DimTy, NumElements: NonContigInfo.Dims[I]);
7365 AllocaInst *DimsAddr =
7366 Builder.CreateAlloca(Ty: ArrayTy, /* ArraySize = */ nullptr, Name: "dims");
7367 Builder.restoreIP(IP: CodeGenIP);
7368 for (unsigned II = 0, EE = NonContigInfo.Dims[I]; II < EE; ++II) {
7369 unsigned RevIdx = EE - II - 1;
7370 Value *DimsLVal = Builder.CreateInBoundsGEP(
7371 Ty: DimsAddr->getAllocatedType(), Ptr: DimsAddr,
7372 IdxList: {Builder.getInt64(C: 0), Builder.getInt64(C: II)});
7373 // Offset
7374 Value *OffsetLVal = Builder.CreateStructGEP(Ty: DimTy, Ptr: DimsLVal, Idx: OffsetFD);
7375 Builder.CreateAlignedStore(
7376 Val: NonContigInfo.Offsets[L][RevIdx], Ptr: OffsetLVal,
7377 Align: M.getDataLayout().getPrefTypeAlign(Ty: OffsetLVal->getType()));
7378 // Count
7379 Value *CountLVal = Builder.CreateStructGEP(Ty: DimTy, Ptr: DimsLVal, Idx: CountFD);
7380 Builder.CreateAlignedStore(
7381 Val: NonContigInfo.Counts[L][RevIdx], Ptr: CountLVal,
7382 Align: M.getDataLayout().getPrefTypeAlign(Ty: CountLVal->getType()));
7383 // Stride
7384 Value *StrideLVal = Builder.CreateStructGEP(Ty: DimTy, Ptr: DimsLVal, Idx: StrideFD);
7385 Builder.CreateAlignedStore(
7386 Val: NonContigInfo.Strides[L][RevIdx], Ptr: StrideLVal,
7387 Align: M.getDataLayout().getPrefTypeAlign(Ty: CountLVal->getType()));
7388 }
7389 // args[I] = &dims
7390 Builder.restoreIP(IP: CodeGenIP);
7391 Value *DAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
7392 V: DimsAddr, DestTy: Builder.getPtrTy());
7393 Value *P = Builder.CreateConstInBoundsGEP2_32(
7394 Ty: ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: Info.NumberOfPtrs),
7395 Ptr: Info.RTArgs.PointersArray, Idx0: 0, Idx1: I);
7396 Builder.CreateAlignedStore(
7397 Val: DAddr, Ptr: P, Align: M.getDataLayout().getPrefTypeAlign(Ty: Builder.getPtrTy()));
7398 ++L;
7399 }
7400}
7401
7402void OpenMPIRBuilder::emitOffloadingArrays(
7403 InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
7404 TargetDataInfo &Info, bool IsNonContiguous,
7405 function_ref<void(unsigned int, Value *)> DeviceAddrCB,
7406 function_ref<Value *(unsigned int)> CustomMapperCB) {
7407
7408 // Reset the array information.
7409 Info.clearArrayInfo();
7410 Info.NumberOfPtrs = CombinedInfo.BasePointers.size();
7411
7412 if (Info.NumberOfPtrs == 0)
7413 return;
7414
7415 Builder.restoreIP(IP: AllocaIP);
7416 // Detect if we have any capture size requiring runtime evaluation of the
7417 // size so that a constant array could be eventually used.
7418 ArrayType *PointerArrayType =
7419 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: Info.NumberOfPtrs);
7420
7421 Info.RTArgs.BasePointersArray = Builder.CreateAlloca(
7422 Ty: PointerArrayType, /* ArraySize = */ nullptr, Name: ".offload_baseptrs");
7423
7424 Info.RTArgs.PointersArray = Builder.CreateAlloca(
7425 Ty: PointerArrayType, /* ArraySize = */ nullptr, Name: ".offload_ptrs");
7426 AllocaInst *MappersArray = Builder.CreateAlloca(
7427 Ty: PointerArrayType, /* ArraySize = */ nullptr, Name: ".offload_mappers");
7428 Info.RTArgs.MappersArray = MappersArray;
7429
7430 // If we don't have any VLA types or other types that require runtime
7431 // evaluation, we can use a constant array for the map sizes, otherwise we
7432 // need to fill up the arrays as we do for the pointers.
7433 Type *Int64Ty = Builder.getInt64Ty();
7434 SmallVector<Constant *> ConstSizes(CombinedInfo.Sizes.size(),
7435 ConstantInt::get(Ty: Int64Ty, V: 0));
7436 SmallBitVector RuntimeSizes(CombinedInfo.Sizes.size());
7437 for (unsigned I = 0, E = CombinedInfo.Sizes.size(); I < E; ++I) {
7438 if (auto *CI = dyn_cast<Constant>(Val: CombinedInfo.Sizes[I])) {
7439 if (!isa<ConstantExpr>(Val: CI) && !isa<GlobalValue>(Val: CI)) {
7440 if (IsNonContiguous &&
7441 static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
7442 CombinedInfo.Types[I] &
7443 OpenMPOffloadMappingFlags::OMP_MAP_NON_CONTIG))
7444 ConstSizes[I] =
7445 ConstantInt::get(Ty: Int64Ty, V: CombinedInfo.NonContigInfo.Dims[I]);
7446 else
7447 ConstSizes[I] = CI;
7448 continue;
7449 }
7450 }
7451 RuntimeSizes.set(I);
7452 }
7453
7454 if (RuntimeSizes.all()) {
7455 ArrayType *SizeArrayType = ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs);
7456 Info.RTArgs.SizesArray = Builder.CreateAlloca(
7457 Ty: SizeArrayType, /* ArraySize = */ nullptr, Name: ".offload_sizes");
7458 Builder.restoreIP(IP: CodeGenIP);
7459 } else {
7460 auto *SizesArrayInit = ConstantArray::get(
7461 T: ArrayType::get(ElementType: Int64Ty, NumElements: ConstSizes.size()), V: ConstSizes);
7462 std::string Name = createPlatformSpecificName(Parts: {"offload_sizes"});
7463 auto *SizesArrayGbl =
7464 new GlobalVariable(M, SizesArrayInit->getType(), /*isConstant=*/true,
7465 GlobalValue::PrivateLinkage, SizesArrayInit, Name);
7466 SizesArrayGbl->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
7467
7468 if (!RuntimeSizes.any()) {
7469 Info.RTArgs.SizesArray = SizesArrayGbl;
7470 } else {
7471 unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(AS: 0);
7472 Align OffloadSizeAlign = M.getDataLayout().getABIIntegerTypeAlignment(BitWidth: 64);
7473 ArrayType *SizeArrayType = ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs);
7474 AllocaInst *Buffer = Builder.CreateAlloca(
7475 Ty: SizeArrayType, /* ArraySize = */ nullptr, Name: ".offload_sizes");
7476 Buffer->setAlignment(OffloadSizeAlign);
7477 Builder.restoreIP(IP: CodeGenIP);
7478 Builder.CreateMemCpy(
7479 Dst: Buffer, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: Buffer->getType()),
7480 Src: SizesArrayGbl, SrcAlign: OffloadSizeAlign,
7481 Size: Builder.getIntN(
7482 N: IndexSize,
7483 C: Buffer->getAllocationSize(DL: M.getDataLayout())->getFixedValue()));
7484
7485 Info.RTArgs.SizesArray = Buffer;
7486 }
7487 Builder.restoreIP(IP: CodeGenIP);
7488 }
7489
7490 // The map types are always constant so we don't need to generate code to
7491 // fill arrays. Instead, we create an array constant.
7492 SmallVector<uint64_t, 4> Mapping;
7493 for (auto mapFlag : CombinedInfo.Types)
7494 Mapping.push_back(
7495 Elt: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
7496 mapFlag));
7497 std::string MaptypesName = createPlatformSpecificName(Parts: {"offload_maptypes"});
7498 auto *MapTypesArrayGbl = createOffloadMaptypes(Mappings&: Mapping, VarName: MaptypesName);
7499 Info.RTArgs.MapTypesArray = MapTypesArrayGbl;
7500
7501 // The information types are only built if provided.
7502 if (!CombinedInfo.Names.empty()) {
7503 std::string MapnamesName = createPlatformSpecificName(Parts: {"offload_mapnames"});
7504 auto *MapNamesArrayGbl =
7505 createOffloadMapnames(Names&: CombinedInfo.Names, VarName: MapnamesName);
7506 Info.RTArgs.MapNamesArray = MapNamesArrayGbl;
7507 } else {
7508 Info.RTArgs.MapNamesArray =
7509 Constant::getNullValue(Ty: PointerType::getUnqual(C&: Builder.getContext()));
7510 }
7511
7512 // If there's a present map type modifier, it must not be applied to the end
7513 // of a region, so generate a separate map type array in that case.
7514 if (Info.separateBeginEndCalls()) {
7515 bool EndMapTypesDiffer = false;
7516 for (uint64_t &Type : Mapping) {
7517 if (Type & static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
7518 OpenMPOffloadMappingFlags::OMP_MAP_PRESENT)) {
7519 Type &= ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
7520 OpenMPOffloadMappingFlags::OMP_MAP_PRESENT);
7521 EndMapTypesDiffer = true;
7522 }
7523 }
7524 if (EndMapTypesDiffer) {
7525 MapTypesArrayGbl = createOffloadMaptypes(Mappings&: Mapping, VarName: MaptypesName);
7526 Info.RTArgs.MapTypesArrayEnd = MapTypesArrayGbl;
7527 }
7528 }
7529
7530 PointerType *PtrTy = Builder.getPtrTy();
7531 for (unsigned I = 0; I < Info.NumberOfPtrs; ++I) {
7532 Value *BPVal = CombinedInfo.BasePointers[I];
7533 Value *BP = Builder.CreateConstInBoundsGEP2_32(
7534 Ty: ArrayType::get(ElementType: PtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.BasePointersArray,
7535 Idx0: 0, Idx1: I);
7536 Builder.CreateAlignedStore(Val: BPVal, Ptr: BP,
7537 Align: M.getDataLayout().getPrefTypeAlign(Ty: PtrTy));
7538
7539 if (Info.requiresDevicePointerInfo()) {
7540 if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Pointer) {
7541 CodeGenIP = Builder.saveIP();
7542 Builder.restoreIP(IP: AllocaIP);
7543 Info.DevicePtrInfoMap[BPVal] = {BP, Builder.CreateAlloca(Ty: PtrTy)};
7544 Builder.restoreIP(IP: CodeGenIP);
7545 if (DeviceAddrCB)
7546 DeviceAddrCB(I, Info.DevicePtrInfoMap[BPVal].second);
7547 } else if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Address) {
7548 Info.DevicePtrInfoMap[BPVal] = {BP, BP};
7549 if (DeviceAddrCB)
7550 DeviceAddrCB(I, BP);
7551 }
7552 }
7553
7554 Value *PVal = CombinedInfo.Pointers[I];
7555 Value *P = Builder.CreateConstInBoundsGEP2_32(
7556 Ty: ArrayType::get(ElementType: PtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.PointersArray, Idx0: 0,
7557 Idx1: I);
7558 // TODO: Check alignment correct.
7559 Builder.CreateAlignedStore(Val: PVal, Ptr: P,
7560 Align: M.getDataLayout().getPrefTypeAlign(Ty: PtrTy));
7561
7562 if (RuntimeSizes.test(Idx: I)) {
7563 Value *S = Builder.CreateConstInBoundsGEP2_32(
7564 Ty: ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.SizesArray,
7565 /*Idx0=*/0,
7566 /*Idx1=*/I);
7567 Builder.CreateAlignedStore(Val: Builder.CreateIntCast(V: CombinedInfo.Sizes[I],
7568 DestTy: Int64Ty,
7569 /*isSigned=*/true),
7570 Ptr: S, Align: M.getDataLayout().getPrefTypeAlign(Ty: PtrTy));
7571 }
7572 // Fill up the mapper array.
7573 unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(AS: 0);
7574 Value *MFunc = ConstantPointerNull::get(T: PtrTy);
7575 if (CustomMapperCB)
7576 if (Value *CustomMFunc = CustomMapperCB(I))
7577 MFunc = Builder.CreatePointerCast(V: CustomMFunc, DestTy: PtrTy);
7578 Value *MAddr = Builder.CreateInBoundsGEP(
7579 Ty: MappersArray->getAllocatedType(), Ptr: MappersArray,
7580 IdxList: {Builder.getIntN(N: IndexSize, C: 0), Builder.getIntN(N: IndexSize, C: I)});
7581 Builder.CreateAlignedStore(
7582 Val: MFunc, Ptr: MAddr, Align: M.getDataLayout().getPrefTypeAlign(Ty: MAddr->getType()));
7583 }
7584
7585 if (!IsNonContiguous || CombinedInfo.NonContigInfo.Offsets.empty() ||
7586 Info.NumberOfPtrs == 0)
7587 return;
7588 emitNonContiguousDescriptor(AllocaIP, CodeGenIP, CombinedInfo, Info);
7589}
7590
7591void OpenMPIRBuilder::emitBranch(BasicBlock *Target) {
7592 BasicBlock *CurBB = Builder.GetInsertBlock();
7593
7594 if (!CurBB || CurBB->getTerminator()) {
7595 // If there is no insert point or the previous block is already
7596 // terminated, don't touch it.
7597 } else {
7598 // Otherwise, create a fall-through branch.
7599 Builder.CreateBr(Dest: Target);
7600 }
7601
7602 Builder.ClearInsertionPoint();
7603}
7604
7605void OpenMPIRBuilder::emitBlock(BasicBlock *BB, Function *CurFn,
7606 bool IsFinished) {
7607 BasicBlock *CurBB = Builder.GetInsertBlock();
7608
7609 // Fall out of the current block (if necessary).
7610 emitBranch(Target: BB);
7611
7612 if (IsFinished && BB->use_empty()) {
7613 BB->eraseFromParent();
7614 return;
7615 }
7616
7617 // Place the block after the current block, if possible, or else at
7618 // the end of the function.
7619 if (CurBB && CurBB->getParent())
7620 CurFn->insert(Position: std::next(x: CurBB->getIterator()), BB);
7621 else
7622 CurFn->insert(Position: CurFn->end(), BB);
7623 Builder.SetInsertPoint(BB);
7624}
7625
7626void OpenMPIRBuilder::emitIfClause(Value *Cond, BodyGenCallbackTy ThenGen,
7627 BodyGenCallbackTy ElseGen,
7628 InsertPointTy AllocaIP) {
7629 // If the condition constant folds and can be elided, try to avoid emitting
7630 // the condition and the dead arm of the if/else.
7631 if (auto *CI = dyn_cast<ConstantInt>(Val: Cond)) {
7632 auto CondConstant = CI->getSExtValue();
7633 if (CondConstant)
7634 ThenGen(AllocaIP, Builder.saveIP());
7635 else
7636 ElseGen(AllocaIP, Builder.saveIP());
7637 return;
7638 }
7639
7640 Function *CurFn = Builder.GetInsertBlock()->getParent();
7641
7642 // Otherwise, the condition did not fold, or we couldn't elide it. Just
7643 // emit the conditional branch.
7644 BasicBlock *ThenBlock = BasicBlock::Create(Context&: M.getContext(), Name: "omp_if.then");
7645 BasicBlock *ElseBlock = BasicBlock::Create(Context&: M.getContext(), Name: "omp_if.else");
7646 BasicBlock *ContBlock = BasicBlock::Create(Context&: M.getContext(), Name: "omp_if.end");
7647 Builder.CreateCondBr(Cond, True: ThenBlock, False: ElseBlock);
7648 // Emit the 'then' code.
7649 emitBlock(BB: ThenBlock, CurFn);
7650 ThenGen(AllocaIP, Builder.saveIP());
7651 emitBranch(Target: ContBlock);
7652 // Emit the 'else' code if present.
7653 // There is no need to emit line number for unconditional branch.
7654 emitBlock(BB: ElseBlock, CurFn);
7655 ElseGen(AllocaIP, Builder.saveIP());
7656 // There is no need to emit line number for unconditional branch.
7657 emitBranch(Target: ContBlock);
7658 // Emit the continuation block for code after the if.
7659 emitBlock(BB: ContBlock, CurFn, /*IsFinished=*/true);
7660}
7661
7662bool OpenMPIRBuilder::checkAndEmitFlushAfterAtomic(
7663 const LocationDescription &Loc, llvm::AtomicOrdering AO, AtomicKind AK) {
7664 assert(!(AO == AtomicOrdering::NotAtomic ||
7665 AO == llvm::AtomicOrdering::Unordered) &&
7666 "Unexpected Atomic Ordering.");
7667
7668 bool Flush = false;
7669 llvm::AtomicOrdering FlushAO = AtomicOrdering::Monotonic;
7670
7671 switch (AK) {
7672 case Read:
7673 if (AO == AtomicOrdering::Acquire || AO == AtomicOrdering::AcquireRelease ||
7674 AO == AtomicOrdering::SequentiallyConsistent) {
7675 FlushAO = AtomicOrdering::Acquire;
7676 Flush = true;
7677 }
7678 break;
7679 case Write:
7680 case Compare:
7681 case Update:
7682 if (AO == AtomicOrdering::Release || AO == AtomicOrdering::AcquireRelease ||
7683 AO == AtomicOrdering::SequentiallyConsistent) {
7684 FlushAO = AtomicOrdering::Release;
7685 Flush = true;
7686 }
7687 break;
7688 case Capture:
7689 switch (AO) {
7690 case AtomicOrdering::Acquire:
7691 FlushAO = AtomicOrdering::Acquire;
7692 Flush = true;
7693 break;
7694 case AtomicOrdering::Release:
7695 FlushAO = AtomicOrdering::Release;
7696 Flush = true;
7697 break;
7698 case AtomicOrdering::AcquireRelease:
7699 case AtomicOrdering::SequentiallyConsistent:
7700 FlushAO = AtomicOrdering::AcquireRelease;
7701 Flush = true;
7702 break;
7703 default:
7704 // do nothing - leave silently.
7705 break;
7706 }
7707 }
7708
7709 if (Flush) {
7710 // Currently Flush RT call still doesn't take memory_ordering, so for when
7711 // that happens, this tries to do the resolution of which atomic ordering
7712 // to use with but issue the flush call
7713 // TODO: pass `FlushAO` after memory ordering support is added
7714 (void)FlushAO;
7715 emitFlush(Loc);
7716 }
7717
7718 // for AO == AtomicOrdering::Monotonic and all other case combinations
7719 // do nothing
7720 return Flush;
7721}
7722
7723OpenMPIRBuilder::InsertPointTy
7724OpenMPIRBuilder::createAtomicRead(const LocationDescription &Loc,
7725 AtomicOpValue &X, AtomicOpValue &V,
7726 AtomicOrdering AO) {
7727 if (!updateToLocation(Loc))
7728 return Loc.IP;
7729
7730 assert(X.Var->getType()->isPointerTy() &&
7731 "OMP Atomic expects a pointer to target memory");
7732 Type *XElemTy = X.ElemTy;
7733 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
7734 XElemTy->isPointerTy()) &&
7735 "OMP atomic read expected a scalar type");
7736
7737 Value *XRead = nullptr;
7738
7739 if (XElemTy->isIntegerTy()) {
7740 LoadInst *XLD =
7741 Builder.CreateLoad(Ty: XElemTy, Ptr: X.Var, isVolatile: X.IsVolatile, Name: "omp.atomic.read");
7742 XLD->setAtomic(Ordering: AO);
7743 XRead = cast<Value>(Val: XLD);
7744 } else {
7745 // We need to perform atomic op as integer
7746 IntegerType *IntCastTy =
7747 IntegerType::get(C&: M.getContext(), NumBits: XElemTy->getScalarSizeInBits());
7748 LoadInst *XLoad =
7749 Builder.CreateLoad(Ty: IntCastTy, Ptr: X.Var, isVolatile: X.IsVolatile, Name: "omp.atomic.load");
7750 XLoad->setAtomic(Ordering: AO);
7751 if (XElemTy->isFloatingPointTy()) {
7752 XRead = Builder.CreateBitCast(V: XLoad, DestTy: XElemTy, Name: "atomic.flt.cast");
7753 } else {
7754 XRead = Builder.CreateIntToPtr(V: XLoad, DestTy: XElemTy, Name: "atomic.ptr.cast");
7755 }
7756 }
7757 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Read);
7758 Builder.CreateStore(Val: XRead, Ptr: V.Var, isVolatile: V.IsVolatile);
7759 return Builder.saveIP();
7760}
7761
7762OpenMPIRBuilder::InsertPointTy
7763OpenMPIRBuilder::createAtomicWrite(const LocationDescription &Loc,
7764 AtomicOpValue &X, Value *Expr,
7765 AtomicOrdering AO) {
7766 if (!updateToLocation(Loc))
7767 return Loc.IP;
7768
7769 assert(X.Var->getType()->isPointerTy() &&
7770 "OMP Atomic expects a pointer to target memory");
7771 Type *XElemTy = X.ElemTy;
7772 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
7773 XElemTy->isPointerTy()) &&
7774 "OMP atomic write expected a scalar type");
7775
7776 if (XElemTy->isIntegerTy()) {
7777 StoreInst *XSt = Builder.CreateStore(Val: Expr, Ptr: X.Var, isVolatile: X.IsVolatile);
7778 XSt->setAtomic(Ordering: AO);
7779 } else {
7780 // We need to bitcast and perform atomic op as integers
7781 IntegerType *IntCastTy =
7782 IntegerType::get(C&: M.getContext(), NumBits: XElemTy->getScalarSizeInBits());
7783 Value *ExprCast =
7784 Builder.CreateBitCast(V: Expr, DestTy: IntCastTy, Name: "atomic.src.int.cast");
7785 StoreInst *XSt = Builder.CreateStore(Val: ExprCast, Ptr: X.Var, isVolatile: X.IsVolatile);
7786 XSt->setAtomic(Ordering: AO);
7787 }
7788
7789 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Write);
7790 return Builder.saveIP();
7791}
7792
7793OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicUpdate(
7794 const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
7795 Value *Expr, AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
7796 AtomicUpdateCallbackTy &UpdateOp, bool IsXBinopExpr) {
7797 assert(!isConflictIP(Loc.IP, AllocaIP) && "IPs must not be ambiguous");
7798 if (!updateToLocation(Loc))
7799 return Loc.IP;
7800
7801 LLVM_DEBUG({
7802 Type *XTy = X.Var->getType();
7803 assert(XTy->isPointerTy() &&
7804 "OMP Atomic expects a pointer to target memory");
7805 Type *XElemTy = X.ElemTy;
7806 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
7807 XElemTy->isPointerTy()) &&
7808 "OMP atomic update expected a scalar type");
7809 assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
7810 (RMWOp != AtomicRMWInst::UMax) && (RMWOp != AtomicRMWInst::UMin) &&
7811 "OpenMP atomic does not support LT or GT operations");
7812 });
7813
7814 emitAtomicUpdate(AllocaIP, X: X.Var, XElemTy: X.ElemTy, Expr, AO, RMWOp, UpdateOp,
7815 VolatileX: X.IsVolatile, IsXBinopExpr);
7816 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Update);
7817 return Builder.saveIP();
7818}
7819
7820// FIXME: Duplicating AtomicExpand
7821Value *OpenMPIRBuilder::emitRMWOpAsInstruction(Value *Src1, Value *Src2,
7822 AtomicRMWInst::BinOp RMWOp) {
7823 switch (RMWOp) {
7824 case AtomicRMWInst::Add:
7825 return Builder.CreateAdd(LHS: Src1, RHS: Src2);
7826 case AtomicRMWInst::Sub:
7827 return Builder.CreateSub(LHS: Src1, RHS: Src2);
7828 case AtomicRMWInst::And:
7829 return Builder.CreateAnd(LHS: Src1, RHS: Src2);
7830 case AtomicRMWInst::Nand:
7831 return Builder.CreateNeg(V: Builder.CreateAnd(LHS: Src1, RHS: Src2));
7832 case AtomicRMWInst::Or:
7833 return Builder.CreateOr(LHS: Src1, RHS: Src2);
7834 case AtomicRMWInst::Xor:
7835 return Builder.CreateXor(LHS: Src1, RHS: Src2);
7836 case AtomicRMWInst::Xchg:
7837 case AtomicRMWInst::FAdd:
7838 case AtomicRMWInst::FSub:
7839 case AtomicRMWInst::BAD_BINOP:
7840 case AtomicRMWInst::Max:
7841 case AtomicRMWInst::Min:
7842 case AtomicRMWInst::UMax:
7843 case AtomicRMWInst::UMin:
7844 case AtomicRMWInst::FMax:
7845 case AtomicRMWInst::FMin:
7846 case AtomicRMWInst::UIncWrap:
7847 case AtomicRMWInst::UDecWrap:
7848 llvm_unreachable("Unsupported atomic update operation");
7849 }
7850 llvm_unreachable("Unsupported atomic update operation");
7851}
7852
7853std::pair<Value *, Value *> OpenMPIRBuilder::emitAtomicUpdate(
7854 InsertPointTy AllocaIP, Value *X, Type *XElemTy, Value *Expr,
7855 AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
7856 AtomicUpdateCallbackTy &UpdateOp, bool VolatileX, bool IsXBinopExpr) {
7857 // TODO: handle the case where XElemTy is not byte-sized or not a power of 2
7858 // or a complex datatype.
7859 bool emitRMWOp = false;
7860 switch (RMWOp) {
7861 case AtomicRMWInst::Add:
7862 case AtomicRMWInst::And:
7863 case AtomicRMWInst::Nand:
7864 case AtomicRMWInst::Or:
7865 case AtomicRMWInst::Xor:
7866 case AtomicRMWInst::Xchg:
7867 emitRMWOp = XElemTy;
7868 break;
7869 case AtomicRMWInst::Sub:
7870 emitRMWOp = (IsXBinopExpr && XElemTy);
7871 break;
7872 default:
7873 emitRMWOp = false;
7874 }
7875 emitRMWOp &= XElemTy->isIntegerTy();
7876
7877 std::pair<Value *, Value *> Res;
7878 if (emitRMWOp) {
7879 Res.first = Builder.CreateAtomicRMW(Op: RMWOp, Ptr: X, Val: Expr, Align: llvm::MaybeAlign(), Ordering: AO);
7880 // not needed except in case of postfix captures. Generate anyway for
7881 // consistency with the else part. Will be removed with any DCE pass.
7882 // AtomicRMWInst::Xchg does not have a coressponding instruction.
7883 if (RMWOp == AtomicRMWInst::Xchg)
7884 Res.second = Res.first;
7885 else
7886 Res.second = emitRMWOpAsInstruction(Src1: Res.first, Src2: Expr, RMWOp);
7887 } else {
7888 IntegerType *IntCastTy =
7889 IntegerType::get(C&: M.getContext(), NumBits: XElemTy->getScalarSizeInBits());
7890 LoadInst *OldVal =
7891 Builder.CreateLoad(Ty: IntCastTy, Ptr: X, Name: X->getName() + ".atomic.load");
7892 OldVal->setAtomic(Ordering: AO);
7893 // CurBB
7894 // | /---\
7895 // ContBB |
7896 // | \---/
7897 // ExitBB
7898 BasicBlock *CurBB = Builder.GetInsertBlock();
7899 Instruction *CurBBTI = CurBB->getTerminator();
7900 CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
7901 BasicBlock *ExitBB =
7902 CurBB->splitBasicBlock(I: CurBBTI, BBName: X->getName() + ".atomic.exit");
7903 BasicBlock *ContBB = CurBB->splitBasicBlock(I: CurBB->getTerminator(),
7904 BBName: X->getName() + ".atomic.cont");
7905 ContBB->getTerminator()->eraseFromParent();
7906 Builder.restoreIP(IP: AllocaIP);
7907 AllocaInst *NewAtomicAddr = Builder.CreateAlloca(Ty: XElemTy);
7908 NewAtomicAddr->setName(X->getName() + "x.new.val");
7909 Builder.SetInsertPoint(ContBB);
7910 llvm::PHINode *PHI = Builder.CreatePHI(Ty: OldVal->getType(), NumReservedValues: 2);
7911 PHI->addIncoming(V: OldVal, BB: CurBB);
7912 bool IsIntTy = XElemTy->isIntegerTy();
7913 Value *OldExprVal = PHI;
7914 if (!IsIntTy) {
7915 if (XElemTy->isFloatingPointTy()) {
7916 OldExprVal = Builder.CreateBitCast(V: PHI, DestTy: XElemTy,
7917 Name: X->getName() + ".atomic.fltCast");
7918 } else {
7919 OldExprVal = Builder.CreateIntToPtr(V: PHI, DestTy: XElemTy,
7920 Name: X->getName() + ".atomic.ptrCast");
7921 }
7922 }
7923
7924 Value *Upd = UpdateOp(OldExprVal, Builder);
7925 Builder.CreateStore(Val: Upd, Ptr: NewAtomicAddr);
7926 LoadInst *DesiredVal = Builder.CreateLoad(Ty: IntCastTy, Ptr: NewAtomicAddr);
7927 AtomicOrdering Failure =
7928 llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(SuccessOrdering: AO);
7929 AtomicCmpXchgInst *Result = Builder.CreateAtomicCmpXchg(
7930 Ptr: X, Cmp: PHI, New: DesiredVal, Align: llvm::MaybeAlign(), SuccessOrdering: AO, FailureOrdering: Failure);
7931 Result->setVolatile(VolatileX);
7932 Value *PreviousVal = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/0);
7933 Value *SuccessFailureVal = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/1);
7934 PHI->addIncoming(V: PreviousVal, BB: Builder.GetInsertBlock());
7935 Builder.CreateCondBr(Cond: SuccessFailureVal, True: ExitBB, False: ContBB);
7936
7937 Res.first = OldExprVal;
7938 Res.second = Upd;
7939
7940 // set Insertion point in exit block
7941 if (UnreachableInst *ExitTI =
7942 dyn_cast<UnreachableInst>(Val: ExitBB->getTerminator())) {
7943 CurBBTI->eraseFromParent();
7944 Builder.SetInsertPoint(ExitBB);
7945 } else {
7946 Builder.SetInsertPoint(ExitTI);
7947 }
7948 }
7949
7950 return Res;
7951}
7952
7953OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCapture(
7954 const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
7955 AtomicOpValue &V, Value *Expr, AtomicOrdering AO,
7956 AtomicRMWInst::BinOp RMWOp, AtomicUpdateCallbackTy &UpdateOp,
7957 bool UpdateExpr, bool IsPostfixUpdate, bool IsXBinopExpr) {
7958 if (!updateToLocation(Loc))
7959 return Loc.IP;
7960
7961 LLVM_DEBUG({
7962 Type *XTy = X.Var->getType();
7963 assert(XTy->isPointerTy() &&
7964 "OMP Atomic expects a pointer to target memory");
7965 Type *XElemTy = X.ElemTy;
7966 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
7967 XElemTy->isPointerTy()) &&
7968 "OMP atomic capture expected a scalar type");
7969 assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
7970 "OpenMP atomic does not support LT or GT operations");
7971 });
7972
7973 // If UpdateExpr is 'x' updated with some `expr` not based on 'x',
7974 // 'x' is simply atomically rewritten with 'expr'.
7975 AtomicRMWInst::BinOp AtomicOp = (UpdateExpr ? RMWOp : AtomicRMWInst::Xchg);
7976 std::pair<Value *, Value *> Result =
7977 emitAtomicUpdate(AllocaIP, X: X.Var, XElemTy: X.ElemTy, Expr, AO, RMWOp: AtomicOp, UpdateOp,
7978 VolatileX: X.IsVolatile, IsXBinopExpr);
7979
7980 Value *CapturedVal = (IsPostfixUpdate ? Result.first : Result.second);
7981 Builder.CreateStore(Val: CapturedVal, Ptr: V.Var, isVolatile: V.IsVolatile);
7982
7983 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Capture);
7984 return Builder.saveIP();
7985}
7986
7987OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
7988 const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
7989 AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
7990 omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
7991 bool IsFailOnly) {
7992
7993 AtomicOrdering Failure = AtomicCmpXchgInst::getStrongestFailureOrdering(SuccessOrdering: AO);
7994 return createAtomicCompare(Loc, X, V, R, E, D, AO, Op, IsXBinopExpr,
7995 IsPostfixUpdate, IsFailOnly, Failure);
7996}
7997
7998OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
7999 const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
8000 AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
8001 omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
8002 bool IsFailOnly, AtomicOrdering Failure) {
8003
8004 if (!updateToLocation(Loc))
8005 return Loc.IP;
8006
8007 assert(X.Var->getType()->isPointerTy() &&
8008 "OMP atomic expects a pointer to target memory");
8009 // compare capture
8010 if (V.Var) {
8011 assert(V.Var->getType()->isPointerTy() && "v.var must be of pointer type");
8012 assert(V.ElemTy == X.ElemTy && "x and v must be of same type");
8013 }
8014
8015 bool IsInteger = E->getType()->isIntegerTy();
8016
8017 if (Op == OMPAtomicCompareOp::EQ) {
8018 AtomicCmpXchgInst *Result = nullptr;
8019 if (!IsInteger) {
8020 IntegerType *IntCastTy =
8021 IntegerType::get(C&: M.getContext(), NumBits: X.ElemTy->getScalarSizeInBits());
8022 Value *EBCast = Builder.CreateBitCast(V: E, DestTy: IntCastTy);
8023 Value *DBCast = Builder.CreateBitCast(V: D, DestTy: IntCastTy);
8024 Result = Builder.CreateAtomicCmpXchg(Ptr: X.Var, Cmp: EBCast, New: DBCast, Align: MaybeAlign(),
8025 SuccessOrdering: AO, FailureOrdering: Failure);
8026 } else {
8027 Result =
8028 Builder.CreateAtomicCmpXchg(Ptr: X.Var, Cmp: E, New: D, Align: MaybeAlign(), SuccessOrdering: AO, FailureOrdering: Failure);
8029 }
8030
8031 if (V.Var) {
8032 Value *OldValue = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/0);
8033 if (!IsInteger)
8034 OldValue = Builder.CreateBitCast(V: OldValue, DestTy: X.ElemTy);
8035 assert(OldValue->getType() == V.ElemTy &&
8036 "OldValue and V must be of same type");
8037 if (IsPostfixUpdate) {
8038 Builder.CreateStore(Val: OldValue, Ptr: V.Var, isVolatile: V.IsVolatile);
8039 } else {
8040 Value *SuccessOrFail = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/1);
8041 if (IsFailOnly) {
8042 // CurBB----
8043 // | |
8044 // v |
8045 // ContBB |
8046 // | |
8047 // v |
8048 // ExitBB <-
8049 //
8050 // where ContBB only contains the store of old value to 'v'.
8051 BasicBlock *CurBB = Builder.GetInsertBlock();
8052 Instruction *CurBBTI = CurBB->getTerminator();
8053 CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
8054 BasicBlock *ExitBB = CurBB->splitBasicBlock(
8055 I: CurBBTI, BBName: X.Var->getName() + ".atomic.exit");
8056 BasicBlock *ContBB = CurBB->splitBasicBlock(
8057 I: CurBB->getTerminator(), BBName: X.Var->getName() + ".atomic.cont");
8058 ContBB->getTerminator()->eraseFromParent();
8059 CurBB->getTerminator()->eraseFromParent();
8060
8061 Builder.CreateCondBr(Cond: SuccessOrFail, True: ExitBB, False: ContBB);
8062
8063 Builder.SetInsertPoint(ContBB);
8064 Builder.CreateStore(Val: OldValue, Ptr: V.Var);
8065 Builder.CreateBr(Dest: ExitBB);
8066
8067 if (UnreachableInst *ExitTI =
8068 dyn_cast<UnreachableInst>(Val: ExitBB->getTerminator())) {
8069 CurBBTI->eraseFromParent();
8070 Builder.SetInsertPoint(ExitBB);
8071 } else {
8072 Builder.SetInsertPoint(ExitTI);
8073 }
8074 } else {
8075 Value *CapturedValue =
8076 Builder.CreateSelect(C: SuccessOrFail, True: E, False: OldValue);
8077 Builder.CreateStore(Val: CapturedValue, Ptr: V.Var, isVolatile: V.IsVolatile);
8078 }
8079 }
8080 }
8081 // The comparison result has to be stored.
8082 if (R.Var) {
8083 assert(R.Var->getType()->isPointerTy() &&
8084 "r.var must be of pointer type");
8085 assert(R.ElemTy->isIntegerTy() && "r must be of integral type");
8086
8087 Value *SuccessFailureVal = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/1);
8088 Value *ResultCast = R.IsSigned
8089 ? Builder.CreateSExt(V: SuccessFailureVal, DestTy: R.ElemTy)
8090 : Builder.CreateZExt(V: SuccessFailureVal, DestTy: R.ElemTy);
8091 Builder.CreateStore(Val: ResultCast, Ptr: R.Var, isVolatile: R.IsVolatile);
8092 }
8093 } else {
8094 assert((Op == OMPAtomicCompareOp::MAX || Op == OMPAtomicCompareOp::MIN) &&
8095 "Op should be either max or min at this point");
8096 assert(!IsFailOnly && "IsFailOnly is only valid when the comparison is ==");
8097
8098 // Reverse the ordop as the OpenMP forms are different from LLVM forms.
8099 // Let's take max as example.
8100 // OpenMP form:
8101 // x = x > expr ? expr : x;
8102 // LLVM form:
8103 // *ptr = *ptr > val ? *ptr : val;
8104 // We need to transform to LLVM form.
8105 // x = x <= expr ? x : expr;
8106 AtomicRMWInst::BinOp NewOp;
8107 if (IsXBinopExpr) {
8108 if (IsInteger) {
8109 if (X.IsSigned)
8110 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Min
8111 : AtomicRMWInst::Max;
8112 else
8113 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMin
8114 : AtomicRMWInst::UMax;
8115 } else {
8116 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMin
8117 : AtomicRMWInst::FMax;
8118 }
8119 } else {
8120 if (IsInteger) {
8121 if (X.IsSigned)
8122 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Max
8123 : AtomicRMWInst::Min;
8124 else
8125 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMax
8126 : AtomicRMWInst::UMin;
8127 } else {
8128 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMax
8129 : AtomicRMWInst::FMin;
8130 }
8131 }
8132
8133 AtomicRMWInst *OldValue =
8134 Builder.CreateAtomicRMW(Op: NewOp, Ptr: X.Var, Val: E, Align: MaybeAlign(), Ordering: AO);
8135 if (V.Var) {
8136 Value *CapturedValue = nullptr;
8137 if (IsPostfixUpdate) {
8138 CapturedValue = OldValue;
8139 } else {
8140 CmpInst::Predicate Pred;
8141 switch (NewOp) {
8142 case AtomicRMWInst::Max:
8143 Pred = CmpInst::ICMP_SGT;
8144 break;
8145 case AtomicRMWInst::UMax:
8146 Pred = CmpInst::ICMP_UGT;
8147 break;
8148 case AtomicRMWInst::FMax:
8149 Pred = CmpInst::FCMP_OGT;
8150 break;
8151 case AtomicRMWInst::Min:
8152 Pred = CmpInst::ICMP_SLT;
8153 break;
8154 case AtomicRMWInst::UMin:
8155 Pred = CmpInst::ICMP_ULT;
8156 break;
8157 case AtomicRMWInst::FMin:
8158 Pred = CmpInst::FCMP_OLT;
8159 break;
8160 default:
8161 llvm_unreachable("unexpected comparison op");
8162 }
8163 Value *NonAtomicCmp = Builder.CreateCmp(Pred, LHS: OldValue, RHS: E);
8164 CapturedValue = Builder.CreateSelect(C: NonAtomicCmp, True: E, False: OldValue);
8165 }
8166 Builder.CreateStore(Val: CapturedValue, Ptr: V.Var, isVolatile: V.IsVolatile);
8167 }
8168 }
8169
8170 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Compare);
8171
8172 return Builder.saveIP();
8173}
8174
8175OpenMPIRBuilder::InsertPointTy
8176OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
8177 BodyGenCallbackTy BodyGenCB, Value *NumTeamsLower,
8178 Value *NumTeamsUpper, Value *ThreadLimit,
8179 Value *IfExpr) {
8180 if (!updateToLocation(Loc))
8181 return InsertPointTy();
8182
8183 uint32_t SrcLocStrSize;
8184 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
8185 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
8186 Function *CurrentFunction = Builder.GetInsertBlock()->getParent();
8187
8188 // Outer allocation basicblock is the entry block of the current function.
8189 BasicBlock &OuterAllocaBB = CurrentFunction->getEntryBlock();
8190 if (&OuterAllocaBB == Builder.GetInsertBlock()) {
8191 BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, Name: "teams.entry");
8192 Builder.SetInsertPoint(TheBB: BodyBB, IP: BodyBB->begin());
8193 }
8194
8195 // The current basic block is split into four basic blocks. After outlining,
8196 // they will be mapped as follows:
8197 // ```
8198 // def current_fn() {
8199 // current_basic_block:
8200 // br label %teams.exit
8201 // teams.exit:
8202 // ; instructions after teams
8203 // }
8204 //
8205 // def outlined_fn() {
8206 // teams.alloca:
8207 // br label %teams.body
8208 // teams.body:
8209 // ; instructions within teams body
8210 // }
8211 // ```
8212 BasicBlock *ExitBB = splitBB(Builder, /*CreateBranch=*/true, Name: "teams.exit");
8213 BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, Name: "teams.body");
8214 BasicBlock *AllocaBB =
8215 splitBB(Builder, /*CreateBranch=*/true, Name: "teams.alloca");
8216
8217 bool SubClausesPresent =
8218 (NumTeamsLower || NumTeamsUpper || ThreadLimit || IfExpr);
8219 // Push num_teams
8220 if (!Config.isTargetDevice() && SubClausesPresent) {
8221 assert((NumTeamsLower == nullptr || NumTeamsUpper != nullptr) &&
8222 "if lowerbound is non-null, then upperbound must also be non-null "
8223 "for bounds on num_teams");
8224
8225 if (NumTeamsUpper == nullptr)
8226 NumTeamsUpper = Builder.getInt32(C: 0);
8227
8228 if (NumTeamsLower == nullptr)
8229 NumTeamsLower = NumTeamsUpper;
8230
8231 if (IfExpr) {
8232 assert(IfExpr->getType()->isIntegerTy() &&
8233 "argument to if clause must be an integer value");
8234
8235 // upper = ifexpr ? upper : 1
8236 if (IfExpr->getType() != Int1)
8237 IfExpr = Builder.CreateICmpNE(LHS: IfExpr,
8238 RHS: ConstantInt::get(Ty: IfExpr->getType(), V: 0));
8239 NumTeamsUpper = Builder.CreateSelect(
8240 C: IfExpr, True: NumTeamsUpper, False: Builder.getInt32(C: 1), Name: "numTeamsUpper");
8241
8242 // lower = ifexpr ? lower : 1
8243 NumTeamsLower = Builder.CreateSelect(
8244 C: IfExpr, True: NumTeamsLower, False: Builder.getInt32(C: 1), Name: "numTeamsLower");
8245 }
8246
8247 if (ThreadLimit == nullptr)
8248 ThreadLimit = Builder.getInt32(C: 0);
8249
8250 Value *ThreadNum = getOrCreateThreadID(Ident);
8251 Builder.CreateCall(
8252 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_push_num_teams_51),
8253 Args: {Ident, ThreadNum, NumTeamsLower, NumTeamsUpper, ThreadLimit});
8254 }
8255 // Generate the body of teams.
8256 InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
8257 InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
8258 BodyGenCB(AllocaIP, CodeGenIP);
8259
8260 OutlineInfo OI;
8261 OI.EntryBB = AllocaBB;
8262 OI.ExitBB = ExitBB;
8263 OI.OuterAllocaBB = &OuterAllocaBB;
8264
8265 // Insert fake values for global tid and bound tid.
8266 SmallVector<Instruction *, 8> ToBeDeleted;
8267 InsertPointTy OuterAllocaIP(&OuterAllocaBB, OuterAllocaBB.begin());
8268 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
8269 Builder, OuterAllocaIP, ToBeDeleted, InnerAllocaIP: AllocaIP, Name: "gid", AsPtr: true));
8270 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
8271 Builder, OuterAllocaIP, ToBeDeleted, InnerAllocaIP: AllocaIP, Name: "tid", AsPtr: true));
8272
8273 auto HostPostOutlineCB = [this, Ident,
8274 ToBeDeleted](Function &OutlinedFn) mutable {
8275 // The stale call instruction will be replaced with a new call instruction
8276 // for runtime call with the outlined function.
8277
8278 assert(OutlinedFn.getNumUses() == 1 &&
8279 "there must be a single user for the outlined function");
8280 CallInst *StaleCI = cast<CallInst>(Val: OutlinedFn.user_back());
8281 ToBeDeleted.push_back(Elt: StaleCI);
8282
8283 assert((OutlinedFn.arg_size() == 2 || OutlinedFn.arg_size() == 3) &&
8284 "Outlined function must have two or three arguments only");
8285
8286 bool HasShared = OutlinedFn.arg_size() == 3;
8287
8288 OutlinedFn.getArg(i: 0)->setName("global.tid.ptr");
8289 OutlinedFn.getArg(i: 1)->setName("bound.tid.ptr");
8290 if (HasShared)
8291 OutlinedFn.getArg(i: 2)->setName("data");
8292
8293 // Call to the runtime function for teams in the current function.
8294 assert(StaleCI && "Error while outlining - no CallInst user found for the "
8295 "outlined function.");
8296 Builder.SetInsertPoint(StaleCI);
8297 SmallVector<Value *> Args = {
8298 Ident, Builder.getInt32(C: StaleCI->arg_size() - 2), &OutlinedFn};
8299 if (HasShared)
8300 Args.push_back(Elt: StaleCI->getArgOperand(i: 2));
8301 Builder.CreateCall(Callee: getOrCreateRuntimeFunctionPtr(
8302 FnID: omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
8303 Args);
8304
8305 llvm::for_each(Range: llvm::reverse(C&: ToBeDeleted),
8306 F: [](Instruction *I) { I->eraseFromParent(); });
8307
8308 };
8309
8310 if (!Config.isTargetDevice())
8311 OI.PostOutlineCB = HostPostOutlineCB;
8312
8313 addOutlineInfo(OI: std::move(OI));
8314
8315 Builder.SetInsertPoint(TheBB: ExitBB, IP: ExitBB->begin());
8316
8317 return Builder.saveIP();
8318}
8319
8320GlobalVariable *
8321OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl<llvm::Constant *> &Names,
8322 std::string VarName) {
8323 llvm::Constant *MapNamesArrayInit = llvm::ConstantArray::get(
8324 T: llvm::ArrayType::get(ElementType: llvm::PointerType::getUnqual(C&: M.getContext()),
8325 NumElements: Names.size()),
8326 V: Names);
8327 auto *MapNamesArrayGlobal = new llvm::GlobalVariable(
8328 M, MapNamesArrayInit->getType(),
8329 /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MapNamesArrayInit,
8330 VarName);
8331 return MapNamesArrayGlobal;
8332}
8333
8334// Create all simple and struct types exposed by the runtime and remember
8335// the llvm::PointerTypes of them for easy access later.
8336void OpenMPIRBuilder::initializeTypes(Module &M) {
8337 LLVMContext &Ctx = M.getContext();
8338 StructType *T;
8339#define OMP_TYPE(VarName, InitValue) VarName = InitValue;
8340#define OMP_ARRAY_TYPE(VarName, ElemTy, ArraySize) \
8341 VarName##Ty = ArrayType::get(ElemTy, ArraySize); \
8342 VarName##PtrTy = PointerType::getUnqual(VarName##Ty);
8343#define OMP_FUNCTION_TYPE(VarName, IsVarArg, ReturnType, ...) \
8344 VarName = FunctionType::get(ReturnType, {__VA_ARGS__}, IsVarArg); \
8345 VarName##Ptr = PointerType::getUnqual(VarName);
8346#define OMP_STRUCT_TYPE(VarName, StructName, Packed, ...) \
8347 T = StructType::getTypeByName(Ctx, StructName); \
8348 if (!T) \
8349 T = StructType::create(Ctx, {__VA_ARGS__}, StructName, Packed); \
8350 VarName = T; \
8351 VarName##Ptr = PointerType::getUnqual(T);
8352#include "llvm/Frontend/OpenMP/OMPKinds.def"
8353}
8354
8355void OpenMPIRBuilder::OutlineInfo::collectBlocks(
8356 SmallPtrSetImpl<BasicBlock *> &BlockSet,
8357 SmallVectorImpl<BasicBlock *> &BlockVector) {
8358 SmallVector<BasicBlock *, 32> Worklist;
8359 BlockSet.insert(Ptr: EntryBB);
8360 BlockSet.insert(Ptr: ExitBB);
8361
8362 Worklist.push_back(Elt: EntryBB);
8363 while (!Worklist.empty()) {
8364 BasicBlock *BB = Worklist.pop_back_val();
8365 BlockVector.push_back(Elt: BB);
8366 for (BasicBlock *SuccBB : successors(BB))
8367 if (BlockSet.insert(Ptr: SuccBB).second)
8368 Worklist.push_back(Elt: SuccBB);
8369 }
8370}
8371
8372void OpenMPIRBuilder::createOffloadEntry(Constant *ID, Constant *Addr,
8373 uint64_t Size, int32_t Flags,
8374 GlobalValue::LinkageTypes,
8375 StringRef Name) {
8376 if (!Config.isGPU()) {
8377 llvm::offloading::emitOffloadingEntry(
8378 M, Addr: ID, Name: Name.empty() ? Addr->getName() : Name, Size, Flags, /*Data=*/0,
8379 SectionName: "omp_offloading_entries");
8380 return;
8381 }
8382 // TODO: Add support for global variables on the device after declare target
8383 // support.
8384 Function *Fn = dyn_cast<Function>(Val: Addr);
8385 if (!Fn)
8386 return;
8387
8388 Module &M = *(Fn->getParent());
8389 LLVMContext &Ctx = M.getContext();
8390
8391 // Get "nvvm.annotations" metadata node.
8392 NamedMDNode *MD = M.getOrInsertNamedMetadata(Name: "nvvm.annotations");
8393
8394 Metadata *MDVals[] = {
8395 ConstantAsMetadata::get(C: Fn), MDString::get(Context&: Ctx, Str: "kernel"),
8396 ConstantAsMetadata::get(C: ConstantInt::get(Ty: Type::getInt32Ty(C&: Ctx), V: 1))};
8397 // Append metadata to nvvm.annotations.
8398 MD->addOperand(M: MDNode::get(Context&: Ctx, MDs: MDVals));
8399
8400 // Add a function attribute for the kernel.
8401 Fn->addFnAttr(Attr: Attribute::get(Context&: Ctx, Kind: "kernel"));
8402 if (T.isAMDGCN())
8403 Fn->addFnAttr(Kind: "uniform-work-group-size", Val: "true");
8404 Fn->addFnAttr(Kind: Attribute::MustProgress);
8405}
8406
8407// We only generate metadata for function that contain target regions.
8408void OpenMPIRBuilder::createOffloadEntriesAndInfoMetadata(
8409 EmitMetadataErrorReportFunctionTy &ErrorFn) {
8410
8411 // If there are no entries, we don't need to do anything.
8412 if (OffloadInfoManager.empty())
8413 return;
8414
8415 LLVMContext &C = M.getContext();
8416 SmallVector<std::pair<const OffloadEntriesInfoManager::OffloadEntryInfo *,
8417 TargetRegionEntryInfo>,
8418 16>
8419 OrderedEntries(OffloadInfoManager.size());
8420
8421 // Auxiliary methods to create metadata values and strings.
8422 auto &&GetMDInt = [this](unsigned V) {
8423 return ConstantAsMetadata::get(C: ConstantInt::get(Ty: Builder.getInt32Ty(), V));
8424 };
8425
8426 auto &&GetMDString = [&C](StringRef V) { return MDString::get(Context&: C, Str: V); };
8427
8428 // Create the offloading info metadata node.
8429 NamedMDNode *MD = M.getOrInsertNamedMetadata(Name: "omp_offload.info");
8430 auto &&TargetRegionMetadataEmitter =
8431 [&C, MD, &OrderedEntries, &GetMDInt, &GetMDString](
8432 const TargetRegionEntryInfo &EntryInfo,
8433 const OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion &E) {
8434 // Generate metadata for target regions. Each entry of this metadata
8435 // contains:
8436 // - Entry 0 -> Kind of this type of metadata (0).
8437 // - Entry 1 -> Device ID of the file where the entry was identified.
8438 // - Entry 2 -> File ID of the file where the entry was identified.
8439 // - Entry 3 -> Mangled name of the function where the entry was
8440 // identified.
8441 // - Entry 4 -> Line in the file where the entry was identified.
8442 // - Entry 5 -> Count of regions at this DeviceID/FilesID/Line.
8443 // - Entry 6 -> Order the entry was created.
8444 // The first element of the metadata node is the kind.
8445 Metadata *Ops[] = {
8446 GetMDInt(E.getKind()), GetMDInt(EntryInfo.DeviceID),
8447 GetMDInt(EntryInfo.FileID), GetMDString(EntryInfo.ParentName),
8448 GetMDInt(EntryInfo.Line), GetMDInt(EntryInfo.Count),
8449 GetMDInt(E.getOrder())};
8450
8451 // Save this entry in the right position of the ordered entries array.
8452 OrderedEntries[E.getOrder()] = std::make_pair(x: &E, y: EntryInfo);
8453
8454 // Add metadata to the named metadata node.
8455 MD->addOperand(M: MDNode::get(Context&: C, MDs: Ops));
8456 };
8457
8458 OffloadInfoManager.actOnTargetRegionEntriesInfo(Action: TargetRegionMetadataEmitter);
8459
8460 // Create function that emits metadata for each device global variable entry;
8461 auto &&DeviceGlobalVarMetadataEmitter =
8462 [&C, &OrderedEntries, &GetMDInt, &GetMDString, MD](
8463 StringRef MangledName,
8464 const OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar &E) {
8465 // Generate metadata for global variables. Each entry of this metadata
8466 // contains:
8467 // - Entry 0 -> Kind of this type of metadata (1).
8468 // - Entry 1 -> Mangled name of the variable.
8469 // - Entry 2 -> Declare target kind.
8470 // - Entry 3 -> Order the entry was created.
8471 // The first element of the metadata node is the kind.
8472 Metadata *Ops[] = {GetMDInt(E.getKind()), GetMDString(MangledName),
8473 GetMDInt(E.getFlags()), GetMDInt(E.getOrder())};
8474
8475 // Save this entry in the right position of the ordered entries array.
8476 TargetRegionEntryInfo varInfo(MangledName, 0, 0, 0);
8477 OrderedEntries[E.getOrder()] = std::make_pair(x: &E, y&: varInfo);
8478
8479 // Add metadata to the named metadata node.
8480 MD->addOperand(M: MDNode::get(Context&: C, MDs: Ops));
8481 };
8482
8483 OffloadInfoManager.actOnDeviceGlobalVarEntriesInfo(
8484 Action: DeviceGlobalVarMetadataEmitter);
8485
8486 for (const auto &E : OrderedEntries) {
8487 assert(E.first && "All ordered entries must exist!");
8488 if (const auto *CE =
8489 dyn_cast<OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion>(
8490 Val: E.first)) {
8491 if (!CE->getID() || !CE->getAddress()) {
8492 // Do not blame the entry if the parent funtion is not emitted.
8493 TargetRegionEntryInfo EntryInfo = E.second;
8494 StringRef FnName = EntryInfo.ParentName;
8495 if (!M.getNamedValue(Name: FnName))
8496 continue;
8497 ErrorFn(EMIT_MD_TARGET_REGION_ERROR, EntryInfo);
8498 continue;
8499 }
8500 createOffloadEntry(ID: CE->getID(), Addr: CE->getAddress(),
8501 /*Size=*/0, Flags: CE->getFlags(),
8502 GlobalValue::WeakAnyLinkage);
8503 } else if (const auto *CE = dyn_cast<
8504 OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar>(
8505 Val: E.first)) {
8506 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags =
8507 static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
8508 CE->getFlags());
8509 switch (Flags) {
8510 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter:
8511 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo:
8512 if (Config.isTargetDevice() && Config.hasRequiresUnifiedSharedMemory())
8513 continue;
8514 if (!CE->getAddress()) {
8515 ErrorFn(EMIT_MD_DECLARE_TARGET_ERROR, E.second);
8516 continue;
8517 }
8518 // The vaiable has no definition - no need to add the entry.
8519 if (CE->getVarSize() == 0)
8520 continue;
8521 break;
8522 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink:
8523 assert(((Config.isTargetDevice() && !CE->getAddress()) ||
8524 (!Config.isTargetDevice() && CE->getAddress())) &&
8525 "Declaret target link address is set.");
8526 if (Config.isTargetDevice())
8527 continue;
8528 if (!CE->getAddress()) {
8529 ErrorFn(EMIT_MD_GLOBAL_VAR_LINK_ERROR, TargetRegionEntryInfo());
8530 continue;
8531 }
8532 break;
8533 default:
8534 break;
8535 }
8536
8537 // Hidden or internal symbols on the device are not externally visible.
8538 // We should not attempt to register them by creating an offloading
8539 // entry. Indirect variables are handled separately on the device.
8540 if (auto *GV = dyn_cast<GlobalValue>(Val: CE->getAddress()))
8541 if ((GV->hasLocalLinkage() || GV->hasHiddenVisibility()) &&
8542 Flags != OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect)
8543 continue;
8544
8545 // Indirect globals need to use a special name that doesn't match the name
8546 // of the associated host global.
8547 if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect)
8548 createOffloadEntry(ID: CE->getAddress(), Addr: CE->getAddress(), Size: CE->getVarSize(),
8549 Flags, CE->getLinkage(), Name: CE->getVarName());
8550 else
8551 createOffloadEntry(ID: CE->getAddress(), Addr: CE->getAddress(), Size: CE->getVarSize(),
8552 Flags, CE->getLinkage());
8553
8554 } else {
8555 llvm_unreachable("Unsupported entry kind.");
8556 }
8557 }
8558
8559 // Emit requires directive globals to a special entry so the runtime can
8560 // register them when the device image is loaded.
8561 // TODO: This reduces the offloading entries to a 32-bit integer. Offloading
8562 // entries should be redesigned to better suit this use-case.
8563 if (Config.hasRequiresFlags() && !Config.isTargetDevice())
8564 offloading::emitOffloadingEntry(
8565 M, Addr: Constant::getNullValue(Ty: PointerType::getUnqual(C&: M.getContext())),
8566 /*Name=*/"",
8567 /*Size=*/0, Flags: OffloadEntriesInfoManager::OMPTargetGlobalRegisterRequires,
8568 Data: Config.getRequiresFlags(), SectionName: "omp_offloading_entries");
8569}
8570
8571void TargetRegionEntryInfo::getTargetRegionEntryFnName(
8572 SmallVectorImpl<char> &Name, StringRef ParentName, unsigned DeviceID,
8573 unsigned FileID, unsigned Line, unsigned Count) {
8574 raw_svector_ostream OS(Name);
8575 OS << "__omp_offloading" << llvm::format(Fmt: "_%x", Vals: DeviceID)
8576 << llvm::format(Fmt: "_%x_", Vals: FileID) << ParentName << "_l" << Line;
8577 if (Count)
8578 OS << "_" << Count;
8579}
8580
8581void OffloadEntriesInfoManager::getTargetRegionEntryFnName(
8582 SmallVectorImpl<char> &Name, const TargetRegionEntryInfo &EntryInfo) {
8583 unsigned NewCount = getTargetRegionEntryInfoCount(EntryInfo);
8584 TargetRegionEntryInfo::getTargetRegionEntryFnName(
8585 Name, ParentName: EntryInfo.ParentName, DeviceID: EntryInfo.DeviceID, FileID: EntryInfo.FileID,
8586 Line: EntryInfo.Line, Count: NewCount);
8587}
8588
8589TargetRegionEntryInfo
8590OpenMPIRBuilder::getTargetEntryUniqueInfo(FileIdentifierInfoCallbackTy CallBack,
8591 StringRef ParentName) {
8592 sys::fs::UniqueID ID;
8593 auto FileIDInfo = CallBack();
8594 if (auto EC = sys::fs::getUniqueID(Path: std::get<0>(t&: FileIDInfo), Result&: ID)) {
8595 report_fatal_error(reason: ("Unable to get unique ID for file, during "
8596 "getTargetEntryUniqueInfo, error message: " +
8597 EC.message())
8598 .c_str());
8599 }
8600
8601 return TargetRegionEntryInfo(ParentName, ID.getDevice(), ID.getFile(),
8602 std::get<1>(t&: FileIDInfo));
8603}
8604
8605unsigned OpenMPIRBuilder::getFlagMemberOffset() {
8606 unsigned Offset = 0;
8607 for (uint64_t Remain =
8608 static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
8609 omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF);
8610 !(Remain & 1); Remain = Remain >> 1)
8611 Offset++;
8612 return Offset;
8613}
8614
8615omp::OpenMPOffloadMappingFlags
8616OpenMPIRBuilder::getMemberOfFlag(unsigned Position) {
8617 // Rotate by getFlagMemberOffset() bits.
8618 return static_cast<omp::OpenMPOffloadMappingFlags>(((uint64_t)Position + 1)
8619 << getFlagMemberOffset());
8620}
8621
8622void OpenMPIRBuilder::setCorrectMemberOfFlag(
8623 omp::OpenMPOffloadMappingFlags &Flags,
8624 omp::OpenMPOffloadMappingFlags MemberOfFlag) {
8625 // If the entry is PTR_AND_OBJ but has not been marked with the special
8626 // placeholder value 0xFFFF in the MEMBER_OF field, then it should not be
8627 // marked as MEMBER_OF.
8628 if (static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
8629 Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ) &&
8630 static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
8631 (Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF) !=
8632 omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF))
8633 return;
8634
8635 // Reset the placeholder value to prepare the flag for the assignment of the
8636 // proper MEMBER_OF value.
8637 Flags &= ~omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
8638 Flags |= MemberOfFlag;
8639}
8640
8641Constant *OpenMPIRBuilder::getAddrOfDeclareTargetVar(
8642 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
8643 OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
8644 bool IsDeclaration, bool IsExternallyVisible,
8645 TargetRegionEntryInfo EntryInfo, StringRef MangledName,
8646 std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
8647 std::vector<Triple> TargetTriple, Type *LlvmPtrTy,
8648 std::function<Constant *()> GlobalInitializer,
8649 std::function<GlobalValue::LinkageTypes()> VariableLinkage) {
8650 // TODO: convert this to utilise the IRBuilder Config rather than
8651 // a passed down argument.
8652 if (OpenMPSIMD)
8653 return nullptr;
8654
8655 if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink ||
8656 ((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
8657 CaptureClause ==
8658 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
8659 Config.hasRequiresUnifiedSharedMemory())) {
8660 SmallString<64> PtrName;
8661 {
8662 raw_svector_ostream OS(PtrName);
8663 OS << MangledName;
8664 if (!IsExternallyVisible)
8665 OS << format(Fmt: "_%x", Vals: EntryInfo.FileID);
8666 OS << "_decl_tgt_ref_ptr";
8667 }
8668
8669 Value *Ptr = M.getNamedValue(Name: PtrName);
8670
8671 if (!Ptr) {
8672 GlobalValue *GlobalValue = M.getNamedValue(Name: MangledName);
8673 Ptr = getOrCreateInternalVariable(Ty: LlvmPtrTy, Name: PtrName);
8674
8675 auto *GV = cast<GlobalVariable>(Val: Ptr);
8676 GV->setLinkage(GlobalValue::WeakAnyLinkage);
8677
8678 if (!Config.isTargetDevice()) {
8679 if (GlobalInitializer)
8680 GV->setInitializer(GlobalInitializer());
8681 else
8682 GV->setInitializer(GlobalValue);
8683 }
8684
8685 registerTargetGlobalVariable(
8686 CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
8687 EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
8688 GlobalInitializer, VariableLinkage, LlvmPtrTy, Addr: cast<Constant>(Val: Ptr));
8689 }
8690
8691 return cast<Constant>(Val: Ptr);
8692 }
8693
8694 return nullptr;
8695}
8696
8697void OpenMPIRBuilder::registerTargetGlobalVariable(
8698 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
8699 OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
8700 bool IsDeclaration, bool IsExternallyVisible,
8701 TargetRegionEntryInfo EntryInfo, StringRef MangledName,
8702 std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
8703 std::vector<Triple> TargetTriple,
8704 std::function<Constant *()> GlobalInitializer,
8705 std::function<GlobalValue::LinkageTypes()> VariableLinkage, Type *LlvmPtrTy,
8706 Constant *Addr) {
8707 if (DeviceClause != OffloadEntriesInfoManager::OMPTargetDeviceClauseAny ||
8708 (TargetTriple.empty() && !Config.isTargetDevice()))
8709 return;
8710
8711 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags;
8712 StringRef VarName;
8713 int64_t VarSize;
8714 GlobalValue::LinkageTypes Linkage;
8715
8716 if ((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
8717 CaptureClause ==
8718 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
8719 !Config.hasRequiresUnifiedSharedMemory()) {
8720 Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
8721 VarName = MangledName;
8722 GlobalValue *LlvmVal = M.getNamedValue(Name: VarName);
8723
8724 if (!IsDeclaration)
8725 VarSize = divideCeil(
8726 Numerator: M.getDataLayout().getTypeSizeInBits(Ty: LlvmVal->getValueType()), Denominator: 8);
8727 else
8728 VarSize = 0;
8729 Linkage = (VariableLinkage) ? VariableLinkage() : LlvmVal->getLinkage();
8730
8731 // This is a workaround carried over from Clang which prevents undesired
8732 // optimisation of internal variables.
8733 if (Config.isTargetDevice() &&
8734 (!IsExternallyVisible || Linkage == GlobalValue::LinkOnceODRLinkage)) {
8735 // Do not create a "ref-variable" if the original is not also available
8736 // on the host.
8737 if (!OffloadInfoManager.hasDeviceGlobalVarEntryInfo(VarName))
8738 return;
8739
8740 std::string RefName = createPlatformSpecificName(Parts: {VarName, "ref"});
8741
8742 if (!M.getNamedValue(Name: RefName)) {
8743 Constant *AddrRef =
8744 getOrCreateInternalVariable(Ty: Addr->getType(), Name: RefName);
8745 auto *GvAddrRef = cast<GlobalVariable>(Val: AddrRef);
8746 GvAddrRef->setConstant(true);
8747 GvAddrRef->setLinkage(GlobalValue::InternalLinkage);
8748 GvAddrRef->setInitializer(Addr);
8749 GeneratedRefs.push_back(x: GvAddrRef);
8750 }
8751 }
8752 } else {
8753 if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink)
8754 Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
8755 else
8756 Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
8757
8758 if (Config.isTargetDevice()) {
8759 VarName = (Addr) ? Addr->getName() : "";
8760 Addr = nullptr;
8761 } else {
8762 Addr = getAddrOfDeclareTargetVar(
8763 CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
8764 EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
8765 LlvmPtrTy, GlobalInitializer, VariableLinkage);
8766 VarName = (Addr) ? Addr->getName() : "";
8767 }
8768 VarSize = M.getDataLayout().getPointerSize();
8769 Linkage = GlobalValue::WeakAnyLinkage;
8770 }
8771
8772 OffloadInfoManager.registerDeviceGlobalVarEntryInfo(VarName, Addr, VarSize,
8773 Flags, Linkage);
8774}
8775
8776/// Loads all the offload entries information from the host IR
8777/// metadata.
8778void OpenMPIRBuilder::loadOffloadInfoMetadata(Module &M) {
8779 // If we are in target mode, load the metadata from the host IR. This code has
8780 // to match the metadata creation in createOffloadEntriesAndInfoMetadata().
8781
8782 NamedMDNode *MD = M.getNamedMetadata(Name: ompOffloadInfoName);
8783 if (!MD)
8784 return;
8785
8786 for (MDNode *MN : MD->operands()) {
8787 auto &&GetMDInt = [MN](unsigned Idx) {
8788 auto *V = cast<ConstantAsMetadata>(Val: MN->getOperand(I: Idx));
8789 return cast<ConstantInt>(Val: V->getValue())->getZExtValue();
8790 };
8791
8792 auto &&GetMDString = [MN](unsigned Idx) {
8793 auto *V = cast<MDString>(Val: MN->getOperand(I: Idx));
8794 return V->getString();
8795 };
8796
8797 switch (GetMDInt(0)) {
8798 default:
8799 llvm_unreachable("Unexpected metadata!");
8800 break;
8801 case OffloadEntriesInfoManager::OffloadEntryInfo::
8802 OffloadingEntryInfoTargetRegion: {
8803 TargetRegionEntryInfo EntryInfo(/*ParentName=*/GetMDString(3),
8804 /*DeviceID=*/GetMDInt(1),
8805 /*FileID=*/GetMDInt(2),
8806 /*Line=*/GetMDInt(4),
8807 /*Count=*/GetMDInt(5));
8808 OffloadInfoManager.initializeTargetRegionEntryInfo(EntryInfo,
8809 /*Order=*/GetMDInt(6));
8810 break;
8811 }
8812 case OffloadEntriesInfoManager::OffloadEntryInfo::
8813 OffloadingEntryInfoDeviceGlobalVar:
8814 OffloadInfoManager.initializeDeviceGlobalVarEntryInfo(
8815 /*MangledName=*/Name: GetMDString(1),
8816 Flags: static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
8817 /*Flags=*/GetMDInt(2)),
8818 /*Order=*/GetMDInt(3));
8819 break;
8820 }
8821 }
8822}
8823
8824void OpenMPIRBuilder::loadOffloadInfoMetadata(StringRef HostFilePath) {
8825 if (HostFilePath.empty())
8826 return;
8827
8828 auto Buf = MemoryBuffer::getFile(Filename: HostFilePath);
8829 if (std::error_code Err = Buf.getError()) {
8830 report_fatal_error(reason: ("error opening host file from host file path inside of "
8831 "OpenMPIRBuilder: " +
8832 Err.message())
8833 .c_str());
8834 }
8835
8836 LLVMContext Ctx;
8837 auto M = expectedToErrorOrAndEmitErrors(
8838 Ctx, Val: parseBitcodeFile(Buffer: Buf.get()->getMemBufferRef(), Context&: Ctx));
8839 if (std::error_code Err = M.getError()) {
8840 report_fatal_error(
8841 reason: ("error parsing host file inside of OpenMPIRBuilder: " + Err.message())
8842 .c_str());
8843 }
8844
8845 loadOffloadInfoMetadata(M&: *M.get());
8846}
8847
8848//===----------------------------------------------------------------------===//
8849// OffloadEntriesInfoManager
8850//===----------------------------------------------------------------------===//
8851
8852bool OffloadEntriesInfoManager::empty() const {
8853 return OffloadEntriesTargetRegion.empty() &&
8854 OffloadEntriesDeviceGlobalVar.empty();
8855}
8856
8857unsigned OffloadEntriesInfoManager::getTargetRegionEntryInfoCount(
8858 const TargetRegionEntryInfo &EntryInfo) const {
8859 auto It = OffloadEntriesTargetRegionCount.find(
8860 x: getTargetRegionEntryCountKey(EntryInfo));
8861 if (It == OffloadEntriesTargetRegionCount.end())
8862 return 0;
8863 return It->second;
8864}
8865
8866void OffloadEntriesInfoManager::incrementTargetRegionEntryInfoCount(
8867 const TargetRegionEntryInfo &EntryInfo) {
8868 OffloadEntriesTargetRegionCount[getTargetRegionEntryCountKey(EntryInfo)] =
8869 EntryInfo.Count + 1;
8870}
8871
8872/// Initialize target region entry.
8873void OffloadEntriesInfoManager::initializeTargetRegionEntryInfo(
8874 const TargetRegionEntryInfo &EntryInfo, unsigned Order) {
8875 OffloadEntriesTargetRegion[EntryInfo] =
8876 OffloadEntryInfoTargetRegion(Order, /*Addr=*/nullptr, /*ID=*/nullptr,
8877 OMPTargetRegionEntryTargetRegion);
8878 ++OffloadingEntriesNum;
8879}
8880
8881void OffloadEntriesInfoManager::registerTargetRegionEntryInfo(
8882 TargetRegionEntryInfo EntryInfo, Constant *Addr, Constant *ID,
8883 OMPTargetRegionEntryKind Flags) {
8884 assert(EntryInfo.Count == 0 && "expected default EntryInfo");
8885
8886 // Update the EntryInfo with the next available count for this location.
8887 EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
8888
8889 // If we are emitting code for a target, the entry is already initialized,
8890 // only has to be registered.
8891 if (OMPBuilder->Config.isTargetDevice()) {
8892 // This could happen if the device compilation is invoked standalone.
8893 if (!hasTargetRegionEntryInfo(EntryInfo)) {
8894 return;
8895 }
8896 auto &Entry = OffloadEntriesTargetRegion[EntryInfo];
8897 Entry.setAddress(Addr);
8898 Entry.setID(ID);
8899 Entry.setFlags(Flags);
8900 } else {
8901 if (Flags == OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion &&
8902 hasTargetRegionEntryInfo(EntryInfo, /*IgnoreAddressId*/ true))
8903 return;
8904 assert(!hasTargetRegionEntryInfo(EntryInfo) &&
8905 "Target region entry already registered!");
8906 OffloadEntryInfoTargetRegion Entry(OffloadingEntriesNum, Addr, ID, Flags);
8907 OffloadEntriesTargetRegion[EntryInfo] = Entry;
8908 ++OffloadingEntriesNum;
8909 }
8910 incrementTargetRegionEntryInfoCount(EntryInfo);
8911}
8912
8913bool OffloadEntriesInfoManager::hasTargetRegionEntryInfo(
8914 TargetRegionEntryInfo EntryInfo, bool IgnoreAddressId) const {
8915
8916 // Update the EntryInfo with the next available count for this location.
8917 EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
8918
8919 auto It = OffloadEntriesTargetRegion.find(x: EntryInfo);
8920 if (It == OffloadEntriesTargetRegion.end()) {
8921 return false;
8922 }
8923 // Fail if this entry is already registered.
8924 if (!IgnoreAddressId && (It->second.getAddress() || It->second.getID()))
8925 return false;
8926 return true;
8927}
8928
8929void OffloadEntriesInfoManager::actOnTargetRegionEntriesInfo(
8930 const OffloadTargetRegionEntryInfoActTy &Action) {
8931 // Scan all target region entries and perform the provided action.
8932 for (const auto &It : OffloadEntriesTargetRegion) {
8933 Action(It.first, It.second);
8934 }
8935}
8936
8937void OffloadEntriesInfoManager::initializeDeviceGlobalVarEntryInfo(
8938 StringRef Name, OMPTargetGlobalVarEntryKind Flags, unsigned Order) {
8939 OffloadEntriesDeviceGlobalVar.try_emplace(Key: Name, Args&: Order, Args&: Flags);
8940 ++OffloadingEntriesNum;
8941}
8942
8943void OffloadEntriesInfoManager::registerDeviceGlobalVarEntryInfo(
8944 StringRef VarName, Constant *Addr, int64_t VarSize,
8945 OMPTargetGlobalVarEntryKind Flags, GlobalValue::LinkageTypes Linkage) {
8946 if (OMPBuilder->Config.isTargetDevice()) {
8947 // This could happen if the device compilation is invoked standalone.
8948 if (!hasDeviceGlobalVarEntryInfo(VarName))
8949 return;
8950 auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
8951 if (Entry.getAddress() && hasDeviceGlobalVarEntryInfo(VarName)) {
8952 if (Entry.getVarSize() == 0) {
8953 Entry.setVarSize(VarSize);
8954 Entry.setLinkage(Linkage);
8955 }
8956 return;
8957 }
8958 Entry.setVarSize(VarSize);
8959 Entry.setLinkage(Linkage);
8960 Entry.setAddress(Addr);
8961 } else {
8962 if (hasDeviceGlobalVarEntryInfo(VarName)) {
8963 auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
8964 assert(Entry.isValid() && Entry.getFlags() == Flags &&
8965 "Entry not initialized!");
8966 if (Entry.getVarSize() == 0) {
8967 Entry.setVarSize(VarSize);
8968 Entry.setLinkage(Linkage);
8969 }
8970 return;
8971 }
8972 if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect)
8973 OffloadEntriesDeviceGlobalVar.try_emplace(Key: VarName, Args&: OffloadingEntriesNum,
8974 Args&: Addr, Args&: VarSize, Args&: Flags, Args&: Linkage,
8975 Args: VarName.str());
8976 else
8977 OffloadEntriesDeviceGlobalVar.try_emplace(
8978 Key: VarName, Args&: OffloadingEntriesNum, Args&: Addr, Args&: VarSize, Args&: Flags, Args&: Linkage, Args: "");
8979 ++OffloadingEntriesNum;
8980 }
8981}
8982
8983void OffloadEntriesInfoManager::actOnDeviceGlobalVarEntriesInfo(
8984 const OffloadDeviceGlobalVarEntryInfoActTy &Action) {
8985 // Scan all target region entries and perform the provided action.
8986 for (const auto &E : OffloadEntriesDeviceGlobalVar)
8987 Action(E.getKey(), E.getValue());
8988}
8989
8990//===----------------------------------------------------------------------===//
8991// CanonicalLoopInfo
8992//===----------------------------------------------------------------------===//
8993
8994void CanonicalLoopInfo::collectControlBlocks(
8995 SmallVectorImpl<BasicBlock *> &BBs) {
8996 // We only count those BBs as control block for which we do not need to
8997 // reverse the CFG, i.e. not the loop body which can contain arbitrary control
8998 // flow. For consistency, this also means we do not add the Body block, which
8999 // is just the entry to the body code.
9000 BBs.reserve(N: BBs.size() + 6);
9001 BBs.append(IL: {getPreheader(), Header, Cond, Latch, Exit, getAfter()});
9002}
9003
9004BasicBlock *CanonicalLoopInfo::getPreheader() const {
9005 assert(isValid() && "Requires a valid canonical loop");
9006 for (BasicBlock *Pred : predecessors(BB: Header)) {
9007 if (Pred != Latch)
9008 return Pred;
9009 }
9010 llvm_unreachable("Missing preheader");
9011}
9012
9013void CanonicalLoopInfo::setTripCount(Value *TripCount) {
9014 assert(isValid() && "Requires a valid canonical loop");
9015
9016 Instruction *CmpI = &getCond()->front();
9017 assert(isa<CmpInst>(CmpI) && "First inst must compare IV with TripCount");
9018 CmpI->setOperand(i: 1, Val: TripCount);
9019
9020#ifndef NDEBUG
9021 assertOK();
9022#endif
9023}
9024
9025void CanonicalLoopInfo::mapIndVar(
9026 llvm::function_ref<Value *(Instruction *)> Updater) {
9027 assert(isValid() && "Requires a valid canonical loop");
9028
9029 Instruction *OldIV = getIndVar();
9030
9031 // Record all uses excluding those introduced by the updater. Uses by the
9032 // CanonicalLoopInfo itself to keep track of the number of iterations are
9033 // excluded.
9034 SmallVector<Use *> ReplacableUses;
9035 for (Use &U : OldIV->uses()) {
9036 auto *User = dyn_cast<Instruction>(Val: U.getUser());
9037 if (!User)
9038 continue;
9039 if (User->getParent() == getCond())
9040 continue;
9041 if (User->getParent() == getLatch())
9042 continue;
9043 ReplacableUses.push_back(Elt: &U);
9044 }
9045
9046 // Run the updater that may introduce new uses
9047 Value *NewIV = Updater(OldIV);
9048
9049 // Replace the old uses with the value returned by the updater.
9050 for (Use *U : ReplacableUses)
9051 U->set(NewIV);
9052
9053#ifndef NDEBUG
9054 assertOK();
9055#endif
9056}
9057
9058void CanonicalLoopInfo::assertOK() const {
9059#ifndef NDEBUG
9060 // No constraints if this object currently does not describe a loop.
9061 if (!isValid())
9062 return;
9063
9064 BasicBlock *Preheader = getPreheader();
9065 BasicBlock *Body = getBody();
9066 BasicBlock *After = getAfter();
9067
9068 // Verify standard control-flow we use for OpenMP loops.
9069 assert(Preheader);
9070 assert(isa<BranchInst>(Preheader->getTerminator()) &&
9071 "Preheader must terminate with unconditional branch");
9072 assert(Preheader->getSingleSuccessor() == Header &&
9073 "Preheader must jump to header");
9074
9075 assert(Header);
9076 assert(isa<BranchInst>(Header->getTerminator()) &&
9077 "Header must terminate with unconditional branch");
9078 assert(Header->getSingleSuccessor() == Cond &&
9079 "Header must jump to exiting block");
9080
9081 assert(Cond);
9082 assert(Cond->getSinglePredecessor() == Header &&
9083 "Exiting block only reachable from header");
9084
9085 assert(isa<BranchInst>(Cond->getTerminator()) &&
9086 "Exiting block must terminate with conditional branch");
9087 assert(size(successors(Cond)) == 2 &&
9088 "Exiting block must have two successors");
9089 assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(0) == Body &&
9090 "Exiting block's first successor jump to the body");
9091 assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(1) == Exit &&
9092 "Exiting block's second successor must exit the loop");
9093
9094 assert(Body);
9095 assert(Body->getSinglePredecessor() == Cond &&
9096 "Body only reachable from exiting block");
9097 assert(!isa<PHINode>(Body->front()));
9098
9099 assert(Latch);
9100 assert(isa<BranchInst>(Latch->getTerminator()) &&
9101 "Latch must terminate with unconditional branch");
9102 assert(Latch->getSingleSuccessor() == Header && "Latch must jump to header");
9103 // TODO: To support simple redirecting of the end of the body code that has
9104 // multiple; introduce another auxiliary basic block like preheader and after.
9105 assert(Latch->getSinglePredecessor() != nullptr);
9106 assert(!isa<PHINode>(Latch->front()));
9107
9108 assert(Exit);
9109 assert(isa<BranchInst>(Exit->getTerminator()) &&
9110 "Exit block must terminate with unconditional branch");
9111 assert(Exit->getSingleSuccessor() == After &&
9112 "Exit block must jump to after block");
9113
9114 assert(After);
9115 assert(After->getSinglePredecessor() == Exit &&
9116 "After block only reachable from exit block");
9117 assert(After->empty() || !isa<PHINode>(After->front()));
9118
9119 Instruction *IndVar = getIndVar();
9120 assert(IndVar && "Canonical induction variable not found?");
9121 assert(isa<IntegerType>(IndVar->getType()) &&
9122 "Induction variable must be an integer");
9123 assert(cast<PHINode>(IndVar)->getParent() == Header &&
9124 "Induction variable must be a PHI in the loop header");
9125 assert(cast<PHINode>(IndVar)->getIncomingBlock(0) == Preheader);
9126 assert(
9127 cast<ConstantInt>(cast<PHINode>(IndVar)->getIncomingValue(0))->isZero());
9128 assert(cast<PHINode>(IndVar)->getIncomingBlock(1) == Latch);
9129
9130 auto *NextIndVar = cast<PHINode>(IndVar)->getIncomingValue(1);
9131 assert(cast<Instruction>(NextIndVar)->getParent() == Latch);
9132 assert(cast<BinaryOperator>(NextIndVar)->getOpcode() == BinaryOperator::Add);
9133 assert(cast<BinaryOperator>(NextIndVar)->getOperand(0) == IndVar);
9134 assert(cast<ConstantInt>(cast<BinaryOperator>(NextIndVar)->getOperand(1))
9135 ->isOne());
9136
9137 Value *TripCount = getTripCount();
9138 assert(TripCount && "Loop trip count not found?");
9139 assert(IndVar->getType() == TripCount->getType() &&
9140 "Trip count and induction variable must have the same type");
9141
9142 auto *CmpI = cast<CmpInst>(&Cond->front());
9143 assert(CmpI->getPredicate() == CmpInst::ICMP_ULT &&
9144 "Exit condition must be a signed less-than comparison");
9145 assert(CmpI->getOperand(0) == IndVar &&
9146 "Exit condition must compare the induction variable");
9147 assert(CmpI->getOperand(1) == TripCount &&
9148 "Exit condition must compare with the trip count");
9149#endif
9150}
9151
9152void CanonicalLoopInfo::invalidate() {
9153 Header = nullptr;
9154 Cond = nullptr;
9155 Latch = nullptr;
9156 Exit = nullptr;
9157}
9158