1//===- OpenMPIRBuilder.cpp - Builder for LLVM-IR for OpenMP directives ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8/// \file
9///
10/// This file implements the OpenMPIRBuilder class, which is used as a
11/// convenient way to create LLVM instructions for OpenMP directives.
12///
13//===----------------------------------------------------------------------===//
14
15#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
16#include "llvm/ADT/SmallBitVector.h"
17#include "llvm/ADT/SmallSet.h"
18#include "llvm/ADT/StringExtras.h"
19#include "llvm/ADT/StringRef.h"
20#include "llvm/Analysis/AssumptionCache.h"
21#include "llvm/Analysis/CodeMetrics.h"
22#include "llvm/Analysis/LoopInfo.h"
23#include "llvm/Analysis/OptimizationRemarkEmitter.h"
24#include "llvm/Analysis/PostDominators.h"
25#include "llvm/Analysis/ScalarEvolution.h"
26#include "llvm/Analysis/TargetLibraryInfo.h"
27#include "llvm/Bitcode/BitcodeReader.h"
28#include "llvm/Frontend/Offloading/Utility.h"
29#include "llvm/Frontend/OpenMP/OMPGridValues.h"
30#include "llvm/IR/Attributes.h"
31#include "llvm/IR/BasicBlock.h"
32#include "llvm/IR/CFG.h"
33#include "llvm/IR/CallingConv.h"
34#include "llvm/IR/Constant.h"
35#include "llvm/IR/Constants.h"
36#include "llvm/IR/DIBuilder.h"
37#include "llvm/IR/DebugInfoMetadata.h"
38#include "llvm/IR/DerivedTypes.h"
39#include "llvm/IR/Function.h"
40#include "llvm/IR/GlobalVariable.h"
41#include "llvm/IR/IRBuilder.h"
42#include "llvm/IR/InstIterator.h"
43#include "llvm/IR/IntrinsicInst.h"
44#include "llvm/IR/LLVMContext.h"
45#include "llvm/IR/MDBuilder.h"
46#include "llvm/IR/Metadata.h"
47#include "llvm/IR/PassInstrumentation.h"
48#include "llvm/IR/PassManager.h"
49#include "llvm/IR/ReplaceConstant.h"
50#include "llvm/IR/Value.h"
51#include "llvm/MC/TargetRegistry.h"
52#include "llvm/Support/CommandLine.h"
53#include "llvm/Support/Error.h"
54#include "llvm/Support/ErrorHandling.h"
55#include "llvm/Support/FileSystem.h"
56#include "llvm/Support/VirtualFileSystem.h"
57#include "llvm/Target/TargetMachine.h"
58#include "llvm/Target/TargetOptions.h"
59#include "llvm/Transforms/Utils/BasicBlockUtils.h"
60#include "llvm/Transforms/Utils/Cloning.h"
61#include "llvm/Transforms/Utils/CodeExtractor.h"
62#include "llvm/Transforms/Utils/LoopPeel.h"
63#include "llvm/Transforms/Utils/UnrollLoop.h"
64
65#include <cstdint>
66#include <optional>
67
68#define DEBUG_TYPE "openmp-ir-builder"
69
70using namespace llvm;
71using namespace omp;
72
73static cl::opt<bool>
74 OptimisticAttributes("openmp-ir-builder-optimistic-attributes", cl::Hidden,
75 cl::desc("Use optimistic attributes describing "
76 "'as-if' properties of runtime calls."),
77 cl::init(Val: false));
78
79static cl::opt<double> UnrollThresholdFactor(
80 "openmp-ir-builder-unroll-threshold-factor", cl::Hidden,
81 cl::desc("Factor for the unroll threshold to account for code "
82 "simplifications still taking place"),
83 cl::init(Val: 1.5));
84
85#ifndef NDEBUG
86/// Return whether IP1 and IP2 are ambiguous, i.e. that inserting instructions
87/// at position IP1 may change the meaning of IP2 or vice-versa. This is because
88/// an InsertPoint stores the instruction before something is inserted. For
89/// instance, if both point to the same instruction, two IRBuilders alternating
90/// creating instruction will cause the instructions to be interleaved.
91static bool isConflictIP(IRBuilder<>::InsertPoint IP1,
92 IRBuilder<>::InsertPoint IP2) {
93 if (!IP1.isSet() || !IP2.isSet())
94 return false;
95 return IP1.getBlock() == IP2.getBlock() && IP1.getPoint() == IP2.getPoint();
96}
97
98static bool isValidWorkshareLoopScheduleType(OMPScheduleType SchedType) {
99 // Valid ordered/unordered and base algorithm combinations.
100 switch (SchedType & ~OMPScheduleType::MonotonicityMask) {
101 case OMPScheduleType::UnorderedStaticChunked:
102 case OMPScheduleType::UnorderedStatic:
103 case OMPScheduleType::UnorderedDynamicChunked:
104 case OMPScheduleType::UnorderedGuidedChunked:
105 case OMPScheduleType::UnorderedRuntime:
106 case OMPScheduleType::UnorderedAuto:
107 case OMPScheduleType::UnorderedTrapezoidal:
108 case OMPScheduleType::UnorderedGreedy:
109 case OMPScheduleType::UnorderedBalanced:
110 case OMPScheduleType::UnorderedGuidedIterativeChunked:
111 case OMPScheduleType::UnorderedGuidedAnalyticalChunked:
112 case OMPScheduleType::UnorderedSteal:
113 case OMPScheduleType::UnorderedStaticBalancedChunked:
114 case OMPScheduleType::UnorderedGuidedSimd:
115 case OMPScheduleType::UnorderedRuntimeSimd:
116 case OMPScheduleType::OrderedStaticChunked:
117 case OMPScheduleType::OrderedStatic:
118 case OMPScheduleType::OrderedDynamicChunked:
119 case OMPScheduleType::OrderedGuidedChunked:
120 case OMPScheduleType::OrderedRuntime:
121 case OMPScheduleType::OrderedAuto:
122 case OMPScheduleType::OrderdTrapezoidal:
123 case OMPScheduleType::NomergeUnorderedStaticChunked:
124 case OMPScheduleType::NomergeUnorderedStatic:
125 case OMPScheduleType::NomergeUnorderedDynamicChunked:
126 case OMPScheduleType::NomergeUnorderedGuidedChunked:
127 case OMPScheduleType::NomergeUnorderedRuntime:
128 case OMPScheduleType::NomergeUnorderedAuto:
129 case OMPScheduleType::NomergeUnorderedTrapezoidal:
130 case OMPScheduleType::NomergeUnorderedGreedy:
131 case OMPScheduleType::NomergeUnorderedBalanced:
132 case OMPScheduleType::NomergeUnorderedGuidedIterativeChunked:
133 case OMPScheduleType::NomergeUnorderedGuidedAnalyticalChunked:
134 case OMPScheduleType::NomergeUnorderedSteal:
135 case OMPScheduleType::NomergeOrderedStaticChunked:
136 case OMPScheduleType::NomergeOrderedStatic:
137 case OMPScheduleType::NomergeOrderedDynamicChunked:
138 case OMPScheduleType::NomergeOrderedGuidedChunked:
139 case OMPScheduleType::NomergeOrderedRuntime:
140 case OMPScheduleType::NomergeOrderedAuto:
141 case OMPScheduleType::NomergeOrderedTrapezoidal:
142 case OMPScheduleType::OrderedDistributeChunked:
143 case OMPScheduleType::OrderedDistribute:
144 break;
145 default:
146 return false;
147 }
148
149 // Must not set both monotonicity modifiers at the same time.
150 OMPScheduleType MonotonicityFlags =
151 SchedType & OMPScheduleType::MonotonicityMask;
152 if (MonotonicityFlags == OMPScheduleType::MonotonicityMask)
153 return false;
154
155 return true;
156}
157#endif
158
159/// This is wrapper over IRBuilderBase::restoreIP that also restores the current
160/// debug location to the last instruction in the specified basic block if the
161/// insert point points to the end of the block.
162static void restoreIPandDebugLoc(llvm::IRBuilderBase &Builder,
163 llvm::IRBuilderBase::InsertPoint IP) {
164 Builder.restoreIP(IP);
165 llvm::BasicBlock *BB = Builder.GetInsertBlock();
166 llvm::BasicBlock::iterator I = Builder.GetInsertPoint();
167 if (!BB->empty() && I == BB->end())
168 Builder.SetCurrentDebugLocation(BB->back().getStableDebugLoc());
169}
170
171static const omp::GV &getGridValue(const Triple &T, Function *Kernel) {
172 if (T.isAMDGPU()) {
173 StringRef Features =
174 Kernel->getFnAttribute(Kind: "target-features").getValueAsString();
175 if (Features.count(Str: "+wavefrontsize64"))
176 return omp::getAMDGPUGridValues<64>();
177 return omp::getAMDGPUGridValues<32>();
178 }
179 if (T.isNVPTX())
180 return omp::NVPTXGridValues;
181 if (T.isSPIRV())
182 return omp::SPIRVGridValues;
183 llvm_unreachable("No grid value available for this architecture!");
184}
185
186/// Determine which scheduling algorithm to use, determined from schedule clause
187/// arguments.
188static OMPScheduleType
189getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind, bool HasChunks,
190 bool HasSimdModifier, bool HasDistScheduleChunks) {
191 // Currently, the default schedule it static.
192 switch (ClauseKind) {
193 case OMP_SCHEDULE_Default:
194 case OMP_SCHEDULE_Static:
195 return HasChunks ? OMPScheduleType::BaseStaticChunked
196 : OMPScheduleType::BaseStatic;
197 case OMP_SCHEDULE_Dynamic:
198 return OMPScheduleType::BaseDynamicChunked;
199 case OMP_SCHEDULE_Guided:
200 return HasSimdModifier ? OMPScheduleType::BaseGuidedSimd
201 : OMPScheduleType::BaseGuidedChunked;
202 case OMP_SCHEDULE_Auto:
203 return llvm::omp::OMPScheduleType::BaseAuto;
204 case OMP_SCHEDULE_Runtime:
205 return HasSimdModifier ? OMPScheduleType::BaseRuntimeSimd
206 : OMPScheduleType::BaseRuntime;
207 case OMP_SCHEDULE_Distribute:
208 return HasDistScheduleChunks ? OMPScheduleType::BaseDistributeChunked
209 : OMPScheduleType::BaseDistribute;
210 }
211 llvm_unreachable("unhandled schedule clause argument");
212}
213
214/// Adds ordering modifier flags to schedule type.
215static OMPScheduleType
216getOpenMPOrderingScheduleType(OMPScheduleType BaseScheduleType,
217 bool HasOrderedClause) {
218 assert((BaseScheduleType & OMPScheduleType::ModifierMask) ==
219 OMPScheduleType::None &&
220 "Must not have ordering nor monotonicity flags already set");
221
222 OMPScheduleType OrderingModifier = HasOrderedClause
223 ? OMPScheduleType::ModifierOrdered
224 : OMPScheduleType::ModifierUnordered;
225 OMPScheduleType OrderingScheduleType = BaseScheduleType | OrderingModifier;
226
227 // Unsupported combinations
228 if (OrderingScheduleType ==
229 (OMPScheduleType::BaseGuidedSimd | OMPScheduleType::ModifierOrdered))
230 return OMPScheduleType::OrderedGuidedChunked;
231 else if (OrderingScheduleType == (OMPScheduleType::BaseRuntimeSimd |
232 OMPScheduleType::ModifierOrdered))
233 return OMPScheduleType::OrderedRuntime;
234
235 return OrderingScheduleType;
236}
237
238/// Adds monotonicity modifier flags to schedule type.
239static OMPScheduleType
240getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,
241 bool HasSimdModifier, bool HasMonotonic,
242 bool HasNonmonotonic, bool HasOrderedClause) {
243 assert((ScheduleType & OMPScheduleType::MonotonicityMask) ==
244 OMPScheduleType::None &&
245 "Must not have monotonicity flags already set");
246 assert((!HasMonotonic || !HasNonmonotonic) &&
247 "Monotonic and Nonmonotonic are contradicting each other");
248
249 if (HasMonotonic) {
250 return ScheduleType | OMPScheduleType::ModifierMonotonic;
251 } else if (HasNonmonotonic) {
252 return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
253 } else {
254 // OpenMP 5.1, 2.11.4 Worksharing-Loop Construct, Description.
255 // If the static schedule kind is specified or if the ordered clause is
256 // specified, and if the nonmonotonic modifier is not specified, the
257 // effect is as if the monotonic modifier is specified. Otherwise, unless
258 // the monotonic modifier is specified, the effect is as if the
259 // nonmonotonic modifier is specified.
260 OMPScheduleType BaseScheduleType =
261 ScheduleType & ~OMPScheduleType::ModifierMask;
262 if ((BaseScheduleType == OMPScheduleType::BaseStatic) ||
263 (BaseScheduleType == OMPScheduleType::BaseStaticChunked) ||
264 HasOrderedClause) {
265 // The monotonic is used by default in openmp runtime library, so no need
266 // to set it.
267 return ScheduleType;
268 } else {
269 return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
270 }
271 }
272}
273
274/// Determine the schedule type using schedule and ordering clause arguments.
275static OMPScheduleType
276computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
277 bool HasSimdModifier, bool HasMonotonicModifier,
278 bool HasNonmonotonicModifier, bool HasOrderedClause,
279 bool HasDistScheduleChunks) {
280 OMPScheduleType BaseSchedule = getOpenMPBaseScheduleType(
281 ClauseKind, HasChunks, HasSimdModifier, HasDistScheduleChunks);
282 OMPScheduleType OrderedSchedule =
283 getOpenMPOrderingScheduleType(BaseScheduleType: BaseSchedule, HasOrderedClause);
284 OMPScheduleType Result = getOpenMPMonotonicityScheduleType(
285 ScheduleType: OrderedSchedule, HasSimdModifier, HasMonotonic: HasMonotonicModifier,
286 HasNonmonotonic: HasNonmonotonicModifier, HasOrderedClause);
287
288 assert(isValidWorkshareLoopScheduleType(Result));
289 return Result;
290}
291
292/// Make \p Source branch to \p Target.
293///
294/// Handles two situations:
295/// * \p Source already has an unconditional branch.
296/// * \p Source is a degenerate block (no terminator because the BB is
297/// the current head of the IR construction).
298static void redirectTo(BasicBlock *Source, BasicBlock *Target, DebugLoc DL) {
299 if (Instruction *Term = Source->getTerminator()) {
300 auto *Br = cast<BranchInst>(Val: Term);
301 assert(!Br->isConditional() &&
302 "BB's terminator must be an unconditional branch (or degenerate)");
303 BasicBlock *Succ = Br->getSuccessor(Idx: 0);
304 Succ->removePredecessor(Pred: Source, /*KeepOneInputPHIs=*/true);
305 Br->setSuccessor(Idx: 0, BB: Target);
306 return;
307 }
308
309 auto *NewBr = BranchInst::Create(IfTrue: Target, InsertBefore: Source);
310 NewBr->setDebugLoc(DL);
311}
312
313void llvm::spliceBB(IRBuilderBase::InsertPoint IP, BasicBlock *New,
314 bool CreateBranch, DebugLoc DL) {
315 assert(New->getFirstInsertionPt() == New->begin() &&
316 "Target BB must not have PHI nodes");
317
318 // Move instructions to new block.
319 BasicBlock *Old = IP.getBlock();
320 // If the `Old` block is empty then there are no instructions to move. But in
321 // the new debug scheme, it could have trailing debug records which will be
322 // moved to `New` in `spliceDebugInfoEmptyBlock`. We dont want that for 2
323 // reasons:
324 // 1. If `New` is also empty, `BasicBlock::splice` crashes.
325 // 2. Even if `New` is not empty, the rationale to move those records to `New`
326 // (in `spliceDebugInfoEmptyBlock`) does not apply here. That function
327 // assumes that `Old` is optimized out and is going away. This is not the case
328 // here. The `Old` block is still being used e.g. a branch instruction is
329 // added to it later in this function.
330 // So we call `BasicBlock::splice` only when `Old` is not empty.
331 if (!Old->empty())
332 New->splice(ToIt: New->begin(), FromBB: Old, FromBeginIt: IP.getPoint(), FromEndIt: Old->end());
333
334 if (CreateBranch) {
335 auto *NewBr = BranchInst::Create(IfTrue: New, InsertBefore: Old);
336 NewBr->setDebugLoc(DL);
337 }
338}
339
340void llvm::spliceBB(IRBuilder<> &Builder, BasicBlock *New, bool CreateBranch) {
341 DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
342 BasicBlock *Old = Builder.GetInsertBlock();
343
344 spliceBB(IP: Builder.saveIP(), New, CreateBranch, DL: DebugLoc);
345 if (CreateBranch)
346 Builder.SetInsertPoint(Old->getTerminator());
347 else
348 Builder.SetInsertPoint(Old);
349
350 // SetInsertPoint also updates the Builder's debug location, but we want to
351 // keep the one the Builder was configured to use.
352 Builder.SetCurrentDebugLocation(DebugLoc);
353}
354
355BasicBlock *llvm::splitBB(IRBuilderBase::InsertPoint IP, bool CreateBranch,
356 DebugLoc DL, llvm::Twine Name) {
357 BasicBlock *Old = IP.getBlock();
358 BasicBlock *New = BasicBlock::Create(
359 Context&: Old->getContext(), Name: Name.isTriviallyEmpty() ? Old->getName() : Name,
360 Parent: Old->getParent(), InsertBefore: Old->getNextNode());
361 spliceBB(IP, New, CreateBranch, DL);
362 New->replaceSuccessorsPhiUsesWith(Old, New);
363 return New;
364}
365
366BasicBlock *llvm::splitBB(IRBuilderBase &Builder, bool CreateBranch,
367 llvm::Twine Name) {
368 DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
369 BasicBlock *New = splitBB(IP: Builder.saveIP(), CreateBranch, DL: DebugLoc, Name);
370 if (CreateBranch)
371 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
372 else
373 Builder.SetInsertPoint(Builder.GetInsertBlock());
374 // SetInsertPoint also updates the Builder's debug location, but we want to
375 // keep the one the Builder was configured to use.
376 Builder.SetCurrentDebugLocation(DebugLoc);
377 return New;
378}
379
380BasicBlock *llvm::splitBB(IRBuilder<> &Builder, bool CreateBranch,
381 llvm::Twine Name) {
382 DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
383 BasicBlock *New = splitBB(IP: Builder.saveIP(), CreateBranch, DL: DebugLoc, Name);
384 if (CreateBranch)
385 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
386 else
387 Builder.SetInsertPoint(Builder.GetInsertBlock());
388 // SetInsertPoint also updates the Builder's debug location, but we want to
389 // keep the one the Builder was configured to use.
390 Builder.SetCurrentDebugLocation(DebugLoc);
391 return New;
392}
393
394BasicBlock *llvm::splitBBWithSuffix(IRBuilderBase &Builder, bool CreateBranch,
395 llvm::Twine Suffix) {
396 BasicBlock *Old = Builder.GetInsertBlock();
397 return splitBB(Builder, CreateBranch, Name: Old->getName() + Suffix);
398}
399
400// This function creates a fake integer value and a fake use for the integer
401// value. It returns the fake value created. This is useful in modeling the
402// extra arguments to the outlined functions.
403Value *createFakeIntVal(IRBuilderBase &Builder,
404 OpenMPIRBuilder::InsertPointTy OuterAllocaIP,
405 llvm::SmallVectorImpl<Instruction *> &ToBeDeleted,
406 OpenMPIRBuilder::InsertPointTy InnerAllocaIP,
407 const Twine &Name = "", bool AsPtr = true,
408 bool Is64Bit = false) {
409 Builder.restoreIP(IP: OuterAllocaIP);
410 IntegerType *IntTy = Is64Bit ? Builder.getInt64Ty() : Builder.getInt32Ty();
411 Instruction *FakeVal;
412 AllocaInst *FakeValAddr =
413 Builder.CreateAlloca(Ty: IntTy, ArraySize: nullptr, Name: Name + ".addr");
414 ToBeDeleted.push_back(Elt: FakeValAddr);
415
416 if (AsPtr) {
417 FakeVal = FakeValAddr;
418 } else {
419 FakeVal = Builder.CreateLoad(Ty: IntTy, Ptr: FakeValAddr, Name: Name + ".val");
420 ToBeDeleted.push_back(Elt: FakeVal);
421 }
422
423 // Generate a fake use of this value
424 Builder.restoreIP(IP: InnerAllocaIP);
425 Instruction *UseFakeVal;
426 if (AsPtr) {
427 UseFakeVal = Builder.CreateLoad(Ty: IntTy, Ptr: FakeVal, Name: Name + ".use");
428 } else {
429 UseFakeVal = cast<BinaryOperator>(Val: Builder.CreateAdd(
430 LHS: FakeVal, RHS: Is64Bit ? Builder.getInt64(C: 10) : Builder.getInt32(C: 10)));
431 }
432 ToBeDeleted.push_back(Elt: UseFakeVal);
433 return FakeVal;
434}
435
436//===----------------------------------------------------------------------===//
437// OpenMPIRBuilderConfig
438//===----------------------------------------------------------------------===//
439
440namespace {
441LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
442/// Values for bit flags for marking which requires clauses have been used.
443enum OpenMPOffloadingRequiresDirFlags {
444 /// flag undefined.
445 OMP_REQ_UNDEFINED = 0x000,
446 /// no requires directive present.
447 OMP_REQ_NONE = 0x001,
448 /// reverse_offload clause.
449 OMP_REQ_REVERSE_OFFLOAD = 0x002,
450 /// unified_address clause.
451 OMP_REQ_UNIFIED_ADDRESS = 0x004,
452 /// unified_shared_memory clause.
453 OMP_REQ_UNIFIED_SHARED_MEMORY = 0x008,
454 /// dynamic_allocators clause.
455 OMP_REQ_DYNAMIC_ALLOCATORS = 0x010,
456 LLVM_MARK_AS_BITMASK_ENUM(/*LargestValue=*/OMP_REQ_DYNAMIC_ALLOCATORS)
457};
458
459} // anonymous namespace
460
461OpenMPIRBuilderConfig::OpenMPIRBuilderConfig()
462 : RequiresFlags(OMP_REQ_UNDEFINED) {}
463
464OpenMPIRBuilderConfig::OpenMPIRBuilderConfig(
465 bool IsTargetDevice, bool IsGPU, bool OpenMPOffloadMandatory,
466 bool HasRequiresReverseOffload, bool HasRequiresUnifiedAddress,
467 bool HasRequiresUnifiedSharedMemory, bool HasRequiresDynamicAllocators)
468 : IsTargetDevice(IsTargetDevice), IsGPU(IsGPU),
469 OpenMPOffloadMandatory(OpenMPOffloadMandatory),
470 RequiresFlags(OMP_REQ_UNDEFINED) {
471 if (HasRequiresReverseOffload)
472 RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
473 if (HasRequiresUnifiedAddress)
474 RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
475 if (HasRequiresUnifiedSharedMemory)
476 RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
477 if (HasRequiresDynamicAllocators)
478 RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
479}
480
481bool OpenMPIRBuilderConfig::hasRequiresReverseOffload() const {
482 return RequiresFlags & OMP_REQ_REVERSE_OFFLOAD;
483}
484
485bool OpenMPIRBuilderConfig::hasRequiresUnifiedAddress() const {
486 return RequiresFlags & OMP_REQ_UNIFIED_ADDRESS;
487}
488
489bool OpenMPIRBuilderConfig::hasRequiresUnifiedSharedMemory() const {
490 return RequiresFlags & OMP_REQ_UNIFIED_SHARED_MEMORY;
491}
492
493bool OpenMPIRBuilderConfig::hasRequiresDynamicAllocators() const {
494 return RequiresFlags & OMP_REQ_DYNAMIC_ALLOCATORS;
495}
496
497int64_t OpenMPIRBuilderConfig::getRequiresFlags() const {
498 return hasRequiresFlags() ? RequiresFlags
499 : static_cast<int64_t>(OMP_REQ_NONE);
500}
501
502void OpenMPIRBuilderConfig::setHasRequiresReverseOffload(bool Value) {
503 if (Value)
504 RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
505 else
506 RequiresFlags &= ~OMP_REQ_REVERSE_OFFLOAD;
507}
508
509void OpenMPIRBuilderConfig::setHasRequiresUnifiedAddress(bool Value) {
510 if (Value)
511 RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
512 else
513 RequiresFlags &= ~OMP_REQ_UNIFIED_ADDRESS;
514}
515
516void OpenMPIRBuilderConfig::setHasRequiresUnifiedSharedMemory(bool Value) {
517 if (Value)
518 RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
519 else
520 RequiresFlags &= ~OMP_REQ_UNIFIED_SHARED_MEMORY;
521}
522
523void OpenMPIRBuilderConfig::setHasRequiresDynamicAllocators(bool Value) {
524 if (Value)
525 RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
526 else
527 RequiresFlags &= ~OMP_REQ_DYNAMIC_ALLOCATORS;
528}
529
530//===----------------------------------------------------------------------===//
531// OpenMPIRBuilder
532//===----------------------------------------------------------------------===//
533
534void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
535 IRBuilderBase &Builder,
536 SmallVector<Value *> &ArgsVector) {
537 Value *Version = Builder.getInt32(OMP_KERNEL_ARG_VERSION);
538 Value *PointerNum = Builder.getInt32(C: KernelArgs.NumTargetItems);
539 auto Int32Ty = Type::getInt32Ty(C&: Builder.getContext());
540 constexpr size_t MaxDim = 3;
541 Value *ZeroArray = Constant::getNullValue(Ty: ArrayType::get(ElementType: Int32Ty, NumElements: MaxDim));
542
543 Value *HasNoWaitFlag = Builder.getInt64(C: KernelArgs.HasNoWait);
544
545 Value *DynCGroupMemFallbackFlag =
546 Builder.getInt64(C: static_cast<uint64_t>(KernelArgs.DynCGroupMemFallback));
547 DynCGroupMemFallbackFlag = Builder.CreateShl(LHS: DynCGroupMemFallbackFlag, RHS: 2);
548 Value *Flags = Builder.CreateOr(LHS: HasNoWaitFlag, RHS: DynCGroupMemFallbackFlag);
549
550 assert(!KernelArgs.NumTeams.empty() && !KernelArgs.NumThreads.empty());
551
552 Value *NumTeams3D =
553 Builder.CreateInsertValue(Agg: ZeroArray, Val: KernelArgs.NumTeams[0], Idxs: {0});
554 Value *NumThreads3D =
555 Builder.CreateInsertValue(Agg: ZeroArray, Val: KernelArgs.NumThreads[0], Idxs: {0});
556 for (unsigned I :
557 seq<unsigned>(Begin: 1, End: std::min(a: KernelArgs.NumTeams.size(), b: MaxDim)))
558 NumTeams3D =
559 Builder.CreateInsertValue(Agg: NumTeams3D, Val: KernelArgs.NumTeams[I], Idxs: {I});
560 for (unsigned I :
561 seq<unsigned>(Begin: 1, End: std::min(a: KernelArgs.NumThreads.size(), b: MaxDim)))
562 NumThreads3D =
563 Builder.CreateInsertValue(Agg: NumThreads3D, Val: KernelArgs.NumThreads[I], Idxs: {I});
564
565 ArgsVector = {Version,
566 PointerNum,
567 KernelArgs.RTArgs.BasePointersArray,
568 KernelArgs.RTArgs.PointersArray,
569 KernelArgs.RTArgs.SizesArray,
570 KernelArgs.RTArgs.MapTypesArray,
571 KernelArgs.RTArgs.MapNamesArray,
572 KernelArgs.RTArgs.MappersArray,
573 KernelArgs.NumIterations,
574 Flags,
575 NumTeams3D,
576 NumThreads3D,
577 KernelArgs.DynCGroupMem};
578}
579
580void OpenMPIRBuilder::addAttributes(omp::RuntimeFunction FnID, Function &Fn) {
581 LLVMContext &Ctx = Fn.getContext();
582
583 // Get the function's current attributes.
584 auto Attrs = Fn.getAttributes();
585 auto FnAttrs = Attrs.getFnAttrs();
586 auto RetAttrs = Attrs.getRetAttrs();
587 SmallVector<AttributeSet, 4> ArgAttrs;
588 for (size_t ArgNo = 0; ArgNo < Fn.arg_size(); ++ArgNo)
589 ArgAttrs.emplace_back(Args: Attrs.getParamAttrs(ArgNo));
590
591 // Add AS to FnAS while taking special care with integer extensions.
592 auto addAttrSet = [&](AttributeSet &FnAS, const AttributeSet &AS,
593 bool Param = true) -> void {
594 bool HasSignExt = AS.hasAttribute(Kind: Attribute::SExt);
595 bool HasZeroExt = AS.hasAttribute(Kind: Attribute::ZExt);
596 if (HasSignExt || HasZeroExt) {
597 assert(AS.getNumAttributes() == 1 &&
598 "Currently not handling extension attr combined with others.");
599 if (Param) {
600 if (auto AK = TargetLibraryInfo::getExtAttrForI32Param(T, Signed: HasSignExt))
601 FnAS = FnAS.addAttribute(C&: Ctx, Kind: AK);
602 } else if (auto AK =
603 TargetLibraryInfo::getExtAttrForI32Return(T, Signed: HasSignExt))
604 FnAS = FnAS.addAttribute(C&: Ctx, Kind: AK);
605 } else {
606 FnAS = FnAS.addAttributes(C&: Ctx, AS);
607 }
608 };
609
610#define OMP_ATTRS_SET(VarName, AttrSet) AttributeSet VarName = AttrSet;
611#include "llvm/Frontend/OpenMP/OMPKinds.def"
612
613 // Add attributes to the function declaration.
614 switch (FnID) {
615#define OMP_RTL_ATTRS(Enum, FnAttrSet, RetAttrSet, ArgAttrSets) \
616 case Enum: \
617 FnAttrs = FnAttrs.addAttributes(Ctx, FnAttrSet); \
618 addAttrSet(RetAttrs, RetAttrSet, /*Param*/ false); \
619 for (size_t ArgNo = 0; ArgNo < ArgAttrSets.size(); ++ArgNo) \
620 addAttrSet(ArgAttrs[ArgNo], ArgAttrSets[ArgNo]); \
621 Fn.setAttributes(AttributeList::get(Ctx, FnAttrs, RetAttrs, ArgAttrs)); \
622 break;
623#include "llvm/Frontend/OpenMP/OMPKinds.def"
624 default:
625 // Attributes are optional.
626 break;
627 }
628}
629
630FunctionCallee
631OpenMPIRBuilder::getOrCreateRuntimeFunction(Module &M, RuntimeFunction FnID) {
632 FunctionType *FnTy = nullptr;
633 Function *Fn = nullptr;
634
635 // Try to find the declation in the module first.
636 switch (FnID) {
637#define OMP_RTL(Enum, Str, IsVarArg, ReturnType, ...) \
638 case Enum: \
639 FnTy = FunctionType::get(ReturnType, ArrayRef<Type *>{__VA_ARGS__}, \
640 IsVarArg); \
641 Fn = M.getFunction(Str); \
642 break;
643#include "llvm/Frontend/OpenMP/OMPKinds.def"
644 }
645
646 if (!Fn) {
647 // Create a new declaration if we need one.
648 switch (FnID) {
649#define OMP_RTL(Enum, Str, ...) \
650 case Enum: \
651 Fn = Function::Create(FnTy, GlobalValue::ExternalLinkage, Str, M); \
652 break;
653#include "llvm/Frontend/OpenMP/OMPKinds.def"
654 }
655 Fn->setCallingConv(Config.getRuntimeCC());
656 // Add information if the runtime function takes a callback function
657 if (FnID == OMPRTL___kmpc_fork_call || FnID == OMPRTL___kmpc_fork_teams) {
658 if (!Fn->hasMetadata(KindID: LLVMContext::MD_callback)) {
659 LLVMContext &Ctx = Fn->getContext();
660 MDBuilder MDB(Ctx);
661 // Annotate the callback behavior of the runtime function:
662 // - The callback callee is argument number 2 (microtask).
663 // - The first two arguments of the callback callee are unknown (-1).
664 // - All variadic arguments to the runtime function are passed to the
665 // callback callee.
666 Fn->addMetadata(
667 KindID: LLVMContext::MD_callback,
668 MD&: *MDNode::get(Context&: Ctx, MDs: {MDB.createCallbackEncoding(
669 CalleeArgNo: 2, Arguments: {-1, -1}, /* VarArgsArePassed */ true)}));
670 }
671 }
672
673 LLVM_DEBUG(dbgs() << "Created OpenMP runtime function " << Fn->getName()
674 << " with type " << *Fn->getFunctionType() << "\n");
675 addAttributes(FnID, Fn&: *Fn);
676
677 } else {
678 LLVM_DEBUG(dbgs() << "Found OpenMP runtime function " << Fn->getName()
679 << " with type " << *Fn->getFunctionType() << "\n");
680 }
681
682 assert(Fn && "Failed to create OpenMP runtime function");
683
684 return {FnTy, Fn};
685}
686
687Expected<BasicBlock *>
688OpenMPIRBuilder::FinalizationInfo::getFiniBB(IRBuilderBase &Builder) {
689 if (!FiniBB) {
690 Function *ParentFunc = Builder.GetInsertBlock()->getParent();
691 IRBuilderBase::InsertPointGuard Guard(Builder);
692 FiniBB = BasicBlock::Create(Context&: Builder.getContext(), Name: ".fini", Parent: ParentFunc);
693 Builder.SetInsertPoint(FiniBB);
694 // FiniCB adds the branch to the exit stub.
695 if (Error Err = FiniCB(Builder.saveIP()))
696 return Err;
697 }
698 return FiniBB;
699}
700
701Error OpenMPIRBuilder::FinalizationInfo::mergeFiniBB(IRBuilderBase &Builder,
702 BasicBlock *OtherFiniBB) {
703 // Simple case: FiniBB does not exist yet: re-use OtherFiniBB.
704 if (!FiniBB) {
705 FiniBB = OtherFiniBB;
706
707 Builder.SetInsertPoint(FiniBB->getFirstNonPHIIt());
708 if (Error Err = FiniCB(Builder.saveIP()))
709 return Err;
710
711 return Error::success();
712 }
713
714 // Move instructions from FiniBB to the start of OtherFiniBB.
715 auto EndIt = FiniBB->end();
716 if (FiniBB->size() >= 1)
717 if (auto Prev = std::prev(x: EndIt); Prev->isTerminator())
718 EndIt = Prev;
719 OtherFiniBB->splice(ToIt: OtherFiniBB->getFirstNonPHIIt(), FromBB: FiniBB, FromBeginIt: FiniBB->begin(),
720 FromEndIt: EndIt);
721
722 FiniBB->replaceAllUsesWith(V: OtherFiniBB);
723 FiniBB->eraseFromParent();
724 FiniBB = OtherFiniBB;
725 return Error::success();
726}
727
728Function *OpenMPIRBuilder::getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID) {
729 FunctionCallee RTLFn = getOrCreateRuntimeFunction(M, FnID);
730 auto *Fn = dyn_cast<llvm::Function>(Val: RTLFn.getCallee());
731 assert(Fn && "Failed to create OpenMP runtime function pointer");
732 return Fn;
733}
734
735CallInst *OpenMPIRBuilder::createRuntimeFunctionCall(FunctionCallee Callee,
736 ArrayRef<Value *> Args,
737 StringRef Name) {
738 CallInst *Call = Builder.CreateCall(Callee, Args, Name);
739 Call->setCallingConv(Config.getRuntimeCC());
740 return Call;
741}
742
743void OpenMPIRBuilder::initialize() { initializeTypes(M); }
744
745static void raiseUserConstantDataAllocasToEntryBlock(IRBuilderBase &Builder,
746 Function *Function) {
747 BasicBlock &EntryBlock = Function->getEntryBlock();
748 BasicBlock::iterator MoveLocInst = EntryBlock.getFirstNonPHIIt();
749
750 // Loop over blocks looking for constant allocas, skipping the entry block
751 // as any allocas there are already in the desired location.
752 for (auto Block = std::next(x: Function->begin(), n: 1); Block != Function->end();
753 Block++) {
754 for (auto Inst = Block->getReverseIterator()->begin();
755 Inst != Block->getReverseIterator()->end();) {
756 if (auto *AllocaInst = dyn_cast_if_present<llvm::AllocaInst>(Val&: Inst)) {
757 Inst++;
758 if (!isa<ConstantData>(Val: AllocaInst->getArraySize()))
759 continue;
760 AllocaInst->moveBeforePreserving(MovePos: MoveLocInst);
761 } else {
762 Inst++;
763 }
764 }
765 }
766}
767
768static void hoistNonEntryAllocasToEntryBlock(llvm::BasicBlock &Block) {
769 llvm::SmallVector<llvm::Instruction *> AllocasToMove;
770
771 auto ShouldHoistAlloca = [](const llvm::AllocaInst &AllocaInst) {
772 // TODO: For now, we support simple static allocations, we might need to
773 // move non-static ones as well. However, this will need further analysis to
774 // move the lenght arguments as well.
775 return !AllocaInst.isArrayAllocation();
776 };
777
778 for (llvm::Instruction &Inst : Block)
779 if (auto *AllocaInst = llvm::dyn_cast<llvm::AllocaInst>(Val: &Inst))
780 if (ShouldHoistAlloca(*AllocaInst))
781 AllocasToMove.push_back(Elt: AllocaInst);
782
783 auto InsertPoint =
784 Block.getParent()->getEntryBlock().getTerminator()->getIterator();
785
786 for (llvm::Instruction *AllocaInst : AllocasToMove)
787 AllocaInst->moveBefore(InsertPos: InsertPoint);
788}
789
790static void hoistNonEntryAllocasToEntryBlock(llvm::Function *Func) {
791 PostDominatorTree PostDomTree(*Func);
792 for (llvm::BasicBlock &BB : *Func)
793 if (PostDomTree.properlyDominates(A: &BB, B: &Func->getEntryBlock()))
794 hoistNonEntryAllocasToEntryBlock(Block&: BB);
795}
796
797void OpenMPIRBuilder::finalize(Function *Fn) {
798 SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
799 SmallVector<BasicBlock *, 32> Blocks;
800 SmallVector<OutlineInfo, 16> DeferredOutlines;
801 for (OutlineInfo &OI : OutlineInfos) {
802 // Skip functions that have not finalized yet; may happen with nested
803 // function generation.
804 if (Fn && OI.getFunction() != Fn) {
805 DeferredOutlines.push_back(Elt: OI);
806 continue;
807 }
808
809 ParallelRegionBlockSet.clear();
810 Blocks.clear();
811 OI.collectBlocks(BlockSet&: ParallelRegionBlockSet, BlockVector&: Blocks);
812
813 Function *OuterFn = OI.getFunction();
814 CodeExtractorAnalysisCache CEAC(*OuterFn);
815 // If we generate code for the target device, we need to allocate
816 // struct for aggregate params in the device default alloca address space.
817 // OpenMP runtime requires that the params of the extracted functions are
818 // passed as zero address space pointers. This flag ensures that
819 // CodeExtractor generates correct code for extracted functions
820 // which are used by OpenMP runtime.
821 bool ArgsInZeroAddressSpace = Config.isTargetDevice();
822 CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
823 /* AggregateArgs */ true,
824 /* BlockFrequencyInfo */ nullptr,
825 /* BranchProbabilityInfo */ nullptr,
826 /* AssumptionCache */ nullptr,
827 /* AllowVarArgs */ true,
828 /* AllowAlloca */ true,
829 /* AllocaBlock*/ OI.OuterAllocaBB,
830 /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
831
832 LLVM_DEBUG(dbgs() << "Before outlining: " << *OuterFn << "\n");
833 LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
834 << " Exit: " << OI.ExitBB->getName() << "\n");
835 assert(Extractor.isEligible() &&
836 "Expected OpenMP outlining to be possible!");
837
838 for (auto *V : OI.ExcludeArgsFromAggregate)
839 Extractor.excludeArgFromAggregate(Arg: V);
840
841 Function *OutlinedFn =
842 Extractor.extractCodeRegion(CEAC, Inputs&: OI.Inputs, Outputs&: OI.Outputs);
843
844 // Forward target-cpu, target-features attributes to the outlined function.
845 auto TargetCpuAttr = OuterFn->getFnAttribute(Kind: "target-cpu");
846 if (TargetCpuAttr.isStringAttribute())
847 OutlinedFn->addFnAttr(Attr: TargetCpuAttr);
848
849 auto TargetFeaturesAttr = OuterFn->getFnAttribute(Kind: "target-features");
850 if (TargetFeaturesAttr.isStringAttribute())
851 OutlinedFn->addFnAttr(Attr: TargetFeaturesAttr);
852
853 LLVM_DEBUG(dbgs() << "After outlining: " << *OuterFn << "\n");
854 LLVM_DEBUG(dbgs() << " Outlined function: " << *OutlinedFn << "\n");
855 assert(OutlinedFn->getReturnType()->isVoidTy() &&
856 "OpenMP outlined functions should not return a value!");
857
858 // For compability with the clang CG we move the outlined function after the
859 // one with the parallel region.
860 OutlinedFn->removeFromParent();
861 M.getFunctionList().insertAfter(where: OuterFn->getIterator(), New: OutlinedFn);
862
863 // Remove the artificial entry introduced by the extractor right away, we
864 // made our own entry block after all.
865 {
866 BasicBlock &ArtificialEntry = OutlinedFn->getEntryBlock();
867 assert(ArtificialEntry.getUniqueSuccessor() == OI.EntryBB);
868 assert(OI.EntryBB->getUniquePredecessor() == &ArtificialEntry);
869 // Move instructions from the to-be-deleted ArtificialEntry to the entry
870 // basic block of the parallel region. CodeExtractor generates
871 // instructions to unwrap the aggregate argument and may sink
872 // allocas/bitcasts for values that are solely used in the outlined region
873 // and do not escape.
874 assert(!ArtificialEntry.empty() &&
875 "Expected instructions to add in the outlined region entry");
876 for (BasicBlock::reverse_iterator It = ArtificialEntry.rbegin(),
877 End = ArtificialEntry.rend();
878 It != End;) {
879 Instruction &I = *It;
880 It++;
881
882 if (I.isTerminator()) {
883 // Absorb any debug value that terminator may have
884 if (OI.EntryBB->getTerminator())
885 OI.EntryBB->getTerminator()->adoptDbgRecords(
886 BB: &ArtificialEntry, It: I.getIterator(), InsertAtHead: false);
887 continue;
888 }
889
890 I.moveBeforePreserving(BB&: *OI.EntryBB, I: OI.EntryBB->getFirstInsertionPt());
891 }
892
893 OI.EntryBB->moveBefore(MovePos: &ArtificialEntry);
894 ArtificialEntry.eraseFromParent();
895 }
896 assert(&OutlinedFn->getEntryBlock() == OI.EntryBB);
897 assert(OutlinedFn && OutlinedFn->hasNUses(1));
898
899 // Run a user callback, e.g. to add attributes.
900 if (OI.PostOutlineCB)
901 OI.PostOutlineCB(*OutlinedFn);
902
903 if (OI.FixUpNonEntryAllocas)
904 hoistNonEntryAllocasToEntryBlock(Func: OutlinedFn);
905 }
906
907 // Remove work items that have been completed.
908 OutlineInfos = std::move(DeferredOutlines);
909
910 // The createTarget functions embeds user written code into
911 // the target region which may inject allocas which need to
912 // be moved to the entry block of our target or risk malformed
913 // optimisations by later passes, this is only relevant for
914 // the device pass which appears to be a little more delicate
915 // when it comes to optimisations (however, we do not block on
916 // that here, it's up to the inserter to the list to do so).
917 // This notbaly has to occur after the OutlinedInfo candidates
918 // have been extracted so we have an end product that will not
919 // be implicitly adversely affected by any raises unless
920 // intentionally appended to the list.
921 // NOTE: This only does so for ConstantData, it could be extended
922 // to ConstantExpr's with further effort, however, they should
923 // largely be folded when they get here. Extending it to runtime
924 // defined/read+writeable allocation sizes would be non-trivial
925 // (need to factor in movement of any stores to variables the
926 // allocation size depends on, as well as the usual loads,
927 // otherwise it'll yield the wrong result after movement) and
928 // likely be more suitable as an LLVM optimisation pass.
929 for (Function *F : ConstantAllocaRaiseCandidates)
930 raiseUserConstantDataAllocasToEntryBlock(Builder, Function: F);
931
932 EmitMetadataErrorReportFunctionTy &&ErrorReportFn =
933 [](EmitMetadataErrorKind Kind,
934 const TargetRegionEntryInfo &EntryInfo) -> void {
935 errs() << "Error of kind: " << Kind
936 << " when emitting offload entries and metadata during "
937 "OMPIRBuilder finalization \n";
938 };
939
940 if (!OffloadInfoManager.empty())
941 createOffloadEntriesAndInfoMetadata(ErrorReportFunction&: ErrorReportFn);
942
943 if (Config.EmitLLVMUsedMetaInfo.value_or(u: false)) {
944 std::vector<WeakTrackingVH> LLVMCompilerUsed = {
945 M.getGlobalVariable(Name: "__openmp_nvptx_data_transfer_temporary_storage")};
946 emitUsed(Name: "llvm.compiler.used", List: LLVMCompilerUsed);
947 }
948
949 IsFinalized = true;
950}
951
952bool OpenMPIRBuilder::isFinalized() { return IsFinalized; }
953
954OpenMPIRBuilder::~OpenMPIRBuilder() {
955 assert(OutlineInfos.empty() && "There must be no outstanding outlinings");
956}
957
958GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
959 IntegerType *I32Ty = Type::getInt32Ty(C&: M.getContext());
960 auto *GV =
961 new GlobalVariable(M, I32Ty,
962 /* isConstant = */ true, GlobalValue::WeakODRLinkage,
963 ConstantInt::get(Ty: I32Ty, V: Value), Name);
964 GV->setVisibility(GlobalValue::HiddenVisibility);
965
966 return GV;
967}
968
969void OpenMPIRBuilder::emitUsed(StringRef Name, ArrayRef<WeakTrackingVH> List) {
970 if (List.empty())
971 return;
972
973 // Convert List to what ConstantArray needs.
974 SmallVector<Constant *, 8> UsedArray;
975 UsedArray.resize(N: List.size());
976 for (unsigned I = 0, E = List.size(); I != E; ++I)
977 UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
978 C: cast<Constant>(Val: &*List[I]), Ty: Builder.getPtrTy());
979
980 if (UsedArray.empty())
981 return;
982 ArrayType *ATy = ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: UsedArray.size());
983
984 auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
985 ConstantArray::get(T: ATy, V: UsedArray), Name);
986
987 GV->setSection("llvm.metadata");
988}
989
990GlobalVariable *
991OpenMPIRBuilder::emitKernelExecutionMode(StringRef KernelName,
992 OMPTgtExecModeFlags Mode) {
993 auto *Int8Ty = Builder.getInt8Ty();
994 auto *GVMode = new GlobalVariable(
995 M, Int8Ty, /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
996 ConstantInt::get(Ty: Int8Ty, V: Mode), Twine(KernelName, "_exec_mode"));
997 GVMode->setVisibility(GlobalVariable::ProtectedVisibility);
998 return GVMode;
999}
1000
1001Constant *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr,
1002 uint32_t SrcLocStrSize,
1003 IdentFlag LocFlags,
1004 unsigned Reserve2Flags) {
1005 // Enable "C-mode".
1006 LocFlags |= OMP_IDENT_FLAG_KMPC;
1007
1008 Constant *&Ident =
1009 IdentMap[{SrcLocStr, uint64_t(LocFlags) << 31 | Reserve2Flags}];
1010 if (!Ident) {
1011 Constant *I32Null = ConstantInt::getNullValue(Ty: Int32);
1012 Constant *IdentData[] = {I32Null,
1013 ConstantInt::get(Ty: Int32, V: uint32_t(LocFlags)),
1014 ConstantInt::get(Ty: Int32, V: Reserve2Flags),
1015 ConstantInt::get(Ty: Int32, V: SrcLocStrSize), SrcLocStr};
1016
1017 size_t SrcLocStrArgIdx = 4;
1018 if (OpenMPIRBuilder::Ident->getElementType(N: SrcLocStrArgIdx)
1019 ->getPointerAddressSpace() !=
1020 IdentData[SrcLocStrArgIdx]->getType()->getPointerAddressSpace())
1021 IdentData[SrcLocStrArgIdx] = ConstantExpr::getAddrSpaceCast(
1022 C: SrcLocStr, Ty: OpenMPIRBuilder::Ident->getElementType(N: SrcLocStrArgIdx));
1023 Constant *Initializer =
1024 ConstantStruct::get(T: OpenMPIRBuilder::Ident, V: IdentData);
1025
1026 // Look for existing encoding of the location + flags, not needed but
1027 // minimizes the difference to the existing solution while we transition.
1028 for (GlobalVariable &GV : M.globals())
1029 if (GV.getValueType() == OpenMPIRBuilder::Ident && GV.hasInitializer())
1030 if (GV.getInitializer() == Initializer)
1031 Ident = &GV;
1032
1033 if (!Ident) {
1034 auto *GV = new GlobalVariable(
1035 M, OpenMPIRBuilder::Ident,
1036 /* isConstant = */ true, GlobalValue::PrivateLinkage, Initializer, "",
1037 nullptr, GlobalValue::NotThreadLocal,
1038 M.getDataLayout().getDefaultGlobalsAddressSpace());
1039 GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
1040 GV->setAlignment(Align(8));
1041 Ident = GV;
1042 }
1043 }
1044
1045 return ConstantExpr::getPointerBitCastOrAddrSpaceCast(C: Ident, Ty: IdentPtr);
1046}
1047
1048Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef LocStr,
1049 uint32_t &SrcLocStrSize) {
1050 SrcLocStrSize = LocStr.size();
1051 Constant *&SrcLocStr = SrcLocStrMap[LocStr];
1052 if (!SrcLocStr) {
1053 Constant *Initializer =
1054 ConstantDataArray::getString(Context&: M.getContext(), Initializer: LocStr);
1055
1056 // Look for existing encoding of the location, not needed but minimizes the
1057 // difference to the existing solution while we transition.
1058 for (GlobalVariable &GV : M.globals())
1059 if (GV.isConstant() && GV.hasInitializer() &&
1060 GV.getInitializer() == Initializer)
1061 return SrcLocStr = ConstantExpr::getPointerCast(C: &GV, Ty: Int8Ptr);
1062
1063 SrcLocStr = Builder.CreateGlobalString(
1064 Str: LocStr, /*Name=*/"", AddressSpace: M.getDataLayout().getDefaultGlobalsAddressSpace(),
1065 M: &M);
1066 }
1067 return SrcLocStr;
1068}
1069
1070Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef FunctionName,
1071 StringRef FileName,
1072 unsigned Line, unsigned Column,
1073 uint32_t &SrcLocStrSize) {
1074 SmallString<128> Buffer;
1075 Buffer.push_back(Elt: ';');
1076 Buffer.append(RHS: FileName);
1077 Buffer.push_back(Elt: ';');
1078 Buffer.append(RHS: FunctionName);
1079 Buffer.push_back(Elt: ';');
1080 Buffer.append(RHS: std::to_string(val: Line));
1081 Buffer.push_back(Elt: ';');
1082 Buffer.append(RHS: std::to_string(val: Column));
1083 Buffer.push_back(Elt: ';');
1084 Buffer.push_back(Elt: ';');
1085 return getOrCreateSrcLocStr(LocStr: Buffer.str(), SrcLocStrSize);
1086}
1087
1088Constant *
1089OpenMPIRBuilder::getOrCreateDefaultSrcLocStr(uint32_t &SrcLocStrSize) {
1090 StringRef UnknownLoc = ";unknown;unknown;0;0;;";
1091 return getOrCreateSrcLocStr(LocStr: UnknownLoc, SrcLocStrSize);
1092}
1093
1094Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(DebugLoc DL,
1095 uint32_t &SrcLocStrSize,
1096 Function *F) {
1097 DILocation *DIL = DL.get();
1098 if (!DIL)
1099 return getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1100 StringRef FileName = M.getName();
1101 if (DIFile *DIF = DIL->getFile())
1102 if (std::optional<StringRef> Source = DIF->getSource())
1103 FileName = *Source;
1104 StringRef Function = DIL->getScope()->getSubprogram()->getName();
1105 if (Function.empty() && F)
1106 Function = F->getName();
1107 return getOrCreateSrcLocStr(FunctionName: Function, FileName, Line: DIL->getLine(),
1108 Column: DIL->getColumn(), SrcLocStrSize);
1109}
1110
1111Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(const LocationDescription &Loc,
1112 uint32_t &SrcLocStrSize) {
1113 return getOrCreateSrcLocStr(DL: Loc.DL, SrcLocStrSize,
1114 F: Loc.IP.getBlock()->getParent());
1115}
1116
1117Value *OpenMPIRBuilder::getOrCreateThreadID(Value *Ident) {
1118 return createRuntimeFunctionCall(
1119 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_global_thread_num), Args: Ident,
1120 Name: "omp_global_thread_num");
1121}
1122
1123OpenMPIRBuilder::InsertPointOrErrorTy
1124OpenMPIRBuilder::createBarrier(const LocationDescription &Loc, Directive Kind,
1125 bool ForceSimpleCall, bool CheckCancelFlag) {
1126 if (!updateToLocation(Loc))
1127 return Loc.IP;
1128
1129 // Build call __kmpc_cancel_barrier(loc, thread_id) or
1130 // __kmpc_barrier(loc, thread_id);
1131
1132 IdentFlag BarrierLocFlags;
1133 switch (Kind) {
1134 case OMPD_for:
1135 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_FOR;
1136 break;
1137 case OMPD_sections:
1138 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SECTIONS;
1139 break;
1140 case OMPD_single:
1141 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SINGLE;
1142 break;
1143 case OMPD_barrier:
1144 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_EXPL;
1145 break;
1146 default:
1147 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL;
1148 break;
1149 }
1150
1151 uint32_t SrcLocStrSize;
1152 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1153 Value *Args[] = {
1154 getOrCreateIdent(SrcLocStr, SrcLocStrSize, LocFlags: BarrierLocFlags),
1155 getOrCreateThreadID(Ident: getOrCreateIdent(SrcLocStr, SrcLocStrSize))};
1156
1157 // If we are in a cancellable parallel region, barriers are cancellation
1158 // points.
1159 // TODO: Check why we would force simple calls or to ignore the cancel flag.
1160 bool UseCancelBarrier =
1161 !ForceSimpleCall && isLastFinalizationInfoCancellable(DK: OMPD_parallel);
1162
1163 Value *Result = createRuntimeFunctionCall(
1164 Callee: getOrCreateRuntimeFunctionPtr(FnID: UseCancelBarrier
1165 ? OMPRTL___kmpc_cancel_barrier
1166 : OMPRTL___kmpc_barrier),
1167 Args);
1168
1169 if (UseCancelBarrier && CheckCancelFlag)
1170 if (Error Err = emitCancelationCheckImpl(CancelFlag: Result, CanceledDirective: OMPD_parallel))
1171 return Err;
1172
1173 return Builder.saveIP();
1174}
1175
1176OpenMPIRBuilder::InsertPointOrErrorTy
1177OpenMPIRBuilder::createCancel(const LocationDescription &Loc,
1178 Value *IfCondition,
1179 omp::Directive CanceledDirective) {
1180 if (!updateToLocation(Loc))
1181 return Loc.IP;
1182
1183 // LLVM utilities like blocks with terminators.
1184 auto *UI = Builder.CreateUnreachable();
1185
1186 Instruction *ThenTI = UI, *ElseTI = nullptr;
1187 if (IfCondition) {
1188 SplitBlockAndInsertIfThenElse(Cond: IfCondition, SplitBefore: UI, ThenTerm: &ThenTI, ElseTerm: &ElseTI);
1189
1190 // Even if the if condition evaluates to false, this should count as a
1191 // cancellation point
1192 Builder.SetInsertPoint(ElseTI);
1193 auto ElseIP = Builder.saveIP();
1194
1195 InsertPointOrErrorTy IPOrErr = createCancellationPoint(
1196 Loc: LocationDescription{ElseIP, Loc.DL}, CanceledDirective);
1197 if (!IPOrErr)
1198 return IPOrErr;
1199 }
1200
1201 Builder.SetInsertPoint(ThenTI);
1202
1203 Value *CancelKind = nullptr;
1204 switch (CanceledDirective) {
1205#define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value) \
1206 case DirectiveEnum: \
1207 CancelKind = Builder.getInt32(Value); \
1208 break;
1209#include "llvm/Frontend/OpenMP/OMPKinds.def"
1210 default:
1211 llvm_unreachable("Unknown cancel kind!");
1212 }
1213
1214 uint32_t SrcLocStrSize;
1215 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1216 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1217 Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
1218 Value *Result = createRuntimeFunctionCall(
1219 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_cancel), Args);
1220
1221 // The actual cancel logic is shared with others, e.g., cancel_barriers.
1222 if (Error Err = emitCancelationCheckImpl(CancelFlag: Result, CanceledDirective))
1223 return Err;
1224
1225 // Update the insertion point and remove the terminator we introduced.
1226 Builder.SetInsertPoint(UI->getParent());
1227 UI->eraseFromParent();
1228
1229 return Builder.saveIP();
1230}
1231
1232OpenMPIRBuilder::InsertPointOrErrorTy
1233OpenMPIRBuilder::createCancellationPoint(const LocationDescription &Loc,
1234 omp::Directive CanceledDirective) {
1235 if (!updateToLocation(Loc))
1236 return Loc.IP;
1237
1238 // LLVM utilities like blocks with terminators.
1239 auto *UI = Builder.CreateUnreachable();
1240 Builder.SetInsertPoint(UI);
1241
1242 Value *CancelKind = nullptr;
1243 switch (CanceledDirective) {
1244#define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value) \
1245 case DirectiveEnum: \
1246 CancelKind = Builder.getInt32(Value); \
1247 break;
1248#include "llvm/Frontend/OpenMP/OMPKinds.def"
1249 default:
1250 llvm_unreachable("Unknown cancel kind!");
1251 }
1252
1253 uint32_t SrcLocStrSize;
1254 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1255 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1256 Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
1257 Value *Result = createRuntimeFunctionCall(
1258 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_cancellationpoint), Args);
1259
1260 // The actual cancel logic is shared with others, e.g., cancel_barriers.
1261 if (Error Err = emitCancelationCheckImpl(CancelFlag: Result, CanceledDirective))
1262 return Err;
1263
1264 // Update the insertion point and remove the terminator we introduced.
1265 Builder.SetInsertPoint(UI->getParent());
1266 UI->eraseFromParent();
1267
1268 return Builder.saveIP();
1269}
1270
1271OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetKernel(
1272 const LocationDescription &Loc, InsertPointTy AllocaIP, Value *&Return,
1273 Value *Ident, Value *DeviceID, Value *NumTeams, Value *NumThreads,
1274 Value *HostPtr, ArrayRef<Value *> KernelArgs) {
1275 if (!updateToLocation(Loc))
1276 return Loc.IP;
1277
1278 Builder.restoreIP(IP: AllocaIP);
1279 auto *KernelArgsPtr =
1280 Builder.CreateAlloca(Ty: OpenMPIRBuilder::KernelArgs, ArraySize: nullptr, Name: "kernel_args");
1281 updateToLocation(Loc);
1282
1283 for (unsigned I = 0, Size = KernelArgs.size(); I != Size; ++I) {
1284 llvm::Value *Arg =
1285 Builder.CreateStructGEP(Ty: OpenMPIRBuilder::KernelArgs, Ptr: KernelArgsPtr, Idx: I);
1286 Builder.CreateAlignedStore(
1287 Val: KernelArgs[I], Ptr: Arg,
1288 Align: M.getDataLayout().getPrefTypeAlign(Ty: KernelArgs[I]->getType()));
1289 }
1290
1291 SmallVector<Value *> OffloadingArgs{Ident, DeviceID, NumTeams,
1292 NumThreads, HostPtr, KernelArgsPtr};
1293
1294 Return = createRuntimeFunctionCall(
1295 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___tgt_target_kernel),
1296 Args: OffloadingArgs);
1297
1298 return Builder.saveIP();
1299}
1300
1301OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitKernelLaunch(
1302 const LocationDescription &Loc, Value *OutlinedFnID,
1303 EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
1304 Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP) {
1305
1306 if (!updateToLocation(Loc))
1307 return Loc.IP;
1308
1309 // On top of the arrays that were filled up, the target offloading call
1310 // takes as arguments the device id as well as the host pointer. The host
1311 // pointer is used by the runtime library to identify the current target
1312 // region, so it only has to be unique and not necessarily point to
1313 // anything. It could be the pointer to the outlined function that
1314 // implements the target region, but we aren't using that so that the
1315 // compiler doesn't need to keep that, and could therefore inline the host
1316 // function if proven worthwhile during optimization.
1317
1318 // From this point on, we need to have an ID of the target region defined.
1319 assert(OutlinedFnID && "Invalid outlined function ID!");
1320 (void)OutlinedFnID;
1321
1322 // Return value of the runtime offloading call.
1323 Value *Return = nullptr;
1324
1325 // Arguments for the target kernel.
1326 SmallVector<Value *> ArgsVector;
1327 getKernelArgsVector(KernelArgs&: Args, Builder, ArgsVector);
1328
1329 // The target region is an outlined function launched by the runtime
1330 // via calls to __tgt_target_kernel().
1331 //
1332 // Note that on the host and CPU targets, the runtime implementation of
1333 // these calls simply call the outlined function without forking threads.
1334 // The outlined functions themselves have runtime calls to
1335 // __kmpc_fork_teams() and __kmpc_fork() for this purpose, codegen'd by
1336 // the compiler in emitTeamsCall() and emitParallelCall().
1337 //
1338 // In contrast, on the NVPTX target, the implementation of
1339 // __tgt_target_teams() launches a GPU kernel with the requested number
1340 // of teams and threads so no additional calls to the runtime are required.
1341 // Check the error code and execute the host version if required.
1342 Builder.restoreIP(IP: emitTargetKernel(
1343 Loc: Builder, AllocaIP, Return, Ident: RTLoc, DeviceID, NumTeams: Args.NumTeams.front(),
1344 NumThreads: Args.NumThreads.front(), HostPtr: OutlinedFnID, KernelArgs: ArgsVector));
1345
1346 BasicBlock *OffloadFailedBlock =
1347 BasicBlock::Create(Context&: Builder.getContext(), Name: "omp_offload.failed");
1348 BasicBlock *OffloadContBlock =
1349 BasicBlock::Create(Context&: Builder.getContext(), Name: "omp_offload.cont");
1350 Value *Failed = Builder.CreateIsNotNull(Arg: Return);
1351 Builder.CreateCondBr(Cond: Failed, True: OffloadFailedBlock, False: OffloadContBlock);
1352
1353 auto CurFn = Builder.GetInsertBlock()->getParent();
1354 emitBlock(BB: OffloadFailedBlock, CurFn);
1355 InsertPointOrErrorTy AfterIP = EmitTargetCallFallbackCB(Builder.saveIP());
1356 if (!AfterIP)
1357 return AfterIP.takeError();
1358 Builder.restoreIP(IP: *AfterIP);
1359 emitBranch(Target: OffloadContBlock);
1360 emitBlock(BB: OffloadContBlock, CurFn, /*IsFinished=*/true);
1361 return Builder.saveIP();
1362}
1363
1364Error OpenMPIRBuilder::emitCancelationCheckImpl(
1365 Value *CancelFlag, omp::Directive CanceledDirective) {
1366 assert(isLastFinalizationInfoCancellable(CanceledDirective) &&
1367 "Unexpected cancellation!");
1368
1369 // For a cancel barrier we create two new blocks.
1370 BasicBlock *BB = Builder.GetInsertBlock();
1371 BasicBlock *NonCancellationBlock;
1372 if (Builder.GetInsertPoint() == BB->end()) {
1373 // TODO: This branch will not be needed once we moved to the
1374 // OpenMPIRBuilder codegen completely.
1375 NonCancellationBlock = BasicBlock::Create(
1376 Context&: BB->getContext(), Name: BB->getName() + ".cont", Parent: BB->getParent());
1377 } else {
1378 NonCancellationBlock = SplitBlock(Old: BB, SplitPt: &*Builder.GetInsertPoint());
1379 BB->getTerminator()->eraseFromParent();
1380 Builder.SetInsertPoint(BB);
1381 }
1382 BasicBlock *CancellationBlock = BasicBlock::Create(
1383 Context&: BB->getContext(), Name: BB->getName() + ".cncl", Parent: BB->getParent());
1384
1385 // Jump to them based on the return value.
1386 Value *Cmp = Builder.CreateIsNull(Arg: CancelFlag);
1387 Builder.CreateCondBr(Cond: Cmp, True: NonCancellationBlock, False: CancellationBlock,
1388 /* TODO weight */ BranchWeights: nullptr, Unpredictable: nullptr);
1389
1390 // From the cancellation block we finalize all variables and go to the
1391 // post finalization block that is known to the FiniCB callback.
1392 auto &FI = FinalizationStack.back();
1393 Expected<BasicBlock *> FiniBBOrErr = FI.getFiniBB(Builder);
1394 if (!FiniBBOrErr)
1395 return FiniBBOrErr.takeError();
1396 Builder.SetInsertPoint(CancellationBlock);
1397 Builder.CreateBr(Dest: *FiniBBOrErr);
1398
1399 // The continuation block is where code generation continues.
1400 Builder.SetInsertPoint(TheBB: NonCancellationBlock, IP: NonCancellationBlock->begin());
1401 return Error::success();
1402}
1403
1404// Callback used to create OpenMP runtime calls to support
1405// omp parallel clause for the device.
1406// We need to use this callback to replace call to the OutlinedFn in OuterFn
1407// by the call to the OpenMP DeviceRTL runtime function (kmpc_parallel_60)
1408static void targetParallelCallback(
1409 OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn, Function *OuterFn,
1410 BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
1411 Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1412 Value *ThreadID, const SmallVector<Instruction *, 4> &ToBeDeleted) {
1413 // Add some known attributes.
1414 IRBuilder<> &Builder = OMPIRBuilder->Builder;
1415 OutlinedFn.addParamAttr(ArgNo: 0, Kind: Attribute::NoAlias);
1416 OutlinedFn.addParamAttr(ArgNo: 1, Kind: Attribute::NoAlias);
1417 OutlinedFn.addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
1418 OutlinedFn.addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
1419 OutlinedFn.addFnAttr(Kind: Attribute::NoUnwind);
1420
1421 assert(OutlinedFn.arg_size() >= 2 &&
1422 "Expected at least tid and bounded tid as arguments");
1423 unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1424
1425 CallInst *CI = cast<CallInst>(Val: OutlinedFn.user_back());
1426 assert(CI && "Expected call instruction to outlined function");
1427 CI->getParent()->setName("omp_parallel");
1428
1429 Builder.SetInsertPoint(CI);
1430 Type *PtrTy = OMPIRBuilder->VoidPtr;
1431 Value *NullPtrValue = Constant::getNullValue(Ty: PtrTy);
1432
1433 // Add alloca for kernel args
1434 OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP();
1435 Builder.SetInsertPoint(TheBB: OuterAllocaBB, IP: OuterAllocaBB->getFirstInsertionPt());
1436 AllocaInst *ArgsAlloca =
1437 Builder.CreateAlloca(Ty: ArrayType::get(ElementType: PtrTy, NumElements: NumCapturedVars));
1438 Value *Args = ArgsAlloca;
1439 // Add address space cast if array for storing arguments is not allocated
1440 // in address space 0
1441 if (ArgsAlloca->getAddressSpace())
1442 Args = Builder.CreatePointerCast(V: ArgsAlloca, DestTy: PtrTy);
1443 Builder.restoreIP(IP: CurrentIP);
1444
1445 // Store captured vars which are used by kmpc_parallel_60
1446 for (unsigned Idx = 0; Idx < NumCapturedVars; Idx++) {
1447 Value *V = *(CI->arg_begin() + 2 + Idx);
1448 Value *StoreAddress = Builder.CreateConstInBoundsGEP2_64(
1449 Ty: ArrayType::get(ElementType: PtrTy, NumElements: NumCapturedVars), Ptr: Args, Idx0: 0, Idx1: Idx);
1450 Builder.CreateStore(Val: V, Ptr: StoreAddress);
1451 }
1452
1453 Value *Cond =
1454 IfCondition ? Builder.CreateSExtOrTrunc(V: IfCondition, DestTy: OMPIRBuilder->Int32)
1455 : Builder.getInt32(C: 1);
1456
1457 // Build kmpc_parallel_60 call
1458 Value *Parallel60CallArgs[] = {
1459 /* identifier*/ Ident,
1460 /* global thread num*/ ThreadID,
1461 /* if expression */ Cond,
1462 /* number of threads */ NumThreads ? NumThreads : Builder.getInt32(C: -1),
1463 /* Proc bind */ Builder.getInt32(C: -1),
1464 /* outlined function */ &OutlinedFn,
1465 /* wrapper function */ NullPtrValue,
1466 /* arguments of the outlined funciton*/ Args,
1467 /* number of arguments */ Builder.getInt64(C: NumCapturedVars),
1468 /* strict for number of threads */ Builder.getInt32(C: 0)};
1469
1470 FunctionCallee RTLFn =
1471 OMPIRBuilder->getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_parallel_60);
1472
1473 OMPIRBuilder->createRuntimeFunctionCall(Callee: RTLFn, Args: Parallel60CallArgs);
1474
1475 LLVM_DEBUG(dbgs() << "With kmpc_parallel_60 placed: "
1476 << *Builder.GetInsertBlock()->getParent() << "\n");
1477
1478 // Initialize the local TID stack location with the argument value.
1479 Builder.SetInsertPoint(PrivTID);
1480 Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1481 Builder.CreateStore(Val: Builder.CreateLoad(Ty: OMPIRBuilder->Int32, Ptr: OutlinedAI),
1482 Ptr: PrivTIDAddr);
1483
1484 // Remove redundant call to the outlined function.
1485 CI->eraseFromParent();
1486
1487 for (Instruction *I : ToBeDeleted) {
1488 I->eraseFromParent();
1489 }
1490}
1491
1492// Callback used to create OpenMP runtime calls to support
1493// omp parallel clause for the host.
1494// We need to use this callback to replace call to the OutlinedFn in OuterFn
1495// by the call to the OpenMP host runtime function ( __kmpc_fork_call[_if])
1496static void
1497hostParallelCallback(OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn,
1498 Function *OuterFn, Value *Ident, Value *IfCondition,
1499 Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1500 const SmallVector<Instruction *, 4> &ToBeDeleted) {
1501 IRBuilder<> &Builder = OMPIRBuilder->Builder;
1502 FunctionCallee RTLFn;
1503 if (IfCondition) {
1504 RTLFn =
1505 OMPIRBuilder->getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_fork_call_if);
1506 } else {
1507 RTLFn =
1508 OMPIRBuilder->getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_fork_call);
1509 }
1510 if (auto *F = dyn_cast<Function>(Val: RTLFn.getCallee())) {
1511 if (!F->hasMetadata(KindID: LLVMContext::MD_callback)) {
1512 LLVMContext &Ctx = F->getContext();
1513 MDBuilder MDB(Ctx);
1514 // Annotate the callback behavior of the __kmpc_fork_call:
1515 // - The callback callee is argument number 2 (microtask).
1516 // - The first two arguments of the callback callee are unknown (-1).
1517 // - All variadic arguments to the __kmpc_fork_call are passed to the
1518 // callback callee.
1519 F->addMetadata(KindID: LLVMContext::MD_callback,
1520 MD&: *MDNode::get(Context&: Ctx, MDs: {MDB.createCallbackEncoding(
1521 CalleeArgNo: 2, Arguments: {-1, -1},
1522 /* VarArgsArePassed */ true)}));
1523 }
1524 }
1525 // Add some known attributes.
1526 OutlinedFn.addParamAttr(ArgNo: 0, Kind: Attribute::NoAlias);
1527 OutlinedFn.addParamAttr(ArgNo: 1, Kind: Attribute::NoAlias);
1528 OutlinedFn.addFnAttr(Kind: Attribute::NoUnwind);
1529
1530 assert(OutlinedFn.arg_size() >= 2 &&
1531 "Expected at least tid and bounded tid as arguments");
1532 unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1533
1534 CallInst *CI = cast<CallInst>(Val: OutlinedFn.user_back());
1535 CI->getParent()->setName("omp_parallel");
1536 Builder.SetInsertPoint(CI);
1537
1538 // Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn);
1539 Value *ForkCallArgs[] = {Ident, Builder.getInt32(C: NumCapturedVars),
1540 &OutlinedFn};
1541
1542 SmallVector<Value *, 16> RealArgs;
1543 RealArgs.append(in_start: std::begin(arr&: ForkCallArgs), in_end: std::end(arr&: ForkCallArgs));
1544 if (IfCondition) {
1545 Value *Cond = Builder.CreateSExtOrTrunc(V: IfCondition, DestTy: OMPIRBuilder->Int32);
1546 RealArgs.push_back(Elt: Cond);
1547 }
1548 RealArgs.append(in_start: CI->arg_begin() + /* tid & bound tid */ 2, in_end: CI->arg_end());
1549
1550 // __kmpc_fork_call_if always expects a void ptr as the last argument
1551 // If there are no arguments, pass a null pointer.
1552 auto PtrTy = OMPIRBuilder->VoidPtr;
1553 if (IfCondition && NumCapturedVars == 0) {
1554 Value *NullPtrValue = Constant::getNullValue(Ty: PtrTy);
1555 RealArgs.push_back(Elt: NullPtrValue);
1556 }
1557
1558 OMPIRBuilder->createRuntimeFunctionCall(Callee: RTLFn, Args: RealArgs);
1559
1560 LLVM_DEBUG(dbgs() << "With fork_call placed: "
1561 << *Builder.GetInsertBlock()->getParent() << "\n");
1562
1563 // Initialize the local TID stack location with the argument value.
1564 Builder.SetInsertPoint(PrivTID);
1565 Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1566 Builder.CreateStore(Val: Builder.CreateLoad(Ty: OMPIRBuilder->Int32, Ptr: OutlinedAI),
1567 Ptr: PrivTIDAddr);
1568
1569 // Remove redundant call to the outlined function.
1570 CI->eraseFromParent();
1571
1572 for (Instruction *I : ToBeDeleted) {
1573 I->eraseFromParent();
1574 }
1575}
1576
1577OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel(
1578 const LocationDescription &Loc, InsertPointTy OuterAllocaIP,
1579 BodyGenCallbackTy BodyGenCB, PrivatizeCallbackTy PrivCB,
1580 FinalizeCallbackTy FiniCB, Value *IfCondition, Value *NumThreads,
1581 omp::ProcBindKind ProcBind, bool IsCancellable) {
1582 assert(!isConflictIP(Loc.IP, OuterAllocaIP) && "IPs must not be ambiguous");
1583
1584 if (!updateToLocation(Loc))
1585 return Loc.IP;
1586
1587 uint32_t SrcLocStrSize;
1588 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1589 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1590 const bool NeedThreadID = NumThreads || Config.isTargetDevice() ||
1591 (ProcBind != OMP_PROC_BIND_default);
1592 Value *ThreadID = NeedThreadID ? getOrCreateThreadID(Ident) : nullptr;
1593 // If we generate code for the target device, we need to allocate
1594 // struct for aggregate params in the device default alloca address space.
1595 // OpenMP runtime requires that the params of the extracted functions are
1596 // passed as zero address space pointers. This flag ensures that extracted
1597 // function arguments are declared in zero address space
1598 bool ArgsInZeroAddressSpace = Config.isTargetDevice();
1599
1600 // Build call __kmpc_push_num_threads(&Ident, global_tid, num_threads)
1601 // only if we compile for host side.
1602 if (NumThreads && !Config.isTargetDevice()) {
1603 Value *Args[] = {
1604 Ident, ThreadID,
1605 Builder.CreateIntCast(V: NumThreads, DestTy: Int32, /*isSigned*/ false)};
1606 createRuntimeFunctionCall(
1607 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_push_num_threads), Args);
1608 }
1609
1610 if (ProcBind != OMP_PROC_BIND_default) {
1611 // Build call __kmpc_push_proc_bind(&Ident, global_tid, proc_bind)
1612 Value *Args[] = {
1613 Ident, ThreadID,
1614 ConstantInt::get(Ty: Int32, V: unsigned(ProcBind), /*isSigned=*/IsSigned: true)};
1615 createRuntimeFunctionCall(
1616 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_push_proc_bind), Args);
1617 }
1618
1619 BasicBlock *InsertBB = Builder.GetInsertBlock();
1620 Function *OuterFn = InsertBB->getParent();
1621
1622 // Save the outer alloca block because the insertion iterator may get
1623 // invalidated and we still need this later.
1624 BasicBlock *OuterAllocaBlock = OuterAllocaIP.getBlock();
1625
1626 // Vector to remember instructions we used only during the modeling but which
1627 // we want to delete at the end.
1628 SmallVector<Instruction *, 4> ToBeDeleted;
1629
1630 // Change the location to the outer alloca insertion point to create and
1631 // initialize the allocas we pass into the parallel region.
1632 InsertPointTy NewOuter(OuterAllocaBlock, OuterAllocaBlock->begin());
1633 Builder.restoreIP(IP: NewOuter);
1634 AllocaInst *TIDAddrAlloca = Builder.CreateAlloca(Ty: Int32, ArraySize: nullptr, Name: "tid.addr");
1635 AllocaInst *ZeroAddrAlloca =
1636 Builder.CreateAlloca(Ty: Int32, ArraySize: nullptr, Name: "zero.addr");
1637 Instruction *TIDAddr = TIDAddrAlloca;
1638 Instruction *ZeroAddr = ZeroAddrAlloca;
1639 if (ArgsInZeroAddressSpace && M.getDataLayout().getAllocaAddrSpace() != 0) {
1640 // Add additional casts to enforce pointers in zero address space
1641 TIDAddr = new AddrSpaceCastInst(
1642 TIDAddrAlloca, PointerType ::get(C&: M.getContext(), AddressSpace: 0), "tid.addr.ascast");
1643 TIDAddr->insertAfter(InsertPos: TIDAddrAlloca->getIterator());
1644 ToBeDeleted.push_back(Elt: TIDAddr);
1645 ZeroAddr = new AddrSpaceCastInst(ZeroAddrAlloca,
1646 PointerType ::get(C&: M.getContext(), AddressSpace: 0),
1647 "zero.addr.ascast");
1648 ZeroAddr->insertAfter(InsertPos: ZeroAddrAlloca->getIterator());
1649 ToBeDeleted.push_back(Elt: ZeroAddr);
1650 }
1651
1652 // We only need TIDAddr and ZeroAddr for modeling purposes to get the
1653 // associated arguments in the outlined function, so we delete them later.
1654 ToBeDeleted.push_back(Elt: TIDAddrAlloca);
1655 ToBeDeleted.push_back(Elt: ZeroAddrAlloca);
1656
1657 // Create an artificial insertion point that will also ensure the blocks we
1658 // are about to split are not degenerated.
1659 auto *UI = new UnreachableInst(Builder.getContext(), InsertBB);
1660
1661 BasicBlock *EntryBB = UI->getParent();
1662 BasicBlock *PRegEntryBB = EntryBB->splitBasicBlock(I: UI, BBName: "omp.par.entry");
1663 BasicBlock *PRegBodyBB = PRegEntryBB->splitBasicBlock(I: UI, BBName: "omp.par.region");
1664 BasicBlock *PRegPreFiniBB =
1665 PRegBodyBB->splitBasicBlock(I: UI, BBName: "omp.par.pre_finalize");
1666 BasicBlock *PRegExitBB = PRegPreFiniBB->splitBasicBlock(I: UI, BBName: "omp.par.exit");
1667
1668 auto FiniCBWrapper = [&](InsertPointTy IP) {
1669 // Hide "open-ended" blocks from the given FiniCB by setting the right jump
1670 // target to the region exit block.
1671 if (IP.getBlock()->end() == IP.getPoint()) {
1672 IRBuilder<>::InsertPointGuard IPG(Builder);
1673 Builder.restoreIP(IP);
1674 Instruction *I = Builder.CreateBr(Dest: PRegExitBB);
1675 IP = InsertPointTy(I->getParent(), I->getIterator());
1676 }
1677 assert(IP.getBlock()->getTerminator()->getNumSuccessors() == 1 &&
1678 IP.getBlock()->getTerminator()->getSuccessor(0) == PRegExitBB &&
1679 "Unexpected insertion point for finalization call!");
1680 return FiniCB(IP);
1681 };
1682
1683 FinalizationStack.push_back(Elt: {FiniCBWrapper, OMPD_parallel, IsCancellable});
1684
1685 // Generate the privatization allocas in the block that will become the entry
1686 // of the outlined function.
1687 Builder.SetInsertPoint(PRegEntryBB->getTerminator());
1688 InsertPointTy InnerAllocaIP = Builder.saveIP();
1689
1690 AllocaInst *PrivTIDAddr =
1691 Builder.CreateAlloca(Ty: Int32, ArraySize: nullptr, Name: "tid.addr.local");
1692 Instruction *PrivTID = Builder.CreateLoad(Ty: Int32, Ptr: PrivTIDAddr, Name: "tid");
1693
1694 // Add some fake uses for OpenMP provided arguments.
1695 ToBeDeleted.push_back(Elt: Builder.CreateLoad(Ty: Int32, Ptr: TIDAddr, Name: "tid.addr.use"));
1696 Instruction *ZeroAddrUse =
1697 Builder.CreateLoad(Ty: Int32, Ptr: ZeroAddr, Name: "zero.addr.use");
1698 ToBeDeleted.push_back(Elt: ZeroAddrUse);
1699
1700 // EntryBB
1701 // |
1702 // V
1703 // PRegionEntryBB <- Privatization allocas are placed here.
1704 // |
1705 // V
1706 // PRegionBodyBB <- BodeGen is invoked here.
1707 // |
1708 // V
1709 // PRegPreFiniBB <- The block we will start finalization from.
1710 // |
1711 // V
1712 // PRegionExitBB <- A common exit to simplify block collection.
1713 //
1714
1715 LLVM_DEBUG(dbgs() << "Before body codegen: " << *OuterFn << "\n");
1716
1717 // Let the caller create the body.
1718 assert(BodyGenCB && "Expected body generation callback!");
1719 InsertPointTy CodeGenIP(PRegBodyBB, PRegBodyBB->begin());
1720 if (Error Err = BodyGenCB(InnerAllocaIP, CodeGenIP))
1721 return Err;
1722
1723 LLVM_DEBUG(dbgs() << "After body codegen: " << *OuterFn << "\n");
1724
1725 OutlineInfo OI;
1726 if (Config.isTargetDevice()) {
1727 // Generate OpenMP target specific runtime call
1728 OI.PostOutlineCB = [=, ToBeDeletedVec =
1729 std::move(ToBeDeleted)](Function &OutlinedFn) {
1730 targetParallelCallback(OMPIRBuilder: this, OutlinedFn, OuterFn, OuterAllocaBB: OuterAllocaBlock, Ident,
1731 IfCondition, NumThreads, PrivTID, PrivTIDAddr,
1732 ThreadID, ToBeDeleted: ToBeDeletedVec);
1733 };
1734 OI.FixUpNonEntryAllocas = true;
1735 } else {
1736 // Generate OpenMP host runtime call
1737 OI.PostOutlineCB = [=, ToBeDeletedVec =
1738 std::move(ToBeDeleted)](Function &OutlinedFn) {
1739 hostParallelCallback(OMPIRBuilder: this, OutlinedFn, OuterFn, Ident, IfCondition,
1740 PrivTID, PrivTIDAddr, ToBeDeleted: ToBeDeletedVec);
1741 };
1742 OI.FixUpNonEntryAllocas = true;
1743 }
1744
1745 OI.OuterAllocaBB = OuterAllocaBlock;
1746 OI.EntryBB = PRegEntryBB;
1747 OI.ExitBB = PRegExitBB;
1748
1749 SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
1750 SmallVector<BasicBlock *, 32> Blocks;
1751 OI.collectBlocks(BlockSet&: ParallelRegionBlockSet, BlockVector&: Blocks);
1752
1753 CodeExtractorAnalysisCache CEAC(*OuterFn);
1754 CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
1755 /* AggregateArgs */ false,
1756 /* BlockFrequencyInfo */ nullptr,
1757 /* BranchProbabilityInfo */ nullptr,
1758 /* AssumptionCache */ nullptr,
1759 /* AllowVarArgs */ true,
1760 /* AllowAlloca */ true,
1761 /* AllocationBlock */ OuterAllocaBlock,
1762 /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
1763
1764 // Find inputs to, outputs from the code region.
1765 BasicBlock *CommonExit = nullptr;
1766 SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
1767 Extractor.findAllocas(CEAC, SinkCands&: SinkingCands, HoistCands&: HoistingCands, ExitBlock&: CommonExit);
1768
1769 Extractor.findInputsOutputs(Inputs, Outputs, Allocas: SinkingCands,
1770 /*CollectGlobalInputs=*/true);
1771
1772 Inputs.remove_if(P: [&](Value *I) {
1773 if (auto *GV = dyn_cast_if_present<GlobalVariable>(Val: I))
1774 return GV->getValueType() == OpenMPIRBuilder::Ident;
1775
1776 return false;
1777 });
1778
1779 LLVM_DEBUG(dbgs() << "Before privatization: " << *OuterFn << "\n");
1780
1781 FunctionCallee TIDRTLFn =
1782 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_global_thread_num);
1783
1784 auto PrivHelper = [&](Value &V) -> Error {
1785 if (&V == TIDAddr || &V == ZeroAddr) {
1786 OI.ExcludeArgsFromAggregate.push_back(Elt: &V);
1787 return Error::success();
1788 }
1789
1790 SetVector<Use *> Uses;
1791 for (Use &U : V.uses())
1792 if (auto *UserI = dyn_cast<Instruction>(Val: U.getUser()))
1793 if (ParallelRegionBlockSet.count(Ptr: UserI->getParent()))
1794 Uses.insert(X: &U);
1795
1796 // __kmpc_fork_call expects extra arguments as pointers. If the input
1797 // already has a pointer type, everything is fine. Otherwise, store the
1798 // value onto stack and load it back inside the to-be-outlined region. This
1799 // will ensure only the pointer will be passed to the function.
1800 // FIXME: if there are more than 15 trailing arguments, they must be
1801 // additionally packed in a struct.
1802 Value *Inner = &V;
1803 if (!V.getType()->isPointerTy()) {
1804 IRBuilder<>::InsertPointGuard Guard(Builder);
1805 LLVM_DEBUG(llvm::dbgs() << "Forwarding input as pointer: " << V << "\n");
1806
1807 Builder.restoreIP(IP: OuterAllocaIP);
1808 Value *Ptr =
1809 Builder.CreateAlloca(Ty: V.getType(), ArraySize: nullptr, Name: V.getName() + ".reloaded");
1810
1811 // Store to stack at end of the block that currently branches to the entry
1812 // block of the to-be-outlined region.
1813 Builder.SetInsertPoint(TheBB: InsertBB,
1814 IP: InsertBB->getTerminator()->getIterator());
1815 Builder.CreateStore(Val: &V, Ptr);
1816
1817 // Load back next to allocations in the to-be-outlined region.
1818 Builder.restoreIP(IP: InnerAllocaIP);
1819 Inner = Builder.CreateLoad(Ty: V.getType(), Ptr);
1820 }
1821
1822 Value *ReplacementValue = nullptr;
1823 CallInst *CI = dyn_cast<CallInst>(Val: &V);
1824 if (CI && CI->getCalledFunction() == TIDRTLFn.getCallee()) {
1825 ReplacementValue = PrivTID;
1826 } else {
1827 InsertPointOrErrorTy AfterIP =
1828 PrivCB(InnerAllocaIP, Builder.saveIP(), V, *Inner, ReplacementValue);
1829 if (!AfterIP)
1830 return AfterIP.takeError();
1831 Builder.restoreIP(IP: *AfterIP);
1832 InnerAllocaIP = {
1833 InnerAllocaIP.getBlock(),
1834 InnerAllocaIP.getBlock()->getTerminator()->getIterator()};
1835
1836 assert(ReplacementValue &&
1837 "Expected copy/create callback to set replacement value!");
1838 if (ReplacementValue == &V)
1839 return Error::success();
1840 }
1841
1842 for (Use *UPtr : Uses)
1843 UPtr->set(ReplacementValue);
1844
1845 return Error::success();
1846 };
1847
1848 // Reset the inner alloca insertion as it will be used for loading the values
1849 // wrapped into pointers before passing them into the to-be-outlined region.
1850 // Configure it to insert immediately after the fake use of zero address so
1851 // that they are available in the generated body and so that the
1852 // OpenMP-related values (thread ID and zero address pointers) remain leading
1853 // in the argument list.
1854 InnerAllocaIP = IRBuilder<>::InsertPoint(
1855 ZeroAddrUse->getParent(), ZeroAddrUse->getNextNode()->getIterator());
1856
1857 // Reset the outer alloca insertion point to the entry of the relevant block
1858 // in case it was invalidated.
1859 OuterAllocaIP = IRBuilder<>::InsertPoint(
1860 OuterAllocaBlock, OuterAllocaBlock->getFirstInsertionPt());
1861
1862 for (Value *Input : Inputs) {
1863 LLVM_DEBUG(dbgs() << "Captured input: " << *Input << "\n");
1864 if (Error Err = PrivHelper(*Input))
1865 return Err;
1866 }
1867 LLVM_DEBUG({
1868 for (Value *Output : Outputs)
1869 LLVM_DEBUG(dbgs() << "Captured output: " << *Output << "\n");
1870 });
1871 assert(Outputs.empty() &&
1872 "OpenMP outlining should not produce live-out values!");
1873
1874 LLVM_DEBUG(dbgs() << "After privatization: " << *OuterFn << "\n");
1875 LLVM_DEBUG({
1876 for (auto *BB : Blocks)
1877 dbgs() << " PBR: " << BB->getName() << "\n";
1878 });
1879
1880 // Adjust the finalization stack, verify the adjustment, and call the
1881 // finalize function a last time to finalize values between the pre-fini
1882 // block and the exit block if we left the parallel "the normal way".
1883 auto FiniInfo = FinalizationStack.pop_back_val();
1884 (void)FiniInfo;
1885 assert(FiniInfo.DK == OMPD_parallel &&
1886 "Unexpected finalization stack state!");
1887
1888 Instruction *PRegPreFiniTI = PRegPreFiniBB->getTerminator();
1889
1890 InsertPointTy PreFiniIP(PRegPreFiniBB, PRegPreFiniTI->getIterator());
1891 Expected<BasicBlock *> FiniBBOrErr = FiniInfo.getFiniBB(Builder);
1892 if (!FiniBBOrErr)
1893 return FiniBBOrErr.takeError();
1894 {
1895 IRBuilderBase::InsertPointGuard Guard(Builder);
1896 Builder.restoreIP(IP: PreFiniIP);
1897 Builder.CreateBr(Dest: *FiniBBOrErr);
1898 // There's currently a branch to omp.par.exit. Delete it. We will get there
1899 // via the fini block
1900 if (Instruction *Term = Builder.GetInsertBlock()->getTerminator())
1901 Term->eraseFromParent();
1902 }
1903
1904 // Register the outlined info.
1905 addOutlineInfo(OI: std::move(OI));
1906
1907 InsertPointTy AfterIP(UI->getParent(), UI->getParent()->end());
1908 UI->eraseFromParent();
1909
1910 return AfterIP;
1911}
1912
1913void OpenMPIRBuilder::emitFlush(const LocationDescription &Loc) {
1914 // Build call void __kmpc_flush(ident_t *loc)
1915 uint32_t SrcLocStrSize;
1916 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1917 Value *Args[] = {getOrCreateIdent(SrcLocStr, SrcLocStrSize)};
1918
1919 createRuntimeFunctionCall(Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_flush),
1920 Args);
1921}
1922
1923void OpenMPIRBuilder::createFlush(const LocationDescription &Loc) {
1924 if (!updateToLocation(Loc))
1925 return;
1926 emitFlush(Loc);
1927}
1928
1929void OpenMPIRBuilder::emitTaskwaitImpl(const LocationDescription &Loc) {
1930 // Build call kmp_int32 __kmpc_omp_taskwait(ident_t *loc, kmp_int32
1931 // global_tid);
1932 uint32_t SrcLocStrSize;
1933 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1934 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1935 Value *Args[] = {Ident, getOrCreateThreadID(Ident)};
1936
1937 // Ignore return result until untied tasks are supported.
1938 createRuntimeFunctionCall(
1939 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_taskwait), Args);
1940}
1941
1942void OpenMPIRBuilder::createTaskwait(const LocationDescription &Loc) {
1943 if (!updateToLocation(Loc))
1944 return;
1945 emitTaskwaitImpl(Loc);
1946}
1947
1948void OpenMPIRBuilder::emitTaskyieldImpl(const LocationDescription &Loc) {
1949 // Build call __kmpc_omp_taskyield(loc, thread_id, 0);
1950 uint32_t SrcLocStrSize;
1951 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1952 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1953 Constant *I32Null = ConstantInt::getNullValue(Ty: Int32);
1954 Value *Args[] = {Ident, getOrCreateThreadID(Ident), I32Null};
1955
1956 createRuntimeFunctionCall(
1957 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_taskyield), Args);
1958}
1959
1960void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
1961 if (!updateToLocation(Loc))
1962 return;
1963 emitTaskyieldImpl(Loc);
1964}
1965
1966// Processes the dependencies in Dependencies and does the following
1967// - Allocates space on the stack of an array of DependInfo objects
1968// - Populates each DependInfo object with relevant information of
1969// the corresponding dependence.
1970// - All code is inserted in the entry block of the current function.
1971static Value *emitTaskDependencies(
1972 OpenMPIRBuilder &OMPBuilder,
1973 const SmallVectorImpl<OpenMPIRBuilder::DependData> &Dependencies) {
1974 // Early return if we have no dependencies to process
1975 if (Dependencies.empty())
1976 return nullptr;
1977
1978 // Given a vector of DependData objects, in this function we create an
1979 // array on the stack that holds kmp_dep_info objects corresponding
1980 // to each dependency. This is then passed to the OpenMP runtime.
1981 // For example, if there are 'n' dependencies then the following psedo
1982 // code is generated. Assume the first dependence is on a variable 'a'
1983 //
1984 // \code{c}
1985 // DepArray = alloc(n x sizeof(kmp_depend_info);
1986 // idx = 0;
1987 // DepArray[idx].base_addr = ptrtoint(&a);
1988 // DepArray[idx].len = 8;
1989 // DepArray[idx].flags = Dep.DepKind; /*(See OMPContants.h for DepKind)*/
1990 // ++idx;
1991 // DepArray[idx].base_addr = ...;
1992 // \endcode
1993
1994 IRBuilderBase &Builder = OMPBuilder.Builder;
1995 Type *DependInfo = OMPBuilder.DependInfo;
1996 Module &M = OMPBuilder.M;
1997
1998 Value *DepArray = nullptr;
1999 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
2000 Builder.SetInsertPoint(
2001 OldIP.getBlock()->getParent()->getEntryBlock().getTerminator());
2002
2003 Type *DepArrayTy = ArrayType::get(ElementType: DependInfo, NumElements: Dependencies.size());
2004 DepArray = Builder.CreateAlloca(Ty: DepArrayTy, ArraySize: nullptr, Name: ".dep.arr.addr");
2005
2006 Builder.restoreIP(IP: OldIP);
2007
2008 for (const auto &[DepIdx, Dep] : enumerate(First: Dependencies)) {
2009 Value *Base =
2010 Builder.CreateConstInBoundsGEP2_64(Ty: DepArrayTy, Ptr: DepArray, Idx0: 0, Idx1: DepIdx);
2011 // Store the pointer to the variable
2012 Value *Addr = Builder.CreateStructGEP(
2013 Ty: DependInfo, Ptr: Base,
2014 Idx: static_cast<unsigned int>(RTLDependInfoFields::BaseAddr));
2015 Value *DepValPtr = Builder.CreatePtrToInt(V: Dep.DepVal, DestTy: Builder.getInt64Ty());
2016 Builder.CreateStore(Val: DepValPtr, Ptr: Addr);
2017 // Store the size of the variable
2018 Value *Size = Builder.CreateStructGEP(
2019 Ty: DependInfo, Ptr: Base, Idx: static_cast<unsigned int>(RTLDependInfoFields::Len));
2020 Builder.CreateStore(
2021 Val: Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: Dep.DepValueType)),
2022 Ptr: Size);
2023 // Store the dependency kind
2024 Value *Flags = Builder.CreateStructGEP(
2025 Ty: DependInfo, Ptr: Base,
2026 Idx: static_cast<unsigned int>(RTLDependInfoFields::Flags));
2027 Builder.CreateStore(
2028 Val: ConstantInt::get(Ty: Builder.getInt8Ty(),
2029 V: static_cast<unsigned int>(Dep.DepKind)),
2030 Ptr: Flags);
2031 }
2032 return DepArray;
2033}
2034
2035/// Create the task duplication function passed to kmpc_taskloop.
2036Expected<Value *> OpenMPIRBuilder::createTaskDuplicationFunction(
2037 Type *PrivatesTy, int32_t PrivatesIndex, TaskDupCallbackTy DupCB) {
2038 unsigned ProgramAddressSpace = M.getDataLayout().getProgramAddressSpace();
2039 if (!DupCB)
2040 return Constant::getNullValue(
2041 Ty: PointerType::get(C&: Builder.getContext(), AddressSpace: ProgramAddressSpace));
2042
2043 // From OpenMP Runtime p_task_dup_t:
2044 // Routine optionally generated by the compiler for setting the lastprivate
2045 // flag and calling needed constructors for private/firstprivate objects (used
2046 // to form taskloop tasks from pattern task) Parameters: dest task, src task,
2047 // lastprivate flag.
2048 // typedef void (*p_task_dup_t)(kmp_task_t *, kmp_task_t *, kmp_int32);
2049
2050 auto *VoidPtrTy = PointerType::get(C&: Builder.getContext(), AddressSpace: ProgramAddressSpace);
2051
2052 FunctionType *DupFuncTy = FunctionType::get(
2053 Result: Builder.getVoidTy(), Params: {VoidPtrTy, VoidPtrTy, Builder.getInt32Ty()},
2054 /*isVarArg=*/false);
2055
2056 Function *DupFunction = Function::Create(Ty: DupFuncTy, Linkage: Function::InternalLinkage,
2057 N: "omp_taskloop_dup", M);
2058 Value *DestTaskArg = DupFunction->getArg(i: 0);
2059 Value *SrcTaskArg = DupFunction->getArg(i: 1);
2060 Value *LastprivateFlagArg = DupFunction->getArg(i: 2);
2061 DestTaskArg->setName("dest_task");
2062 SrcTaskArg->setName("src_task");
2063 LastprivateFlagArg->setName("lastprivate_flag");
2064
2065 IRBuilderBase::InsertPointGuard Guard(Builder);
2066 Builder.SetInsertPoint(
2067 BasicBlock::Create(Context&: Builder.getContext(), Name: "entry", Parent: DupFunction));
2068
2069 auto GetTaskContextPtrFromArg = [&](Value *Arg) -> Value * {
2070 Type *TaskWithPrivatesTy =
2071 StructType::get(Context&: Builder.getContext(), Elements: {Task, PrivatesTy});
2072 Value *TaskPrivates = Builder.CreateGEP(
2073 Ty: TaskWithPrivatesTy, Ptr: Arg, IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 1)});
2074 Value *ContextPtr = Builder.CreateGEP(
2075 Ty: PrivatesTy, Ptr: TaskPrivates,
2076 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: PrivatesIndex)});
2077 return ContextPtr;
2078 };
2079
2080 Value *DestTaskContextPtr = GetTaskContextPtrFromArg(DestTaskArg);
2081 Value *SrcTaskContextPtr = GetTaskContextPtrFromArg(SrcTaskArg);
2082
2083 DestTaskContextPtr->setName("destPtr");
2084 SrcTaskContextPtr->setName("srcPtr");
2085
2086 InsertPointTy AllocaIP(&DupFunction->getEntryBlock(),
2087 DupFunction->getEntryBlock().begin());
2088 InsertPointTy CodeGenIP = Builder.saveIP();
2089 Expected<IRBuilderBase::InsertPoint> AfterIPOrError =
2090 DupCB(AllocaIP, CodeGenIP, DestTaskContextPtr, SrcTaskContextPtr);
2091 if (!AfterIPOrError)
2092 return AfterIPOrError.takeError();
2093 Builder.restoreIP(IP: *AfterIPOrError);
2094
2095 Builder.CreateRetVoid();
2096
2097 return DupFunction;
2098}
2099
2100OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTaskloop(
2101 const LocationDescription &Loc, InsertPointTy AllocaIP,
2102 BodyGenCallbackTy BodyGenCB,
2103 llvm::function_ref<llvm::Expected<llvm::CanonicalLoopInfo *>()> LoopInfo,
2104 Value *LBVal, Value *UBVal, Value *StepVal, bool Untied, Value *IfCond,
2105 Value *GrainSize, bool NoGroup, int Sched, Value *Final, bool Mergeable,
2106 Value *Priority, uint64_t NumOfCollapseLoops, TaskDupCallbackTy DupCB,
2107 Value *TaskContextStructPtrVal) {
2108
2109 if (!updateToLocation(Loc))
2110 return InsertPointTy();
2111
2112 uint32_t SrcLocStrSize;
2113 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2114 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2115
2116 BasicBlock *TaskloopExitBB =
2117 splitBB(Builder, /*CreateBranch=*/true, Name: "taskloop.exit");
2118 BasicBlock *TaskloopBodyBB =
2119 splitBB(Builder, /*CreateBranch=*/true, Name: "taskloop.body");
2120 BasicBlock *TaskloopAllocaBB =
2121 splitBB(Builder, /*CreateBranch=*/true, Name: "taskloop.alloca");
2122
2123 InsertPointTy TaskloopAllocaIP =
2124 InsertPointTy(TaskloopAllocaBB, TaskloopAllocaBB->begin());
2125 InsertPointTy TaskloopBodyIP =
2126 InsertPointTy(TaskloopBodyBB, TaskloopBodyBB->begin());
2127
2128 if (Error Err = BodyGenCB(TaskloopAllocaIP, TaskloopBodyIP))
2129 return Err;
2130
2131 llvm::Expected<llvm::CanonicalLoopInfo *> result = LoopInfo();
2132 if (!result) {
2133 return result.takeError();
2134 }
2135
2136 llvm::CanonicalLoopInfo *CLI = result.get();
2137 OutlineInfo OI;
2138 OI.EntryBB = TaskloopAllocaBB;
2139 OI.OuterAllocaBB = AllocaIP.getBlock();
2140 OI.ExitBB = TaskloopExitBB;
2141
2142 // Add the thread ID argument.
2143 SmallVector<Instruction *> ToBeDeleted;
2144 // dummy instruction to be used as a fake argument
2145 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
2146 Builder, OuterAllocaIP: AllocaIP, ToBeDeleted, InnerAllocaIP: TaskloopAllocaIP, Name: "global.tid", AsPtr: false));
2147 Value *FakeLB = createFakeIntVal(Builder, OuterAllocaIP: AllocaIP, ToBeDeleted,
2148 InnerAllocaIP: TaskloopAllocaIP, Name: "lb", AsPtr: false, Is64Bit: true);
2149 Value *FakeUB = createFakeIntVal(Builder, OuterAllocaIP: AllocaIP, ToBeDeleted,
2150 InnerAllocaIP: TaskloopAllocaIP, Name: "ub", AsPtr: false, Is64Bit: true);
2151 Value *FakeStep = createFakeIntVal(Builder, OuterAllocaIP: AllocaIP, ToBeDeleted,
2152 InnerAllocaIP: TaskloopAllocaIP, Name: "step", AsPtr: false, Is64Bit: true);
2153 // For Taskloop, we want to force the bounds being the first 3 inputs in the
2154 // aggregate struct
2155 OI.Inputs.insert(X: FakeLB);
2156 OI.Inputs.insert(X: FakeUB);
2157 OI.Inputs.insert(X: FakeStep);
2158 if (TaskContextStructPtrVal)
2159 OI.Inputs.insert(X: TaskContextStructPtrVal);
2160 assert(((TaskContextStructPtrVal && DupCB) ||
2161 (!TaskContextStructPtrVal && !DupCB)) &&
2162 "Task context struct ptr and duplication callback must be both set "
2163 "or both null");
2164
2165 // It isn't safe to run the duplication bodygen callback inside the post
2166 // outlining callback so this has to be run now before we know the real task
2167 // shareds structure type.
2168 unsigned ProgramAddressSpace = M.getDataLayout().getProgramAddressSpace();
2169 Type *PointerTy = PointerType::get(C&: Builder.getContext(), AddressSpace: ProgramAddressSpace);
2170 Type *FakeSharedsTy = StructType::get(
2171 Context&: Builder.getContext(),
2172 Elements: {FakeLB->getType(), FakeUB->getType(), FakeStep->getType(), PointerTy});
2173 Expected<Value *> TaskDupFnOrErr = createTaskDuplicationFunction(
2174 PrivatesTy: FakeSharedsTy,
2175 /*PrivatesIndex: the pointer after the three indices above*/ PrivatesIndex: 3, DupCB);
2176 if (!TaskDupFnOrErr) {
2177 return TaskDupFnOrErr.takeError();
2178 }
2179 Value *TaskDupFn = *TaskDupFnOrErr;
2180
2181 OI.PostOutlineCB = [this, Ident, LBVal, UBVal, StepVal, Untied,
2182 TaskloopAllocaBB, CLI, Loc, TaskDupFn, ToBeDeleted,
2183 IfCond, GrainSize, NoGroup, Sched, FakeLB, FakeUB,
2184 FakeStep, FakeSharedsTy, Final, Mergeable, Priority,
2185 NumOfCollapseLoops](Function &OutlinedFn) mutable {
2186 // Replace the Stale CI by appropriate RTL function call.
2187 assert(OutlinedFn.hasOneUse() &&
2188 "there must be a single user for the outlined function");
2189 CallInst *StaleCI = cast<CallInst>(Val: OutlinedFn.user_back());
2190
2191 /* Create the casting for the Bounds Values that can be used when outlining
2192 * to replace the uses of the fakes with real values */
2193 BasicBlock *CodeReplBB = StaleCI->getParent();
2194 IRBuilderBase::InsertPoint CurrentIp = Builder.saveIP();
2195 Builder.SetInsertPoint(CodeReplBB->getFirstInsertionPt());
2196 Value *CastedLBVal =
2197 Builder.CreateIntCast(V: LBVal, DestTy: Builder.getInt64Ty(), isSigned: true, Name: "lb64");
2198 Value *CastedUBVal =
2199 Builder.CreateIntCast(V: UBVal, DestTy: Builder.getInt64Ty(), isSigned: true, Name: "ub64");
2200 Value *CastedStepVal =
2201 Builder.CreateIntCast(V: StepVal, DestTy: Builder.getInt64Ty(), isSigned: true, Name: "step64");
2202 Builder.restoreIP(IP: CurrentIp);
2203
2204 Builder.SetInsertPoint(StaleCI);
2205
2206 // Gather the arguments for emitting the runtime call for
2207 // @__kmpc_omp_task_alloc
2208 Function *TaskAllocFn =
2209 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_alloc);
2210
2211 Value *ThreadID = getOrCreateThreadID(Ident);
2212
2213 if (!NoGroup) {
2214 // Emit runtime call for @__kmpc_taskgroup
2215 Function *TaskgroupFn =
2216 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_taskgroup);
2217 Builder.CreateCall(Callee: TaskgroupFn, Args: {Ident, ThreadID});
2218 }
2219
2220 // `flags` Argument Configuration
2221 // Task is tied if (Flags & 1) == 1.
2222 // Task is untied if (Flags & 1) == 0.
2223 // Task is final if (Flags & 2) == 2.
2224 // Task is not final if (Flags & 2) == 0.
2225 // Task is mergeable if (Flags & 4) == 4.
2226 // Task is not mergeable if (Flags & 4) == 0.
2227 // Task is priority if (Flags & 32) == 32.
2228 // Task is not priority if (Flags & 32) == 0.
2229 Value *Flags = Builder.getInt32(C: Untied ? 0 : 1);
2230 if (Final)
2231 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 2), RHS: Flags);
2232 if (Mergeable)
2233 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 4), RHS: Flags);
2234 if (Priority)
2235 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 32), RHS: Flags);
2236
2237 Value *TaskSize = Builder.getInt64(
2238 C: divideCeil(Numerator: M.getDataLayout().getTypeSizeInBits(Ty: Task), Denominator: 8));
2239
2240 AllocaInst *ArgStructAlloca =
2241 dyn_cast<AllocaInst>(Val: StaleCI->getArgOperand(i: 1));
2242 assert(ArgStructAlloca &&
2243 "Unable to find the alloca instruction corresponding to arguments "
2244 "for extracted function");
2245 std::optional<TypeSize> ArgAllocSize =
2246 ArgStructAlloca->getAllocationSize(DL: M.getDataLayout());
2247 assert(ArgAllocSize &&
2248 "Unable to determine size of arguments for extracted function");
2249 Value *SharedsSize = Builder.getInt64(C: ArgAllocSize->getFixedValue());
2250
2251 // Emit the @__kmpc_omp_task_alloc runtime call
2252 // The runtime call returns a pointer to an area where the task captured
2253 // variables must be copied before the task is run (TaskData)
2254 CallInst *TaskData = Builder.CreateCall(
2255 Callee: TaskAllocFn, Args: {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
2256 /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
2257 /*task_func=*/&OutlinedFn});
2258
2259 Value *Shareds = StaleCI->getArgOperand(i: 1);
2260 Align Alignment = TaskData->getPointerAlignment(DL: M.getDataLayout());
2261 Value *TaskShareds = Builder.CreateLoad(Ty: VoidPtr, Ptr: TaskData);
2262 Builder.CreateMemCpy(Dst: TaskShareds, DstAlign: Alignment, Src: Shareds, SrcAlign: Alignment,
2263 Size: SharedsSize);
2264 // Get the pointer to loop lb, ub, step from task ptr
2265 // and set up the lowerbound,upperbound and step values
2266 llvm::Value *Lb = Builder.CreateGEP(
2267 Ty: FakeSharedsTy, Ptr: TaskShareds, IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 0)});
2268
2269 llvm::Value *Ub = Builder.CreateGEP(
2270 Ty: FakeSharedsTy, Ptr: TaskShareds, IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 1)});
2271
2272 llvm::Value *Step = Builder.CreateGEP(
2273 Ty: FakeSharedsTy, Ptr: TaskShareds, IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 2)});
2274 llvm::Value *Loadstep = Builder.CreateLoad(Ty: Builder.getInt64Ty(), Ptr: Step);
2275
2276 // set up the arguments for emitting kmpc_taskloop runtime call
2277 // setting values for ifval, nogroup, sched, grainsize, task_dup
2278 Value *IfCondVal =
2279 IfCond ? Builder.CreateIntCast(V: IfCond, DestTy: Builder.getInt32Ty(), isSigned: true)
2280 : Builder.getInt32(C: 1);
2281 // As __kmpc_taskgroup is called manually in OMPIRBuilder, NoGroupVal should
2282 // always be 1 when calling __kmpc_taskloop to ensure it is not called again
2283 Value *NoGroupVal = Builder.getInt32(C: 1);
2284 Value *SchedVal = Builder.getInt32(C: Sched);
2285 Value *GrainSizeVal =
2286 GrainSize ? Builder.CreateIntCast(V: GrainSize, DestTy: Builder.getInt64Ty(), isSigned: true)
2287 : Builder.getInt64(C: 0);
2288 Value *TaskDup = TaskDupFn;
2289
2290 Value *Args[] = {Ident, ThreadID, TaskData, IfCondVal, Lb, Ub,
2291 Loadstep, NoGroupVal, SchedVal, GrainSizeVal, TaskDup};
2292
2293 // taskloop runtime call
2294 Function *TaskloopFn =
2295 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_taskloop);
2296 Builder.CreateCall(Callee: TaskloopFn, Args);
2297
2298 // Emit the @__kmpc_end_taskgroup runtime call to end the taskgroup if
2299 // nogroup is not defined
2300 if (!NoGroup) {
2301 Function *EndTaskgroupFn =
2302 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_taskgroup);
2303 Builder.CreateCall(Callee: EndTaskgroupFn, Args: {Ident, ThreadID});
2304 }
2305
2306 StaleCI->eraseFromParent();
2307
2308 Builder.SetInsertPoint(TheBB: TaskloopAllocaBB, IP: TaskloopAllocaBB->begin());
2309
2310 LoadInst *SharedsOutlined =
2311 Builder.CreateLoad(Ty: VoidPtr, Ptr: OutlinedFn.getArg(i: 1));
2312 OutlinedFn.getArg(i: 1)->replaceUsesWithIf(
2313 New: SharedsOutlined,
2314 ShouldReplace: [SharedsOutlined](Use &U) { return U.getUser() != SharedsOutlined; });
2315
2316 Value *IV = CLI->getIndVar();
2317 Type *IVTy = IV->getType();
2318 Constant *One = ConstantInt::get(Ty: Builder.getInt64Ty(), V: 1);
2319
2320 // When outlining, CodeExtractor will create GEP's to the LowerBound and
2321 // UpperBound. These GEP's can be reused for loading the tasks respective
2322 // bounds.
2323 Value *TaskLB = nullptr;
2324 Value *TaskUB = nullptr;
2325 Value *LoadTaskLB = nullptr;
2326 Value *LoadTaskUB = nullptr;
2327 for (Instruction &I : *TaskloopAllocaBB) {
2328 if (I.getOpcode() == Instruction::GetElementPtr) {
2329 GetElementPtrInst &Gep = cast<GetElementPtrInst>(Val&: I);
2330 if (ConstantInt *CI = dyn_cast<ConstantInt>(Val: Gep.getOperand(i_nocapture: 2))) {
2331 switch (CI->getZExtValue()) {
2332 case 0:
2333 TaskLB = &I;
2334 break;
2335 case 1:
2336 TaskUB = &I;
2337 break;
2338 }
2339 }
2340 } else if (I.getOpcode() == Instruction::Load) {
2341 LoadInst &Load = cast<LoadInst>(Val&: I);
2342 if (Load.getPointerOperand() == TaskLB) {
2343 assert(TaskLB != nullptr && "Expected value for TaskLB");
2344 LoadTaskLB = &I;
2345 } else if (Load.getPointerOperand() == TaskUB) {
2346 assert(TaskUB != nullptr && "Expected value for TaskUB");
2347 LoadTaskUB = &I;
2348 }
2349 }
2350 }
2351
2352 Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
2353
2354 assert(LoadTaskLB != nullptr && "Expected value for LoadTaskLB");
2355 assert(LoadTaskUB != nullptr && "Expected value for LoadTaskUB");
2356 Value *TripCountMinusOne =
2357 Builder.CreateSDiv(LHS: Builder.CreateSub(LHS: LoadTaskUB, RHS: LoadTaskLB), RHS: FakeStep);
2358 Value *TripCount = Builder.CreateAdd(LHS: TripCountMinusOne, RHS: One, Name: "trip_cnt");
2359 Value *CastedTripCount = Builder.CreateIntCast(V: TripCount, DestTy: IVTy, isSigned: true);
2360 Value *CastedTaskLB = Builder.CreateIntCast(V: LoadTaskLB, DestTy: IVTy, isSigned: true);
2361 // set the trip count in the CLI
2362 CLI->setTripCount(CastedTripCount);
2363
2364 Builder.SetInsertPoint(TheBB: CLI->getBody(),
2365 IP: CLI->getBody()->getFirstInsertionPt());
2366
2367 if (NumOfCollapseLoops > 1) {
2368 llvm::SmallVector<User *> UsersToReplace;
2369 // When using the collapse clause, the bounds of the loop have to be
2370 // adjusted to properly represent the iterator of the outer loop.
2371 Value *IVPlusTaskLB = Builder.CreateAdd(
2372 LHS: CLI->getIndVar(),
2373 RHS: Builder.CreateSub(LHS: CastedTaskLB, RHS: ConstantInt::get(Ty: IVTy, V: 1)));
2374 // To ensure every Use is correctly captured, we first want to record
2375 // which users to replace the value in, and then replace the value.
2376 for (auto IVUse = CLI->getIndVar()->uses().begin();
2377 IVUse != CLI->getIndVar()->uses().end(); IVUse++) {
2378 User *IVUser = IVUse->getUser();
2379 if (auto *Op = dyn_cast<BinaryOperator>(Val: IVUser)) {
2380 if (Op->getOpcode() == Instruction::URem ||
2381 Op->getOpcode() == Instruction::UDiv) {
2382 UsersToReplace.push_back(Elt: IVUser);
2383 }
2384 }
2385 }
2386 for (User *User : UsersToReplace) {
2387 User->replaceUsesOfWith(From: CLI->getIndVar(), To: IVPlusTaskLB);
2388 }
2389 } else {
2390 // The canonical loop is generated with a fixed lower bound. We need to
2391 // update the index calculation code to use the task's lower bound. The
2392 // generated code looks like this:
2393 // %omp_loop.iv = phi ...
2394 // ...
2395 // %tmp = mul [type] %omp_loop.iv, step
2396 // %user_index = add [type] tmp, lb
2397 // OpenMPIRBuilder constructs canonical loops to have exactly three uses
2398 // of the normalised induction variable:
2399 // 1. This one: converting the normalised IV to the user IV
2400 // 2. The increment (add)
2401 // 3. The comparison against the trip count (icmp)
2402 // (1) is the only use that is a mul followed by an add so this cannot
2403 // match other IR.
2404 assert(CLI->getIndVar()->getNumUses() == 3 &&
2405 "Canonical loop should have exactly three uses of the ind var");
2406 for (User *IVUser : CLI->getIndVar()->users()) {
2407 if (auto *Mul = dyn_cast<BinaryOperator>(Val: IVUser)) {
2408 if (Mul->getOpcode() == Instruction::Mul) {
2409 for (User *MulUser : Mul->users()) {
2410 if (auto *Add = dyn_cast<BinaryOperator>(Val: MulUser)) {
2411 if (Add->getOpcode() == Instruction::Add) {
2412 Add->setOperand(i_nocapture: 1, Val_nocapture: CastedTaskLB);
2413 }
2414 }
2415 }
2416 }
2417 }
2418 }
2419 }
2420
2421 FakeLB->replaceAllUsesWith(V: CastedLBVal);
2422 FakeUB->replaceAllUsesWith(V: CastedUBVal);
2423 FakeStep->replaceAllUsesWith(V: CastedStepVal);
2424 for (Instruction *I : llvm::reverse(C&: ToBeDeleted)) {
2425 I->eraseFromParent();
2426 }
2427 };
2428
2429 addOutlineInfo(OI: std::move(OI));
2430 Builder.SetInsertPoint(TheBB: TaskloopExitBB, IP: TaskloopExitBB->begin());
2431 return Builder.saveIP();
2432}
2433
2434llvm::StructType *OpenMPIRBuilder::getKmpTaskAffinityInfoTy() {
2435 llvm::Type *IntPtrTy = llvm::Type::getIntNTy(
2436 C&: M.getContext(), N: M.getDataLayout().getPointerSizeInBits());
2437 return llvm::StructType::get(elt1: IntPtrTy, elts: IntPtrTy,
2438 elts: llvm::Type::getInt32Ty(C&: M.getContext()));
2439}
2440
2441OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
2442 const LocationDescription &Loc, InsertPointTy AllocaIP,
2443 BodyGenCallbackTy BodyGenCB, bool Tied, Value *Final, Value *IfCondition,
2444 SmallVector<DependData> Dependencies, AffinityData Affinities,
2445 bool Mergeable, Value *EventHandle, Value *Priority) {
2446
2447 if (!updateToLocation(Loc))
2448 return InsertPointTy();
2449
2450 uint32_t SrcLocStrSize;
2451 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2452 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2453 // The current basic block is split into four basic blocks. After outlining,
2454 // they will be mapped as follows:
2455 // ```
2456 // def current_fn() {
2457 // current_basic_block:
2458 // br label %task.exit
2459 // task.exit:
2460 // ; instructions after task
2461 // }
2462 // def outlined_fn() {
2463 // task.alloca:
2464 // br label %task.body
2465 // task.body:
2466 // ret void
2467 // }
2468 // ```
2469 BasicBlock *TaskExitBB = splitBB(Builder, /*CreateBranch=*/true, Name: "task.exit");
2470 BasicBlock *TaskBodyBB = splitBB(Builder, /*CreateBranch=*/true, Name: "task.body");
2471 BasicBlock *TaskAllocaBB =
2472 splitBB(Builder, /*CreateBranch=*/true, Name: "task.alloca");
2473
2474 InsertPointTy TaskAllocaIP =
2475 InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin());
2476 InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin());
2477 if (Error Err = BodyGenCB(TaskAllocaIP, TaskBodyIP))
2478 return Err;
2479
2480 OutlineInfo OI;
2481 OI.EntryBB = TaskAllocaBB;
2482 OI.OuterAllocaBB = AllocaIP.getBlock();
2483 OI.ExitBB = TaskExitBB;
2484
2485 // Add the thread ID argument.
2486 SmallVector<Instruction *, 4> ToBeDeleted;
2487 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
2488 Builder, OuterAllocaIP: AllocaIP, ToBeDeleted, InnerAllocaIP: TaskAllocaIP, Name: "global.tid", AsPtr: false));
2489
2490 OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
2491 Affinities, Mergeable, Priority, EventHandle,
2492 TaskAllocaBB, ToBeDeleted](Function &OutlinedFn) mutable {
2493 // Replace the Stale CI by appropriate RTL function call.
2494 assert(OutlinedFn.hasOneUse() &&
2495 "there must be a single user for the outlined function");
2496 CallInst *StaleCI = cast<CallInst>(Val: OutlinedFn.user_back());
2497
2498 // HasShareds is true if any variables are captured in the outlined region,
2499 // false otherwise.
2500 bool HasShareds = StaleCI->arg_size() > 1;
2501 Builder.SetInsertPoint(StaleCI);
2502
2503 // Gather the arguments for emitting the runtime call for
2504 // @__kmpc_omp_task_alloc
2505 Function *TaskAllocFn =
2506 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_alloc);
2507
2508 // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
2509 // call.
2510 Value *ThreadID = getOrCreateThreadID(Ident);
2511
2512 // Argument - `flags`
2513 // Task is tied iff (Flags & 1) == 1.
2514 // Task is untied iff (Flags & 1) == 0.
2515 // Task is final iff (Flags & 2) == 2.
2516 // Task is not final iff (Flags & 2) == 0.
2517 // Task is mergeable iff (Flags & 4) == 4.
2518 // Task is not mergeable iff (Flags & 4) == 0.
2519 // Task is priority iff (Flags & 32) == 32.
2520 // Task is not priority iff (Flags & 32) == 0.
2521 // TODO: Handle the other flags.
2522 Value *Flags = Builder.getInt32(C: Tied);
2523 if (Final) {
2524 Value *FinalFlag =
2525 Builder.CreateSelect(C: Final, True: Builder.getInt32(C: 2), False: Builder.getInt32(C: 0));
2526 Flags = Builder.CreateOr(LHS: FinalFlag, RHS: Flags);
2527 }
2528
2529 if (Mergeable)
2530 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 4), RHS: Flags);
2531 if (Priority)
2532 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 32), RHS: Flags);
2533
2534 // Argument - `sizeof_kmp_task_t` (TaskSize)
2535 // Tasksize refers to the size in bytes of kmp_task_t data structure
2536 // including private vars accessed in task.
2537 // TODO: add kmp_task_t_with_privates (privates)
2538 Value *TaskSize = Builder.getInt64(
2539 C: divideCeil(Numerator: M.getDataLayout().getTypeSizeInBits(Ty: Task), Denominator: 8));
2540
2541 // Argument - `sizeof_shareds` (SharedsSize)
2542 // SharedsSize refers to the shareds array size in the kmp_task_t data
2543 // structure.
2544 Value *SharedsSize = Builder.getInt64(C: 0);
2545 if (HasShareds) {
2546 AllocaInst *ArgStructAlloca =
2547 dyn_cast<AllocaInst>(Val: StaleCI->getArgOperand(i: 1));
2548 assert(ArgStructAlloca &&
2549 "Unable to find the alloca instruction corresponding to arguments "
2550 "for extracted function");
2551 std::optional<TypeSize> ArgAllocSize =
2552 ArgStructAlloca->getAllocationSize(DL: M.getDataLayout());
2553 assert(ArgAllocSize &&
2554 "Unable to determine size of arguments for extracted function");
2555 SharedsSize = Builder.getInt64(C: ArgAllocSize->getFixedValue());
2556 }
2557 // Emit the @__kmpc_omp_task_alloc runtime call
2558 // The runtime call returns a pointer to an area where the task captured
2559 // variables must be copied before the task is run (TaskData)
2560 CallInst *TaskData = createRuntimeFunctionCall(
2561 Callee: TaskAllocFn, Args: {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
2562 /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
2563 /*task_func=*/&OutlinedFn});
2564
2565 if (Affinities.Count && Affinities.Info) {
2566 Function *RegAffFn = getOrCreateRuntimeFunctionPtr(
2567 FnID: OMPRTL___kmpc_omp_reg_task_with_affinity);
2568
2569 createRuntimeFunctionCall(Callee: RegAffFn, Args: {Ident, ThreadID, TaskData,
2570 Affinities.Count, Affinities.Info});
2571 }
2572
2573 // Emit detach clause initialization.
2574 // evt = (typeof(evt))__kmpc_task_allow_completion_event(loc, tid,
2575 // task_descriptor);
2576 if (EventHandle) {
2577 Function *TaskDetachFn = getOrCreateRuntimeFunctionPtr(
2578 FnID: OMPRTL___kmpc_task_allow_completion_event);
2579 llvm::Value *EventVal =
2580 createRuntimeFunctionCall(Callee: TaskDetachFn, Args: {Ident, ThreadID, TaskData});
2581 llvm::Value *EventHandleAddr =
2582 Builder.CreatePointerBitCastOrAddrSpaceCast(V: EventHandle,
2583 DestTy: Builder.getPtrTy(AddrSpace: 0));
2584 EventVal = Builder.CreatePtrToInt(V: EventVal, DestTy: Builder.getInt64Ty());
2585 Builder.CreateStore(Val: EventVal, Ptr: EventHandleAddr);
2586 }
2587 // Copy the arguments for outlined function
2588 if (HasShareds) {
2589 Value *Shareds = StaleCI->getArgOperand(i: 1);
2590 Align Alignment = TaskData->getPointerAlignment(DL: M.getDataLayout());
2591 Value *TaskShareds = Builder.CreateLoad(Ty: VoidPtr, Ptr: TaskData);
2592 Builder.CreateMemCpy(Dst: TaskShareds, DstAlign: Alignment, Src: Shareds, SrcAlign: Alignment,
2593 Size: SharedsSize);
2594 }
2595
2596 if (Priority) {
2597 //
2598 // The return type of "__kmpc_omp_task_alloc" is "kmp_task_t *",
2599 // we populate the priority information into the "kmp_task_t" here
2600 //
2601 // The struct "kmp_task_t" definition is available in kmp.h
2602 // kmp_task_t = { shareds, routine, part_id, data1, data2 }
2603 // data2 is used for priority
2604 //
2605 Type *Int32Ty = Builder.getInt32Ty();
2606 Constant *Zero = ConstantInt::get(Ty: Int32Ty, V: 0);
2607 // kmp_task_t* => { ptr }
2608 Type *TaskPtr = StructType::get(elt1: VoidPtr);
2609 Value *TaskGEP =
2610 Builder.CreateInBoundsGEP(Ty: TaskPtr, Ptr: TaskData, IdxList: {Zero, Zero});
2611 // kmp_task_t => { ptr, ptr, i32, ptr, ptr }
2612 Type *TaskStructType = StructType::get(
2613 elt1: VoidPtr, elts: VoidPtr, elts: Builder.getInt32Ty(), elts: VoidPtr, elts: VoidPtr);
2614 Value *PriorityData = Builder.CreateInBoundsGEP(
2615 Ty: TaskStructType, Ptr: TaskGEP, IdxList: {Zero, ConstantInt::get(Ty: Int32Ty, V: 4)});
2616 // kmp_cmplrdata_t => { ptr, ptr }
2617 Type *CmplrStructType = StructType::get(elt1: VoidPtr, elts: VoidPtr);
2618 Value *CmplrData = Builder.CreateInBoundsGEP(Ty: CmplrStructType,
2619 Ptr: PriorityData, IdxList: {Zero, Zero});
2620 Builder.CreateStore(Val: Priority, Ptr: CmplrData);
2621 }
2622
2623 Value *DepArray = emitTaskDependencies(OMPBuilder&: *this, Dependencies);
2624
2625 // In the presence of the `if` clause, the following IR is generated:
2626 // ...
2627 // %data = call @__kmpc_omp_task_alloc(...)
2628 // br i1 %if_condition, label %then, label %else
2629 // then:
2630 // call @__kmpc_omp_task(...)
2631 // br label %exit
2632 // else:
2633 // ;; Wait for resolution of dependencies, if any, before
2634 // ;; beginning the task
2635 // call @__kmpc_omp_wait_deps(...)
2636 // call @__kmpc_omp_task_begin_if0(...)
2637 // call @outlined_fn(...)
2638 // call @__kmpc_omp_task_complete_if0(...)
2639 // br label %exit
2640 // exit:
2641 // ...
2642 if (IfCondition) {
2643 // `SplitBlockAndInsertIfThenElse` requires the block to have a
2644 // terminator.
2645 splitBB(Builder, /*CreateBranch=*/true, Name: "if.end");
2646 Instruction *IfTerminator =
2647 Builder.GetInsertPoint()->getParent()->getTerminator();
2648 Instruction *ThenTI = IfTerminator, *ElseTI = nullptr;
2649 Builder.SetInsertPoint(IfTerminator);
2650 SplitBlockAndInsertIfThenElse(Cond: IfCondition, SplitBefore: IfTerminator, ThenTerm: &ThenTI,
2651 ElseTerm: &ElseTI);
2652 Builder.SetInsertPoint(ElseTI);
2653
2654 if (Dependencies.size()) {
2655 Function *TaskWaitFn =
2656 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_wait_deps);
2657 createRuntimeFunctionCall(
2658 Callee: TaskWaitFn,
2659 Args: {Ident, ThreadID, Builder.getInt32(C: Dependencies.size()), DepArray,
2660 ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
2661 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
2662 }
2663 Function *TaskBeginFn =
2664 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_begin_if0);
2665 Function *TaskCompleteFn =
2666 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_complete_if0);
2667 createRuntimeFunctionCall(Callee: TaskBeginFn, Args: {Ident, ThreadID, TaskData});
2668 CallInst *CI = nullptr;
2669 if (HasShareds)
2670 CI = createRuntimeFunctionCall(Callee: &OutlinedFn, Args: {ThreadID, TaskData});
2671 else
2672 CI = createRuntimeFunctionCall(Callee: &OutlinedFn, Args: {ThreadID});
2673 CI->setDebugLoc(StaleCI->getDebugLoc());
2674 createRuntimeFunctionCall(Callee: TaskCompleteFn, Args: {Ident, ThreadID, TaskData});
2675 Builder.SetInsertPoint(ThenTI);
2676 }
2677
2678 if (Dependencies.size()) {
2679 Function *TaskFn =
2680 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_with_deps);
2681 createRuntimeFunctionCall(
2682 Callee: TaskFn,
2683 Args: {Ident, ThreadID, TaskData, Builder.getInt32(C: Dependencies.size()),
2684 DepArray, ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
2685 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
2686
2687 } else {
2688 // Emit the @__kmpc_omp_task runtime call to spawn the task
2689 Function *TaskFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task);
2690 createRuntimeFunctionCall(Callee: TaskFn, Args: {Ident, ThreadID, TaskData});
2691 }
2692
2693 StaleCI->eraseFromParent();
2694
2695 Builder.SetInsertPoint(TheBB: TaskAllocaBB, IP: TaskAllocaBB->begin());
2696 if (HasShareds) {
2697 LoadInst *Shareds = Builder.CreateLoad(Ty: VoidPtr, Ptr: OutlinedFn.getArg(i: 1));
2698 OutlinedFn.getArg(i: 1)->replaceUsesWithIf(
2699 New: Shareds, ShouldReplace: [Shareds](Use &U) { return U.getUser() != Shareds; });
2700 }
2701
2702 for (Instruction *I : llvm::reverse(C&: ToBeDeleted))
2703 I->eraseFromParent();
2704 };
2705
2706 addOutlineInfo(OI: std::move(OI));
2707 Builder.SetInsertPoint(TheBB: TaskExitBB, IP: TaskExitBB->begin());
2708
2709 return Builder.saveIP();
2710}
2711
2712OpenMPIRBuilder::InsertPointOrErrorTy
2713OpenMPIRBuilder::createTaskgroup(const LocationDescription &Loc,
2714 InsertPointTy AllocaIP,
2715 BodyGenCallbackTy BodyGenCB) {
2716 if (!updateToLocation(Loc))
2717 return InsertPointTy();
2718
2719 uint32_t SrcLocStrSize;
2720 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2721 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2722 Value *ThreadID = getOrCreateThreadID(Ident);
2723
2724 // Emit the @__kmpc_taskgroup runtime call to start the taskgroup
2725 Function *TaskgroupFn =
2726 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_taskgroup);
2727 createRuntimeFunctionCall(Callee: TaskgroupFn, Args: {Ident, ThreadID});
2728
2729 BasicBlock *TaskgroupExitBB = splitBB(Builder, CreateBranch: true, Name: "taskgroup.exit");
2730 if (Error Err = BodyGenCB(AllocaIP, Builder.saveIP()))
2731 return Err;
2732
2733 Builder.SetInsertPoint(TaskgroupExitBB);
2734 // Emit the @__kmpc_end_taskgroup runtime call to end the taskgroup
2735 Function *EndTaskgroupFn =
2736 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_taskgroup);
2737 createRuntimeFunctionCall(Callee: EndTaskgroupFn, Args: {Ident, ThreadID});
2738
2739 return Builder.saveIP();
2740}
2741
2742OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createSections(
2743 const LocationDescription &Loc, InsertPointTy AllocaIP,
2744 ArrayRef<StorableBodyGenCallbackTy> SectionCBs, PrivatizeCallbackTy PrivCB,
2745 FinalizeCallbackTy FiniCB, bool IsCancellable, bool IsNowait) {
2746 assert(!isConflictIP(AllocaIP, Loc.IP) && "Dedicated IP allocas required");
2747
2748 if (!updateToLocation(Loc))
2749 return Loc.IP;
2750
2751 FinalizationStack.push_back(Elt: {FiniCB, OMPD_sections, IsCancellable});
2752
2753 // Each section is emitted as a switch case
2754 // Each finalization callback is handled from clang.EmitOMPSectionDirective()
2755 // -> OMP.createSection() which generates the IR for each section
2756 // Iterate through all sections and emit a switch construct:
2757 // switch (IV) {
2758 // case 0:
2759 // <SectionStmt[0]>;
2760 // break;
2761 // ...
2762 // case <NumSection> - 1:
2763 // <SectionStmt[<NumSection> - 1]>;
2764 // break;
2765 // }
2766 // ...
2767 // section_loop.after:
2768 // <FiniCB>;
2769 auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, Value *IndVar) -> Error {
2770 Builder.restoreIP(IP: CodeGenIP);
2771 BasicBlock *Continue =
2772 splitBBWithSuffix(Builder, /*CreateBranch=*/false, Suffix: ".sections.after");
2773 Function *CurFn = Continue->getParent();
2774 SwitchInst *SwitchStmt = Builder.CreateSwitch(V: IndVar, Dest: Continue);
2775
2776 unsigned CaseNumber = 0;
2777 for (auto SectionCB : SectionCBs) {
2778 BasicBlock *CaseBB = BasicBlock::Create(
2779 Context&: M.getContext(), Name: "omp_section_loop.body.case", Parent: CurFn, InsertBefore: Continue);
2780 SwitchStmt->addCase(OnVal: Builder.getInt32(C: CaseNumber), Dest: CaseBB);
2781 Builder.SetInsertPoint(CaseBB);
2782 BranchInst *CaseEndBr = Builder.CreateBr(Dest: Continue);
2783 if (Error Err = SectionCB(InsertPointTy(), {CaseEndBr->getParent(),
2784 CaseEndBr->getIterator()}))
2785 return Err;
2786 CaseNumber++;
2787 }
2788 // remove the existing terminator from body BB since there can be no
2789 // terminators after switch/case
2790 return Error::success();
2791 };
2792 // Loop body ends here
2793 // LowerBound, UpperBound, and STride for createCanonicalLoop
2794 Type *I32Ty = Type::getInt32Ty(C&: M.getContext());
2795 Value *LB = ConstantInt::get(Ty: I32Ty, V: 0);
2796 Value *UB = ConstantInt::get(Ty: I32Ty, V: SectionCBs.size());
2797 Value *ST = ConstantInt::get(Ty: I32Ty, V: 1);
2798 Expected<CanonicalLoopInfo *> LoopInfo = createCanonicalLoop(
2799 Loc, BodyGenCB: LoopBodyGenCB, Start: LB, Stop: UB, Step: ST, IsSigned: true, InclusiveStop: false, ComputeIP: AllocaIP, Name: "section_loop");
2800 if (!LoopInfo)
2801 return LoopInfo.takeError();
2802
2803 InsertPointOrErrorTy WsloopIP =
2804 applyStaticWorkshareLoop(DL: Loc.DL, CLI: *LoopInfo, AllocaIP,
2805 LoopType: WorksharingLoopType::ForStaticLoop, NeedsBarrier: !IsNowait);
2806 if (!WsloopIP)
2807 return WsloopIP.takeError();
2808 InsertPointTy AfterIP = *WsloopIP;
2809
2810 BasicBlock *LoopFini = AfterIP.getBlock()->getSinglePredecessor();
2811 assert(LoopFini && "Bad structure of static workshare loop finalization");
2812
2813 // Apply the finalization callback in LoopAfterBB
2814 auto FiniInfo = FinalizationStack.pop_back_val();
2815 assert(FiniInfo.DK == OMPD_sections &&
2816 "Unexpected finalization stack state!");
2817 if (Error Err = FiniInfo.mergeFiniBB(Builder, OtherFiniBB: LoopFini))
2818 return Err;
2819
2820 return AfterIP;
2821}
2822
2823OpenMPIRBuilder::InsertPointOrErrorTy
2824OpenMPIRBuilder::createSection(const LocationDescription &Loc,
2825 BodyGenCallbackTy BodyGenCB,
2826 FinalizeCallbackTy FiniCB) {
2827 if (!updateToLocation(Loc))
2828 return Loc.IP;
2829
2830 auto FiniCBWrapper = [&](InsertPointTy IP) {
2831 if (IP.getBlock()->end() != IP.getPoint())
2832 return FiniCB(IP);
2833 // This must be done otherwise any nested constructs using FinalizeOMPRegion
2834 // will fail because that function requires the Finalization Basic Block to
2835 // have a terminator, which is already removed by EmitOMPRegionBody.
2836 // IP is currently at cancelation block.
2837 // We need to backtrack to the condition block to fetch
2838 // the exit block and create a branch from cancelation
2839 // to exit block.
2840 IRBuilder<>::InsertPointGuard IPG(Builder);
2841 Builder.restoreIP(IP);
2842 auto *CaseBB = Loc.IP.getBlock();
2843 auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
2844 auto *ExitBB = CondBB->getTerminator()->getSuccessor(Idx: 1);
2845 Instruction *I = Builder.CreateBr(Dest: ExitBB);
2846 IP = InsertPointTy(I->getParent(), I->getIterator());
2847 return FiniCB(IP);
2848 };
2849
2850 Directive OMPD = Directive::OMPD_sections;
2851 // Since we are using Finalization Callback here, HasFinalize
2852 // and IsCancellable have to be true
2853 return EmitOMPInlinedRegion(OMPD, EntryCall: nullptr, ExitCall: nullptr, BodyGenCB, FiniCB: FiniCBWrapper,
2854 /*Conditional*/ false, /*hasFinalize*/ HasFinalize: true,
2855 /*IsCancellable*/ true);
2856}
2857
2858static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
2859 BasicBlock::iterator IT(I);
2860 IT++;
2861 return OpenMPIRBuilder::InsertPointTy(I->getParent(), IT);
2862}
2863
2864Value *OpenMPIRBuilder::getGPUThreadID() {
2865 return createRuntimeFunctionCall(
2866 Callee: getOrCreateRuntimeFunction(M,
2867 FnID: OMPRTL___kmpc_get_hardware_thread_id_in_block),
2868 Args: {});
2869}
2870
2871Value *OpenMPIRBuilder::getGPUWarpSize() {
2872 return createRuntimeFunctionCall(
2873 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___kmpc_get_warp_size), Args: {});
2874}
2875
2876Value *OpenMPIRBuilder::getNVPTXWarpID() {
2877 unsigned LaneIDBits = Log2_32(Value: Config.getGridValue().GV_Warp_Size);
2878 return Builder.CreateAShr(LHS: getGPUThreadID(), RHS: LaneIDBits, Name: "nvptx_warp_id");
2879}
2880
2881Value *OpenMPIRBuilder::getNVPTXLaneID() {
2882 unsigned LaneIDBits = Log2_32(Value: Config.getGridValue().GV_Warp_Size);
2883 assert(LaneIDBits < 32 && "Invalid LaneIDBits size in NVPTX device.");
2884 unsigned LaneIDMask = ~0u >> (32u - LaneIDBits);
2885 return Builder.CreateAnd(LHS: getGPUThreadID(), RHS: Builder.getInt32(C: LaneIDMask),
2886 Name: "nvptx_lane_id");
2887}
2888
2889Value *OpenMPIRBuilder::castValueToType(InsertPointTy AllocaIP, Value *From,
2890 Type *ToType) {
2891 Type *FromType = From->getType();
2892 uint64_t FromSize = M.getDataLayout().getTypeStoreSize(Ty: FromType);
2893 uint64_t ToSize = M.getDataLayout().getTypeStoreSize(Ty: ToType);
2894 assert(FromSize > 0 && "From size must be greater than zero");
2895 assert(ToSize > 0 && "To size must be greater than zero");
2896 if (FromType == ToType)
2897 return From;
2898 if (FromSize == ToSize)
2899 return Builder.CreateBitCast(V: From, DestTy: ToType);
2900 if (ToType->isIntegerTy() && FromType->isIntegerTy())
2901 return Builder.CreateIntCast(V: From, DestTy: ToType, /*isSigned*/ true);
2902 InsertPointTy SaveIP = Builder.saveIP();
2903 Builder.restoreIP(IP: AllocaIP);
2904 Value *CastItem = Builder.CreateAlloca(Ty: ToType);
2905 Builder.restoreIP(IP: SaveIP);
2906
2907 Value *ValCastItem = Builder.CreatePointerBitCastOrAddrSpaceCast(
2908 V: CastItem, DestTy: Builder.getPtrTy(AddrSpace: 0));
2909 Builder.CreateStore(Val: From, Ptr: ValCastItem);
2910 return Builder.CreateLoad(Ty: ToType, Ptr: CastItem);
2911}
2912
2913Value *OpenMPIRBuilder::createRuntimeShuffleFunction(InsertPointTy AllocaIP,
2914 Value *Element,
2915 Type *ElementType,
2916 Value *Offset) {
2917 uint64_t Size = M.getDataLayout().getTypeStoreSize(Ty: ElementType);
2918 assert(Size <= 8 && "Unsupported bitwidth in shuffle instruction");
2919
2920 // Cast all types to 32- or 64-bit values before calling shuffle routines.
2921 Type *CastTy = Builder.getIntNTy(N: Size <= 4 ? 32 : 64);
2922 Value *ElemCast = castValueToType(AllocaIP, From: Element, ToType: CastTy);
2923 Value *WarpSize =
2924 Builder.CreateIntCast(V: getGPUWarpSize(), DestTy: Builder.getInt16Ty(), isSigned: true);
2925 Function *ShuffleFunc = getOrCreateRuntimeFunctionPtr(
2926 FnID: Size <= 4 ? RuntimeFunction::OMPRTL___kmpc_shuffle_int32
2927 : RuntimeFunction::OMPRTL___kmpc_shuffle_int64);
2928 Value *WarpSizeCast =
2929 Builder.CreateIntCast(V: WarpSize, DestTy: Builder.getInt16Ty(), /*isSigned=*/true);
2930 Value *ShuffleCall =
2931 createRuntimeFunctionCall(Callee: ShuffleFunc, Args: {ElemCast, Offset, WarpSizeCast});
2932 return castValueToType(AllocaIP, From: ShuffleCall, ToType: CastTy);
2933}
2934
2935void OpenMPIRBuilder::shuffleAndStore(InsertPointTy AllocaIP, Value *SrcAddr,
2936 Value *DstAddr, Type *ElemType,
2937 Value *Offset, Type *ReductionArrayTy,
2938 bool IsByRefElem) {
2939 uint64_t Size = M.getDataLayout().getTypeStoreSize(Ty: ElemType);
2940 // Create the loop over the big sized data.
2941 // ptr = (void*)Elem;
2942 // ptrEnd = (void*) Elem + 1;
2943 // Step = 8;
2944 // while (ptr + Step < ptrEnd)
2945 // shuffle((int64_t)*ptr);
2946 // Step = 4;
2947 // while (ptr + Step < ptrEnd)
2948 // shuffle((int32_t)*ptr);
2949 // ...
2950 Type *IndexTy = Builder.getIndexTy(
2951 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
2952 Value *ElemPtr = DstAddr;
2953 Value *Ptr = SrcAddr;
2954 for (unsigned IntSize = 8; IntSize >= 1; IntSize /= 2) {
2955 if (Size < IntSize)
2956 continue;
2957 Type *IntType = Builder.getIntNTy(N: IntSize * 8);
2958 Ptr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2959 V: Ptr, DestTy: Builder.getPtrTy(AddrSpace: 0), Name: Ptr->getName() + ".ascast");
2960 Value *SrcAddrGEP =
2961 Builder.CreateGEP(Ty: ElemType, Ptr: SrcAddr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
2962 ElemPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2963 V: ElemPtr, DestTy: Builder.getPtrTy(AddrSpace: 0), Name: ElemPtr->getName() + ".ascast");
2964
2965 Function *CurFunc = Builder.GetInsertBlock()->getParent();
2966 if ((Size / IntSize) > 1) {
2967 Value *PtrEnd = Builder.CreatePointerBitCastOrAddrSpaceCast(
2968 V: SrcAddrGEP, DestTy: Builder.getPtrTy());
2969 BasicBlock *PreCondBB =
2970 BasicBlock::Create(Context&: M.getContext(), Name: ".shuffle.pre_cond");
2971 BasicBlock *ThenBB = BasicBlock::Create(Context&: M.getContext(), Name: ".shuffle.then");
2972 BasicBlock *ExitBB = BasicBlock::Create(Context&: M.getContext(), Name: ".shuffle.exit");
2973 BasicBlock *CurrentBB = Builder.GetInsertBlock();
2974 emitBlock(BB: PreCondBB, CurFn: CurFunc);
2975 PHINode *PhiSrc =
2976 Builder.CreatePHI(Ty: Ptr->getType(), /*NumReservedValues=*/2);
2977 PhiSrc->addIncoming(V: Ptr, BB: CurrentBB);
2978 PHINode *PhiDest =
2979 Builder.CreatePHI(Ty: ElemPtr->getType(), /*NumReservedValues=*/2);
2980 PhiDest->addIncoming(V: ElemPtr, BB: CurrentBB);
2981 Ptr = PhiSrc;
2982 ElemPtr = PhiDest;
2983 Value *PtrDiff = Builder.CreatePtrDiff(
2984 ElemTy: Builder.getInt8Ty(), LHS: PtrEnd,
2985 RHS: Builder.CreatePointerBitCastOrAddrSpaceCast(V: Ptr, DestTy: Builder.getPtrTy()));
2986 Builder.CreateCondBr(
2987 Cond: Builder.CreateICmpSGT(LHS: PtrDiff, RHS: Builder.getInt64(C: IntSize - 1)), True: ThenBB,
2988 False: ExitBB);
2989 emitBlock(BB: ThenBB, CurFn: CurFunc);
2990 Value *Res = createRuntimeShuffleFunction(
2991 AllocaIP,
2992 Element: Builder.CreateAlignedLoad(
2993 Ty: IntType, Ptr, Align: M.getDataLayout().getPrefTypeAlign(Ty: ElemType)),
2994 ElementType: IntType, Offset);
2995 Builder.CreateAlignedStore(Val: Res, Ptr: ElemPtr,
2996 Align: M.getDataLayout().getPrefTypeAlign(Ty: ElemType));
2997 Value *LocalPtr =
2998 Builder.CreateGEP(Ty: IntType, Ptr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
2999 Value *LocalElemPtr =
3000 Builder.CreateGEP(Ty: IntType, Ptr: ElemPtr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
3001 PhiSrc->addIncoming(V: LocalPtr, BB: ThenBB);
3002 PhiDest->addIncoming(V: LocalElemPtr, BB: ThenBB);
3003 emitBranch(Target: PreCondBB);
3004 emitBlock(BB: ExitBB, CurFn: CurFunc);
3005 } else {
3006 Value *Res = createRuntimeShuffleFunction(
3007 AllocaIP, Element: Builder.CreateLoad(Ty: IntType, Ptr), ElementType: IntType, Offset);
3008 if (ElemType->isIntegerTy() && ElemType->getScalarSizeInBits() <
3009 Res->getType()->getScalarSizeInBits())
3010 Res = Builder.CreateTrunc(V: Res, DestTy: ElemType);
3011 Builder.CreateStore(Val: Res, Ptr: ElemPtr);
3012 Ptr = Builder.CreateGEP(Ty: IntType, Ptr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
3013 ElemPtr =
3014 Builder.CreateGEP(Ty: IntType, Ptr: ElemPtr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
3015 }
3016 Size = Size % IntSize;
3017 }
3018}
3019
3020Error OpenMPIRBuilder::emitReductionListCopy(
3021 InsertPointTy AllocaIP, CopyAction Action, Type *ReductionArrayTy,
3022 ArrayRef<ReductionInfo> ReductionInfos, Value *SrcBase, Value *DestBase,
3023 ArrayRef<bool> IsByRef, CopyOptionsTy CopyOptions) {
3024 Type *IndexTy = Builder.getIndexTy(
3025 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3026 Value *RemoteLaneOffset = CopyOptions.RemoteLaneOffset;
3027
3028 // Iterates, element-by-element, through the source Reduce list and
3029 // make a copy.
3030 for (auto En : enumerate(First&: ReductionInfos)) {
3031 const ReductionInfo &RI = En.value();
3032 Value *SrcElementAddr = nullptr;
3033 AllocaInst *DestAlloca = nullptr;
3034 Value *DestElementAddr = nullptr;
3035 Value *DestElementPtrAddr = nullptr;
3036 // Should we shuffle in an element from a remote lane?
3037 bool ShuffleInElement = false;
3038 // Set to true to update the pointer in the dest Reduce list to a
3039 // newly created element.
3040 bool UpdateDestListPtr = false;
3041
3042 // Step 1.1: Get the address for the src element in the Reduce list.
3043 Value *SrcElementPtrAddr = Builder.CreateInBoundsGEP(
3044 Ty: ReductionArrayTy, Ptr: SrcBase,
3045 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3046 SrcElementAddr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: SrcElementPtrAddr);
3047
3048 // Step 1.2: Create a temporary to store the element in the destination
3049 // Reduce list.
3050 DestElementPtrAddr = Builder.CreateInBoundsGEP(
3051 Ty: ReductionArrayTy, Ptr: DestBase,
3052 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3053 bool IsByRefElem = (!IsByRef.empty() && IsByRef[En.index()]);
3054 switch (Action) {
3055 case CopyAction::RemoteLaneToThread: {
3056 InsertPointTy CurIP = Builder.saveIP();
3057 Builder.restoreIP(IP: AllocaIP);
3058
3059 Type *DestAllocaType =
3060 IsByRefElem ? RI.ByRefAllocatedType : RI.ElementType;
3061 DestAlloca = Builder.CreateAlloca(Ty: DestAllocaType, ArraySize: nullptr,
3062 Name: ".omp.reduction.element");
3063 DestAlloca->setAlignment(
3064 M.getDataLayout().getPrefTypeAlign(Ty: DestAllocaType));
3065 DestElementAddr = DestAlloca;
3066 DestElementAddr =
3067 Builder.CreateAddrSpaceCast(V: DestElementAddr, DestTy: Builder.getPtrTy(),
3068 Name: DestElementAddr->getName() + ".ascast");
3069 Builder.restoreIP(IP: CurIP);
3070 ShuffleInElement = true;
3071 UpdateDestListPtr = true;
3072 break;
3073 }
3074 case CopyAction::ThreadCopy: {
3075 DestElementAddr =
3076 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: DestElementPtrAddr);
3077 break;
3078 }
3079 }
3080
3081 // Now that all active lanes have read the element in the
3082 // Reduce list, shuffle over the value from the remote lane.
3083 if (ShuffleInElement) {
3084 Type *ShuffleType = RI.ElementType;
3085 Value *ShuffleSrcAddr = SrcElementAddr;
3086 Value *ShuffleDestAddr = DestElementAddr;
3087 AllocaInst *LocalStorage = nullptr;
3088
3089 if (IsByRefElem) {
3090 assert(RI.ByRefElementType && "Expected by-ref element type to be set");
3091 assert(RI.ByRefAllocatedType &&
3092 "Expected by-ref allocated type to be set");
3093 // For by-ref reductions, we need to copy from the remote lane the
3094 // actual value of the partial reduction computed by that remote lane;
3095 // rather than, for example, a pointer to that data or, even worse, a
3096 // pointer to the descriptor of the by-ref reduction element.
3097 ShuffleType = RI.ByRefElementType;
3098
3099 InsertPointOrErrorTy GenResult =
3100 RI.DataPtrPtrGen(Builder.saveIP(), ShuffleSrcAddr, ShuffleSrcAddr);
3101
3102 if (!GenResult)
3103 return GenResult.takeError();
3104
3105 ShuffleSrcAddr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ShuffleSrcAddr);
3106
3107 {
3108 InsertPointTy OldIP = Builder.saveIP();
3109 Builder.restoreIP(IP: AllocaIP);
3110
3111 LocalStorage = Builder.CreateAlloca(Ty: ShuffleType);
3112 Builder.restoreIP(IP: OldIP);
3113 ShuffleDestAddr = LocalStorage;
3114 }
3115 }
3116
3117 shuffleAndStore(AllocaIP, SrcAddr: ShuffleSrcAddr, DstAddr: ShuffleDestAddr, ElemType: ShuffleType,
3118 Offset: RemoteLaneOffset, ReductionArrayTy, IsByRefElem);
3119
3120 if (IsByRefElem) {
3121 // Copy descriptor from source and update base_ptr to shuffled data
3122 Value *DestDescriptorAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3123 V: DestAlloca, DestTy: Builder.getPtrTy(), Name: ".ascast");
3124
3125 InsertPointOrErrorTy GenResult = generateReductionDescriptor(
3126 DescriptorAddr: DestDescriptorAddr, DataPtr: LocalStorage, SrcDescriptorAddr: SrcElementAddr,
3127 DescriptorType: RI.ByRefAllocatedType, DataPtrPtrGen: RI.DataPtrPtrGen);
3128
3129 if (!GenResult)
3130 return GenResult.takeError();
3131 }
3132 } else {
3133 switch (RI.EvaluationKind) {
3134 case EvalKind::Scalar: {
3135 Value *Elem = Builder.CreateLoad(Ty: RI.ElementType, Ptr: SrcElementAddr);
3136 // Store the source element value to the dest element address.
3137 Builder.CreateStore(Val: Elem, Ptr: DestElementAddr);
3138 break;
3139 }
3140 case EvalKind::Complex: {
3141 Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
3142 Ty: RI.ElementType, Ptr: SrcElementAddr, Idx0: 0, Idx1: 0, Name: ".realp");
3143 Value *SrcReal = Builder.CreateLoad(
3144 Ty: RI.ElementType->getStructElementType(N: 0), Ptr: SrcRealPtr, Name: ".real");
3145 Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
3146 Ty: RI.ElementType, Ptr: SrcElementAddr, Idx0: 0, Idx1: 1, Name: ".imagp");
3147 Value *SrcImg = Builder.CreateLoad(
3148 Ty: RI.ElementType->getStructElementType(N: 1), Ptr: SrcImgPtr, Name: ".imag");
3149
3150 Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
3151 Ty: RI.ElementType, Ptr: DestElementAddr, Idx0: 0, Idx1: 0, Name: ".realp");
3152 Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
3153 Ty: RI.ElementType, Ptr: DestElementAddr, Idx0: 0, Idx1: 1, Name: ".imagp");
3154 Builder.CreateStore(Val: SrcReal, Ptr: DestRealPtr);
3155 Builder.CreateStore(Val: SrcImg, Ptr: DestImgPtr);
3156 break;
3157 }
3158 case EvalKind::Aggregate: {
3159 Value *SizeVal = Builder.getInt64(
3160 C: M.getDataLayout().getTypeStoreSize(Ty: RI.ElementType));
3161 Builder.CreateMemCpy(
3162 Dst: DestElementAddr, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
3163 Src: SrcElementAddr, SrcAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
3164 Size: SizeVal, isVolatile: false);
3165 break;
3166 }
3167 };
3168 }
3169
3170 // Step 3.1: Modify reference in dest Reduce list as needed.
3171 // Modifying the reference in Reduce list to point to the newly
3172 // created element. The element is live in the current function
3173 // scope and that of functions it invokes (i.e., reduce_function).
3174 // RemoteReduceData[i] = (void*)&RemoteElem
3175 if (UpdateDestListPtr) {
3176 Value *CastDestAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3177 V: DestElementAddr, DestTy: Builder.getPtrTy(),
3178 Name: DestElementAddr->getName() + ".ascast");
3179 Builder.CreateStore(Val: CastDestAddr, Ptr: DestElementPtrAddr);
3180 }
3181 }
3182
3183 return Error::success();
3184}
3185
3186Expected<Function *> OpenMPIRBuilder::emitInterWarpCopyFunction(
3187 const LocationDescription &Loc, ArrayRef<ReductionInfo> ReductionInfos,
3188 AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
3189 InsertPointTy SavedIP = Builder.saveIP();
3190 LLVMContext &Ctx = M.getContext();
3191 FunctionType *FuncTy = FunctionType::get(
3192 Result: Builder.getVoidTy(), Params: {Builder.getPtrTy(), Builder.getInt32Ty()},
3193 /* IsVarArg */ isVarArg: false);
3194 Function *WcFunc =
3195 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3196 N: "_omp_reduction_inter_warp_copy_func", M: &M);
3197 WcFunc->setAttributes(FuncAttrs);
3198 WcFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3199 WcFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3200 BasicBlock *EntryBB = BasicBlock::Create(Context&: M.getContext(), Name: "entry", Parent: WcFunc);
3201 Builder.SetInsertPoint(EntryBB);
3202
3203 // ReduceList: thread local Reduce list.
3204 // At the stage of the computation when this function is called, partially
3205 // aggregated values reside in the first lane of every active warp.
3206 Argument *ReduceListArg = WcFunc->getArg(i: 0);
3207 // NumWarps: number of warps active in the parallel region. This could
3208 // be smaller than 32 (max warps in a CTA) for partial block reduction.
3209 Argument *NumWarpsArg = WcFunc->getArg(i: 1);
3210
3211 // This array is used as a medium to transfer, one reduce element at a time,
3212 // the data from the first lane of every warp to lanes in the first warp
3213 // in order to perform the final step of a reduction in a parallel region
3214 // (reduction across warps). The array is placed in NVPTX __shared__ memory
3215 // for reduced latency, as well as to have a distinct copy for concurrently
3216 // executing target regions. The array is declared with common linkage so
3217 // as to be shared across compilation units.
3218 StringRef TransferMediumName =
3219 "__openmp_nvptx_data_transfer_temporary_storage";
3220 GlobalVariable *TransferMedium = M.getGlobalVariable(Name: TransferMediumName);
3221 unsigned WarpSize = Config.getGridValue().GV_Warp_Size;
3222 ArrayType *ArrayTy = ArrayType::get(ElementType: Builder.getInt32Ty(), NumElements: WarpSize);
3223 if (!TransferMedium) {
3224 TransferMedium = new GlobalVariable(
3225 M, ArrayTy, /*isConstant=*/false, GlobalVariable::WeakAnyLinkage,
3226 UndefValue::get(T: ArrayTy), TransferMediumName,
3227 /*InsertBefore=*/nullptr, GlobalVariable::NotThreadLocal,
3228 /*AddressSpace=*/3);
3229 }
3230
3231 // Get the CUDA thread id of the current OpenMP thread on the GPU.
3232 Value *GPUThreadID = getGPUThreadID();
3233 // nvptx_lane_id = nvptx_id % warpsize
3234 Value *LaneID = getNVPTXLaneID();
3235 // nvptx_warp_id = nvptx_id / warpsize
3236 Value *WarpID = getNVPTXWarpID();
3237
3238 InsertPointTy AllocaIP =
3239 InsertPointTy(Builder.GetInsertBlock(),
3240 Builder.GetInsertBlock()->getFirstInsertionPt());
3241 Type *Arg0Type = ReduceListArg->getType();
3242 Type *Arg1Type = NumWarpsArg->getType();
3243 Builder.restoreIP(IP: AllocaIP);
3244 AllocaInst *ReduceListAlloca = Builder.CreateAlloca(
3245 Ty: Arg0Type, ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3246 AllocaInst *NumWarpsAlloca =
3247 Builder.CreateAlloca(Ty: Arg1Type, ArraySize: nullptr, Name: NumWarpsArg->getName() + ".addr");
3248 Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3249 V: ReduceListAlloca, DestTy: Arg0Type, Name: ReduceListAlloca->getName() + ".ascast");
3250 Value *NumWarpsAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3251 V: NumWarpsAlloca, DestTy: Builder.getPtrTy(AddrSpace: 0),
3252 Name: NumWarpsAlloca->getName() + ".ascast");
3253 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListAddrCast);
3254 Builder.CreateStore(Val: NumWarpsArg, Ptr: NumWarpsAddrCast);
3255 AllocaIP = getInsertPointAfterInstr(I: NumWarpsAlloca);
3256 InsertPointTy CodeGenIP =
3257 getInsertPointAfterInstr(I: &Builder.GetInsertBlock()->back());
3258 Builder.restoreIP(IP: CodeGenIP);
3259
3260 Value *ReduceList =
3261 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListAddrCast);
3262
3263 for (auto En : enumerate(First&: ReductionInfos)) {
3264 //
3265 // Warp master copies reduce element to transfer medium in __shared__
3266 // memory.
3267 //
3268 const ReductionInfo &RI = En.value();
3269 bool IsByRefElem = !IsByRef.empty() && IsByRef[En.index()];
3270 unsigned RealTySize = M.getDataLayout().getTypeAllocSize(
3271 Ty: IsByRefElem ? RI.ByRefElementType : RI.ElementType);
3272 for (unsigned TySize = 4; TySize > 0 && RealTySize > 0; TySize /= 2) {
3273 Type *CType = Builder.getIntNTy(N: TySize * 8);
3274
3275 unsigned NumIters = RealTySize / TySize;
3276 if (NumIters == 0)
3277 continue;
3278 Value *Cnt = nullptr;
3279 Value *CntAddr = nullptr;
3280 BasicBlock *PrecondBB = nullptr;
3281 BasicBlock *ExitBB = nullptr;
3282 if (NumIters > 1) {
3283 CodeGenIP = Builder.saveIP();
3284 Builder.restoreIP(IP: AllocaIP);
3285 CntAddr =
3286 Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr, Name: ".cnt.addr");
3287
3288 CntAddr = Builder.CreateAddrSpaceCast(V: CntAddr, DestTy: Builder.getPtrTy(),
3289 Name: CntAddr->getName() + ".ascast");
3290 Builder.restoreIP(IP: CodeGenIP);
3291 Builder.CreateStore(Val: Constant::getNullValue(Ty: Builder.getInt32Ty()),
3292 Ptr: CntAddr,
3293 /*Volatile=*/isVolatile: false);
3294 PrecondBB = BasicBlock::Create(Context&: Ctx, Name: "precond");
3295 ExitBB = BasicBlock::Create(Context&: Ctx, Name: "exit");
3296 BasicBlock *BodyBB = BasicBlock::Create(Context&: Ctx, Name: "body");
3297 emitBlock(BB: PrecondBB, CurFn: Builder.GetInsertBlock()->getParent());
3298 Cnt = Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: CntAddr,
3299 /*Volatile=*/isVolatile: false);
3300 Value *Cmp = Builder.CreateICmpULT(
3301 LHS: Cnt, RHS: ConstantInt::get(Ty: Builder.getInt32Ty(), V: NumIters));
3302 Builder.CreateCondBr(Cond: Cmp, True: BodyBB, False: ExitBB);
3303 emitBlock(BB: BodyBB, CurFn: Builder.GetInsertBlock()->getParent());
3304 }
3305
3306 // kmpc_barrier.
3307 InsertPointOrErrorTy BarrierIP1 =
3308 createBarrier(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
3309 Kind: omp::Directive::OMPD_unknown,
3310 /* ForceSimpleCall */ false,
3311 /* CheckCancelFlag */ true);
3312 if (!BarrierIP1)
3313 return BarrierIP1.takeError();
3314 BasicBlock *ThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
3315 BasicBlock *ElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
3316 BasicBlock *MergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
3317
3318 // if (lane_id == 0)
3319 Value *IsWarpMaster = Builder.CreateIsNull(Arg: LaneID, Name: "warp_master");
3320 Builder.CreateCondBr(Cond: IsWarpMaster, True: ThenBB, False: ElseBB);
3321 emitBlock(BB: ThenBB, CurFn: Builder.GetInsertBlock()->getParent());
3322
3323 // Reduce element = LocalReduceList[i]
3324 auto *RedListArrayTy =
3325 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3326 Type *IndexTy = Builder.getIndexTy(
3327 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3328 Value *ElemPtrPtr =
3329 Builder.CreateInBoundsGEP(Ty: RedListArrayTy, Ptr: ReduceList,
3330 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0),
3331 ConstantInt::get(Ty: IndexTy, V: En.index())});
3332 // elemptr = ((CopyType*)(elemptrptr)) + I
3333 Value *ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtrPtr);
3334
3335 if (IsByRefElem) {
3336 InsertPointOrErrorTy GenRes =
3337 RI.DataPtrPtrGen(Builder.saveIP(), ElemPtr, ElemPtr);
3338
3339 if (!GenRes)
3340 return GenRes.takeError();
3341
3342 ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtr);
3343 }
3344
3345 if (NumIters > 1)
3346 ElemPtr = Builder.CreateGEP(Ty: Builder.getInt32Ty(), Ptr: ElemPtr, IdxList: Cnt);
3347
3348 // Get pointer to location in transfer medium.
3349 // MediumPtr = &medium[warp_id]
3350 Value *MediumPtr = Builder.CreateInBoundsGEP(
3351 Ty: ArrayTy, Ptr: TransferMedium, IdxList: {Builder.getInt64(C: 0), WarpID});
3352 // elem = *elemptr
3353 //*MediumPtr = elem
3354 Value *Elem = Builder.CreateLoad(Ty: CType, Ptr: ElemPtr);
3355 // Store the source element value to the dest element address.
3356 Builder.CreateStore(Val: Elem, Ptr: MediumPtr,
3357 /*IsVolatile*/ isVolatile: true);
3358 Builder.CreateBr(Dest: MergeBB);
3359
3360 // else
3361 emitBlock(BB: ElseBB, CurFn: Builder.GetInsertBlock()->getParent());
3362 Builder.CreateBr(Dest: MergeBB);
3363
3364 // endif
3365 emitBlock(BB: MergeBB, CurFn: Builder.GetInsertBlock()->getParent());
3366 InsertPointOrErrorTy BarrierIP2 =
3367 createBarrier(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
3368 Kind: omp::Directive::OMPD_unknown,
3369 /* ForceSimpleCall */ false,
3370 /* CheckCancelFlag */ true);
3371 if (!BarrierIP2)
3372 return BarrierIP2.takeError();
3373
3374 // Warp 0 copies reduce element from transfer medium
3375 BasicBlock *W0ThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
3376 BasicBlock *W0ElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
3377 BasicBlock *W0MergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
3378
3379 Value *NumWarpsVal =
3380 Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: NumWarpsAddrCast);
3381 // Up to 32 threads in warp 0 are active.
3382 Value *IsActiveThread =
3383 Builder.CreateICmpULT(LHS: GPUThreadID, RHS: NumWarpsVal, Name: "is_active_thread");
3384 Builder.CreateCondBr(Cond: IsActiveThread, True: W0ThenBB, False: W0ElseBB);
3385
3386 emitBlock(BB: W0ThenBB, CurFn: Builder.GetInsertBlock()->getParent());
3387
3388 // SecMediumPtr = &medium[tid]
3389 // SrcMediumVal = *SrcMediumPtr
3390 Value *SrcMediumPtrVal = Builder.CreateInBoundsGEP(
3391 Ty: ArrayTy, Ptr: TransferMedium, IdxList: {Builder.getInt64(C: 0), GPUThreadID});
3392 // TargetElemPtr = (CopyType*)(SrcDataAddr[i]) + I
3393 Value *TargetElemPtrPtr =
3394 Builder.CreateInBoundsGEP(Ty: RedListArrayTy, Ptr: ReduceList,
3395 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0),
3396 ConstantInt::get(Ty: IndexTy, V: En.index())});
3397 Value *TargetElemPtrVal =
3398 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: TargetElemPtrPtr);
3399 Value *TargetElemPtr = TargetElemPtrVal;
3400
3401 if (IsByRefElem) {
3402 InsertPointOrErrorTy GenRes =
3403 RI.DataPtrPtrGen(Builder.saveIP(), TargetElemPtr, TargetElemPtr);
3404
3405 if (!GenRes)
3406 return GenRes.takeError();
3407
3408 TargetElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: TargetElemPtr);
3409 }
3410
3411 if (NumIters > 1)
3412 TargetElemPtr =
3413 Builder.CreateGEP(Ty: Builder.getInt32Ty(), Ptr: TargetElemPtr, IdxList: Cnt);
3414
3415 // *TargetElemPtr = SrcMediumVal;
3416 Value *SrcMediumValue =
3417 Builder.CreateLoad(Ty: CType, Ptr: SrcMediumPtrVal, /*IsVolatile*/ isVolatile: true);
3418 Builder.CreateStore(Val: SrcMediumValue, Ptr: TargetElemPtr);
3419 Builder.CreateBr(Dest: W0MergeBB);
3420
3421 emitBlock(BB: W0ElseBB, CurFn: Builder.GetInsertBlock()->getParent());
3422 Builder.CreateBr(Dest: W0MergeBB);
3423
3424 emitBlock(BB: W0MergeBB, CurFn: Builder.GetInsertBlock()->getParent());
3425
3426 if (NumIters > 1) {
3427 Cnt = Builder.CreateNSWAdd(
3428 LHS: Cnt, RHS: ConstantInt::get(Ty: Builder.getInt32Ty(), /*V=*/1));
3429 Builder.CreateStore(Val: Cnt, Ptr: CntAddr, /*Volatile=*/isVolatile: false);
3430
3431 auto *CurFn = Builder.GetInsertBlock()->getParent();
3432 emitBranch(Target: PrecondBB);
3433 emitBlock(BB: ExitBB, CurFn);
3434 }
3435 RealTySize %= TySize;
3436 }
3437 }
3438
3439 Builder.CreateRetVoid();
3440 Builder.restoreIP(IP: SavedIP);
3441
3442 return WcFunc;
3443}
3444
3445Expected<Function *> OpenMPIRBuilder::emitShuffleAndReduceFunction(
3446 ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
3447 AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
3448 LLVMContext &Ctx = M.getContext();
3449 FunctionType *FuncTy =
3450 FunctionType::get(Result: Builder.getVoidTy(),
3451 Params: {Builder.getPtrTy(), Builder.getInt16Ty(),
3452 Builder.getInt16Ty(), Builder.getInt16Ty()},
3453 /* IsVarArg */ isVarArg: false);
3454 Function *SarFunc =
3455 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3456 N: "_omp_reduction_shuffle_and_reduce_func", M: &M);
3457 SarFunc->setAttributes(FuncAttrs);
3458 SarFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3459 SarFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3460 SarFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
3461 SarFunc->addParamAttr(ArgNo: 3, Kind: Attribute::NoUndef);
3462 SarFunc->addParamAttr(ArgNo: 1, Kind: Attribute::SExt);
3463 SarFunc->addParamAttr(ArgNo: 2, Kind: Attribute::SExt);
3464 SarFunc->addParamAttr(ArgNo: 3, Kind: Attribute::SExt);
3465 BasicBlock *EntryBB = BasicBlock::Create(Context&: M.getContext(), Name: "entry", Parent: SarFunc);
3466 Builder.SetInsertPoint(EntryBB);
3467
3468 // Thread local Reduce list used to host the values of data to be reduced.
3469 Argument *ReduceListArg = SarFunc->getArg(i: 0);
3470 // Current lane id; could be logical.
3471 Argument *LaneIDArg = SarFunc->getArg(i: 1);
3472 // Offset of the remote source lane relative to the current lane.
3473 Argument *RemoteLaneOffsetArg = SarFunc->getArg(i: 2);
3474 // Algorithm version. This is expected to be known at compile time.
3475 Argument *AlgoVerArg = SarFunc->getArg(i: 3);
3476
3477 Type *ReduceListArgType = ReduceListArg->getType();
3478 Type *LaneIDArgType = LaneIDArg->getType();
3479 Type *LaneIDArgPtrType = Builder.getPtrTy(AddrSpace: 0);
3480 Value *ReduceListAlloca = Builder.CreateAlloca(
3481 Ty: ReduceListArgType, ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3482 Value *LaneIdAlloca = Builder.CreateAlloca(Ty: LaneIDArgType, ArraySize: nullptr,
3483 Name: LaneIDArg->getName() + ".addr");
3484 Value *RemoteLaneOffsetAlloca = Builder.CreateAlloca(
3485 Ty: LaneIDArgType, ArraySize: nullptr, Name: RemoteLaneOffsetArg->getName() + ".addr");
3486 Value *AlgoVerAlloca = Builder.CreateAlloca(Ty: LaneIDArgType, ArraySize: nullptr,
3487 Name: AlgoVerArg->getName() + ".addr");
3488 ArrayType *RedListArrayTy =
3489 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3490
3491 // Create a local thread-private variable to host the Reduce list
3492 // from a remote lane.
3493 Instruction *RemoteReductionListAlloca = Builder.CreateAlloca(
3494 Ty: RedListArrayTy, ArraySize: nullptr, Name: ".omp.reduction.remote_reduce_list");
3495
3496 Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3497 V: ReduceListAlloca, DestTy: ReduceListArgType,
3498 Name: ReduceListAlloca->getName() + ".ascast");
3499 Value *LaneIdAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3500 V: LaneIdAlloca, DestTy: LaneIDArgPtrType, Name: LaneIdAlloca->getName() + ".ascast");
3501 Value *RemoteLaneOffsetAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3502 V: RemoteLaneOffsetAlloca, DestTy: LaneIDArgPtrType,
3503 Name: RemoteLaneOffsetAlloca->getName() + ".ascast");
3504 Value *AlgoVerAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3505 V: AlgoVerAlloca, DestTy: LaneIDArgPtrType, Name: AlgoVerAlloca->getName() + ".ascast");
3506 Value *RemoteListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3507 V: RemoteReductionListAlloca, DestTy: Builder.getPtrTy(),
3508 Name: RemoteReductionListAlloca->getName() + ".ascast");
3509
3510 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListAddrCast);
3511 Builder.CreateStore(Val: LaneIDArg, Ptr: LaneIdAddrCast);
3512 Builder.CreateStore(Val: RemoteLaneOffsetArg, Ptr: RemoteLaneOffsetAddrCast);
3513 Builder.CreateStore(Val: AlgoVerArg, Ptr: AlgoVerAddrCast);
3514
3515 Value *ReduceList = Builder.CreateLoad(Ty: ReduceListArgType, Ptr: ReduceListAddrCast);
3516 Value *LaneId = Builder.CreateLoad(Ty: LaneIDArgType, Ptr: LaneIdAddrCast);
3517 Value *RemoteLaneOffset =
3518 Builder.CreateLoad(Ty: LaneIDArgType, Ptr: RemoteLaneOffsetAddrCast);
3519 Value *AlgoVer = Builder.CreateLoad(Ty: LaneIDArgType, Ptr: AlgoVerAddrCast);
3520
3521 InsertPointTy AllocaIP = getInsertPointAfterInstr(I: RemoteReductionListAlloca);
3522
3523 // This loop iterates through the list of reduce elements and copies,
3524 // element by element, from a remote lane in the warp to RemoteReduceList,
3525 // hosted on the thread's stack.
3526 Error EmitRedLsCpRes = emitReductionListCopy(
3527 AllocaIP, Action: CopyAction::RemoteLaneToThread, ReductionArrayTy: RedListArrayTy, ReductionInfos,
3528 SrcBase: ReduceList, DestBase: RemoteListAddrCast, IsByRef,
3529 CopyOptions: {.RemoteLaneOffset: RemoteLaneOffset, .ScratchpadIndex: nullptr, .ScratchpadWidth: nullptr});
3530
3531 if (EmitRedLsCpRes)
3532 return EmitRedLsCpRes;
3533
3534 // The actions to be performed on the Remote Reduce list is dependent
3535 // on the algorithm version.
3536 //
3537 // if (AlgoVer==0) || (AlgoVer==1 && (LaneId < Offset)) || (AlgoVer==2 &&
3538 // LaneId % 2 == 0 && Offset > 0):
3539 // do the reduction value aggregation
3540 //
3541 // The thread local variable Reduce list is mutated in place to host the
3542 // reduced data, which is the aggregated value produced from local and
3543 // remote lanes.
3544 //
3545 // Note that AlgoVer is expected to be a constant integer known at compile
3546 // time.
3547 // When AlgoVer==0, the first conjunction evaluates to true, making
3548 // the entire predicate true during compile time.
3549 // When AlgoVer==1, the second conjunction has only the second part to be
3550 // evaluated during runtime. Other conjunctions evaluates to false
3551 // during compile time.
3552 // When AlgoVer==2, the third conjunction has only the second part to be
3553 // evaluated during runtime. Other conjunctions evaluates to false
3554 // during compile time.
3555 Value *CondAlgo0 = Builder.CreateIsNull(Arg: AlgoVer);
3556 Value *Algo1 = Builder.CreateICmpEQ(LHS: AlgoVer, RHS: Builder.getInt16(C: 1));
3557 Value *LaneComp = Builder.CreateICmpULT(LHS: LaneId, RHS: RemoteLaneOffset);
3558 Value *CondAlgo1 = Builder.CreateAnd(LHS: Algo1, RHS: LaneComp);
3559 Value *Algo2 = Builder.CreateICmpEQ(LHS: AlgoVer, RHS: Builder.getInt16(C: 2));
3560 Value *LaneIdAnd1 = Builder.CreateAnd(LHS: LaneId, RHS: Builder.getInt16(C: 1));
3561 Value *LaneIdComp = Builder.CreateIsNull(Arg: LaneIdAnd1);
3562 Value *Algo2AndLaneIdComp = Builder.CreateAnd(LHS: Algo2, RHS: LaneIdComp);
3563 Value *RemoteOffsetComp =
3564 Builder.CreateICmpSGT(LHS: RemoteLaneOffset, RHS: Builder.getInt16(C: 0));
3565 Value *CondAlgo2 = Builder.CreateAnd(LHS: Algo2AndLaneIdComp, RHS: RemoteOffsetComp);
3566 Value *CA0OrCA1 = Builder.CreateOr(LHS: CondAlgo0, RHS: CondAlgo1);
3567 Value *CondReduce = Builder.CreateOr(LHS: CA0OrCA1, RHS: CondAlgo2);
3568
3569 BasicBlock *ThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
3570 BasicBlock *ElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
3571 BasicBlock *MergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
3572
3573 Builder.CreateCondBr(Cond: CondReduce, True: ThenBB, False: ElseBB);
3574 emitBlock(BB: ThenBB, CurFn: Builder.GetInsertBlock()->getParent());
3575 Value *LocalReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3576 V: ReduceList, DestTy: Builder.getPtrTy());
3577 Value *RemoteReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3578 V: RemoteListAddrCast, DestTy: Builder.getPtrTy());
3579 createRuntimeFunctionCall(Callee: ReduceFn, Args: {LocalReduceListPtr, RemoteReduceListPtr})
3580 ->addFnAttr(Kind: Attribute::NoUnwind);
3581 Builder.CreateBr(Dest: MergeBB);
3582
3583 emitBlock(BB: ElseBB, CurFn: Builder.GetInsertBlock()->getParent());
3584 Builder.CreateBr(Dest: MergeBB);
3585
3586 emitBlock(BB: MergeBB, CurFn: Builder.GetInsertBlock()->getParent());
3587
3588 // if (AlgoVer==1 && (LaneId >= Offset)) copy Remote Reduce list to local
3589 // Reduce list.
3590 Algo1 = Builder.CreateICmpEQ(LHS: AlgoVer, RHS: Builder.getInt16(C: 1));
3591 Value *LaneIdGtOffset = Builder.CreateICmpUGE(LHS: LaneId, RHS: RemoteLaneOffset);
3592 Value *CondCopy = Builder.CreateAnd(LHS: Algo1, RHS: LaneIdGtOffset);
3593
3594 BasicBlock *CpyThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
3595 BasicBlock *CpyElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
3596 BasicBlock *CpyMergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
3597 Builder.CreateCondBr(Cond: CondCopy, True: CpyThenBB, False: CpyElseBB);
3598
3599 emitBlock(BB: CpyThenBB, CurFn: Builder.GetInsertBlock()->getParent());
3600
3601 EmitRedLsCpRes = emitReductionListCopy(
3602 AllocaIP, Action: CopyAction::ThreadCopy, ReductionArrayTy: RedListArrayTy, ReductionInfos,
3603 SrcBase: RemoteListAddrCast, DestBase: ReduceList, IsByRef);
3604
3605 if (EmitRedLsCpRes)
3606 return EmitRedLsCpRes;
3607
3608 Builder.CreateBr(Dest: CpyMergeBB);
3609
3610 emitBlock(BB: CpyElseBB, CurFn: Builder.GetInsertBlock()->getParent());
3611 Builder.CreateBr(Dest: CpyMergeBB);
3612
3613 emitBlock(BB: CpyMergeBB, CurFn: Builder.GetInsertBlock()->getParent());
3614
3615 Builder.CreateRetVoid();
3616
3617 return SarFunc;
3618}
3619
3620OpenMPIRBuilder::InsertPointOrErrorTy
3621OpenMPIRBuilder::generateReductionDescriptor(
3622 Value *DescriptorAddr, Value *DataPtr, Value *SrcDescriptorAddr,
3623 Type *DescriptorType,
3624 function_ref<InsertPointOrErrorTy(InsertPointTy, Value *, Value *&)>
3625 DataPtrPtrGen) {
3626
3627 // Copy the source descriptor to preserve all metadata (rank, extents,
3628 // strides, etc.)
3629 Value *DescriptorSize =
3630 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: DescriptorType));
3631 Builder.CreateMemCpy(
3632 Dst: DescriptorAddr, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: DescriptorType),
3633 Src: SrcDescriptorAddr, SrcAlign: M.getDataLayout().getPrefTypeAlign(Ty: DescriptorType),
3634 Size: DescriptorSize);
3635
3636 // Update the base pointer field to point to the local shuffled data
3637 Value *DataPtrField;
3638 InsertPointOrErrorTy GenResult =
3639 DataPtrPtrGen(Builder.saveIP(), DescriptorAddr, DataPtrField);
3640
3641 if (!GenResult)
3642 return GenResult.takeError();
3643
3644 Builder.CreateStore(Val: Builder.CreatePointerBitCastOrAddrSpaceCast(
3645 V: DataPtr, DestTy: Builder.getPtrTy(), Name: ".ascast"),
3646 Ptr: DataPtrField);
3647
3648 return Builder.saveIP();
3649}
3650
3651Expected<Function *> OpenMPIRBuilder::emitListToGlobalCopyFunction(
3652 ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
3653 AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
3654 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
3655 LLVMContext &Ctx = M.getContext();
3656 FunctionType *FuncTy = FunctionType::get(
3657 Result: Builder.getVoidTy(),
3658 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3659 /* IsVarArg */ isVarArg: false);
3660 Function *LtGCFunc =
3661 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3662 N: "_omp_reduction_list_to_global_copy_func", M: &M);
3663 LtGCFunc->setAttributes(FuncAttrs);
3664 LtGCFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3665 LtGCFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3666 LtGCFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
3667
3668 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: LtGCFunc);
3669 Builder.SetInsertPoint(EntryBlock);
3670
3671 // Buffer: global reduction buffer.
3672 Argument *BufferArg = LtGCFunc->getArg(i: 0);
3673 // Idx: index of the buffer.
3674 Argument *IdxArg = LtGCFunc->getArg(i: 1);
3675 // ReduceList: thread local Reduce list.
3676 Argument *ReduceListArg = LtGCFunc->getArg(i: 2);
3677
3678 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
3679 Name: BufferArg->getName() + ".addr");
3680 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
3681 Name: IdxArg->getName() + ".addr");
3682 Value *ReduceListArgAlloca = Builder.CreateAlloca(
3683 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3684 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3685 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
3686 Name: BufferArgAlloca->getName() + ".ascast");
3687 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3688 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
3689 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3690 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
3691 Name: ReduceListArgAlloca->getName() + ".ascast");
3692
3693 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
3694 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
3695 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
3696
3697 Value *LocalReduceList =
3698 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
3699 Value *BufferArgVal =
3700 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
3701 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
3702 Type *IndexTy = Builder.getIndexTy(
3703 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3704 for (auto En : enumerate(First&: ReductionInfos)) {
3705 const ReductionInfo &RI = En.value();
3706 auto *RedListArrayTy =
3707 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3708 // Reduce element = LocalReduceList[i]
3709 Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
3710 Ty: RedListArrayTy, Ptr: LocalReduceList,
3711 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3712 // elemptr = ((CopyType*)(elemptrptr)) + I
3713 Value *ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtrPtr);
3714
3715 // Global = Buffer.VD[Idx];
3716 Value *BufferVD =
3717 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferArgVal, IdxList: Idxs);
3718 Value *GlobVal = Builder.CreateConstInBoundsGEP2_32(
3719 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
3720
3721 switch (RI.EvaluationKind) {
3722 case EvalKind::Scalar: {
3723 Value *TargetElement;
3724
3725 if (IsByRef.empty() || !IsByRef[En.index()]) {
3726 TargetElement = Builder.CreateLoad(Ty: RI.ElementType, Ptr: ElemPtr);
3727 } else {
3728 InsertPointOrErrorTy GenResult =
3729 RI.DataPtrPtrGen(Builder.saveIP(), ElemPtr, ElemPtr);
3730
3731 if (!GenResult)
3732 return GenResult.takeError();
3733
3734 ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtr);
3735 TargetElement = Builder.CreateLoad(Ty: RI.ByRefElementType, Ptr: ElemPtr);
3736 }
3737
3738 Builder.CreateStore(Val: TargetElement, Ptr: GlobVal);
3739 break;
3740 }
3741 case EvalKind::Complex: {
3742 Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
3743 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 0, Name: ".realp");
3744 Value *SrcReal = Builder.CreateLoad(
3745 Ty: RI.ElementType->getStructElementType(N: 0), Ptr: SrcRealPtr, Name: ".real");
3746 Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
3747 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 1, Name: ".imagp");
3748 Value *SrcImg = Builder.CreateLoad(
3749 Ty: RI.ElementType->getStructElementType(N: 1), Ptr: SrcImgPtr, Name: ".imag");
3750
3751 Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
3752 Ty: RI.ElementType, Ptr: GlobVal, Idx0: 0, Idx1: 0, Name: ".realp");
3753 Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
3754 Ty: RI.ElementType, Ptr: GlobVal, Idx0: 0, Idx1: 1, Name: ".imagp");
3755 Builder.CreateStore(Val: SrcReal, Ptr: DestRealPtr);
3756 Builder.CreateStore(Val: SrcImg, Ptr: DestImgPtr);
3757 break;
3758 }
3759 case EvalKind::Aggregate: {
3760 Value *SizeVal =
3761 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: RI.ElementType));
3762 Builder.CreateMemCpy(
3763 Dst: GlobVal, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType), Src: ElemPtr,
3764 SrcAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType), Size: SizeVal, isVolatile: false);
3765 break;
3766 }
3767 }
3768 }
3769
3770 Builder.CreateRetVoid();
3771 Builder.restoreIP(IP: OldIP);
3772 return LtGCFunc;
3773}
3774
3775Expected<Function *> OpenMPIRBuilder::emitListToGlobalReduceFunction(
3776 ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
3777 Type *ReductionsBufferTy, AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
3778 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
3779 LLVMContext &Ctx = M.getContext();
3780 FunctionType *FuncTy = FunctionType::get(
3781 Result: Builder.getVoidTy(),
3782 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3783 /* IsVarArg */ isVarArg: false);
3784 Function *LtGRFunc =
3785 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3786 N: "_omp_reduction_list_to_global_reduce_func", M: &M);
3787 LtGRFunc->setAttributes(FuncAttrs);
3788 LtGRFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3789 LtGRFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3790 LtGRFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
3791
3792 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: LtGRFunc);
3793 Builder.SetInsertPoint(EntryBlock);
3794
3795 // Buffer: global reduction buffer.
3796 Argument *BufferArg = LtGRFunc->getArg(i: 0);
3797 // Idx: index of the buffer.
3798 Argument *IdxArg = LtGRFunc->getArg(i: 1);
3799 // ReduceList: thread local Reduce list.
3800 Argument *ReduceListArg = LtGRFunc->getArg(i: 2);
3801
3802 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
3803 Name: BufferArg->getName() + ".addr");
3804 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
3805 Name: IdxArg->getName() + ".addr");
3806 Value *ReduceListArgAlloca = Builder.CreateAlloca(
3807 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3808 auto *RedListArrayTy =
3809 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3810
3811 // 1. Build a list of reduction variables.
3812 // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
3813 Value *LocalReduceList =
3814 Builder.CreateAlloca(Ty: RedListArrayTy, ArraySize: nullptr, Name: ".omp.reduction.red_list");
3815
3816 InsertPointTy AllocaIP{EntryBlock, EntryBlock->begin()};
3817
3818 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3819 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
3820 Name: BufferArgAlloca->getName() + ".ascast");
3821 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3822 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
3823 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3824 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
3825 Name: ReduceListArgAlloca->getName() + ".ascast");
3826 Value *LocalReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3827 V: LocalReduceList, DestTy: Builder.getPtrTy(),
3828 Name: LocalReduceList->getName() + ".ascast");
3829
3830 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
3831 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
3832 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
3833
3834 Value *BufferVal = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
3835 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
3836 Type *IndexTy = Builder.getIndexTy(
3837 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3838 for (auto En : enumerate(First&: ReductionInfos)) {
3839 const ReductionInfo &RI = En.value();
3840 Value *ByRefAlloc;
3841
3842 if (!IsByRef.empty() && IsByRef[En.index()]) {
3843 InsertPointTy OldIP = Builder.saveIP();
3844 Builder.restoreIP(IP: AllocaIP);
3845
3846 ByRefAlloc = Builder.CreateAlloca(Ty: RI.ByRefAllocatedType);
3847 ByRefAlloc = Builder.CreatePointerBitCastOrAddrSpaceCast(
3848 V: ByRefAlloc, DestTy: Builder.getPtrTy(), Name: ByRefAlloc->getName() + ".ascast");
3849
3850 Builder.restoreIP(IP: OldIP);
3851 }
3852
3853 Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
3854 Ty: RedListArrayTy, Ptr: LocalReduceListAddrCast,
3855 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3856 Value *BufferVD =
3857 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferVal, IdxList: Idxs);
3858 // Global = Buffer.VD[Idx];
3859 Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3860 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
3861
3862 if (!IsByRef.empty() && IsByRef[En.index()]) {
3863 // Get source descriptor from the reduce list argument
3864 Value *ReduceList =
3865 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
3866 Value *SrcElementPtrPtr =
3867 Builder.CreateInBoundsGEP(Ty: RedListArrayTy, Ptr: ReduceList,
3868 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0),
3869 ConstantInt::get(Ty: IndexTy, V: En.index())});
3870 Value *SrcDescriptorAddr =
3871 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: SrcElementPtrPtr);
3872
3873 // Copy descriptor from source and update base_ptr to global buffer data
3874 InsertPointOrErrorTy GenResult =
3875 generateReductionDescriptor(DescriptorAddr: ByRefAlloc, DataPtr: GlobValPtr, SrcDescriptorAddr,
3876 DescriptorType: RI.ByRefAllocatedType, DataPtrPtrGen: RI.DataPtrPtrGen);
3877
3878 if (!GenResult)
3879 return GenResult.takeError();
3880
3881 Builder.CreateStore(Val: ByRefAlloc, Ptr: TargetElementPtrPtr);
3882 } else {
3883 Builder.CreateStore(Val: GlobValPtr, Ptr: TargetElementPtrPtr);
3884 }
3885 }
3886
3887 // Call reduce_function(GlobalReduceList, ReduceList)
3888 Value *ReduceList =
3889 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
3890 createRuntimeFunctionCall(Callee: ReduceFn, Args: {LocalReduceListAddrCast, ReduceList})
3891 ->addFnAttr(Kind: Attribute::NoUnwind);
3892 Builder.CreateRetVoid();
3893 Builder.restoreIP(IP: OldIP);
3894 return LtGRFunc;
3895}
3896
3897Expected<Function *> OpenMPIRBuilder::emitGlobalToListCopyFunction(
3898 ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
3899 AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
3900 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
3901 LLVMContext &Ctx = M.getContext();
3902 FunctionType *FuncTy = FunctionType::get(
3903 Result: Builder.getVoidTy(),
3904 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3905 /* IsVarArg */ isVarArg: false);
3906 Function *GtLCFunc =
3907 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3908 N: "_omp_reduction_global_to_list_copy_func", M: &M);
3909 GtLCFunc->setAttributes(FuncAttrs);
3910 GtLCFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3911 GtLCFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3912 GtLCFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
3913
3914 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: GtLCFunc);
3915 Builder.SetInsertPoint(EntryBlock);
3916
3917 // Buffer: global reduction buffer.
3918 Argument *BufferArg = GtLCFunc->getArg(i: 0);
3919 // Idx: index of the buffer.
3920 Argument *IdxArg = GtLCFunc->getArg(i: 1);
3921 // ReduceList: thread local Reduce list.
3922 Argument *ReduceListArg = GtLCFunc->getArg(i: 2);
3923
3924 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
3925 Name: BufferArg->getName() + ".addr");
3926 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
3927 Name: IdxArg->getName() + ".addr");
3928 Value *ReduceListArgAlloca = Builder.CreateAlloca(
3929 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3930 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3931 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
3932 Name: BufferArgAlloca->getName() + ".ascast");
3933 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3934 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
3935 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3936 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
3937 Name: ReduceListArgAlloca->getName() + ".ascast");
3938 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
3939 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
3940 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
3941
3942 Value *LocalReduceList =
3943 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
3944 Value *BufferVal = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
3945 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
3946 Type *IndexTy = Builder.getIndexTy(
3947 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3948 for (auto En : enumerate(First&: ReductionInfos)) {
3949 const OpenMPIRBuilder::ReductionInfo &RI = En.value();
3950 auto *RedListArrayTy =
3951 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3952 // Reduce element = LocalReduceList[i]
3953 Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
3954 Ty: RedListArrayTy, Ptr: LocalReduceList,
3955 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3956 // elemptr = ((CopyType*)(elemptrptr)) + I
3957 Value *ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtrPtr);
3958 // Global = Buffer.VD[Idx];
3959 Value *BufferVD =
3960 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferVal, IdxList: Idxs);
3961 Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3962 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
3963
3964 switch (RI.EvaluationKind) {
3965 case EvalKind::Scalar: {
3966 Type *ElemType = RI.ElementType;
3967
3968 if (!IsByRef.empty() && IsByRef[En.index()]) {
3969 ElemType = RI.ByRefElementType;
3970 InsertPointOrErrorTy GenResult =
3971 RI.DataPtrPtrGen(Builder.saveIP(), ElemPtr, ElemPtr);
3972
3973 if (!GenResult)
3974 return GenResult.takeError();
3975
3976 ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtr);
3977 }
3978
3979 Value *TargetElement = Builder.CreateLoad(Ty: ElemType, Ptr: GlobValPtr);
3980 Builder.CreateStore(Val: TargetElement, Ptr: ElemPtr);
3981 break;
3982 }
3983 case EvalKind::Complex: {
3984 Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
3985 Ty: RI.ElementType, Ptr: GlobValPtr, Idx0: 0, Idx1: 0, Name: ".realp");
3986 Value *SrcReal = Builder.CreateLoad(
3987 Ty: RI.ElementType->getStructElementType(N: 0), Ptr: SrcRealPtr, Name: ".real");
3988 Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
3989 Ty: RI.ElementType, Ptr: GlobValPtr, Idx0: 0, Idx1: 1, Name: ".imagp");
3990 Value *SrcImg = Builder.CreateLoad(
3991 Ty: RI.ElementType->getStructElementType(N: 1), Ptr: SrcImgPtr, Name: ".imag");
3992
3993 Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
3994 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 0, Name: ".realp");
3995 Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
3996 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 1, Name: ".imagp");
3997 Builder.CreateStore(Val: SrcReal, Ptr: DestRealPtr);
3998 Builder.CreateStore(Val: SrcImg, Ptr: DestImgPtr);
3999 break;
4000 }
4001 case EvalKind::Aggregate: {
4002 Value *SizeVal =
4003 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: RI.ElementType));
4004 Builder.CreateMemCpy(
4005 Dst: ElemPtr, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
4006 Src: GlobValPtr, SrcAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
4007 Size: SizeVal, isVolatile: false);
4008 break;
4009 }
4010 }
4011 }
4012
4013 Builder.CreateRetVoid();
4014 Builder.restoreIP(IP: OldIP);
4015 return GtLCFunc;
4016}
4017
4018Expected<Function *> OpenMPIRBuilder::emitGlobalToListReduceFunction(
4019 ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
4020 Type *ReductionsBufferTy, AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
4021 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
4022 LLVMContext &Ctx = M.getContext();
4023 auto *FuncTy = FunctionType::get(
4024 Result: Builder.getVoidTy(),
4025 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
4026 /* IsVarArg */ isVarArg: false);
4027 Function *GtLRFunc =
4028 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
4029 N: "_omp_reduction_global_to_list_reduce_func", M: &M);
4030 GtLRFunc->setAttributes(FuncAttrs);
4031 GtLRFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
4032 GtLRFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
4033 GtLRFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
4034
4035 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: GtLRFunc);
4036 Builder.SetInsertPoint(EntryBlock);
4037
4038 // Buffer: global reduction buffer.
4039 Argument *BufferArg = GtLRFunc->getArg(i: 0);
4040 // Idx: index of the buffer.
4041 Argument *IdxArg = GtLRFunc->getArg(i: 1);
4042 // ReduceList: thread local Reduce list.
4043 Argument *ReduceListArg = GtLRFunc->getArg(i: 2);
4044
4045 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
4046 Name: BufferArg->getName() + ".addr");
4047 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
4048 Name: IdxArg->getName() + ".addr");
4049 Value *ReduceListArgAlloca = Builder.CreateAlloca(
4050 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
4051 ArrayType *RedListArrayTy =
4052 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
4053
4054 // 1. Build a list of reduction variables.
4055 // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
4056 Value *LocalReduceList =
4057 Builder.CreateAlloca(Ty: RedListArrayTy, ArraySize: nullptr, Name: ".omp.reduction.red_list");
4058
4059 InsertPointTy AllocaIP{EntryBlock, EntryBlock->begin()};
4060
4061 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4062 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
4063 Name: BufferArgAlloca->getName() + ".ascast");
4064 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4065 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
4066 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4067 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
4068 Name: ReduceListArgAlloca->getName() + ".ascast");
4069 Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
4070 V: LocalReduceList, DestTy: Builder.getPtrTy(),
4071 Name: LocalReduceList->getName() + ".ascast");
4072
4073 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
4074 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
4075 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
4076
4077 Value *BufferVal = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
4078 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
4079 Type *IndexTy = Builder.getIndexTy(
4080 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
4081 for (auto En : enumerate(First&: ReductionInfos)) {
4082 const ReductionInfo &RI = En.value();
4083 Value *ByRefAlloc;
4084
4085 if (!IsByRef.empty() && IsByRef[En.index()]) {
4086 InsertPointTy OldIP = Builder.saveIP();
4087 Builder.restoreIP(IP: AllocaIP);
4088
4089 ByRefAlloc = Builder.CreateAlloca(Ty: RI.ByRefAllocatedType);
4090 ByRefAlloc = Builder.CreatePointerBitCastOrAddrSpaceCast(
4091 V: ByRefAlloc, DestTy: Builder.getPtrTy(), Name: ByRefAlloc->getName() + ".ascast");
4092
4093 Builder.restoreIP(IP: OldIP);
4094 }
4095
4096 Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
4097 Ty: RedListArrayTy, Ptr: ReductionList,
4098 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
4099 // Global = Buffer.VD[Idx];
4100 Value *BufferVD =
4101 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferVal, IdxList: Idxs);
4102 Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
4103 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
4104
4105 if (!IsByRef.empty() && IsByRef[En.index()]) {
4106 // Get source descriptor from the reduce list
4107 Value *ReduceListVal =
4108 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
4109 Value *SrcElementPtrPtr =
4110 Builder.CreateInBoundsGEP(Ty: RedListArrayTy, Ptr: ReduceListVal,
4111 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0),
4112 ConstantInt::get(Ty: IndexTy, V: En.index())});
4113 Value *SrcDescriptorAddr =
4114 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: SrcElementPtrPtr);
4115
4116 // Copy descriptor from source and update base_ptr to global buffer data
4117 InsertPointOrErrorTy GenResult =
4118 generateReductionDescriptor(DescriptorAddr: ByRefAlloc, DataPtr: GlobValPtr, SrcDescriptorAddr,
4119 DescriptorType: RI.ByRefAllocatedType, DataPtrPtrGen: RI.DataPtrPtrGen);
4120 if (!GenResult)
4121 return GenResult.takeError();
4122
4123 Builder.CreateStore(Val: ByRefAlloc, Ptr: TargetElementPtrPtr);
4124 } else {
4125 Builder.CreateStore(Val: GlobValPtr, Ptr: TargetElementPtrPtr);
4126 }
4127 }
4128
4129 // Call reduce_function(ReduceList, GlobalReduceList)
4130 Value *ReduceList =
4131 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
4132 createRuntimeFunctionCall(Callee: ReduceFn, Args: {ReduceList, ReductionList})
4133 ->addFnAttr(Kind: Attribute::NoUnwind);
4134 Builder.CreateRetVoid();
4135 Builder.restoreIP(IP: OldIP);
4136 return GtLRFunc;
4137}
4138
4139std::string OpenMPIRBuilder::getReductionFuncName(StringRef Name) const {
4140 std::string Suffix =
4141 createPlatformSpecificName(Parts: {"omp", "reduction", "reduction_func"});
4142 return (Name + Suffix).str();
4143}
4144
4145Expected<Function *> OpenMPIRBuilder::createReductionFunction(
4146 StringRef ReducerName, ArrayRef<ReductionInfo> ReductionInfos,
4147 ArrayRef<bool> IsByRef, ReductionGenCBKind ReductionGenCBKind,
4148 AttributeList FuncAttrs) {
4149 auto *FuncTy = FunctionType::get(Result: Builder.getVoidTy(),
4150 Params: {Builder.getPtrTy(), Builder.getPtrTy()},
4151 /* IsVarArg */ isVarArg: false);
4152 std::string Name = getReductionFuncName(Name: ReducerName);
4153 Function *ReductionFunc =
4154 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage, N: Name, M: &M);
4155 ReductionFunc->setAttributes(FuncAttrs);
4156 ReductionFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
4157 ReductionFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
4158 BasicBlock *EntryBB =
4159 BasicBlock::Create(Context&: M.getContext(), Name: "entry", Parent: ReductionFunc);
4160 Builder.SetInsertPoint(EntryBB);
4161
4162 // Need to alloca memory here and deal with the pointers before getting
4163 // LHS/RHS pointers out
4164 Value *LHSArrayPtr = nullptr;
4165 Value *RHSArrayPtr = nullptr;
4166 Argument *Arg0 = ReductionFunc->getArg(i: 0);
4167 Argument *Arg1 = ReductionFunc->getArg(i: 1);
4168 Type *Arg0Type = Arg0->getType();
4169 Type *Arg1Type = Arg1->getType();
4170
4171 Value *LHSAlloca =
4172 Builder.CreateAlloca(Ty: Arg0Type, ArraySize: nullptr, Name: Arg0->getName() + ".addr");
4173 Value *RHSAlloca =
4174 Builder.CreateAlloca(Ty: Arg1Type, ArraySize: nullptr, Name: Arg1->getName() + ".addr");
4175 Value *LHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4176 V: LHSAlloca, DestTy: Arg0Type, Name: LHSAlloca->getName() + ".ascast");
4177 Value *RHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4178 V: RHSAlloca, DestTy: Arg1Type, Name: RHSAlloca->getName() + ".ascast");
4179 Builder.CreateStore(Val: Arg0, Ptr: LHSAddrCast);
4180 Builder.CreateStore(Val: Arg1, Ptr: RHSAddrCast);
4181 LHSArrayPtr = Builder.CreateLoad(Ty: Arg0Type, Ptr: LHSAddrCast);
4182 RHSArrayPtr = Builder.CreateLoad(Ty: Arg1Type, Ptr: RHSAddrCast);
4183
4184 Type *RedArrayTy = ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
4185 Type *IndexTy = Builder.getIndexTy(
4186 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
4187 SmallVector<Value *> LHSPtrs, RHSPtrs;
4188 for (auto En : enumerate(First&: ReductionInfos)) {
4189 const ReductionInfo &RI = En.value();
4190 Value *RHSI8PtrPtr = Builder.CreateInBoundsGEP(
4191 Ty: RedArrayTy, Ptr: RHSArrayPtr,
4192 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
4193 Value *RHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: RHSI8PtrPtr);
4194 Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
4195 V: RHSI8Ptr, DestTy: RI.PrivateVariable->getType(),
4196 Name: RHSI8Ptr->getName() + ".ascast");
4197
4198 Value *LHSI8PtrPtr = Builder.CreateInBoundsGEP(
4199 Ty: RedArrayTy, Ptr: LHSArrayPtr,
4200 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
4201 Value *LHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: LHSI8PtrPtr);
4202 Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
4203 V: LHSI8Ptr, DestTy: RI.Variable->getType(), Name: LHSI8Ptr->getName() + ".ascast");
4204
4205 if (ReductionGenCBKind == ReductionGenCBKind::Clang) {
4206 LHSPtrs.emplace_back(Args&: LHSPtr);
4207 RHSPtrs.emplace_back(Args&: RHSPtr);
4208 } else {
4209 Value *LHS = LHSPtr;
4210 Value *RHS = RHSPtr;
4211
4212 if (!IsByRef.empty() && !IsByRef[En.index()]) {
4213 LHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: LHSPtr);
4214 RHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: RHSPtr);
4215 }
4216
4217 Value *Reduced;
4218 InsertPointOrErrorTy AfterIP =
4219 RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
4220 if (!AfterIP)
4221 return AfterIP.takeError();
4222 if (!Builder.GetInsertBlock())
4223 return ReductionFunc;
4224
4225 Builder.restoreIP(IP: *AfterIP);
4226
4227 if (!IsByRef.empty() && !IsByRef[En.index()])
4228 Builder.CreateStore(Val: Reduced, Ptr: LHSPtr);
4229 }
4230 }
4231
4232 if (ReductionGenCBKind == ReductionGenCBKind::Clang)
4233 for (auto En : enumerate(First&: ReductionInfos)) {
4234 unsigned Index = En.index();
4235 const ReductionInfo &RI = En.value();
4236 Value *LHSFixupPtr, *RHSFixupPtr;
4237 Builder.restoreIP(IP: RI.ReductionGenClang(
4238 Builder.saveIP(), Index, &LHSFixupPtr, &RHSFixupPtr, ReductionFunc));
4239
4240 // Fix the CallBack code genereated to use the correct Values for the LHS
4241 // and RHS
4242 LHSFixupPtr->replaceUsesWithIf(
4243 New: LHSPtrs[Index], ShouldReplace: [ReductionFunc](const Use &U) {
4244 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
4245 ReductionFunc;
4246 });
4247 RHSFixupPtr->replaceUsesWithIf(
4248 New: RHSPtrs[Index], ShouldReplace: [ReductionFunc](const Use &U) {
4249 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
4250 ReductionFunc;
4251 });
4252 }
4253
4254 Builder.CreateRetVoid();
4255 // Compiling with `-O0`, `alloca`s emitted in non-entry blocks are not hoisted
4256 // to the entry block (this is dones for higher opt levels by later passes in
4257 // the pipeline). This has caused issues because non-entry `alloca`s force the
4258 // function to use dynamic stack allocations and we might run out of scratch
4259 // memory.
4260 hoistNonEntryAllocasToEntryBlock(Func: ReductionFunc);
4261
4262 return ReductionFunc;
4263}
4264
4265static void
4266checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
4267 bool IsGPU) {
4268 for (const OpenMPIRBuilder::ReductionInfo &RI : ReductionInfos) {
4269 (void)RI;
4270 assert(RI.Variable && "expected non-null variable");
4271 assert(RI.PrivateVariable && "expected non-null private variable");
4272 assert((RI.ReductionGen || RI.ReductionGenClang) &&
4273 "expected non-null reduction generator callback");
4274 if (!IsGPU) {
4275 assert(
4276 RI.Variable->getType() == RI.PrivateVariable->getType() &&
4277 "expected variables and their private equivalents to have the same "
4278 "type");
4279 }
4280 assert(RI.Variable->getType()->isPointerTy() &&
4281 "expected variables to be pointers");
4282 }
4283}
4284
4285OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
4286 const LocationDescription &Loc, InsertPointTy AllocaIP,
4287 InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
4288 ArrayRef<bool> IsByRef, bool IsNoWait, bool IsTeamsReduction,
4289 ReductionGenCBKind ReductionGenCBKind, std::optional<omp::GV> GridValue,
4290 unsigned ReductionBufNum, Value *SrcLocInfo) {
4291 if (!updateToLocation(Loc))
4292 return InsertPointTy();
4293 Builder.restoreIP(IP: CodeGenIP);
4294 checkReductionInfos(ReductionInfos, /*IsGPU*/ true);
4295 LLVMContext &Ctx = M.getContext();
4296
4297 // Source location for the ident struct
4298 if (!SrcLocInfo) {
4299 uint32_t SrcLocStrSize;
4300 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4301 SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4302 }
4303
4304 if (ReductionInfos.size() == 0)
4305 return Builder.saveIP();
4306
4307 BasicBlock *ContinuationBlock = nullptr;
4308 if (ReductionGenCBKind != ReductionGenCBKind::Clang) {
4309 // Copied code from createReductions
4310 BasicBlock *InsertBlock = Loc.IP.getBlock();
4311 ContinuationBlock =
4312 InsertBlock->splitBasicBlock(I: Loc.IP.getPoint(), BBName: "reduce.finalize");
4313 InsertBlock->getTerminator()->eraseFromParent();
4314 Builder.SetInsertPoint(TheBB: InsertBlock, IP: InsertBlock->end());
4315 }
4316
4317 Function *CurFunc = Builder.GetInsertBlock()->getParent();
4318 AttributeList FuncAttrs;
4319 AttrBuilder AttrBldr(Ctx);
4320 for (auto Attr : CurFunc->getAttributes().getFnAttrs())
4321 AttrBldr.addAttribute(A: Attr);
4322 AttrBldr.removeAttribute(Val: Attribute::OptimizeNone);
4323 FuncAttrs = FuncAttrs.addFnAttributes(C&: Ctx, B: AttrBldr);
4324
4325 CodeGenIP = Builder.saveIP();
4326 Expected<Function *> ReductionResult = createReductionFunction(
4327 ReducerName: Builder.GetInsertBlock()->getParent()->getName(), ReductionInfos, IsByRef,
4328 ReductionGenCBKind, FuncAttrs);
4329 if (!ReductionResult)
4330 return ReductionResult.takeError();
4331 Function *ReductionFunc = *ReductionResult;
4332 Builder.restoreIP(IP: CodeGenIP);
4333
4334 // Set the grid value in the config needed for lowering later on
4335 if (GridValue.has_value())
4336 Config.setGridValue(GridValue.value());
4337 else
4338 Config.setGridValue(getGridValue(T, Kernel: ReductionFunc));
4339
4340 // Build res = __kmpc_reduce{_nowait}(<gtid>, <n>, sizeof(RedList),
4341 // RedList, shuffle_reduce_func, interwarp_copy_func);
4342 // or
4343 // Build res = __kmpc_reduce_teams_nowait_simple(<loc>, <gtid>, <lck>);
4344 Value *Res;
4345
4346 // 1. Build a list of reduction variables.
4347 // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
4348 auto Size = ReductionInfos.size();
4349 Type *PtrTy = PointerType::get(C&: Ctx, AddressSpace: Config.getDefaultTargetAS());
4350 Type *FuncPtrTy =
4351 Builder.getPtrTy(AddrSpace: M.getDataLayout().getProgramAddressSpace());
4352 Type *RedArrayTy = ArrayType::get(ElementType: PtrTy, NumElements: Size);
4353 CodeGenIP = Builder.saveIP();
4354 Builder.restoreIP(IP: AllocaIP);
4355 Value *ReductionListAlloca =
4356 Builder.CreateAlloca(Ty: RedArrayTy, ArraySize: nullptr, Name: ".omp.reduction.red_list");
4357 Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
4358 V: ReductionListAlloca, DestTy: PtrTy, Name: ReductionListAlloca->getName() + ".ascast");
4359 Builder.restoreIP(IP: CodeGenIP);
4360 Type *IndexTy = Builder.getIndexTy(
4361 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
4362 for (auto En : enumerate(First&: ReductionInfos)) {
4363 const ReductionInfo &RI = En.value();
4364 Value *ElemPtr = Builder.CreateInBoundsGEP(
4365 Ty: RedArrayTy, Ptr: ReductionList,
4366 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
4367
4368 Value *PrivateVar = RI.PrivateVariable;
4369 bool IsByRefElem = !IsByRef.empty() && IsByRef[En.index()];
4370 if (IsByRefElem)
4371 PrivateVar = Builder.CreateLoad(Ty: RI.ElementType, Ptr: PrivateVar);
4372
4373 Value *CastElem =
4374 Builder.CreatePointerBitCastOrAddrSpaceCast(V: PrivateVar, DestTy: PtrTy);
4375 Builder.CreateStore(Val: CastElem, Ptr: ElemPtr);
4376 }
4377 CodeGenIP = Builder.saveIP();
4378 Expected<Function *> SarFunc = emitShuffleAndReduceFunction(
4379 ReductionInfos, ReduceFn: ReductionFunc, FuncAttrs, IsByRef);
4380
4381 if (!SarFunc)
4382 return SarFunc.takeError();
4383
4384 Expected<Function *> CopyResult =
4385 emitInterWarpCopyFunction(Loc, ReductionInfos, FuncAttrs, IsByRef);
4386 if (!CopyResult)
4387 return CopyResult.takeError();
4388 Function *WcFunc = *CopyResult;
4389 Builder.restoreIP(IP: CodeGenIP);
4390
4391 Value *RL = Builder.CreatePointerBitCastOrAddrSpaceCast(V: ReductionList, DestTy: PtrTy);
4392
4393 unsigned MaxDataSize = 0;
4394 SmallVector<Type *> ReductionTypeArgs;
4395 for (auto En : enumerate(First&: ReductionInfos)) {
4396 auto Size = M.getDataLayout().getTypeStoreSize(Ty: En.value().ElementType);
4397 if (Size > MaxDataSize)
4398 MaxDataSize = Size;
4399 Type *RedTypeArg = (!IsByRef.empty() && IsByRef[En.index()])
4400 ? En.value().ByRefElementType
4401 : En.value().ElementType;
4402 ReductionTypeArgs.emplace_back(Args&: RedTypeArg);
4403 }
4404 Value *ReductionDataSize =
4405 Builder.getInt64(C: MaxDataSize * ReductionInfos.size());
4406 if (!IsTeamsReduction) {
4407 Value *SarFuncCast =
4408 Builder.CreatePointerBitCastOrAddrSpaceCast(V: *SarFunc, DestTy: FuncPtrTy);
4409 Value *WcFuncCast =
4410 Builder.CreatePointerBitCastOrAddrSpaceCast(V: WcFunc, DestTy: FuncPtrTy);
4411 Value *Args[] = {SrcLocInfo, ReductionDataSize, RL, SarFuncCast,
4412 WcFuncCast};
4413 Function *Pv2Ptr = getOrCreateRuntimeFunctionPtr(
4414 FnID: RuntimeFunction::OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2);
4415 Res = createRuntimeFunctionCall(Callee: Pv2Ptr, Args);
4416 } else {
4417 CodeGenIP = Builder.saveIP();
4418 StructType *ReductionsBufferTy = StructType::create(
4419 Context&: Ctx, Elements: ReductionTypeArgs, Name: "struct._globalized_locals_ty");
4420 Function *RedFixedBufferFn = getOrCreateRuntimeFunctionPtr(
4421 FnID: RuntimeFunction::OMPRTL___kmpc_reduction_get_fixed_buffer);
4422
4423 Expected<Function *> LtGCFunc = emitListToGlobalCopyFunction(
4424 ReductionInfos, ReductionsBufferTy, FuncAttrs, IsByRef);
4425 if (!LtGCFunc)
4426 return LtGCFunc.takeError();
4427
4428 Expected<Function *> LtGRFunc = emitListToGlobalReduceFunction(
4429 ReductionInfos, ReduceFn: ReductionFunc, ReductionsBufferTy, FuncAttrs, IsByRef);
4430 if (!LtGRFunc)
4431 return LtGRFunc.takeError();
4432
4433 Expected<Function *> GtLCFunc = emitGlobalToListCopyFunction(
4434 ReductionInfos, ReductionsBufferTy, FuncAttrs, IsByRef);
4435 if (!GtLCFunc)
4436 return GtLCFunc.takeError();
4437
4438 Expected<Function *> GtLRFunc = emitGlobalToListReduceFunction(
4439 ReductionInfos, ReduceFn: ReductionFunc, ReductionsBufferTy, FuncAttrs, IsByRef);
4440 if (!GtLRFunc)
4441 return GtLRFunc.takeError();
4442
4443 Builder.restoreIP(IP: CodeGenIP);
4444
4445 Value *KernelTeamsReductionPtr = createRuntimeFunctionCall(
4446 Callee: RedFixedBufferFn, Args: {}, Name: "_openmp_teams_reductions_buffer_$_$ptr");
4447
4448 Value *Args3[] = {SrcLocInfo,
4449 KernelTeamsReductionPtr,
4450 Builder.getInt32(C: ReductionBufNum),
4451 ReductionDataSize,
4452 RL,
4453 *SarFunc,
4454 WcFunc,
4455 *LtGCFunc,
4456 *LtGRFunc,
4457 *GtLCFunc,
4458 *GtLRFunc};
4459
4460 Function *TeamsReduceFn = getOrCreateRuntimeFunctionPtr(
4461 FnID: RuntimeFunction::OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2);
4462 Res = createRuntimeFunctionCall(Callee: TeamsReduceFn, Args: Args3);
4463 }
4464
4465 // 5. Build if (res == 1)
4466 BasicBlock *ExitBB = BasicBlock::Create(Context&: Ctx, Name: ".omp.reduction.done");
4467 BasicBlock *ThenBB = BasicBlock::Create(Context&: Ctx, Name: ".omp.reduction.then");
4468 Value *Cond = Builder.CreateICmpEQ(LHS: Res, RHS: Builder.getInt32(C: 1));
4469 Builder.CreateCondBr(Cond, True: ThenBB, False: ExitBB);
4470
4471 // 6. Build then branch: where we have reduced values in the master
4472 // thread in each team.
4473 // __kmpc_end_reduce{_nowait}(<gtid>);
4474 // break;
4475 emitBlock(BB: ThenBB, CurFn: CurFunc);
4476
4477 // Add emission of __kmpc_end_reduce{_nowait}(<gtid>);
4478 for (auto En : enumerate(First&: ReductionInfos)) {
4479 const ReductionInfo &RI = En.value();
4480 Type *ValueType = RI.ElementType;
4481 Value *RedValue = RI.Variable;
4482 Value *RHS =
4483 Builder.CreatePointerBitCastOrAddrSpaceCast(V: RI.PrivateVariable, DestTy: PtrTy);
4484
4485 if (ReductionGenCBKind == ReductionGenCBKind::Clang) {
4486 Value *LHSPtr, *RHSPtr;
4487 Builder.restoreIP(IP: RI.ReductionGenClang(Builder.saveIP(), En.index(),
4488 &LHSPtr, &RHSPtr, CurFunc));
4489
4490 // Fix the CallBack code genereated to use the correct Values for the LHS
4491 // and RHS
4492 LHSPtr->replaceUsesWithIf(New: RedValue, ShouldReplace: [ReductionFunc](const Use &U) {
4493 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
4494 ReductionFunc;
4495 });
4496 RHSPtr->replaceUsesWithIf(New: RHS, ShouldReplace: [ReductionFunc](const Use &U) {
4497 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
4498 ReductionFunc;
4499 });
4500 } else {
4501 if (IsByRef.empty() || !IsByRef[En.index()]) {
4502 RedValue = Builder.CreateLoad(Ty: ValueType, Ptr: RI.Variable,
4503 Name: "red.value." + Twine(En.index()));
4504 }
4505 Value *PrivateRedValue = Builder.CreateLoad(
4506 Ty: ValueType, Ptr: RHS, Name: "red.private.value" + Twine(En.index()));
4507 Value *Reduced;
4508 InsertPointOrErrorTy AfterIP =
4509 RI.ReductionGen(Builder.saveIP(), RedValue, PrivateRedValue, Reduced);
4510 if (!AfterIP)
4511 return AfterIP.takeError();
4512 Builder.restoreIP(IP: *AfterIP);
4513
4514 if (!IsByRef.empty() && !IsByRef[En.index()])
4515 Builder.CreateStore(Val: Reduced, Ptr: RI.Variable);
4516 }
4517 }
4518 emitBlock(BB: ExitBB, CurFn: CurFunc);
4519 if (ContinuationBlock) {
4520 Builder.CreateBr(Dest: ContinuationBlock);
4521 Builder.SetInsertPoint(ContinuationBlock);
4522 }
4523 Config.setEmitLLVMUsed();
4524
4525 return Builder.saveIP();
4526}
4527
4528static Function *getFreshReductionFunc(Module &M) {
4529 Type *VoidTy = Type::getVoidTy(C&: M.getContext());
4530 Type *Int8PtrTy = PointerType::getUnqual(C&: M.getContext());
4531 auto *FuncTy =
4532 FunctionType::get(Result: VoidTy, Params: {Int8PtrTy, Int8PtrTy}, /* IsVarArg */ isVarArg: false);
4533 return Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
4534 N: ".omp.reduction.func", M: &M);
4535}
4536
4537static Error populateReductionFunction(
4538 Function *ReductionFunc,
4539 ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
4540 IRBuilder<> &Builder, ArrayRef<bool> IsByRef, bool IsGPU) {
4541 Module *Module = ReductionFunc->getParent();
4542 BasicBlock *ReductionFuncBlock =
4543 BasicBlock::Create(Context&: Module->getContext(), Name: "", Parent: ReductionFunc);
4544 Builder.SetInsertPoint(ReductionFuncBlock);
4545 Value *LHSArrayPtr = nullptr;
4546 Value *RHSArrayPtr = nullptr;
4547 if (IsGPU) {
4548 // Need to alloca memory here and deal with the pointers before getting
4549 // LHS/RHS pointers out
4550 //
4551 Argument *Arg0 = ReductionFunc->getArg(i: 0);
4552 Argument *Arg1 = ReductionFunc->getArg(i: 1);
4553 Type *Arg0Type = Arg0->getType();
4554 Type *Arg1Type = Arg1->getType();
4555
4556 Value *LHSAlloca =
4557 Builder.CreateAlloca(Ty: Arg0Type, ArraySize: nullptr, Name: Arg0->getName() + ".addr");
4558 Value *RHSAlloca =
4559 Builder.CreateAlloca(Ty: Arg1Type, ArraySize: nullptr, Name: Arg1->getName() + ".addr");
4560 Value *LHSAddrCast =
4561 Builder.CreatePointerBitCastOrAddrSpaceCast(V: LHSAlloca, DestTy: Arg0Type);
4562 Value *RHSAddrCast =
4563 Builder.CreatePointerBitCastOrAddrSpaceCast(V: RHSAlloca, DestTy: Arg1Type);
4564 Builder.CreateStore(Val: Arg0, Ptr: LHSAddrCast);
4565 Builder.CreateStore(Val: Arg1, Ptr: RHSAddrCast);
4566 LHSArrayPtr = Builder.CreateLoad(Ty: Arg0Type, Ptr: LHSAddrCast);
4567 RHSArrayPtr = Builder.CreateLoad(Ty: Arg1Type, Ptr: RHSAddrCast);
4568 } else {
4569 LHSArrayPtr = ReductionFunc->getArg(i: 0);
4570 RHSArrayPtr = ReductionFunc->getArg(i: 1);
4571 }
4572
4573 unsigned NumReductions = ReductionInfos.size();
4574 Type *RedArrayTy = ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: NumReductions);
4575
4576 for (auto En : enumerate(First&: ReductionInfos)) {
4577 const OpenMPIRBuilder::ReductionInfo &RI = En.value();
4578 Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
4579 Ty: RedArrayTy, Ptr: LHSArrayPtr, Idx0: 0, Idx1: En.index());
4580 Value *LHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: LHSI8PtrPtr);
4581 Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
4582 V: LHSI8Ptr, DestTy: RI.Variable->getType());
4583 Value *LHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: LHSPtr);
4584 Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
4585 Ty: RedArrayTy, Ptr: RHSArrayPtr, Idx0: 0, Idx1: En.index());
4586 Value *RHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: RHSI8PtrPtr);
4587 Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
4588 V: RHSI8Ptr, DestTy: RI.PrivateVariable->getType());
4589 Value *RHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: RHSPtr);
4590 Value *Reduced;
4591 OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
4592 RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
4593 if (!AfterIP)
4594 return AfterIP.takeError();
4595
4596 Builder.restoreIP(IP: *AfterIP);
4597 // TODO: Consider flagging an error.
4598 if (!Builder.GetInsertBlock())
4599 return Error::success();
4600
4601 // store is inside of the reduction region when using by-ref
4602 if (!IsByRef[En.index()])
4603 Builder.CreateStore(Val: Reduced, Ptr: LHSPtr);
4604 }
4605 Builder.CreateRetVoid();
4606 return Error::success();
4607}
4608
4609OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductions(
4610 const LocationDescription &Loc, InsertPointTy AllocaIP,
4611 ArrayRef<ReductionInfo> ReductionInfos, ArrayRef<bool> IsByRef,
4612 bool IsNoWait, bool IsTeamsReduction) {
4613 assert(ReductionInfos.size() == IsByRef.size());
4614 if (Config.isGPU())
4615 return createReductionsGPU(Loc, AllocaIP, CodeGenIP: Builder.saveIP(), ReductionInfos,
4616 IsByRef, IsNoWait, IsTeamsReduction);
4617
4618 checkReductionInfos(ReductionInfos, /*IsGPU*/ false);
4619
4620 if (!updateToLocation(Loc))
4621 return InsertPointTy();
4622
4623 if (ReductionInfos.size() == 0)
4624 return Builder.saveIP();
4625
4626 BasicBlock *InsertBlock = Loc.IP.getBlock();
4627 BasicBlock *ContinuationBlock =
4628 InsertBlock->splitBasicBlock(I: Loc.IP.getPoint(), BBName: "reduce.finalize");
4629 InsertBlock->getTerminator()->eraseFromParent();
4630
4631 // Create and populate array of type-erased pointers to private reduction
4632 // values.
4633 unsigned NumReductions = ReductionInfos.size();
4634 Type *RedArrayTy = ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: NumReductions);
4635 Builder.SetInsertPoint(AllocaIP.getBlock()->getTerminator());
4636 Value *RedArray = Builder.CreateAlloca(Ty: RedArrayTy, ArraySize: nullptr, Name: "red.array");
4637
4638 Builder.SetInsertPoint(TheBB: InsertBlock, IP: InsertBlock->end());
4639
4640 for (auto En : enumerate(First&: ReductionInfos)) {
4641 unsigned Index = En.index();
4642 const ReductionInfo &RI = En.value();
4643 Value *RedArrayElemPtr = Builder.CreateConstInBoundsGEP2_64(
4644 Ty: RedArrayTy, Ptr: RedArray, Idx0: 0, Idx1: Index, Name: "red.array.elem." + Twine(Index));
4645 Builder.CreateStore(Val: RI.PrivateVariable, Ptr: RedArrayElemPtr);
4646 }
4647
4648 // Emit a call to the runtime function that orchestrates the reduction.
4649 // Declare the reduction function in the process.
4650 Type *IndexTy = Builder.getIndexTy(
4651 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
4652 Function *Func = Builder.GetInsertBlock()->getParent();
4653 Module *Module = Func->getParent();
4654 uint32_t SrcLocStrSize;
4655 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4656 bool CanGenerateAtomic = all_of(Range&: ReductionInfos, P: [](const ReductionInfo &RI) {
4657 return RI.AtomicReductionGen;
4658 });
4659 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize,
4660 LocFlags: CanGenerateAtomic
4661 ? IdentFlag::OMP_IDENT_FLAG_ATOMIC_REDUCE
4662 : IdentFlag(0));
4663 Value *ThreadId = getOrCreateThreadID(Ident);
4664 Constant *NumVariables = Builder.getInt32(C: NumReductions);
4665 const DataLayout &DL = Module->getDataLayout();
4666 unsigned RedArrayByteSize = DL.getTypeStoreSize(Ty: RedArrayTy);
4667 Constant *RedArraySize = ConstantInt::get(Ty: IndexTy, V: RedArrayByteSize);
4668 Function *ReductionFunc = getFreshReductionFunc(M&: *Module);
4669 Value *Lock = getOMPCriticalRegionLock(CriticalName: ".reduction");
4670 Function *ReduceFunc = getOrCreateRuntimeFunctionPtr(
4671 FnID: IsNoWait ? RuntimeFunction::OMPRTL___kmpc_reduce_nowait
4672 : RuntimeFunction::OMPRTL___kmpc_reduce);
4673 CallInst *ReduceCall =
4674 createRuntimeFunctionCall(Callee: ReduceFunc,
4675 Args: {Ident, ThreadId, NumVariables, RedArraySize,
4676 RedArray, ReductionFunc, Lock},
4677 Name: "reduce");
4678
4679 // Create final reduction entry blocks for the atomic and non-atomic case.
4680 // Emit IR that dispatches control flow to one of the blocks based on the
4681 // reduction supporting the atomic mode.
4682 BasicBlock *NonAtomicRedBlock =
4683 BasicBlock::Create(Context&: Module->getContext(), Name: "reduce.switch.nonatomic", Parent: Func);
4684 BasicBlock *AtomicRedBlock =
4685 BasicBlock::Create(Context&: Module->getContext(), Name: "reduce.switch.atomic", Parent: Func);
4686 SwitchInst *Switch =
4687 Builder.CreateSwitch(V: ReduceCall, Dest: ContinuationBlock, /* NumCases */ 2);
4688 Switch->addCase(OnVal: Builder.getInt32(C: 1), Dest: NonAtomicRedBlock);
4689 Switch->addCase(OnVal: Builder.getInt32(C: 2), Dest: AtomicRedBlock);
4690
4691 // Populate the non-atomic reduction using the elementwise reduction function.
4692 // This loads the elements from the global and private variables and reduces
4693 // them before storing back the result to the global variable.
4694 Builder.SetInsertPoint(NonAtomicRedBlock);
4695 for (auto En : enumerate(First&: ReductionInfos)) {
4696 const ReductionInfo &RI = En.value();
4697 Type *ValueType = RI.ElementType;
4698 // We have one less load for by-ref case because that load is now inside of
4699 // the reduction region
4700 Value *RedValue = RI.Variable;
4701 if (!IsByRef[En.index()]) {
4702 RedValue = Builder.CreateLoad(Ty: ValueType, Ptr: RI.Variable,
4703 Name: "red.value." + Twine(En.index()));
4704 }
4705 Value *PrivateRedValue =
4706 Builder.CreateLoad(Ty: ValueType, Ptr: RI.PrivateVariable,
4707 Name: "red.private.value." + Twine(En.index()));
4708 Value *Reduced;
4709 InsertPointOrErrorTy AfterIP =
4710 RI.ReductionGen(Builder.saveIP(), RedValue, PrivateRedValue, Reduced);
4711 if (!AfterIP)
4712 return AfterIP.takeError();
4713 Builder.restoreIP(IP: *AfterIP);
4714
4715 if (!Builder.GetInsertBlock())
4716 return InsertPointTy();
4717 // for by-ref case, the load is inside of the reduction region
4718 if (!IsByRef[En.index()])
4719 Builder.CreateStore(Val: Reduced, Ptr: RI.Variable);
4720 }
4721 Function *EndReduceFunc = getOrCreateRuntimeFunctionPtr(
4722 FnID: IsNoWait ? RuntimeFunction::OMPRTL___kmpc_end_reduce_nowait
4723 : RuntimeFunction::OMPRTL___kmpc_end_reduce);
4724 createRuntimeFunctionCall(Callee: EndReduceFunc, Args: {Ident, ThreadId, Lock});
4725 Builder.CreateBr(Dest: ContinuationBlock);
4726
4727 // Populate the atomic reduction using the atomic elementwise reduction
4728 // function. There are no loads/stores here because they will be happening
4729 // inside the atomic elementwise reduction.
4730 Builder.SetInsertPoint(AtomicRedBlock);
4731 if (CanGenerateAtomic && llvm::none_of(Range&: IsByRef, P: [](bool P) { return P; })) {
4732 for (const ReductionInfo &RI : ReductionInfos) {
4733 InsertPointOrErrorTy AfterIP = RI.AtomicReductionGen(
4734 Builder.saveIP(), RI.ElementType, RI.Variable, RI.PrivateVariable);
4735 if (!AfterIP)
4736 return AfterIP.takeError();
4737 Builder.restoreIP(IP: *AfterIP);
4738 if (!Builder.GetInsertBlock())
4739 return InsertPointTy();
4740 }
4741 Builder.CreateBr(Dest: ContinuationBlock);
4742 } else {
4743 Builder.CreateUnreachable();
4744 }
4745
4746 // Populate the outlined reduction function using the elementwise reduction
4747 // function. Partial values are extracted from the type-erased array of
4748 // pointers to private variables.
4749 Error Err = populateReductionFunction(ReductionFunc, ReductionInfos, Builder,
4750 IsByRef, /*isGPU=*/IsGPU: false);
4751 if (Err)
4752 return Err;
4753
4754 if (!Builder.GetInsertBlock())
4755 return InsertPointTy();
4756
4757 Builder.SetInsertPoint(ContinuationBlock);
4758 return Builder.saveIP();
4759}
4760
4761OpenMPIRBuilder::InsertPointOrErrorTy
4762OpenMPIRBuilder::createMaster(const LocationDescription &Loc,
4763 BodyGenCallbackTy BodyGenCB,
4764 FinalizeCallbackTy FiniCB) {
4765 if (!updateToLocation(Loc))
4766 return Loc.IP;
4767
4768 Directive OMPD = Directive::OMPD_master;
4769 uint32_t SrcLocStrSize;
4770 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4771 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4772 Value *ThreadId = getOrCreateThreadID(Ident);
4773 Value *Args[] = {Ident, ThreadId};
4774
4775 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_master);
4776 Instruction *EntryCall = createRuntimeFunctionCall(Callee: EntryRTLFn, Args);
4777
4778 Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_master);
4779 Instruction *ExitCall = createRuntimeFunctionCall(Callee: ExitRTLFn, Args);
4780
4781 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
4782 /*Conditional*/ true, /*hasFinalize*/ HasFinalize: true);
4783}
4784
4785OpenMPIRBuilder::InsertPointOrErrorTy
4786OpenMPIRBuilder::createMasked(const LocationDescription &Loc,
4787 BodyGenCallbackTy BodyGenCB,
4788 FinalizeCallbackTy FiniCB, Value *Filter) {
4789 if (!updateToLocation(Loc))
4790 return Loc.IP;
4791
4792 Directive OMPD = Directive::OMPD_masked;
4793 uint32_t SrcLocStrSize;
4794 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4795 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4796 Value *ThreadId = getOrCreateThreadID(Ident);
4797 Value *Args[] = {Ident, ThreadId, Filter};
4798 Value *ArgsEnd[] = {Ident, ThreadId};
4799
4800 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_masked);
4801 Instruction *EntryCall = createRuntimeFunctionCall(Callee: EntryRTLFn, Args);
4802
4803 Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_masked);
4804 Instruction *ExitCall = createRuntimeFunctionCall(Callee: ExitRTLFn, Args: ArgsEnd);
4805
4806 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
4807 /*Conditional*/ true, /*hasFinalize*/ HasFinalize: true);
4808}
4809
4810static llvm::CallInst *emitNoUnwindRuntimeCall(IRBuilder<> &Builder,
4811 llvm::FunctionCallee Callee,
4812 ArrayRef<llvm::Value *> Args,
4813 const llvm::Twine &Name) {
4814 llvm::CallInst *Call = Builder.CreateCall(
4815 Callee, Args, OpBundles: SmallVector<llvm::OperandBundleDef, 1>(), Name);
4816 Call->setDoesNotThrow();
4817 return Call;
4818}
4819
4820// Expects input basic block is dominated by BeforeScanBB.
4821// Once Scan directive is encountered, the code after scan directive should be
4822// dominated by AfterScanBB. Scan directive splits the code sequence to
4823// scan and input phase. Based on whether inclusive or exclusive
4824// clause is used in the scan directive and whether input loop or scan loop
4825// is lowered, it adds jumps to input and scan phase. First Scan loop is the
4826// input loop and second is the scan loop. The code generated handles only
4827// inclusive scans now.
4828OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createScan(
4829 const LocationDescription &Loc, InsertPointTy AllocaIP,
4830 ArrayRef<llvm::Value *> ScanVars, ArrayRef<llvm::Type *> ScanVarsType,
4831 bool IsInclusive, ScanInfo *ScanRedInfo) {
4832 if (ScanRedInfo->OMPFirstScanLoop) {
4833 llvm::Error Err = emitScanBasedDirectiveDeclsIR(AllocaIP, ScanVars,
4834 ScanVarsType, ScanRedInfo);
4835 if (Err)
4836 return Err;
4837 }
4838 if (!updateToLocation(Loc))
4839 return Loc.IP;
4840
4841 llvm::Value *IV = ScanRedInfo->IV;
4842
4843 if (ScanRedInfo->OMPFirstScanLoop) {
4844 // Emit buffer[i] = red; at the end of the input phase.
4845 for (size_t i = 0; i < ScanVars.size(); i++) {
4846 Value *BuffPtr = (*(ScanRedInfo->ScanBuffPtrs))[ScanVars[i]];
4847 Value *Buff = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BuffPtr);
4848 Type *DestTy = ScanVarsType[i];
4849 Value *Val = Builder.CreateInBoundsGEP(Ty: DestTy, Ptr: Buff, IdxList: IV, Name: "arrayOffset");
4850 Value *Src = Builder.CreateLoad(Ty: DestTy, Ptr: ScanVars[i]);
4851
4852 Builder.CreateStore(Val: Src, Ptr: Val);
4853 }
4854 }
4855 Builder.CreateBr(Dest: ScanRedInfo->OMPScanLoopExit);
4856 emitBlock(BB: ScanRedInfo->OMPScanDispatch,
4857 CurFn: Builder.GetInsertBlock()->getParent());
4858
4859 if (!ScanRedInfo->OMPFirstScanLoop) {
4860 IV = ScanRedInfo->IV;
4861 // Emit red = buffer[i]; at the entrance to the scan phase.
4862 // TODO: if exclusive scan, the red = buffer[i-1] needs to be updated.
4863 for (size_t i = 0; i < ScanVars.size(); i++) {
4864 Value *BuffPtr = (*(ScanRedInfo->ScanBuffPtrs))[ScanVars[i]];
4865 Value *Buff = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BuffPtr);
4866 Type *DestTy = ScanVarsType[i];
4867 Value *SrcPtr =
4868 Builder.CreateInBoundsGEP(Ty: DestTy, Ptr: Buff, IdxList: IV, Name: "arrayOffset");
4869 Value *Src = Builder.CreateLoad(Ty: DestTy, Ptr: SrcPtr);
4870 Builder.CreateStore(Val: Src, Ptr: ScanVars[i]);
4871 }
4872 }
4873
4874 // TODO: Update it to CreateBr and remove dead blocks
4875 llvm::Value *CmpI = Builder.getInt1(V: true);
4876 if (ScanRedInfo->OMPFirstScanLoop == IsInclusive) {
4877 Builder.CreateCondBr(Cond: CmpI, True: ScanRedInfo->OMPBeforeScanBlock,
4878 False: ScanRedInfo->OMPAfterScanBlock);
4879 } else {
4880 Builder.CreateCondBr(Cond: CmpI, True: ScanRedInfo->OMPAfterScanBlock,
4881 False: ScanRedInfo->OMPBeforeScanBlock);
4882 }
4883 emitBlock(BB: ScanRedInfo->OMPAfterScanBlock,
4884 CurFn: Builder.GetInsertBlock()->getParent());
4885 Builder.SetInsertPoint(ScanRedInfo->OMPAfterScanBlock);
4886 return Builder.saveIP();
4887}
4888
4889Error OpenMPIRBuilder::emitScanBasedDirectiveDeclsIR(
4890 InsertPointTy AllocaIP, ArrayRef<Value *> ScanVars,
4891 ArrayRef<Type *> ScanVarsType, ScanInfo *ScanRedInfo) {
4892
4893 Builder.restoreIP(IP: AllocaIP);
4894 // Create the shared pointer at alloca IP.
4895 for (size_t i = 0; i < ScanVars.size(); i++) {
4896 llvm::Value *BuffPtr =
4897 Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: "vla");
4898 (*(ScanRedInfo->ScanBuffPtrs))[ScanVars[i]] = BuffPtr;
4899 }
4900
4901 // Allocate temporary buffer by master thread
4902 auto BodyGenCB = [&](InsertPointTy AllocaIP,
4903 InsertPointTy CodeGenIP) -> Error {
4904 Builder.restoreIP(IP: CodeGenIP);
4905 Value *AllocSpan =
4906 Builder.CreateAdd(LHS: ScanRedInfo->Span, RHS: Builder.getInt32(C: 1));
4907 for (size_t i = 0; i < ScanVars.size(); i++) {
4908 Type *IntPtrTy = Builder.getInt32Ty();
4909 Constant *Allocsize = ConstantExpr::getSizeOf(Ty: ScanVarsType[i]);
4910 Allocsize = ConstantExpr::getTruncOrBitCast(C: Allocsize, Ty: IntPtrTy);
4911 Value *Buff = Builder.CreateMalloc(IntPtrTy, AllocTy: ScanVarsType[i], AllocSize: Allocsize,
4912 ArraySize: AllocSpan, MallocF: nullptr, Name: "arr");
4913 Builder.CreateStore(Val: Buff, Ptr: (*(ScanRedInfo->ScanBuffPtrs))[ScanVars[i]]);
4914 }
4915 return Error::success();
4916 };
4917 // TODO: Perform finalization actions for variables. This has to be
4918 // called for variables which have destructors/finalizers.
4919 auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); };
4920
4921 Builder.SetInsertPoint(ScanRedInfo->OMPScanInit->getTerminator());
4922 llvm::Value *FilterVal = Builder.getInt32(C: 0);
4923 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
4924 createMasked(Loc: Builder.saveIP(), BodyGenCB, FiniCB, Filter: FilterVal);
4925
4926 if (!AfterIP)
4927 return AfterIP.takeError();
4928 Builder.restoreIP(IP: *AfterIP);
4929 BasicBlock *InputBB = Builder.GetInsertBlock();
4930 if (InputBB->getTerminator())
4931 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
4932 AfterIP = createBarrier(Loc: Builder.saveIP(), Kind: llvm::omp::OMPD_barrier);
4933 if (!AfterIP)
4934 return AfterIP.takeError();
4935 Builder.restoreIP(IP: *AfterIP);
4936
4937 return Error::success();
4938}
4939
4940Error OpenMPIRBuilder::emitScanBasedDirectiveFinalsIR(
4941 ArrayRef<ReductionInfo> ReductionInfos, ScanInfo *ScanRedInfo) {
4942 auto BodyGenCB = [&](InsertPointTy AllocaIP,
4943 InsertPointTy CodeGenIP) -> Error {
4944 Builder.restoreIP(IP: CodeGenIP);
4945 for (ReductionInfo RedInfo : ReductionInfos) {
4946 Value *PrivateVar = RedInfo.PrivateVariable;
4947 Value *OrigVar = RedInfo.Variable;
4948 Value *BuffPtr = (*(ScanRedInfo->ScanBuffPtrs))[PrivateVar];
4949 Value *Buff = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BuffPtr);
4950
4951 Type *SrcTy = RedInfo.ElementType;
4952 Value *Val = Builder.CreateInBoundsGEP(Ty: SrcTy, Ptr: Buff, IdxList: ScanRedInfo->Span,
4953 Name: "arrayOffset");
4954 Value *Src = Builder.CreateLoad(Ty: SrcTy, Ptr: Val);
4955
4956 Builder.CreateStore(Val: Src, Ptr: OrigVar);
4957 Builder.CreateFree(Source: Buff);
4958 }
4959 return Error::success();
4960 };
4961 // TODO: Perform finalization actions for variables. This has to be
4962 // called for variables which have destructors/finalizers.
4963 auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); };
4964
4965 if (ScanRedInfo->OMPScanFinish->getTerminator())
4966 Builder.SetInsertPoint(ScanRedInfo->OMPScanFinish->getTerminator());
4967 else
4968 Builder.SetInsertPoint(ScanRedInfo->OMPScanFinish);
4969
4970 llvm::Value *FilterVal = Builder.getInt32(C: 0);
4971 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
4972 createMasked(Loc: Builder.saveIP(), BodyGenCB, FiniCB, Filter: FilterVal);
4973
4974 if (!AfterIP)
4975 return AfterIP.takeError();
4976 Builder.restoreIP(IP: *AfterIP);
4977 BasicBlock *InputBB = Builder.GetInsertBlock();
4978 if (InputBB->getTerminator())
4979 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
4980 AfterIP = createBarrier(Loc: Builder.saveIP(), Kind: llvm::omp::OMPD_barrier);
4981 if (!AfterIP)
4982 return AfterIP.takeError();
4983 Builder.restoreIP(IP: *AfterIP);
4984 return Error::success();
4985}
4986
4987OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitScanReduction(
4988 const LocationDescription &Loc,
4989 ArrayRef<llvm::OpenMPIRBuilder::ReductionInfo> ReductionInfos,
4990 ScanInfo *ScanRedInfo) {
4991
4992 if (!updateToLocation(Loc))
4993 return Loc.IP;
4994 auto BodyGenCB = [&](InsertPointTy AllocaIP,
4995 InsertPointTy CodeGenIP) -> Error {
4996 Builder.restoreIP(IP: CodeGenIP);
4997 Function *CurFn = Builder.GetInsertBlock()->getParent();
4998 // for (int k = 0; k <= ceil(log2(n)); ++k)
4999 llvm::BasicBlock *LoopBB =
5000 BasicBlock::Create(Context&: CurFn->getContext(), Name: "omp.outer.log.scan.body");
5001 llvm::BasicBlock *ExitBB =
5002 splitBB(Builder, CreateBranch: false, Name: "omp.outer.log.scan.exit");
5003 llvm::Function *F = llvm::Intrinsic::getOrInsertDeclaration(
5004 M: Builder.GetInsertBlock()->getModule(),
5005 id: (llvm::Intrinsic::ID)llvm::Intrinsic::log2, Tys: Builder.getDoubleTy());
5006 llvm::BasicBlock *InputBB = Builder.GetInsertBlock();
5007 llvm::Value *Arg =
5008 Builder.CreateUIToFP(V: ScanRedInfo->Span, DestTy: Builder.getDoubleTy());
5009 llvm::Value *LogVal = emitNoUnwindRuntimeCall(Builder, Callee: F, Args: Arg, Name: "");
5010 F = llvm::Intrinsic::getOrInsertDeclaration(
5011 M: Builder.GetInsertBlock()->getModule(),
5012 id: (llvm::Intrinsic::ID)llvm::Intrinsic::ceil, Tys: Builder.getDoubleTy());
5013 LogVal = emitNoUnwindRuntimeCall(Builder, Callee: F, Args: LogVal, Name: "");
5014 LogVal = Builder.CreateFPToUI(V: LogVal, DestTy: Builder.getInt32Ty());
5015 llvm::Value *NMin1 = Builder.CreateNUWSub(
5016 LHS: ScanRedInfo->Span,
5017 RHS: llvm::ConstantInt::get(Ty: ScanRedInfo->Span->getType(), V: 1));
5018 Builder.SetInsertPoint(InputBB);
5019 Builder.CreateBr(Dest: LoopBB);
5020 emitBlock(BB: LoopBB, CurFn);
5021 Builder.SetInsertPoint(LoopBB);
5022
5023 PHINode *Counter = Builder.CreatePHI(Ty: Builder.getInt32Ty(), NumReservedValues: 2);
5024 // size pow2k = 1;
5025 PHINode *Pow2K = Builder.CreatePHI(Ty: Builder.getInt32Ty(), NumReservedValues: 2);
5026 Counter->addIncoming(V: llvm::ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
5027 BB: InputBB);
5028 Pow2K->addIncoming(V: llvm::ConstantInt::get(Ty: Builder.getInt32Ty(), V: 1),
5029 BB: InputBB);
5030 // for (size i = n - 1; i >= 2 ^ k; --i)
5031 // tmp[i] op= tmp[i-pow2k];
5032 llvm::BasicBlock *InnerLoopBB =
5033 BasicBlock::Create(Context&: CurFn->getContext(), Name: "omp.inner.log.scan.body");
5034 llvm::BasicBlock *InnerExitBB =
5035 BasicBlock::Create(Context&: CurFn->getContext(), Name: "omp.inner.log.scan.exit");
5036 llvm::Value *CmpI = Builder.CreateICmpUGE(LHS: NMin1, RHS: Pow2K);
5037 Builder.CreateCondBr(Cond: CmpI, True: InnerLoopBB, False: InnerExitBB);
5038 emitBlock(BB: InnerLoopBB, CurFn);
5039 Builder.SetInsertPoint(InnerLoopBB);
5040 PHINode *IVal = Builder.CreatePHI(Ty: Builder.getInt32Ty(), NumReservedValues: 2);
5041 IVal->addIncoming(V: NMin1, BB: LoopBB);
5042 for (ReductionInfo RedInfo : ReductionInfos) {
5043 Value *ReductionVal = RedInfo.PrivateVariable;
5044 Value *BuffPtr = (*(ScanRedInfo->ScanBuffPtrs))[ReductionVal];
5045 Value *Buff = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BuffPtr);
5046 Type *DestTy = RedInfo.ElementType;
5047 Value *IV = Builder.CreateAdd(LHS: IVal, RHS: Builder.getInt32(C: 1));
5048 Value *LHSPtr =
5049 Builder.CreateInBoundsGEP(Ty: DestTy, Ptr: Buff, IdxList: IV, Name: "arrayOffset");
5050 Value *OffsetIval = Builder.CreateNUWSub(LHS: IV, RHS: Pow2K);
5051 Value *RHSPtr =
5052 Builder.CreateInBoundsGEP(Ty: DestTy, Ptr: Buff, IdxList: OffsetIval, Name: "arrayOffset");
5053 Value *LHS = Builder.CreateLoad(Ty: DestTy, Ptr: LHSPtr);
5054 Value *RHS = Builder.CreateLoad(Ty: DestTy, Ptr: RHSPtr);
5055 llvm::Value *Result;
5056 InsertPointOrErrorTy AfterIP =
5057 RedInfo.ReductionGen(Builder.saveIP(), LHS, RHS, Result);
5058 if (!AfterIP)
5059 return AfterIP.takeError();
5060 Builder.CreateStore(Val: Result, Ptr: LHSPtr);
5061 }
5062 llvm::Value *NextIVal = Builder.CreateNUWSub(
5063 LHS: IVal, RHS: llvm::ConstantInt::get(Ty: Builder.getInt32Ty(), V: 1));
5064 IVal->addIncoming(V: NextIVal, BB: Builder.GetInsertBlock());
5065 CmpI = Builder.CreateICmpUGE(LHS: NextIVal, RHS: Pow2K);
5066 Builder.CreateCondBr(Cond: CmpI, True: InnerLoopBB, False: InnerExitBB);
5067 emitBlock(BB: InnerExitBB, CurFn);
5068 llvm::Value *Next = Builder.CreateNUWAdd(
5069 LHS: Counter, RHS: llvm::ConstantInt::get(Ty: Counter->getType(), V: 1));
5070 Counter->addIncoming(V: Next, BB: Builder.GetInsertBlock());
5071 // pow2k <<= 1;
5072 llvm::Value *NextPow2K = Builder.CreateShl(LHS: Pow2K, RHS: 1, Name: "", /*HasNUW=*/true);
5073 Pow2K->addIncoming(V: NextPow2K, BB: Builder.GetInsertBlock());
5074 llvm::Value *Cmp = Builder.CreateICmpNE(LHS: Next, RHS: LogVal);
5075 Builder.CreateCondBr(Cond: Cmp, True: LoopBB, False: ExitBB);
5076 Builder.SetInsertPoint(ExitBB->getFirstInsertionPt());
5077 return Error::success();
5078 };
5079
5080 // TODO: Perform finalization actions for variables. This has to be
5081 // called for variables which have destructors/finalizers.
5082 auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); };
5083
5084 llvm::Value *FilterVal = Builder.getInt32(C: 0);
5085 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
5086 createMasked(Loc: Builder.saveIP(), BodyGenCB, FiniCB, Filter: FilterVal);
5087
5088 if (!AfterIP)
5089 return AfterIP.takeError();
5090 Builder.restoreIP(IP: *AfterIP);
5091 AfterIP = createBarrier(Loc: Builder.saveIP(), Kind: llvm::omp::OMPD_barrier);
5092
5093 if (!AfterIP)
5094 return AfterIP.takeError();
5095 Builder.restoreIP(IP: *AfterIP);
5096 Error Err = emitScanBasedDirectiveFinalsIR(ReductionInfos, ScanRedInfo);
5097 if (Err)
5098 return Err;
5099
5100 return AfterIP;
5101}
5102
5103Error OpenMPIRBuilder::emitScanBasedDirectiveIR(
5104 llvm::function_ref<Error()> InputLoopGen,
5105 llvm::function_ref<Error(LocationDescription Loc)> ScanLoopGen,
5106 ScanInfo *ScanRedInfo) {
5107
5108 {
5109 // Emit loop with input phase:
5110 // for (i: 0..<num_iters>) {
5111 // <input phase>;
5112 // buffer[i] = red;
5113 // }
5114 ScanRedInfo->OMPFirstScanLoop = true;
5115 Error Err = InputLoopGen();
5116 if (Err)
5117 return Err;
5118 }
5119 {
5120 // Emit loop with scan phase:
5121 // for (i: 0..<num_iters>) {
5122 // red = buffer[i];
5123 // <scan phase>;
5124 // }
5125 ScanRedInfo->OMPFirstScanLoop = false;
5126 Error Err = ScanLoopGen(Builder.saveIP());
5127 if (Err)
5128 return Err;
5129 }
5130 return Error::success();
5131}
5132
5133void OpenMPIRBuilder::createScanBBs(ScanInfo *ScanRedInfo) {
5134 Function *Fun = Builder.GetInsertBlock()->getParent();
5135 ScanRedInfo->OMPScanDispatch =
5136 BasicBlock::Create(Context&: Fun->getContext(), Name: "omp.inscan.dispatch");
5137 ScanRedInfo->OMPAfterScanBlock =
5138 BasicBlock::Create(Context&: Fun->getContext(), Name: "omp.after.scan.bb");
5139 ScanRedInfo->OMPBeforeScanBlock =
5140 BasicBlock::Create(Context&: Fun->getContext(), Name: "omp.before.scan.bb");
5141 ScanRedInfo->OMPScanLoopExit =
5142 BasicBlock::Create(Context&: Fun->getContext(), Name: "omp.scan.loop.exit");
5143}
5144CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton(
5145 DebugLoc DL, Value *TripCount, Function *F, BasicBlock *PreInsertBefore,
5146 BasicBlock *PostInsertBefore, const Twine &Name) {
5147 Module *M = F->getParent();
5148 LLVMContext &Ctx = M->getContext();
5149 Type *IndVarTy = TripCount->getType();
5150
5151 // Create the basic block structure.
5152 BasicBlock *Preheader =
5153 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".preheader", Parent: F, InsertBefore: PreInsertBefore);
5154 BasicBlock *Header =
5155 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".header", Parent: F, InsertBefore: PreInsertBefore);
5156 BasicBlock *Cond =
5157 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".cond", Parent: F, InsertBefore: PreInsertBefore);
5158 BasicBlock *Body =
5159 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".body", Parent: F, InsertBefore: PreInsertBefore);
5160 BasicBlock *Latch =
5161 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".inc", Parent: F, InsertBefore: PostInsertBefore);
5162 BasicBlock *Exit =
5163 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".exit", Parent: F, InsertBefore: PostInsertBefore);
5164 BasicBlock *After =
5165 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".after", Parent: F, InsertBefore: PostInsertBefore);
5166
5167 // Use specified DebugLoc for new instructions.
5168 Builder.SetCurrentDebugLocation(DL);
5169
5170 Builder.SetInsertPoint(Preheader);
5171 Builder.CreateBr(Dest: Header);
5172
5173 Builder.SetInsertPoint(Header);
5174 PHINode *IndVarPHI = Builder.CreatePHI(Ty: IndVarTy, NumReservedValues: 2, Name: "omp_" + Name + ".iv");
5175 IndVarPHI->addIncoming(V: ConstantInt::get(Ty: IndVarTy, V: 0), BB: Preheader);
5176 Builder.CreateBr(Dest: Cond);
5177
5178 Builder.SetInsertPoint(Cond);
5179 Value *Cmp =
5180 Builder.CreateICmpULT(LHS: IndVarPHI, RHS: TripCount, Name: "omp_" + Name + ".cmp");
5181 Builder.CreateCondBr(Cond: Cmp, True: Body, False: Exit);
5182
5183 Builder.SetInsertPoint(Body);
5184 Builder.CreateBr(Dest: Latch);
5185
5186 Builder.SetInsertPoint(Latch);
5187 Value *Next = Builder.CreateAdd(LHS: IndVarPHI, RHS: ConstantInt::get(Ty: IndVarTy, V: 1),
5188 Name: "omp_" + Name + ".next", /*HasNUW=*/true);
5189 Builder.CreateBr(Dest: Header);
5190 IndVarPHI->addIncoming(V: Next, BB: Latch);
5191
5192 Builder.SetInsertPoint(Exit);
5193 Builder.CreateBr(Dest: After);
5194
5195 // Remember and return the canonical control flow.
5196 LoopInfos.emplace_front();
5197 CanonicalLoopInfo *CL = &LoopInfos.front();
5198
5199 CL->Header = Header;
5200 CL->Cond = Cond;
5201 CL->Latch = Latch;
5202 CL->Exit = Exit;
5203
5204#ifndef NDEBUG
5205 CL->assertOK();
5206#endif
5207 return CL;
5208}
5209
5210Expected<CanonicalLoopInfo *>
5211OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
5212 LoopBodyGenCallbackTy BodyGenCB,
5213 Value *TripCount, const Twine &Name) {
5214 BasicBlock *BB = Loc.IP.getBlock();
5215 BasicBlock *NextBB = BB->getNextNode();
5216
5217 CanonicalLoopInfo *CL = createLoopSkeleton(DL: Loc.DL, TripCount, F: BB->getParent(),
5218 PreInsertBefore: NextBB, PostInsertBefore: NextBB, Name);
5219 BasicBlock *After = CL->getAfter();
5220
5221 // If location is not set, don't connect the loop.
5222 if (updateToLocation(Loc)) {
5223 // Split the loop at the insertion point: Branch to the preheader and move
5224 // every following instruction to after the loop (the After BB). Also, the
5225 // new successor is the loop's after block.
5226 spliceBB(Builder, New: After, /*CreateBranch=*/false);
5227 Builder.CreateBr(Dest: CL->getPreheader());
5228 }
5229
5230 // Emit the body content. We do it after connecting the loop to the CFG to
5231 // avoid that the callback encounters degenerate BBs.
5232 if (Error Err = BodyGenCB(CL->getBodyIP(), CL->getIndVar()))
5233 return Err;
5234
5235#ifndef NDEBUG
5236 CL->assertOK();
5237#endif
5238 return CL;
5239}
5240
5241Expected<ScanInfo *> OpenMPIRBuilder::scanInfoInitialize() {
5242 ScanInfos.emplace_front();
5243 ScanInfo *Result = &ScanInfos.front();
5244 return Result;
5245}
5246
5247Expected<SmallVector<llvm::CanonicalLoopInfo *>>
5248OpenMPIRBuilder::createCanonicalScanLoops(
5249 const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
5250 Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
5251 InsertPointTy ComputeIP, const Twine &Name, ScanInfo *ScanRedInfo) {
5252 LocationDescription ComputeLoc =
5253 ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
5254 updateToLocation(Loc: ComputeLoc);
5255
5256 SmallVector<CanonicalLoopInfo *> Result;
5257
5258 Value *TripCount = calculateCanonicalLoopTripCount(
5259 Loc: ComputeLoc, Start, Stop, Step, IsSigned, InclusiveStop, Name);
5260 ScanRedInfo->Span = TripCount;
5261 ScanRedInfo->OMPScanInit = splitBB(Builder, CreateBranch: true, Name: "scan.init");
5262 Builder.SetInsertPoint(ScanRedInfo->OMPScanInit);
5263
5264 auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
5265 Builder.restoreIP(IP: CodeGenIP);
5266 ScanRedInfo->IV = IV;
5267 createScanBBs(ScanRedInfo);
5268 BasicBlock *InputBlock = Builder.GetInsertBlock();
5269 Instruction *Terminator = InputBlock->getTerminator();
5270 assert(Terminator->getNumSuccessors() == 1);
5271 BasicBlock *ContinueBlock = Terminator->getSuccessor(Idx: 0);
5272 Terminator->setSuccessor(Idx: 0, BB: ScanRedInfo->OMPScanDispatch);
5273 emitBlock(BB: ScanRedInfo->OMPBeforeScanBlock,
5274 CurFn: Builder.GetInsertBlock()->getParent());
5275 Builder.CreateBr(Dest: ScanRedInfo->OMPScanLoopExit);
5276 emitBlock(BB: ScanRedInfo->OMPScanLoopExit,
5277 CurFn: Builder.GetInsertBlock()->getParent());
5278 Builder.CreateBr(Dest: ContinueBlock);
5279 Builder.SetInsertPoint(
5280 ScanRedInfo->OMPBeforeScanBlock->getFirstInsertionPt());
5281 return BodyGenCB(Builder.saveIP(), IV);
5282 };
5283
5284 const auto &&InputLoopGen = [&]() -> Error {
5285 Expected<CanonicalLoopInfo *> LoopInfo = createCanonicalLoop(
5286 Loc: Builder.saveIP(), BodyGenCB: BodyGen, Start, Stop, Step, IsSigned, InclusiveStop,
5287 ComputeIP, Name, InScan: true, ScanRedInfo);
5288 if (!LoopInfo)
5289 return LoopInfo.takeError();
5290 Result.push_back(Elt: *LoopInfo);
5291 Builder.restoreIP(IP: (*LoopInfo)->getAfterIP());
5292 return Error::success();
5293 };
5294 const auto &&ScanLoopGen = [&](LocationDescription Loc) -> Error {
5295 Expected<CanonicalLoopInfo *> LoopInfo =
5296 createCanonicalLoop(Loc, BodyGenCB: BodyGen, Start, Stop, Step, IsSigned,
5297 InclusiveStop, ComputeIP, Name, InScan: true, ScanRedInfo);
5298 if (!LoopInfo)
5299 return LoopInfo.takeError();
5300 Result.push_back(Elt: *LoopInfo);
5301 Builder.restoreIP(IP: (*LoopInfo)->getAfterIP());
5302 ScanRedInfo->OMPScanFinish = Builder.GetInsertBlock();
5303 return Error::success();
5304 };
5305 Error Err = emitScanBasedDirectiveIR(InputLoopGen, ScanLoopGen, ScanRedInfo);
5306 if (Err)
5307 return Err;
5308 return Result;
5309}
5310
5311Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
5312 const LocationDescription &Loc, Value *Start, Value *Stop, Value *Step,
5313 bool IsSigned, bool InclusiveStop, const Twine &Name) {
5314
5315 // Consider the following difficulties (assuming 8-bit signed integers):
5316 // * Adding \p Step to the loop counter which passes \p Stop may overflow:
5317 // DO I = 1, 100, 50
5318 /// * A \p Step of INT_MIN cannot not be normalized to a positive direction:
5319 // DO I = 100, 0, -128
5320
5321 // Start, Stop and Step must be of the same integer type.
5322 auto *IndVarTy = cast<IntegerType>(Val: Start->getType());
5323 assert(IndVarTy == Stop->getType() && "Stop type mismatch");
5324 assert(IndVarTy == Step->getType() && "Step type mismatch");
5325
5326 updateToLocation(Loc);
5327
5328 ConstantInt *Zero = ConstantInt::get(Ty: IndVarTy, V: 0);
5329 ConstantInt *One = ConstantInt::get(Ty: IndVarTy, V: 1);
5330
5331 // Like Step, but always positive.
5332 Value *Incr = Step;
5333
5334 // Distance between Start and Stop; always positive.
5335 Value *Span;
5336
5337 // Condition whether there are no iterations are executed at all, e.g. because
5338 // UB < LB.
5339 Value *ZeroCmp;
5340
5341 if (IsSigned) {
5342 // Ensure that increment is positive. If not, negate and invert LB and UB.
5343 Value *IsNeg = Builder.CreateICmpSLT(LHS: Step, RHS: Zero);
5344 Incr = Builder.CreateSelect(C: IsNeg, True: Builder.CreateNeg(V: Step), False: Step);
5345 Value *LB = Builder.CreateSelect(C: IsNeg, True: Stop, False: Start);
5346 Value *UB = Builder.CreateSelect(C: IsNeg, True: Start, False: Stop);
5347 Span = Builder.CreateSub(LHS: UB, RHS: LB, Name: "", HasNUW: false, HasNSW: true);
5348 ZeroCmp = Builder.CreateICmp(
5349 P: InclusiveStop ? CmpInst::ICMP_SLT : CmpInst::ICMP_SLE, LHS: UB, RHS: LB);
5350 } else {
5351 Span = Builder.CreateSub(LHS: Stop, RHS: Start, Name: "", HasNUW: true);
5352 ZeroCmp = Builder.CreateICmp(
5353 P: InclusiveStop ? CmpInst::ICMP_ULT : CmpInst::ICMP_ULE, LHS: Stop, RHS: Start);
5354 }
5355
5356 Value *CountIfLooping;
5357 if (InclusiveStop) {
5358 CountIfLooping = Builder.CreateAdd(LHS: Builder.CreateUDiv(LHS: Span, RHS: Incr), RHS: One);
5359 } else {
5360 // Avoid incrementing past stop since it could overflow.
5361 Value *CountIfTwo = Builder.CreateAdd(
5362 LHS: Builder.CreateUDiv(LHS: Builder.CreateSub(LHS: Span, RHS: One), RHS: Incr), RHS: One);
5363 Value *OneCmp = Builder.CreateICmp(P: CmpInst::ICMP_ULE, LHS: Span, RHS: Incr);
5364 CountIfLooping = Builder.CreateSelect(C: OneCmp, True: One, False: CountIfTwo);
5365 }
5366
5367 return Builder.CreateSelect(C: ZeroCmp, True: Zero, False: CountIfLooping,
5368 Name: "omp_" + Name + ".tripcount");
5369}
5370
5371Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
5372 const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
5373 Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
5374 InsertPointTy ComputeIP, const Twine &Name, bool InScan,
5375 ScanInfo *ScanRedInfo) {
5376 LocationDescription ComputeLoc =
5377 ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
5378
5379 Value *TripCount = calculateCanonicalLoopTripCount(
5380 Loc: ComputeLoc, Start, Stop, Step, IsSigned, InclusiveStop, Name);
5381
5382 auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
5383 Builder.restoreIP(IP: CodeGenIP);
5384 Value *Span = Builder.CreateMul(LHS: IV, RHS: Step);
5385 Value *IndVar = Builder.CreateAdd(LHS: Span, RHS: Start);
5386 if (InScan)
5387 ScanRedInfo->IV = IndVar;
5388 return BodyGenCB(Builder.saveIP(), IndVar);
5389 };
5390 LocationDescription LoopLoc =
5391 ComputeIP.isSet()
5392 ? Loc
5393 : LocationDescription(Builder.saveIP(),
5394 Builder.getCurrentDebugLocation());
5395 return createCanonicalLoop(Loc: LoopLoc, BodyGenCB: BodyGen, TripCount, Name);
5396}
5397
5398// Returns an LLVM function to call for initializing loop bounds using OpenMP
5399// static scheduling for composite `distribute parallel for` depending on
5400// `type`. Only i32 and i64 are supported by the runtime. Always interpret
5401// integers as unsigned similarly to CanonicalLoopInfo.
5402static FunctionCallee
5403getKmpcDistForStaticInitForType(Type *Ty, Module &M,
5404 OpenMPIRBuilder &OMPBuilder) {
5405 unsigned Bitwidth = Ty->getIntegerBitWidth();
5406 if (Bitwidth == 32)
5407 return OMPBuilder.getOrCreateRuntimeFunction(
5408 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dist_for_static_init_4u);
5409 if (Bitwidth == 64)
5410 return OMPBuilder.getOrCreateRuntimeFunction(
5411 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dist_for_static_init_8u);
5412 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
5413}
5414
5415// Returns an LLVM function to call for initializing loop bounds using OpenMP
5416// static scheduling depending on `type`. Only i32 and i64 are supported by the
5417// runtime. Always interpret integers as unsigned similarly to
5418// CanonicalLoopInfo.
5419static FunctionCallee getKmpcForStaticInitForType(Type *Ty, Module &M,
5420 OpenMPIRBuilder &OMPBuilder) {
5421 unsigned Bitwidth = Ty->getIntegerBitWidth();
5422 if (Bitwidth == 32)
5423 return OMPBuilder.getOrCreateRuntimeFunction(
5424 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_4u);
5425 if (Bitwidth == 64)
5426 return OMPBuilder.getOrCreateRuntimeFunction(
5427 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_8u);
5428 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
5429}
5430
5431OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
5432 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
5433 WorksharingLoopType LoopType, bool NeedsBarrier, bool HasDistSchedule,
5434 OMPScheduleType DistScheduleSchedType) {
5435 assert(CLI->isValid() && "Requires a valid canonical loop");
5436 assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
5437 "Require dedicated allocate IP");
5438
5439 // Set up the source location value for OpenMP runtime.
5440 Builder.restoreIP(IP: CLI->getPreheaderIP());
5441 Builder.SetCurrentDebugLocation(DL);
5442
5443 uint32_t SrcLocStrSize;
5444 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
5445 Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5446
5447 // Declare useful OpenMP runtime functions.
5448 Value *IV = CLI->getIndVar();
5449 Type *IVTy = IV->getType();
5450 FunctionCallee StaticInit =
5451 LoopType == WorksharingLoopType::DistributeForStaticLoop
5452 ? getKmpcDistForStaticInitForType(Ty: IVTy, M, OMPBuilder&: *this)
5453 : getKmpcForStaticInitForType(Ty: IVTy, M, OMPBuilder&: *this);
5454 FunctionCallee StaticFini =
5455 getOrCreateRuntimeFunction(M, FnID: omp::OMPRTL___kmpc_for_static_fini);
5456
5457 // Allocate space for computed loop bounds as expected by the "init" function.
5458 Builder.SetInsertPoint(AllocaIP.getBlock()->getFirstNonPHIOrDbgOrAlloca());
5459
5460 Type *I32Type = Type::getInt32Ty(C&: M.getContext());
5461 Value *PLastIter = Builder.CreateAlloca(Ty: I32Type, ArraySize: nullptr, Name: "p.lastiter");
5462 Value *PLowerBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.lowerbound");
5463 Value *PUpperBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.upperbound");
5464 Value *PStride = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.stride");
5465 CLI->setLastIter(PLastIter);
5466
5467 // At the end of the preheader, prepare for calling the "init" function by
5468 // storing the current loop bounds into the allocated space. A canonical loop
5469 // always iterates from 0 to trip-count with step 1. Note that "init" expects
5470 // and produces an inclusive upper bound.
5471 Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
5472 Constant *Zero = ConstantInt::get(Ty: IVTy, V: 0);
5473 Constant *One = ConstantInt::get(Ty: IVTy, V: 1);
5474 Builder.CreateStore(Val: Zero, Ptr: PLowerBound);
5475 Value *UpperBound = Builder.CreateSub(LHS: CLI->getTripCount(), RHS: One);
5476 Builder.CreateStore(Val: UpperBound, Ptr: PUpperBound);
5477 Builder.CreateStore(Val: One, Ptr: PStride);
5478
5479 Value *ThreadNum = getOrCreateThreadID(Ident: SrcLoc);
5480
5481 OMPScheduleType SchedType =
5482 (LoopType == WorksharingLoopType::DistributeStaticLoop)
5483 ? OMPScheduleType::OrderedDistribute
5484 : OMPScheduleType::UnorderedStatic;
5485 Constant *SchedulingType =
5486 ConstantInt::get(Ty: I32Type, V: static_cast<int>(SchedType));
5487
5488 // Call the "init" function and update the trip count of the loop with the
5489 // value it produced.
5490 auto BuildInitCall = [LoopType, SrcLoc, ThreadNum, PLastIter, PLowerBound,
5491 PUpperBound, IVTy, PStride, One, Zero, StaticInit,
5492 this](Value *SchedulingType, auto &Builder) {
5493 SmallVector<Value *, 10> Args({SrcLoc, ThreadNum, SchedulingType, PLastIter,
5494 PLowerBound, PUpperBound});
5495 if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
5496 Value *PDistUpperBound =
5497 Builder.CreateAlloca(IVTy, nullptr, "p.distupperbound");
5498 Args.push_back(Elt: PDistUpperBound);
5499 }
5500 Args.append(IL: {PStride, One, Zero});
5501 createRuntimeFunctionCall(Callee: StaticInit, Args);
5502 };
5503 BuildInitCall(SchedulingType, Builder);
5504 if (HasDistSchedule &&
5505 LoopType != WorksharingLoopType::DistributeStaticLoop) {
5506 Constant *DistScheduleSchedType = ConstantInt::get(
5507 Ty: I32Type, V: static_cast<int>(omp::OMPScheduleType::OrderedDistribute));
5508 // We want to emit a second init function call for the dist_schedule clause
5509 // to the Distribute construct. This should only be done however if a
5510 // Workshare Loop is nested within a Distribute Construct
5511 BuildInitCall(DistScheduleSchedType, Builder);
5512 }
5513 Value *LowerBound = Builder.CreateLoad(Ty: IVTy, Ptr: PLowerBound);
5514 Value *InclusiveUpperBound = Builder.CreateLoad(Ty: IVTy, Ptr: PUpperBound);
5515 Value *TripCountMinusOne = Builder.CreateSub(LHS: InclusiveUpperBound, RHS: LowerBound);
5516 Value *TripCount = Builder.CreateAdd(LHS: TripCountMinusOne, RHS: One);
5517 CLI->setTripCount(TripCount);
5518
5519 // Update all uses of the induction variable except the one in the condition
5520 // block that compares it with the actual upper bound, and the increment in
5521 // the latch block.
5522
5523 CLI->mapIndVar(Updater: [&](Instruction *OldIV) -> Value * {
5524 Builder.SetInsertPoint(TheBB: CLI->getBody(),
5525 IP: CLI->getBody()->getFirstInsertionPt());
5526 Builder.SetCurrentDebugLocation(DL);
5527 return Builder.CreateAdd(LHS: OldIV, RHS: LowerBound);
5528 });
5529
5530 // In the "exit" block, call the "fini" function.
5531 Builder.SetInsertPoint(TheBB: CLI->getExit(),
5532 IP: CLI->getExit()->getTerminator()->getIterator());
5533 createRuntimeFunctionCall(Callee: StaticFini, Args: {SrcLoc, ThreadNum});
5534
5535 // Add the barrier if requested.
5536 if (NeedsBarrier) {
5537 InsertPointOrErrorTy BarrierIP =
5538 createBarrier(Loc: LocationDescription(Builder.saveIP(), DL),
5539 Kind: omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
5540 /* CheckCancelFlag */ false);
5541 if (!BarrierIP)
5542 return BarrierIP.takeError();
5543 }
5544
5545 InsertPointTy AfterIP = CLI->getAfterIP();
5546 CLI->invalidate();
5547
5548 return AfterIP;
5549}
5550
5551static void addAccessGroupMetadata(BasicBlock *Block, MDNode *AccessGroup,
5552 LoopInfo &LI);
5553static void addLoopMetadata(CanonicalLoopInfo *Loop,
5554 ArrayRef<Metadata *> Properties);
5555
5556static void applyParallelAccessesMetadata(CanonicalLoopInfo *CLI,
5557 LLVMContext &Ctx, Loop *Loop,
5558 LoopInfo &LoopInfo,
5559 SmallVector<Metadata *> &LoopMDList) {
5560 SmallSet<BasicBlock *, 8> Reachable;
5561
5562 // Get the basic blocks from the loop in which memref instructions
5563 // can be found.
5564 // TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
5565 // preferably without running any passes.
5566 for (BasicBlock *Block : Loop->getBlocks()) {
5567 if (Block == CLI->getCond() || Block == CLI->getHeader())
5568 continue;
5569 Reachable.insert(Ptr: Block);
5570 }
5571
5572 // Add access group metadata to memory-access instructions.
5573 MDNode *AccessGroup = MDNode::getDistinct(Context&: Ctx, MDs: {});
5574 for (BasicBlock *BB : Reachable)
5575 addAccessGroupMetadata(Block: BB, AccessGroup, LI&: LoopInfo);
5576 // TODO: If the loop has existing parallel access metadata, have
5577 // to combine two lists.
5578 LoopMDList.push_back(Elt: MDNode::get(
5579 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.parallel_accesses"), AccessGroup}));
5580}
5581
5582OpenMPIRBuilder::InsertPointOrErrorTy
5583OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
5584 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
5585 bool NeedsBarrier, Value *ChunkSize, OMPScheduleType SchedType,
5586 Value *DistScheduleChunkSize, OMPScheduleType DistScheduleSchedType) {
5587 assert(CLI->isValid() && "Requires a valid canonical loop");
5588 assert((ChunkSize || DistScheduleChunkSize) && "Chunk size is required");
5589
5590 LLVMContext &Ctx = CLI->getFunction()->getContext();
5591 Value *IV = CLI->getIndVar();
5592 Value *OrigTripCount = CLI->getTripCount();
5593 Type *IVTy = IV->getType();
5594 assert(IVTy->getIntegerBitWidth() <= 64 &&
5595 "Max supported tripcount bitwidth is 64 bits");
5596 Type *InternalIVTy = IVTy->getIntegerBitWidth() <= 32 ? Type::getInt32Ty(C&: Ctx)
5597 : Type::getInt64Ty(C&: Ctx);
5598 Type *I32Type = Type::getInt32Ty(C&: M.getContext());
5599 Constant *Zero = ConstantInt::get(Ty: InternalIVTy, V: 0);
5600 Constant *One = ConstantInt::get(Ty: InternalIVTy, V: 1);
5601
5602 Function *F = CLI->getFunction();
5603 // Blocks must have terminators.
5604 // FIXME: Don't run analyses on incomplete/invalid IR.
5605 SmallVector<Instruction *> UIs;
5606 for (BasicBlock &BB : *F)
5607 if (!BB.getTerminator())
5608 UIs.push_back(Elt: new UnreachableInst(F->getContext(), &BB));
5609 FunctionAnalysisManager FAM;
5610 FAM.registerPass(PassBuilder: []() { return DominatorTreeAnalysis(); });
5611 FAM.registerPass(PassBuilder: []() { return PassInstrumentationAnalysis(); });
5612 LoopAnalysis LIA;
5613 LoopInfo &&LI = LIA.run(F&: *F, AM&: FAM);
5614 for (Instruction *I : UIs)
5615 I->eraseFromParent();
5616 Loop *L = LI.getLoopFor(BB: CLI->getHeader());
5617 SmallVector<Metadata *> LoopMDList;
5618 if (ChunkSize || DistScheduleChunkSize)
5619 applyParallelAccessesMetadata(CLI, Ctx, Loop: L, LoopInfo&: LI, LoopMDList);
5620 addLoopMetadata(Loop: CLI, Properties: LoopMDList);
5621
5622 // Declare useful OpenMP runtime functions.
5623 FunctionCallee StaticInit =
5624 getKmpcForStaticInitForType(Ty: InternalIVTy, M, OMPBuilder&: *this);
5625 FunctionCallee StaticFini =
5626 getOrCreateRuntimeFunction(M, FnID: omp::OMPRTL___kmpc_for_static_fini);
5627
5628 // Allocate space for computed loop bounds as expected by the "init" function.
5629 Builder.restoreIP(IP: AllocaIP);
5630 Builder.SetCurrentDebugLocation(DL);
5631 Value *PLastIter = Builder.CreateAlloca(Ty: I32Type, ArraySize: nullptr, Name: "p.lastiter");
5632 Value *PLowerBound =
5633 Builder.CreateAlloca(Ty: InternalIVTy, ArraySize: nullptr, Name: "p.lowerbound");
5634 Value *PUpperBound =
5635 Builder.CreateAlloca(Ty: InternalIVTy, ArraySize: nullptr, Name: "p.upperbound");
5636 Value *PStride = Builder.CreateAlloca(Ty: InternalIVTy, ArraySize: nullptr, Name: "p.stride");
5637 CLI->setLastIter(PLastIter);
5638
5639 // Set up the source location value for the OpenMP runtime.
5640 Builder.restoreIP(IP: CLI->getPreheaderIP());
5641 Builder.SetCurrentDebugLocation(DL);
5642
5643 // TODO: Detect overflow in ubsan or max-out with current tripcount.
5644 Value *CastedChunkSize = Builder.CreateZExtOrTrunc(
5645 V: ChunkSize ? ChunkSize : Zero, DestTy: InternalIVTy, Name: "chunksize");
5646 Value *CastedDistScheduleChunkSize = Builder.CreateZExtOrTrunc(
5647 V: DistScheduleChunkSize ? DistScheduleChunkSize : Zero, DestTy: InternalIVTy,
5648 Name: "distschedulechunksize");
5649 Value *CastedTripCount =
5650 Builder.CreateZExt(V: OrigTripCount, DestTy: InternalIVTy, Name: "tripcount");
5651
5652 Constant *SchedulingType =
5653 ConstantInt::get(Ty: I32Type, V: static_cast<int>(SchedType));
5654 Constant *DistSchedulingType =
5655 ConstantInt::get(Ty: I32Type, V: static_cast<int>(DistScheduleSchedType));
5656 Builder.CreateStore(Val: Zero, Ptr: PLowerBound);
5657 Value *OrigUpperBound = Builder.CreateSub(LHS: CastedTripCount, RHS: One);
5658 Value *IsTripCountZero = Builder.CreateICmpEQ(LHS: CastedTripCount, RHS: Zero);
5659 Value *UpperBound =
5660 Builder.CreateSelect(C: IsTripCountZero, True: Zero, False: OrigUpperBound);
5661 Builder.CreateStore(Val: UpperBound, Ptr: PUpperBound);
5662 Builder.CreateStore(Val: One, Ptr: PStride);
5663
5664 // Call the "init" function and update the trip count of the loop with the
5665 // value it produced.
5666 uint32_t SrcLocStrSize;
5667 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
5668 Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5669 Value *ThreadNum = getOrCreateThreadID(Ident: SrcLoc);
5670 auto BuildInitCall = [StaticInit, SrcLoc, ThreadNum, PLastIter, PLowerBound,
5671 PUpperBound, PStride, One,
5672 this](Value *SchedulingType, Value *ChunkSize,
5673 auto &Builder) {
5674 createRuntimeFunctionCall(
5675 Callee: StaticInit, Args: {/*loc=*/SrcLoc, /*global_tid=*/ThreadNum,
5676 /*schedtype=*/SchedulingType, /*plastiter=*/PLastIter,
5677 /*plower=*/PLowerBound, /*pupper=*/PUpperBound,
5678 /*pstride=*/PStride, /*incr=*/One,
5679 /*chunk=*/ChunkSize});
5680 };
5681 BuildInitCall(SchedulingType, CastedChunkSize, Builder);
5682 if (DistScheduleSchedType != OMPScheduleType::None &&
5683 SchedType != OMPScheduleType::OrderedDistributeChunked &&
5684 SchedType != OMPScheduleType::OrderedDistribute) {
5685 // We want to emit a second init function call for the dist_schedule clause
5686 // to the Distribute construct. This should only be done however if a
5687 // Workshare Loop is nested within a Distribute Construct
5688 BuildInitCall(DistSchedulingType, CastedDistScheduleChunkSize, Builder);
5689 }
5690
5691 // Load values written by the "init" function.
5692 Value *FirstChunkStart =
5693 Builder.CreateLoad(Ty: InternalIVTy, Ptr: PLowerBound, Name: "omp_firstchunk.lb");
5694 Value *FirstChunkStop =
5695 Builder.CreateLoad(Ty: InternalIVTy, Ptr: PUpperBound, Name: "omp_firstchunk.ub");
5696 Value *FirstChunkEnd = Builder.CreateAdd(LHS: FirstChunkStop, RHS: One);
5697 Value *ChunkRange =
5698 Builder.CreateSub(LHS: FirstChunkEnd, RHS: FirstChunkStart, Name: "omp_chunk.range");
5699 Value *NextChunkStride =
5700 Builder.CreateLoad(Ty: InternalIVTy, Ptr: PStride, Name: "omp_dispatch.stride");
5701
5702 // Create outer "dispatch" loop for enumerating the chunks.
5703 BasicBlock *DispatchEnter = splitBB(Builder, CreateBranch: true);
5704 Value *DispatchCounter;
5705
5706 // It is safe to assume this didn't return an error because the callback
5707 // passed into createCanonicalLoop is the only possible error source, and it
5708 // always returns success.
5709 CanonicalLoopInfo *DispatchCLI = cantFail(ValOrErr: createCanonicalLoop(
5710 Loc: {Builder.saveIP(), DL},
5711 BodyGenCB: [&](InsertPointTy BodyIP, Value *Counter) {
5712 DispatchCounter = Counter;
5713 return Error::success();
5714 },
5715 Start: FirstChunkStart, Stop: CastedTripCount, Step: NextChunkStride,
5716 /*IsSigned=*/false, /*InclusiveStop=*/false, /*ComputeIP=*/{},
5717 Name: "dispatch"));
5718
5719 // Remember the BasicBlocks of the dispatch loop we need, then invalidate to
5720 // not have to preserve the canonical invariant.
5721 BasicBlock *DispatchBody = DispatchCLI->getBody();
5722 BasicBlock *DispatchLatch = DispatchCLI->getLatch();
5723 BasicBlock *DispatchExit = DispatchCLI->getExit();
5724 BasicBlock *DispatchAfter = DispatchCLI->getAfter();
5725 DispatchCLI->invalidate();
5726
5727 // Rewire the original loop to become the chunk loop inside the dispatch loop.
5728 redirectTo(Source: DispatchAfter, Target: CLI->getAfter(), DL);
5729 redirectTo(Source: CLI->getExit(), Target: DispatchLatch, DL);
5730 redirectTo(Source: DispatchBody, Target: DispatchEnter, DL);
5731
5732 // Prepare the prolog of the chunk loop.
5733 Builder.restoreIP(IP: CLI->getPreheaderIP());
5734 Builder.SetCurrentDebugLocation(DL);
5735
5736 // Compute the number of iterations of the chunk loop.
5737 Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
5738 Value *ChunkEnd = Builder.CreateAdd(LHS: DispatchCounter, RHS: ChunkRange);
5739 Value *IsLastChunk =
5740 Builder.CreateICmpUGE(LHS: ChunkEnd, RHS: CastedTripCount, Name: "omp_chunk.is_last");
5741 Value *CountUntilOrigTripCount =
5742 Builder.CreateSub(LHS: CastedTripCount, RHS: DispatchCounter);
5743 Value *ChunkTripCount = Builder.CreateSelect(
5744 C: IsLastChunk, True: CountUntilOrigTripCount, False: ChunkRange, Name: "omp_chunk.tripcount");
5745 Value *BackcastedChunkTC =
5746 Builder.CreateTrunc(V: ChunkTripCount, DestTy: IVTy, Name: "omp_chunk.tripcount.trunc");
5747 CLI->setTripCount(BackcastedChunkTC);
5748
5749 // Update all uses of the induction variable except the one in the condition
5750 // block that compares it with the actual upper bound, and the increment in
5751 // the latch block.
5752 Value *BackcastedDispatchCounter =
5753 Builder.CreateTrunc(V: DispatchCounter, DestTy: IVTy, Name: "omp_dispatch.iv.trunc");
5754 CLI->mapIndVar(Updater: [&](Instruction *) -> Value * {
5755 Builder.restoreIP(IP: CLI->getBodyIP());
5756 return Builder.CreateAdd(LHS: IV, RHS: BackcastedDispatchCounter);
5757 });
5758
5759 // In the "exit" block, call the "fini" function.
5760 Builder.SetInsertPoint(TheBB: DispatchExit, IP: DispatchExit->getFirstInsertionPt());
5761 createRuntimeFunctionCall(Callee: StaticFini, Args: {SrcLoc, ThreadNum});
5762
5763 // Add the barrier if requested.
5764 if (NeedsBarrier) {
5765 InsertPointOrErrorTy AfterIP =
5766 createBarrier(Loc: LocationDescription(Builder.saveIP(), DL), Kind: OMPD_for,
5767 /*ForceSimpleCall=*/false, /*CheckCancelFlag=*/false);
5768 if (!AfterIP)
5769 return AfterIP.takeError();
5770 }
5771
5772#ifndef NDEBUG
5773 // Even though we currently do not support applying additional methods to it,
5774 // the chunk loop should remain a canonical loop.
5775 CLI->assertOK();
5776#endif
5777
5778 return InsertPointTy(DispatchAfter, DispatchAfter->getFirstInsertionPt());
5779}
5780
5781// Returns an LLVM function to call for executing an OpenMP static worksharing
5782// for loop depending on `type`. Only i32 and i64 are supported by the runtime.
5783// Always interpret integers as unsigned similarly to CanonicalLoopInfo.
5784static FunctionCallee
5785getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
5786 WorksharingLoopType LoopType) {
5787 unsigned Bitwidth = Ty->getIntegerBitWidth();
5788 Module &M = OMPBuilder->M;
5789 switch (LoopType) {
5790 case WorksharingLoopType::ForStaticLoop:
5791 if (Bitwidth == 32)
5792 return OMPBuilder->getOrCreateRuntimeFunction(
5793 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_4u);
5794 if (Bitwidth == 64)
5795 return OMPBuilder->getOrCreateRuntimeFunction(
5796 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_8u);
5797 break;
5798 case WorksharingLoopType::DistributeStaticLoop:
5799 if (Bitwidth == 32)
5800 return OMPBuilder->getOrCreateRuntimeFunction(
5801 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_4u);
5802 if (Bitwidth == 64)
5803 return OMPBuilder->getOrCreateRuntimeFunction(
5804 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_8u);
5805 break;
5806 case WorksharingLoopType::DistributeForStaticLoop:
5807 if (Bitwidth == 32)
5808 return OMPBuilder->getOrCreateRuntimeFunction(
5809 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_4u);
5810 if (Bitwidth == 64)
5811 return OMPBuilder->getOrCreateRuntimeFunction(
5812 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_8u);
5813 break;
5814 }
5815 if (Bitwidth != 32 && Bitwidth != 64) {
5816 llvm_unreachable("Unknown OpenMP loop iterator bitwidth");
5817 }
5818 llvm_unreachable("Unknown type of OpenMP worksharing loop");
5819}
5820
5821// Inserts a call to proper OpenMP Device RTL function which handles
5822// loop worksharing.
5823static void createTargetLoopWorkshareCall(OpenMPIRBuilder *OMPBuilder,
5824 WorksharingLoopType LoopType,
5825 BasicBlock *InsertBlock, Value *Ident,
5826 Value *LoopBodyArg, Value *TripCount,
5827 Function &LoopBodyFn, bool NoLoop) {
5828 Type *TripCountTy = TripCount->getType();
5829 Module &M = OMPBuilder->M;
5830 IRBuilder<> &Builder = OMPBuilder->Builder;
5831 FunctionCallee RTLFn =
5832 getKmpcForStaticLoopForType(Ty: TripCountTy, OMPBuilder, LoopType);
5833 SmallVector<Value *, 8> RealArgs;
5834 RealArgs.push_back(Elt: Ident);
5835 RealArgs.push_back(Elt: &LoopBodyFn);
5836 RealArgs.push_back(Elt: LoopBodyArg);
5837 RealArgs.push_back(Elt: TripCount);
5838 if (LoopType == WorksharingLoopType::DistributeStaticLoop) {
5839 RealArgs.push_back(Elt: ConstantInt::get(Ty: TripCountTy, V: 0));
5840 RealArgs.push_back(Elt: ConstantInt::get(Ty: Builder.getInt8Ty(), V: 0));
5841 Builder.restoreIP(IP: {InsertBlock, std::prev(x: InsertBlock->end())});
5842 OMPBuilder->createRuntimeFunctionCall(Callee: RTLFn, Args: RealArgs);
5843 return;
5844 }
5845 FunctionCallee RTLNumThreads = OMPBuilder->getOrCreateRuntimeFunction(
5846 M, FnID: omp::RuntimeFunction::OMPRTL_omp_get_num_threads);
5847 Builder.restoreIP(IP: {InsertBlock, std::prev(x: InsertBlock->end())});
5848 Value *NumThreads = OMPBuilder->createRuntimeFunctionCall(Callee: RTLNumThreads, Args: {});
5849
5850 RealArgs.push_back(
5851 Elt: Builder.CreateZExtOrTrunc(V: NumThreads, DestTy: TripCountTy, Name: "num.threads.cast"));
5852 RealArgs.push_back(Elt: ConstantInt::get(Ty: TripCountTy, V: 0));
5853 if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
5854 RealArgs.push_back(Elt: ConstantInt::get(Ty: TripCountTy, V: 0));
5855 RealArgs.push_back(Elt: ConstantInt::get(Ty: Builder.getInt8Ty(), V: NoLoop));
5856 } else {
5857 RealArgs.push_back(Elt: ConstantInt::get(Ty: Builder.getInt8Ty(), V: 0));
5858 }
5859
5860 OMPBuilder->createRuntimeFunctionCall(Callee: RTLFn, Args: RealArgs);
5861}
5862
5863static void workshareLoopTargetCallback(
5864 OpenMPIRBuilder *OMPIRBuilder, CanonicalLoopInfo *CLI, Value *Ident,
5865 Function &OutlinedFn, const SmallVector<Instruction *, 4> &ToBeDeleted,
5866 WorksharingLoopType LoopType, bool NoLoop) {
5867 IRBuilder<> &Builder = OMPIRBuilder->Builder;
5868 BasicBlock *Preheader = CLI->getPreheader();
5869 Value *TripCount = CLI->getTripCount();
5870
5871 // After loop body outling, the loop body contains only set up
5872 // of loop body argument structure and the call to the outlined
5873 // loop body function. Firstly, we need to move setup of loop body args
5874 // into loop preheader.
5875 Preheader->splice(ToIt: std::prev(x: Preheader->end()), FromBB: CLI->getBody(),
5876 FromBeginIt: CLI->getBody()->begin(), FromEndIt: std::prev(x: CLI->getBody()->end()));
5877
5878 // The next step is to remove the whole loop. We do not it need anymore.
5879 // That's why make an unconditional branch from loop preheader to loop
5880 // exit block
5881 Builder.restoreIP(IP: {Preheader, Preheader->end()});
5882 Builder.SetCurrentDebugLocation(Preheader->getTerminator()->getDebugLoc());
5883 Preheader->getTerminator()->eraseFromParent();
5884 Builder.CreateBr(Dest: CLI->getExit());
5885
5886 // Delete dead loop blocks
5887 OpenMPIRBuilder::OutlineInfo CleanUpInfo;
5888 SmallPtrSet<BasicBlock *, 32> RegionBlockSet;
5889 SmallVector<BasicBlock *, 32> BlocksToBeRemoved;
5890 CleanUpInfo.EntryBB = CLI->getHeader();
5891 CleanUpInfo.ExitBB = CLI->getExit();
5892 CleanUpInfo.collectBlocks(BlockSet&: RegionBlockSet, BlockVector&: BlocksToBeRemoved);
5893 DeleteDeadBlocks(BBs: BlocksToBeRemoved);
5894
5895 // Find the instruction which corresponds to loop body argument structure
5896 // and remove the call to loop body function instruction.
5897 Value *LoopBodyArg;
5898 User *OutlinedFnUser = OutlinedFn.getUniqueUndroppableUser();
5899 assert(OutlinedFnUser &&
5900 "Expected unique undroppable user of outlined function");
5901 CallInst *OutlinedFnCallInstruction = dyn_cast<CallInst>(Val: OutlinedFnUser);
5902 assert(OutlinedFnCallInstruction && "Expected outlined function call");
5903 assert((OutlinedFnCallInstruction->getParent() == Preheader) &&
5904 "Expected outlined function call to be located in loop preheader");
5905 // Check in case no argument structure has been passed.
5906 if (OutlinedFnCallInstruction->arg_size() > 1)
5907 LoopBodyArg = OutlinedFnCallInstruction->getArgOperand(i: 1);
5908 else
5909 LoopBodyArg = Constant::getNullValue(Ty: Builder.getPtrTy());
5910 OutlinedFnCallInstruction->eraseFromParent();
5911
5912 createTargetLoopWorkshareCall(OMPBuilder: OMPIRBuilder, LoopType, InsertBlock: Preheader, Ident,
5913 LoopBodyArg, TripCount, LoopBodyFn&: OutlinedFn, NoLoop);
5914
5915 for (auto &ToBeDeletedItem : ToBeDeleted)
5916 ToBeDeletedItem->eraseFromParent();
5917 CLI->invalidate();
5918}
5919
5920OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoopTarget(
5921 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
5922 WorksharingLoopType LoopType, bool NoLoop) {
5923 uint32_t SrcLocStrSize;
5924 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
5925 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5926
5927 OutlineInfo OI;
5928 OI.OuterAllocaBB = CLI->getPreheader();
5929 Function *OuterFn = CLI->getPreheader()->getParent();
5930
5931 // Instructions which need to be deleted at the end of code generation
5932 SmallVector<Instruction *, 4> ToBeDeleted;
5933
5934 OI.OuterAllocaBB = AllocaIP.getBlock();
5935
5936 // Mark the body loop as region which needs to be extracted
5937 OI.EntryBB = CLI->getBody();
5938 OI.ExitBB = CLI->getLatch()->splitBasicBlockBefore(I: CLI->getLatch()->begin(),
5939 BBName: "omp.prelatch");
5940
5941 // Prepare loop body for extraction
5942 Builder.restoreIP(IP: {CLI->getPreheader(), CLI->getPreheader()->begin()});
5943
5944 // Insert new loop counter variable which will be used only in loop
5945 // body.
5946 AllocaInst *NewLoopCnt = Builder.CreateAlloca(Ty: CLI->getIndVarType(), ArraySize: 0, Name: "");
5947 Instruction *NewLoopCntLoad =
5948 Builder.CreateLoad(Ty: CLI->getIndVarType(), Ptr: NewLoopCnt);
5949 // New loop counter instructions are redundant in the loop preheader when
5950 // code generation for workshare loop is finshed. That's why mark them as
5951 // ready for deletion.
5952 ToBeDeleted.push_back(Elt: NewLoopCntLoad);
5953 ToBeDeleted.push_back(Elt: NewLoopCnt);
5954
5955 // Analyse loop body region. Find all input variables which are used inside
5956 // loop body region.
5957 SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
5958 SmallVector<BasicBlock *, 32> Blocks;
5959 OI.collectBlocks(BlockSet&: ParallelRegionBlockSet, BlockVector&: Blocks);
5960
5961 CodeExtractorAnalysisCache CEAC(*OuterFn);
5962 CodeExtractor Extractor(Blocks,
5963 /* DominatorTree */ nullptr,
5964 /* AggregateArgs */ true,
5965 /* BlockFrequencyInfo */ nullptr,
5966 /* BranchProbabilityInfo */ nullptr,
5967 /* AssumptionCache */ nullptr,
5968 /* AllowVarArgs */ true,
5969 /* AllowAlloca */ true,
5970 /* AllocationBlock */ CLI->getPreheader(),
5971 /* Suffix */ ".omp_wsloop",
5972 /* AggrArgsIn0AddrSpace */ true);
5973
5974 BasicBlock *CommonExit = nullptr;
5975 SetVector<Value *> SinkingCands, HoistingCands;
5976
5977 // Find allocas outside the loop body region which are used inside loop
5978 // body
5979 Extractor.findAllocas(CEAC, SinkCands&: SinkingCands, HoistCands&: HoistingCands, ExitBlock&: CommonExit);
5980
5981 // We need to model loop body region as the function f(cnt, loop_arg).
5982 // That's why we replace loop induction variable by the new counter
5983 // which will be one of loop body function argument
5984 SmallVector<User *> Users(CLI->getIndVar()->user_begin(),
5985 CLI->getIndVar()->user_end());
5986 for (auto Use : Users) {
5987 if (Instruction *Inst = dyn_cast<Instruction>(Val: Use)) {
5988 if (ParallelRegionBlockSet.count(Ptr: Inst->getParent())) {
5989 Inst->replaceUsesOfWith(From: CLI->getIndVar(), To: NewLoopCntLoad);
5990 }
5991 }
5992 }
5993 // Make sure that loop counter variable is not merged into loop body
5994 // function argument structure and it is passed as separate variable
5995 OI.ExcludeArgsFromAggregate.push_back(Elt: NewLoopCntLoad);
5996
5997 // PostOutline CB is invoked when loop body function is outlined and
5998 // loop body is replaced by call to outlined function. We need to add
5999 // call to OpenMP device rtl inside loop preheader. OpenMP device rtl
6000 // function will handle loop control logic.
6001 //
6002 OI.PostOutlineCB = [=, ToBeDeletedVec =
6003 std::move(ToBeDeleted)](Function &OutlinedFn) {
6004 workshareLoopTargetCallback(OMPIRBuilder: this, CLI, Ident, OutlinedFn, ToBeDeleted: ToBeDeletedVec,
6005 LoopType, NoLoop);
6006 };
6007 addOutlineInfo(OI: std::move(OI));
6008 return CLI->getAfterIP();
6009}
6010
6011OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyWorkshareLoop(
6012 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
6013 bool NeedsBarrier, omp::ScheduleKind SchedKind, Value *ChunkSize,
6014 bool HasSimdModifier, bool HasMonotonicModifier,
6015 bool HasNonmonotonicModifier, bool HasOrderedClause,
6016 WorksharingLoopType LoopType, bool NoLoop, bool HasDistSchedule,
6017 Value *DistScheduleChunkSize) {
6018 if (Config.isTargetDevice())
6019 return applyWorkshareLoopTarget(DL, CLI, AllocaIP, LoopType, NoLoop);
6020 OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType(
6021 ClauseKind: SchedKind, HasChunks: ChunkSize, HasSimdModifier, HasMonotonicModifier,
6022 HasNonmonotonicModifier, HasOrderedClause, HasDistScheduleChunks: DistScheduleChunkSize);
6023
6024 bool IsOrdered = (EffectiveScheduleType & OMPScheduleType::ModifierOrdered) ==
6025 OMPScheduleType::ModifierOrdered;
6026 OMPScheduleType DistScheduleSchedType = OMPScheduleType::None;
6027 if (HasDistSchedule) {
6028 DistScheduleSchedType = DistScheduleChunkSize
6029 ? OMPScheduleType::OrderedDistributeChunked
6030 : OMPScheduleType::OrderedDistribute;
6031 }
6032 switch (EffectiveScheduleType & ~OMPScheduleType::ModifierMask) {
6033 case OMPScheduleType::BaseStatic:
6034 case OMPScheduleType::BaseDistribute:
6035 assert((!ChunkSize || !DistScheduleChunkSize) &&
6036 "No chunk size with static-chunked schedule");
6037 if (IsOrdered && !HasDistSchedule)
6038 return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, SchedType: EffectiveScheduleType,
6039 NeedsBarrier, Chunk: ChunkSize);
6040 // FIXME: Monotonicity ignored?
6041 if (DistScheduleChunkSize)
6042 return applyStaticChunkedWorkshareLoop(
6043 DL, CLI, AllocaIP, NeedsBarrier, ChunkSize, SchedType: EffectiveScheduleType,
6044 DistScheduleChunkSize, DistScheduleSchedType);
6045 return applyStaticWorkshareLoop(DL, CLI, AllocaIP, LoopType, NeedsBarrier,
6046 HasDistSchedule);
6047
6048 case OMPScheduleType::BaseStaticChunked:
6049 case OMPScheduleType::BaseDistributeChunked:
6050 if (IsOrdered && !HasDistSchedule)
6051 return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, SchedType: EffectiveScheduleType,
6052 NeedsBarrier, Chunk: ChunkSize);
6053 // FIXME: Monotonicity ignored?
6054 return applyStaticChunkedWorkshareLoop(
6055 DL, CLI, AllocaIP, NeedsBarrier, ChunkSize, SchedType: EffectiveScheduleType,
6056 DistScheduleChunkSize, DistScheduleSchedType);
6057
6058 case OMPScheduleType::BaseRuntime:
6059 case OMPScheduleType::BaseAuto:
6060 case OMPScheduleType::BaseGreedy:
6061 case OMPScheduleType::BaseBalanced:
6062 case OMPScheduleType::BaseSteal:
6063 case OMPScheduleType::BaseRuntimeSimd:
6064 assert(!ChunkSize &&
6065 "schedule type does not support user-defined chunk sizes");
6066 [[fallthrough]];
6067 case OMPScheduleType::BaseGuidedSimd:
6068 case OMPScheduleType::BaseDynamicChunked:
6069 case OMPScheduleType::BaseGuidedChunked:
6070 case OMPScheduleType::BaseGuidedIterativeChunked:
6071 case OMPScheduleType::BaseGuidedAnalyticalChunked:
6072 case OMPScheduleType::BaseStaticBalancedChunked:
6073 return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, SchedType: EffectiveScheduleType,
6074 NeedsBarrier, Chunk: ChunkSize);
6075
6076 default:
6077 llvm_unreachable("Unknown/unimplemented schedule kind");
6078 }
6079}
6080
6081/// Returns an LLVM function to call for initializing loop bounds using OpenMP
6082/// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
6083/// the runtime. Always interpret integers as unsigned similarly to
6084/// CanonicalLoopInfo.
6085static FunctionCallee
6086getKmpcForDynamicInitForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
6087 unsigned Bitwidth = Ty->getIntegerBitWidth();
6088 if (Bitwidth == 32)
6089 return OMPBuilder.getOrCreateRuntimeFunction(
6090 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_4u);
6091 if (Bitwidth == 64)
6092 return OMPBuilder.getOrCreateRuntimeFunction(
6093 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_8u);
6094 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
6095}
6096
6097/// Returns an LLVM function to call for updating the next loop using OpenMP
6098/// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
6099/// the runtime. Always interpret integers as unsigned similarly to
6100/// CanonicalLoopInfo.
6101static FunctionCallee
6102getKmpcForDynamicNextForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
6103 unsigned Bitwidth = Ty->getIntegerBitWidth();
6104 if (Bitwidth == 32)
6105 return OMPBuilder.getOrCreateRuntimeFunction(
6106 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_4u);
6107 if (Bitwidth == 64)
6108 return OMPBuilder.getOrCreateRuntimeFunction(
6109 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_8u);
6110 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
6111}
6112
6113/// Returns an LLVM function to call for finalizing the dynamic loop using
6114/// depending on `type`. Only i32 and i64 are supported by the runtime. Always
6115/// interpret integers as unsigned similarly to CanonicalLoopInfo.
6116static FunctionCallee
6117getKmpcForDynamicFiniForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
6118 unsigned Bitwidth = Ty->getIntegerBitWidth();
6119 if (Bitwidth == 32)
6120 return OMPBuilder.getOrCreateRuntimeFunction(
6121 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_4u);
6122 if (Bitwidth == 64)
6123 return OMPBuilder.getOrCreateRuntimeFunction(
6124 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_8u);
6125 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
6126}
6127
6128OpenMPIRBuilder::InsertPointOrErrorTy
6129OpenMPIRBuilder::applyDynamicWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
6130 InsertPointTy AllocaIP,
6131 OMPScheduleType SchedType,
6132 bool NeedsBarrier, Value *Chunk) {
6133 assert(CLI->isValid() && "Requires a valid canonical loop");
6134 assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
6135 "Require dedicated allocate IP");
6136 assert(isValidWorkshareLoopScheduleType(SchedType) &&
6137 "Require valid schedule type");
6138
6139 bool Ordered = (SchedType & OMPScheduleType::ModifierOrdered) ==
6140 OMPScheduleType::ModifierOrdered;
6141
6142 // Set up the source location value for OpenMP runtime.
6143 Builder.SetCurrentDebugLocation(DL);
6144
6145 uint32_t SrcLocStrSize;
6146 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
6147 Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6148
6149 // Declare useful OpenMP runtime functions.
6150 Value *IV = CLI->getIndVar();
6151 Type *IVTy = IV->getType();
6152 FunctionCallee DynamicInit = getKmpcForDynamicInitForType(Ty: IVTy, M, OMPBuilder&: *this);
6153 FunctionCallee DynamicNext = getKmpcForDynamicNextForType(Ty: IVTy, M, OMPBuilder&: *this);
6154
6155 // Allocate space for computed loop bounds as expected by the "init" function.
6156 Builder.SetInsertPoint(AllocaIP.getBlock()->getFirstNonPHIOrDbgOrAlloca());
6157 Type *I32Type = Type::getInt32Ty(C&: M.getContext());
6158 Value *PLastIter = Builder.CreateAlloca(Ty: I32Type, ArraySize: nullptr, Name: "p.lastiter");
6159 Value *PLowerBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.lowerbound");
6160 Value *PUpperBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.upperbound");
6161 Value *PStride = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.stride");
6162 CLI->setLastIter(PLastIter);
6163
6164 // At the end of the preheader, prepare for calling the "init" function by
6165 // storing the current loop bounds into the allocated space. A canonical loop
6166 // always iterates from 0 to trip-count with step 1. Note that "init" expects
6167 // and produces an inclusive upper bound.
6168 BasicBlock *PreHeader = CLI->getPreheader();
6169 Builder.SetInsertPoint(PreHeader->getTerminator());
6170 Constant *One = ConstantInt::get(Ty: IVTy, V: 1);
6171 Builder.CreateStore(Val: One, Ptr: PLowerBound);
6172 Value *UpperBound = CLI->getTripCount();
6173 Builder.CreateStore(Val: UpperBound, Ptr: PUpperBound);
6174 Builder.CreateStore(Val: One, Ptr: PStride);
6175
6176 BasicBlock *Header = CLI->getHeader();
6177 BasicBlock *Exit = CLI->getExit();
6178 BasicBlock *Cond = CLI->getCond();
6179 BasicBlock *Latch = CLI->getLatch();
6180 InsertPointTy AfterIP = CLI->getAfterIP();
6181
6182 // The CLI will be "broken" in the code below, as the loop is no longer
6183 // a valid canonical loop.
6184
6185 if (!Chunk)
6186 Chunk = One;
6187
6188 Value *ThreadNum = getOrCreateThreadID(Ident: SrcLoc);
6189
6190 Constant *SchedulingType =
6191 ConstantInt::get(Ty: I32Type, V: static_cast<int>(SchedType));
6192
6193 // Call the "init" function.
6194 createRuntimeFunctionCall(Callee: DynamicInit, Args: {SrcLoc, ThreadNum, SchedulingType,
6195 /* LowerBound */ One, UpperBound,
6196 /* step */ One, Chunk});
6197
6198 // An outer loop around the existing one.
6199 BasicBlock *OuterCond = BasicBlock::Create(
6200 Context&: PreHeader->getContext(), Name: Twine(PreHeader->getName()) + ".outer.cond",
6201 Parent: PreHeader->getParent());
6202 // This needs to be 32-bit always, so can't use the IVTy Zero above.
6203 Builder.SetInsertPoint(TheBB: OuterCond, IP: OuterCond->getFirstInsertionPt());
6204 Value *Res = createRuntimeFunctionCall(
6205 Callee: DynamicNext,
6206 Args: {SrcLoc, ThreadNum, PLastIter, PLowerBound, PUpperBound, PStride});
6207 Constant *Zero32 = ConstantInt::get(Ty: I32Type, V: 0);
6208 Value *MoreWork = Builder.CreateCmp(Pred: CmpInst::ICMP_NE, LHS: Res, RHS: Zero32);
6209 Value *LowerBound =
6210 Builder.CreateSub(LHS: Builder.CreateLoad(Ty: IVTy, Ptr: PLowerBound), RHS: One, Name: "lb");
6211 Builder.CreateCondBr(Cond: MoreWork, True: Header, False: Exit);
6212
6213 // Change PHI-node in loop header to use outer cond rather than preheader,
6214 // and set IV to the LowerBound.
6215 Instruction *Phi = &Header->front();
6216 auto *PI = cast<PHINode>(Val: Phi);
6217 PI->setIncomingBlock(i: 0, BB: OuterCond);
6218 PI->setIncomingValue(i: 0, V: LowerBound);
6219
6220 // Then set the pre-header to jump to the OuterCond
6221 Instruction *Term = PreHeader->getTerminator();
6222 auto *Br = cast<BranchInst>(Val: Term);
6223 Br->setSuccessor(Idx: 0, BB: OuterCond);
6224
6225 // Modify the inner condition:
6226 // * Use the UpperBound returned from the DynamicNext call.
6227 // * jump to the loop outer loop when done with one of the inner loops.
6228 Builder.SetInsertPoint(TheBB: Cond, IP: Cond->getFirstInsertionPt());
6229 UpperBound = Builder.CreateLoad(Ty: IVTy, Ptr: PUpperBound, Name: "ub");
6230 Instruction *Comp = &*Builder.GetInsertPoint();
6231 auto *CI = cast<CmpInst>(Val: Comp);
6232 CI->setOperand(i_nocapture: 1, Val_nocapture: UpperBound);
6233 // Redirect the inner exit to branch to outer condition.
6234 Instruction *Branch = &Cond->back();
6235 auto *BI = cast<BranchInst>(Val: Branch);
6236 assert(BI->getSuccessor(1) == Exit);
6237 BI->setSuccessor(Idx: 1, BB: OuterCond);
6238
6239 // Call the "fini" function if "ordered" is present in wsloop directive.
6240 if (Ordered) {
6241 Builder.SetInsertPoint(&Latch->back());
6242 FunctionCallee DynamicFini = getKmpcForDynamicFiniForType(Ty: IVTy, M, OMPBuilder&: *this);
6243 createRuntimeFunctionCall(Callee: DynamicFini, Args: {SrcLoc, ThreadNum});
6244 }
6245
6246 // Add the barrier if requested.
6247 if (NeedsBarrier) {
6248 Builder.SetInsertPoint(&Exit->back());
6249 InsertPointOrErrorTy BarrierIP =
6250 createBarrier(Loc: LocationDescription(Builder.saveIP(), DL),
6251 Kind: omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
6252 /* CheckCancelFlag */ false);
6253 if (!BarrierIP)
6254 return BarrierIP.takeError();
6255 }
6256
6257 CLI->invalidate();
6258 return AfterIP;
6259}
6260
6261/// Redirect all edges that branch to \p OldTarget to \p NewTarget. That is,
6262/// after this \p OldTarget will be orphaned.
6263static void redirectAllPredecessorsTo(BasicBlock *OldTarget,
6264 BasicBlock *NewTarget, DebugLoc DL) {
6265 for (BasicBlock *Pred : make_early_inc_range(Range: predecessors(BB: OldTarget)))
6266 redirectTo(Source: Pred, Target: NewTarget, DL);
6267}
6268
6269/// Determine which blocks in \p BBs are reachable from outside and remove the
6270/// ones that are not reachable from the function.
6271static void removeUnusedBlocksFromParent(ArrayRef<BasicBlock *> BBs) {
6272 SmallPtrSet<BasicBlock *, 6> BBsToErase(llvm::from_range, BBs);
6273 auto HasRemainingUses = [&BBsToErase](BasicBlock *BB) {
6274 for (Use &U : BB->uses()) {
6275 auto *UseInst = dyn_cast<Instruction>(Val: U.getUser());
6276 if (!UseInst)
6277 continue;
6278 if (BBsToErase.count(Ptr: UseInst->getParent()))
6279 continue;
6280 return true;
6281 }
6282 return false;
6283 };
6284
6285 while (BBsToErase.remove_if(P: HasRemainingUses)) {
6286 // Try again if anything was removed.
6287 }
6288
6289 SmallVector<BasicBlock *, 7> BBVec(BBsToErase.begin(), BBsToErase.end());
6290 DeleteDeadBlocks(BBs: BBVec);
6291}
6292
6293CanonicalLoopInfo *
6294OpenMPIRBuilder::collapseLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
6295 InsertPointTy ComputeIP) {
6296 assert(Loops.size() >= 1 && "At least one loop required");
6297 size_t NumLoops = Loops.size();
6298
6299 // Nothing to do if there is already just one loop.
6300 if (NumLoops == 1)
6301 return Loops.front();
6302
6303 CanonicalLoopInfo *Outermost = Loops.front();
6304 CanonicalLoopInfo *Innermost = Loops.back();
6305 BasicBlock *OrigPreheader = Outermost->getPreheader();
6306 BasicBlock *OrigAfter = Outermost->getAfter();
6307 Function *F = OrigPreheader->getParent();
6308
6309 // Loop control blocks that may become orphaned later.
6310 SmallVector<BasicBlock *, 12> OldControlBBs;
6311 OldControlBBs.reserve(N: 6 * Loops.size());
6312 for (CanonicalLoopInfo *Loop : Loops)
6313 Loop->collectControlBlocks(BBs&: OldControlBBs);
6314
6315 // Setup the IRBuilder for inserting the trip count computation.
6316 Builder.SetCurrentDebugLocation(DL);
6317 if (ComputeIP.isSet())
6318 Builder.restoreIP(IP: ComputeIP);
6319 else
6320 Builder.restoreIP(IP: Outermost->getPreheaderIP());
6321
6322 // Derive the collapsed' loop trip count.
6323 // TODO: Find common/largest indvar type.
6324 Value *CollapsedTripCount = nullptr;
6325 for (CanonicalLoopInfo *L : Loops) {
6326 assert(L->isValid() &&
6327 "All loops to collapse must be valid canonical loops");
6328 Value *OrigTripCount = L->getTripCount();
6329 if (!CollapsedTripCount) {
6330 CollapsedTripCount = OrigTripCount;
6331 continue;
6332 }
6333
6334 // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
6335 CollapsedTripCount =
6336 Builder.CreateNUWMul(LHS: CollapsedTripCount, RHS: OrigTripCount);
6337 }
6338
6339 // Create the collapsed loop control flow.
6340 CanonicalLoopInfo *Result =
6341 createLoopSkeleton(DL, TripCount: CollapsedTripCount, F,
6342 PreInsertBefore: OrigPreheader->getNextNode(), PostInsertBefore: OrigAfter, Name: "collapsed");
6343
6344 // Build the collapsed loop body code.
6345 // Start with deriving the input loop induction variables from the collapsed
6346 // one, using a divmod scheme. To preserve the original loops' order, the
6347 // innermost loop use the least significant bits.
6348 Builder.restoreIP(IP: Result->getBodyIP());
6349
6350 Value *Leftover = Result->getIndVar();
6351 SmallVector<Value *> NewIndVars;
6352 NewIndVars.resize(N: NumLoops);
6353 for (int i = NumLoops - 1; i >= 1; --i) {
6354 Value *OrigTripCount = Loops[i]->getTripCount();
6355
6356 Value *NewIndVar = Builder.CreateURem(LHS: Leftover, RHS: OrigTripCount);
6357 NewIndVars[i] = NewIndVar;
6358
6359 Leftover = Builder.CreateUDiv(LHS: Leftover, RHS: OrigTripCount);
6360 }
6361 // Outermost loop gets all the remaining bits.
6362 NewIndVars[0] = Leftover;
6363
6364 // Construct the loop body control flow.
6365 // We progressively construct the branch structure following in direction of
6366 // the control flow, from the leading in-between code, the loop nest body, the
6367 // trailing in-between code, and rejoining the collapsed loop's latch.
6368 // ContinueBlock and ContinuePred keep track of the source(s) of next edge. If
6369 // the ContinueBlock is set, continue with that block. If ContinuePred, use
6370 // its predecessors as sources.
6371 BasicBlock *ContinueBlock = Result->getBody();
6372 BasicBlock *ContinuePred = nullptr;
6373 auto ContinueWith = [&ContinueBlock, &ContinuePred, DL](BasicBlock *Dest,
6374 BasicBlock *NextSrc) {
6375 if (ContinueBlock)
6376 redirectTo(Source: ContinueBlock, Target: Dest, DL);
6377 else
6378 redirectAllPredecessorsTo(OldTarget: ContinuePred, NewTarget: Dest, DL);
6379
6380 ContinueBlock = nullptr;
6381 ContinuePred = NextSrc;
6382 };
6383
6384 // The code before the nested loop of each level.
6385 // Because we are sinking it into the nest, it will be executed more often
6386 // that the original loop. More sophisticated schemes could keep track of what
6387 // the in-between code is and instantiate it only once per thread.
6388 for (size_t i = 0; i < NumLoops - 1; ++i)
6389 ContinueWith(Loops[i]->getBody(), Loops[i + 1]->getHeader());
6390
6391 // Connect the loop nest body.
6392 ContinueWith(Innermost->getBody(), Innermost->getLatch());
6393
6394 // The code after the nested loop at each level.
6395 for (size_t i = NumLoops - 1; i > 0; --i)
6396 ContinueWith(Loops[i]->getAfter(), Loops[i - 1]->getLatch());
6397
6398 // Connect the finished loop to the collapsed loop latch.
6399 ContinueWith(Result->getLatch(), nullptr);
6400
6401 // Replace the input loops with the new collapsed loop.
6402 redirectTo(Source: Outermost->getPreheader(), Target: Result->getPreheader(), DL);
6403 redirectTo(Source: Result->getAfter(), Target: Outermost->getAfter(), DL);
6404
6405 // Replace the input loop indvars with the derived ones.
6406 for (size_t i = 0; i < NumLoops; ++i)
6407 Loops[i]->getIndVar()->replaceAllUsesWith(V: NewIndVars[i]);
6408
6409 // Remove unused parts of the input loops.
6410 removeUnusedBlocksFromParent(BBs: OldControlBBs);
6411
6412 for (CanonicalLoopInfo *L : Loops)
6413 L->invalidate();
6414
6415#ifndef NDEBUG
6416 Result->assertOK();
6417#endif
6418 return Result;
6419}
6420
6421std::vector<CanonicalLoopInfo *>
6422OpenMPIRBuilder::tileLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
6423 ArrayRef<Value *> TileSizes) {
6424 assert(TileSizes.size() == Loops.size() &&
6425 "Must pass as many tile sizes as there are loops");
6426 int NumLoops = Loops.size();
6427 assert(NumLoops >= 1 && "At least one loop to tile required");
6428
6429 CanonicalLoopInfo *OutermostLoop = Loops.front();
6430 CanonicalLoopInfo *InnermostLoop = Loops.back();
6431 Function *F = OutermostLoop->getBody()->getParent();
6432 BasicBlock *InnerEnter = InnermostLoop->getBody();
6433 BasicBlock *InnerLatch = InnermostLoop->getLatch();
6434
6435 // Loop control blocks that may become orphaned later.
6436 SmallVector<BasicBlock *, 12> OldControlBBs;
6437 OldControlBBs.reserve(N: 6 * Loops.size());
6438 for (CanonicalLoopInfo *Loop : Loops)
6439 Loop->collectControlBlocks(BBs&: OldControlBBs);
6440
6441 // Collect original trip counts and induction variable to be accessible by
6442 // index. Also, the structure of the original loops is not preserved during
6443 // the construction of the tiled loops, so do it before we scavenge the BBs of
6444 // any original CanonicalLoopInfo.
6445 SmallVector<Value *, 4> OrigTripCounts, OrigIndVars;
6446 for (CanonicalLoopInfo *L : Loops) {
6447 assert(L->isValid() && "All input loops must be valid canonical loops");
6448 OrigTripCounts.push_back(Elt: L->getTripCount());
6449 OrigIndVars.push_back(Elt: L->getIndVar());
6450 }
6451
6452 // Collect the code between loop headers. These may contain SSA definitions
6453 // that are used in the loop nest body. To be usable with in the innermost
6454 // body, these BasicBlocks will be sunk into the loop nest body. That is,
6455 // these instructions may be executed more often than before the tiling.
6456 // TODO: It would be sufficient to only sink them into body of the
6457 // corresponding tile loop.
6458 SmallVector<std::pair<BasicBlock *, BasicBlock *>, 4> InbetweenCode;
6459 for (int i = 0; i < NumLoops - 1; ++i) {
6460 CanonicalLoopInfo *Surrounding = Loops[i];
6461 CanonicalLoopInfo *Nested = Loops[i + 1];
6462
6463 BasicBlock *EnterBB = Surrounding->getBody();
6464 BasicBlock *ExitBB = Nested->getHeader();
6465 InbetweenCode.emplace_back(Args&: EnterBB, Args&: ExitBB);
6466 }
6467
6468 // Compute the trip counts of the floor loops.
6469 Builder.SetCurrentDebugLocation(DL);
6470 Builder.restoreIP(IP: OutermostLoop->getPreheaderIP());
6471 SmallVector<Value *, 4> FloorCompleteCount, FloorCount, FloorRems;
6472 for (int i = 0; i < NumLoops; ++i) {
6473 Value *TileSize = TileSizes[i];
6474 Value *OrigTripCount = OrigTripCounts[i];
6475 Type *IVType = OrigTripCount->getType();
6476
6477 Value *FloorCompleteTripCount = Builder.CreateUDiv(LHS: OrigTripCount, RHS: TileSize);
6478 Value *FloorTripRem = Builder.CreateURem(LHS: OrigTripCount, RHS: TileSize);
6479
6480 // 0 if tripcount divides the tilesize, 1 otherwise.
6481 // 1 means we need an additional iteration for a partial tile.
6482 //
6483 // Unfortunately we cannot just use the roundup-formula
6484 // (tripcount + tilesize - 1)/tilesize
6485 // because the summation might overflow. We do not want introduce undefined
6486 // behavior when the untiled loop nest did not.
6487 Value *FloorTripOverflow =
6488 Builder.CreateICmpNE(LHS: FloorTripRem, RHS: ConstantInt::get(Ty: IVType, V: 0));
6489
6490 FloorTripOverflow = Builder.CreateZExt(V: FloorTripOverflow, DestTy: IVType);
6491 Value *FloorTripCount =
6492 Builder.CreateAdd(LHS: FloorCompleteTripCount, RHS: FloorTripOverflow,
6493 Name: "omp_floor" + Twine(i) + ".tripcount", HasNUW: true);
6494
6495 // Remember some values for later use.
6496 FloorCompleteCount.push_back(Elt: FloorCompleteTripCount);
6497 FloorCount.push_back(Elt: FloorTripCount);
6498 FloorRems.push_back(Elt: FloorTripRem);
6499 }
6500
6501 // Generate the new loop nest, from the outermost to the innermost.
6502 std::vector<CanonicalLoopInfo *> Result;
6503 Result.reserve(n: NumLoops * 2);
6504
6505 // The basic block of the surrounding loop that enters the nest generated
6506 // loop.
6507 BasicBlock *Enter = OutermostLoop->getPreheader();
6508
6509 // The basic block of the surrounding loop where the inner code should
6510 // continue.
6511 BasicBlock *Continue = OutermostLoop->getAfter();
6512
6513 // Where the next loop basic block should be inserted.
6514 BasicBlock *OutroInsertBefore = InnermostLoop->getExit();
6515
6516 auto EmbeddNewLoop =
6517 [this, DL, F, InnerEnter, &Enter, &Continue, &OutroInsertBefore](
6518 Value *TripCount, const Twine &Name) -> CanonicalLoopInfo * {
6519 CanonicalLoopInfo *EmbeddedLoop = createLoopSkeleton(
6520 DL, TripCount, F, PreInsertBefore: InnerEnter, PostInsertBefore: OutroInsertBefore, Name);
6521 redirectTo(Source: Enter, Target: EmbeddedLoop->getPreheader(), DL);
6522 redirectTo(Source: EmbeddedLoop->getAfter(), Target: Continue, DL);
6523
6524 // Setup the position where the next embedded loop connects to this loop.
6525 Enter = EmbeddedLoop->getBody();
6526 Continue = EmbeddedLoop->getLatch();
6527 OutroInsertBefore = EmbeddedLoop->getLatch();
6528 return EmbeddedLoop;
6529 };
6530
6531 auto EmbeddNewLoops = [&Result, &EmbeddNewLoop](ArrayRef<Value *> TripCounts,
6532 const Twine &NameBase) {
6533 for (auto P : enumerate(First&: TripCounts)) {
6534 CanonicalLoopInfo *EmbeddedLoop =
6535 EmbeddNewLoop(P.value(), NameBase + Twine(P.index()));
6536 Result.push_back(x: EmbeddedLoop);
6537 }
6538 };
6539
6540 EmbeddNewLoops(FloorCount, "floor");
6541
6542 // Within the innermost floor loop, emit the code that computes the tile
6543 // sizes.
6544 Builder.SetInsertPoint(Enter->getTerminator());
6545 SmallVector<Value *, 4> TileCounts;
6546 for (int i = 0; i < NumLoops; ++i) {
6547 CanonicalLoopInfo *FloorLoop = Result[i];
6548 Value *TileSize = TileSizes[i];
6549
6550 Value *FloorIsEpilogue =
6551 Builder.CreateICmpEQ(LHS: FloorLoop->getIndVar(), RHS: FloorCompleteCount[i]);
6552 Value *TileTripCount =
6553 Builder.CreateSelect(C: FloorIsEpilogue, True: FloorRems[i], False: TileSize);
6554
6555 TileCounts.push_back(Elt: TileTripCount);
6556 }
6557
6558 // Create the tile loops.
6559 EmbeddNewLoops(TileCounts, "tile");
6560
6561 // Insert the inbetween code into the body.
6562 BasicBlock *BodyEnter = Enter;
6563 BasicBlock *BodyEntered = nullptr;
6564 for (std::pair<BasicBlock *, BasicBlock *> P : InbetweenCode) {
6565 BasicBlock *EnterBB = P.first;
6566 BasicBlock *ExitBB = P.second;
6567
6568 if (BodyEnter)
6569 redirectTo(Source: BodyEnter, Target: EnterBB, DL);
6570 else
6571 redirectAllPredecessorsTo(OldTarget: BodyEntered, NewTarget: EnterBB, DL);
6572
6573 BodyEnter = nullptr;
6574 BodyEntered = ExitBB;
6575 }
6576
6577 // Append the original loop nest body into the generated loop nest body.
6578 if (BodyEnter)
6579 redirectTo(Source: BodyEnter, Target: InnerEnter, DL);
6580 else
6581 redirectAllPredecessorsTo(OldTarget: BodyEntered, NewTarget: InnerEnter, DL);
6582 redirectAllPredecessorsTo(OldTarget: InnerLatch, NewTarget: Continue, DL);
6583
6584 // Replace the original induction variable with an induction variable computed
6585 // from the tile and floor induction variables.
6586 Builder.restoreIP(IP: Result.back()->getBodyIP());
6587 for (int i = 0; i < NumLoops; ++i) {
6588 CanonicalLoopInfo *FloorLoop = Result[i];
6589 CanonicalLoopInfo *TileLoop = Result[NumLoops + i];
6590 Value *OrigIndVar = OrigIndVars[i];
6591 Value *Size = TileSizes[i];
6592
6593 Value *Scale =
6594 Builder.CreateMul(LHS: Size, RHS: FloorLoop->getIndVar(), Name: {}, /*HasNUW=*/true);
6595 Value *Shift =
6596 Builder.CreateAdd(LHS: Scale, RHS: TileLoop->getIndVar(), Name: {}, /*HasNUW=*/true);
6597 OrigIndVar->replaceAllUsesWith(V: Shift);
6598 }
6599
6600 // Remove unused parts of the original loops.
6601 removeUnusedBlocksFromParent(BBs: OldControlBBs);
6602
6603 for (CanonicalLoopInfo *L : Loops)
6604 L->invalidate();
6605
6606#ifndef NDEBUG
6607 for (CanonicalLoopInfo *GenL : Result)
6608 GenL->assertOK();
6609#endif
6610 return Result;
6611}
6612
6613/// Attach metadata \p Properties to the basic block described by \p BB. If the
6614/// basic block already has metadata, the basic block properties are appended.
6615static void addBasicBlockMetadata(BasicBlock *BB,
6616 ArrayRef<Metadata *> Properties) {
6617 // Nothing to do if no property to attach.
6618 if (Properties.empty())
6619 return;
6620
6621 LLVMContext &Ctx = BB->getContext();
6622 SmallVector<Metadata *> NewProperties;
6623 NewProperties.push_back(Elt: nullptr);
6624
6625 // If the basic block already has metadata, prepend it to the new metadata.
6626 MDNode *Existing = BB->getTerminator()->getMetadata(KindID: LLVMContext::MD_loop);
6627 if (Existing)
6628 append_range(C&: NewProperties, R: drop_begin(RangeOrContainer: Existing->operands(), N: 1));
6629
6630 append_range(C&: NewProperties, R&: Properties);
6631 MDNode *BasicBlockID = MDNode::getDistinct(Context&: Ctx, MDs: NewProperties);
6632 BasicBlockID->replaceOperandWith(I: 0, New: BasicBlockID);
6633
6634 BB->getTerminator()->setMetadata(KindID: LLVMContext::MD_loop, Node: BasicBlockID);
6635}
6636
6637/// Attach loop metadata \p Properties to the loop described by \p Loop. If the
6638/// loop already has metadata, the loop properties are appended.
6639static void addLoopMetadata(CanonicalLoopInfo *Loop,
6640 ArrayRef<Metadata *> Properties) {
6641 assert(Loop->isValid() && "Expecting a valid CanonicalLoopInfo");
6642
6643 // Attach metadata to the loop's latch
6644 BasicBlock *Latch = Loop->getLatch();
6645 assert(Latch && "A valid CanonicalLoopInfo must have a unique latch");
6646 addBasicBlockMetadata(BB: Latch, Properties);
6647}
6648
6649/// Attach llvm.access.group metadata to the memref instructions of \p Block
6650static void addAccessGroupMetadata(BasicBlock *Block, MDNode *AccessGroup,
6651 LoopInfo &LI) {
6652 for (Instruction &I : *Block) {
6653 if (I.mayReadOrWriteMemory()) {
6654 // TODO: This instruction may already have access group from
6655 // other pragmas e.g. #pragma clang loop vectorize. Append
6656 // so that the existing metadata is not overwritten.
6657 I.setMetadata(KindID: LLVMContext::MD_access_group, Node: AccessGroup);
6658 }
6659 }
6660}
6661
6662CanonicalLoopInfo *
6663OpenMPIRBuilder::fuseLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops) {
6664 CanonicalLoopInfo *firstLoop = Loops.front();
6665 CanonicalLoopInfo *lastLoop = Loops.back();
6666 Function *F = firstLoop->getPreheader()->getParent();
6667
6668 // Loop control blocks that will become orphaned later
6669 SmallVector<BasicBlock *> oldControlBBs;
6670 for (CanonicalLoopInfo *Loop : Loops)
6671 Loop->collectControlBlocks(BBs&: oldControlBBs);
6672
6673 // Collect original trip counts
6674 SmallVector<Value *> origTripCounts;
6675 for (CanonicalLoopInfo *L : Loops) {
6676 assert(L->isValid() && "All input loops must be valid canonical loops");
6677 origTripCounts.push_back(Elt: L->getTripCount());
6678 }
6679
6680 Builder.SetCurrentDebugLocation(DL);
6681
6682 // Compute max trip count.
6683 // The fused loop will be from 0 to max(origTripCounts)
6684 BasicBlock *TCBlock = BasicBlock::Create(Context&: F->getContext(), Name: "omp.fuse.comp.tc",
6685 Parent: F, InsertBefore: firstLoop->getHeader());
6686 Builder.SetInsertPoint(TCBlock);
6687 Value *fusedTripCount = nullptr;
6688 for (CanonicalLoopInfo *L : Loops) {
6689 assert(L->isValid() && "All loops to fuse must be valid canonical loops");
6690 Value *origTripCount = L->getTripCount();
6691 if (!fusedTripCount) {
6692 fusedTripCount = origTripCount;
6693 continue;
6694 }
6695 Value *condTP = Builder.CreateICmpSGT(LHS: fusedTripCount, RHS: origTripCount);
6696 fusedTripCount = Builder.CreateSelect(C: condTP, True: fusedTripCount, False: origTripCount,
6697 Name: ".omp.fuse.tc");
6698 }
6699
6700 // Generate new loop
6701 CanonicalLoopInfo *fused =
6702 createLoopSkeleton(DL, TripCount: fusedTripCount, F, PreInsertBefore: firstLoop->getBody(),
6703 PostInsertBefore: lastLoop->getLatch(), Name: "fused");
6704
6705 // Replace original loops with the fused loop
6706 // Preheader and After are not considered inside the CLI.
6707 // These are used to compute the individual TCs of the loops
6708 // so they have to be put before the resulting fused loop.
6709 // Moving them up for readability.
6710 for (size_t i = 0; i < Loops.size() - 1; ++i) {
6711 Loops[i]->getPreheader()->moveBefore(MovePos: TCBlock);
6712 Loops[i]->getAfter()->moveBefore(MovePos: TCBlock);
6713 }
6714 lastLoop->getPreheader()->moveBefore(MovePos: TCBlock);
6715
6716 for (size_t i = 0; i < Loops.size() - 1; ++i) {
6717 redirectTo(Source: Loops[i]->getPreheader(), Target: Loops[i]->getAfter(), DL);
6718 redirectTo(Source: Loops[i]->getAfter(), Target: Loops[i + 1]->getPreheader(), DL);
6719 }
6720 redirectTo(Source: lastLoop->getPreheader(), Target: TCBlock, DL);
6721 redirectTo(Source: TCBlock, Target: fused->getPreheader(), DL);
6722 redirectTo(Source: fused->getAfter(), Target: lastLoop->getAfter(), DL);
6723
6724 // Build the fused body
6725 // Create new Blocks with conditions that jump to the original loop bodies
6726 SmallVector<BasicBlock *> condBBs;
6727 SmallVector<Value *> condValues;
6728 for (size_t i = 0; i < Loops.size(); ++i) {
6729 BasicBlock *condBlock = BasicBlock::Create(
6730 Context&: F->getContext(), Name: "omp.fused.inner.cond", Parent: F, InsertBefore: Loops[i]->getBody());
6731 Builder.SetInsertPoint(condBlock);
6732 Value *condValue =
6733 Builder.CreateICmpSLT(LHS: fused->getIndVar(), RHS: origTripCounts[i]);
6734 condBBs.push_back(Elt: condBlock);
6735 condValues.push_back(Elt: condValue);
6736 }
6737 // Join the condition blocks with the bodies of the original loops
6738 redirectTo(Source: fused->getBody(), Target: condBBs[0], DL);
6739 for (size_t i = 0; i < Loops.size() - 1; ++i) {
6740 Builder.SetInsertPoint(condBBs[i]);
6741 Builder.CreateCondBr(Cond: condValues[i], True: Loops[i]->getBody(), False: condBBs[i + 1]);
6742 redirectAllPredecessorsTo(OldTarget: Loops[i]->getLatch(), NewTarget: condBBs[i + 1], DL);
6743 // Replace the IV with the fused IV
6744 Loops[i]->getIndVar()->replaceAllUsesWith(V: fused->getIndVar());
6745 }
6746 // Last body jumps to the created end body block
6747 Builder.SetInsertPoint(condBBs.back());
6748 Builder.CreateCondBr(Cond: condValues.back(), True: lastLoop->getBody(),
6749 False: fused->getLatch());
6750 redirectAllPredecessorsTo(OldTarget: lastLoop->getLatch(), NewTarget: fused->getLatch(), DL);
6751 // Replace the IV with the fused IV
6752 lastLoop->getIndVar()->replaceAllUsesWith(V: fused->getIndVar());
6753
6754 // The loop latch must have only one predecessor. Currently it is branched to
6755 // from both the last condition block and the last loop body
6756 fused->getLatch()->splitBasicBlockBefore(I: fused->getLatch()->begin(),
6757 BBName: "omp.fused.pre_latch");
6758
6759 // Remove unused parts
6760 removeUnusedBlocksFromParent(BBs: oldControlBBs);
6761
6762 // Invalidate old CLIs
6763 for (CanonicalLoopInfo *L : Loops)
6764 L->invalidate();
6765
6766#ifndef NDEBUG
6767 fused->assertOK();
6768#endif
6769 return fused;
6770}
6771
6772void OpenMPIRBuilder::unrollLoopFull(DebugLoc, CanonicalLoopInfo *Loop) {
6773 LLVMContext &Ctx = Builder.getContext();
6774 addLoopMetadata(
6775 Loop, Properties: {MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")),
6776 MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.full"))});
6777}
6778
6779void OpenMPIRBuilder::unrollLoopHeuristic(DebugLoc, CanonicalLoopInfo *Loop) {
6780 LLVMContext &Ctx = Builder.getContext();
6781 addLoopMetadata(
6782 Loop, Properties: {
6783 MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")),
6784 });
6785}
6786
6787void OpenMPIRBuilder::createIfVersion(CanonicalLoopInfo *CanonicalLoop,
6788 Value *IfCond, ValueToValueMapTy &VMap,
6789 LoopAnalysis &LIA, LoopInfo &LI, Loop *L,
6790 const Twine &NamePrefix) {
6791 Function *F = CanonicalLoop->getFunction();
6792
6793 // We can't do
6794 // if (cond) {
6795 // simd_loop;
6796 // } else {
6797 // non_simd_loop;
6798 // }
6799 // because then the CanonicalLoopInfo would only point to one of the loops:
6800 // leading to other constructs operating on the same loop to malfunction.
6801 // Instead generate
6802 // while (...) {
6803 // if (cond) {
6804 // simd_body;
6805 // } else {
6806 // not_simd_body;
6807 // }
6808 // }
6809 // At least for simple loops, LLVM seems able to hoist the if out of the loop
6810 // body at -O3
6811
6812 // Define where if branch should be inserted
6813 auto SplitBeforeIt = CanonicalLoop->getBody()->getFirstNonPHIIt();
6814
6815 // Create additional blocks for the if statement
6816 BasicBlock *Cond = SplitBeforeIt->getParent();
6817 llvm::LLVMContext &C = Cond->getContext();
6818 llvm::BasicBlock *ThenBlock = llvm::BasicBlock::Create(
6819 Context&: C, Name: NamePrefix + ".if.then", Parent: Cond->getParent(), InsertBefore: Cond->getNextNode());
6820 llvm::BasicBlock *ElseBlock = llvm::BasicBlock::Create(
6821 Context&: C, Name: NamePrefix + ".if.else", Parent: Cond->getParent(), InsertBefore: CanonicalLoop->getExit());
6822
6823 // Create if condition branch.
6824 Builder.SetInsertPoint(SplitBeforeIt);
6825 Instruction *BrInstr =
6826 Builder.CreateCondBr(Cond: IfCond, True: ThenBlock, /*ifFalse*/ False: ElseBlock);
6827 InsertPointTy IP{BrInstr->getParent(), ++BrInstr->getIterator()};
6828 // Then block contains branch to omp loop body which needs to be vectorized
6829 spliceBB(IP, New: ThenBlock, CreateBranch: false, DL: Builder.getCurrentDebugLocation());
6830 ThenBlock->replaceSuccessorsPhiUsesWith(Old: Cond, New: ThenBlock);
6831
6832 Builder.SetInsertPoint(ElseBlock);
6833
6834 // Clone loop for the else branch
6835 SmallVector<BasicBlock *, 8> NewBlocks;
6836
6837 SmallVector<BasicBlock *, 8> ExistingBlocks;
6838 ExistingBlocks.reserve(N: L->getNumBlocks() + 1);
6839 ExistingBlocks.push_back(Elt: ThenBlock);
6840 ExistingBlocks.append(in_start: L->block_begin(), in_end: L->block_end());
6841 // Cond is the block that has the if clause condition
6842 // LoopCond is omp_loop.cond
6843 // LoopHeader is omp_loop.header
6844 BasicBlock *LoopCond = Cond->getUniquePredecessor();
6845 BasicBlock *LoopHeader = LoopCond->getUniquePredecessor();
6846 assert(LoopCond && LoopHeader && "Invalid loop structure");
6847 for (BasicBlock *Block : ExistingBlocks) {
6848 if (Block == L->getLoopPreheader() || Block == L->getLoopLatch() ||
6849 Block == LoopHeader || Block == LoopCond || Block == Cond) {
6850 continue;
6851 }
6852 BasicBlock *NewBB = CloneBasicBlock(BB: Block, VMap, NameSuffix: "", F);
6853
6854 // fix name not to be omp.if.then
6855 if (Block == ThenBlock)
6856 NewBB->setName(NamePrefix + ".if.else");
6857
6858 NewBB->moveBefore(MovePos: CanonicalLoop->getExit());
6859 VMap[Block] = NewBB;
6860 NewBlocks.push_back(Elt: NewBB);
6861 }
6862 remapInstructionsInBlocks(Blocks: NewBlocks, VMap);
6863 Builder.CreateBr(Dest: NewBlocks.front());
6864
6865 // The loop latch must have only one predecessor. Currently it is branched to
6866 // from both the 'then' and 'else' branches.
6867 L->getLoopLatch()->splitBasicBlockBefore(I: L->getLoopLatch()->begin(),
6868 BBName: NamePrefix + ".pre_latch");
6869
6870 // Ensure that the then block is added to the loop so we add the attributes in
6871 // the next step
6872 L->addBasicBlockToLoop(NewBB: ThenBlock, LI);
6873}
6874
6875unsigned
6876OpenMPIRBuilder::getOpenMPDefaultSimdAlign(const Triple &TargetTriple,
6877 const StringMap<bool> &Features) {
6878 if (TargetTriple.isX86()) {
6879 if (Features.lookup(Key: "avx512f"))
6880 return 512;
6881 else if (Features.lookup(Key: "avx"))
6882 return 256;
6883 return 128;
6884 }
6885 if (TargetTriple.isPPC())
6886 return 128;
6887 if (TargetTriple.isWasm())
6888 return 128;
6889 return 0;
6890}
6891
6892void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
6893 MapVector<Value *, Value *> AlignedVars,
6894 Value *IfCond, OrderKind Order,
6895 ConstantInt *Simdlen, ConstantInt *Safelen) {
6896 LLVMContext &Ctx = Builder.getContext();
6897
6898 Function *F = CanonicalLoop->getFunction();
6899
6900 // Blocks must have terminators.
6901 // FIXME: Don't run analyses on incomplete/invalid IR.
6902 SmallVector<Instruction *> UIs;
6903 for (BasicBlock &BB : *F)
6904 if (!BB.getTerminator())
6905 UIs.push_back(Elt: new UnreachableInst(F->getContext(), &BB));
6906
6907 // TODO: We should not rely on pass manager. Currently we use pass manager
6908 // only for getting llvm::Loop which corresponds to given CanonicalLoopInfo
6909 // object. We should have a method which returns all blocks between
6910 // CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter()
6911 FunctionAnalysisManager FAM;
6912 FAM.registerPass(PassBuilder: []() { return DominatorTreeAnalysis(); });
6913 FAM.registerPass(PassBuilder: []() { return LoopAnalysis(); });
6914 FAM.registerPass(PassBuilder: []() { return PassInstrumentationAnalysis(); });
6915
6916 LoopAnalysis LIA;
6917 LoopInfo &&LI = LIA.run(F&: *F, AM&: FAM);
6918
6919 for (Instruction *I : UIs)
6920 I->eraseFromParent();
6921
6922 Loop *L = LI.getLoopFor(BB: CanonicalLoop->getHeader());
6923 if (AlignedVars.size()) {
6924 InsertPointTy IP = Builder.saveIP();
6925 for (auto &AlignedItem : AlignedVars) {
6926 Value *AlignedPtr = AlignedItem.first;
6927 Value *Alignment = AlignedItem.second;
6928 Instruction *loadInst = dyn_cast<Instruction>(Val: AlignedPtr);
6929 Builder.SetInsertPoint(loadInst->getNextNode());
6930 Builder.CreateAlignmentAssumption(DL: F->getDataLayout(), PtrValue: AlignedPtr,
6931 Alignment);
6932 }
6933 Builder.restoreIP(IP);
6934 }
6935
6936 if (IfCond) {
6937 ValueToValueMapTy VMap;
6938 createIfVersion(CanonicalLoop, IfCond, VMap, LIA, LI, L, NamePrefix: "simd");
6939 }
6940
6941 SmallPtrSet<BasicBlock *, 8> Reachable;
6942
6943 // Get the basic blocks from the loop in which memref instructions
6944 // can be found.
6945 // TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
6946 // preferably without running any passes.
6947 for (BasicBlock *Block : L->getBlocks()) {
6948 if (Block == CanonicalLoop->getCond() ||
6949 Block == CanonicalLoop->getHeader())
6950 continue;
6951 Reachable.insert(Ptr: Block);
6952 }
6953
6954 SmallVector<Metadata *> LoopMDList;
6955
6956 // In presence of finite 'safelen', it may be unsafe to mark all
6957 // the memory instructions parallel, because loop-carried
6958 // dependences of 'safelen' iterations are possible.
6959 // If clause order(concurrent) is specified then the memory instructions
6960 // are marked parallel even if 'safelen' is finite.
6961 if ((Safelen == nullptr) || (Order == OrderKind::OMP_ORDER_concurrent))
6962 applyParallelAccessesMetadata(CLI: CanonicalLoop, Ctx, Loop: L, LoopInfo&: LI, LoopMDList);
6963
6964 // FIXME: the IF clause shares a loop backedge for the SIMD and non-SIMD
6965 // versions so we can't add the loop attributes in that case.
6966 if (IfCond) {
6967 // we can still add llvm.loop.parallel_access
6968 addLoopMetadata(Loop: CanonicalLoop, Properties: LoopMDList);
6969 return;
6970 }
6971
6972 // Use the above access group metadata to create loop level
6973 // metadata, which should be distinct for each loop.
6974 ConstantAsMetadata *BoolConst =
6975 ConstantAsMetadata::get(C: ConstantInt::getTrue(Ty: Type::getInt1Ty(C&: Ctx)));
6976 LoopMDList.push_back(Elt: MDNode::get(
6977 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.vectorize.enable"), BoolConst}));
6978
6979 if (Simdlen || Safelen) {
6980 // If both simdlen and safelen clauses are specified, the value of the
6981 // simdlen parameter must be less than or equal to the value of the safelen
6982 // parameter. Therefore, use safelen only in the absence of simdlen.
6983 ConstantInt *VectorizeWidth = Simdlen == nullptr ? Safelen : Simdlen;
6984 LoopMDList.push_back(
6985 Elt: MDNode::get(Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.vectorize.width"),
6986 ConstantAsMetadata::get(C: VectorizeWidth)}));
6987 }
6988
6989 addLoopMetadata(Loop: CanonicalLoop, Properties: LoopMDList);
6990}
6991
6992/// Create the TargetMachine object to query the backend for optimization
6993/// preferences.
6994///
6995/// Ideally, this would be passed from the front-end to the OpenMPBuilder, but
6996/// e.g. Clang does not pass it to its CodeGen layer and creates it only when
6997/// needed for the LLVM pass pipline. We use some default options to avoid
6998/// having to pass too many settings from the frontend that probably do not
6999/// matter.
7000///
7001/// Currently, TargetMachine is only used sometimes by the unrollLoopPartial
7002/// method. If we are going to use TargetMachine for more purposes, especially
7003/// those that are sensitive to TargetOptions, RelocModel and CodeModel, it
7004/// might become be worth requiring front-ends to pass on their TargetMachine,
7005/// or at least cache it between methods. Note that while fontends such as Clang
7006/// have just a single main TargetMachine per translation unit, "target-cpu" and
7007/// "target-features" that determine the TargetMachine are per-function and can
7008/// be overrided using __attribute__((target("OPTIONS"))).
7009static std::unique_ptr<TargetMachine>
7010createTargetMachine(Function *F, CodeGenOptLevel OptLevel) {
7011 Module *M = F->getParent();
7012
7013 StringRef CPU = F->getFnAttribute(Kind: "target-cpu").getValueAsString();
7014 StringRef Features = F->getFnAttribute(Kind: "target-features").getValueAsString();
7015 const llvm::Triple &Triple = M->getTargetTriple();
7016
7017 std::string Error;
7018 const llvm::Target *TheTarget = TargetRegistry::lookupTarget(TheTriple: Triple, Error);
7019 if (!TheTarget)
7020 return {};
7021
7022 llvm::TargetOptions Options;
7023 return std::unique_ptr<TargetMachine>(TheTarget->createTargetMachine(
7024 TT: Triple, CPU, Features, Options, /*RelocModel=*/RM: std::nullopt,
7025 /*CodeModel=*/CM: std::nullopt, OL: OptLevel));
7026}
7027
7028/// Heuristically determine the best-performant unroll factor for \p CLI. This
7029/// depends on the target processor. We are re-using the same heuristics as the
7030/// LoopUnrollPass.
7031static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) {
7032 Function *F = CLI->getFunction();
7033
7034 // Assume the user requests the most aggressive unrolling, even if the rest of
7035 // the code is optimized using a lower setting.
7036 CodeGenOptLevel OptLevel = CodeGenOptLevel::Aggressive;
7037 std::unique_ptr<TargetMachine> TM = createTargetMachine(F, OptLevel);
7038
7039 // Blocks must have terminators.
7040 // FIXME: Don't run analyses on incomplete/invalid IR.
7041 SmallVector<Instruction *> UIs;
7042 for (BasicBlock &BB : *F)
7043 if (!BB.getTerminator())
7044 UIs.push_back(Elt: new UnreachableInst(F->getContext(), &BB));
7045
7046 FunctionAnalysisManager FAM;
7047 FAM.registerPass(PassBuilder: []() { return TargetLibraryAnalysis(); });
7048 FAM.registerPass(PassBuilder: []() { return AssumptionAnalysis(); });
7049 FAM.registerPass(PassBuilder: []() { return DominatorTreeAnalysis(); });
7050 FAM.registerPass(PassBuilder: []() { return LoopAnalysis(); });
7051 FAM.registerPass(PassBuilder: []() { return ScalarEvolutionAnalysis(); });
7052 FAM.registerPass(PassBuilder: []() { return PassInstrumentationAnalysis(); });
7053 TargetIRAnalysis TIRA;
7054 if (TM)
7055 TIRA = TargetIRAnalysis(
7056 [&](const Function &F) { return TM->getTargetTransformInfo(F); });
7057 FAM.registerPass(PassBuilder: [&]() { return TIRA; });
7058
7059 TargetIRAnalysis::Result &&TTI = TIRA.run(F: *F, FAM);
7060 ScalarEvolutionAnalysis SEA;
7061 ScalarEvolution &&SE = SEA.run(F&: *F, AM&: FAM);
7062 DominatorTreeAnalysis DTA;
7063 DominatorTree &&DT = DTA.run(F&: *F, FAM);
7064 LoopAnalysis LIA;
7065 LoopInfo &&LI = LIA.run(F&: *F, AM&: FAM);
7066 AssumptionAnalysis ACT;
7067 AssumptionCache &&AC = ACT.run(F&: *F, FAM);
7068 OptimizationRemarkEmitter ORE{F};
7069
7070 for (Instruction *I : UIs)
7071 I->eraseFromParent();
7072
7073 Loop *L = LI.getLoopFor(BB: CLI->getHeader());
7074 assert(L && "Expecting CanonicalLoopInfo to be recognized as a loop");
7075
7076 TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences(
7077 L, SE, TTI,
7078 /*BlockFrequencyInfo=*/BFI: nullptr,
7079 /*ProfileSummaryInfo=*/PSI: nullptr, ORE, OptLevel: static_cast<int>(OptLevel),
7080 /*UserThreshold=*/std::nullopt,
7081 /*UserCount=*/std::nullopt,
7082 /*UserAllowPartial=*/true,
7083 /*UserAllowRuntime=*/UserRuntime: true,
7084 /*UserUpperBound=*/std::nullopt,
7085 /*UserFullUnrollMaxCount=*/std::nullopt);
7086
7087 UP.Force = true;
7088
7089 // Account for additional optimizations taking place before the LoopUnrollPass
7090 // would unroll the loop.
7091 UP.Threshold *= UnrollThresholdFactor;
7092 UP.PartialThreshold *= UnrollThresholdFactor;
7093
7094 // Use normal unroll factors even if the rest of the code is optimized for
7095 // size.
7096 UP.OptSizeThreshold = UP.Threshold;
7097 UP.PartialOptSizeThreshold = UP.PartialThreshold;
7098
7099 LLVM_DEBUG(dbgs() << "Unroll heuristic thresholds:\n"
7100 << " Threshold=" << UP.Threshold << "\n"
7101 << " PartialThreshold=" << UP.PartialThreshold << "\n"
7102 << " OptSizeThreshold=" << UP.OptSizeThreshold << "\n"
7103 << " PartialOptSizeThreshold="
7104 << UP.PartialOptSizeThreshold << "\n");
7105
7106 // Disable peeling.
7107 TargetTransformInfo::PeelingPreferences PP =
7108 gatherPeelingPreferences(L, SE, TTI,
7109 /*UserAllowPeeling=*/false,
7110 /*UserAllowProfileBasedPeeling=*/false,
7111 /*UnrollingSpecficValues=*/false);
7112
7113 SmallPtrSet<const Value *, 32> EphValues;
7114 CodeMetrics::collectEphemeralValues(L, AC: &AC, EphValues);
7115
7116 // Assume that reads and writes to stack variables can be eliminated by
7117 // Mem2Reg, SROA or LICM. That is, don't count them towards the loop body's
7118 // size.
7119 for (BasicBlock *BB : L->blocks()) {
7120 for (Instruction &I : *BB) {
7121 Value *Ptr;
7122 if (auto *Load = dyn_cast<LoadInst>(Val: &I)) {
7123 Ptr = Load->getPointerOperand();
7124 } else if (auto *Store = dyn_cast<StoreInst>(Val: &I)) {
7125 Ptr = Store->getPointerOperand();
7126 } else
7127 continue;
7128
7129 Ptr = Ptr->stripPointerCasts();
7130
7131 if (auto *Alloca = dyn_cast<AllocaInst>(Val: Ptr)) {
7132 if (Alloca->getParent() == &F->getEntryBlock())
7133 EphValues.insert(Ptr: &I);
7134 }
7135 }
7136 }
7137
7138 UnrollCostEstimator UCE(L, TTI, EphValues, UP.BEInsns);
7139
7140 // Loop is not unrollable if the loop contains certain instructions.
7141 if (!UCE.canUnroll()) {
7142 LLVM_DEBUG(dbgs() << "Loop not considered unrollable\n");
7143 return 1;
7144 }
7145
7146 LLVM_DEBUG(dbgs() << "Estimated loop size is " << UCE.getRolledLoopSize()
7147 << "\n");
7148
7149 // TODO: Determine trip count of \p CLI if constant, computeUnrollCount might
7150 // be able to use it.
7151 int TripCount = 0;
7152 int MaxTripCount = 0;
7153 bool MaxOrZero = false;
7154 unsigned TripMultiple = 0;
7155
7156 computeUnrollCount(L, TTI, DT, LI: &LI, AC: &AC, SE, EphValues, ORE: &ORE, TripCount,
7157 MaxTripCount, MaxOrZero, TripMultiple, UCE, UP, PP);
7158 unsigned Factor = UP.Count;
7159 LLVM_DEBUG(dbgs() << "Suggesting unroll factor of " << Factor << "\n");
7160
7161 // This function returns 1 to signal to not unroll a loop.
7162 if (Factor == 0)
7163 return 1;
7164 return Factor;
7165}
7166
7167void OpenMPIRBuilder::unrollLoopPartial(DebugLoc DL, CanonicalLoopInfo *Loop,
7168 int32_t Factor,
7169 CanonicalLoopInfo **UnrolledCLI) {
7170 assert(Factor >= 0 && "Unroll factor must not be negative");
7171
7172 Function *F = Loop->getFunction();
7173 LLVMContext &Ctx = F->getContext();
7174
7175 // If the unrolled loop is not used for another loop-associated directive, it
7176 // is sufficient to add metadata for the LoopUnrollPass.
7177 if (!UnrolledCLI) {
7178 SmallVector<Metadata *, 2> LoopMetadata;
7179 LoopMetadata.push_back(
7180 Elt: MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")));
7181
7182 if (Factor >= 1) {
7183 ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
7184 C: ConstantInt::get(Ty: Type::getInt32Ty(C&: Ctx), V: APInt(32, Factor)));
7185 LoopMetadata.push_back(Elt: MDNode::get(
7186 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.count"), FactorConst}));
7187 }
7188
7189 addLoopMetadata(Loop, Properties: LoopMetadata);
7190 return;
7191 }
7192
7193 // Heuristically determine the unroll factor.
7194 if (Factor == 0)
7195 Factor = computeHeuristicUnrollFactor(CLI: Loop);
7196
7197 // No change required with unroll factor 1.
7198 if (Factor == 1) {
7199 *UnrolledCLI = Loop;
7200 return;
7201 }
7202
7203 assert(Factor >= 2 &&
7204 "unrolling only makes sense with a factor of 2 or larger");
7205
7206 Type *IndVarTy = Loop->getIndVarType();
7207
7208 // Apply partial unrolling by tiling the loop by the unroll-factor, then fully
7209 // unroll the inner loop.
7210 Value *FactorVal =
7211 ConstantInt::get(Ty: IndVarTy, V: APInt(IndVarTy->getIntegerBitWidth(), Factor,
7212 /*isSigned=*/false));
7213 std::vector<CanonicalLoopInfo *> LoopNest =
7214 tileLoops(DL, Loops: {Loop}, TileSizes: {FactorVal});
7215 assert(LoopNest.size() == 2 && "Expect 2 loops after tiling");
7216 *UnrolledCLI = LoopNest[0];
7217 CanonicalLoopInfo *InnerLoop = LoopNest[1];
7218
7219 // LoopUnrollPass can only fully unroll loops with constant trip count.
7220 // Unroll by the unroll factor with a fallback epilog for the remainder
7221 // iterations if necessary.
7222 ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
7223 C: ConstantInt::get(Ty: Type::getInt32Ty(C&: Ctx), V: APInt(32, Factor)));
7224 addLoopMetadata(
7225 Loop: InnerLoop,
7226 Properties: {MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")),
7227 MDNode::get(
7228 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.count"), FactorConst})});
7229
7230#ifndef NDEBUG
7231 (*UnrolledCLI)->assertOK();
7232#endif
7233}
7234
7235OpenMPIRBuilder::InsertPointTy
7236OpenMPIRBuilder::createCopyPrivate(const LocationDescription &Loc,
7237 llvm::Value *BufSize, llvm::Value *CpyBuf,
7238 llvm::Value *CpyFn, llvm::Value *DidIt) {
7239 if (!updateToLocation(Loc))
7240 return Loc.IP;
7241
7242 uint32_t SrcLocStrSize;
7243 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7244 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7245 Value *ThreadId = getOrCreateThreadID(Ident);
7246
7247 llvm::Value *DidItLD = Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: DidIt);
7248
7249 Value *Args[] = {Ident, ThreadId, BufSize, CpyBuf, CpyFn, DidItLD};
7250
7251 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_copyprivate);
7252 createRuntimeFunctionCall(Callee: Fn, Args);
7253
7254 return Builder.saveIP();
7255}
7256
7257OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createSingle(
7258 const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
7259 FinalizeCallbackTy FiniCB, bool IsNowait, ArrayRef<llvm::Value *> CPVars,
7260 ArrayRef<llvm::Function *> CPFuncs) {
7261
7262 if (!updateToLocation(Loc))
7263 return Loc.IP;
7264
7265 // If needed allocate and initialize `DidIt` with 0.
7266 // DidIt: flag variable: 1=single thread; 0=not single thread.
7267 llvm::Value *DidIt = nullptr;
7268 if (!CPVars.empty()) {
7269 DidIt = Builder.CreateAlloca(Ty: llvm::Type::getInt32Ty(C&: Builder.getContext()));
7270 Builder.CreateStore(Val: Builder.getInt32(C: 0), Ptr: DidIt);
7271 }
7272
7273 Directive OMPD = Directive::OMPD_single;
7274 uint32_t SrcLocStrSize;
7275 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7276 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7277 Value *ThreadId = getOrCreateThreadID(Ident);
7278 Value *Args[] = {Ident, ThreadId};
7279
7280 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_single);
7281 Instruction *EntryCall = createRuntimeFunctionCall(Callee: EntryRTLFn, Args);
7282
7283 Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_single);
7284 Instruction *ExitCall = createRuntimeFunctionCall(Callee: ExitRTLFn, Args);
7285
7286 auto FiniCBWrapper = [&](InsertPointTy IP) -> Error {
7287 if (Error Err = FiniCB(IP))
7288 return Err;
7289
7290 // The thread that executes the single region must set `DidIt` to 1.
7291 // This is used by __kmpc_copyprivate, to know if the caller is the
7292 // single thread or not.
7293 if (DidIt)
7294 Builder.CreateStore(Val: Builder.getInt32(C: 1), Ptr: DidIt);
7295
7296 return Error::success();
7297 };
7298
7299 // generates the following:
7300 // if (__kmpc_single()) {
7301 // .... single region ...
7302 // __kmpc_end_single
7303 // }
7304 // __kmpc_copyprivate
7305 // __kmpc_barrier
7306
7307 InsertPointOrErrorTy AfterIP =
7308 EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB: FiniCBWrapper,
7309 /*Conditional*/ true,
7310 /*hasFinalize*/ HasFinalize: true);
7311 if (!AfterIP)
7312 return AfterIP.takeError();
7313
7314 if (DidIt) {
7315 for (size_t I = 0, E = CPVars.size(); I < E; ++I)
7316 // NOTE BufSize is currently unused, so just pass 0.
7317 createCopyPrivate(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
7318 /*BufSize=*/ConstantInt::get(Ty: Int64, V: 0), CpyBuf: CPVars[I],
7319 CpyFn: CPFuncs[I], DidIt);
7320 // NOTE __kmpc_copyprivate already inserts a barrier
7321 } else if (!IsNowait) {
7322 InsertPointOrErrorTy AfterIP =
7323 createBarrier(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
7324 Kind: omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
7325 /* CheckCancelFlag */ false);
7326 if (!AfterIP)
7327 return AfterIP.takeError();
7328 }
7329 return Builder.saveIP();
7330}
7331
7332OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createCritical(
7333 const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
7334 FinalizeCallbackTy FiniCB, StringRef CriticalName, Value *HintInst) {
7335
7336 if (!updateToLocation(Loc))
7337 return Loc.IP;
7338
7339 Directive OMPD = Directive::OMPD_critical;
7340 uint32_t SrcLocStrSize;
7341 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7342 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7343 Value *ThreadId = getOrCreateThreadID(Ident);
7344 Value *LockVar = getOMPCriticalRegionLock(CriticalName);
7345 Value *Args[] = {Ident, ThreadId, LockVar};
7346
7347 SmallVector<llvm::Value *, 4> EnterArgs(std::begin(arr&: Args), std::end(arr&: Args));
7348 Function *RTFn = nullptr;
7349 if (HintInst) {
7350 // Add Hint to entry Args and create call
7351 EnterArgs.push_back(Elt: HintInst);
7352 RTFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_critical_with_hint);
7353 } else {
7354 RTFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_critical);
7355 }
7356 Instruction *EntryCall = createRuntimeFunctionCall(Callee: RTFn, Args: EnterArgs);
7357
7358 Function *ExitRTLFn =
7359 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_critical);
7360 Instruction *ExitCall = createRuntimeFunctionCall(Callee: ExitRTLFn, Args);
7361
7362 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
7363 /*Conditional*/ false, /*hasFinalize*/ HasFinalize: true);
7364}
7365
7366OpenMPIRBuilder::InsertPointTy
7367OpenMPIRBuilder::createOrderedDepend(const LocationDescription &Loc,
7368 InsertPointTy AllocaIP, unsigned NumLoops,
7369 ArrayRef<llvm::Value *> StoreValues,
7370 const Twine &Name, bool IsDependSource) {
7371 assert(
7372 llvm::all_of(StoreValues,
7373 [](Value *SV) { return SV->getType()->isIntegerTy(64); }) &&
7374 "OpenMP runtime requires depend vec with i64 type");
7375
7376 if (!updateToLocation(Loc))
7377 return Loc.IP;
7378
7379 // Allocate space for vector and generate alloc instruction.
7380 auto *ArrI64Ty = ArrayType::get(ElementType: Int64, NumElements: NumLoops);
7381 Builder.restoreIP(IP: AllocaIP);
7382 AllocaInst *ArgsBase = Builder.CreateAlloca(Ty: ArrI64Ty, ArraySize: nullptr, Name);
7383 ArgsBase->setAlignment(Align(8));
7384 updateToLocation(Loc);
7385
7386 // Store the index value with offset in depend vector.
7387 for (unsigned I = 0; I < NumLoops; ++I) {
7388 Value *DependAddrGEPIter = Builder.CreateInBoundsGEP(
7389 Ty: ArrI64Ty, Ptr: ArgsBase, IdxList: {Builder.getInt64(C: 0), Builder.getInt64(C: I)});
7390 StoreInst *STInst = Builder.CreateStore(Val: StoreValues[I], Ptr: DependAddrGEPIter);
7391 STInst->setAlignment(Align(8));
7392 }
7393
7394 Value *DependBaseAddrGEP = Builder.CreateInBoundsGEP(
7395 Ty: ArrI64Ty, Ptr: ArgsBase, IdxList: {Builder.getInt64(C: 0), Builder.getInt64(C: 0)});
7396
7397 uint32_t SrcLocStrSize;
7398 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7399 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7400 Value *ThreadId = getOrCreateThreadID(Ident);
7401 Value *Args[] = {Ident, ThreadId, DependBaseAddrGEP};
7402
7403 Function *RTLFn = nullptr;
7404 if (IsDependSource)
7405 RTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_doacross_post);
7406 else
7407 RTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_doacross_wait);
7408 createRuntimeFunctionCall(Callee: RTLFn, Args);
7409
7410 return Builder.saveIP();
7411}
7412
7413OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createOrderedThreadsSimd(
7414 const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
7415 FinalizeCallbackTy FiniCB, bool IsThreads) {
7416 if (!updateToLocation(Loc))
7417 return Loc.IP;
7418
7419 Directive OMPD = Directive::OMPD_ordered;
7420 Instruction *EntryCall = nullptr;
7421 Instruction *ExitCall = nullptr;
7422
7423 if (IsThreads) {
7424 uint32_t SrcLocStrSize;
7425 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7426 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7427 Value *ThreadId = getOrCreateThreadID(Ident);
7428 Value *Args[] = {Ident, ThreadId};
7429
7430 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_ordered);
7431 EntryCall = createRuntimeFunctionCall(Callee: EntryRTLFn, Args);
7432
7433 Function *ExitRTLFn =
7434 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_ordered);
7435 ExitCall = createRuntimeFunctionCall(Callee: ExitRTLFn, Args);
7436 }
7437
7438 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
7439 /*Conditional*/ false, /*hasFinalize*/ HasFinalize: true);
7440}
7441
7442OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::EmitOMPInlinedRegion(
7443 Directive OMPD, Instruction *EntryCall, Instruction *ExitCall,
7444 BodyGenCallbackTy BodyGenCB, FinalizeCallbackTy FiniCB, bool Conditional,
7445 bool HasFinalize, bool IsCancellable) {
7446
7447 if (HasFinalize)
7448 FinalizationStack.push_back(Elt: {FiniCB, OMPD, IsCancellable});
7449
7450 // Create inlined region's entry and body blocks, in preparation
7451 // for conditional creation
7452 BasicBlock *EntryBB = Builder.GetInsertBlock();
7453 Instruction *SplitPos = EntryBB->getTerminator();
7454 if (!isa_and_nonnull<BranchInst>(Val: SplitPos))
7455 SplitPos = new UnreachableInst(Builder.getContext(), EntryBB);
7456 BasicBlock *ExitBB = EntryBB->splitBasicBlock(I: SplitPos, BBName: "omp_region.end");
7457 BasicBlock *FiniBB =
7458 EntryBB->splitBasicBlock(I: EntryBB->getTerminator(), BBName: "omp_region.finalize");
7459
7460 Builder.SetInsertPoint(EntryBB->getTerminator());
7461 emitCommonDirectiveEntry(OMPD, EntryCall, ExitBB, Conditional);
7462
7463 // generate body
7464 if (Error Err = BodyGenCB(/* AllocaIP */ InsertPointTy(),
7465 /* CodeGenIP */ Builder.saveIP()))
7466 return Err;
7467
7468 // emit exit call and do any needed finalization.
7469 auto FinIP = InsertPointTy(FiniBB, FiniBB->getFirstInsertionPt());
7470 assert(FiniBB->getTerminator()->getNumSuccessors() == 1 &&
7471 FiniBB->getTerminator()->getSuccessor(0) == ExitBB &&
7472 "Unexpected control flow graph state!!");
7473 InsertPointOrErrorTy AfterIP =
7474 emitCommonDirectiveExit(OMPD, FinIP, ExitCall, HasFinalize);
7475 if (!AfterIP)
7476 return AfterIP.takeError();
7477
7478 // If we are skipping the region of a non conditional, remove the exit
7479 // block, and clear the builder's insertion point.
7480 assert(SplitPos->getParent() == ExitBB &&
7481 "Unexpected Insertion point location!");
7482 auto merged = MergeBlockIntoPredecessor(BB: ExitBB);
7483 BasicBlock *ExitPredBB = SplitPos->getParent();
7484 auto InsertBB = merged ? ExitPredBB : ExitBB;
7485 if (!isa_and_nonnull<BranchInst>(Val: SplitPos))
7486 SplitPos->eraseFromParent();
7487 Builder.SetInsertPoint(InsertBB);
7488
7489 return Builder.saveIP();
7490}
7491
7492OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveEntry(
7493 Directive OMPD, Value *EntryCall, BasicBlock *ExitBB, bool Conditional) {
7494 // if nothing to do, Return current insertion point.
7495 if (!Conditional || !EntryCall)
7496 return Builder.saveIP();
7497
7498 BasicBlock *EntryBB = Builder.GetInsertBlock();
7499 Value *CallBool = Builder.CreateIsNotNull(Arg: EntryCall);
7500 auto *ThenBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp_region.body");
7501 auto *UI = new UnreachableInst(Builder.getContext(), ThenBB);
7502
7503 // Emit thenBB and set the Builder's insertion point there for
7504 // body generation next. Place the block after the current block.
7505 Function *CurFn = EntryBB->getParent();
7506 CurFn->insert(Position: std::next(x: EntryBB->getIterator()), BB: ThenBB);
7507
7508 // Move Entry branch to end of ThenBB, and replace with conditional
7509 // branch (If-stmt)
7510 Instruction *EntryBBTI = EntryBB->getTerminator();
7511 Builder.CreateCondBr(Cond: CallBool, True: ThenBB, False: ExitBB);
7512 EntryBBTI->removeFromParent();
7513 Builder.SetInsertPoint(UI);
7514 Builder.Insert(I: EntryBBTI);
7515 UI->eraseFromParent();
7516 Builder.SetInsertPoint(ThenBB->getTerminator());
7517
7518 // return an insertion point to ExitBB.
7519 return IRBuilder<>::InsertPoint(ExitBB, ExitBB->getFirstInsertionPt());
7520}
7521
7522OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitCommonDirectiveExit(
7523 omp::Directive OMPD, InsertPointTy FinIP, Instruction *ExitCall,
7524 bool HasFinalize) {
7525
7526 Builder.restoreIP(IP: FinIP);
7527
7528 // If there is finalization to do, emit it before the exit call
7529 if (HasFinalize) {
7530 assert(!FinalizationStack.empty() &&
7531 "Unexpected finalization stack state!");
7532
7533 FinalizationInfo Fi = FinalizationStack.pop_back_val();
7534 assert(Fi.DK == OMPD && "Unexpected Directive for Finalization call!");
7535
7536 if (Error Err = Fi.mergeFiniBB(Builder, OtherFiniBB: FinIP.getBlock()))
7537 return std::move(Err);
7538
7539 // Exit condition: insertion point is before the terminator of the new Fini
7540 // block
7541 Builder.SetInsertPoint(FinIP.getBlock()->getTerminator());
7542 }
7543
7544 if (!ExitCall)
7545 return Builder.saveIP();
7546
7547 // place the Exitcall as last instruction before Finalization block terminator
7548 ExitCall->removeFromParent();
7549 Builder.Insert(I: ExitCall);
7550
7551 return IRBuilder<>::InsertPoint(ExitCall->getParent(),
7552 ExitCall->getIterator());
7553}
7554
7555OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCopyinClauseBlocks(
7556 InsertPointTy IP, Value *MasterAddr, Value *PrivateAddr,
7557 llvm::IntegerType *IntPtrTy, bool BranchtoEnd) {
7558 if (!IP.isSet())
7559 return IP;
7560
7561 IRBuilder<>::InsertPointGuard IPG(Builder);
7562
7563 // creates the following CFG structure
7564 // OMP_Entry : (MasterAddr != PrivateAddr)?
7565 // F T
7566 // | \
7567 // | copin.not.master
7568 // | /
7569 // v /
7570 // copyin.not.master.end
7571 // |
7572 // v
7573 // OMP.Entry.Next
7574
7575 BasicBlock *OMP_Entry = IP.getBlock();
7576 Function *CurFn = OMP_Entry->getParent();
7577 BasicBlock *CopyBegin =
7578 BasicBlock::Create(Context&: M.getContext(), Name: "copyin.not.master", Parent: CurFn);
7579 BasicBlock *CopyEnd = nullptr;
7580
7581 // If entry block is terminated, split to preserve the branch to following
7582 // basic block (i.e. OMP.Entry.Next), otherwise, leave everything as is.
7583 if (isa_and_nonnull<BranchInst>(Val: OMP_Entry->getTerminator())) {
7584 CopyEnd = OMP_Entry->splitBasicBlock(I: OMP_Entry->getTerminator(),
7585 BBName: "copyin.not.master.end");
7586 OMP_Entry->getTerminator()->eraseFromParent();
7587 } else {
7588 CopyEnd =
7589 BasicBlock::Create(Context&: M.getContext(), Name: "copyin.not.master.end", Parent: CurFn);
7590 }
7591
7592 Builder.SetInsertPoint(OMP_Entry);
7593 Value *MasterPtr = Builder.CreatePtrToInt(V: MasterAddr, DestTy: IntPtrTy);
7594 Value *PrivatePtr = Builder.CreatePtrToInt(V: PrivateAddr, DestTy: IntPtrTy);
7595 Value *cmp = Builder.CreateICmpNE(LHS: MasterPtr, RHS: PrivatePtr);
7596 Builder.CreateCondBr(Cond: cmp, True: CopyBegin, False: CopyEnd);
7597
7598 Builder.SetInsertPoint(CopyBegin);
7599 if (BranchtoEnd)
7600 Builder.SetInsertPoint(Builder.CreateBr(Dest: CopyEnd));
7601
7602 return Builder.saveIP();
7603}
7604
7605CallInst *OpenMPIRBuilder::createOMPAlloc(const LocationDescription &Loc,
7606 Value *Size, Value *Allocator,
7607 std::string Name) {
7608 IRBuilder<>::InsertPointGuard IPG(Builder);
7609 updateToLocation(Loc);
7610
7611 uint32_t SrcLocStrSize;
7612 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7613 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7614 Value *ThreadId = getOrCreateThreadID(Ident);
7615 Value *Args[] = {ThreadId, Size, Allocator};
7616
7617 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_alloc);
7618
7619 return createRuntimeFunctionCall(Callee: Fn, Args, Name);
7620}
7621
7622CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc,
7623 Value *Addr, Value *Allocator,
7624 std::string Name) {
7625 IRBuilder<>::InsertPointGuard IPG(Builder);
7626 updateToLocation(Loc);
7627
7628 uint32_t SrcLocStrSize;
7629 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7630 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7631 Value *ThreadId = getOrCreateThreadID(Ident);
7632 Value *Args[] = {ThreadId, Addr, Allocator};
7633 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_free);
7634 return createRuntimeFunctionCall(Callee: Fn, Args, Name);
7635}
7636
7637CallInst *OpenMPIRBuilder::createOMPInteropInit(
7638 const LocationDescription &Loc, Value *InteropVar,
7639 omp::OMPInteropType InteropType, Value *Device, Value *NumDependences,
7640 Value *DependenceAddress, bool HaveNowaitClause) {
7641 IRBuilder<>::InsertPointGuard IPG(Builder);
7642 updateToLocation(Loc);
7643
7644 uint32_t SrcLocStrSize;
7645 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7646 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7647 Value *ThreadId = getOrCreateThreadID(Ident);
7648 if (Device == nullptr)
7649 Device = Constant::getAllOnesValue(Ty: Int32);
7650 Constant *InteropTypeVal = ConstantInt::get(Ty: Int32, V: (int)InteropType);
7651 if (NumDependences == nullptr) {
7652 NumDependences = ConstantInt::get(Ty: Int32, V: 0);
7653 PointerType *PointerTypeVar = PointerType::getUnqual(C&: M.getContext());
7654 DependenceAddress = ConstantPointerNull::get(T: PointerTypeVar);
7655 }
7656 Value *HaveNowaitClauseVal = ConstantInt::get(Ty: Int32, V: HaveNowaitClause);
7657 Value *Args[] = {
7658 Ident, ThreadId, InteropVar, InteropTypeVal,
7659 Device, NumDependences, DependenceAddress, HaveNowaitClauseVal};
7660
7661 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___tgt_interop_init);
7662
7663 return createRuntimeFunctionCall(Callee: Fn, Args);
7664}
7665
7666CallInst *OpenMPIRBuilder::createOMPInteropDestroy(
7667 const LocationDescription &Loc, Value *InteropVar, Value *Device,
7668 Value *NumDependences, Value *DependenceAddress, bool HaveNowaitClause) {
7669 IRBuilder<>::InsertPointGuard IPG(Builder);
7670 updateToLocation(Loc);
7671
7672 uint32_t SrcLocStrSize;
7673 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7674 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7675 Value *ThreadId = getOrCreateThreadID(Ident);
7676 if (Device == nullptr)
7677 Device = Constant::getAllOnesValue(Ty: Int32);
7678 if (NumDependences == nullptr) {
7679 NumDependences = ConstantInt::get(Ty: Int32, V: 0);
7680 PointerType *PointerTypeVar = PointerType::getUnqual(C&: M.getContext());
7681 DependenceAddress = ConstantPointerNull::get(T: PointerTypeVar);
7682 }
7683 Value *HaveNowaitClauseVal = ConstantInt::get(Ty: Int32, V: HaveNowaitClause);
7684 Value *Args[] = {
7685 Ident, ThreadId, InteropVar, Device,
7686 NumDependences, DependenceAddress, HaveNowaitClauseVal};
7687
7688 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___tgt_interop_destroy);
7689
7690 return createRuntimeFunctionCall(Callee: Fn, Args);
7691}
7692
7693CallInst *OpenMPIRBuilder::createOMPInteropUse(const LocationDescription &Loc,
7694 Value *InteropVar, Value *Device,
7695 Value *NumDependences,
7696 Value *DependenceAddress,
7697 bool HaveNowaitClause) {
7698 IRBuilder<>::InsertPointGuard IPG(Builder);
7699 updateToLocation(Loc);
7700 uint32_t SrcLocStrSize;
7701 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7702 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7703 Value *ThreadId = getOrCreateThreadID(Ident);
7704 if (Device == nullptr)
7705 Device = Constant::getAllOnesValue(Ty: Int32);
7706 if (NumDependences == nullptr) {
7707 NumDependences = ConstantInt::get(Ty: Int32, V: 0);
7708 PointerType *PointerTypeVar = PointerType::getUnqual(C&: M.getContext());
7709 DependenceAddress = ConstantPointerNull::get(T: PointerTypeVar);
7710 }
7711 Value *HaveNowaitClauseVal = ConstantInt::get(Ty: Int32, V: HaveNowaitClause);
7712 Value *Args[] = {
7713 Ident, ThreadId, InteropVar, Device,
7714 NumDependences, DependenceAddress, HaveNowaitClauseVal};
7715
7716 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___tgt_interop_use);
7717
7718 return createRuntimeFunctionCall(Callee: Fn, Args);
7719}
7720
7721CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
7722 const LocationDescription &Loc, llvm::Value *Pointer,
7723 llvm::ConstantInt *Size, const llvm::Twine &Name) {
7724 IRBuilder<>::InsertPointGuard IPG(Builder);
7725 updateToLocation(Loc);
7726
7727 uint32_t SrcLocStrSize;
7728 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7729 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7730 Value *ThreadId = getOrCreateThreadID(Ident);
7731 Constant *ThreadPrivateCache =
7732 getOrCreateInternalVariable(Ty: Int8PtrPtr, Name: Name.str());
7733 llvm::Value *Args[] = {Ident, ThreadId, Pointer, Size, ThreadPrivateCache};
7734
7735 Function *Fn =
7736 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_threadprivate_cached);
7737
7738 return createRuntimeFunctionCall(Callee: Fn, Args);
7739}
7740
7741OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
7742 const LocationDescription &Loc,
7743 const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
7744 assert(!Attrs.MaxThreads.empty() && !Attrs.MaxTeams.empty() &&
7745 "expected num_threads and num_teams to be specified");
7746
7747 if (!updateToLocation(Loc))
7748 return Loc.IP;
7749
7750 uint32_t SrcLocStrSize;
7751 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7752 Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7753 Constant *IsSPMDVal = ConstantInt::getSigned(Ty: Int8, V: Attrs.ExecFlags);
7754 Constant *UseGenericStateMachineVal = ConstantInt::getSigned(
7755 Ty: Int8, V: Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD);
7756 Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Ty: Int8, V: true);
7757 Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Ty: Int16, V: 0);
7758
7759 Function *DebugKernelWrapper = Builder.GetInsertBlock()->getParent();
7760 Function *Kernel = DebugKernelWrapper;
7761
7762 // We need to strip the debug prefix to get the correct kernel name.
7763 StringRef KernelName = Kernel->getName();
7764 const std::string DebugPrefix = "_debug__";
7765 if (KernelName.ends_with(Suffix: DebugPrefix)) {
7766 KernelName = KernelName.drop_back(N: DebugPrefix.length());
7767 Kernel = M.getFunction(Name: KernelName);
7768 assert(Kernel && "Expected the real kernel to exist");
7769 }
7770
7771 // Manifest the launch configuration in the metadata matching the kernel
7772 // environment.
7773 if (Attrs.MinTeams > 1 || Attrs.MaxTeams.front() > 0)
7774 writeTeamsForKernel(T, Kernel&: *Kernel, LB: Attrs.MinTeams, UB: Attrs.MaxTeams.front());
7775
7776 // If MaxThreads not set, select the maximum between the default workgroup
7777 // size and the MinThreads value.
7778 int32_t MaxThreadsVal = Attrs.MaxThreads.front();
7779 if (MaxThreadsVal < 0)
7780 MaxThreadsVal = std::max(
7781 a: int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), b: Attrs.MinThreads);
7782
7783 if (MaxThreadsVal > 0)
7784 writeThreadBoundsForKernel(T, Kernel&: *Kernel, LB: Attrs.MinThreads, UB: MaxThreadsVal);
7785
7786 Constant *MinThreads = ConstantInt::getSigned(Ty: Int32, V: Attrs.MinThreads);
7787 Constant *MaxThreads = ConstantInt::getSigned(Ty: Int32, V: MaxThreadsVal);
7788 Constant *MinTeams = ConstantInt::getSigned(Ty: Int32, V: Attrs.MinTeams);
7789 Constant *MaxTeams = ConstantInt::getSigned(Ty: Int32, V: Attrs.MaxTeams.front());
7790 Constant *ReductionDataSize =
7791 ConstantInt::getSigned(Ty: Int32, V: Attrs.ReductionDataSize);
7792 Constant *ReductionBufferLength =
7793 ConstantInt::getSigned(Ty: Int32, V: Attrs.ReductionBufferLength);
7794
7795 Function *Fn = getOrCreateRuntimeFunctionPtr(
7796 FnID: omp::RuntimeFunction::OMPRTL___kmpc_target_init);
7797 const DataLayout &DL = Fn->getDataLayout();
7798
7799 Twine DynamicEnvironmentName = KernelName + "_dynamic_environment";
7800 Constant *DynamicEnvironmentInitializer =
7801 ConstantStruct::get(T: DynamicEnvironment, V: {DebugIndentionLevelVal});
7802 GlobalVariable *DynamicEnvironmentGV = new GlobalVariable(
7803 M, DynamicEnvironment, /*IsConstant=*/false, GlobalValue::WeakODRLinkage,
7804 DynamicEnvironmentInitializer, DynamicEnvironmentName,
7805 /*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
7806 DL.getDefaultGlobalsAddressSpace());
7807 DynamicEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
7808
7809 Constant *DynamicEnvironment =
7810 DynamicEnvironmentGV->getType() == DynamicEnvironmentPtr
7811 ? DynamicEnvironmentGV
7812 : ConstantExpr::getAddrSpaceCast(C: DynamicEnvironmentGV,
7813 Ty: DynamicEnvironmentPtr);
7814
7815 Constant *ConfigurationEnvironmentInitializer = ConstantStruct::get(
7816 T: ConfigurationEnvironment, V: {
7817 UseGenericStateMachineVal,
7818 MayUseNestedParallelismVal,
7819 IsSPMDVal,
7820 MinThreads,
7821 MaxThreads,
7822 MinTeams,
7823 MaxTeams,
7824 ReductionDataSize,
7825 ReductionBufferLength,
7826 });
7827 Constant *KernelEnvironmentInitializer = ConstantStruct::get(
7828 T: KernelEnvironment, V: {
7829 ConfigurationEnvironmentInitializer,
7830 Ident,
7831 DynamicEnvironment,
7832 });
7833 std::string KernelEnvironmentName =
7834 (KernelName + "_kernel_environment").str();
7835 GlobalVariable *KernelEnvironmentGV = new GlobalVariable(
7836 M, KernelEnvironment, /*IsConstant=*/true, GlobalValue::WeakODRLinkage,
7837 KernelEnvironmentInitializer, KernelEnvironmentName,
7838 /*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
7839 DL.getDefaultGlobalsAddressSpace());
7840 KernelEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
7841
7842 Constant *KernelEnvironment =
7843 KernelEnvironmentGV->getType() == KernelEnvironmentPtr
7844 ? KernelEnvironmentGV
7845 : ConstantExpr::getAddrSpaceCast(C: KernelEnvironmentGV,
7846 Ty: KernelEnvironmentPtr);
7847 Value *KernelLaunchEnvironment =
7848 DebugKernelWrapper->getArg(i: DebugKernelWrapper->arg_size() - 1);
7849 Type *KernelLaunchEnvParamTy = Fn->getFunctionType()->getParamType(i: 1);
7850 KernelLaunchEnvironment =
7851 KernelLaunchEnvironment->getType() == KernelLaunchEnvParamTy
7852 ? KernelLaunchEnvironment
7853 : Builder.CreateAddrSpaceCast(V: KernelLaunchEnvironment,
7854 DestTy: KernelLaunchEnvParamTy);
7855 CallInst *ThreadKind = createRuntimeFunctionCall(
7856 Callee: Fn, Args: {KernelEnvironment, KernelLaunchEnvironment});
7857
7858 Value *ExecUserCode = Builder.CreateICmpEQ(
7859 LHS: ThreadKind, RHS: Constant::getAllOnesValue(Ty: ThreadKind->getType()),
7860 Name: "exec_user_code");
7861
7862 // ThreadKind = __kmpc_target_init(...)
7863 // if (ThreadKind == -1)
7864 // user_code
7865 // else
7866 // return;
7867
7868 auto *UI = Builder.CreateUnreachable();
7869 BasicBlock *CheckBB = UI->getParent();
7870 BasicBlock *UserCodeEntryBB = CheckBB->splitBasicBlock(I: UI, BBName: "user_code.entry");
7871
7872 BasicBlock *WorkerExitBB = BasicBlock::Create(
7873 Context&: CheckBB->getContext(), Name: "worker.exit", Parent: CheckBB->getParent());
7874 Builder.SetInsertPoint(WorkerExitBB);
7875 Builder.CreateRetVoid();
7876
7877 auto *CheckBBTI = CheckBB->getTerminator();
7878 Builder.SetInsertPoint(CheckBBTI);
7879 Builder.CreateCondBr(Cond: ExecUserCode, True: UI->getParent(), False: WorkerExitBB);
7880
7881 CheckBBTI->eraseFromParent();
7882 UI->eraseFromParent();
7883
7884 // Continue in the "user_code" block, see diagram above and in
7885 // openmp/libomptarget/deviceRTLs/common/include/target.h .
7886 return InsertPointTy(UserCodeEntryBB, UserCodeEntryBB->getFirstInsertionPt());
7887}
7888
7889void OpenMPIRBuilder::createTargetDeinit(const LocationDescription &Loc,
7890 int32_t TeamsReductionDataSize,
7891 int32_t TeamsReductionBufferLength) {
7892 if (!updateToLocation(Loc))
7893 return;
7894
7895 Function *Fn = getOrCreateRuntimeFunctionPtr(
7896 FnID: omp::RuntimeFunction::OMPRTL___kmpc_target_deinit);
7897
7898 createRuntimeFunctionCall(Callee: Fn, Args: {});
7899
7900 if (!TeamsReductionBufferLength || !TeamsReductionDataSize)
7901 return;
7902
7903 Function *Kernel = Builder.GetInsertBlock()->getParent();
7904 // We need to strip the debug prefix to get the correct kernel name.
7905 StringRef KernelName = Kernel->getName();
7906 const std::string DebugPrefix = "_debug__";
7907 if (KernelName.ends_with(Suffix: DebugPrefix))
7908 KernelName = KernelName.drop_back(N: DebugPrefix.length());
7909 auto *KernelEnvironmentGV =
7910 M.getNamedGlobal(Name: (KernelName + "_kernel_environment").str());
7911 assert(KernelEnvironmentGV && "Expected kernel environment global\n");
7912 auto *KernelEnvironmentInitializer = KernelEnvironmentGV->getInitializer();
7913 auto *NewInitializer = ConstantFoldInsertValueInstruction(
7914 Agg: KernelEnvironmentInitializer,
7915 Val: ConstantInt::get(Ty: Int32, V: TeamsReductionDataSize), Idxs: {0, 7});
7916 NewInitializer = ConstantFoldInsertValueInstruction(
7917 Agg: NewInitializer, Val: ConstantInt::get(Ty: Int32, V: TeamsReductionBufferLength),
7918 Idxs: {0, 8});
7919 KernelEnvironmentGV->setInitializer(NewInitializer);
7920}
7921
7922static void updateNVPTXAttr(Function &Kernel, StringRef Name, int32_t Value,
7923 bool Min) {
7924 if (Kernel.hasFnAttribute(Kind: Name)) {
7925 int32_t OldLimit = Kernel.getFnAttributeAsParsedInteger(Kind: Name);
7926 Value = Min ? std::min(a: OldLimit, b: Value) : std::max(a: OldLimit, b: Value);
7927 }
7928 Kernel.addFnAttr(Kind: Name, Val: llvm::utostr(X: Value));
7929}
7930
7931std::pair<int32_t, int32_t>
7932OpenMPIRBuilder::readThreadBoundsForKernel(const Triple &T, Function &Kernel) {
7933 int32_t ThreadLimit =
7934 Kernel.getFnAttributeAsParsedInteger(Kind: "omp_target_thread_limit");
7935
7936 if (T.isAMDGPU()) {
7937 const auto &Attr = Kernel.getFnAttribute(Kind: "amdgpu-flat-work-group-size");
7938 if (!Attr.isValid() || !Attr.isStringAttribute())
7939 return {0, ThreadLimit};
7940 auto [LBStr, UBStr] = Attr.getValueAsString().split(Separator: ',');
7941 int32_t LB, UB;
7942 if (!llvm::to_integer(S: UBStr, Num&: UB, Base: 10))
7943 return {0, ThreadLimit};
7944 UB = ThreadLimit ? std::min(a: ThreadLimit, b: UB) : UB;
7945 if (!llvm::to_integer(S: LBStr, Num&: LB, Base: 10))
7946 return {0, UB};
7947 return {LB, UB};
7948 }
7949
7950 if (Kernel.hasFnAttribute(Kind: "nvvm.maxntid")) {
7951 int32_t UB = Kernel.getFnAttributeAsParsedInteger(Kind: "nvvm.maxntid");
7952 return {0, ThreadLimit ? std::min(a: ThreadLimit, b: UB) : UB};
7953 }
7954 return {0, ThreadLimit};
7955}
7956
7957void OpenMPIRBuilder::writeThreadBoundsForKernel(const Triple &T,
7958 Function &Kernel, int32_t LB,
7959 int32_t UB) {
7960 Kernel.addFnAttr(Kind: "omp_target_thread_limit", Val: std::to_string(val: UB));
7961
7962 if (T.isAMDGPU()) {
7963 Kernel.addFnAttr(Kind: "amdgpu-flat-work-group-size",
7964 Val: llvm::utostr(X: LB) + "," + llvm::utostr(X: UB));
7965 return;
7966 }
7967
7968 updateNVPTXAttr(Kernel, Name: "nvvm.maxntid", Value: UB, Min: true);
7969}
7970
7971std::pair<int32_t, int32_t>
7972OpenMPIRBuilder::readTeamBoundsForKernel(const Triple &, Function &Kernel) {
7973 // TODO: Read from backend annotations if available.
7974 return {0, Kernel.getFnAttributeAsParsedInteger(Kind: "omp_target_num_teams")};
7975}
7976
7977void OpenMPIRBuilder::writeTeamsForKernel(const Triple &T, Function &Kernel,
7978 int32_t LB, int32_t UB) {
7979 if (T.isNVPTX())
7980 if (UB > 0)
7981 Kernel.addFnAttr(Kind: "nvvm.maxclusterrank", Val: llvm::utostr(X: UB));
7982 if (T.isAMDGPU())
7983 Kernel.addFnAttr(Kind: "amdgpu-max-num-workgroups", Val: llvm::utostr(X: LB) + ",1,1");
7984
7985 Kernel.addFnAttr(Kind: "omp_target_num_teams", Val: std::to_string(val: LB));
7986}
7987
7988void OpenMPIRBuilder::setOutlinedTargetRegionFunctionAttributes(
7989 Function *OutlinedFn) {
7990 if (Config.isTargetDevice()) {
7991 OutlinedFn->setLinkage(GlobalValue::WeakODRLinkage);
7992 // TODO: Determine if DSO local can be set to true.
7993 OutlinedFn->setDSOLocal(false);
7994 OutlinedFn->setVisibility(GlobalValue::ProtectedVisibility);
7995 if (T.isAMDGCN())
7996 OutlinedFn->setCallingConv(CallingConv::AMDGPU_KERNEL);
7997 else if (T.isNVPTX())
7998 OutlinedFn->setCallingConv(CallingConv::PTX_Kernel);
7999 else if (T.isSPIRV())
8000 OutlinedFn->setCallingConv(CallingConv::SPIR_KERNEL);
8001 }
8002}
8003
8004Constant *OpenMPIRBuilder::createOutlinedFunctionID(Function *OutlinedFn,
8005 StringRef EntryFnIDName) {
8006 if (Config.isTargetDevice()) {
8007 assert(OutlinedFn && "The outlined function must exist if embedded");
8008 return OutlinedFn;
8009 }
8010
8011 return new GlobalVariable(
8012 M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
8013 Constant::getNullValue(Ty: Builder.getInt8Ty()), EntryFnIDName);
8014}
8015
8016Constant *OpenMPIRBuilder::createTargetRegionEntryAddr(Function *OutlinedFn,
8017 StringRef EntryFnName) {
8018 if (OutlinedFn)
8019 return OutlinedFn;
8020
8021 assert(!M.getGlobalVariable(EntryFnName, true) &&
8022 "Named kernel already exists?");
8023 return new GlobalVariable(
8024 M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::InternalLinkage,
8025 Constant::getNullValue(Ty: Builder.getInt8Ty()), EntryFnName);
8026}
8027
8028Error OpenMPIRBuilder::emitTargetRegionFunction(
8029 TargetRegionEntryInfo &EntryInfo,
8030 FunctionGenCallback &GenerateFunctionCallback, bool IsOffloadEntry,
8031 Function *&OutlinedFn, Constant *&OutlinedFnID) {
8032
8033 SmallString<64> EntryFnName;
8034 OffloadInfoManager.getTargetRegionEntryFnName(Name&: EntryFnName, EntryInfo);
8035
8036 if (Config.isTargetDevice() || !Config.openMPOffloadMandatory()) {
8037 Expected<Function *> CBResult = GenerateFunctionCallback(EntryFnName);
8038 if (!CBResult)
8039 return CBResult.takeError();
8040 OutlinedFn = *CBResult;
8041 } else {
8042 OutlinedFn = nullptr;
8043 }
8044
8045 // If this target outline function is not an offload entry, we don't need to
8046 // register it. This may be in the case of a false if clause, or if there are
8047 // no OpenMP targets.
8048 if (!IsOffloadEntry)
8049 return Error::success();
8050
8051 std::string EntryFnIDName =
8052 Config.isTargetDevice()
8053 ? std::string(EntryFnName)
8054 : createPlatformSpecificName(Parts: {EntryFnName, "region_id"});
8055
8056 OutlinedFnID = registerTargetRegionFunction(EntryInfo, OutlinedFunction: OutlinedFn,
8057 EntryFnName, EntryFnIDName);
8058 return Error::success();
8059}
8060
8061Constant *OpenMPIRBuilder::registerTargetRegionFunction(
8062 TargetRegionEntryInfo &EntryInfo, Function *OutlinedFn,
8063 StringRef EntryFnName, StringRef EntryFnIDName) {
8064 if (OutlinedFn)
8065 setOutlinedTargetRegionFunctionAttributes(OutlinedFn);
8066 auto OutlinedFnID = createOutlinedFunctionID(OutlinedFn, EntryFnIDName);
8067 auto EntryAddr = createTargetRegionEntryAddr(OutlinedFn, EntryFnName);
8068 OffloadInfoManager.registerTargetRegionEntryInfo(
8069 EntryInfo, Addr: EntryAddr, ID: OutlinedFnID,
8070 Flags: OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion);
8071 return OutlinedFnID;
8072}
8073
8074OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData(
8075 const LocationDescription &Loc, InsertPointTy AllocaIP,
8076 InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
8077 TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
8078 CustomMapperCallbackTy CustomMapperCB, omp::RuntimeFunction *MapperFunc,
8079 function_ref<InsertPointOrErrorTy(InsertPointTy CodeGenIP,
8080 BodyGenTy BodyGenType)>
8081 BodyGenCB,
8082 function_ref<void(unsigned int, Value *)> DeviceAddrCB, Value *SrcLocInfo) {
8083 if (!updateToLocation(Loc))
8084 return InsertPointTy();
8085
8086 Builder.restoreIP(IP: CodeGenIP);
8087
8088 bool IsStandAlone = !BodyGenCB;
8089 MapInfosTy *MapInfo;
8090 // Generate the code for the opening of the data environment. Capture all the
8091 // arguments of the runtime call by reference because they are used in the
8092 // closing of the region.
8093 auto BeginThenGen = [&](InsertPointTy AllocaIP,
8094 InsertPointTy CodeGenIP) -> Error {
8095 MapInfo = &GenMapInfoCB(Builder.saveIP());
8096 if (Error Err = emitOffloadingArrays(
8097 AllocaIP, CodeGenIP: Builder.saveIP(), CombinedInfo&: *MapInfo, Info, CustomMapperCB,
8098 /*IsNonContiguous=*/true, DeviceAddrCB))
8099 return Err;
8100
8101 TargetDataRTArgs RTArgs;
8102 emitOffloadingArraysArgument(Builder, RTArgs, Info);
8103
8104 // Emit the number of elements in the offloading arrays.
8105 Value *PointerNum = Builder.getInt32(C: Info.NumberOfPtrs);
8106
8107 // Source location for the ident struct
8108 if (!SrcLocInfo) {
8109 uint32_t SrcLocStrSize;
8110 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
8111 SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
8112 }
8113
8114 SmallVector<llvm::Value *, 13> OffloadingArgs = {
8115 SrcLocInfo, DeviceID,
8116 PointerNum, RTArgs.BasePointersArray,
8117 RTArgs.PointersArray, RTArgs.SizesArray,
8118 RTArgs.MapTypesArray, RTArgs.MapNamesArray,
8119 RTArgs.MappersArray};
8120
8121 if (IsStandAlone) {
8122 assert(MapperFunc && "MapperFunc missing for standalone target data");
8123
8124 auto TaskBodyCB = [&](Value *, Value *,
8125 IRBuilderBase::InsertPoint) -> Error {
8126 if (Info.HasNoWait) {
8127 OffloadingArgs.append(IL: {llvm::Constant::getNullValue(Ty: Int32),
8128 llvm::Constant::getNullValue(Ty: VoidPtr),
8129 llvm::Constant::getNullValue(Ty: Int32),
8130 llvm::Constant::getNullValue(Ty: VoidPtr)});
8131 }
8132
8133 createRuntimeFunctionCall(Callee: getOrCreateRuntimeFunctionPtr(FnID: *MapperFunc),
8134 Args: OffloadingArgs);
8135
8136 if (Info.HasNoWait) {
8137 BasicBlock *OffloadContBlock =
8138 BasicBlock::Create(Context&: Builder.getContext(), Name: "omp_offload.cont");
8139 Function *CurFn = Builder.GetInsertBlock()->getParent();
8140 emitBlock(BB: OffloadContBlock, CurFn, /*IsFinished=*/true);
8141 Builder.restoreIP(IP: Builder.saveIP());
8142 }
8143 return Error::success();
8144 };
8145
8146 bool RequiresOuterTargetTask = Info.HasNoWait;
8147 if (!RequiresOuterTargetTask)
8148 cantFail(Err: TaskBodyCB(/*DeviceID=*/nullptr, /*RTLoc=*/nullptr,
8149 /*TargetTaskAllocaIP=*/{}));
8150 else
8151 cantFail(ValOrErr: emitTargetTask(TaskBodyCB, DeviceID, RTLoc: SrcLocInfo, AllocaIP,
8152 /*Dependencies=*/{}, RTArgs, HasNoWait: Info.HasNoWait));
8153 } else {
8154 Function *BeginMapperFunc = getOrCreateRuntimeFunctionPtr(
8155 FnID: omp::OMPRTL___tgt_target_data_begin_mapper);
8156
8157 createRuntimeFunctionCall(Callee: BeginMapperFunc, Args: OffloadingArgs);
8158
8159 for (auto DeviceMap : Info.DevicePtrInfoMap) {
8160 if (isa<AllocaInst>(Val: DeviceMap.second.second)) {
8161 auto *LI =
8162 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: DeviceMap.second.first);
8163 Builder.CreateStore(Val: LI, Ptr: DeviceMap.second.second);
8164 }
8165 }
8166
8167 // If device pointer privatization is required, emit the body of the
8168 // region here. It will have to be duplicated: with and without
8169 // privatization.
8170 InsertPointOrErrorTy AfterIP =
8171 BodyGenCB(Builder.saveIP(), BodyGenTy::Priv);
8172 if (!AfterIP)
8173 return AfterIP.takeError();
8174 Builder.restoreIP(IP: *AfterIP);
8175 }
8176 return Error::success();
8177 };
8178
8179 // If we need device pointer privatization, we need to emit the body of the
8180 // region with no privatization in the 'else' branch of the conditional.
8181 // Otherwise, we don't have to do anything.
8182 auto BeginElseGen = [&](InsertPointTy AllocaIP,
8183 InsertPointTy CodeGenIP) -> Error {
8184 InsertPointOrErrorTy AfterIP =
8185 BodyGenCB(Builder.saveIP(), BodyGenTy::DupNoPriv);
8186 if (!AfterIP)
8187 return AfterIP.takeError();
8188 Builder.restoreIP(IP: *AfterIP);
8189 return Error::success();
8190 };
8191
8192 // Generate code for the closing of the data region.
8193 auto EndThenGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
8194 TargetDataRTArgs RTArgs;
8195 Info.EmitDebug = !MapInfo->Names.empty();
8196 emitOffloadingArraysArgument(Builder, RTArgs, Info, /*ForEndCall=*/true);
8197
8198 // Emit the number of elements in the offloading arrays.
8199 Value *PointerNum = Builder.getInt32(C: Info.NumberOfPtrs);
8200
8201 // Source location for the ident struct
8202 if (!SrcLocInfo) {
8203 uint32_t SrcLocStrSize;
8204 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
8205 SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
8206 }
8207
8208 Value *OffloadingArgs[] = {SrcLocInfo, DeviceID,
8209 PointerNum, RTArgs.BasePointersArray,
8210 RTArgs.PointersArray, RTArgs.SizesArray,
8211 RTArgs.MapTypesArray, RTArgs.MapNamesArray,
8212 RTArgs.MappersArray};
8213 Function *EndMapperFunc =
8214 getOrCreateRuntimeFunctionPtr(FnID: omp::OMPRTL___tgt_target_data_end_mapper);
8215
8216 createRuntimeFunctionCall(Callee: EndMapperFunc, Args: OffloadingArgs);
8217 return Error::success();
8218 };
8219
8220 // We don't have to do anything to close the region if the if clause evaluates
8221 // to false.
8222 auto EndElseGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
8223 return Error::success();
8224 };
8225
8226 Error Err = [&]() -> Error {
8227 if (BodyGenCB) {
8228 Error Err = [&]() {
8229 if (IfCond)
8230 return emitIfClause(Cond: IfCond, ThenGen: BeginThenGen, ElseGen: BeginElseGen, AllocaIP);
8231 return BeginThenGen(AllocaIP, Builder.saveIP());
8232 }();
8233
8234 if (Err)
8235 return Err;
8236
8237 // If we don't require privatization of device pointers, we emit the body
8238 // in between the runtime calls. This avoids duplicating the body code.
8239 InsertPointOrErrorTy AfterIP =
8240 BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv);
8241 if (!AfterIP)
8242 return AfterIP.takeError();
8243 restoreIPandDebugLoc(Builder, IP: *AfterIP);
8244
8245 if (IfCond)
8246 return emitIfClause(Cond: IfCond, ThenGen: EndThenGen, ElseGen: EndElseGen, AllocaIP);
8247 return EndThenGen(AllocaIP, Builder.saveIP());
8248 }
8249 if (IfCond)
8250 return emitIfClause(Cond: IfCond, ThenGen: BeginThenGen, ElseGen: EndElseGen, AllocaIP);
8251 return BeginThenGen(AllocaIP, Builder.saveIP());
8252 }();
8253
8254 if (Err)
8255 return Err;
8256
8257 return Builder.saveIP();
8258}
8259
8260FunctionCallee
8261OpenMPIRBuilder::createForStaticInitFunction(unsigned IVSize, bool IVSigned,
8262 bool IsGPUDistribute) {
8263 assert((IVSize == 32 || IVSize == 64) &&
8264 "IV size is not compatible with the omp runtime");
8265 RuntimeFunction Name;
8266 if (IsGPUDistribute)
8267 Name = IVSize == 32
8268 ? (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_4
8269 : omp::OMPRTL___kmpc_distribute_static_init_4u)
8270 : (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_8
8271 : omp::OMPRTL___kmpc_distribute_static_init_8u);
8272 else
8273 Name = IVSize == 32 ? (IVSigned ? omp::OMPRTL___kmpc_for_static_init_4
8274 : omp::OMPRTL___kmpc_for_static_init_4u)
8275 : (IVSigned ? omp::OMPRTL___kmpc_for_static_init_8
8276 : omp::OMPRTL___kmpc_for_static_init_8u);
8277
8278 return getOrCreateRuntimeFunction(M, FnID: Name);
8279}
8280
8281FunctionCallee OpenMPIRBuilder::createDispatchInitFunction(unsigned IVSize,
8282 bool IVSigned) {
8283 assert((IVSize == 32 || IVSize == 64) &&
8284 "IV size is not compatible with the omp runtime");
8285 RuntimeFunction Name = IVSize == 32
8286 ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_4
8287 : omp::OMPRTL___kmpc_dispatch_init_4u)
8288 : (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_8
8289 : omp::OMPRTL___kmpc_dispatch_init_8u);
8290
8291 return getOrCreateRuntimeFunction(M, FnID: Name);
8292}
8293
8294FunctionCallee OpenMPIRBuilder::createDispatchNextFunction(unsigned IVSize,
8295 bool IVSigned) {
8296 assert((IVSize == 32 || IVSize == 64) &&
8297 "IV size is not compatible with the omp runtime");
8298 RuntimeFunction Name = IVSize == 32
8299 ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_4
8300 : omp::OMPRTL___kmpc_dispatch_next_4u)
8301 : (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_8
8302 : omp::OMPRTL___kmpc_dispatch_next_8u);
8303
8304 return getOrCreateRuntimeFunction(M, FnID: Name);
8305}
8306
8307FunctionCallee OpenMPIRBuilder::createDispatchFiniFunction(unsigned IVSize,
8308 bool IVSigned) {
8309 assert((IVSize == 32 || IVSize == 64) &&
8310 "IV size is not compatible with the omp runtime");
8311 RuntimeFunction Name = IVSize == 32
8312 ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_4
8313 : omp::OMPRTL___kmpc_dispatch_fini_4u)
8314 : (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_8
8315 : omp::OMPRTL___kmpc_dispatch_fini_8u);
8316
8317 return getOrCreateRuntimeFunction(M, FnID: Name);
8318}
8319
8320FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
8321 return getOrCreateRuntimeFunction(M, FnID: omp::OMPRTL___kmpc_dispatch_deinit);
8322}
8323
8324static void FixupDebugInfoForOutlinedFunction(
8325 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, Function *Func,
8326 DenseMap<Value *, std::tuple<Value *, unsigned>> &ValueReplacementMap) {
8327
8328 DISubprogram *NewSP = Func->getSubprogram();
8329 if (!NewSP)
8330 return;
8331
8332 SmallDenseMap<DILocalVariable *, DILocalVariable *> RemappedVariables;
8333
8334 auto GetUpdatedDIVariable = [&](DILocalVariable *OldVar, unsigned arg) {
8335 DILocalVariable *&NewVar = RemappedVariables[OldVar];
8336 // Only use cached variable if the arg number matches. This is important
8337 // so that DIVariable created for privatized variables are not discarded.
8338 if (NewVar && (arg == NewVar->getArg()))
8339 return NewVar;
8340
8341 NewVar = llvm::DILocalVariable::get(
8342 Context&: Builder.getContext(), Scope: OldVar->getScope(), Name: OldVar->getName(),
8343 File: OldVar->getFile(), Line: OldVar->getLine(), Type: OldVar->getType(), Arg: arg,
8344 Flags: OldVar->getFlags(), AlignInBits: OldVar->getAlignInBits(), Annotations: OldVar->getAnnotations());
8345 return NewVar;
8346 };
8347
8348 auto UpdateDebugRecord = [&](auto *DR) {
8349 DILocalVariable *OldVar = DR->getVariable();
8350 unsigned ArgNo = 0;
8351 for (auto Loc : DR->location_ops()) {
8352 auto Iter = ValueReplacementMap.find(Loc);
8353 if (Iter != ValueReplacementMap.end()) {
8354 DR->replaceVariableLocationOp(Loc, std::get<0>(Iter->second));
8355 ArgNo = std::get<1>(Iter->second) + 1;
8356 }
8357 }
8358 if (ArgNo != 0)
8359 DR->setVariable(GetUpdatedDIVariable(OldVar, ArgNo));
8360 };
8361
8362 // The location and scope of variable intrinsics and records still point to
8363 // the parent function of the target region. Update them.
8364 for (Instruction &I : instructions(F: Func)) {
8365 assert(!isa<llvm::DbgVariableIntrinsic>(&I) &&
8366 "Unexpected debug intrinsic");
8367 for (DbgVariableRecord &DVR : filterDbgVars(R: I.getDbgRecordRange()))
8368 UpdateDebugRecord(&DVR);
8369 }
8370 // An extra argument is passed to the device. Create the debug data for it.
8371 if (OMPBuilder.Config.isTargetDevice()) {
8372 DICompileUnit *CU = NewSP->getUnit();
8373 Module *M = Func->getParent();
8374 DIBuilder DB(*M, true, CU);
8375 DIType *VoidPtrTy =
8376 DB.createQualifiedType(Tag: dwarf::DW_TAG_pointer_type, FromTy: nullptr);
8377 unsigned ArgNo = Func->arg_size();
8378 DILocalVariable *Var = DB.createParameterVariable(
8379 Scope: NewSP, Name: "dyn_ptr", ArgNo, File: NewSP->getFile(), /*LineNo=*/0, Ty: VoidPtrTy,
8380 /*AlwaysPreserve=*/false, Flags: DINode::DIFlags::FlagArtificial);
8381 auto Loc = DILocation::get(Context&: Func->getContext(), Line: 0, Column: 0, Scope: NewSP, InlinedAt: 0);
8382 Argument *LastArg = Func->getArg(i: Func->arg_size() - 1);
8383 DB.insertDeclare(Storage: LastArg, VarInfo: Var, Expr: DB.createExpression(), DL: Loc,
8384 InsertAtEnd: &(*Func->begin()));
8385 }
8386}
8387
8388static Value *removeASCastIfPresent(Value *V) {
8389 if (Operator::getOpcode(V) == Instruction::AddrSpaceCast)
8390 return cast<Operator>(Val: V)->getOperand(i: 0);
8391 return V;
8392}
8393
8394static Expected<Function *> createOutlinedFunction(
8395 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
8396 const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
8397 StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
8398 OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
8399 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
8400 SmallVector<Type *> ParameterTypes;
8401 if (OMPBuilder.Config.isTargetDevice()) {
8402 // All parameters to target devices are passed as pointers
8403 // or i64. This assumes 64-bit address spaces/pointers.
8404 for (auto &Arg : Inputs)
8405 ParameterTypes.push_back(Elt: Arg->getType()->isPointerTy()
8406 ? Arg->getType()
8407 : Type::getInt64Ty(C&: Builder.getContext()));
8408 } else {
8409 for (auto &Arg : Inputs)
8410 ParameterTypes.push_back(Elt: Arg->getType());
8411 }
8412
8413 // The implicit dyn_ptr argument is always the last parameter on both host
8414 // and device so the argument counts match without runtime manipulation.
8415 auto *PtrTy = PointerType::getUnqual(C&: Builder.getContext());
8416 ParameterTypes.push_back(Elt: PtrTy);
8417
8418 auto BB = Builder.GetInsertBlock();
8419 auto M = BB->getModule();
8420 auto FuncType = FunctionType::get(Result: Builder.getVoidTy(), Params: ParameterTypes,
8421 /*isVarArg*/ false);
8422 auto Func =
8423 Function::Create(Ty: FuncType, Linkage: GlobalValue::InternalLinkage, N: FuncName, M);
8424
8425 // Forward target-cpu and target-features function attributes from the
8426 // original function to the new outlined function.
8427 Function *ParentFn = Builder.GetInsertBlock()->getParent();
8428
8429 auto TargetCpuAttr = ParentFn->getFnAttribute(Kind: "target-cpu");
8430 if (TargetCpuAttr.isStringAttribute())
8431 Func->addFnAttr(Attr: TargetCpuAttr);
8432
8433 auto TargetFeaturesAttr = ParentFn->getFnAttribute(Kind: "target-features");
8434 if (TargetFeaturesAttr.isStringAttribute())
8435 Func->addFnAttr(Attr: TargetFeaturesAttr);
8436
8437 if (OMPBuilder.Config.isTargetDevice()) {
8438 Value *ExecMode =
8439 OMPBuilder.emitKernelExecutionMode(KernelName: FuncName, Mode: DefaultAttrs.ExecFlags);
8440 OMPBuilder.emitUsed(Name: "llvm.compiler.used", List: {ExecMode});
8441 }
8442
8443 // Save insert point.
8444 IRBuilder<>::InsertPointGuard IPG(Builder);
8445 // We will generate the entries in the outlined function but the debug
8446 // location may still be pointing to the parent function. Reset it now.
8447 Builder.SetCurrentDebugLocation(llvm::DebugLoc());
8448
8449 // Generate the region into the function.
8450 BasicBlock *EntryBB = BasicBlock::Create(Context&: Builder.getContext(), Name: "entry", Parent: Func);
8451 Builder.SetInsertPoint(EntryBB);
8452
8453 // Insert target init call in the device compilation pass.
8454 if (OMPBuilder.Config.isTargetDevice())
8455 Builder.restoreIP(IP: OMPBuilder.createTargetInit(Loc: Builder, Attrs: DefaultAttrs));
8456
8457 BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
8458
8459 // As we embed the user code in the middle of our target region after we
8460 // generate entry code, we must move what allocas we can into the entry
8461 // block to avoid possible breaking optimisations for device
8462 if (OMPBuilder.Config.isTargetDevice())
8463 OMPBuilder.ConstantAllocaRaiseCandidates.emplace_back(Args&: Func);
8464
8465 // Insert target deinit call in the device compilation pass.
8466 BasicBlock *OutlinedBodyBB =
8467 splitBB(Builder, /*CreateBranch=*/true, Name: "outlined.body");
8468 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = CBFunc(
8469 Builder.saveIP(),
8470 OpenMPIRBuilder::InsertPointTy(OutlinedBodyBB, OutlinedBodyBB->begin()));
8471 if (!AfterIP)
8472 return AfterIP.takeError();
8473 Builder.restoreIP(IP: *AfterIP);
8474 if (OMPBuilder.Config.isTargetDevice())
8475 OMPBuilder.createTargetDeinit(Loc: Builder);
8476
8477 // Insert return instruction.
8478 Builder.CreateRetVoid();
8479
8480 // New Alloca IP at entry point of created device function.
8481 Builder.SetInsertPoint(EntryBB->getFirstNonPHIIt());
8482 auto AllocaIP = Builder.saveIP();
8483
8484 Builder.SetInsertPoint(UserCodeEntryBB->getFirstNonPHIOrDbg());
8485
8486 // Do not include the artificial dyn_ptr argument.
8487 const auto &ArgRange = make_range(x: Func->arg_begin(), y: Func->arg_end() - 1);
8488
8489 DenseMap<Value *, std::tuple<Value *, unsigned>> ValueReplacementMap;
8490
8491 auto ReplaceValue = [](Value *Input, Value *InputCopy, Function *Func) {
8492 // Things like GEP's can come in the form of Constants. Constants and
8493 // ConstantExpr's do not have access to the knowledge of what they're
8494 // contained in, so we must dig a little to find an instruction so we
8495 // can tell if they're used inside of the function we're outlining. We
8496 // also replace the original constant expression with a new instruction
8497 // equivalent; an instruction as it allows easy modification in the
8498 // following loop, as we can now know the constant (instruction) is
8499 // owned by our target function and replaceUsesOfWith can now be invoked
8500 // on it (cannot do this with constants it seems). A brand new one also
8501 // allows us to be cautious as it is perhaps possible the old expression
8502 // was used inside of the function but exists and is used externally
8503 // (unlikely by the nature of a Constant, but still).
8504 // NOTE: We cannot remove dead constants that have been rewritten to
8505 // instructions at this stage, we run the risk of breaking later lowering
8506 // by doing so as we could still be in the process of lowering the module
8507 // from MLIR to LLVM-IR and the MLIR lowering may still require the original
8508 // constants we have created rewritten versions of.
8509 if (auto *Const = dyn_cast<Constant>(Val: Input))
8510 convertUsersOfConstantsToInstructions(Consts: Const, RestrictToFunc: Func, RemoveDeadConstants: false);
8511
8512 // Collect users before iterating over them to avoid invalidating the
8513 // iteration in case a user uses Input more than once (e.g. a call
8514 // instruction).
8515 SetVector<User *> Users(Input->users().begin(), Input->users().end());
8516 // Collect all the instructions
8517 for (User *User : make_early_inc_range(Range&: Users))
8518 if (auto *Instr = dyn_cast<Instruction>(Val: User))
8519 if (Instr->getFunction() == Func)
8520 Instr->replaceUsesOfWith(From: Input, To: InputCopy);
8521 };
8522
8523 SmallVector<std::pair<Value *, Value *>> DeferredReplacement;
8524
8525 // Rewrite uses of input valus to parameters.
8526 for (auto InArg : zip(t&: Inputs, u: ArgRange)) {
8527 Value *Input = std::get<0>(t&: InArg);
8528 Argument &Arg = std::get<1>(t&: InArg);
8529 Value *InputCopy = nullptr;
8530
8531 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
8532 ArgAccessorFuncCB(Arg, Input, InputCopy, AllocaIP, Builder.saveIP());
8533 if (!AfterIP)
8534 return AfterIP.takeError();
8535 Builder.restoreIP(IP: *AfterIP);
8536 ValueReplacementMap[Input] = std::make_tuple(args&: InputCopy, args: Arg.getArgNo());
8537
8538 // In certain cases a Global may be set up for replacement, however, this
8539 // Global may be used in multiple arguments to the kernel, just segmented
8540 // apart, for example, if we have a global array, that is sectioned into
8541 // multiple mappings (technically not legal in OpenMP, but there is a case
8542 // in Fortran for Common Blocks where this is neccesary), we will end up
8543 // with GEP's into this array inside the kernel, that refer to the Global
8544 // but are technically separate arguments to the kernel for all intents and
8545 // purposes. If we have mapped a segment that requires a GEP into the 0-th
8546 // index, it will fold into an referal to the Global, if we then encounter
8547 // this folded GEP during replacement all of the references to the
8548 // Global in the kernel will be replaced with the argument we have generated
8549 // that corresponds to it, including any other GEP's that refer to the
8550 // Global that may be other arguments. This will invalidate all of the other
8551 // preceding mapped arguments that refer to the same global that may be
8552 // separate segments. To prevent this, we defer global processing until all
8553 // other processing has been performed.
8554 if (llvm::isa<llvm::GlobalValue, llvm::GlobalObject, llvm::GlobalVariable>(
8555 Val: removeASCastIfPresent(V: Input))) {
8556 DeferredReplacement.push_back(Elt: std::make_pair(x&: Input, y&: InputCopy));
8557 continue;
8558 }
8559
8560 if (isa<ConstantData>(Val: Input))
8561 continue;
8562
8563 ReplaceValue(Input, InputCopy, Func);
8564 }
8565
8566 // Replace all of our deferred Input values, currently just Globals.
8567 for (auto Deferred : DeferredReplacement)
8568 ReplaceValue(std::get<0>(in&: Deferred), std::get<1>(in&: Deferred), Func);
8569
8570 FixupDebugInfoForOutlinedFunction(OMPBuilder, Builder, Func,
8571 ValueReplacementMap);
8572 return Func;
8573}
8574/// Given a task descriptor, TaskWithPrivates, return the pointer to the block
8575/// of pointers containing shared data between the parent task and the created
8576/// task.
8577static LoadInst *loadSharedDataFromTaskDescriptor(OpenMPIRBuilder &OMPIRBuilder,
8578 IRBuilderBase &Builder,
8579 Value *TaskWithPrivates,
8580 Type *TaskWithPrivatesTy) {
8581
8582 Type *TaskTy = OMPIRBuilder.Task;
8583 LLVMContext &Ctx = Builder.getContext();
8584 Value *TaskT =
8585 Builder.CreateStructGEP(Ty: TaskWithPrivatesTy, Ptr: TaskWithPrivates, Idx: 0);
8586 Value *Shareds = TaskT;
8587 // TaskWithPrivatesTy can be one of the following
8588 // 1. %struct.task_with_privates = type { %struct.kmp_task_ompbuilder_t,
8589 // %struct.privates }
8590 // 2. %struct.kmp_task_ompbuilder_t ;; This is simply TaskTy
8591 //
8592 // In the former case, that is when TaskWithPrivatesTy != TaskTy,
8593 // its first member has to be the task descriptor. TaskTy is the type of the
8594 // task descriptor. TaskT is the pointer to the task descriptor. Loading the
8595 // first member of TaskT, gives us the pointer to shared data.
8596 if (TaskWithPrivatesTy != TaskTy)
8597 Shareds = Builder.CreateStructGEP(Ty: TaskTy, Ptr: TaskT, Idx: 0);
8598 return Builder.CreateLoad(Ty: PointerType::getUnqual(C&: Ctx), Ptr: Shareds);
8599}
8600/// Create an entry point for a target task with the following.
8601/// It'll have the following signature
8602/// void @.omp_target_task_proxy_func(i32 %thread.id, ptr %task)
8603/// This function is called from emitTargetTask once the
8604/// code to launch the target kernel has been outlined already.
8605/// NumOffloadingArrays is the number of offloading arrays that we need to copy
8606/// into the task structure so that the deferred target task can access this
8607/// data even after the stack frame of the generating task has been rolled
8608/// back. Offloading arrays contain base pointers, pointers, sizes etc
8609/// of the data that the target kernel will access. These in effect are the
8610/// non-empty arrays of pointers held by OpenMPIRBuilder::TargetDataRTArgs.
8611static Function *emitTargetTaskProxyFunction(
8612 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, CallInst *StaleCI,
8613 StructType *PrivatesTy, StructType *TaskWithPrivatesTy,
8614 const size_t NumOffloadingArrays, const int SharedArgsOperandNo) {
8615
8616 // If NumOffloadingArrays is non-zero, PrivatesTy better not be nullptr.
8617 // This is because PrivatesTy is the type of the structure in which
8618 // we pass the offloading arrays to the deferred target task.
8619 assert((!NumOffloadingArrays || PrivatesTy) &&
8620 "PrivatesTy cannot be nullptr when there are offloadingArrays"
8621 "to privatize");
8622
8623 Module &M = OMPBuilder.M;
8624 // KernelLaunchFunction is the target launch function, i.e.
8625 // the function that sets up kernel arguments and calls
8626 // __tgt_target_kernel to launch the kernel on the device.
8627 //
8628 Function *KernelLaunchFunction = StaleCI->getCalledFunction();
8629
8630 // StaleCI is the CallInst which is the call to the outlined
8631 // target kernel launch function. If there are local live-in values
8632 // that the outlined function uses then these are aggregated into a structure
8633 // which is passed as the second argument. If there are no local live-in
8634 // values or if all values used by the outlined kernel are global variables,
8635 // then there's only one argument, the threadID. So, StaleCI can be
8636 //
8637 // %structArg = alloca { ptr, ptr }, align 8
8638 // %gep_ = getelementptr { ptr, ptr }, ptr %structArg, i32 0, i32 0
8639 // store ptr %20, ptr %gep_, align 8
8640 // %gep_8 = getelementptr { ptr, ptr }, ptr %structArg, i32 0, i32 1
8641 // store ptr %21, ptr %gep_8, align 8
8642 // call void @_QQmain..omp_par.1(i32 %global.tid.val6, ptr %structArg)
8643 //
8644 // OR
8645 //
8646 // call void @_QQmain..omp_par.1(i32 %global.tid.val6)
8647 OpenMPIRBuilder::InsertPointTy IP(StaleCI->getParent(),
8648 StaleCI->getIterator());
8649
8650 LLVMContext &Ctx = StaleCI->getParent()->getContext();
8651
8652 Type *ThreadIDTy = Type::getInt32Ty(C&: Ctx);
8653 Type *TaskPtrTy = OMPBuilder.TaskPtr;
8654 [[maybe_unused]] Type *TaskTy = OMPBuilder.Task;
8655
8656 auto ProxyFnTy =
8657 FunctionType::get(Result: Builder.getVoidTy(), Params: {ThreadIDTy, TaskPtrTy},
8658 /* isVarArg */ false);
8659 auto ProxyFn = Function::Create(Ty: ProxyFnTy, Linkage: GlobalValue::InternalLinkage,
8660 N: ".omp_target_task_proxy_func",
8661 M: Builder.GetInsertBlock()->getModule());
8662 Value *ThreadId = ProxyFn->getArg(i: 0);
8663 Value *TaskWithPrivates = ProxyFn->getArg(i: 1);
8664 ThreadId->setName("thread.id");
8665 TaskWithPrivates->setName("task");
8666
8667 bool HasShareds = SharedArgsOperandNo > 0;
8668 bool HasOffloadingArrays = NumOffloadingArrays > 0;
8669 BasicBlock *EntryBB =
8670 BasicBlock::Create(Context&: Builder.getContext(), Name: "entry", Parent: ProxyFn);
8671 Builder.SetInsertPoint(EntryBB);
8672
8673 SmallVector<Value *> KernelLaunchArgs;
8674 KernelLaunchArgs.reserve(N: StaleCI->arg_size());
8675 KernelLaunchArgs.push_back(Elt: ThreadId);
8676
8677 if (HasOffloadingArrays) {
8678 assert(TaskTy != TaskWithPrivatesTy &&
8679 "If there are offloading arrays to pass to the target"
8680 "TaskTy cannot be the same as TaskWithPrivatesTy");
8681 (void)TaskTy;
8682 Value *Privates =
8683 Builder.CreateStructGEP(Ty: TaskWithPrivatesTy, Ptr: TaskWithPrivates, Idx: 1);
8684 for (unsigned int i = 0; i < NumOffloadingArrays; ++i)
8685 KernelLaunchArgs.push_back(
8686 Elt: Builder.CreateStructGEP(Ty: PrivatesTy, Ptr: Privates, Idx: i));
8687 }
8688
8689 if (HasShareds) {
8690 auto *ArgStructAlloca =
8691 dyn_cast<AllocaInst>(Val: StaleCI->getArgOperand(i: SharedArgsOperandNo));
8692 assert(ArgStructAlloca &&
8693 "Unable to find the alloca instruction corresponding to arguments "
8694 "for extracted function");
8695 auto *ArgStructType = cast<StructType>(Val: ArgStructAlloca->getAllocatedType());
8696 std::optional<TypeSize> ArgAllocSize =
8697 ArgStructAlloca->getAllocationSize(DL: M.getDataLayout());
8698 assert(ArgStructType && ArgAllocSize &&
8699 "Unable to determine size of arguments for extracted function");
8700 uint64_t StructSize = ArgAllocSize->getFixedValue();
8701
8702 AllocaInst *NewArgStructAlloca =
8703 Builder.CreateAlloca(Ty: ArgStructType, ArraySize: nullptr, Name: "structArg");
8704
8705 Value *SharedsSize = Builder.getInt64(C: StructSize);
8706
8707 LoadInst *LoadShared = loadSharedDataFromTaskDescriptor(
8708 OMPIRBuilder&: OMPBuilder, Builder, TaskWithPrivates, TaskWithPrivatesTy);
8709
8710 Builder.CreateMemCpy(
8711 Dst: NewArgStructAlloca, DstAlign: NewArgStructAlloca->getAlign(), Src: LoadShared,
8712 SrcAlign: LoadShared->getPointerAlignment(DL: M.getDataLayout()), Size: SharedsSize);
8713 KernelLaunchArgs.push_back(Elt: NewArgStructAlloca);
8714 }
8715 OMPBuilder.createRuntimeFunctionCall(Callee: KernelLaunchFunction, Args: KernelLaunchArgs);
8716 Builder.CreateRetVoid();
8717 return ProxyFn;
8718}
8719static Type *getOffloadingArrayType(Value *V) {
8720
8721 if (auto *GEP = dyn_cast<GetElementPtrInst>(Val: V))
8722 return GEP->getSourceElementType();
8723 if (auto *Alloca = dyn_cast<AllocaInst>(Val: V))
8724 return Alloca->getAllocatedType();
8725
8726 llvm_unreachable("Unhandled Instruction type");
8727 return nullptr;
8728}
8729// This function returns a struct that has at most two members.
8730// The first member is always %struct.kmp_task_ompbuilder_t, that is the task
8731// descriptor. The second member, if needed, is a struct containing arrays
8732// that need to be passed to the offloaded target kernel. For example,
8733// if .offload_baseptrs, .offload_ptrs and .offload_sizes have to be passed to
8734// the target kernel and their types are [3 x ptr], [3 x ptr] and [3 x i64]
8735// respectively, then the types created by this function are
8736//
8737// %struct.privates = type { [3 x ptr], [3 x ptr], [3 x i64] }
8738// %struct.task_with_privates = type { %struct.kmp_task_ompbuilder_t,
8739// %struct.privates }
8740// %struct.task_with_privates is returned by this function.
8741// If there aren't any offloading arrays to pass to the target kernel,
8742// %struct.kmp_task_ompbuilder_t is returned.
8743static StructType *
8744createTaskWithPrivatesTy(OpenMPIRBuilder &OMPIRBuilder,
8745 ArrayRef<Value *> OffloadingArraysToPrivatize) {
8746
8747 if (OffloadingArraysToPrivatize.empty())
8748 return OMPIRBuilder.Task;
8749
8750 SmallVector<Type *, 4> StructFieldTypes;
8751 for (Value *V : OffloadingArraysToPrivatize) {
8752 assert(V->getType()->isPointerTy() &&
8753 "Expected pointer to array to privatize. Got a non-pointer value "
8754 "instead");
8755 Type *ArrayTy = getOffloadingArrayType(V);
8756 assert(ArrayTy && "ArrayType cannot be nullptr");
8757 StructFieldTypes.push_back(Elt: ArrayTy);
8758 }
8759 StructType *PrivatesStructTy =
8760 StructType::create(Elements: StructFieldTypes, Name: "struct.privates");
8761 return StructType::create(Elements: {OMPIRBuilder.Task, PrivatesStructTy},
8762 Name: "struct.task_with_privates");
8763}
8764static Error emitTargetOutlinedFunction(
8765 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
8766 TargetRegionEntryInfo &EntryInfo,
8767 const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
8768 Function *&OutlinedFn, Constant *&OutlinedFnID,
8769 SmallVectorImpl<Value *> &Inputs,
8770 OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
8771 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
8772
8773 OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
8774 [&](StringRef EntryFnName) {
8775 return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
8776 FuncName: EntryFnName, Inputs, CBFunc,
8777 ArgAccessorFuncCB);
8778 };
8779
8780 return OMPBuilder.emitTargetRegionFunction(
8781 EntryInfo, GenerateFunctionCallback&: GenerateOutlinedFunction, IsOffloadEntry, OutlinedFn,
8782 OutlinedFnID);
8783}
8784
8785OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask(
8786 TargetTaskBodyCallbackTy TaskBodyCB, Value *DeviceID, Value *RTLoc,
8787 OpenMPIRBuilder::InsertPointTy AllocaIP,
8788 const SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
8789 const TargetDataRTArgs &RTArgs, bool HasNoWait) {
8790
8791 // The following explains the code-gen scenario for the `target` directive. A
8792 // similar scneario is followed for other device-related directives (e.g.
8793 // `target enter data`) but in similar fashion since we only need to emit task
8794 // that encapsulates the proper runtime call.
8795 //
8796 // When we arrive at this function, the target region itself has been
8797 // outlined into the function OutlinedFn.
8798 // So at ths point, for
8799 // --------------------------------------------------------------
8800 // void user_code_that_offloads(...) {
8801 // omp target depend(..) map(from:a) map(to:b) private(i)
8802 // do i = 1, 10
8803 // a(i) = b(i) + n
8804 // }
8805 //
8806 // --------------------------------------------------------------
8807 //
8808 // we have
8809 //
8810 // --------------------------------------------------------------
8811 //
8812 // void user_code_that_offloads(...) {
8813 // %.offload_baseptrs = alloca [2 x ptr], align 8
8814 // %.offload_ptrs = alloca [2 x ptr], align 8
8815 // %.offload_mappers = alloca [2 x ptr], align 8
8816 // ;; target region has been outlined and now we need to
8817 // ;; offload to it via a target task.
8818 // }
8819 // void outlined_device_function(ptr a, ptr b, ptr n) {
8820 // n = *n_ptr;
8821 // do i = 1, 10
8822 // a(i) = b(i) + n
8823 // }
8824 //
8825 // We have to now do the following
8826 // (i) Make an offloading call to outlined_device_function using the OpenMP
8827 // RTL. See 'kernel_launch_function' in the pseudo code below. This is
8828 // emitted by emitKernelLaunch
8829 // (ii) Create a task entry point function that calls kernel_launch_function
8830 // and is the entry point for the target task. See
8831 // '@.omp_target_task_proxy_func in the pseudocode below.
8832 // (iii) Create a task with the task entry point created in (ii)
8833 //
8834 // That is we create the following
8835 // struct task_with_privates {
8836 // struct kmp_task_ompbuilder_t task_struct;
8837 // struct privates {
8838 // [2 x ptr] ; baseptrs
8839 // [2 x ptr] ; ptrs
8840 // [2 x i64] ; sizes
8841 // }
8842 // }
8843 // void user_code_that_offloads(...) {
8844 // %.offload_baseptrs = alloca [2 x ptr], align 8
8845 // %.offload_ptrs = alloca [2 x ptr], align 8
8846 // %.offload_sizes = alloca [2 x i64], align 8
8847 //
8848 // %structArg = alloca { ptr, ptr, ptr }, align 8
8849 // %strucArg[0] = a
8850 // %strucArg[1] = b
8851 // %strucArg[2] = &n
8852 //
8853 // target_task_with_privates = @__kmpc_omp_target_task_alloc(...,
8854 // sizeof(kmp_task_ompbuilder_t),
8855 // sizeof(structArg),
8856 // @.omp_target_task_proxy_func,
8857 // ...)
8858 // memcpy(target_task_with_privates->task_struct->shareds, %structArg,
8859 // sizeof(structArg))
8860 // memcpy(target_task_with_privates->privates->baseptrs,
8861 // offload_baseptrs, sizeof(offload_baseptrs)
8862 // memcpy(target_task_with_privates->privates->ptrs,
8863 // offload_ptrs, sizeof(offload_ptrs)
8864 // memcpy(target_task_with_privates->privates->sizes,
8865 // offload_sizes, sizeof(offload_sizes)
8866 // dependencies_array = ...
8867 // ;; if nowait not present
8868 // call @__kmpc_omp_wait_deps(..., dependencies_array)
8869 // call @__kmpc_omp_task_begin_if0(...)
8870 // call @ @.omp_target_task_proxy_func(i32 thread_id, ptr
8871 // %target_task_with_privates)
8872 // call @__kmpc_omp_task_complete_if0(...)
8873 // }
8874 //
8875 // define internal void @.omp_target_task_proxy_func(i32 %thread.id,
8876 // ptr %task) {
8877 // %structArg = alloca {ptr, ptr, ptr}
8878 // %task_ptr = getelementptr(%task, 0, 0)
8879 // %shared_data = load (getelementptr %task_ptr, 0, 0)
8880 // mempcy(%structArg, %shared_data, sizeof(%structArg))
8881 //
8882 // %offloading_arrays = getelementptr(%task, 0, 1)
8883 // %offload_baseptrs = getelementptr(%offloading_arrays, 0, 0)
8884 // %offload_ptrs = getelementptr(%offloading_arrays, 0, 1)
8885 // %offload_sizes = getelementptr(%offloading_arrays, 0, 2)
8886 // kernel_launch_function(%thread.id, %offload_baseptrs, %offload_ptrs,
8887 // %offload_sizes, %structArg)
8888 // }
8889 //
8890 // We need the proxy function because the signature of the task entry point
8891 // expected by kmpc_omp_task is always the same and will be different from
8892 // that of the kernel_launch function.
8893 //
8894 // kernel_launch_function is generated by emitKernelLaunch and has the
8895 // always_inline attribute. For this example, it'll look like so:
8896 // void kernel_launch_function(%thread_id, %offload_baseptrs, %offload_ptrs,
8897 // %offload_sizes, %structArg) alwaysinline {
8898 // %kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
8899 // ; load aggregated data from %structArg
8900 // ; setup kernel_args using offload_baseptrs, offload_ptrs and
8901 // ; offload_sizes
8902 // call i32 @__tgt_target_kernel(...,
8903 // outlined_device_function,
8904 // ptr %kernel_args)
8905 // }
8906 // void outlined_device_function(ptr a, ptr b, ptr n) {
8907 // n = *n_ptr;
8908 // do i = 1, 10
8909 // a(i) = b(i) + n
8910 // }
8911 //
8912 BasicBlock *TargetTaskBodyBB =
8913 splitBB(Builder, /*CreateBranch=*/true, Name: "target.task.body");
8914 BasicBlock *TargetTaskAllocaBB =
8915 splitBB(Builder, /*CreateBranch=*/true, Name: "target.task.alloca");
8916
8917 InsertPointTy TargetTaskAllocaIP(TargetTaskAllocaBB,
8918 TargetTaskAllocaBB->begin());
8919 InsertPointTy TargetTaskBodyIP(TargetTaskBodyBB, TargetTaskBodyBB->begin());
8920
8921 OutlineInfo OI;
8922 OI.EntryBB = TargetTaskAllocaBB;
8923 OI.OuterAllocaBB = AllocaIP.getBlock();
8924
8925 // Add the thread ID argument.
8926 SmallVector<Instruction *, 4> ToBeDeleted;
8927 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
8928 Builder, OuterAllocaIP: AllocaIP, ToBeDeleted, InnerAllocaIP: TargetTaskAllocaIP, Name: "global.tid", AsPtr: false));
8929
8930 // Generate the task body which will subsequently be outlined.
8931 Builder.restoreIP(IP: TargetTaskBodyIP);
8932 if (Error Err = TaskBodyCB(DeviceID, RTLoc, TargetTaskAllocaIP))
8933 return Err;
8934
8935 // The outliner (CodeExtractor) extract a sequence or vector of blocks that
8936 // it is given. These blocks are enumerated by
8937 // OpenMPIRBuilder::OutlineInfo::collectBlocks which expects the OI.ExitBlock
8938 // to be outside the region. In other words, OI.ExitBlock is expected to be
8939 // the start of the region after the outlining. We used to set OI.ExitBlock
8940 // to the InsertBlock after TaskBodyCB is done. This is fine in most cases
8941 // except when the task body is a single basic block. In that case,
8942 // OI.ExitBlock is set to the single task body block and will get left out of
8943 // the outlining process. So, simply create a new empty block to which we
8944 // uncoditionally branch from where TaskBodyCB left off
8945 OI.ExitBB = BasicBlock::Create(Context&: Builder.getContext(), Name: "target.task.cont");
8946 emitBlock(BB: OI.ExitBB, CurFn: Builder.GetInsertBlock()->getParent(),
8947 /*IsFinished=*/true);
8948
8949 SmallVector<Value *, 2> OffloadingArraysToPrivatize;
8950 bool NeedsTargetTask = HasNoWait && DeviceID;
8951 if (NeedsTargetTask) {
8952 for (auto *V :
8953 {RTArgs.BasePointersArray, RTArgs.PointersArray, RTArgs.MappersArray,
8954 RTArgs.MapNamesArray, RTArgs.MapTypesArray, RTArgs.MapTypesArrayEnd,
8955 RTArgs.SizesArray}) {
8956 if (V && !isa<ConstantPointerNull, GlobalVariable>(Val: V)) {
8957 OffloadingArraysToPrivatize.push_back(Elt: V);
8958 OI.ExcludeArgsFromAggregate.push_back(Elt: V);
8959 }
8960 }
8961 }
8962 OI.PostOutlineCB = [this, ToBeDeleted, Dependencies, NeedsTargetTask,
8963 DeviceID, OffloadingArraysToPrivatize](
8964 Function &OutlinedFn) mutable {
8965 assert(OutlinedFn.hasOneUse() &&
8966 "there must be a single user for the outlined function");
8967
8968 CallInst *StaleCI = cast<CallInst>(Val: OutlinedFn.user_back());
8969
8970 // The first argument of StaleCI is always the thread id.
8971 // The next few arguments are the pointers to offloading arrays
8972 // if any. (see OffloadingArraysToPrivatize)
8973 // Finally, all other local values that are live-in into the outlined region
8974 // end up in a structure whose pointer is passed as the last argument. This
8975 // piece of data is passed in the "shared" field of the task structure. So,
8976 // we know we have to pass shareds to the task if the number of arguments is
8977 // greater than OffloadingArraysToPrivatize.size() + 1 The 1 is for the
8978 // thread id. Further, for safety, we assert that the number of arguments of
8979 // StaleCI is exactly OffloadingArraysToPrivatize.size() + 2
8980 const unsigned int NumStaleCIArgs = StaleCI->arg_size();
8981 bool HasShareds = NumStaleCIArgs > OffloadingArraysToPrivatize.size() + 1;
8982 assert((!HasShareds ||
8983 NumStaleCIArgs == (OffloadingArraysToPrivatize.size() + 2)) &&
8984 "Wrong number of arguments for StaleCI when shareds are present");
8985 int SharedArgOperandNo =
8986 HasShareds ? OffloadingArraysToPrivatize.size() + 1 : 0;
8987
8988 StructType *TaskWithPrivatesTy =
8989 createTaskWithPrivatesTy(OMPIRBuilder&: *this, OffloadingArraysToPrivatize);
8990 StructType *PrivatesTy = nullptr;
8991
8992 if (!OffloadingArraysToPrivatize.empty())
8993 PrivatesTy =
8994 static_cast<StructType *>(TaskWithPrivatesTy->getElementType(N: 1));
8995
8996 Function *ProxyFn = emitTargetTaskProxyFunction(
8997 OMPBuilder&: *this, Builder, StaleCI, PrivatesTy, TaskWithPrivatesTy,
8998 NumOffloadingArrays: OffloadingArraysToPrivatize.size(), SharedArgsOperandNo: SharedArgOperandNo);
8999
9000 LLVM_DEBUG(dbgs() << "Proxy task entry function created: " << *ProxyFn
9001 << "\n");
9002
9003 Builder.SetInsertPoint(StaleCI);
9004
9005 // Gather the arguments for emitting the runtime call.
9006 uint32_t SrcLocStrSize;
9007 Constant *SrcLocStr =
9008 getOrCreateSrcLocStr(Loc: LocationDescription(Builder), SrcLocStrSize);
9009 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
9010
9011 // @__kmpc_omp_task_alloc or @__kmpc_omp_target_task_alloc
9012 //
9013 // If `HasNoWait == true`, we call @__kmpc_omp_target_task_alloc to provide
9014 // the DeviceID to the deferred task and also since
9015 // @__kmpc_omp_target_task_alloc creates an untied/async task.
9016 Function *TaskAllocFn =
9017 !NeedsTargetTask
9018 ? getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_alloc)
9019 : getOrCreateRuntimeFunctionPtr(
9020 FnID: OMPRTL___kmpc_omp_target_task_alloc);
9021
9022 // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
9023 // call.
9024 Value *ThreadID = getOrCreateThreadID(Ident);
9025
9026 // Argument - `sizeof_kmp_task_t` (TaskSize)
9027 // Tasksize refers to the size in bytes of kmp_task_t data structure
9028 // plus any other data to be passed to the target task, if any, which
9029 // is packed into a struct. kmp_task_t and the struct so created are
9030 // packed into a wrapper struct whose type is TaskWithPrivatesTy.
9031 Value *TaskSize = Builder.getInt64(
9032 C: M.getDataLayout().getTypeStoreSize(Ty: TaskWithPrivatesTy));
9033
9034 // Argument - `sizeof_shareds` (SharedsSize)
9035 // SharedsSize refers to the shareds array size in the kmp_task_t data
9036 // structure.
9037 Value *SharedsSize = Builder.getInt64(C: 0);
9038 if (HasShareds) {
9039 auto *ArgStructAlloca =
9040 dyn_cast<AllocaInst>(Val: StaleCI->getArgOperand(i: SharedArgOperandNo));
9041 assert(ArgStructAlloca &&
9042 "Unable to find the alloca instruction corresponding to arguments "
9043 "for extracted function");
9044 std::optional<TypeSize> ArgAllocSize =
9045 ArgStructAlloca->getAllocationSize(DL: M.getDataLayout());
9046 assert(ArgAllocSize &&
9047 "Unable to determine size of arguments for extracted function");
9048 SharedsSize = Builder.getInt64(C: ArgAllocSize->getFixedValue());
9049 }
9050
9051 // Argument - `flags`
9052 // Task is tied iff (Flags & 1) == 1.
9053 // Task is untied iff (Flags & 1) == 0.
9054 // Task is final iff (Flags & 2) == 2.
9055 // Task is not final iff (Flags & 2) == 0.
9056 // A target task is not final and is untied.
9057 Value *Flags = Builder.getInt32(C: 0);
9058
9059 // Emit the @__kmpc_omp_task_alloc runtime call
9060 // The runtime call returns a pointer to an area where the task captured
9061 // variables must be copied before the task is run (TaskData)
9062 CallInst *TaskData = nullptr;
9063
9064 SmallVector<llvm::Value *> TaskAllocArgs = {
9065 /*loc_ref=*/Ident, /*gtid=*/ThreadID,
9066 /*flags=*/Flags,
9067 /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
9068 /*task_func=*/ProxyFn};
9069
9070 if (NeedsTargetTask) {
9071 assert(DeviceID && "Expected non-empty device ID.");
9072 TaskAllocArgs.push_back(Elt: DeviceID);
9073 }
9074
9075 TaskData = createRuntimeFunctionCall(Callee: TaskAllocFn, Args: TaskAllocArgs);
9076
9077 Align Alignment = TaskData->getPointerAlignment(DL: M.getDataLayout());
9078 if (HasShareds) {
9079 Value *Shareds = StaleCI->getArgOperand(i: SharedArgOperandNo);
9080 Value *TaskShareds = loadSharedDataFromTaskDescriptor(
9081 OMPIRBuilder&: *this, Builder, TaskWithPrivates: TaskData, TaskWithPrivatesTy);
9082 Builder.CreateMemCpy(Dst: TaskShareds, DstAlign: Alignment, Src: Shareds, SrcAlign: Alignment,
9083 Size: SharedsSize);
9084 }
9085 if (!OffloadingArraysToPrivatize.empty()) {
9086 Value *Privates =
9087 Builder.CreateStructGEP(Ty: TaskWithPrivatesTy, Ptr: TaskData, Idx: 1);
9088 for (unsigned int i = 0; i < OffloadingArraysToPrivatize.size(); ++i) {
9089 Value *PtrToPrivatize = OffloadingArraysToPrivatize[i];
9090 [[maybe_unused]] Type *ArrayType =
9091 getOffloadingArrayType(V: PtrToPrivatize);
9092 assert(ArrayType && "ArrayType cannot be nullptr");
9093
9094 Type *ElementType = PrivatesTy->getElementType(N: i);
9095 assert(ElementType == ArrayType &&
9096 "ElementType should match ArrayType");
9097 (void)ArrayType;
9098
9099 Value *Dst = Builder.CreateStructGEP(Ty: PrivatesTy, Ptr: Privates, Idx: i);
9100 Builder.CreateMemCpy(
9101 Dst, DstAlign: Alignment, Src: PtrToPrivatize, SrcAlign: Alignment,
9102 Size: Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: ElementType)));
9103 }
9104 }
9105
9106 Value *DepArray = emitTaskDependencies(OMPBuilder&: *this, Dependencies);
9107
9108 // ---------------------------------------------------------------
9109 // V5.2 13.8 target construct
9110 // If the nowait clause is present, execution of the target task
9111 // may be deferred. If the nowait clause is not present, the target task is
9112 // an included task.
9113 // ---------------------------------------------------------------
9114 // The above means that the lack of a nowait on the target construct
9115 // translates to '#pragma omp task if(0)'
9116 if (!NeedsTargetTask) {
9117 if (DepArray) {
9118 Function *TaskWaitFn =
9119 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_wait_deps);
9120 createRuntimeFunctionCall(
9121 Callee: TaskWaitFn,
9122 Args: {/*loc_ref=*/Ident, /*gtid=*/ThreadID,
9123 /*ndeps=*/Builder.getInt32(C: Dependencies.size()),
9124 /*dep_list=*/DepArray,
9125 /*ndeps_noalias=*/ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
9126 /*noalias_dep_list=*/
9127 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
9128 }
9129 // Included task.
9130 Function *TaskBeginFn =
9131 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_begin_if0);
9132 Function *TaskCompleteFn =
9133 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_complete_if0);
9134 createRuntimeFunctionCall(Callee: TaskBeginFn, Args: {Ident, ThreadID, TaskData});
9135 CallInst *CI = createRuntimeFunctionCall(Callee: ProxyFn, Args: {ThreadID, TaskData});
9136 CI->setDebugLoc(StaleCI->getDebugLoc());
9137 createRuntimeFunctionCall(Callee: TaskCompleteFn, Args: {Ident, ThreadID, TaskData});
9138 } else if (DepArray) {
9139 // HasNoWait - meaning the task may be deferred. Call
9140 // __kmpc_omp_task_with_deps if there are dependencies,
9141 // else call __kmpc_omp_task
9142 Function *TaskFn =
9143 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_with_deps);
9144 createRuntimeFunctionCall(
9145 Callee: TaskFn,
9146 Args: {Ident, ThreadID, TaskData, Builder.getInt32(C: Dependencies.size()),
9147 DepArray, ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
9148 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
9149 } else {
9150 // Emit the @__kmpc_omp_task runtime call to spawn the task
9151 Function *TaskFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task);
9152 createRuntimeFunctionCall(Callee: TaskFn, Args: {Ident, ThreadID, TaskData});
9153 }
9154
9155 StaleCI->eraseFromParent();
9156 for (Instruction *I : llvm::reverse(C&: ToBeDeleted))
9157 I->eraseFromParent();
9158 };
9159 addOutlineInfo(OI: std::move(OI));
9160
9161 LLVM_DEBUG(dbgs() << "Insert block after emitKernelLaunch = \n"
9162 << *(Builder.GetInsertBlock()) << "\n");
9163 LLVM_DEBUG(dbgs() << "Module after emitKernelLaunch = \n"
9164 << *(Builder.GetInsertBlock()->getParent()->getParent())
9165 << "\n");
9166 return Builder.saveIP();
9167}
9168
9169Error OpenMPIRBuilder::emitOffloadingArraysAndArgs(
9170 InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
9171 TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo,
9172 CustomMapperCallbackTy CustomMapperCB, bool IsNonContiguous,
9173 bool ForEndCall, function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
9174 if (Error Err =
9175 emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info,
9176 CustomMapperCB, IsNonContiguous, DeviceAddrCB))
9177 return Err;
9178 emitOffloadingArraysArgument(Builder, RTArgs, Info, ForEndCall);
9179 return Error::success();
9180}
9181
9182static void emitTargetCall(
9183 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
9184 OpenMPIRBuilder::InsertPointTy AllocaIP,
9185 OpenMPIRBuilder::TargetDataInfo &Info,
9186 const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
9187 const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
9188 Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID,
9189 SmallVectorImpl<Value *> &Args,
9190 OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
9191 OpenMPIRBuilder::CustomMapperCallbackTy CustomMapperCB,
9192 const SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
9193 bool HasNoWait, Value *DynCGroupMem,
9194 OMPDynGroupprivateFallbackType DynCGroupMemFallback) {
9195 // Generate a function call to the host fallback implementation of the target
9196 // region. This is called by the host when no offload entry was generated for
9197 // the target region and when the offloading call fails at runtime.
9198 auto &&EmitTargetCallFallbackCB = [&](OpenMPIRBuilder::InsertPointTy IP)
9199 -> OpenMPIRBuilder::InsertPointOrErrorTy {
9200 Builder.restoreIP(IP);
9201 // Ensure the host fallback has the same dyn_ptr ABI as the device.
9202 SmallVector<Value *> FallbackArgs(Args.begin(), Args.end());
9203 FallbackArgs.push_back(
9204 Elt: Constant::getNullValue(Ty: PointerType::getUnqual(C&: Builder.getContext())));
9205 OMPBuilder.createRuntimeFunctionCall(Callee: OutlinedFn, Args: FallbackArgs);
9206 return Builder.saveIP();
9207 };
9208
9209 bool HasDependencies = Dependencies.size() > 0;
9210 bool RequiresOuterTargetTask = HasNoWait || HasDependencies;
9211
9212 OpenMPIRBuilder::TargetKernelArgs KArgs;
9213
9214 auto TaskBodyCB =
9215 [&](Value *DeviceID, Value *RTLoc,
9216 IRBuilderBase::InsertPoint TargetTaskAllocaIP) -> Error {
9217 // Assume no error was returned because EmitTargetCallFallbackCB doesn't
9218 // produce any.
9219 llvm::OpenMPIRBuilder::InsertPointTy AfterIP = cantFail(ValOrErr: [&]() {
9220 // emitKernelLaunch makes the necessary runtime call to offload the
9221 // kernel. We then outline all that code into a separate function
9222 // ('kernel_launch_function' in the pseudo code above). This function is
9223 // then called by the target task proxy function (see
9224 // '@.omp_target_task_proxy_func' in the pseudo code above)
9225 // "@.omp_target_task_proxy_func' is generated by
9226 // emitTargetTaskProxyFunction.
9227 if (OutlinedFnID && DeviceID)
9228 return OMPBuilder.emitKernelLaunch(Loc: Builder, OutlinedFnID,
9229 EmitTargetCallFallbackCB, Args&: KArgs,
9230 DeviceID, RTLoc, AllocaIP: TargetTaskAllocaIP);
9231
9232 // We only need to do the outlining if `DeviceID` is set to avoid calling
9233 // `emitKernelLaunch` if we want to code-gen for the host; e.g. if we are
9234 // generating the `else` branch of an `if` clause.
9235 //
9236 // When OutlinedFnID is set to nullptr, then it's not an offloading call.
9237 // In this case, we execute the host implementation directly.
9238 return EmitTargetCallFallbackCB(OMPBuilder.Builder.saveIP());
9239 }());
9240
9241 OMPBuilder.Builder.restoreIP(IP: AfterIP);
9242 return Error::success();
9243 };
9244
9245 auto &&EmitTargetCallElse =
9246 [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
9247 OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
9248 // Assume no error was returned because EmitTargetCallFallbackCB doesn't
9249 // produce any.
9250 OpenMPIRBuilder::InsertPointTy AfterIP = cantFail(ValOrErr: [&]() {
9251 if (RequiresOuterTargetTask) {
9252 // Arguments that are intended to be directly forwarded to an
9253 // emitKernelLaunch call are pased as nullptr, since
9254 // OutlinedFnID=nullptr results in that call not being done.
9255 OpenMPIRBuilder::TargetDataRTArgs EmptyRTArgs;
9256 return OMPBuilder.emitTargetTask(TaskBodyCB, /*DeviceID=*/nullptr,
9257 /*RTLoc=*/nullptr, AllocaIP,
9258 Dependencies, RTArgs: EmptyRTArgs, HasNoWait);
9259 }
9260 return EmitTargetCallFallbackCB(Builder.saveIP());
9261 }());
9262
9263 Builder.restoreIP(IP: AfterIP);
9264 return Error::success();
9265 };
9266
9267 auto &&EmitTargetCallThen =
9268 [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
9269 OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
9270 Info.HasNoWait = HasNoWait;
9271 OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
9272
9273 OpenMPIRBuilder::TargetDataRTArgs RTArgs;
9274 if (Error Err = OMPBuilder.emitOffloadingArraysAndArgs(
9275 AllocaIP, CodeGenIP: Builder.saveIP(), Info, RTArgs, CombinedInfo&: MapInfo, CustomMapperCB,
9276 /*IsNonContiguous=*/true,
9277 /*ForEndCall=*/false))
9278 return Err;
9279
9280 SmallVector<Value *, 3> NumTeamsC;
9281 for (auto [DefaultVal, RuntimeVal] :
9282 zip_equal(t: DefaultAttrs.MaxTeams, u: RuntimeAttrs.MaxTeams))
9283 NumTeamsC.push_back(Elt: RuntimeVal ? RuntimeVal
9284 : Builder.getInt32(C: DefaultVal));
9285
9286 // Calculate number of threads: 0 if no clauses specified, otherwise it is
9287 // the minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
9288 auto InitMaxThreadsClause = [&Builder](Value *Clause) {
9289 if (Clause)
9290 Clause = Builder.CreateIntCast(V: Clause, DestTy: Builder.getInt32Ty(),
9291 /*isSigned=*/false);
9292 return Clause;
9293 };
9294 auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
9295 if (Clause)
9296 Result =
9297 Result ? Builder.CreateSelect(C: Builder.CreateICmpULT(LHS: Result, RHS: Clause),
9298 True: Result, False: Clause)
9299 : Clause;
9300 };
9301
9302 // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
9303 // the NUM_THREADS clause is overriden by THREAD_LIMIT.
9304 SmallVector<Value *, 3> NumThreadsC;
9305 Value *MaxThreadsClause =
9306 RuntimeAttrs.TeamsThreadLimit.size() == 1
9307 ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
9308 : nullptr;
9309
9310 for (auto [TeamsVal, TargetVal] : zip_equal(
9311 t: RuntimeAttrs.TeamsThreadLimit, u: RuntimeAttrs.TargetThreadLimit)) {
9312 Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
9313 Value *NumThreads = InitMaxThreadsClause(TargetVal);
9314
9315 CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
9316 CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
9317
9318 NumThreadsC.push_back(Elt: NumThreads ? NumThreads : Builder.getInt32(C: 0));
9319 }
9320
9321 unsigned NumTargetItems = Info.NumberOfPtrs;
9322 uint32_t SrcLocStrSize;
9323 Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
9324 Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
9325 LocFlags: llvm::omp::IdentFlag(0), Reserve2Flags: 0);
9326
9327 Value *TripCount = RuntimeAttrs.LoopTripCount
9328 ? Builder.CreateIntCast(V: RuntimeAttrs.LoopTripCount,
9329 DestTy: Builder.getInt64Ty(),
9330 /*isSigned=*/false)
9331 : Builder.getInt64(C: 0);
9332
9333 // Request zero groupprivate bytes by default.
9334 if (!DynCGroupMem)
9335 DynCGroupMem = Builder.getInt32(C: 0);
9336
9337 KArgs = OpenMPIRBuilder::TargetKernelArgs(
9338 NumTargetItems, RTArgs, TripCount, NumTeamsC, NumThreadsC, DynCGroupMem,
9339 HasNoWait, DynCGroupMemFallback);
9340
9341 // Assume no error was returned because TaskBodyCB and
9342 // EmitTargetCallFallbackCB don't produce any.
9343 OpenMPIRBuilder::InsertPointTy AfterIP = cantFail(ValOrErr: [&]() {
9344 // The presence of certain clauses on the target directive require the
9345 // explicit generation of the target task.
9346 if (RequiresOuterTargetTask)
9347 return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID: RuntimeAttrs.DeviceID,
9348 RTLoc, AllocaIP, Dependencies,
9349 RTArgs: KArgs.RTArgs, HasNoWait: Info.HasNoWait);
9350
9351 return OMPBuilder.emitKernelLaunch(
9352 Loc: Builder, OutlinedFnID, EmitTargetCallFallbackCB, Args&: KArgs,
9353 DeviceID: RuntimeAttrs.DeviceID, RTLoc, AllocaIP);
9354 }());
9355
9356 Builder.restoreIP(IP: AfterIP);
9357 return Error::success();
9358 };
9359
9360 // If we don't have an ID for the target region, it means an offload entry
9361 // wasn't created. In this case we just run the host fallback directly and
9362 // ignore any potential 'if' clauses.
9363 if (!OutlinedFnID) {
9364 cantFail(Err: EmitTargetCallElse(AllocaIP, Builder.saveIP()));
9365 return;
9366 }
9367
9368 // If there's no 'if' clause, only generate the kernel launch code path.
9369 if (!IfCond) {
9370 cantFail(Err: EmitTargetCallThen(AllocaIP, Builder.saveIP()));
9371 return;
9372 }
9373
9374 cantFail(Err: OMPBuilder.emitIfClause(Cond: IfCond, ThenGen: EmitTargetCallThen,
9375 ElseGen: EmitTargetCallElse, AllocaIP));
9376}
9377
9378OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
9379 const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
9380 InsertPointTy CodeGenIP, TargetDataInfo &Info,
9381 TargetRegionEntryInfo &EntryInfo,
9382 const TargetKernelDefaultAttrs &DefaultAttrs,
9383 const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
9384 SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
9385 OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
9386 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
9387 CustomMapperCallbackTy CustomMapperCB,
9388 const SmallVector<DependData> &Dependencies, bool HasNowait,
9389 Value *DynCGroupMem, OMPDynGroupprivateFallbackType DynCGroupMemFallback) {
9390
9391 if (!updateToLocation(Loc))
9392 return InsertPointTy();
9393
9394 Builder.restoreIP(IP: CodeGenIP);
9395
9396 Function *OutlinedFn;
9397 Constant *OutlinedFnID = nullptr;
9398 // The target region is outlined into its own function. The LLVM IR for
9399 // the target region itself is generated using the callbacks CBFunc
9400 // and ArgAccessorFuncCB
9401 if (Error Err = emitTargetOutlinedFunction(
9402 OMPBuilder&: *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
9403 OutlinedFnID, Inputs, CBFunc, ArgAccessorFuncCB))
9404 return Err;
9405
9406 // If we are not on the target device, then we need to generate code
9407 // to make a remote call (offload) to the previously outlined function
9408 // that represents the target region. Do that now.
9409 if (!Config.isTargetDevice())
9410 emitTargetCall(OMPBuilder&: *this, Builder, AllocaIP, Info, DefaultAttrs, RuntimeAttrs,
9411 IfCond, OutlinedFn, OutlinedFnID, Args&: Inputs, GenMapInfoCB,
9412 CustomMapperCB, Dependencies, HasNoWait: HasNowait, DynCGroupMem,
9413 DynCGroupMemFallback);
9414 return Builder.saveIP();
9415}
9416
9417std::string OpenMPIRBuilder::getNameWithSeparators(ArrayRef<StringRef> Parts,
9418 StringRef FirstSeparator,
9419 StringRef Separator) {
9420 SmallString<128> Buffer;
9421 llvm::raw_svector_ostream OS(Buffer);
9422 StringRef Sep = FirstSeparator;
9423 for (StringRef Part : Parts) {
9424 OS << Sep << Part;
9425 Sep = Separator;
9426 }
9427 return OS.str().str();
9428}
9429
9430std::string
9431OpenMPIRBuilder::createPlatformSpecificName(ArrayRef<StringRef> Parts) const {
9432 return OpenMPIRBuilder::getNameWithSeparators(Parts, FirstSeparator: Config.firstSeparator(),
9433 Separator: Config.separator());
9434}
9435
9436GlobalVariable *OpenMPIRBuilder::getOrCreateInternalVariable(
9437 Type *Ty, const StringRef &Name, std::optional<unsigned> AddressSpace) {
9438 auto &Elem = *InternalVars.try_emplace(Key: Name, Args: nullptr).first;
9439 if (Elem.second) {
9440 assert(Elem.second->getValueType() == Ty &&
9441 "OMP internal variable has different type than requested");
9442 } else {
9443 // TODO: investigate the appropriate linkage type used for the global
9444 // variable for possibly changing that to internal or private, or maybe
9445 // create different versions of the function for different OMP internal
9446 // variables.
9447 const DataLayout &DL = M.getDataLayout();
9448 // TODO: Investigate why AMDGPU expects AS 0 for globals even though the
9449 // default global AS is 1.
9450 // See double-target-call-with-declare-target.f90 and
9451 // declare-target-vars-in-target-region.f90 libomptarget
9452 // tests.
9453 unsigned AddressSpaceVal = AddressSpace ? *AddressSpace
9454 : M.getTargetTriple().isAMDGPU()
9455 ? 0
9456 : DL.getDefaultGlobalsAddressSpace();
9457 auto Linkage = this->M.getTargetTriple().getArch() == Triple::wasm32
9458 ? GlobalValue::InternalLinkage
9459 : GlobalValue::CommonLinkage;
9460 auto *GV = new GlobalVariable(M, Ty, /*IsConstant=*/false, Linkage,
9461 Constant::getNullValue(Ty), Elem.first(),
9462 /*InsertBefore=*/nullptr,
9463 GlobalValue::NotThreadLocal, AddressSpaceVal);
9464 const llvm::Align TypeAlign = DL.getABITypeAlign(Ty);
9465 const llvm::Align PtrAlign = DL.getPointerABIAlignment(AS: AddressSpaceVal);
9466 GV->setAlignment(std::max(a: TypeAlign, b: PtrAlign));
9467 Elem.second = GV;
9468 }
9469
9470 return Elem.second;
9471}
9472
9473Value *OpenMPIRBuilder::getOMPCriticalRegionLock(StringRef CriticalName) {
9474 std::string Prefix = Twine("gomp_critical_user_", CriticalName).str();
9475 std::string Name = getNameWithSeparators(Parts: {Prefix, "var"}, FirstSeparator: ".", Separator: ".");
9476 return getOrCreateInternalVariable(Ty: KmpCriticalNameTy, Name);
9477}
9478
9479Value *OpenMPIRBuilder::getSizeInBytes(Value *BasePtr) {
9480 LLVMContext &Ctx = Builder.getContext();
9481 Value *Null =
9482 Constant::getNullValue(Ty: PointerType::getUnqual(C&: BasePtr->getContext()));
9483 Value *SizeGep =
9484 Builder.CreateGEP(Ty: BasePtr->getType(), Ptr: Null, IdxList: Builder.getInt32(C: 1));
9485 Value *SizePtrToInt = Builder.CreatePtrToInt(V: SizeGep, DestTy: Type::getInt64Ty(C&: Ctx));
9486 return SizePtrToInt;
9487}
9488
9489GlobalVariable *
9490OpenMPIRBuilder::createOffloadMaptypes(SmallVectorImpl<uint64_t> &Mappings,
9491 std::string VarName) {
9492 llvm::Constant *MaptypesArrayInit =
9493 llvm::ConstantDataArray::get(Context&: M.getContext(), Elts&: Mappings);
9494 auto *MaptypesArrayGlobal = new llvm::GlobalVariable(
9495 M, MaptypesArrayInit->getType(),
9496 /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MaptypesArrayInit,
9497 VarName);
9498 MaptypesArrayGlobal->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
9499 return MaptypesArrayGlobal;
9500}
9501
9502void OpenMPIRBuilder::createMapperAllocas(const LocationDescription &Loc,
9503 InsertPointTy AllocaIP,
9504 unsigned NumOperands,
9505 struct MapperAllocas &MapperAllocas) {
9506 if (!updateToLocation(Loc))
9507 return;
9508
9509 auto *ArrI8PtrTy = ArrayType::get(ElementType: Int8Ptr, NumElements: NumOperands);
9510 auto *ArrI64Ty = ArrayType::get(ElementType: Int64, NumElements: NumOperands);
9511 Builder.restoreIP(IP: AllocaIP);
9512 AllocaInst *ArgsBase = Builder.CreateAlloca(
9513 Ty: ArrI8PtrTy, /* ArraySize = */ nullptr, Name: ".offload_baseptrs");
9514 AllocaInst *Args = Builder.CreateAlloca(Ty: ArrI8PtrTy, /* ArraySize = */ nullptr,
9515 Name: ".offload_ptrs");
9516 AllocaInst *ArgSizes = Builder.CreateAlloca(
9517 Ty: ArrI64Ty, /* ArraySize = */ nullptr, Name: ".offload_sizes");
9518 updateToLocation(Loc);
9519 MapperAllocas.ArgsBase = ArgsBase;
9520 MapperAllocas.Args = Args;
9521 MapperAllocas.ArgSizes = ArgSizes;
9522}
9523
9524void OpenMPIRBuilder::emitMapperCall(const LocationDescription &Loc,
9525 Function *MapperFunc, Value *SrcLocInfo,
9526 Value *MaptypesArg, Value *MapnamesArg,
9527 struct MapperAllocas &MapperAllocas,
9528 int64_t DeviceID, unsigned NumOperands) {
9529 if (!updateToLocation(Loc))
9530 return;
9531
9532 auto *ArrI8PtrTy = ArrayType::get(ElementType: Int8Ptr, NumElements: NumOperands);
9533 auto *ArrI64Ty = ArrayType::get(ElementType: Int64, NumElements: NumOperands);
9534 Value *ArgsBaseGEP =
9535 Builder.CreateInBoundsGEP(Ty: ArrI8PtrTy, Ptr: MapperAllocas.ArgsBase,
9536 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 0)});
9537 Value *ArgsGEP =
9538 Builder.CreateInBoundsGEP(Ty: ArrI8PtrTy, Ptr: MapperAllocas.Args,
9539 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 0)});
9540 Value *ArgSizesGEP =
9541 Builder.CreateInBoundsGEP(Ty: ArrI64Ty, Ptr: MapperAllocas.ArgSizes,
9542 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 0)});
9543 Value *NullPtr =
9544 Constant::getNullValue(Ty: PointerType::getUnqual(C&: Int8Ptr->getContext()));
9545 createRuntimeFunctionCall(Callee: MapperFunc, Args: {SrcLocInfo, Builder.getInt64(C: DeviceID),
9546 Builder.getInt32(C: NumOperands),
9547 ArgsBaseGEP, ArgsGEP, ArgSizesGEP,
9548 MaptypesArg, MapnamesArg, NullPtr});
9549}
9550
9551void OpenMPIRBuilder::emitOffloadingArraysArgument(IRBuilderBase &Builder,
9552 TargetDataRTArgs &RTArgs,
9553 TargetDataInfo &Info,
9554 bool ForEndCall) {
9555 assert((!ForEndCall || Info.separateBeginEndCalls()) &&
9556 "expected region end call to runtime only when end call is separate");
9557 auto UnqualPtrTy = PointerType::getUnqual(C&: M.getContext());
9558 auto VoidPtrTy = UnqualPtrTy;
9559 auto VoidPtrPtrTy = UnqualPtrTy;
9560 auto Int64Ty = Type::getInt64Ty(C&: M.getContext());
9561 auto Int64PtrTy = UnqualPtrTy;
9562
9563 if (!Info.NumberOfPtrs) {
9564 RTArgs.BasePointersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
9565 RTArgs.PointersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
9566 RTArgs.SizesArray = ConstantPointerNull::get(T: Int64PtrTy);
9567 RTArgs.MapTypesArray = ConstantPointerNull::get(T: Int64PtrTy);
9568 RTArgs.MapNamesArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
9569 RTArgs.MappersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
9570 return;
9571 }
9572
9573 RTArgs.BasePointersArray = Builder.CreateConstInBoundsGEP2_32(
9574 Ty: ArrayType::get(ElementType: VoidPtrTy, NumElements: Info.NumberOfPtrs),
9575 Ptr: Info.RTArgs.BasePointersArray,
9576 /*Idx0=*/0, /*Idx1=*/0);
9577 RTArgs.PointersArray = Builder.CreateConstInBoundsGEP2_32(
9578 Ty: ArrayType::get(ElementType: VoidPtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.PointersArray,
9579 /*Idx0=*/0,
9580 /*Idx1=*/0);
9581 RTArgs.SizesArray = Builder.CreateConstInBoundsGEP2_32(
9582 Ty: ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.SizesArray,
9583 /*Idx0=*/0, /*Idx1=*/0);
9584 RTArgs.MapTypesArray = Builder.CreateConstInBoundsGEP2_32(
9585 Ty: ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs),
9586 Ptr: ForEndCall && Info.RTArgs.MapTypesArrayEnd ? Info.RTArgs.MapTypesArrayEnd
9587 : Info.RTArgs.MapTypesArray,
9588 /*Idx0=*/0,
9589 /*Idx1=*/0);
9590
9591 // Only emit the mapper information arrays if debug information is
9592 // requested.
9593 if (!Info.EmitDebug)
9594 RTArgs.MapNamesArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
9595 else
9596 RTArgs.MapNamesArray = Builder.CreateConstInBoundsGEP2_32(
9597 Ty: ArrayType::get(ElementType: VoidPtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.MapNamesArray,
9598 /*Idx0=*/0,
9599 /*Idx1=*/0);
9600 // If there is no user-defined mapper, set the mapper array to nullptr to
9601 // avoid an unnecessary data privatization
9602 if (!Info.HasMapper)
9603 RTArgs.MappersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
9604 else
9605 RTArgs.MappersArray =
9606 Builder.CreatePointerCast(V: Info.RTArgs.MappersArray, DestTy: VoidPtrPtrTy);
9607}
9608
9609void OpenMPIRBuilder::emitNonContiguousDescriptor(InsertPointTy AllocaIP,
9610 InsertPointTy CodeGenIP,
9611 MapInfosTy &CombinedInfo,
9612 TargetDataInfo &Info) {
9613 MapInfosTy::StructNonContiguousInfo &NonContigInfo =
9614 CombinedInfo.NonContigInfo;
9615
9616 // Build an array of struct descriptor_dim and then assign it to
9617 // offload_args.
9618 //
9619 // struct descriptor_dim {
9620 // uint64_t offset;
9621 // uint64_t count;
9622 // uint64_t stride
9623 // };
9624 Type *Int64Ty = Builder.getInt64Ty();
9625 StructType *DimTy = StructType::create(
9626 Context&: M.getContext(), Elements: ArrayRef<Type *>({Int64Ty, Int64Ty, Int64Ty}),
9627 Name: "struct.descriptor_dim");
9628
9629 enum { OffsetFD = 0, CountFD, StrideFD };
9630 // We need two index variable here since the size of "Dims" is the same as
9631 // the size of Components, however, the size of offset, count, and stride is
9632 // equal to the size of base declaration that is non-contiguous.
9633 for (unsigned I = 0, L = 0, E = NonContigInfo.Dims.size(); I < E; ++I) {
9634 // Skip emitting ir if dimension size is 1 since it cannot be
9635 // non-contiguous.
9636 if (NonContigInfo.Dims[I] == 1)
9637 continue;
9638 Builder.restoreIP(IP: AllocaIP);
9639 ArrayType *ArrayTy = ArrayType::get(ElementType: DimTy, NumElements: NonContigInfo.Dims[I]);
9640 AllocaInst *DimsAddr =
9641 Builder.CreateAlloca(Ty: ArrayTy, /* ArraySize = */ nullptr, Name: "dims");
9642 Builder.restoreIP(IP: CodeGenIP);
9643 for (unsigned II = 0, EE = NonContigInfo.Dims[I]; II < EE; ++II) {
9644 unsigned RevIdx = EE - II - 1;
9645 Value *DimsLVal = Builder.CreateInBoundsGEP(
9646 Ty: ArrayTy, Ptr: DimsAddr, IdxList: {Builder.getInt64(C: 0), Builder.getInt64(C: II)});
9647 // Offset
9648 Value *OffsetLVal = Builder.CreateStructGEP(Ty: DimTy, Ptr: DimsLVal, Idx: OffsetFD);
9649 Builder.CreateAlignedStore(
9650 Val: NonContigInfo.Offsets[L][RevIdx], Ptr: OffsetLVal,
9651 Align: M.getDataLayout().getPrefTypeAlign(Ty: OffsetLVal->getType()));
9652 // Count
9653 Value *CountLVal = Builder.CreateStructGEP(Ty: DimTy, Ptr: DimsLVal, Idx: CountFD);
9654 Builder.CreateAlignedStore(
9655 Val: NonContigInfo.Counts[L][RevIdx], Ptr: CountLVal,
9656 Align: M.getDataLayout().getPrefTypeAlign(Ty: CountLVal->getType()));
9657 // Stride
9658 Value *StrideLVal = Builder.CreateStructGEP(Ty: DimTy, Ptr: DimsLVal, Idx: StrideFD);
9659 Builder.CreateAlignedStore(
9660 Val: NonContigInfo.Strides[L][RevIdx], Ptr: StrideLVal,
9661 Align: M.getDataLayout().getPrefTypeAlign(Ty: CountLVal->getType()));
9662 }
9663 // args[I] = &dims
9664 Builder.restoreIP(IP: CodeGenIP);
9665 Value *DAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
9666 V: DimsAddr, DestTy: Builder.getPtrTy());
9667 Value *P = Builder.CreateConstInBoundsGEP2_32(
9668 Ty: ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: Info.NumberOfPtrs),
9669 Ptr: Info.RTArgs.PointersArray, Idx0: 0, Idx1: I);
9670 Builder.CreateAlignedStore(
9671 Val: DAddr, Ptr: P, Align: M.getDataLayout().getPrefTypeAlign(Ty: Builder.getPtrTy()));
9672 ++L;
9673 }
9674}
9675
9676void OpenMPIRBuilder::emitUDMapperArrayInitOrDel(
9677 Function *MapperFn, Value *MapperHandle, Value *Base, Value *Begin,
9678 Value *Size, Value *MapType, Value *MapName, TypeSize ElementSize,
9679 BasicBlock *ExitBB, bool IsInit) {
9680 StringRef Prefix = IsInit ? ".init" : ".del";
9681
9682 // Evaluate if this is an array section.
9683 BasicBlock *BodyBB = BasicBlock::Create(
9684 Context&: M.getContext(), Name: createPlatformSpecificName(Parts: {"omp.array", Prefix}));
9685 Value *IsArray =
9686 Builder.CreateICmpSGT(LHS: Size, RHS: Builder.getInt64(C: 1), Name: "omp.arrayinit.isarray");
9687 Value *DeleteBit = Builder.CreateAnd(
9688 LHS: MapType,
9689 RHS: Builder.getInt64(
9690 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9691 OpenMPOffloadMappingFlags::OMP_MAP_DELETE)));
9692 Value *DeleteCond;
9693 Value *Cond;
9694 if (IsInit) {
9695 // base != begin?
9696 Value *BaseIsBegin = Builder.CreateICmpNE(LHS: Base, RHS: Begin);
9697 Cond = Builder.CreateOr(LHS: IsArray, RHS: BaseIsBegin);
9698 DeleteCond = Builder.CreateIsNull(
9699 Arg: DeleteBit,
9700 Name: createPlatformSpecificName(Parts: {"omp.array", Prefix, ".delete"}));
9701 } else {
9702 Cond = IsArray;
9703 DeleteCond = Builder.CreateIsNotNull(
9704 Arg: DeleteBit,
9705 Name: createPlatformSpecificName(Parts: {"omp.array", Prefix, ".delete"}));
9706 }
9707 Cond = Builder.CreateAnd(LHS: Cond, RHS: DeleteCond);
9708 Builder.CreateCondBr(Cond, True: BodyBB, False: ExitBB);
9709
9710 emitBlock(BB: BodyBB, CurFn: MapperFn);
9711 // Get the array size by multiplying element size and element number (i.e., \p
9712 // Size).
9713 Value *ArraySize = Builder.CreateNUWMul(LHS: Size, RHS: Builder.getInt64(C: ElementSize));
9714 // Remove OMP_MAP_TO and OMP_MAP_FROM from the map type, so that it achieves
9715 // memory allocation/deletion purpose only.
9716 Value *MapTypeArg = Builder.CreateAnd(
9717 LHS: MapType,
9718 RHS: Builder.getInt64(
9719 C: ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9720 OpenMPOffloadMappingFlags::OMP_MAP_TO |
9721 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
9722 MapTypeArg = Builder.CreateOr(
9723 LHS: MapTypeArg,
9724 RHS: Builder.getInt64(
9725 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9726 OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT)));
9727
9728 // Call the runtime API __tgt_push_mapper_component to fill up the runtime
9729 // data structure.
9730 Value *OffloadingArgs[] = {MapperHandle, Base, Begin,
9731 ArraySize, MapTypeArg, MapName};
9732 createRuntimeFunctionCall(
9733 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___tgt_push_mapper_component),
9734 Args: OffloadingArgs);
9735}
9736
9737Expected<Function *> OpenMPIRBuilder::emitUserDefinedMapper(
9738 function_ref<MapInfosOrErrorTy(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
9739 llvm::Value *BeginArg)>
9740 GenMapInfoCB,
9741 Type *ElemTy, StringRef FuncName, CustomMapperCallbackTy CustomMapperCB) {
9742 SmallVector<Type *> Params;
9743 Params.emplace_back(Args: Builder.getPtrTy());
9744 Params.emplace_back(Args: Builder.getPtrTy());
9745 Params.emplace_back(Args: Builder.getPtrTy());
9746 Params.emplace_back(Args: Builder.getInt64Ty());
9747 Params.emplace_back(Args: Builder.getInt64Ty());
9748 Params.emplace_back(Args: Builder.getPtrTy());
9749
9750 auto *FnTy =
9751 FunctionType::get(Result: Builder.getVoidTy(), Params, /* IsVarArg */ isVarArg: false);
9752
9753 SmallString<64> TyStr;
9754 raw_svector_ostream Out(TyStr);
9755 Function *MapperFn =
9756 Function::Create(Ty: FnTy, Linkage: GlobalValue::InternalLinkage, N: FuncName, M);
9757 MapperFn->addFnAttr(Kind: Attribute::NoInline);
9758 MapperFn->addFnAttr(Kind: Attribute::NoUnwind);
9759 MapperFn->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
9760 MapperFn->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
9761 MapperFn->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
9762 MapperFn->addParamAttr(ArgNo: 3, Kind: Attribute::NoUndef);
9763 MapperFn->addParamAttr(ArgNo: 4, Kind: Attribute::NoUndef);
9764 MapperFn->addParamAttr(ArgNo: 5, Kind: Attribute::NoUndef);
9765
9766 // Start the mapper function code generation.
9767 BasicBlock *EntryBB = BasicBlock::Create(Context&: M.getContext(), Name: "entry", Parent: MapperFn);
9768 auto SavedIP = Builder.saveIP();
9769 Builder.SetInsertPoint(EntryBB);
9770
9771 Value *MapperHandle = MapperFn->getArg(i: 0);
9772 Value *BaseIn = MapperFn->getArg(i: 1);
9773 Value *BeginIn = MapperFn->getArg(i: 2);
9774 Value *Size = MapperFn->getArg(i: 3);
9775 Value *MapType = MapperFn->getArg(i: 4);
9776 Value *MapName = MapperFn->getArg(i: 5);
9777
9778 // Compute the starting and end addresses of array elements.
9779 // Prepare common arguments for array initiation and deletion.
9780 // Convert the size in bytes into the number of array elements.
9781 TypeSize ElementSize = M.getDataLayout().getTypeStoreSize(Ty: ElemTy);
9782 Size = Builder.CreateExactUDiv(LHS: Size, RHS: Builder.getInt64(C: ElementSize));
9783 Value *PtrBegin = BeginIn;
9784 Value *PtrEnd = Builder.CreateGEP(Ty: ElemTy, Ptr: PtrBegin, IdxList: Size);
9785
9786 // Emit array initiation if this is an array section and \p MapType indicates
9787 // that memory allocation is required.
9788 BasicBlock *HeadBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.arraymap.head");
9789 emitUDMapperArrayInitOrDel(MapperFn, MapperHandle, Base: BaseIn, Begin: BeginIn, Size,
9790 MapType, MapName, ElementSize, ExitBB: HeadBB,
9791 /*IsInit=*/true);
9792
9793 // Emit a for loop to iterate through SizeArg of elements and map all of them.
9794
9795 // Emit the loop header block.
9796 emitBlock(BB: HeadBB, CurFn: MapperFn);
9797 BasicBlock *BodyBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.arraymap.body");
9798 BasicBlock *DoneBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.done");
9799 // Evaluate whether the initial condition is satisfied.
9800 Value *IsEmpty =
9801 Builder.CreateICmpEQ(LHS: PtrBegin, RHS: PtrEnd, Name: "omp.arraymap.isempty");
9802 Builder.CreateCondBr(Cond: IsEmpty, True: DoneBB, False: BodyBB);
9803
9804 // Emit the loop body block.
9805 emitBlock(BB: BodyBB, CurFn: MapperFn);
9806 BasicBlock *LastBB = BodyBB;
9807 PHINode *PtrPHI =
9808 Builder.CreatePHI(Ty: PtrBegin->getType(), NumReservedValues: 2, Name: "omp.arraymap.ptrcurrent");
9809 PtrPHI->addIncoming(V: PtrBegin, BB: HeadBB);
9810
9811 // Get map clause information. Fill up the arrays with all mapped variables.
9812 MapInfosOrErrorTy Info = GenMapInfoCB(Builder.saveIP(), PtrPHI, BeginIn);
9813 if (!Info)
9814 return Info.takeError();
9815
9816 // Call the runtime API __tgt_mapper_num_components to get the number of
9817 // pre-existing components.
9818 Value *OffloadingArgs[] = {MapperHandle};
9819 Value *PreviousSize = createRuntimeFunctionCall(
9820 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___tgt_mapper_num_components),
9821 Args: OffloadingArgs);
9822 Value *ShiftedPreviousSize =
9823 Builder.CreateShl(LHS: PreviousSize, RHS: Builder.getInt64(C: getFlagMemberOffset()));
9824
9825 // Fill up the runtime mapper handle for all components.
9826 for (unsigned I = 0; I < Info->BasePointers.size(); ++I) {
9827 Value *CurBaseArg = Info->BasePointers[I];
9828 Value *CurBeginArg = Info->Pointers[I];
9829 Value *CurSizeArg = Info->Sizes[I];
9830 Value *CurNameArg = Info->Names.size()
9831 ? Info->Names[I]
9832 : Constant::getNullValue(Ty: Builder.getPtrTy());
9833
9834 // Extract the MEMBER_OF field from the map type.
9835 Value *OriMapType = Builder.getInt64(
9836 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9837 Info->Types[I]));
9838 Value *MemberMapType =
9839 Builder.CreateNUWAdd(LHS: OriMapType, RHS: ShiftedPreviousSize);
9840
9841 // Combine the map type inherited from user-defined mapper with that
9842 // specified in the program. According to the OMP_MAP_TO and OMP_MAP_FROM
9843 // bits of the \a MapType, which is the input argument of the mapper
9844 // function, the following code will set the OMP_MAP_TO and OMP_MAP_FROM
9845 // bits of MemberMapType.
9846 // [OpenMP 5.0], 1.2.6. map-type decay.
9847 // | alloc | to | from | tofrom | release | delete
9848 // ----------------------------------------------------------
9849 // alloc | alloc | alloc | alloc | alloc | release | delete
9850 // to | alloc | to | alloc | to | release | delete
9851 // from | alloc | alloc | from | from | release | delete
9852 // tofrom | alloc | to | from | tofrom | release | delete
9853 Value *LeftToFrom = Builder.CreateAnd(
9854 LHS: MapType,
9855 RHS: Builder.getInt64(
9856 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9857 OpenMPOffloadMappingFlags::OMP_MAP_TO |
9858 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
9859 BasicBlock *AllocBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.alloc");
9860 BasicBlock *AllocElseBB =
9861 BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.alloc.else");
9862 BasicBlock *ToBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.to");
9863 BasicBlock *ToElseBB =
9864 BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.to.else");
9865 BasicBlock *FromBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.from");
9866 BasicBlock *EndBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.end");
9867 Value *IsAlloc = Builder.CreateIsNull(Arg: LeftToFrom);
9868 Builder.CreateCondBr(Cond: IsAlloc, True: AllocBB, False: AllocElseBB);
9869 // In case of alloc, clear OMP_MAP_TO and OMP_MAP_FROM.
9870 emitBlock(BB: AllocBB, CurFn: MapperFn);
9871 Value *AllocMapType = Builder.CreateAnd(
9872 LHS: MemberMapType,
9873 RHS: Builder.getInt64(
9874 C: ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9875 OpenMPOffloadMappingFlags::OMP_MAP_TO |
9876 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
9877 Builder.CreateBr(Dest: EndBB);
9878 emitBlock(BB: AllocElseBB, CurFn: MapperFn);
9879 Value *IsTo = Builder.CreateICmpEQ(
9880 LHS: LeftToFrom,
9881 RHS: Builder.getInt64(
9882 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9883 OpenMPOffloadMappingFlags::OMP_MAP_TO)));
9884 Builder.CreateCondBr(Cond: IsTo, True: ToBB, False: ToElseBB);
9885 // In case of to, clear OMP_MAP_FROM.
9886 emitBlock(BB: ToBB, CurFn: MapperFn);
9887 Value *ToMapType = Builder.CreateAnd(
9888 LHS: MemberMapType,
9889 RHS: Builder.getInt64(
9890 C: ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9891 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
9892 Builder.CreateBr(Dest: EndBB);
9893 emitBlock(BB: ToElseBB, CurFn: MapperFn);
9894 Value *IsFrom = Builder.CreateICmpEQ(
9895 LHS: LeftToFrom,
9896 RHS: Builder.getInt64(
9897 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9898 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
9899 Builder.CreateCondBr(Cond: IsFrom, True: FromBB, False: EndBB);
9900 // In case of from, clear OMP_MAP_TO.
9901 emitBlock(BB: FromBB, CurFn: MapperFn);
9902 Value *FromMapType = Builder.CreateAnd(
9903 LHS: MemberMapType,
9904 RHS: Builder.getInt64(
9905 C: ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9906 OpenMPOffloadMappingFlags::OMP_MAP_TO)));
9907 // In case of tofrom, do nothing.
9908 emitBlock(BB: EndBB, CurFn: MapperFn);
9909 LastBB = EndBB;
9910 PHINode *CurMapType =
9911 Builder.CreatePHI(Ty: Builder.getInt64Ty(), NumReservedValues: 4, Name: "omp.maptype");
9912 CurMapType->addIncoming(V: AllocMapType, BB: AllocBB);
9913 CurMapType->addIncoming(V: ToMapType, BB: ToBB);
9914 CurMapType->addIncoming(V: FromMapType, BB: FromBB);
9915 CurMapType->addIncoming(V: MemberMapType, BB: ToElseBB);
9916
9917 Value *OffloadingArgs[] = {MapperHandle, CurBaseArg, CurBeginArg,
9918 CurSizeArg, CurMapType, CurNameArg};
9919
9920 auto ChildMapperFn = CustomMapperCB(I);
9921 if (!ChildMapperFn)
9922 return ChildMapperFn.takeError();
9923 if (*ChildMapperFn) {
9924 // Call the corresponding mapper function.
9925 createRuntimeFunctionCall(Callee: *ChildMapperFn, Args: OffloadingArgs)
9926 ->setDoesNotThrow();
9927 } else {
9928 // Call the runtime API __tgt_push_mapper_component to fill up the runtime
9929 // data structure.
9930 createRuntimeFunctionCall(
9931 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___tgt_push_mapper_component),
9932 Args: OffloadingArgs);
9933 }
9934 }
9935
9936 // Update the pointer to point to the next element that needs to be mapped,
9937 // and check whether we have mapped all elements.
9938 Value *PtrNext = Builder.CreateConstGEP1_32(Ty: ElemTy, Ptr: PtrPHI, /*Idx0=*/1,
9939 Name: "omp.arraymap.next");
9940 PtrPHI->addIncoming(V: PtrNext, BB: LastBB);
9941 Value *IsDone = Builder.CreateICmpEQ(LHS: PtrNext, RHS: PtrEnd, Name: "omp.arraymap.isdone");
9942 BasicBlock *ExitBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.arraymap.exit");
9943 Builder.CreateCondBr(Cond: IsDone, True: ExitBB, False: BodyBB);
9944
9945 emitBlock(BB: ExitBB, CurFn: MapperFn);
9946 // Emit array deletion if this is an array section and \p MapType indicates
9947 // that deletion is required.
9948 emitUDMapperArrayInitOrDel(MapperFn, MapperHandle, Base: BaseIn, Begin: BeginIn, Size,
9949 MapType, MapName, ElementSize, ExitBB: DoneBB,
9950 /*IsInit=*/false);
9951
9952 // Emit the function exit block.
9953 emitBlock(BB: DoneBB, CurFn: MapperFn, /*IsFinished=*/true);
9954
9955 Builder.CreateRetVoid();
9956 Builder.restoreIP(IP: SavedIP);
9957 return MapperFn;
9958}
9959
9960Error OpenMPIRBuilder::emitOffloadingArrays(
9961 InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
9962 TargetDataInfo &Info, CustomMapperCallbackTy CustomMapperCB,
9963 bool IsNonContiguous,
9964 function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
9965
9966 // Reset the array information.
9967 Info.clearArrayInfo();
9968 Info.NumberOfPtrs = CombinedInfo.BasePointers.size();
9969
9970 if (Info.NumberOfPtrs == 0)
9971 return Error::success();
9972
9973 Builder.restoreIP(IP: AllocaIP);
9974 // Detect if we have any capture size requiring runtime evaluation of the
9975 // size so that a constant array could be eventually used.
9976 ArrayType *PointerArrayType =
9977 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: Info.NumberOfPtrs);
9978
9979 Info.RTArgs.BasePointersArray = Builder.CreateAlloca(
9980 Ty: PointerArrayType, /* ArraySize = */ nullptr, Name: ".offload_baseptrs");
9981
9982 Info.RTArgs.PointersArray = Builder.CreateAlloca(
9983 Ty: PointerArrayType, /* ArraySize = */ nullptr, Name: ".offload_ptrs");
9984 AllocaInst *MappersArray = Builder.CreateAlloca(
9985 Ty: PointerArrayType, /* ArraySize = */ nullptr, Name: ".offload_mappers");
9986 Info.RTArgs.MappersArray = MappersArray;
9987
9988 // If we don't have any VLA types or other types that require runtime
9989 // evaluation, we can use a constant array for the map sizes, otherwise we
9990 // need to fill up the arrays as we do for the pointers.
9991 Type *Int64Ty = Builder.getInt64Ty();
9992 SmallVector<Constant *> ConstSizes(CombinedInfo.Sizes.size(),
9993 ConstantInt::get(Ty: Int64Ty, V: 0));
9994 SmallBitVector RuntimeSizes(CombinedInfo.Sizes.size());
9995 for (unsigned I = 0, E = CombinedInfo.Sizes.size(); I < E; ++I) {
9996 bool IsNonContigEntry =
9997 IsNonContiguous &&
9998 (static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9999 CombinedInfo.Types[I] &
10000 OpenMPOffloadMappingFlags::OMP_MAP_NON_CONTIG) != 0);
10001 // For NON_CONTIG entries, ArgSizes stores the dimension count (number of
10002 // descriptor_dim records), not the byte size.
10003 if (IsNonContigEntry) {
10004 assert(I < CombinedInfo.NonContigInfo.Dims.size() &&
10005 "Index must be in-bounds for NON_CONTIG Dims array");
10006 const uint64_t DimCount = CombinedInfo.NonContigInfo.Dims[I];
10007 assert(DimCount > 0 && "NON_CONTIG DimCount must be > 0");
10008 ConstSizes[I] = ConstantInt::get(Ty: Int64Ty, V: DimCount);
10009 continue;
10010 }
10011 if (auto *CI = dyn_cast<Constant>(Val: CombinedInfo.Sizes[I])) {
10012 if (!isa<ConstantExpr>(Val: CI) && !isa<GlobalValue>(Val: CI)) {
10013 ConstSizes[I] = CI;
10014 continue;
10015 }
10016 }
10017 RuntimeSizes.set(I);
10018 }
10019
10020 if (RuntimeSizes.all()) {
10021 ArrayType *SizeArrayType = ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs);
10022 Info.RTArgs.SizesArray = Builder.CreateAlloca(
10023 Ty: SizeArrayType, /* ArraySize = */ nullptr, Name: ".offload_sizes");
10024 restoreIPandDebugLoc(Builder, IP: CodeGenIP);
10025 } else {
10026 auto *SizesArrayInit = ConstantArray::get(
10027 T: ArrayType::get(ElementType: Int64Ty, NumElements: ConstSizes.size()), V: ConstSizes);
10028 std::string Name = createPlatformSpecificName(Parts: {"offload_sizes"});
10029 auto *SizesArrayGbl =
10030 new GlobalVariable(M, SizesArrayInit->getType(), /*isConstant=*/true,
10031 GlobalValue::PrivateLinkage, SizesArrayInit, Name);
10032 SizesArrayGbl->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
10033
10034 if (!RuntimeSizes.any()) {
10035 Info.RTArgs.SizesArray = SizesArrayGbl;
10036 } else {
10037 unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(AS: 0);
10038 Align OffloadSizeAlign = M.getDataLayout().getABIIntegerTypeAlignment(BitWidth: 64);
10039 ArrayType *SizeArrayType = ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs);
10040 AllocaInst *Buffer = Builder.CreateAlloca(
10041 Ty: SizeArrayType, /* ArraySize = */ nullptr, Name: ".offload_sizes");
10042 Buffer->setAlignment(OffloadSizeAlign);
10043 restoreIPandDebugLoc(Builder, IP: CodeGenIP);
10044 Builder.CreateMemCpy(
10045 Dst: Buffer, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: Buffer->getType()),
10046 Src: SizesArrayGbl, SrcAlign: OffloadSizeAlign,
10047 Size: Builder.getIntN(
10048 N: IndexSize,
10049 C: Buffer->getAllocationSize(DL: M.getDataLayout())->getFixedValue()));
10050
10051 Info.RTArgs.SizesArray = Buffer;
10052 }
10053 restoreIPandDebugLoc(Builder, IP: CodeGenIP);
10054 }
10055
10056 // The map types are always constant so we don't need to generate code to
10057 // fill arrays. Instead, we create an array constant.
10058 SmallVector<uint64_t, 4> Mapping;
10059 for (auto mapFlag : CombinedInfo.Types)
10060 Mapping.push_back(
10061 Elt: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
10062 mapFlag));
10063 std::string MaptypesName = createPlatformSpecificName(Parts: {"offload_maptypes"});
10064 auto *MapTypesArrayGbl = createOffloadMaptypes(Mappings&: Mapping, VarName: MaptypesName);
10065 Info.RTArgs.MapTypesArray = MapTypesArrayGbl;
10066
10067 // The information types are only built if provided.
10068 if (!CombinedInfo.Names.empty()) {
10069 auto *MapNamesArrayGbl = createOffloadMapnames(
10070 Names&: CombinedInfo.Names, VarName: createPlatformSpecificName(Parts: {"offload_mapnames"}));
10071 Info.RTArgs.MapNamesArray = MapNamesArrayGbl;
10072 Info.EmitDebug = true;
10073 } else {
10074 Info.RTArgs.MapNamesArray =
10075 Constant::getNullValue(Ty: PointerType::getUnqual(C&: Builder.getContext()));
10076 Info.EmitDebug = false;
10077 }
10078
10079 // If there's a present map type modifier, it must not be applied to the end
10080 // of a region, so generate a separate map type array in that case.
10081 if (Info.separateBeginEndCalls()) {
10082 bool EndMapTypesDiffer = false;
10083 for (uint64_t &Type : Mapping) {
10084 if (Type & static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
10085 OpenMPOffloadMappingFlags::OMP_MAP_PRESENT)) {
10086 Type &= ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
10087 OpenMPOffloadMappingFlags::OMP_MAP_PRESENT);
10088 EndMapTypesDiffer = true;
10089 }
10090 }
10091 if (EndMapTypesDiffer) {
10092 MapTypesArrayGbl = createOffloadMaptypes(Mappings&: Mapping, VarName: MaptypesName);
10093 Info.RTArgs.MapTypesArrayEnd = MapTypesArrayGbl;
10094 }
10095 }
10096
10097 PointerType *PtrTy = Builder.getPtrTy();
10098 for (unsigned I = 0; I < Info.NumberOfPtrs; ++I) {
10099 Value *BPVal = CombinedInfo.BasePointers[I];
10100 Value *BP = Builder.CreateConstInBoundsGEP2_32(
10101 Ty: ArrayType::get(ElementType: PtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.BasePointersArray,
10102 Idx0: 0, Idx1: I);
10103 Builder.CreateAlignedStore(Val: BPVal, Ptr: BP,
10104 Align: M.getDataLayout().getPrefTypeAlign(Ty: PtrTy));
10105
10106 if (Info.requiresDevicePointerInfo()) {
10107 if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Pointer) {
10108 CodeGenIP = Builder.saveIP();
10109 Builder.restoreIP(IP: AllocaIP);
10110 Info.DevicePtrInfoMap[BPVal] = {BP, Builder.CreateAlloca(Ty: PtrTy)};
10111 Builder.restoreIP(IP: CodeGenIP);
10112 if (DeviceAddrCB)
10113 DeviceAddrCB(I, Info.DevicePtrInfoMap[BPVal].second);
10114 } else if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Address) {
10115 Info.DevicePtrInfoMap[BPVal] = {BP, BP};
10116 if (DeviceAddrCB)
10117 DeviceAddrCB(I, BP);
10118 }
10119 }
10120
10121 Value *PVal = CombinedInfo.Pointers[I];
10122 Value *P = Builder.CreateConstInBoundsGEP2_32(
10123 Ty: ArrayType::get(ElementType: PtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.PointersArray, Idx0: 0,
10124 Idx1: I);
10125 // TODO: Check alignment correct.
10126 Builder.CreateAlignedStore(Val: PVal, Ptr: P,
10127 Align: M.getDataLayout().getPrefTypeAlign(Ty: PtrTy));
10128
10129 if (RuntimeSizes.test(Idx: I)) {
10130 Value *S = Builder.CreateConstInBoundsGEP2_32(
10131 Ty: ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.SizesArray,
10132 /*Idx0=*/0,
10133 /*Idx1=*/I);
10134 Builder.CreateAlignedStore(Val: Builder.CreateIntCast(V: CombinedInfo.Sizes[I],
10135 DestTy: Int64Ty,
10136 /*isSigned=*/true),
10137 Ptr: S, Align: M.getDataLayout().getPrefTypeAlign(Ty: PtrTy));
10138 }
10139 // Fill up the mapper array.
10140 unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(AS: 0);
10141 Value *MFunc = ConstantPointerNull::get(T: PtrTy);
10142
10143 auto CustomMFunc = CustomMapperCB(I);
10144 if (!CustomMFunc)
10145 return CustomMFunc.takeError();
10146 if (*CustomMFunc)
10147 MFunc = Builder.CreatePointerCast(V: *CustomMFunc, DestTy: PtrTy);
10148
10149 Value *MAddr = Builder.CreateInBoundsGEP(
10150 Ty: PointerArrayType, Ptr: MappersArray,
10151 IdxList: {Builder.getIntN(N: IndexSize, C: 0), Builder.getIntN(N: IndexSize, C: I)});
10152 Builder.CreateAlignedStore(
10153 Val: MFunc, Ptr: MAddr, Align: M.getDataLayout().getPrefTypeAlign(Ty: MAddr->getType()));
10154 }
10155
10156 if (!IsNonContiguous || CombinedInfo.NonContigInfo.Offsets.empty() ||
10157 Info.NumberOfPtrs == 0)
10158 return Error::success();
10159 emitNonContiguousDescriptor(AllocaIP, CodeGenIP, CombinedInfo, Info);
10160 return Error::success();
10161}
10162
10163void OpenMPIRBuilder::emitBranch(BasicBlock *Target) {
10164 BasicBlock *CurBB = Builder.GetInsertBlock();
10165
10166 if (!CurBB || CurBB->getTerminator()) {
10167 // If there is no insert point or the previous block is already
10168 // terminated, don't touch it.
10169 } else {
10170 // Otherwise, create a fall-through branch.
10171 Builder.CreateBr(Dest: Target);
10172 }
10173
10174 Builder.ClearInsertionPoint();
10175}
10176
10177void OpenMPIRBuilder::emitBlock(BasicBlock *BB, Function *CurFn,
10178 bool IsFinished) {
10179 BasicBlock *CurBB = Builder.GetInsertBlock();
10180
10181 // Fall out of the current block (if necessary).
10182 emitBranch(Target: BB);
10183
10184 if (IsFinished && BB->use_empty()) {
10185 BB->eraseFromParent();
10186 return;
10187 }
10188
10189 // Place the block after the current block, if possible, or else at
10190 // the end of the function.
10191 if (CurBB && CurBB->getParent())
10192 CurFn->insert(Position: std::next(x: CurBB->getIterator()), BB);
10193 else
10194 CurFn->insert(Position: CurFn->end(), BB);
10195 Builder.SetInsertPoint(BB);
10196}
10197
10198Error OpenMPIRBuilder::emitIfClause(Value *Cond, BodyGenCallbackTy ThenGen,
10199 BodyGenCallbackTy ElseGen,
10200 InsertPointTy AllocaIP) {
10201 // If the condition constant folds and can be elided, try to avoid emitting
10202 // the condition and the dead arm of the if/else.
10203 if (auto *CI = dyn_cast<ConstantInt>(Val: Cond)) {
10204 auto CondConstant = CI->getSExtValue();
10205 if (CondConstant)
10206 return ThenGen(AllocaIP, Builder.saveIP());
10207
10208 return ElseGen(AllocaIP, Builder.saveIP());
10209 }
10210
10211 Function *CurFn = Builder.GetInsertBlock()->getParent();
10212
10213 // Otherwise, the condition did not fold, or we couldn't elide it. Just
10214 // emit the conditional branch.
10215 BasicBlock *ThenBlock = BasicBlock::Create(Context&: M.getContext(), Name: "omp_if.then");
10216 BasicBlock *ElseBlock = BasicBlock::Create(Context&: M.getContext(), Name: "omp_if.else");
10217 BasicBlock *ContBlock = BasicBlock::Create(Context&: M.getContext(), Name: "omp_if.end");
10218 Builder.CreateCondBr(Cond, True: ThenBlock, False: ElseBlock);
10219 // Emit the 'then' code.
10220 emitBlock(BB: ThenBlock, CurFn);
10221 if (Error Err = ThenGen(AllocaIP, Builder.saveIP()))
10222 return Err;
10223 emitBranch(Target: ContBlock);
10224 // Emit the 'else' code if present.
10225 // There is no need to emit line number for unconditional branch.
10226 emitBlock(BB: ElseBlock, CurFn);
10227 if (Error Err = ElseGen(AllocaIP, Builder.saveIP()))
10228 return Err;
10229 // There is no need to emit line number for unconditional branch.
10230 emitBranch(Target: ContBlock);
10231 // Emit the continuation block for code after the if.
10232 emitBlock(BB: ContBlock, CurFn, /*IsFinished=*/true);
10233 return Error::success();
10234}
10235
10236bool OpenMPIRBuilder::checkAndEmitFlushAfterAtomic(
10237 const LocationDescription &Loc, llvm::AtomicOrdering AO, AtomicKind AK) {
10238 assert(!(AO == AtomicOrdering::NotAtomic ||
10239 AO == llvm::AtomicOrdering::Unordered) &&
10240 "Unexpected Atomic Ordering.");
10241
10242 bool Flush = false;
10243 llvm::AtomicOrdering FlushAO = AtomicOrdering::Monotonic;
10244
10245 switch (AK) {
10246 case Read:
10247 if (AO == AtomicOrdering::Acquire || AO == AtomicOrdering::AcquireRelease ||
10248 AO == AtomicOrdering::SequentiallyConsistent) {
10249 FlushAO = AtomicOrdering::Acquire;
10250 Flush = true;
10251 }
10252 break;
10253 case Write:
10254 case Compare:
10255 case Update:
10256 if (AO == AtomicOrdering::Release || AO == AtomicOrdering::AcquireRelease ||
10257 AO == AtomicOrdering::SequentiallyConsistent) {
10258 FlushAO = AtomicOrdering::Release;
10259 Flush = true;
10260 }
10261 break;
10262 case Capture:
10263 switch (AO) {
10264 case AtomicOrdering::Acquire:
10265 FlushAO = AtomicOrdering::Acquire;
10266 Flush = true;
10267 break;
10268 case AtomicOrdering::Release:
10269 FlushAO = AtomicOrdering::Release;
10270 Flush = true;
10271 break;
10272 case AtomicOrdering::AcquireRelease:
10273 case AtomicOrdering::SequentiallyConsistent:
10274 FlushAO = AtomicOrdering::AcquireRelease;
10275 Flush = true;
10276 break;
10277 default:
10278 // do nothing - leave silently.
10279 break;
10280 }
10281 }
10282
10283 if (Flush) {
10284 // Currently Flush RT call still doesn't take memory_ordering, so for when
10285 // that happens, this tries to do the resolution of which atomic ordering
10286 // to use with but issue the flush call
10287 // TODO: pass `FlushAO` after memory ordering support is added
10288 (void)FlushAO;
10289 emitFlush(Loc);
10290 }
10291
10292 // for AO == AtomicOrdering::Monotonic and all other case combinations
10293 // do nothing
10294 return Flush;
10295}
10296
10297OpenMPIRBuilder::InsertPointTy
10298OpenMPIRBuilder::createAtomicRead(const LocationDescription &Loc,
10299 AtomicOpValue &X, AtomicOpValue &V,
10300 AtomicOrdering AO, InsertPointTy AllocaIP) {
10301 if (!updateToLocation(Loc))
10302 return Loc.IP;
10303
10304 assert(X.Var->getType()->isPointerTy() &&
10305 "OMP Atomic expects a pointer to target memory");
10306 Type *XElemTy = X.ElemTy;
10307 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
10308 XElemTy->isPointerTy() || XElemTy->isStructTy()) &&
10309 "OMP atomic read expected a scalar type");
10310
10311 Value *XRead = nullptr;
10312
10313 if (XElemTy->isIntegerTy()) {
10314 LoadInst *XLD =
10315 Builder.CreateLoad(Ty: XElemTy, Ptr: X.Var, isVolatile: X.IsVolatile, Name: "omp.atomic.read");
10316 XLD->setAtomic(Ordering: AO);
10317 XRead = cast<Value>(Val: XLD);
10318 } else if (XElemTy->isStructTy()) {
10319 // FIXME: Add checks to ensure __atomic_load is emitted iff the
10320 // target does not support `atomicrmw` of the size of the struct
10321 LoadInst *OldVal = Builder.CreateLoad(Ty: XElemTy, Ptr: X.Var, Name: "omp.atomic.read");
10322 OldVal->setAtomic(Ordering: AO);
10323 const DataLayout &DL = OldVal->getModule()->getDataLayout();
10324 unsigned LoadSize = DL.getTypeStoreSize(Ty: XElemTy);
10325 OpenMPIRBuilder::AtomicInfo atomicInfo(
10326 &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
10327 OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X.Var);
10328 auto AtomicLoadRes = atomicInfo.EmitAtomicLoadLibcall(AO);
10329 XRead = AtomicLoadRes.first;
10330 OldVal->eraseFromParent();
10331 } else {
10332 // We need to perform atomic op as integer
10333 IntegerType *IntCastTy =
10334 IntegerType::get(C&: M.getContext(), NumBits: XElemTy->getScalarSizeInBits());
10335 LoadInst *XLoad =
10336 Builder.CreateLoad(Ty: IntCastTy, Ptr: X.Var, isVolatile: X.IsVolatile, Name: "omp.atomic.load");
10337 XLoad->setAtomic(Ordering: AO);
10338 if (XElemTy->isFloatingPointTy()) {
10339 XRead = Builder.CreateBitCast(V: XLoad, DestTy: XElemTy, Name: "atomic.flt.cast");
10340 } else {
10341 XRead = Builder.CreateIntToPtr(V: XLoad, DestTy: XElemTy, Name: "atomic.ptr.cast");
10342 }
10343 }
10344 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Read);
10345 Builder.CreateStore(Val: XRead, Ptr: V.Var, isVolatile: V.IsVolatile);
10346 return Builder.saveIP();
10347}
10348
10349OpenMPIRBuilder::InsertPointTy
10350OpenMPIRBuilder::createAtomicWrite(const LocationDescription &Loc,
10351 AtomicOpValue &X, Value *Expr,
10352 AtomicOrdering AO, InsertPointTy AllocaIP) {
10353 if (!updateToLocation(Loc))
10354 return Loc.IP;
10355
10356 assert(X.Var->getType()->isPointerTy() &&
10357 "OMP Atomic expects a pointer to target memory");
10358 Type *XElemTy = X.ElemTy;
10359 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
10360 XElemTy->isPointerTy() || XElemTy->isStructTy()) &&
10361 "OMP atomic write expected a scalar type");
10362
10363 if (XElemTy->isIntegerTy()) {
10364 StoreInst *XSt = Builder.CreateStore(Val: Expr, Ptr: X.Var, isVolatile: X.IsVolatile);
10365 XSt->setAtomic(Ordering: AO);
10366 } else if (XElemTy->isStructTy()) {
10367 LoadInst *OldVal = Builder.CreateLoad(Ty: XElemTy, Ptr: X.Var, Name: "omp.atomic.read");
10368 const DataLayout &DL = OldVal->getModule()->getDataLayout();
10369 unsigned LoadSize = DL.getTypeStoreSize(Ty: XElemTy);
10370 OpenMPIRBuilder::AtomicInfo atomicInfo(
10371 &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
10372 OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X.Var);
10373 atomicInfo.EmitAtomicStoreLibcall(AO, Source: Expr);
10374 OldVal->eraseFromParent();
10375 } else {
10376 // We need to bitcast and perform atomic op as integers
10377 IntegerType *IntCastTy =
10378 IntegerType::get(C&: M.getContext(), NumBits: XElemTy->getScalarSizeInBits());
10379 Value *ExprCast =
10380 Builder.CreateBitCast(V: Expr, DestTy: IntCastTy, Name: "atomic.src.int.cast");
10381 StoreInst *XSt = Builder.CreateStore(Val: ExprCast, Ptr: X.Var, isVolatile: X.IsVolatile);
10382 XSt->setAtomic(Ordering: AO);
10383 }
10384
10385 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Write);
10386 return Builder.saveIP();
10387}
10388
10389OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createAtomicUpdate(
10390 const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
10391 Value *Expr, AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
10392 AtomicUpdateCallbackTy &UpdateOp, bool IsXBinopExpr,
10393 bool IsIgnoreDenormalMode, bool IsFineGrainedMemory, bool IsRemoteMemory) {
10394 assert(!isConflictIP(Loc.IP, AllocaIP) && "IPs must not be ambiguous");
10395 if (!updateToLocation(Loc))
10396 return Loc.IP;
10397
10398 LLVM_DEBUG({
10399 Type *XTy = X.Var->getType();
10400 assert(XTy->isPointerTy() &&
10401 "OMP Atomic expects a pointer to target memory");
10402 Type *XElemTy = X.ElemTy;
10403 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
10404 XElemTy->isPointerTy()) &&
10405 "OMP atomic update expected a scalar type");
10406 assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
10407 (RMWOp != AtomicRMWInst::UMax) && (RMWOp != AtomicRMWInst::UMin) &&
10408 "OpenMP atomic does not support LT or GT operations");
10409 });
10410
10411 Expected<std::pair<Value *, Value *>> AtomicResult = emitAtomicUpdate(
10412 AllocaIP, X: X.Var, XElemTy: X.ElemTy, Expr, AO, RMWOp, UpdateOp, VolatileX: X.IsVolatile,
10413 IsXBinopExpr, IsIgnoreDenormalMode, IsFineGrainedMemory, IsRemoteMemory);
10414 if (!AtomicResult)
10415 return AtomicResult.takeError();
10416 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Update);
10417 return Builder.saveIP();
10418}
10419
10420// FIXME: Duplicating AtomicExpand
10421Value *OpenMPIRBuilder::emitRMWOpAsInstruction(Value *Src1, Value *Src2,
10422 AtomicRMWInst::BinOp RMWOp) {
10423 switch (RMWOp) {
10424 case AtomicRMWInst::Add:
10425 return Builder.CreateAdd(LHS: Src1, RHS: Src2);
10426 case AtomicRMWInst::Sub:
10427 return Builder.CreateSub(LHS: Src1, RHS: Src2);
10428 case AtomicRMWInst::And:
10429 return Builder.CreateAnd(LHS: Src1, RHS: Src2);
10430 case AtomicRMWInst::Nand:
10431 return Builder.CreateNeg(V: Builder.CreateAnd(LHS: Src1, RHS: Src2));
10432 case AtomicRMWInst::Or:
10433 return Builder.CreateOr(LHS: Src1, RHS: Src2);
10434 case AtomicRMWInst::Xor:
10435 return Builder.CreateXor(LHS: Src1, RHS: Src2);
10436 case AtomicRMWInst::Xchg:
10437 case AtomicRMWInst::FAdd:
10438 case AtomicRMWInst::FSub:
10439 case AtomicRMWInst::BAD_BINOP:
10440 case AtomicRMWInst::Max:
10441 case AtomicRMWInst::Min:
10442 case AtomicRMWInst::UMax:
10443 case AtomicRMWInst::UMin:
10444 case AtomicRMWInst::FMax:
10445 case AtomicRMWInst::FMin:
10446 case AtomicRMWInst::FMaximum:
10447 case AtomicRMWInst::FMinimum:
10448 case AtomicRMWInst::UIncWrap:
10449 case AtomicRMWInst::UDecWrap:
10450 case AtomicRMWInst::USubCond:
10451 case AtomicRMWInst::USubSat:
10452 llvm_unreachable("Unsupported atomic update operation");
10453 }
10454 llvm_unreachable("Unsupported atomic update operation");
10455}
10456
10457Expected<std::pair<Value *, Value *>> OpenMPIRBuilder::emitAtomicUpdate(
10458 InsertPointTy AllocaIP, Value *X, Type *XElemTy, Value *Expr,
10459 AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
10460 AtomicUpdateCallbackTy &UpdateOp, bool VolatileX, bool IsXBinopExpr,
10461 bool IsIgnoreDenormalMode, bool IsFineGrainedMemory, bool IsRemoteMemory) {
10462 // TODO: handle the case where XElemTy is not byte-sized or not a power of 2
10463 // or a complex datatype.
10464 bool emitRMWOp = false;
10465 switch (RMWOp) {
10466 case AtomicRMWInst::Add:
10467 case AtomicRMWInst::And:
10468 case AtomicRMWInst::Nand:
10469 case AtomicRMWInst::Or:
10470 case AtomicRMWInst::Xor:
10471 case AtomicRMWInst::Xchg:
10472 emitRMWOp = XElemTy;
10473 break;
10474 case AtomicRMWInst::Sub:
10475 emitRMWOp = (IsXBinopExpr && XElemTy);
10476 break;
10477 default:
10478 emitRMWOp = false;
10479 }
10480 emitRMWOp &= XElemTy->isIntegerTy();
10481
10482 std::pair<Value *, Value *> Res;
10483 if (emitRMWOp) {
10484 AtomicRMWInst *RMWInst =
10485 Builder.CreateAtomicRMW(Op: RMWOp, Ptr: X, Val: Expr, Align: llvm::MaybeAlign(), Ordering: AO);
10486 if (T.isAMDGPU()) {
10487 if (IsIgnoreDenormalMode)
10488 RMWInst->setMetadata(Kind: "amdgpu.ignore.denormal.mode",
10489 Node: llvm::MDNode::get(Context&: Builder.getContext(), MDs: {}));
10490 if (!IsFineGrainedMemory)
10491 RMWInst->setMetadata(Kind: "amdgpu.no.fine.grained.memory",
10492 Node: llvm::MDNode::get(Context&: Builder.getContext(), MDs: {}));
10493 if (!IsRemoteMemory)
10494 RMWInst->setMetadata(Kind: "amdgpu.no.remote.memory",
10495 Node: llvm::MDNode::get(Context&: Builder.getContext(), MDs: {}));
10496 }
10497 Res.first = RMWInst;
10498 // not needed except in case of postfix captures. Generate anyway for
10499 // consistency with the else part. Will be removed with any DCE pass.
10500 // AtomicRMWInst::Xchg does not have a coressponding instruction.
10501 if (RMWOp == AtomicRMWInst::Xchg)
10502 Res.second = Res.first;
10503 else
10504 Res.second = emitRMWOpAsInstruction(Src1: Res.first, Src2: Expr, RMWOp);
10505 } else if (RMWOp == llvm::AtomicRMWInst::BinOp::BAD_BINOP &&
10506 XElemTy->isStructTy()) {
10507 LoadInst *OldVal =
10508 Builder.CreateLoad(Ty: XElemTy, Ptr: X, Name: X->getName() + ".atomic.load");
10509 OldVal->setAtomic(Ordering: AO);
10510 const DataLayout &LoadDL = OldVal->getModule()->getDataLayout();
10511 unsigned LoadSize =
10512 LoadDL.getTypeStoreSize(Ty: OldVal->getPointerOperand()->getType());
10513
10514 OpenMPIRBuilder::AtomicInfo atomicInfo(
10515 &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
10516 OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X);
10517 auto AtomicLoadRes = atomicInfo.EmitAtomicLoadLibcall(AO);
10518 BasicBlock *CurBB = Builder.GetInsertBlock();
10519 Instruction *CurBBTI = CurBB->getTerminator();
10520 CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
10521 BasicBlock *ExitBB =
10522 CurBB->splitBasicBlock(I: CurBBTI, BBName: X->getName() + ".atomic.exit");
10523 BasicBlock *ContBB = CurBB->splitBasicBlock(I: CurBB->getTerminator(),
10524 BBName: X->getName() + ".atomic.cont");
10525 ContBB->getTerminator()->eraseFromParent();
10526 Builder.restoreIP(IP: AllocaIP);
10527 AllocaInst *NewAtomicAddr = Builder.CreateAlloca(Ty: XElemTy);
10528 NewAtomicAddr->setName(X->getName() + "x.new.val");
10529 Builder.SetInsertPoint(ContBB);
10530 llvm::PHINode *PHI = Builder.CreatePHI(Ty: OldVal->getType(), NumReservedValues: 2);
10531 PHI->addIncoming(V: AtomicLoadRes.first, BB: CurBB);
10532 Value *OldExprVal = PHI;
10533 Expected<Value *> CBResult = UpdateOp(OldExprVal, Builder);
10534 if (!CBResult)
10535 return CBResult.takeError();
10536 Value *Upd = *CBResult;
10537 Builder.CreateStore(Val: Upd, Ptr: NewAtomicAddr);
10538 AtomicOrdering Failure =
10539 llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(SuccessOrdering: AO);
10540 auto Result = atomicInfo.EmitAtomicCompareExchangeLibcall(
10541 ExpectedVal: AtomicLoadRes.second, DesiredVal: NewAtomicAddr, Success: AO, Failure);
10542 LoadInst *PHILoad = Builder.CreateLoad(Ty: XElemTy, Ptr: Result.first);
10543 PHI->addIncoming(V: PHILoad, BB: Builder.GetInsertBlock());
10544 Builder.CreateCondBr(Cond: Result.second, True: ExitBB, False: ContBB);
10545 OldVal->eraseFromParent();
10546 Res.first = OldExprVal;
10547 Res.second = Upd;
10548
10549 if (UnreachableInst *ExitTI =
10550 dyn_cast<UnreachableInst>(Val: ExitBB->getTerminator())) {
10551 CurBBTI->eraseFromParent();
10552 Builder.SetInsertPoint(ExitBB);
10553 } else {
10554 Builder.SetInsertPoint(ExitTI);
10555 }
10556 } else {
10557 IntegerType *IntCastTy =
10558 IntegerType::get(C&: M.getContext(), NumBits: XElemTy->getScalarSizeInBits());
10559 LoadInst *OldVal =
10560 Builder.CreateLoad(Ty: IntCastTy, Ptr: X, Name: X->getName() + ".atomic.load");
10561 OldVal->setAtomic(Ordering: AO);
10562 // CurBB
10563 // | /---\
10564 // ContBB |
10565 // | \---/
10566 // ExitBB
10567 BasicBlock *CurBB = Builder.GetInsertBlock();
10568 Instruction *CurBBTI = CurBB->getTerminator();
10569 CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
10570 BasicBlock *ExitBB =
10571 CurBB->splitBasicBlock(I: CurBBTI, BBName: X->getName() + ".atomic.exit");
10572 BasicBlock *ContBB = CurBB->splitBasicBlock(I: CurBB->getTerminator(),
10573 BBName: X->getName() + ".atomic.cont");
10574 ContBB->getTerminator()->eraseFromParent();
10575 Builder.restoreIP(IP: AllocaIP);
10576 AllocaInst *NewAtomicAddr = Builder.CreateAlloca(Ty: XElemTy);
10577 NewAtomicAddr->setName(X->getName() + "x.new.val");
10578 Builder.SetInsertPoint(ContBB);
10579 llvm::PHINode *PHI = Builder.CreatePHI(Ty: OldVal->getType(), NumReservedValues: 2);
10580 PHI->addIncoming(V: OldVal, BB: CurBB);
10581 bool IsIntTy = XElemTy->isIntegerTy();
10582 Value *OldExprVal = PHI;
10583 if (!IsIntTy) {
10584 if (XElemTy->isFloatingPointTy()) {
10585 OldExprVal = Builder.CreateBitCast(V: PHI, DestTy: XElemTy,
10586 Name: X->getName() + ".atomic.fltCast");
10587 } else {
10588 OldExprVal = Builder.CreateIntToPtr(V: PHI, DestTy: XElemTy,
10589 Name: X->getName() + ".atomic.ptrCast");
10590 }
10591 }
10592
10593 Expected<Value *> CBResult = UpdateOp(OldExprVal, Builder);
10594 if (!CBResult)
10595 return CBResult.takeError();
10596 Value *Upd = *CBResult;
10597 Builder.CreateStore(Val: Upd, Ptr: NewAtomicAddr);
10598 LoadInst *DesiredVal = Builder.CreateLoad(Ty: IntCastTy, Ptr: NewAtomicAddr);
10599 AtomicOrdering Failure =
10600 llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(SuccessOrdering: AO);
10601 AtomicCmpXchgInst *Result = Builder.CreateAtomicCmpXchg(
10602 Ptr: X, Cmp: PHI, New: DesiredVal, Align: llvm::MaybeAlign(), SuccessOrdering: AO, FailureOrdering: Failure);
10603 Result->setVolatile(VolatileX);
10604 Value *PreviousVal = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/0);
10605 Value *SuccessFailureVal = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/1);
10606 PHI->addIncoming(V: PreviousVal, BB: Builder.GetInsertBlock());
10607 Builder.CreateCondBr(Cond: SuccessFailureVal, True: ExitBB, False: ContBB);
10608
10609 Res.first = OldExprVal;
10610 Res.second = Upd;
10611
10612 // set Insertion point in exit block
10613 if (UnreachableInst *ExitTI =
10614 dyn_cast<UnreachableInst>(Val: ExitBB->getTerminator())) {
10615 CurBBTI->eraseFromParent();
10616 Builder.SetInsertPoint(ExitBB);
10617 } else {
10618 Builder.SetInsertPoint(ExitTI);
10619 }
10620 }
10621
10622 return Res;
10623}
10624
10625OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createAtomicCapture(
10626 const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
10627 AtomicOpValue &V, Value *Expr, AtomicOrdering AO,
10628 AtomicRMWInst::BinOp RMWOp, AtomicUpdateCallbackTy &UpdateOp,
10629 bool UpdateExpr, bool IsPostfixUpdate, bool IsXBinopExpr,
10630 bool IsIgnoreDenormalMode, bool IsFineGrainedMemory, bool IsRemoteMemory) {
10631 if (!updateToLocation(Loc))
10632 return Loc.IP;
10633
10634 LLVM_DEBUG({
10635 Type *XTy = X.Var->getType();
10636 assert(XTy->isPointerTy() &&
10637 "OMP Atomic expects a pointer to target memory");
10638 Type *XElemTy = X.ElemTy;
10639 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
10640 XElemTy->isPointerTy()) &&
10641 "OMP atomic capture expected a scalar type");
10642 assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
10643 "OpenMP atomic does not support LT or GT operations");
10644 });
10645
10646 // If UpdateExpr is 'x' updated with some `expr` not based on 'x',
10647 // 'x' is simply atomically rewritten with 'expr'.
10648 AtomicRMWInst::BinOp AtomicOp = (UpdateExpr ? RMWOp : AtomicRMWInst::Xchg);
10649 Expected<std::pair<Value *, Value *>> AtomicResult = emitAtomicUpdate(
10650 AllocaIP, X: X.Var, XElemTy: X.ElemTy, Expr, AO, RMWOp: AtomicOp, UpdateOp, VolatileX: X.IsVolatile,
10651 IsXBinopExpr, IsIgnoreDenormalMode, IsFineGrainedMemory, IsRemoteMemory);
10652 if (!AtomicResult)
10653 return AtomicResult.takeError();
10654 Value *CapturedVal =
10655 (IsPostfixUpdate ? AtomicResult->first : AtomicResult->second);
10656 Builder.CreateStore(Val: CapturedVal, Ptr: V.Var, isVolatile: V.IsVolatile);
10657
10658 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Capture);
10659 return Builder.saveIP();
10660}
10661
10662OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
10663 const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
10664 AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
10665 omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
10666 bool IsFailOnly) {
10667
10668 AtomicOrdering Failure = AtomicCmpXchgInst::getStrongestFailureOrdering(SuccessOrdering: AO);
10669 return createAtomicCompare(Loc, X, V, R, E, D, AO, Op, IsXBinopExpr,
10670 IsPostfixUpdate, IsFailOnly, Failure);
10671}
10672
10673OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
10674 const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
10675 AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
10676 omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
10677 bool IsFailOnly, AtomicOrdering Failure) {
10678
10679 if (!updateToLocation(Loc))
10680 return Loc.IP;
10681
10682 assert(X.Var->getType()->isPointerTy() &&
10683 "OMP atomic expects a pointer to target memory");
10684 // compare capture
10685 if (V.Var) {
10686 assert(V.Var->getType()->isPointerTy() && "v.var must be of pointer type");
10687 assert(V.ElemTy == X.ElemTy && "x and v must be of same type");
10688 }
10689
10690 bool IsInteger = E->getType()->isIntegerTy();
10691
10692 if (Op == OMPAtomicCompareOp::EQ) {
10693 AtomicCmpXchgInst *Result = nullptr;
10694 if (!IsInteger) {
10695 IntegerType *IntCastTy =
10696 IntegerType::get(C&: M.getContext(), NumBits: X.ElemTy->getScalarSizeInBits());
10697 Value *EBCast = Builder.CreateBitCast(V: E, DestTy: IntCastTy);
10698 Value *DBCast = Builder.CreateBitCast(V: D, DestTy: IntCastTy);
10699 Result = Builder.CreateAtomicCmpXchg(Ptr: X.Var, Cmp: EBCast, New: DBCast, Align: MaybeAlign(),
10700 SuccessOrdering: AO, FailureOrdering: Failure);
10701 } else {
10702 Result =
10703 Builder.CreateAtomicCmpXchg(Ptr: X.Var, Cmp: E, New: D, Align: MaybeAlign(), SuccessOrdering: AO, FailureOrdering: Failure);
10704 }
10705
10706 if (V.Var) {
10707 Value *OldValue = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/0);
10708 if (!IsInteger)
10709 OldValue = Builder.CreateBitCast(V: OldValue, DestTy: X.ElemTy);
10710 assert(OldValue->getType() == V.ElemTy &&
10711 "OldValue and V must be of same type");
10712 if (IsPostfixUpdate) {
10713 Builder.CreateStore(Val: OldValue, Ptr: V.Var, isVolatile: V.IsVolatile);
10714 } else {
10715 Value *SuccessOrFail = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/1);
10716 if (IsFailOnly) {
10717 // CurBB----
10718 // | |
10719 // v |
10720 // ContBB |
10721 // | |
10722 // v |
10723 // ExitBB <-
10724 //
10725 // where ContBB only contains the store of old value to 'v'.
10726 BasicBlock *CurBB = Builder.GetInsertBlock();
10727 Instruction *CurBBTI = CurBB->getTerminator();
10728 CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
10729 BasicBlock *ExitBB = CurBB->splitBasicBlock(
10730 I: CurBBTI, BBName: X.Var->getName() + ".atomic.exit");
10731 BasicBlock *ContBB = CurBB->splitBasicBlock(
10732 I: CurBB->getTerminator(), BBName: X.Var->getName() + ".atomic.cont");
10733 ContBB->getTerminator()->eraseFromParent();
10734 CurBB->getTerminator()->eraseFromParent();
10735
10736 Builder.CreateCondBr(Cond: SuccessOrFail, True: ExitBB, False: ContBB);
10737
10738 Builder.SetInsertPoint(ContBB);
10739 Builder.CreateStore(Val: OldValue, Ptr: V.Var);
10740 Builder.CreateBr(Dest: ExitBB);
10741
10742 if (UnreachableInst *ExitTI =
10743 dyn_cast<UnreachableInst>(Val: ExitBB->getTerminator())) {
10744 CurBBTI->eraseFromParent();
10745 Builder.SetInsertPoint(ExitBB);
10746 } else {
10747 Builder.SetInsertPoint(ExitTI);
10748 }
10749 } else {
10750 Value *CapturedValue =
10751 Builder.CreateSelect(C: SuccessOrFail, True: E, False: OldValue);
10752 Builder.CreateStore(Val: CapturedValue, Ptr: V.Var, isVolatile: V.IsVolatile);
10753 }
10754 }
10755 }
10756 // The comparison result has to be stored.
10757 if (R.Var) {
10758 assert(R.Var->getType()->isPointerTy() &&
10759 "r.var must be of pointer type");
10760 assert(R.ElemTy->isIntegerTy() && "r must be of integral type");
10761
10762 Value *SuccessFailureVal = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/1);
10763 Value *ResultCast = R.IsSigned
10764 ? Builder.CreateSExt(V: SuccessFailureVal, DestTy: R.ElemTy)
10765 : Builder.CreateZExt(V: SuccessFailureVal, DestTy: R.ElemTy);
10766 Builder.CreateStore(Val: ResultCast, Ptr: R.Var, isVolatile: R.IsVolatile);
10767 }
10768 } else {
10769 assert((Op == OMPAtomicCompareOp::MAX || Op == OMPAtomicCompareOp::MIN) &&
10770 "Op should be either max or min at this point");
10771 assert(!IsFailOnly && "IsFailOnly is only valid when the comparison is ==");
10772
10773 // Reverse the ordop as the OpenMP forms are different from LLVM forms.
10774 // Let's take max as example.
10775 // OpenMP form:
10776 // x = x > expr ? expr : x;
10777 // LLVM form:
10778 // *ptr = *ptr > val ? *ptr : val;
10779 // We need to transform to LLVM form.
10780 // x = x <= expr ? x : expr;
10781 AtomicRMWInst::BinOp NewOp;
10782 if (IsXBinopExpr) {
10783 if (IsInteger) {
10784 if (X.IsSigned)
10785 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Min
10786 : AtomicRMWInst::Max;
10787 else
10788 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMin
10789 : AtomicRMWInst::UMax;
10790 } else {
10791 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMin
10792 : AtomicRMWInst::FMax;
10793 }
10794 } else {
10795 if (IsInteger) {
10796 if (X.IsSigned)
10797 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Max
10798 : AtomicRMWInst::Min;
10799 else
10800 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMax
10801 : AtomicRMWInst::UMin;
10802 } else {
10803 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMax
10804 : AtomicRMWInst::FMin;
10805 }
10806 }
10807
10808 AtomicRMWInst *OldValue =
10809 Builder.CreateAtomicRMW(Op: NewOp, Ptr: X.Var, Val: E, Align: MaybeAlign(), Ordering: AO);
10810 if (V.Var) {
10811 Value *CapturedValue = nullptr;
10812 if (IsPostfixUpdate) {
10813 CapturedValue = OldValue;
10814 } else {
10815 CmpInst::Predicate Pred;
10816 switch (NewOp) {
10817 case AtomicRMWInst::Max:
10818 Pred = CmpInst::ICMP_SGT;
10819 break;
10820 case AtomicRMWInst::UMax:
10821 Pred = CmpInst::ICMP_UGT;
10822 break;
10823 case AtomicRMWInst::FMax:
10824 Pred = CmpInst::FCMP_OGT;
10825 break;
10826 case AtomicRMWInst::Min:
10827 Pred = CmpInst::ICMP_SLT;
10828 break;
10829 case AtomicRMWInst::UMin:
10830 Pred = CmpInst::ICMP_ULT;
10831 break;
10832 case AtomicRMWInst::FMin:
10833 Pred = CmpInst::FCMP_OLT;
10834 break;
10835 default:
10836 llvm_unreachable("unexpected comparison op");
10837 }
10838 Value *NonAtomicCmp = Builder.CreateCmp(Pred, LHS: OldValue, RHS: E);
10839 CapturedValue = Builder.CreateSelect(C: NonAtomicCmp, True: E, False: OldValue);
10840 }
10841 Builder.CreateStore(Val: CapturedValue, Ptr: V.Var, isVolatile: V.IsVolatile);
10842 }
10843 }
10844
10845 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Compare);
10846
10847 return Builder.saveIP();
10848}
10849
10850OpenMPIRBuilder::InsertPointOrErrorTy
10851OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
10852 BodyGenCallbackTy BodyGenCB, Value *NumTeamsLower,
10853 Value *NumTeamsUpper, Value *ThreadLimit,
10854 Value *IfExpr) {
10855 if (!updateToLocation(Loc))
10856 return InsertPointTy();
10857
10858 uint32_t SrcLocStrSize;
10859 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
10860 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
10861 Function *CurrentFunction = Builder.GetInsertBlock()->getParent();
10862
10863 // Outer allocation basicblock is the entry block of the current function.
10864 BasicBlock &OuterAllocaBB = CurrentFunction->getEntryBlock();
10865 if (&OuterAllocaBB == Builder.GetInsertBlock()) {
10866 BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, Name: "teams.entry");
10867 Builder.SetInsertPoint(TheBB: BodyBB, IP: BodyBB->begin());
10868 }
10869
10870 // The current basic block is split into four basic blocks. After outlining,
10871 // they will be mapped as follows:
10872 // ```
10873 // def current_fn() {
10874 // current_basic_block:
10875 // br label %teams.exit
10876 // teams.exit:
10877 // ; instructions after teams
10878 // }
10879 //
10880 // def outlined_fn() {
10881 // teams.alloca:
10882 // br label %teams.body
10883 // teams.body:
10884 // ; instructions within teams body
10885 // }
10886 // ```
10887 BasicBlock *ExitBB = splitBB(Builder, /*CreateBranch=*/true, Name: "teams.exit");
10888 BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, Name: "teams.body");
10889 BasicBlock *AllocaBB =
10890 splitBB(Builder, /*CreateBranch=*/true, Name: "teams.alloca");
10891
10892 bool SubClausesPresent =
10893 (NumTeamsLower || NumTeamsUpper || ThreadLimit || IfExpr);
10894 // Push num_teams
10895 if (!Config.isTargetDevice() && SubClausesPresent) {
10896 assert((NumTeamsLower == nullptr || NumTeamsUpper != nullptr) &&
10897 "if lowerbound is non-null, then upperbound must also be non-null "
10898 "for bounds on num_teams");
10899
10900 if (NumTeamsUpper == nullptr)
10901 NumTeamsUpper = Builder.getInt32(C: 0);
10902
10903 if (NumTeamsLower == nullptr)
10904 NumTeamsLower = NumTeamsUpper;
10905
10906 if (IfExpr) {
10907 assert(IfExpr->getType()->isIntegerTy() &&
10908 "argument to if clause must be an integer value");
10909
10910 // upper = ifexpr ? upper : 1
10911 if (IfExpr->getType() != Int1)
10912 IfExpr = Builder.CreateICmpNE(LHS: IfExpr,
10913 RHS: ConstantInt::get(Ty: IfExpr->getType(), V: 0));
10914 NumTeamsUpper = Builder.CreateSelect(
10915 C: IfExpr, True: NumTeamsUpper, False: Builder.getInt32(C: 1), Name: "numTeamsUpper");
10916
10917 // lower = ifexpr ? lower : 1
10918 NumTeamsLower = Builder.CreateSelect(
10919 C: IfExpr, True: NumTeamsLower, False: Builder.getInt32(C: 1), Name: "numTeamsLower");
10920 }
10921
10922 if (ThreadLimit == nullptr)
10923 ThreadLimit = Builder.getInt32(C: 0);
10924
10925 // The __kmpc_push_num_teams_51 function expects int32 as the arguments. So,
10926 // truncate or sign extend the passed values to match the int32 parameters.
10927 Value *NumTeamsLowerInt32 =
10928 Builder.CreateSExtOrTrunc(V: NumTeamsLower, DestTy: Builder.getInt32Ty());
10929 Value *NumTeamsUpperInt32 =
10930 Builder.CreateSExtOrTrunc(V: NumTeamsUpper, DestTy: Builder.getInt32Ty());
10931 Value *ThreadLimitInt32 =
10932 Builder.CreateSExtOrTrunc(V: ThreadLimit, DestTy: Builder.getInt32Ty());
10933
10934 Value *ThreadNum = getOrCreateThreadID(Ident);
10935
10936 createRuntimeFunctionCall(
10937 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_push_num_teams_51),
10938 Args: {Ident, ThreadNum, NumTeamsLowerInt32, NumTeamsUpperInt32,
10939 ThreadLimitInt32});
10940 }
10941 // Generate the body of teams.
10942 InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
10943 InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
10944 if (Error Err = BodyGenCB(AllocaIP, CodeGenIP))
10945 return Err;
10946
10947 OutlineInfo OI;
10948 OI.EntryBB = AllocaBB;
10949 OI.ExitBB = ExitBB;
10950 OI.OuterAllocaBB = &OuterAllocaBB;
10951
10952 // Insert fake values for global tid and bound tid.
10953 SmallVector<Instruction *, 8> ToBeDeleted;
10954 InsertPointTy OuterAllocaIP(&OuterAllocaBB, OuterAllocaBB.begin());
10955 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
10956 Builder, OuterAllocaIP, ToBeDeleted, InnerAllocaIP: AllocaIP, Name: "gid", AsPtr: true));
10957 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
10958 Builder, OuterAllocaIP, ToBeDeleted, InnerAllocaIP: AllocaIP, Name: "tid", AsPtr: true));
10959
10960 auto HostPostOutlineCB = [this, Ident,
10961 ToBeDeleted](Function &OutlinedFn) mutable {
10962 // The stale call instruction will be replaced with a new call instruction
10963 // for runtime call with the outlined function.
10964
10965 assert(OutlinedFn.hasOneUse() &&
10966 "there must be a single user for the outlined function");
10967 CallInst *StaleCI = cast<CallInst>(Val: OutlinedFn.user_back());
10968 ToBeDeleted.push_back(Elt: StaleCI);
10969
10970 assert((OutlinedFn.arg_size() == 2 || OutlinedFn.arg_size() == 3) &&
10971 "Outlined function must have two or three arguments only");
10972
10973 bool HasShared = OutlinedFn.arg_size() == 3;
10974
10975 OutlinedFn.getArg(i: 0)->setName("global.tid.ptr");
10976 OutlinedFn.getArg(i: 1)->setName("bound.tid.ptr");
10977 if (HasShared)
10978 OutlinedFn.getArg(i: 2)->setName("data");
10979
10980 // Call to the runtime function for teams in the current function.
10981 assert(StaleCI && "Error while outlining - no CallInst user found for the "
10982 "outlined function.");
10983 Builder.SetInsertPoint(StaleCI);
10984 SmallVector<Value *> Args = {
10985 Ident, Builder.getInt32(C: StaleCI->arg_size() - 2), &OutlinedFn};
10986 if (HasShared)
10987 Args.push_back(Elt: StaleCI->getArgOperand(i: 2));
10988 createRuntimeFunctionCall(
10989 Callee: getOrCreateRuntimeFunctionPtr(
10990 FnID: omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
10991 Args);
10992
10993 for (Instruction *I : llvm::reverse(C&: ToBeDeleted))
10994 I->eraseFromParent();
10995 };
10996
10997 if (!Config.isTargetDevice())
10998 OI.PostOutlineCB = HostPostOutlineCB;
10999
11000 addOutlineInfo(OI: std::move(OI));
11001
11002 Builder.SetInsertPoint(TheBB: ExitBB, IP: ExitBB->begin());
11003
11004 return Builder.saveIP();
11005}
11006
11007OpenMPIRBuilder::InsertPointOrErrorTy
11008OpenMPIRBuilder::createDistribute(const LocationDescription &Loc,
11009 InsertPointTy OuterAllocaIP,
11010 BodyGenCallbackTy BodyGenCB) {
11011 if (!updateToLocation(Loc))
11012 return InsertPointTy();
11013
11014 BasicBlock *OuterAllocaBB = OuterAllocaIP.getBlock();
11015
11016 if (OuterAllocaBB == Builder.GetInsertBlock()) {
11017 BasicBlock *BodyBB =
11018 splitBB(Builder, /*CreateBranch=*/true, Name: "distribute.entry");
11019 Builder.SetInsertPoint(TheBB: BodyBB, IP: BodyBB->begin());
11020 }
11021 BasicBlock *ExitBB =
11022 splitBB(Builder, /*CreateBranch=*/true, Name: "distribute.exit");
11023 BasicBlock *BodyBB =
11024 splitBB(Builder, /*CreateBranch=*/true, Name: "distribute.body");
11025 BasicBlock *AllocaBB =
11026 splitBB(Builder, /*CreateBranch=*/true, Name: "distribute.alloca");
11027
11028 // Generate the body of distribute clause
11029 InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
11030 InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
11031 if (Error Err = BodyGenCB(AllocaIP, CodeGenIP))
11032 return Err;
11033
11034 // When using target we use different runtime functions which require a
11035 // callback.
11036 if (Config.isTargetDevice()) {
11037 OutlineInfo OI;
11038 OI.OuterAllocaBB = OuterAllocaIP.getBlock();
11039 OI.EntryBB = AllocaBB;
11040 OI.ExitBB = ExitBB;
11041
11042 addOutlineInfo(OI: std::move(OI));
11043 }
11044 Builder.SetInsertPoint(TheBB: ExitBB, IP: ExitBB->begin());
11045
11046 return Builder.saveIP();
11047}
11048
11049GlobalVariable *
11050OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl<llvm::Constant *> &Names,
11051 std::string VarName) {
11052 llvm::Constant *MapNamesArrayInit = llvm::ConstantArray::get(
11053 T: llvm::ArrayType::get(ElementType: llvm::PointerType::getUnqual(C&: M.getContext()),
11054 NumElements: Names.size()),
11055 V: Names);
11056 auto *MapNamesArrayGlobal = new llvm::GlobalVariable(
11057 M, MapNamesArrayInit->getType(),
11058 /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MapNamesArrayInit,
11059 VarName);
11060 return MapNamesArrayGlobal;
11061}
11062
11063// Create all simple and struct types exposed by the runtime and remember
11064// the llvm::PointerTypes of them for easy access later.
11065void OpenMPIRBuilder::initializeTypes(Module &M) {
11066 LLVMContext &Ctx = M.getContext();
11067 StructType *T;
11068 unsigned DefaultTargetAS = Config.getDefaultTargetAS();
11069 unsigned ProgramAS = M.getDataLayout().getProgramAddressSpace();
11070#define OMP_TYPE(VarName, InitValue) VarName = InitValue;
11071#define OMP_ARRAY_TYPE(VarName, ElemTy, ArraySize) \
11072 VarName##Ty = ArrayType::get(ElemTy, ArraySize); \
11073 VarName##PtrTy = PointerType::get(Ctx, DefaultTargetAS);
11074#define OMP_FUNCTION_TYPE(VarName, IsVarArg, ReturnType, ...) \
11075 VarName = FunctionType::get(ReturnType, {__VA_ARGS__}, IsVarArg); \
11076 VarName##Ptr = PointerType::get(Ctx, ProgramAS);
11077#define OMP_STRUCT_TYPE(VarName, StructName, Packed, ...) \
11078 T = StructType::getTypeByName(Ctx, StructName); \
11079 if (!T) \
11080 T = StructType::create(Ctx, {__VA_ARGS__}, StructName, Packed); \
11081 VarName = T; \
11082 VarName##Ptr = PointerType::get(Ctx, DefaultTargetAS);
11083#include "llvm/Frontend/OpenMP/OMPKinds.def"
11084}
11085
11086void OpenMPIRBuilder::OutlineInfo::collectBlocks(
11087 SmallPtrSetImpl<BasicBlock *> &BlockSet,
11088 SmallVectorImpl<BasicBlock *> &BlockVector) {
11089 SmallVector<BasicBlock *, 32> Worklist;
11090 BlockSet.insert(Ptr: EntryBB);
11091 BlockSet.insert(Ptr: ExitBB);
11092
11093 Worklist.push_back(Elt: EntryBB);
11094 while (!Worklist.empty()) {
11095 BasicBlock *BB = Worklist.pop_back_val();
11096 BlockVector.push_back(Elt: BB);
11097 for (BasicBlock *SuccBB : successors(BB))
11098 if (BlockSet.insert(Ptr: SuccBB).second)
11099 Worklist.push_back(Elt: SuccBB);
11100 }
11101}
11102
11103void OpenMPIRBuilder::createOffloadEntry(Constant *ID, Constant *Addr,
11104 uint64_t Size, int32_t Flags,
11105 GlobalValue::LinkageTypes,
11106 StringRef Name) {
11107 if (!Config.isGPU()) {
11108 llvm::offloading::emitOffloadingEntry(
11109 M, Kind: object::OffloadKind::OFK_OpenMP, Addr: ID,
11110 Name: Name.empty() ? Addr->getName() : Name, Size, Flags, /*Data=*/0);
11111 return;
11112 }
11113 // TODO: Add support for global variables on the device after declare target
11114 // support.
11115 Function *Fn = dyn_cast<Function>(Val: Addr);
11116 if (!Fn)
11117 return;
11118
11119 // Add a function attribute for the kernel.
11120 Fn->addFnAttr(Kind: "kernel");
11121 if (T.isAMDGCN())
11122 Fn->addFnAttr(Kind: "uniform-work-group-size");
11123 Fn->addFnAttr(Kind: Attribute::MustProgress);
11124}
11125
11126// We only generate metadata for function that contain target regions.
11127void OpenMPIRBuilder::createOffloadEntriesAndInfoMetadata(
11128 EmitMetadataErrorReportFunctionTy &ErrorFn) {
11129
11130 // If there are no entries, we don't need to do anything.
11131 if (OffloadInfoManager.empty())
11132 return;
11133
11134 LLVMContext &C = M.getContext();
11135 SmallVector<std::pair<const OffloadEntriesInfoManager::OffloadEntryInfo *,
11136 TargetRegionEntryInfo>,
11137 16>
11138 OrderedEntries(OffloadInfoManager.size());
11139
11140 // Auxiliary methods to create metadata values and strings.
11141 auto &&GetMDInt = [this](unsigned V) {
11142 return ConstantAsMetadata::get(C: ConstantInt::get(Ty: Builder.getInt32Ty(), V));
11143 };
11144
11145 auto &&GetMDString = [&C](StringRef V) { return MDString::get(Context&: C, Str: V); };
11146
11147 // Create the offloading info metadata node.
11148 NamedMDNode *MD = M.getOrInsertNamedMetadata(Name: "omp_offload.info");
11149 auto &&TargetRegionMetadataEmitter =
11150 [&C, MD, &OrderedEntries, &GetMDInt, &GetMDString](
11151 const TargetRegionEntryInfo &EntryInfo,
11152 const OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion &E) {
11153 // Generate metadata for target regions. Each entry of this metadata
11154 // contains:
11155 // - Entry 0 -> Kind of this type of metadata (0).
11156 // - Entry 1 -> Device ID of the file where the entry was identified.
11157 // - Entry 2 -> File ID of the file where the entry was identified.
11158 // - Entry 3 -> Mangled name of the function where the entry was
11159 // identified.
11160 // - Entry 4 -> Line in the file where the entry was identified.
11161 // - Entry 5 -> Count of regions at this DeviceID/FilesID/Line.
11162 // - Entry 6 -> Order the entry was created.
11163 // The first element of the metadata node is the kind.
11164 Metadata *Ops[] = {
11165 GetMDInt(E.getKind()), GetMDInt(EntryInfo.DeviceID),
11166 GetMDInt(EntryInfo.FileID), GetMDString(EntryInfo.ParentName),
11167 GetMDInt(EntryInfo.Line), GetMDInt(EntryInfo.Count),
11168 GetMDInt(E.getOrder())};
11169
11170 // Save this entry in the right position of the ordered entries array.
11171 OrderedEntries[E.getOrder()] = std::make_pair(x: &E, y: EntryInfo);
11172
11173 // Add metadata to the named metadata node.
11174 MD->addOperand(M: MDNode::get(Context&: C, MDs: Ops));
11175 };
11176
11177 OffloadInfoManager.actOnTargetRegionEntriesInfo(Action: TargetRegionMetadataEmitter);
11178
11179 // Create function that emits metadata for each device global variable entry;
11180 auto &&DeviceGlobalVarMetadataEmitter =
11181 [&C, &OrderedEntries, &GetMDInt, &GetMDString, MD](
11182 StringRef MangledName,
11183 const OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar &E) {
11184 // Generate metadata for global variables. Each entry of this metadata
11185 // contains:
11186 // - Entry 0 -> Kind of this type of metadata (1).
11187 // - Entry 1 -> Mangled name of the variable.
11188 // - Entry 2 -> Declare target kind.
11189 // - Entry 3 -> Order the entry was created.
11190 // The first element of the metadata node is the kind.
11191 Metadata *Ops[] = {GetMDInt(E.getKind()), GetMDString(MangledName),
11192 GetMDInt(E.getFlags()), GetMDInt(E.getOrder())};
11193
11194 // Save this entry in the right position of the ordered entries array.
11195 TargetRegionEntryInfo varInfo(MangledName, 0, 0, 0);
11196 OrderedEntries[E.getOrder()] = std::make_pair(x: &E, y&: varInfo);
11197
11198 // Add metadata to the named metadata node.
11199 MD->addOperand(M: MDNode::get(Context&: C, MDs: Ops));
11200 };
11201
11202 OffloadInfoManager.actOnDeviceGlobalVarEntriesInfo(
11203 Action: DeviceGlobalVarMetadataEmitter);
11204
11205 for (const auto &E : OrderedEntries) {
11206 assert(E.first && "All ordered entries must exist!");
11207 if (const auto *CE =
11208 dyn_cast<OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion>(
11209 Val: E.first)) {
11210 if (!CE->getID() || !CE->getAddress()) {
11211 // Do not blame the entry if the parent funtion is not emitted.
11212 TargetRegionEntryInfo EntryInfo = E.second;
11213 StringRef FnName = EntryInfo.ParentName;
11214 if (!M.getNamedValue(Name: FnName))
11215 continue;
11216 ErrorFn(EMIT_MD_TARGET_REGION_ERROR, EntryInfo);
11217 continue;
11218 }
11219 createOffloadEntry(ID: CE->getID(), Addr: CE->getAddress(),
11220 /*Size=*/0, Flags: CE->getFlags(),
11221 GlobalValue::WeakAnyLinkage);
11222 } else if (const auto *CE = dyn_cast<
11223 OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar>(
11224 Val: E.first)) {
11225 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags =
11226 static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
11227 CE->getFlags());
11228 switch (Flags) {
11229 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter:
11230 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo:
11231 if (Config.isTargetDevice() && Config.hasRequiresUnifiedSharedMemory())
11232 continue;
11233 if (!CE->getAddress()) {
11234 ErrorFn(EMIT_MD_DECLARE_TARGET_ERROR, E.second);
11235 continue;
11236 }
11237 // The vaiable has no definition - no need to add the entry.
11238 if (CE->getVarSize() == 0)
11239 continue;
11240 break;
11241 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink:
11242 assert(((Config.isTargetDevice() && !CE->getAddress()) ||
11243 (!Config.isTargetDevice() && CE->getAddress())) &&
11244 "Declaret target link address is set.");
11245 if (Config.isTargetDevice())
11246 continue;
11247 if (!CE->getAddress()) {
11248 ErrorFn(EMIT_MD_GLOBAL_VAR_LINK_ERROR, TargetRegionEntryInfo());
11249 continue;
11250 }
11251 break;
11252 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect:
11253 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirectVTable:
11254 if (!CE->getAddress()) {
11255 ErrorFn(EMIT_MD_GLOBAL_VAR_INDIRECT_ERROR, E.second);
11256 continue;
11257 }
11258 break;
11259 default:
11260 break;
11261 }
11262
11263 // Hidden or internal symbols on the device are not externally visible.
11264 // We should not attempt to register them by creating an offloading
11265 // entry. Indirect variables are handled separately on the device.
11266 if (auto *GV = dyn_cast<GlobalValue>(Val: CE->getAddress()))
11267 if ((GV->hasLocalLinkage() || GV->hasHiddenVisibility()) &&
11268 (Flags !=
11269 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect &&
11270 Flags != OffloadEntriesInfoManager::
11271 OMPTargetGlobalVarEntryIndirectVTable))
11272 continue;
11273
11274 // Indirect globals need to use a special name that doesn't match the name
11275 // of the associated host global.
11276 if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect ||
11277 Flags ==
11278 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirectVTable)
11279 createOffloadEntry(ID: CE->getAddress(), Addr: CE->getAddress(), Size: CE->getVarSize(),
11280 Flags, CE->getLinkage(), Name: CE->getVarName());
11281 else
11282 createOffloadEntry(ID: CE->getAddress(), Addr: CE->getAddress(), Size: CE->getVarSize(),
11283 Flags, CE->getLinkage());
11284
11285 } else {
11286 llvm_unreachable("Unsupported entry kind.");
11287 }
11288 }
11289
11290 // Emit requires directive globals to a special entry so the runtime can
11291 // register them when the device image is loaded.
11292 // TODO: This reduces the offloading entries to a 32-bit integer. Offloading
11293 // entries should be redesigned to better suit this use-case.
11294 if (Config.hasRequiresFlags() && !Config.isTargetDevice())
11295 offloading::emitOffloadingEntry(
11296 M, Kind: object::OffloadKind::OFK_OpenMP,
11297 Addr: Constant::getNullValue(Ty: PointerType::getUnqual(C&: M.getContext())),
11298 Name: ".requires", /*Size=*/0,
11299 Flags: OffloadEntriesInfoManager::OMPTargetGlobalRegisterRequires,
11300 Data: Config.getRequiresFlags());
11301}
11302
11303void TargetRegionEntryInfo::getTargetRegionEntryFnName(
11304 SmallVectorImpl<char> &Name, StringRef ParentName, unsigned DeviceID,
11305 unsigned FileID, unsigned Line, unsigned Count) {
11306 raw_svector_ostream OS(Name);
11307 OS << KernelNamePrefix << llvm::format(Fmt: "%x", Vals: DeviceID)
11308 << llvm::format(Fmt: "_%x_", Vals: FileID) << ParentName << "_l" << Line;
11309 if (Count)
11310 OS << "_" << Count;
11311}
11312
11313void OffloadEntriesInfoManager::getTargetRegionEntryFnName(
11314 SmallVectorImpl<char> &Name, const TargetRegionEntryInfo &EntryInfo) {
11315 unsigned NewCount = getTargetRegionEntryInfoCount(EntryInfo);
11316 TargetRegionEntryInfo::getTargetRegionEntryFnName(
11317 Name, ParentName: EntryInfo.ParentName, DeviceID: EntryInfo.DeviceID, FileID: EntryInfo.FileID,
11318 Line: EntryInfo.Line, Count: NewCount);
11319}
11320
11321TargetRegionEntryInfo
11322OpenMPIRBuilder::getTargetEntryUniqueInfo(FileIdentifierInfoCallbackTy CallBack,
11323 vfs::FileSystem &VFS,
11324 StringRef ParentName) {
11325 sys::fs::UniqueID ID(0xdeadf17e, 0);
11326 auto FileIDInfo = CallBack();
11327 uint64_t FileID = 0;
11328 if (ErrorOr<vfs::Status> Status = VFS.status(Path: std::get<0>(t&: FileIDInfo))) {
11329 ID = Status->getUniqueID();
11330 FileID = Status->getUniqueID().getFile();
11331 } else {
11332 // If the inode ID could not be determined, create a hash value
11333 // the current file name and use that as an ID.
11334 FileID = hash_value(arg: std::get<0>(t&: FileIDInfo));
11335 }
11336
11337 return TargetRegionEntryInfo(ParentName, ID.getDevice(), FileID,
11338 std::get<1>(t&: FileIDInfo));
11339}
11340
11341unsigned OpenMPIRBuilder::getFlagMemberOffset() {
11342 unsigned Offset = 0;
11343 for (uint64_t Remain =
11344 static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
11345 omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF);
11346 !(Remain & 1); Remain = Remain >> 1)
11347 Offset++;
11348 return Offset;
11349}
11350
11351omp::OpenMPOffloadMappingFlags
11352OpenMPIRBuilder::getMemberOfFlag(unsigned Position) {
11353 // Rotate by getFlagMemberOffset() bits.
11354 return static_cast<omp::OpenMPOffloadMappingFlags>(((uint64_t)Position + 1)
11355 << getFlagMemberOffset());
11356}
11357
11358void OpenMPIRBuilder::setCorrectMemberOfFlag(
11359 omp::OpenMPOffloadMappingFlags &Flags,
11360 omp::OpenMPOffloadMappingFlags MemberOfFlag) {
11361 // If the entry is PTR_AND_OBJ but has not been marked with the special
11362 // placeholder value 0xFFFF in the MEMBER_OF field, then it should not be
11363 // marked as MEMBER_OF.
11364 if (static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
11365 Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ) &&
11366 static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
11367 (Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF) !=
11368 omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF))
11369 return;
11370
11371 // Entries with ATTACH are not members-of anything. They are handled
11372 // separately by the runtime after other maps have been handled.
11373 if (static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
11374 Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH))
11375 return;
11376
11377 // Reset the placeholder value to prepare the flag for the assignment of the
11378 // proper MEMBER_OF value.
11379 Flags &= ~omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
11380 Flags |= MemberOfFlag;
11381}
11382
11383Constant *OpenMPIRBuilder::getAddrOfDeclareTargetVar(
11384 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
11385 OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
11386 bool IsDeclaration, bool IsExternallyVisible,
11387 TargetRegionEntryInfo EntryInfo, StringRef MangledName,
11388 std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
11389 std::vector<Triple> TargetTriple, Type *LlvmPtrTy,
11390 std::function<Constant *()> GlobalInitializer,
11391 std::function<GlobalValue::LinkageTypes()> VariableLinkage) {
11392 // TODO: convert this to utilise the IRBuilder Config rather than
11393 // a passed down argument.
11394 if (OpenMPSIMD)
11395 return nullptr;
11396
11397 if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink ||
11398 ((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
11399 CaptureClause ==
11400 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
11401 Config.hasRequiresUnifiedSharedMemory())) {
11402 SmallString<64> PtrName;
11403 {
11404 raw_svector_ostream OS(PtrName);
11405 OS << MangledName;
11406 if (!IsExternallyVisible)
11407 OS << format(Fmt: "_%x", Vals: EntryInfo.FileID);
11408 OS << "_decl_tgt_ref_ptr";
11409 }
11410
11411 Value *Ptr = M.getNamedValue(Name: PtrName);
11412
11413 if (!Ptr) {
11414 GlobalValue *GlobalValue = M.getNamedValue(Name: MangledName);
11415 Ptr = getOrCreateInternalVariable(Ty: LlvmPtrTy, Name: PtrName);
11416
11417 auto *GV = cast<GlobalVariable>(Val: Ptr);
11418 GV->setLinkage(GlobalValue::WeakAnyLinkage);
11419
11420 if (!Config.isTargetDevice()) {
11421 if (GlobalInitializer)
11422 GV->setInitializer(GlobalInitializer());
11423 else
11424 GV->setInitializer(GlobalValue);
11425 }
11426
11427 registerTargetGlobalVariable(
11428 CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
11429 EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
11430 GlobalInitializer, VariableLinkage, LlvmPtrTy, Addr: cast<Constant>(Val: Ptr));
11431 }
11432
11433 return cast<Constant>(Val: Ptr);
11434 }
11435
11436 return nullptr;
11437}
11438
11439void OpenMPIRBuilder::registerTargetGlobalVariable(
11440 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
11441 OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
11442 bool IsDeclaration, bool IsExternallyVisible,
11443 TargetRegionEntryInfo EntryInfo, StringRef MangledName,
11444 std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
11445 std::vector<Triple> TargetTriple,
11446 std::function<Constant *()> GlobalInitializer,
11447 std::function<GlobalValue::LinkageTypes()> VariableLinkage, Type *LlvmPtrTy,
11448 Constant *Addr) {
11449 if (DeviceClause != OffloadEntriesInfoManager::OMPTargetDeviceClauseAny ||
11450 (TargetTriple.empty() && !Config.isTargetDevice()))
11451 return;
11452
11453 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags;
11454 StringRef VarName;
11455 int64_t VarSize;
11456 GlobalValue::LinkageTypes Linkage;
11457
11458 if ((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
11459 CaptureClause ==
11460 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
11461 !Config.hasRequiresUnifiedSharedMemory()) {
11462 Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
11463 VarName = MangledName;
11464 GlobalValue *LlvmVal = M.getNamedValue(Name: VarName);
11465
11466 if (!IsDeclaration)
11467 VarSize = divideCeil(
11468 Numerator: M.getDataLayout().getTypeSizeInBits(Ty: LlvmVal->getValueType()), Denominator: 8);
11469 else
11470 VarSize = 0;
11471 Linkage = (VariableLinkage) ? VariableLinkage() : LlvmVal->getLinkage();
11472
11473 // This is a workaround carried over from Clang which prevents undesired
11474 // optimisation of internal variables.
11475 if (Config.isTargetDevice() &&
11476 (!IsExternallyVisible || Linkage == GlobalValue::LinkOnceODRLinkage)) {
11477 // Do not create a "ref-variable" if the original is not also available
11478 // on the host.
11479 if (!OffloadInfoManager.hasDeviceGlobalVarEntryInfo(VarName))
11480 return;
11481
11482 std::string RefName = createPlatformSpecificName(Parts: {VarName, "ref"});
11483
11484 if (!M.getNamedValue(Name: RefName)) {
11485 Constant *AddrRef =
11486 getOrCreateInternalVariable(Ty: Addr->getType(), Name: RefName);
11487 auto *GvAddrRef = cast<GlobalVariable>(Val: AddrRef);
11488 GvAddrRef->setConstant(true);
11489 GvAddrRef->setLinkage(GlobalValue::InternalLinkage);
11490 GvAddrRef->setInitializer(Addr);
11491 GeneratedRefs.push_back(x: GvAddrRef);
11492 }
11493 }
11494 } else {
11495 if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink)
11496 Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
11497 else
11498 Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
11499
11500 if (Config.isTargetDevice()) {
11501 VarName = (Addr) ? Addr->getName() : "";
11502 Addr = nullptr;
11503 } else {
11504 Addr = getAddrOfDeclareTargetVar(
11505 CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
11506 EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
11507 LlvmPtrTy, GlobalInitializer, VariableLinkage);
11508 VarName = (Addr) ? Addr->getName() : "";
11509 }
11510 VarSize = M.getDataLayout().getPointerSize();
11511 Linkage = GlobalValue::WeakAnyLinkage;
11512 }
11513
11514 OffloadInfoManager.registerDeviceGlobalVarEntryInfo(VarName, Addr, VarSize,
11515 Flags, Linkage);
11516}
11517
11518/// Loads all the offload entries information from the host IR
11519/// metadata.
11520void OpenMPIRBuilder::loadOffloadInfoMetadata(Module &M) {
11521 // If we are in target mode, load the metadata from the host IR. This code has
11522 // to match the metadata creation in createOffloadEntriesAndInfoMetadata().
11523
11524 NamedMDNode *MD = M.getNamedMetadata(Name: ompOffloadInfoName);
11525 if (!MD)
11526 return;
11527
11528 for (MDNode *MN : MD->operands()) {
11529 auto &&GetMDInt = [MN](unsigned Idx) {
11530 auto *V = cast<ConstantAsMetadata>(Val: MN->getOperand(I: Idx));
11531 return cast<ConstantInt>(Val: V->getValue())->getZExtValue();
11532 };
11533
11534 auto &&GetMDString = [MN](unsigned Idx) {
11535 auto *V = cast<MDString>(Val: MN->getOperand(I: Idx));
11536 return V->getString();
11537 };
11538
11539 switch (GetMDInt(0)) {
11540 default:
11541 llvm_unreachable("Unexpected metadata!");
11542 break;
11543 case OffloadEntriesInfoManager::OffloadEntryInfo::
11544 OffloadingEntryInfoTargetRegion: {
11545 TargetRegionEntryInfo EntryInfo(/*ParentName=*/GetMDString(3),
11546 /*DeviceID=*/GetMDInt(1),
11547 /*FileID=*/GetMDInt(2),
11548 /*Line=*/GetMDInt(4),
11549 /*Count=*/GetMDInt(5));
11550 OffloadInfoManager.initializeTargetRegionEntryInfo(EntryInfo,
11551 /*Order=*/GetMDInt(6));
11552 break;
11553 }
11554 case OffloadEntriesInfoManager::OffloadEntryInfo::
11555 OffloadingEntryInfoDeviceGlobalVar:
11556 OffloadInfoManager.initializeDeviceGlobalVarEntryInfo(
11557 /*MangledName=*/Name: GetMDString(1),
11558 Flags: static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
11559 /*Flags=*/GetMDInt(2)),
11560 /*Order=*/GetMDInt(3));
11561 break;
11562 }
11563 }
11564}
11565
11566void OpenMPIRBuilder::loadOffloadInfoMetadata(vfs::FileSystem &VFS,
11567 StringRef HostFilePath) {
11568 if (HostFilePath.empty())
11569 return;
11570
11571 auto Buf = VFS.getBufferForFile(Name: HostFilePath);
11572 if (std::error_code Err = Buf.getError()) {
11573 report_fatal_error(reason: ("error opening host file from host file path inside of "
11574 "OpenMPIRBuilder: " +
11575 Err.message())
11576 .c_str());
11577 }
11578
11579 LLVMContext Ctx;
11580 auto M = expectedToErrorOrAndEmitErrors(
11581 Ctx, Val: parseBitcodeFile(Buffer: Buf.get()->getMemBufferRef(), Context&: Ctx));
11582 if (std::error_code Err = M.getError()) {
11583 report_fatal_error(
11584 reason: ("error parsing host file inside of OpenMPIRBuilder: " + Err.message())
11585 .c_str());
11586 }
11587
11588 loadOffloadInfoMetadata(M&: *M.get());
11589}
11590
11591OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createIteratorLoop(
11592 LocationDescription Loc, llvm::Value *TripCount, IteratorBodyGenTy BodyGen,
11593 llvm::StringRef Name) {
11594 Builder.restoreIP(IP: Loc.IP);
11595
11596 BasicBlock *CurBB = Builder.GetInsertBlock();
11597 assert(CurBB &&
11598 "expected a valid insertion block for creating an iterator loop");
11599 Function *F = CurBB->getParent();
11600
11601 InsertPointTy SplitIP = Builder.saveIP();
11602 if (SplitIP.getPoint() == CurBB->end())
11603 if (Instruction *Terminator = CurBB->getTerminator())
11604 SplitIP = InsertPointTy(CurBB, Terminator->getIterator());
11605
11606 BasicBlock *ContBB =
11607 splitBB(IP: SplitIP, /*CreateBranch=*/false,
11608 DL: Builder.getCurrentDebugLocation(), Name: "omp.it.cont");
11609
11610 CanonicalLoopInfo *CLI =
11611 createLoopSkeleton(DL: Builder.getCurrentDebugLocation(), TripCount, F,
11612 /*PreInsertBefore=*/ContBB,
11613 /*PostInsertBefore=*/ContBB, Name);
11614
11615 // Enter loop from original block.
11616 redirectTo(Source: CurBB, Target: CLI->getPreheader(), DL: Builder.getCurrentDebugLocation());
11617
11618 // Remove the unconditional branch inserted by createLoopSkeleton in the body
11619 if (Instruction *T = CLI->getBody()->getTerminator())
11620 T->eraseFromParent();
11621
11622 InsertPointTy BodyIP = CLI->getBodyIP();
11623 if (llvm::Error Err = BodyGen(BodyIP, CLI->getIndVar()))
11624 return Err;
11625
11626 // Body must either fallthrough to the latch or branch directly to it.
11627 if (Instruction *BodyTerminator = CLI->getBody()->getTerminator()) {
11628 auto *BodyBr = dyn_cast<BranchInst>(Val: BodyTerminator);
11629 if (!BodyBr || !BodyBr->isUnconditional() ||
11630 BodyBr->getSuccessor(Idx: 0) != CLI->getLatch()) {
11631 return make_error<StringError>(
11632 Args: "iterator bodygen must terminate the canonical body with an "
11633 "unconditional branch to the loop latch",
11634 Args: inconvertibleErrorCode());
11635 }
11636 } else {
11637 // Ensure we end the loop body by jumping to the latch.
11638 Builder.SetInsertPoint(CLI->getBody());
11639 Builder.CreateBr(Dest: CLI->getLatch());
11640 }
11641
11642 // Link After -> ContBB
11643 Builder.SetInsertPoint(TheBB: CLI->getAfter(), IP: CLI->getAfter()->begin());
11644 if (!CLI->getAfter()->getTerminator())
11645 Builder.CreateBr(Dest: ContBB);
11646
11647 return InsertPointTy{ContBB, ContBB->begin()};
11648}
11649
11650/// Mangle the parameter part of the vector function name according to
11651/// their OpenMP classification. The mangling function is defined in
11652/// section 4.5 of the AAVFABI(2021Q1).
11653static std::string mangleVectorParameters(
11654 ArrayRef<llvm::OpenMPIRBuilder::DeclareSimdAttrTy> ParamAttrs) {
11655 SmallString<256> Buffer;
11656 llvm::raw_svector_ostream Out(Buffer);
11657 for (const auto &ParamAttr : ParamAttrs) {
11658 switch (ParamAttr.Kind) {
11659 case llvm::OpenMPIRBuilder::DeclareSimdKindTy::Linear:
11660 Out << 'l';
11661 break;
11662 case llvm::OpenMPIRBuilder::DeclareSimdKindTy::LinearRef:
11663 Out << 'R';
11664 break;
11665 case llvm::OpenMPIRBuilder::DeclareSimdKindTy::LinearUVal:
11666 Out << 'U';
11667 break;
11668 case llvm::OpenMPIRBuilder::DeclareSimdKindTy::LinearVal:
11669 Out << 'L';
11670 break;
11671 case llvm::OpenMPIRBuilder::DeclareSimdKindTy::Uniform:
11672 Out << 'u';
11673 break;
11674 case llvm::OpenMPIRBuilder::DeclareSimdKindTy::Vector:
11675 Out << 'v';
11676 break;
11677 }
11678 if (ParamAttr.HasVarStride)
11679 Out << "s" << ParamAttr.StrideOrArg;
11680 else if (ParamAttr.Kind ==
11681 llvm::OpenMPIRBuilder::DeclareSimdKindTy::Linear ||
11682 ParamAttr.Kind ==
11683 llvm::OpenMPIRBuilder::DeclareSimdKindTy::LinearRef ||
11684 ParamAttr.Kind ==
11685 llvm::OpenMPIRBuilder::DeclareSimdKindTy::LinearUVal ||
11686 ParamAttr.Kind ==
11687 llvm::OpenMPIRBuilder::DeclareSimdKindTy::LinearVal) {
11688 // Don't print the step value if it is not present or if it is
11689 // equal to 1.
11690 if (ParamAttr.StrideOrArg < 0)
11691 Out << 'n' << -ParamAttr.StrideOrArg;
11692 else if (ParamAttr.StrideOrArg != 1)
11693 Out << ParamAttr.StrideOrArg;
11694 }
11695
11696 if (!!ParamAttr.Alignment)
11697 Out << 'a' << ParamAttr.Alignment;
11698 }
11699
11700 return std::string(Out.str());
11701}
11702
11703void OpenMPIRBuilder::emitX86DeclareSimdFunction(
11704 llvm::Function *Fn, unsigned NumElts, const llvm::APSInt &VLENVal,
11705 llvm::ArrayRef<DeclareSimdAttrTy> ParamAttrs, DeclareSimdBranch Branch) {
11706 struct ISADataTy {
11707 char ISA;
11708 unsigned VecRegSize;
11709 };
11710 ISADataTy ISAData[] = {
11711 {.ISA: 'b', .VecRegSize: 128}, // SSE
11712 {.ISA: 'c', .VecRegSize: 256}, // AVX
11713 {.ISA: 'd', .VecRegSize: 256}, // AVX2
11714 {.ISA: 'e', .VecRegSize: 512}, // AVX512
11715 };
11716 llvm::SmallVector<char, 2> Masked;
11717 switch (Branch) {
11718 case DeclareSimdBranch::Undefined:
11719 Masked.push_back(Elt: 'N');
11720 Masked.push_back(Elt: 'M');
11721 break;
11722 case DeclareSimdBranch::Notinbranch:
11723 Masked.push_back(Elt: 'N');
11724 break;
11725 case DeclareSimdBranch::Inbranch:
11726 Masked.push_back(Elt: 'M');
11727 break;
11728 }
11729 for (char Mask : Masked) {
11730 for (const ISADataTy &Data : ISAData) {
11731 llvm::SmallString<256> Buffer;
11732 llvm::raw_svector_ostream Out(Buffer);
11733 Out << "_ZGV" << Data.ISA << Mask;
11734 if (!VLENVal) {
11735 assert(NumElts && "Non-zero simdlen/cdtsize expected");
11736 Out << llvm::APSInt::getUnsigned(X: Data.VecRegSize / NumElts);
11737 } else {
11738 Out << VLENVal;
11739 }
11740 Out << mangleVectorParameters(ParamAttrs);
11741 Out << '_' << Fn->getName();
11742 Fn->addFnAttr(Kind: Out.str());
11743 }
11744 }
11745}
11746
11747// Function used to add the attribute. The parameter `VLEN` is templated to
11748// allow the use of `x` when targeting scalable functions for SVE.
11749template <typename T>
11750static void addAArch64VectorName(T VLEN, StringRef LMask, StringRef Prefix,
11751 char ISA, StringRef ParSeq,
11752 StringRef MangledName, bool OutputBecomesInput,
11753 llvm::Function *Fn) {
11754 SmallString<256> Buffer;
11755 llvm::raw_svector_ostream Out(Buffer);
11756 Out << Prefix << ISA << LMask << VLEN;
11757 if (OutputBecomesInput)
11758 Out << 'v';
11759 Out << ParSeq << '_' << MangledName;
11760 Fn->addFnAttr(Kind: Out.str());
11761}
11762
11763// Helper function to generate the Advanced SIMD names depending on the value
11764// of the NDS when simdlen is not present.
11765static void addAArch64AdvSIMDNDSNames(unsigned NDS, StringRef Mask,
11766 StringRef Prefix, char ISA,
11767 StringRef ParSeq, StringRef MangledName,
11768 bool OutputBecomesInput,
11769 llvm::Function *Fn) {
11770 switch (NDS) {
11771 case 8:
11772 addAArch64VectorName(VLEN: 8, LMask: Mask, Prefix, ISA, ParSeq, MangledName,
11773 OutputBecomesInput, Fn);
11774 addAArch64VectorName(VLEN: 16, LMask: Mask, Prefix, ISA, ParSeq, MangledName,
11775 OutputBecomesInput, Fn);
11776 break;
11777 case 16:
11778 addAArch64VectorName(VLEN: 4, LMask: Mask, Prefix, ISA, ParSeq, MangledName,
11779 OutputBecomesInput, Fn);
11780 addAArch64VectorName(VLEN: 8, LMask: Mask, Prefix, ISA, ParSeq, MangledName,
11781 OutputBecomesInput, Fn);
11782 break;
11783 case 32:
11784 addAArch64VectorName(VLEN: 2, LMask: Mask, Prefix, ISA, ParSeq, MangledName,
11785 OutputBecomesInput, Fn);
11786 addAArch64VectorName(VLEN: 4, LMask: Mask, Prefix, ISA, ParSeq, MangledName,
11787 OutputBecomesInput, Fn);
11788 break;
11789 case 64:
11790 case 128:
11791 addAArch64VectorName(VLEN: 2, LMask: Mask, Prefix, ISA, ParSeq, MangledName,
11792 OutputBecomesInput, Fn);
11793 break;
11794 default:
11795 llvm_unreachable("Scalar type is too wide.");
11796 }
11797}
11798
11799/// Emit vector function attributes for AArch64, as defined in the AAVFABI.
11800void OpenMPIRBuilder::emitAArch64DeclareSimdFunction(
11801 llvm::Function *Fn, unsigned UserVLEN,
11802 llvm::ArrayRef<DeclareSimdAttrTy> ParamAttrs, DeclareSimdBranch Branch,
11803 char ISA, unsigned NarrowestDataSize, bool OutputBecomesInput) {
11804 assert((ISA == 'n' || ISA == 's') && "Expected ISA either 's' or 'n'.");
11805
11806 // Sort out parameter sequence.
11807 const std::string ParSeq = mangleVectorParameters(ParamAttrs);
11808 StringRef Prefix = "_ZGV";
11809 StringRef MangledName = Fn->getName();
11810
11811 // Generate simdlen from user input (if any).
11812 if (UserVLEN) {
11813 if (ISA == 's') {
11814 // SVE generates only a masked function.
11815 addAArch64VectorName(VLEN: UserVLEN, LMask: "M", Prefix, ISA, ParSeq, MangledName,
11816 OutputBecomesInput, Fn);
11817 return;
11818 }
11819
11820 switch (Branch) {
11821 case DeclareSimdBranch::Undefined:
11822 addAArch64VectorName(VLEN: UserVLEN, LMask: "N", Prefix, ISA, ParSeq, MangledName,
11823 OutputBecomesInput, Fn);
11824 addAArch64VectorName(VLEN: UserVLEN, LMask: "M", Prefix, ISA, ParSeq, MangledName,
11825 OutputBecomesInput, Fn);
11826 break;
11827 case DeclareSimdBranch::Inbranch:
11828 addAArch64VectorName(VLEN: UserVLEN, LMask: "M", Prefix, ISA, ParSeq, MangledName,
11829 OutputBecomesInput, Fn);
11830 break;
11831 case DeclareSimdBranch::Notinbranch:
11832 addAArch64VectorName(VLEN: UserVLEN, LMask: "N", Prefix, ISA, ParSeq, MangledName,
11833 OutputBecomesInput, Fn);
11834 break;
11835 }
11836 return;
11837 }
11838
11839 if (ISA == 's') {
11840 // SVE, section 3.4.1, item 1.
11841 addAArch64VectorName(VLEN: "x", LMask: "M", Prefix, ISA, ParSeq, MangledName,
11842 OutputBecomesInput, Fn);
11843 return;
11844 }
11845
11846 switch (Branch) {
11847 case DeclareSimdBranch::Undefined:
11848 addAArch64AdvSIMDNDSNames(NDS: NarrowestDataSize, Mask: "N", Prefix, ISA, ParSeq,
11849 MangledName, OutputBecomesInput, Fn);
11850 addAArch64AdvSIMDNDSNames(NDS: NarrowestDataSize, Mask: "M", Prefix, ISA, ParSeq,
11851 MangledName, OutputBecomesInput, Fn);
11852 break;
11853 case DeclareSimdBranch::Inbranch:
11854 addAArch64AdvSIMDNDSNames(NDS: NarrowestDataSize, Mask: "M", Prefix, ISA, ParSeq,
11855 MangledName, OutputBecomesInput, Fn);
11856 break;
11857 case DeclareSimdBranch::Notinbranch:
11858 addAArch64AdvSIMDNDSNames(NDS: NarrowestDataSize, Mask: "N", Prefix, ISA, ParSeq,
11859 MangledName, OutputBecomesInput, Fn);
11860 break;
11861 }
11862}
11863
11864//===----------------------------------------------------------------------===//
11865// OffloadEntriesInfoManager
11866//===----------------------------------------------------------------------===//
11867
11868bool OffloadEntriesInfoManager::empty() const {
11869 return OffloadEntriesTargetRegion.empty() &&
11870 OffloadEntriesDeviceGlobalVar.empty();
11871}
11872
11873unsigned OffloadEntriesInfoManager::getTargetRegionEntryInfoCount(
11874 const TargetRegionEntryInfo &EntryInfo) const {
11875 auto It = OffloadEntriesTargetRegionCount.find(
11876 x: getTargetRegionEntryCountKey(EntryInfo));
11877 if (It == OffloadEntriesTargetRegionCount.end())
11878 return 0;
11879 return It->second;
11880}
11881
11882void OffloadEntriesInfoManager::incrementTargetRegionEntryInfoCount(
11883 const TargetRegionEntryInfo &EntryInfo) {
11884 OffloadEntriesTargetRegionCount[getTargetRegionEntryCountKey(EntryInfo)] =
11885 EntryInfo.Count + 1;
11886}
11887
11888/// Initialize target region entry.
11889void OffloadEntriesInfoManager::initializeTargetRegionEntryInfo(
11890 const TargetRegionEntryInfo &EntryInfo, unsigned Order) {
11891 OffloadEntriesTargetRegion[EntryInfo] =
11892 OffloadEntryInfoTargetRegion(Order, /*Addr=*/nullptr, /*ID=*/nullptr,
11893 OMPTargetRegionEntryTargetRegion);
11894 ++OffloadingEntriesNum;
11895}
11896
11897void OffloadEntriesInfoManager::registerTargetRegionEntryInfo(
11898 TargetRegionEntryInfo EntryInfo, Constant *Addr, Constant *ID,
11899 OMPTargetRegionEntryKind Flags) {
11900 assert(EntryInfo.Count == 0 && "expected default EntryInfo");
11901
11902 // Update the EntryInfo with the next available count for this location.
11903 EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
11904
11905 // If we are emitting code for a target, the entry is already initialized,
11906 // only has to be registered.
11907 if (OMPBuilder->Config.isTargetDevice()) {
11908 // This could happen if the device compilation is invoked standalone.
11909 if (!hasTargetRegionEntryInfo(EntryInfo)) {
11910 return;
11911 }
11912 auto &Entry = OffloadEntriesTargetRegion[EntryInfo];
11913 Entry.setAddress(Addr);
11914 Entry.setID(ID);
11915 Entry.setFlags(Flags);
11916 } else {
11917 if (Flags == OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion &&
11918 hasTargetRegionEntryInfo(EntryInfo, /*IgnoreAddressId*/ true))
11919 return;
11920 assert(!hasTargetRegionEntryInfo(EntryInfo) &&
11921 "Target region entry already registered!");
11922 OffloadEntryInfoTargetRegion Entry(OffloadingEntriesNum, Addr, ID, Flags);
11923 OffloadEntriesTargetRegion[EntryInfo] = Entry;
11924 ++OffloadingEntriesNum;
11925 }
11926 incrementTargetRegionEntryInfoCount(EntryInfo);
11927}
11928
11929bool OffloadEntriesInfoManager::hasTargetRegionEntryInfo(
11930 TargetRegionEntryInfo EntryInfo, bool IgnoreAddressId) const {
11931
11932 // Update the EntryInfo with the next available count for this location.
11933 EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
11934
11935 auto It = OffloadEntriesTargetRegion.find(x: EntryInfo);
11936 if (It == OffloadEntriesTargetRegion.end()) {
11937 return false;
11938 }
11939 // Fail if this entry is already registered.
11940 if (!IgnoreAddressId && (It->second.getAddress() || It->second.getID()))
11941 return false;
11942 return true;
11943}
11944
11945void OffloadEntriesInfoManager::actOnTargetRegionEntriesInfo(
11946 const OffloadTargetRegionEntryInfoActTy &Action) {
11947 // Scan all target region entries and perform the provided action.
11948 for (const auto &It : OffloadEntriesTargetRegion) {
11949 Action(It.first, It.second);
11950 }
11951}
11952
11953void OffloadEntriesInfoManager::initializeDeviceGlobalVarEntryInfo(
11954 StringRef Name, OMPTargetGlobalVarEntryKind Flags, unsigned Order) {
11955 OffloadEntriesDeviceGlobalVar.try_emplace(Key: Name, Args&: Order, Args&: Flags);
11956 ++OffloadingEntriesNum;
11957}
11958
11959void OffloadEntriesInfoManager::registerDeviceGlobalVarEntryInfo(
11960 StringRef VarName, Constant *Addr, int64_t VarSize,
11961 OMPTargetGlobalVarEntryKind Flags, GlobalValue::LinkageTypes Linkage) {
11962 if (OMPBuilder->Config.isTargetDevice()) {
11963 // This could happen if the device compilation is invoked standalone.
11964 if (!hasDeviceGlobalVarEntryInfo(VarName))
11965 return;
11966 auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
11967 if (Entry.getAddress() && hasDeviceGlobalVarEntryInfo(VarName)) {
11968 if (Entry.getVarSize() == 0) {
11969 Entry.setVarSize(VarSize);
11970 Entry.setLinkage(Linkage);
11971 }
11972 return;
11973 }
11974 Entry.setVarSize(VarSize);
11975 Entry.setLinkage(Linkage);
11976 Entry.setAddress(Addr);
11977 } else {
11978 if (hasDeviceGlobalVarEntryInfo(VarName)) {
11979 auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
11980 assert(Entry.isValid() && Entry.getFlags() == Flags &&
11981 "Entry not initialized!");
11982 if (Entry.getVarSize() == 0) {
11983 Entry.setVarSize(VarSize);
11984 Entry.setLinkage(Linkage);
11985 }
11986 return;
11987 }
11988 if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect ||
11989 Flags ==
11990 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirectVTable)
11991 OffloadEntriesDeviceGlobalVar.try_emplace(Key: VarName, Args&: OffloadingEntriesNum,
11992 Args&: Addr, Args&: VarSize, Args&: Flags, Args&: Linkage,
11993 Args: VarName.str());
11994 else
11995 OffloadEntriesDeviceGlobalVar.try_emplace(
11996 Key: VarName, Args&: OffloadingEntriesNum, Args&: Addr, Args&: VarSize, Args&: Flags, Args&: Linkage, Args: "");
11997 ++OffloadingEntriesNum;
11998 }
11999}
12000
12001void OffloadEntriesInfoManager::actOnDeviceGlobalVarEntriesInfo(
12002 const OffloadDeviceGlobalVarEntryInfoActTy &Action) {
12003 // Scan all target region entries and perform the provided action.
12004 for (const auto &E : OffloadEntriesDeviceGlobalVar)
12005 Action(E.getKey(), E.getValue());
12006}
12007
12008//===----------------------------------------------------------------------===//
12009// CanonicalLoopInfo
12010//===----------------------------------------------------------------------===//
12011
12012void CanonicalLoopInfo::collectControlBlocks(
12013 SmallVectorImpl<BasicBlock *> &BBs) {
12014 // We only count those BBs as control block for which we do not need to
12015 // reverse the CFG, i.e. not the loop body which can contain arbitrary control
12016 // flow. For consistency, this also means we do not add the Body block, which
12017 // is just the entry to the body code.
12018 BBs.reserve(N: BBs.size() + 6);
12019 BBs.append(IL: {getPreheader(), Header, Cond, Latch, Exit, getAfter()});
12020}
12021
12022BasicBlock *CanonicalLoopInfo::getPreheader() const {
12023 assert(isValid() && "Requires a valid canonical loop");
12024 for (BasicBlock *Pred : predecessors(BB: Header)) {
12025 if (Pred != Latch)
12026 return Pred;
12027 }
12028 llvm_unreachable("Missing preheader");
12029}
12030
12031void CanonicalLoopInfo::setTripCount(Value *TripCount) {
12032 assert(isValid() && "Requires a valid canonical loop");
12033
12034 Instruction *CmpI = &getCond()->front();
12035 assert(isa<CmpInst>(CmpI) && "First inst must compare IV with TripCount");
12036 CmpI->setOperand(i: 1, Val: TripCount);
12037
12038#ifndef NDEBUG
12039 assertOK();
12040#endif
12041}
12042
12043void CanonicalLoopInfo::mapIndVar(
12044 llvm::function_ref<Value *(Instruction *)> Updater) {
12045 assert(isValid() && "Requires a valid canonical loop");
12046
12047 Instruction *OldIV = getIndVar();
12048
12049 // Record all uses excluding those introduced by the updater. Uses by the
12050 // CanonicalLoopInfo itself to keep track of the number of iterations are
12051 // excluded.
12052 SmallVector<Use *> ReplacableUses;
12053 for (Use &U : OldIV->uses()) {
12054 auto *User = dyn_cast<Instruction>(Val: U.getUser());
12055 if (!User)
12056 continue;
12057 if (User->getParent() == getCond())
12058 continue;
12059 if (User->getParent() == getLatch())
12060 continue;
12061 ReplacableUses.push_back(Elt: &U);
12062 }
12063
12064 // Run the updater that may introduce new uses
12065 Value *NewIV = Updater(OldIV);
12066
12067 // Replace the old uses with the value returned by the updater.
12068 for (Use *U : ReplacableUses)
12069 U->set(NewIV);
12070
12071#ifndef NDEBUG
12072 assertOK();
12073#endif
12074}
12075
12076void CanonicalLoopInfo::assertOK() const {
12077#ifndef NDEBUG
12078 // No constraints if this object currently does not describe a loop.
12079 if (!isValid())
12080 return;
12081
12082 BasicBlock *Preheader = getPreheader();
12083 BasicBlock *Body = getBody();
12084 BasicBlock *After = getAfter();
12085
12086 // Verify standard control-flow we use for OpenMP loops.
12087 assert(Preheader);
12088 assert(isa<BranchInst>(Preheader->getTerminator()) &&
12089 "Preheader must terminate with unconditional branch");
12090 assert(Preheader->getSingleSuccessor() == Header &&
12091 "Preheader must jump to header");
12092
12093 assert(Header);
12094 assert(isa<BranchInst>(Header->getTerminator()) &&
12095 "Header must terminate with unconditional branch");
12096 assert(Header->getSingleSuccessor() == Cond &&
12097 "Header must jump to exiting block");
12098
12099 assert(Cond);
12100 assert(Cond->getSinglePredecessor() == Header &&
12101 "Exiting block only reachable from header");
12102
12103 assert(isa<BranchInst>(Cond->getTerminator()) &&
12104 "Exiting block must terminate with conditional branch");
12105 assert(size(successors(Cond)) == 2 &&
12106 "Exiting block must have two successors");
12107 assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(0) == Body &&
12108 "Exiting block's first successor jump to the body");
12109 assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(1) == Exit &&
12110 "Exiting block's second successor must exit the loop");
12111
12112 assert(Body);
12113 assert(Body->getSinglePredecessor() == Cond &&
12114 "Body only reachable from exiting block");
12115 assert(!isa<PHINode>(Body->front()));
12116
12117 assert(Latch);
12118 assert(isa<BranchInst>(Latch->getTerminator()) &&
12119 "Latch must terminate with unconditional branch");
12120 assert(Latch->getSingleSuccessor() == Header && "Latch must jump to header");
12121 // TODO: To support simple redirecting of the end of the body code that has
12122 // multiple; introduce another auxiliary basic block like preheader and after.
12123 assert(Latch->getSinglePredecessor() != nullptr);
12124 assert(!isa<PHINode>(Latch->front()));
12125
12126 assert(Exit);
12127 assert(isa<BranchInst>(Exit->getTerminator()) &&
12128 "Exit block must terminate with unconditional branch");
12129 assert(Exit->getSingleSuccessor() == After &&
12130 "Exit block must jump to after block");
12131
12132 assert(After);
12133 assert(After->getSinglePredecessor() == Exit &&
12134 "After block only reachable from exit block");
12135 assert(After->empty() || !isa<PHINode>(After->front()));
12136
12137 Instruction *IndVar = getIndVar();
12138 assert(IndVar && "Canonical induction variable not found?");
12139 assert(isa<IntegerType>(IndVar->getType()) &&
12140 "Induction variable must be an integer");
12141 assert(cast<PHINode>(IndVar)->getParent() == Header &&
12142 "Induction variable must be a PHI in the loop header");
12143 assert(cast<PHINode>(IndVar)->getIncomingBlock(0) == Preheader);
12144 assert(
12145 cast<ConstantInt>(cast<PHINode>(IndVar)->getIncomingValue(0))->isZero());
12146 assert(cast<PHINode>(IndVar)->getIncomingBlock(1) == Latch);
12147
12148 auto *NextIndVar = cast<PHINode>(IndVar)->getIncomingValue(1);
12149 assert(cast<Instruction>(NextIndVar)->getParent() == Latch);
12150 assert(cast<BinaryOperator>(NextIndVar)->getOpcode() == BinaryOperator::Add);
12151 assert(cast<BinaryOperator>(NextIndVar)->getOperand(0) == IndVar);
12152 assert(cast<ConstantInt>(cast<BinaryOperator>(NextIndVar)->getOperand(1))
12153 ->isOne());
12154
12155 Value *TripCount = getTripCount();
12156 assert(TripCount && "Loop trip count not found?");
12157 assert(IndVar->getType() == TripCount->getType() &&
12158 "Trip count and induction variable must have the same type");
12159
12160 auto *CmpI = cast<CmpInst>(&Cond->front());
12161 assert(CmpI->getPredicate() == CmpInst::ICMP_ULT &&
12162 "Exit condition must be a signed less-than comparison");
12163 assert(CmpI->getOperand(0) == IndVar &&
12164 "Exit condition must compare the induction variable");
12165 assert(CmpI->getOperand(1) == TripCount &&
12166 "Exit condition must compare with the trip count");
12167#endif
12168}
12169
12170void CanonicalLoopInfo::invalidate() {
12171 Header = nullptr;
12172 Cond = nullptr;
12173 Latch = nullptr;
12174 Exit = nullptr;
12175}
12176