1//===- OpenMPIRBuilder.cpp - Builder for LLVM-IR for OpenMP directives ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8/// \file
9///
10/// This file implements the OpenMPIRBuilder class, which is used as a
11/// convenient way to create LLVM instructions for OpenMP directives.
12///
13//===----------------------------------------------------------------------===//
14
15#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
16#include "llvm/ADT/SmallBitVector.h"
17#include "llvm/ADT/SmallSet.h"
18#include "llvm/ADT/StringExtras.h"
19#include "llvm/ADT/StringRef.h"
20#include "llvm/Analysis/AssumptionCache.h"
21#include "llvm/Analysis/CodeMetrics.h"
22#include "llvm/Analysis/LoopInfo.h"
23#include "llvm/Analysis/OptimizationRemarkEmitter.h"
24#include "llvm/Analysis/PostDominators.h"
25#include "llvm/Analysis/ScalarEvolution.h"
26#include "llvm/Analysis/TargetLibraryInfo.h"
27#include "llvm/Bitcode/BitcodeReader.h"
28#include "llvm/Frontend/Offloading/Utility.h"
29#include "llvm/Frontend/OpenMP/OMPGridValues.h"
30#include "llvm/IR/Attributes.h"
31#include "llvm/IR/BasicBlock.h"
32#include "llvm/IR/CFG.h"
33#include "llvm/IR/CallingConv.h"
34#include "llvm/IR/Constant.h"
35#include "llvm/IR/Constants.h"
36#include "llvm/IR/DIBuilder.h"
37#include "llvm/IR/DebugInfoMetadata.h"
38#include "llvm/IR/DerivedTypes.h"
39#include "llvm/IR/Function.h"
40#include "llvm/IR/GlobalVariable.h"
41#include "llvm/IR/IRBuilder.h"
42#include "llvm/IR/InstIterator.h"
43#include "llvm/IR/IntrinsicInst.h"
44#include "llvm/IR/LLVMContext.h"
45#include "llvm/IR/MDBuilder.h"
46#include "llvm/IR/Metadata.h"
47#include "llvm/IR/PassInstrumentation.h"
48#include "llvm/IR/PassManager.h"
49#include "llvm/IR/ReplaceConstant.h"
50#include "llvm/IR/Value.h"
51#include "llvm/MC/TargetRegistry.h"
52#include "llvm/Support/CommandLine.h"
53#include "llvm/Support/Error.h"
54#include "llvm/Support/ErrorHandling.h"
55#include "llvm/Support/FileSystem.h"
56#include "llvm/Support/VirtualFileSystem.h"
57#include "llvm/Target/TargetMachine.h"
58#include "llvm/Target/TargetOptions.h"
59#include "llvm/Transforms/Utils/BasicBlockUtils.h"
60#include "llvm/Transforms/Utils/Cloning.h"
61#include "llvm/Transforms/Utils/CodeExtractor.h"
62#include "llvm/Transforms/Utils/LoopPeel.h"
63#include "llvm/Transforms/Utils/UnrollLoop.h"
64
65#include <cstdint>
66#include <optional>
67
68#define DEBUG_TYPE "openmp-ir-builder"
69
70using namespace llvm;
71using namespace omp;
72
73static cl::opt<bool>
74 OptimisticAttributes("openmp-ir-builder-optimistic-attributes", cl::Hidden,
75 cl::desc("Use optimistic attributes describing "
76 "'as-if' properties of runtime calls."),
77 cl::init(Val: false));
78
79static cl::opt<double> UnrollThresholdFactor(
80 "openmp-ir-builder-unroll-threshold-factor", cl::Hidden,
81 cl::desc("Factor for the unroll threshold to account for code "
82 "simplifications still taking place"),
83 cl::init(Val: 1.5));
84
85#ifndef NDEBUG
86/// Return whether IP1 and IP2 are ambiguous, i.e. that inserting instructions
87/// at position IP1 may change the meaning of IP2 or vice-versa. This is because
88/// an InsertPoint stores the instruction before something is inserted. For
89/// instance, if both point to the same instruction, two IRBuilders alternating
90/// creating instruction will cause the instructions to be interleaved.
91static bool isConflictIP(IRBuilder<>::InsertPoint IP1,
92 IRBuilder<>::InsertPoint IP2) {
93 if (!IP1.isSet() || !IP2.isSet())
94 return false;
95 return IP1.getBlock() == IP2.getBlock() && IP1.getPoint() == IP2.getPoint();
96}
97
98static bool isValidWorkshareLoopScheduleType(OMPScheduleType SchedType) {
99 // Valid ordered/unordered and base algorithm combinations.
100 switch (SchedType & ~OMPScheduleType::MonotonicityMask) {
101 case OMPScheduleType::UnorderedStaticChunked:
102 case OMPScheduleType::UnorderedStatic:
103 case OMPScheduleType::UnorderedDynamicChunked:
104 case OMPScheduleType::UnorderedGuidedChunked:
105 case OMPScheduleType::UnorderedRuntime:
106 case OMPScheduleType::UnorderedAuto:
107 case OMPScheduleType::UnorderedTrapezoidal:
108 case OMPScheduleType::UnorderedGreedy:
109 case OMPScheduleType::UnorderedBalanced:
110 case OMPScheduleType::UnorderedGuidedIterativeChunked:
111 case OMPScheduleType::UnorderedGuidedAnalyticalChunked:
112 case OMPScheduleType::UnorderedSteal:
113 case OMPScheduleType::UnorderedStaticBalancedChunked:
114 case OMPScheduleType::UnorderedGuidedSimd:
115 case OMPScheduleType::UnorderedRuntimeSimd:
116 case OMPScheduleType::OrderedStaticChunked:
117 case OMPScheduleType::OrderedStatic:
118 case OMPScheduleType::OrderedDynamicChunked:
119 case OMPScheduleType::OrderedGuidedChunked:
120 case OMPScheduleType::OrderedRuntime:
121 case OMPScheduleType::OrderedAuto:
122 case OMPScheduleType::OrderdTrapezoidal:
123 case OMPScheduleType::NomergeUnorderedStaticChunked:
124 case OMPScheduleType::NomergeUnorderedStatic:
125 case OMPScheduleType::NomergeUnorderedDynamicChunked:
126 case OMPScheduleType::NomergeUnorderedGuidedChunked:
127 case OMPScheduleType::NomergeUnorderedRuntime:
128 case OMPScheduleType::NomergeUnorderedAuto:
129 case OMPScheduleType::NomergeUnorderedTrapezoidal:
130 case OMPScheduleType::NomergeUnorderedGreedy:
131 case OMPScheduleType::NomergeUnorderedBalanced:
132 case OMPScheduleType::NomergeUnorderedGuidedIterativeChunked:
133 case OMPScheduleType::NomergeUnorderedGuidedAnalyticalChunked:
134 case OMPScheduleType::NomergeUnorderedSteal:
135 case OMPScheduleType::NomergeOrderedStaticChunked:
136 case OMPScheduleType::NomergeOrderedStatic:
137 case OMPScheduleType::NomergeOrderedDynamicChunked:
138 case OMPScheduleType::NomergeOrderedGuidedChunked:
139 case OMPScheduleType::NomergeOrderedRuntime:
140 case OMPScheduleType::NomergeOrderedAuto:
141 case OMPScheduleType::NomergeOrderedTrapezoidal:
142 case OMPScheduleType::OrderedDistributeChunked:
143 case OMPScheduleType::OrderedDistribute:
144 break;
145 default:
146 return false;
147 }
148
149 // Must not set both monotonicity modifiers at the same time.
150 OMPScheduleType MonotonicityFlags =
151 SchedType & OMPScheduleType::MonotonicityMask;
152 if (MonotonicityFlags == OMPScheduleType::MonotonicityMask)
153 return false;
154
155 return true;
156}
157#endif
158
159/// This is wrapper over IRBuilderBase::restoreIP that also restores the current
160/// debug location to the last instruction in the specified basic block if the
161/// insert point points to the end of the block.
162static void restoreIPandDebugLoc(llvm::IRBuilderBase &Builder,
163 llvm::IRBuilderBase::InsertPoint IP) {
164 Builder.restoreIP(IP);
165 llvm::BasicBlock *BB = Builder.GetInsertBlock();
166 llvm::BasicBlock::iterator I = Builder.GetInsertPoint();
167 if (!BB->empty() && I == BB->end())
168 Builder.SetCurrentDebugLocation(BB->back().getStableDebugLoc());
169}
170
171static const omp::GV &getGridValue(const Triple &T, Function *Kernel) {
172 if (T.isAMDGPU()) {
173 StringRef Features =
174 Kernel->getFnAttribute(Kind: "target-features").getValueAsString();
175 if (Features.count(Str: "+wavefrontsize64"))
176 return omp::getAMDGPUGridValues<64>();
177 return omp::getAMDGPUGridValues<32>();
178 }
179 if (T.isNVPTX())
180 return omp::NVPTXGridValues;
181 if (T.isSPIRV())
182 return omp::SPIRVGridValues;
183 llvm_unreachable("No grid value available for this architecture!");
184}
185
186/// Determine which scheduling algorithm to use, determined from schedule clause
187/// arguments.
188static OMPScheduleType
189getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind, bool HasChunks,
190 bool HasSimdModifier, bool HasDistScheduleChunks) {
191 // Currently, the default schedule it static.
192 switch (ClauseKind) {
193 case OMP_SCHEDULE_Default:
194 case OMP_SCHEDULE_Static:
195 return HasChunks ? OMPScheduleType::BaseStaticChunked
196 : OMPScheduleType::BaseStatic;
197 case OMP_SCHEDULE_Dynamic:
198 return OMPScheduleType::BaseDynamicChunked;
199 case OMP_SCHEDULE_Guided:
200 return HasSimdModifier ? OMPScheduleType::BaseGuidedSimd
201 : OMPScheduleType::BaseGuidedChunked;
202 case OMP_SCHEDULE_Auto:
203 return llvm::omp::OMPScheduleType::BaseAuto;
204 case OMP_SCHEDULE_Runtime:
205 return HasSimdModifier ? OMPScheduleType::BaseRuntimeSimd
206 : OMPScheduleType::BaseRuntime;
207 case OMP_SCHEDULE_Distribute:
208 return HasDistScheduleChunks ? OMPScheduleType::BaseDistributeChunked
209 : OMPScheduleType::BaseDistribute;
210 }
211 llvm_unreachable("unhandled schedule clause argument");
212}
213
214/// Adds ordering modifier flags to schedule type.
215static OMPScheduleType
216getOpenMPOrderingScheduleType(OMPScheduleType BaseScheduleType,
217 bool HasOrderedClause) {
218 assert((BaseScheduleType & OMPScheduleType::ModifierMask) ==
219 OMPScheduleType::None &&
220 "Must not have ordering nor monotonicity flags already set");
221
222 OMPScheduleType OrderingModifier = HasOrderedClause
223 ? OMPScheduleType::ModifierOrdered
224 : OMPScheduleType::ModifierUnordered;
225 OMPScheduleType OrderingScheduleType = BaseScheduleType | OrderingModifier;
226
227 // Unsupported combinations
228 if (OrderingScheduleType ==
229 (OMPScheduleType::BaseGuidedSimd | OMPScheduleType::ModifierOrdered))
230 return OMPScheduleType::OrderedGuidedChunked;
231 else if (OrderingScheduleType == (OMPScheduleType::BaseRuntimeSimd |
232 OMPScheduleType::ModifierOrdered))
233 return OMPScheduleType::OrderedRuntime;
234
235 return OrderingScheduleType;
236}
237
238/// Adds monotonicity modifier flags to schedule type.
239static OMPScheduleType
240getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,
241 bool HasSimdModifier, bool HasMonotonic,
242 bool HasNonmonotonic, bool HasOrderedClause) {
243 assert((ScheduleType & OMPScheduleType::MonotonicityMask) ==
244 OMPScheduleType::None &&
245 "Must not have monotonicity flags already set");
246 assert((!HasMonotonic || !HasNonmonotonic) &&
247 "Monotonic and Nonmonotonic are contradicting each other");
248
249 if (HasMonotonic) {
250 return ScheduleType | OMPScheduleType::ModifierMonotonic;
251 } else if (HasNonmonotonic) {
252 return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
253 } else {
254 // OpenMP 5.1, 2.11.4 Worksharing-Loop Construct, Description.
255 // If the static schedule kind is specified or if the ordered clause is
256 // specified, and if the nonmonotonic modifier is not specified, the
257 // effect is as if the monotonic modifier is specified. Otherwise, unless
258 // the monotonic modifier is specified, the effect is as if the
259 // nonmonotonic modifier is specified.
260 OMPScheduleType BaseScheduleType =
261 ScheduleType & ~OMPScheduleType::ModifierMask;
262 if ((BaseScheduleType == OMPScheduleType::BaseStatic) ||
263 (BaseScheduleType == OMPScheduleType::BaseStaticChunked) ||
264 HasOrderedClause) {
265 // The monotonic is used by default in openmp runtime library, so no need
266 // to set it.
267 return ScheduleType;
268 } else {
269 return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
270 }
271 }
272}
273
274/// Determine the schedule type using schedule and ordering clause arguments.
275static OMPScheduleType
276computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
277 bool HasSimdModifier, bool HasMonotonicModifier,
278 bool HasNonmonotonicModifier, bool HasOrderedClause,
279 bool HasDistScheduleChunks) {
280 OMPScheduleType BaseSchedule = getOpenMPBaseScheduleType(
281 ClauseKind, HasChunks, HasSimdModifier, HasDistScheduleChunks);
282 OMPScheduleType OrderedSchedule =
283 getOpenMPOrderingScheduleType(BaseScheduleType: BaseSchedule, HasOrderedClause);
284 OMPScheduleType Result = getOpenMPMonotonicityScheduleType(
285 ScheduleType: OrderedSchedule, HasSimdModifier, HasMonotonic: HasMonotonicModifier,
286 HasNonmonotonic: HasNonmonotonicModifier, HasOrderedClause);
287
288 assert(isValidWorkshareLoopScheduleType(Result));
289 return Result;
290}
291
292/// Make \p Source branch to \p Target.
293///
294/// Handles two situations:
295/// * \p Source already has an unconditional branch.
296/// * \p Source is a degenerate block (no terminator because the BB is
297/// the current head of the IR construction).
298static void redirectTo(BasicBlock *Source, BasicBlock *Target, DebugLoc DL) {
299 if (Instruction *Term = Source->getTerminator()) {
300 auto *Br = cast<BranchInst>(Val: Term);
301 assert(!Br->isConditional() &&
302 "BB's terminator must be an unconditional branch (or degenerate)");
303 BasicBlock *Succ = Br->getSuccessor(i: 0);
304 Succ->removePredecessor(Pred: Source, /*KeepOneInputPHIs=*/true);
305 Br->setSuccessor(idx: 0, NewSucc: Target);
306 return;
307 }
308
309 auto *NewBr = BranchInst::Create(IfTrue: Target, InsertBefore: Source);
310 NewBr->setDebugLoc(DL);
311}
312
313void llvm::spliceBB(IRBuilderBase::InsertPoint IP, BasicBlock *New,
314 bool CreateBranch, DebugLoc DL) {
315 assert(New->getFirstInsertionPt() == New->begin() &&
316 "Target BB must not have PHI nodes");
317
318 // Move instructions to new block.
319 BasicBlock *Old = IP.getBlock();
320 // If the `Old` block is empty then there are no instructions to move. But in
321 // the new debug scheme, it could have trailing debug records which will be
322 // moved to `New` in `spliceDebugInfoEmptyBlock`. We dont want that for 2
323 // reasons:
324 // 1. If `New` is also empty, `BasicBlock::splice` crashes.
325 // 2. Even if `New` is not empty, the rationale to move those records to `New`
326 // (in `spliceDebugInfoEmptyBlock`) does not apply here. That function
327 // assumes that `Old` is optimized out and is going away. This is not the case
328 // here. The `Old` block is still being used e.g. a branch instruction is
329 // added to it later in this function.
330 // So we call `BasicBlock::splice` only when `Old` is not empty.
331 if (!Old->empty())
332 New->splice(ToIt: New->begin(), FromBB: Old, FromBeginIt: IP.getPoint(), FromEndIt: Old->end());
333
334 if (CreateBranch) {
335 auto *NewBr = BranchInst::Create(IfTrue: New, InsertBefore: Old);
336 NewBr->setDebugLoc(DL);
337 }
338}
339
340void llvm::spliceBB(IRBuilder<> &Builder, BasicBlock *New, bool CreateBranch) {
341 DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
342 BasicBlock *Old = Builder.GetInsertBlock();
343
344 spliceBB(IP: Builder.saveIP(), New, CreateBranch, DL: DebugLoc);
345 if (CreateBranch)
346 Builder.SetInsertPoint(Old->getTerminator());
347 else
348 Builder.SetInsertPoint(Old);
349
350 // SetInsertPoint also updates the Builder's debug location, but we want to
351 // keep the one the Builder was configured to use.
352 Builder.SetCurrentDebugLocation(DebugLoc);
353}
354
355BasicBlock *llvm::splitBB(IRBuilderBase::InsertPoint IP, bool CreateBranch,
356 DebugLoc DL, llvm::Twine Name) {
357 BasicBlock *Old = IP.getBlock();
358 BasicBlock *New = BasicBlock::Create(
359 Context&: Old->getContext(), Name: Name.isTriviallyEmpty() ? Old->getName() : Name,
360 Parent: Old->getParent(), InsertBefore: Old->getNextNode());
361 spliceBB(IP, New, CreateBranch, DL);
362 New->replaceSuccessorsPhiUsesWith(Old, New);
363 return New;
364}
365
366BasicBlock *llvm::splitBB(IRBuilderBase &Builder, bool CreateBranch,
367 llvm::Twine Name) {
368 DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
369 BasicBlock *New = splitBB(IP: Builder.saveIP(), CreateBranch, DL: DebugLoc, Name);
370 if (CreateBranch)
371 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
372 else
373 Builder.SetInsertPoint(Builder.GetInsertBlock());
374 // SetInsertPoint also updates the Builder's debug location, but we want to
375 // keep the one the Builder was configured to use.
376 Builder.SetCurrentDebugLocation(DebugLoc);
377 return New;
378}
379
380BasicBlock *llvm::splitBB(IRBuilder<> &Builder, bool CreateBranch,
381 llvm::Twine Name) {
382 DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
383 BasicBlock *New = splitBB(IP: Builder.saveIP(), CreateBranch, DL: DebugLoc, Name);
384 if (CreateBranch)
385 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
386 else
387 Builder.SetInsertPoint(Builder.GetInsertBlock());
388 // SetInsertPoint also updates the Builder's debug location, but we want to
389 // keep the one the Builder was configured to use.
390 Builder.SetCurrentDebugLocation(DebugLoc);
391 return New;
392}
393
394BasicBlock *llvm::splitBBWithSuffix(IRBuilderBase &Builder, bool CreateBranch,
395 llvm::Twine Suffix) {
396 BasicBlock *Old = Builder.GetInsertBlock();
397 return splitBB(Builder, CreateBranch, Name: Old->getName() + Suffix);
398}
399
400// This function creates a fake integer value and a fake use for the integer
401// value. It returns the fake value created. This is useful in modeling the
402// extra arguments to the outlined functions.
403Value *createFakeIntVal(IRBuilderBase &Builder,
404 OpenMPIRBuilder::InsertPointTy OuterAllocaIP,
405 llvm::SmallVectorImpl<Instruction *> &ToBeDeleted,
406 OpenMPIRBuilder::InsertPointTy InnerAllocaIP,
407 const Twine &Name = "", bool AsPtr = true,
408 bool Is64Bit = false) {
409 Builder.restoreIP(IP: OuterAllocaIP);
410 IntegerType *IntTy = Is64Bit ? Builder.getInt64Ty() : Builder.getInt32Ty();
411 Instruction *FakeVal;
412 AllocaInst *FakeValAddr =
413 Builder.CreateAlloca(Ty: IntTy, ArraySize: nullptr, Name: Name + ".addr");
414 ToBeDeleted.push_back(Elt: FakeValAddr);
415
416 if (AsPtr) {
417 FakeVal = FakeValAddr;
418 } else {
419 FakeVal = Builder.CreateLoad(Ty: IntTy, Ptr: FakeValAddr, Name: Name + ".val");
420 ToBeDeleted.push_back(Elt: FakeVal);
421 }
422
423 // Generate a fake use of this value
424 Builder.restoreIP(IP: InnerAllocaIP);
425 Instruction *UseFakeVal;
426 if (AsPtr) {
427 UseFakeVal = Builder.CreateLoad(Ty: IntTy, Ptr: FakeVal, Name: Name + ".use");
428 } else {
429 UseFakeVal = cast<BinaryOperator>(Val: Builder.CreateAdd(
430 LHS: FakeVal, RHS: Is64Bit ? Builder.getInt64(C: 10) : Builder.getInt32(C: 10)));
431 }
432 ToBeDeleted.push_back(Elt: UseFakeVal);
433 return FakeVal;
434}
435
436//===----------------------------------------------------------------------===//
437// OpenMPIRBuilderConfig
438//===----------------------------------------------------------------------===//
439
440namespace {
441LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
442/// Values for bit flags for marking which requires clauses have been used.
443enum OpenMPOffloadingRequiresDirFlags {
444 /// flag undefined.
445 OMP_REQ_UNDEFINED = 0x000,
446 /// no requires directive present.
447 OMP_REQ_NONE = 0x001,
448 /// reverse_offload clause.
449 OMP_REQ_REVERSE_OFFLOAD = 0x002,
450 /// unified_address clause.
451 OMP_REQ_UNIFIED_ADDRESS = 0x004,
452 /// unified_shared_memory clause.
453 OMP_REQ_UNIFIED_SHARED_MEMORY = 0x008,
454 /// dynamic_allocators clause.
455 OMP_REQ_DYNAMIC_ALLOCATORS = 0x010,
456 LLVM_MARK_AS_BITMASK_ENUM(/*LargestValue=*/OMP_REQ_DYNAMIC_ALLOCATORS)
457};
458
459} // anonymous namespace
460
461OpenMPIRBuilderConfig::OpenMPIRBuilderConfig()
462 : RequiresFlags(OMP_REQ_UNDEFINED) {}
463
464OpenMPIRBuilderConfig::OpenMPIRBuilderConfig(
465 bool IsTargetDevice, bool IsGPU, bool OpenMPOffloadMandatory,
466 bool HasRequiresReverseOffload, bool HasRequiresUnifiedAddress,
467 bool HasRequiresUnifiedSharedMemory, bool HasRequiresDynamicAllocators)
468 : IsTargetDevice(IsTargetDevice), IsGPU(IsGPU),
469 OpenMPOffloadMandatory(OpenMPOffloadMandatory),
470 RequiresFlags(OMP_REQ_UNDEFINED) {
471 if (HasRequiresReverseOffload)
472 RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
473 if (HasRequiresUnifiedAddress)
474 RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
475 if (HasRequiresUnifiedSharedMemory)
476 RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
477 if (HasRequiresDynamicAllocators)
478 RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
479}
480
481bool OpenMPIRBuilderConfig::hasRequiresReverseOffload() const {
482 return RequiresFlags & OMP_REQ_REVERSE_OFFLOAD;
483}
484
485bool OpenMPIRBuilderConfig::hasRequiresUnifiedAddress() const {
486 return RequiresFlags & OMP_REQ_UNIFIED_ADDRESS;
487}
488
489bool OpenMPIRBuilderConfig::hasRequiresUnifiedSharedMemory() const {
490 return RequiresFlags & OMP_REQ_UNIFIED_SHARED_MEMORY;
491}
492
493bool OpenMPIRBuilderConfig::hasRequiresDynamicAllocators() const {
494 return RequiresFlags & OMP_REQ_DYNAMIC_ALLOCATORS;
495}
496
497int64_t OpenMPIRBuilderConfig::getRequiresFlags() const {
498 return hasRequiresFlags() ? RequiresFlags
499 : static_cast<int64_t>(OMP_REQ_NONE);
500}
501
502void OpenMPIRBuilderConfig::setHasRequiresReverseOffload(bool Value) {
503 if (Value)
504 RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
505 else
506 RequiresFlags &= ~OMP_REQ_REVERSE_OFFLOAD;
507}
508
509void OpenMPIRBuilderConfig::setHasRequiresUnifiedAddress(bool Value) {
510 if (Value)
511 RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
512 else
513 RequiresFlags &= ~OMP_REQ_UNIFIED_ADDRESS;
514}
515
516void OpenMPIRBuilderConfig::setHasRequiresUnifiedSharedMemory(bool Value) {
517 if (Value)
518 RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
519 else
520 RequiresFlags &= ~OMP_REQ_UNIFIED_SHARED_MEMORY;
521}
522
523void OpenMPIRBuilderConfig::setHasRequiresDynamicAllocators(bool Value) {
524 if (Value)
525 RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
526 else
527 RequiresFlags &= ~OMP_REQ_DYNAMIC_ALLOCATORS;
528}
529
530//===----------------------------------------------------------------------===//
531// OpenMPIRBuilder
532//===----------------------------------------------------------------------===//
533
534void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
535 IRBuilderBase &Builder,
536 SmallVector<Value *> &ArgsVector) {
537 Value *Version = Builder.getInt32(OMP_KERNEL_ARG_VERSION);
538 Value *PointerNum = Builder.getInt32(C: KernelArgs.NumTargetItems);
539 auto Int32Ty = Type::getInt32Ty(C&: Builder.getContext());
540 constexpr size_t MaxDim = 3;
541 Value *ZeroArray = Constant::getNullValue(Ty: ArrayType::get(ElementType: Int32Ty, NumElements: MaxDim));
542
543 Value *HasNoWaitFlag = Builder.getInt64(C: KernelArgs.HasNoWait);
544
545 Value *DynCGroupMemFallbackFlag =
546 Builder.getInt64(C: static_cast<uint64_t>(KernelArgs.DynCGroupMemFallback));
547 DynCGroupMemFallbackFlag = Builder.CreateShl(LHS: DynCGroupMemFallbackFlag, RHS: 2);
548 Value *Flags = Builder.CreateOr(LHS: HasNoWaitFlag, RHS: DynCGroupMemFallbackFlag);
549
550 assert(!KernelArgs.NumTeams.empty() && !KernelArgs.NumThreads.empty());
551
552 Value *NumTeams3D =
553 Builder.CreateInsertValue(Agg: ZeroArray, Val: KernelArgs.NumTeams[0], Idxs: {0});
554 Value *NumThreads3D =
555 Builder.CreateInsertValue(Agg: ZeroArray, Val: KernelArgs.NumThreads[0], Idxs: {0});
556 for (unsigned I :
557 seq<unsigned>(Begin: 1, End: std::min(a: KernelArgs.NumTeams.size(), b: MaxDim)))
558 NumTeams3D =
559 Builder.CreateInsertValue(Agg: NumTeams3D, Val: KernelArgs.NumTeams[I], Idxs: {I});
560 for (unsigned I :
561 seq<unsigned>(Begin: 1, End: std::min(a: KernelArgs.NumThreads.size(), b: MaxDim)))
562 NumThreads3D =
563 Builder.CreateInsertValue(Agg: NumThreads3D, Val: KernelArgs.NumThreads[I], Idxs: {I});
564
565 ArgsVector = {Version,
566 PointerNum,
567 KernelArgs.RTArgs.BasePointersArray,
568 KernelArgs.RTArgs.PointersArray,
569 KernelArgs.RTArgs.SizesArray,
570 KernelArgs.RTArgs.MapTypesArray,
571 KernelArgs.RTArgs.MapNamesArray,
572 KernelArgs.RTArgs.MappersArray,
573 KernelArgs.NumIterations,
574 Flags,
575 NumTeams3D,
576 NumThreads3D,
577 KernelArgs.DynCGroupMem};
578}
579
580void OpenMPIRBuilder::addAttributes(omp::RuntimeFunction FnID, Function &Fn) {
581 LLVMContext &Ctx = Fn.getContext();
582
583 // Get the function's current attributes.
584 auto Attrs = Fn.getAttributes();
585 auto FnAttrs = Attrs.getFnAttrs();
586 auto RetAttrs = Attrs.getRetAttrs();
587 SmallVector<AttributeSet, 4> ArgAttrs;
588 for (size_t ArgNo = 0; ArgNo < Fn.arg_size(); ++ArgNo)
589 ArgAttrs.emplace_back(Args: Attrs.getParamAttrs(ArgNo));
590
591 // Add AS to FnAS while taking special care with integer extensions.
592 auto addAttrSet = [&](AttributeSet &FnAS, const AttributeSet &AS,
593 bool Param = true) -> void {
594 bool HasSignExt = AS.hasAttribute(Kind: Attribute::SExt);
595 bool HasZeroExt = AS.hasAttribute(Kind: Attribute::ZExt);
596 if (HasSignExt || HasZeroExt) {
597 assert(AS.getNumAttributes() == 1 &&
598 "Currently not handling extension attr combined with others.");
599 if (Param) {
600 if (auto AK = TargetLibraryInfo::getExtAttrForI32Param(T, Signed: HasSignExt))
601 FnAS = FnAS.addAttribute(C&: Ctx, Kind: AK);
602 } else if (auto AK =
603 TargetLibraryInfo::getExtAttrForI32Return(T, Signed: HasSignExt))
604 FnAS = FnAS.addAttribute(C&: Ctx, Kind: AK);
605 } else {
606 FnAS = FnAS.addAttributes(C&: Ctx, AS);
607 }
608 };
609
610#define OMP_ATTRS_SET(VarName, AttrSet) AttributeSet VarName = AttrSet;
611#include "llvm/Frontend/OpenMP/OMPKinds.def"
612
613 // Add attributes to the function declaration.
614 switch (FnID) {
615#define OMP_RTL_ATTRS(Enum, FnAttrSet, RetAttrSet, ArgAttrSets) \
616 case Enum: \
617 FnAttrs = FnAttrs.addAttributes(Ctx, FnAttrSet); \
618 addAttrSet(RetAttrs, RetAttrSet, /*Param*/ false); \
619 for (size_t ArgNo = 0; ArgNo < ArgAttrSets.size(); ++ArgNo) \
620 addAttrSet(ArgAttrs[ArgNo], ArgAttrSets[ArgNo]); \
621 Fn.setAttributes(AttributeList::get(Ctx, FnAttrs, RetAttrs, ArgAttrs)); \
622 break;
623#include "llvm/Frontend/OpenMP/OMPKinds.def"
624 default:
625 // Attributes are optional.
626 break;
627 }
628}
629
630FunctionCallee
631OpenMPIRBuilder::getOrCreateRuntimeFunction(Module &M, RuntimeFunction FnID) {
632 FunctionType *FnTy = nullptr;
633 Function *Fn = nullptr;
634
635 // Try to find the declation in the module first.
636 switch (FnID) {
637#define OMP_RTL(Enum, Str, IsVarArg, ReturnType, ...) \
638 case Enum: \
639 FnTy = FunctionType::get(ReturnType, ArrayRef<Type *>{__VA_ARGS__}, \
640 IsVarArg); \
641 Fn = M.getFunction(Str); \
642 break;
643#include "llvm/Frontend/OpenMP/OMPKinds.def"
644 }
645
646 if (!Fn) {
647 // Create a new declaration if we need one.
648 switch (FnID) {
649#define OMP_RTL(Enum, Str, ...) \
650 case Enum: \
651 Fn = Function::Create(FnTy, GlobalValue::ExternalLinkage, Str, M); \
652 break;
653#include "llvm/Frontend/OpenMP/OMPKinds.def"
654 }
655 Fn->setCallingConv(Config.getRuntimeCC());
656 // Add information if the runtime function takes a callback function
657 if (FnID == OMPRTL___kmpc_fork_call || FnID == OMPRTL___kmpc_fork_teams) {
658 if (!Fn->hasMetadata(KindID: LLVMContext::MD_callback)) {
659 LLVMContext &Ctx = Fn->getContext();
660 MDBuilder MDB(Ctx);
661 // Annotate the callback behavior of the runtime function:
662 // - The callback callee is argument number 2 (microtask).
663 // - The first two arguments of the callback callee are unknown (-1).
664 // - All variadic arguments to the runtime function are passed to the
665 // callback callee.
666 Fn->addMetadata(
667 KindID: LLVMContext::MD_callback,
668 MD&: *MDNode::get(Context&: Ctx, MDs: {MDB.createCallbackEncoding(
669 CalleeArgNo: 2, Arguments: {-1, -1}, /* VarArgsArePassed */ true)}));
670 }
671 }
672
673 LLVM_DEBUG(dbgs() << "Created OpenMP runtime function " << Fn->getName()
674 << " with type " << *Fn->getFunctionType() << "\n");
675 addAttributes(FnID, Fn&: *Fn);
676
677 } else {
678 LLVM_DEBUG(dbgs() << "Found OpenMP runtime function " << Fn->getName()
679 << " with type " << *Fn->getFunctionType() << "\n");
680 }
681
682 assert(Fn && "Failed to create OpenMP runtime function");
683
684 return {FnTy, Fn};
685}
686
687Expected<BasicBlock *>
688OpenMPIRBuilder::FinalizationInfo::getFiniBB(IRBuilderBase &Builder) {
689 if (!FiniBB) {
690 Function *ParentFunc = Builder.GetInsertBlock()->getParent();
691 IRBuilderBase::InsertPointGuard Guard(Builder);
692 FiniBB = BasicBlock::Create(Context&: Builder.getContext(), Name: ".fini", Parent: ParentFunc);
693 Builder.SetInsertPoint(FiniBB);
694 // FiniCB adds the branch to the exit stub.
695 if (Error Err = FiniCB(Builder.saveIP()))
696 return Err;
697 }
698 return FiniBB;
699}
700
701Error OpenMPIRBuilder::FinalizationInfo::mergeFiniBB(IRBuilderBase &Builder,
702 BasicBlock *OtherFiniBB) {
703 // Simple case: FiniBB does not exist yet: re-use OtherFiniBB.
704 if (!FiniBB) {
705 FiniBB = OtherFiniBB;
706
707 Builder.SetInsertPoint(FiniBB->getFirstNonPHIIt());
708 if (Error Err = FiniCB(Builder.saveIP()))
709 return Err;
710
711 return Error::success();
712 }
713
714 // Move instructions from FiniBB to the start of OtherFiniBB.
715 auto EndIt = FiniBB->end();
716 if (FiniBB->size() >= 1)
717 if (auto Prev = std::prev(x: EndIt); Prev->isTerminator())
718 EndIt = Prev;
719 OtherFiniBB->splice(ToIt: OtherFiniBB->getFirstNonPHIIt(), FromBB: FiniBB, FromBeginIt: FiniBB->begin(),
720 FromEndIt: EndIt);
721
722 FiniBB->replaceAllUsesWith(V: OtherFiniBB);
723 FiniBB->eraseFromParent();
724 FiniBB = OtherFiniBB;
725 return Error::success();
726}
727
728Function *OpenMPIRBuilder::getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID) {
729 FunctionCallee RTLFn = getOrCreateRuntimeFunction(M, FnID);
730 auto *Fn = dyn_cast<llvm::Function>(Val: RTLFn.getCallee());
731 assert(Fn && "Failed to create OpenMP runtime function pointer");
732 return Fn;
733}
734
735CallInst *OpenMPIRBuilder::createRuntimeFunctionCall(FunctionCallee Callee,
736 ArrayRef<Value *> Args,
737 StringRef Name) {
738 CallInst *Call = Builder.CreateCall(Callee, Args, Name);
739 Call->setCallingConv(Config.getRuntimeCC());
740 return Call;
741}
742
743void OpenMPIRBuilder::initialize() { initializeTypes(M); }
744
745static void raiseUserConstantDataAllocasToEntryBlock(IRBuilderBase &Builder,
746 Function *Function) {
747 BasicBlock &EntryBlock = Function->getEntryBlock();
748 BasicBlock::iterator MoveLocInst = EntryBlock.getFirstNonPHIIt();
749
750 // Loop over blocks looking for constant allocas, skipping the entry block
751 // as any allocas there are already in the desired location.
752 for (auto Block = std::next(x: Function->begin(), n: 1); Block != Function->end();
753 Block++) {
754 for (auto Inst = Block->getReverseIterator()->begin();
755 Inst != Block->getReverseIterator()->end();) {
756 if (auto *AllocaInst = dyn_cast_if_present<llvm::AllocaInst>(Val&: Inst)) {
757 Inst++;
758 if (!isa<ConstantData>(Val: AllocaInst->getArraySize()))
759 continue;
760 AllocaInst->moveBeforePreserving(MovePos: MoveLocInst);
761 } else {
762 Inst++;
763 }
764 }
765 }
766}
767
768static void hoistNonEntryAllocasToEntryBlock(llvm::BasicBlock &Block) {
769 llvm::SmallVector<llvm::Instruction *> AllocasToMove;
770
771 auto ShouldHoistAlloca = [](const llvm::AllocaInst &AllocaInst) {
772 // TODO: For now, we support simple static allocations, we might need to
773 // move non-static ones as well. However, this will need further analysis to
774 // move the lenght arguments as well.
775 return !AllocaInst.isArrayAllocation();
776 };
777
778 for (llvm::Instruction &Inst : Block)
779 if (auto *AllocaInst = llvm::dyn_cast<llvm::AllocaInst>(Val: &Inst))
780 if (ShouldHoistAlloca(*AllocaInst))
781 AllocasToMove.push_back(Elt: AllocaInst);
782
783 auto InsertPoint =
784 Block.getParent()->getEntryBlock().getTerminator()->getIterator();
785
786 for (llvm::Instruction *AllocaInst : AllocasToMove)
787 AllocaInst->moveBefore(InsertPos: InsertPoint);
788}
789
790void OpenMPIRBuilder::finalize(Function *Fn) {
791 SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
792 SmallVector<BasicBlock *, 32> Blocks;
793 SmallVector<OutlineInfo, 16> DeferredOutlines;
794 for (OutlineInfo &OI : OutlineInfos) {
795 // Skip functions that have not finalized yet; may happen with nested
796 // function generation.
797 if (Fn && OI.getFunction() != Fn) {
798 DeferredOutlines.push_back(Elt: OI);
799 continue;
800 }
801
802 ParallelRegionBlockSet.clear();
803 Blocks.clear();
804 OI.collectBlocks(BlockSet&: ParallelRegionBlockSet, BlockVector&: Blocks);
805
806 Function *OuterFn = OI.getFunction();
807 CodeExtractorAnalysisCache CEAC(*OuterFn);
808 // If we generate code for the target device, we need to allocate
809 // struct for aggregate params in the device default alloca address space.
810 // OpenMP runtime requires that the params of the extracted functions are
811 // passed as zero address space pointers. This flag ensures that
812 // CodeExtractor generates correct code for extracted functions
813 // which are used by OpenMP runtime.
814 bool ArgsInZeroAddressSpace = Config.isTargetDevice();
815 CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
816 /* AggregateArgs */ true,
817 /* BlockFrequencyInfo */ nullptr,
818 /* BranchProbabilityInfo */ nullptr,
819 /* AssumptionCache */ nullptr,
820 /* AllowVarArgs */ true,
821 /* AllowAlloca */ true,
822 /* AllocaBlock*/ OI.OuterAllocaBB,
823 /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
824
825 LLVM_DEBUG(dbgs() << "Before outlining: " << *OuterFn << "\n");
826 LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
827 << " Exit: " << OI.ExitBB->getName() << "\n");
828 assert(Extractor.isEligible() &&
829 "Expected OpenMP outlining to be possible!");
830
831 for (auto *V : OI.ExcludeArgsFromAggregate)
832 Extractor.excludeArgFromAggregate(Arg: V);
833
834 Function *OutlinedFn =
835 Extractor.extractCodeRegion(CEAC, Inputs&: OI.Inputs, Outputs&: OI.Outputs);
836
837 // Forward target-cpu, target-features attributes to the outlined function.
838 auto TargetCpuAttr = OuterFn->getFnAttribute(Kind: "target-cpu");
839 if (TargetCpuAttr.isStringAttribute())
840 OutlinedFn->addFnAttr(Attr: TargetCpuAttr);
841
842 auto TargetFeaturesAttr = OuterFn->getFnAttribute(Kind: "target-features");
843 if (TargetFeaturesAttr.isStringAttribute())
844 OutlinedFn->addFnAttr(Attr: TargetFeaturesAttr);
845
846 LLVM_DEBUG(dbgs() << "After outlining: " << *OuterFn << "\n");
847 LLVM_DEBUG(dbgs() << " Outlined function: " << *OutlinedFn << "\n");
848 assert(OutlinedFn->getReturnType()->isVoidTy() &&
849 "OpenMP outlined functions should not return a value!");
850
851 // For compability with the clang CG we move the outlined function after the
852 // one with the parallel region.
853 OutlinedFn->removeFromParent();
854 M.getFunctionList().insertAfter(where: OuterFn->getIterator(), New: OutlinedFn);
855
856 // Remove the artificial entry introduced by the extractor right away, we
857 // made our own entry block after all.
858 {
859 BasicBlock &ArtificialEntry = OutlinedFn->getEntryBlock();
860 assert(ArtificialEntry.getUniqueSuccessor() == OI.EntryBB);
861 assert(OI.EntryBB->getUniquePredecessor() == &ArtificialEntry);
862 // Move instructions from the to-be-deleted ArtificialEntry to the entry
863 // basic block of the parallel region. CodeExtractor generates
864 // instructions to unwrap the aggregate argument and may sink
865 // allocas/bitcasts for values that are solely used in the outlined region
866 // and do not escape.
867 assert(!ArtificialEntry.empty() &&
868 "Expected instructions to add in the outlined region entry");
869 for (BasicBlock::reverse_iterator It = ArtificialEntry.rbegin(),
870 End = ArtificialEntry.rend();
871 It != End;) {
872 Instruction &I = *It;
873 It++;
874
875 if (I.isTerminator()) {
876 // Absorb any debug value that terminator may have
877 if (OI.EntryBB->getTerminator())
878 OI.EntryBB->getTerminator()->adoptDbgRecords(
879 BB: &ArtificialEntry, It: I.getIterator(), InsertAtHead: false);
880 continue;
881 }
882
883 I.moveBeforePreserving(BB&: *OI.EntryBB, I: OI.EntryBB->getFirstInsertionPt());
884 }
885
886 OI.EntryBB->moveBefore(MovePos: &ArtificialEntry);
887 ArtificialEntry.eraseFromParent();
888 }
889 assert(&OutlinedFn->getEntryBlock() == OI.EntryBB);
890 assert(OutlinedFn && OutlinedFn->hasNUses(1));
891
892 // Run a user callback, e.g. to add attributes.
893 if (OI.PostOutlineCB)
894 OI.PostOutlineCB(*OutlinedFn);
895
896 if (OI.FixUpNonEntryAllocas) {
897 PostDominatorTree PostDomTree(*OutlinedFn);
898 for (llvm::BasicBlock &BB : *OutlinedFn)
899 if (PostDomTree.properlyDominates(A: &BB, B: &OutlinedFn->getEntryBlock()))
900 hoistNonEntryAllocasToEntryBlock(Block&: BB);
901 }
902 }
903
904 // Remove work items that have been completed.
905 OutlineInfos = std::move(DeferredOutlines);
906
907 // The createTarget functions embeds user written code into
908 // the target region which may inject allocas which need to
909 // be moved to the entry block of our target or risk malformed
910 // optimisations by later passes, this is only relevant for
911 // the device pass which appears to be a little more delicate
912 // when it comes to optimisations (however, we do not block on
913 // that here, it's up to the inserter to the list to do so).
914 // This notbaly has to occur after the OutlinedInfo candidates
915 // have been extracted so we have an end product that will not
916 // be implicitly adversely affected by any raises unless
917 // intentionally appended to the list.
918 // NOTE: This only does so for ConstantData, it could be extended
919 // to ConstantExpr's with further effort, however, they should
920 // largely be folded when they get here. Extending it to runtime
921 // defined/read+writeable allocation sizes would be non-trivial
922 // (need to factor in movement of any stores to variables the
923 // allocation size depends on, as well as the usual loads,
924 // otherwise it'll yield the wrong result after movement) and
925 // likely be more suitable as an LLVM optimisation pass.
926 for (Function *F : ConstantAllocaRaiseCandidates)
927 raiseUserConstantDataAllocasToEntryBlock(Builder, Function: F);
928
929 EmitMetadataErrorReportFunctionTy &&ErrorReportFn =
930 [](EmitMetadataErrorKind Kind,
931 const TargetRegionEntryInfo &EntryInfo) -> void {
932 errs() << "Error of kind: " << Kind
933 << " when emitting offload entries and metadata during "
934 "OMPIRBuilder finalization \n";
935 };
936
937 if (!OffloadInfoManager.empty())
938 createOffloadEntriesAndInfoMetadata(ErrorReportFunction&: ErrorReportFn);
939
940 if (Config.EmitLLVMUsedMetaInfo.value_or(u: false)) {
941 std::vector<WeakTrackingVH> LLVMCompilerUsed = {
942 M.getGlobalVariable(Name: "__openmp_nvptx_data_transfer_temporary_storage")};
943 emitUsed(Name: "llvm.compiler.used", List: LLVMCompilerUsed);
944 }
945
946 IsFinalized = true;
947}
948
949bool OpenMPIRBuilder::isFinalized() { return IsFinalized; }
950
951OpenMPIRBuilder::~OpenMPIRBuilder() {
952 assert(OutlineInfos.empty() && "There must be no outstanding outlinings");
953}
954
955GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
956 IntegerType *I32Ty = Type::getInt32Ty(C&: M.getContext());
957 auto *GV =
958 new GlobalVariable(M, I32Ty,
959 /* isConstant = */ true, GlobalValue::WeakODRLinkage,
960 ConstantInt::get(Ty: I32Ty, V: Value), Name);
961 GV->setVisibility(GlobalValue::HiddenVisibility);
962
963 return GV;
964}
965
966void OpenMPIRBuilder::emitUsed(StringRef Name, ArrayRef<WeakTrackingVH> List) {
967 if (List.empty())
968 return;
969
970 // Convert List to what ConstantArray needs.
971 SmallVector<Constant *, 8> UsedArray;
972 UsedArray.resize(N: List.size());
973 for (unsigned I = 0, E = List.size(); I != E; ++I)
974 UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
975 C: cast<Constant>(Val: &*List[I]), Ty: Builder.getPtrTy());
976
977 if (UsedArray.empty())
978 return;
979 ArrayType *ATy = ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: UsedArray.size());
980
981 auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
982 ConstantArray::get(T: ATy, V: UsedArray), Name);
983
984 GV->setSection("llvm.metadata");
985}
986
987GlobalVariable *
988OpenMPIRBuilder::emitKernelExecutionMode(StringRef KernelName,
989 OMPTgtExecModeFlags Mode) {
990 auto *Int8Ty = Builder.getInt8Ty();
991 auto *GVMode = new GlobalVariable(
992 M, Int8Ty, /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
993 ConstantInt::get(Ty: Int8Ty, V: Mode), Twine(KernelName, "_exec_mode"));
994 GVMode->setVisibility(GlobalVariable::ProtectedVisibility);
995 return GVMode;
996}
997
998Constant *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr,
999 uint32_t SrcLocStrSize,
1000 IdentFlag LocFlags,
1001 unsigned Reserve2Flags) {
1002 // Enable "C-mode".
1003 LocFlags |= OMP_IDENT_FLAG_KMPC;
1004
1005 Constant *&Ident =
1006 IdentMap[{SrcLocStr, uint64_t(LocFlags) << 31 | Reserve2Flags}];
1007 if (!Ident) {
1008 Constant *I32Null = ConstantInt::getNullValue(Ty: Int32);
1009 Constant *IdentData[] = {I32Null,
1010 ConstantInt::get(Ty: Int32, V: uint32_t(LocFlags)),
1011 ConstantInt::get(Ty: Int32, V: Reserve2Flags),
1012 ConstantInt::get(Ty: Int32, V: SrcLocStrSize), SrcLocStr};
1013
1014 size_t SrcLocStrArgIdx = 4;
1015 if (OpenMPIRBuilder::Ident->getElementType(N: SrcLocStrArgIdx)
1016 ->getPointerAddressSpace() !=
1017 IdentData[SrcLocStrArgIdx]->getType()->getPointerAddressSpace())
1018 IdentData[SrcLocStrArgIdx] = ConstantExpr::getAddrSpaceCast(
1019 C: SrcLocStr, Ty: OpenMPIRBuilder::Ident->getElementType(N: SrcLocStrArgIdx));
1020 Constant *Initializer =
1021 ConstantStruct::get(T: OpenMPIRBuilder::Ident, V: IdentData);
1022
1023 // Look for existing encoding of the location + flags, not needed but
1024 // minimizes the difference to the existing solution while we transition.
1025 for (GlobalVariable &GV : M.globals())
1026 if (GV.getValueType() == OpenMPIRBuilder::Ident && GV.hasInitializer())
1027 if (GV.getInitializer() == Initializer)
1028 Ident = &GV;
1029
1030 if (!Ident) {
1031 auto *GV = new GlobalVariable(
1032 M, OpenMPIRBuilder::Ident,
1033 /* isConstant = */ true, GlobalValue::PrivateLinkage, Initializer, "",
1034 nullptr, GlobalValue::NotThreadLocal,
1035 M.getDataLayout().getDefaultGlobalsAddressSpace());
1036 GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
1037 GV->setAlignment(Align(8));
1038 Ident = GV;
1039 }
1040 }
1041
1042 return ConstantExpr::getPointerBitCastOrAddrSpaceCast(C: Ident, Ty: IdentPtr);
1043}
1044
1045Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef LocStr,
1046 uint32_t &SrcLocStrSize) {
1047 SrcLocStrSize = LocStr.size();
1048 Constant *&SrcLocStr = SrcLocStrMap[LocStr];
1049 if (!SrcLocStr) {
1050 Constant *Initializer =
1051 ConstantDataArray::getString(Context&: M.getContext(), Initializer: LocStr);
1052
1053 // Look for existing encoding of the location, not needed but minimizes the
1054 // difference to the existing solution while we transition.
1055 for (GlobalVariable &GV : M.globals())
1056 if (GV.isConstant() && GV.hasInitializer() &&
1057 GV.getInitializer() == Initializer)
1058 return SrcLocStr = ConstantExpr::getPointerCast(C: &GV, Ty: Int8Ptr);
1059
1060 SrcLocStr = Builder.CreateGlobalString(
1061 Str: LocStr, /*Name=*/"", AddressSpace: M.getDataLayout().getDefaultGlobalsAddressSpace(),
1062 M: &M);
1063 }
1064 return SrcLocStr;
1065}
1066
1067Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef FunctionName,
1068 StringRef FileName,
1069 unsigned Line, unsigned Column,
1070 uint32_t &SrcLocStrSize) {
1071 SmallString<128> Buffer;
1072 Buffer.push_back(Elt: ';');
1073 Buffer.append(RHS: FileName);
1074 Buffer.push_back(Elt: ';');
1075 Buffer.append(RHS: FunctionName);
1076 Buffer.push_back(Elt: ';');
1077 Buffer.append(RHS: std::to_string(val: Line));
1078 Buffer.push_back(Elt: ';');
1079 Buffer.append(RHS: std::to_string(val: Column));
1080 Buffer.push_back(Elt: ';');
1081 Buffer.push_back(Elt: ';');
1082 return getOrCreateSrcLocStr(LocStr: Buffer.str(), SrcLocStrSize);
1083}
1084
1085Constant *
1086OpenMPIRBuilder::getOrCreateDefaultSrcLocStr(uint32_t &SrcLocStrSize) {
1087 StringRef UnknownLoc = ";unknown;unknown;0;0;;";
1088 return getOrCreateSrcLocStr(LocStr: UnknownLoc, SrcLocStrSize);
1089}
1090
1091Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(DebugLoc DL,
1092 uint32_t &SrcLocStrSize,
1093 Function *F) {
1094 DILocation *DIL = DL.get();
1095 if (!DIL)
1096 return getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1097 StringRef FileName = M.getName();
1098 if (DIFile *DIF = DIL->getFile())
1099 if (std::optional<StringRef> Source = DIF->getSource())
1100 FileName = *Source;
1101 StringRef Function = DIL->getScope()->getSubprogram()->getName();
1102 if (Function.empty() && F)
1103 Function = F->getName();
1104 return getOrCreateSrcLocStr(FunctionName: Function, FileName, Line: DIL->getLine(),
1105 Column: DIL->getColumn(), SrcLocStrSize);
1106}
1107
1108Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(const LocationDescription &Loc,
1109 uint32_t &SrcLocStrSize) {
1110 return getOrCreateSrcLocStr(DL: Loc.DL, SrcLocStrSize,
1111 F: Loc.IP.getBlock()->getParent());
1112}
1113
1114Value *OpenMPIRBuilder::getOrCreateThreadID(Value *Ident) {
1115 return createRuntimeFunctionCall(
1116 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_global_thread_num), Args: Ident,
1117 Name: "omp_global_thread_num");
1118}
1119
1120OpenMPIRBuilder::InsertPointOrErrorTy
1121OpenMPIRBuilder::createBarrier(const LocationDescription &Loc, Directive Kind,
1122 bool ForceSimpleCall, bool CheckCancelFlag) {
1123 if (!updateToLocation(Loc))
1124 return Loc.IP;
1125
1126 // Build call __kmpc_cancel_barrier(loc, thread_id) or
1127 // __kmpc_barrier(loc, thread_id);
1128
1129 IdentFlag BarrierLocFlags;
1130 switch (Kind) {
1131 case OMPD_for:
1132 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_FOR;
1133 break;
1134 case OMPD_sections:
1135 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SECTIONS;
1136 break;
1137 case OMPD_single:
1138 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SINGLE;
1139 break;
1140 case OMPD_barrier:
1141 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_EXPL;
1142 break;
1143 default:
1144 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL;
1145 break;
1146 }
1147
1148 uint32_t SrcLocStrSize;
1149 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1150 Value *Args[] = {
1151 getOrCreateIdent(SrcLocStr, SrcLocStrSize, LocFlags: BarrierLocFlags),
1152 getOrCreateThreadID(Ident: getOrCreateIdent(SrcLocStr, SrcLocStrSize))};
1153
1154 // If we are in a cancellable parallel region, barriers are cancellation
1155 // points.
1156 // TODO: Check why we would force simple calls or to ignore the cancel flag.
1157 bool UseCancelBarrier =
1158 !ForceSimpleCall && isLastFinalizationInfoCancellable(DK: OMPD_parallel);
1159
1160 Value *Result = createRuntimeFunctionCall(
1161 Callee: getOrCreateRuntimeFunctionPtr(FnID: UseCancelBarrier
1162 ? OMPRTL___kmpc_cancel_barrier
1163 : OMPRTL___kmpc_barrier),
1164 Args);
1165
1166 if (UseCancelBarrier && CheckCancelFlag)
1167 if (Error Err = emitCancelationCheckImpl(CancelFlag: Result, CanceledDirective: OMPD_parallel))
1168 return Err;
1169
1170 return Builder.saveIP();
1171}
1172
1173OpenMPIRBuilder::InsertPointOrErrorTy
1174OpenMPIRBuilder::createCancel(const LocationDescription &Loc,
1175 Value *IfCondition,
1176 omp::Directive CanceledDirective) {
1177 if (!updateToLocation(Loc))
1178 return Loc.IP;
1179
1180 // LLVM utilities like blocks with terminators.
1181 auto *UI = Builder.CreateUnreachable();
1182
1183 Instruction *ThenTI = UI, *ElseTI = nullptr;
1184 if (IfCondition) {
1185 SplitBlockAndInsertIfThenElse(Cond: IfCondition, SplitBefore: UI, ThenTerm: &ThenTI, ElseTerm: &ElseTI);
1186
1187 // Even if the if condition evaluates to false, this should count as a
1188 // cancellation point
1189 Builder.SetInsertPoint(ElseTI);
1190 auto ElseIP = Builder.saveIP();
1191
1192 InsertPointOrErrorTy IPOrErr = createCancellationPoint(
1193 Loc: LocationDescription{ElseIP, Loc.DL}, CanceledDirective);
1194 if (!IPOrErr)
1195 return IPOrErr;
1196 }
1197
1198 Builder.SetInsertPoint(ThenTI);
1199
1200 Value *CancelKind = nullptr;
1201 switch (CanceledDirective) {
1202#define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value) \
1203 case DirectiveEnum: \
1204 CancelKind = Builder.getInt32(Value); \
1205 break;
1206#include "llvm/Frontend/OpenMP/OMPKinds.def"
1207 default:
1208 llvm_unreachable("Unknown cancel kind!");
1209 }
1210
1211 uint32_t SrcLocStrSize;
1212 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1213 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1214 Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
1215 Value *Result = createRuntimeFunctionCall(
1216 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_cancel), Args);
1217
1218 // The actual cancel logic is shared with others, e.g., cancel_barriers.
1219 if (Error Err = emitCancelationCheckImpl(CancelFlag: Result, CanceledDirective))
1220 return Err;
1221
1222 // Update the insertion point and remove the terminator we introduced.
1223 Builder.SetInsertPoint(UI->getParent());
1224 UI->eraseFromParent();
1225
1226 return Builder.saveIP();
1227}
1228
1229OpenMPIRBuilder::InsertPointOrErrorTy
1230OpenMPIRBuilder::createCancellationPoint(const LocationDescription &Loc,
1231 omp::Directive CanceledDirective) {
1232 if (!updateToLocation(Loc))
1233 return Loc.IP;
1234
1235 // LLVM utilities like blocks with terminators.
1236 auto *UI = Builder.CreateUnreachable();
1237 Builder.SetInsertPoint(UI);
1238
1239 Value *CancelKind = nullptr;
1240 switch (CanceledDirective) {
1241#define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value) \
1242 case DirectiveEnum: \
1243 CancelKind = Builder.getInt32(Value); \
1244 break;
1245#include "llvm/Frontend/OpenMP/OMPKinds.def"
1246 default:
1247 llvm_unreachable("Unknown cancel kind!");
1248 }
1249
1250 uint32_t SrcLocStrSize;
1251 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1252 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1253 Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
1254 Value *Result = createRuntimeFunctionCall(
1255 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_cancellationpoint), Args);
1256
1257 // The actual cancel logic is shared with others, e.g., cancel_barriers.
1258 if (Error Err = emitCancelationCheckImpl(CancelFlag: Result, CanceledDirective))
1259 return Err;
1260
1261 // Update the insertion point and remove the terminator we introduced.
1262 Builder.SetInsertPoint(UI->getParent());
1263 UI->eraseFromParent();
1264
1265 return Builder.saveIP();
1266}
1267
1268OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetKernel(
1269 const LocationDescription &Loc, InsertPointTy AllocaIP, Value *&Return,
1270 Value *Ident, Value *DeviceID, Value *NumTeams, Value *NumThreads,
1271 Value *HostPtr, ArrayRef<Value *> KernelArgs) {
1272 if (!updateToLocation(Loc))
1273 return Loc.IP;
1274
1275 Builder.restoreIP(IP: AllocaIP);
1276 auto *KernelArgsPtr =
1277 Builder.CreateAlloca(Ty: OpenMPIRBuilder::KernelArgs, ArraySize: nullptr, Name: "kernel_args");
1278 updateToLocation(Loc);
1279
1280 for (unsigned I = 0, Size = KernelArgs.size(); I != Size; ++I) {
1281 llvm::Value *Arg =
1282 Builder.CreateStructGEP(Ty: OpenMPIRBuilder::KernelArgs, Ptr: KernelArgsPtr, Idx: I);
1283 Builder.CreateAlignedStore(
1284 Val: KernelArgs[I], Ptr: Arg,
1285 Align: M.getDataLayout().getPrefTypeAlign(Ty: KernelArgs[I]->getType()));
1286 }
1287
1288 SmallVector<Value *> OffloadingArgs{Ident, DeviceID, NumTeams,
1289 NumThreads, HostPtr, KernelArgsPtr};
1290
1291 Return = createRuntimeFunctionCall(
1292 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___tgt_target_kernel),
1293 Args: OffloadingArgs);
1294
1295 return Builder.saveIP();
1296}
1297
1298OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitKernelLaunch(
1299 const LocationDescription &Loc, Value *OutlinedFnID,
1300 EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
1301 Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP) {
1302
1303 if (!updateToLocation(Loc))
1304 return Loc.IP;
1305
1306 // On top of the arrays that were filled up, the target offloading call
1307 // takes as arguments the device id as well as the host pointer. The host
1308 // pointer is used by the runtime library to identify the current target
1309 // region, so it only has to be unique and not necessarily point to
1310 // anything. It could be the pointer to the outlined function that
1311 // implements the target region, but we aren't using that so that the
1312 // compiler doesn't need to keep that, and could therefore inline the host
1313 // function if proven worthwhile during optimization.
1314
1315 // From this point on, we need to have an ID of the target region defined.
1316 assert(OutlinedFnID && "Invalid outlined function ID!");
1317 (void)OutlinedFnID;
1318
1319 // Return value of the runtime offloading call.
1320 Value *Return = nullptr;
1321
1322 // Arguments for the target kernel.
1323 SmallVector<Value *> ArgsVector;
1324 getKernelArgsVector(KernelArgs&: Args, Builder, ArgsVector);
1325
1326 // The target region is an outlined function launched by the runtime
1327 // via calls to __tgt_target_kernel().
1328 //
1329 // Note that on the host and CPU targets, the runtime implementation of
1330 // these calls simply call the outlined function without forking threads.
1331 // The outlined functions themselves have runtime calls to
1332 // __kmpc_fork_teams() and __kmpc_fork() for this purpose, codegen'd by
1333 // the compiler in emitTeamsCall() and emitParallelCall().
1334 //
1335 // In contrast, on the NVPTX target, the implementation of
1336 // __tgt_target_teams() launches a GPU kernel with the requested number
1337 // of teams and threads so no additional calls to the runtime are required.
1338 // Check the error code and execute the host version if required.
1339 Builder.restoreIP(IP: emitTargetKernel(
1340 Loc: Builder, AllocaIP, Return, Ident: RTLoc, DeviceID, NumTeams: Args.NumTeams.front(),
1341 NumThreads: Args.NumThreads.front(), HostPtr: OutlinedFnID, KernelArgs: ArgsVector));
1342
1343 BasicBlock *OffloadFailedBlock =
1344 BasicBlock::Create(Context&: Builder.getContext(), Name: "omp_offload.failed");
1345 BasicBlock *OffloadContBlock =
1346 BasicBlock::Create(Context&: Builder.getContext(), Name: "omp_offload.cont");
1347 Value *Failed = Builder.CreateIsNotNull(Arg: Return);
1348 Builder.CreateCondBr(Cond: Failed, True: OffloadFailedBlock, False: OffloadContBlock);
1349
1350 auto CurFn = Builder.GetInsertBlock()->getParent();
1351 emitBlock(BB: OffloadFailedBlock, CurFn);
1352 InsertPointOrErrorTy AfterIP = EmitTargetCallFallbackCB(Builder.saveIP());
1353 if (!AfterIP)
1354 return AfterIP.takeError();
1355 Builder.restoreIP(IP: *AfterIP);
1356 emitBranch(Target: OffloadContBlock);
1357 emitBlock(BB: OffloadContBlock, CurFn, /*IsFinished=*/true);
1358 return Builder.saveIP();
1359}
1360
1361Error OpenMPIRBuilder::emitCancelationCheckImpl(
1362 Value *CancelFlag, omp::Directive CanceledDirective) {
1363 assert(isLastFinalizationInfoCancellable(CanceledDirective) &&
1364 "Unexpected cancellation!");
1365
1366 // For a cancel barrier we create two new blocks.
1367 BasicBlock *BB = Builder.GetInsertBlock();
1368 BasicBlock *NonCancellationBlock;
1369 if (Builder.GetInsertPoint() == BB->end()) {
1370 // TODO: This branch will not be needed once we moved to the
1371 // OpenMPIRBuilder codegen completely.
1372 NonCancellationBlock = BasicBlock::Create(
1373 Context&: BB->getContext(), Name: BB->getName() + ".cont", Parent: BB->getParent());
1374 } else {
1375 NonCancellationBlock = SplitBlock(Old: BB, SplitPt: &*Builder.GetInsertPoint());
1376 BB->getTerminator()->eraseFromParent();
1377 Builder.SetInsertPoint(BB);
1378 }
1379 BasicBlock *CancellationBlock = BasicBlock::Create(
1380 Context&: BB->getContext(), Name: BB->getName() + ".cncl", Parent: BB->getParent());
1381
1382 // Jump to them based on the return value.
1383 Value *Cmp = Builder.CreateIsNull(Arg: CancelFlag);
1384 Builder.CreateCondBr(Cond: Cmp, True: NonCancellationBlock, False: CancellationBlock,
1385 /* TODO weight */ BranchWeights: nullptr, Unpredictable: nullptr);
1386
1387 // From the cancellation block we finalize all variables and go to the
1388 // post finalization block that is known to the FiniCB callback.
1389 auto &FI = FinalizationStack.back();
1390 Expected<BasicBlock *> FiniBBOrErr = FI.getFiniBB(Builder);
1391 if (!FiniBBOrErr)
1392 return FiniBBOrErr.takeError();
1393 Builder.SetInsertPoint(CancellationBlock);
1394 Builder.CreateBr(Dest: *FiniBBOrErr);
1395
1396 // The continuation block is where code generation continues.
1397 Builder.SetInsertPoint(TheBB: NonCancellationBlock, IP: NonCancellationBlock->begin());
1398 return Error::success();
1399}
1400
1401// Callback used to create OpenMP runtime calls to support
1402// omp parallel clause for the device.
1403// We need to use this callback to replace call to the OutlinedFn in OuterFn
1404// by the call to the OpenMP DeviceRTL runtime function (kmpc_parallel_60)
1405static void targetParallelCallback(
1406 OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn, Function *OuterFn,
1407 BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
1408 Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1409 Value *ThreadID, const SmallVector<Instruction *, 4> &ToBeDeleted) {
1410 // Add some known attributes.
1411 IRBuilder<> &Builder = OMPIRBuilder->Builder;
1412 OutlinedFn.addParamAttr(ArgNo: 0, Kind: Attribute::NoAlias);
1413 OutlinedFn.addParamAttr(ArgNo: 1, Kind: Attribute::NoAlias);
1414 OutlinedFn.addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
1415 OutlinedFn.addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
1416 OutlinedFn.addFnAttr(Kind: Attribute::NoUnwind);
1417
1418 assert(OutlinedFn.arg_size() >= 2 &&
1419 "Expected at least tid and bounded tid as arguments");
1420 unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1421
1422 CallInst *CI = cast<CallInst>(Val: OutlinedFn.user_back());
1423 assert(CI && "Expected call instruction to outlined function");
1424 CI->getParent()->setName("omp_parallel");
1425
1426 Builder.SetInsertPoint(CI);
1427 Type *PtrTy = OMPIRBuilder->VoidPtr;
1428 Value *NullPtrValue = Constant::getNullValue(Ty: PtrTy);
1429
1430 // Add alloca for kernel args
1431 OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP();
1432 Builder.SetInsertPoint(TheBB: OuterAllocaBB, IP: OuterAllocaBB->getFirstInsertionPt());
1433 AllocaInst *ArgsAlloca =
1434 Builder.CreateAlloca(Ty: ArrayType::get(ElementType: PtrTy, NumElements: NumCapturedVars));
1435 Value *Args = ArgsAlloca;
1436 // Add address space cast if array for storing arguments is not allocated
1437 // in address space 0
1438 if (ArgsAlloca->getAddressSpace())
1439 Args = Builder.CreatePointerCast(V: ArgsAlloca, DestTy: PtrTy);
1440 Builder.restoreIP(IP: CurrentIP);
1441
1442 // Store captured vars which are used by kmpc_parallel_60
1443 for (unsigned Idx = 0; Idx < NumCapturedVars; Idx++) {
1444 Value *V = *(CI->arg_begin() + 2 + Idx);
1445 Value *StoreAddress = Builder.CreateConstInBoundsGEP2_64(
1446 Ty: ArrayType::get(ElementType: PtrTy, NumElements: NumCapturedVars), Ptr: Args, Idx0: 0, Idx1: Idx);
1447 Builder.CreateStore(Val: V, Ptr: StoreAddress);
1448 }
1449
1450 Value *Cond =
1451 IfCondition ? Builder.CreateSExtOrTrunc(V: IfCondition, DestTy: OMPIRBuilder->Int32)
1452 : Builder.getInt32(C: 1);
1453
1454 // Build kmpc_parallel_60 call
1455 Value *Parallel60CallArgs[] = {
1456 /* identifier*/ Ident,
1457 /* global thread num*/ ThreadID,
1458 /* if expression */ Cond,
1459 /* number of threads */ NumThreads ? NumThreads : Builder.getInt32(C: -1),
1460 /* Proc bind */ Builder.getInt32(C: -1),
1461 /* outlined function */ &OutlinedFn,
1462 /* wrapper function */ NullPtrValue,
1463 /* arguments of the outlined funciton*/ Args,
1464 /* number of arguments */ Builder.getInt64(C: NumCapturedVars),
1465 /* strict for number of threads */ Builder.getInt32(C: 0)};
1466
1467 FunctionCallee RTLFn =
1468 OMPIRBuilder->getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_parallel_60);
1469
1470 OMPIRBuilder->createRuntimeFunctionCall(Callee: RTLFn, Args: Parallel60CallArgs);
1471
1472 LLVM_DEBUG(dbgs() << "With kmpc_parallel_60 placed: "
1473 << *Builder.GetInsertBlock()->getParent() << "\n");
1474
1475 // Initialize the local TID stack location with the argument value.
1476 Builder.SetInsertPoint(PrivTID);
1477 Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1478 Builder.CreateStore(Val: Builder.CreateLoad(Ty: OMPIRBuilder->Int32, Ptr: OutlinedAI),
1479 Ptr: PrivTIDAddr);
1480
1481 // Remove redundant call to the outlined function.
1482 CI->eraseFromParent();
1483
1484 for (Instruction *I : ToBeDeleted) {
1485 I->eraseFromParent();
1486 }
1487}
1488
1489// Callback used to create OpenMP runtime calls to support
1490// omp parallel clause for the host.
1491// We need to use this callback to replace call to the OutlinedFn in OuterFn
1492// by the call to the OpenMP host runtime function ( __kmpc_fork_call[_if])
1493static void
1494hostParallelCallback(OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn,
1495 Function *OuterFn, Value *Ident, Value *IfCondition,
1496 Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1497 const SmallVector<Instruction *, 4> &ToBeDeleted) {
1498 IRBuilder<> &Builder = OMPIRBuilder->Builder;
1499 FunctionCallee RTLFn;
1500 if (IfCondition) {
1501 RTLFn =
1502 OMPIRBuilder->getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_fork_call_if);
1503 } else {
1504 RTLFn =
1505 OMPIRBuilder->getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_fork_call);
1506 }
1507 if (auto *F = dyn_cast<Function>(Val: RTLFn.getCallee())) {
1508 if (!F->hasMetadata(KindID: LLVMContext::MD_callback)) {
1509 LLVMContext &Ctx = F->getContext();
1510 MDBuilder MDB(Ctx);
1511 // Annotate the callback behavior of the __kmpc_fork_call:
1512 // - The callback callee is argument number 2 (microtask).
1513 // - The first two arguments of the callback callee are unknown (-1).
1514 // - All variadic arguments to the __kmpc_fork_call are passed to the
1515 // callback callee.
1516 F->addMetadata(KindID: LLVMContext::MD_callback,
1517 MD&: *MDNode::get(Context&: Ctx, MDs: {MDB.createCallbackEncoding(
1518 CalleeArgNo: 2, Arguments: {-1, -1},
1519 /* VarArgsArePassed */ true)}));
1520 }
1521 }
1522 // Add some known attributes.
1523 OutlinedFn.addParamAttr(ArgNo: 0, Kind: Attribute::NoAlias);
1524 OutlinedFn.addParamAttr(ArgNo: 1, Kind: Attribute::NoAlias);
1525 OutlinedFn.addFnAttr(Kind: Attribute::NoUnwind);
1526
1527 assert(OutlinedFn.arg_size() >= 2 &&
1528 "Expected at least tid and bounded tid as arguments");
1529 unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1530
1531 CallInst *CI = cast<CallInst>(Val: OutlinedFn.user_back());
1532 CI->getParent()->setName("omp_parallel");
1533 Builder.SetInsertPoint(CI);
1534
1535 // Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn);
1536 Value *ForkCallArgs[] = {Ident, Builder.getInt32(C: NumCapturedVars),
1537 &OutlinedFn};
1538
1539 SmallVector<Value *, 16> RealArgs;
1540 RealArgs.append(in_start: std::begin(arr&: ForkCallArgs), in_end: std::end(arr&: ForkCallArgs));
1541 if (IfCondition) {
1542 Value *Cond = Builder.CreateSExtOrTrunc(V: IfCondition, DestTy: OMPIRBuilder->Int32);
1543 RealArgs.push_back(Elt: Cond);
1544 }
1545 RealArgs.append(in_start: CI->arg_begin() + /* tid & bound tid */ 2, in_end: CI->arg_end());
1546
1547 // __kmpc_fork_call_if always expects a void ptr as the last argument
1548 // If there are no arguments, pass a null pointer.
1549 auto PtrTy = OMPIRBuilder->VoidPtr;
1550 if (IfCondition && NumCapturedVars == 0) {
1551 Value *NullPtrValue = Constant::getNullValue(Ty: PtrTy);
1552 RealArgs.push_back(Elt: NullPtrValue);
1553 }
1554
1555 OMPIRBuilder->createRuntimeFunctionCall(Callee: RTLFn, Args: RealArgs);
1556
1557 LLVM_DEBUG(dbgs() << "With fork_call placed: "
1558 << *Builder.GetInsertBlock()->getParent() << "\n");
1559
1560 // Initialize the local TID stack location with the argument value.
1561 Builder.SetInsertPoint(PrivTID);
1562 Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1563 Builder.CreateStore(Val: Builder.CreateLoad(Ty: OMPIRBuilder->Int32, Ptr: OutlinedAI),
1564 Ptr: PrivTIDAddr);
1565
1566 // Remove redundant call to the outlined function.
1567 CI->eraseFromParent();
1568
1569 for (Instruction *I : ToBeDeleted) {
1570 I->eraseFromParent();
1571 }
1572}
1573
1574OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel(
1575 const LocationDescription &Loc, InsertPointTy OuterAllocaIP,
1576 BodyGenCallbackTy BodyGenCB, PrivatizeCallbackTy PrivCB,
1577 FinalizeCallbackTy FiniCB, Value *IfCondition, Value *NumThreads,
1578 omp::ProcBindKind ProcBind, bool IsCancellable) {
1579 assert(!isConflictIP(Loc.IP, OuterAllocaIP) && "IPs must not be ambiguous");
1580
1581 if (!updateToLocation(Loc))
1582 return Loc.IP;
1583
1584 uint32_t SrcLocStrSize;
1585 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1586 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1587 Value *ThreadID = getOrCreateThreadID(Ident);
1588 // If we generate code for the target device, we need to allocate
1589 // struct for aggregate params in the device default alloca address space.
1590 // OpenMP runtime requires that the params of the extracted functions are
1591 // passed as zero address space pointers. This flag ensures that extracted
1592 // function arguments are declared in zero address space
1593 bool ArgsInZeroAddressSpace = Config.isTargetDevice();
1594
1595 // Build call __kmpc_push_num_threads(&Ident, global_tid, num_threads)
1596 // only if we compile for host side.
1597 if (NumThreads && !Config.isTargetDevice()) {
1598 Value *Args[] = {
1599 Ident, ThreadID,
1600 Builder.CreateIntCast(V: NumThreads, DestTy: Int32, /*isSigned*/ false)};
1601 createRuntimeFunctionCall(
1602 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_push_num_threads), Args);
1603 }
1604
1605 if (ProcBind != OMP_PROC_BIND_default) {
1606 // Build call __kmpc_push_proc_bind(&Ident, global_tid, proc_bind)
1607 Value *Args[] = {
1608 Ident, ThreadID,
1609 ConstantInt::get(Ty: Int32, V: unsigned(ProcBind), /*isSigned=*/IsSigned: true)};
1610 createRuntimeFunctionCall(
1611 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_push_proc_bind), Args);
1612 }
1613
1614 BasicBlock *InsertBB = Builder.GetInsertBlock();
1615 Function *OuterFn = InsertBB->getParent();
1616
1617 // Save the outer alloca block because the insertion iterator may get
1618 // invalidated and we still need this later.
1619 BasicBlock *OuterAllocaBlock = OuterAllocaIP.getBlock();
1620
1621 // Vector to remember instructions we used only during the modeling but which
1622 // we want to delete at the end.
1623 SmallVector<Instruction *, 4> ToBeDeleted;
1624
1625 // Change the location to the outer alloca insertion point to create and
1626 // initialize the allocas we pass into the parallel region.
1627 InsertPointTy NewOuter(OuterAllocaBlock, OuterAllocaBlock->begin());
1628 Builder.restoreIP(IP: NewOuter);
1629 AllocaInst *TIDAddrAlloca = Builder.CreateAlloca(Ty: Int32, ArraySize: nullptr, Name: "tid.addr");
1630 AllocaInst *ZeroAddrAlloca =
1631 Builder.CreateAlloca(Ty: Int32, ArraySize: nullptr, Name: "zero.addr");
1632 Instruction *TIDAddr = TIDAddrAlloca;
1633 Instruction *ZeroAddr = ZeroAddrAlloca;
1634 if (ArgsInZeroAddressSpace && M.getDataLayout().getAllocaAddrSpace() != 0) {
1635 // Add additional casts to enforce pointers in zero address space
1636 TIDAddr = new AddrSpaceCastInst(
1637 TIDAddrAlloca, PointerType ::get(C&: M.getContext(), AddressSpace: 0), "tid.addr.ascast");
1638 TIDAddr->insertAfter(InsertPos: TIDAddrAlloca->getIterator());
1639 ToBeDeleted.push_back(Elt: TIDAddr);
1640 ZeroAddr = new AddrSpaceCastInst(ZeroAddrAlloca,
1641 PointerType ::get(C&: M.getContext(), AddressSpace: 0),
1642 "zero.addr.ascast");
1643 ZeroAddr->insertAfter(InsertPos: ZeroAddrAlloca->getIterator());
1644 ToBeDeleted.push_back(Elt: ZeroAddr);
1645 }
1646
1647 // We only need TIDAddr and ZeroAddr for modeling purposes to get the
1648 // associated arguments in the outlined function, so we delete them later.
1649 ToBeDeleted.push_back(Elt: TIDAddrAlloca);
1650 ToBeDeleted.push_back(Elt: ZeroAddrAlloca);
1651
1652 // Create an artificial insertion point that will also ensure the blocks we
1653 // are about to split are not degenerated.
1654 auto *UI = new UnreachableInst(Builder.getContext(), InsertBB);
1655
1656 BasicBlock *EntryBB = UI->getParent();
1657 BasicBlock *PRegEntryBB = EntryBB->splitBasicBlock(I: UI, BBName: "omp.par.entry");
1658 BasicBlock *PRegBodyBB = PRegEntryBB->splitBasicBlock(I: UI, BBName: "omp.par.region");
1659 BasicBlock *PRegPreFiniBB =
1660 PRegBodyBB->splitBasicBlock(I: UI, BBName: "omp.par.pre_finalize");
1661 BasicBlock *PRegExitBB = PRegPreFiniBB->splitBasicBlock(I: UI, BBName: "omp.par.exit");
1662
1663 auto FiniCBWrapper = [&](InsertPointTy IP) {
1664 // Hide "open-ended" blocks from the given FiniCB by setting the right jump
1665 // target to the region exit block.
1666 if (IP.getBlock()->end() == IP.getPoint()) {
1667 IRBuilder<>::InsertPointGuard IPG(Builder);
1668 Builder.restoreIP(IP);
1669 Instruction *I = Builder.CreateBr(Dest: PRegExitBB);
1670 IP = InsertPointTy(I->getParent(), I->getIterator());
1671 }
1672 assert(IP.getBlock()->getTerminator()->getNumSuccessors() == 1 &&
1673 IP.getBlock()->getTerminator()->getSuccessor(0) == PRegExitBB &&
1674 "Unexpected insertion point for finalization call!");
1675 return FiniCB(IP);
1676 };
1677
1678 FinalizationStack.push_back(Elt: {FiniCBWrapper, OMPD_parallel, IsCancellable});
1679
1680 // Generate the privatization allocas in the block that will become the entry
1681 // of the outlined function.
1682 Builder.SetInsertPoint(PRegEntryBB->getTerminator());
1683 InsertPointTy InnerAllocaIP = Builder.saveIP();
1684
1685 AllocaInst *PrivTIDAddr =
1686 Builder.CreateAlloca(Ty: Int32, ArraySize: nullptr, Name: "tid.addr.local");
1687 Instruction *PrivTID = Builder.CreateLoad(Ty: Int32, Ptr: PrivTIDAddr, Name: "tid");
1688
1689 // Add some fake uses for OpenMP provided arguments.
1690 ToBeDeleted.push_back(Elt: Builder.CreateLoad(Ty: Int32, Ptr: TIDAddr, Name: "tid.addr.use"));
1691 Instruction *ZeroAddrUse =
1692 Builder.CreateLoad(Ty: Int32, Ptr: ZeroAddr, Name: "zero.addr.use");
1693 ToBeDeleted.push_back(Elt: ZeroAddrUse);
1694
1695 // EntryBB
1696 // |
1697 // V
1698 // PRegionEntryBB <- Privatization allocas are placed here.
1699 // |
1700 // V
1701 // PRegionBodyBB <- BodeGen is invoked here.
1702 // |
1703 // V
1704 // PRegPreFiniBB <- The block we will start finalization from.
1705 // |
1706 // V
1707 // PRegionExitBB <- A common exit to simplify block collection.
1708 //
1709
1710 LLVM_DEBUG(dbgs() << "Before body codegen: " << *OuterFn << "\n");
1711
1712 // Let the caller create the body.
1713 assert(BodyGenCB && "Expected body generation callback!");
1714 InsertPointTy CodeGenIP(PRegBodyBB, PRegBodyBB->begin());
1715 if (Error Err = BodyGenCB(InnerAllocaIP, CodeGenIP))
1716 return Err;
1717
1718 LLVM_DEBUG(dbgs() << "After body codegen: " << *OuterFn << "\n");
1719
1720 OutlineInfo OI;
1721 if (Config.isTargetDevice()) {
1722 // Generate OpenMP target specific runtime call
1723 OI.PostOutlineCB = [=, ToBeDeletedVec =
1724 std::move(ToBeDeleted)](Function &OutlinedFn) {
1725 targetParallelCallback(OMPIRBuilder: this, OutlinedFn, OuterFn, OuterAllocaBB: OuterAllocaBlock, Ident,
1726 IfCondition, NumThreads, PrivTID, PrivTIDAddr,
1727 ThreadID, ToBeDeleted: ToBeDeletedVec);
1728 };
1729 OI.FixUpNonEntryAllocas = true;
1730 } else {
1731 // Generate OpenMP host runtime call
1732 OI.PostOutlineCB = [=, ToBeDeletedVec =
1733 std::move(ToBeDeleted)](Function &OutlinedFn) {
1734 hostParallelCallback(OMPIRBuilder: this, OutlinedFn, OuterFn, Ident, IfCondition,
1735 PrivTID, PrivTIDAddr, ToBeDeleted: ToBeDeletedVec);
1736 };
1737 OI.FixUpNonEntryAllocas = true;
1738 }
1739
1740 OI.OuterAllocaBB = OuterAllocaBlock;
1741 OI.EntryBB = PRegEntryBB;
1742 OI.ExitBB = PRegExitBB;
1743
1744 SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
1745 SmallVector<BasicBlock *, 32> Blocks;
1746 OI.collectBlocks(BlockSet&: ParallelRegionBlockSet, BlockVector&: Blocks);
1747
1748 CodeExtractorAnalysisCache CEAC(*OuterFn);
1749 CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
1750 /* AggregateArgs */ false,
1751 /* BlockFrequencyInfo */ nullptr,
1752 /* BranchProbabilityInfo */ nullptr,
1753 /* AssumptionCache */ nullptr,
1754 /* AllowVarArgs */ true,
1755 /* AllowAlloca */ true,
1756 /* AllocationBlock */ OuterAllocaBlock,
1757 /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
1758
1759 // Find inputs to, outputs from the code region.
1760 BasicBlock *CommonExit = nullptr;
1761 SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
1762 Extractor.findAllocas(CEAC, SinkCands&: SinkingCands, HoistCands&: HoistingCands, ExitBlock&: CommonExit);
1763
1764 Extractor.findInputsOutputs(Inputs, Outputs, Allocas: SinkingCands,
1765 /*CollectGlobalInputs=*/true);
1766
1767 Inputs.remove_if(P: [&](Value *I) {
1768 if (auto *GV = dyn_cast_if_present<GlobalVariable>(Val: I))
1769 return GV->getValueType() == OpenMPIRBuilder::Ident;
1770
1771 return false;
1772 });
1773
1774 LLVM_DEBUG(dbgs() << "Before privatization: " << *OuterFn << "\n");
1775
1776 FunctionCallee TIDRTLFn =
1777 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_global_thread_num);
1778
1779 auto PrivHelper = [&](Value &V) -> Error {
1780 if (&V == TIDAddr || &V == ZeroAddr) {
1781 OI.ExcludeArgsFromAggregate.push_back(Elt: &V);
1782 return Error::success();
1783 }
1784
1785 SetVector<Use *> Uses;
1786 for (Use &U : V.uses())
1787 if (auto *UserI = dyn_cast<Instruction>(Val: U.getUser()))
1788 if (ParallelRegionBlockSet.count(Ptr: UserI->getParent()))
1789 Uses.insert(X: &U);
1790
1791 // __kmpc_fork_call expects extra arguments as pointers. If the input
1792 // already has a pointer type, everything is fine. Otherwise, store the
1793 // value onto stack and load it back inside the to-be-outlined region. This
1794 // will ensure only the pointer will be passed to the function.
1795 // FIXME: if there are more than 15 trailing arguments, they must be
1796 // additionally packed in a struct.
1797 Value *Inner = &V;
1798 if (!V.getType()->isPointerTy()) {
1799 IRBuilder<>::InsertPointGuard Guard(Builder);
1800 LLVM_DEBUG(llvm::dbgs() << "Forwarding input as pointer: " << V << "\n");
1801
1802 Builder.restoreIP(IP: OuterAllocaIP);
1803 Value *Ptr =
1804 Builder.CreateAlloca(Ty: V.getType(), ArraySize: nullptr, Name: V.getName() + ".reloaded");
1805
1806 // Store to stack at end of the block that currently branches to the entry
1807 // block of the to-be-outlined region.
1808 Builder.SetInsertPoint(TheBB: InsertBB,
1809 IP: InsertBB->getTerminator()->getIterator());
1810 Builder.CreateStore(Val: &V, Ptr);
1811
1812 // Load back next to allocations in the to-be-outlined region.
1813 Builder.restoreIP(IP: InnerAllocaIP);
1814 Inner = Builder.CreateLoad(Ty: V.getType(), Ptr);
1815 }
1816
1817 Value *ReplacementValue = nullptr;
1818 CallInst *CI = dyn_cast<CallInst>(Val: &V);
1819 if (CI && CI->getCalledFunction() == TIDRTLFn.getCallee()) {
1820 ReplacementValue = PrivTID;
1821 } else {
1822 InsertPointOrErrorTy AfterIP =
1823 PrivCB(InnerAllocaIP, Builder.saveIP(), V, *Inner, ReplacementValue);
1824 if (!AfterIP)
1825 return AfterIP.takeError();
1826 Builder.restoreIP(IP: *AfterIP);
1827 InnerAllocaIP = {
1828 InnerAllocaIP.getBlock(),
1829 InnerAllocaIP.getBlock()->getTerminator()->getIterator()};
1830
1831 assert(ReplacementValue &&
1832 "Expected copy/create callback to set replacement value!");
1833 if (ReplacementValue == &V)
1834 return Error::success();
1835 }
1836
1837 for (Use *UPtr : Uses)
1838 UPtr->set(ReplacementValue);
1839
1840 return Error::success();
1841 };
1842
1843 // Reset the inner alloca insertion as it will be used for loading the values
1844 // wrapped into pointers before passing them into the to-be-outlined region.
1845 // Configure it to insert immediately after the fake use of zero address so
1846 // that they are available in the generated body and so that the
1847 // OpenMP-related values (thread ID and zero address pointers) remain leading
1848 // in the argument list.
1849 InnerAllocaIP = IRBuilder<>::InsertPoint(
1850 ZeroAddrUse->getParent(), ZeroAddrUse->getNextNode()->getIterator());
1851
1852 // Reset the outer alloca insertion point to the entry of the relevant block
1853 // in case it was invalidated.
1854 OuterAllocaIP = IRBuilder<>::InsertPoint(
1855 OuterAllocaBlock, OuterAllocaBlock->getFirstInsertionPt());
1856
1857 for (Value *Input : Inputs) {
1858 LLVM_DEBUG(dbgs() << "Captured input: " << *Input << "\n");
1859 if (Error Err = PrivHelper(*Input))
1860 return Err;
1861 }
1862 LLVM_DEBUG({
1863 for (Value *Output : Outputs)
1864 LLVM_DEBUG(dbgs() << "Captured output: " << *Output << "\n");
1865 });
1866 assert(Outputs.empty() &&
1867 "OpenMP outlining should not produce live-out values!");
1868
1869 LLVM_DEBUG(dbgs() << "After privatization: " << *OuterFn << "\n");
1870 LLVM_DEBUG({
1871 for (auto *BB : Blocks)
1872 dbgs() << " PBR: " << BB->getName() << "\n";
1873 });
1874
1875 // Adjust the finalization stack, verify the adjustment, and call the
1876 // finalize function a last time to finalize values between the pre-fini
1877 // block and the exit block if we left the parallel "the normal way".
1878 auto FiniInfo = FinalizationStack.pop_back_val();
1879 (void)FiniInfo;
1880 assert(FiniInfo.DK == OMPD_parallel &&
1881 "Unexpected finalization stack state!");
1882
1883 Instruction *PRegPreFiniTI = PRegPreFiniBB->getTerminator();
1884
1885 InsertPointTy PreFiniIP(PRegPreFiniBB, PRegPreFiniTI->getIterator());
1886 Expected<BasicBlock *> FiniBBOrErr = FiniInfo.getFiniBB(Builder);
1887 if (!FiniBBOrErr)
1888 return FiniBBOrErr.takeError();
1889 {
1890 IRBuilderBase::InsertPointGuard Guard(Builder);
1891 Builder.restoreIP(IP: PreFiniIP);
1892 Builder.CreateBr(Dest: *FiniBBOrErr);
1893 // There's currently a branch to omp.par.exit. Delete it. We will get there
1894 // via the fini block
1895 if (Instruction *Term = Builder.GetInsertBlock()->getTerminator())
1896 Term->eraseFromParent();
1897 }
1898
1899 // Register the outlined info.
1900 addOutlineInfo(OI: std::move(OI));
1901
1902 InsertPointTy AfterIP(UI->getParent(), UI->getParent()->end());
1903 UI->eraseFromParent();
1904
1905 return AfterIP;
1906}
1907
1908void OpenMPIRBuilder::emitFlush(const LocationDescription &Loc) {
1909 // Build call void __kmpc_flush(ident_t *loc)
1910 uint32_t SrcLocStrSize;
1911 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1912 Value *Args[] = {getOrCreateIdent(SrcLocStr, SrcLocStrSize)};
1913
1914 createRuntimeFunctionCall(Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_flush),
1915 Args);
1916}
1917
1918void OpenMPIRBuilder::createFlush(const LocationDescription &Loc) {
1919 if (!updateToLocation(Loc))
1920 return;
1921 emitFlush(Loc);
1922}
1923
1924void OpenMPIRBuilder::emitTaskwaitImpl(const LocationDescription &Loc) {
1925 // Build call kmp_int32 __kmpc_omp_taskwait(ident_t *loc, kmp_int32
1926 // global_tid);
1927 uint32_t SrcLocStrSize;
1928 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1929 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1930 Value *Args[] = {Ident, getOrCreateThreadID(Ident)};
1931
1932 // Ignore return result until untied tasks are supported.
1933 createRuntimeFunctionCall(
1934 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_taskwait), Args);
1935}
1936
1937void OpenMPIRBuilder::createTaskwait(const LocationDescription &Loc) {
1938 if (!updateToLocation(Loc))
1939 return;
1940 emitTaskwaitImpl(Loc);
1941}
1942
1943void OpenMPIRBuilder::emitTaskyieldImpl(const LocationDescription &Loc) {
1944 // Build call __kmpc_omp_taskyield(loc, thread_id, 0);
1945 uint32_t SrcLocStrSize;
1946 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1947 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1948 Constant *I32Null = ConstantInt::getNullValue(Ty: Int32);
1949 Value *Args[] = {Ident, getOrCreateThreadID(Ident), I32Null};
1950
1951 createRuntimeFunctionCall(
1952 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_taskyield), Args);
1953}
1954
1955void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
1956 if (!updateToLocation(Loc))
1957 return;
1958 emitTaskyieldImpl(Loc);
1959}
1960
1961// Processes the dependencies in Dependencies and does the following
1962// - Allocates space on the stack of an array of DependInfo objects
1963// - Populates each DependInfo object with relevant information of
1964// the corresponding dependence.
1965// - All code is inserted in the entry block of the current function.
1966static Value *emitTaskDependencies(
1967 OpenMPIRBuilder &OMPBuilder,
1968 const SmallVectorImpl<OpenMPIRBuilder::DependData> &Dependencies) {
1969 // Early return if we have no dependencies to process
1970 if (Dependencies.empty())
1971 return nullptr;
1972
1973 // Given a vector of DependData objects, in this function we create an
1974 // array on the stack that holds kmp_dep_info objects corresponding
1975 // to each dependency. This is then passed to the OpenMP runtime.
1976 // For example, if there are 'n' dependencies then the following psedo
1977 // code is generated. Assume the first dependence is on a variable 'a'
1978 //
1979 // \code{c}
1980 // DepArray = alloc(n x sizeof(kmp_depend_info);
1981 // idx = 0;
1982 // DepArray[idx].base_addr = ptrtoint(&a);
1983 // DepArray[idx].len = 8;
1984 // DepArray[idx].flags = Dep.DepKind; /*(See OMPContants.h for DepKind)*/
1985 // ++idx;
1986 // DepArray[idx].base_addr = ...;
1987 // \endcode
1988
1989 IRBuilderBase &Builder = OMPBuilder.Builder;
1990 Type *DependInfo = OMPBuilder.DependInfo;
1991 Module &M = OMPBuilder.M;
1992
1993 Value *DepArray = nullptr;
1994 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
1995 Builder.SetInsertPoint(
1996 OldIP.getBlock()->getParent()->getEntryBlock().getTerminator());
1997
1998 Type *DepArrayTy = ArrayType::get(ElementType: DependInfo, NumElements: Dependencies.size());
1999 DepArray = Builder.CreateAlloca(Ty: DepArrayTy, ArraySize: nullptr, Name: ".dep.arr.addr");
2000
2001 Builder.restoreIP(IP: OldIP);
2002
2003 for (const auto &[DepIdx, Dep] : enumerate(First: Dependencies)) {
2004 Value *Base =
2005 Builder.CreateConstInBoundsGEP2_64(Ty: DepArrayTy, Ptr: DepArray, Idx0: 0, Idx1: DepIdx);
2006 // Store the pointer to the variable
2007 Value *Addr = Builder.CreateStructGEP(
2008 Ty: DependInfo, Ptr: Base,
2009 Idx: static_cast<unsigned int>(RTLDependInfoFields::BaseAddr));
2010 Value *DepValPtr = Builder.CreatePtrToInt(V: Dep.DepVal, DestTy: Builder.getInt64Ty());
2011 Builder.CreateStore(Val: DepValPtr, Ptr: Addr);
2012 // Store the size of the variable
2013 Value *Size = Builder.CreateStructGEP(
2014 Ty: DependInfo, Ptr: Base, Idx: static_cast<unsigned int>(RTLDependInfoFields::Len));
2015 Builder.CreateStore(
2016 Val: Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: Dep.DepValueType)),
2017 Ptr: Size);
2018 // Store the dependency kind
2019 Value *Flags = Builder.CreateStructGEP(
2020 Ty: DependInfo, Ptr: Base,
2021 Idx: static_cast<unsigned int>(RTLDependInfoFields::Flags));
2022 Builder.CreateStore(
2023 Val: ConstantInt::get(Ty: Builder.getInt8Ty(),
2024 V: static_cast<unsigned int>(Dep.DepKind)),
2025 Ptr: Flags);
2026 }
2027 return DepArray;
2028}
2029
2030/// Create the task duplication function passed to kmpc_taskloop.
2031Expected<Value *> OpenMPIRBuilder::createTaskDuplicationFunction(
2032 Type *PrivatesTy, int32_t PrivatesIndex, TaskDupCallbackTy DupCB) {
2033 unsigned ProgramAddressSpace = M.getDataLayout().getProgramAddressSpace();
2034 if (!DupCB)
2035 return Constant::getNullValue(
2036 Ty: PointerType::get(C&: Builder.getContext(), AddressSpace: ProgramAddressSpace));
2037
2038 // From OpenMP Runtime p_task_dup_t:
2039 // Routine optionally generated by the compiler for setting the lastprivate
2040 // flag and calling needed constructors for private/firstprivate objects (used
2041 // to form taskloop tasks from pattern task) Parameters: dest task, src task,
2042 // lastprivate flag.
2043 // typedef void (*p_task_dup_t)(kmp_task_t *, kmp_task_t *, kmp_int32);
2044
2045 auto *VoidPtrTy = PointerType::get(C&: Builder.getContext(), AddressSpace: ProgramAddressSpace);
2046
2047 FunctionType *DupFuncTy = FunctionType::get(
2048 Result: Builder.getVoidTy(), Params: {VoidPtrTy, VoidPtrTy, Builder.getInt32Ty()},
2049 /*isVarArg=*/false);
2050
2051 Function *DupFunction = Function::Create(Ty: DupFuncTy, Linkage: Function::InternalLinkage,
2052 N: "omp_taskloop_dup", M);
2053 Value *DestTaskArg = DupFunction->getArg(i: 0);
2054 Value *SrcTaskArg = DupFunction->getArg(i: 1);
2055 Value *LastprivateFlagArg = DupFunction->getArg(i: 2);
2056 DestTaskArg->setName("dest_task");
2057 SrcTaskArg->setName("src_task");
2058 LastprivateFlagArg->setName("lastprivate_flag");
2059
2060 IRBuilderBase::InsertPointGuard Guard(Builder);
2061 Builder.SetInsertPoint(
2062 BasicBlock::Create(Context&: Builder.getContext(), Name: "entry", Parent: DupFunction));
2063
2064 auto GetTaskContextPtrFromArg = [&](Value *Arg) -> Value * {
2065 Type *TaskWithPrivatesTy =
2066 StructType::get(Context&: Builder.getContext(), Elements: {Task, PrivatesTy});
2067 Value *TaskPrivates = Builder.CreateGEP(
2068 Ty: TaskWithPrivatesTy, Ptr: Arg, IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 1)});
2069 Value *ContextPtr = Builder.CreateGEP(
2070 Ty: PrivatesTy, Ptr: TaskPrivates,
2071 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: PrivatesIndex)});
2072 return ContextPtr;
2073 };
2074
2075 Value *DestTaskContextPtr = GetTaskContextPtrFromArg(DestTaskArg);
2076 Value *SrcTaskContextPtr = GetTaskContextPtrFromArg(SrcTaskArg);
2077
2078 DestTaskContextPtr->setName("destPtr");
2079 SrcTaskContextPtr->setName("srcPtr");
2080
2081 InsertPointTy AllocaIP(&DupFunction->getEntryBlock(),
2082 DupFunction->getEntryBlock().begin());
2083 InsertPointTy CodeGenIP = Builder.saveIP();
2084 Expected<IRBuilderBase::InsertPoint> AfterIPOrError =
2085 DupCB(AllocaIP, CodeGenIP, DestTaskContextPtr, SrcTaskContextPtr);
2086 if (!AfterIPOrError)
2087 return AfterIPOrError.takeError();
2088 Builder.restoreIP(IP: *AfterIPOrError);
2089
2090 Builder.CreateRetVoid();
2091
2092 return DupFunction;
2093}
2094
2095OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTaskloop(
2096 const LocationDescription &Loc, InsertPointTy AllocaIP,
2097 BodyGenCallbackTy BodyGenCB,
2098 llvm::function_ref<llvm::Expected<llvm::CanonicalLoopInfo *>()> LoopInfo,
2099 Value *LBVal, Value *UBVal, Value *StepVal, bool Untied, Value *IfCond,
2100 Value *GrainSize, bool NoGroup, int Sched, Value *Final, bool Mergeable,
2101 Value *Priority, TaskDupCallbackTy DupCB, Value *TaskContextStructPtrVal) {
2102
2103 if (!updateToLocation(Loc))
2104 return InsertPointTy();
2105
2106 uint32_t SrcLocStrSize;
2107 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2108 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2109
2110 BasicBlock *TaskloopExitBB =
2111 splitBB(Builder, /*CreateBranch=*/true, Name: "taskloop.exit");
2112 BasicBlock *TaskloopBodyBB =
2113 splitBB(Builder, /*CreateBranch=*/true, Name: "taskloop.body");
2114 BasicBlock *TaskloopAllocaBB =
2115 splitBB(Builder, /*CreateBranch=*/true, Name: "taskloop.alloca");
2116
2117 InsertPointTy TaskloopAllocaIP =
2118 InsertPointTy(TaskloopAllocaBB, TaskloopAllocaBB->begin());
2119 InsertPointTy TaskloopBodyIP =
2120 InsertPointTy(TaskloopBodyBB, TaskloopBodyBB->begin());
2121
2122 if (Error Err = BodyGenCB(TaskloopAllocaIP, TaskloopBodyIP))
2123 return Err;
2124
2125 llvm::Expected<llvm::CanonicalLoopInfo *> result = LoopInfo();
2126 if (!result) {
2127 return result.takeError();
2128 }
2129
2130 llvm::CanonicalLoopInfo *CLI = result.get();
2131 OutlineInfo OI;
2132 OI.EntryBB = TaskloopAllocaBB;
2133 OI.OuterAllocaBB = AllocaIP.getBlock();
2134 OI.ExitBB = TaskloopExitBB;
2135
2136 // Add the thread ID argument.
2137 SmallVector<Instruction *> ToBeDeleted;
2138 // dummy instruction to be used as a fake argument
2139 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
2140 Builder, OuterAllocaIP: AllocaIP, ToBeDeleted, InnerAllocaIP: TaskloopAllocaIP, Name: "global.tid", AsPtr: false));
2141 Value *FakeLB = createFakeIntVal(Builder, OuterAllocaIP: AllocaIP, ToBeDeleted,
2142 InnerAllocaIP: TaskloopAllocaIP, Name: "lb", AsPtr: false, Is64Bit: true);
2143 Value *FakeUB = createFakeIntVal(Builder, OuterAllocaIP: AllocaIP, ToBeDeleted,
2144 InnerAllocaIP: TaskloopAllocaIP, Name: "ub", AsPtr: false, Is64Bit: true);
2145 Value *FakeStep = createFakeIntVal(Builder, OuterAllocaIP: AllocaIP, ToBeDeleted,
2146 InnerAllocaIP: TaskloopAllocaIP, Name: "step", AsPtr: false, Is64Bit: true);
2147 // For Taskloop, we want to force the bounds being the first 3 inputs in the
2148 // aggregate struct
2149 OI.Inputs.insert(X: FakeLB);
2150 OI.Inputs.insert(X: FakeUB);
2151 OI.Inputs.insert(X: FakeStep);
2152 if (TaskContextStructPtrVal)
2153 OI.Inputs.insert(X: TaskContextStructPtrVal);
2154 assert(((TaskContextStructPtrVal && DupCB) ||
2155 (!TaskContextStructPtrVal && !DupCB)) &&
2156 "Task context struct ptr and duplication callback must be both set "
2157 "or both null");
2158
2159 // It isn't safe to run the duplication bodygen callback inside the post
2160 // outlining callback so this has to be run now before we know the real task
2161 // shareds structure type.
2162 unsigned ProgramAddressSpace = M.getDataLayout().getProgramAddressSpace();
2163 Type *PointerTy = PointerType::get(C&: Builder.getContext(), AddressSpace: ProgramAddressSpace);
2164 Type *FakeSharedsTy = StructType::get(
2165 Context&: Builder.getContext(),
2166 Elements: {FakeLB->getType(), FakeUB->getType(), FakeStep->getType(), PointerTy});
2167 Expected<Value *> TaskDupFnOrErr = createTaskDuplicationFunction(
2168 PrivatesTy: FakeSharedsTy,
2169 /*PrivatesIndex: the pointer after the three indices above*/ PrivatesIndex: 3, DupCB);
2170 if (!TaskDupFnOrErr) {
2171 return TaskDupFnOrErr.takeError();
2172 }
2173 Value *TaskDupFn = *TaskDupFnOrErr;
2174
2175 OI.PostOutlineCB = [this, Ident, LBVal, UBVal, StepVal, Untied,
2176 TaskloopAllocaBB, CLI, Loc, TaskDupFn, ToBeDeleted,
2177 IfCond, GrainSize, NoGroup, Sched, FakeLB, FakeUB,
2178 FakeStep, Final, Mergeable,
2179 Priority](Function &OutlinedFn) mutable {
2180 // Replace the Stale CI by appropriate RTL function call.
2181 assert(OutlinedFn.hasOneUse() &&
2182 "there must be a single user for the outlined function");
2183 CallInst *StaleCI = cast<CallInst>(Val: OutlinedFn.user_back());
2184
2185 /* Create the casting for the Bounds Values that can be used when outlining
2186 * to replace the uses of the fakes with real values */
2187 BasicBlock *CodeReplBB = StaleCI->getParent();
2188 IRBuilderBase::InsertPoint CurrentIp = Builder.saveIP();
2189 Builder.SetInsertPoint(CodeReplBB->getFirstInsertionPt());
2190 Value *CastedLBVal =
2191 Builder.CreateIntCast(V: LBVal, DestTy: Builder.getInt64Ty(), isSigned: true, Name: "lb64");
2192 Value *CastedUBVal =
2193 Builder.CreateIntCast(V: UBVal, DestTy: Builder.getInt64Ty(), isSigned: true, Name: "ub64");
2194 Value *CastedStepVal =
2195 Builder.CreateIntCast(V: StepVal, DestTy: Builder.getInt64Ty(), isSigned: true, Name: "step64");
2196 Builder.restoreIP(IP: CurrentIp);
2197
2198 Builder.SetInsertPoint(StaleCI);
2199
2200 // Gather the arguments for emitting the runtime call for
2201 // @__kmpc_omp_task_alloc
2202 Function *TaskAllocFn =
2203 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_alloc);
2204
2205 Value *ThreadID = getOrCreateThreadID(Ident);
2206
2207 if (!NoGroup) {
2208 // Emit runtime call for @__kmpc_taskgroup
2209 Function *TaskgroupFn =
2210 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_taskgroup);
2211 Builder.CreateCall(Callee: TaskgroupFn, Args: {Ident, ThreadID});
2212 }
2213
2214 // `flags` Argument Configuration
2215 // Task is tied if (Flags & 1) == 1.
2216 // Task is untied if (Flags & 1) == 0.
2217 // Task is final if (Flags & 2) == 2.
2218 // Task is not final if (Flags & 2) == 0.
2219 // Task is mergeable if (Flags & 4) == 4.
2220 // Task is not mergeable if (Flags & 4) == 0.
2221 // Task is priority if (Flags & 32) == 32.
2222 // Task is not priority if (Flags & 32) == 0.
2223 Value *Flags = Builder.getInt32(C: Untied ? 0 : 1);
2224 if (Final)
2225 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 2), RHS: Flags);
2226 if (Mergeable)
2227 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 4), RHS: Flags);
2228 if (Priority)
2229 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 32), RHS: Flags);
2230
2231 Value *TaskSize = Builder.getInt64(
2232 C: divideCeil(Numerator: M.getDataLayout().getTypeSizeInBits(Ty: Task), Denominator: 8));
2233
2234 AllocaInst *ArgStructAlloca =
2235 dyn_cast<AllocaInst>(Val: StaleCI->getArgOperand(i: 1));
2236 assert(ArgStructAlloca &&
2237 "Unable to find the alloca instruction corresponding to arguments "
2238 "for extracted function");
2239 StructType *ArgStructType =
2240 dyn_cast<StructType>(Val: ArgStructAlloca->getAllocatedType());
2241 assert(ArgStructType && "Unable to find struct type corresponding to "
2242 "arguments for extracted function");
2243 Value *SharedsSize =
2244 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: ArgStructType));
2245
2246 // Emit the @__kmpc_omp_task_alloc runtime call
2247 // The runtime call returns a pointer to an area where the task captured
2248 // variables must be copied before the task is run (TaskData)
2249 CallInst *TaskData = Builder.CreateCall(
2250 Callee: TaskAllocFn, Args: {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
2251 /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
2252 /*task_func=*/&OutlinedFn});
2253
2254 Value *Shareds = StaleCI->getArgOperand(i: 1);
2255 Align Alignment = TaskData->getPointerAlignment(DL: M.getDataLayout());
2256 Value *TaskShareds = Builder.CreateLoad(Ty: VoidPtr, Ptr: TaskData);
2257 Builder.CreateMemCpy(Dst: TaskShareds, DstAlign: Alignment, Src: Shareds, SrcAlign: Alignment,
2258 Size: SharedsSize);
2259 // Get the pointer to loop lb, ub, step from task ptr
2260 // and set up the lowerbound,upperbound and step values
2261 llvm::Value *Lb = Builder.CreateGEP(
2262 Ty: ArgStructType, Ptr: TaskShareds, IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 0)});
2263
2264 llvm::Value *Ub = Builder.CreateGEP(
2265 Ty: ArgStructType, Ptr: TaskShareds, IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 1)});
2266
2267 llvm::Value *Step = Builder.CreateGEP(
2268 Ty: ArgStructType, Ptr: TaskShareds, IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 2)});
2269 llvm::Value *Loadstep = Builder.CreateLoad(Ty: Builder.getInt64Ty(), Ptr: Step);
2270
2271 // set up the arguments for emitting kmpc_taskloop runtime call
2272 // setting values for ifval, nogroup, sched, grainsize, task_dup
2273 Value *IfCondVal =
2274 IfCond ? Builder.CreateIntCast(V: IfCond, DestTy: Builder.getInt32Ty(), isSigned: true)
2275 : Builder.getInt32(C: 1);
2276 // As __kmpc_taskgroup is called manually in OMPIRBuilder, NoGroupVal should
2277 // always be 1 when calling __kmpc_taskloop to ensure it is not called again
2278 Value *NoGroupVal = Builder.getInt32(C: 1);
2279 Value *SchedVal = Builder.getInt32(C: Sched);
2280 Value *GrainSizeVal =
2281 GrainSize ? Builder.CreateIntCast(V: GrainSize, DestTy: Builder.getInt64Ty(), isSigned: true)
2282 : Builder.getInt64(C: 0);
2283 Value *TaskDup = TaskDupFn;
2284
2285 Value *Args[] = {Ident, ThreadID, TaskData, IfCondVal, Lb, Ub,
2286 Loadstep, NoGroupVal, SchedVal, GrainSizeVal, TaskDup};
2287
2288 // taskloop runtime call
2289 Function *TaskloopFn =
2290 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_taskloop);
2291 Builder.CreateCall(Callee: TaskloopFn, Args);
2292
2293 // Emit the @__kmpc_end_taskgroup runtime call to end the taskgroup if
2294 // nogroup is not defined
2295 if (!NoGroup) {
2296 Function *EndTaskgroupFn =
2297 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_taskgroup);
2298 Builder.CreateCall(Callee: EndTaskgroupFn, Args: {Ident, ThreadID});
2299 }
2300
2301 StaleCI->eraseFromParent();
2302
2303 Builder.SetInsertPoint(TheBB: TaskloopAllocaBB, IP: TaskloopAllocaBB->begin());
2304
2305 LoadInst *SharedsOutlined =
2306 Builder.CreateLoad(Ty: VoidPtr, Ptr: OutlinedFn.getArg(i: 1));
2307 OutlinedFn.getArg(i: 1)->replaceUsesWithIf(
2308 New: SharedsOutlined,
2309 ShouldReplace: [SharedsOutlined](Use &U) { return U.getUser() != SharedsOutlined; });
2310
2311 Value *IV = CLI->getIndVar();
2312 Type *IVTy = IV->getType();
2313 Constant *One = ConstantInt::get(Ty: Builder.getInt64Ty(), V: 1);
2314
2315 // When outlining, CodeExtractor will create GEP's to the LowerBound and
2316 // UpperBound. These GEP's can be reused for loading the tasks respective
2317 // bounds.
2318 Value *TaskLB = nullptr;
2319 Value *TaskUB = nullptr;
2320 Value *LoadTaskLB = nullptr;
2321 Value *LoadTaskUB = nullptr;
2322 for (Instruction &I : *TaskloopAllocaBB) {
2323 if (I.getOpcode() == Instruction::GetElementPtr) {
2324 GetElementPtrInst &Gep = cast<GetElementPtrInst>(Val&: I);
2325 if (ConstantInt *CI = dyn_cast<ConstantInt>(Val: Gep.getOperand(i_nocapture: 2))) {
2326 switch (CI->getZExtValue()) {
2327 case 0:
2328 TaskLB = &I;
2329 break;
2330 case 1:
2331 TaskUB = &I;
2332 break;
2333 }
2334 }
2335 } else if (I.getOpcode() == Instruction::Load) {
2336 LoadInst &Load = cast<LoadInst>(Val&: I);
2337 if (Load.getPointerOperand() == TaskLB) {
2338 assert(TaskLB != nullptr && "Expected value for TaskLB");
2339 LoadTaskLB = &I;
2340 } else if (Load.getPointerOperand() == TaskUB) {
2341 assert(TaskUB != nullptr && "Expected value for TaskUB");
2342 LoadTaskUB = &I;
2343 }
2344 }
2345 }
2346
2347 Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
2348
2349 assert(LoadTaskLB != nullptr && "Expected value for LoadTaskLB");
2350 assert(LoadTaskUB != nullptr && "Expected value for LoadTaskUB");
2351 Value *TripCountMinusOne =
2352 Builder.CreateSDiv(LHS: Builder.CreateSub(LHS: LoadTaskUB, RHS: LoadTaskLB), RHS: FakeStep);
2353 Value *TripCount = Builder.CreateAdd(LHS: TripCountMinusOne, RHS: One, Name: "trip_cnt");
2354 Value *CastedTripCount = Builder.CreateIntCast(V: TripCount, DestTy: IVTy, isSigned: true);
2355 Value *CastedTaskLB = Builder.CreateIntCast(V: LoadTaskLB, DestTy: IVTy, isSigned: true);
2356 // set the trip count in the CLI
2357 CLI->setTripCount(CastedTripCount);
2358
2359 Builder.SetInsertPoint(TheBB: CLI->getBody(),
2360 IP: CLI->getBody()->getFirstInsertionPt());
2361
2362 // The canonical loop is generated with a fixed lower bound. We need to
2363 // update the index calculation code to use the task's lower bound. The
2364 // generated code looks like this:
2365 // %omp_loop.iv = phi ...
2366 // ...
2367 // %tmp = mul [type] %omp_loop.iv, step
2368 // %user_index = add [type] tmp, lb
2369 // OpenMPIRBuilder constructs canonical loops to have exactly three uses of
2370 // the normalised induction variable:
2371 // 1. This one: converting the normalised IV to the user IV
2372 // 2. The increment (add)
2373 // 3. The comparison against the trip count (icmp)
2374 // (1) is the only use that is a mul followed by an add so this cannot match
2375 // other IR.
2376 assert(CLI->getIndVar()->getNumUses() == 3 &&
2377 "Canonical loop should have exactly three uses of the ind var");
2378 for (User *IVUser : CLI->getIndVar()->users()) {
2379 if (auto *Mul = dyn_cast<BinaryOperator>(Val: IVUser)) {
2380 if (Mul->getOpcode() == Instruction::Mul) {
2381 for (User *MulUser : Mul->users()) {
2382 if (auto *Add = dyn_cast<BinaryOperator>(Val: MulUser)) {
2383 if (Add->getOpcode() == Instruction::Add) {
2384 Add->setOperand(i_nocapture: 1, Val_nocapture: CastedTaskLB);
2385 }
2386 }
2387 }
2388 }
2389 }
2390 }
2391
2392 FakeLB->replaceAllUsesWith(V: CastedLBVal);
2393 FakeUB->replaceAllUsesWith(V: CastedUBVal);
2394 FakeStep->replaceAllUsesWith(V: CastedStepVal);
2395 for (Instruction *I : llvm::reverse(C&: ToBeDeleted)) {
2396 I->eraseFromParent();
2397 }
2398 };
2399
2400 addOutlineInfo(OI: std::move(OI));
2401 Builder.SetInsertPoint(TheBB: TaskloopExitBB, IP: TaskloopExitBB->begin());
2402 return Builder.saveIP();
2403}
2404
2405OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
2406 const LocationDescription &Loc, InsertPointTy AllocaIP,
2407 BodyGenCallbackTy BodyGenCB, bool Tied, Value *Final, Value *IfCondition,
2408 SmallVector<DependData> Dependencies, bool Mergeable, Value *EventHandle,
2409 Value *Priority) {
2410
2411 if (!updateToLocation(Loc))
2412 return InsertPointTy();
2413
2414 uint32_t SrcLocStrSize;
2415 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2416 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2417 // The current basic block is split into four basic blocks. After outlining,
2418 // they will be mapped as follows:
2419 // ```
2420 // def current_fn() {
2421 // current_basic_block:
2422 // br label %task.exit
2423 // task.exit:
2424 // ; instructions after task
2425 // }
2426 // def outlined_fn() {
2427 // task.alloca:
2428 // br label %task.body
2429 // task.body:
2430 // ret void
2431 // }
2432 // ```
2433 BasicBlock *TaskExitBB = splitBB(Builder, /*CreateBranch=*/true, Name: "task.exit");
2434 BasicBlock *TaskBodyBB = splitBB(Builder, /*CreateBranch=*/true, Name: "task.body");
2435 BasicBlock *TaskAllocaBB =
2436 splitBB(Builder, /*CreateBranch=*/true, Name: "task.alloca");
2437
2438 InsertPointTy TaskAllocaIP =
2439 InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin());
2440 InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin());
2441 if (Error Err = BodyGenCB(TaskAllocaIP, TaskBodyIP))
2442 return Err;
2443
2444 OutlineInfo OI;
2445 OI.EntryBB = TaskAllocaBB;
2446 OI.OuterAllocaBB = AllocaIP.getBlock();
2447 OI.ExitBB = TaskExitBB;
2448
2449 // Add the thread ID argument.
2450 SmallVector<Instruction *, 4> ToBeDeleted;
2451 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
2452 Builder, OuterAllocaIP: AllocaIP, ToBeDeleted, InnerAllocaIP: TaskAllocaIP, Name: "global.tid", AsPtr: false));
2453
2454 OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
2455 Mergeable, Priority, EventHandle, TaskAllocaBB,
2456 ToBeDeleted](Function &OutlinedFn) mutable {
2457 // Replace the Stale CI by appropriate RTL function call.
2458 assert(OutlinedFn.hasOneUse() &&
2459 "there must be a single user for the outlined function");
2460 CallInst *StaleCI = cast<CallInst>(Val: OutlinedFn.user_back());
2461
2462 // HasShareds is true if any variables are captured in the outlined region,
2463 // false otherwise.
2464 bool HasShareds = StaleCI->arg_size() > 1;
2465 Builder.SetInsertPoint(StaleCI);
2466
2467 // Gather the arguments for emitting the runtime call for
2468 // @__kmpc_omp_task_alloc
2469 Function *TaskAllocFn =
2470 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_alloc);
2471
2472 // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
2473 // call.
2474 Value *ThreadID = getOrCreateThreadID(Ident);
2475
2476 // Argument - `flags`
2477 // Task is tied iff (Flags & 1) == 1.
2478 // Task is untied iff (Flags & 1) == 0.
2479 // Task is final iff (Flags & 2) == 2.
2480 // Task is not final iff (Flags & 2) == 0.
2481 // Task is mergeable iff (Flags & 4) == 4.
2482 // Task is not mergeable iff (Flags & 4) == 0.
2483 // Task is priority iff (Flags & 32) == 32.
2484 // Task is not priority iff (Flags & 32) == 0.
2485 // TODO: Handle the other flags.
2486 Value *Flags = Builder.getInt32(C: Tied);
2487 if (Final) {
2488 Value *FinalFlag =
2489 Builder.CreateSelect(C: Final, True: Builder.getInt32(C: 2), False: Builder.getInt32(C: 0));
2490 Flags = Builder.CreateOr(LHS: FinalFlag, RHS: Flags);
2491 }
2492
2493 if (Mergeable)
2494 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 4), RHS: Flags);
2495 if (Priority)
2496 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 32), RHS: Flags);
2497
2498 // Argument - `sizeof_kmp_task_t` (TaskSize)
2499 // Tasksize refers to the size in bytes of kmp_task_t data structure
2500 // including private vars accessed in task.
2501 // TODO: add kmp_task_t_with_privates (privates)
2502 Value *TaskSize = Builder.getInt64(
2503 C: divideCeil(Numerator: M.getDataLayout().getTypeSizeInBits(Ty: Task), Denominator: 8));
2504
2505 // Argument - `sizeof_shareds` (SharedsSize)
2506 // SharedsSize refers to the shareds array size in the kmp_task_t data
2507 // structure.
2508 Value *SharedsSize = Builder.getInt64(C: 0);
2509 if (HasShareds) {
2510 AllocaInst *ArgStructAlloca =
2511 dyn_cast<AllocaInst>(Val: StaleCI->getArgOperand(i: 1));
2512 assert(ArgStructAlloca &&
2513 "Unable to find the alloca instruction corresponding to arguments "
2514 "for extracted function");
2515 StructType *ArgStructType =
2516 dyn_cast<StructType>(Val: ArgStructAlloca->getAllocatedType());
2517 assert(ArgStructType && "Unable to find struct type corresponding to "
2518 "arguments for extracted function");
2519 SharedsSize =
2520 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: ArgStructType));
2521 }
2522 // Emit the @__kmpc_omp_task_alloc runtime call
2523 // The runtime call returns a pointer to an area where the task captured
2524 // variables must be copied before the task is run (TaskData)
2525 CallInst *TaskData = createRuntimeFunctionCall(
2526 Callee: TaskAllocFn, Args: {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
2527 /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
2528 /*task_func=*/&OutlinedFn});
2529
2530 // Emit detach clause initialization.
2531 // evt = (typeof(evt))__kmpc_task_allow_completion_event(loc, tid,
2532 // task_descriptor);
2533 if (EventHandle) {
2534 Function *TaskDetachFn = getOrCreateRuntimeFunctionPtr(
2535 FnID: OMPRTL___kmpc_task_allow_completion_event);
2536 llvm::Value *EventVal =
2537 createRuntimeFunctionCall(Callee: TaskDetachFn, Args: {Ident, ThreadID, TaskData});
2538 llvm::Value *EventHandleAddr =
2539 Builder.CreatePointerBitCastOrAddrSpaceCast(V: EventHandle,
2540 DestTy: Builder.getPtrTy(AddrSpace: 0));
2541 EventVal = Builder.CreatePtrToInt(V: EventVal, DestTy: Builder.getInt64Ty());
2542 Builder.CreateStore(Val: EventVal, Ptr: EventHandleAddr);
2543 }
2544 // Copy the arguments for outlined function
2545 if (HasShareds) {
2546 Value *Shareds = StaleCI->getArgOperand(i: 1);
2547 Align Alignment = TaskData->getPointerAlignment(DL: M.getDataLayout());
2548 Value *TaskShareds = Builder.CreateLoad(Ty: VoidPtr, Ptr: TaskData);
2549 Builder.CreateMemCpy(Dst: TaskShareds, DstAlign: Alignment, Src: Shareds, SrcAlign: Alignment,
2550 Size: SharedsSize);
2551 }
2552
2553 if (Priority) {
2554 //
2555 // The return type of "__kmpc_omp_task_alloc" is "kmp_task_t *",
2556 // we populate the priority information into the "kmp_task_t" here
2557 //
2558 // The struct "kmp_task_t" definition is available in kmp.h
2559 // kmp_task_t = { shareds, routine, part_id, data1, data2 }
2560 // data2 is used for priority
2561 //
2562 Type *Int32Ty = Builder.getInt32Ty();
2563 Constant *Zero = ConstantInt::get(Ty: Int32Ty, V: 0);
2564 // kmp_task_t* => { ptr }
2565 Type *TaskPtr = StructType::get(elt1: VoidPtr);
2566 Value *TaskGEP =
2567 Builder.CreateInBoundsGEP(Ty: TaskPtr, Ptr: TaskData, IdxList: {Zero, Zero});
2568 // kmp_task_t => { ptr, ptr, i32, ptr, ptr }
2569 Type *TaskStructType = StructType::get(
2570 elt1: VoidPtr, elts: VoidPtr, elts: Builder.getInt32Ty(), elts: VoidPtr, elts: VoidPtr);
2571 Value *PriorityData = Builder.CreateInBoundsGEP(
2572 Ty: TaskStructType, Ptr: TaskGEP, IdxList: {Zero, ConstantInt::get(Ty: Int32Ty, V: 4)});
2573 // kmp_cmplrdata_t => { ptr, ptr }
2574 Type *CmplrStructType = StructType::get(elt1: VoidPtr, elts: VoidPtr);
2575 Value *CmplrData = Builder.CreateInBoundsGEP(Ty: CmplrStructType,
2576 Ptr: PriorityData, IdxList: {Zero, Zero});
2577 Builder.CreateStore(Val: Priority, Ptr: CmplrData);
2578 }
2579
2580 Value *DepArray = emitTaskDependencies(OMPBuilder&: *this, Dependencies);
2581
2582 // In the presence of the `if` clause, the following IR is generated:
2583 // ...
2584 // %data = call @__kmpc_omp_task_alloc(...)
2585 // br i1 %if_condition, label %then, label %else
2586 // then:
2587 // call @__kmpc_omp_task(...)
2588 // br label %exit
2589 // else:
2590 // ;; Wait for resolution of dependencies, if any, before
2591 // ;; beginning the task
2592 // call @__kmpc_omp_wait_deps(...)
2593 // call @__kmpc_omp_task_begin_if0(...)
2594 // call @outlined_fn(...)
2595 // call @__kmpc_omp_task_complete_if0(...)
2596 // br label %exit
2597 // exit:
2598 // ...
2599 if (IfCondition) {
2600 // `SplitBlockAndInsertIfThenElse` requires the block to have a
2601 // terminator.
2602 splitBB(Builder, /*CreateBranch=*/true, Name: "if.end");
2603 Instruction *IfTerminator =
2604 Builder.GetInsertPoint()->getParent()->getTerminator();
2605 Instruction *ThenTI = IfTerminator, *ElseTI = nullptr;
2606 Builder.SetInsertPoint(IfTerminator);
2607 SplitBlockAndInsertIfThenElse(Cond: IfCondition, SplitBefore: IfTerminator, ThenTerm: &ThenTI,
2608 ElseTerm: &ElseTI);
2609 Builder.SetInsertPoint(ElseTI);
2610
2611 if (Dependencies.size()) {
2612 Function *TaskWaitFn =
2613 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_wait_deps);
2614 createRuntimeFunctionCall(
2615 Callee: TaskWaitFn,
2616 Args: {Ident, ThreadID, Builder.getInt32(C: Dependencies.size()), DepArray,
2617 ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
2618 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
2619 }
2620 Function *TaskBeginFn =
2621 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_begin_if0);
2622 Function *TaskCompleteFn =
2623 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_complete_if0);
2624 createRuntimeFunctionCall(Callee: TaskBeginFn, Args: {Ident, ThreadID, TaskData});
2625 CallInst *CI = nullptr;
2626 if (HasShareds)
2627 CI = createRuntimeFunctionCall(Callee: &OutlinedFn, Args: {ThreadID, TaskData});
2628 else
2629 CI = createRuntimeFunctionCall(Callee: &OutlinedFn, Args: {ThreadID});
2630 CI->setDebugLoc(StaleCI->getDebugLoc());
2631 createRuntimeFunctionCall(Callee: TaskCompleteFn, Args: {Ident, ThreadID, TaskData});
2632 Builder.SetInsertPoint(ThenTI);
2633 }
2634
2635 if (Dependencies.size()) {
2636 Function *TaskFn =
2637 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_with_deps);
2638 createRuntimeFunctionCall(
2639 Callee: TaskFn,
2640 Args: {Ident, ThreadID, TaskData, Builder.getInt32(C: Dependencies.size()),
2641 DepArray, ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
2642 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
2643
2644 } else {
2645 // Emit the @__kmpc_omp_task runtime call to spawn the task
2646 Function *TaskFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task);
2647 createRuntimeFunctionCall(Callee: TaskFn, Args: {Ident, ThreadID, TaskData});
2648 }
2649
2650 StaleCI->eraseFromParent();
2651
2652 Builder.SetInsertPoint(TheBB: TaskAllocaBB, IP: TaskAllocaBB->begin());
2653 if (HasShareds) {
2654 LoadInst *Shareds = Builder.CreateLoad(Ty: VoidPtr, Ptr: OutlinedFn.getArg(i: 1));
2655 OutlinedFn.getArg(i: 1)->replaceUsesWithIf(
2656 New: Shareds, ShouldReplace: [Shareds](Use &U) { return U.getUser() != Shareds; });
2657 }
2658
2659 for (Instruction *I : llvm::reverse(C&: ToBeDeleted))
2660 I->eraseFromParent();
2661 };
2662
2663 addOutlineInfo(OI: std::move(OI));
2664 Builder.SetInsertPoint(TheBB: TaskExitBB, IP: TaskExitBB->begin());
2665
2666 return Builder.saveIP();
2667}
2668
2669OpenMPIRBuilder::InsertPointOrErrorTy
2670OpenMPIRBuilder::createTaskgroup(const LocationDescription &Loc,
2671 InsertPointTy AllocaIP,
2672 BodyGenCallbackTy BodyGenCB) {
2673 if (!updateToLocation(Loc))
2674 return InsertPointTy();
2675
2676 uint32_t SrcLocStrSize;
2677 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2678 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2679 Value *ThreadID = getOrCreateThreadID(Ident);
2680
2681 // Emit the @__kmpc_taskgroup runtime call to start the taskgroup
2682 Function *TaskgroupFn =
2683 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_taskgroup);
2684 createRuntimeFunctionCall(Callee: TaskgroupFn, Args: {Ident, ThreadID});
2685
2686 BasicBlock *TaskgroupExitBB = splitBB(Builder, CreateBranch: true, Name: "taskgroup.exit");
2687 if (Error Err = BodyGenCB(AllocaIP, Builder.saveIP()))
2688 return Err;
2689
2690 Builder.SetInsertPoint(TaskgroupExitBB);
2691 // Emit the @__kmpc_end_taskgroup runtime call to end the taskgroup
2692 Function *EndTaskgroupFn =
2693 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_taskgroup);
2694 createRuntimeFunctionCall(Callee: EndTaskgroupFn, Args: {Ident, ThreadID});
2695
2696 return Builder.saveIP();
2697}
2698
2699OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createSections(
2700 const LocationDescription &Loc, InsertPointTy AllocaIP,
2701 ArrayRef<StorableBodyGenCallbackTy> SectionCBs, PrivatizeCallbackTy PrivCB,
2702 FinalizeCallbackTy FiniCB, bool IsCancellable, bool IsNowait) {
2703 assert(!isConflictIP(AllocaIP, Loc.IP) && "Dedicated IP allocas required");
2704
2705 if (!updateToLocation(Loc))
2706 return Loc.IP;
2707
2708 FinalizationStack.push_back(Elt: {FiniCB, OMPD_sections, IsCancellable});
2709
2710 // Each section is emitted as a switch case
2711 // Each finalization callback is handled from clang.EmitOMPSectionDirective()
2712 // -> OMP.createSection() which generates the IR for each section
2713 // Iterate through all sections and emit a switch construct:
2714 // switch (IV) {
2715 // case 0:
2716 // <SectionStmt[0]>;
2717 // break;
2718 // ...
2719 // case <NumSection> - 1:
2720 // <SectionStmt[<NumSection> - 1]>;
2721 // break;
2722 // }
2723 // ...
2724 // section_loop.after:
2725 // <FiniCB>;
2726 auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, Value *IndVar) -> Error {
2727 Builder.restoreIP(IP: CodeGenIP);
2728 BasicBlock *Continue =
2729 splitBBWithSuffix(Builder, /*CreateBranch=*/false, Suffix: ".sections.after");
2730 Function *CurFn = Continue->getParent();
2731 SwitchInst *SwitchStmt = Builder.CreateSwitch(V: IndVar, Dest: Continue);
2732
2733 unsigned CaseNumber = 0;
2734 for (auto SectionCB : SectionCBs) {
2735 BasicBlock *CaseBB = BasicBlock::Create(
2736 Context&: M.getContext(), Name: "omp_section_loop.body.case", Parent: CurFn, InsertBefore: Continue);
2737 SwitchStmt->addCase(OnVal: Builder.getInt32(C: CaseNumber), Dest: CaseBB);
2738 Builder.SetInsertPoint(CaseBB);
2739 BranchInst *CaseEndBr = Builder.CreateBr(Dest: Continue);
2740 if (Error Err = SectionCB(InsertPointTy(), {CaseEndBr->getParent(),
2741 CaseEndBr->getIterator()}))
2742 return Err;
2743 CaseNumber++;
2744 }
2745 // remove the existing terminator from body BB since there can be no
2746 // terminators after switch/case
2747 return Error::success();
2748 };
2749 // Loop body ends here
2750 // LowerBound, UpperBound, and STride for createCanonicalLoop
2751 Type *I32Ty = Type::getInt32Ty(C&: M.getContext());
2752 Value *LB = ConstantInt::get(Ty: I32Ty, V: 0);
2753 Value *UB = ConstantInt::get(Ty: I32Ty, V: SectionCBs.size());
2754 Value *ST = ConstantInt::get(Ty: I32Ty, V: 1);
2755 Expected<CanonicalLoopInfo *> LoopInfo = createCanonicalLoop(
2756 Loc, BodyGenCB: LoopBodyGenCB, Start: LB, Stop: UB, Step: ST, IsSigned: true, InclusiveStop: false, ComputeIP: AllocaIP, Name: "section_loop");
2757 if (!LoopInfo)
2758 return LoopInfo.takeError();
2759
2760 InsertPointOrErrorTy WsloopIP =
2761 applyStaticWorkshareLoop(DL: Loc.DL, CLI: *LoopInfo, AllocaIP,
2762 LoopType: WorksharingLoopType::ForStaticLoop, NeedsBarrier: !IsNowait);
2763 if (!WsloopIP)
2764 return WsloopIP.takeError();
2765 InsertPointTy AfterIP = *WsloopIP;
2766
2767 BasicBlock *LoopFini = AfterIP.getBlock()->getSinglePredecessor();
2768 assert(LoopFini && "Bad structure of static workshare loop finalization");
2769
2770 // Apply the finalization callback in LoopAfterBB
2771 auto FiniInfo = FinalizationStack.pop_back_val();
2772 assert(FiniInfo.DK == OMPD_sections &&
2773 "Unexpected finalization stack state!");
2774 if (Error Err = FiniInfo.mergeFiniBB(Builder, OtherFiniBB: LoopFini))
2775 return Err;
2776
2777 return AfterIP;
2778}
2779
2780OpenMPIRBuilder::InsertPointOrErrorTy
2781OpenMPIRBuilder::createSection(const LocationDescription &Loc,
2782 BodyGenCallbackTy BodyGenCB,
2783 FinalizeCallbackTy FiniCB) {
2784 if (!updateToLocation(Loc))
2785 return Loc.IP;
2786
2787 auto FiniCBWrapper = [&](InsertPointTy IP) {
2788 if (IP.getBlock()->end() != IP.getPoint())
2789 return FiniCB(IP);
2790 // This must be done otherwise any nested constructs using FinalizeOMPRegion
2791 // will fail because that function requires the Finalization Basic Block to
2792 // have a terminator, which is already removed by EmitOMPRegionBody.
2793 // IP is currently at cancelation block.
2794 // We need to backtrack to the condition block to fetch
2795 // the exit block and create a branch from cancelation
2796 // to exit block.
2797 IRBuilder<>::InsertPointGuard IPG(Builder);
2798 Builder.restoreIP(IP);
2799 auto *CaseBB = Loc.IP.getBlock();
2800 auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
2801 auto *ExitBB = CondBB->getTerminator()->getSuccessor(Idx: 1);
2802 Instruction *I = Builder.CreateBr(Dest: ExitBB);
2803 IP = InsertPointTy(I->getParent(), I->getIterator());
2804 return FiniCB(IP);
2805 };
2806
2807 Directive OMPD = Directive::OMPD_sections;
2808 // Since we are using Finalization Callback here, HasFinalize
2809 // and IsCancellable have to be true
2810 return EmitOMPInlinedRegion(OMPD, EntryCall: nullptr, ExitCall: nullptr, BodyGenCB, FiniCB: FiniCBWrapper,
2811 /*Conditional*/ false, /*hasFinalize*/ HasFinalize: true,
2812 /*IsCancellable*/ true);
2813}
2814
2815static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
2816 BasicBlock::iterator IT(I);
2817 IT++;
2818 return OpenMPIRBuilder::InsertPointTy(I->getParent(), IT);
2819}
2820
2821Value *OpenMPIRBuilder::getGPUThreadID() {
2822 return createRuntimeFunctionCall(
2823 Callee: getOrCreateRuntimeFunction(M,
2824 FnID: OMPRTL___kmpc_get_hardware_thread_id_in_block),
2825 Args: {});
2826}
2827
2828Value *OpenMPIRBuilder::getGPUWarpSize() {
2829 return createRuntimeFunctionCall(
2830 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___kmpc_get_warp_size), Args: {});
2831}
2832
2833Value *OpenMPIRBuilder::getNVPTXWarpID() {
2834 unsigned LaneIDBits = Log2_32(Value: Config.getGridValue().GV_Warp_Size);
2835 return Builder.CreateAShr(LHS: getGPUThreadID(), RHS: LaneIDBits, Name: "nvptx_warp_id");
2836}
2837
2838Value *OpenMPIRBuilder::getNVPTXLaneID() {
2839 unsigned LaneIDBits = Log2_32(Value: Config.getGridValue().GV_Warp_Size);
2840 assert(LaneIDBits < 32 && "Invalid LaneIDBits size in NVPTX device.");
2841 unsigned LaneIDMask = ~0u >> (32u - LaneIDBits);
2842 return Builder.CreateAnd(LHS: getGPUThreadID(), RHS: Builder.getInt32(C: LaneIDMask),
2843 Name: "nvptx_lane_id");
2844}
2845
2846Value *OpenMPIRBuilder::castValueToType(InsertPointTy AllocaIP, Value *From,
2847 Type *ToType) {
2848 Type *FromType = From->getType();
2849 uint64_t FromSize = M.getDataLayout().getTypeStoreSize(Ty: FromType);
2850 uint64_t ToSize = M.getDataLayout().getTypeStoreSize(Ty: ToType);
2851 assert(FromSize > 0 && "From size must be greater than zero");
2852 assert(ToSize > 0 && "To size must be greater than zero");
2853 if (FromType == ToType)
2854 return From;
2855 if (FromSize == ToSize)
2856 return Builder.CreateBitCast(V: From, DestTy: ToType);
2857 if (ToType->isIntegerTy() && FromType->isIntegerTy())
2858 return Builder.CreateIntCast(V: From, DestTy: ToType, /*isSigned*/ true);
2859 InsertPointTy SaveIP = Builder.saveIP();
2860 Builder.restoreIP(IP: AllocaIP);
2861 Value *CastItem = Builder.CreateAlloca(Ty: ToType);
2862 Builder.restoreIP(IP: SaveIP);
2863
2864 Value *ValCastItem = Builder.CreatePointerBitCastOrAddrSpaceCast(
2865 V: CastItem, DestTy: Builder.getPtrTy(AddrSpace: 0));
2866 Builder.CreateStore(Val: From, Ptr: ValCastItem);
2867 return Builder.CreateLoad(Ty: ToType, Ptr: CastItem);
2868}
2869
2870Value *OpenMPIRBuilder::createRuntimeShuffleFunction(InsertPointTy AllocaIP,
2871 Value *Element,
2872 Type *ElementType,
2873 Value *Offset) {
2874 uint64_t Size = M.getDataLayout().getTypeStoreSize(Ty: ElementType);
2875 assert(Size <= 8 && "Unsupported bitwidth in shuffle instruction");
2876
2877 // Cast all types to 32- or 64-bit values before calling shuffle routines.
2878 Type *CastTy = Builder.getIntNTy(N: Size <= 4 ? 32 : 64);
2879 Value *ElemCast = castValueToType(AllocaIP, From: Element, ToType: CastTy);
2880 Value *WarpSize =
2881 Builder.CreateIntCast(V: getGPUWarpSize(), DestTy: Builder.getInt16Ty(), isSigned: true);
2882 Function *ShuffleFunc = getOrCreateRuntimeFunctionPtr(
2883 FnID: Size <= 4 ? RuntimeFunction::OMPRTL___kmpc_shuffle_int32
2884 : RuntimeFunction::OMPRTL___kmpc_shuffle_int64);
2885 Value *WarpSizeCast =
2886 Builder.CreateIntCast(V: WarpSize, DestTy: Builder.getInt16Ty(), /*isSigned=*/true);
2887 Value *ShuffleCall =
2888 createRuntimeFunctionCall(Callee: ShuffleFunc, Args: {ElemCast, Offset, WarpSizeCast});
2889 return castValueToType(AllocaIP, From: ShuffleCall, ToType: CastTy);
2890}
2891
2892void OpenMPIRBuilder::shuffleAndStore(InsertPointTy AllocaIP, Value *SrcAddr,
2893 Value *DstAddr, Type *ElemType,
2894 Value *Offset, Type *ReductionArrayTy,
2895 bool IsByRefElem) {
2896 uint64_t Size = M.getDataLayout().getTypeStoreSize(Ty: ElemType);
2897 // Create the loop over the big sized data.
2898 // ptr = (void*)Elem;
2899 // ptrEnd = (void*) Elem + 1;
2900 // Step = 8;
2901 // while (ptr + Step < ptrEnd)
2902 // shuffle((int64_t)*ptr);
2903 // Step = 4;
2904 // while (ptr + Step < ptrEnd)
2905 // shuffle((int32_t)*ptr);
2906 // ...
2907 Type *IndexTy = Builder.getIndexTy(
2908 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
2909 Value *ElemPtr = DstAddr;
2910 Value *Ptr = SrcAddr;
2911 for (unsigned IntSize = 8; IntSize >= 1; IntSize /= 2) {
2912 if (Size < IntSize)
2913 continue;
2914 Type *IntType = Builder.getIntNTy(N: IntSize * 8);
2915 Ptr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2916 V: Ptr, DestTy: Builder.getPtrTy(AddrSpace: 0), Name: Ptr->getName() + ".ascast");
2917 Value *SrcAddrGEP =
2918 Builder.CreateGEP(Ty: ElemType, Ptr: SrcAddr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
2919 ElemPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2920 V: ElemPtr, DestTy: Builder.getPtrTy(AddrSpace: 0), Name: ElemPtr->getName() + ".ascast");
2921
2922 Function *CurFunc = Builder.GetInsertBlock()->getParent();
2923 if ((Size / IntSize) > 1) {
2924 Value *PtrEnd = Builder.CreatePointerBitCastOrAddrSpaceCast(
2925 V: SrcAddrGEP, DestTy: Builder.getPtrTy());
2926 BasicBlock *PreCondBB =
2927 BasicBlock::Create(Context&: M.getContext(), Name: ".shuffle.pre_cond");
2928 BasicBlock *ThenBB = BasicBlock::Create(Context&: M.getContext(), Name: ".shuffle.then");
2929 BasicBlock *ExitBB = BasicBlock::Create(Context&: M.getContext(), Name: ".shuffle.exit");
2930 BasicBlock *CurrentBB = Builder.GetInsertBlock();
2931 emitBlock(BB: PreCondBB, CurFn: CurFunc);
2932 PHINode *PhiSrc =
2933 Builder.CreatePHI(Ty: Ptr->getType(), /*NumReservedValues=*/2);
2934 PhiSrc->addIncoming(V: Ptr, BB: CurrentBB);
2935 PHINode *PhiDest =
2936 Builder.CreatePHI(Ty: ElemPtr->getType(), /*NumReservedValues=*/2);
2937 PhiDest->addIncoming(V: ElemPtr, BB: CurrentBB);
2938 Ptr = PhiSrc;
2939 ElemPtr = PhiDest;
2940 Value *PtrDiff = Builder.CreatePtrDiff(
2941 ElemTy: Builder.getInt8Ty(), LHS: PtrEnd,
2942 RHS: Builder.CreatePointerBitCastOrAddrSpaceCast(V: Ptr, DestTy: Builder.getPtrTy()));
2943 Builder.CreateCondBr(
2944 Cond: Builder.CreateICmpSGT(LHS: PtrDiff, RHS: Builder.getInt64(C: IntSize - 1)), True: ThenBB,
2945 False: ExitBB);
2946 emitBlock(BB: ThenBB, CurFn: CurFunc);
2947 Value *Res = createRuntimeShuffleFunction(
2948 AllocaIP,
2949 Element: Builder.CreateAlignedLoad(
2950 Ty: IntType, Ptr, Align: M.getDataLayout().getPrefTypeAlign(Ty: ElemType)),
2951 ElementType: IntType, Offset);
2952 Builder.CreateAlignedStore(Val: Res, Ptr: ElemPtr,
2953 Align: M.getDataLayout().getPrefTypeAlign(Ty: ElemType));
2954 Value *LocalPtr =
2955 Builder.CreateGEP(Ty: IntType, Ptr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
2956 Value *LocalElemPtr =
2957 Builder.CreateGEP(Ty: IntType, Ptr: ElemPtr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
2958 PhiSrc->addIncoming(V: LocalPtr, BB: ThenBB);
2959 PhiDest->addIncoming(V: LocalElemPtr, BB: ThenBB);
2960 emitBranch(Target: PreCondBB);
2961 emitBlock(BB: ExitBB, CurFn: CurFunc);
2962 } else {
2963 Value *Res = createRuntimeShuffleFunction(
2964 AllocaIP, Element: Builder.CreateLoad(Ty: IntType, Ptr), ElementType: IntType, Offset);
2965 if (ElemType->isIntegerTy() && ElemType->getScalarSizeInBits() <
2966 Res->getType()->getScalarSizeInBits())
2967 Res = Builder.CreateTrunc(V: Res, DestTy: ElemType);
2968 Builder.CreateStore(Val: Res, Ptr: ElemPtr);
2969 Ptr = Builder.CreateGEP(Ty: IntType, Ptr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
2970 ElemPtr =
2971 Builder.CreateGEP(Ty: IntType, Ptr: ElemPtr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
2972 }
2973 Size = Size % IntSize;
2974 }
2975}
2976
2977Error OpenMPIRBuilder::emitReductionListCopy(
2978 InsertPointTy AllocaIP, CopyAction Action, Type *ReductionArrayTy,
2979 ArrayRef<ReductionInfo> ReductionInfos, Value *SrcBase, Value *DestBase,
2980 ArrayRef<bool> IsByRef, CopyOptionsTy CopyOptions) {
2981 Type *IndexTy = Builder.getIndexTy(
2982 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
2983 Value *RemoteLaneOffset = CopyOptions.RemoteLaneOffset;
2984
2985 // Iterates, element-by-element, through the source Reduce list and
2986 // make a copy.
2987 for (auto En : enumerate(First&: ReductionInfos)) {
2988 const ReductionInfo &RI = En.value();
2989 Value *SrcElementAddr = nullptr;
2990 AllocaInst *DestAlloca = nullptr;
2991 Value *DestElementAddr = nullptr;
2992 Value *DestElementPtrAddr = nullptr;
2993 // Should we shuffle in an element from a remote lane?
2994 bool ShuffleInElement = false;
2995 // Set to true to update the pointer in the dest Reduce list to a
2996 // newly created element.
2997 bool UpdateDestListPtr = false;
2998
2999 // Step 1.1: Get the address for the src element in the Reduce list.
3000 Value *SrcElementPtrAddr = Builder.CreateInBoundsGEP(
3001 Ty: ReductionArrayTy, Ptr: SrcBase,
3002 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3003 SrcElementAddr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: SrcElementPtrAddr);
3004
3005 // Step 1.2: Create a temporary to store the element in the destination
3006 // Reduce list.
3007 DestElementPtrAddr = Builder.CreateInBoundsGEP(
3008 Ty: ReductionArrayTy, Ptr: DestBase,
3009 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3010 bool IsByRefElem = (!IsByRef.empty() && IsByRef[En.index()]);
3011 switch (Action) {
3012 case CopyAction::RemoteLaneToThread: {
3013 InsertPointTy CurIP = Builder.saveIP();
3014 Builder.restoreIP(IP: AllocaIP);
3015
3016 Type *DestAllocaType =
3017 IsByRefElem ? RI.ByRefAllocatedType : RI.ElementType;
3018 DestAlloca = Builder.CreateAlloca(Ty: DestAllocaType, ArraySize: nullptr,
3019 Name: ".omp.reduction.element");
3020 DestAlloca->setAlignment(
3021 M.getDataLayout().getPrefTypeAlign(Ty: DestAllocaType));
3022 DestElementAddr = DestAlloca;
3023 DestElementAddr =
3024 Builder.CreateAddrSpaceCast(V: DestElementAddr, DestTy: Builder.getPtrTy(),
3025 Name: DestElementAddr->getName() + ".ascast");
3026 Builder.restoreIP(IP: CurIP);
3027 ShuffleInElement = true;
3028 UpdateDestListPtr = true;
3029 break;
3030 }
3031 case CopyAction::ThreadCopy: {
3032 DestElementAddr =
3033 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: DestElementPtrAddr);
3034 break;
3035 }
3036 }
3037
3038 // Now that all active lanes have read the element in the
3039 // Reduce list, shuffle over the value from the remote lane.
3040 if (ShuffleInElement) {
3041 Type *ShuffleType = RI.ElementType;
3042 Value *ShuffleSrcAddr = SrcElementAddr;
3043 Value *ShuffleDestAddr = DestElementAddr;
3044 AllocaInst *LocalStorage = nullptr;
3045
3046 if (IsByRefElem) {
3047 assert(RI.ByRefElementType && "Expected by-ref element type to be set");
3048 assert(RI.ByRefAllocatedType &&
3049 "Expected by-ref allocated type to be set");
3050 // For by-ref reductions, we need to copy from the remote lane the
3051 // actual value of the partial reduction computed by that remote lane;
3052 // rather than, for example, a pointer to that data or, even worse, a
3053 // pointer to the descriptor of the by-ref reduction element.
3054 ShuffleType = RI.ByRefElementType;
3055
3056 InsertPointOrErrorTy GenResult =
3057 RI.DataPtrPtrGen(Builder.saveIP(), ShuffleSrcAddr, ShuffleSrcAddr);
3058
3059 if (!GenResult)
3060 return GenResult.takeError();
3061
3062 ShuffleSrcAddr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ShuffleSrcAddr);
3063
3064 {
3065 InsertPointTy OldIP = Builder.saveIP();
3066 Builder.restoreIP(IP: AllocaIP);
3067
3068 LocalStorage = Builder.CreateAlloca(Ty: ShuffleType);
3069 Builder.restoreIP(IP: OldIP);
3070 ShuffleDestAddr = LocalStorage;
3071 }
3072 }
3073
3074 shuffleAndStore(AllocaIP, SrcAddr: ShuffleSrcAddr, DstAddr: ShuffleDestAddr, ElemType: ShuffleType,
3075 Offset: RemoteLaneOffset, ReductionArrayTy, IsByRefElem);
3076
3077 if (IsByRefElem) {
3078 Value *GEP;
3079 InsertPointOrErrorTy GenResult =
3080 RI.DataPtrPtrGen(Builder.saveIP(),
3081 Builder.CreatePointerBitCastOrAddrSpaceCast(
3082 V: DestAlloca, DestTy: Builder.getPtrTy(), Name: ".ascast"),
3083 GEP);
3084
3085 if (!GenResult)
3086 return GenResult.takeError();
3087
3088 Builder.CreateStore(Val: Builder.CreatePointerBitCastOrAddrSpaceCast(
3089 V: LocalStorage, DestTy: Builder.getPtrTy(), Name: ".ascast"),
3090 Ptr: GEP);
3091 }
3092 } else {
3093 switch (RI.EvaluationKind) {
3094 case EvalKind::Scalar: {
3095 Value *Elem = Builder.CreateLoad(Ty: RI.ElementType, Ptr: SrcElementAddr);
3096 // Store the source element value to the dest element address.
3097 Builder.CreateStore(Val: Elem, Ptr: DestElementAddr);
3098 break;
3099 }
3100 case EvalKind::Complex: {
3101 Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
3102 Ty: RI.ElementType, Ptr: SrcElementAddr, Idx0: 0, Idx1: 0, Name: ".realp");
3103 Value *SrcReal = Builder.CreateLoad(
3104 Ty: RI.ElementType->getStructElementType(N: 0), Ptr: SrcRealPtr, Name: ".real");
3105 Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
3106 Ty: RI.ElementType, Ptr: SrcElementAddr, Idx0: 0, Idx1: 1, Name: ".imagp");
3107 Value *SrcImg = Builder.CreateLoad(
3108 Ty: RI.ElementType->getStructElementType(N: 1), Ptr: SrcImgPtr, Name: ".imag");
3109
3110 Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
3111 Ty: RI.ElementType, Ptr: DestElementAddr, Idx0: 0, Idx1: 0, Name: ".realp");
3112 Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
3113 Ty: RI.ElementType, Ptr: DestElementAddr, Idx0: 0, Idx1: 1, Name: ".imagp");
3114 Builder.CreateStore(Val: SrcReal, Ptr: DestRealPtr);
3115 Builder.CreateStore(Val: SrcImg, Ptr: DestImgPtr);
3116 break;
3117 }
3118 case EvalKind::Aggregate: {
3119 Value *SizeVal = Builder.getInt64(
3120 C: M.getDataLayout().getTypeStoreSize(Ty: RI.ElementType));
3121 Builder.CreateMemCpy(
3122 Dst: DestElementAddr, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
3123 Src: SrcElementAddr, SrcAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
3124 Size: SizeVal, isVolatile: false);
3125 break;
3126 }
3127 };
3128 }
3129
3130 // Step 3.1: Modify reference in dest Reduce list as needed.
3131 // Modifying the reference in Reduce list to point to the newly
3132 // created element. The element is live in the current function
3133 // scope and that of functions it invokes (i.e., reduce_function).
3134 // RemoteReduceData[i] = (void*)&RemoteElem
3135 if (UpdateDestListPtr) {
3136 Value *CastDestAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3137 V: DestElementAddr, DestTy: Builder.getPtrTy(),
3138 Name: DestElementAddr->getName() + ".ascast");
3139 Builder.CreateStore(Val: CastDestAddr, Ptr: DestElementPtrAddr);
3140 }
3141 }
3142
3143 return Error::success();
3144}
3145
3146Expected<Function *> OpenMPIRBuilder::emitInterWarpCopyFunction(
3147 const LocationDescription &Loc, ArrayRef<ReductionInfo> ReductionInfos,
3148 AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
3149 InsertPointTy SavedIP = Builder.saveIP();
3150 LLVMContext &Ctx = M.getContext();
3151 FunctionType *FuncTy = FunctionType::get(
3152 Result: Builder.getVoidTy(), Params: {Builder.getPtrTy(), Builder.getInt32Ty()},
3153 /* IsVarArg */ isVarArg: false);
3154 Function *WcFunc =
3155 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3156 N: "_omp_reduction_inter_warp_copy_func", M: &M);
3157 WcFunc->setAttributes(FuncAttrs);
3158 WcFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3159 WcFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3160 BasicBlock *EntryBB = BasicBlock::Create(Context&: M.getContext(), Name: "entry", Parent: WcFunc);
3161 Builder.SetInsertPoint(EntryBB);
3162
3163 // ReduceList: thread local Reduce list.
3164 // At the stage of the computation when this function is called, partially
3165 // aggregated values reside in the first lane of every active warp.
3166 Argument *ReduceListArg = WcFunc->getArg(i: 0);
3167 // NumWarps: number of warps active in the parallel region. This could
3168 // be smaller than 32 (max warps in a CTA) for partial block reduction.
3169 Argument *NumWarpsArg = WcFunc->getArg(i: 1);
3170
3171 // This array is used as a medium to transfer, one reduce element at a time,
3172 // the data from the first lane of every warp to lanes in the first warp
3173 // in order to perform the final step of a reduction in a parallel region
3174 // (reduction across warps). The array is placed in NVPTX __shared__ memory
3175 // for reduced latency, as well as to have a distinct copy for concurrently
3176 // executing target regions. The array is declared with common linkage so
3177 // as to be shared across compilation units.
3178 StringRef TransferMediumName =
3179 "__openmp_nvptx_data_transfer_temporary_storage";
3180 GlobalVariable *TransferMedium = M.getGlobalVariable(Name: TransferMediumName);
3181 unsigned WarpSize = Config.getGridValue().GV_Warp_Size;
3182 ArrayType *ArrayTy = ArrayType::get(ElementType: Builder.getInt32Ty(), NumElements: WarpSize);
3183 if (!TransferMedium) {
3184 TransferMedium = new GlobalVariable(
3185 M, ArrayTy, /*isConstant=*/false, GlobalVariable::WeakAnyLinkage,
3186 UndefValue::get(T: ArrayTy), TransferMediumName,
3187 /*InsertBefore=*/nullptr, GlobalVariable::NotThreadLocal,
3188 /*AddressSpace=*/3);
3189 }
3190
3191 // Get the CUDA thread id of the current OpenMP thread on the GPU.
3192 Value *GPUThreadID = getGPUThreadID();
3193 // nvptx_lane_id = nvptx_id % warpsize
3194 Value *LaneID = getNVPTXLaneID();
3195 // nvptx_warp_id = nvptx_id / warpsize
3196 Value *WarpID = getNVPTXWarpID();
3197
3198 InsertPointTy AllocaIP =
3199 InsertPointTy(Builder.GetInsertBlock(),
3200 Builder.GetInsertBlock()->getFirstInsertionPt());
3201 Type *Arg0Type = ReduceListArg->getType();
3202 Type *Arg1Type = NumWarpsArg->getType();
3203 Builder.restoreIP(IP: AllocaIP);
3204 AllocaInst *ReduceListAlloca = Builder.CreateAlloca(
3205 Ty: Arg0Type, ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3206 AllocaInst *NumWarpsAlloca =
3207 Builder.CreateAlloca(Ty: Arg1Type, ArraySize: nullptr, Name: NumWarpsArg->getName() + ".addr");
3208 Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3209 V: ReduceListAlloca, DestTy: Arg0Type, Name: ReduceListAlloca->getName() + ".ascast");
3210 Value *NumWarpsAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3211 V: NumWarpsAlloca, DestTy: Builder.getPtrTy(AddrSpace: 0),
3212 Name: NumWarpsAlloca->getName() + ".ascast");
3213 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListAddrCast);
3214 Builder.CreateStore(Val: NumWarpsArg, Ptr: NumWarpsAddrCast);
3215 AllocaIP = getInsertPointAfterInstr(I: NumWarpsAlloca);
3216 InsertPointTy CodeGenIP =
3217 getInsertPointAfterInstr(I: &Builder.GetInsertBlock()->back());
3218 Builder.restoreIP(IP: CodeGenIP);
3219
3220 Value *ReduceList =
3221 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListAddrCast);
3222
3223 for (auto En : enumerate(First&: ReductionInfos)) {
3224 //
3225 // Warp master copies reduce element to transfer medium in __shared__
3226 // memory.
3227 //
3228 const ReductionInfo &RI = En.value();
3229 bool IsByRefElem = !IsByRef.empty() && IsByRef[En.index()];
3230 unsigned RealTySize = M.getDataLayout().getTypeAllocSize(
3231 Ty: IsByRefElem ? RI.ByRefElementType : RI.ElementType);
3232 for (unsigned TySize = 4; TySize > 0 && RealTySize > 0; TySize /= 2) {
3233 Type *CType = Builder.getIntNTy(N: TySize * 8);
3234
3235 unsigned NumIters = RealTySize / TySize;
3236 if (NumIters == 0)
3237 continue;
3238 Value *Cnt = nullptr;
3239 Value *CntAddr = nullptr;
3240 BasicBlock *PrecondBB = nullptr;
3241 BasicBlock *ExitBB = nullptr;
3242 if (NumIters > 1) {
3243 CodeGenIP = Builder.saveIP();
3244 Builder.restoreIP(IP: AllocaIP);
3245 CntAddr =
3246 Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr, Name: ".cnt.addr");
3247
3248 CntAddr = Builder.CreateAddrSpaceCast(V: CntAddr, DestTy: Builder.getPtrTy(),
3249 Name: CntAddr->getName() + ".ascast");
3250 Builder.restoreIP(IP: CodeGenIP);
3251 Builder.CreateStore(Val: Constant::getNullValue(Ty: Builder.getInt32Ty()),
3252 Ptr: CntAddr,
3253 /*Volatile=*/isVolatile: false);
3254 PrecondBB = BasicBlock::Create(Context&: Ctx, Name: "precond");
3255 ExitBB = BasicBlock::Create(Context&: Ctx, Name: "exit");
3256 BasicBlock *BodyBB = BasicBlock::Create(Context&: Ctx, Name: "body");
3257 emitBlock(BB: PrecondBB, CurFn: Builder.GetInsertBlock()->getParent());
3258 Cnt = Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: CntAddr,
3259 /*Volatile=*/isVolatile: false);
3260 Value *Cmp = Builder.CreateICmpULT(
3261 LHS: Cnt, RHS: ConstantInt::get(Ty: Builder.getInt32Ty(), V: NumIters));
3262 Builder.CreateCondBr(Cond: Cmp, True: BodyBB, False: ExitBB);
3263 emitBlock(BB: BodyBB, CurFn: Builder.GetInsertBlock()->getParent());
3264 }
3265
3266 // kmpc_barrier.
3267 InsertPointOrErrorTy BarrierIP1 =
3268 createBarrier(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
3269 Kind: omp::Directive::OMPD_unknown,
3270 /* ForceSimpleCall */ false,
3271 /* CheckCancelFlag */ true);
3272 if (!BarrierIP1)
3273 return BarrierIP1.takeError();
3274 BasicBlock *ThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
3275 BasicBlock *ElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
3276 BasicBlock *MergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
3277
3278 // if (lane_id == 0)
3279 Value *IsWarpMaster = Builder.CreateIsNull(Arg: LaneID, Name: "warp_master");
3280 Builder.CreateCondBr(Cond: IsWarpMaster, True: ThenBB, False: ElseBB);
3281 emitBlock(BB: ThenBB, CurFn: Builder.GetInsertBlock()->getParent());
3282
3283 // Reduce element = LocalReduceList[i]
3284 auto *RedListArrayTy =
3285 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3286 Type *IndexTy = Builder.getIndexTy(
3287 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3288 Value *ElemPtrPtr =
3289 Builder.CreateInBoundsGEP(Ty: RedListArrayTy, Ptr: ReduceList,
3290 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0),
3291 ConstantInt::get(Ty: IndexTy, V: En.index())});
3292 // elemptr = ((CopyType*)(elemptrptr)) + I
3293 Value *ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtrPtr);
3294
3295 if (IsByRefElem) {
3296 InsertPointOrErrorTy GenRes =
3297 RI.DataPtrPtrGen(Builder.saveIP(), ElemPtr, ElemPtr);
3298
3299 if (!GenRes)
3300 return GenRes.takeError();
3301
3302 ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtr);
3303 }
3304
3305 if (NumIters > 1)
3306 ElemPtr = Builder.CreateGEP(Ty: Builder.getInt32Ty(), Ptr: ElemPtr, IdxList: Cnt);
3307
3308 // Get pointer to location in transfer medium.
3309 // MediumPtr = &medium[warp_id]
3310 Value *MediumPtr = Builder.CreateInBoundsGEP(
3311 Ty: ArrayTy, Ptr: TransferMedium, IdxList: {Builder.getInt64(C: 0), WarpID});
3312 // elem = *elemptr
3313 //*MediumPtr = elem
3314 Value *Elem = Builder.CreateLoad(Ty: CType, Ptr: ElemPtr);
3315 // Store the source element value to the dest element address.
3316 Builder.CreateStore(Val: Elem, Ptr: MediumPtr,
3317 /*IsVolatile*/ isVolatile: true);
3318 Builder.CreateBr(Dest: MergeBB);
3319
3320 // else
3321 emitBlock(BB: ElseBB, CurFn: Builder.GetInsertBlock()->getParent());
3322 Builder.CreateBr(Dest: MergeBB);
3323
3324 // endif
3325 emitBlock(BB: MergeBB, CurFn: Builder.GetInsertBlock()->getParent());
3326 InsertPointOrErrorTy BarrierIP2 =
3327 createBarrier(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
3328 Kind: omp::Directive::OMPD_unknown,
3329 /* ForceSimpleCall */ false,
3330 /* CheckCancelFlag */ true);
3331 if (!BarrierIP2)
3332 return BarrierIP2.takeError();
3333
3334 // Warp 0 copies reduce element from transfer medium
3335 BasicBlock *W0ThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
3336 BasicBlock *W0ElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
3337 BasicBlock *W0MergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
3338
3339 Value *NumWarpsVal =
3340 Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: NumWarpsAddrCast);
3341 // Up to 32 threads in warp 0 are active.
3342 Value *IsActiveThread =
3343 Builder.CreateICmpULT(LHS: GPUThreadID, RHS: NumWarpsVal, Name: "is_active_thread");
3344 Builder.CreateCondBr(Cond: IsActiveThread, True: W0ThenBB, False: W0ElseBB);
3345
3346 emitBlock(BB: W0ThenBB, CurFn: Builder.GetInsertBlock()->getParent());
3347
3348 // SecMediumPtr = &medium[tid]
3349 // SrcMediumVal = *SrcMediumPtr
3350 Value *SrcMediumPtrVal = Builder.CreateInBoundsGEP(
3351 Ty: ArrayTy, Ptr: TransferMedium, IdxList: {Builder.getInt64(C: 0), GPUThreadID});
3352 // TargetElemPtr = (CopyType*)(SrcDataAddr[i]) + I
3353 Value *TargetElemPtrPtr =
3354 Builder.CreateInBoundsGEP(Ty: RedListArrayTy, Ptr: ReduceList,
3355 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0),
3356 ConstantInt::get(Ty: IndexTy, V: En.index())});
3357 Value *TargetElemPtrVal =
3358 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: TargetElemPtrPtr);
3359 Value *TargetElemPtr = TargetElemPtrVal;
3360
3361 if (IsByRefElem) {
3362 InsertPointOrErrorTy GenRes =
3363 RI.DataPtrPtrGen(Builder.saveIP(), TargetElemPtr, TargetElemPtr);
3364
3365 if (!GenRes)
3366 return GenRes.takeError();
3367
3368 TargetElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: TargetElemPtr);
3369 }
3370
3371 if (NumIters > 1)
3372 TargetElemPtr =
3373 Builder.CreateGEP(Ty: Builder.getInt32Ty(), Ptr: TargetElemPtr, IdxList: Cnt);
3374
3375 // *TargetElemPtr = SrcMediumVal;
3376 Value *SrcMediumValue =
3377 Builder.CreateLoad(Ty: CType, Ptr: SrcMediumPtrVal, /*IsVolatile*/ isVolatile: true);
3378 Builder.CreateStore(Val: SrcMediumValue, Ptr: TargetElemPtr);
3379 Builder.CreateBr(Dest: W0MergeBB);
3380
3381 emitBlock(BB: W0ElseBB, CurFn: Builder.GetInsertBlock()->getParent());
3382 Builder.CreateBr(Dest: W0MergeBB);
3383
3384 emitBlock(BB: W0MergeBB, CurFn: Builder.GetInsertBlock()->getParent());
3385
3386 if (NumIters > 1) {
3387 Cnt = Builder.CreateNSWAdd(
3388 LHS: Cnt, RHS: ConstantInt::get(Ty: Builder.getInt32Ty(), /*V=*/1));
3389 Builder.CreateStore(Val: Cnt, Ptr: CntAddr, /*Volatile=*/isVolatile: false);
3390
3391 auto *CurFn = Builder.GetInsertBlock()->getParent();
3392 emitBranch(Target: PrecondBB);
3393 emitBlock(BB: ExitBB, CurFn);
3394 }
3395 RealTySize %= TySize;
3396 }
3397 }
3398
3399 Builder.CreateRetVoid();
3400 Builder.restoreIP(IP: SavedIP);
3401
3402 return WcFunc;
3403}
3404
3405Expected<Function *> OpenMPIRBuilder::emitShuffleAndReduceFunction(
3406 ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
3407 AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
3408 LLVMContext &Ctx = M.getContext();
3409 FunctionType *FuncTy =
3410 FunctionType::get(Result: Builder.getVoidTy(),
3411 Params: {Builder.getPtrTy(), Builder.getInt16Ty(),
3412 Builder.getInt16Ty(), Builder.getInt16Ty()},
3413 /* IsVarArg */ isVarArg: false);
3414 Function *SarFunc =
3415 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3416 N: "_omp_reduction_shuffle_and_reduce_func", M: &M);
3417 SarFunc->setAttributes(FuncAttrs);
3418 SarFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3419 SarFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3420 SarFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
3421 SarFunc->addParamAttr(ArgNo: 3, Kind: Attribute::NoUndef);
3422 SarFunc->addParamAttr(ArgNo: 1, Kind: Attribute::SExt);
3423 SarFunc->addParamAttr(ArgNo: 2, Kind: Attribute::SExt);
3424 SarFunc->addParamAttr(ArgNo: 3, Kind: Attribute::SExt);
3425 BasicBlock *EntryBB = BasicBlock::Create(Context&: M.getContext(), Name: "entry", Parent: SarFunc);
3426 Builder.SetInsertPoint(EntryBB);
3427
3428 // Thread local Reduce list used to host the values of data to be reduced.
3429 Argument *ReduceListArg = SarFunc->getArg(i: 0);
3430 // Current lane id; could be logical.
3431 Argument *LaneIDArg = SarFunc->getArg(i: 1);
3432 // Offset of the remote source lane relative to the current lane.
3433 Argument *RemoteLaneOffsetArg = SarFunc->getArg(i: 2);
3434 // Algorithm version. This is expected to be known at compile time.
3435 Argument *AlgoVerArg = SarFunc->getArg(i: 3);
3436
3437 Type *ReduceListArgType = ReduceListArg->getType();
3438 Type *LaneIDArgType = LaneIDArg->getType();
3439 Type *LaneIDArgPtrType = Builder.getPtrTy(AddrSpace: 0);
3440 Value *ReduceListAlloca = Builder.CreateAlloca(
3441 Ty: ReduceListArgType, ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3442 Value *LaneIdAlloca = Builder.CreateAlloca(Ty: LaneIDArgType, ArraySize: nullptr,
3443 Name: LaneIDArg->getName() + ".addr");
3444 Value *RemoteLaneOffsetAlloca = Builder.CreateAlloca(
3445 Ty: LaneIDArgType, ArraySize: nullptr, Name: RemoteLaneOffsetArg->getName() + ".addr");
3446 Value *AlgoVerAlloca = Builder.CreateAlloca(Ty: LaneIDArgType, ArraySize: nullptr,
3447 Name: AlgoVerArg->getName() + ".addr");
3448 ArrayType *RedListArrayTy =
3449 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3450
3451 // Create a local thread-private variable to host the Reduce list
3452 // from a remote lane.
3453 Instruction *RemoteReductionListAlloca = Builder.CreateAlloca(
3454 Ty: RedListArrayTy, ArraySize: nullptr, Name: ".omp.reduction.remote_reduce_list");
3455
3456 Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3457 V: ReduceListAlloca, DestTy: ReduceListArgType,
3458 Name: ReduceListAlloca->getName() + ".ascast");
3459 Value *LaneIdAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3460 V: LaneIdAlloca, DestTy: LaneIDArgPtrType, Name: LaneIdAlloca->getName() + ".ascast");
3461 Value *RemoteLaneOffsetAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3462 V: RemoteLaneOffsetAlloca, DestTy: LaneIDArgPtrType,
3463 Name: RemoteLaneOffsetAlloca->getName() + ".ascast");
3464 Value *AlgoVerAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3465 V: AlgoVerAlloca, DestTy: LaneIDArgPtrType, Name: AlgoVerAlloca->getName() + ".ascast");
3466 Value *RemoteListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3467 V: RemoteReductionListAlloca, DestTy: Builder.getPtrTy(),
3468 Name: RemoteReductionListAlloca->getName() + ".ascast");
3469
3470 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListAddrCast);
3471 Builder.CreateStore(Val: LaneIDArg, Ptr: LaneIdAddrCast);
3472 Builder.CreateStore(Val: RemoteLaneOffsetArg, Ptr: RemoteLaneOffsetAddrCast);
3473 Builder.CreateStore(Val: AlgoVerArg, Ptr: AlgoVerAddrCast);
3474
3475 Value *ReduceList = Builder.CreateLoad(Ty: ReduceListArgType, Ptr: ReduceListAddrCast);
3476 Value *LaneId = Builder.CreateLoad(Ty: LaneIDArgType, Ptr: LaneIdAddrCast);
3477 Value *RemoteLaneOffset =
3478 Builder.CreateLoad(Ty: LaneIDArgType, Ptr: RemoteLaneOffsetAddrCast);
3479 Value *AlgoVer = Builder.CreateLoad(Ty: LaneIDArgType, Ptr: AlgoVerAddrCast);
3480
3481 InsertPointTy AllocaIP = getInsertPointAfterInstr(I: RemoteReductionListAlloca);
3482
3483 // This loop iterates through the list of reduce elements and copies,
3484 // element by element, from a remote lane in the warp to RemoteReduceList,
3485 // hosted on the thread's stack.
3486 Error EmitRedLsCpRes = emitReductionListCopy(
3487 AllocaIP, Action: CopyAction::RemoteLaneToThread, ReductionArrayTy: RedListArrayTy, ReductionInfos,
3488 SrcBase: ReduceList, DestBase: RemoteListAddrCast, IsByRef,
3489 CopyOptions: {.RemoteLaneOffset: RemoteLaneOffset, .ScratchpadIndex: nullptr, .ScratchpadWidth: nullptr});
3490
3491 if (EmitRedLsCpRes)
3492 return EmitRedLsCpRes;
3493
3494 // The actions to be performed on the Remote Reduce list is dependent
3495 // on the algorithm version.
3496 //
3497 // if (AlgoVer==0) || (AlgoVer==1 && (LaneId < Offset)) || (AlgoVer==2 &&
3498 // LaneId % 2 == 0 && Offset > 0):
3499 // do the reduction value aggregation
3500 //
3501 // The thread local variable Reduce list is mutated in place to host the
3502 // reduced data, which is the aggregated value produced from local and
3503 // remote lanes.
3504 //
3505 // Note that AlgoVer is expected to be a constant integer known at compile
3506 // time.
3507 // When AlgoVer==0, the first conjunction evaluates to true, making
3508 // the entire predicate true during compile time.
3509 // When AlgoVer==1, the second conjunction has only the second part to be
3510 // evaluated during runtime. Other conjunctions evaluates to false
3511 // during compile time.
3512 // When AlgoVer==2, the third conjunction has only the second part to be
3513 // evaluated during runtime. Other conjunctions evaluates to false
3514 // during compile time.
3515 Value *CondAlgo0 = Builder.CreateIsNull(Arg: AlgoVer);
3516 Value *Algo1 = Builder.CreateICmpEQ(LHS: AlgoVer, RHS: Builder.getInt16(C: 1));
3517 Value *LaneComp = Builder.CreateICmpULT(LHS: LaneId, RHS: RemoteLaneOffset);
3518 Value *CondAlgo1 = Builder.CreateAnd(LHS: Algo1, RHS: LaneComp);
3519 Value *Algo2 = Builder.CreateICmpEQ(LHS: AlgoVer, RHS: Builder.getInt16(C: 2));
3520 Value *LaneIdAnd1 = Builder.CreateAnd(LHS: LaneId, RHS: Builder.getInt16(C: 1));
3521 Value *LaneIdComp = Builder.CreateIsNull(Arg: LaneIdAnd1);
3522 Value *Algo2AndLaneIdComp = Builder.CreateAnd(LHS: Algo2, RHS: LaneIdComp);
3523 Value *RemoteOffsetComp =
3524 Builder.CreateICmpSGT(LHS: RemoteLaneOffset, RHS: Builder.getInt16(C: 0));
3525 Value *CondAlgo2 = Builder.CreateAnd(LHS: Algo2AndLaneIdComp, RHS: RemoteOffsetComp);
3526 Value *CA0OrCA1 = Builder.CreateOr(LHS: CondAlgo0, RHS: CondAlgo1);
3527 Value *CondReduce = Builder.CreateOr(LHS: CA0OrCA1, RHS: CondAlgo2);
3528
3529 BasicBlock *ThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
3530 BasicBlock *ElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
3531 BasicBlock *MergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
3532
3533 Builder.CreateCondBr(Cond: CondReduce, True: ThenBB, False: ElseBB);
3534 emitBlock(BB: ThenBB, CurFn: Builder.GetInsertBlock()->getParent());
3535 Value *LocalReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3536 V: ReduceList, DestTy: Builder.getPtrTy());
3537 Value *RemoteReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3538 V: RemoteListAddrCast, DestTy: Builder.getPtrTy());
3539 createRuntimeFunctionCall(Callee: ReduceFn, Args: {LocalReduceListPtr, RemoteReduceListPtr})
3540 ->addFnAttr(Kind: Attribute::NoUnwind);
3541 Builder.CreateBr(Dest: MergeBB);
3542
3543 emitBlock(BB: ElseBB, CurFn: Builder.GetInsertBlock()->getParent());
3544 Builder.CreateBr(Dest: MergeBB);
3545
3546 emitBlock(BB: MergeBB, CurFn: Builder.GetInsertBlock()->getParent());
3547
3548 // if (AlgoVer==1 && (LaneId >= Offset)) copy Remote Reduce list to local
3549 // Reduce list.
3550 Algo1 = Builder.CreateICmpEQ(LHS: AlgoVer, RHS: Builder.getInt16(C: 1));
3551 Value *LaneIdGtOffset = Builder.CreateICmpUGE(LHS: LaneId, RHS: RemoteLaneOffset);
3552 Value *CondCopy = Builder.CreateAnd(LHS: Algo1, RHS: LaneIdGtOffset);
3553
3554 BasicBlock *CpyThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
3555 BasicBlock *CpyElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
3556 BasicBlock *CpyMergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
3557 Builder.CreateCondBr(Cond: CondCopy, True: CpyThenBB, False: CpyElseBB);
3558
3559 emitBlock(BB: CpyThenBB, CurFn: Builder.GetInsertBlock()->getParent());
3560
3561 EmitRedLsCpRes = emitReductionListCopy(
3562 AllocaIP, Action: CopyAction::ThreadCopy, ReductionArrayTy: RedListArrayTy, ReductionInfos,
3563 SrcBase: RemoteListAddrCast, DestBase: ReduceList, IsByRef);
3564
3565 if (EmitRedLsCpRes)
3566 return EmitRedLsCpRes;
3567
3568 Builder.CreateBr(Dest: CpyMergeBB);
3569
3570 emitBlock(BB: CpyElseBB, CurFn: Builder.GetInsertBlock()->getParent());
3571 Builder.CreateBr(Dest: CpyMergeBB);
3572
3573 emitBlock(BB: CpyMergeBB, CurFn: Builder.GetInsertBlock()->getParent());
3574
3575 Builder.CreateRetVoid();
3576
3577 return SarFunc;
3578}
3579
3580Expected<Function *> OpenMPIRBuilder::emitListToGlobalCopyFunction(
3581 ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
3582 AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
3583 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
3584 LLVMContext &Ctx = M.getContext();
3585 FunctionType *FuncTy = FunctionType::get(
3586 Result: Builder.getVoidTy(),
3587 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3588 /* IsVarArg */ isVarArg: false);
3589 Function *LtGCFunc =
3590 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3591 N: "_omp_reduction_list_to_global_copy_func", M: &M);
3592 LtGCFunc->setAttributes(FuncAttrs);
3593 LtGCFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3594 LtGCFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3595 LtGCFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
3596
3597 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: LtGCFunc);
3598 Builder.SetInsertPoint(EntryBlock);
3599
3600 // Buffer: global reduction buffer.
3601 Argument *BufferArg = LtGCFunc->getArg(i: 0);
3602 // Idx: index of the buffer.
3603 Argument *IdxArg = LtGCFunc->getArg(i: 1);
3604 // ReduceList: thread local Reduce list.
3605 Argument *ReduceListArg = LtGCFunc->getArg(i: 2);
3606
3607 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
3608 Name: BufferArg->getName() + ".addr");
3609 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
3610 Name: IdxArg->getName() + ".addr");
3611 Value *ReduceListArgAlloca = Builder.CreateAlloca(
3612 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3613 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3614 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
3615 Name: BufferArgAlloca->getName() + ".ascast");
3616 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3617 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
3618 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3619 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
3620 Name: ReduceListArgAlloca->getName() + ".ascast");
3621
3622 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
3623 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
3624 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
3625
3626 Value *LocalReduceList =
3627 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
3628 Value *BufferArgVal =
3629 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
3630 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
3631 Type *IndexTy = Builder.getIndexTy(
3632 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3633 for (auto En : enumerate(First&: ReductionInfos)) {
3634 const ReductionInfo &RI = En.value();
3635 auto *RedListArrayTy =
3636 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3637 // Reduce element = LocalReduceList[i]
3638 Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
3639 Ty: RedListArrayTy, Ptr: LocalReduceList,
3640 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3641 // elemptr = ((CopyType*)(elemptrptr)) + I
3642 Value *ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtrPtr);
3643
3644 // Global = Buffer.VD[Idx];
3645 Value *BufferVD =
3646 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferArgVal, IdxList: Idxs);
3647 Value *GlobVal = Builder.CreateConstInBoundsGEP2_32(
3648 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
3649
3650 switch (RI.EvaluationKind) {
3651 case EvalKind::Scalar: {
3652 Value *TargetElement;
3653
3654 if (IsByRef.empty() || !IsByRef[En.index()]) {
3655 TargetElement = Builder.CreateLoad(Ty: RI.ElementType, Ptr: ElemPtr);
3656 } else {
3657 InsertPointOrErrorTy GenResult =
3658 RI.DataPtrPtrGen(Builder.saveIP(), ElemPtr, ElemPtr);
3659
3660 if (!GenResult)
3661 return GenResult.takeError();
3662
3663 ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtr);
3664 TargetElement = Builder.CreateLoad(Ty: RI.ByRefElementType, Ptr: ElemPtr);
3665 }
3666
3667 Builder.CreateStore(Val: TargetElement, Ptr: GlobVal);
3668 break;
3669 }
3670 case EvalKind::Complex: {
3671 Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
3672 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 0, Name: ".realp");
3673 Value *SrcReal = Builder.CreateLoad(
3674 Ty: RI.ElementType->getStructElementType(N: 0), Ptr: SrcRealPtr, Name: ".real");
3675 Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
3676 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 1, Name: ".imagp");
3677 Value *SrcImg = Builder.CreateLoad(
3678 Ty: RI.ElementType->getStructElementType(N: 1), Ptr: SrcImgPtr, Name: ".imag");
3679
3680 Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
3681 Ty: RI.ElementType, Ptr: GlobVal, Idx0: 0, Idx1: 0, Name: ".realp");
3682 Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
3683 Ty: RI.ElementType, Ptr: GlobVal, Idx0: 0, Idx1: 1, Name: ".imagp");
3684 Builder.CreateStore(Val: SrcReal, Ptr: DestRealPtr);
3685 Builder.CreateStore(Val: SrcImg, Ptr: DestImgPtr);
3686 break;
3687 }
3688 case EvalKind::Aggregate: {
3689 Value *SizeVal =
3690 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: RI.ElementType));
3691 Builder.CreateMemCpy(
3692 Dst: GlobVal, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType), Src: ElemPtr,
3693 SrcAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType), Size: SizeVal, isVolatile: false);
3694 break;
3695 }
3696 }
3697 }
3698
3699 Builder.CreateRetVoid();
3700 Builder.restoreIP(IP: OldIP);
3701 return LtGCFunc;
3702}
3703
3704Expected<Function *> OpenMPIRBuilder::emitListToGlobalReduceFunction(
3705 ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
3706 Type *ReductionsBufferTy, AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
3707 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
3708 LLVMContext &Ctx = M.getContext();
3709 FunctionType *FuncTy = FunctionType::get(
3710 Result: Builder.getVoidTy(),
3711 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3712 /* IsVarArg */ isVarArg: false);
3713 Function *LtGRFunc =
3714 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3715 N: "_omp_reduction_list_to_global_reduce_func", M: &M);
3716 LtGRFunc->setAttributes(FuncAttrs);
3717 LtGRFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3718 LtGRFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3719 LtGRFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
3720
3721 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: LtGRFunc);
3722 Builder.SetInsertPoint(EntryBlock);
3723
3724 // Buffer: global reduction buffer.
3725 Argument *BufferArg = LtGRFunc->getArg(i: 0);
3726 // Idx: index of the buffer.
3727 Argument *IdxArg = LtGRFunc->getArg(i: 1);
3728 // ReduceList: thread local Reduce list.
3729 Argument *ReduceListArg = LtGRFunc->getArg(i: 2);
3730
3731 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
3732 Name: BufferArg->getName() + ".addr");
3733 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
3734 Name: IdxArg->getName() + ".addr");
3735 Value *ReduceListArgAlloca = Builder.CreateAlloca(
3736 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3737 auto *RedListArrayTy =
3738 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3739
3740 // 1. Build a list of reduction variables.
3741 // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
3742 Value *LocalReduceList =
3743 Builder.CreateAlloca(Ty: RedListArrayTy, ArraySize: nullptr, Name: ".omp.reduction.red_list");
3744
3745 InsertPointTy AllocaIP{EntryBlock, EntryBlock->begin()};
3746
3747 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3748 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
3749 Name: BufferArgAlloca->getName() + ".ascast");
3750 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3751 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
3752 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3753 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
3754 Name: ReduceListArgAlloca->getName() + ".ascast");
3755 Value *LocalReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3756 V: LocalReduceList, DestTy: Builder.getPtrTy(),
3757 Name: LocalReduceList->getName() + ".ascast");
3758
3759 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
3760 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
3761 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
3762
3763 Value *BufferVal = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
3764 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
3765 Type *IndexTy = Builder.getIndexTy(
3766 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3767 for (auto En : enumerate(First&: ReductionInfos)) {
3768 const ReductionInfo &RI = En.value();
3769 Value *ByRefAlloc;
3770
3771 if (!IsByRef.empty() && IsByRef[En.index()]) {
3772 InsertPointTy OldIP = Builder.saveIP();
3773 Builder.restoreIP(IP: AllocaIP);
3774
3775 ByRefAlloc = Builder.CreateAlloca(Ty: RI.ByRefAllocatedType);
3776 ByRefAlloc = Builder.CreatePointerBitCastOrAddrSpaceCast(
3777 V: ByRefAlloc, DestTy: Builder.getPtrTy(), Name: ByRefAlloc->getName() + ".ascast");
3778
3779 Builder.restoreIP(IP: OldIP);
3780 }
3781
3782 Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
3783 Ty: RedListArrayTy, Ptr: LocalReduceListAddrCast,
3784 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3785 Value *BufferVD =
3786 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferVal, IdxList: Idxs);
3787 // Global = Buffer.VD[Idx];
3788 Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3789 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
3790
3791 if (!IsByRef.empty() && IsByRef[En.index()]) {
3792 Value *ByRefDataPtr;
3793
3794 InsertPointOrErrorTy GenResult =
3795 RI.DataPtrPtrGen(Builder.saveIP(), ByRefAlloc, ByRefDataPtr);
3796
3797 if (!GenResult)
3798 return GenResult.takeError();
3799
3800 Builder.CreateStore(Val: GlobValPtr, Ptr: ByRefDataPtr);
3801 Builder.CreateStore(Val: ByRefAlloc, Ptr: TargetElementPtrPtr);
3802 } else {
3803 Builder.CreateStore(Val: GlobValPtr, Ptr: TargetElementPtrPtr);
3804 }
3805 }
3806
3807 // Call reduce_function(GlobalReduceList, ReduceList)
3808 Value *ReduceList =
3809 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
3810 createRuntimeFunctionCall(Callee: ReduceFn, Args: {LocalReduceListAddrCast, ReduceList})
3811 ->addFnAttr(Kind: Attribute::NoUnwind);
3812 Builder.CreateRetVoid();
3813 Builder.restoreIP(IP: OldIP);
3814 return LtGRFunc;
3815}
3816
3817Expected<Function *> OpenMPIRBuilder::emitGlobalToListCopyFunction(
3818 ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
3819 AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
3820 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
3821 LLVMContext &Ctx = M.getContext();
3822 FunctionType *FuncTy = FunctionType::get(
3823 Result: Builder.getVoidTy(),
3824 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3825 /* IsVarArg */ isVarArg: false);
3826 Function *GtLCFunc =
3827 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3828 N: "_omp_reduction_global_to_list_copy_func", M: &M);
3829 GtLCFunc->setAttributes(FuncAttrs);
3830 GtLCFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3831 GtLCFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3832 GtLCFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
3833
3834 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: GtLCFunc);
3835 Builder.SetInsertPoint(EntryBlock);
3836
3837 // Buffer: global reduction buffer.
3838 Argument *BufferArg = GtLCFunc->getArg(i: 0);
3839 // Idx: index of the buffer.
3840 Argument *IdxArg = GtLCFunc->getArg(i: 1);
3841 // ReduceList: thread local Reduce list.
3842 Argument *ReduceListArg = GtLCFunc->getArg(i: 2);
3843
3844 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
3845 Name: BufferArg->getName() + ".addr");
3846 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
3847 Name: IdxArg->getName() + ".addr");
3848 Value *ReduceListArgAlloca = Builder.CreateAlloca(
3849 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3850 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3851 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
3852 Name: BufferArgAlloca->getName() + ".ascast");
3853 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3854 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
3855 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3856 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
3857 Name: ReduceListArgAlloca->getName() + ".ascast");
3858 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
3859 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
3860 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
3861
3862 Value *LocalReduceList =
3863 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
3864 Value *BufferVal = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
3865 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
3866 Type *IndexTy = Builder.getIndexTy(
3867 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3868 for (auto En : enumerate(First&: ReductionInfos)) {
3869 const OpenMPIRBuilder::ReductionInfo &RI = En.value();
3870 auto *RedListArrayTy =
3871 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3872 // Reduce element = LocalReduceList[i]
3873 Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
3874 Ty: RedListArrayTy, Ptr: LocalReduceList,
3875 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3876 // elemptr = ((CopyType*)(elemptrptr)) + I
3877 Value *ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtrPtr);
3878 // Global = Buffer.VD[Idx];
3879 Value *BufferVD =
3880 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferVal, IdxList: Idxs);
3881 Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3882 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
3883
3884 switch (RI.EvaluationKind) {
3885 case EvalKind::Scalar: {
3886 Type *ElemType = RI.ElementType;
3887
3888 if (!IsByRef.empty() && IsByRef[En.index()]) {
3889 ElemType = RI.ByRefElementType;
3890 InsertPointOrErrorTy GenResult =
3891 RI.DataPtrPtrGen(Builder.saveIP(), ElemPtr, ElemPtr);
3892
3893 if (!GenResult)
3894 return GenResult.takeError();
3895
3896 ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtr);
3897 }
3898
3899 Value *TargetElement = Builder.CreateLoad(Ty: ElemType, Ptr: GlobValPtr);
3900 Builder.CreateStore(Val: TargetElement, Ptr: ElemPtr);
3901 break;
3902 }
3903 case EvalKind::Complex: {
3904 Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
3905 Ty: RI.ElementType, Ptr: GlobValPtr, Idx0: 0, Idx1: 0, Name: ".realp");
3906 Value *SrcReal = Builder.CreateLoad(
3907 Ty: RI.ElementType->getStructElementType(N: 0), Ptr: SrcRealPtr, Name: ".real");
3908 Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
3909 Ty: RI.ElementType, Ptr: GlobValPtr, Idx0: 0, Idx1: 1, Name: ".imagp");
3910 Value *SrcImg = Builder.CreateLoad(
3911 Ty: RI.ElementType->getStructElementType(N: 1), Ptr: SrcImgPtr, Name: ".imag");
3912
3913 Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
3914 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 0, Name: ".realp");
3915 Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
3916 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 1, Name: ".imagp");
3917 Builder.CreateStore(Val: SrcReal, Ptr: DestRealPtr);
3918 Builder.CreateStore(Val: SrcImg, Ptr: DestImgPtr);
3919 break;
3920 }
3921 case EvalKind::Aggregate: {
3922 Value *SizeVal =
3923 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: RI.ElementType));
3924 Builder.CreateMemCpy(
3925 Dst: ElemPtr, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
3926 Src: GlobValPtr, SrcAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
3927 Size: SizeVal, isVolatile: false);
3928 break;
3929 }
3930 }
3931 }
3932
3933 Builder.CreateRetVoid();
3934 Builder.restoreIP(IP: OldIP);
3935 return GtLCFunc;
3936}
3937
3938Expected<Function *> OpenMPIRBuilder::emitGlobalToListReduceFunction(
3939 ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
3940 Type *ReductionsBufferTy, AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
3941 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
3942 LLVMContext &Ctx = M.getContext();
3943 auto *FuncTy = FunctionType::get(
3944 Result: Builder.getVoidTy(),
3945 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3946 /* IsVarArg */ isVarArg: false);
3947 Function *GtLRFunc =
3948 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3949 N: "_omp_reduction_global_to_list_reduce_func", M: &M);
3950 GtLRFunc->setAttributes(FuncAttrs);
3951 GtLRFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3952 GtLRFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3953 GtLRFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
3954
3955 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: GtLRFunc);
3956 Builder.SetInsertPoint(EntryBlock);
3957
3958 // Buffer: global reduction buffer.
3959 Argument *BufferArg = GtLRFunc->getArg(i: 0);
3960 // Idx: index of the buffer.
3961 Argument *IdxArg = GtLRFunc->getArg(i: 1);
3962 // ReduceList: thread local Reduce list.
3963 Argument *ReduceListArg = GtLRFunc->getArg(i: 2);
3964
3965 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
3966 Name: BufferArg->getName() + ".addr");
3967 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
3968 Name: IdxArg->getName() + ".addr");
3969 Value *ReduceListArgAlloca = Builder.CreateAlloca(
3970 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3971 ArrayType *RedListArrayTy =
3972 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3973
3974 // 1. Build a list of reduction variables.
3975 // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
3976 Value *LocalReduceList =
3977 Builder.CreateAlloca(Ty: RedListArrayTy, ArraySize: nullptr, Name: ".omp.reduction.red_list");
3978
3979 InsertPointTy AllocaIP{EntryBlock, EntryBlock->begin()};
3980
3981 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3982 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
3983 Name: BufferArgAlloca->getName() + ".ascast");
3984 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3985 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
3986 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3987 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
3988 Name: ReduceListArgAlloca->getName() + ".ascast");
3989 Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
3990 V: LocalReduceList, DestTy: Builder.getPtrTy(),
3991 Name: LocalReduceList->getName() + ".ascast");
3992
3993 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
3994 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
3995 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
3996
3997 Value *BufferVal = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
3998 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
3999 Type *IndexTy = Builder.getIndexTy(
4000 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
4001 for (auto En : enumerate(First&: ReductionInfos)) {
4002 const ReductionInfo &RI = En.value();
4003 Value *ByRefAlloc;
4004
4005 if (!IsByRef.empty() && IsByRef[En.index()]) {
4006 InsertPointTy OldIP = Builder.saveIP();
4007 Builder.restoreIP(IP: AllocaIP);
4008
4009 ByRefAlloc = Builder.CreateAlloca(Ty: RI.ByRefAllocatedType);
4010 ByRefAlloc = Builder.CreatePointerBitCastOrAddrSpaceCast(
4011 V: ByRefAlloc, DestTy: Builder.getPtrTy(), Name: ByRefAlloc->getName() + ".ascast");
4012
4013 Builder.restoreIP(IP: OldIP);
4014 }
4015
4016 Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
4017 Ty: RedListArrayTy, Ptr: ReductionList,
4018 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
4019 // Global = Buffer.VD[Idx];
4020 Value *BufferVD =
4021 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferVal, IdxList: Idxs);
4022 Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
4023 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
4024
4025 if (!IsByRef.empty() && IsByRef[En.index()]) {
4026 Value *ByRefDataPtr;
4027 InsertPointOrErrorTy GenResult =
4028 RI.DataPtrPtrGen(Builder.saveIP(), ByRefAlloc, ByRefDataPtr);
4029 if (!GenResult)
4030 return GenResult.takeError();
4031
4032 Builder.CreateStore(Val: GlobValPtr, Ptr: ByRefDataPtr);
4033 Builder.CreateStore(Val: ByRefAlloc, Ptr: TargetElementPtrPtr);
4034 } else {
4035 Builder.CreateStore(Val: GlobValPtr, Ptr: TargetElementPtrPtr);
4036 }
4037 }
4038
4039 // Call reduce_function(ReduceList, GlobalReduceList)
4040 Value *ReduceList =
4041 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
4042 createRuntimeFunctionCall(Callee: ReduceFn, Args: {ReduceList, ReductionList})
4043 ->addFnAttr(Kind: Attribute::NoUnwind);
4044 Builder.CreateRetVoid();
4045 Builder.restoreIP(IP: OldIP);
4046 return GtLRFunc;
4047}
4048
4049std::string OpenMPIRBuilder::getReductionFuncName(StringRef Name) const {
4050 std::string Suffix =
4051 createPlatformSpecificName(Parts: {"omp", "reduction", "reduction_func"});
4052 return (Name + Suffix).str();
4053}
4054
4055Expected<Function *> OpenMPIRBuilder::createReductionFunction(
4056 StringRef ReducerName, ArrayRef<ReductionInfo> ReductionInfos,
4057 ArrayRef<bool> IsByRef, ReductionGenCBKind ReductionGenCBKind,
4058 AttributeList FuncAttrs) {
4059 auto *FuncTy = FunctionType::get(Result: Builder.getVoidTy(),
4060 Params: {Builder.getPtrTy(), Builder.getPtrTy()},
4061 /* IsVarArg */ isVarArg: false);
4062 std::string Name = getReductionFuncName(Name: ReducerName);
4063 Function *ReductionFunc =
4064 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage, N: Name, M: &M);
4065 ReductionFunc->setAttributes(FuncAttrs);
4066 ReductionFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
4067 ReductionFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
4068 BasicBlock *EntryBB =
4069 BasicBlock::Create(Context&: M.getContext(), Name: "entry", Parent: ReductionFunc);
4070 Builder.SetInsertPoint(EntryBB);
4071
4072 // Need to alloca memory here and deal with the pointers before getting
4073 // LHS/RHS pointers out
4074 Value *LHSArrayPtr = nullptr;
4075 Value *RHSArrayPtr = nullptr;
4076 Argument *Arg0 = ReductionFunc->getArg(i: 0);
4077 Argument *Arg1 = ReductionFunc->getArg(i: 1);
4078 Type *Arg0Type = Arg0->getType();
4079 Type *Arg1Type = Arg1->getType();
4080
4081 Value *LHSAlloca =
4082 Builder.CreateAlloca(Ty: Arg0Type, ArraySize: nullptr, Name: Arg0->getName() + ".addr");
4083 Value *RHSAlloca =
4084 Builder.CreateAlloca(Ty: Arg1Type, ArraySize: nullptr, Name: Arg1->getName() + ".addr");
4085 Value *LHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4086 V: LHSAlloca, DestTy: Arg0Type, Name: LHSAlloca->getName() + ".ascast");
4087 Value *RHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4088 V: RHSAlloca, DestTy: Arg1Type, Name: RHSAlloca->getName() + ".ascast");
4089 Builder.CreateStore(Val: Arg0, Ptr: LHSAddrCast);
4090 Builder.CreateStore(Val: Arg1, Ptr: RHSAddrCast);
4091 LHSArrayPtr = Builder.CreateLoad(Ty: Arg0Type, Ptr: LHSAddrCast);
4092 RHSArrayPtr = Builder.CreateLoad(Ty: Arg1Type, Ptr: RHSAddrCast);
4093
4094 Type *RedArrayTy = ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
4095 Type *IndexTy = Builder.getIndexTy(
4096 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
4097 SmallVector<Value *> LHSPtrs, RHSPtrs;
4098 for (auto En : enumerate(First&: ReductionInfos)) {
4099 const ReductionInfo &RI = En.value();
4100 Value *RHSI8PtrPtr = Builder.CreateInBoundsGEP(
4101 Ty: RedArrayTy, Ptr: RHSArrayPtr,
4102 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
4103 Value *RHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: RHSI8PtrPtr);
4104 Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
4105 V: RHSI8Ptr, DestTy: RI.PrivateVariable->getType(),
4106 Name: RHSI8Ptr->getName() + ".ascast");
4107
4108 Value *LHSI8PtrPtr = Builder.CreateInBoundsGEP(
4109 Ty: RedArrayTy, Ptr: LHSArrayPtr,
4110 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
4111 Value *LHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: LHSI8PtrPtr);
4112 Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
4113 V: LHSI8Ptr, DestTy: RI.Variable->getType(), Name: LHSI8Ptr->getName() + ".ascast");
4114
4115 if (ReductionGenCBKind == ReductionGenCBKind::Clang) {
4116 LHSPtrs.emplace_back(Args&: LHSPtr);
4117 RHSPtrs.emplace_back(Args&: RHSPtr);
4118 } else {
4119 Value *LHS = LHSPtr;
4120 Value *RHS = RHSPtr;
4121
4122 if (!IsByRef.empty() && !IsByRef[En.index()]) {
4123 LHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: LHSPtr);
4124 RHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: RHSPtr);
4125 }
4126
4127 Value *Reduced;
4128 InsertPointOrErrorTy AfterIP =
4129 RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
4130 if (!AfterIP)
4131 return AfterIP.takeError();
4132 if (!Builder.GetInsertBlock())
4133 return ReductionFunc;
4134
4135 Builder.restoreIP(IP: *AfterIP);
4136
4137 if (!IsByRef.empty() && !IsByRef[En.index()])
4138 Builder.CreateStore(Val: Reduced, Ptr: LHSPtr);
4139 }
4140 }
4141
4142 if (ReductionGenCBKind == ReductionGenCBKind::Clang)
4143 for (auto En : enumerate(First&: ReductionInfos)) {
4144 unsigned Index = En.index();
4145 const ReductionInfo &RI = En.value();
4146 Value *LHSFixupPtr, *RHSFixupPtr;
4147 Builder.restoreIP(IP: RI.ReductionGenClang(
4148 Builder.saveIP(), Index, &LHSFixupPtr, &RHSFixupPtr, ReductionFunc));
4149
4150 // Fix the CallBack code genereated to use the correct Values for the LHS
4151 // and RHS
4152 LHSFixupPtr->replaceUsesWithIf(
4153 New: LHSPtrs[Index], ShouldReplace: [ReductionFunc](const Use &U) {
4154 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
4155 ReductionFunc;
4156 });
4157 RHSFixupPtr->replaceUsesWithIf(
4158 New: RHSPtrs[Index], ShouldReplace: [ReductionFunc](const Use &U) {
4159 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
4160 ReductionFunc;
4161 });
4162 }
4163
4164 Builder.CreateRetVoid();
4165 return ReductionFunc;
4166}
4167
4168static void
4169checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
4170 bool IsGPU) {
4171 for (const OpenMPIRBuilder::ReductionInfo &RI : ReductionInfos) {
4172 (void)RI;
4173 assert(RI.Variable && "expected non-null variable");
4174 assert(RI.PrivateVariable && "expected non-null private variable");
4175 assert((RI.ReductionGen || RI.ReductionGenClang) &&
4176 "expected non-null reduction generator callback");
4177 if (!IsGPU) {
4178 assert(
4179 RI.Variable->getType() == RI.PrivateVariable->getType() &&
4180 "expected variables and their private equivalents to have the same "
4181 "type");
4182 }
4183 assert(RI.Variable->getType()->isPointerTy() &&
4184 "expected variables to be pointers");
4185 }
4186}
4187
4188OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
4189 const LocationDescription &Loc, InsertPointTy AllocaIP,
4190 InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
4191 ArrayRef<bool> IsByRef, bool IsNoWait, bool IsTeamsReduction,
4192 ReductionGenCBKind ReductionGenCBKind, std::optional<omp::GV> GridValue,
4193 unsigned ReductionBufNum, Value *SrcLocInfo) {
4194 if (!updateToLocation(Loc))
4195 return InsertPointTy();
4196 Builder.restoreIP(IP: CodeGenIP);
4197 checkReductionInfos(ReductionInfos, /*IsGPU*/ true);
4198 LLVMContext &Ctx = M.getContext();
4199
4200 // Source location for the ident struct
4201 if (!SrcLocInfo) {
4202 uint32_t SrcLocStrSize;
4203 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4204 SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4205 }
4206
4207 if (ReductionInfos.size() == 0)
4208 return Builder.saveIP();
4209
4210 BasicBlock *ContinuationBlock = nullptr;
4211 if (ReductionGenCBKind != ReductionGenCBKind::Clang) {
4212 // Copied code from createReductions
4213 BasicBlock *InsertBlock = Loc.IP.getBlock();
4214 ContinuationBlock =
4215 InsertBlock->splitBasicBlock(I: Loc.IP.getPoint(), BBName: "reduce.finalize");
4216 InsertBlock->getTerminator()->eraseFromParent();
4217 Builder.SetInsertPoint(TheBB: InsertBlock, IP: InsertBlock->end());
4218 }
4219
4220 Function *CurFunc = Builder.GetInsertBlock()->getParent();
4221 AttributeList FuncAttrs;
4222 AttrBuilder AttrBldr(Ctx);
4223 for (auto Attr : CurFunc->getAttributes().getFnAttrs())
4224 AttrBldr.addAttribute(A: Attr);
4225 AttrBldr.removeAttribute(Val: Attribute::OptimizeNone);
4226 FuncAttrs = FuncAttrs.addFnAttributes(C&: Ctx, B: AttrBldr);
4227
4228 CodeGenIP = Builder.saveIP();
4229 Expected<Function *> ReductionResult = createReductionFunction(
4230 ReducerName: Builder.GetInsertBlock()->getParent()->getName(), ReductionInfos, IsByRef,
4231 ReductionGenCBKind, FuncAttrs);
4232 if (!ReductionResult)
4233 return ReductionResult.takeError();
4234 Function *ReductionFunc = *ReductionResult;
4235 Builder.restoreIP(IP: CodeGenIP);
4236
4237 // Set the grid value in the config needed for lowering later on
4238 if (GridValue.has_value())
4239 Config.setGridValue(GridValue.value());
4240 else
4241 Config.setGridValue(getGridValue(T, Kernel: ReductionFunc));
4242
4243 // Build res = __kmpc_reduce{_nowait}(<gtid>, <n>, sizeof(RedList),
4244 // RedList, shuffle_reduce_func, interwarp_copy_func);
4245 // or
4246 // Build res = __kmpc_reduce_teams_nowait_simple(<loc>, <gtid>, <lck>);
4247 Value *Res;
4248
4249 // 1. Build a list of reduction variables.
4250 // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
4251 auto Size = ReductionInfos.size();
4252 Type *PtrTy = PointerType::get(C&: Ctx, AddressSpace: Config.getDefaultTargetAS());
4253 Type *FuncPtrTy =
4254 Builder.getPtrTy(AddrSpace: M.getDataLayout().getProgramAddressSpace());
4255 Type *RedArrayTy = ArrayType::get(ElementType: PtrTy, NumElements: Size);
4256 CodeGenIP = Builder.saveIP();
4257 Builder.restoreIP(IP: AllocaIP);
4258 Value *ReductionListAlloca =
4259 Builder.CreateAlloca(Ty: RedArrayTy, ArraySize: nullptr, Name: ".omp.reduction.red_list");
4260 Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
4261 V: ReductionListAlloca, DestTy: PtrTy, Name: ReductionListAlloca->getName() + ".ascast");
4262 Builder.restoreIP(IP: CodeGenIP);
4263 Type *IndexTy = Builder.getIndexTy(
4264 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
4265 for (auto En : enumerate(First&: ReductionInfos)) {
4266 const ReductionInfo &RI = En.value();
4267 Value *ElemPtr = Builder.CreateInBoundsGEP(
4268 Ty: RedArrayTy, Ptr: ReductionList,
4269 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
4270
4271 Value *PrivateVar = RI.PrivateVariable;
4272 bool IsByRefElem = !IsByRef.empty() && IsByRef[En.index()];
4273 if (IsByRefElem)
4274 PrivateVar = Builder.CreateLoad(Ty: RI.ElementType, Ptr: PrivateVar);
4275
4276 Value *CastElem =
4277 Builder.CreatePointerBitCastOrAddrSpaceCast(V: PrivateVar, DestTy: PtrTy);
4278 Builder.CreateStore(Val: CastElem, Ptr: ElemPtr);
4279 }
4280 CodeGenIP = Builder.saveIP();
4281 Expected<Function *> SarFunc = emitShuffleAndReduceFunction(
4282 ReductionInfos, ReduceFn: ReductionFunc, FuncAttrs, IsByRef);
4283
4284 if (!SarFunc)
4285 return SarFunc.takeError();
4286
4287 Expected<Function *> CopyResult =
4288 emitInterWarpCopyFunction(Loc, ReductionInfos, FuncAttrs, IsByRef);
4289 if (!CopyResult)
4290 return CopyResult.takeError();
4291 Function *WcFunc = *CopyResult;
4292 Builder.restoreIP(IP: CodeGenIP);
4293
4294 Value *RL = Builder.CreatePointerBitCastOrAddrSpaceCast(V: ReductionList, DestTy: PtrTy);
4295
4296 unsigned MaxDataSize = 0;
4297 SmallVector<Type *> ReductionTypeArgs;
4298 for (auto En : enumerate(First&: ReductionInfos)) {
4299 auto Size = M.getDataLayout().getTypeStoreSize(Ty: En.value().ElementType);
4300 if (Size > MaxDataSize)
4301 MaxDataSize = Size;
4302 Type *RedTypeArg = (!IsByRef.empty() && IsByRef[En.index()])
4303 ? En.value().ByRefElementType
4304 : En.value().ElementType;
4305 ReductionTypeArgs.emplace_back(Args&: RedTypeArg);
4306 }
4307 Value *ReductionDataSize =
4308 Builder.getInt64(C: MaxDataSize * ReductionInfos.size());
4309 if (!IsTeamsReduction) {
4310 Value *SarFuncCast =
4311 Builder.CreatePointerBitCastOrAddrSpaceCast(V: *SarFunc, DestTy: FuncPtrTy);
4312 Value *WcFuncCast =
4313 Builder.CreatePointerBitCastOrAddrSpaceCast(V: WcFunc, DestTy: FuncPtrTy);
4314 Value *Args[] = {SrcLocInfo, ReductionDataSize, RL, SarFuncCast,
4315 WcFuncCast};
4316 Function *Pv2Ptr = getOrCreateRuntimeFunctionPtr(
4317 FnID: RuntimeFunction::OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2);
4318 Res = createRuntimeFunctionCall(Callee: Pv2Ptr, Args);
4319 } else {
4320 CodeGenIP = Builder.saveIP();
4321 StructType *ReductionsBufferTy = StructType::create(
4322 Context&: Ctx, Elements: ReductionTypeArgs, Name: "struct._globalized_locals_ty");
4323 Function *RedFixedBufferFn = getOrCreateRuntimeFunctionPtr(
4324 FnID: RuntimeFunction::OMPRTL___kmpc_reduction_get_fixed_buffer);
4325
4326 Expected<Function *> LtGCFunc = emitListToGlobalCopyFunction(
4327 ReductionInfos, ReductionsBufferTy, FuncAttrs, IsByRef);
4328 if (!LtGCFunc)
4329 return LtGCFunc.takeError();
4330
4331 Expected<Function *> LtGRFunc = emitListToGlobalReduceFunction(
4332 ReductionInfos, ReduceFn: ReductionFunc, ReductionsBufferTy, FuncAttrs, IsByRef);
4333 if (!LtGRFunc)
4334 return LtGRFunc.takeError();
4335
4336 Expected<Function *> GtLCFunc = emitGlobalToListCopyFunction(
4337 ReductionInfos, ReductionsBufferTy, FuncAttrs, IsByRef);
4338 if (!GtLCFunc)
4339 return GtLCFunc.takeError();
4340
4341 Expected<Function *> GtLRFunc = emitGlobalToListReduceFunction(
4342 ReductionInfos, ReduceFn: ReductionFunc, ReductionsBufferTy, FuncAttrs, IsByRef);
4343 if (!GtLRFunc)
4344 return GtLRFunc.takeError();
4345
4346 Builder.restoreIP(IP: CodeGenIP);
4347
4348 Value *KernelTeamsReductionPtr = createRuntimeFunctionCall(
4349 Callee: RedFixedBufferFn, Args: {}, Name: "_openmp_teams_reductions_buffer_$_$ptr");
4350
4351 Value *Args3[] = {SrcLocInfo,
4352 KernelTeamsReductionPtr,
4353 Builder.getInt32(C: ReductionBufNum),
4354 ReductionDataSize,
4355 RL,
4356 *SarFunc,
4357 WcFunc,
4358 *LtGCFunc,
4359 *LtGRFunc,
4360 *GtLCFunc,
4361 *GtLRFunc};
4362
4363 Function *TeamsReduceFn = getOrCreateRuntimeFunctionPtr(
4364 FnID: RuntimeFunction::OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2);
4365 Res = createRuntimeFunctionCall(Callee: TeamsReduceFn, Args: Args3);
4366 }
4367
4368 // 5. Build if (res == 1)
4369 BasicBlock *ExitBB = BasicBlock::Create(Context&: Ctx, Name: ".omp.reduction.done");
4370 BasicBlock *ThenBB = BasicBlock::Create(Context&: Ctx, Name: ".omp.reduction.then");
4371 Value *Cond = Builder.CreateICmpEQ(LHS: Res, RHS: Builder.getInt32(C: 1));
4372 Builder.CreateCondBr(Cond, True: ThenBB, False: ExitBB);
4373
4374 // 6. Build then branch: where we have reduced values in the master
4375 // thread in each team.
4376 // __kmpc_end_reduce{_nowait}(<gtid>);
4377 // break;
4378 emitBlock(BB: ThenBB, CurFn: CurFunc);
4379
4380 // Add emission of __kmpc_end_reduce{_nowait}(<gtid>);
4381 for (auto En : enumerate(First&: ReductionInfos)) {
4382 const ReductionInfo &RI = En.value();
4383 Type *ValueType = RI.ElementType;
4384 Value *RedValue = RI.Variable;
4385 Value *RHS =
4386 Builder.CreatePointerBitCastOrAddrSpaceCast(V: RI.PrivateVariable, DestTy: PtrTy);
4387
4388 if (ReductionGenCBKind == ReductionGenCBKind::Clang) {
4389 Value *LHSPtr, *RHSPtr;
4390 Builder.restoreIP(IP: RI.ReductionGenClang(Builder.saveIP(), En.index(),
4391 &LHSPtr, &RHSPtr, CurFunc));
4392
4393 // Fix the CallBack code genereated to use the correct Values for the LHS
4394 // and RHS
4395 LHSPtr->replaceUsesWithIf(New: RedValue, ShouldReplace: [ReductionFunc](const Use &U) {
4396 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
4397 ReductionFunc;
4398 });
4399 RHSPtr->replaceUsesWithIf(New: RHS, ShouldReplace: [ReductionFunc](const Use &U) {
4400 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
4401 ReductionFunc;
4402 });
4403 } else {
4404 if (IsByRef.empty() || !IsByRef[En.index()]) {
4405 RedValue = Builder.CreateLoad(Ty: ValueType, Ptr: RI.Variable,
4406 Name: "red.value." + Twine(En.index()));
4407 }
4408 Value *PrivateRedValue = Builder.CreateLoad(
4409 Ty: ValueType, Ptr: RHS, Name: "red.private.value" + Twine(En.index()));
4410 Value *Reduced;
4411 InsertPointOrErrorTy AfterIP =
4412 RI.ReductionGen(Builder.saveIP(), RedValue, PrivateRedValue, Reduced);
4413 if (!AfterIP)
4414 return AfterIP.takeError();
4415 Builder.restoreIP(IP: *AfterIP);
4416
4417 if (!IsByRef.empty() && !IsByRef[En.index()])
4418 Builder.CreateStore(Val: Reduced, Ptr: RI.Variable);
4419 }
4420 }
4421 emitBlock(BB: ExitBB, CurFn: CurFunc);
4422 if (ContinuationBlock) {
4423 Builder.CreateBr(Dest: ContinuationBlock);
4424 Builder.SetInsertPoint(ContinuationBlock);
4425 }
4426 Config.setEmitLLVMUsed();
4427
4428 return Builder.saveIP();
4429}
4430
4431static Function *getFreshReductionFunc(Module &M) {
4432 Type *VoidTy = Type::getVoidTy(C&: M.getContext());
4433 Type *Int8PtrTy = PointerType::getUnqual(C&: M.getContext());
4434 auto *FuncTy =
4435 FunctionType::get(Result: VoidTy, Params: {Int8PtrTy, Int8PtrTy}, /* IsVarArg */ isVarArg: false);
4436 return Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
4437 N: ".omp.reduction.func", M: &M);
4438}
4439
4440static Error populateReductionFunction(
4441 Function *ReductionFunc,
4442 ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
4443 IRBuilder<> &Builder, ArrayRef<bool> IsByRef, bool IsGPU) {
4444 Module *Module = ReductionFunc->getParent();
4445 BasicBlock *ReductionFuncBlock =
4446 BasicBlock::Create(Context&: Module->getContext(), Name: "", Parent: ReductionFunc);
4447 Builder.SetInsertPoint(ReductionFuncBlock);
4448 Value *LHSArrayPtr = nullptr;
4449 Value *RHSArrayPtr = nullptr;
4450 if (IsGPU) {
4451 // Need to alloca memory here and deal with the pointers before getting
4452 // LHS/RHS pointers out
4453 //
4454 Argument *Arg0 = ReductionFunc->getArg(i: 0);
4455 Argument *Arg1 = ReductionFunc->getArg(i: 1);
4456 Type *Arg0Type = Arg0->getType();
4457 Type *Arg1Type = Arg1->getType();
4458
4459 Value *LHSAlloca =
4460 Builder.CreateAlloca(Ty: Arg0Type, ArraySize: nullptr, Name: Arg0->getName() + ".addr");
4461 Value *RHSAlloca =
4462 Builder.CreateAlloca(Ty: Arg1Type, ArraySize: nullptr, Name: Arg1->getName() + ".addr");
4463 Value *LHSAddrCast =
4464 Builder.CreatePointerBitCastOrAddrSpaceCast(V: LHSAlloca, DestTy: Arg0Type);
4465 Value *RHSAddrCast =
4466 Builder.CreatePointerBitCastOrAddrSpaceCast(V: RHSAlloca, DestTy: Arg1Type);
4467 Builder.CreateStore(Val: Arg0, Ptr: LHSAddrCast);
4468 Builder.CreateStore(Val: Arg1, Ptr: RHSAddrCast);
4469 LHSArrayPtr = Builder.CreateLoad(Ty: Arg0Type, Ptr: LHSAddrCast);
4470 RHSArrayPtr = Builder.CreateLoad(Ty: Arg1Type, Ptr: RHSAddrCast);
4471 } else {
4472 LHSArrayPtr = ReductionFunc->getArg(i: 0);
4473 RHSArrayPtr = ReductionFunc->getArg(i: 1);
4474 }
4475
4476 unsigned NumReductions = ReductionInfos.size();
4477 Type *RedArrayTy = ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: NumReductions);
4478
4479 for (auto En : enumerate(First&: ReductionInfos)) {
4480 const OpenMPIRBuilder::ReductionInfo &RI = En.value();
4481 Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
4482 Ty: RedArrayTy, Ptr: LHSArrayPtr, Idx0: 0, Idx1: En.index());
4483 Value *LHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: LHSI8PtrPtr);
4484 Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
4485 V: LHSI8Ptr, DestTy: RI.Variable->getType());
4486 Value *LHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: LHSPtr);
4487 Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
4488 Ty: RedArrayTy, Ptr: RHSArrayPtr, Idx0: 0, Idx1: En.index());
4489 Value *RHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: RHSI8PtrPtr);
4490 Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
4491 V: RHSI8Ptr, DestTy: RI.PrivateVariable->getType());
4492 Value *RHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: RHSPtr);
4493 Value *Reduced;
4494 OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
4495 RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
4496 if (!AfterIP)
4497 return AfterIP.takeError();
4498
4499 Builder.restoreIP(IP: *AfterIP);
4500 // TODO: Consider flagging an error.
4501 if (!Builder.GetInsertBlock())
4502 return Error::success();
4503
4504 // store is inside of the reduction region when using by-ref
4505 if (!IsByRef[En.index()])
4506 Builder.CreateStore(Val: Reduced, Ptr: LHSPtr);
4507 }
4508 Builder.CreateRetVoid();
4509 return Error::success();
4510}
4511
4512OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductions(
4513 const LocationDescription &Loc, InsertPointTy AllocaIP,
4514 ArrayRef<ReductionInfo> ReductionInfos, ArrayRef<bool> IsByRef,
4515 bool IsNoWait, bool IsTeamsReduction) {
4516 assert(ReductionInfos.size() == IsByRef.size());
4517 if (Config.isGPU())
4518 return createReductionsGPU(Loc, AllocaIP, CodeGenIP: Builder.saveIP(), ReductionInfos,
4519 IsByRef, IsNoWait, IsTeamsReduction);
4520
4521 checkReductionInfos(ReductionInfos, /*IsGPU*/ false);
4522
4523 if (!updateToLocation(Loc))
4524 return InsertPointTy();
4525
4526 if (ReductionInfos.size() == 0)
4527 return Builder.saveIP();
4528
4529 BasicBlock *InsertBlock = Loc.IP.getBlock();
4530 BasicBlock *ContinuationBlock =
4531 InsertBlock->splitBasicBlock(I: Loc.IP.getPoint(), BBName: "reduce.finalize");
4532 InsertBlock->getTerminator()->eraseFromParent();
4533
4534 // Create and populate array of type-erased pointers to private reduction
4535 // values.
4536 unsigned NumReductions = ReductionInfos.size();
4537 Type *RedArrayTy = ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: NumReductions);
4538 Builder.SetInsertPoint(AllocaIP.getBlock()->getTerminator());
4539 Value *RedArray = Builder.CreateAlloca(Ty: RedArrayTy, ArraySize: nullptr, Name: "red.array");
4540
4541 Builder.SetInsertPoint(TheBB: InsertBlock, IP: InsertBlock->end());
4542
4543 for (auto En : enumerate(First&: ReductionInfos)) {
4544 unsigned Index = En.index();
4545 const ReductionInfo &RI = En.value();
4546 Value *RedArrayElemPtr = Builder.CreateConstInBoundsGEP2_64(
4547 Ty: RedArrayTy, Ptr: RedArray, Idx0: 0, Idx1: Index, Name: "red.array.elem." + Twine(Index));
4548 Builder.CreateStore(Val: RI.PrivateVariable, Ptr: RedArrayElemPtr);
4549 }
4550
4551 // Emit a call to the runtime function that orchestrates the reduction.
4552 // Declare the reduction function in the process.
4553 Type *IndexTy = Builder.getIndexTy(
4554 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
4555 Function *Func = Builder.GetInsertBlock()->getParent();
4556 Module *Module = Func->getParent();
4557 uint32_t SrcLocStrSize;
4558 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4559 bool CanGenerateAtomic = all_of(Range&: ReductionInfos, P: [](const ReductionInfo &RI) {
4560 return RI.AtomicReductionGen;
4561 });
4562 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize,
4563 LocFlags: CanGenerateAtomic
4564 ? IdentFlag::OMP_IDENT_FLAG_ATOMIC_REDUCE
4565 : IdentFlag(0));
4566 Value *ThreadId = getOrCreateThreadID(Ident);
4567 Constant *NumVariables = Builder.getInt32(C: NumReductions);
4568 const DataLayout &DL = Module->getDataLayout();
4569 unsigned RedArrayByteSize = DL.getTypeStoreSize(Ty: RedArrayTy);
4570 Constant *RedArraySize = ConstantInt::get(Ty: IndexTy, V: RedArrayByteSize);
4571 Function *ReductionFunc = getFreshReductionFunc(M&: *Module);
4572 Value *Lock = getOMPCriticalRegionLock(CriticalName: ".reduction");
4573 Function *ReduceFunc = getOrCreateRuntimeFunctionPtr(
4574 FnID: IsNoWait ? RuntimeFunction::OMPRTL___kmpc_reduce_nowait
4575 : RuntimeFunction::OMPRTL___kmpc_reduce);
4576 CallInst *ReduceCall =
4577 createRuntimeFunctionCall(Callee: ReduceFunc,
4578 Args: {Ident, ThreadId, NumVariables, RedArraySize,
4579 RedArray, ReductionFunc, Lock},
4580 Name: "reduce");
4581
4582 // Create final reduction entry blocks for the atomic and non-atomic case.
4583 // Emit IR that dispatches control flow to one of the blocks based on the
4584 // reduction supporting the atomic mode.
4585 BasicBlock *NonAtomicRedBlock =
4586 BasicBlock::Create(Context&: Module->getContext(), Name: "reduce.switch.nonatomic", Parent: Func);
4587 BasicBlock *AtomicRedBlock =
4588 BasicBlock::Create(Context&: Module->getContext(), Name: "reduce.switch.atomic", Parent: Func);
4589 SwitchInst *Switch =
4590 Builder.CreateSwitch(V: ReduceCall, Dest: ContinuationBlock, /* NumCases */ 2);
4591 Switch->addCase(OnVal: Builder.getInt32(C: 1), Dest: NonAtomicRedBlock);
4592 Switch->addCase(OnVal: Builder.getInt32(C: 2), Dest: AtomicRedBlock);
4593
4594 // Populate the non-atomic reduction using the elementwise reduction function.
4595 // This loads the elements from the global and private variables and reduces
4596 // them before storing back the result to the global variable.
4597 Builder.SetInsertPoint(NonAtomicRedBlock);
4598 for (auto En : enumerate(First&: ReductionInfos)) {
4599 const ReductionInfo &RI = En.value();
4600 Type *ValueType = RI.ElementType;
4601 // We have one less load for by-ref case because that load is now inside of
4602 // the reduction region
4603 Value *RedValue = RI.Variable;
4604 if (!IsByRef[En.index()]) {
4605 RedValue = Builder.CreateLoad(Ty: ValueType, Ptr: RI.Variable,
4606 Name: "red.value." + Twine(En.index()));
4607 }
4608 Value *PrivateRedValue =
4609 Builder.CreateLoad(Ty: ValueType, Ptr: RI.PrivateVariable,
4610 Name: "red.private.value." + Twine(En.index()));
4611 Value *Reduced;
4612 InsertPointOrErrorTy AfterIP =
4613 RI.ReductionGen(Builder.saveIP(), RedValue, PrivateRedValue, Reduced);
4614 if (!AfterIP)
4615 return AfterIP.takeError();
4616 Builder.restoreIP(IP: *AfterIP);
4617
4618 if (!Builder.GetInsertBlock())
4619 return InsertPointTy();
4620 // for by-ref case, the load is inside of the reduction region
4621 if (!IsByRef[En.index()])
4622 Builder.CreateStore(Val: Reduced, Ptr: RI.Variable);
4623 }
4624 Function *EndReduceFunc = getOrCreateRuntimeFunctionPtr(
4625 FnID: IsNoWait ? RuntimeFunction::OMPRTL___kmpc_end_reduce_nowait
4626 : RuntimeFunction::OMPRTL___kmpc_end_reduce);
4627 createRuntimeFunctionCall(Callee: EndReduceFunc, Args: {Ident, ThreadId, Lock});
4628 Builder.CreateBr(Dest: ContinuationBlock);
4629
4630 // Populate the atomic reduction using the atomic elementwise reduction
4631 // function. There are no loads/stores here because they will be happening
4632 // inside the atomic elementwise reduction.
4633 Builder.SetInsertPoint(AtomicRedBlock);
4634 if (CanGenerateAtomic && llvm::none_of(Range&: IsByRef, P: [](bool P) { return P; })) {
4635 for (const ReductionInfo &RI : ReductionInfos) {
4636 InsertPointOrErrorTy AfterIP = RI.AtomicReductionGen(
4637 Builder.saveIP(), RI.ElementType, RI.Variable, RI.PrivateVariable);
4638 if (!AfterIP)
4639 return AfterIP.takeError();
4640 Builder.restoreIP(IP: *AfterIP);
4641 if (!Builder.GetInsertBlock())
4642 return InsertPointTy();
4643 }
4644 Builder.CreateBr(Dest: ContinuationBlock);
4645 } else {
4646 Builder.CreateUnreachable();
4647 }
4648
4649 // Populate the outlined reduction function using the elementwise reduction
4650 // function. Partial values are extracted from the type-erased array of
4651 // pointers to private variables.
4652 Error Err = populateReductionFunction(ReductionFunc, ReductionInfos, Builder,
4653 IsByRef, /*isGPU=*/IsGPU: false);
4654 if (Err)
4655 return Err;
4656
4657 if (!Builder.GetInsertBlock())
4658 return InsertPointTy();
4659
4660 Builder.SetInsertPoint(ContinuationBlock);
4661 return Builder.saveIP();
4662}
4663
4664OpenMPIRBuilder::InsertPointOrErrorTy
4665OpenMPIRBuilder::createMaster(const LocationDescription &Loc,
4666 BodyGenCallbackTy BodyGenCB,
4667 FinalizeCallbackTy FiniCB) {
4668 if (!updateToLocation(Loc))
4669 return Loc.IP;
4670
4671 Directive OMPD = Directive::OMPD_master;
4672 uint32_t SrcLocStrSize;
4673 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4674 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4675 Value *ThreadId = getOrCreateThreadID(Ident);
4676 Value *Args[] = {Ident, ThreadId};
4677
4678 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_master);
4679 Instruction *EntryCall = createRuntimeFunctionCall(Callee: EntryRTLFn, Args);
4680
4681 Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_master);
4682 Instruction *ExitCall = createRuntimeFunctionCall(Callee: ExitRTLFn, Args);
4683
4684 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
4685 /*Conditional*/ true, /*hasFinalize*/ HasFinalize: true);
4686}
4687
4688OpenMPIRBuilder::InsertPointOrErrorTy
4689OpenMPIRBuilder::createMasked(const LocationDescription &Loc,
4690 BodyGenCallbackTy BodyGenCB,
4691 FinalizeCallbackTy FiniCB, Value *Filter) {
4692 if (!updateToLocation(Loc))
4693 return Loc.IP;
4694
4695 Directive OMPD = Directive::OMPD_masked;
4696 uint32_t SrcLocStrSize;
4697 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4698 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4699 Value *ThreadId = getOrCreateThreadID(Ident);
4700 Value *Args[] = {Ident, ThreadId, Filter};
4701 Value *ArgsEnd[] = {Ident, ThreadId};
4702
4703 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_masked);
4704 Instruction *EntryCall = createRuntimeFunctionCall(Callee: EntryRTLFn, Args);
4705
4706 Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_masked);
4707 Instruction *ExitCall = createRuntimeFunctionCall(Callee: ExitRTLFn, Args: ArgsEnd);
4708
4709 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
4710 /*Conditional*/ true, /*hasFinalize*/ HasFinalize: true);
4711}
4712
4713static llvm::CallInst *emitNoUnwindRuntimeCall(IRBuilder<> &Builder,
4714 llvm::FunctionCallee Callee,
4715 ArrayRef<llvm::Value *> Args,
4716 const llvm::Twine &Name) {
4717 llvm::CallInst *Call = Builder.CreateCall(
4718 Callee, Args, OpBundles: SmallVector<llvm::OperandBundleDef, 1>(), Name);
4719 Call->setDoesNotThrow();
4720 return Call;
4721}
4722
4723// Expects input basic block is dominated by BeforeScanBB.
4724// Once Scan directive is encountered, the code after scan directive should be
4725// dominated by AfterScanBB. Scan directive splits the code sequence to
4726// scan and input phase. Based on whether inclusive or exclusive
4727// clause is used in the scan directive and whether input loop or scan loop
4728// is lowered, it adds jumps to input and scan phase. First Scan loop is the
4729// input loop and second is the scan loop. The code generated handles only
4730// inclusive scans now.
4731OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createScan(
4732 const LocationDescription &Loc, InsertPointTy AllocaIP,
4733 ArrayRef<llvm::Value *> ScanVars, ArrayRef<llvm::Type *> ScanVarsType,
4734 bool IsInclusive, ScanInfo *ScanRedInfo) {
4735 if (ScanRedInfo->OMPFirstScanLoop) {
4736 llvm::Error Err = emitScanBasedDirectiveDeclsIR(AllocaIP, ScanVars,
4737 ScanVarsType, ScanRedInfo);
4738 if (Err)
4739 return Err;
4740 }
4741 if (!updateToLocation(Loc))
4742 return Loc.IP;
4743
4744 llvm::Value *IV = ScanRedInfo->IV;
4745
4746 if (ScanRedInfo->OMPFirstScanLoop) {
4747 // Emit buffer[i] = red; at the end of the input phase.
4748 for (size_t i = 0; i < ScanVars.size(); i++) {
4749 Value *BuffPtr = (*(ScanRedInfo->ScanBuffPtrs))[ScanVars[i]];
4750 Value *Buff = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BuffPtr);
4751 Type *DestTy = ScanVarsType[i];
4752 Value *Val = Builder.CreateInBoundsGEP(Ty: DestTy, Ptr: Buff, IdxList: IV, Name: "arrayOffset");
4753 Value *Src = Builder.CreateLoad(Ty: DestTy, Ptr: ScanVars[i]);
4754
4755 Builder.CreateStore(Val: Src, Ptr: Val);
4756 }
4757 }
4758 Builder.CreateBr(Dest: ScanRedInfo->OMPScanLoopExit);
4759 emitBlock(BB: ScanRedInfo->OMPScanDispatch,
4760 CurFn: Builder.GetInsertBlock()->getParent());
4761
4762 if (!ScanRedInfo->OMPFirstScanLoop) {
4763 IV = ScanRedInfo->IV;
4764 // Emit red = buffer[i]; at the entrance to the scan phase.
4765 // TODO: if exclusive scan, the red = buffer[i-1] needs to be updated.
4766 for (size_t i = 0; i < ScanVars.size(); i++) {
4767 Value *BuffPtr = (*(ScanRedInfo->ScanBuffPtrs))[ScanVars[i]];
4768 Value *Buff = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BuffPtr);
4769 Type *DestTy = ScanVarsType[i];
4770 Value *SrcPtr =
4771 Builder.CreateInBoundsGEP(Ty: DestTy, Ptr: Buff, IdxList: IV, Name: "arrayOffset");
4772 Value *Src = Builder.CreateLoad(Ty: DestTy, Ptr: SrcPtr);
4773 Builder.CreateStore(Val: Src, Ptr: ScanVars[i]);
4774 }
4775 }
4776
4777 // TODO: Update it to CreateBr and remove dead blocks
4778 llvm::Value *CmpI = Builder.getInt1(V: true);
4779 if (ScanRedInfo->OMPFirstScanLoop == IsInclusive) {
4780 Builder.CreateCondBr(Cond: CmpI, True: ScanRedInfo->OMPBeforeScanBlock,
4781 False: ScanRedInfo->OMPAfterScanBlock);
4782 } else {
4783 Builder.CreateCondBr(Cond: CmpI, True: ScanRedInfo->OMPAfterScanBlock,
4784 False: ScanRedInfo->OMPBeforeScanBlock);
4785 }
4786 emitBlock(BB: ScanRedInfo->OMPAfterScanBlock,
4787 CurFn: Builder.GetInsertBlock()->getParent());
4788 Builder.SetInsertPoint(ScanRedInfo->OMPAfterScanBlock);
4789 return Builder.saveIP();
4790}
4791
4792Error OpenMPIRBuilder::emitScanBasedDirectiveDeclsIR(
4793 InsertPointTy AllocaIP, ArrayRef<Value *> ScanVars,
4794 ArrayRef<Type *> ScanVarsType, ScanInfo *ScanRedInfo) {
4795
4796 Builder.restoreIP(IP: AllocaIP);
4797 // Create the shared pointer at alloca IP.
4798 for (size_t i = 0; i < ScanVars.size(); i++) {
4799 llvm::Value *BuffPtr =
4800 Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: "vla");
4801 (*(ScanRedInfo->ScanBuffPtrs))[ScanVars[i]] = BuffPtr;
4802 }
4803
4804 // Allocate temporary buffer by master thread
4805 auto BodyGenCB = [&](InsertPointTy AllocaIP,
4806 InsertPointTy CodeGenIP) -> Error {
4807 Builder.restoreIP(IP: CodeGenIP);
4808 Value *AllocSpan =
4809 Builder.CreateAdd(LHS: ScanRedInfo->Span, RHS: Builder.getInt32(C: 1));
4810 for (size_t i = 0; i < ScanVars.size(); i++) {
4811 Type *IntPtrTy = Builder.getInt32Ty();
4812 Constant *Allocsize = ConstantExpr::getSizeOf(Ty: ScanVarsType[i]);
4813 Allocsize = ConstantExpr::getTruncOrBitCast(C: Allocsize, Ty: IntPtrTy);
4814 Value *Buff = Builder.CreateMalloc(IntPtrTy, AllocTy: ScanVarsType[i], AllocSize: Allocsize,
4815 ArraySize: AllocSpan, MallocF: nullptr, Name: "arr");
4816 Builder.CreateStore(Val: Buff, Ptr: (*(ScanRedInfo->ScanBuffPtrs))[ScanVars[i]]);
4817 }
4818 return Error::success();
4819 };
4820 // TODO: Perform finalization actions for variables. This has to be
4821 // called for variables which have destructors/finalizers.
4822 auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); };
4823
4824 Builder.SetInsertPoint(ScanRedInfo->OMPScanInit->getTerminator());
4825 llvm::Value *FilterVal = Builder.getInt32(C: 0);
4826 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
4827 createMasked(Loc: Builder.saveIP(), BodyGenCB, FiniCB, Filter: FilterVal);
4828
4829 if (!AfterIP)
4830 return AfterIP.takeError();
4831 Builder.restoreIP(IP: *AfterIP);
4832 BasicBlock *InputBB = Builder.GetInsertBlock();
4833 if (InputBB->getTerminator())
4834 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
4835 AfterIP = createBarrier(Loc: Builder.saveIP(), Kind: llvm::omp::OMPD_barrier);
4836 if (!AfterIP)
4837 return AfterIP.takeError();
4838 Builder.restoreIP(IP: *AfterIP);
4839
4840 return Error::success();
4841}
4842
4843Error OpenMPIRBuilder::emitScanBasedDirectiveFinalsIR(
4844 ArrayRef<ReductionInfo> ReductionInfos, ScanInfo *ScanRedInfo) {
4845 auto BodyGenCB = [&](InsertPointTy AllocaIP,
4846 InsertPointTy CodeGenIP) -> Error {
4847 Builder.restoreIP(IP: CodeGenIP);
4848 for (ReductionInfo RedInfo : ReductionInfos) {
4849 Value *PrivateVar = RedInfo.PrivateVariable;
4850 Value *OrigVar = RedInfo.Variable;
4851 Value *BuffPtr = (*(ScanRedInfo->ScanBuffPtrs))[PrivateVar];
4852 Value *Buff = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BuffPtr);
4853
4854 Type *SrcTy = RedInfo.ElementType;
4855 Value *Val = Builder.CreateInBoundsGEP(Ty: SrcTy, Ptr: Buff, IdxList: ScanRedInfo->Span,
4856 Name: "arrayOffset");
4857 Value *Src = Builder.CreateLoad(Ty: SrcTy, Ptr: Val);
4858
4859 Builder.CreateStore(Val: Src, Ptr: OrigVar);
4860 Builder.CreateFree(Source: Buff);
4861 }
4862 return Error::success();
4863 };
4864 // TODO: Perform finalization actions for variables. This has to be
4865 // called for variables which have destructors/finalizers.
4866 auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); };
4867
4868 if (ScanRedInfo->OMPScanFinish->getTerminator())
4869 Builder.SetInsertPoint(ScanRedInfo->OMPScanFinish->getTerminator());
4870 else
4871 Builder.SetInsertPoint(ScanRedInfo->OMPScanFinish);
4872
4873 llvm::Value *FilterVal = Builder.getInt32(C: 0);
4874 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
4875 createMasked(Loc: Builder.saveIP(), BodyGenCB, FiniCB, Filter: FilterVal);
4876
4877 if (!AfterIP)
4878 return AfterIP.takeError();
4879 Builder.restoreIP(IP: *AfterIP);
4880 BasicBlock *InputBB = Builder.GetInsertBlock();
4881 if (InputBB->getTerminator())
4882 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
4883 AfterIP = createBarrier(Loc: Builder.saveIP(), Kind: llvm::omp::OMPD_barrier);
4884 if (!AfterIP)
4885 return AfterIP.takeError();
4886 Builder.restoreIP(IP: *AfterIP);
4887 return Error::success();
4888}
4889
4890OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitScanReduction(
4891 const LocationDescription &Loc,
4892 ArrayRef<llvm::OpenMPIRBuilder::ReductionInfo> ReductionInfos,
4893 ScanInfo *ScanRedInfo) {
4894
4895 if (!updateToLocation(Loc))
4896 return Loc.IP;
4897 auto BodyGenCB = [&](InsertPointTy AllocaIP,
4898 InsertPointTy CodeGenIP) -> Error {
4899 Builder.restoreIP(IP: CodeGenIP);
4900 Function *CurFn = Builder.GetInsertBlock()->getParent();
4901 // for (int k = 0; k <= ceil(log2(n)); ++k)
4902 llvm::BasicBlock *LoopBB =
4903 BasicBlock::Create(Context&: CurFn->getContext(), Name: "omp.outer.log.scan.body");
4904 llvm::BasicBlock *ExitBB =
4905 splitBB(Builder, CreateBranch: false, Name: "omp.outer.log.scan.exit");
4906 llvm::Function *F = llvm::Intrinsic::getOrInsertDeclaration(
4907 M: Builder.GetInsertBlock()->getModule(),
4908 id: (llvm::Intrinsic::ID)llvm::Intrinsic::log2, Tys: Builder.getDoubleTy());
4909 llvm::BasicBlock *InputBB = Builder.GetInsertBlock();
4910 llvm::Value *Arg =
4911 Builder.CreateUIToFP(V: ScanRedInfo->Span, DestTy: Builder.getDoubleTy());
4912 llvm::Value *LogVal = emitNoUnwindRuntimeCall(Builder, Callee: F, Args: Arg, Name: "");
4913 F = llvm::Intrinsic::getOrInsertDeclaration(
4914 M: Builder.GetInsertBlock()->getModule(),
4915 id: (llvm::Intrinsic::ID)llvm::Intrinsic::ceil, Tys: Builder.getDoubleTy());
4916 LogVal = emitNoUnwindRuntimeCall(Builder, Callee: F, Args: LogVal, Name: "");
4917 LogVal = Builder.CreateFPToUI(V: LogVal, DestTy: Builder.getInt32Ty());
4918 llvm::Value *NMin1 = Builder.CreateNUWSub(
4919 LHS: ScanRedInfo->Span,
4920 RHS: llvm::ConstantInt::get(Ty: ScanRedInfo->Span->getType(), V: 1));
4921 Builder.SetInsertPoint(InputBB);
4922 Builder.CreateBr(Dest: LoopBB);
4923 emitBlock(BB: LoopBB, CurFn);
4924 Builder.SetInsertPoint(LoopBB);
4925
4926 PHINode *Counter = Builder.CreatePHI(Ty: Builder.getInt32Ty(), NumReservedValues: 2);
4927 // size pow2k = 1;
4928 PHINode *Pow2K = Builder.CreatePHI(Ty: Builder.getInt32Ty(), NumReservedValues: 2);
4929 Counter->addIncoming(V: llvm::ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
4930 BB: InputBB);
4931 Pow2K->addIncoming(V: llvm::ConstantInt::get(Ty: Builder.getInt32Ty(), V: 1),
4932 BB: InputBB);
4933 // for (size i = n - 1; i >= 2 ^ k; --i)
4934 // tmp[i] op= tmp[i-pow2k];
4935 llvm::BasicBlock *InnerLoopBB =
4936 BasicBlock::Create(Context&: CurFn->getContext(), Name: "omp.inner.log.scan.body");
4937 llvm::BasicBlock *InnerExitBB =
4938 BasicBlock::Create(Context&: CurFn->getContext(), Name: "omp.inner.log.scan.exit");
4939 llvm::Value *CmpI = Builder.CreateICmpUGE(LHS: NMin1, RHS: Pow2K);
4940 Builder.CreateCondBr(Cond: CmpI, True: InnerLoopBB, False: InnerExitBB);
4941 emitBlock(BB: InnerLoopBB, CurFn);
4942 Builder.SetInsertPoint(InnerLoopBB);
4943 PHINode *IVal = Builder.CreatePHI(Ty: Builder.getInt32Ty(), NumReservedValues: 2);
4944 IVal->addIncoming(V: NMin1, BB: LoopBB);
4945 for (ReductionInfo RedInfo : ReductionInfos) {
4946 Value *ReductionVal = RedInfo.PrivateVariable;
4947 Value *BuffPtr = (*(ScanRedInfo->ScanBuffPtrs))[ReductionVal];
4948 Value *Buff = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BuffPtr);
4949 Type *DestTy = RedInfo.ElementType;
4950 Value *IV = Builder.CreateAdd(LHS: IVal, RHS: Builder.getInt32(C: 1));
4951 Value *LHSPtr =
4952 Builder.CreateInBoundsGEP(Ty: DestTy, Ptr: Buff, IdxList: IV, Name: "arrayOffset");
4953 Value *OffsetIval = Builder.CreateNUWSub(LHS: IV, RHS: Pow2K);
4954 Value *RHSPtr =
4955 Builder.CreateInBoundsGEP(Ty: DestTy, Ptr: Buff, IdxList: OffsetIval, Name: "arrayOffset");
4956 Value *LHS = Builder.CreateLoad(Ty: DestTy, Ptr: LHSPtr);
4957 Value *RHS = Builder.CreateLoad(Ty: DestTy, Ptr: RHSPtr);
4958 llvm::Value *Result;
4959 InsertPointOrErrorTy AfterIP =
4960 RedInfo.ReductionGen(Builder.saveIP(), LHS, RHS, Result);
4961 if (!AfterIP)
4962 return AfterIP.takeError();
4963 Builder.CreateStore(Val: Result, Ptr: LHSPtr);
4964 }
4965 llvm::Value *NextIVal = Builder.CreateNUWSub(
4966 LHS: IVal, RHS: llvm::ConstantInt::get(Ty: Builder.getInt32Ty(), V: 1));
4967 IVal->addIncoming(V: NextIVal, BB: Builder.GetInsertBlock());
4968 CmpI = Builder.CreateICmpUGE(LHS: NextIVal, RHS: Pow2K);
4969 Builder.CreateCondBr(Cond: CmpI, True: InnerLoopBB, False: InnerExitBB);
4970 emitBlock(BB: InnerExitBB, CurFn);
4971 llvm::Value *Next = Builder.CreateNUWAdd(
4972 LHS: Counter, RHS: llvm::ConstantInt::get(Ty: Counter->getType(), V: 1));
4973 Counter->addIncoming(V: Next, BB: Builder.GetInsertBlock());
4974 // pow2k <<= 1;
4975 llvm::Value *NextPow2K = Builder.CreateShl(LHS: Pow2K, RHS: 1, Name: "", /*HasNUW=*/true);
4976 Pow2K->addIncoming(V: NextPow2K, BB: Builder.GetInsertBlock());
4977 llvm::Value *Cmp = Builder.CreateICmpNE(LHS: Next, RHS: LogVal);
4978 Builder.CreateCondBr(Cond: Cmp, True: LoopBB, False: ExitBB);
4979 Builder.SetInsertPoint(ExitBB->getFirstInsertionPt());
4980 return Error::success();
4981 };
4982
4983 // TODO: Perform finalization actions for variables. This has to be
4984 // called for variables which have destructors/finalizers.
4985 auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); };
4986
4987 llvm::Value *FilterVal = Builder.getInt32(C: 0);
4988 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
4989 createMasked(Loc: Builder.saveIP(), BodyGenCB, FiniCB, Filter: FilterVal);
4990
4991 if (!AfterIP)
4992 return AfterIP.takeError();
4993 Builder.restoreIP(IP: *AfterIP);
4994 AfterIP = createBarrier(Loc: Builder.saveIP(), Kind: llvm::omp::OMPD_barrier);
4995
4996 if (!AfterIP)
4997 return AfterIP.takeError();
4998 Builder.restoreIP(IP: *AfterIP);
4999 Error Err = emitScanBasedDirectiveFinalsIR(ReductionInfos, ScanRedInfo);
5000 if (Err)
5001 return Err;
5002
5003 return AfterIP;
5004}
5005
5006Error OpenMPIRBuilder::emitScanBasedDirectiveIR(
5007 llvm::function_ref<Error()> InputLoopGen,
5008 llvm::function_ref<Error(LocationDescription Loc)> ScanLoopGen,
5009 ScanInfo *ScanRedInfo) {
5010
5011 {
5012 // Emit loop with input phase:
5013 // for (i: 0..<num_iters>) {
5014 // <input phase>;
5015 // buffer[i] = red;
5016 // }
5017 ScanRedInfo->OMPFirstScanLoop = true;
5018 Error Err = InputLoopGen();
5019 if (Err)
5020 return Err;
5021 }
5022 {
5023 // Emit loop with scan phase:
5024 // for (i: 0..<num_iters>) {
5025 // red = buffer[i];
5026 // <scan phase>;
5027 // }
5028 ScanRedInfo->OMPFirstScanLoop = false;
5029 Error Err = ScanLoopGen(Builder.saveIP());
5030 if (Err)
5031 return Err;
5032 }
5033 return Error::success();
5034}
5035
5036void OpenMPIRBuilder::createScanBBs(ScanInfo *ScanRedInfo) {
5037 Function *Fun = Builder.GetInsertBlock()->getParent();
5038 ScanRedInfo->OMPScanDispatch =
5039 BasicBlock::Create(Context&: Fun->getContext(), Name: "omp.inscan.dispatch");
5040 ScanRedInfo->OMPAfterScanBlock =
5041 BasicBlock::Create(Context&: Fun->getContext(), Name: "omp.after.scan.bb");
5042 ScanRedInfo->OMPBeforeScanBlock =
5043 BasicBlock::Create(Context&: Fun->getContext(), Name: "omp.before.scan.bb");
5044 ScanRedInfo->OMPScanLoopExit =
5045 BasicBlock::Create(Context&: Fun->getContext(), Name: "omp.scan.loop.exit");
5046}
5047CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton(
5048 DebugLoc DL, Value *TripCount, Function *F, BasicBlock *PreInsertBefore,
5049 BasicBlock *PostInsertBefore, const Twine &Name) {
5050 Module *M = F->getParent();
5051 LLVMContext &Ctx = M->getContext();
5052 Type *IndVarTy = TripCount->getType();
5053
5054 // Create the basic block structure.
5055 BasicBlock *Preheader =
5056 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".preheader", Parent: F, InsertBefore: PreInsertBefore);
5057 BasicBlock *Header =
5058 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".header", Parent: F, InsertBefore: PreInsertBefore);
5059 BasicBlock *Cond =
5060 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".cond", Parent: F, InsertBefore: PreInsertBefore);
5061 BasicBlock *Body =
5062 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".body", Parent: F, InsertBefore: PreInsertBefore);
5063 BasicBlock *Latch =
5064 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".inc", Parent: F, InsertBefore: PostInsertBefore);
5065 BasicBlock *Exit =
5066 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".exit", Parent: F, InsertBefore: PostInsertBefore);
5067 BasicBlock *After =
5068 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".after", Parent: F, InsertBefore: PostInsertBefore);
5069
5070 // Use specified DebugLoc for new instructions.
5071 Builder.SetCurrentDebugLocation(DL);
5072
5073 Builder.SetInsertPoint(Preheader);
5074 Builder.CreateBr(Dest: Header);
5075
5076 Builder.SetInsertPoint(Header);
5077 PHINode *IndVarPHI = Builder.CreatePHI(Ty: IndVarTy, NumReservedValues: 2, Name: "omp_" + Name + ".iv");
5078 IndVarPHI->addIncoming(V: ConstantInt::get(Ty: IndVarTy, V: 0), BB: Preheader);
5079 Builder.CreateBr(Dest: Cond);
5080
5081 Builder.SetInsertPoint(Cond);
5082 Value *Cmp =
5083 Builder.CreateICmpULT(LHS: IndVarPHI, RHS: TripCount, Name: "omp_" + Name + ".cmp");
5084 Builder.CreateCondBr(Cond: Cmp, True: Body, False: Exit);
5085
5086 Builder.SetInsertPoint(Body);
5087 Builder.CreateBr(Dest: Latch);
5088
5089 Builder.SetInsertPoint(Latch);
5090 Value *Next = Builder.CreateAdd(LHS: IndVarPHI, RHS: ConstantInt::get(Ty: IndVarTy, V: 1),
5091 Name: "omp_" + Name + ".next", /*HasNUW=*/true);
5092 Builder.CreateBr(Dest: Header);
5093 IndVarPHI->addIncoming(V: Next, BB: Latch);
5094
5095 Builder.SetInsertPoint(Exit);
5096 Builder.CreateBr(Dest: After);
5097
5098 // Remember and return the canonical control flow.
5099 LoopInfos.emplace_front();
5100 CanonicalLoopInfo *CL = &LoopInfos.front();
5101
5102 CL->Header = Header;
5103 CL->Cond = Cond;
5104 CL->Latch = Latch;
5105 CL->Exit = Exit;
5106
5107#ifndef NDEBUG
5108 CL->assertOK();
5109#endif
5110 return CL;
5111}
5112
5113Expected<CanonicalLoopInfo *>
5114OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
5115 LoopBodyGenCallbackTy BodyGenCB,
5116 Value *TripCount, const Twine &Name) {
5117 BasicBlock *BB = Loc.IP.getBlock();
5118 BasicBlock *NextBB = BB->getNextNode();
5119
5120 CanonicalLoopInfo *CL = createLoopSkeleton(DL: Loc.DL, TripCount, F: BB->getParent(),
5121 PreInsertBefore: NextBB, PostInsertBefore: NextBB, Name);
5122 BasicBlock *After = CL->getAfter();
5123
5124 // If location is not set, don't connect the loop.
5125 if (updateToLocation(Loc)) {
5126 // Split the loop at the insertion point: Branch to the preheader and move
5127 // every following instruction to after the loop (the After BB). Also, the
5128 // new successor is the loop's after block.
5129 spliceBB(Builder, New: After, /*CreateBranch=*/false);
5130 Builder.CreateBr(Dest: CL->getPreheader());
5131 }
5132
5133 // Emit the body content. We do it after connecting the loop to the CFG to
5134 // avoid that the callback encounters degenerate BBs.
5135 if (Error Err = BodyGenCB(CL->getBodyIP(), CL->getIndVar()))
5136 return Err;
5137
5138#ifndef NDEBUG
5139 CL->assertOK();
5140#endif
5141 return CL;
5142}
5143
5144Expected<ScanInfo *> OpenMPIRBuilder::scanInfoInitialize() {
5145 ScanInfos.emplace_front();
5146 ScanInfo *Result = &ScanInfos.front();
5147 return Result;
5148}
5149
5150Expected<SmallVector<llvm::CanonicalLoopInfo *>>
5151OpenMPIRBuilder::createCanonicalScanLoops(
5152 const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
5153 Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
5154 InsertPointTy ComputeIP, const Twine &Name, ScanInfo *ScanRedInfo) {
5155 LocationDescription ComputeLoc =
5156 ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
5157 updateToLocation(Loc: ComputeLoc);
5158
5159 SmallVector<CanonicalLoopInfo *> Result;
5160
5161 Value *TripCount = calculateCanonicalLoopTripCount(
5162 Loc: ComputeLoc, Start, Stop, Step, IsSigned, InclusiveStop, Name);
5163 ScanRedInfo->Span = TripCount;
5164 ScanRedInfo->OMPScanInit = splitBB(Builder, CreateBranch: true, Name: "scan.init");
5165 Builder.SetInsertPoint(ScanRedInfo->OMPScanInit);
5166
5167 auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
5168 Builder.restoreIP(IP: CodeGenIP);
5169 ScanRedInfo->IV = IV;
5170 createScanBBs(ScanRedInfo);
5171 BasicBlock *InputBlock = Builder.GetInsertBlock();
5172 Instruction *Terminator = InputBlock->getTerminator();
5173 assert(Terminator->getNumSuccessors() == 1);
5174 BasicBlock *ContinueBlock = Terminator->getSuccessor(Idx: 0);
5175 Terminator->setSuccessor(Idx: 0, BB: ScanRedInfo->OMPScanDispatch);
5176 emitBlock(BB: ScanRedInfo->OMPBeforeScanBlock,
5177 CurFn: Builder.GetInsertBlock()->getParent());
5178 Builder.CreateBr(Dest: ScanRedInfo->OMPScanLoopExit);
5179 emitBlock(BB: ScanRedInfo->OMPScanLoopExit,
5180 CurFn: Builder.GetInsertBlock()->getParent());
5181 Builder.CreateBr(Dest: ContinueBlock);
5182 Builder.SetInsertPoint(
5183 ScanRedInfo->OMPBeforeScanBlock->getFirstInsertionPt());
5184 return BodyGenCB(Builder.saveIP(), IV);
5185 };
5186
5187 const auto &&InputLoopGen = [&]() -> Error {
5188 Expected<CanonicalLoopInfo *> LoopInfo = createCanonicalLoop(
5189 Loc: Builder.saveIP(), BodyGenCB: BodyGen, Start, Stop, Step, IsSigned, InclusiveStop,
5190 ComputeIP, Name, InScan: true, ScanRedInfo);
5191 if (!LoopInfo)
5192 return LoopInfo.takeError();
5193 Result.push_back(Elt: *LoopInfo);
5194 Builder.restoreIP(IP: (*LoopInfo)->getAfterIP());
5195 return Error::success();
5196 };
5197 const auto &&ScanLoopGen = [&](LocationDescription Loc) -> Error {
5198 Expected<CanonicalLoopInfo *> LoopInfo =
5199 createCanonicalLoop(Loc, BodyGenCB: BodyGen, Start, Stop, Step, IsSigned,
5200 InclusiveStop, ComputeIP, Name, InScan: true, ScanRedInfo);
5201 if (!LoopInfo)
5202 return LoopInfo.takeError();
5203 Result.push_back(Elt: *LoopInfo);
5204 Builder.restoreIP(IP: (*LoopInfo)->getAfterIP());
5205 ScanRedInfo->OMPScanFinish = Builder.GetInsertBlock();
5206 return Error::success();
5207 };
5208 Error Err = emitScanBasedDirectiveIR(InputLoopGen, ScanLoopGen, ScanRedInfo);
5209 if (Err)
5210 return Err;
5211 return Result;
5212}
5213
5214Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
5215 const LocationDescription &Loc, Value *Start, Value *Stop, Value *Step,
5216 bool IsSigned, bool InclusiveStop, const Twine &Name) {
5217
5218 // Consider the following difficulties (assuming 8-bit signed integers):
5219 // * Adding \p Step to the loop counter which passes \p Stop may overflow:
5220 // DO I = 1, 100, 50
5221 /// * A \p Step of INT_MIN cannot not be normalized to a positive direction:
5222 // DO I = 100, 0, -128
5223
5224 // Start, Stop and Step must be of the same integer type.
5225 auto *IndVarTy = cast<IntegerType>(Val: Start->getType());
5226 assert(IndVarTy == Stop->getType() && "Stop type mismatch");
5227 assert(IndVarTy == Step->getType() && "Step type mismatch");
5228
5229 updateToLocation(Loc);
5230
5231 ConstantInt *Zero = ConstantInt::get(Ty: IndVarTy, V: 0);
5232 ConstantInt *One = ConstantInt::get(Ty: IndVarTy, V: 1);
5233
5234 // Like Step, but always positive.
5235 Value *Incr = Step;
5236
5237 // Distance between Start and Stop; always positive.
5238 Value *Span;
5239
5240 // Condition whether there are no iterations are executed at all, e.g. because
5241 // UB < LB.
5242 Value *ZeroCmp;
5243
5244 if (IsSigned) {
5245 // Ensure that increment is positive. If not, negate and invert LB and UB.
5246 Value *IsNeg = Builder.CreateICmpSLT(LHS: Step, RHS: Zero);
5247 Incr = Builder.CreateSelect(C: IsNeg, True: Builder.CreateNeg(V: Step), False: Step);
5248 Value *LB = Builder.CreateSelect(C: IsNeg, True: Stop, False: Start);
5249 Value *UB = Builder.CreateSelect(C: IsNeg, True: Start, False: Stop);
5250 Span = Builder.CreateSub(LHS: UB, RHS: LB, Name: "", HasNUW: false, HasNSW: true);
5251 ZeroCmp = Builder.CreateICmp(
5252 P: InclusiveStop ? CmpInst::ICMP_SLT : CmpInst::ICMP_SLE, LHS: UB, RHS: LB);
5253 } else {
5254 Span = Builder.CreateSub(LHS: Stop, RHS: Start, Name: "", HasNUW: true);
5255 ZeroCmp = Builder.CreateICmp(
5256 P: InclusiveStop ? CmpInst::ICMP_ULT : CmpInst::ICMP_ULE, LHS: Stop, RHS: Start);
5257 }
5258
5259 Value *CountIfLooping;
5260 if (InclusiveStop) {
5261 CountIfLooping = Builder.CreateAdd(LHS: Builder.CreateUDiv(LHS: Span, RHS: Incr), RHS: One);
5262 } else {
5263 // Avoid incrementing past stop since it could overflow.
5264 Value *CountIfTwo = Builder.CreateAdd(
5265 LHS: Builder.CreateUDiv(LHS: Builder.CreateSub(LHS: Span, RHS: One), RHS: Incr), RHS: One);
5266 Value *OneCmp = Builder.CreateICmp(P: CmpInst::ICMP_ULE, LHS: Span, RHS: Incr);
5267 CountIfLooping = Builder.CreateSelect(C: OneCmp, True: One, False: CountIfTwo);
5268 }
5269
5270 return Builder.CreateSelect(C: ZeroCmp, True: Zero, False: CountIfLooping,
5271 Name: "omp_" + Name + ".tripcount");
5272}
5273
5274Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
5275 const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
5276 Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
5277 InsertPointTy ComputeIP, const Twine &Name, bool InScan,
5278 ScanInfo *ScanRedInfo) {
5279 LocationDescription ComputeLoc =
5280 ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
5281
5282 Value *TripCount = calculateCanonicalLoopTripCount(
5283 Loc: ComputeLoc, Start, Stop, Step, IsSigned, InclusiveStop, Name);
5284
5285 auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
5286 Builder.restoreIP(IP: CodeGenIP);
5287 Value *Span = Builder.CreateMul(LHS: IV, RHS: Step);
5288 Value *IndVar = Builder.CreateAdd(LHS: Span, RHS: Start);
5289 if (InScan)
5290 ScanRedInfo->IV = IndVar;
5291 return BodyGenCB(Builder.saveIP(), IndVar);
5292 };
5293 LocationDescription LoopLoc =
5294 ComputeIP.isSet()
5295 ? Loc
5296 : LocationDescription(Builder.saveIP(),
5297 Builder.getCurrentDebugLocation());
5298 return createCanonicalLoop(Loc: LoopLoc, BodyGenCB: BodyGen, TripCount, Name);
5299}
5300
5301// Returns an LLVM function to call for initializing loop bounds using OpenMP
5302// static scheduling for composite `distribute parallel for` depending on
5303// `type`. Only i32 and i64 are supported by the runtime. Always interpret
5304// integers as unsigned similarly to CanonicalLoopInfo.
5305static FunctionCallee
5306getKmpcDistForStaticInitForType(Type *Ty, Module &M,
5307 OpenMPIRBuilder &OMPBuilder) {
5308 unsigned Bitwidth = Ty->getIntegerBitWidth();
5309 if (Bitwidth == 32)
5310 return OMPBuilder.getOrCreateRuntimeFunction(
5311 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dist_for_static_init_4u);
5312 if (Bitwidth == 64)
5313 return OMPBuilder.getOrCreateRuntimeFunction(
5314 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dist_for_static_init_8u);
5315 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
5316}
5317
5318// Returns an LLVM function to call for initializing loop bounds using OpenMP
5319// static scheduling depending on `type`. Only i32 and i64 are supported by the
5320// runtime. Always interpret integers as unsigned similarly to
5321// CanonicalLoopInfo.
5322static FunctionCallee getKmpcForStaticInitForType(Type *Ty, Module &M,
5323 OpenMPIRBuilder &OMPBuilder) {
5324 unsigned Bitwidth = Ty->getIntegerBitWidth();
5325 if (Bitwidth == 32)
5326 return OMPBuilder.getOrCreateRuntimeFunction(
5327 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_4u);
5328 if (Bitwidth == 64)
5329 return OMPBuilder.getOrCreateRuntimeFunction(
5330 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_8u);
5331 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
5332}
5333
5334OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
5335 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
5336 WorksharingLoopType LoopType, bool NeedsBarrier, bool HasDistSchedule,
5337 OMPScheduleType DistScheduleSchedType) {
5338 assert(CLI->isValid() && "Requires a valid canonical loop");
5339 assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
5340 "Require dedicated allocate IP");
5341
5342 // Set up the source location value for OpenMP runtime.
5343 Builder.restoreIP(IP: CLI->getPreheaderIP());
5344 Builder.SetCurrentDebugLocation(DL);
5345
5346 uint32_t SrcLocStrSize;
5347 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
5348 Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5349
5350 // Declare useful OpenMP runtime functions.
5351 Value *IV = CLI->getIndVar();
5352 Type *IVTy = IV->getType();
5353 FunctionCallee StaticInit =
5354 LoopType == WorksharingLoopType::DistributeForStaticLoop
5355 ? getKmpcDistForStaticInitForType(Ty: IVTy, M, OMPBuilder&: *this)
5356 : getKmpcForStaticInitForType(Ty: IVTy, M, OMPBuilder&: *this);
5357 FunctionCallee StaticFini =
5358 getOrCreateRuntimeFunction(M, FnID: omp::OMPRTL___kmpc_for_static_fini);
5359
5360 // Allocate space for computed loop bounds as expected by the "init" function.
5361 Builder.SetInsertPoint(AllocaIP.getBlock()->getFirstNonPHIOrDbgOrAlloca());
5362
5363 Type *I32Type = Type::getInt32Ty(C&: M.getContext());
5364 Value *PLastIter = Builder.CreateAlloca(Ty: I32Type, ArraySize: nullptr, Name: "p.lastiter");
5365 Value *PLowerBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.lowerbound");
5366 Value *PUpperBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.upperbound");
5367 Value *PStride = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.stride");
5368 CLI->setLastIter(PLastIter);
5369
5370 // At the end of the preheader, prepare for calling the "init" function by
5371 // storing the current loop bounds into the allocated space. A canonical loop
5372 // always iterates from 0 to trip-count with step 1. Note that "init" expects
5373 // and produces an inclusive upper bound.
5374 Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
5375 Constant *Zero = ConstantInt::get(Ty: IVTy, V: 0);
5376 Constant *One = ConstantInt::get(Ty: IVTy, V: 1);
5377 Builder.CreateStore(Val: Zero, Ptr: PLowerBound);
5378 Value *UpperBound = Builder.CreateSub(LHS: CLI->getTripCount(), RHS: One);
5379 Builder.CreateStore(Val: UpperBound, Ptr: PUpperBound);
5380 Builder.CreateStore(Val: One, Ptr: PStride);
5381
5382 Value *ThreadNum = getOrCreateThreadID(Ident: SrcLoc);
5383
5384 OMPScheduleType SchedType =
5385 (LoopType == WorksharingLoopType::DistributeStaticLoop)
5386 ? OMPScheduleType::OrderedDistribute
5387 : OMPScheduleType::UnorderedStatic;
5388 Constant *SchedulingType =
5389 ConstantInt::get(Ty: I32Type, V: static_cast<int>(SchedType));
5390
5391 // Call the "init" function and update the trip count of the loop with the
5392 // value it produced.
5393 auto BuildInitCall = [LoopType, SrcLoc, ThreadNum, PLastIter, PLowerBound,
5394 PUpperBound, IVTy, PStride, One, Zero, StaticInit,
5395 this](Value *SchedulingType, auto &Builder) {
5396 SmallVector<Value *, 10> Args({SrcLoc, ThreadNum, SchedulingType, PLastIter,
5397 PLowerBound, PUpperBound});
5398 if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
5399 Value *PDistUpperBound =
5400 Builder.CreateAlloca(IVTy, nullptr, "p.distupperbound");
5401 Args.push_back(Elt: PDistUpperBound);
5402 }
5403 Args.append(IL: {PStride, One, Zero});
5404 createRuntimeFunctionCall(Callee: StaticInit, Args);
5405 };
5406 BuildInitCall(SchedulingType, Builder);
5407 if (HasDistSchedule &&
5408 LoopType != WorksharingLoopType::DistributeStaticLoop) {
5409 Constant *DistScheduleSchedType = ConstantInt::get(
5410 Ty: I32Type, V: static_cast<int>(omp::OMPScheduleType::OrderedDistribute));
5411 // We want to emit a second init function call for the dist_schedule clause
5412 // to the Distribute construct. This should only be done however if a
5413 // Workshare Loop is nested within a Distribute Construct
5414 BuildInitCall(DistScheduleSchedType, Builder);
5415 }
5416 Value *LowerBound = Builder.CreateLoad(Ty: IVTy, Ptr: PLowerBound);
5417 Value *InclusiveUpperBound = Builder.CreateLoad(Ty: IVTy, Ptr: PUpperBound);
5418 Value *TripCountMinusOne = Builder.CreateSub(LHS: InclusiveUpperBound, RHS: LowerBound);
5419 Value *TripCount = Builder.CreateAdd(LHS: TripCountMinusOne, RHS: One);
5420 CLI->setTripCount(TripCount);
5421
5422 // Update all uses of the induction variable except the one in the condition
5423 // block that compares it with the actual upper bound, and the increment in
5424 // the latch block.
5425
5426 CLI->mapIndVar(Updater: [&](Instruction *OldIV) -> Value * {
5427 Builder.SetInsertPoint(TheBB: CLI->getBody(),
5428 IP: CLI->getBody()->getFirstInsertionPt());
5429 Builder.SetCurrentDebugLocation(DL);
5430 return Builder.CreateAdd(LHS: OldIV, RHS: LowerBound);
5431 });
5432
5433 // In the "exit" block, call the "fini" function.
5434 Builder.SetInsertPoint(TheBB: CLI->getExit(),
5435 IP: CLI->getExit()->getTerminator()->getIterator());
5436 createRuntimeFunctionCall(Callee: StaticFini, Args: {SrcLoc, ThreadNum});
5437
5438 // Add the barrier if requested.
5439 if (NeedsBarrier) {
5440 InsertPointOrErrorTy BarrierIP =
5441 createBarrier(Loc: LocationDescription(Builder.saveIP(), DL),
5442 Kind: omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
5443 /* CheckCancelFlag */ false);
5444 if (!BarrierIP)
5445 return BarrierIP.takeError();
5446 }
5447
5448 InsertPointTy AfterIP = CLI->getAfterIP();
5449 CLI->invalidate();
5450
5451 return AfterIP;
5452}
5453
5454static void addAccessGroupMetadata(BasicBlock *Block, MDNode *AccessGroup,
5455 LoopInfo &LI);
5456static void addLoopMetadata(CanonicalLoopInfo *Loop,
5457 ArrayRef<Metadata *> Properties);
5458
5459static void applyParallelAccessesMetadata(CanonicalLoopInfo *CLI,
5460 LLVMContext &Ctx, Loop *Loop,
5461 LoopInfo &LoopInfo,
5462 SmallVector<Metadata *> &LoopMDList) {
5463 SmallSet<BasicBlock *, 8> Reachable;
5464
5465 // Get the basic blocks from the loop in which memref instructions
5466 // can be found.
5467 // TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
5468 // preferably without running any passes.
5469 for (BasicBlock *Block : Loop->getBlocks()) {
5470 if (Block == CLI->getCond() || Block == CLI->getHeader())
5471 continue;
5472 Reachable.insert(Ptr: Block);
5473 }
5474
5475 // Add access group metadata to memory-access instructions.
5476 MDNode *AccessGroup = MDNode::getDistinct(Context&: Ctx, MDs: {});
5477 for (BasicBlock *BB : Reachable)
5478 addAccessGroupMetadata(Block: BB, AccessGroup, LI&: LoopInfo);
5479 // TODO: If the loop has existing parallel access metadata, have
5480 // to combine two lists.
5481 LoopMDList.push_back(Elt: MDNode::get(
5482 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.parallel_accesses"), AccessGroup}));
5483}
5484
5485OpenMPIRBuilder::InsertPointOrErrorTy
5486OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
5487 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
5488 bool NeedsBarrier, Value *ChunkSize, OMPScheduleType SchedType,
5489 Value *DistScheduleChunkSize, OMPScheduleType DistScheduleSchedType) {
5490 assert(CLI->isValid() && "Requires a valid canonical loop");
5491 assert((ChunkSize || DistScheduleChunkSize) && "Chunk size is required");
5492
5493 LLVMContext &Ctx = CLI->getFunction()->getContext();
5494 Value *IV = CLI->getIndVar();
5495 Value *OrigTripCount = CLI->getTripCount();
5496 Type *IVTy = IV->getType();
5497 assert(IVTy->getIntegerBitWidth() <= 64 &&
5498 "Max supported tripcount bitwidth is 64 bits");
5499 Type *InternalIVTy = IVTy->getIntegerBitWidth() <= 32 ? Type::getInt32Ty(C&: Ctx)
5500 : Type::getInt64Ty(C&: Ctx);
5501 Type *I32Type = Type::getInt32Ty(C&: M.getContext());
5502 Constant *Zero = ConstantInt::get(Ty: InternalIVTy, V: 0);
5503 Constant *One = ConstantInt::get(Ty: InternalIVTy, V: 1);
5504
5505 Function *F = CLI->getFunction();
5506 FunctionAnalysisManager FAM;
5507 FAM.registerPass(PassBuilder: []() { return DominatorTreeAnalysis(); });
5508 FAM.registerPass(PassBuilder: []() { return PassInstrumentationAnalysis(); });
5509 LoopAnalysis LIA;
5510 LoopInfo &&LI = LIA.run(F&: *F, AM&: FAM);
5511 Loop *L = LI.getLoopFor(BB: CLI->getHeader());
5512 SmallVector<Metadata *> LoopMDList;
5513 if (ChunkSize || DistScheduleChunkSize)
5514 applyParallelAccessesMetadata(CLI, Ctx, Loop: L, LoopInfo&: LI, LoopMDList);
5515 addLoopMetadata(Loop: CLI, Properties: LoopMDList);
5516
5517 // Declare useful OpenMP runtime functions.
5518 FunctionCallee StaticInit =
5519 getKmpcForStaticInitForType(Ty: InternalIVTy, M, OMPBuilder&: *this);
5520 FunctionCallee StaticFini =
5521 getOrCreateRuntimeFunction(M, FnID: omp::OMPRTL___kmpc_for_static_fini);
5522
5523 // Allocate space for computed loop bounds as expected by the "init" function.
5524 Builder.restoreIP(IP: AllocaIP);
5525 Builder.SetCurrentDebugLocation(DL);
5526 Value *PLastIter = Builder.CreateAlloca(Ty: I32Type, ArraySize: nullptr, Name: "p.lastiter");
5527 Value *PLowerBound =
5528 Builder.CreateAlloca(Ty: InternalIVTy, ArraySize: nullptr, Name: "p.lowerbound");
5529 Value *PUpperBound =
5530 Builder.CreateAlloca(Ty: InternalIVTy, ArraySize: nullptr, Name: "p.upperbound");
5531 Value *PStride = Builder.CreateAlloca(Ty: InternalIVTy, ArraySize: nullptr, Name: "p.stride");
5532 CLI->setLastIter(PLastIter);
5533
5534 // Set up the source location value for the OpenMP runtime.
5535 Builder.restoreIP(IP: CLI->getPreheaderIP());
5536 Builder.SetCurrentDebugLocation(DL);
5537
5538 // TODO: Detect overflow in ubsan or max-out with current tripcount.
5539 Value *CastedChunkSize = Builder.CreateZExtOrTrunc(
5540 V: ChunkSize ? ChunkSize : Zero, DestTy: InternalIVTy, Name: "chunksize");
5541 Value *CastedDistScheduleChunkSize = Builder.CreateZExtOrTrunc(
5542 V: DistScheduleChunkSize ? DistScheduleChunkSize : Zero, DestTy: InternalIVTy,
5543 Name: "distschedulechunksize");
5544 Value *CastedTripCount =
5545 Builder.CreateZExt(V: OrigTripCount, DestTy: InternalIVTy, Name: "tripcount");
5546
5547 Constant *SchedulingType =
5548 ConstantInt::get(Ty: I32Type, V: static_cast<int>(SchedType));
5549 Constant *DistSchedulingType =
5550 ConstantInt::get(Ty: I32Type, V: static_cast<int>(DistScheduleSchedType));
5551 Builder.CreateStore(Val: Zero, Ptr: PLowerBound);
5552 Value *OrigUpperBound = Builder.CreateSub(LHS: CastedTripCount, RHS: One);
5553 Value *IsTripCountZero = Builder.CreateICmpEQ(LHS: CastedTripCount, RHS: Zero);
5554 Value *UpperBound =
5555 Builder.CreateSelect(C: IsTripCountZero, True: Zero, False: OrigUpperBound);
5556 Builder.CreateStore(Val: UpperBound, Ptr: PUpperBound);
5557 Builder.CreateStore(Val: One, Ptr: PStride);
5558
5559 // Call the "init" function and update the trip count of the loop with the
5560 // value it produced.
5561 uint32_t SrcLocStrSize;
5562 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
5563 Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5564 Value *ThreadNum = getOrCreateThreadID(Ident: SrcLoc);
5565 auto BuildInitCall = [StaticInit, SrcLoc, ThreadNum, PLastIter, PLowerBound,
5566 PUpperBound, PStride, One,
5567 this](Value *SchedulingType, Value *ChunkSize,
5568 auto &Builder) {
5569 createRuntimeFunctionCall(
5570 Callee: StaticInit, Args: {/*loc=*/SrcLoc, /*global_tid=*/ThreadNum,
5571 /*schedtype=*/SchedulingType, /*plastiter=*/PLastIter,
5572 /*plower=*/PLowerBound, /*pupper=*/PUpperBound,
5573 /*pstride=*/PStride, /*incr=*/One,
5574 /*chunk=*/ChunkSize});
5575 };
5576 BuildInitCall(SchedulingType, CastedChunkSize, Builder);
5577 if (DistScheduleSchedType != OMPScheduleType::None &&
5578 SchedType != OMPScheduleType::OrderedDistributeChunked &&
5579 SchedType != OMPScheduleType::OrderedDistribute) {
5580 // We want to emit a second init function call for the dist_schedule clause
5581 // to the Distribute construct. This should only be done however if a
5582 // Workshare Loop is nested within a Distribute Construct
5583 BuildInitCall(DistSchedulingType, CastedDistScheduleChunkSize, Builder);
5584 }
5585
5586 // Load values written by the "init" function.
5587 Value *FirstChunkStart =
5588 Builder.CreateLoad(Ty: InternalIVTy, Ptr: PLowerBound, Name: "omp_firstchunk.lb");
5589 Value *FirstChunkStop =
5590 Builder.CreateLoad(Ty: InternalIVTy, Ptr: PUpperBound, Name: "omp_firstchunk.ub");
5591 Value *FirstChunkEnd = Builder.CreateAdd(LHS: FirstChunkStop, RHS: One);
5592 Value *ChunkRange =
5593 Builder.CreateSub(LHS: FirstChunkEnd, RHS: FirstChunkStart, Name: "omp_chunk.range");
5594 Value *NextChunkStride =
5595 Builder.CreateLoad(Ty: InternalIVTy, Ptr: PStride, Name: "omp_dispatch.stride");
5596
5597 // Create outer "dispatch" loop for enumerating the chunks.
5598 BasicBlock *DispatchEnter = splitBB(Builder, CreateBranch: true);
5599 Value *DispatchCounter;
5600
5601 // It is safe to assume this didn't return an error because the callback
5602 // passed into createCanonicalLoop is the only possible error source, and it
5603 // always returns success.
5604 CanonicalLoopInfo *DispatchCLI = cantFail(ValOrErr: createCanonicalLoop(
5605 Loc: {Builder.saveIP(), DL},
5606 BodyGenCB: [&](InsertPointTy BodyIP, Value *Counter) {
5607 DispatchCounter = Counter;
5608 return Error::success();
5609 },
5610 Start: FirstChunkStart, Stop: CastedTripCount, Step: NextChunkStride,
5611 /*IsSigned=*/false, /*InclusiveStop=*/false, /*ComputeIP=*/{},
5612 Name: "dispatch"));
5613
5614 // Remember the BasicBlocks of the dispatch loop we need, then invalidate to
5615 // not have to preserve the canonical invariant.
5616 BasicBlock *DispatchBody = DispatchCLI->getBody();
5617 BasicBlock *DispatchLatch = DispatchCLI->getLatch();
5618 BasicBlock *DispatchExit = DispatchCLI->getExit();
5619 BasicBlock *DispatchAfter = DispatchCLI->getAfter();
5620 DispatchCLI->invalidate();
5621
5622 // Rewire the original loop to become the chunk loop inside the dispatch loop.
5623 redirectTo(Source: DispatchAfter, Target: CLI->getAfter(), DL);
5624 redirectTo(Source: CLI->getExit(), Target: DispatchLatch, DL);
5625 redirectTo(Source: DispatchBody, Target: DispatchEnter, DL);
5626
5627 // Prepare the prolog of the chunk loop.
5628 Builder.restoreIP(IP: CLI->getPreheaderIP());
5629 Builder.SetCurrentDebugLocation(DL);
5630
5631 // Compute the number of iterations of the chunk loop.
5632 Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
5633 Value *ChunkEnd = Builder.CreateAdd(LHS: DispatchCounter, RHS: ChunkRange);
5634 Value *IsLastChunk =
5635 Builder.CreateICmpUGE(LHS: ChunkEnd, RHS: CastedTripCount, Name: "omp_chunk.is_last");
5636 Value *CountUntilOrigTripCount =
5637 Builder.CreateSub(LHS: CastedTripCount, RHS: DispatchCounter);
5638 Value *ChunkTripCount = Builder.CreateSelect(
5639 C: IsLastChunk, True: CountUntilOrigTripCount, False: ChunkRange, Name: "omp_chunk.tripcount");
5640 Value *BackcastedChunkTC =
5641 Builder.CreateTrunc(V: ChunkTripCount, DestTy: IVTy, Name: "omp_chunk.tripcount.trunc");
5642 CLI->setTripCount(BackcastedChunkTC);
5643
5644 // Update all uses of the induction variable except the one in the condition
5645 // block that compares it with the actual upper bound, and the increment in
5646 // the latch block.
5647 Value *BackcastedDispatchCounter =
5648 Builder.CreateTrunc(V: DispatchCounter, DestTy: IVTy, Name: "omp_dispatch.iv.trunc");
5649 CLI->mapIndVar(Updater: [&](Instruction *) -> Value * {
5650 Builder.restoreIP(IP: CLI->getBodyIP());
5651 return Builder.CreateAdd(LHS: IV, RHS: BackcastedDispatchCounter);
5652 });
5653
5654 // In the "exit" block, call the "fini" function.
5655 Builder.SetInsertPoint(TheBB: DispatchExit, IP: DispatchExit->getFirstInsertionPt());
5656 createRuntimeFunctionCall(Callee: StaticFini, Args: {SrcLoc, ThreadNum});
5657
5658 // Add the barrier if requested.
5659 if (NeedsBarrier) {
5660 InsertPointOrErrorTy AfterIP =
5661 createBarrier(Loc: LocationDescription(Builder.saveIP(), DL), Kind: OMPD_for,
5662 /*ForceSimpleCall=*/false, /*CheckCancelFlag=*/false);
5663 if (!AfterIP)
5664 return AfterIP.takeError();
5665 }
5666
5667#ifndef NDEBUG
5668 // Even though we currently do not support applying additional methods to it,
5669 // the chunk loop should remain a canonical loop.
5670 CLI->assertOK();
5671#endif
5672
5673 return InsertPointTy(DispatchAfter, DispatchAfter->getFirstInsertionPt());
5674}
5675
5676// Returns an LLVM function to call for executing an OpenMP static worksharing
5677// for loop depending on `type`. Only i32 and i64 are supported by the runtime.
5678// Always interpret integers as unsigned similarly to CanonicalLoopInfo.
5679static FunctionCallee
5680getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
5681 WorksharingLoopType LoopType) {
5682 unsigned Bitwidth = Ty->getIntegerBitWidth();
5683 Module &M = OMPBuilder->M;
5684 switch (LoopType) {
5685 case WorksharingLoopType::ForStaticLoop:
5686 if (Bitwidth == 32)
5687 return OMPBuilder->getOrCreateRuntimeFunction(
5688 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_4u);
5689 if (Bitwidth == 64)
5690 return OMPBuilder->getOrCreateRuntimeFunction(
5691 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_8u);
5692 break;
5693 case WorksharingLoopType::DistributeStaticLoop:
5694 if (Bitwidth == 32)
5695 return OMPBuilder->getOrCreateRuntimeFunction(
5696 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_4u);
5697 if (Bitwidth == 64)
5698 return OMPBuilder->getOrCreateRuntimeFunction(
5699 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_8u);
5700 break;
5701 case WorksharingLoopType::DistributeForStaticLoop:
5702 if (Bitwidth == 32)
5703 return OMPBuilder->getOrCreateRuntimeFunction(
5704 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_4u);
5705 if (Bitwidth == 64)
5706 return OMPBuilder->getOrCreateRuntimeFunction(
5707 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_8u);
5708 break;
5709 }
5710 if (Bitwidth != 32 && Bitwidth != 64) {
5711 llvm_unreachable("Unknown OpenMP loop iterator bitwidth");
5712 }
5713 llvm_unreachable("Unknown type of OpenMP worksharing loop");
5714}
5715
5716// Inserts a call to proper OpenMP Device RTL function which handles
5717// loop worksharing.
5718static void createTargetLoopWorkshareCall(OpenMPIRBuilder *OMPBuilder,
5719 WorksharingLoopType LoopType,
5720 BasicBlock *InsertBlock, Value *Ident,
5721 Value *LoopBodyArg, Value *TripCount,
5722 Function &LoopBodyFn, bool NoLoop) {
5723 Type *TripCountTy = TripCount->getType();
5724 Module &M = OMPBuilder->M;
5725 IRBuilder<> &Builder = OMPBuilder->Builder;
5726 FunctionCallee RTLFn =
5727 getKmpcForStaticLoopForType(Ty: TripCountTy, OMPBuilder, LoopType);
5728 SmallVector<Value *, 8> RealArgs;
5729 RealArgs.push_back(Elt: Ident);
5730 RealArgs.push_back(Elt: &LoopBodyFn);
5731 RealArgs.push_back(Elt: LoopBodyArg);
5732 RealArgs.push_back(Elt: TripCount);
5733 if (LoopType == WorksharingLoopType::DistributeStaticLoop) {
5734 RealArgs.push_back(Elt: ConstantInt::get(Ty: TripCountTy, V: 0));
5735 RealArgs.push_back(Elt: ConstantInt::get(Ty: Builder.getInt8Ty(), V: 0));
5736 Builder.restoreIP(IP: {InsertBlock, std::prev(x: InsertBlock->end())});
5737 OMPBuilder->createRuntimeFunctionCall(Callee: RTLFn, Args: RealArgs);
5738 return;
5739 }
5740 FunctionCallee RTLNumThreads = OMPBuilder->getOrCreateRuntimeFunction(
5741 M, FnID: omp::RuntimeFunction::OMPRTL_omp_get_num_threads);
5742 Builder.restoreIP(IP: {InsertBlock, std::prev(x: InsertBlock->end())});
5743 Value *NumThreads = OMPBuilder->createRuntimeFunctionCall(Callee: RTLNumThreads, Args: {});
5744
5745 RealArgs.push_back(
5746 Elt: Builder.CreateZExtOrTrunc(V: NumThreads, DestTy: TripCountTy, Name: "num.threads.cast"));
5747 RealArgs.push_back(Elt: ConstantInt::get(Ty: TripCountTy, V: 0));
5748 if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
5749 RealArgs.push_back(Elt: ConstantInt::get(Ty: TripCountTy, V: 0));
5750 RealArgs.push_back(Elt: ConstantInt::get(Ty: Builder.getInt8Ty(), V: NoLoop));
5751 } else {
5752 RealArgs.push_back(Elt: ConstantInt::get(Ty: Builder.getInt8Ty(), V: 0));
5753 }
5754
5755 OMPBuilder->createRuntimeFunctionCall(Callee: RTLFn, Args: RealArgs);
5756}
5757
5758static void workshareLoopTargetCallback(
5759 OpenMPIRBuilder *OMPIRBuilder, CanonicalLoopInfo *CLI, Value *Ident,
5760 Function &OutlinedFn, const SmallVector<Instruction *, 4> &ToBeDeleted,
5761 WorksharingLoopType LoopType, bool NoLoop) {
5762 IRBuilder<> &Builder = OMPIRBuilder->Builder;
5763 BasicBlock *Preheader = CLI->getPreheader();
5764 Value *TripCount = CLI->getTripCount();
5765
5766 // After loop body outling, the loop body contains only set up
5767 // of loop body argument structure and the call to the outlined
5768 // loop body function. Firstly, we need to move setup of loop body args
5769 // into loop preheader.
5770 Preheader->splice(ToIt: std::prev(x: Preheader->end()), FromBB: CLI->getBody(),
5771 FromBeginIt: CLI->getBody()->begin(), FromEndIt: std::prev(x: CLI->getBody()->end()));
5772
5773 // The next step is to remove the whole loop. We do not it need anymore.
5774 // That's why make an unconditional branch from loop preheader to loop
5775 // exit block
5776 Builder.restoreIP(IP: {Preheader, Preheader->end()});
5777 Builder.SetCurrentDebugLocation(Preheader->getTerminator()->getDebugLoc());
5778 Preheader->getTerminator()->eraseFromParent();
5779 Builder.CreateBr(Dest: CLI->getExit());
5780
5781 // Delete dead loop blocks
5782 OpenMPIRBuilder::OutlineInfo CleanUpInfo;
5783 SmallPtrSet<BasicBlock *, 32> RegionBlockSet;
5784 SmallVector<BasicBlock *, 32> BlocksToBeRemoved;
5785 CleanUpInfo.EntryBB = CLI->getHeader();
5786 CleanUpInfo.ExitBB = CLI->getExit();
5787 CleanUpInfo.collectBlocks(BlockSet&: RegionBlockSet, BlockVector&: BlocksToBeRemoved);
5788 DeleteDeadBlocks(BBs: BlocksToBeRemoved);
5789
5790 // Find the instruction which corresponds to loop body argument structure
5791 // and remove the call to loop body function instruction.
5792 Value *LoopBodyArg;
5793 User *OutlinedFnUser = OutlinedFn.getUniqueUndroppableUser();
5794 assert(OutlinedFnUser &&
5795 "Expected unique undroppable user of outlined function");
5796 CallInst *OutlinedFnCallInstruction = dyn_cast<CallInst>(Val: OutlinedFnUser);
5797 assert(OutlinedFnCallInstruction && "Expected outlined function call");
5798 assert((OutlinedFnCallInstruction->getParent() == Preheader) &&
5799 "Expected outlined function call to be located in loop preheader");
5800 // Check in case no argument structure has been passed.
5801 if (OutlinedFnCallInstruction->arg_size() > 1)
5802 LoopBodyArg = OutlinedFnCallInstruction->getArgOperand(i: 1);
5803 else
5804 LoopBodyArg = Constant::getNullValue(Ty: Builder.getPtrTy());
5805 OutlinedFnCallInstruction->eraseFromParent();
5806
5807 createTargetLoopWorkshareCall(OMPBuilder: OMPIRBuilder, LoopType, InsertBlock: Preheader, Ident,
5808 LoopBodyArg, TripCount, LoopBodyFn&: OutlinedFn, NoLoop);
5809
5810 for (auto &ToBeDeletedItem : ToBeDeleted)
5811 ToBeDeletedItem->eraseFromParent();
5812 CLI->invalidate();
5813}
5814
5815OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoopTarget(
5816 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
5817 WorksharingLoopType LoopType, bool NoLoop) {
5818 uint32_t SrcLocStrSize;
5819 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
5820 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5821
5822 OutlineInfo OI;
5823 OI.OuterAllocaBB = CLI->getPreheader();
5824 Function *OuterFn = CLI->getPreheader()->getParent();
5825
5826 // Instructions which need to be deleted at the end of code generation
5827 SmallVector<Instruction *, 4> ToBeDeleted;
5828
5829 OI.OuterAllocaBB = AllocaIP.getBlock();
5830
5831 // Mark the body loop as region which needs to be extracted
5832 OI.EntryBB = CLI->getBody();
5833 OI.ExitBB = CLI->getLatch()->splitBasicBlockBefore(I: CLI->getLatch()->begin(),
5834 BBName: "omp.prelatch");
5835
5836 // Prepare loop body for extraction
5837 Builder.restoreIP(IP: {CLI->getPreheader(), CLI->getPreheader()->begin()});
5838
5839 // Insert new loop counter variable which will be used only in loop
5840 // body.
5841 AllocaInst *NewLoopCnt = Builder.CreateAlloca(Ty: CLI->getIndVarType(), ArraySize: 0, Name: "");
5842 Instruction *NewLoopCntLoad =
5843 Builder.CreateLoad(Ty: CLI->getIndVarType(), Ptr: NewLoopCnt);
5844 // New loop counter instructions are redundant in the loop preheader when
5845 // code generation for workshare loop is finshed. That's why mark them as
5846 // ready for deletion.
5847 ToBeDeleted.push_back(Elt: NewLoopCntLoad);
5848 ToBeDeleted.push_back(Elt: NewLoopCnt);
5849
5850 // Analyse loop body region. Find all input variables which are used inside
5851 // loop body region.
5852 SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
5853 SmallVector<BasicBlock *, 32> Blocks;
5854 OI.collectBlocks(BlockSet&: ParallelRegionBlockSet, BlockVector&: Blocks);
5855
5856 CodeExtractorAnalysisCache CEAC(*OuterFn);
5857 CodeExtractor Extractor(Blocks,
5858 /* DominatorTree */ nullptr,
5859 /* AggregateArgs */ true,
5860 /* BlockFrequencyInfo */ nullptr,
5861 /* BranchProbabilityInfo */ nullptr,
5862 /* AssumptionCache */ nullptr,
5863 /* AllowVarArgs */ true,
5864 /* AllowAlloca */ true,
5865 /* AllocationBlock */ CLI->getPreheader(),
5866 /* Suffix */ ".omp_wsloop",
5867 /* AggrArgsIn0AddrSpace */ true);
5868
5869 BasicBlock *CommonExit = nullptr;
5870 SetVector<Value *> SinkingCands, HoistingCands;
5871
5872 // Find allocas outside the loop body region which are used inside loop
5873 // body
5874 Extractor.findAllocas(CEAC, SinkCands&: SinkingCands, HoistCands&: HoistingCands, ExitBlock&: CommonExit);
5875
5876 // We need to model loop body region as the function f(cnt, loop_arg).
5877 // That's why we replace loop induction variable by the new counter
5878 // which will be one of loop body function argument
5879 SmallVector<User *> Users(CLI->getIndVar()->user_begin(),
5880 CLI->getIndVar()->user_end());
5881 for (auto Use : Users) {
5882 if (Instruction *Inst = dyn_cast<Instruction>(Val: Use)) {
5883 if (ParallelRegionBlockSet.count(Ptr: Inst->getParent())) {
5884 Inst->replaceUsesOfWith(From: CLI->getIndVar(), To: NewLoopCntLoad);
5885 }
5886 }
5887 }
5888 // Make sure that loop counter variable is not merged into loop body
5889 // function argument structure and it is passed as separate variable
5890 OI.ExcludeArgsFromAggregate.push_back(Elt: NewLoopCntLoad);
5891
5892 // PostOutline CB is invoked when loop body function is outlined and
5893 // loop body is replaced by call to outlined function. We need to add
5894 // call to OpenMP device rtl inside loop preheader. OpenMP device rtl
5895 // function will handle loop control logic.
5896 //
5897 OI.PostOutlineCB = [=, ToBeDeletedVec =
5898 std::move(ToBeDeleted)](Function &OutlinedFn) {
5899 workshareLoopTargetCallback(OMPIRBuilder: this, CLI, Ident, OutlinedFn, ToBeDeleted: ToBeDeletedVec,
5900 LoopType, NoLoop);
5901 };
5902 addOutlineInfo(OI: std::move(OI));
5903 return CLI->getAfterIP();
5904}
5905
5906OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyWorkshareLoop(
5907 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
5908 bool NeedsBarrier, omp::ScheduleKind SchedKind, Value *ChunkSize,
5909 bool HasSimdModifier, bool HasMonotonicModifier,
5910 bool HasNonmonotonicModifier, bool HasOrderedClause,
5911 WorksharingLoopType LoopType, bool NoLoop, bool HasDistSchedule,
5912 Value *DistScheduleChunkSize) {
5913 if (Config.isTargetDevice())
5914 return applyWorkshareLoopTarget(DL, CLI, AllocaIP, LoopType, NoLoop);
5915 OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType(
5916 ClauseKind: SchedKind, HasChunks: ChunkSize, HasSimdModifier, HasMonotonicModifier,
5917 HasNonmonotonicModifier, HasOrderedClause, HasDistScheduleChunks: DistScheduleChunkSize);
5918
5919 bool IsOrdered = (EffectiveScheduleType & OMPScheduleType::ModifierOrdered) ==
5920 OMPScheduleType::ModifierOrdered;
5921 OMPScheduleType DistScheduleSchedType = OMPScheduleType::None;
5922 if (HasDistSchedule) {
5923 DistScheduleSchedType = DistScheduleChunkSize
5924 ? OMPScheduleType::OrderedDistributeChunked
5925 : OMPScheduleType::OrderedDistribute;
5926 }
5927 switch (EffectiveScheduleType & ~OMPScheduleType::ModifierMask) {
5928 case OMPScheduleType::BaseStatic:
5929 case OMPScheduleType::BaseDistribute:
5930 assert((!ChunkSize || !DistScheduleChunkSize) &&
5931 "No chunk size with static-chunked schedule");
5932 if (IsOrdered && !HasDistSchedule)
5933 return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, SchedType: EffectiveScheduleType,
5934 NeedsBarrier, Chunk: ChunkSize);
5935 // FIXME: Monotonicity ignored?
5936 if (DistScheduleChunkSize)
5937 return applyStaticChunkedWorkshareLoop(
5938 DL, CLI, AllocaIP, NeedsBarrier, ChunkSize, SchedType: EffectiveScheduleType,
5939 DistScheduleChunkSize, DistScheduleSchedType);
5940 return applyStaticWorkshareLoop(DL, CLI, AllocaIP, LoopType, NeedsBarrier,
5941 HasDistSchedule);
5942
5943 case OMPScheduleType::BaseStaticChunked:
5944 case OMPScheduleType::BaseDistributeChunked:
5945 if (IsOrdered && !HasDistSchedule)
5946 return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, SchedType: EffectiveScheduleType,
5947 NeedsBarrier, Chunk: ChunkSize);
5948 // FIXME: Monotonicity ignored?
5949 return applyStaticChunkedWorkshareLoop(
5950 DL, CLI, AllocaIP, NeedsBarrier, ChunkSize, SchedType: EffectiveScheduleType,
5951 DistScheduleChunkSize, DistScheduleSchedType);
5952
5953 case OMPScheduleType::BaseRuntime:
5954 case OMPScheduleType::BaseAuto:
5955 case OMPScheduleType::BaseGreedy:
5956 case OMPScheduleType::BaseBalanced:
5957 case OMPScheduleType::BaseSteal:
5958 case OMPScheduleType::BaseRuntimeSimd:
5959 assert(!ChunkSize &&
5960 "schedule type does not support user-defined chunk sizes");
5961 [[fallthrough]];
5962 case OMPScheduleType::BaseGuidedSimd:
5963 case OMPScheduleType::BaseDynamicChunked:
5964 case OMPScheduleType::BaseGuidedChunked:
5965 case OMPScheduleType::BaseGuidedIterativeChunked:
5966 case OMPScheduleType::BaseGuidedAnalyticalChunked:
5967 case OMPScheduleType::BaseStaticBalancedChunked:
5968 return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, SchedType: EffectiveScheduleType,
5969 NeedsBarrier, Chunk: ChunkSize);
5970
5971 default:
5972 llvm_unreachable("Unknown/unimplemented schedule kind");
5973 }
5974}
5975
5976/// Returns an LLVM function to call for initializing loop bounds using OpenMP
5977/// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
5978/// the runtime. Always interpret integers as unsigned similarly to
5979/// CanonicalLoopInfo.
5980static FunctionCallee
5981getKmpcForDynamicInitForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
5982 unsigned Bitwidth = Ty->getIntegerBitWidth();
5983 if (Bitwidth == 32)
5984 return OMPBuilder.getOrCreateRuntimeFunction(
5985 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_4u);
5986 if (Bitwidth == 64)
5987 return OMPBuilder.getOrCreateRuntimeFunction(
5988 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_8u);
5989 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
5990}
5991
5992/// Returns an LLVM function to call for updating the next loop using OpenMP
5993/// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
5994/// the runtime. Always interpret integers as unsigned similarly to
5995/// CanonicalLoopInfo.
5996static FunctionCallee
5997getKmpcForDynamicNextForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
5998 unsigned Bitwidth = Ty->getIntegerBitWidth();
5999 if (Bitwidth == 32)
6000 return OMPBuilder.getOrCreateRuntimeFunction(
6001 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_4u);
6002 if (Bitwidth == 64)
6003 return OMPBuilder.getOrCreateRuntimeFunction(
6004 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_8u);
6005 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
6006}
6007
6008/// Returns an LLVM function to call for finalizing the dynamic loop using
6009/// depending on `type`. Only i32 and i64 are supported by the runtime. Always
6010/// interpret integers as unsigned similarly to CanonicalLoopInfo.
6011static FunctionCallee
6012getKmpcForDynamicFiniForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
6013 unsigned Bitwidth = Ty->getIntegerBitWidth();
6014 if (Bitwidth == 32)
6015 return OMPBuilder.getOrCreateRuntimeFunction(
6016 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_4u);
6017 if (Bitwidth == 64)
6018 return OMPBuilder.getOrCreateRuntimeFunction(
6019 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_8u);
6020 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
6021}
6022
6023OpenMPIRBuilder::InsertPointOrErrorTy
6024OpenMPIRBuilder::applyDynamicWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
6025 InsertPointTy AllocaIP,
6026 OMPScheduleType SchedType,
6027 bool NeedsBarrier, Value *Chunk) {
6028 assert(CLI->isValid() && "Requires a valid canonical loop");
6029 assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
6030 "Require dedicated allocate IP");
6031 assert(isValidWorkshareLoopScheduleType(SchedType) &&
6032 "Require valid schedule type");
6033
6034 bool Ordered = (SchedType & OMPScheduleType::ModifierOrdered) ==
6035 OMPScheduleType::ModifierOrdered;
6036
6037 // Set up the source location value for OpenMP runtime.
6038 Builder.SetCurrentDebugLocation(DL);
6039
6040 uint32_t SrcLocStrSize;
6041 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
6042 Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6043
6044 // Declare useful OpenMP runtime functions.
6045 Value *IV = CLI->getIndVar();
6046 Type *IVTy = IV->getType();
6047 FunctionCallee DynamicInit = getKmpcForDynamicInitForType(Ty: IVTy, M, OMPBuilder&: *this);
6048 FunctionCallee DynamicNext = getKmpcForDynamicNextForType(Ty: IVTy, M, OMPBuilder&: *this);
6049
6050 // Allocate space for computed loop bounds as expected by the "init" function.
6051 Builder.SetInsertPoint(AllocaIP.getBlock()->getFirstNonPHIOrDbgOrAlloca());
6052 Type *I32Type = Type::getInt32Ty(C&: M.getContext());
6053 Value *PLastIter = Builder.CreateAlloca(Ty: I32Type, ArraySize: nullptr, Name: "p.lastiter");
6054 Value *PLowerBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.lowerbound");
6055 Value *PUpperBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.upperbound");
6056 Value *PStride = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.stride");
6057 CLI->setLastIter(PLastIter);
6058
6059 // At the end of the preheader, prepare for calling the "init" function by
6060 // storing the current loop bounds into the allocated space. A canonical loop
6061 // always iterates from 0 to trip-count with step 1. Note that "init" expects
6062 // and produces an inclusive upper bound.
6063 BasicBlock *PreHeader = CLI->getPreheader();
6064 Builder.SetInsertPoint(PreHeader->getTerminator());
6065 Constant *One = ConstantInt::get(Ty: IVTy, V: 1);
6066 Builder.CreateStore(Val: One, Ptr: PLowerBound);
6067 Value *UpperBound = CLI->getTripCount();
6068 Builder.CreateStore(Val: UpperBound, Ptr: PUpperBound);
6069 Builder.CreateStore(Val: One, Ptr: PStride);
6070
6071 BasicBlock *Header = CLI->getHeader();
6072 BasicBlock *Exit = CLI->getExit();
6073 BasicBlock *Cond = CLI->getCond();
6074 BasicBlock *Latch = CLI->getLatch();
6075 InsertPointTy AfterIP = CLI->getAfterIP();
6076
6077 // The CLI will be "broken" in the code below, as the loop is no longer
6078 // a valid canonical loop.
6079
6080 if (!Chunk)
6081 Chunk = One;
6082
6083 Value *ThreadNum = getOrCreateThreadID(Ident: SrcLoc);
6084
6085 Constant *SchedulingType =
6086 ConstantInt::get(Ty: I32Type, V: static_cast<int>(SchedType));
6087
6088 // Call the "init" function.
6089 createRuntimeFunctionCall(Callee: DynamicInit, Args: {SrcLoc, ThreadNum, SchedulingType,
6090 /* LowerBound */ One, UpperBound,
6091 /* step */ One, Chunk});
6092
6093 // An outer loop around the existing one.
6094 BasicBlock *OuterCond = BasicBlock::Create(
6095 Context&: PreHeader->getContext(), Name: Twine(PreHeader->getName()) + ".outer.cond",
6096 Parent: PreHeader->getParent());
6097 // This needs to be 32-bit always, so can't use the IVTy Zero above.
6098 Builder.SetInsertPoint(TheBB: OuterCond, IP: OuterCond->getFirstInsertionPt());
6099 Value *Res = createRuntimeFunctionCall(
6100 Callee: DynamicNext,
6101 Args: {SrcLoc, ThreadNum, PLastIter, PLowerBound, PUpperBound, PStride});
6102 Constant *Zero32 = ConstantInt::get(Ty: I32Type, V: 0);
6103 Value *MoreWork = Builder.CreateCmp(Pred: CmpInst::ICMP_NE, LHS: Res, RHS: Zero32);
6104 Value *LowerBound =
6105 Builder.CreateSub(LHS: Builder.CreateLoad(Ty: IVTy, Ptr: PLowerBound), RHS: One, Name: "lb");
6106 Builder.CreateCondBr(Cond: MoreWork, True: Header, False: Exit);
6107
6108 // Change PHI-node in loop header to use outer cond rather than preheader,
6109 // and set IV to the LowerBound.
6110 Instruction *Phi = &Header->front();
6111 auto *PI = cast<PHINode>(Val: Phi);
6112 PI->setIncomingBlock(i: 0, BB: OuterCond);
6113 PI->setIncomingValue(i: 0, V: LowerBound);
6114
6115 // Then set the pre-header to jump to the OuterCond
6116 Instruction *Term = PreHeader->getTerminator();
6117 auto *Br = cast<BranchInst>(Val: Term);
6118 Br->setSuccessor(idx: 0, NewSucc: OuterCond);
6119
6120 // Modify the inner condition:
6121 // * Use the UpperBound returned from the DynamicNext call.
6122 // * jump to the loop outer loop when done with one of the inner loops.
6123 Builder.SetInsertPoint(TheBB: Cond, IP: Cond->getFirstInsertionPt());
6124 UpperBound = Builder.CreateLoad(Ty: IVTy, Ptr: PUpperBound, Name: "ub");
6125 Instruction *Comp = &*Builder.GetInsertPoint();
6126 auto *CI = cast<CmpInst>(Val: Comp);
6127 CI->setOperand(i_nocapture: 1, Val_nocapture: UpperBound);
6128 // Redirect the inner exit to branch to outer condition.
6129 Instruction *Branch = &Cond->back();
6130 auto *BI = cast<BranchInst>(Val: Branch);
6131 assert(BI->getSuccessor(1) == Exit);
6132 BI->setSuccessor(idx: 1, NewSucc: OuterCond);
6133
6134 // Call the "fini" function if "ordered" is present in wsloop directive.
6135 if (Ordered) {
6136 Builder.SetInsertPoint(&Latch->back());
6137 FunctionCallee DynamicFini = getKmpcForDynamicFiniForType(Ty: IVTy, M, OMPBuilder&: *this);
6138 createRuntimeFunctionCall(Callee: DynamicFini, Args: {SrcLoc, ThreadNum});
6139 }
6140
6141 // Add the barrier if requested.
6142 if (NeedsBarrier) {
6143 Builder.SetInsertPoint(&Exit->back());
6144 InsertPointOrErrorTy BarrierIP =
6145 createBarrier(Loc: LocationDescription(Builder.saveIP(), DL),
6146 Kind: omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
6147 /* CheckCancelFlag */ false);
6148 if (!BarrierIP)
6149 return BarrierIP.takeError();
6150 }
6151
6152 CLI->invalidate();
6153 return AfterIP;
6154}
6155
6156/// Redirect all edges that branch to \p OldTarget to \p NewTarget. That is,
6157/// after this \p OldTarget will be orphaned.
6158static void redirectAllPredecessorsTo(BasicBlock *OldTarget,
6159 BasicBlock *NewTarget, DebugLoc DL) {
6160 for (BasicBlock *Pred : make_early_inc_range(Range: predecessors(BB: OldTarget)))
6161 redirectTo(Source: Pred, Target: NewTarget, DL);
6162}
6163
6164/// Determine which blocks in \p BBs are reachable from outside and remove the
6165/// ones that are not reachable from the function.
6166static void removeUnusedBlocksFromParent(ArrayRef<BasicBlock *> BBs) {
6167 SmallPtrSet<BasicBlock *, 6> BBsToErase(llvm::from_range, BBs);
6168 auto HasRemainingUses = [&BBsToErase](BasicBlock *BB) {
6169 for (Use &U : BB->uses()) {
6170 auto *UseInst = dyn_cast<Instruction>(Val: U.getUser());
6171 if (!UseInst)
6172 continue;
6173 if (BBsToErase.count(Ptr: UseInst->getParent()))
6174 continue;
6175 return true;
6176 }
6177 return false;
6178 };
6179
6180 while (BBsToErase.remove_if(P: HasRemainingUses)) {
6181 // Try again if anything was removed.
6182 }
6183
6184 SmallVector<BasicBlock *, 7> BBVec(BBsToErase.begin(), BBsToErase.end());
6185 DeleteDeadBlocks(BBs: BBVec);
6186}
6187
6188CanonicalLoopInfo *
6189OpenMPIRBuilder::collapseLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
6190 InsertPointTy ComputeIP) {
6191 assert(Loops.size() >= 1 && "At least one loop required");
6192 size_t NumLoops = Loops.size();
6193
6194 // Nothing to do if there is already just one loop.
6195 if (NumLoops == 1)
6196 return Loops.front();
6197
6198 CanonicalLoopInfo *Outermost = Loops.front();
6199 CanonicalLoopInfo *Innermost = Loops.back();
6200 BasicBlock *OrigPreheader = Outermost->getPreheader();
6201 BasicBlock *OrigAfter = Outermost->getAfter();
6202 Function *F = OrigPreheader->getParent();
6203
6204 // Loop control blocks that may become orphaned later.
6205 SmallVector<BasicBlock *, 12> OldControlBBs;
6206 OldControlBBs.reserve(N: 6 * Loops.size());
6207 for (CanonicalLoopInfo *Loop : Loops)
6208 Loop->collectControlBlocks(BBs&: OldControlBBs);
6209
6210 // Setup the IRBuilder for inserting the trip count computation.
6211 Builder.SetCurrentDebugLocation(DL);
6212 if (ComputeIP.isSet())
6213 Builder.restoreIP(IP: ComputeIP);
6214 else
6215 Builder.restoreIP(IP: Outermost->getPreheaderIP());
6216
6217 // Derive the collapsed' loop trip count.
6218 // TODO: Find common/largest indvar type.
6219 Value *CollapsedTripCount = nullptr;
6220 for (CanonicalLoopInfo *L : Loops) {
6221 assert(L->isValid() &&
6222 "All loops to collapse must be valid canonical loops");
6223 Value *OrigTripCount = L->getTripCount();
6224 if (!CollapsedTripCount) {
6225 CollapsedTripCount = OrigTripCount;
6226 continue;
6227 }
6228
6229 // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
6230 CollapsedTripCount =
6231 Builder.CreateNUWMul(LHS: CollapsedTripCount, RHS: OrigTripCount);
6232 }
6233
6234 // Create the collapsed loop control flow.
6235 CanonicalLoopInfo *Result =
6236 createLoopSkeleton(DL, TripCount: CollapsedTripCount, F,
6237 PreInsertBefore: OrigPreheader->getNextNode(), PostInsertBefore: OrigAfter, Name: "collapsed");
6238
6239 // Build the collapsed loop body code.
6240 // Start with deriving the input loop induction variables from the collapsed
6241 // one, using a divmod scheme. To preserve the original loops' order, the
6242 // innermost loop use the least significant bits.
6243 Builder.restoreIP(IP: Result->getBodyIP());
6244
6245 Value *Leftover = Result->getIndVar();
6246 SmallVector<Value *> NewIndVars;
6247 NewIndVars.resize(N: NumLoops);
6248 for (int i = NumLoops - 1; i >= 1; --i) {
6249 Value *OrigTripCount = Loops[i]->getTripCount();
6250
6251 Value *NewIndVar = Builder.CreateURem(LHS: Leftover, RHS: OrigTripCount);
6252 NewIndVars[i] = NewIndVar;
6253
6254 Leftover = Builder.CreateUDiv(LHS: Leftover, RHS: OrigTripCount);
6255 }
6256 // Outermost loop gets all the remaining bits.
6257 NewIndVars[0] = Leftover;
6258
6259 // Construct the loop body control flow.
6260 // We progressively construct the branch structure following in direction of
6261 // the control flow, from the leading in-between code, the loop nest body, the
6262 // trailing in-between code, and rejoining the collapsed loop's latch.
6263 // ContinueBlock and ContinuePred keep track of the source(s) of next edge. If
6264 // the ContinueBlock is set, continue with that block. If ContinuePred, use
6265 // its predecessors as sources.
6266 BasicBlock *ContinueBlock = Result->getBody();
6267 BasicBlock *ContinuePred = nullptr;
6268 auto ContinueWith = [&ContinueBlock, &ContinuePred, DL](BasicBlock *Dest,
6269 BasicBlock *NextSrc) {
6270 if (ContinueBlock)
6271 redirectTo(Source: ContinueBlock, Target: Dest, DL);
6272 else
6273 redirectAllPredecessorsTo(OldTarget: ContinuePred, NewTarget: Dest, DL);
6274
6275 ContinueBlock = nullptr;
6276 ContinuePred = NextSrc;
6277 };
6278
6279 // The code before the nested loop of each level.
6280 // Because we are sinking it into the nest, it will be executed more often
6281 // that the original loop. More sophisticated schemes could keep track of what
6282 // the in-between code is and instantiate it only once per thread.
6283 for (size_t i = 0; i < NumLoops - 1; ++i)
6284 ContinueWith(Loops[i]->getBody(), Loops[i + 1]->getHeader());
6285
6286 // Connect the loop nest body.
6287 ContinueWith(Innermost->getBody(), Innermost->getLatch());
6288
6289 // The code after the nested loop at each level.
6290 for (size_t i = NumLoops - 1; i > 0; --i)
6291 ContinueWith(Loops[i]->getAfter(), Loops[i - 1]->getLatch());
6292
6293 // Connect the finished loop to the collapsed loop latch.
6294 ContinueWith(Result->getLatch(), nullptr);
6295
6296 // Replace the input loops with the new collapsed loop.
6297 redirectTo(Source: Outermost->getPreheader(), Target: Result->getPreheader(), DL);
6298 redirectTo(Source: Result->getAfter(), Target: Outermost->getAfter(), DL);
6299
6300 // Replace the input loop indvars with the derived ones.
6301 for (size_t i = 0; i < NumLoops; ++i)
6302 Loops[i]->getIndVar()->replaceAllUsesWith(V: NewIndVars[i]);
6303
6304 // Remove unused parts of the input loops.
6305 removeUnusedBlocksFromParent(BBs: OldControlBBs);
6306
6307 for (CanonicalLoopInfo *L : Loops)
6308 L->invalidate();
6309
6310#ifndef NDEBUG
6311 Result->assertOK();
6312#endif
6313 return Result;
6314}
6315
6316std::vector<CanonicalLoopInfo *>
6317OpenMPIRBuilder::tileLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
6318 ArrayRef<Value *> TileSizes) {
6319 assert(TileSizes.size() == Loops.size() &&
6320 "Must pass as many tile sizes as there are loops");
6321 int NumLoops = Loops.size();
6322 assert(NumLoops >= 1 && "At least one loop to tile required");
6323
6324 CanonicalLoopInfo *OutermostLoop = Loops.front();
6325 CanonicalLoopInfo *InnermostLoop = Loops.back();
6326 Function *F = OutermostLoop->getBody()->getParent();
6327 BasicBlock *InnerEnter = InnermostLoop->getBody();
6328 BasicBlock *InnerLatch = InnermostLoop->getLatch();
6329
6330 // Loop control blocks that may become orphaned later.
6331 SmallVector<BasicBlock *, 12> OldControlBBs;
6332 OldControlBBs.reserve(N: 6 * Loops.size());
6333 for (CanonicalLoopInfo *Loop : Loops)
6334 Loop->collectControlBlocks(BBs&: OldControlBBs);
6335
6336 // Collect original trip counts and induction variable to be accessible by
6337 // index. Also, the structure of the original loops is not preserved during
6338 // the construction of the tiled loops, so do it before we scavenge the BBs of
6339 // any original CanonicalLoopInfo.
6340 SmallVector<Value *, 4> OrigTripCounts, OrigIndVars;
6341 for (CanonicalLoopInfo *L : Loops) {
6342 assert(L->isValid() && "All input loops must be valid canonical loops");
6343 OrigTripCounts.push_back(Elt: L->getTripCount());
6344 OrigIndVars.push_back(Elt: L->getIndVar());
6345 }
6346
6347 // Collect the code between loop headers. These may contain SSA definitions
6348 // that are used in the loop nest body. To be usable with in the innermost
6349 // body, these BasicBlocks will be sunk into the loop nest body. That is,
6350 // these instructions may be executed more often than before the tiling.
6351 // TODO: It would be sufficient to only sink them into body of the
6352 // corresponding tile loop.
6353 SmallVector<std::pair<BasicBlock *, BasicBlock *>, 4> InbetweenCode;
6354 for (int i = 0; i < NumLoops - 1; ++i) {
6355 CanonicalLoopInfo *Surrounding = Loops[i];
6356 CanonicalLoopInfo *Nested = Loops[i + 1];
6357
6358 BasicBlock *EnterBB = Surrounding->getBody();
6359 BasicBlock *ExitBB = Nested->getHeader();
6360 InbetweenCode.emplace_back(Args&: EnterBB, Args&: ExitBB);
6361 }
6362
6363 // Compute the trip counts of the floor loops.
6364 Builder.SetCurrentDebugLocation(DL);
6365 Builder.restoreIP(IP: OutermostLoop->getPreheaderIP());
6366 SmallVector<Value *, 4> FloorCompleteCount, FloorCount, FloorRems;
6367 for (int i = 0; i < NumLoops; ++i) {
6368 Value *TileSize = TileSizes[i];
6369 Value *OrigTripCount = OrigTripCounts[i];
6370 Type *IVType = OrigTripCount->getType();
6371
6372 Value *FloorCompleteTripCount = Builder.CreateUDiv(LHS: OrigTripCount, RHS: TileSize);
6373 Value *FloorTripRem = Builder.CreateURem(LHS: OrigTripCount, RHS: TileSize);
6374
6375 // 0 if tripcount divides the tilesize, 1 otherwise.
6376 // 1 means we need an additional iteration for a partial tile.
6377 //
6378 // Unfortunately we cannot just use the roundup-formula
6379 // (tripcount + tilesize - 1)/tilesize
6380 // because the summation might overflow. We do not want introduce undefined
6381 // behavior when the untiled loop nest did not.
6382 Value *FloorTripOverflow =
6383 Builder.CreateICmpNE(LHS: FloorTripRem, RHS: ConstantInt::get(Ty: IVType, V: 0));
6384
6385 FloorTripOverflow = Builder.CreateZExt(V: FloorTripOverflow, DestTy: IVType);
6386 Value *FloorTripCount =
6387 Builder.CreateAdd(LHS: FloorCompleteTripCount, RHS: FloorTripOverflow,
6388 Name: "omp_floor" + Twine(i) + ".tripcount", HasNUW: true);
6389
6390 // Remember some values for later use.
6391 FloorCompleteCount.push_back(Elt: FloorCompleteTripCount);
6392 FloorCount.push_back(Elt: FloorTripCount);
6393 FloorRems.push_back(Elt: FloorTripRem);
6394 }
6395
6396 // Generate the new loop nest, from the outermost to the innermost.
6397 std::vector<CanonicalLoopInfo *> Result;
6398 Result.reserve(n: NumLoops * 2);
6399
6400 // The basic block of the surrounding loop that enters the nest generated
6401 // loop.
6402 BasicBlock *Enter = OutermostLoop->getPreheader();
6403
6404 // The basic block of the surrounding loop where the inner code should
6405 // continue.
6406 BasicBlock *Continue = OutermostLoop->getAfter();
6407
6408 // Where the next loop basic block should be inserted.
6409 BasicBlock *OutroInsertBefore = InnermostLoop->getExit();
6410
6411 auto EmbeddNewLoop =
6412 [this, DL, F, InnerEnter, &Enter, &Continue, &OutroInsertBefore](
6413 Value *TripCount, const Twine &Name) -> CanonicalLoopInfo * {
6414 CanonicalLoopInfo *EmbeddedLoop = createLoopSkeleton(
6415 DL, TripCount, F, PreInsertBefore: InnerEnter, PostInsertBefore: OutroInsertBefore, Name);
6416 redirectTo(Source: Enter, Target: EmbeddedLoop->getPreheader(), DL);
6417 redirectTo(Source: EmbeddedLoop->getAfter(), Target: Continue, DL);
6418
6419 // Setup the position where the next embedded loop connects to this loop.
6420 Enter = EmbeddedLoop->getBody();
6421 Continue = EmbeddedLoop->getLatch();
6422 OutroInsertBefore = EmbeddedLoop->getLatch();
6423 return EmbeddedLoop;
6424 };
6425
6426 auto EmbeddNewLoops = [&Result, &EmbeddNewLoop](ArrayRef<Value *> TripCounts,
6427 const Twine &NameBase) {
6428 for (auto P : enumerate(First&: TripCounts)) {
6429 CanonicalLoopInfo *EmbeddedLoop =
6430 EmbeddNewLoop(P.value(), NameBase + Twine(P.index()));
6431 Result.push_back(x: EmbeddedLoop);
6432 }
6433 };
6434
6435 EmbeddNewLoops(FloorCount, "floor");
6436
6437 // Within the innermost floor loop, emit the code that computes the tile
6438 // sizes.
6439 Builder.SetInsertPoint(Enter->getTerminator());
6440 SmallVector<Value *, 4> TileCounts;
6441 for (int i = 0; i < NumLoops; ++i) {
6442 CanonicalLoopInfo *FloorLoop = Result[i];
6443 Value *TileSize = TileSizes[i];
6444
6445 Value *FloorIsEpilogue =
6446 Builder.CreateICmpEQ(LHS: FloorLoop->getIndVar(), RHS: FloorCompleteCount[i]);
6447 Value *TileTripCount =
6448 Builder.CreateSelect(C: FloorIsEpilogue, True: FloorRems[i], False: TileSize);
6449
6450 TileCounts.push_back(Elt: TileTripCount);
6451 }
6452
6453 // Create the tile loops.
6454 EmbeddNewLoops(TileCounts, "tile");
6455
6456 // Insert the inbetween code into the body.
6457 BasicBlock *BodyEnter = Enter;
6458 BasicBlock *BodyEntered = nullptr;
6459 for (std::pair<BasicBlock *, BasicBlock *> P : InbetweenCode) {
6460 BasicBlock *EnterBB = P.first;
6461 BasicBlock *ExitBB = P.second;
6462
6463 if (BodyEnter)
6464 redirectTo(Source: BodyEnter, Target: EnterBB, DL);
6465 else
6466 redirectAllPredecessorsTo(OldTarget: BodyEntered, NewTarget: EnterBB, DL);
6467
6468 BodyEnter = nullptr;
6469 BodyEntered = ExitBB;
6470 }
6471
6472 // Append the original loop nest body into the generated loop nest body.
6473 if (BodyEnter)
6474 redirectTo(Source: BodyEnter, Target: InnerEnter, DL);
6475 else
6476 redirectAllPredecessorsTo(OldTarget: BodyEntered, NewTarget: InnerEnter, DL);
6477 redirectAllPredecessorsTo(OldTarget: InnerLatch, NewTarget: Continue, DL);
6478
6479 // Replace the original induction variable with an induction variable computed
6480 // from the tile and floor induction variables.
6481 Builder.restoreIP(IP: Result.back()->getBodyIP());
6482 for (int i = 0; i < NumLoops; ++i) {
6483 CanonicalLoopInfo *FloorLoop = Result[i];
6484 CanonicalLoopInfo *TileLoop = Result[NumLoops + i];
6485 Value *OrigIndVar = OrigIndVars[i];
6486 Value *Size = TileSizes[i];
6487
6488 Value *Scale =
6489 Builder.CreateMul(LHS: Size, RHS: FloorLoop->getIndVar(), Name: {}, /*HasNUW=*/true);
6490 Value *Shift =
6491 Builder.CreateAdd(LHS: Scale, RHS: TileLoop->getIndVar(), Name: {}, /*HasNUW=*/true);
6492 OrigIndVar->replaceAllUsesWith(V: Shift);
6493 }
6494
6495 // Remove unused parts of the original loops.
6496 removeUnusedBlocksFromParent(BBs: OldControlBBs);
6497
6498 for (CanonicalLoopInfo *L : Loops)
6499 L->invalidate();
6500
6501#ifndef NDEBUG
6502 for (CanonicalLoopInfo *GenL : Result)
6503 GenL->assertOK();
6504#endif
6505 return Result;
6506}
6507
6508/// Attach metadata \p Properties to the basic block described by \p BB. If the
6509/// basic block already has metadata, the basic block properties are appended.
6510static void addBasicBlockMetadata(BasicBlock *BB,
6511 ArrayRef<Metadata *> Properties) {
6512 // Nothing to do if no property to attach.
6513 if (Properties.empty())
6514 return;
6515
6516 LLVMContext &Ctx = BB->getContext();
6517 SmallVector<Metadata *> NewProperties;
6518 NewProperties.push_back(Elt: nullptr);
6519
6520 // If the basic block already has metadata, prepend it to the new metadata.
6521 MDNode *Existing = BB->getTerminator()->getMetadata(KindID: LLVMContext::MD_loop);
6522 if (Existing)
6523 append_range(C&: NewProperties, R: drop_begin(RangeOrContainer: Existing->operands(), N: 1));
6524
6525 append_range(C&: NewProperties, R&: Properties);
6526 MDNode *BasicBlockID = MDNode::getDistinct(Context&: Ctx, MDs: NewProperties);
6527 BasicBlockID->replaceOperandWith(I: 0, New: BasicBlockID);
6528
6529 BB->getTerminator()->setMetadata(KindID: LLVMContext::MD_loop, Node: BasicBlockID);
6530}
6531
6532/// Attach loop metadata \p Properties to the loop described by \p Loop. If the
6533/// loop already has metadata, the loop properties are appended.
6534static void addLoopMetadata(CanonicalLoopInfo *Loop,
6535 ArrayRef<Metadata *> Properties) {
6536 assert(Loop->isValid() && "Expecting a valid CanonicalLoopInfo");
6537
6538 // Attach metadata to the loop's latch
6539 BasicBlock *Latch = Loop->getLatch();
6540 assert(Latch && "A valid CanonicalLoopInfo must have a unique latch");
6541 addBasicBlockMetadata(BB: Latch, Properties);
6542}
6543
6544/// Attach llvm.access.group metadata to the memref instructions of \p Block
6545static void addAccessGroupMetadata(BasicBlock *Block, MDNode *AccessGroup,
6546 LoopInfo &LI) {
6547 for (Instruction &I : *Block) {
6548 if (I.mayReadOrWriteMemory()) {
6549 // TODO: This instruction may already have access group from
6550 // other pragmas e.g. #pragma clang loop vectorize. Append
6551 // so that the existing metadata is not overwritten.
6552 I.setMetadata(KindID: LLVMContext::MD_access_group, Node: AccessGroup);
6553 }
6554 }
6555}
6556
6557void OpenMPIRBuilder::unrollLoopFull(DebugLoc, CanonicalLoopInfo *Loop) {
6558 LLVMContext &Ctx = Builder.getContext();
6559 addLoopMetadata(
6560 Loop, Properties: {MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")),
6561 MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.full"))});
6562}
6563
6564void OpenMPIRBuilder::unrollLoopHeuristic(DebugLoc, CanonicalLoopInfo *Loop) {
6565 LLVMContext &Ctx = Builder.getContext();
6566 addLoopMetadata(
6567 Loop, Properties: {
6568 MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")),
6569 });
6570}
6571
6572void OpenMPIRBuilder::createIfVersion(CanonicalLoopInfo *CanonicalLoop,
6573 Value *IfCond, ValueToValueMapTy &VMap,
6574 LoopAnalysis &LIA, LoopInfo &LI, Loop *L,
6575 const Twine &NamePrefix) {
6576 Function *F = CanonicalLoop->getFunction();
6577
6578 // We can't do
6579 // if (cond) {
6580 // simd_loop;
6581 // } else {
6582 // non_simd_loop;
6583 // }
6584 // because then the CanonicalLoopInfo would only point to one of the loops:
6585 // leading to other constructs operating on the same loop to malfunction.
6586 // Instead generate
6587 // while (...) {
6588 // if (cond) {
6589 // simd_body;
6590 // } else {
6591 // not_simd_body;
6592 // }
6593 // }
6594 // At least for simple loops, LLVM seems able to hoist the if out of the loop
6595 // body at -O3
6596
6597 // Define where if branch should be inserted
6598 auto SplitBeforeIt = CanonicalLoop->getBody()->getFirstNonPHIIt();
6599
6600 // Create additional blocks for the if statement
6601 BasicBlock *Cond = SplitBeforeIt->getParent();
6602 llvm::LLVMContext &C = Cond->getContext();
6603 llvm::BasicBlock *ThenBlock = llvm::BasicBlock::Create(
6604 Context&: C, Name: NamePrefix + ".if.then", Parent: Cond->getParent(), InsertBefore: Cond->getNextNode());
6605 llvm::BasicBlock *ElseBlock = llvm::BasicBlock::Create(
6606 Context&: C, Name: NamePrefix + ".if.else", Parent: Cond->getParent(), InsertBefore: CanonicalLoop->getExit());
6607
6608 // Create if condition branch.
6609 Builder.SetInsertPoint(SplitBeforeIt);
6610 Instruction *BrInstr =
6611 Builder.CreateCondBr(Cond: IfCond, True: ThenBlock, /*ifFalse*/ False: ElseBlock);
6612 InsertPointTy IP{BrInstr->getParent(), ++BrInstr->getIterator()};
6613 // Then block contains branch to omp loop body which needs to be vectorized
6614 spliceBB(IP, New: ThenBlock, CreateBranch: false, DL: Builder.getCurrentDebugLocation());
6615 ThenBlock->replaceSuccessorsPhiUsesWith(Old: Cond, New: ThenBlock);
6616
6617 Builder.SetInsertPoint(ElseBlock);
6618
6619 // Clone loop for the else branch
6620 SmallVector<BasicBlock *, 8> NewBlocks;
6621
6622 SmallVector<BasicBlock *, 8> ExistingBlocks;
6623 ExistingBlocks.reserve(N: L->getNumBlocks() + 1);
6624 ExistingBlocks.push_back(Elt: ThenBlock);
6625 ExistingBlocks.append(in_start: L->block_begin(), in_end: L->block_end());
6626 // Cond is the block that has the if clause condition
6627 // LoopCond is omp_loop.cond
6628 // LoopHeader is omp_loop.header
6629 BasicBlock *LoopCond = Cond->getUniquePredecessor();
6630 BasicBlock *LoopHeader = LoopCond->getUniquePredecessor();
6631 assert(LoopCond && LoopHeader && "Invalid loop structure");
6632 for (BasicBlock *Block : ExistingBlocks) {
6633 if (Block == L->getLoopPreheader() || Block == L->getLoopLatch() ||
6634 Block == LoopHeader || Block == LoopCond || Block == Cond) {
6635 continue;
6636 }
6637 BasicBlock *NewBB = CloneBasicBlock(BB: Block, VMap, NameSuffix: "", F);
6638
6639 // fix name not to be omp.if.then
6640 if (Block == ThenBlock)
6641 NewBB->setName(NamePrefix + ".if.else");
6642
6643 NewBB->moveBefore(MovePos: CanonicalLoop->getExit());
6644 VMap[Block] = NewBB;
6645 NewBlocks.push_back(Elt: NewBB);
6646 }
6647 remapInstructionsInBlocks(Blocks: NewBlocks, VMap);
6648 Builder.CreateBr(Dest: NewBlocks.front());
6649
6650 // The loop latch must have only one predecessor. Currently it is branched to
6651 // from both the 'then' and 'else' branches.
6652 L->getLoopLatch()->splitBasicBlockBefore(I: L->getLoopLatch()->begin(),
6653 BBName: NamePrefix + ".pre_latch");
6654
6655 // Ensure that the then block is added to the loop so we add the attributes in
6656 // the next step
6657 L->addBasicBlockToLoop(NewBB: ThenBlock, LI);
6658}
6659
6660unsigned
6661OpenMPIRBuilder::getOpenMPDefaultSimdAlign(const Triple &TargetTriple,
6662 const StringMap<bool> &Features) {
6663 if (TargetTriple.isX86()) {
6664 if (Features.lookup(Key: "avx512f"))
6665 return 512;
6666 else if (Features.lookup(Key: "avx"))
6667 return 256;
6668 return 128;
6669 }
6670 if (TargetTriple.isPPC())
6671 return 128;
6672 if (TargetTriple.isWasm())
6673 return 128;
6674 return 0;
6675}
6676
6677void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
6678 MapVector<Value *, Value *> AlignedVars,
6679 Value *IfCond, OrderKind Order,
6680 ConstantInt *Simdlen, ConstantInt *Safelen) {
6681 LLVMContext &Ctx = Builder.getContext();
6682
6683 Function *F = CanonicalLoop->getFunction();
6684
6685 // TODO: We should not rely on pass manager. Currently we use pass manager
6686 // only for getting llvm::Loop which corresponds to given CanonicalLoopInfo
6687 // object. We should have a method which returns all blocks between
6688 // CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter()
6689 FunctionAnalysisManager FAM;
6690 FAM.registerPass(PassBuilder: []() { return DominatorTreeAnalysis(); });
6691 FAM.registerPass(PassBuilder: []() { return LoopAnalysis(); });
6692 FAM.registerPass(PassBuilder: []() { return PassInstrumentationAnalysis(); });
6693
6694 LoopAnalysis LIA;
6695 LoopInfo &&LI = LIA.run(F&: *F, AM&: FAM);
6696
6697 Loop *L = LI.getLoopFor(BB: CanonicalLoop->getHeader());
6698 if (AlignedVars.size()) {
6699 InsertPointTy IP = Builder.saveIP();
6700 for (auto &AlignedItem : AlignedVars) {
6701 Value *AlignedPtr = AlignedItem.first;
6702 Value *Alignment = AlignedItem.second;
6703 Instruction *loadInst = dyn_cast<Instruction>(Val: AlignedPtr);
6704 Builder.SetInsertPoint(loadInst->getNextNode());
6705 Builder.CreateAlignmentAssumption(DL: F->getDataLayout(), PtrValue: AlignedPtr,
6706 Alignment);
6707 }
6708 Builder.restoreIP(IP);
6709 }
6710
6711 if (IfCond) {
6712 ValueToValueMapTy VMap;
6713 createIfVersion(CanonicalLoop, IfCond, VMap, LIA, LI, L, NamePrefix: "simd");
6714 }
6715
6716 SmallPtrSet<BasicBlock *, 8> Reachable;
6717
6718 // Get the basic blocks from the loop in which memref instructions
6719 // can be found.
6720 // TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
6721 // preferably without running any passes.
6722 for (BasicBlock *Block : L->getBlocks()) {
6723 if (Block == CanonicalLoop->getCond() ||
6724 Block == CanonicalLoop->getHeader())
6725 continue;
6726 Reachable.insert(Ptr: Block);
6727 }
6728
6729 SmallVector<Metadata *> LoopMDList;
6730
6731 // In presence of finite 'safelen', it may be unsafe to mark all
6732 // the memory instructions parallel, because loop-carried
6733 // dependences of 'safelen' iterations are possible.
6734 // If clause order(concurrent) is specified then the memory instructions
6735 // are marked parallel even if 'safelen' is finite.
6736 if ((Safelen == nullptr) || (Order == OrderKind::OMP_ORDER_concurrent))
6737 applyParallelAccessesMetadata(CLI: CanonicalLoop, Ctx, Loop: L, LoopInfo&: LI, LoopMDList);
6738
6739 // FIXME: the IF clause shares a loop backedge for the SIMD and non-SIMD
6740 // versions so we can't add the loop attributes in that case.
6741 if (IfCond) {
6742 // we can still add llvm.loop.parallel_access
6743 addLoopMetadata(Loop: CanonicalLoop, Properties: LoopMDList);
6744 return;
6745 }
6746
6747 // Use the above access group metadata to create loop level
6748 // metadata, which should be distinct for each loop.
6749 ConstantAsMetadata *BoolConst =
6750 ConstantAsMetadata::get(C: ConstantInt::getTrue(Ty: Type::getInt1Ty(C&: Ctx)));
6751 LoopMDList.push_back(Elt: MDNode::get(
6752 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.vectorize.enable"), BoolConst}));
6753
6754 if (Simdlen || Safelen) {
6755 // If both simdlen and safelen clauses are specified, the value of the
6756 // simdlen parameter must be less than or equal to the value of the safelen
6757 // parameter. Therefore, use safelen only in the absence of simdlen.
6758 ConstantInt *VectorizeWidth = Simdlen == nullptr ? Safelen : Simdlen;
6759 LoopMDList.push_back(
6760 Elt: MDNode::get(Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.vectorize.width"),
6761 ConstantAsMetadata::get(C: VectorizeWidth)}));
6762 }
6763
6764 addLoopMetadata(Loop: CanonicalLoop, Properties: LoopMDList);
6765}
6766
6767/// Create the TargetMachine object to query the backend for optimization
6768/// preferences.
6769///
6770/// Ideally, this would be passed from the front-end to the OpenMPBuilder, but
6771/// e.g. Clang does not pass it to its CodeGen layer and creates it only when
6772/// needed for the LLVM pass pipline. We use some default options to avoid
6773/// having to pass too many settings from the frontend that probably do not
6774/// matter.
6775///
6776/// Currently, TargetMachine is only used sometimes by the unrollLoopPartial
6777/// method. If we are going to use TargetMachine for more purposes, especially
6778/// those that are sensitive to TargetOptions, RelocModel and CodeModel, it
6779/// might become be worth requiring front-ends to pass on their TargetMachine,
6780/// or at least cache it between methods. Note that while fontends such as Clang
6781/// have just a single main TargetMachine per translation unit, "target-cpu" and
6782/// "target-features" that determine the TargetMachine are per-function and can
6783/// be overrided using __attribute__((target("OPTIONS"))).
6784static std::unique_ptr<TargetMachine>
6785createTargetMachine(Function *F, CodeGenOptLevel OptLevel) {
6786 Module *M = F->getParent();
6787
6788 StringRef CPU = F->getFnAttribute(Kind: "target-cpu").getValueAsString();
6789 StringRef Features = F->getFnAttribute(Kind: "target-features").getValueAsString();
6790 const llvm::Triple &Triple = M->getTargetTriple();
6791
6792 std::string Error;
6793 const llvm::Target *TheTarget = TargetRegistry::lookupTarget(TheTriple: Triple, Error);
6794 if (!TheTarget)
6795 return {};
6796
6797 llvm::TargetOptions Options;
6798 return std::unique_ptr<TargetMachine>(TheTarget->createTargetMachine(
6799 TT: Triple, CPU, Features, Options, /*RelocModel=*/RM: std::nullopt,
6800 /*CodeModel=*/CM: std::nullopt, OL: OptLevel));
6801}
6802
6803/// Heuristically determine the best-performant unroll factor for \p CLI. This
6804/// depends on the target processor. We are re-using the same heuristics as the
6805/// LoopUnrollPass.
6806static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) {
6807 Function *F = CLI->getFunction();
6808
6809 // Assume the user requests the most aggressive unrolling, even if the rest of
6810 // the code is optimized using a lower setting.
6811 CodeGenOptLevel OptLevel = CodeGenOptLevel::Aggressive;
6812 std::unique_ptr<TargetMachine> TM = createTargetMachine(F, OptLevel);
6813
6814 FunctionAnalysisManager FAM;
6815 FAM.registerPass(PassBuilder: []() { return TargetLibraryAnalysis(); });
6816 FAM.registerPass(PassBuilder: []() { return AssumptionAnalysis(); });
6817 FAM.registerPass(PassBuilder: []() { return DominatorTreeAnalysis(); });
6818 FAM.registerPass(PassBuilder: []() { return LoopAnalysis(); });
6819 FAM.registerPass(PassBuilder: []() { return ScalarEvolutionAnalysis(); });
6820 FAM.registerPass(PassBuilder: []() { return PassInstrumentationAnalysis(); });
6821 TargetIRAnalysis TIRA;
6822 if (TM)
6823 TIRA = TargetIRAnalysis(
6824 [&](const Function &F) { return TM->getTargetTransformInfo(F); });
6825 FAM.registerPass(PassBuilder: [&]() { return TIRA; });
6826
6827 TargetIRAnalysis::Result &&TTI = TIRA.run(F: *F, FAM);
6828 ScalarEvolutionAnalysis SEA;
6829 ScalarEvolution &&SE = SEA.run(F&: *F, AM&: FAM);
6830 DominatorTreeAnalysis DTA;
6831 DominatorTree &&DT = DTA.run(F&: *F, FAM);
6832 LoopAnalysis LIA;
6833 LoopInfo &&LI = LIA.run(F&: *F, AM&: FAM);
6834 AssumptionAnalysis ACT;
6835 AssumptionCache &&AC = ACT.run(F&: *F, FAM);
6836 OptimizationRemarkEmitter ORE{F};
6837
6838 Loop *L = LI.getLoopFor(BB: CLI->getHeader());
6839 assert(L && "Expecting CanonicalLoopInfo to be recognized as a loop");
6840
6841 TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences(
6842 L, SE, TTI,
6843 /*BlockFrequencyInfo=*/BFI: nullptr,
6844 /*ProfileSummaryInfo=*/PSI: nullptr, ORE, OptLevel: static_cast<int>(OptLevel),
6845 /*UserThreshold=*/std::nullopt,
6846 /*UserCount=*/std::nullopt,
6847 /*UserAllowPartial=*/true,
6848 /*UserAllowRuntime=*/UserRuntime: true,
6849 /*UserUpperBound=*/std::nullopt,
6850 /*UserFullUnrollMaxCount=*/std::nullopt);
6851
6852 UP.Force = true;
6853
6854 // Account for additional optimizations taking place before the LoopUnrollPass
6855 // would unroll the loop.
6856 UP.Threshold *= UnrollThresholdFactor;
6857 UP.PartialThreshold *= UnrollThresholdFactor;
6858
6859 // Use normal unroll factors even if the rest of the code is optimized for
6860 // size.
6861 UP.OptSizeThreshold = UP.Threshold;
6862 UP.PartialOptSizeThreshold = UP.PartialThreshold;
6863
6864 LLVM_DEBUG(dbgs() << "Unroll heuristic thresholds:\n"
6865 << " Threshold=" << UP.Threshold << "\n"
6866 << " PartialThreshold=" << UP.PartialThreshold << "\n"
6867 << " OptSizeThreshold=" << UP.OptSizeThreshold << "\n"
6868 << " PartialOptSizeThreshold="
6869 << UP.PartialOptSizeThreshold << "\n");
6870
6871 // Disable peeling.
6872 TargetTransformInfo::PeelingPreferences PP =
6873 gatherPeelingPreferences(L, SE, TTI,
6874 /*UserAllowPeeling=*/false,
6875 /*UserAllowProfileBasedPeeling=*/false,
6876 /*UnrollingSpecficValues=*/false);
6877
6878 SmallPtrSet<const Value *, 32> EphValues;
6879 CodeMetrics::collectEphemeralValues(L, AC: &AC, EphValues);
6880
6881 // Assume that reads and writes to stack variables can be eliminated by
6882 // Mem2Reg, SROA or LICM. That is, don't count them towards the loop body's
6883 // size.
6884 for (BasicBlock *BB : L->blocks()) {
6885 for (Instruction &I : *BB) {
6886 Value *Ptr;
6887 if (auto *Load = dyn_cast<LoadInst>(Val: &I)) {
6888 Ptr = Load->getPointerOperand();
6889 } else if (auto *Store = dyn_cast<StoreInst>(Val: &I)) {
6890 Ptr = Store->getPointerOperand();
6891 } else
6892 continue;
6893
6894 Ptr = Ptr->stripPointerCasts();
6895
6896 if (auto *Alloca = dyn_cast<AllocaInst>(Val: Ptr)) {
6897 if (Alloca->getParent() == &F->getEntryBlock())
6898 EphValues.insert(Ptr: &I);
6899 }
6900 }
6901 }
6902
6903 UnrollCostEstimator UCE(L, TTI, EphValues, UP.BEInsns);
6904
6905 // Loop is not unrollable if the loop contains certain instructions.
6906 if (!UCE.canUnroll()) {
6907 LLVM_DEBUG(dbgs() << "Loop not considered unrollable\n");
6908 return 1;
6909 }
6910
6911 LLVM_DEBUG(dbgs() << "Estimated loop size is " << UCE.getRolledLoopSize()
6912 << "\n");
6913
6914 // TODO: Determine trip count of \p CLI if constant, computeUnrollCount might
6915 // be able to use it.
6916 int TripCount = 0;
6917 int MaxTripCount = 0;
6918 bool MaxOrZero = false;
6919 unsigned TripMultiple = 0;
6920
6921 bool UseUpperBound = false;
6922 computeUnrollCount(L, TTI, DT, LI: &LI, AC: &AC, SE, EphValues, ORE: &ORE, TripCount,
6923 MaxTripCount, MaxOrZero, TripMultiple, UCE, UP, PP,
6924 UseUpperBound);
6925 unsigned Factor = UP.Count;
6926 LLVM_DEBUG(dbgs() << "Suggesting unroll factor of " << Factor << "\n");
6927
6928 // This function returns 1 to signal to not unroll a loop.
6929 if (Factor == 0)
6930 return 1;
6931 return Factor;
6932}
6933
6934void OpenMPIRBuilder::unrollLoopPartial(DebugLoc DL, CanonicalLoopInfo *Loop,
6935 int32_t Factor,
6936 CanonicalLoopInfo **UnrolledCLI) {
6937 assert(Factor >= 0 && "Unroll factor must not be negative");
6938
6939 Function *F = Loop->getFunction();
6940 LLVMContext &Ctx = F->getContext();
6941
6942 // If the unrolled loop is not used for another loop-associated directive, it
6943 // is sufficient to add metadata for the LoopUnrollPass.
6944 if (!UnrolledCLI) {
6945 SmallVector<Metadata *, 2> LoopMetadata;
6946 LoopMetadata.push_back(
6947 Elt: MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")));
6948
6949 if (Factor >= 1) {
6950 ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
6951 C: ConstantInt::get(Ty: Type::getInt32Ty(C&: Ctx), V: APInt(32, Factor)));
6952 LoopMetadata.push_back(Elt: MDNode::get(
6953 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.count"), FactorConst}));
6954 }
6955
6956 addLoopMetadata(Loop, Properties: LoopMetadata);
6957 return;
6958 }
6959
6960 // Heuristically determine the unroll factor.
6961 if (Factor == 0)
6962 Factor = computeHeuristicUnrollFactor(CLI: Loop);
6963
6964 // No change required with unroll factor 1.
6965 if (Factor == 1) {
6966 *UnrolledCLI = Loop;
6967 return;
6968 }
6969
6970 assert(Factor >= 2 &&
6971 "unrolling only makes sense with a factor of 2 or larger");
6972
6973 Type *IndVarTy = Loop->getIndVarType();
6974
6975 // Apply partial unrolling by tiling the loop by the unroll-factor, then fully
6976 // unroll the inner loop.
6977 Value *FactorVal =
6978 ConstantInt::get(Ty: IndVarTy, V: APInt(IndVarTy->getIntegerBitWidth(), Factor,
6979 /*isSigned=*/false));
6980 std::vector<CanonicalLoopInfo *> LoopNest =
6981 tileLoops(DL, Loops: {Loop}, TileSizes: {FactorVal});
6982 assert(LoopNest.size() == 2 && "Expect 2 loops after tiling");
6983 *UnrolledCLI = LoopNest[0];
6984 CanonicalLoopInfo *InnerLoop = LoopNest[1];
6985
6986 // LoopUnrollPass can only fully unroll loops with constant trip count.
6987 // Unroll by the unroll factor with a fallback epilog for the remainder
6988 // iterations if necessary.
6989 ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
6990 C: ConstantInt::get(Ty: Type::getInt32Ty(C&: Ctx), V: APInt(32, Factor)));
6991 addLoopMetadata(
6992 Loop: InnerLoop,
6993 Properties: {MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")),
6994 MDNode::get(
6995 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.count"), FactorConst})});
6996
6997#ifndef NDEBUG
6998 (*UnrolledCLI)->assertOK();
6999#endif
7000}
7001
7002OpenMPIRBuilder::InsertPointTy
7003OpenMPIRBuilder::createCopyPrivate(const LocationDescription &Loc,
7004 llvm::Value *BufSize, llvm::Value *CpyBuf,
7005 llvm::Value *CpyFn, llvm::Value *DidIt) {
7006 if (!updateToLocation(Loc))
7007 return Loc.IP;
7008
7009 uint32_t SrcLocStrSize;
7010 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7011 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7012 Value *ThreadId = getOrCreateThreadID(Ident);
7013
7014 llvm::Value *DidItLD = Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: DidIt);
7015
7016 Value *Args[] = {Ident, ThreadId, BufSize, CpyBuf, CpyFn, DidItLD};
7017
7018 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_copyprivate);
7019 createRuntimeFunctionCall(Callee: Fn, Args);
7020
7021 return Builder.saveIP();
7022}
7023
7024OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createSingle(
7025 const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
7026 FinalizeCallbackTy FiniCB, bool IsNowait, ArrayRef<llvm::Value *> CPVars,
7027 ArrayRef<llvm::Function *> CPFuncs) {
7028
7029 if (!updateToLocation(Loc))
7030 return Loc.IP;
7031
7032 // If needed allocate and initialize `DidIt` with 0.
7033 // DidIt: flag variable: 1=single thread; 0=not single thread.
7034 llvm::Value *DidIt = nullptr;
7035 if (!CPVars.empty()) {
7036 DidIt = Builder.CreateAlloca(Ty: llvm::Type::getInt32Ty(C&: Builder.getContext()));
7037 Builder.CreateStore(Val: Builder.getInt32(C: 0), Ptr: DidIt);
7038 }
7039
7040 Directive OMPD = Directive::OMPD_single;
7041 uint32_t SrcLocStrSize;
7042 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7043 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7044 Value *ThreadId = getOrCreateThreadID(Ident);
7045 Value *Args[] = {Ident, ThreadId};
7046
7047 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_single);
7048 Instruction *EntryCall = createRuntimeFunctionCall(Callee: EntryRTLFn, Args);
7049
7050 Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_single);
7051 Instruction *ExitCall = createRuntimeFunctionCall(Callee: ExitRTLFn, Args);
7052
7053 auto FiniCBWrapper = [&](InsertPointTy IP) -> Error {
7054 if (Error Err = FiniCB(IP))
7055 return Err;
7056
7057 // The thread that executes the single region must set `DidIt` to 1.
7058 // This is used by __kmpc_copyprivate, to know if the caller is the
7059 // single thread or not.
7060 if (DidIt)
7061 Builder.CreateStore(Val: Builder.getInt32(C: 1), Ptr: DidIt);
7062
7063 return Error::success();
7064 };
7065
7066 // generates the following:
7067 // if (__kmpc_single()) {
7068 // .... single region ...
7069 // __kmpc_end_single
7070 // }
7071 // __kmpc_copyprivate
7072 // __kmpc_barrier
7073
7074 InsertPointOrErrorTy AfterIP =
7075 EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB: FiniCBWrapper,
7076 /*Conditional*/ true,
7077 /*hasFinalize*/ HasFinalize: true);
7078 if (!AfterIP)
7079 return AfterIP.takeError();
7080
7081 if (DidIt) {
7082 for (size_t I = 0, E = CPVars.size(); I < E; ++I)
7083 // NOTE BufSize is currently unused, so just pass 0.
7084 createCopyPrivate(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
7085 /*BufSize=*/ConstantInt::get(Ty: Int64, V: 0), CpyBuf: CPVars[I],
7086 CpyFn: CPFuncs[I], DidIt);
7087 // NOTE __kmpc_copyprivate already inserts a barrier
7088 } else if (!IsNowait) {
7089 InsertPointOrErrorTy AfterIP =
7090 createBarrier(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
7091 Kind: omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
7092 /* CheckCancelFlag */ false);
7093 if (!AfterIP)
7094 return AfterIP.takeError();
7095 }
7096 return Builder.saveIP();
7097}
7098
7099OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createCritical(
7100 const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
7101 FinalizeCallbackTy FiniCB, StringRef CriticalName, Value *HintInst) {
7102
7103 if (!updateToLocation(Loc))
7104 return Loc.IP;
7105
7106 Directive OMPD = Directive::OMPD_critical;
7107 uint32_t SrcLocStrSize;
7108 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7109 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7110 Value *ThreadId = getOrCreateThreadID(Ident);
7111 Value *LockVar = getOMPCriticalRegionLock(CriticalName);
7112 Value *Args[] = {Ident, ThreadId, LockVar};
7113
7114 SmallVector<llvm::Value *, 4> EnterArgs(std::begin(arr&: Args), std::end(arr&: Args));
7115 Function *RTFn = nullptr;
7116 if (HintInst) {
7117 // Add Hint to entry Args and create call
7118 EnterArgs.push_back(Elt: HintInst);
7119 RTFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_critical_with_hint);
7120 } else {
7121 RTFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_critical);
7122 }
7123 Instruction *EntryCall = createRuntimeFunctionCall(Callee: RTFn, Args: EnterArgs);
7124
7125 Function *ExitRTLFn =
7126 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_critical);
7127 Instruction *ExitCall = createRuntimeFunctionCall(Callee: ExitRTLFn, Args);
7128
7129 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
7130 /*Conditional*/ false, /*hasFinalize*/ HasFinalize: true);
7131}
7132
7133OpenMPIRBuilder::InsertPointTy
7134OpenMPIRBuilder::createOrderedDepend(const LocationDescription &Loc,
7135 InsertPointTy AllocaIP, unsigned NumLoops,
7136 ArrayRef<llvm::Value *> StoreValues,
7137 const Twine &Name, bool IsDependSource) {
7138 assert(
7139 llvm::all_of(StoreValues,
7140 [](Value *SV) { return SV->getType()->isIntegerTy(64); }) &&
7141 "OpenMP runtime requires depend vec with i64 type");
7142
7143 if (!updateToLocation(Loc))
7144 return Loc.IP;
7145
7146 // Allocate space for vector and generate alloc instruction.
7147 auto *ArrI64Ty = ArrayType::get(ElementType: Int64, NumElements: NumLoops);
7148 Builder.restoreIP(IP: AllocaIP);
7149 AllocaInst *ArgsBase = Builder.CreateAlloca(Ty: ArrI64Ty, ArraySize: nullptr, Name);
7150 ArgsBase->setAlignment(Align(8));
7151 updateToLocation(Loc);
7152
7153 // Store the index value with offset in depend vector.
7154 for (unsigned I = 0; I < NumLoops; ++I) {
7155 Value *DependAddrGEPIter = Builder.CreateInBoundsGEP(
7156 Ty: ArrI64Ty, Ptr: ArgsBase, IdxList: {Builder.getInt64(C: 0), Builder.getInt64(C: I)});
7157 StoreInst *STInst = Builder.CreateStore(Val: StoreValues[I], Ptr: DependAddrGEPIter);
7158 STInst->setAlignment(Align(8));
7159 }
7160
7161 Value *DependBaseAddrGEP = Builder.CreateInBoundsGEP(
7162 Ty: ArrI64Ty, Ptr: ArgsBase, IdxList: {Builder.getInt64(C: 0), Builder.getInt64(C: 0)});
7163
7164 uint32_t SrcLocStrSize;
7165 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7166 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7167 Value *ThreadId = getOrCreateThreadID(Ident);
7168 Value *Args[] = {Ident, ThreadId, DependBaseAddrGEP};
7169
7170 Function *RTLFn = nullptr;
7171 if (IsDependSource)
7172 RTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_doacross_post);
7173 else
7174 RTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_doacross_wait);
7175 createRuntimeFunctionCall(Callee: RTLFn, Args);
7176
7177 return Builder.saveIP();
7178}
7179
7180OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createOrderedThreadsSimd(
7181 const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
7182 FinalizeCallbackTy FiniCB, bool IsThreads) {
7183 if (!updateToLocation(Loc))
7184 return Loc.IP;
7185
7186 Directive OMPD = Directive::OMPD_ordered;
7187 Instruction *EntryCall = nullptr;
7188 Instruction *ExitCall = nullptr;
7189
7190 if (IsThreads) {
7191 uint32_t SrcLocStrSize;
7192 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7193 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7194 Value *ThreadId = getOrCreateThreadID(Ident);
7195 Value *Args[] = {Ident, ThreadId};
7196
7197 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_ordered);
7198 EntryCall = createRuntimeFunctionCall(Callee: EntryRTLFn, Args);
7199
7200 Function *ExitRTLFn =
7201 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_ordered);
7202 ExitCall = createRuntimeFunctionCall(Callee: ExitRTLFn, Args);
7203 }
7204
7205 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
7206 /*Conditional*/ false, /*hasFinalize*/ HasFinalize: true);
7207}
7208
7209OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::EmitOMPInlinedRegion(
7210 Directive OMPD, Instruction *EntryCall, Instruction *ExitCall,
7211 BodyGenCallbackTy BodyGenCB, FinalizeCallbackTy FiniCB, bool Conditional,
7212 bool HasFinalize, bool IsCancellable) {
7213
7214 if (HasFinalize)
7215 FinalizationStack.push_back(Elt: {FiniCB, OMPD, IsCancellable});
7216
7217 // Create inlined region's entry and body blocks, in preparation
7218 // for conditional creation
7219 BasicBlock *EntryBB = Builder.GetInsertBlock();
7220 Instruction *SplitPos = EntryBB->getTerminator();
7221 if (!isa_and_nonnull<BranchInst>(Val: SplitPos))
7222 SplitPos = new UnreachableInst(Builder.getContext(), EntryBB);
7223 BasicBlock *ExitBB = EntryBB->splitBasicBlock(I: SplitPos, BBName: "omp_region.end");
7224 BasicBlock *FiniBB =
7225 EntryBB->splitBasicBlock(I: EntryBB->getTerminator(), BBName: "omp_region.finalize");
7226
7227 Builder.SetInsertPoint(EntryBB->getTerminator());
7228 emitCommonDirectiveEntry(OMPD, EntryCall, ExitBB, Conditional);
7229
7230 // generate body
7231 if (Error Err = BodyGenCB(/* AllocaIP */ InsertPointTy(),
7232 /* CodeGenIP */ Builder.saveIP()))
7233 return Err;
7234
7235 // emit exit call and do any needed finalization.
7236 auto FinIP = InsertPointTy(FiniBB, FiniBB->getFirstInsertionPt());
7237 assert(FiniBB->getTerminator()->getNumSuccessors() == 1 &&
7238 FiniBB->getTerminator()->getSuccessor(0) == ExitBB &&
7239 "Unexpected control flow graph state!!");
7240 InsertPointOrErrorTy AfterIP =
7241 emitCommonDirectiveExit(OMPD, FinIP, ExitCall, HasFinalize);
7242 if (!AfterIP)
7243 return AfterIP.takeError();
7244
7245 // If we are skipping the region of a non conditional, remove the exit
7246 // block, and clear the builder's insertion point.
7247 assert(SplitPos->getParent() == ExitBB &&
7248 "Unexpected Insertion point location!");
7249 auto merged = MergeBlockIntoPredecessor(BB: ExitBB);
7250 BasicBlock *ExitPredBB = SplitPos->getParent();
7251 auto InsertBB = merged ? ExitPredBB : ExitBB;
7252 if (!isa_and_nonnull<BranchInst>(Val: SplitPos))
7253 SplitPos->eraseFromParent();
7254 Builder.SetInsertPoint(InsertBB);
7255
7256 return Builder.saveIP();
7257}
7258
7259OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveEntry(
7260 Directive OMPD, Value *EntryCall, BasicBlock *ExitBB, bool Conditional) {
7261 // if nothing to do, Return current insertion point.
7262 if (!Conditional || !EntryCall)
7263 return Builder.saveIP();
7264
7265 BasicBlock *EntryBB = Builder.GetInsertBlock();
7266 Value *CallBool = Builder.CreateIsNotNull(Arg: EntryCall);
7267 auto *ThenBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp_region.body");
7268 auto *UI = new UnreachableInst(Builder.getContext(), ThenBB);
7269
7270 // Emit thenBB and set the Builder's insertion point there for
7271 // body generation next. Place the block after the current block.
7272 Function *CurFn = EntryBB->getParent();
7273 CurFn->insert(Position: std::next(x: EntryBB->getIterator()), BB: ThenBB);
7274
7275 // Move Entry branch to end of ThenBB, and replace with conditional
7276 // branch (If-stmt)
7277 Instruction *EntryBBTI = EntryBB->getTerminator();
7278 Builder.CreateCondBr(Cond: CallBool, True: ThenBB, False: ExitBB);
7279 EntryBBTI->removeFromParent();
7280 Builder.SetInsertPoint(UI);
7281 Builder.Insert(I: EntryBBTI);
7282 UI->eraseFromParent();
7283 Builder.SetInsertPoint(ThenBB->getTerminator());
7284
7285 // return an insertion point to ExitBB.
7286 return IRBuilder<>::InsertPoint(ExitBB, ExitBB->getFirstInsertionPt());
7287}
7288
7289OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitCommonDirectiveExit(
7290 omp::Directive OMPD, InsertPointTy FinIP, Instruction *ExitCall,
7291 bool HasFinalize) {
7292
7293 Builder.restoreIP(IP: FinIP);
7294
7295 // If there is finalization to do, emit it before the exit call
7296 if (HasFinalize) {
7297 assert(!FinalizationStack.empty() &&
7298 "Unexpected finalization stack state!");
7299
7300 FinalizationInfo Fi = FinalizationStack.pop_back_val();
7301 assert(Fi.DK == OMPD && "Unexpected Directive for Finalization call!");
7302
7303 if (Error Err = Fi.mergeFiniBB(Builder, OtherFiniBB: FinIP.getBlock()))
7304 return std::move(Err);
7305
7306 // Exit condition: insertion point is before the terminator of the new Fini
7307 // block
7308 Builder.SetInsertPoint(FinIP.getBlock()->getTerminator());
7309 }
7310
7311 if (!ExitCall)
7312 return Builder.saveIP();
7313
7314 // place the Exitcall as last instruction before Finalization block terminator
7315 ExitCall->removeFromParent();
7316 Builder.Insert(I: ExitCall);
7317
7318 return IRBuilder<>::InsertPoint(ExitCall->getParent(),
7319 ExitCall->getIterator());
7320}
7321
7322OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCopyinClauseBlocks(
7323 InsertPointTy IP, Value *MasterAddr, Value *PrivateAddr,
7324 llvm::IntegerType *IntPtrTy, bool BranchtoEnd) {
7325 if (!IP.isSet())
7326 return IP;
7327
7328 IRBuilder<>::InsertPointGuard IPG(Builder);
7329
7330 // creates the following CFG structure
7331 // OMP_Entry : (MasterAddr != PrivateAddr)?
7332 // F T
7333 // | \
7334 // | copin.not.master
7335 // | /
7336 // v /
7337 // copyin.not.master.end
7338 // |
7339 // v
7340 // OMP.Entry.Next
7341
7342 BasicBlock *OMP_Entry = IP.getBlock();
7343 Function *CurFn = OMP_Entry->getParent();
7344 BasicBlock *CopyBegin =
7345 BasicBlock::Create(Context&: M.getContext(), Name: "copyin.not.master", Parent: CurFn);
7346 BasicBlock *CopyEnd = nullptr;
7347
7348 // If entry block is terminated, split to preserve the branch to following
7349 // basic block (i.e. OMP.Entry.Next), otherwise, leave everything as is.
7350 if (isa_and_nonnull<BranchInst>(Val: OMP_Entry->getTerminator())) {
7351 CopyEnd = OMP_Entry->splitBasicBlock(I: OMP_Entry->getTerminator(),
7352 BBName: "copyin.not.master.end");
7353 OMP_Entry->getTerminator()->eraseFromParent();
7354 } else {
7355 CopyEnd =
7356 BasicBlock::Create(Context&: M.getContext(), Name: "copyin.not.master.end", Parent: CurFn);
7357 }
7358
7359 Builder.SetInsertPoint(OMP_Entry);
7360 Value *MasterPtr = Builder.CreatePtrToInt(V: MasterAddr, DestTy: IntPtrTy);
7361 Value *PrivatePtr = Builder.CreatePtrToInt(V: PrivateAddr, DestTy: IntPtrTy);
7362 Value *cmp = Builder.CreateICmpNE(LHS: MasterPtr, RHS: PrivatePtr);
7363 Builder.CreateCondBr(Cond: cmp, True: CopyBegin, False: CopyEnd);
7364
7365 Builder.SetInsertPoint(CopyBegin);
7366 if (BranchtoEnd)
7367 Builder.SetInsertPoint(Builder.CreateBr(Dest: CopyEnd));
7368
7369 return Builder.saveIP();
7370}
7371
7372CallInst *OpenMPIRBuilder::createOMPAlloc(const LocationDescription &Loc,
7373 Value *Size, Value *Allocator,
7374 std::string Name) {
7375 IRBuilder<>::InsertPointGuard IPG(Builder);
7376 updateToLocation(Loc);
7377
7378 uint32_t SrcLocStrSize;
7379 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7380 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7381 Value *ThreadId = getOrCreateThreadID(Ident);
7382 Value *Args[] = {ThreadId, Size, Allocator};
7383
7384 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_alloc);
7385
7386 return createRuntimeFunctionCall(Callee: Fn, Args, Name);
7387}
7388
7389CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc,
7390 Value *Addr, Value *Allocator,
7391 std::string Name) {
7392 IRBuilder<>::InsertPointGuard IPG(Builder);
7393 updateToLocation(Loc);
7394
7395 uint32_t SrcLocStrSize;
7396 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7397 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7398 Value *ThreadId = getOrCreateThreadID(Ident);
7399 Value *Args[] = {ThreadId, Addr, Allocator};
7400 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_free);
7401 return createRuntimeFunctionCall(Callee: Fn, Args, Name);
7402}
7403
7404CallInst *OpenMPIRBuilder::createOMPInteropInit(
7405 const LocationDescription &Loc, Value *InteropVar,
7406 omp::OMPInteropType InteropType, Value *Device, Value *NumDependences,
7407 Value *DependenceAddress, bool HaveNowaitClause) {
7408 IRBuilder<>::InsertPointGuard IPG(Builder);
7409 updateToLocation(Loc);
7410
7411 uint32_t SrcLocStrSize;
7412 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7413 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7414 Value *ThreadId = getOrCreateThreadID(Ident);
7415 if (Device == nullptr)
7416 Device = Constant::getAllOnesValue(Ty: Int32);
7417 Constant *InteropTypeVal = ConstantInt::get(Ty: Int32, V: (int)InteropType);
7418 if (NumDependences == nullptr) {
7419 NumDependences = ConstantInt::get(Ty: Int32, V: 0);
7420 PointerType *PointerTypeVar = PointerType::getUnqual(C&: M.getContext());
7421 DependenceAddress = ConstantPointerNull::get(T: PointerTypeVar);
7422 }
7423 Value *HaveNowaitClauseVal = ConstantInt::get(Ty: Int32, V: HaveNowaitClause);
7424 Value *Args[] = {
7425 Ident, ThreadId, InteropVar, InteropTypeVal,
7426 Device, NumDependences, DependenceAddress, HaveNowaitClauseVal};
7427
7428 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___tgt_interop_init);
7429
7430 return createRuntimeFunctionCall(Callee: Fn, Args);
7431}
7432
7433CallInst *OpenMPIRBuilder::createOMPInteropDestroy(
7434 const LocationDescription &Loc, Value *InteropVar, Value *Device,
7435 Value *NumDependences, Value *DependenceAddress, bool HaveNowaitClause) {
7436 IRBuilder<>::InsertPointGuard IPG(Builder);
7437 updateToLocation(Loc);
7438
7439 uint32_t SrcLocStrSize;
7440 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7441 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7442 Value *ThreadId = getOrCreateThreadID(Ident);
7443 if (Device == nullptr)
7444 Device = Constant::getAllOnesValue(Ty: Int32);
7445 if (NumDependences == nullptr) {
7446 NumDependences = ConstantInt::get(Ty: Int32, V: 0);
7447 PointerType *PointerTypeVar = PointerType::getUnqual(C&: M.getContext());
7448 DependenceAddress = ConstantPointerNull::get(T: PointerTypeVar);
7449 }
7450 Value *HaveNowaitClauseVal = ConstantInt::get(Ty: Int32, V: HaveNowaitClause);
7451 Value *Args[] = {
7452 Ident, ThreadId, InteropVar, Device,
7453 NumDependences, DependenceAddress, HaveNowaitClauseVal};
7454
7455 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___tgt_interop_destroy);
7456
7457 return createRuntimeFunctionCall(Callee: Fn, Args);
7458}
7459
7460CallInst *OpenMPIRBuilder::createOMPInteropUse(const LocationDescription &Loc,
7461 Value *InteropVar, Value *Device,
7462 Value *NumDependences,
7463 Value *DependenceAddress,
7464 bool HaveNowaitClause) {
7465 IRBuilder<>::InsertPointGuard IPG(Builder);
7466 updateToLocation(Loc);
7467 uint32_t SrcLocStrSize;
7468 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7469 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7470 Value *ThreadId = getOrCreateThreadID(Ident);
7471 if (Device == nullptr)
7472 Device = Constant::getAllOnesValue(Ty: Int32);
7473 if (NumDependences == nullptr) {
7474 NumDependences = ConstantInt::get(Ty: Int32, V: 0);
7475 PointerType *PointerTypeVar = PointerType::getUnqual(C&: M.getContext());
7476 DependenceAddress = ConstantPointerNull::get(T: PointerTypeVar);
7477 }
7478 Value *HaveNowaitClauseVal = ConstantInt::get(Ty: Int32, V: HaveNowaitClause);
7479 Value *Args[] = {
7480 Ident, ThreadId, InteropVar, Device,
7481 NumDependences, DependenceAddress, HaveNowaitClauseVal};
7482
7483 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___tgt_interop_use);
7484
7485 return createRuntimeFunctionCall(Callee: Fn, Args);
7486}
7487
7488CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
7489 const LocationDescription &Loc, llvm::Value *Pointer,
7490 llvm::ConstantInt *Size, const llvm::Twine &Name) {
7491 IRBuilder<>::InsertPointGuard IPG(Builder);
7492 updateToLocation(Loc);
7493
7494 uint32_t SrcLocStrSize;
7495 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7496 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7497 Value *ThreadId = getOrCreateThreadID(Ident);
7498 Constant *ThreadPrivateCache =
7499 getOrCreateInternalVariable(Ty: Int8PtrPtr, Name: Name.str());
7500 llvm::Value *Args[] = {Ident, ThreadId, Pointer, Size, ThreadPrivateCache};
7501
7502 Function *Fn =
7503 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_threadprivate_cached);
7504
7505 return createRuntimeFunctionCall(Callee: Fn, Args);
7506}
7507
7508OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
7509 const LocationDescription &Loc,
7510 const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
7511 assert(!Attrs.MaxThreads.empty() && !Attrs.MaxTeams.empty() &&
7512 "expected num_threads and num_teams to be specified");
7513
7514 if (!updateToLocation(Loc))
7515 return Loc.IP;
7516
7517 uint32_t SrcLocStrSize;
7518 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7519 Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7520 Constant *IsSPMDVal = ConstantInt::getSigned(Ty: Int8, V: Attrs.ExecFlags);
7521 Constant *UseGenericStateMachineVal = ConstantInt::getSigned(
7522 Ty: Int8, V: Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD);
7523 Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Ty: Int8, V: true);
7524 Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Ty: Int16, V: 0);
7525
7526 Function *DebugKernelWrapper = Builder.GetInsertBlock()->getParent();
7527 Function *Kernel = DebugKernelWrapper;
7528
7529 // We need to strip the debug prefix to get the correct kernel name.
7530 StringRef KernelName = Kernel->getName();
7531 const std::string DebugPrefix = "_debug__";
7532 if (KernelName.ends_with(Suffix: DebugPrefix)) {
7533 KernelName = KernelName.drop_back(N: DebugPrefix.length());
7534 Kernel = M.getFunction(Name: KernelName);
7535 assert(Kernel && "Expected the real kernel to exist");
7536 }
7537
7538 // Manifest the launch configuration in the metadata matching the kernel
7539 // environment.
7540 if (Attrs.MinTeams > 1 || Attrs.MaxTeams.front() > 0)
7541 writeTeamsForKernel(T, Kernel&: *Kernel, LB: Attrs.MinTeams, UB: Attrs.MaxTeams.front());
7542
7543 // If MaxThreads not set, select the maximum between the default workgroup
7544 // size and the MinThreads value.
7545 int32_t MaxThreadsVal = Attrs.MaxThreads.front();
7546 if (MaxThreadsVal < 0)
7547 MaxThreadsVal = std::max(
7548 a: int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), b: Attrs.MinThreads);
7549
7550 if (MaxThreadsVal > 0)
7551 writeThreadBoundsForKernel(T, Kernel&: *Kernel, LB: Attrs.MinThreads, UB: MaxThreadsVal);
7552
7553 Constant *MinThreads = ConstantInt::getSigned(Ty: Int32, V: Attrs.MinThreads);
7554 Constant *MaxThreads = ConstantInt::getSigned(Ty: Int32, V: MaxThreadsVal);
7555 Constant *MinTeams = ConstantInt::getSigned(Ty: Int32, V: Attrs.MinTeams);
7556 Constant *MaxTeams = ConstantInt::getSigned(Ty: Int32, V: Attrs.MaxTeams.front());
7557 Constant *ReductionDataSize =
7558 ConstantInt::getSigned(Ty: Int32, V: Attrs.ReductionDataSize);
7559 Constant *ReductionBufferLength =
7560 ConstantInt::getSigned(Ty: Int32, V: Attrs.ReductionBufferLength);
7561
7562 Function *Fn = getOrCreateRuntimeFunctionPtr(
7563 FnID: omp::RuntimeFunction::OMPRTL___kmpc_target_init);
7564 const DataLayout &DL = Fn->getDataLayout();
7565
7566 Twine DynamicEnvironmentName = KernelName + "_dynamic_environment";
7567 Constant *DynamicEnvironmentInitializer =
7568 ConstantStruct::get(T: DynamicEnvironment, V: {DebugIndentionLevelVal});
7569 GlobalVariable *DynamicEnvironmentGV = new GlobalVariable(
7570 M, DynamicEnvironment, /*IsConstant=*/false, GlobalValue::WeakODRLinkage,
7571 DynamicEnvironmentInitializer, DynamicEnvironmentName,
7572 /*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
7573 DL.getDefaultGlobalsAddressSpace());
7574 DynamicEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
7575
7576 Constant *DynamicEnvironment =
7577 DynamicEnvironmentGV->getType() == DynamicEnvironmentPtr
7578 ? DynamicEnvironmentGV
7579 : ConstantExpr::getAddrSpaceCast(C: DynamicEnvironmentGV,
7580 Ty: DynamicEnvironmentPtr);
7581
7582 Constant *ConfigurationEnvironmentInitializer = ConstantStruct::get(
7583 T: ConfigurationEnvironment, V: {
7584 UseGenericStateMachineVal,
7585 MayUseNestedParallelismVal,
7586 IsSPMDVal,
7587 MinThreads,
7588 MaxThreads,
7589 MinTeams,
7590 MaxTeams,
7591 ReductionDataSize,
7592 ReductionBufferLength,
7593 });
7594 Constant *KernelEnvironmentInitializer = ConstantStruct::get(
7595 T: KernelEnvironment, V: {
7596 ConfigurationEnvironmentInitializer,
7597 Ident,
7598 DynamicEnvironment,
7599 });
7600 std::string KernelEnvironmentName =
7601 (KernelName + "_kernel_environment").str();
7602 GlobalVariable *KernelEnvironmentGV = new GlobalVariable(
7603 M, KernelEnvironment, /*IsConstant=*/true, GlobalValue::WeakODRLinkage,
7604 KernelEnvironmentInitializer, KernelEnvironmentName,
7605 /*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
7606 DL.getDefaultGlobalsAddressSpace());
7607 KernelEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
7608
7609 Constant *KernelEnvironment =
7610 KernelEnvironmentGV->getType() == KernelEnvironmentPtr
7611 ? KernelEnvironmentGV
7612 : ConstantExpr::getAddrSpaceCast(C: KernelEnvironmentGV,
7613 Ty: KernelEnvironmentPtr);
7614 Value *KernelLaunchEnvironment = DebugKernelWrapper->getArg(i: 0);
7615 Type *KernelLaunchEnvParamTy = Fn->getFunctionType()->getParamType(i: 1);
7616 KernelLaunchEnvironment =
7617 KernelLaunchEnvironment->getType() == KernelLaunchEnvParamTy
7618 ? KernelLaunchEnvironment
7619 : Builder.CreateAddrSpaceCast(V: KernelLaunchEnvironment,
7620 DestTy: KernelLaunchEnvParamTy);
7621 CallInst *ThreadKind = createRuntimeFunctionCall(
7622 Callee: Fn, Args: {KernelEnvironment, KernelLaunchEnvironment});
7623
7624 Value *ExecUserCode = Builder.CreateICmpEQ(
7625 LHS: ThreadKind, RHS: Constant::getAllOnesValue(Ty: ThreadKind->getType()),
7626 Name: "exec_user_code");
7627
7628 // ThreadKind = __kmpc_target_init(...)
7629 // if (ThreadKind == -1)
7630 // user_code
7631 // else
7632 // return;
7633
7634 auto *UI = Builder.CreateUnreachable();
7635 BasicBlock *CheckBB = UI->getParent();
7636 BasicBlock *UserCodeEntryBB = CheckBB->splitBasicBlock(I: UI, BBName: "user_code.entry");
7637
7638 BasicBlock *WorkerExitBB = BasicBlock::Create(
7639 Context&: CheckBB->getContext(), Name: "worker.exit", Parent: CheckBB->getParent());
7640 Builder.SetInsertPoint(WorkerExitBB);
7641 Builder.CreateRetVoid();
7642
7643 auto *CheckBBTI = CheckBB->getTerminator();
7644 Builder.SetInsertPoint(CheckBBTI);
7645 Builder.CreateCondBr(Cond: ExecUserCode, True: UI->getParent(), False: WorkerExitBB);
7646
7647 CheckBBTI->eraseFromParent();
7648 UI->eraseFromParent();
7649
7650 // Continue in the "user_code" block, see diagram above and in
7651 // openmp/libomptarget/deviceRTLs/common/include/target.h .
7652 return InsertPointTy(UserCodeEntryBB, UserCodeEntryBB->getFirstInsertionPt());
7653}
7654
7655void OpenMPIRBuilder::createTargetDeinit(const LocationDescription &Loc,
7656 int32_t TeamsReductionDataSize,
7657 int32_t TeamsReductionBufferLength) {
7658 if (!updateToLocation(Loc))
7659 return;
7660
7661 Function *Fn = getOrCreateRuntimeFunctionPtr(
7662 FnID: omp::RuntimeFunction::OMPRTL___kmpc_target_deinit);
7663
7664 createRuntimeFunctionCall(Callee: Fn, Args: {});
7665
7666 if (!TeamsReductionBufferLength || !TeamsReductionDataSize)
7667 return;
7668
7669 Function *Kernel = Builder.GetInsertBlock()->getParent();
7670 // We need to strip the debug prefix to get the correct kernel name.
7671 StringRef KernelName = Kernel->getName();
7672 const std::string DebugPrefix = "_debug__";
7673 if (KernelName.ends_with(Suffix: DebugPrefix))
7674 KernelName = KernelName.drop_back(N: DebugPrefix.length());
7675 auto *KernelEnvironmentGV =
7676 M.getNamedGlobal(Name: (KernelName + "_kernel_environment").str());
7677 assert(KernelEnvironmentGV && "Expected kernel environment global\n");
7678 auto *KernelEnvironmentInitializer = KernelEnvironmentGV->getInitializer();
7679 auto *NewInitializer = ConstantFoldInsertValueInstruction(
7680 Agg: KernelEnvironmentInitializer,
7681 Val: ConstantInt::get(Ty: Int32, V: TeamsReductionDataSize), Idxs: {0, 7});
7682 NewInitializer = ConstantFoldInsertValueInstruction(
7683 Agg: NewInitializer, Val: ConstantInt::get(Ty: Int32, V: TeamsReductionBufferLength),
7684 Idxs: {0, 8});
7685 KernelEnvironmentGV->setInitializer(NewInitializer);
7686}
7687
7688static void updateNVPTXAttr(Function &Kernel, StringRef Name, int32_t Value,
7689 bool Min) {
7690 if (Kernel.hasFnAttribute(Kind: Name)) {
7691 int32_t OldLimit = Kernel.getFnAttributeAsParsedInteger(Kind: Name);
7692 Value = Min ? std::min(a: OldLimit, b: Value) : std::max(a: OldLimit, b: Value);
7693 }
7694 Kernel.addFnAttr(Kind: Name, Val: llvm::utostr(X: Value));
7695}
7696
7697std::pair<int32_t, int32_t>
7698OpenMPIRBuilder::readThreadBoundsForKernel(const Triple &T, Function &Kernel) {
7699 int32_t ThreadLimit =
7700 Kernel.getFnAttributeAsParsedInteger(Kind: "omp_target_thread_limit");
7701
7702 if (T.isAMDGPU()) {
7703 const auto &Attr = Kernel.getFnAttribute(Kind: "amdgpu-flat-work-group-size");
7704 if (!Attr.isValid() || !Attr.isStringAttribute())
7705 return {0, ThreadLimit};
7706 auto [LBStr, UBStr] = Attr.getValueAsString().split(Separator: ',');
7707 int32_t LB, UB;
7708 if (!llvm::to_integer(S: UBStr, Num&: UB, Base: 10))
7709 return {0, ThreadLimit};
7710 UB = ThreadLimit ? std::min(a: ThreadLimit, b: UB) : UB;
7711 if (!llvm::to_integer(S: LBStr, Num&: LB, Base: 10))
7712 return {0, UB};
7713 return {LB, UB};
7714 }
7715
7716 if (Kernel.hasFnAttribute(Kind: "nvvm.maxntid")) {
7717 int32_t UB = Kernel.getFnAttributeAsParsedInteger(Kind: "nvvm.maxntid");
7718 return {0, ThreadLimit ? std::min(a: ThreadLimit, b: UB) : UB};
7719 }
7720 return {0, ThreadLimit};
7721}
7722
7723void OpenMPIRBuilder::writeThreadBoundsForKernel(const Triple &T,
7724 Function &Kernel, int32_t LB,
7725 int32_t UB) {
7726 Kernel.addFnAttr(Kind: "omp_target_thread_limit", Val: std::to_string(val: UB));
7727
7728 if (T.isAMDGPU()) {
7729 Kernel.addFnAttr(Kind: "amdgpu-flat-work-group-size",
7730 Val: llvm::utostr(X: LB) + "," + llvm::utostr(X: UB));
7731 return;
7732 }
7733
7734 updateNVPTXAttr(Kernel, Name: "nvvm.maxntid", Value: UB, Min: true);
7735}
7736
7737std::pair<int32_t, int32_t>
7738OpenMPIRBuilder::readTeamBoundsForKernel(const Triple &, Function &Kernel) {
7739 // TODO: Read from backend annotations if available.
7740 return {0, Kernel.getFnAttributeAsParsedInteger(Kind: "omp_target_num_teams")};
7741}
7742
7743void OpenMPIRBuilder::writeTeamsForKernel(const Triple &T, Function &Kernel,
7744 int32_t LB, int32_t UB) {
7745 if (T.isNVPTX())
7746 if (UB > 0)
7747 Kernel.addFnAttr(Kind: "nvvm.maxclusterrank", Val: llvm::utostr(X: UB));
7748 if (T.isAMDGPU())
7749 Kernel.addFnAttr(Kind: "amdgpu-max-num-workgroups", Val: llvm::utostr(X: LB) + ",1,1");
7750
7751 Kernel.addFnAttr(Kind: "omp_target_num_teams", Val: std::to_string(val: LB));
7752}
7753
7754void OpenMPIRBuilder::setOutlinedTargetRegionFunctionAttributes(
7755 Function *OutlinedFn) {
7756 if (Config.isTargetDevice()) {
7757 OutlinedFn->setLinkage(GlobalValue::WeakODRLinkage);
7758 // TODO: Determine if DSO local can be set to true.
7759 OutlinedFn->setDSOLocal(false);
7760 OutlinedFn->setVisibility(GlobalValue::ProtectedVisibility);
7761 if (T.isAMDGCN())
7762 OutlinedFn->setCallingConv(CallingConv::AMDGPU_KERNEL);
7763 else if (T.isNVPTX())
7764 OutlinedFn->setCallingConv(CallingConv::PTX_Kernel);
7765 else if (T.isSPIRV())
7766 OutlinedFn->setCallingConv(CallingConv::SPIR_KERNEL);
7767 }
7768}
7769
7770Constant *OpenMPIRBuilder::createOutlinedFunctionID(Function *OutlinedFn,
7771 StringRef EntryFnIDName) {
7772 if (Config.isTargetDevice()) {
7773 assert(OutlinedFn && "The outlined function must exist if embedded");
7774 return OutlinedFn;
7775 }
7776
7777 return new GlobalVariable(
7778 M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
7779 Constant::getNullValue(Ty: Builder.getInt8Ty()), EntryFnIDName);
7780}
7781
7782Constant *OpenMPIRBuilder::createTargetRegionEntryAddr(Function *OutlinedFn,
7783 StringRef EntryFnName) {
7784 if (OutlinedFn)
7785 return OutlinedFn;
7786
7787 assert(!M.getGlobalVariable(EntryFnName, true) &&
7788 "Named kernel already exists?");
7789 return new GlobalVariable(
7790 M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::InternalLinkage,
7791 Constant::getNullValue(Ty: Builder.getInt8Ty()), EntryFnName);
7792}
7793
7794Error OpenMPIRBuilder::emitTargetRegionFunction(
7795 TargetRegionEntryInfo &EntryInfo,
7796 FunctionGenCallback &GenerateFunctionCallback, bool IsOffloadEntry,
7797 Function *&OutlinedFn, Constant *&OutlinedFnID) {
7798
7799 SmallString<64> EntryFnName;
7800 OffloadInfoManager.getTargetRegionEntryFnName(Name&: EntryFnName, EntryInfo);
7801
7802 if (Config.isTargetDevice() || !Config.openMPOffloadMandatory()) {
7803 Expected<Function *> CBResult = GenerateFunctionCallback(EntryFnName);
7804 if (!CBResult)
7805 return CBResult.takeError();
7806 OutlinedFn = *CBResult;
7807 } else {
7808 OutlinedFn = nullptr;
7809 }
7810
7811 // If this target outline function is not an offload entry, we don't need to
7812 // register it. This may be in the case of a false if clause, or if there are
7813 // no OpenMP targets.
7814 if (!IsOffloadEntry)
7815 return Error::success();
7816
7817 std::string EntryFnIDName =
7818 Config.isTargetDevice()
7819 ? std::string(EntryFnName)
7820 : createPlatformSpecificName(Parts: {EntryFnName, "region_id"});
7821
7822 OutlinedFnID = registerTargetRegionFunction(EntryInfo, OutlinedFunction: OutlinedFn,
7823 EntryFnName, EntryFnIDName);
7824 return Error::success();
7825}
7826
7827Constant *OpenMPIRBuilder::registerTargetRegionFunction(
7828 TargetRegionEntryInfo &EntryInfo, Function *OutlinedFn,
7829 StringRef EntryFnName, StringRef EntryFnIDName) {
7830 if (OutlinedFn)
7831 setOutlinedTargetRegionFunctionAttributes(OutlinedFn);
7832 auto OutlinedFnID = createOutlinedFunctionID(OutlinedFn, EntryFnIDName);
7833 auto EntryAddr = createTargetRegionEntryAddr(OutlinedFn, EntryFnName);
7834 OffloadInfoManager.registerTargetRegionEntryInfo(
7835 EntryInfo, Addr: EntryAddr, ID: OutlinedFnID,
7836 Flags: OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion);
7837 return OutlinedFnID;
7838}
7839
7840OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData(
7841 const LocationDescription &Loc, InsertPointTy AllocaIP,
7842 InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
7843 TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
7844 CustomMapperCallbackTy CustomMapperCB, omp::RuntimeFunction *MapperFunc,
7845 function_ref<InsertPointOrErrorTy(InsertPointTy CodeGenIP,
7846 BodyGenTy BodyGenType)>
7847 BodyGenCB,
7848 function_ref<void(unsigned int, Value *)> DeviceAddrCB, Value *SrcLocInfo) {
7849 if (!updateToLocation(Loc))
7850 return InsertPointTy();
7851
7852 Builder.restoreIP(IP: CodeGenIP);
7853
7854 bool IsStandAlone = !BodyGenCB;
7855 MapInfosTy *MapInfo;
7856 // Generate the code for the opening of the data environment. Capture all the
7857 // arguments of the runtime call by reference because they are used in the
7858 // closing of the region.
7859 auto BeginThenGen = [&](InsertPointTy AllocaIP,
7860 InsertPointTy CodeGenIP) -> Error {
7861 MapInfo = &GenMapInfoCB(Builder.saveIP());
7862 if (Error Err = emitOffloadingArrays(
7863 AllocaIP, CodeGenIP: Builder.saveIP(), CombinedInfo&: *MapInfo, Info, CustomMapperCB,
7864 /*IsNonContiguous=*/true, DeviceAddrCB))
7865 return Err;
7866
7867 TargetDataRTArgs RTArgs;
7868 emitOffloadingArraysArgument(Builder, RTArgs, Info);
7869
7870 // Emit the number of elements in the offloading arrays.
7871 Value *PointerNum = Builder.getInt32(C: Info.NumberOfPtrs);
7872
7873 // Source location for the ident struct
7874 if (!SrcLocInfo) {
7875 uint32_t SrcLocStrSize;
7876 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7877 SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7878 }
7879
7880 SmallVector<llvm::Value *, 13> OffloadingArgs = {
7881 SrcLocInfo, DeviceID,
7882 PointerNum, RTArgs.BasePointersArray,
7883 RTArgs.PointersArray, RTArgs.SizesArray,
7884 RTArgs.MapTypesArray, RTArgs.MapNamesArray,
7885 RTArgs.MappersArray};
7886
7887 if (IsStandAlone) {
7888 assert(MapperFunc && "MapperFunc missing for standalone target data");
7889
7890 auto TaskBodyCB = [&](Value *, Value *,
7891 IRBuilderBase::InsertPoint) -> Error {
7892 if (Info.HasNoWait) {
7893 OffloadingArgs.append(IL: {llvm::Constant::getNullValue(Ty: Int32),
7894 llvm::Constant::getNullValue(Ty: VoidPtr),
7895 llvm::Constant::getNullValue(Ty: Int32),
7896 llvm::Constant::getNullValue(Ty: VoidPtr)});
7897 }
7898
7899 createRuntimeFunctionCall(Callee: getOrCreateRuntimeFunctionPtr(FnID: *MapperFunc),
7900 Args: OffloadingArgs);
7901
7902 if (Info.HasNoWait) {
7903 BasicBlock *OffloadContBlock =
7904 BasicBlock::Create(Context&: Builder.getContext(), Name: "omp_offload.cont");
7905 Function *CurFn = Builder.GetInsertBlock()->getParent();
7906 emitBlock(BB: OffloadContBlock, CurFn, /*IsFinished=*/true);
7907 Builder.restoreIP(IP: Builder.saveIP());
7908 }
7909 return Error::success();
7910 };
7911
7912 bool RequiresOuterTargetTask = Info.HasNoWait;
7913 if (!RequiresOuterTargetTask)
7914 cantFail(Err: TaskBodyCB(/*DeviceID=*/nullptr, /*RTLoc=*/nullptr,
7915 /*TargetTaskAllocaIP=*/{}));
7916 else
7917 cantFail(ValOrErr: emitTargetTask(TaskBodyCB, DeviceID, RTLoc: SrcLocInfo, AllocaIP,
7918 /*Dependencies=*/{}, RTArgs, HasNoWait: Info.HasNoWait));
7919 } else {
7920 Function *BeginMapperFunc = getOrCreateRuntimeFunctionPtr(
7921 FnID: omp::OMPRTL___tgt_target_data_begin_mapper);
7922
7923 createRuntimeFunctionCall(Callee: BeginMapperFunc, Args: OffloadingArgs);
7924
7925 for (auto DeviceMap : Info.DevicePtrInfoMap) {
7926 if (isa<AllocaInst>(Val: DeviceMap.second.second)) {
7927 auto *LI =
7928 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: DeviceMap.second.first);
7929 Builder.CreateStore(Val: LI, Ptr: DeviceMap.second.second);
7930 }
7931 }
7932
7933 // If device pointer privatization is required, emit the body of the
7934 // region here. It will have to be duplicated: with and without
7935 // privatization.
7936 InsertPointOrErrorTy AfterIP =
7937 BodyGenCB(Builder.saveIP(), BodyGenTy::Priv);
7938 if (!AfterIP)
7939 return AfterIP.takeError();
7940 Builder.restoreIP(IP: *AfterIP);
7941 }
7942 return Error::success();
7943 };
7944
7945 // If we need device pointer privatization, we need to emit the body of the
7946 // region with no privatization in the 'else' branch of the conditional.
7947 // Otherwise, we don't have to do anything.
7948 auto BeginElseGen = [&](InsertPointTy AllocaIP,
7949 InsertPointTy CodeGenIP) -> Error {
7950 InsertPointOrErrorTy AfterIP =
7951 BodyGenCB(Builder.saveIP(), BodyGenTy::DupNoPriv);
7952 if (!AfterIP)
7953 return AfterIP.takeError();
7954 Builder.restoreIP(IP: *AfterIP);
7955 return Error::success();
7956 };
7957
7958 // Generate code for the closing of the data region.
7959 auto EndThenGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
7960 TargetDataRTArgs RTArgs;
7961 Info.EmitDebug = !MapInfo->Names.empty();
7962 emitOffloadingArraysArgument(Builder, RTArgs, Info, /*ForEndCall=*/true);
7963
7964 // Emit the number of elements in the offloading arrays.
7965 Value *PointerNum = Builder.getInt32(C: Info.NumberOfPtrs);
7966
7967 // Source location for the ident struct
7968 if (!SrcLocInfo) {
7969 uint32_t SrcLocStrSize;
7970 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7971 SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7972 }
7973
7974 Value *OffloadingArgs[] = {SrcLocInfo, DeviceID,
7975 PointerNum, RTArgs.BasePointersArray,
7976 RTArgs.PointersArray, RTArgs.SizesArray,
7977 RTArgs.MapTypesArray, RTArgs.MapNamesArray,
7978 RTArgs.MappersArray};
7979 Function *EndMapperFunc =
7980 getOrCreateRuntimeFunctionPtr(FnID: omp::OMPRTL___tgt_target_data_end_mapper);
7981
7982 createRuntimeFunctionCall(Callee: EndMapperFunc, Args: OffloadingArgs);
7983 return Error::success();
7984 };
7985
7986 // We don't have to do anything to close the region if the if clause evaluates
7987 // to false.
7988 auto EndElseGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
7989 return Error::success();
7990 };
7991
7992 Error Err = [&]() -> Error {
7993 if (BodyGenCB) {
7994 Error Err = [&]() {
7995 if (IfCond)
7996 return emitIfClause(Cond: IfCond, ThenGen: BeginThenGen, ElseGen: BeginElseGen, AllocaIP);
7997 return BeginThenGen(AllocaIP, Builder.saveIP());
7998 }();
7999
8000 if (Err)
8001 return Err;
8002
8003 // If we don't require privatization of device pointers, we emit the body
8004 // in between the runtime calls. This avoids duplicating the body code.
8005 InsertPointOrErrorTy AfterIP =
8006 BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv);
8007 if (!AfterIP)
8008 return AfterIP.takeError();
8009 restoreIPandDebugLoc(Builder, IP: *AfterIP);
8010
8011 if (IfCond)
8012 return emitIfClause(Cond: IfCond, ThenGen: EndThenGen, ElseGen: EndElseGen, AllocaIP);
8013 return EndThenGen(AllocaIP, Builder.saveIP());
8014 }
8015 if (IfCond)
8016 return emitIfClause(Cond: IfCond, ThenGen: BeginThenGen, ElseGen: EndElseGen, AllocaIP);
8017 return BeginThenGen(AllocaIP, Builder.saveIP());
8018 }();
8019
8020 if (Err)
8021 return Err;
8022
8023 return Builder.saveIP();
8024}
8025
8026FunctionCallee
8027OpenMPIRBuilder::createForStaticInitFunction(unsigned IVSize, bool IVSigned,
8028 bool IsGPUDistribute) {
8029 assert((IVSize == 32 || IVSize == 64) &&
8030 "IV size is not compatible with the omp runtime");
8031 RuntimeFunction Name;
8032 if (IsGPUDistribute)
8033 Name = IVSize == 32
8034 ? (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_4
8035 : omp::OMPRTL___kmpc_distribute_static_init_4u)
8036 : (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_8
8037 : omp::OMPRTL___kmpc_distribute_static_init_8u);
8038 else
8039 Name = IVSize == 32 ? (IVSigned ? omp::OMPRTL___kmpc_for_static_init_4
8040 : omp::OMPRTL___kmpc_for_static_init_4u)
8041 : (IVSigned ? omp::OMPRTL___kmpc_for_static_init_8
8042 : omp::OMPRTL___kmpc_for_static_init_8u);
8043
8044 return getOrCreateRuntimeFunction(M, FnID: Name);
8045}
8046
8047FunctionCallee OpenMPIRBuilder::createDispatchInitFunction(unsigned IVSize,
8048 bool IVSigned) {
8049 assert((IVSize == 32 || IVSize == 64) &&
8050 "IV size is not compatible with the omp runtime");
8051 RuntimeFunction Name = IVSize == 32
8052 ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_4
8053 : omp::OMPRTL___kmpc_dispatch_init_4u)
8054 : (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_8
8055 : omp::OMPRTL___kmpc_dispatch_init_8u);
8056
8057 return getOrCreateRuntimeFunction(M, FnID: Name);
8058}
8059
8060FunctionCallee OpenMPIRBuilder::createDispatchNextFunction(unsigned IVSize,
8061 bool IVSigned) {
8062 assert((IVSize == 32 || IVSize == 64) &&
8063 "IV size is not compatible with the omp runtime");
8064 RuntimeFunction Name = IVSize == 32
8065 ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_4
8066 : omp::OMPRTL___kmpc_dispatch_next_4u)
8067 : (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_8
8068 : omp::OMPRTL___kmpc_dispatch_next_8u);
8069
8070 return getOrCreateRuntimeFunction(M, FnID: Name);
8071}
8072
8073FunctionCallee OpenMPIRBuilder::createDispatchFiniFunction(unsigned IVSize,
8074 bool IVSigned) {
8075 assert((IVSize == 32 || IVSize == 64) &&
8076 "IV size is not compatible with the omp runtime");
8077 RuntimeFunction Name = IVSize == 32
8078 ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_4
8079 : omp::OMPRTL___kmpc_dispatch_fini_4u)
8080 : (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_8
8081 : omp::OMPRTL___kmpc_dispatch_fini_8u);
8082
8083 return getOrCreateRuntimeFunction(M, FnID: Name);
8084}
8085
8086FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
8087 return getOrCreateRuntimeFunction(M, FnID: omp::OMPRTL___kmpc_dispatch_deinit);
8088}
8089
8090static void FixupDebugInfoForOutlinedFunction(
8091 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, Function *Func,
8092 DenseMap<Value *, std::tuple<Value *, unsigned>> &ValueReplacementMap) {
8093
8094 DISubprogram *NewSP = Func->getSubprogram();
8095 if (!NewSP)
8096 return;
8097
8098 SmallDenseMap<DILocalVariable *, DILocalVariable *> RemappedVariables;
8099
8100 auto GetUpdatedDIVariable = [&](DILocalVariable *OldVar, unsigned arg) {
8101 DILocalVariable *&NewVar = RemappedVariables[OldVar];
8102 // Only use cached variable if the arg number matches. This is important
8103 // so that DIVariable created for privatized variables are not discarded.
8104 if (NewVar && (arg == NewVar->getArg()))
8105 return NewVar;
8106
8107 NewVar = llvm::DILocalVariable::get(
8108 Context&: Builder.getContext(), Scope: OldVar->getScope(), Name: OldVar->getName(),
8109 File: OldVar->getFile(), Line: OldVar->getLine(), Type: OldVar->getType(), Arg: arg,
8110 Flags: OldVar->getFlags(), AlignInBits: OldVar->getAlignInBits(), Annotations: OldVar->getAnnotations());
8111 return NewVar;
8112 };
8113
8114 auto UpdateDebugRecord = [&](auto *DR) {
8115 DILocalVariable *OldVar = DR->getVariable();
8116 unsigned ArgNo = 0;
8117 for (auto Loc : DR->location_ops()) {
8118 auto Iter = ValueReplacementMap.find(Loc);
8119 if (Iter != ValueReplacementMap.end()) {
8120 DR->replaceVariableLocationOp(Loc, std::get<0>(Iter->second));
8121 ArgNo = std::get<1>(Iter->second) + 1;
8122 }
8123 }
8124 if (ArgNo != 0)
8125 DR->setVariable(GetUpdatedDIVariable(OldVar, ArgNo));
8126 };
8127
8128 // The location and scope of variable intrinsics and records still point to
8129 // the parent function of the target region. Update them.
8130 for (Instruction &I : instructions(F: Func)) {
8131 assert(!isa<llvm::DbgVariableIntrinsic>(&I) &&
8132 "Unexpected debug intrinsic");
8133 for (DbgVariableRecord &DVR : filterDbgVars(R: I.getDbgRecordRange()))
8134 UpdateDebugRecord(&DVR);
8135 }
8136 // An extra argument is passed to the device. Create the debug data for it.
8137 if (OMPBuilder.Config.isTargetDevice()) {
8138 DICompileUnit *CU = NewSP->getUnit();
8139 Module *M = Func->getParent();
8140 DIBuilder DB(*M, true, CU);
8141 DIType *VoidPtrTy =
8142 DB.createQualifiedType(Tag: dwarf::DW_TAG_pointer_type, FromTy: nullptr);
8143 DILocalVariable *Var = DB.createParameterVariable(
8144 Scope: NewSP, Name: "dyn_ptr", /*ArgNo*/ 1, File: NewSP->getFile(), /*LineNo=*/0,
8145 Ty: VoidPtrTy, /*AlwaysPreserve=*/false, Flags: DINode::DIFlags::FlagArtificial);
8146 auto Loc = DILocation::get(Context&: Func->getContext(), Line: 0, Column: 0, Scope: NewSP, InlinedAt: 0);
8147 DB.insertDeclare(Storage: &(*Func->arg_begin()), VarInfo: Var, Expr: DB.createExpression(), DL: Loc,
8148 InsertAtEnd: &(*Func->begin()));
8149 }
8150}
8151
8152static Value *removeASCastIfPresent(Value *V) {
8153 if (Operator::getOpcode(V) == Instruction::AddrSpaceCast)
8154 return cast<Operator>(Val: V)->getOperand(i: 0);
8155 return V;
8156}
8157
8158static Expected<Function *> createOutlinedFunction(
8159 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
8160 const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
8161 StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
8162 OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
8163 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
8164 SmallVector<Type *> ParameterTypes;
8165 if (OMPBuilder.Config.isTargetDevice()) {
8166 // Add the "implicit" runtime argument we use to provide launch specific
8167 // information for target devices.
8168 auto *Int8PtrTy = PointerType::getUnqual(C&: Builder.getContext());
8169 ParameterTypes.push_back(Elt: Int8PtrTy);
8170
8171 // All parameters to target devices are passed as pointers
8172 // or i64. This assumes 64-bit address spaces/pointers.
8173 for (auto &Arg : Inputs)
8174 ParameterTypes.push_back(Elt: Arg->getType()->isPointerTy()
8175 ? Arg->getType()
8176 : Type::getInt64Ty(C&: Builder.getContext()));
8177 } else {
8178 for (auto &Arg : Inputs)
8179 ParameterTypes.push_back(Elt: Arg->getType());
8180 }
8181
8182 auto BB = Builder.GetInsertBlock();
8183 auto M = BB->getModule();
8184 auto FuncType = FunctionType::get(Result: Builder.getVoidTy(), Params: ParameterTypes,
8185 /*isVarArg*/ false);
8186 auto Func =
8187 Function::Create(Ty: FuncType, Linkage: GlobalValue::InternalLinkage, N: FuncName, M);
8188
8189 // Forward target-cpu and target-features function attributes from the
8190 // original function to the new outlined function.
8191 Function *ParentFn = Builder.GetInsertBlock()->getParent();
8192
8193 auto TargetCpuAttr = ParentFn->getFnAttribute(Kind: "target-cpu");
8194 if (TargetCpuAttr.isStringAttribute())
8195 Func->addFnAttr(Attr: TargetCpuAttr);
8196
8197 auto TargetFeaturesAttr = ParentFn->getFnAttribute(Kind: "target-features");
8198 if (TargetFeaturesAttr.isStringAttribute())
8199 Func->addFnAttr(Attr: TargetFeaturesAttr);
8200
8201 if (OMPBuilder.Config.isTargetDevice()) {
8202 Value *ExecMode =
8203 OMPBuilder.emitKernelExecutionMode(KernelName: FuncName, Mode: DefaultAttrs.ExecFlags);
8204 OMPBuilder.emitUsed(Name: "llvm.compiler.used", List: {ExecMode});
8205 }
8206
8207 // Save insert point.
8208 IRBuilder<>::InsertPointGuard IPG(Builder);
8209 // We will generate the entries in the outlined function but the debug
8210 // location may still be pointing to the parent function. Reset it now.
8211 Builder.SetCurrentDebugLocation(llvm::DebugLoc());
8212
8213 // Generate the region into the function.
8214 BasicBlock *EntryBB = BasicBlock::Create(Context&: Builder.getContext(), Name: "entry", Parent: Func);
8215 Builder.SetInsertPoint(EntryBB);
8216
8217 // Insert target init call in the device compilation pass.
8218 if (OMPBuilder.Config.isTargetDevice())
8219 Builder.restoreIP(IP: OMPBuilder.createTargetInit(Loc: Builder, Attrs: DefaultAttrs));
8220
8221 BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
8222
8223 // As we embed the user code in the middle of our target region after we
8224 // generate entry code, we must move what allocas we can into the entry
8225 // block to avoid possible breaking optimisations for device
8226 if (OMPBuilder.Config.isTargetDevice())
8227 OMPBuilder.ConstantAllocaRaiseCandidates.emplace_back(Args&: Func);
8228
8229 // Insert target deinit call in the device compilation pass.
8230 BasicBlock *OutlinedBodyBB =
8231 splitBB(Builder, /*CreateBranch=*/true, Name: "outlined.body");
8232 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = CBFunc(
8233 Builder.saveIP(),
8234 OpenMPIRBuilder::InsertPointTy(OutlinedBodyBB, OutlinedBodyBB->begin()));
8235 if (!AfterIP)
8236 return AfterIP.takeError();
8237 Builder.restoreIP(IP: *AfterIP);
8238 if (OMPBuilder.Config.isTargetDevice())
8239 OMPBuilder.createTargetDeinit(Loc: Builder);
8240
8241 // Insert return instruction.
8242 Builder.CreateRetVoid();
8243
8244 // New Alloca IP at entry point of created device function.
8245 Builder.SetInsertPoint(EntryBB->getFirstNonPHIIt());
8246 auto AllocaIP = Builder.saveIP();
8247
8248 Builder.SetInsertPoint(UserCodeEntryBB->getFirstNonPHIOrDbg());
8249
8250 // Skip the artificial dyn_ptr on the device.
8251 const auto &ArgRange =
8252 OMPBuilder.Config.isTargetDevice()
8253 ? make_range(x: Func->arg_begin() + 1, y: Func->arg_end())
8254 : Func->args();
8255
8256 DenseMap<Value *, std::tuple<Value *, unsigned>> ValueReplacementMap;
8257
8258 auto ReplaceValue = [](Value *Input, Value *InputCopy, Function *Func) {
8259 // Things like GEP's can come in the form of Constants. Constants and
8260 // ConstantExpr's do not have access to the knowledge of what they're
8261 // contained in, so we must dig a little to find an instruction so we
8262 // can tell if they're used inside of the function we're outlining. We
8263 // also replace the original constant expression with a new instruction
8264 // equivalent; an instruction as it allows easy modification in the
8265 // following loop, as we can now know the constant (instruction) is
8266 // owned by our target function and replaceUsesOfWith can now be invoked
8267 // on it (cannot do this with constants it seems). A brand new one also
8268 // allows us to be cautious as it is perhaps possible the old expression
8269 // was used inside of the function but exists and is used externally
8270 // (unlikely by the nature of a Constant, but still).
8271 // NOTE: We cannot remove dead constants that have been rewritten to
8272 // instructions at this stage, we run the risk of breaking later lowering
8273 // by doing so as we could still be in the process of lowering the module
8274 // from MLIR to LLVM-IR and the MLIR lowering may still require the original
8275 // constants we have created rewritten versions of.
8276 if (auto *Const = dyn_cast<Constant>(Val: Input))
8277 convertUsersOfConstantsToInstructions(Consts: Const, RestrictToFunc: Func, RemoveDeadConstants: false);
8278
8279 // Collect users before iterating over them to avoid invalidating the
8280 // iteration in case a user uses Input more than once (e.g. a call
8281 // instruction).
8282 SetVector<User *> Users(Input->users().begin(), Input->users().end());
8283 // Collect all the instructions
8284 for (User *User : make_early_inc_range(Range&: Users))
8285 if (auto *Instr = dyn_cast<Instruction>(Val: User))
8286 if (Instr->getFunction() == Func)
8287 Instr->replaceUsesOfWith(From: Input, To: InputCopy);
8288 };
8289
8290 SmallVector<std::pair<Value *, Value *>> DeferredReplacement;
8291
8292 // Rewrite uses of input valus to parameters.
8293 for (auto InArg : zip(t&: Inputs, u: ArgRange)) {
8294 Value *Input = std::get<0>(t&: InArg);
8295 Argument &Arg = std::get<1>(t&: InArg);
8296 Value *InputCopy = nullptr;
8297
8298 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
8299 ArgAccessorFuncCB(Arg, Input, InputCopy, AllocaIP, Builder.saveIP());
8300 if (!AfterIP)
8301 return AfterIP.takeError();
8302 Builder.restoreIP(IP: *AfterIP);
8303 ValueReplacementMap[Input] = std::make_tuple(args&: InputCopy, args: Arg.getArgNo());
8304
8305 // In certain cases a Global may be set up for replacement, however, this
8306 // Global may be used in multiple arguments to the kernel, just segmented
8307 // apart, for example, if we have a global array, that is sectioned into
8308 // multiple mappings (technically not legal in OpenMP, but there is a case
8309 // in Fortran for Common Blocks where this is neccesary), we will end up
8310 // with GEP's into this array inside the kernel, that refer to the Global
8311 // but are technically separate arguments to the kernel for all intents and
8312 // purposes. If we have mapped a segment that requires a GEP into the 0-th
8313 // index, it will fold into an referal to the Global, if we then encounter
8314 // this folded GEP during replacement all of the references to the
8315 // Global in the kernel will be replaced with the argument we have generated
8316 // that corresponds to it, including any other GEP's that refer to the
8317 // Global that may be other arguments. This will invalidate all of the other
8318 // preceding mapped arguments that refer to the same global that may be
8319 // separate segments. To prevent this, we defer global processing until all
8320 // other processing has been performed.
8321 if (llvm::isa<llvm::GlobalValue, llvm::GlobalObject, llvm::GlobalVariable>(
8322 Val: removeASCastIfPresent(V: Input))) {
8323 DeferredReplacement.push_back(Elt: std::make_pair(x&: Input, y&: InputCopy));
8324 continue;
8325 }
8326
8327 if (isa<ConstantData>(Val: Input))
8328 continue;
8329
8330 ReplaceValue(Input, InputCopy, Func);
8331 }
8332
8333 // Replace all of our deferred Input values, currently just Globals.
8334 for (auto Deferred : DeferredReplacement)
8335 ReplaceValue(std::get<0>(in&: Deferred), std::get<1>(in&: Deferred), Func);
8336
8337 FixupDebugInfoForOutlinedFunction(OMPBuilder, Builder, Func,
8338 ValueReplacementMap);
8339 return Func;
8340}
8341/// Given a task descriptor, TaskWithPrivates, return the pointer to the block
8342/// of pointers containing shared data between the parent task and the created
8343/// task.
8344static LoadInst *loadSharedDataFromTaskDescriptor(OpenMPIRBuilder &OMPIRBuilder,
8345 IRBuilderBase &Builder,
8346 Value *TaskWithPrivates,
8347 Type *TaskWithPrivatesTy) {
8348
8349 Type *TaskTy = OMPIRBuilder.Task;
8350 LLVMContext &Ctx = Builder.getContext();
8351 Value *TaskT =
8352 Builder.CreateStructGEP(Ty: TaskWithPrivatesTy, Ptr: TaskWithPrivates, Idx: 0);
8353 Value *Shareds = TaskT;
8354 // TaskWithPrivatesTy can be one of the following
8355 // 1. %struct.task_with_privates = type { %struct.kmp_task_ompbuilder_t,
8356 // %struct.privates }
8357 // 2. %struct.kmp_task_ompbuilder_t ;; This is simply TaskTy
8358 //
8359 // In the former case, that is when TaskWithPrivatesTy != TaskTy,
8360 // its first member has to be the task descriptor. TaskTy is the type of the
8361 // task descriptor. TaskT is the pointer to the task descriptor. Loading the
8362 // first member of TaskT, gives us the pointer to shared data.
8363 if (TaskWithPrivatesTy != TaskTy)
8364 Shareds = Builder.CreateStructGEP(Ty: TaskTy, Ptr: TaskT, Idx: 0);
8365 return Builder.CreateLoad(Ty: PointerType::getUnqual(C&: Ctx), Ptr: Shareds);
8366}
8367/// Create an entry point for a target task with the following.
8368/// It'll have the following signature
8369/// void @.omp_target_task_proxy_func(i32 %thread.id, ptr %task)
8370/// This function is called from emitTargetTask once the
8371/// code to launch the target kernel has been outlined already.
8372/// NumOffloadingArrays is the number of offloading arrays that we need to copy
8373/// into the task structure so that the deferred target task can access this
8374/// data even after the stack frame of the generating task has been rolled
8375/// back. Offloading arrays contain base pointers, pointers, sizes etc
8376/// of the data that the target kernel will access. These in effect are the
8377/// non-empty arrays of pointers held by OpenMPIRBuilder::TargetDataRTArgs.
8378static Function *emitTargetTaskProxyFunction(
8379 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, CallInst *StaleCI,
8380 StructType *PrivatesTy, StructType *TaskWithPrivatesTy,
8381 const size_t NumOffloadingArrays, const int SharedArgsOperandNo) {
8382
8383 // If NumOffloadingArrays is non-zero, PrivatesTy better not be nullptr.
8384 // This is because PrivatesTy is the type of the structure in which
8385 // we pass the offloading arrays to the deferred target task.
8386 assert((!NumOffloadingArrays || PrivatesTy) &&
8387 "PrivatesTy cannot be nullptr when there are offloadingArrays"
8388 "to privatize");
8389
8390 Module &M = OMPBuilder.M;
8391 // KernelLaunchFunction is the target launch function, i.e.
8392 // the function that sets up kernel arguments and calls
8393 // __tgt_target_kernel to launch the kernel on the device.
8394 //
8395 Function *KernelLaunchFunction = StaleCI->getCalledFunction();
8396
8397 // StaleCI is the CallInst which is the call to the outlined
8398 // target kernel launch function. If there are local live-in values
8399 // that the outlined function uses then these are aggregated into a structure
8400 // which is passed as the second argument. If there are no local live-in
8401 // values or if all values used by the outlined kernel are global variables,
8402 // then there's only one argument, the threadID. So, StaleCI can be
8403 //
8404 // %structArg = alloca { ptr, ptr }, align 8
8405 // %gep_ = getelementptr { ptr, ptr }, ptr %structArg, i32 0, i32 0
8406 // store ptr %20, ptr %gep_, align 8
8407 // %gep_8 = getelementptr { ptr, ptr }, ptr %structArg, i32 0, i32 1
8408 // store ptr %21, ptr %gep_8, align 8
8409 // call void @_QQmain..omp_par.1(i32 %global.tid.val6, ptr %structArg)
8410 //
8411 // OR
8412 //
8413 // call void @_QQmain..omp_par.1(i32 %global.tid.val6)
8414 OpenMPIRBuilder::InsertPointTy IP(StaleCI->getParent(),
8415 StaleCI->getIterator());
8416
8417 LLVMContext &Ctx = StaleCI->getParent()->getContext();
8418
8419 Type *ThreadIDTy = Type::getInt32Ty(C&: Ctx);
8420 Type *TaskPtrTy = OMPBuilder.TaskPtr;
8421 [[maybe_unused]] Type *TaskTy = OMPBuilder.Task;
8422
8423 auto ProxyFnTy =
8424 FunctionType::get(Result: Builder.getVoidTy(), Params: {ThreadIDTy, TaskPtrTy},
8425 /* isVarArg */ false);
8426 auto ProxyFn = Function::Create(Ty: ProxyFnTy, Linkage: GlobalValue::InternalLinkage,
8427 N: ".omp_target_task_proxy_func",
8428 M: Builder.GetInsertBlock()->getModule());
8429 Value *ThreadId = ProxyFn->getArg(i: 0);
8430 Value *TaskWithPrivates = ProxyFn->getArg(i: 1);
8431 ThreadId->setName("thread.id");
8432 TaskWithPrivates->setName("task");
8433
8434 bool HasShareds = SharedArgsOperandNo > 0;
8435 bool HasOffloadingArrays = NumOffloadingArrays > 0;
8436 BasicBlock *EntryBB =
8437 BasicBlock::Create(Context&: Builder.getContext(), Name: "entry", Parent: ProxyFn);
8438 Builder.SetInsertPoint(EntryBB);
8439
8440 SmallVector<Value *> KernelLaunchArgs;
8441 KernelLaunchArgs.reserve(N: StaleCI->arg_size());
8442 KernelLaunchArgs.push_back(Elt: ThreadId);
8443
8444 if (HasOffloadingArrays) {
8445 assert(TaskTy != TaskWithPrivatesTy &&
8446 "If there are offloading arrays to pass to the target"
8447 "TaskTy cannot be the same as TaskWithPrivatesTy");
8448 (void)TaskTy;
8449 Value *Privates =
8450 Builder.CreateStructGEP(Ty: TaskWithPrivatesTy, Ptr: TaskWithPrivates, Idx: 1);
8451 for (unsigned int i = 0; i < NumOffloadingArrays; ++i)
8452 KernelLaunchArgs.push_back(
8453 Elt: Builder.CreateStructGEP(Ty: PrivatesTy, Ptr: Privates, Idx: i));
8454 }
8455
8456 if (HasShareds) {
8457 auto *ArgStructAlloca =
8458 dyn_cast<AllocaInst>(Val: StaleCI->getArgOperand(i: SharedArgsOperandNo));
8459 assert(ArgStructAlloca &&
8460 "Unable to find the alloca instruction corresponding to arguments "
8461 "for extracted function");
8462 auto *ArgStructType = cast<StructType>(Val: ArgStructAlloca->getAllocatedType());
8463
8464 AllocaInst *NewArgStructAlloca =
8465 Builder.CreateAlloca(Ty: ArgStructType, ArraySize: nullptr, Name: "structArg");
8466
8467 Value *SharedsSize =
8468 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: ArgStructType));
8469
8470 LoadInst *LoadShared = loadSharedDataFromTaskDescriptor(
8471 OMPIRBuilder&: OMPBuilder, Builder, TaskWithPrivates, TaskWithPrivatesTy);
8472
8473 Builder.CreateMemCpy(
8474 Dst: NewArgStructAlloca, DstAlign: NewArgStructAlloca->getAlign(), Src: LoadShared,
8475 SrcAlign: LoadShared->getPointerAlignment(DL: M.getDataLayout()), Size: SharedsSize);
8476 KernelLaunchArgs.push_back(Elt: NewArgStructAlloca);
8477 }
8478 OMPBuilder.createRuntimeFunctionCall(Callee: KernelLaunchFunction, Args: KernelLaunchArgs);
8479 Builder.CreateRetVoid();
8480 return ProxyFn;
8481}
8482static Type *getOffloadingArrayType(Value *V) {
8483
8484 if (auto *GEP = dyn_cast<GetElementPtrInst>(Val: V))
8485 return GEP->getSourceElementType();
8486 if (auto *Alloca = dyn_cast<AllocaInst>(Val: V))
8487 return Alloca->getAllocatedType();
8488
8489 llvm_unreachable("Unhandled Instruction type");
8490 return nullptr;
8491}
8492// This function returns a struct that has at most two members.
8493// The first member is always %struct.kmp_task_ompbuilder_t, that is the task
8494// descriptor. The second member, if needed, is a struct containing arrays
8495// that need to be passed to the offloaded target kernel. For example,
8496// if .offload_baseptrs, .offload_ptrs and .offload_sizes have to be passed to
8497// the target kernel and their types are [3 x ptr], [3 x ptr] and [3 x i64]
8498// respectively, then the types created by this function are
8499//
8500// %struct.privates = type { [3 x ptr], [3 x ptr], [3 x i64] }
8501// %struct.task_with_privates = type { %struct.kmp_task_ompbuilder_t,
8502// %struct.privates }
8503// %struct.task_with_privates is returned by this function.
8504// If there aren't any offloading arrays to pass to the target kernel,
8505// %struct.kmp_task_ompbuilder_t is returned.
8506static StructType *
8507createTaskWithPrivatesTy(OpenMPIRBuilder &OMPIRBuilder,
8508 ArrayRef<Value *> OffloadingArraysToPrivatize) {
8509
8510 if (OffloadingArraysToPrivatize.empty())
8511 return OMPIRBuilder.Task;
8512
8513 SmallVector<Type *, 4> StructFieldTypes;
8514 for (Value *V : OffloadingArraysToPrivatize) {
8515 assert(V->getType()->isPointerTy() &&
8516 "Expected pointer to array to privatize. Got a non-pointer value "
8517 "instead");
8518 Type *ArrayTy = getOffloadingArrayType(V);
8519 assert(ArrayTy && "ArrayType cannot be nullptr");
8520 StructFieldTypes.push_back(Elt: ArrayTy);
8521 }
8522 StructType *PrivatesStructTy =
8523 StructType::create(Elements: StructFieldTypes, Name: "struct.privates");
8524 return StructType::create(Elements: {OMPIRBuilder.Task, PrivatesStructTy},
8525 Name: "struct.task_with_privates");
8526}
8527static Error emitTargetOutlinedFunction(
8528 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
8529 TargetRegionEntryInfo &EntryInfo,
8530 const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
8531 Function *&OutlinedFn, Constant *&OutlinedFnID,
8532 SmallVectorImpl<Value *> &Inputs,
8533 OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
8534 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
8535
8536 OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
8537 [&](StringRef EntryFnName) {
8538 return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
8539 FuncName: EntryFnName, Inputs, CBFunc,
8540 ArgAccessorFuncCB);
8541 };
8542
8543 return OMPBuilder.emitTargetRegionFunction(
8544 EntryInfo, GenerateFunctionCallback&: GenerateOutlinedFunction, IsOffloadEntry, OutlinedFn,
8545 OutlinedFnID);
8546}
8547
8548OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask(
8549 TargetTaskBodyCallbackTy TaskBodyCB, Value *DeviceID, Value *RTLoc,
8550 OpenMPIRBuilder::InsertPointTy AllocaIP,
8551 const SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
8552 const TargetDataRTArgs &RTArgs, bool HasNoWait) {
8553
8554 // The following explains the code-gen scenario for the `target` directive. A
8555 // similar scneario is followed for other device-related directives (e.g.
8556 // `target enter data`) but in similar fashion since we only need to emit task
8557 // that encapsulates the proper runtime call.
8558 //
8559 // When we arrive at this function, the target region itself has been
8560 // outlined into the function OutlinedFn.
8561 // So at ths point, for
8562 // --------------------------------------------------------------
8563 // void user_code_that_offloads(...) {
8564 // omp target depend(..) map(from:a) map(to:b) private(i)
8565 // do i = 1, 10
8566 // a(i) = b(i) + n
8567 // }
8568 //
8569 // --------------------------------------------------------------
8570 //
8571 // we have
8572 //
8573 // --------------------------------------------------------------
8574 //
8575 // void user_code_that_offloads(...) {
8576 // %.offload_baseptrs = alloca [2 x ptr], align 8
8577 // %.offload_ptrs = alloca [2 x ptr], align 8
8578 // %.offload_mappers = alloca [2 x ptr], align 8
8579 // ;; target region has been outlined and now we need to
8580 // ;; offload to it via a target task.
8581 // }
8582 // void outlined_device_function(ptr a, ptr b, ptr n) {
8583 // n = *n_ptr;
8584 // do i = 1, 10
8585 // a(i) = b(i) + n
8586 // }
8587 //
8588 // We have to now do the following
8589 // (i) Make an offloading call to outlined_device_function using the OpenMP
8590 // RTL. See 'kernel_launch_function' in the pseudo code below. This is
8591 // emitted by emitKernelLaunch
8592 // (ii) Create a task entry point function that calls kernel_launch_function
8593 // and is the entry point for the target task. See
8594 // '@.omp_target_task_proxy_func in the pseudocode below.
8595 // (iii) Create a task with the task entry point created in (ii)
8596 //
8597 // That is we create the following
8598 // struct task_with_privates {
8599 // struct kmp_task_ompbuilder_t task_struct;
8600 // struct privates {
8601 // [2 x ptr] ; baseptrs
8602 // [2 x ptr] ; ptrs
8603 // [2 x i64] ; sizes
8604 // }
8605 // }
8606 // void user_code_that_offloads(...) {
8607 // %.offload_baseptrs = alloca [2 x ptr], align 8
8608 // %.offload_ptrs = alloca [2 x ptr], align 8
8609 // %.offload_sizes = alloca [2 x i64], align 8
8610 //
8611 // %structArg = alloca { ptr, ptr, ptr }, align 8
8612 // %strucArg[0] = a
8613 // %strucArg[1] = b
8614 // %strucArg[2] = &n
8615 //
8616 // target_task_with_privates = @__kmpc_omp_target_task_alloc(...,
8617 // sizeof(kmp_task_ompbuilder_t),
8618 // sizeof(structArg),
8619 // @.omp_target_task_proxy_func,
8620 // ...)
8621 // memcpy(target_task_with_privates->task_struct->shareds, %structArg,
8622 // sizeof(structArg))
8623 // memcpy(target_task_with_privates->privates->baseptrs,
8624 // offload_baseptrs, sizeof(offload_baseptrs)
8625 // memcpy(target_task_with_privates->privates->ptrs,
8626 // offload_ptrs, sizeof(offload_ptrs)
8627 // memcpy(target_task_with_privates->privates->sizes,
8628 // offload_sizes, sizeof(offload_sizes)
8629 // dependencies_array = ...
8630 // ;; if nowait not present
8631 // call @__kmpc_omp_wait_deps(..., dependencies_array)
8632 // call @__kmpc_omp_task_begin_if0(...)
8633 // call @ @.omp_target_task_proxy_func(i32 thread_id, ptr
8634 // %target_task_with_privates)
8635 // call @__kmpc_omp_task_complete_if0(...)
8636 // }
8637 //
8638 // define internal void @.omp_target_task_proxy_func(i32 %thread.id,
8639 // ptr %task) {
8640 // %structArg = alloca {ptr, ptr, ptr}
8641 // %task_ptr = getelementptr(%task, 0, 0)
8642 // %shared_data = load (getelementptr %task_ptr, 0, 0)
8643 // mempcy(%structArg, %shared_data, sizeof(%structArg))
8644 //
8645 // %offloading_arrays = getelementptr(%task, 0, 1)
8646 // %offload_baseptrs = getelementptr(%offloading_arrays, 0, 0)
8647 // %offload_ptrs = getelementptr(%offloading_arrays, 0, 1)
8648 // %offload_sizes = getelementptr(%offloading_arrays, 0, 2)
8649 // kernel_launch_function(%thread.id, %offload_baseptrs, %offload_ptrs,
8650 // %offload_sizes, %structArg)
8651 // }
8652 //
8653 // We need the proxy function because the signature of the task entry point
8654 // expected by kmpc_omp_task is always the same and will be different from
8655 // that of the kernel_launch function.
8656 //
8657 // kernel_launch_function is generated by emitKernelLaunch and has the
8658 // always_inline attribute. For this example, it'll look like so:
8659 // void kernel_launch_function(%thread_id, %offload_baseptrs, %offload_ptrs,
8660 // %offload_sizes, %structArg) alwaysinline {
8661 // %kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
8662 // ; load aggregated data from %structArg
8663 // ; setup kernel_args using offload_baseptrs, offload_ptrs and
8664 // ; offload_sizes
8665 // call i32 @__tgt_target_kernel(...,
8666 // outlined_device_function,
8667 // ptr %kernel_args)
8668 // }
8669 // void outlined_device_function(ptr a, ptr b, ptr n) {
8670 // n = *n_ptr;
8671 // do i = 1, 10
8672 // a(i) = b(i) + n
8673 // }
8674 //
8675 BasicBlock *TargetTaskBodyBB =
8676 splitBB(Builder, /*CreateBranch=*/true, Name: "target.task.body");
8677 BasicBlock *TargetTaskAllocaBB =
8678 splitBB(Builder, /*CreateBranch=*/true, Name: "target.task.alloca");
8679
8680 InsertPointTy TargetTaskAllocaIP(TargetTaskAllocaBB,
8681 TargetTaskAllocaBB->begin());
8682 InsertPointTy TargetTaskBodyIP(TargetTaskBodyBB, TargetTaskBodyBB->begin());
8683
8684 OutlineInfo OI;
8685 OI.EntryBB = TargetTaskAllocaBB;
8686 OI.OuterAllocaBB = AllocaIP.getBlock();
8687
8688 // Add the thread ID argument.
8689 SmallVector<Instruction *, 4> ToBeDeleted;
8690 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
8691 Builder, OuterAllocaIP: AllocaIP, ToBeDeleted, InnerAllocaIP: TargetTaskAllocaIP, Name: "global.tid", AsPtr: false));
8692
8693 // Generate the task body which will subsequently be outlined.
8694 Builder.restoreIP(IP: TargetTaskBodyIP);
8695 if (Error Err = TaskBodyCB(DeviceID, RTLoc, TargetTaskAllocaIP))
8696 return Err;
8697
8698 // The outliner (CodeExtractor) extract a sequence or vector of blocks that
8699 // it is given. These blocks are enumerated by
8700 // OpenMPIRBuilder::OutlineInfo::collectBlocks which expects the OI.ExitBlock
8701 // to be outside the region. In other words, OI.ExitBlock is expected to be
8702 // the start of the region after the outlining. We used to set OI.ExitBlock
8703 // to the InsertBlock after TaskBodyCB is done. This is fine in most cases
8704 // except when the task body is a single basic block. In that case,
8705 // OI.ExitBlock is set to the single task body block and will get left out of
8706 // the outlining process. So, simply create a new empty block to which we
8707 // uncoditionally branch from where TaskBodyCB left off
8708 OI.ExitBB = BasicBlock::Create(Context&: Builder.getContext(), Name: "target.task.cont");
8709 emitBlock(BB: OI.ExitBB, CurFn: Builder.GetInsertBlock()->getParent(),
8710 /*IsFinished=*/true);
8711
8712 SmallVector<Value *, 2> OffloadingArraysToPrivatize;
8713 bool NeedsTargetTask = HasNoWait && DeviceID;
8714 if (NeedsTargetTask) {
8715 for (auto *V :
8716 {RTArgs.BasePointersArray, RTArgs.PointersArray, RTArgs.MappersArray,
8717 RTArgs.MapNamesArray, RTArgs.MapTypesArray, RTArgs.MapTypesArrayEnd,
8718 RTArgs.SizesArray}) {
8719 if (V && !isa<ConstantPointerNull, GlobalVariable>(Val: V)) {
8720 OffloadingArraysToPrivatize.push_back(Elt: V);
8721 OI.ExcludeArgsFromAggregate.push_back(Elt: V);
8722 }
8723 }
8724 }
8725 OI.PostOutlineCB = [this, ToBeDeleted, Dependencies, NeedsTargetTask,
8726 DeviceID, OffloadingArraysToPrivatize](
8727 Function &OutlinedFn) mutable {
8728 assert(OutlinedFn.hasOneUse() &&
8729 "there must be a single user for the outlined function");
8730
8731 CallInst *StaleCI = cast<CallInst>(Val: OutlinedFn.user_back());
8732
8733 // The first argument of StaleCI is always the thread id.
8734 // The next few arguments are the pointers to offloading arrays
8735 // if any. (see OffloadingArraysToPrivatize)
8736 // Finally, all other local values that are live-in into the outlined region
8737 // end up in a structure whose pointer is passed as the last argument. This
8738 // piece of data is passed in the "shared" field of the task structure. So,
8739 // we know we have to pass shareds to the task if the number of arguments is
8740 // greater than OffloadingArraysToPrivatize.size() + 1 The 1 is for the
8741 // thread id. Further, for safety, we assert that the number of arguments of
8742 // StaleCI is exactly OffloadingArraysToPrivatize.size() + 2
8743 const unsigned int NumStaleCIArgs = StaleCI->arg_size();
8744 bool HasShareds = NumStaleCIArgs > OffloadingArraysToPrivatize.size() + 1;
8745 assert((!HasShareds ||
8746 NumStaleCIArgs == (OffloadingArraysToPrivatize.size() + 2)) &&
8747 "Wrong number of arguments for StaleCI when shareds are present");
8748 int SharedArgOperandNo =
8749 HasShareds ? OffloadingArraysToPrivatize.size() + 1 : 0;
8750
8751 StructType *TaskWithPrivatesTy =
8752 createTaskWithPrivatesTy(OMPIRBuilder&: *this, OffloadingArraysToPrivatize);
8753 StructType *PrivatesTy = nullptr;
8754
8755 if (!OffloadingArraysToPrivatize.empty())
8756 PrivatesTy =
8757 static_cast<StructType *>(TaskWithPrivatesTy->getElementType(N: 1));
8758
8759 Function *ProxyFn = emitTargetTaskProxyFunction(
8760 OMPBuilder&: *this, Builder, StaleCI, PrivatesTy, TaskWithPrivatesTy,
8761 NumOffloadingArrays: OffloadingArraysToPrivatize.size(), SharedArgsOperandNo: SharedArgOperandNo);
8762
8763 LLVM_DEBUG(dbgs() << "Proxy task entry function created: " << *ProxyFn
8764 << "\n");
8765
8766 Builder.SetInsertPoint(StaleCI);
8767
8768 // Gather the arguments for emitting the runtime call.
8769 uint32_t SrcLocStrSize;
8770 Constant *SrcLocStr =
8771 getOrCreateSrcLocStr(Loc: LocationDescription(Builder), SrcLocStrSize);
8772 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
8773
8774 // @__kmpc_omp_task_alloc or @__kmpc_omp_target_task_alloc
8775 //
8776 // If `HasNoWait == true`, we call @__kmpc_omp_target_task_alloc to provide
8777 // the DeviceID to the deferred task and also since
8778 // @__kmpc_omp_target_task_alloc creates an untied/async task.
8779 Function *TaskAllocFn =
8780 !NeedsTargetTask
8781 ? getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_alloc)
8782 : getOrCreateRuntimeFunctionPtr(
8783 FnID: OMPRTL___kmpc_omp_target_task_alloc);
8784
8785 // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
8786 // call.
8787 Value *ThreadID = getOrCreateThreadID(Ident);
8788
8789 // Argument - `sizeof_kmp_task_t` (TaskSize)
8790 // Tasksize refers to the size in bytes of kmp_task_t data structure
8791 // plus any other data to be passed to the target task, if any, which
8792 // is packed into a struct. kmp_task_t and the struct so created are
8793 // packed into a wrapper struct whose type is TaskWithPrivatesTy.
8794 Value *TaskSize = Builder.getInt64(
8795 C: M.getDataLayout().getTypeStoreSize(Ty: TaskWithPrivatesTy));
8796
8797 // Argument - `sizeof_shareds` (SharedsSize)
8798 // SharedsSize refers to the shareds array size in the kmp_task_t data
8799 // structure.
8800 Value *SharedsSize = Builder.getInt64(C: 0);
8801 if (HasShareds) {
8802 auto *ArgStructAlloca =
8803 dyn_cast<AllocaInst>(Val: StaleCI->getArgOperand(i: SharedArgOperandNo));
8804 assert(ArgStructAlloca &&
8805 "Unable to find the alloca instruction corresponding to arguments "
8806 "for extracted function");
8807 auto *ArgStructType =
8808 dyn_cast<StructType>(Val: ArgStructAlloca->getAllocatedType());
8809 assert(ArgStructType && "Unable to find struct type corresponding to "
8810 "arguments for extracted function");
8811 SharedsSize =
8812 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: ArgStructType));
8813 }
8814
8815 // Argument - `flags`
8816 // Task is tied iff (Flags & 1) == 1.
8817 // Task is untied iff (Flags & 1) == 0.
8818 // Task is final iff (Flags & 2) == 2.
8819 // Task is not final iff (Flags & 2) == 0.
8820 // A target task is not final and is untied.
8821 Value *Flags = Builder.getInt32(C: 0);
8822
8823 // Emit the @__kmpc_omp_task_alloc runtime call
8824 // The runtime call returns a pointer to an area where the task captured
8825 // variables must be copied before the task is run (TaskData)
8826 CallInst *TaskData = nullptr;
8827
8828 SmallVector<llvm::Value *> TaskAllocArgs = {
8829 /*loc_ref=*/Ident, /*gtid=*/ThreadID,
8830 /*flags=*/Flags,
8831 /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
8832 /*task_func=*/ProxyFn};
8833
8834 if (NeedsTargetTask) {
8835 assert(DeviceID && "Expected non-empty device ID.");
8836 TaskAllocArgs.push_back(Elt: DeviceID);
8837 }
8838
8839 TaskData = createRuntimeFunctionCall(Callee: TaskAllocFn, Args: TaskAllocArgs);
8840
8841 Align Alignment = TaskData->getPointerAlignment(DL: M.getDataLayout());
8842 if (HasShareds) {
8843 Value *Shareds = StaleCI->getArgOperand(i: SharedArgOperandNo);
8844 Value *TaskShareds = loadSharedDataFromTaskDescriptor(
8845 OMPIRBuilder&: *this, Builder, TaskWithPrivates: TaskData, TaskWithPrivatesTy);
8846 Builder.CreateMemCpy(Dst: TaskShareds, DstAlign: Alignment, Src: Shareds, SrcAlign: Alignment,
8847 Size: SharedsSize);
8848 }
8849 if (!OffloadingArraysToPrivatize.empty()) {
8850 Value *Privates =
8851 Builder.CreateStructGEP(Ty: TaskWithPrivatesTy, Ptr: TaskData, Idx: 1);
8852 for (unsigned int i = 0; i < OffloadingArraysToPrivatize.size(); ++i) {
8853 Value *PtrToPrivatize = OffloadingArraysToPrivatize[i];
8854 [[maybe_unused]] Type *ArrayType =
8855 getOffloadingArrayType(V: PtrToPrivatize);
8856 assert(ArrayType && "ArrayType cannot be nullptr");
8857
8858 Type *ElementType = PrivatesTy->getElementType(N: i);
8859 assert(ElementType == ArrayType &&
8860 "ElementType should match ArrayType");
8861 (void)ArrayType;
8862
8863 Value *Dst = Builder.CreateStructGEP(Ty: PrivatesTy, Ptr: Privates, Idx: i);
8864 Builder.CreateMemCpy(
8865 Dst, DstAlign: Alignment, Src: PtrToPrivatize, SrcAlign: Alignment,
8866 Size: Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: ElementType)));
8867 }
8868 }
8869
8870 Value *DepArray = emitTaskDependencies(OMPBuilder&: *this, Dependencies);
8871
8872 // ---------------------------------------------------------------
8873 // V5.2 13.8 target construct
8874 // If the nowait clause is present, execution of the target task
8875 // may be deferred. If the nowait clause is not present, the target task is
8876 // an included task.
8877 // ---------------------------------------------------------------
8878 // The above means that the lack of a nowait on the target construct
8879 // translates to '#pragma omp task if(0)'
8880 if (!NeedsTargetTask) {
8881 if (DepArray) {
8882 Function *TaskWaitFn =
8883 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_wait_deps);
8884 createRuntimeFunctionCall(
8885 Callee: TaskWaitFn,
8886 Args: {/*loc_ref=*/Ident, /*gtid=*/ThreadID,
8887 /*ndeps=*/Builder.getInt32(C: Dependencies.size()),
8888 /*dep_list=*/DepArray,
8889 /*ndeps_noalias=*/ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
8890 /*noalias_dep_list=*/
8891 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
8892 }
8893 // Included task.
8894 Function *TaskBeginFn =
8895 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_begin_if0);
8896 Function *TaskCompleteFn =
8897 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_complete_if0);
8898 createRuntimeFunctionCall(Callee: TaskBeginFn, Args: {Ident, ThreadID, TaskData});
8899 CallInst *CI = createRuntimeFunctionCall(Callee: ProxyFn, Args: {ThreadID, TaskData});
8900 CI->setDebugLoc(StaleCI->getDebugLoc());
8901 createRuntimeFunctionCall(Callee: TaskCompleteFn, Args: {Ident, ThreadID, TaskData});
8902 } else if (DepArray) {
8903 // HasNoWait - meaning the task may be deferred. Call
8904 // __kmpc_omp_task_with_deps if there are dependencies,
8905 // else call __kmpc_omp_task
8906 Function *TaskFn =
8907 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_with_deps);
8908 createRuntimeFunctionCall(
8909 Callee: TaskFn,
8910 Args: {Ident, ThreadID, TaskData, Builder.getInt32(C: Dependencies.size()),
8911 DepArray, ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
8912 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
8913 } else {
8914 // Emit the @__kmpc_omp_task runtime call to spawn the task
8915 Function *TaskFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task);
8916 createRuntimeFunctionCall(Callee: TaskFn, Args: {Ident, ThreadID, TaskData});
8917 }
8918
8919 StaleCI->eraseFromParent();
8920 for (Instruction *I : llvm::reverse(C&: ToBeDeleted))
8921 I->eraseFromParent();
8922 };
8923 addOutlineInfo(OI: std::move(OI));
8924
8925 LLVM_DEBUG(dbgs() << "Insert block after emitKernelLaunch = \n"
8926 << *(Builder.GetInsertBlock()) << "\n");
8927 LLVM_DEBUG(dbgs() << "Module after emitKernelLaunch = \n"
8928 << *(Builder.GetInsertBlock()->getParent()->getParent())
8929 << "\n");
8930 return Builder.saveIP();
8931}
8932
8933Error OpenMPIRBuilder::emitOffloadingArraysAndArgs(
8934 InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
8935 TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo,
8936 CustomMapperCallbackTy CustomMapperCB, bool IsNonContiguous,
8937 bool ForEndCall, function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
8938 if (Error Err =
8939 emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info,
8940 CustomMapperCB, IsNonContiguous, DeviceAddrCB))
8941 return Err;
8942 emitOffloadingArraysArgument(Builder, RTArgs, Info, ForEndCall);
8943 return Error::success();
8944}
8945
8946static void emitTargetCall(
8947 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
8948 OpenMPIRBuilder::InsertPointTy AllocaIP,
8949 OpenMPIRBuilder::TargetDataInfo &Info,
8950 const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
8951 const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
8952 Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID,
8953 SmallVectorImpl<Value *> &Args,
8954 OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
8955 OpenMPIRBuilder::CustomMapperCallbackTy CustomMapperCB,
8956 const SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
8957 bool HasNoWait, Value *DynCGroupMem,
8958 OMPDynGroupprivateFallbackType DynCGroupMemFallback) {
8959 // Generate a function call to the host fallback implementation of the target
8960 // region. This is called by the host when no offload entry was generated for
8961 // the target region and when the offloading call fails at runtime.
8962 auto &&EmitTargetCallFallbackCB = [&](OpenMPIRBuilder::InsertPointTy IP)
8963 -> OpenMPIRBuilder::InsertPointOrErrorTy {
8964 Builder.restoreIP(IP);
8965 OMPBuilder.createRuntimeFunctionCall(Callee: OutlinedFn, Args);
8966 return Builder.saveIP();
8967 };
8968
8969 bool HasDependencies = Dependencies.size() > 0;
8970 bool RequiresOuterTargetTask = HasNoWait || HasDependencies;
8971
8972 OpenMPIRBuilder::TargetKernelArgs KArgs;
8973
8974 auto TaskBodyCB =
8975 [&](Value *DeviceID, Value *RTLoc,
8976 IRBuilderBase::InsertPoint TargetTaskAllocaIP) -> Error {
8977 // Assume no error was returned because EmitTargetCallFallbackCB doesn't
8978 // produce any.
8979 llvm::OpenMPIRBuilder::InsertPointTy AfterIP = cantFail(ValOrErr: [&]() {
8980 // emitKernelLaunch makes the necessary runtime call to offload the
8981 // kernel. We then outline all that code into a separate function
8982 // ('kernel_launch_function' in the pseudo code above). This function is
8983 // then called by the target task proxy function (see
8984 // '@.omp_target_task_proxy_func' in the pseudo code above)
8985 // "@.omp_target_task_proxy_func' is generated by
8986 // emitTargetTaskProxyFunction.
8987 if (OutlinedFnID && DeviceID)
8988 return OMPBuilder.emitKernelLaunch(Loc: Builder, OutlinedFnID,
8989 EmitTargetCallFallbackCB, Args&: KArgs,
8990 DeviceID, RTLoc, AllocaIP: TargetTaskAllocaIP);
8991
8992 // We only need to do the outlining if `DeviceID` is set to avoid calling
8993 // `emitKernelLaunch` if we want to code-gen for the host; e.g. if we are
8994 // generating the `else` branch of an `if` clause.
8995 //
8996 // When OutlinedFnID is set to nullptr, then it's not an offloading call.
8997 // In this case, we execute the host implementation directly.
8998 return EmitTargetCallFallbackCB(OMPBuilder.Builder.saveIP());
8999 }());
9000
9001 OMPBuilder.Builder.restoreIP(IP: AfterIP);
9002 return Error::success();
9003 };
9004
9005 auto &&EmitTargetCallElse =
9006 [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
9007 OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
9008 // Assume no error was returned because EmitTargetCallFallbackCB doesn't
9009 // produce any.
9010 OpenMPIRBuilder::InsertPointTy AfterIP = cantFail(ValOrErr: [&]() {
9011 if (RequiresOuterTargetTask) {
9012 // Arguments that are intended to be directly forwarded to an
9013 // emitKernelLaunch call are pased as nullptr, since
9014 // OutlinedFnID=nullptr results in that call not being done.
9015 OpenMPIRBuilder::TargetDataRTArgs EmptyRTArgs;
9016 return OMPBuilder.emitTargetTask(TaskBodyCB, /*DeviceID=*/nullptr,
9017 /*RTLoc=*/nullptr, AllocaIP,
9018 Dependencies, RTArgs: EmptyRTArgs, HasNoWait);
9019 }
9020 return EmitTargetCallFallbackCB(Builder.saveIP());
9021 }());
9022
9023 Builder.restoreIP(IP: AfterIP);
9024 return Error::success();
9025 };
9026
9027 auto &&EmitTargetCallThen =
9028 [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
9029 OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
9030 Info.HasNoWait = HasNoWait;
9031 OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
9032 OpenMPIRBuilder::TargetDataRTArgs RTArgs;
9033 if (Error Err = OMPBuilder.emitOffloadingArraysAndArgs(
9034 AllocaIP, CodeGenIP: Builder.saveIP(), Info, RTArgs, CombinedInfo&: MapInfo, CustomMapperCB,
9035 /*IsNonContiguous=*/true,
9036 /*ForEndCall=*/false))
9037 return Err;
9038
9039 SmallVector<Value *, 3> NumTeamsC;
9040 for (auto [DefaultVal, RuntimeVal] :
9041 zip_equal(t: DefaultAttrs.MaxTeams, u: RuntimeAttrs.MaxTeams))
9042 NumTeamsC.push_back(Elt: RuntimeVal ? RuntimeVal
9043 : Builder.getInt32(C: DefaultVal));
9044
9045 // Calculate number of threads: 0 if no clauses specified, otherwise it is
9046 // the minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
9047 auto InitMaxThreadsClause = [&Builder](Value *Clause) {
9048 if (Clause)
9049 Clause = Builder.CreateIntCast(V: Clause, DestTy: Builder.getInt32Ty(),
9050 /*isSigned=*/false);
9051 return Clause;
9052 };
9053 auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
9054 if (Clause)
9055 Result =
9056 Result ? Builder.CreateSelect(C: Builder.CreateICmpULT(LHS: Result, RHS: Clause),
9057 True: Result, False: Clause)
9058 : Clause;
9059 };
9060
9061 // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
9062 // the NUM_THREADS clause is overriden by THREAD_LIMIT.
9063 SmallVector<Value *, 3> NumThreadsC;
9064 Value *MaxThreadsClause =
9065 RuntimeAttrs.TeamsThreadLimit.size() == 1
9066 ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
9067 : nullptr;
9068
9069 for (auto [TeamsVal, TargetVal] : zip_equal(
9070 t: RuntimeAttrs.TeamsThreadLimit, u: RuntimeAttrs.TargetThreadLimit)) {
9071 Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
9072 Value *NumThreads = InitMaxThreadsClause(TargetVal);
9073
9074 CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
9075 CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
9076
9077 NumThreadsC.push_back(Elt: NumThreads ? NumThreads : Builder.getInt32(C: 0));
9078 }
9079
9080 unsigned NumTargetItems = Info.NumberOfPtrs;
9081 uint32_t SrcLocStrSize;
9082 Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
9083 Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
9084 LocFlags: llvm::omp::IdentFlag(0), Reserve2Flags: 0);
9085
9086 Value *TripCount = RuntimeAttrs.LoopTripCount
9087 ? Builder.CreateIntCast(V: RuntimeAttrs.LoopTripCount,
9088 DestTy: Builder.getInt64Ty(),
9089 /*isSigned=*/false)
9090 : Builder.getInt64(C: 0);
9091
9092 // Request zero groupprivate bytes by default.
9093 if (!DynCGroupMem)
9094 DynCGroupMem = Builder.getInt32(C: 0);
9095
9096 KArgs = OpenMPIRBuilder::TargetKernelArgs(
9097 NumTargetItems, RTArgs, TripCount, NumTeamsC, NumThreadsC, DynCGroupMem,
9098 HasNoWait, DynCGroupMemFallback);
9099
9100 // Assume no error was returned because TaskBodyCB and
9101 // EmitTargetCallFallbackCB don't produce any.
9102 OpenMPIRBuilder::InsertPointTy AfterIP = cantFail(ValOrErr: [&]() {
9103 // The presence of certain clauses on the target directive require the
9104 // explicit generation of the target task.
9105 if (RequiresOuterTargetTask)
9106 return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID: RuntimeAttrs.DeviceID,
9107 RTLoc, AllocaIP, Dependencies,
9108 RTArgs: KArgs.RTArgs, HasNoWait: Info.HasNoWait);
9109
9110 return OMPBuilder.emitKernelLaunch(
9111 Loc: Builder, OutlinedFnID, EmitTargetCallFallbackCB, Args&: KArgs,
9112 DeviceID: RuntimeAttrs.DeviceID, RTLoc, AllocaIP);
9113 }());
9114
9115 Builder.restoreIP(IP: AfterIP);
9116 return Error::success();
9117 };
9118
9119 // If we don't have an ID for the target region, it means an offload entry
9120 // wasn't created. In this case we just run the host fallback directly and
9121 // ignore any potential 'if' clauses.
9122 if (!OutlinedFnID) {
9123 cantFail(Err: EmitTargetCallElse(AllocaIP, Builder.saveIP()));
9124 return;
9125 }
9126
9127 // If there's no 'if' clause, only generate the kernel launch code path.
9128 if (!IfCond) {
9129 cantFail(Err: EmitTargetCallThen(AllocaIP, Builder.saveIP()));
9130 return;
9131 }
9132
9133 cantFail(Err: OMPBuilder.emitIfClause(Cond: IfCond, ThenGen: EmitTargetCallThen,
9134 ElseGen: EmitTargetCallElse, AllocaIP));
9135}
9136
9137OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
9138 const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
9139 InsertPointTy CodeGenIP, TargetDataInfo &Info,
9140 TargetRegionEntryInfo &EntryInfo,
9141 const TargetKernelDefaultAttrs &DefaultAttrs,
9142 const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
9143 SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
9144 OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
9145 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
9146 CustomMapperCallbackTy CustomMapperCB,
9147 const SmallVector<DependData> &Dependencies, bool HasNowait,
9148 Value *DynCGroupMem, OMPDynGroupprivateFallbackType DynCGroupMemFallback) {
9149
9150 if (!updateToLocation(Loc))
9151 return InsertPointTy();
9152
9153 Builder.restoreIP(IP: CodeGenIP);
9154
9155 Function *OutlinedFn;
9156 Constant *OutlinedFnID = nullptr;
9157 // The target region is outlined into its own function. The LLVM IR for
9158 // the target region itself is generated using the callbacks CBFunc
9159 // and ArgAccessorFuncCB
9160 if (Error Err = emitTargetOutlinedFunction(
9161 OMPBuilder&: *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
9162 OutlinedFnID, Inputs, CBFunc, ArgAccessorFuncCB))
9163 return Err;
9164
9165 // If we are not on the target device, then we need to generate code
9166 // to make a remote call (offload) to the previously outlined function
9167 // that represents the target region. Do that now.
9168 if (!Config.isTargetDevice())
9169 emitTargetCall(OMPBuilder&: *this, Builder, AllocaIP, Info, DefaultAttrs, RuntimeAttrs,
9170 IfCond, OutlinedFn, OutlinedFnID, Args&: Inputs, GenMapInfoCB,
9171 CustomMapperCB, Dependencies, HasNoWait: HasNowait, DynCGroupMem,
9172 DynCGroupMemFallback);
9173 return Builder.saveIP();
9174}
9175
9176std::string OpenMPIRBuilder::getNameWithSeparators(ArrayRef<StringRef> Parts,
9177 StringRef FirstSeparator,
9178 StringRef Separator) {
9179 SmallString<128> Buffer;
9180 llvm::raw_svector_ostream OS(Buffer);
9181 StringRef Sep = FirstSeparator;
9182 for (StringRef Part : Parts) {
9183 OS << Sep << Part;
9184 Sep = Separator;
9185 }
9186 return OS.str().str();
9187}
9188
9189std::string
9190OpenMPIRBuilder::createPlatformSpecificName(ArrayRef<StringRef> Parts) const {
9191 return OpenMPIRBuilder::getNameWithSeparators(Parts, FirstSeparator: Config.firstSeparator(),
9192 Separator: Config.separator());
9193}
9194
9195GlobalVariable *OpenMPIRBuilder::getOrCreateInternalVariable(
9196 Type *Ty, const StringRef &Name, std::optional<unsigned> AddressSpace) {
9197 auto &Elem = *InternalVars.try_emplace(Key: Name, Args: nullptr).first;
9198 if (Elem.second) {
9199 assert(Elem.second->getValueType() == Ty &&
9200 "OMP internal variable has different type than requested");
9201 } else {
9202 // TODO: investigate the appropriate linkage type used for the global
9203 // variable for possibly changing that to internal or private, or maybe
9204 // create different versions of the function for different OMP internal
9205 // variables.
9206 const DataLayout &DL = M.getDataLayout();
9207 // TODO: Investigate why AMDGPU expects AS 0 for globals even though the
9208 // default global AS is 1.
9209 // See double-target-call-with-declare-target.f90 and
9210 // declare-target-vars-in-target-region.f90 libomptarget
9211 // tests.
9212 unsigned AddressSpaceVal = AddressSpace ? *AddressSpace
9213 : M.getTargetTriple().isAMDGPU()
9214 ? 0
9215 : DL.getDefaultGlobalsAddressSpace();
9216 auto Linkage = this->M.getTargetTriple().getArch() == Triple::wasm32
9217 ? GlobalValue::InternalLinkage
9218 : GlobalValue::CommonLinkage;
9219 auto *GV = new GlobalVariable(M, Ty, /*IsConstant=*/false, Linkage,
9220 Constant::getNullValue(Ty), Elem.first(),
9221 /*InsertBefore=*/nullptr,
9222 GlobalValue::NotThreadLocal, AddressSpaceVal);
9223 const llvm::Align TypeAlign = DL.getABITypeAlign(Ty);
9224 const llvm::Align PtrAlign = DL.getPointerABIAlignment(AS: AddressSpaceVal);
9225 GV->setAlignment(std::max(a: TypeAlign, b: PtrAlign));
9226 Elem.second = GV;
9227 }
9228
9229 return Elem.second;
9230}
9231
9232Value *OpenMPIRBuilder::getOMPCriticalRegionLock(StringRef CriticalName) {
9233 std::string Prefix = Twine("gomp_critical_user_", CriticalName).str();
9234 std::string Name = getNameWithSeparators(Parts: {Prefix, "var"}, FirstSeparator: ".", Separator: ".");
9235 return getOrCreateInternalVariable(Ty: KmpCriticalNameTy, Name);
9236}
9237
9238Value *OpenMPIRBuilder::getSizeInBytes(Value *BasePtr) {
9239 LLVMContext &Ctx = Builder.getContext();
9240 Value *Null =
9241 Constant::getNullValue(Ty: PointerType::getUnqual(C&: BasePtr->getContext()));
9242 Value *SizeGep =
9243 Builder.CreateGEP(Ty: BasePtr->getType(), Ptr: Null, IdxList: Builder.getInt32(C: 1));
9244 Value *SizePtrToInt = Builder.CreatePtrToInt(V: SizeGep, DestTy: Type::getInt64Ty(C&: Ctx));
9245 return SizePtrToInt;
9246}
9247
9248GlobalVariable *
9249OpenMPIRBuilder::createOffloadMaptypes(SmallVectorImpl<uint64_t> &Mappings,
9250 std::string VarName) {
9251 llvm::Constant *MaptypesArrayInit =
9252 llvm::ConstantDataArray::get(Context&: M.getContext(), Elts&: Mappings);
9253 auto *MaptypesArrayGlobal = new llvm::GlobalVariable(
9254 M, MaptypesArrayInit->getType(),
9255 /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MaptypesArrayInit,
9256 VarName);
9257 MaptypesArrayGlobal->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
9258 return MaptypesArrayGlobal;
9259}
9260
9261void OpenMPIRBuilder::createMapperAllocas(const LocationDescription &Loc,
9262 InsertPointTy AllocaIP,
9263 unsigned NumOperands,
9264 struct MapperAllocas &MapperAllocas) {
9265 if (!updateToLocation(Loc))
9266 return;
9267
9268 auto *ArrI8PtrTy = ArrayType::get(ElementType: Int8Ptr, NumElements: NumOperands);
9269 auto *ArrI64Ty = ArrayType::get(ElementType: Int64, NumElements: NumOperands);
9270 Builder.restoreIP(IP: AllocaIP);
9271 AllocaInst *ArgsBase = Builder.CreateAlloca(
9272 Ty: ArrI8PtrTy, /* ArraySize = */ nullptr, Name: ".offload_baseptrs");
9273 AllocaInst *Args = Builder.CreateAlloca(Ty: ArrI8PtrTy, /* ArraySize = */ nullptr,
9274 Name: ".offload_ptrs");
9275 AllocaInst *ArgSizes = Builder.CreateAlloca(
9276 Ty: ArrI64Ty, /* ArraySize = */ nullptr, Name: ".offload_sizes");
9277 updateToLocation(Loc);
9278 MapperAllocas.ArgsBase = ArgsBase;
9279 MapperAllocas.Args = Args;
9280 MapperAllocas.ArgSizes = ArgSizes;
9281}
9282
9283void OpenMPIRBuilder::emitMapperCall(const LocationDescription &Loc,
9284 Function *MapperFunc, Value *SrcLocInfo,
9285 Value *MaptypesArg, Value *MapnamesArg,
9286 struct MapperAllocas &MapperAllocas,
9287 int64_t DeviceID, unsigned NumOperands) {
9288 if (!updateToLocation(Loc))
9289 return;
9290
9291 auto *ArrI8PtrTy = ArrayType::get(ElementType: Int8Ptr, NumElements: NumOperands);
9292 auto *ArrI64Ty = ArrayType::get(ElementType: Int64, NumElements: NumOperands);
9293 Value *ArgsBaseGEP =
9294 Builder.CreateInBoundsGEP(Ty: ArrI8PtrTy, Ptr: MapperAllocas.ArgsBase,
9295 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 0)});
9296 Value *ArgsGEP =
9297 Builder.CreateInBoundsGEP(Ty: ArrI8PtrTy, Ptr: MapperAllocas.Args,
9298 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 0)});
9299 Value *ArgSizesGEP =
9300 Builder.CreateInBoundsGEP(Ty: ArrI64Ty, Ptr: MapperAllocas.ArgSizes,
9301 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 0)});
9302 Value *NullPtr =
9303 Constant::getNullValue(Ty: PointerType::getUnqual(C&: Int8Ptr->getContext()));
9304 createRuntimeFunctionCall(Callee: MapperFunc, Args: {SrcLocInfo, Builder.getInt64(C: DeviceID),
9305 Builder.getInt32(C: NumOperands),
9306 ArgsBaseGEP, ArgsGEP, ArgSizesGEP,
9307 MaptypesArg, MapnamesArg, NullPtr});
9308}
9309
9310void OpenMPIRBuilder::emitOffloadingArraysArgument(IRBuilderBase &Builder,
9311 TargetDataRTArgs &RTArgs,
9312 TargetDataInfo &Info,
9313 bool ForEndCall) {
9314 assert((!ForEndCall || Info.separateBeginEndCalls()) &&
9315 "expected region end call to runtime only when end call is separate");
9316 auto UnqualPtrTy = PointerType::getUnqual(C&: M.getContext());
9317 auto VoidPtrTy = UnqualPtrTy;
9318 auto VoidPtrPtrTy = UnqualPtrTy;
9319 auto Int64Ty = Type::getInt64Ty(C&: M.getContext());
9320 auto Int64PtrTy = UnqualPtrTy;
9321
9322 if (!Info.NumberOfPtrs) {
9323 RTArgs.BasePointersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
9324 RTArgs.PointersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
9325 RTArgs.SizesArray = ConstantPointerNull::get(T: Int64PtrTy);
9326 RTArgs.MapTypesArray = ConstantPointerNull::get(T: Int64PtrTy);
9327 RTArgs.MapNamesArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
9328 RTArgs.MappersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
9329 return;
9330 }
9331
9332 RTArgs.BasePointersArray = Builder.CreateConstInBoundsGEP2_32(
9333 Ty: ArrayType::get(ElementType: VoidPtrTy, NumElements: Info.NumberOfPtrs),
9334 Ptr: Info.RTArgs.BasePointersArray,
9335 /*Idx0=*/0, /*Idx1=*/0);
9336 RTArgs.PointersArray = Builder.CreateConstInBoundsGEP2_32(
9337 Ty: ArrayType::get(ElementType: VoidPtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.PointersArray,
9338 /*Idx0=*/0,
9339 /*Idx1=*/0);
9340 RTArgs.SizesArray = Builder.CreateConstInBoundsGEP2_32(
9341 Ty: ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.SizesArray,
9342 /*Idx0=*/0, /*Idx1=*/0);
9343 RTArgs.MapTypesArray = Builder.CreateConstInBoundsGEP2_32(
9344 Ty: ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs),
9345 Ptr: ForEndCall && Info.RTArgs.MapTypesArrayEnd ? Info.RTArgs.MapTypesArrayEnd
9346 : Info.RTArgs.MapTypesArray,
9347 /*Idx0=*/0,
9348 /*Idx1=*/0);
9349
9350 // Only emit the mapper information arrays if debug information is
9351 // requested.
9352 if (!Info.EmitDebug)
9353 RTArgs.MapNamesArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
9354 else
9355 RTArgs.MapNamesArray = Builder.CreateConstInBoundsGEP2_32(
9356 Ty: ArrayType::get(ElementType: VoidPtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.MapNamesArray,
9357 /*Idx0=*/0,
9358 /*Idx1=*/0);
9359 // If there is no user-defined mapper, set the mapper array to nullptr to
9360 // avoid an unnecessary data privatization
9361 if (!Info.HasMapper)
9362 RTArgs.MappersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
9363 else
9364 RTArgs.MappersArray =
9365 Builder.CreatePointerCast(V: Info.RTArgs.MappersArray, DestTy: VoidPtrPtrTy);
9366}
9367
9368void OpenMPIRBuilder::emitNonContiguousDescriptor(InsertPointTy AllocaIP,
9369 InsertPointTy CodeGenIP,
9370 MapInfosTy &CombinedInfo,
9371 TargetDataInfo &Info) {
9372 MapInfosTy::StructNonContiguousInfo &NonContigInfo =
9373 CombinedInfo.NonContigInfo;
9374
9375 // Build an array of struct descriptor_dim and then assign it to
9376 // offload_args.
9377 //
9378 // struct descriptor_dim {
9379 // uint64_t offset;
9380 // uint64_t count;
9381 // uint64_t stride
9382 // };
9383 Type *Int64Ty = Builder.getInt64Ty();
9384 StructType *DimTy = StructType::create(
9385 Context&: M.getContext(), Elements: ArrayRef<Type *>({Int64Ty, Int64Ty, Int64Ty}),
9386 Name: "struct.descriptor_dim");
9387
9388 enum { OffsetFD = 0, CountFD, StrideFD };
9389 // We need two index variable here since the size of "Dims" is the same as
9390 // the size of Components, however, the size of offset, count, and stride is
9391 // equal to the size of base declaration that is non-contiguous.
9392 for (unsigned I = 0, L = 0, E = NonContigInfo.Dims.size(); I < E; ++I) {
9393 // Skip emitting ir if dimension size is 1 since it cannot be
9394 // non-contiguous.
9395 if (NonContigInfo.Dims[I] == 1)
9396 continue;
9397 Builder.restoreIP(IP: AllocaIP);
9398 ArrayType *ArrayTy = ArrayType::get(ElementType: DimTy, NumElements: NonContigInfo.Dims[I]);
9399 AllocaInst *DimsAddr =
9400 Builder.CreateAlloca(Ty: ArrayTy, /* ArraySize = */ nullptr, Name: "dims");
9401 Builder.restoreIP(IP: CodeGenIP);
9402 for (unsigned II = 0, EE = NonContigInfo.Dims[I]; II < EE; ++II) {
9403 unsigned RevIdx = EE - II - 1;
9404 Value *DimsLVal = Builder.CreateInBoundsGEP(
9405 Ty: DimsAddr->getAllocatedType(), Ptr: DimsAddr,
9406 IdxList: {Builder.getInt64(C: 0), Builder.getInt64(C: II)});
9407 // Offset
9408 Value *OffsetLVal = Builder.CreateStructGEP(Ty: DimTy, Ptr: DimsLVal, Idx: OffsetFD);
9409 Builder.CreateAlignedStore(
9410 Val: NonContigInfo.Offsets[L][RevIdx], Ptr: OffsetLVal,
9411 Align: M.getDataLayout().getPrefTypeAlign(Ty: OffsetLVal->getType()));
9412 // Count
9413 Value *CountLVal = Builder.CreateStructGEP(Ty: DimTy, Ptr: DimsLVal, Idx: CountFD);
9414 Builder.CreateAlignedStore(
9415 Val: NonContigInfo.Counts[L][RevIdx], Ptr: CountLVal,
9416 Align: M.getDataLayout().getPrefTypeAlign(Ty: CountLVal->getType()));
9417 // Stride
9418 Value *StrideLVal = Builder.CreateStructGEP(Ty: DimTy, Ptr: DimsLVal, Idx: StrideFD);
9419 Builder.CreateAlignedStore(
9420 Val: NonContigInfo.Strides[L][RevIdx], Ptr: StrideLVal,
9421 Align: M.getDataLayout().getPrefTypeAlign(Ty: CountLVal->getType()));
9422 }
9423 // args[I] = &dims
9424 Builder.restoreIP(IP: CodeGenIP);
9425 Value *DAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
9426 V: DimsAddr, DestTy: Builder.getPtrTy());
9427 Value *P = Builder.CreateConstInBoundsGEP2_32(
9428 Ty: ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: Info.NumberOfPtrs),
9429 Ptr: Info.RTArgs.PointersArray, Idx0: 0, Idx1: I);
9430 Builder.CreateAlignedStore(
9431 Val: DAddr, Ptr: P, Align: M.getDataLayout().getPrefTypeAlign(Ty: Builder.getPtrTy()));
9432 ++L;
9433 }
9434}
9435
9436void OpenMPIRBuilder::emitUDMapperArrayInitOrDel(
9437 Function *MapperFn, Value *MapperHandle, Value *Base, Value *Begin,
9438 Value *Size, Value *MapType, Value *MapName, TypeSize ElementSize,
9439 BasicBlock *ExitBB, bool IsInit) {
9440 StringRef Prefix = IsInit ? ".init" : ".del";
9441
9442 // Evaluate if this is an array section.
9443 BasicBlock *BodyBB = BasicBlock::Create(
9444 Context&: M.getContext(), Name: createPlatformSpecificName(Parts: {"omp.array", Prefix}));
9445 Value *IsArray =
9446 Builder.CreateICmpSGT(LHS: Size, RHS: Builder.getInt64(C: 1), Name: "omp.arrayinit.isarray");
9447 Value *DeleteBit = Builder.CreateAnd(
9448 LHS: MapType,
9449 RHS: Builder.getInt64(
9450 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9451 OpenMPOffloadMappingFlags::OMP_MAP_DELETE)));
9452 Value *DeleteCond;
9453 Value *Cond;
9454 if (IsInit) {
9455 // base != begin?
9456 Value *BaseIsBegin = Builder.CreateICmpNE(LHS: Base, RHS: Begin);
9457 Cond = Builder.CreateOr(LHS: IsArray, RHS: BaseIsBegin);
9458 DeleteCond = Builder.CreateIsNull(
9459 Arg: DeleteBit,
9460 Name: createPlatformSpecificName(Parts: {"omp.array", Prefix, ".delete"}));
9461 } else {
9462 Cond = IsArray;
9463 DeleteCond = Builder.CreateIsNotNull(
9464 Arg: DeleteBit,
9465 Name: createPlatformSpecificName(Parts: {"omp.array", Prefix, ".delete"}));
9466 }
9467 Cond = Builder.CreateAnd(LHS: Cond, RHS: DeleteCond);
9468 Builder.CreateCondBr(Cond, True: BodyBB, False: ExitBB);
9469
9470 emitBlock(BB: BodyBB, CurFn: MapperFn);
9471 // Get the array size by multiplying element size and element number (i.e., \p
9472 // Size).
9473 Value *ArraySize = Builder.CreateNUWMul(LHS: Size, RHS: Builder.getInt64(C: ElementSize));
9474 // Remove OMP_MAP_TO and OMP_MAP_FROM from the map type, so that it achieves
9475 // memory allocation/deletion purpose only.
9476 Value *MapTypeArg = Builder.CreateAnd(
9477 LHS: MapType,
9478 RHS: Builder.getInt64(
9479 C: ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9480 OpenMPOffloadMappingFlags::OMP_MAP_TO |
9481 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
9482 MapTypeArg = Builder.CreateOr(
9483 LHS: MapTypeArg,
9484 RHS: Builder.getInt64(
9485 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9486 OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT)));
9487
9488 // Call the runtime API __tgt_push_mapper_component to fill up the runtime
9489 // data structure.
9490 Value *OffloadingArgs[] = {MapperHandle, Base, Begin,
9491 ArraySize, MapTypeArg, MapName};
9492 createRuntimeFunctionCall(
9493 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___tgt_push_mapper_component),
9494 Args: OffloadingArgs);
9495}
9496
9497Expected<Function *> OpenMPIRBuilder::emitUserDefinedMapper(
9498 function_ref<MapInfosOrErrorTy(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
9499 llvm::Value *BeginArg)>
9500 GenMapInfoCB,
9501 Type *ElemTy, StringRef FuncName, CustomMapperCallbackTy CustomMapperCB) {
9502 SmallVector<Type *> Params;
9503 Params.emplace_back(Args: Builder.getPtrTy());
9504 Params.emplace_back(Args: Builder.getPtrTy());
9505 Params.emplace_back(Args: Builder.getPtrTy());
9506 Params.emplace_back(Args: Builder.getInt64Ty());
9507 Params.emplace_back(Args: Builder.getInt64Ty());
9508 Params.emplace_back(Args: Builder.getPtrTy());
9509
9510 auto *FnTy =
9511 FunctionType::get(Result: Builder.getVoidTy(), Params, /* IsVarArg */ isVarArg: false);
9512
9513 SmallString<64> TyStr;
9514 raw_svector_ostream Out(TyStr);
9515 Function *MapperFn =
9516 Function::Create(Ty: FnTy, Linkage: GlobalValue::InternalLinkage, N: FuncName, M);
9517 MapperFn->addFnAttr(Kind: Attribute::NoInline);
9518 MapperFn->addFnAttr(Kind: Attribute::NoUnwind);
9519 MapperFn->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
9520 MapperFn->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
9521 MapperFn->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
9522 MapperFn->addParamAttr(ArgNo: 3, Kind: Attribute::NoUndef);
9523 MapperFn->addParamAttr(ArgNo: 4, Kind: Attribute::NoUndef);
9524 MapperFn->addParamAttr(ArgNo: 5, Kind: Attribute::NoUndef);
9525
9526 // Start the mapper function code generation.
9527 BasicBlock *EntryBB = BasicBlock::Create(Context&: M.getContext(), Name: "entry", Parent: MapperFn);
9528 auto SavedIP = Builder.saveIP();
9529 Builder.SetInsertPoint(EntryBB);
9530
9531 Value *MapperHandle = MapperFn->getArg(i: 0);
9532 Value *BaseIn = MapperFn->getArg(i: 1);
9533 Value *BeginIn = MapperFn->getArg(i: 2);
9534 Value *Size = MapperFn->getArg(i: 3);
9535 Value *MapType = MapperFn->getArg(i: 4);
9536 Value *MapName = MapperFn->getArg(i: 5);
9537
9538 // Compute the starting and end addresses of array elements.
9539 // Prepare common arguments for array initiation and deletion.
9540 // Convert the size in bytes into the number of array elements.
9541 TypeSize ElementSize = M.getDataLayout().getTypeStoreSize(Ty: ElemTy);
9542 Size = Builder.CreateExactUDiv(LHS: Size, RHS: Builder.getInt64(C: ElementSize));
9543 Value *PtrBegin = BeginIn;
9544 Value *PtrEnd = Builder.CreateGEP(Ty: ElemTy, Ptr: PtrBegin, IdxList: Size);
9545
9546 // Emit array initiation if this is an array section and \p MapType indicates
9547 // that memory allocation is required.
9548 BasicBlock *HeadBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.arraymap.head");
9549 emitUDMapperArrayInitOrDel(MapperFn, MapperHandle, Base: BaseIn, Begin: BeginIn, Size,
9550 MapType, MapName, ElementSize, ExitBB: HeadBB,
9551 /*IsInit=*/true);
9552
9553 // Emit a for loop to iterate through SizeArg of elements and map all of them.
9554
9555 // Emit the loop header block.
9556 emitBlock(BB: HeadBB, CurFn: MapperFn);
9557 BasicBlock *BodyBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.arraymap.body");
9558 BasicBlock *DoneBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.done");
9559 // Evaluate whether the initial condition is satisfied.
9560 Value *IsEmpty =
9561 Builder.CreateICmpEQ(LHS: PtrBegin, RHS: PtrEnd, Name: "omp.arraymap.isempty");
9562 Builder.CreateCondBr(Cond: IsEmpty, True: DoneBB, False: BodyBB);
9563
9564 // Emit the loop body block.
9565 emitBlock(BB: BodyBB, CurFn: MapperFn);
9566 BasicBlock *LastBB = BodyBB;
9567 PHINode *PtrPHI =
9568 Builder.CreatePHI(Ty: PtrBegin->getType(), NumReservedValues: 2, Name: "omp.arraymap.ptrcurrent");
9569 PtrPHI->addIncoming(V: PtrBegin, BB: HeadBB);
9570
9571 // Get map clause information. Fill up the arrays with all mapped variables.
9572 MapInfosOrErrorTy Info = GenMapInfoCB(Builder.saveIP(), PtrPHI, BeginIn);
9573 if (!Info)
9574 return Info.takeError();
9575
9576 // Call the runtime API __tgt_mapper_num_components to get the number of
9577 // pre-existing components.
9578 Value *OffloadingArgs[] = {MapperHandle};
9579 Value *PreviousSize = createRuntimeFunctionCall(
9580 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___tgt_mapper_num_components),
9581 Args: OffloadingArgs);
9582 Value *ShiftedPreviousSize =
9583 Builder.CreateShl(LHS: PreviousSize, RHS: Builder.getInt64(C: getFlagMemberOffset()));
9584
9585 // Fill up the runtime mapper handle for all components.
9586 for (unsigned I = 0; I < Info->BasePointers.size(); ++I) {
9587 Value *CurBaseArg = Info->BasePointers[I];
9588 Value *CurBeginArg = Info->Pointers[I];
9589 Value *CurSizeArg = Info->Sizes[I];
9590 Value *CurNameArg = Info->Names.size()
9591 ? Info->Names[I]
9592 : Constant::getNullValue(Ty: Builder.getPtrTy());
9593
9594 // Extract the MEMBER_OF field from the map type.
9595 Value *OriMapType = Builder.getInt64(
9596 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9597 Info->Types[I]));
9598 Value *MemberMapType =
9599 Builder.CreateNUWAdd(LHS: OriMapType, RHS: ShiftedPreviousSize);
9600
9601 // Combine the map type inherited from user-defined mapper with that
9602 // specified in the program. According to the OMP_MAP_TO and OMP_MAP_FROM
9603 // bits of the \a MapType, which is the input argument of the mapper
9604 // function, the following code will set the OMP_MAP_TO and OMP_MAP_FROM
9605 // bits of MemberMapType.
9606 // [OpenMP 5.0], 1.2.6. map-type decay.
9607 // | alloc | to | from | tofrom | release | delete
9608 // ----------------------------------------------------------
9609 // alloc | alloc | alloc | alloc | alloc | release | delete
9610 // to | alloc | to | alloc | to | release | delete
9611 // from | alloc | alloc | from | from | release | delete
9612 // tofrom | alloc | to | from | tofrom | release | delete
9613 Value *LeftToFrom = Builder.CreateAnd(
9614 LHS: MapType,
9615 RHS: Builder.getInt64(
9616 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9617 OpenMPOffloadMappingFlags::OMP_MAP_TO |
9618 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
9619 BasicBlock *AllocBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.alloc");
9620 BasicBlock *AllocElseBB =
9621 BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.alloc.else");
9622 BasicBlock *ToBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.to");
9623 BasicBlock *ToElseBB =
9624 BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.to.else");
9625 BasicBlock *FromBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.from");
9626 BasicBlock *EndBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.end");
9627 Value *IsAlloc = Builder.CreateIsNull(Arg: LeftToFrom);
9628 Builder.CreateCondBr(Cond: IsAlloc, True: AllocBB, False: AllocElseBB);
9629 // In case of alloc, clear OMP_MAP_TO and OMP_MAP_FROM.
9630 emitBlock(BB: AllocBB, CurFn: MapperFn);
9631 Value *AllocMapType = Builder.CreateAnd(
9632 LHS: MemberMapType,
9633 RHS: Builder.getInt64(
9634 C: ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9635 OpenMPOffloadMappingFlags::OMP_MAP_TO |
9636 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
9637 Builder.CreateBr(Dest: EndBB);
9638 emitBlock(BB: AllocElseBB, CurFn: MapperFn);
9639 Value *IsTo = Builder.CreateICmpEQ(
9640 LHS: LeftToFrom,
9641 RHS: Builder.getInt64(
9642 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9643 OpenMPOffloadMappingFlags::OMP_MAP_TO)));
9644 Builder.CreateCondBr(Cond: IsTo, True: ToBB, False: ToElseBB);
9645 // In case of to, clear OMP_MAP_FROM.
9646 emitBlock(BB: ToBB, CurFn: MapperFn);
9647 Value *ToMapType = Builder.CreateAnd(
9648 LHS: MemberMapType,
9649 RHS: Builder.getInt64(
9650 C: ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9651 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
9652 Builder.CreateBr(Dest: EndBB);
9653 emitBlock(BB: ToElseBB, CurFn: MapperFn);
9654 Value *IsFrom = Builder.CreateICmpEQ(
9655 LHS: LeftToFrom,
9656 RHS: Builder.getInt64(
9657 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9658 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
9659 Builder.CreateCondBr(Cond: IsFrom, True: FromBB, False: EndBB);
9660 // In case of from, clear OMP_MAP_TO.
9661 emitBlock(BB: FromBB, CurFn: MapperFn);
9662 Value *FromMapType = Builder.CreateAnd(
9663 LHS: MemberMapType,
9664 RHS: Builder.getInt64(
9665 C: ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9666 OpenMPOffloadMappingFlags::OMP_MAP_TO)));
9667 // In case of tofrom, do nothing.
9668 emitBlock(BB: EndBB, CurFn: MapperFn);
9669 LastBB = EndBB;
9670 PHINode *CurMapType =
9671 Builder.CreatePHI(Ty: Builder.getInt64Ty(), NumReservedValues: 4, Name: "omp.maptype");
9672 CurMapType->addIncoming(V: AllocMapType, BB: AllocBB);
9673 CurMapType->addIncoming(V: ToMapType, BB: ToBB);
9674 CurMapType->addIncoming(V: FromMapType, BB: FromBB);
9675 CurMapType->addIncoming(V: MemberMapType, BB: ToElseBB);
9676
9677 Value *OffloadingArgs[] = {MapperHandle, CurBaseArg, CurBeginArg,
9678 CurSizeArg, CurMapType, CurNameArg};
9679
9680 auto ChildMapperFn = CustomMapperCB(I);
9681 if (!ChildMapperFn)
9682 return ChildMapperFn.takeError();
9683 if (*ChildMapperFn) {
9684 // Call the corresponding mapper function.
9685 createRuntimeFunctionCall(Callee: *ChildMapperFn, Args: OffloadingArgs)
9686 ->setDoesNotThrow();
9687 } else {
9688 // Call the runtime API __tgt_push_mapper_component to fill up the runtime
9689 // data structure.
9690 createRuntimeFunctionCall(
9691 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___tgt_push_mapper_component),
9692 Args: OffloadingArgs);
9693 }
9694 }
9695
9696 // Update the pointer to point to the next element that needs to be mapped,
9697 // and check whether we have mapped all elements.
9698 Value *PtrNext = Builder.CreateConstGEP1_32(Ty: ElemTy, Ptr: PtrPHI, /*Idx0=*/1,
9699 Name: "omp.arraymap.next");
9700 PtrPHI->addIncoming(V: PtrNext, BB: LastBB);
9701 Value *IsDone = Builder.CreateICmpEQ(LHS: PtrNext, RHS: PtrEnd, Name: "omp.arraymap.isdone");
9702 BasicBlock *ExitBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.arraymap.exit");
9703 Builder.CreateCondBr(Cond: IsDone, True: ExitBB, False: BodyBB);
9704
9705 emitBlock(BB: ExitBB, CurFn: MapperFn);
9706 // Emit array deletion if this is an array section and \p MapType indicates
9707 // that deletion is required.
9708 emitUDMapperArrayInitOrDel(MapperFn, MapperHandle, Base: BaseIn, Begin: BeginIn, Size,
9709 MapType, MapName, ElementSize, ExitBB: DoneBB,
9710 /*IsInit=*/false);
9711
9712 // Emit the function exit block.
9713 emitBlock(BB: DoneBB, CurFn: MapperFn, /*IsFinished=*/true);
9714
9715 Builder.CreateRetVoid();
9716 Builder.restoreIP(IP: SavedIP);
9717 return MapperFn;
9718}
9719
9720Error OpenMPIRBuilder::emitOffloadingArrays(
9721 InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
9722 TargetDataInfo &Info, CustomMapperCallbackTy CustomMapperCB,
9723 bool IsNonContiguous,
9724 function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
9725
9726 // Reset the array information.
9727 Info.clearArrayInfo();
9728 Info.NumberOfPtrs = CombinedInfo.BasePointers.size();
9729
9730 if (Info.NumberOfPtrs == 0)
9731 return Error::success();
9732
9733 Builder.restoreIP(IP: AllocaIP);
9734 // Detect if we have any capture size requiring runtime evaluation of the
9735 // size so that a constant array could be eventually used.
9736 ArrayType *PointerArrayType =
9737 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: Info.NumberOfPtrs);
9738
9739 Info.RTArgs.BasePointersArray = Builder.CreateAlloca(
9740 Ty: PointerArrayType, /* ArraySize = */ nullptr, Name: ".offload_baseptrs");
9741
9742 Info.RTArgs.PointersArray = Builder.CreateAlloca(
9743 Ty: PointerArrayType, /* ArraySize = */ nullptr, Name: ".offload_ptrs");
9744 AllocaInst *MappersArray = Builder.CreateAlloca(
9745 Ty: PointerArrayType, /* ArraySize = */ nullptr, Name: ".offload_mappers");
9746 Info.RTArgs.MappersArray = MappersArray;
9747
9748 // If we don't have any VLA types or other types that require runtime
9749 // evaluation, we can use a constant array for the map sizes, otherwise we
9750 // need to fill up the arrays as we do for the pointers.
9751 Type *Int64Ty = Builder.getInt64Ty();
9752 SmallVector<Constant *> ConstSizes(CombinedInfo.Sizes.size(),
9753 ConstantInt::get(Ty: Int64Ty, V: 0));
9754 SmallBitVector RuntimeSizes(CombinedInfo.Sizes.size());
9755 for (unsigned I = 0, E = CombinedInfo.Sizes.size(); I < E; ++I) {
9756 if (auto *CI = dyn_cast<Constant>(Val: CombinedInfo.Sizes[I])) {
9757 if (!isa<ConstantExpr>(Val: CI) && !isa<GlobalValue>(Val: CI)) {
9758 if (IsNonContiguous &&
9759 static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9760 CombinedInfo.Types[I] &
9761 OpenMPOffloadMappingFlags::OMP_MAP_NON_CONTIG))
9762 ConstSizes[I] =
9763 ConstantInt::get(Ty: Int64Ty, V: CombinedInfo.NonContigInfo.Dims[I]);
9764 else
9765 ConstSizes[I] = CI;
9766 continue;
9767 }
9768 }
9769 RuntimeSizes.set(I);
9770 }
9771
9772 if (RuntimeSizes.all()) {
9773 ArrayType *SizeArrayType = ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs);
9774 Info.RTArgs.SizesArray = Builder.CreateAlloca(
9775 Ty: SizeArrayType, /* ArraySize = */ nullptr, Name: ".offload_sizes");
9776 restoreIPandDebugLoc(Builder, IP: CodeGenIP);
9777 } else {
9778 auto *SizesArrayInit = ConstantArray::get(
9779 T: ArrayType::get(ElementType: Int64Ty, NumElements: ConstSizes.size()), V: ConstSizes);
9780 std::string Name = createPlatformSpecificName(Parts: {"offload_sizes"});
9781 auto *SizesArrayGbl =
9782 new GlobalVariable(M, SizesArrayInit->getType(), /*isConstant=*/true,
9783 GlobalValue::PrivateLinkage, SizesArrayInit, Name);
9784 SizesArrayGbl->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
9785
9786 if (!RuntimeSizes.any()) {
9787 Info.RTArgs.SizesArray = SizesArrayGbl;
9788 } else {
9789 unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(AS: 0);
9790 Align OffloadSizeAlign = M.getDataLayout().getABIIntegerTypeAlignment(BitWidth: 64);
9791 ArrayType *SizeArrayType = ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs);
9792 AllocaInst *Buffer = Builder.CreateAlloca(
9793 Ty: SizeArrayType, /* ArraySize = */ nullptr, Name: ".offload_sizes");
9794 Buffer->setAlignment(OffloadSizeAlign);
9795 restoreIPandDebugLoc(Builder, IP: CodeGenIP);
9796 Builder.CreateMemCpy(
9797 Dst: Buffer, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: Buffer->getType()),
9798 Src: SizesArrayGbl, SrcAlign: OffloadSizeAlign,
9799 Size: Builder.getIntN(
9800 N: IndexSize,
9801 C: Buffer->getAllocationSize(DL: M.getDataLayout())->getFixedValue()));
9802
9803 Info.RTArgs.SizesArray = Buffer;
9804 }
9805 restoreIPandDebugLoc(Builder, IP: CodeGenIP);
9806 }
9807
9808 // The map types are always constant so we don't need to generate code to
9809 // fill arrays. Instead, we create an array constant.
9810 SmallVector<uint64_t, 4> Mapping;
9811 for (auto mapFlag : CombinedInfo.Types)
9812 Mapping.push_back(
9813 Elt: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9814 mapFlag));
9815 std::string MaptypesName = createPlatformSpecificName(Parts: {"offload_maptypes"});
9816 auto *MapTypesArrayGbl = createOffloadMaptypes(Mappings&: Mapping, VarName: MaptypesName);
9817 Info.RTArgs.MapTypesArray = MapTypesArrayGbl;
9818
9819 // The information types are only built if provided.
9820 if (!CombinedInfo.Names.empty()) {
9821 auto *MapNamesArrayGbl = createOffloadMapnames(
9822 Names&: CombinedInfo.Names, VarName: createPlatformSpecificName(Parts: {"offload_mapnames"}));
9823 Info.RTArgs.MapNamesArray = MapNamesArrayGbl;
9824 Info.EmitDebug = true;
9825 } else {
9826 Info.RTArgs.MapNamesArray =
9827 Constant::getNullValue(Ty: PointerType::getUnqual(C&: Builder.getContext()));
9828 Info.EmitDebug = false;
9829 }
9830
9831 // If there's a present map type modifier, it must not be applied to the end
9832 // of a region, so generate a separate map type array in that case.
9833 if (Info.separateBeginEndCalls()) {
9834 bool EndMapTypesDiffer = false;
9835 for (uint64_t &Type : Mapping) {
9836 if (Type & static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9837 OpenMPOffloadMappingFlags::OMP_MAP_PRESENT)) {
9838 Type &= ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9839 OpenMPOffloadMappingFlags::OMP_MAP_PRESENT);
9840 EndMapTypesDiffer = true;
9841 }
9842 }
9843 if (EndMapTypesDiffer) {
9844 MapTypesArrayGbl = createOffloadMaptypes(Mappings&: Mapping, VarName: MaptypesName);
9845 Info.RTArgs.MapTypesArrayEnd = MapTypesArrayGbl;
9846 }
9847 }
9848
9849 PointerType *PtrTy = Builder.getPtrTy();
9850 for (unsigned I = 0; I < Info.NumberOfPtrs; ++I) {
9851 Value *BPVal = CombinedInfo.BasePointers[I];
9852 Value *BP = Builder.CreateConstInBoundsGEP2_32(
9853 Ty: ArrayType::get(ElementType: PtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.BasePointersArray,
9854 Idx0: 0, Idx1: I);
9855 Builder.CreateAlignedStore(Val: BPVal, Ptr: BP,
9856 Align: M.getDataLayout().getPrefTypeAlign(Ty: PtrTy));
9857
9858 if (Info.requiresDevicePointerInfo()) {
9859 if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Pointer) {
9860 CodeGenIP = Builder.saveIP();
9861 Builder.restoreIP(IP: AllocaIP);
9862 Info.DevicePtrInfoMap[BPVal] = {BP, Builder.CreateAlloca(Ty: PtrTy)};
9863 Builder.restoreIP(IP: CodeGenIP);
9864 if (DeviceAddrCB)
9865 DeviceAddrCB(I, Info.DevicePtrInfoMap[BPVal].second);
9866 } else if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Address) {
9867 Info.DevicePtrInfoMap[BPVal] = {BP, BP};
9868 if (DeviceAddrCB)
9869 DeviceAddrCB(I, BP);
9870 }
9871 }
9872
9873 Value *PVal = CombinedInfo.Pointers[I];
9874 Value *P = Builder.CreateConstInBoundsGEP2_32(
9875 Ty: ArrayType::get(ElementType: PtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.PointersArray, Idx0: 0,
9876 Idx1: I);
9877 // TODO: Check alignment correct.
9878 Builder.CreateAlignedStore(Val: PVal, Ptr: P,
9879 Align: M.getDataLayout().getPrefTypeAlign(Ty: PtrTy));
9880
9881 if (RuntimeSizes.test(Idx: I)) {
9882 Value *S = Builder.CreateConstInBoundsGEP2_32(
9883 Ty: ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.SizesArray,
9884 /*Idx0=*/0,
9885 /*Idx1=*/I);
9886 Builder.CreateAlignedStore(Val: Builder.CreateIntCast(V: CombinedInfo.Sizes[I],
9887 DestTy: Int64Ty,
9888 /*isSigned=*/true),
9889 Ptr: S, Align: M.getDataLayout().getPrefTypeAlign(Ty: PtrTy));
9890 }
9891 // Fill up the mapper array.
9892 unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(AS: 0);
9893 Value *MFunc = ConstantPointerNull::get(T: PtrTy);
9894
9895 auto CustomMFunc = CustomMapperCB(I);
9896 if (!CustomMFunc)
9897 return CustomMFunc.takeError();
9898 if (*CustomMFunc)
9899 MFunc = Builder.CreatePointerCast(V: *CustomMFunc, DestTy: PtrTy);
9900
9901 Value *MAddr = Builder.CreateInBoundsGEP(
9902 Ty: MappersArray->getAllocatedType(), Ptr: MappersArray,
9903 IdxList: {Builder.getIntN(N: IndexSize, C: 0), Builder.getIntN(N: IndexSize, C: I)});
9904 Builder.CreateAlignedStore(
9905 Val: MFunc, Ptr: MAddr, Align: M.getDataLayout().getPrefTypeAlign(Ty: MAddr->getType()));
9906 }
9907
9908 if (!IsNonContiguous || CombinedInfo.NonContigInfo.Offsets.empty() ||
9909 Info.NumberOfPtrs == 0)
9910 return Error::success();
9911 emitNonContiguousDescriptor(AllocaIP, CodeGenIP, CombinedInfo, Info);
9912 return Error::success();
9913}
9914
9915void OpenMPIRBuilder::emitBranch(BasicBlock *Target) {
9916 BasicBlock *CurBB = Builder.GetInsertBlock();
9917
9918 if (!CurBB || CurBB->getTerminator()) {
9919 // If there is no insert point or the previous block is already
9920 // terminated, don't touch it.
9921 } else {
9922 // Otherwise, create a fall-through branch.
9923 Builder.CreateBr(Dest: Target);
9924 }
9925
9926 Builder.ClearInsertionPoint();
9927}
9928
9929void OpenMPIRBuilder::emitBlock(BasicBlock *BB, Function *CurFn,
9930 bool IsFinished) {
9931 BasicBlock *CurBB = Builder.GetInsertBlock();
9932
9933 // Fall out of the current block (if necessary).
9934 emitBranch(Target: BB);
9935
9936 if (IsFinished && BB->use_empty()) {
9937 BB->eraseFromParent();
9938 return;
9939 }
9940
9941 // Place the block after the current block, if possible, or else at
9942 // the end of the function.
9943 if (CurBB && CurBB->getParent())
9944 CurFn->insert(Position: std::next(x: CurBB->getIterator()), BB);
9945 else
9946 CurFn->insert(Position: CurFn->end(), BB);
9947 Builder.SetInsertPoint(BB);
9948}
9949
9950Error OpenMPIRBuilder::emitIfClause(Value *Cond, BodyGenCallbackTy ThenGen,
9951 BodyGenCallbackTy ElseGen,
9952 InsertPointTy AllocaIP) {
9953 // If the condition constant folds and can be elided, try to avoid emitting
9954 // the condition and the dead arm of the if/else.
9955 if (auto *CI = dyn_cast<ConstantInt>(Val: Cond)) {
9956 auto CondConstant = CI->getSExtValue();
9957 if (CondConstant)
9958 return ThenGen(AllocaIP, Builder.saveIP());
9959
9960 return ElseGen(AllocaIP, Builder.saveIP());
9961 }
9962
9963 Function *CurFn = Builder.GetInsertBlock()->getParent();
9964
9965 // Otherwise, the condition did not fold, or we couldn't elide it. Just
9966 // emit the conditional branch.
9967 BasicBlock *ThenBlock = BasicBlock::Create(Context&: M.getContext(), Name: "omp_if.then");
9968 BasicBlock *ElseBlock = BasicBlock::Create(Context&: M.getContext(), Name: "omp_if.else");
9969 BasicBlock *ContBlock = BasicBlock::Create(Context&: M.getContext(), Name: "omp_if.end");
9970 Builder.CreateCondBr(Cond, True: ThenBlock, False: ElseBlock);
9971 // Emit the 'then' code.
9972 emitBlock(BB: ThenBlock, CurFn);
9973 if (Error Err = ThenGen(AllocaIP, Builder.saveIP()))
9974 return Err;
9975 emitBranch(Target: ContBlock);
9976 // Emit the 'else' code if present.
9977 // There is no need to emit line number for unconditional branch.
9978 emitBlock(BB: ElseBlock, CurFn);
9979 if (Error Err = ElseGen(AllocaIP, Builder.saveIP()))
9980 return Err;
9981 // There is no need to emit line number for unconditional branch.
9982 emitBranch(Target: ContBlock);
9983 // Emit the continuation block for code after the if.
9984 emitBlock(BB: ContBlock, CurFn, /*IsFinished=*/true);
9985 return Error::success();
9986}
9987
9988bool OpenMPIRBuilder::checkAndEmitFlushAfterAtomic(
9989 const LocationDescription &Loc, llvm::AtomicOrdering AO, AtomicKind AK) {
9990 assert(!(AO == AtomicOrdering::NotAtomic ||
9991 AO == llvm::AtomicOrdering::Unordered) &&
9992 "Unexpected Atomic Ordering.");
9993
9994 bool Flush = false;
9995 llvm::AtomicOrdering FlushAO = AtomicOrdering::Monotonic;
9996
9997 switch (AK) {
9998 case Read:
9999 if (AO == AtomicOrdering::Acquire || AO == AtomicOrdering::AcquireRelease ||
10000 AO == AtomicOrdering::SequentiallyConsistent) {
10001 FlushAO = AtomicOrdering::Acquire;
10002 Flush = true;
10003 }
10004 break;
10005 case Write:
10006 case Compare:
10007 case Update:
10008 if (AO == AtomicOrdering::Release || AO == AtomicOrdering::AcquireRelease ||
10009 AO == AtomicOrdering::SequentiallyConsistent) {
10010 FlushAO = AtomicOrdering::Release;
10011 Flush = true;
10012 }
10013 break;
10014 case Capture:
10015 switch (AO) {
10016 case AtomicOrdering::Acquire:
10017 FlushAO = AtomicOrdering::Acquire;
10018 Flush = true;
10019 break;
10020 case AtomicOrdering::Release:
10021 FlushAO = AtomicOrdering::Release;
10022 Flush = true;
10023 break;
10024 case AtomicOrdering::AcquireRelease:
10025 case AtomicOrdering::SequentiallyConsistent:
10026 FlushAO = AtomicOrdering::AcquireRelease;
10027 Flush = true;
10028 break;
10029 default:
10030 // do nothing - leave silently.
10031 break;
10032 }
10033 }
10034
10035 if (Flush) {
10036 // Currently Flush RT call still doesn't take memory_ordering, so for when
10037 // that happens, this tries to do the resolution of which atomic ordering
10038 // to use with but issue the flush call
10039 // TODO: pass `FlushAO` after memory ordering support is added
10040 (void)FlushAO;
10041 emitFlush(Loc);
10042 }
10043
10044 // for AO == AtomicOrdering::Monotonic and all other case combinations
10045 // do nothing
10046 return Flush;
10047}
10048
10049OpenMPIRBuilder::InsertPointTy
10050OpenMPIRBuilder::createAtomicRead(const LocationDescription &Loc,
10051 AtomicOpValue &X, AtomicOpValue &V,
10052 AtomicOrdering AO, InsertPointTy AllocaIP) {
10053 if (!updateToLocation(Loc))
10054 return Loc.IP;
10055
10056 assert(X.Var->getType()->isPointerTy() &&
10057 "OMP Atomic expects a pointer to target memory");
10058 Type *XElemTy = X.ElemTy;
10059 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
10060 XElemTy->isPointerTy() || XElemTy->isStructTy()) &&
10061 "OMP atomic read expected a scalar type");
10062
10063 Value *XRead = nullptr;
10064
10065 if (XElemTy->isIntegerTy()) {
10066 LoadInst *XLD =
10067 Builder.CreateLoad(Ty: XElemTy, Ptr: X.Var, isVolatile: X.IsVolatile, Name: "omp.atomic.read");
10068 XLD->setAtomic(Ordering: AO);
10069 XRead = cast<Value>(Val: XLD);
10070 } else if (XElemTy->isStructTy()) {
10071 // FIXME: Add checks to ensure __atomic_load is emitted iff the
10072 // target does not support `atomicrmw` of the size of the struct
10073 LoadInst *OldVal = Builder.CreateLoad(Ty: XElemTy, Ptr: X.Var, Name: "omp.atomic.read");
10074 OldVal->setAtomic(Ordering: AO);
10075 const DataLayout &DL = OldVal->getModule()->getDataLayout();
10076 unsigned LoadSize = DL.getTypeStoreSize(Ty: XElemTy);
10077 OpenMPIRBuilder::AtomicInfo atomicInfo(
10078 &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
10079 OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X.Var);
10080 auto AtomicLoadRes = atomicInfo.EmitAtomicLoadLibcall(AO);
10081 XRead = AtomicLoadRes.first;
10082 OldVal->eraseFromParent();
10083 } else {
10084 // We need to perform atomic op as integer
10085 IntegerType *IntCastTy =
10086 IntegerType::get(C&: M.getContext(), NumBits: XElemTy->getScalarSizeInBits());
10087 LoadInst *XLoad =
10088 Builder.CreateLoad(Ty: IntCastTy, Ptr: X.Var, isVolatile: X.IsVolatile, Name: "omp.atomic.load");
10089 XLoad->setAtomic(Ordering: AO);
10090 if (XElemTy->isFloatingPointTy()) {
10091 XRead = Builder.CreateBitCast(V: XLoad, DestTy: XElemTy, Name: "atomic.flt.cast");
10092 } else {
10093 XRead = Builder.CreateIntToPtr(V: XLoad, DestTy: XElemTy, Name: "atomic.ptr.cast");
10094 }
10095 }
10096 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Read);
10097 Builder.CreateStore(Val: XRead, Ptr: V.Var, isVolatile: V.IsVolatile);
10098 return Builder.saveIP();
10099}
10100
10101OpenMPIRBuilder::InsertPointTy
10102OpenMPIRBuilder::createAtomicWrite(const LocationDescription &Loc,
10103 AtomicOpValue &X, Value *Expr,
10104 AtomicOrdering AO, InsertPointTy AllocaIP) {
10105 if (!updateToLocation(Loc))
10106 return Loc.IP;
10107
10108 assert(X.Var->getType()->isPointerTy() &&
10109 "OMP Atomic expects a pointer to target memory");
10110 Type *XElemTy = X.ElemTy;
10111 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
10112 XElemTy->isPointerTy() || XElemTy->isStructTy()) &&
10113 "OMP atomic write expected a scalar type");
10114
10115 if (XElemTy->isIntegerTy()) {
10116 StoreInst *XSt = Builder.CreateStore(Val: Expr, Ptr: X.Var, isVolatile: X.IsVolatile);
10117 XSt->setAtomic(Ordering: AO);
10118 } else if (XElemTy->isStructTy()) {
10119 LoadInst *OldVal = Builder.CreateLoad(Ty: XElemTy, Ptr: X.Var, Name: "omp.atomic.read");
10120 const DataLayout &DL = OldVal->getModule()->getDataLayout();
10121 unsigned LoadSize = DL.getTypeStoreSize(Ty: XElemTy);
10122 OpenMPIRBuilder::AtomicInfo atomicInfo(
10123 &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
10124 OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X.Var);
10125 atomicInfo.EmitAtomicStoreLibcall(AO, Source: Expr);
10126 OldVal->eraseFromParent();
10127 } else {
10128 // We need to bitcast and perform atomic op as integers
10129 IntegerType *IntCastTy =
10130 IntegerType::get(C&: M.getContext(), NumBits: XElemTy->getScalarSizeInBits());
10131 Value *ExprCast =
10132 Builder.CreateBitCast(V: Expr, DestTy: IntCastTy, Name: "atomic.src.int.cast");
10133 StoreInst *XSt = Builder.CreateStore(Val: ExprCast, Ptr: X.Var, isVolatile: X.IsVolatile);
10134 XSt->setAtomic(Ordering: AO);
10135 }
10136
10137 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Write);
10138 return Builder.saveIP();
10139}
10140
10141OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createAtomicUpdate(
10142 const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
10143 Value *Expr, AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
10144 AtomicUpdateCallbackTy &UpdateOp, bool IsXBinopExpr,
10145 bool IsIgnoreDenormalMode, bool IsFineGrainedMemory, bool IsRemoteMemory) {
10146 assert(!isConflictIP(Loc.IP, AllocaIP) && "IPs must not be ambiguous");
10147 if (!updateToLocation(Loc))
10148 return Loc.IP;
10149
10150 LLVM_DEBUG({
10151 Type *XTy = X.Var->getType();
10152 assert(XTy->isPointerTy() &&
10153 "OMP Atomic expects a pointer to target memory");
10154 Type *XElemTy = X.ElemTy;
10155 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
10156 XElemTy->isPointerTy()) &&
10157 "OMP atomic update expected a scalar type");
10158 assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
10159 (RMWOp != AtomicRMWInst::UMax) && (RMWOp != AtomicRMWInst::UMin) &&
10160 "OpenMP atomic does not support LT or GT operations");
10161 });
10162
10163 Expected<std::pair<Value *, Value *>> AtomicResult = emitAtomicUpdate(
10164 AllocaIP, X: X.Var, XElemTy: X.ElemTy, Expr, AO, RMWOp, UpdateOp, VolatileX: X.IsVolatile,
10165 IsXBinopExpr, IsIgnoreDenormalMode, IsFineGrainedMemory, IsRemoteMemory);
10166 if (!AtomicResult)
10167 return AtomicResult.takeError();
10168 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Update);
10169 return Builder.saveIP();
10170}
10171
10172// FIXME: Duplicating AtomicExpand
10173Value *OpenMPIRBuilder::emitRMWOpAsInstruction(Value *Src1, Value *Src2,
10174 AtomicRMWInst::BinOp RMWOp) {
10175 switch (RMWOp) {
10176 case AtomicRMWInst::Add:
10177 return Builder.CreateAdd(LHS: Src1, RHS: Src2);
10178 case AtomicRMWInst::Sub:
10179 return Builder.CreateSub(LHS: Src1, RHS: Src2);
10180 case AtomicRMWInst::And:
10181 return Builder.CreateAnd(LHS: Src1, RHS: Src2);
10182 case AtomicRMWInst::Nand:
10183 return Builder.CreateNeg(V: Builder.CreateAnd(LHS: Src1, RHS: Src2));
10184 case AtomicRMWInst::Or:
10185 return Builder.CreateOr(LHS: Src1, RHS: Src2);
10186 case AtomicRMWInst::Xor:
10187 return Builder.CreateXor(LHS: Src1, RHS: Src2);
10188 case AtomicRMWInst::Xchg:
10189 case AtomicRMWInst::FAdd:
10190 case AtomicRMWInst::FSub:
10191 case AtomicRMWInst::BAD_BINOP:
10192 case AtomicRMWInst::Max:
10193 case AtomicRMWInst::Min:
10194 case AtomicRMWInst::UMax:
10195 case AtomicRMWInst::UMin:
10196 case AtomicRMWInst::FMax:
10197 case AtomicRMWInst::FMin:
10198 case AtomicRMWInst::FMaximum:
10199 case AtomicRMWInst::FMinimum:
10200 case AtomicRMWInst::UIncWrap:
10201 case AtomicRMWInst::UDecWrap:
10202 case AtomicRMWInst::USubCond:
10203 case AtomicRMWInst::USubSat:
10204 llvm_unreachable("Unsupported atomic update operation");
10205 }
10206 llvm_unreachable("Unsupported atomic update operation");
10207}
10208
10209Expected<std::pair<Value *, Value *>> OpenMPIRBuilder::emitAtomicUpdate(
10210 InsertPointTy AllocaIP, Value *X, Type *XElemTy, Value *Expr,
10211 AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
10212 AtomicUpdateCallbackTy &UpdateOp, bool VolatileX, bool IsXBinopExpr,
10213 bool IsIgnoreDenormalMode, bool IsFineGrainedMemory, bool IsRemoteMemory) {
10214 // TODO: handle the case where XElemTy is not byte-sized or not a power of 2
10215 // or a complex datatype.
10216 bool emitRMWOp = false;
10217 switch (RMWOp) {
10218 case AtomicRMWInst::Add:
10219 case AtomicRMWInst::And:
10220 case AtomicRMWInst::Nand:
10221 case AtomicRMWInst::Or:
10222 case AtomicRMWInst::Xor:
10223 case AtomicRMWInst::Xchg:
10224 emitRMWOp = XElemTy;
10225 break;
10226 case AtomicRMWInst::Sub:
10227 emitRMWOp = (IsXBinopExpr && XElemTy);
10228 break;
10229 default:
10230 emitRMWOp = false;
10231 }
10232 emitRMWOp &= XElemTy->isIntegerTy();
10233
10234 std::pair<Value *, Value *> Res;
10235 if (emitRMWOp) {
10236 AtomicRMWInst *RMWInst =
10237 Builder.CreateAtomicRMW(Op: RMWOp, Ptr: X, Val: Expr, Align: llvm::MaybeAlign(), Ordering: AO);
10238 if (T.isAMDGPU()) {
10239 if (IsIgnoreDenormalMode)
10240 RMWInst->setMetadata(Kind: "amdgpu.ignore.denormal.mode",
10241 Node: llvm::MDNode::get(Context&: Builder.getContext(), MDs: {}));
10242 if (!IsFineGrainedMemory)
10243 RMWInst->setMetadata(Kind: "amdgpu.no.fine.grained.memory",
10244 Node: llvm::MDNode::get(Context&: Builder.getContext(), MDs: {}));
10245 if (!IsRemoteMemory)
10246 RMWInst->setMetadata(Kind: "amdgpu.no.remote.memory",
10247 Node: llvm::MDNode::get(Context&: Builder.getContext(), MDs: {}));
10248 }
10249 Res.first = RMWInst;
10250 // not needed except in case of postfix captures. Generate anyway for
10251 // consistency with the else part. Will be removed with any DCE pass.
10252 // AtomicRMWInst::Xchg does not have a coressponding instruction.
10253 if (RMWOp == AtomicRMWInst::Xchg)
10254 Res.second = Res.first;
10255 else
10256 Res.second = emitRMWOpAsInstruction(Src1: Res.first, Src2: Expr, RMWOp);
10257 } else if (RMWOp == llvm::AtomicRMWInst::BinOp::BAD_BINOP &&
10258 XElemTy->isStructTy()) {
10259 LoadInst *OldVal =
10260 Builder.CreateLoad(Ty: XElemTy, Ptr: X, Name: X->getName() + ".atomic.load");
10261 OldVal->setAtomic(Ordering: AO);
10262 const DataLayout &LoadDL = OldVal->getModule()->getDataLayout();
10263 unsigned LoadSize =
10264 LoadDL.getTypeStoreSize(Ty: OldVal->getPointerOperand()->getType());
10265
10266 OpenMPIRBuilder::AtomicInfo atomicInfo(
10267 &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
10268 OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X);
10269 auto AtomicLoadRes = atomicInfo.EmitAtomicLoadLibcall(AO);
10270 BasicBlock *CurBB = Builder.GetInsertBlock();
10271 Instruction *CurBBTI = CurBB->getTerminator();
10272 CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
10273 BasicBlock *ExitBB =
10274 CurBB->splitBasicBlock(I: CurBBTI, BBName: X->getName() + ".atomic.exit");
10275 BasicBlock *ContBB = CurBB->splitBasicBlock(I: CurBB->getTerminator(),
10276 BBName: X->getName() + ".atomic.cont");
10277 ContBB->getTerminator()->eraseFromParent();
10278 Builder.restoreIP(IP: AllocaIP);
10279 AllocaInst *NewAtomicAddr = Builder.CreateAlloca(Ty: XElemTy);
10280 NewAtomicAddr->setName(X->getName() + "x.new.val");
10281 Builder.SetInsertPoint(ContBB);
10282 llvm::PHINode *PHI = Builder.CreatePHI(Ty: OldVal->getType(), NumReservedValues: 2);
10283 PHI->addIncoming(V: AtomicLoadRes.first, BB: CurBB);
10284 Value *OldExprVal = PHI;
10285 Expected<Value *> CBResult = UpdateOp(OldExprVal, Builder);
10286 if (!CBResult)
10287 return CBResult.takeError();
10288 Value *Upd = *CBResult;
10289 Builder.CreateStore(Val: Upd, Ptr: NewAtomicAddr);
10290 AtomicOrdering Failure =
10291 llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(SuccessOrdering: AO);
10292 auto Result = atomicInfo.EmitAtomicCompareExchangeLibcall(
10293 ExpectedVal: AtomicLoadRes.second, DesiredVal: NewAtomicAddr, Success: AO, Failure);
10294 LoadInst *PHILoad = Builder.CreateLoad(Ty: XElemTy, Ptr: Result.first);
10295 PHI->addIncoming(V: PHILoad, BB: Builder.GetInsertBlock());
10296 Builder.CreateCondBr(Cond: Result.second, True: ExitBB, False: ContBB);
10297 OldVal->eraseFromParent();
10298 Res.first = OldExprVal;
10299 Res.second = Upd;
10300
10301 if (UnreachableInst *ExitTI =
10302 dyn_cast<UnreachableInst>(Val: ExitBB->getTerminator())) {
10303 CurBBTI->eraseFromParent();
10304 Builder.SetInsertPoint(ExitBB);
10305 } else {
10306 Builder.SetInsertPoint(ExitTI);
10307 }
10308 } else {
10309 IntegerType *IntCastTy =
10310 IntegerType::get(C&: M.getContext(), NumBits: XElemTy->getScalarSizeInBits());
10311 LoadInst *OldVal =
10312 Builder.CreateLoad(Ty: IntCastTy, Ptr: X, Name: X->getName() + ".atomic.load");
10313 OldVal->setAtomic(Ordering: AO);
10314 // CurBB
10315 // | /---\
10316 // ContBB |
10317 // | \---/
10318 // ExitBB
10319 BasicBlock *CurBB = Builder.GetInsertBlock();
10320 Instruction *CurBBTI = CurBB->getTerminator();
10321 CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
10322 BasicBlock *ExitBB =
10323 CurBB->splitBasicBlock(I: CurBBTI, BBName: X->getName() + ".atomic.exit");
10324 BasicBlock *ContBB = CurBB->splitBasicBlock(I: CurBB->getTerminator(),
10325 BBName: X->getName() + ".atomic.cont");
10326 ContBB->getTerminator()->eraseFromParent();
10327 Builder.restoreIP(IP: AllocaIP);
10328 AllocaInst *NewAtomicAddr = Builder.CreateAlloca(Ty: XElemTy);
10329 NewAtomicAddr->setName(X->getName() + "x.new.val");
10330 Builder.SetInsertPoint(ContBB);
10331 llvm::PHINode *PHI = Builder.CreatePHI(Ty: OldVal->getType(), NumReservedValues: 2);
10332 PHI->addIncoming(V: OldVal, BB: CurBB);
10333 bool IsIntTy = XElemTy->isIntegerTy();
10334 Value *OldExprVal = PHI;
10335 if (!IsIntTy) {
10336 if (XElemTy->isFloatingPointTy()) {
10337 OldExprVal = Builder.CreateBitCast(V: PHI, DestTy: XElemTy,
10338 Name: X->getName() + ".atomic.fltCast");
10339 } else {
10340 OldExprVal = Builder.CreateIntToPtr(V: PHI, DestTy: XElemTy,
10341 Name: X->getName() + ".atomic.ptrCast");
10342 }
10343 }
10344
10345 Expected<Value *> CBResult = UpdateOp(OldExprVal, Builder);
10346 if (!CBResult)
10347 return CBResult.takeError();
10348 Value *Upd = *CBResult;
10349 Builder.CreateStore(Val: Upd, Ptr: NewAtomicAddr);
10350 LoadInst *DesiredVal = Builder.CreateLoad(Ty: IntCastTy, Ptr: NewAtomicAddr);
10351 AtomicOrdering Failure =
10352 llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(SuccessOrdering: AO);
10353 AtomicCmpXchgInst *Result = Builder.CreateAtomicCmpXchg(
10354 Ptr: X, Cmp: PHI, New: DesiredVal, Align: llvm::MaybeAlign(), SuccessOrdering: AO, FailureOrdering: Failure);
10355 Result->setVolatile(VolatileX);
10356 Value *PreviousVal = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/0);
10357 Value *SuccessFailureVal = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/1);
10358 PHI->addIncoming(V: PreviousVal, BB: Builder.GetInsertBlock());
10359 Builder.CreateCondBr(Cond: SuccessFailureVal, True: ExitBB, False: ContBB);
10360
10361 Res.first = OldExprVal;
10362 Res.second = Upd;
10363
10364 // set Insertion point in exit block
10365 if (UnreachableInst *ExitTI =
10366 dyn_cast<UnreachableInst>(Val: ExitBB->getTerminator())) {
10367 CurBBTI->eraseFromParent();
10368 Builder.SetInsertPoint(ExitBB);
10369 } else {
10370 Builder.SetInsertPoint(ExitTI);
10371 }
10372 }
10373
10374 return Res;
10375}
10376
10377OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createAtomicCapture(
10378 const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
10379 AtomicOpValue &V, Value *Expr, AtomicOrdering AO,
10380 AtomicRMWInst::BinOp RMWOp, AtomicUpdateCallbackTy &UpdateOp,
10381 bool UpdateExpr, bool IsPostfixUpdate, bool IsXBinopExpr,
10382 bool IsIgnoreDenormalMode, bool IsFineGrainedMemory, bool IsRemoteMemory) {
10383 if (!updateToLocation(Loc))
10384 return Loc.IP;
10385
10386 LLVM_DEBUG({
10387 Type *XTy = X.Var->getType();
10388 assert(XTy->isPointerTy() &&
10389 "OMP Atomic expects a pointer to target memory");
10390 Type *XElemTy = X.ElemTy;
10391 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
10392 XElemTy->isPointerTy()) &&
10393 "OMP atomic capture expected a scalar type");
10394 assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
10395 "OpenMP atomic does not support LT or GT operations");
10396 });
10397
10398 // If UpdateExpr is 'x' updated with some `expr` not based on 'x',
10399 // 'x' is simply atomically rewritten with 'expr'.
10400 AtomicRMWInst::BinOp AtomicOp = (UpdateExpr ? RMWOp : AtomicRMWInst::Xchg);
10401 Expected<std::pair<Value *, Value *>> AtomicResult = emitAtomicUpdate(
10402 AllocaIP, X: X.Var, XElemTy: X.ElemTy, Expr, AO, RMWOp: AtomicOp, UpdateOp, VolatileX: X.IsVolatile,
10403 IsXBinopExpr, IsIgnoreDenormalMode, IsFineGrainedMemory, IsRemoteMemory);
10404 if (!AtomicResult)
10405 return AtomicResult.takeError();
10406 Value *CapturedVal =
10407 (IsPostfixUpdate ? AtomicResult->first : AtomicResult->second);
10408 Builder.CreateStore(Val: CapturedVal, Ptr: V.Var, isVolatile: V.IsVolatile);
10409
10410 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Capture);
10411 return Builder.saveIP();
10412}
10413
10414OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
10415 const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
10416 AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
10417 omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
10418 bool IsFailOnly) {
10419
10420 AtomicOrdering Failure = AtomicCmpXchgInst::getStrongestFailureOrdering(SuccessOrdering: AO);
10421 return createAtomicCompare(Loc, X, V, R, E, D, AO, Op, IsXBinopExpr,
10422 IsPostfixUpdate, IsFailOnly, Failure);
10423}
10424
10425OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
10426 const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
10427 AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
10428 omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
10429 bool IsFailOnly, AtomicOrdering Failure) {
10430
10431 if (!updateToLocation(Loc))
10432 return Loc.IP;
10433
10434 assert(X.Var->getType()->isPointerTy() &&
10435 "OMP atomic expects a pointer to target memory");
10436 // compare capture
10437 if (V.Var) {
10438 assert(V.Var->getType()->isPointerTy() && "v.var must be of pointer type");
10439 assert(V.ElemTy == X.ElemTy && "x and v must be of same type");
10440 }
10441
10442 bool IsInteger = E->getType()->isIntegerTy();
10443
10444 if (Op == OMPAtomicCompareOp::EQ) {
10445 AtomicCmpXchgInst *Result = nullptr;
10446 if (!IsInteger) {
10447 IntegerType *IntCastTy =
10448 IntegerType::get(C&: M.getContext(), NumBits: X.ElemTy->getScalarSizeInBits());
10449 Value *EBCast = Builder.CreateBitCast(V: E, DestTy: IntCastTy);
10450 Value *DBCast = Builder.CreateBitCast(V: D, DestTy: IntCastTy);
10451 Result = Builder.CreateAtomicCmpXchg(Ptr: X.Var, Cmp: EBCast, New: DBCast, Align: MaybeAlign(),
10452 SuccessOrdering: AO, FailureOrdering: Failure);
10453 } else {
10454 Result =
10455 Builder.CreateAtomicCmpXchg(Ptr: X.Var, Cmp: E, New: D, Align: MaybeAlign(), SuccessOrdering: AO, FailureOrdering: Failure);
10456 }
10457
10458 if (V.Var) {
10459 Value *OldValue = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/0);
10460 if (!IsInteger)
10461 OldValue = Builder.CreateBitCast(V: OldValue, DestTy: X.ElemTy);
10462 assert(OldValue->getType() == V.ElemTy &&
10463 "OldValue and V must be of same type");
10464 if (IsPostfixUpdate) {
10465 Builder.CreateStore(Val: OldValue, Ptr: V.Var, isVolatile: V.IsVolatile);
10466 } else {
10467 Value *SuccessOrFail = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/1);
10468 if (IsFailOnly) {
10469 // CurBB----
10470 // | |
10471 // v |
10472 // ContBB |
10473 // | |
10474 // v |
10475 // ExitBB <-
10476 //
10477 // where ContBB only contains the store of old value to 'v'.
10478 BasicBlock *CurBB = Builder.GetInsertBlock();
10479 Instruction *CurBBTI = CurBB->getTerminator();
10480 CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
10481 BasicBlock *ExitBB = CurBB->splitBasicBlock(
10482 I: CurBBTI, BBName: X.Var->getName() + ".atomic.exit");
10483 BasicBlock *ContBB = CurBB->splitBasicBlock(
10484 I: CurBB->getTerminator(), BBName: X.Var->getName() + ".atomic.cont");
10485 ContBB->getTerminator()->eraseFromParent();
10486 CurBB->getTerminator()->eraseFromParent();
10487
10488 Builder.CreateCondBr(Cond: SuccessOrFail, True: ExitBB, False: ContBB);
10489
10490 Builder.SetInsertPoint(ContBB);
10491 Builder.CreateStore(Val: OldValue, Ptr: V.Var);
10492 Builder.CreateBr(Dest: ExitBB);
10493
10494 if (UnreachableInst *ExitTI =
10495 dyn_cast<UnreachableInst>(Val: ExitBB->getTerminator())) {
10496 CurBBTI->eraseFromParent();
10497 Builder.SetInsertPoint(ExitBB);
10498 } else {
10499 Builder.SetInsertPoint(ExitTI);
10500 }
10501 } else {
10502 Value *CapturedValue =
10503 Builder.CreateSelect(C: SuccessOrFail, True: E, False: OldValue);
10504 Builder.CreateStore(Val: CapturedValue, Ptr: V.Var, isVolatile: V.IsVolatile);
10505 }
10506 }
10507 }
10508 // The comparison result has to be stored.
10509 if (R.Var) {
10510 assert(R.Var->getType()->isPointerTy() &&
10511 "r.var must be of pointer type");
10512 assert(R.ElemTy->isIntegerTy() && "r must be of integral type");
10513
10514 Value *SuccessFailureVal = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/1);
10515 Value *ResultCast = R.IsSigned
10516 ? Builder.CreateSExt(V: SuccessFailureVal, DestTy: R.ElemTy)
10517 : Builder.CreateZExt(V: SuccessFailureVal, DestTy: R.ElemTy);
10518 Builder.CreateStore(Val: ResultCast, Ptr: R.Var, isVolatile: R.IsVolatile);
10519 }
10520 } else {
10521 assert((Op == OMPAtomicCompareOp::MAX || Op == OMPAtomicCompareOp::MIN) &&
10522 "Op should be either max or min at this point");
10523 assert(!IsFailOnly && "IsFailOnly is only valid when the comparison is ==");
10524
10525 // Reverse the ordop as the OpenMP forms are different from LLVM forms.
10526 // Let's take max as example.
10527 // OpenMP form:
10528 // x = x > expr ? expr : x;
10529 // LLVM form:
10530 // *ptr = *ptr > val ? *ptr : val;
10531 // We need to transform to LLVM form.
10532 // x = x <= expr ? x : expr;
10533 AtomicRMWInst::BinOp NewOp;
10534 if (IsXBinopExpr) {
10535 if (IsInteger) {
10536 if (X.IsSigned)
10537 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Min
10538 : AtomicRMWInst::Max;
10539 else
10540 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMin
10541 : AtomicRMWInst::UMax;
10542 } else {
10543 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMin
10544 : AtomicRMWInst::FMax;
10545 }
10546 } else {
10547 if (IsInteger) {
10548 if (X.IsSigned)
10549 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Max
10550 : AtomicRMWInst::Min;
10551 else
10552 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMax
10553 : AtomicRMWInst::UMin;
10554 } else {
10555 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMax
10556 : AtomicRMWInst::FMin;
10557 }
10558 }
10559
10560 AtomicRMWInst *OldValue =
10561 Builder.CreateAtomicRMW(Op: NewOp, Ptr: X.Var, Val: E, Align: MaybeAlign(), Ordering: AO);
10562 if (V.Var) {
10563 Value *CapturedValue = nullptr;
10564 if (IsPostfixUpdate) {
10565 CapturedValue = OldValue;
10566 } else {
10567 CmpInst::Predicate Pred;
10568 switch (NewOp) {
10569 case AtomicRMWInst::Max:
10570 Pred = CmpInst::ICMP_SGT;
10571 break;
10572 case AtomicRMWInst::UMax:
10573 Pred = CmpInst::ICMP_UGT;
10574 break;
10575 case AtomicRMWInst::FMax:
10576 Pred = CmpInst::FCMP_OGT;
10577 break;
10578 case AtomicRMWInst::Min:
10579 Pred = CmpInst::ICMP_SLT;
10580 break;
10581 case AtomicRMWInst::UMin:
10582 Pred = CmpInst::ICMP_ULT;
10583 break;
10584 case AtomicRMWInst::FMin:
10585 Pred = CmpInst::FCMP_OLT;
10586 break;
10587 default:
10588 llvm_unreachable("unexpected comparison op");
10589 }
10590 Value *NonAtomicCmp = Builder.CreateCmp(Pred, LHS: OldValue, RHS: E);
10591 CapturedValue = Builder.CreateSelect(C: NonAtomicCmp, True: E, False: OldValue);
10592 }
10593 Builder.CreateStore(Val: CapturedValue, Ptr: V.Var, isVolatile: V.IsVolatile);
10594 }
10595 }
10596
10597 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Compare);
10598
10599 return Builder.saveIP();
10600}
10601
10602OpenMPIRBuilder::InsertPointOrErrorTy
10603OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
10604 BodyGenCallbackTy BodyGenCB, Value *NumTeamsLower,
10605 Value *NumTeamsUpper, Value *ThreadLimit,
10606 Value *IfExpr) {
10607 if (!updateToLocation(Loc))
10608 return InsertPointTy();
10609
10610 uint32_t SrcLocStrSize;
10611 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
10612 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
10613 Function *CurrentFunction = Builder.GetInsertBlock()->getParent();
10614
10615 // Outer allocation basicblock is the entry block of the current function.
10616 BasicBlock &OuterAllocaBB = CurrentFunction->getEntryBlock();
10617 if (&OuterAllocaBB == Builder.GetInsertBlock()) {
10618 BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, Name: "teams.entry");
10619 Builder.SetInsertPoint(TheBB: BodyBB, IP: BodyBB->begin());
10620 }
10621
10622 // The current basic block is split into four basic blocks. After outlining,
10623 // they will be mapped as follows:
10624 // ```
10625 // def current_fn() {
10626 // current_basic_block:
10627 // br label %teams.exit
10628 // teams.exit:
10629 // ; instructions after teams
10630 // }
10631 //
10632 // def outlined_fn() {
10633 // teams.alloca:
10634 // br label %teams.body
10635 // teams.body:
10636 // ; instructions within teams body
10637 // }
10638 // ```
10639 BasicBlock *ExitBB = splitBB(Builder, /*CreateBranch=*/true, Name: "teams.exit");
10640 BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, Name: "teams.body");
10641 BasicBlock *AllocaBB =
10642 splitBB(Builder, /*CreateBranch=*/true, Name: "teams.alloca");
10643
10644 bool SubClausesPresent =
10645 (NumTeamsLower || NumTeamsUpper || ThreadLimit || IfExpr);
10646 // Push num_teams
10647 if (!Config.isTargetDevice() && SubClausesPresent) {
10648 assert((NumTeamsLower == nullptr || NumTeamsUpper != nullptr) &&
10649 "if lowerbound is non-null, then upperbound must also be non-null "
10650 "for bounds on num_teams");
10651
10652 if (NumTeamsUpper == nullptr)
10653 NumTeamsUpper = Builder.getInt32(C: 0);
10654
10655 if (NumTeamsLower == nullptr)
10656 NumTeamsLower = NumTeamsUpper;
10657
10658 if (IfExpr) {
10659 assert(IfExpr->getType()->isIntegerTy() &&
10660 "argument to if clause must be an integer value");
10661
10662 // upper = ifexpr ? upper : 1
10663 if (IfExpr->getType() != Int1)
10664 IfExpr = Builder.CreateICmpNE(LHS: IfExpr,
10665 RHS: ConstantInt::get(Ty: IfExpr->getType(), V: 0));
10666 NumTeamsUpper = Builder.CreateSelect(
10667 C: IfExpr, True: NumTeamsUpper, False: Builder.getInt32(C: 1), Name: "numTeamsUpper");
10668
10669 // lower = ifexpr ? lower : 1
10670 NumTeamsLower = Builder.CreateSelect(
10671 C: IfExpr, True: NumTeamsLower, False: Builder.getInt32(C: 1), Name: "numTeamsLower");
10672 }
10673
10674 if (ThreadLimit == nullptr)
10675 ThreadLimit = Builder.getInt32(C: 0);
10676
10677 // The __kmpc_push_num_teams_51 function expects int32 as the arguments. So,
10678 // truncate or sign extend the passed values to match the int32 parameters.
10679 Value *NumTeamsLowerInt32 =
10680 Builder.CreateSExtOrTrunc(V: NumTeamsLower, DestTy: Builder.getInt32Ty());
10681 Value *NumTeamsUpperInt32 =
10682 Builder.CreateSExtOrTrunc(V: NumTeamsUpper, DestTy: Builder.getInt32Ty());
10683 Value *ThreadLimitInt32 =
10684 Builder.CreateSExtOrTrunc(V: ThreadLimit, DestTy: Builder.getInt32Ty());
10685
10686 Value *ThreadNum = getOrCreateThreadID(Ident);
10687
10688 createRuntimeFunctionCall(
10689 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_push_num_teams_51),
10690 Args: {Ident, ThreadNum, NumTeamsLowerInt32, NumTeamsUpperInt32,
10691 ThreadLimitInt32});
10692 }
10693 // Generate the body of teams.
10694 InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
10695 InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
10696 if (Error Err = BodyGenCB(AllocaIP, CodeGenIP))
10697 return Err;
10698
10699 OutlineInfo OI;
10700 OI.EntryBB = AllocaBB;
10701 OI.ExitBB = ExitBB;
10702 OI.OuterAllocaBB = &OuterAllocaBB;
10703
10704 // Insert fake values for global tid and bound tid.
10705 SmallVector<Instruction *, 8> ToBeDeleted;
10706 InsertPointTy OuterAllocaIP(&OuterAllocaBB, OuterAllocaBB.begin());
10707 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
10708 Builder, OuterAllocaIP, ToBeDeleted, InnerAllocaIP: AllocaIP, Name: "gid", AsPtr: true));
10709 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
10710 Builder, OuterAllocaIP, ToBeDeleted, InnerAllocaIP: AllocaIP, Name: "tid", AsPtr: true));
10711
10712 auto HostPostOutlineCB = [this, Ident,
10713 ToBeDeleted](Function &OutlinedFn) mutable {
10714 // The stale call instruction will be replaced with a new call instruction
10715 // for runtime call with the outlined function.
10716
10717 assert(OutlinedFn.hasOneUse() &&
10718 "there must be a single user for the outlined function");
10719 CallInst *StaleCI = cast<CallInst>(Val: OutlinedFn.user_back());
10720 ToBeDeleted.push_back(Elt: StaleCI);
10721
10722 assert((OutlinedFn.arg_size() == 2 || OutlinedFn.arg_size() == 3) &&
10723 "Outlined function must have two or three arguments only");
10724
10725 bool HasShared = OutlinedFn.arg_size() == 3;
10726
10727 OutlinedFn.getArg(i: 0)->setName("global.tid.ptr");
10728 OutlinedFn.getArg(i: 1)->setName("bound.tid.ptr");
10729 if (HasShared)
10730 OutlinedFn.getArg(i: 2)->setName("data");
10731
10732 // Call to the runtime function for teams in the current function.
10733 assert(StaleCI && "Error while outlining - no CallInst user found for the "
10734 "outlined function.");
10735 Builder.SetInsertPoint(StaleCI);
10736 SmallVector<Value *> Args = {
10737 Ident, Builder.getInt32(C: StaleCI->arg_size() - 2), &OutlinedFn};
10738 if (HasShared)
10739 Args.push_back(Elt: StaleCI->getArgOperand(i: 2));
10740 createRuntimeFunctionCall(
10741 Callee: getOrCreateRuntimeFunctionPtr(
10742 FnID: omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
10743 Args);
10744
10745 for (Instruction *I : llvm::reverse(C&: ToBeDeleted))
10746 I->eraseFromParent();
10747 };
10748
10749 if (!Config.isTargetDevice())
10750 OI.PostOutlineCB = HostPostOutlineCB;
10751
10752 addOutlineInfo(OI: std::move(OI));
10753
10754 Builder.SetInsertPoint(TheBB: ExitBB, IP: ExitBB->begin());
10755
10756 return Builder.saveIP();
10757}
10758
10759OpenMPIRBuilder::InsertPointOrErrorTy
10760OpenMPIRBuilder::createDistribute(const LocationDescription &Loc,
10761 InsertPointTy OuterAllocaIP,
10762 BodyGenCallbackTy BodyGenCB) {
10763 if (!updateToLocation(Loc))
10764 return InsertPointTy();
10765
10766 BasicBlock *OuterAllocaBB = OuterAllocaIP.getBlock();
10767
10768 if (OuterAllocaBB == Builder.GetInsertBlock()) {
10769 BasicBlock *BodyBB =
10770 splitBB(Builder, /*CreateBranch=*/true, Name: "distribute.entry");
10771 Builder.SetInsertPoint(TheBB: BodyBB, IP: BodyBB->begin());
10772 }
10773 BasicBlock *ExitBB =
10774 splitBB(Builder, /*CreateBranch=*/true, Name: "distribute.exit");
10775 BasicBlock *BodyBB =
10776 splitBB(Builder, /*CreateBranch=*/true, Name: "distribute.body");
10777 BasicBlock *AllocaBB =
10778 splitBB(Builder, /*CreateBranch=*/true, Name: "distribute.alloca");
10779
10780 // Generate the body of distribute clause
10781 InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
10782 InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
10783 if (Error Err = BodyGenCB(AllocaIP, CodeGenIP))
10784 return Err;
10785
10786 // When using target we use different runtime functions which require a
10787 // callback.
10788 if (Config.isTargetDevice()) {
10789 OutlineInfo OI;
10790 OI.OuterAllocaBB = OuterAllocaIP.getBlock();
10791 OI.EntryBB = AllocaBB;
10792 OI.ExitBB = ExitBB;
10793
10794 addOutlineInfo(OI: std::move(OI));
10795 }
10796 Builder.SetInsertPoint(TheBB: ExitBB, IP: ExitBB->begin());
10797
10798 return Builder.saveIP();
10799}
10800
10801GlobalVariable *
10802OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl<llvm::Constant *> &Names,
10803 std::string VarName) {
10804 llvm::Constant *MapNamesArrayInit = llvm::ConstantArray::get(
10805 T: llvm::ArrayType::get(ElementType: llvm::PointerType::getUnqual(C&: M.getContext()),
10806 NumElements: Names.size()),
10807 V: Names);
10808 auto *MapNamesArrayGlobal = new llvm::GlobalVariable(
10809 M, MapNamesArrayInit->getType(),
10810 /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MapNamesArrayInit,
10811 VarName);
10812 return MapNamesArrayGlobal;
10813}
10814
10815// Create all simple and struct types exposed by the runtime and remember
10816// the llvm::PointerTypes of them for easy access later.
10817void OpenMPIRBuilder::initializeTypes(Module &M) {
10818 LLVMContext &Ctx = M.getContext();
10819 StructType *T;
10820 unsigned DefaultTargetAS = Config.getDefaultTargetAS();
10821 unsigned ProgramAS = M.getDataLayout().getProgramAddressSpace();
10822#define OMP_TYPE(VarName, InitValue) VarName = InitValue;
10823#define OMP_ARRAY_TYPE(VarName, ElemTy, ArraySize) \
10824 VarName##Ty = ArrayType::get(ElemTy, ArraySize); \
10825 VarName##PtrTy = PointerType::get(Ctx, DefaultTargetAS);
10826#define OMP_FUNCTION_TYPE(VarName, IsVarArg, ReturnType, ...) \
10827 VarName = FunctionType::get(ReturnType, {__VA_ARGS__}, IsVarArg); \
10828 VarName##Ptr = PointerType::get(Ctx, ProgramAS);
10829#define OMP_STRUCT_TYPE(VarName, StructName, Packed, ...) \
10830 T = StructType::getTypeByName(Ctx, StructName); \
10831 if (!T) \
10832 T = StructType::create(Ctx, {__VA_ARGS__}, StructName, Packed); \
10833 VarName = T; \
10834 VarName##Ptr = PointerType::get(Ctx, DefaultTargetAS);
10835#include "llvm/Frontend/OpenMP/OMPKinds.def"
10836}
10837
10838void OpenMPIRBuilder::OutlineInfo::collectBlocks(
10839 SmallPtrSetImpl<BasicBlock *> &BlockSet,
10840 SmallVectorImpl<BasicBlock *> &BlockVector) {
10841 SmallVector<BasicBlock *, 32> Worklist;
10842 BlockSet.insert(Ptr: EntryBB);
10843 BlockSet.insert(Ptr: ExitBB);
10844
10845 Worklist.push_back(Elt: EntryBB);
10846 while (!Worklist.empty()) {
10847 BasicBlock *BB = Worklist.pop_back_val();
10848 BlockVector.push_back(Elt: BB);
10849 for (BasicBlock *SuccBB : successors(BB))
10850 if (BlockSet.insert(Ptr: SuccBB).second)
10851 Worklist.push_back(Elt: SuccBB);
10852 }
10853}
10854
10855void OpenMPIRBuilder::createOffloadEntry(Constant *ID, Constant *Addr,
10856 uint64_t Size, int32_t Flags,
10857 GlobalValue::LinkageTypes,
10858 StringRef Name) {
10859 if (!Config.isGPU()) {
10860 llvm::offloading::emitOffloadingEntry(
10861 M, Kind: object::OffloadKind::OFK_OpenMP, Addr: ID,
10862 Name: Name.empty() ? Addr->getName() : Name, Size, Flags, /*Data=*/0);
10863 return;
10864 }
10865 // TODO: Add support for global variables on the device after declare target
10866 // support.
10867 Function *Fn = dyn_cast<Function>(Val: Addr);
10868 if (!Fn)
10869 return;
10870
10871 // Add a function attribute for the kernel.
10872 Fn->addFnAttr(Kind: "kernel");
10873 if (T.isAMDGCN())
10874 Fn->addFnAttr(Kind: "uniform-work-group-size", Val: "true");
10875 Fn->addFnAttr(Kind: Attribute::MustProgress);
10876}
10877
10878// We only generate metadata for function that contain target regions.
10879void OpenMPIRBuilder::createOffloadEntriesAndInfoMetadata(
10880 EmitMetadataErrorReportFunctionTy &ErrorFn) {
10881
10882 // If there are no entries, we don't need to do anything.
10883 if (OffloadInfoManager.empty())
10884 return;
10885
10886 LLVMContext &C = M.getContext();
10887 SmallVector<std::pair<const OffloadEntriesInfoManager::OffloadEntryInfo *,
10888 TargetRegionEntryInfo>,
10889 16>
10890 OrderedEntries(OffloadInfoManager.size());
10891
10892 // Auxiliary methods to create metadata values and strings.
10893 auto &&GetMDInt = [this](unsigned V) {
10894 return ConstantAsMetadata::get(C: ConstantInt::get(Ty: Builder.getInt32Ty(), V));
10895 };
10896
10897 auto &&GetMDString = [&C](StringRef V) { return MDString::get(Context&: C, Str: V); };
10898
10899 // Create the offloading info metadata node.
10900 NamedMDNode *MD = M.getOrInsertNamedMetadata(Name: "omp_offload.info");
10901 auto &&TargetRegionMetadataEmitter =
10902 [&C, MD, &OrderedEntries, &GetMDInt, &GetMDString](
10903 const TargetRegionEntryInfo &EntryInfo,
10904 const OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion &E) {
10905 // Generate metadata for target regions. Each entry of this metadata
10906 // contains:
10907 // - Entry 0 -> Kind of this type of metadata (0).
10908 // - Entry 1 -> Device ID of the file where the entry was identified.
10909 // - Entry 2 -> File ID of the file where the entry was identified.
10910 // - Entry 3 -> Mangled name of the function where the entry was
10911 // identified.
10912 // - Entry 4 -> Line in the file where the entry was identified.
10913 // - Entry 5 -> Count of regions at this DeviceID/FilesID/Line.
10914 // - Entry 6 -> Order the entry was created.
10915 // The first element of the metadata node is the kind.
10916 Metadata *Ops[] = {
10917 GetMDInt(E.getKind()), GetMDInt(EntryInfo.DeviceID),
10918 GetMDInt(EntryInfo.FileID), GetMDString(EntryInfo.ParentName),
10919 GetMDInt(EntryInfo.Line), GetMDInt(EntryInfo.Count),
10920 GetMDInt(E.getOrder())};
10921
10922 // Save this entry in the right position of the ordered entries array.
10923 OrderedEntries[E.getOrder()] = std::make_pair(x: &E, y: EntryInfo);
10924
10925 // Add metadata to the named metadata node.
10926 MD->addOperand(M: MDNode::get(Context&: C, MDs: Ops));
10927 };
10928
10929 OffloadInfoManager.actOnTargetRegionEntriesInfo(Action: TargetRegionMetadataEmitter);
10930
10931 // Create function that emits metadata for each device global variable entry;
10932 auto &&DeviceGlobalVarMetadataEmitter =
10933 [&C, &OrderedEntries, &GetMDInt, &GetMDString, MD](
10934 StringRef MangledName,
10935 const OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar &E) {
10936 // Generate metadata for global variables. Each entry of this metadata
10937 // contains:
10938 // - Entry 0 -> Kind of this type of metadata (1).
10939 // - Entry 1 -> Mangled name of the variable.
10940 // - Entry 2 -> Declare target kind.
10941 // - Entry 3 -> Order the entry was created.
10942 // The first element of the metadata node is the kind.
10943 Metadata *Ops[] = {GetMDInt(E.getKind()), GetMDString(MangledName),
10944 GetMDInt(E.getFlags()), GetMDInt(E.getOrder())};
10945
10946 // Save this entry in the right position of the ordered entries array.
10947 TargetRegionEntryInfo varInfo(MangledName, 0, 0, 0);
10948 OrderedEntries[E.getOrder()] = std::make_pair(x: &E, y&: varInfo);
10949
10950 // Add metadata to the named metadata node.
10951 MD->addOperand(M: MDNode::get(Context&: C, MDs: Ops));
10952 };
10953
10954 OffloadInfoManager.actOnDeviceGlobalVarEntriesInfo(
10955 Action: DeviceGlobalVarMetadataEmitter);
10956
10957 for (const auto &E : OrderedEntries) {
10958 assert(E.first && "All ordered entries must exist!");
10959 if (const auto *CE =
10960 dyn_cast<OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion>(
10961 Val: E.first)) {
10962 if (!CE->getID() || !CE->getAddress()) {
10963 // Do not blame the entry if the parent funtion is not emitted.
10964 TargetRegionEntryInfo EntryInfo = E.second;
10965 StringRef FnName = EntryInfo.ParentName;
10966 if (!M.getNamedValue(Name: FnName))
10967 continue;
10968 ErrorFn(EMIT_MD_TARGET_REGION_ERROR, EntryInfo);
10969 continue;
10970 }
10971 createOffloadEntry(ID: CE->getID(), Addr: CE->getAddress(),
10972 /*Size=*/0, Flags: CE->getFlags(),
10973 GlobalValue::WeakAnyLinkage);
10974 } else if (const auto *CE = dyn_cast<
10975 OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar>(
10976 Val: E.first)) {
10977 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags =
10978 static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
10979 CE->getFlags());
10980 switch (Flags) {
10981 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter:
10982 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo:
10983 if (Config.isTargetDevice() && Config.hasRequiresUnifiedSharedMemory())
10984 continue;
10985 if (!CE->getAddress()) {
10986 ErrorFn(EMIT_MD_DECLARE_TARGET_ERROR, E.second);
10987 continue;
10988 }
10989 // The vaiable has no definition - no need to add the entry.
10990 if (CE->getVarSize() == 0)
10991 continue;
10992 break;
10993 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink:
10994 assert(((Config.isTargetDevice() && !CE->getAddress()) ||
10995 (!Config.isTargetDevice() && CE->getAddress())) &&
10996 "Declaret target link address is set.");
10997 if (Config.isTargetDevice())
10998 continue;
10999 if (!CE->getAddress()) {
11000 ErrorFn(EMIT_MD_GLOBAL_VAR_LINK_ERROR, TargetRegionEntryInfo());
11001 continue;
11002 }
11003 break;
11004 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect:
11005 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirectVTable:
11006 if (!CE->getAddress()) {
11007 ErrorFn(EMIT_MD_GLOBAL_VAR_INDIRECT_ERROR, E.second);
11008 continue;
11009 }
11010 break;
11011 default:
11012 break;
11013 }
11014
11015 // Hidden or internal symbols on the device are not externally visible.
11016 // We should not attempt to register them by creating an offloading
11017 // entry. Indirect variables are handled separately on the device.
11018 if (auto *GV = dyn_cast<GlobalValue>(Val: CE->getAddress()))
11019 if ((GV->hasLocalLinkage() || GV->hasHiddenVisibility()) &&
11020 (Flags !=
11021 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect &&
11022 Flags != OffloadEntriesInfoManager::
11023 OMPTargetGlobalVarEntryIndirectVTable))
11024 continue;
11025
11026 // Indirect globals need to use a special name that doesn't match the name
11027 // of the associated host global.
11028 if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect ||
11029 Flags ==
11030 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirectVTable)
11031 createOffloadEntry(ID: CE->getAddress(), Addr: CE->getAddress(), Size: CE->getVarSize(),
11032 Flags, CE->getLinkage(), Name: CE->getVarName());
11033 else
11034 createOffloadEntry(ID: CE->getAddress(), Addr: CE->getAddress(), Size: CE->getVarSize(),
11035 Flags, CE->getLinkage());
11036
11037 } else {
11038 llvm_unreachable("Unsupported entry kind.");
11039 }
11040 }
11041
11042 // Emit requires directive globals to a special entry so the runtime can
11043 // register them when the device image is loaded.
11044 // TODO: This reduces the offloading entries to a 32-bit integer. Offloading
11045 // entries should be redesigned to better suit this use-case.
11046 if (Config.hasRequiresFlags() && !Config.isTargetDevice())
11047 offloading::emitOffloadingEntry(
11048 M, Kind: object::OffloadKind::OFK_OpenMP,
11049 Addr: Constant::getNullValue(Ty: PointerType::getUnqual(C&: M.getContext())),
11050 Name: ".requires", /*Size=*/0,
11051 Flags: OffloadEntriesInfoManager::OMPTargetGlobalRegisterRequires,
11052 Data: Config.getRequiresFlags());
11053}
11054
11055void TargetRegionEntryInfo::getTargetRegionEntryFnName(
11056 SmallVectorImpl<char> &Name, StringRef ParentName, unsigned DeviceID,
11057 unsigned FileID, unsigned Line, unsigned Count) {
11058 raw_svector_ostream OS(Name);
11059 OS << KernelNamePrefix << llvm::format(Fmt: "%x", Vals: DeviceID)
11060 << llvm::format(Fmt: "_%x_", Vals: FileID) << ParentName << "_l" << Line;
11061 if (Count)
11062 OS << "_" << Count;
11063}
11064
11065void OffloadEntriesInfoManager::getTargetRegionEntryFnName(
11066 SmallVectorImpl<char> &Name, const TargetRegionEntryInfo &EntryInfo) {
11067 unsigned NewCount = getTargetRegionEntryInfoCount(EntryInfo);
11068 TargetRegionEntryInfo::getTargetRegionEntryFnName(
11069 Name, ParentName: EntryInfo.ParentName, DeviceID: EntryInfo.DeviceID, FileID: EntryInfo.FileID,
11070 Line: EntryInfo.Line, Count: NewCount);
11071}
11072
11073TargetRegionEntryInfo
11074OpenMPIRBuilder::getTargetEntryUniqueInfo(FileIdentifierInfoCallbackTy CallBack,
11075 vfs::FileSystem &VFS,
11076 StringRef ParentName) {
11077 sys::fs::UniqueID ID(0xdeadf17e, 0);
11078 auto FileIDInfo = CallBack();
11079 uint64_t FileID = 0;
11080 if (ErrorOr<vfs::Status> Status = VFS.status(Path: std::get<0>(t&: FileIDInfo))) {
11081 ID = Status->getUniqueID();
11082 FileID = Status->getUniqueID().getFile();
11083 } else {
11084 // If the inode ID could not be determined, create a hash value
11085 // the current file name and use that as an ID.
11086 FileID = hash_value(arg: std::get<0>(t&: FileIDInfo));
11087 }
11088
11089 return TargetRegionEntryInfo(ParentName, ID.getDevice(), FileID,
11090 std::get<1>(t&: FileIDInfo));
11091}
11092
11093unsigned OpenMPIRBuilder::getFlagMemberOffset() {
11094 unsigned Offset = 0;
11095 for (uint64_t Remain =
11096 static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
11097 omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF);
11098 !(Remain & 1); Remain = Remain >> 1)
11099 Offset++;
11100 return Offset;
11101}
11102
11103omp::OpenMPOffloadMappingFlags
11104OpenMPIRBuilder::getMemberOfFlag(unsigned Position) {
11105 // Rotate by getFlagMemberOffset() bits.
11106 return static_cast<omp::OpenMPOffloadMappingFlags>(((uint64_t)Position + 1)
11107 << getFlagMemberOffset());
11108}
11109
11110void OpenMPIRBuilder::setCorrectMemberOfFlag(
11111 omp::OpenMPOffloadMappingFlags &Flags,
11112 omp::OpenMPOffloadMappingFlags MemberOfFlag) {
11113 // If the entry is PTR_AND_OBJ but has not been marked with the special
11114 // placeholder value 0xFFFF in the MEMBER_OF field, then it should not be
11115 // marked as MEMBER_OF.
11116 if (static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
11117 Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ) &&
11118 static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
11119 (Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF) !=
11120 omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF))
11121 return;
11122
11123 // Entries with ATTACH are not members-of anything. They are handled
11124 // separately by the runtime after other maps have been handled.
11125 if (static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
11126 Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH))
11127 return;
11128
11129 // Reset the placeholder value to prepare the flag for the assignment of the
11130 // proper MEMBER_OF value.
11131 Flags &= ~omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
11132 Flags |= MemberOfFlag;
11133}
11134
11135Constant *OpenMPIRBuilder::getAddrOfDeclareTargetVar(
11136 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
11137 OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
11138 bool IsDeclaration, bool IsExternallyVisible,
11139 TargetRegionEntryInfo EntryInfo, StringRef MangledName,
11140 std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
11141 std::vector<Triple> TargetTriple, Type *LlvmPtrTy,
11142 std::function<Constant *()> GlobalInitializer,
11143 std::function<GlobalValue::LinkageTypes()> VariableLinkage) {
11144 // TODO: convert this to utilise the IRBuilder Config rather than
11145 // a passed down argument.
11146 if (OpenMPSIMD)
11147 return nullptr;
11148
11149 if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink ||
11150 ((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
11151 CaptureClause ==
11152 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
11153 Config.hasRequiresUnifiedSharedMemory())) {
11154 SmallString<64> PtrName;
11155 {
11156 raw_svector_ostream OS(PtrName);
11157 OS << MangledName;
11158 if (!IsExternallyVisible)
11159 OS << format(Fmt: "_%x", Vals: EntryInfo.FileID);
11160 OS << "_decl_tgt_ref_ptr";
11161 }
11162
11163 Value *Ptr = M.getNamedValue(Name: PtrName);
11164
11165 if (!Ptr) {
11166 GlobalValue *GlobalValue = M.getNamedValue(Name: MangledName);
11167 Ptr = getOrCreateInternalVariable(Ty: LlvmPtrTy, Name: PtrName);
11168
11169 auto *GV = cast<GlobalVariable>(Val: Ptr);
11170 GV->setLinkage(GlobalValue::WeakAnyLinkage);
11171
11172 if (!Config.isTargetDevice()) {
11173 if (GlobalInitializer)
11174 GV->setInitializer(GlobalInitializer());
11175 else
11176 GV->setInitializer(GlobalValue);
11177 }
11178
11179 registerTargetGlobalVariable(
11180 CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
11181 EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
11182 GlobalInitializer, VariableLinkage, LlvmPtrTy, Addr: cast<Constant>(Val: Ptr));
11183 }
11184
11185 return cast<Constant>(Val: Ptr);
11186 }
11187
11188 return nullptr;
11189}
11190
11191void OpenMPIRBuilder::registerTargetGlobalVariable(
11192 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
11193 OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
11194 bool IsDeclaration, bool IsExternallyVisible,
11195 TargetRegionEntryInfo EntryInfo, StringRef MangledName,
11196 std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
11197 std::vector<Triple> TargetTriple,
11198 std::function<Constant *()> GlobalInitializer,
11199 std::function<GlobalValue::LinkageTypes()> VariableLinkage, Type *LlvmPtrTy,
11200 Constant *Addr) {
11201 if (DeviceClause != OffloadEntriesInfoManager::OMPTargetDeviceClauseAny ||
11202 (TargetTriple.empty() && !Config.isTargetDevice()))
11203 return;
11204
11205 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags;
11206 StringRef VarName;
11207 int64_t VarSize;
11208 GlobalValue::LinkageTypes Linkage;
11209
11210 if ((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
11211 CaptureClause ==
11212 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
11213 !Config.hasRequiresUnifiedSharedMemory()) {
11214 Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
11215 VarName = MangledName;
11216 GlobalValue *LlvmVal = M.getNamedValue(Name: VarName);
11217
11218 if (!IsDeclaration)
11219 VarSize = divideCeil(
11220 Numerator: M.getDataLayout().getTypeSizeInBits(Ty: LlvmVal->getValueType()), Denominator: 8);
11221 else
11222 VarSize = 0;
11223 Linkage = (VariableLinkage) ? VariableLinkage() : LlvmVal->getLinkage();
11224
11225 // This is a workaround carried over from Clang which prevents undesired
11226 // optimisation of internal variables.
11227 if (Config.isTargetDevice() &&
11228 (!IsExternallyVisible || Linkage == GlobalValue::LinkOnceODRLinkage)) {
11229 // Do not create a "ref-variable" if the original is not also available
11230 // on the host.
11231 if (!OffloadInfoManager.hasDeviceGlobalVarEntryInfo(VarName))
11232 return;
11233
11234 std::string RefName = createPlatformSpecificName(Parts: {VarName, "ref"});
11235
11236 if (!M.getNamedValue(Name: RefName)) {
11237 Constant *AddrRef =
11238 getOrCreateInternalVariable(Ty: Addr->getType(), Name: RefName);
11239 auto *GvAddrRef = cast<GlobalVariable>(Val: AddrRef);
11240 GvAddrRef->setConstant(true);
11241 GvAddrRef->setLinkage(GlobalValue::InternalLinkage);
11242 GvAddrRef->setInitializer(Addr);
11243 GeneratedRefs.push_back(x: GvAddrRef);
11244 }
11245 }
11246 } else {
11247 if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink)
11248 Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
11249 else
11250 Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
11251
11252 if (Config.isTargetDevice()) {
11253 VarName = (Addr) ? Addr->getName() : "";
11254 Addr = nullptr;
11255 } else {
11256 Addr = getAddrOfDeclareTargetVar(
11257 CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
11258 EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
11259 LlvmPtrTy, GlobalInitializer, VariableLinkage);
11260 VarName = (Addr) ? Addr->getName() : "";
11261 }
11262 VarSize = M.getDataLayout().getPointerSize();
11263 Linkage = GlobalValue::WeakAnyLinkage;
11264 }
11265
11266 OffloadInfoManager.registerDeviceGlobalVarEntryInfo(VarName, Addr, VarSize,
11267 Flags, Linkage);
11268}
11269
11270/// Loads all the offload entries information from the host IR
11271/// metadata.
11272void OpenMPIRBuilder::loadOffloadInfoMetadata(Module &M) {
11273 // If we are in target mode, load the metadata from the host IR. This code has
11274 // to match the metadata creation in createOffloadEntriesAndInfoMetadata().
11275
11276 NamedMDNode *MD = M.getNamedMetadata(Name: ompOffloadInfoName);
11277 if (!MD)
11278 return;
11279
11280 for (MDNode *MN : MD->operands()) {
11281 auto &&GetMDInt = [MN](unsigned Idx) {
11282 auto *V = cast<ConstantAsMetadata>(Val: MN->getOperand(I: Idx));
11283 return cast<ConstantInt>(Val: V->getValue())->getZExtValue();
11284 };
11285
11286 auto &&GetMDString = [MN](unsigned Idx) {
11287 auto *V = cast<MDString>(Val: MN->getOperand(I: Idx));
11288 return V->getString();
11289 };
11290
11291 switch (GetMDInt(0)) {
11292 default:
11293 llvm_unreachable("Unexpected metadata!");
11294 break;
11295 case OffloadEntriesInfoManager::OffloadEntryInfo::
11296 OffloadingEntryInfoTargetRegion: {
11297 TargetRegionEntryInfo EntryInfo(/*ParentName=*/GetMDString(3),
11298 /*DeviceID=*/GetMDInt(1),
11299 /*FileID=*/GetMDInt(2),
11300 /*Line=*/GetMDInt(4),
11301 /*Count=*/GetMDInt(5));
11302 OffloadInfoManager.initializeTargetRegionEntryInfo(EntryInfo,
11303 /*Order=*/GetMDInt(6));
11304 break;
11305 }
11306 case OffloadEntriesInfoManager::OffloadEntryInfo::
11307 OffloadingEntryInfoDeviceGlobalVar:
11308 OffloadInfoManager.initializeDeviceGlobalVarEntryInfo(
11309 /*MangledName=*/Name: GetMDString(1),
11310 Flags: static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
11311 /*Flags=*/GetMDInt(2)),
11312 /*Order=*/GetMDInt(3));
11313 break;
11314 }
11315 }
11316}
11317
11318void OpenMPIRBuilder::loadOffloadInfoMetadata(vfs::FileSystem &VFS,
11319 StringRef HostFilePath) {
11320 if (HostFilePath.empty())
11321 return;
11322
11323 auto Buf = VFS.getBufferForFile(Name: HostFilePath);
11324 if (std::error_code Err = Buf.getError()) {
11325 report_fatal_error(reason: ("error opening host file from host file path inside of "
11326 "OpenMPIRBuilder: " +
11327 Err.message())
11328 .c_str());
11329 }
11330
11331 LLVMContext Ctx;
11332 auto M = expectedToErrorOrAndEmitErrors(
11333 Ctx, Val: parseBitcodeFile(Buffer: Buf.get()->getMemBufferRef(), Context&: Ctx));
11334 if (std::error_code Err = M.getError()) {
11335 report_fatal_error(
11336 reason: ("error parsing host file inside of OpenMPIRBuilder: " + Err.message())
11337 .c_str());
11338 }
11339
11340 loadOffloadInfoMetadata(M&: *M.get());
11341}
11342
11343//===----------------------------------------------------------------------===//
11344// OffloadEntriesInfoManager
11345//===----------------------------------------------------------------------===//
11346
11347bool OffloadEntriesInfoManager::empty() const {
11348 return OffloadEntriesTargetRegion.empty() &&
11349 OffloadEntriesDeviceGlobalVar.empty();
11350}
11351
11352unsigned OffloadEntriesInfoManager::getTargetRegionEntryInfoCount(
11353 const TargetRegionEntryInfo &EntryInfo) const {
11354 auto It = OffloadEntriesTargetRegionCount.find(
11355 x: getTargetRegionEntryCountKey(EntryInfo));
11356 if (It == OffloadEntriesTargetRegionCount.end())
11357 return 0;
11358 return It->second;
11359}
11360
11361void OffloadEntriesInfoManager::incrementTargetRegionEntryInfoCount(
11362 const TargetRegionEntryInfo &EntryInfo) {
11363 OffloadEntriesTargetRegionCount[getTargetRegionEntryCountKey(EntryInfo)] =
11364 EntryInfo.Count + 1;
11365}
11366
11367/// Initialize target region entry.
11368void OffloadEntriesInfoManager::initializeTargetRegionEntryInfo(
11369 const TargetRegionEntryInfo &EntryInfo, unsigned Order) {
11370 OffloadEntriesTargetRegion[EntryInfo] =
11371 OffloadEntryInfoTargetRegion(Order, /*Addr=*/nullptr, /*ID=*/nullptr,
11372 OMPTargetRegionEntryTargetRegion);
11373 ++OffloadingEntriesNum;
11374}
11375
11376void OffloadEntriesInfoManager::registerTargetRegionEntryInfo(
11377 TargetRegionEntryInfo EntryInfo, Constant *Addr, Constant *ID,
11378 OMPTargetRegionEntryKind Flags) {
11379 assert(EntryInfo.Count == 0 && "expected default EntryInfo");
11380
11381 // Update the EntryInfo with the next available count for this location.
11382 EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
11383
11384 // If we are emitting code for a target, the entry is already initialized,
11385 // only has to be registered.
11386 if (OMPBuilder->Config.isTargetDevice()) {
11387 // This could happen if the device compilation is invoked standalone.
11388 if (!hasTargetRegionEntryInfo(EntryInfo)) {
11389 return;
11390 }
11391 auto &Entry = OffloadEntriesTargetRegion[EntryInfo];
11392 Entry.setAddress(Addr);
11393 Entry.setID(ID);
11394 Entry.setFlags(Flags);
11395 } else {
11396 if (Flags == OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion &&
11397 hasTargetRegionEntryInfo(EntryInfo, /*IgnoreAddressId*/ true))
11398 return;
11399 assert(!hasTargetRegionEntryInfo(EntryInfo) &&
11400 "Target region entry already registered!");
11401 OffloadEntryInfoTargetRegion Entry(OffloadingEntriesNum, Addr, ID, Flags);
11402 OffloadEntriesTargetRegion[EntryInfo] = Entry;
11403 ++OffloadingEntriesNum;
11404 }
11405 incrementTargetRegionEntryInfoCount(EntryInfo);
11406}
11407
11408bool OffloadEntriesInfoManager::hasTargetRegionEntryInfo(
11409 TargetRegionEntryInfo EntryInfo, bool IgnoreAddressId) const {
11410
11411 // Update the EntryInfo with the next available count for this location.
11412 EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
11413
11414 auto It = OffloadEntriesTargetRegion.find(x: EntryInfo);
11415 if (It == OffloadEntriesTargetRegion.end()) {
11416 return false;
11417 }
11418 // Fail if this entry is already registered.
11419 if (!IgnoreAddressId && (It->second.getAddress() || It->second.getID()))
11420 return false;
11421 return true;
11422}
11423
11424void OffloadEntriesInfoManager::actOnTargetRegionEntriesInfo(
11425 const OffloadTargetRegionEntryInfoActTy &Action) {
11426 // Scan all target region entries and perform the provided action.
11427 for (const auto &It : OffloadEntriesTargetRegion) {
11428 Action(It.first, It.second);
11429 }
11430}
11431
11432void OffloadEntriesInfoManager::initializeDeviceGlobalVarEntryInfo(
11433 StringRef Name, OMPTargetGlobalVarEntryKind Flags, unsigned Order) {
11434 OffloadEntriesDeviceGlobalVar.try_emplace(Key: Name, Args&: Order, Args&: Flags);
11435 ++OffloadingEntriesNum;
11436}
11437
11438void OffloadEntriesInfoManager::registerDeviceGlobalVarEntryInfo(
11439 StringRef VarName, Constant *Addr, int64_t VarSize,
11440 OMPTargetGlobalVarEntryKind Flags, GlobalValue::LinkageTypes Linkage) {
11441 if (OMPBuilder->Config.isTargetDevice()) {
11442 // This could happen if the device compilation is invoked standalone.
11443 if (!hasDeviceGlobalVarEntryInfo(VarName))
11444 return;
11445 auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
11446 if (Entry.getAddress() && hasDeviceGlobalVarEntryInfo(VarName)) {
11447 if (Entry.getVarSize() == 0) {
11448 Entry.setVarSize(VarSize);
11449 Entry.setLinkage(Linkage);
11450 }
11451 return;
11452 }
11453 Entry.setVarSize(VarSize);
11454 Entry.setLinkage(Linkage);
11455 Entry.setAddress(Addr);
11456 } else {
11457 if (hasDeviceGlobalVarEntryInfo(VarName)) {
11458 auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
11459 assert(Entry.isValid() && Entry.getFlags() == Flags &&
11460 "Entry not initialized!");
11461 if (Entry.getVarSize() == 0) {
11462 Entry.setVarSize(VarSize);
11463 Entry.setLinkage(Linkage);
11464 }
11465 return;
11466 }
11467 if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect ||
11468 Flags ==
11469 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirectVTable)
11470 OffloadEntriesDeviceGlobalVar.try_emplace(Key: VarName, Args&: OffloadingEntriesNum,
11471 Args&: Addr, Args&: VarSize, Args&: Flags, Args&: Linkage,
11472 Args: VarName.str());
11473 else
11474 OffloadEntriesDeviceGlobalVar.try_emplace(
11475 Key: VarName, Args&: OffloadingEntriesNum, Args&: Addr, Args&: VarSize, Args&: Flags, Args&: Linkage, Args: "");
11476 ++OffloadingEntriesNum;
11477 }
11478}
11479
11480void OffloadEntriesInfoManager::actOnDeviceGlobalVarEntriesInfo(
11481 const OffloadDeviceGlobalVarEntryInfoActTy &Action) {
11482 // Scan all target region entries and perform the provided action.
11483 for (const auto &E : OffloadEntriesDeviceGlobalVar)
11484 Action(E.getKey(), E.getValue());
11485}
11486
11487//===----------------------------------------------------------------------===//
11488// CanonicalLoopInfo
11489//===----------------------------------------------------------------------===//
11490
11491void CanonicalLoopInfo::collectControlBlocks(
11492 SmallVectorImpl<BasicBlock *> &BBs) {
11493 // We only count those BBs as control block for which we do not need to
11494 // reverse the CFG, i.e. not the loop body which can contain arbitrary control
11495 // flow. For consistency, this also means we do not add the Body block, which
11496 // is just the entry to the body code.
11497 BBs.reserve(N: BBs.size() + 6);
11498 BBs.append(IL: {getPreheader(), Header, Cond, Latch, Exit, getAfter()});
11499}
11500
11501BasicBlock *CanonicalLoopInfo::getPreheader() const {
11502 assert(isValid() && "Requires a valid canonical loop");
11503 for (BasicBlock *Pred : predecessors(BB: Header)) {
11504 if (Pred != Latch)
11505 return Pred;
11506 }
11507 llvm_unreachable("Missing preheader");
11508}
11509
11510void CanonicalLoopInfo::setTripCount(Value *TripCount) {
11511 assert(isValid() && "Requires a valid canonical loop");
11512
11513 Instruction *CmpI = &getCond()->front();
11514 assert(isa<CmpInst>(CmpI) && "First inst must compare IV with TripCount");
11515 CmpI->setOperand(i: 1, Val: TripCount);
11516
11517#ifndef NDEBUG
11518 assertOK();
11519#endif
11520}
11521
11522void CanonicalLoopInfo::mapIndVar(
11523 llvm::function_ref<Value *(Instruction *)> Updater) {
11524 assert(isValid() && "Requires a valid canonical loop");
11525
11526 Instruction *OldIV = getIndVar();
11527
11528 // Record all uses excluding those introduced by the updater. Uses by the
11529 // CanonicalLoopInfo itself to keep track of the number of iterations are
11530 // excluded.
11531 SmallVector<Use *> ReplacableUses;
11532 for (Use &U : OldIV->uses()) {
11533 auto *User = dyn_cast<Instruction>(Val: U.getUser());
11534 if (!User)
11535 continue;
11536 if (User->getParent() == getCond())
11537 continue;
11538 if (User->getParent() == getLatch())
11539 continue;
11540 ReplacableUses.push_back(Elt: &U);
11541 }
11542
11543 // Run the updater that may introduce new uses
11544 Value *NewIV = Updater(OldIV);
11545
11546 // Replace the old uses with the value returned by the updater.
11547 for (Use *U : ReplacableUses)
11548 U->set(NewIV);
11549
11550#ifndef NDEBUG
11551 assertOK();
11552#endif
11553}
11554
11555void CanonicalLoopInfo::assertOK() const {
11556#ifndef NDEBUG
11557 // No constraints if this object currently does not describe a loop.
11558 if (!isValid())
11559 return;
11560
11561 BasicBlock *Preheader = getPreheader();
11562 BasicBlock *Body = getBody();
11563 BasicBlock *After = getAfter();
11564
11565 // Verify standard control-flow we use for OpenMP loops.
11566 assert(Preheader);
11567 assert(isa<BranchInst>(Preheader->getTerminator()) &&
11568 "Preheader must terminate with unconditional branch");
11569 assert(Preheader->getSingleSuccessor() == Header &&
11570 "Preheader must jump to header");
11571
11572 assert(Header);
11573 assert(isa<BranchInst>(Header->getTerminator()) &&
11574 "Header must terminate with unconditional branch");
11575 assert(Header->getSingleSuccessor() == Cond &&
11576 "Header must jump to exiting block");
11577
11578 assert(Cond);
11579 assert(Cond->getSinglePredecessor() == Header &&
11580 "Exiting block only reachable from header");
11581
11582 assert(isa<BranchInst>(Cond->getTerminator()) &&
11583 "Exiting block must terminate with conditional branch");
11584 assert(size(successors(Cond)) == 2 &&
11585 "Exiting block must have two successors");
11586 assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(0) == Body &&
11587 "Exiting block's first successor jump to the body");
11588 assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(1) == Exit &&
11589 "Exiting block's second successor must exit the loop");
11590
11591 assert(Body);
11592 assert(Body->getSinglePredecessor() == Cond &&
11593 "Body only reachable from exiting block");
11594 assert(!isa<PHINode>(Body->front()));
11595
11596 assert(Latch);
11597 assert(isa<BranchInst>(Latch->getTerminator()) &&
11598 "Latch must terminate with unconditional branch");
11599 assert(Latch->getSingleSuccessor() == Header && "Latch must jump to header");
11600 // TODO: To support simple redirecting of the end of the body code that has
11601 // multiple; introduce another auxiliary basic block like preheader and after.
11602 assert(Latch->getSinglePredecessor() != nullptr);
11603 assert(!isa<PHINode>(Latch->front()));
11604
11605 assert(Exit);
11606 assert(isa<BranchInst>(Exit->getTerminator()) &&
11607 "Exit block must terminate with unconditional branch");
11608 assert(Exit->getSingleSuccessor() == After &&
11609 "Exit block must jump to after block");
11610
11611 assert(After);
11612 assert(After->getSinglePredecessor() == Exit &&
11613 "After block only reachable from exit block");
11614 assert(After->empty() || !isa<PHINode>(After->front()));
11615
11616 Instruction *IndVar = getIndVar();
11617 assert(IndVar && "Canonical induction variable not found?");
11618 assert(isa<IntegerType>(IndVar->getType()) &&
11619 "Induction variable must be an integer");
11620 assert(cast<PHINode>(IndVar)->getParent() == Header &&
11621 "Induction variable must be a PHI in the loop header");
11622 assert(cast<PHINode>(IndVar)->getIncomingBlock(0) == Preheader);
11623 assert(
11624 cast<ConstantInt>(cast<PHINode>(IndVar)->getIncomingValue(0))->isZero());
11625 assert(cast<PHINode>(IndVar)->getIncomingBlock(1) == Latch);
11626
11627 auto *NextIndVar = cast<PHINode>(IndVar)->getIncomingValue(1);
11628 assert(cast<Instruction>(NextIndVar)->getParent() == Latch);
11629 assert(cast<BinaryOperator>(NextIndVar)->getOpcode() == BinaryOperator::Add);
11630 assert(cast<BinaryOperator>(NextIndVar)->getOperand(0) == IndVar);
11631 assert(cast<ConstantInt>(cast<BinaryOperator>(NextIndVar)->getOperand(1))
11632 ->isOne());
11633
11634 Value *TripCount = getTripCount();
11635 assert(TripCount && "Loop trip count not found?");
11636 assert(IndVar->getType() == TripCount->getType() &&
11637 "Trip count and induction variable must have the same type");
11638
11639 auto *CmpI = cast<CmpInst>(&Cond->front());
11640 assert(CmpI->getPredicate() == CmpInst::ICMP_ULT &&
11641 "Exit condition must be a signed less-than comparison");
11642 assert(CmpI->getOperand(0) == IndVar &&
11643 "Exit condition must compare the induction variable");
11644 assert(CmpI->getOperand(1) == TripCount &&
11645 "Exit condition must compare with the trip count");
11646#endif
11647}
11648
11649void CanonicalLoopInfo::invalidate() {
11650 Header = nullptr;
11651 Cond = nullptr;
11652 Latch = nullptr;
11653 Exit = nullptr;
11654}
11655