1//===- OpenMPIRBuilder.cpp - Builder for LLVM-IR for OpenMP directives ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8/// \file
9///
10/// This file implements the OpenMPIRBuilder class, which is used as a
11/// convenient way to create LLVM instructions for OpenMP directives.
12///
13//===----------------------------------------------------------------------===//
14
15#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
16#include "llvm/ADT/SmallBitVector.h"
17#include "llvm/ADT/SmallSet.h"
18#include "llvm/ADT/StringExtras.h"
19#include "llvm/ADT/StringRef.h"
20#include "llvm/Analysis/AssumptionCache.h"
21#include "llvm/Analysis/CodeMetrics.h"
22#include "llvm/Analysis/LoopInfo.h"
23#include "llvm/Analysis/OptimizationRemarkEmitter.h"
24#include "llvm/Analysis/PostDominators.h"
25#include "llvm/Analysis/ScalarEvolution.h"
26#include "llvm/Analysis/TargetLibraryInfo.h"
27#include "llvm/Bitcode/BitcodeReader.h"
28#include "llvm/Frontend/Offloading/Utility.h"
29#include "llvm/Frontend/OpenMP/OMPGridValues.h"
30#include "llvm/IR/Attributes.h"
31#include "llvm/IR/BasicBlock.h"
32#include "llvm/IR/CFG.h"
33#include "llvm/IR/CallingConv.h"
34#include "llvm/IR/Constant.h"
35#include "llvm/IR/Constants.h"
36#include "llvm/IR/DIBuilder.h"
37#include "llvm/IR/DebugInfoMetadata.h"
38#include "llvm/IR/DerivedTypes.h"
39#include "llvm/IR/Function.h"
40#include "llvm/IR/GlobalVariable.h"
41#include "llvm/IR/IRBuilder.h"
42#include "llvm/IR/InstIterator.h"
43#include "llvm/IR/IntrinsicInst.h"
44#include "llvm/IR/LLVMContext.h"
45#include "llvm/IR/MDBuilder.h"
46#include "llvm/IR/Metadata.h"
47#include "llvm/IR/PassInstrumentation.h"
48#include "llvm/IR/PassManager.h"
49#include "llvm/IR/ReplaceConstant.h"
50#include "llvm/IR/Value.h"
51#include "llvm/MC/TargetRegistry.h"
52#include "llvm/Support/CommandLine.h"
53#include "llvm/Support/Error.h"
54#include "llvm/Support/ErrorHandling.h"
55#include "llvm/Support/FileSystem.h"
56#include "llvm/Support/NVVMAttributes.h"
57#include "llvm/Support/VirtualFileSystem.h"
58#include "llvm/Target/TargetMachine.h"
59#include "llvm/Target/TargetOptions.h"
60#include "llvm/Transforms/Utils/BasicBlockUtils.h"
61#include "llvm/Transforms/Utils/Cloning.h"
62#include "llvm/Transforms/Utils/CodeExtractor.h"
63#include "llvm/Transforms/Utils/LoopPeel.h"
64#include "llvm/Transforms/Utils/UnrollLoop.h"
65
66#include <cstdint>
67#include <optional>
68
69#define DEBUG_TYPE "openmp-ir-builder"
70
71using namespace llvm;
72using namespace omp;
73
74static cl::opt<bool>
75 OptimisticAttributes("openmp-ir-builder-optimistic-attributes", cl::Hidden,
76 cl::desc("Use optimistic attributes describing "
77 "'as-if' properties of runtime calls."),
78 cl::init(Val: false));
79
80static cl::opt<double> UnrollThresholdFactor(
81 "openmp-ir-builder-unroll-threshold-factor", cl::Hidden,
82 cl::desc("Factor for the unroll threshold to account for code "
83 "simplifications still taking place"),
84 cl::init(Val: 1.5));
85
86#ifndef NDEBUG
87/// Return whether IP1 and IP2 are ambiguous, i.e. that inserting instructions
88/// at position IP1 may change the meaning of IP2 or vice-versa. This is because
89/// an InsertPoint stores the instruction before something is inserted. For
90/// instance, if both point to the same instruction, two IRBuilders alternating
91/// creating instruction will cause the instructions to be interleaved.
92static bool isConflictIP(IRBuilder<>::InsertPoint IP1,
93 IRBuilder<>::InsertPoint IP2) {
94 if (!IP1.isSet() || !IP2.isSet())
95 return false;
96 return IP1.getBlock() == IP2.getBlock() && IP1.getPoint() == IP2.getPoint();
97}
98
99static bool isValidWorkshareLoopScheduleType(OMPScheduleType SchedType) {
100 // Valid ordered/unordered and base algorithm combinations.
101 switch (SchedType & ~OMPScheduleType::MonotonicityMask) {
102 case OMPScheduleType::UnorderedStaticChunked:
103 case OMPScheduleType::UnorderedStatic:
104 case OMPScheduleType::UnorderedDynamicChunked:
105 case OMPScheduleType::UnorderedGuidedChunked:
106 case OMPScheduleType::UnorderedRuntime:
107 case OMPScheduleType::UnorderedAuto:
108 case OMPScheduleType::UnorderedTrapezoidal:
109 case OMPScheduleType::UnorderedGreedy:
110 case OMPScheduleType::UnorderedBalanced:
111 case OMPScheduleType::UnorderedGuidedIterativeChunked:
112 case OMPScheduleType::UnorderedGuidedAnalyticalChunked:
113 case OMPScheduleType::UnorderedSteal:
114 case OMPScheduleType::UnorderedStaticBalancedChunked:
115 case OMPScheduleType::UnorderedGuidedSimd:
116 case OMPScheduleType::UnorderedRuntimeSimd:
117 case OMPScheduleType::OrderedStaticChunked:
118 case OMPScheduleType::OrderedStatic:
119 case OMPScheduleType::OrderedDynamicChunked:
120 case OMPScheduleType::OrderedGuidedChunked:
121 case OMPScheduleType::OrderedRuntime:
122 case OMPScheduleType::OrderedAuto:
123 case OMPScheduleType::OrderdTrapezoidal:
124 case OMPScheduleType::NomergeUnorderedStaticChunked:
125 case OMPScheduleType::NomergeUnorderedStatic:
126 case OMPScheduleType::NomergeUnorderedDynamicChunked:
127 case OMPScheduleType::NomergeUnorderedGuidedChunked:
128 case OMPScheduleType::NomergeUnorderedRuntime:
129 case OMPScheduleType::NomergeUnorderedAuto:
130 case OMPScheduleType::NomergeUnorderedTrapezoidal:
131 case OMPScheduleType::NomergeUnorderedGreedy:
132 case OMPScheduleType::NomergeUnorderedBalanced:
133 case OMPScheduleType::NomergeUnorderedGuidedIterativeChunked:
134 case OMPScheduleType::NomergeUnorderedGuidedAnalyticalChunked:
135 case OMPScheduleType::NomergeUnorderedSteal:
136 case OMPScheduleType::NomergeOrderedStaticChunked:
137 case OMPScheduleType::NomergeOrderedStatic:
138 case OMPScheduleType::NomergeOrderedDynamicChunked:
139 case OMPScheduleType::NomergeOrderedGuidedChunked:
140 case OMPScheduleType::NomergeOrderedRuntime:
141 case OMPScheduleType::NomergeOrderedAuto:
142 case OMPScheduleType::NomergeOrderedTrapezoidal:
143 case OMPScheduleType::OrderedDistributeChunked:
144 case OMPScheduleType::OrderedDistribute:
145 break;
146 default:
147 return false;
148 }
149
150 // Must not set both monotonicity modifiers at the same time.
151 OMPScheduleType MonotonicityFlags =
152 SchedType & OMPScheduleType::MonotonicityMask;
153 if (MonotonicityFlags == OMPScheduleType::MonotonicityMask)
154 return false;
155
156 return true;
157}
158#endif
159
160/// This is wrapper over IRBuilderBase::restoreIP that also restores the current
161/// debug location to the last instruction in the specified basic block if the
162/// insert point points to the end of the block.
163static void restoreIPandDebugLoc(llvm::IRBuilderBase &Builder,
164 llvm::IRBuilderBase::InsertPoint IP) {
165 Builder.restoreIP(IP);
166 llvm::BasicBlock *BB = Builder.GetInsertBlock();
167 llvm::BasicBlock::iterator I = Builder.GetInsertPoint();
168 if (!BB->empty() && I == BB->end())
169 Builder.SetCurrentDebugLocation(BB->back().getStableDebugLoc());
170}
171
172static bool hasGridValue(const Triple &T) {
173 return T.isAMDGPU() || T.isNVPTX() || T.isSPIRV();
174}
175
176static const omp::GV &getGridValue(const Triple &T, Function *Kernel) {
177 if (T.isAMDGPU()) {
178 StringRef Features =
179 Kernel->getFnAttribute(Kind: "target-features").getValueAsString();
180 if (Features.count(Str: "+wavefrontsize64"))
181 return omp::getAMDGPUGridValues<64>();
182 return omp::getAMDGPUGridValues<32>();
183 }
184 if (T.isNVPTX())
185 return omp::NVPTXGridValues;
186 if (T.isSPIRV())
187 return omp::SPIRVGridValues;
188 llvm_unreachable("No grid value available for this architecture!");
189}
190
191/// Determine which scheduling algorithm to use, determined from schedule clause
192/// arguments.
193static OMPScheduleType
194getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind, bool HasChunks,
195 bool HasSimdModifier, bool HasDistScheduleChunks) {
196 // Currently, the default schedule it static.
197 switch (ClauseKind) {
198 case OMP_SCHEDULE_Default:
199 case OMP_SCHEDULE_Static:
200 return HasChunks ? OMPScheduleType::BaseStaticChunked
201 : OMPScheduleType::BaseStatic;
202 case OMP_SCHEDULE_Dynamic:
203 return OMPScheduleType::BaseDynamicChunked;
204 case OMP_SCHEDULE_Guided:
205 return HasSimdModifier ? OMPScheduleType::BaseGuidedSimd
206 : OMPScheduleType::BaseGuidedChunked;
207 case OMP_SCHEDULE_Auto:
208 return llvm::omp::OMPScheduleType::BaseAuto;
209 case OMP_SCHEDULE_Runtime:
210 return HasSimdModifier ? OMPScheduleType::BaseRuntimeSimd
211 : OMPScheduleType::BaseRuntime;
212 case OMP_SCHEDULE_Distribute:
213 return HasDistScheduleChunks ? OMPScheduleType::BaseDistributeChunked
214 : OMPScheduleType::BaseDistribute;
215 }
216 llvm_unreachable("unhandled schedule clause argument");
217}
218
219/// Adds ordering modifier flags to schedule type.
220static OMPScheduleType
221getOpenMPOrderingScheduleType(OMPScheduleType BaseScheduleType,
222 bool HasOrderedClause) {
223 assert((BaseScheduleType & OMPScheduleType::ModifierMask) ==
224 OMPScheduleType::None &&
225 "Must not have ordering nor monotonicity flags already set");
226
227 OMPScheduleType OrderingModifier = HasOrderedClause
228 ? OMPScheduleType::ModifierOrdered
229 : OMPScheduleType::ModifierUnordered;
230 OMPScheduleType OrderingScheduleType = BaseScheduleType | OrderingModifier;
231
232 // Unsupported combinations
233 if (OrderingScheduleType ==
234 (OMPScheduleType::BaseGuidedSimd | OMPScheduleType::ModifierOrdered))
235 return OMPScheduleType::OrderedGuidedChunked;
236 else if (OrderingScheduleType == (OMPScheduleType::BaseRuntimeSimd |
237 OMPScheduleType::ModifierOrdered))
238 return OMPScheduleType::OrderedRuntime;
239
240 return OrderingScheduleType;
241}
242
243/// Adds monotonicity modifier flags to schedule type.
244static OMPScheduleType
245getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,
246 bool HasSimdModifier, bool HasMonotonic,
247 bool HasNonmonotonic, bool HasOrderedClause) {
248 assert((ScheduleType & OMPScheduleType::MonotonicityMask) ==
249 OMPScheduleType::None &&
250 "Must not have monotonicity flags already set");
251 assert((!HasMonotonic || !HasNonmonotonic) &&
252 "Monotonic and Nonmonotonic are contradicting each other");
253
254 if (HasMonotonic) {
255 return ScheduleType | OMPScheduleType::ModifierMonotonic;
256 } else if (HasNonmonotonic) {
257 return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
258 } else {
259 // OpenMP 5.1, 2.11.4 Worksharing-Loop Construct, Description.
260 // If the static schedule kind is specified or if the ordered clause is
261 // specified, and if the nonmonotonic modifier is not specified, the
262 // effect is as if the monotonic modifier is specified. Otherwise, unless
263 // the monotonic modifier is specified, the effect is as if the
264 // nonmonotonic modifier is specified.
265 OMPScheduleType BaseScheduleType =
266 ScheduleType & ~OMPScheduleType::ModifierMask;
267 if ((BaseScheduleType == OMPScheduleType::BaseStatic) ||
268 (BaseScheduleType == OMPScheduleType::BaseStaticChunked) ||
269 HasOrderedClause) {
270 // The monotonic is used by default in openmp runtime library, so no need
271 // to set it.
272 return ScheduleType;
273 } else {
274 return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
275 }
276 }
277}
278
279/// Determine the schedule type using schedule and ordering clause arguments.
280static OMPScheduleType
281computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
282 bool HasSimdModifier, bool HasMonotonicModifier,
283 bool HasNonmonotonicModifier, bool HasOrderedClause,
284 bool HasDistScheduleChunks) {
285 OMPScheduleType BaseSchedule = getOpenMPBaseScheduleType(
286 ClauseKind, HasChunks, HasSimdModifier, HasDistScheduleChunks);
287 OMPScheduleType OrderedSchedule =
288 getOpenMPOrderingScheduleType(BaseScheduleType: BaseSchedule, HasOrderedClause);
289 OMPScheduleType Result = getOpenMPMonotonicityScheduleType(
290 ScheduleType: OrderedSchedule, HasSimdModifier, HasMonotonic: HasMonotonicModifier,
291 HasNonmonotonic: HasNonmonotonicModifier, HasOrderedClause);
292
293 assert(isValidWorkshareLoopScheduleType(Result));
294 return Result;
295}
296
297/// Make \p Source branch to \p Target.
298///
299/// Handles two situations:
300/// * \p Source already has an unconditional branch.
301/// * \p Source is a degenerate block (no terminator because the BB is
302/// the current head of the IR construction).
303static void redirectTo(BasicBlock *Source, BasicBlock *Target, DebugLoc DL) {
304 if (Instruction *Term = Source->getTerminatorOrNull()) {
305 auto *Br = cast<UncondBrInst>(Val: Term);
306 BasicBlock *Succ = Br->getSuccessor();
307 Succ->removePredecessor(Pred: Source, /*KeepOneInputPHIs=*/true);
308 Br->setSuccessor(Target);
309 return;
310 }
311
312 auto *NewBr = UncondBrInst::Create(Target, InsertBefore: Source);
313 NewBr->setDebugLoc(DL);
314}
315
316void llvm::spliceBB(IRBuilderBase::InsertPoint IP, BasicBlock *New,
317 bool CreateBranch, DebugLoc DL) {
318 assert(New->getFirstInsertionPt() == New->begin() &&
319 "Target BB must not have PHI nodes");
320
321 // Move instructions to new block.
322 BasicBlock *Old = IP.getBlock();
323 // If the `Old` block is empty then there are no instructions to move. But in
324 // the new debug scheme, it could have trailing debug records which will be
325 // moved to `New` in `spliceDebugInfoEmptyBlock`. We dont want that for 2
326 // reasons:
327 // 1. If `New` is also empty, `BasicBlock::splice` crashes.
328 // 2. Even if `New` is not empty, the rationale to move those records to `New`
329 // (in `spliceDebugInfoEmptyBlock`) does not apply here. That function
330 // assumes that `Old` is optimized out and is going away. This is not the case
331 // here. The `Old` block is still being used e.g. a branch instruction is
332 // added to it later in this function.
333 // So we call `BasicBlock::splice` only when `Old` is not empty.
334 if (!Old->empty())
335 New->splice(ToIt: New->begin(), FromBB: Old, FromBeginIt: IP.getPoint(), FromEndIt: Old->end());
336
337 if (CreateBranch) {
338 auto *NewBr = UncondBrInst::Create(Target: New, InsertBefore: Old);
339 NewBr->setDebugLoc(DL);
340 }
341}
342
343void llvm::spliceBB(IRBuilder<> &Builder, BasicBlock *New, bool CreateBranch) {
344 DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
345 BasicBlock *Old = Builder.GetInsertBlock();
346
347 spliceBB(IP: Builder.saveIP(), New, CreateBranch, DL: DebugLoc);
348 if (CreateBranch)
349 Builder.SetInsertPoint(Old->getTerminator());
350 else
351 Builder.SetInsertPoint(Old);
352
353 // SetInsertPoint also updates the Builder's debug location, but we want to
354 // keep the one the Builder was configured to use.
355 Builder.SetCurrentDebugLocation(DebugLoc);
356}
357
358BasicBlock *llvm::splitBB(IRBuilderBase::InsertPoint IP, bool CreateBranch,
359 DebugLoc DL, llvm::Twine Name) {
360 BasicBlock *Old = IP.getBlock();
361 BasicBlock *New = BasicBlock::Create(
362 Context&: Old->getContext(), Name: Name.isTriviallyEmpty() ? Old->getName() : Name,
363 Parent: Old->getParent(), InsertBefore: Old->getNextNode());
364 spliceBB(IP, New, CreateBranch, DL);
365 New->replaceSuccessorsPhiUsesWith(Old, New);
366 return New;
367}
368
369BasicBlock *llvm::splitBB(IRBuilderBase &Builder, bool CreateBranch,
370 llvm::Twine Name) {
371 DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
372 BasicBlock *New = splitBB(IP: Builder.saveIP(), CreateBranch, DL: DebugLoc, Name);
373 if (CreateBranch)
374 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
375 else
376 Builder.SetInsertPoint(Builder.GetInsertBlock());
377 // SetInsertPoint also updates the Builder's debug location, but we want to
378 // keep the one the Builder was configured to use.
379 Builder.SetCurrentDebugLocation(DebugLoc);
380 return New;
381}
382
383BasicBlock *llvm::splitBB(IRBuilder<> &Builder, bool CreateBranch,
384 llvm::Twine Name) {
385 DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
386 BasicBlock *New = splitBB(IP: Builder.saveIP(), CreateBranch, DL: DebugLoc, Name);
387 if (CreateBranch)
388 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
389 else
390 Builder.SetInsertPoint(Builder.GetInsertBlock());
391 // SetInsertPoint also updates the Builder's debug location, but we want to
392 // keep the one the Builder was configured to use.
393 Builder.SetCurrentDebugLocation(DebugLoc);
394 return New;
395}
396
397BasicBlock *llvm::splitBBWithSuffix(IRBuilderBase &Builder, bool CreateBranch,
398 llvm::Twine Suffix) {
399 BasicBlock *Old = Builder.GetInsertBlock();
400 return splitBB(Builder, CreateBranch, Name: Old->getName() + Suffix);
401}
402
403// This function creates a fake integer value and a fake use for the integer
404// value. It returns the fake value created. This is useful in modeling the
405// extra arguments to the outlined functions.
406Value *createFakeIntVal(IRBuilderBase &Builder,
407 OpenMPIRBuilder::InsertPointTy OuterAllocaIP,
408 llvm::SmallVectorImpl<Instruction *> &ToBeDeleted,
409 OpenMPIRBuilder::InsertPointTy InnerAllocaIP,
410 const Twine &Name = "", bool AsPtr = true,
411 bool Is64Bit = false) {
412 Builder.restoreIP(IP: OuterAllocaIP);
413 IntegerType *IntTy = Is64Bit ? Builder.getInt64Ty() : Builder.getInt32Ty();
414 Instruction *FakeVal;
415 AllocaInst *FakeValAddr =
416 Builder.CreateAlloca(Ty: IntTy, ArraySize: nullptr, Name: Name + ".addr");
417 ToBeDeleted.push_back(Elt: FakeValAddr);
418
419 if (AsPtr) {
420 FakeVal = FakeValAddr;
421 } else {
422 FakeVal = Builder.CreateLoad(Ty: IntTy, Ptr: FakeValAddr, Name: Name + ".val");
423 ToBeDeleted.push_back(Elt: FakeVal);
424 }
425
426 // Generate a fake use of this value
427 Builder.restoreIP(IP: InnerAllocaIP);
428 Instruction *UseFakeVal;
429 if (AsPtr) {
430 UseFakeVal = Builder.CreateLoad(Ty: IntTy, Ptr: FakeVal, Name: Name + ".use");
431 } else {
432 UseFakeVal = cast<BinaryOperator>(Val: Builder.CreateAdd(
433 LHS: FakeVal, RHS: Is64Bit ? Builder.getInt64(C: 10) : Builder.getInt32(C: 10)));
434 }
435 ToBeDeleted.push_back(Elt: UseFakeVal);
436 return FakeVal;
437}
438
439//===----------------------------------------------------------------------===//
440// OpenMPIRBuilderConfig
441//===----------------------------------------------------------------------===//
442
443namespace {
444LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
445/// Values for bit flags for marking which requires clauses have been used.
446enum OpenMPOffloadingRequiresDirFlags {
447 /// flag undefined.
448 OMP_REQ_UNDEFINED = 0x000,
449 /// no requires directive present.
450 OMP_REQ_NONE = 0x001,
451 /// reverse_offload clause.
452 OMP_REQ_REVERSE_OFFLOAD = 0x002,
453 /// unified_address clause.
454 OMP_REQ_UNIFIED_ADDRESS = 0x004,
455 /// unified_shared_memory clause.
456 OMP_REQ_UNIFIED_SHARED_MEMORY = 0x008,
457 /// dynamic_allocators clause.
458 OMP_REQ_DYNAMIC_ALLOCATORS = 0x010,
459 LLVM_MARK_AS_BITMASK_ENUM(/*LargestValue=*/OMP_REQ_DYNAMIC_ALLOCATORS)
460};
461
462} // anonymous namespace
463
464OpenMPIRBuilderConfig::OpenMPIRBuilderConfig()
465 : RequiresFlags(OMP_REQ_UNDEFINED) {}
466
467OpenMPIRBuilderConfig::OpenMPIRBuilderConfig(
468 bool IsTargetDevice, bool IsGPU, bool OpenMPOffloadMandatory,
469 bool HasRequiresReverseOffload, bool HasRequiresUnifiedAddress,
470 bool HasRequiresUnifiedSharedMemory, bool HasRequiresDynamicAllocators)
471 : IsTargetDevice(IsTargetDevice), IsGPU(IsGPU),
472 OpenMPOffloadMandatory(OpenMPOffloadMandatory),
473 RequiresFlags(OMP_REQ_UNDEFINED) {
474 if (HasRequiresReverseOffload)
475 RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
476 if (HasRequiresUnifiedAddress)
477 RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
478 if (HasRequiresUnifiedSharedMemory)
479 RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
480 if (HasRequiresDynamicAllocators)
481 RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
482}
483
484bool OpenMPIRBuilderConfig::hasRequiresReverseOffload() const {
485 return RequiresFlags & OMP_REQ_REVERSE_OFFLOAD;
486}
487
488bool OpenMPIRBuilderConfig::hasRequiresUnifiedAddress() const {
489 return RequiresFlags & OMP_REQ_UNIFIED_ADDRESS;
490}
491
492bool OpenMPIRBuilderConfig::hasRequiresUnifiedSharedMemory() const {
493 return RequiresFlags & OMP_REQ_UNIFIED_SHARED_MEMORY;
494}
495
496bool OpenMPIRBuilderConfig::hasRequiresDynamicAllocators() const {
497 return RequiresFlags & OMP_REQ_DYNAMIC_ALLOCATORS;
498}
499
500int64_t OpenMPIRBuilderConfig::getRequiresFlags() const {
501 return hasRequiresFlags() ? RequiresFlags
502 : static_cast<int64_t>(OMP_REQ_NONE);
503}
504
505void OpenMPIRBuilderConfig::setHasRequiresReverseOffload(bool Value) {
506 if (Value)
507 RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
508 else
509 RequiresFlags &= ~OMP_REQ_REVERSE_OFFLOAD;
510}
511
512void OpenMPIRBuilderConfig::setHasRequiresUnifiedAddress(bool Value) {
513 if (Value)
514 RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
515 else
516 RequiresFlags &= ~OMP_REQ_UNIFIED_ADDRESS;
517}
518
519void OpenMPIRBuilderConfig::setHasRequiresUnifiedSharedMemory(bool Value) {
520 if (Value)
521 RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
522 else
523 RequiresFlags &= ~OMP_REQ_UNIFIED_SHARED_MEMORY;
524}
525
526void OpenMPIRBuilderConfig::setHasRequiresDynamicAllocators(bool Value) {
527 if (Value)
528 RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
529 else
530 RequiresFlags &= ~OMP_REQ_DYNAMIC_ALLOCATORS;
531}
532
533//===----------------------------------------------------------------------===//
534// OpenMPIRBuilder
535//===----------------------------------------------------------------------===//
536
537void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
538 IRBuilderBase &Builder,
539 SmallVector<Value *> &ArgsVector) {
540 Value *Version = Builder.getInt32(OMP_KERNEL_ARG_VERSION);
541 Value *PointerNum = Builder.getInt32(C: KernelArgs.NumTargetItems);
542 auto Int32Ty = Type::getInt32Ty(C&: Builder.getContext());
543 constexpr size_t MaxDim = 3;
544 Value *ZeroArray = Constant::getNullValue(Ty: ArrayType::get(ElementType: Int32Ty, NumElements: MaxDim));
545
546 Value *HasNoWaitFlag = Builder.getInt64(C: KernelArgs.HasNoWait);
547
548 Value *DynCGroupMemFallbackFlag =
549 Builder.getInt64(C: static_cast<uint64_t>(KernelArgs.DynCGroupMemFallback));
550 DynCGroupMemFallbackFlag = Builder.CreateShl(LHS: DynCGroupMemFallbackFlag, RHS: 2);
551 Value *Flags = Builder.CreateOr(LHS: HasNoWaitFlag, RHS: DynCGroupMemFallbackFlag);
552
553 assert(!KernelArgs.NumTeams.empty() && !KernelArgs.NumThreads.empty());
554
555 Value *NumTeams3D =
556 Builder.CreateInsertValue(Agg: ZeroArray, Val: KernelArgs.NumTeams[0], Idxs: {0});
557 Value *NumThreads3D =
558 Builder.CreateInsertValue(Agg: ZeroArray, Val: KernelArgs.NumThreads[0], Idxs: {0});
559 for (unsigned I :
560 seq<unsigned>(Begin: 1, End: std::min(a: KernelArgs.NumTeams.size(), b: MaxDim)))
561 NumTeams3D =
562 Builder.CreateInsertValue(Agg: NumTeams3D, Val: KernelArgs.NumTeams[I], Idxs: {I});
563 for (unsigned I :
564 seq<unsigned>(Begin: 1, End: std::min(a: KernelArgs.NumThreads.size(), b: MaxDim)))
565 NumThreads3D =
566 Builder.CreateInsertValue(Agg: NumThreads3D, Val: KernelArgs.NumThreads[I], Idxs: {I});
567
568 ArgsVector = {Version,
569 PointerNum,
570 KernelArgs.RTArgs.BasePointersArray,
571 KernelArgs.RTArgs.PointersArray,
572 KernelArgs.RTArgs.SizesArray,
573 KernelArgs.RTArgs.MapTypesArray,
574 KernelArgs.RTArgs.MapNamesArray,
575 KernelArgs.RTArgs.MappersArray,
576 KernelArgs.NumIterations,
577 Flags,
578 NumTeams3D,
579 NumThreads3D,
580 KernelArgs.DynCGroupMem};
581}
582
583void OpenMPIRBuilder::addAttributes(omp::RuntimeFunction FnID, Function &Fn) {
584 LLVMContext &Ctx = Fn.getContext();
585
586 // Get the function's current attributes.
587 auto Attrs = Fn.getAttributes();
588 auto FnAttrs = Attrs.getFnAttrs();
589 auto RetAttrs = Attrs.getRetAttrs();
590 SmallVector<AttributeSet, 4> ArgAttrs;
591 for (size_t ArgNo = 0; ArgNo < Fn.arg_size(); ++ArgNo)
592 ArgAttrs.emplace_back(Args: Attrs.getParamAttrs(ArgNo));
593
594 // Add AS to FnAS while taking special care with integer extensions.
595 auto addAttrSet = [&](AttributeSet &FnAS, const AttributeSet &AS,
596 bool Param = true) -> void {
597 bool HasSignExt = AS.hasAttribute(Kind: Attribute::SExt);
598 bool HasZeroExt = AS.hasAttribute(Kind: Attribute::ZExt);
599 if (HasSignExt || HasZeroExt) {
600 assert(AS.getNumAttributes() == 1 &&
601 "Currently not handling extension attr combined with others.");
602 if (Param) {
603 if (auto AK = TargetLibraryInfo::getExtAttrForI32Param(T, Signed: HasSignExt))
604 FnAS = FnAS.addAttribute(C&: Ctx, Kind: AK);
605 } else if (auto AK =
606 TargetLibraryInfo::getExtAttrForI32Return(T, Signed: HasSignExt))
607 FnAS = FnAS.addAttribute(C&: Ctx, Kind: AK);
608 } else {
609 FnAS = FnAS.addAttributes(C&: Ctx, AS);
610 }
611 };
612
613#define OMP_ATTRS_SET(VarName, AttrSet) AttributeSet VarName = AttrSet;
614#include "llvm/Frontend/OpenMP/OMPKinds.def"
615
616 // Add attributes to the function declaration.
617 switch (FnID) {
618#define OMP_RTL_ATTRS(Enum, FnAttrSet, RetAttrSet, ArgAttrSets) \
619 case Enum: \
620 FnAttrs = FnAttrs.addAttributes(Ctx, FnAttrSet); \
621 addAttrSet(RetAttrs, RetAttrSet, /*Param*/ false); \
622 for (size_t ArgNo = 0; ArgNo < ArgAttrSets.size(); ++ArgNo) \
623 addAttrSet(ArgAttrs[ArgNo], ArgAttrSets[ArgNo]); \
624 Fn.setAttributes(AttributeList::get(Ctx, FnAttrs, RetAttrs, ArgAttrs)); \
625 break;
626#include "llvm/Frontend/OpenMP/OMPKinds.def"
627 default:
628 // Attributes are optional.
629 break;
630 }
631}
632
633FunctionCallee
634OpenMPIRBuilder::getOrCreateRuntimeFunction(Module &M, RuntimeFunction FnID) {
635 FunctionType *FnTy = nullptr;
636 Function *Fn = nullptr;
637
638 // Try to find the declation in the module first.
639 switch (FnID) {
640#define OMP_RTL(Enum, Str, IsVarArg, ReturnType, ...) \
641 case Enum: \
642 FnTy = FunctionType::get(ReturnType, ArrayRef<Type *>{__VA_ARGS__}, \
643 IsVarArg); \
644 Fn = M.getFunction(Str); \
645 break;
646#include "llvm/Frontend/OpenMP/OMPKinds.def"
647 }
648
649 if (!Fn) {
650 // Create a new declaration if we need one.
651 switch (FnID) {
652#define OMP_RTL(Enum, Str, ...) \
653 case Enum: \
654 Fn = Function::Create(FnTy, GlobalValue::ExternalLinkage, Str, M); \
655 break;
656#include "llvm/Frontend/OpenMP/OMPKinds.def"
657 }
658 Fn->setCallingConv(Config.getRuntimeCC());
659 // Add information if the runtime function takes a callback function
660 if (FnID == OMPRTL___kmpc_fork_call || FnID == OMPRTL___kmpc_fork_teams) {
661 if (!Fn->hasMetadata(KindID: LLVMContext::MD_callback)) {
662 LLVMContext &Ctx = Fn->getContext();
663 MDBuilder MDB(Ctx);
664 // Annotate the callback behavior of the runtime function:
665 // - The callback callee is argument number 2 (microtask).
666 // - The first two arguments of the callback callee are unknown (-1).
667 // - All variadic arguments to the runtime function are passed to the
668 // callback callee.
669 Fn->addMetadata(
670 KindID: LLVMContext::MD_callback,
671 MD&: *MDNode::get(Context&: Ctx, MDs: {MDB.createCallbackEncoding(
672 CalleeArgNo: 2, Arguments: {-1, -1}, /* VarArgsArePassed */ true)}));
673 }
674 }
675
676 LLVM_DEBUG(dbgs() << "Created OpenMP runtime function " << Fn->getName()
677 << " with type " << *Fn->getFunctionType() << "\n");
678 addAttributes(FnID, Fn&: *Fn);
679
680 } else {
681 LLVM_DEBUG(dbgs() << "Found OpenMP runtime function " << Fn->getName()
682 << " with type " << *Fn->getFunctionType() << "\n");
683 }
684
685 assert(Fn && "Failed to create OpenMP runtime function");
686
687 return {FnTy, Fn};
688}
689
690Expected<BasicBlock *>
691OpenMPIRBuilder::FinalizationInfo::getFiniBB(IRBuilderBase &Builder) {
692 if (!FiniBB) {
693 Function *ParentFunc = Builder.GetInsertBlock()->getParent();
694 IRBuilderBase::InsertPointGuard Guard(Builder);
695 FiniBB = BasicBlock::Create(Context&: Builder.getContext(), Name: ".fini", Parent: ParentFunc);
696 Builder.SetInsertPoint(FiniBB);
697 // FiniCB adds the branch to the exit stub.
698 if (Error Err = FiniCB(Builder.saveIP()))
699 return Err;
700 }
701 return FiniBB;
702}
703
704Error OpenMPIRBuilder::FinalizationInfo::mergeFiniBB(IRBuilderBase &Builder,
705 BasicBlock *OtherFiniBB) {
706 // Simple case: FiniBB does not exist yet: re-use OtherFiniBB.
707 if (!FiniBB) {
708 FiniBB = OtherFiniBB;
709
710 Builder.SetInsertPoint(FiniBB->getFirstNonPHIIt());
711 if (Error Err = FiniCB(Builder.saveIP()))
712 return Err;
713
714 return Error::success();
715 }
716
717 // Move instructions from FiniBB to the start of OtherFiniBB.
718 auto EndIt = FiniBB->end();
719 if (FiniBB->size() >= 1)
720 if (auto Prev = std::prev(x: EndIt); Prev->isTerminator())
721 EndIt = Prev;
722 OtherFiniBB->splice(ToIt: OtherFiniBB->getFirstNonPHIIt(), FromBB: FiniBB, FromBeginIt: FiniBB->begin(),
723 FromEndIt: EndIt);
724
725 FiniBB->replaceAllUsesWith(V: OtherFiniBB);
726 FiniBB->eraseFromParent();
727 FiniBB = OtherFiniBB;
728 return Error::success();
729}
730
731Function *OpenMPIRBuilder::getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID) {
732 FunctionCallee RTLFn = getOrCreateRuntimeFunction(M, FnID);
733 auto *Fn = dyn_cast<llvm::Function>(Val: RTLFn.getCallee());
734 assert(Fn && "Failed to create OpenMP runtime function pointer");
735 return Fn;
736}
737
738CallInst *OpenMPIRBuilder::createRuntimeFunctionCall(FunctionCallee Callee,
739 ArrayRef<Value *> Args,
740 StringRef Name) {
741 CallInst *Call = Builder.CreateCall(Callee, Args, Name);
742 Call->setCallingConv(Config.getRuntimeCC());
743 return Call;
744}
745
746void OpenMPIRBuilder::initialize() { initializeTypes(M); }
747
748static void raiseUserConstantDataAllocasToEntryBlock(IRBuilderBase &Builder,
749 Function *Function) {
750 BasicBlock &EntryBlock = Function->getEntryBlock();
751 BasicBlock::iterator MoveLocInst = EntryBlock.getFirstNonPHIIt();
752
753 // Loop over blocks looking for constant allocas, skipping the entry block
754 // as any allocas there are already in the desired location.
755 for (auto Block = std::next(x: Function->begin(), n: 1); Block != Function->end();
756 Block++) {
757 for (auto Inst = Block->getReverseIterator()->begin();
758 Inst != Block->getReverseIterator()->end();) {
759 if (auto *AllocaInst = dyn_cast_if_present<llvm::AllocaInst>(Val&: Inst)) {
760 Inst++;
761 if (!isa<ConstantData>(Val: AllocaInst->getArraySize()))
762 continue;
763 AllocaInst->moveBeforePreserving(MovePos: MoveLocInst);
764 } else {
765 Inst++;
766 }
767 }
768 }
769}
770
771static void hoistNonEntryAllocasToEntryBlock(llvm::BasicBlock &Block) {
772 llvm::SmallVector<llvm::Instruction *> AllocasToMove;
773
774 auto ShouldHoistAlloca = [](const llvm::AllocaInst &AllocaInst) {
775 // TODO: For now, we support simple static allocations, we might need to
776 // move non-static ones as well. However, this will need further analysis to
777 // move the lenght arguments as well.
778 return !AllocaInst.isArrayAllocation();
779 };
780
781 for (llvm::Instruction &Inst : Block)
782 if (auto *AllocaInst = llvm::dyn_cast<llvm::AllocaInst>(Val: &Inst))
783 if (ShouldHoistAlloca(*AllocaInst))
784 AllocasToMove.push_back(Elt: AllocaInst);
785
786 auto InsertPoint =
787 Block.getParent()->getEntryBlock().getTerminator()->getIterator();
788
789 for (llvm::Instruction *AllocaInst : AllocasToMove)
790 AllocaInst->moveBefore(InsertPos: InsertPoint);
791}
792
793static void hoistNonEntryAllocasToEntryBlock(llvm::Function *Func) {
794 PostDominatorTree PostDomTree(*Func);
795 for (llvm::BasicBlock &BB : *Func)
796 if (PostDomTree.properlyDominates(A: &BB, B: &Func->getEntryBlock()))
797 hoistNonEntryAllocasToEntryBlock(Block&: BB);
798}
799
800void OpenMPIRBuilder::finalize(Function *Fn) {
801 SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
802 SmallVector<BasicBlock *, 32> Blocks;
803 SmallVector<OutlineInfo, 16> DeferredOutlines;
804 for (OutlineInfo &OI : OutlineInfos) {
805 // Skip functions that have not finalized yet; may happen with nested
806 // function generation.
807 if (Fn && OI.getFunction() != Fn) {
808 DeferredOutlines.push_back(Elt: OI);
809 continue;
810 }
811
812 ParallelRegionBlockSet.clear();
813 Blocks.clear();
814 OI.collectBlocks(BlockSet&: ParallelRegionBlockSet, BlockVector&: Blocks);
815
816 Function *OuterFn = OI.getFunction();
817 CodeExtractorAnalysisCache CEAC(*OuterFn);
818 // If we generate code for the target device, we need to allocate
819 // struct for aggregate params in the device default alloca address space.
820 // OpenMP runtime requires that the params of the extracted functions are
821 // passed as zero address space pointers. This flag ensures that
822 // CodeExtractor generates correct code for extracted functions
823 // which are used by OpenMP runtime.
824 bool ArgsInZeroAddressSpace = Config.isTargetDevice();
825 CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
826 /* AggregateArgs */ true,
827 /* BlockFrequencyInfo */ nullptr,
828 /* BranchProbabilityInfo */ nullptr,
829 /* AssumptionCache */ nullptr,
830 /* AllowVarArgs */ true,
831 /* AllowAlloca */ true,
832 /* AllocaBlock*/ OI.OuterAllocaBB,
833 /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
834
835 LLVM_DEBUG(dbgs() << "Before outlining: " << *OuterFn << "\n");
836 LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
837 << " Exit: " << OI.ExitBB->getName() << "\n");
838 assert(Extractor.isEligible() &&
839 "Expected OpenMP outlining to be possible!");
840
841 for (auto *V : OI.ExcludeArgsFromAggregate)
842 Extractor.excludeArgFromAggregate(Arg: V);
843
844 Function *OutlinedFn =
845 Extractor.extractCodeRegion(CEAC, Inputs&: OI.Inputs, Outputs&: OI.Outputs);
846
847 // Forward target-cpu, target-features attributes to the outlined function.
848 auto TargetCpuAttr = OuterFn->getFnAttribute(Kind: "target-cpu");
849 if (TargetCpuAttr.isStringAttribute())
850 OutlinedFn->addFnAttr(Attr: TargetCpuAttr);
851
852 auto TargetFeaturesAttr = OuterFn->getFnAttribute(Kind: "target-features");
853 if (TargetFeaturesAttr.isStringAttribute())
854 OutlinedFn->addFnAttr(Attr: TargetFeaturesAttr);
855
856 LLVM_DEBUG(dbgs() << "After outlining: " << *OuterFn << "\n");
857 LLVM_DEBUG(dbgs() << " Outlined function: " << *OutlinedFn << "\n");
858 assert(OutlinedFn->getReturnType()->isVoidTy() &&
859 "OpenMP outlined functions should not return a value!");
860
861 // For compability with the clang CG we move the outlined function after the
862 // one with the parallel region.
863 OutlinedFn->removeFromParent();
864 M.getFunctionList().insertAfter(where: OuterFn->getIterator(), New: OutlinedFn);
865
866 // Remove the artificial entry introduced by the extractor right away, we
867 // made our own entry block after all.
868 {
869 BasicBlock &ArtificialEntry = OutlinedFn->getEntryBlock();
870 assert(ArtificialEntry.getUniqueSuccessor() == OI.EntryBB);
871 assert(OI.EntryBB->getUniquePredecessor() == &ArtificialEntry);
872 // Move instructions from the to-be-deleted ArtificialEntry to the entry
873 // basic block of the parallel region. CodeExtractor generates
874 // instructions to unwrap the aggregate argument and may sink
875 // allocas/bitcasts for values that are solely used in the outlined region
876 // and do not escape.
877 assert(!ArtificialEntry.empty() &&
878 "Expected instructions to add in the outlined region entry");
879 for (BasicBlock::reverse_iterator It = ArtificialEntry.rbegin(),
880 End = ArtificialEntry.rend();
881 It != End;) {
882 Instruction &I = *It;
883 It++;
884
885 if (I.isTerminator()) {
886 // Absorb any debug value that terminator may have
887 if (Instruction *TI = OI.EntryBB->getTerminatorOrNull())
888 TI->adoptDbgRecords(BB: &ArtificialEntry, It: I.getIterator(), InsertAtHead: false);
889 continue;
890 }
891
892 I.moveBeforePreserving(BB&: *OI.EntryBB, I: OI.EntryBB->getFirstInsertionPt());
893 }
894
895 OI.EntryBB->moveBefore(MovePos: &ArtificialEntry);
896 ArtificialEntry.eraseFromParent();
897 }
898 assert(&OutlinedFn->getEntryBlock() == OI.EntryBB);
899 assert(OutlinedFn && OutlinedFn->hasNUses(1));
900
901 // Run a user callback, e.g. to add attributes.
902 if (OI.PostOutlineCB)
903 OI.PostOutlineCB(*OutlinedFn);
904
905 if (OI.FixUpNonEntryAllocas)
906 hoistNonEntryAllocasToEntryBlock(Func: OutlinedFn);
907 }
908
909 // Remove work items that have been completed.
910 OutlineInfos = std::move(DeferredOutlines);
911
912 // The createTarget functions embeds user written code into
913 // the target region which may inject allocas which need to
914 // be moved to the entry block of our target or risk malformed
915 // optimisations by later passes, this is only relevant for
916 // the device pass which appears to be a little more delicate
917 // when it comes to optimisations (however, we do not block on
918 // that here, it's up to the inserter to the list to do so).
919 // This notbaly has to occur after the OutlinedInfo candidates
920 // have been extracted so we have an end product that will not
921 // be implicitly adversely affected by any raises unless
922 // intentionally appended to the list.
923 // NOTE: This only does so for ConstantData, it could be extended
924 // to ConstantExpr's with further effort, however, they should
925 // largely be folded when they get here. Extending it to runtime
926 // defined/read+writeable allocation sizes would be non-trivial
927 // (need to factor in movement of any stores to variables the
928 // allocation size depends on, as well as the usual loads,
929 // otherwise it'll yield the wrong result after movement) and
930 // likely be more suitable as an LLVM optimisation pass.
931 for (Function *F : ConstantAllocaRaiseCandidates)
932 raiseUserConstantDataAllocasToEntryBlock(Builder, Function: F);
933
934 EmitMetadataErrorReportFunctionTy &&ErrorReportFn =
935 [](EmitMetadataErrorKind Kind,
936 const TargetRegionEntryInfo &EntryInfo) -> void {
937 errs() << "Error of kind: " << Kind
938 << " when emitting offload entries and metadata during "
939 "OMPIRBuilder finalization \n";
940 };
941
942 if (!OffloadInfoManager.empty())
943 createOffloadEntriesAndInfoMetadata(ErrorReportFunction&: ErrorReportFn);
944
945 if (Config.EmitLLVMUsedMetaInfo.value_or(u: false)) {
946 std::vector<WeakTrackingVH> LLVMCompilerUsed = {
947 M.getGlobalVariable(Name: "__openmp_nvptx_data_transfer_temporary_storage")};
948 emitUsed(Name: "llvm.compiler.used", List: LLVMCompilerUsed);
949 }
950
951 IsFinalized = true;
952}
953
954bool OpenMPIRBuilder::isFinalized() { return IsFinalized; }
955
956OpenMPIRBuilder::~OpenMPIRBuilder() {
957 assert(OutlineInfos.empty() && "There must be no outstanding outlinings");
958}
959
960GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
961 IntegerType *I32Ty = Type::getInt32Ty(C&: M.getContext());
962 auto *GV =
963 new GlobalVariable(M, I32Ty,
964 /* isConstant = */ true, GlobalValue::WeakODRLinkage,
965 ConstantInt::get(Ty: I32Ty, V: Value), Name);
966 GV->setVisibility(GlobalValue::HiddenVisibility);
967
968 return GV;
969}
970
971void OpenMPIRBuilder::emitUsed(StringRef Name, ArrayRef<WeakTrackingVH> List) {
972 if (List.empty())
973 return;
974
975 // Convert List to what ConstantArray needs.
976 SmallVector<Constant *, 8> UsedArray;
977 UsedArray.resize(N: List.size());
978 for (unsigned I = 0, E = List.size(); I != E; ++I)
979 UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
980 C: cast<Constant>(Val: &*List[I]), Ty: Builder.getPtrTy());
981
982 if (UsedArray.empty())
983 return;
984 ArrayType *ATy = ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: UsedArray.size());
985
986 auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
987 ConstantArray::get(T: ATy, V: UsedArray), Name);
988
989 GV->setSection("llvm.metadata");
990}
991
992GlobalVariable *
993OpenMPIRBuilder::emitKernelExecutionMode(StringRef KernelName,
994 OMPTgtExecModeFlags Mode) {
995 auto *Int8Ty = Builder.getInt8Ty();
996 auto *GVMode = new GlobalVariable(
997 M, Int8Ty, /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
998 ConstantInt::get(Ty: Int8Ty, V: Mode), Twine(KernelName, "_exec_mode"));
999 GVMode->setVisibility(GlobalVariable::ProtectedVisibility);
1000 return GVMode;
1001}
1002
1003Constant *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr,
1004 uint32_t SrcLocStrSize,
1005 IdentFlag LocFlags,
1006 unsigned Reserve2Flags) {
1007 // Enable "C-mode".
1008 LocFlags |= OMP_IDENT_FLAG_KMPC;
1009
1010 Constant *&Ident =
1011 IdentMap[{SrcLocStr, uint64_t(LocFlags) << 31 | Reserve2Flags}];
1012 if (!Ident) {
1013 Constant *I32Null = ConstantInt::getNullValue(Ty: Int32);
1014 Constant *IdentData[] = {I32Null,
1015 ConstantInt::get(Ty: Int32, V: uint32_t(LocFlags)),
1016 ConstantInt::get(Ty: Int32, V: Reserve2Flags),
1017 ConstantInt::get(Ty: Int32, V: SrcLocStrSize), SrcLocStr};
1018
1019 size_t SrcLocStrArgIdx = 4;
1020 if (OpenMPIRBuilder::Ident->getElementType(N: SrcLocStrArgIdx)
1021 ->getPointerAddressSpace() !=
1022 IdentData[SrcLocStrArgIdx]->getType()->getPointerAddressSpace())
1023 IdentData[SrcLocStrArgIdx] = ConstantExpr::getAddrSpaceCast(
1024 C: SrcLocStr, Ty: OpenMPIRBuilder::Ident->getElementType(N: SrcLocStrArgIdx));
1025 Constant *Initializer =
1026 ConstantStruct::get(T: OpenMPIRBuilder::Ident, V: IdentData);
1027
1028 // Look for existing encoding of the location + flags, not needed but
1029 // minimizes the difference to the existing solution while we transition.
1030 for (GlobalVariable &GV : M.globals())
1031 if (GV.getValueType() == OpenMPIRBuilder::Ident && GV.hasInitializer())
1032 if (GV.getInitializer() == Initializer)
1033 Ident = &GV;
1034
1035 if (!Ident) {
1036 auto *GV = new GlobalVariable(
1037 M, OpenMPIRBuilder::Ident,
1038 /* isConstant = */ true, GlobalValue::PrivateLinkage, Initializer, "",
1039 nullptr, GlobalValue::NotThreadLocal,
1040 M.getDataLayout().getDefaultGlobalsAddressSpace());
1041 GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
1042 GV->setAlignment(Align(8));
1043 Ident = GV;
1044 }
1045 }
1046
1047 return ConstantExpr::getPointerBitCastOrAddrSpaceCast(C: Ident, Ty: IdentPtr);
1048}
1049
1050Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef LocStr,
1051 uint32_t &SrcLocStrSize) {
1052 SrcLocStrSize = LocStr.size();
1053 Constant *&SrcLocStr = SrcLocStrMap[LocStr];
1054 if (!SrcLocStr) {
1055 Constant *Initializer =
1056 ConstantDataArray::getString(Context&: M.getContext(), Initializer: LocStr);
1057
1058 // Look for existing encoding of the location, not needed but minimizes the
1059 // difference to the existing solution while we transition.
1060 for (GlobalVariable &GV : M.globals())
1061 if (GV.isConstant() && GV.hasInitializer() &&
1062 GV.getInitializer() == Initializer)
1063 return SrcLocStr = ConstantExpr::getPointerCast(C: &GV, Ty: Int8Ptr);
1064
1065 SrcLocStr = Builder.CreateGlobalString(
1066 Str: LocStr, /*Name=*/"", AddressSpace: M.getDataLayout().getDefaultGlobalsAddressSpace(),
1067 M: &M);
1068 }
1069 return SrcLocStr;
1070}
1071
1072Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef FunctionName,
1073 StringRef FileName,
1074 unsigned Line, unsigned Column,
1075 uint32_t &SrcLocStrSize) {
1076 SmallString<128> Buffer;
1077 Buffer.push_back(Elt: ';');
1078 Buffer.append(RHS: FileName);
1079 Buffer.push_back(Elt: ';');
1080 Buffer.append(RHS: FunctionName);
1081 Buffer.push_back(Elt: ';');
1082 Buffer.append(RHS: std::to_string(val: Line));
1083 Buffer.push_back(Elt: ';');
1084 Buffer.append(RHS: std::to_string(val: Column));
1085 Buffer.push_back(Elt: ';');
1086 Buffer.push_back(Elt: ';');
1087 return getOrCreateSrcLocStr(LocStr: Buffer.str(), SrcLocStrSize);
1088}
1089
1090Constant *
1091OpenMPIRBuilder::getOrCreateDefaultSrcLocStr(uint32_t &SrcLocStrSize) {
1092 StringRef UnknownLoc = ";unknown;unknown;0;0;;";
1093 return getOrCreateSrcLocStr(LocStr: UnknownLoc, SrcLocStrSize);
1094}
1095
1096Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(DebugLoc DL,
1097 uint32_t &SrcLocStrSize,
1098 Function *F) {
1099 DILocation *DIL = DL.get();
1100 if (!DIL)
1101 return getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1102 StringRef FileName = M.getName();
1103 if (DIFile *DIF = DIL->getFile())
1104 if (std::optional<StringRef> Source = DIF->getSource())
1105 FileName = *Source;
1106 StringRef Function = DIL->getScope()->getSubprogram()->getName();
1107 if (Function.empty() && F)
1108 Function = F->getName();
1109 return getOrCreateSrcLocStr(FunctionName: Function, FileName, Line: DIL->getLine(),
1110 Column: DIL->getColumn(), SrcLocStrSize);
1111}
1112
1113Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(const LocationDescription &Loc,
1114 uint32_t &SrcLocStrSize) {
1115 return getOrCreateSrcLocStr(DL: Loc.DL, SrcLocStrSize,
1116 F: Loc.IP.getBlock()->getParent());
1117}
1118
1119Value *OpenMPIRBuilder::getOrCreateThreadID(Value *Ident) {
1120 return createRuntimeFunctionCall(
1121 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_global_thread_num), Args: Ident,
1122 Name: "omp_global_thread_num");
1123}
1124
1125OpenMPIRBuilder::InsertPointOrErrorTy
1126OpenMPIRBuilder::createBarrier(const LocationDescription &Loc, Directive Kind,
1127 bool ForceSimpleCall, bool CheckCancelFlag) {
1128 if (!updateToLocation(Loc))
1129 return Loc.IP;
1130
1131 // Build call __kmpc_cancel_barrier(loc, thread_id) or
1132 // __kmpc_barrier(loc, thread_id);
1133
1134 IdentFlag BarrierLocFlags;
1135 switch (Kind) {
1136 case OMPD_for:
1137 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_FOR;
1138 break;
1139 case OMPD_sections:
1140 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SECTIONS;
1141 break;
1142 case OMPD_single:
1143 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SINGLE;
1144 break;
1145 case OMPD_barrier:
1146 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_EXPL;
1147 break;
1148 default:
1149 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL;
1150 break;
1151 }
1152
1153 uint32_t SrcLocStrSize;
1154 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1155 Value *Args[] = {
1156 getOrCreateIdent(SrcLocStr, SrcLocStrSize, LocFlags: BarrierLocFlags),
1157 getOrCreateThreadID(Ident: getOrCreateIdent(SrcLocStr, SrcLocStrSize))};
1158
1159 // If we are in a cancellable parallel region, barriers are cancellation
1160 // points.
1161 // TODO: Check why we would force simple calls or to ignore the cancel flag.
1162 bool UseCancelBarrier =
1163 !ForceSimpleCall && isLastFinalizationInfoCancellable(DK: OMPD_parallel);
1164
1165 Value *Result = createRuntimeFunctionCall(
1166 Callee: getOrCreateRuntimeFunctionPtr(FnID: UseCancelBarrier
1167 ? OMPRTL___kmpc_cancel_barrier
1168 : OMPRTL___kmpc_barrier),
1169 Args);
1170
1171 if (UseCancelBarrier && CheckCancelFlag)
1172 if (Error Err = emitCancelationCheckImpl(CancelFlag: Result, CanceledDirective: OMPD_parallel))
1173 return Err;
1174
1175 return Builder.saveIP();
1176}
1177
1178OpenMPIRBuilder::InsertPointOrErrorTy
1179OpenMPIRBuilder::createCancel(const LocationDescription &Loc,
1180 Value *IfCondition,
1181 omp::Directive CanceledDirective) {
1182 if (!updateToLocation(Loc))
1183 return Loc.IP;
1184
1185 // LLVM utilities like blocks with terminators.
1186 auto *UI = Builder.CreateUnreachable();
1187
1188 Instruction *ThenTI = UI, *ElseTI = nullptr;
1189 if (IfCondition) {
1190 SplitBlockAndInsertIfThenElse(Cond: IfCondition, SplitBefore: UI, ThenTerm: &ThenTI, ElseTerm: &ElseTI);
1191
1192 // Even if the if condition evaluates to false, this should count as a
1193 // cancellation point
1194 Builder.SetInsertPoint(ElseTI);
1195 auto ElseIP = Builder.saveIP();
1196
1197 InsertPointOrErrorTy IPOrErr = createCancellationPoint(
1198 Loc: LocationDescription{ElseIP, Loc.DL}, CanceledDirective);
1199 if (!IPOrErr)
1200 return IPOrErr;
1201 }
1202
1203 Builder.SetInsertPoint(ThenTI);
1204
1205 Value *CancelKind = nullptr;
1206 switch (CanceledDirective) {
1207#define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value) \
1208 case DirectiveEnum: \
1209 CancelKind = Builder.getInt32(Value); \
1210 break;
1211#include "llvm/Frontend/OpenMP/OMPKinds.def"
1212 default:
1213 llvm_unreachable("Unknown cancel kind!");
1214 }
1215
1216 uint32_t SrcLocStrSize;
1217 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1218 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1219 Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
1220 Value *Result = createRuntimeFunctionCall(
1221 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_cancel), Args);
1222
1223 // The actual cancel logic is shared with others, e.g., cancel_barriers.
1224 if (Error Err = emitCancelationCheckImpl(CancelFlag: Result, CanceledDirective))
1225 return Err;
1226
1227 // Update the insertion point and remove the terminator we introduced.
1228 Builder.SetInsertPoint(UI->getParent());
1229 UI->eraseFromParent();
1230
1231 return Builder.saveIP();
1232}
1233
1234OpenMPIRBuilder::InsertPointOrErrorTy
1235OpenMPIRBuilder::createCancellationPoint(const LocationDescription &Loc,
1236 omp::Directive CanceledDirective) {
1237 if (!updateToLocation(Loc))
1238 return Loc.IP;
1239
1240 // LLVM utilities like blocks with terminators.
1241 auto *UI = Builder.CreateUnreachable();
1242 Builder.SetInsertPoint(UI);
1243
1244 Value *CancelKind = nullptr;
1245 switch (CanceledDirective) {
1246#define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value) \
1247 case DirectiveEnum: \
1248 CancelKind = Builder.getInt32(Value); \
1249 break;
1250#include "llvm/Frontend/OpenMP/OMPKinds.def"
1251 default:
1252 llvm_unreachable("Unknown cancel kind!");
1253 }
1254
1255 uint32_t SrcLocStrSize;
1256 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1257 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1258 Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
1259 Value *Result = createRuntimeFunctionCall(
1260 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_cancellationpoint), Args);
1261
1262 // The actual cancel logic is shared with others, e.g., cancel_barriers.
1263 if (Error Err = emitCancelationCheckImpl(CancelFlag: Result, CanceledDirective))
1264 return Err;
1265
1266 // Update the insertion point and remove the terminator we introduced.
1267 Builder.SetInsertPoint(UI->getParent());
1268 UI->eraseFromParent();
1269
1270 return Builder.saveIP();
1271}
1272
1273OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetKernel(
1274 const LocationDescription &Loc, InsertPointTy AllocaIP, Value *&Return,
1275 Value *Ident, Value *DeviceID, Value *NumTeams, Value *NumThreads,
1276 Value *HostPtr, ArrayRef<Value *> KernelArgs) {
1277 if (!updateToLocation(Loc))
1278 return Loc.IP;
1279
1280 Builder.restoreIP(IP: AllocaIP);
1281 auto *KernelArgsPtr =
1282 Builder.CreateAlloca(Ty: OpenMPIRBuilder::KernelArgs, ArraySize: nullptr, Name: "kernel_args");
1283 updateToLocation(Loc);
1284
1285 for (unsigned I = 0, Size = KernelArgs.size(); I != Size; ++I) {
1286 llvm::Value *Arg =
1287 Builder.CreateStructGEP(Ty: OpenMPIRBuilder::KernelArgs, Ptr: KernelArgsPtr, Idx: I);
1288 Builder.CreateAlignedStore(
1289 Val: KernelArgs[I], Ptr: Arg,
1290 Align: M.getDataLayout().getPrefTypeAlign(Ty: KernelArgs[I]->getType()));
1291 }
1292
1293 SmallVector<Value *> OffloadingArgs{Ident, DeviceID, NumTeams,
1294 NumThreads, HostPtr, KernelArgsPtr};
1295
1296 Return = createRuntimeFunctionCall(
1297 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___tgt_target_kernel),
1298 Args: OffloadingArgs);
1299
1300 return Builder.saveIP();
1301}
1302
1303OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitKernelLaunch(
1304 const LocationDescription &Loc, Value *OutlinedFnID,
1305 EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
1306 Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP) {
1307
1308 if (!updateToLocation(Loc))
1309 return Loc.IP;
1310
1311 // On top of the arrays that were filled up, the target offloading call
1312 // takes as arguments the device id as well as the host pointer. The host
1313 // pointer is used by the runtime library to identify the current target
1314 // region, so it only has to be unique and not necessarily point to
1315 // anything. It could be the pointer to the outlined function that
1316 // implements the target region, but we aren't using that so that the
1317 // compiler doesn't need to keep that, and could therefore inline the host
1318 // function if proven worthwhile during optimization.
1319
1320 // From this point on, we need to have an ID of the target region defined.
1321 assert(OutlinedFnID && "Invalid outlined function ID!");
1322 (void)OutlinedFnID;
1323
1324 // Return value of the runtime offloading call.
1325 Value *Return = nullptr;
1326
1327 // Arguments for the target kernel.
1328 SmallVector<Value *> ArgsVector;
1329 getKernelArgsVector(KernelArgs&: Args, Builder, ArgsVector);
1330
1331 // The target region is an outlined function launched by the runtime
1332 // via calls to __tgt_target_kernel().
1333 //
1334 // Note that on the host and CPU targets, the runtime implementation of
1335 // these calls simply call the outlined function without forking threads.
1336 // The outlined functions themselves have runtime calls to
1337 // __kmpc_fork_teams() and __kmpc_fork() for this purpose, codegen'd by
1338 // the compiler in emitTeamsCall() and emitParallelCall().
1339 //
1340 // In contrast, on the NVPTX target, the implementation of
1341 // __tgt_target_teams() launches a GPU kernel with the requested number
1342 // of teams and threads so no additional calls to the runtime are required.
1343 // Check the error code and execute the host version if required.
1344 Builder.restoreIP(IP: emitTargetKernel(
1345 Loc: Builder, AllocaIP, Return, Ident: RTLoc, DeviceID, NumTeams: Args.NumTeams.front(),
1346 NumThreads: Args.NumThreads.front(), HostPtr: OutlinedFnID, KernelArgs: ArgsVector));
1347
1348 BasicBlock *OffloadFailedBlock =
1349 BasicBlock::Create(Context&: Builder.getContext(), Name: "omp_offload.failed");
1350 BasicBlock *OffloadContBlock =
1351 BasicBlock::Create(Context&: Builder.getContext(), Name: "omp_offload.cont");
1352 Value *Failed = Builder.CreateIsNotNull(Arg: Return);
1353 Builder.CreateCondBr(Cond: Failed, True: OffloadFailedBlock, False: OffloadContBlock);
1354
1355 auto CurFn = Builder.GetInsertBlock()->getParent();
1356 emitBlock(BB: OffloadFailedBlock, CurFn);
1357 InsertPointOrErrorTy AfterIP = EmitTargetCallFallbackCB(Builder.saveIP());
1358 if (!AfterIP)
1359 return AfterIP.takeError();
1360 Builder.restoreIP(IP: *AfterIP);
1361 emitBranch(Target: OffloadContBlock);
1362 emitBlock(BB: OffloadContBlock, CurFn, /*IsFinished=*/true);
1363 return Builder.saveIP();
1364}
1365
1366Error OpenMPIRBuilder::emitCancelationCheckImpl(
1367 Value *CancelFlag, omp::Directive CanceledDirective) {
1368 assert(isLastFinalizationInfoCancellable(CanceledDirective) &&
1369 "Unexpected cancellation!");
1370
1371 // For a cancel barrier we create two new blocks.
1372 BasicBlock *BB = Builder.GetInsertBlock();
1373 BasicBlock *NonCancellationBlock;
1374 if (Builder.GetInsertPoint() == BB->end()) {
1375 // TODO: This branch will not be needed once we moved to the
1376 // OpenMPIRBuilder codegen completely.
1377 NonCancellationBlock = BasicBlock::Create(
1378 Context&: BB->getContext(), Name: BB->getName() + ".cont", Parent: BB->getParent());
1379 } else {
1380 NonCancellationBlock = SplitBlock(Old: BB, SplitPt: &*Builder.GetInsertPoint());
1381 BB->getTerminator()->eraseFromParent();
1382 Builder.SetInsertPoint(BB);
1383 }
1384 BasicBlock *CancellationBlock = BasicBlock::Create(
1385 Context&: BB->getContext(), Name: BB->getName() + ".cncl", Parent: BB->getParent());
1386
1387 // Jump to them based on the return value.
1388 Value *Cmp = Builder.CreateIsNull(Arg: CancelFlag);
1389 Builder.CreateCondBr(Cond: Cmp, True: NonCancellationBlock, False: CancellationBlock,
1390 /* TODO weight */ BranchWeights: nullptr, Unpredictable: nullptr);
1391
1392 // From the cancellation block we finalize all variables and go to the
1393 // post finalization block that is known to the FiniCB callback.
1394 auto &FI = FinalizationStack.back();
1395 Expected<BasicBlock *> FiniBBOrErr = FI.getFiniBB(Builder);
1396 if (!FiniBBOrErr)
1397 return FiniBBOrErr.takeError();
1398 Builder.SetInsertPoint(CancellationBlock);
1399 Builder.CreateBr(Dest: *FiniBBOrErr);
1400
1401 // The continuation block is where code generation continues.
1402 Builder.SetInsertPoint(TheBB: NonCancellationBlock, IP: NonCancellationBlock->begin());
1403 return Error::success();
1404}
1405
1406// Callback used to create OpenMP runtime calls to support
1407// omp parallel clause for the device.
1408// We need to use this callback to replace call to the OutlinedFn in OuterFn
1409// by the call to the OpenMP DeviceRTL runtime function (kmpc_parallel_60)
1410static void targetParallelCallback(
1411 OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn, Function *OuterFn,
1412 BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
1413 Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1414 Value *ThreadID, const SmallVector<Instruction *, 4> &ToBeDeleted) {
1415 // Add some known attributes.
1416 IRBuilder<> &Builder = OMPIRBuilder->Builder;
1417 OutlinedFn.addParamAttr(ArgNo: 0, Kind: Attribute::NoAlias);
1418 OutlinedFn.addParamAttr(ArgNo: 1, Kind: Attribute::NoAlias);
1419 OutlinedFn.addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
1420 OutlinedFn.addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
1421 OutlinedFn.addFnAttr(Kind: Attribute::NoUnwind);
1422
1423 assert(OutlinedFn.arg_size() >= 2 &&
1424 "Expected at least tid and bounded tid as arguments");
1425 unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1426
1427 CallInst *CI = cast<CallInst>(Val: OutlinedFn.user_back());
1428 assert(CI && "Expected call instruction to outlined function");
1429 CI->getParent()->setName("omp_parallel");
1430
1431 Builder.SetInsertPoint(CI);
1432 Type *PtrTy = OMPIRBuilder->VoidPtr;
1433 Value *NullPtrValue = Constant::getNullValue(Ty: PtrTy);
1434
1435 // Add alloca for kernel args
1436 OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP();
1437 Builder.SetInsertPoint(TheBB: OuterAllocaBB, IP: OuterAllocaBB->getFirstInsertionPt());
1438 AllocaInst *ArgsAlloca =
1439 Builder.CreateAlloca(Ty: ArrayType::get(ElementType: PtrTy, NumElements: NumCapturedVars));
1440 Value *Args = ArgsAlloca;
1441 // Add address space cast if array for storing arguments is not allocated
1442 // in address space 0
1443 if (ArgsAlloca->getAddressSpace())
1444 Args = Builder.CreatePointerCast(V: ArgsAlloca, DestTy: PtrTy);
1445 Builder.restoreIP(IP: CurrentIP);
1446
1447 // Store captured vars which are used by kmpc_parallel_60
1448 for (unsigned Idx = 0; Idx < NumCapturedVars; Idx++) {
1449 Value *V = *(CI->arg_begin() + 2 + Idx);
1450 Value *StoreAddress = Builder.CreateConstInBoundsGEP2_64(
1451 Ty: ArrayType::get(ElementType: PtrTy, NumElements: NumCapturedVars), Ptr: Args, Idx0: 0, Idx1: Idx);
1452 Builder.CreateStore(Val: V, Ptr: StoreAddress);
1453 }
1454
1455 Value *Cond =
1456 IfCondition ? Builder.CreateSExtOrTrunc(V: IfCondition, DestTy: OMPIRBuilder->Int32)
1457 : Builder.getInt32(C: 1);
1458
1459 // Build kmpc_parallel_60 call
1460 Value *Parallel60CallArgs[] = {
1461 /* identifier*/ Ident,
1462 /* global thread num*/ ThreadID,
1463 /* if expression */ Cond,
1464 /* number of threads */ NumThreads ? NumThreads : Builder.getInt32(C: -1),
1465 /* Proc bind */ Builder.getInt32(C: -1),
1466 /* outlined function */ &OutlinedFn,
1467 /* wrapper function */ NullPtrValue,
1468 /* arguments of the outlined funciton*/ Args,
1469 /* number of arguments */ Builder.getInt64(C: NumCapturedVars),
1470 /* strict for number of threads */ Builder.getInt32(C: 0)};
1471
1472 FunctionCallee RTLFn =
1473 OMPIRBuilder->getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_parallel_60);
1474
1475 OMPIRBuilder->createRuntimeFunctionCall(Callee: RTLFn, Args: Parallel60CallArgs);
1476
1477 LLVM_DEBUG(dbgs() << "With kmpc_parallel_60 placed: "
1478 << *Builder.GetInsertBlock()->getParent() << "\n");
1479
1480 // Initialize the local TID stack location with the argument value.
1481 Builder.SetInsertPoint(PrivTID);
1482 Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1483 Builder.CreateStore(Val: Builder.CreateLoad(Ty: OMPIRBuilder->Int32, Ptr: OutlinedAI),
1484 Ptr: PrivTIDAddr);
1485
1486 // Remove redundant call to the outlined function.
1487 CI->eraseFromParent();
1488
1489 for (Instruction *I : ToBeDeleted) {
1490 I->eraseFromParent();
1491 }
1492}
1493
1494// Callback used to create OpenMP runtime calls to support
1495// omp parallel clause for the host.
1496// We need to use this callback to replace call to the OutlinedFn in OuterFn
1497// by the call to the OpenMP host runtime function ( __kmpc_fork_call[_if])
1498static void
1499hostParallelCallback(OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn,
1500 Function *OuterFn, Value *Ident, Value *IfCondition,
1501 Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1502 const SmallVector<Instruction *, 4> &ToBeDeleted) {
1503 IRBuilder<> &Builder = OMPIRBuilder->Builder;
1504 FunctionCallee RTLFn;
1505 if (IfCondition) {
1506 RTLFn =
1507 OMPIRBuilder->getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_fork_call_if);
1508 } else {
1509 RTLFn =
1510 OMPIRBuilder->getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_fork_call);
1511 }
1512 if (auto *F = dyn_cast<Function>(Val: RTLFn.getCallee())) {
1513 if (!F->hasMetadata(KindID: LLVMContext::MD_callback)) {
1514 LLVMContext &Ctx = F->getContext();
1515 MDBuilder MDB(Ctx);
1516 // Annotate the callback behavior of the __kmpc_fork_call:
1517 // - The callback callee is argument number 2 (microtask).
1518 // - The first two arguments of the callback callee are unknown (-1).
1519 // - All variadic arguments to the __kmpc_fork_call are passed to the
1520 // callback callee.
1521 F->addMetadata(KindID: LLVMContext::MD_callback,
1522 MD&: *MDNode::get(Context&: Ctx, MDs: {MDB.createCallbackEncoding(
1523 CalleeArgNo: 2, Arguments: {-1, -1},
1524 /* VarArgsArePassed */ true)}));
1525 }
1526 }
1527 // Add some known attributes.
1528 OutlinedFn.addParamAttr(ArgNo: 0, Kind: Attribute::NoAlias);
1529 OutlinedFn.addParamAttr(ArgNo: 1, Kind: Attribute::NoAlias);
1530 OutlinedFn.addFnAttr(Kind: Attribute::NoUnwind);
1531
1532 assert(OutlinedFn.arg_size() >= 2 &&
1533 "Expected at least tid and bounded tid as arguments");
1534 unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1535
1536 CallInst *CI = cast<CallInst>(Val: OutlinedFn.user_back());
1537 CI->getParent()->setName("omp_parallel");
1538 Builder.SetInsertPoint(CI);
1539
1540 // Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn);
1541 Value *ForkCallArgs[] = {Ident, Builder.getInt32(C: NumCapturedVars),
1542 &OutlinedFn};
1543
1544 SmallVector<Value *, 16> RealArgs;
1545 RealArgs.append(in_start: std::begin(arr&: ForkCallArgs), in_end: std::end(arr&: ForkCallArgs));
1546 if (IfCondition) {
1547 Value *Cond = Builder.CreateSExtOrTrunc(V: IfCondition, DestTy: OMPIRBuilder->Int32);
1548 RealArgs.push_back(Elt: Cond);
1549 }
1550 RealArgs.append(in_start: CI->arg_begin() + /* tid & bound tid */ 2, in_end: CI->arg_end());
1551
1552 // __kmpc_fork_call_if always expects a void ptr as the last argument
1553 // If there are no arguments, pass a null pointer.
1554 auto PtrTy = OMPIRBuilder->VoidPtr;
1555 if (IfCondition && NumCapturedVars == 0) {
1556 Value *NullPtrValue = Constant::getNullValue(Ty: PtrTy);
1557 RealArgs.push_back(Elt: NullPtrValue);
1558 }
1559
1560 OMPIRBuilder->createRuntimeFunctionCall(Callee: RTLFn, Args: RealArgs);
1561
1562 LLVM_DEBUG(dbgs() << "With fork_call placed: "
1563 << *Builder.GetInsertBlock()->getParent() << "\n");
1564
1565 // Initialize the local TID stack location with the argument value.
1566 Builder.SetInsertPoint(PrivTID);
1567 Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1568 Builder.CreateStore(Val: Builder.CreateLoad(Ty: OMPIRBuilder->Int32, Ptr: OutlinedAI),
1569 Ptr: PrivTIDAddr);
1570
1571 // Remove redundant call to the outlined function.
1572 CI->eraseFromParent();
1573
1574 for (Instruction *I : ToBeDeleted) {
1575 I->eraseFromParent();
1576 }
1577}
1578
1579OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel(
1580 const LocationDescription &Loc, InsertPointTy OuterAllocaIP,
1581 BodyGenCallbackTy BodyGenCB, PrivatizeCallbackTy PrivCB,
1582 FinalizeCallbackTy FiniCB, Value *IfCondition, Value *NumThreads,
1583 omp::ProcBindKind ProcBind, bool IsCancellable) {
1584 assert(!isConflictIP(Loc.IP, OuterAllocaIP) && "IPs must not be ambiguous");
1585
1586 if (!updateToLocation(Loc))
1587 return Loc.IP;
1588
1589 uint32_t SrcLocStrSize;
1590 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1591 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1592 const bool NeedThreadID = NumThreads || Config.isTargetDevice() ||
1593 (ProcBind != OMP_PROC_BIND_default);
1594 Value *ThreadID = NeedThreadID ? getOrCreateThreadID(Ident) : nullptr;
1595 // If we generate code for the target device, we need to allocate
1596 // struct for aggregate params in the device default alloca address space.
1597 // OpenMP runtime requires that the params of the extracted functions are
1598 // passed as zero address space pointers. This flag ensures that extracted
1599 // function arguments are declared in zero address space
1600 bool ArgsInZeroAddressSpace = Config.isTargetDevice();
1601
1602 // Build call __kmpc_push_num_threads(&Ident, global_tid, num_threads)
1603 // only if we compile for host side.
1604 if (NumThreads && !Config.isTargetDevice()) {
1605 Value *Args[] = {
1606 Ident, ThreadID,
1607 Builder.CreateIntCast(V: NumThreads, DestTy: Int32, /*isSigned*/ false)};
1608 createRuntimeFunctionCall(
1609 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_push_num_threads), Args);
1610 }
1611
1612 if (ProcBind != OMP_PROC_BIND_default) {
1613 // Build call __kmpc_push_proc_bind(&Ident, global_tid, proc_bind)
1614 Value *Args[] = {
1615 Ident, ThreadID,
1616 ConstantInt::get(Ty: Int32, V: unsigned(ProcBind), /*isSigned=*/IsSigned: true)};
1617 createRuntimeFunctionCall(
1618 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_push_proc_bind), Args);
1619 }
1620
1621 BasicBlock *InsertBB = Builder.GetInsertBlock();
1622 Function *OuterFn = InsertBB->getParent();
1623
1624 // Save the outer alloca block because the insertion iterator may get
1625 // invalidated and we still need this later.
1626 BasicBlock *OuterAllocaBlock = OuterAllocaIP.getBlock();
1627
1628 // Vector to remember instructions we used only during the modeling but which
1629 // we want to delete at the end.
1630 SmallVector<Instruction *, 4> ToBeDeleted;
1631
1632 // Change the location to the outer alloca insertion point to create and
1633 // initialize the allocas we pass into the parallel region.
1634 InsertPointTy NewOuter(OuterAllocaBlock, OuterAllocaBlock->begin());
1635 Builder.restoreIP(IP: NewOuter);
1636 AllocaInst *TIDAddrAlloca = Builder.CreateAlloca(Ty: Int32, ArraySize: nullptr, Name: "tid.addr");
1637 AllocaInst *ZeroAddrAlloca =
1638 Builder.CreateAlloca(Ty: Int32, ArraySize: nullptr, Name: "zero.addr");
1639 Instruction *TIDAddr = TIDAddrAlloca;
1640 Instruction *ZeroAddr = ZeroAddrAlloca;
1641 if (ArgsInZeroAddressSpace && M.getDataLayout().getAllocaAddrSpace() != 0) {
1642 // Add additional casts to enforce pointers in zero address space
1643 TIDAddr = new AddrSpaceCastInst(
1644 TIDAddrAlloca, PointerType ::get(C&: M.getContext(), AddressSpace: 0), "tid.addr.ascast");
1645 TIDAddr->insertAfter(InsertPos: TIDAddrAlloca->getIterator());
1646 ToBeDeleted.push_back(Elt: TIDAddr);
1647 ZeroAddr = new AddrSpaceCastInst(ZeroAddrAlloca,
1648 PointerType ::get(C&: M.getContext(), AddressSpace: 0),
1649 "zero.addr.ascast");
1650 ZeroAddr->insertAfter(InsertPos: ZeroAddrAlloca->getIterator());
1651 ToBeDeleted.push_back(Elt: ZeroAddr);
1652 }
1653
1654 // We only need TIDAddr and ZeroAddr for modeling purposes to get the
1655 // associated arguments in the outlined function, so we delete them later.
1656 ToBeDeleted.push_back(Elt: TIDAddrAlloca);
1657 ToBeDeleted.push_back(Elt: ZeroAddrAlloca);
1658
1659 // Create an artificial insertion point that will also ensure the blocks we
1660 // are about to split are not degenerated.
1661 auto *UI = new UnreachableInst(Builder.getContext(), InsertBB);
1662
1663 BasicBlock *EntryBB = UI->getParent();
1664 BasicBlock *PRegEntryBB = EntryBB->splitBasicBlock(I: UI, BBName: "omp.par.entry");
1665 BasicBlock *PRegBodyBB = PRegEntryBB->splitBasicBlock(I: UI, BBName: "omp.par.region");
1666 BasicBlock *PRegPreFiniBB =
1667 PRegBodyBB->splitBasicBlock(I: UI, BBName: "omp.par.pre_finalize");
1668 BasicBlock *PRegExitBB = PRegPreFiniBB->splitBasicBlock(I: UI, BBName: "omp.par.exit");
1669
1670 auto FiniCBWrapper = [&](InsertPointTy IP) {
1671 // Hide "open-ended" blocks from the given FiniCB by setting the right jump
1672 // target to the region exit block.
1673 if (IP.getBlock()->end() == IP.getPoint()) {
1674 IRBuilder<>::InsertPointGuard IPG(Builder);
1675 Builder.restoreIP(IP);
1676 Instruction *I = Builder.CreateBr(Dest: PRegExitBB);
1677 IP = InsertPointTy(I->getParent(), I->getIterator());
1678 }
1679 assert(IP.getBlock()->getTerminator()->getNumSuccessors() == 1 &&
1680 IP.getBlock()->getTerminator()->getSuccessor(0) == PRegExitBB &&
1681 "Unexpected insertion point for finalization call!");
1682 return FiniCB(IP);
1683 };
1684
1685 FinalizationStack.push_back(Elt: {FiniCBWrapper, OMPD_parallel, IsCancellable});
1686
1687 // Generate the privatization allocas in the block that will become the entry
1688 // of the outlined function.
1689 Builder.SetInsertPoint(PRegEntryBB->getTerminator());
1690 InsertPointTy InnerAllocaIP = Builder.saveIP();
1691
1692 AllocaInst *PrivTIDAddr =
1693 Builder.CreateAlloca(Ty: Int32, ArraySize: nullptr, Name: "tid.addr.local");
1694 Instruction *PrivTID = Builder.CreateLoad(Ty: Int32, Ptr: PrivTIDAddr, Name: "tid");
1695
1696 // Add some fake uses for OpenMP provided arguments.
1697 ToBeDeleted.push_back(Elt: Builder.CreateLoad(Ty: Int32, Ptr: TIDAddr, Name: "tid.addr.use"));
1698 Instruction *ZeroAddrUse =
1699 Builder.CreateLoad(Ty: Int32, Ptr: ZeroAddr, Name: "zero.addr.use");
1700 ToBeDeleted.push_back(Elt: ZeroAddrUse);
1701
1702 // EntryBB
1703 // |
1704 // V
1705 // PRegionEntryBB <- Privatization allocas are placed here.
1706 // |
1707 // V
1708 // PRegionBodyBB <- BodeGen is invoked here.
1709 // |
1710 // V
1711 // PRegPreFiniBB <- The block we will start finalization from.
1712 // |
1713 // V
1714 // PRegionExitBB <- A common exit to simplify block collection.
1715 //
1716
1717 LLVM_DEBUG(dbgs() << "Before body codegen: " << *OuterFn << "\n");
1718
1719 // Let the caller create the body.
1720 assert(BodyGenCB && "Expected body generation callback!");
1721 InsertPointTy CodeGenIP(PRegBodyBB, PRegBodyBB->begin());
1722 if (Error Err = BodyGenCB(InnerAllocaIP, CodeGenIP))
1723 return Err;
1724
1725 LLVM_DEBUG(dbgs() << "After body codegen: " << *OuterFn << "\n");
1726
1727 OutlineInfo OI;
1728 if (Config.isTargetDevice()) {
1729 // Generate OpenMP target specific runtime call
1730 OI.PostOutlineCB = [=, ToBeDeletedVec =
1731 std::move(ToBeDeleted)](Function &OutlinedFn) {
1732 targetParallelCallback(OMPIRBuilder: this, OutlinedFn, OuterFn, OuterAllocaBB: OuterAllocaBlock, Ident,
1733 IfCondition, NumThreads, PrivTID, PrivTIDAddr,
1734 ThreadID, ToBeDeleted: ToBeDeletedVec);
1735 };
1736 OI.FixUpNonEntryAllocas = true;
1737 } else {
1738 // Generate OpenMP host runtime call
1739 OI.PostOutlineCB = [=, ToBeDeletedVec =
1740 std::move(ToBeDeleted)](Function &OutlinedFn) {
1741 hostParallelCallback(OMPIRBuilder: this, OutlinedFn, OuterFn, Ident, IfCondition,
1742 PrivTID, PrivTIDAddr, ToBeDeleted: ToBeDeletedVec);
1743 };
1744 OI.FixUpNonEntryAllocas = true;
1745 }
1746
1747 OI.OuterAllocaBB = OuterAllocaBlock;
1748 OI.EntryBB = PRegEntryBB;
1749 OI.ExitBB = PRegExitBB;
1750
1751 SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
1752 SmallVector<BasicBlock *, 32> Blocks;
1753 OI.collectBlocks(BlockSet&: ParallelRegionBlockSet, BlockVector&: Blocks);
1754
1755 CodeExtractorAnalysisCache CEAC(*OuterFn);
1756 CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
1757 /* AggregateArgs */ false,
1758 /* BlockFrequencyInfo */ nullptr,
1759 /* BranchProbabilityInfo */ nullptr,
1760 /* AssumptionCache */ nullptr,
1761 /* AllowVarArgs */ true,
1762 /* AllowAlloca */ true,
1763 /* AllocationBlock */ OuterAllocaBlock,
1764 /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
1765
1766 // Find inputs to, outputs from the code region.
1767 BasicBlock *CommonExit = nullptr;
1768 SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
1769 Extractor.findAllocas(CEAC, SinkCands&: SinkingCands, HoistCands&: HoistingCands, ExitBlock&: CommonExit);
1770
1771 Extractor.findInputsOutputs(Inputs, Outputs, Allocas: SinkingCands,
1772 /*CollectGlobalInputs=*/true);
1773
1774 Inputs.remove_if(P: [&](Value *I) {
1775 if (auto *GV = dyn_cast_if_present<GlobalVariable>(Val: I))
1776 return GV->getValueType() == OpenMPIRBuilder::Ident;
1777
1778 return false;
1779 });
1780
1781 LLVM_DEBUG(dbgs() << "Before privatization: " << *OuterFn << "\n");
1782
1783 FunctionCallee TIDRTLFn =
1784 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_global_thread_num);
1785
1786 auto PrivHelper = [&](Value &V) -> Error {
1787 if (&V == TIDAddr || &V == ZeroAddr) {
1788 OI.ExcludeArgsFromAggregate.push_back(Elt: &V);
1789 return Error::success();
1790 }
1791
1792 SetVector<Use *> Uses;
1793 for (Use &U : V.uses())
1794 if (auto *UserI = dyn_cast<Instruction>(Val: U.getUser()))
1795 if (ParallelRegionBlockSet.count(Ptr: UserI->getParent()))
1796 Uses.insert(X: &U);
1797
1798 // __kmpc_fork_call expects extra arguments as pointers. If the input
1799 // already has a pointer type, everything is fine. Otherwise, store the
1800 // value onto stack and load it back inside the to-be-outlined region. This
1801 // will ensure only the pointer will be passed to the function.
1802 // FIXME: if there are more than 15 trailing arguments, they must be
1803 // additionally packed in a struct.
1804 Value *Inner = &V;
1805 if (!V.getType()->isPointerTy()) {
1806 IRBuilder<>::InsertPointGuard Guard(Builder);
1807 LLVM_DEBUG(llvm::dbgs() << "Forwarding input as pointer: " << V << "\n");
1808
1809 Builder.restoreIP(IP: OuterAllocaIP);
1810 Value *Ptr =
1811 Builder.CreateAlloca(Ty: V.getType(), ArraySize: nullptr, Name: V.getName() + ".reloaded");
1812
1813 // Store to stack at end of the block that currently branches to the entry
1814 // block of the to-be-outlined region.
1815 Builder.SetInsertPoint(TheBB: InsertBB,
1816 IP: InsertBB->getTerminator()->getIterator());
1817 Builder.CreateStore(Val: &V, Ptr);
1818
1819 // Load back next to allocations in the to-be-outlined region.
1820 Builder.restoreIP(IP: InnerAllocaIP);
1821 Inner = Builder.CreateLoad(Ty: V.getType(), Ptr);
1822 }
1823
1824 Value *ReplacementValue = nullptr;
1825 CallInst *CI = dyn_cast<CallInst>(Val: &V);
1826 if (CI && CI->getCalledFunction() == TIDRTLFn.getCallee()) {
1827 ReplacementValue = PrivTID;
1828 } else {
1829 InsertPointOrErrorTy AfterIP =
1830 PrivCB(InnerAllocaIP, Builder.saveIP(), V, *Inner, ReplacementValue);
1831 if (!AfterIP)
1832 return AfterIP.takeError();
1833 Builder.restoreIP(IP: *AfterIP);
1834 InnerAllocaIP = {
1835 InnerAllocaIP.getBlock(),
1836 InnerAllocaIP.getBlock()->getTerminator()->getIterator()};
1837
1838 assert(ReplacementValue &&
1839 "Expected copy/create callback to set replacement value!");
1840 if (ReplacementValue == &V)
1841 return Error::success();
1842 }
1843
1844 for (Use *UPtr : Uses)
1845 UPtr->set(ReplacementValue);
1846
1847 return Error::success();
1848 };
1849
1850 // Reset the inner alloca insertion as it will be used for loading the values
1851 // wrapped into pointers before passing them into the to-be-outlined region.
1852 // Configure it to insert immediately after the fake use of zero address so
1853 // that they are available in the generated body and so that the
1854 // OpenMP-related values (thread ID and zero address pointers) remain leading
1855 // in the argument list.
1856 InnerAllocaIP = IRBuilder<>::InsertPoint(
1857 ZeroAddrUse->getParent(), ZeroAddrUse->getNextNode()->getIterator());
1858
1859 // Reset the outer alloca insertion point to the entry of the relevant block
1860 // in case it was invalidated.
1861 OuterAllocaIP = IRBuilder<>::InsertPoint(
1862 OuterAllocaBlock, OuterAllocaBlock->getFirstInsertionPt());
1863
1864 for (Value *Input : Inputs) {
1865 LLVM_DEBUG(dbgs() << "Captured input: " << *Input << "\n");
1866 if (Error Err = PrivHelper(*Input))
1867 return Err;
1868 }
1869 LLVM_DEBUG({
1870 for (Value *Output : Outputs)
1871 LLVM_DEBUG(dbgs() << "Captured output: " << *Output << "\n");
1872 });
1873 assert(Outputs.empty() &&
1874 "OpenMP outlining should not produce live-out values!");
1875
1876 LLVM_DEBUG(dbgs() << "After privatization: " << *OuterFn << "\n");
1877 LLVM_DEBUG({
1878 for (auto *BB : Blocks)
1879 dbgs() << " PBR: " << BB->getName() << "\n";
1880 });
1881
1882 // Adjust the finalization stack, verify the adjustment, and call the
1883 // finalize function a last time to finalize values between the pre-fini
1884 // block and the exit block if we left the parallel "the normal way".
1885 auto FiniInfo = FinalizationStack.pop_back_val();
1886 (void)FiniInfo;
1887 assert(FiniInfo.DK == OMPD_parallel &&
1888 "Unexpected finalization stack state!");
1889
1890 Instruction *PRegPreFiniTI = PRegPreFiniBB->getTerminator();
1891
1892 InsertPointTy PreFiniIP(PRegPreFiniBB, PRegPreFiniTI->getIterator());
1893 Expected<BasicBlock *> FiniBBOrErr = FiniInfo.getFiniBB(Builder);
1894 if (!FiniBBOrErr)
1895 return FiniBBOrErr.takeError();
1896 {
1897 IRBuilderBase::InsertPointGuard Guard(Builder);
1898 Builder.restoreIP(IP: PreFiniIP);
1899 Builder.CreateBr(Dest: *FiniBBOrErr);
1900 // There's currently a branch to omp.par.exit. Delete it. We will get there
1901 // via the fini block
1902 if (Instruction *Term = Builder.GetInsertBlock()->getTerminator())
1903 Term->eraseFromParent();
1904 }
1905
1906 // Register the outlined info.
1907 addOutlineInfo(OI: std::move(OI));
1908
1909 InsertPointTy AfterIP(UI->getParent(), UI->getParent()->end());
1910 UI->eraseFromParent();
1911
1912 return AfterIP;
1913}
1914
1915void OpenMPIRBuilder::emitFlush(const LocationDescription &Loc) {
1916 // Build call void __kmpc_flush(ident_t *loc)
1917 uint32_t SrcLocStrSize;
1918 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1919 Value *Args[] = {getOrCreateIdent(SrcLocStr, SrcLocStrSize)};
1920
1921 createRuntimeFunctionCall(Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_flush),
1922 Args);
1923}
1924
1925void OpenMPIRBuilder::createFlush(const LocationDescription &Loc) {
1926 if (!updateToLocation(Loc))
1927 return;
1928 emitFlush(Loc);
1929}
1930
1931void OpenMPIRBuilder::emitTaskwaitImpl(const LocationDescription &Loc) {
1932 // Build call kmp_int32 __kmpc_omp_taskwait(ident_t *loc, kmp_int32
1933 // global_tid);
1934 uint32_t SrcLocStrSize;
1935 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1936 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1937 Value *Args[] = {Ident, getOrCreateThreadID(Ident)};
1938
1939 // Ignore return result until untied tasks are supported.
1940 createRuntimeFunctionCall(
1941 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_taskwait), Args);
1942}
1943
1944void OpenMPIRBuilder::createTaskwait(const LocationDescription &Loc) {
1945 if (!updateToLocation(Loc))
1946 return;
1947 emitTaskwaitImpl(Loc);
1948}
1949
1950void OpenMPIRBuilder::emitTaskyieldImpl(const LocationDescription &Loc) {
1951 // Build call __kmpc_omp_taskyield(loc, thread_id, 0);
1952 uint32_t SrcLocStrSize;
1953 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1954 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1955 Constant *I32Null = ConstantInt::getNullValue(Ty: Int32);
1956 Value *Args[] = {Ident, getOrCreateThreadID(Ident), I32Null};
1957
1958 createRuntimeFunctionCall(
1959 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_taskyield), Args);
1960}
1961
1962void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
1963 if (!updateToLocation(Loc))
1964 return;
1965 emitTaskyieldImpl(Loc);
1966}
1967
1968// Processes the dependencies in Dependencies and does the following
1969// - Allocates space on the stack of an array of DependInfo objects
1970// - Populates each DependInfo object with relevant information of
1971// the corresponding dependence.
1972// - All code is inserted in the entry block of the current function.
1973static Value *emitTaskDependencies(
1974 OpenMPIRBuilder &OMPBuilder,
1975 const SmallVectorImpl<OpenMPIRBuilder::DependData> &Dependencies) {
1976 // Early return if we have no dependencies to process
1977 if (Dependencies.empty())
1978 return nullptr;
1979
1980 // Given a vector of DependData objects, in this function we create an
1981 // array on the stack that holds kmp_dep_info objects corresponding
1982 // to each dependency. This is then passed to the OpenMP runtime.
1983 // For example, if there are 'n' dependencies then the following psedo
1984 // code is generated. Assume the first dependence is on a variable 'a'
1985 //
1986 // \code{c}
1987 // DepArray = alloc(n x sizeof(kmp_depend_info);
1988 // idx = 0;
1989 // DepArray[idx].base_addr = ptrtoint(&a);
1990 // DepArray[idx].len = 8;
1991 // DepArray[idx].flags = Dep.DepKind; /*(See OMPContants.h for DepKind)*/
1992 // ++idx;
1993 // DepArray[idx].base_addr = ...;
1994 // \endcode
1995
1996 IRBuilderBase &Builder = OMPBuilder.Builder;
1997 Type *DependInfo = OMPBuilder.DependInfo;
1998 Module &M = OMPBuilder.M;
1999
2000 Value *DepArray = nullptr;
2001 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
2002 Builder.SetInsertPoint(
2003 OldIP.getBlock()->getParent()->getEntryBlock().getTerminator());
2004
2005 Type *DepArrayTy = ArrayType::get(ElementType: DependInfo, NumElements: Dependencies.size());
2006 DepArray = Builder.CreateAlloca(Ty: DepArrayTy, ArraySize: nullptr, Name: ".dep.arr.addr");
2007
2008 Builder.restoreIP(IP: OldIP);
2009
2010 for (const auto &[DepIdx, Dep] : enumerate(First: Dependencies)) {
2011 Value *Base =
2012 Builder.CreateConstInBoundsGEP2_64(Ty: DepArrayTy, Ptr: DepArray, Idx0: 0, Idx1: DepIdx);
2013 // Store the pointer to the variable
2014 Value *Addr = Builder.CreateStructGEP(
2015 Ty: DependInfo, Ptr: Base,
2016 Idx: static_cast<unsigned int>(RTLDependInfoFields::BaseAddr));
2017 Value *DepValPtr = Builder.CreatePtrToInt(V: Dep.DepVal, DestTy: Builder.getInt64Ty());
2018 Builder.CreateStore(Val: DepValPtr, Ptr: Addr);
2019 // Store the size of the variable
2020 Value *Size = Builder.CreateStructGEP(
2021 Ty: DependInfo, Ptr: Base, Idx: static_cast<unsigned int>(RTLDependInfoFields::Len));
2022 Builder.CreateStore(
2023 Val: Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: Dep.DepValueType)),
2024 Ptr: Size);
2025 // Store the dependency kind
2026 Value *Flags = Builder.CreateStructGEP(
2027 Ty: DependInfo, Ptr: Base,
2028 Idx: static_cast<unsigned int>(RTLDependInfoFields::Flags));
2029 Builder.CreateStore(
2030 Val: ConstantInt::get(Ty: Builder.getInt8Ty(),
2031 V: static_cast<unsigned int>(Dep.DepKind)),
2032 Ptr: Flags);
2033 }
2034 return DepArray;
2035}
2036
2037/// Create the task duplication function passed to kmpc_taskloop.
2038Expected<Value *> OpenMPIRBuilder::createTaskDuplicationFunction(
2039 Type *PrivatesTy, int32_t PrivatesIndex, TaskDupCallbackTy DupCB) {
2040 unsigned ProgramAddressSpace = M.getDataLayout().getProgramAddressSpace();
2041 if (!DupCB)
2042 return Constant::getNullValue(
2043 Ty: PointerType::get(C&: Builder.getContext(), AddressSpace: ProgramAddressSpace));
2044
2045 // From OpenMP Runtime p_task_dup_t:
2046 // Routine optionally generated by the compiler for setting the lastprivate
2047 // flag and calling needed constructors for private/firstprivate objects (used
2048 // to form taskloop tasks from pattern task) Parameters: dest task, src task,
2049 // lastprivate flag.
2050 // typedef void (*p_task_dup_t)(kmp_task_t *, kmp_task_t *, kmp_int32);
2051
2052 auto *VoidPtrTy = PointerType::get(C&: Builder.getContext(), AddressSpace: ProgramAddressSpace);
2053
2054 FunctionType *DupFuncTy = FunctionType::get(
2055 Result: Builder.getVoidTy(), Params: {VoidPtrTy, VoidPtrTy, Builder.getInt32Ty()},
2056 /*isVarArg=*/false);
2057
2058 Function *DupFunction = Function::Create(Ty: DupFuncTy, Linkage: Function::InternalLinkage,
2059 N: "omp_taskloop_dup", M);
2060 Value *DestTaskArg = DupFunction->getArg(i: 0);
2061 Value *SrcTaskArg = DupFunction->getArg(i: 1);
2062 Value *LastprivateFlagArg = DupFunction->getArg(i: 2);
2063 DestTaskArg->setName("dest_task");
2064 SrcTaskArg->setName("src_task");
2065 LastprivateFlagArg->setName("lastprivate_flag");
2066
2067 IRBuilderBase::InsertPointGuard Guard(Builder);
2068 Builder.SetInsertPoint(
2069 BasicBlock::Create(Context&: Builder.getContext(), Name: "entry", Parent: DupFunction));
2070
2071 auto GetTaskContextPtrFromArg = [&](Value *Arg) -> Value * {
2072 Type *TaskWithPrivatesTy =
2073 StructType::get(Context&: Builder.getContext(), Elements: {Task, PrivatesTy});
2074 Value *TaskPrivates = Builder.CreateGEP(
2075 Ty: TaskWithPrivatesTy, Ptr: Arg, IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 1)});
2076 Value *ContextPtr = Builder.CreateGEP(
2077 Ty: PrivatesTy, Ptr: TaskPrivates,
2078 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: PrivatesIndex)});
2079 return ContextPtr;
2080 };
2081
2082 Value *DestTaskContextPtr = GetTaskContextPtrFromArg(DestTaskArg);
2083 Value *SrcTaskContextPtr = GetTaskContextPtrFromArg(SrcTaskArg);
2084
2085 DestTaskContextPtr->setName("destPtr");
2086 SrcTaskContextPtr->setName("srcPtr");
2087
2088 InsertPointTy AllocaIP(&DupFunction->getEntryBlock(),
2089 DupFunction->getEntryBlock().begin());
2090 InsertPointTy CodeGenIP = Builder.saveIP();
2091 Expected<IRBuilderBase::InsertPoint> AfterIPOrError =
2092 DupCB(AllocaIP, CodeGenIP, DestTaskContextPtr, SrcTaskContextPtr);
2093 if (!AfterIPOrError)
2094 return AfterIPOrError.takeError();
2095 Builder.restoreIP(IP: *AfterIPOrError);
2096
2097 Builder.CreateRetVoid();
2098
2099 return DupFunction;
2100}
2101
2102OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTaskloop(
2103 const LocationDescription &Loc, InsertPointTy AllocaIP,
2104 BodyGenCallbackTy BodyGenCB,
2105 llvm::function_ref<llvm::Expected<llvm::CanonicalLoopInfo *>()> LoopInfo,
2106 Value *LBVal, Value *UBVal, Value *StepVal, bool Untied, Value *IfCond,
2107 Value *GrainSize, bool NoGroup, int Sched, Value *Final, bool Mergeable,
2108 Value *Priority, uint64_t NumOfCollapseLoops, TaskDupCallbackTy DupCB,
2109 Value *TaskContextStructPtrVal) {
2110
2111 if (!updateToLocation(Loc))
2112 return InsertPointTy();
2113
2114 uint32_t SrcLocStrSize;
2115 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2116 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2117
2118 BasicBlock *TaskloopExitBB =
2119 splitBB(Builder, /*CreateBranch=*/true, Name: "taskloop.exit");
2120 BasicBlock *TaskloopBodyBB =
2121 splitBB(Builder, /*CreateBranch=*/true, Name: "taskloop.body");
2122 BasicBlock *TaskloopAllocaBB =
2123 splitBB(Builder, /*CreateBranch=*/true, Name: "taskloop.alloca");
2124
2125 InsertPointTy TaskloopAllocaIP =
2126 InsertPointTy(TaskloopAllocaBB, TaskloopAllocaBB->begin());
2127 InsertPointTy TaskloopBodyIP =
2128 InsertPointTy(TaskloopBodyBB, TaskloopBodyBB->begin());
2129
2130 if (Error Err = BodyGenCB(TaskloopAllocaIP, TaskloopBodyIP))
2131 return Err;
2132
2133 llvm::Expected<llvm::CanonicalLoopInfo *> result = LoopInfo();
2134 if (!result) {
2135 return result.takeError();
2136 }
2137
2138 llvm::CanonicalLoopInfo *CLI = result.get();
2139 OutlineInfo OI;
2140 OI.EntryBB = TaskloopAllocaBB;
2141 OI.OuterAllocaBB = AllocaIP.getBlock();
2142 OI.ExitBB = TaskloopExitBB;
2143
2144 // Add the thread ID argument.
2145 SmallVector<Instruction *> ToBeDeleted;
2146 // dummy instruction to be used as a fake argument
2147 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
2148 Builder, OuterAllocaIP: AllocaIP, ToBeDeleted, InnerAllocaIP: TaskloopAllocaIP, Name: "global.tid", AsPtr: false));
2149 Value *FakeLB = createFakeIntVal(Builder, OuterAllocaIP: AllocaIP, ToBeDeleted,
2150 InnerAllocaIP: TaskloopAllocaIP, Name: "lb", AsPtr: false, Is64Bit: true);
2151 Value *FakeUB = createFakeIntVal(Builder, OuterAllocaIP: AllocaIP, ToBeDeleted,
2152 InnerAllocaIP: TaskloopAllocaIP, Name: "ub", AsPtr: false, Is64Bit: true);
2153 Value *FakeStep = createFakeIntVal(Builder, OuterAllocaIP: AllocaIP, ToBeDeleted,
2154 InnerAllocaIP: TaskloopAllocaIP, Name: "step", AsPtr: false, Is64Bit: true);
2155 // For Taskloop, we want to force the bounds being the first 3 inputs in the
2156 // aggregate struct
2157 OI.Inputs.insert(X: FakeLB);
2158 OI.Inputs.insert(X: FakeUB);
2159 OI.Inputs.insert(X: FakeStep);
2160 if (TaskContextStructPtrVal)
2161 OI.Inputs.insert(X: TaskContextStructPtrVal);
2162 assert(((TaskContextStructPtrVal && DupCB) ||
2163 (!TaskContextStructPtrVal && !DupCB)) &&
2164 "Task context struct ptr and duplication callback must be both set "
2165 "or both null");
2166
2167 // It isn't safe to run the duplication bodygen callback inside the post
2168 // outlining callback so this has to be run now before we know the real task
2169 // shareds structure type.
2170 unsigned ProgramAddressSpace = M.getDataLayout().getProgramAddressSpace();
2171 Type *PointerTy = PointerType::get(C&: Builder.getContext(), AddressSpace: ProgramAddressSpace);
2172 Type *FakeSharedsTy = StructType::get(
2173 Context&: Builder.getContext(),
2174 Elements: {FakeLB->getType(), FakeUB->getType(), FakeStep->getType(), PointerTy});
2175 Expected<Value *> TaskDupFnOrErr = createTaskDuplicationFunction(
2176 PrivatesTy: FakeSharedsTy,
2177 /*PrivatesIndex: the pointer after the three indices above*/ PrivatesIndex: 3, DupCB);
2178 if (!TaskDupFnOrErr) {
2179 return TaskDupFnOrErr.takeError();
2180 }
2181 Value *TaskDupFn = *TaskDupFnOrErr;
2182
2183 OI.PostOutlineCB = [this, Ident, LBVal, UBVal, StepVal, Untied,
2184 TaskloopAllocaBB, CLI, Loc, TaskDupFn, ToBeDeleted,
2185 IfCond, GrainSize, NoGroup, Sched, FakeLB, FakeUB,
2186 FakeStep, FakeSharedsTy, Final, Mergeable, Priority,
2187 NumOfCollapseLoops](Function &OutlinedFn) mutable {
2188 // Replace the Stale CI by appropriate RTL function call.
2189 assert(OutlinedFn.hasOneUse() &&
2190 "there must be a single user for the outlined function");
2191 CallInst *StaleCI = cast<CallInst>(Val: OutlinedFn.user_back());
2192
2193 /* Create the casting for the Bounds Values that can be used when outlining
2194 * to replace the uses of the fakes with real values */
2195 BasicBlock *CodeReplBB = StaleCI->getParent();
2196 Builder.SetInsertPoint(CodeReplBB->getFirstInsertionPt());
2197 Value *CastedLBVal =
2198 Builder.CreateIntCast(V: LBVal, DestTy: Builder.getInt64Ty(), isSigned: true, Name: "lb64");
2199 Value *CastedUBVal =
2200 Builder.CreateIntCast(V: UBVal, DestTy: Builder.getInt64Ty(), isSigned: true, Name: "ub64");
2201 Value *CastedStepVal =
2202 Builder.CreateIntCast(V: StepVal, DestTy: Builder.getInt64Ty(), isSigned: true, Name: "step64");
2203
2204 Builder.SetInsertPoint(StaleCI);
2205
2206 // Gather the arguments for emitting the runtime call for
2207 // @__kmpc_omp_task_alloc
2208 Function *TaskAllocFn =
2209 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_alloc);
2210
2211 Value *ThreadID = getOrCreateThreadID(Ident);
2212
2213 if (!NoGroup) {
2214 // Emit runtime call for @__kmpc_taskgroup
2215 Function *TaskgroupFn =
2216 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_taskgroup);
2217 Builder.CreateCall(Callee: TaskgroupFn, Args: {Ident, ThreadID});
2218 }
2219
2220 // `flags` Argument Configuration
2221 // Task is tied if (Flags & 1) == 1.
2222 // Task is untied if (Flags & 1) == 0.
2223 // Task is final if (Flags & 2) == 2.
2224 // Task is not final if (Flags & 2) == 0.
2225 // Task is mergeable if (Flags & 4) == 4.
2226 // Task is not mergeable if (Flags & 4) == 0.
2227 // Task is priority if (Flags & 32) == 32.
2228 // Task is not priority if (Flags & 32) == 0.
2229 Value *Flags = Builder.getInt32(C: Untied ? 0 : 1);
2230 if (Final)
2231 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 2), RHS: Flags);
2232 if (Mergeable)
2233 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 4), RHS: Flags);
2234 if (Priority)
2235 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 32), RHS: Flags);
2236
2237 Value *TaskSize = Builder.getInt64(
2238 C: divideCeil(Numerator: M.getDataLayout().getTypeSizeInBits(Ty: Task), Denominator: 8));
2239
2240 AllocaInst *ArgStructAlloca =
2241 dyn_cast<AllocaInst>(Val: StaleCI->getArgOperand(i: 1));
2242 assert(ArgStructAlloca &&
2243 "Unable to find the alloca instruction corresponding to arguments "
2244 "for extracted function");
2245 std::optional<TypeSize> ArgAllocSize =
2246 ArgStructAlloca->getAllocationSize(DL: M.getDataLayout());
2247 assert(ArgAllocSize &&
2248 "Unable to determine size of arguments for extracted function");
2249 Value *SharedsSize = Builder.getInt64(C: ArgAllocSize->getFixedValue());
2250
2251 // Emit the @__kmpc_omp_task_alloc runtime call
2252 // The runtime call returns a pointer to an area where the task captured
2253 // variables must be copied before the task is run (TaskData)
2254 CallInst *TaskData = Builder.CreateCall(
2255 Callee: TaskAllocFn, Args: {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
2256 /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
2257 /*task_func=*/&OutlinedFn});
2258
2259 Value *Shareds = StaleCI->getArgOperand(i: 1);
2260 Align Alignment = TaskData->getPointerAlignment(DL: M.getDataLayout());
2261 Value *TaskShareds = Builder.CreateLoad(Ty: VoidPtr, Ptr: TaskData);
2262 Builder.CreateMemCpy(Dst: TaskShareds, DstAlign: Alignment, Src: Shareds, SrcAlign: Alignment,
2263 Size: SharedsSize);
2264 // Get the pointer to loop lb, ub, step from task ptr
2265 // and set up the lowerbound,upperbound and step values
2266 llvm::Value *Lb = Builder.CreateGEP(
2267 Ty: FakeSharedsTy, Ptr: TaskShareds, IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 0)});
2268
2269 llvm::Value *Ub = Builder.CreateGEP(
2270 Ty: FakeSharedsTy, Ptr: TaskShareds, IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 1)});
2271
2272 llvm::Value *Step = Builder.CreateGEP(
2273 Ty: FakeSharedsTy, Ptr: TaskShareds, IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 2)});
2274 llvm::Value *Loadstep = Builder.CreateLoad(Ty: Builder.getInt64Ty(), Ptr: Step);
2275
2276 // set up the arguments for emitting kmpc_taskloop runtime call
2277 // setting values for ifval, nogroup, sched, grainsize, task_dup
2278 Value *IfCondVal =
2279 IfCond ? Builder.CreateIntCast(V: IfCond, DestTy: Builder.getInt32Ty(), isSigned: true)
2280 : Builder.getInt32(C: 1);
2281 // As __kmpc_taskgroup is called manually in OMPIRBuilder, NoGroupVal should
2282 // always be 1 when calling __kmpc_taskloop to ensure it is not called again
2283 Value *NoGroupVal = Builder.getInt32(C: 1);
2284 Value *SchedVal = Builder.getInt32(C: Sched);
2285 Value *GrainSizeVal =
2286 GrainSize ? Builder.CreateIntCast(V: GrainSize, DestTy: Builder.getInt64Ty(), isSigned: true)
2287 : Builder.getInt64(C: 0);
2288 Value *TaskDup = TaskDupFn;
2289
2290 Value *Args[] = {Ident, ThreadID, TaskData, IfCondVal, Lb, Ub,
2291 Loadstep, NoGroupVal, SchedVal, GrainSizeVal, TaskDup};
2292
2293 // taskloop runtime call
2294 Function *TaskloopFn =
2295 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_taskloop);
2296 Builder.CreateCall(Callee: TaskloopFn, Args);
2297
2298 // Emit the @__kmpc_end_taskgroup runtime call to end the taskgroup if
2299 // nogroup is not defined
2300 if (!NoGroup) {
2301 Function *EndTaskgroupFn =
2302 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_taskgroup);
2303 Builder.CreateCall(Callee: EndTaskgroupFn, Args: {Ident, ThreadID});
2304 }
2305
2306 StaleCI->eraseFromParent();
2307
2308 Builder.SetInsertPoint(TheBB: TaskloopAllocaBB, IP: TaskloopAllocaBB->begin());
2309
2310 LoadInst *SharedsOutlined =
2311 Builder.CreateLoad(Ty: VoidPtr, Ptr: OutlinedFn.getArg(i: 1));
2312 OutlinedFn.getArg(i: 1)->replaceUsesWithIf(
2313 New: SharedsOutlined,
2314 ShouldReplace: [SharedsOutlined](Use &U) { return U.getUser() != SharedsOutlined; });
2315
2316 Value *IV = CLI->getIndVar();
2317 Type *IVTy = IV->getType();
2318 Constant *One = ConstantInt::get(Ty: Builder.getInt64Ty(), V: 1);
2319
2320 // When outlining, CodeExtractor will create GEP's to the LowerBound and
2321 // UpperBound. These GEP's can be reused for loading the tasks respective
2322 // bounds.
2323 Value *TaskLB = nullptr;
2324 Value *TaskUB = nullptr;
2325 Value *LoadTaskLB = nullptr;
2326 Value *LoadTaskUB = nullptr;
2327 for (Instruction &I : *TaskloopAllocaBB) {
2328 if (I.getOpcode() == Instruction::GetElementPtr) {
2329 GetElementPtrInst &Gep = cast<GetElementPtrInst>(Val&: I);
2330 if (ConstantInt *CI = dyn_cast<ConstantInt>(Val: Gep.getOperand(i_nocapture: 2))) {
2331 switch (CI->getZExtValue()) {
2332 case 0:
2333 TaskLB = &I;
2334 break;
2335 case 1:
2336 TaskUB = &I;
2337 break;
2338 }
2339 }
2340 } else if (I.getOpcode() == Instruction::Load) {
2341 LoadInst &Load = cast<LoadInst>(Val&: I);
2342 if (Load.getPointerOperand() == TaskLB) {
2343 assert(TaskLB != nullptr && "Expected value for TaskLB");
2344 LoadTaskLB = &I;
2345 } else if (Load.getPointerOperand() == TaskUB) {
2346 assert(TaskUB != nullptr && "Expected value for TaskUB");
2347 LoadTaskUB = &I;
2348 }
2349 }
2350 }
2351
2352 Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
2353
2354 assert(LoadTaskLB != nullptr && "Expected value for LoadTaskLB");
2355 assert(LoadTaskUB != nullptr && "Expected value for LoadTaskUB");
2356 Value *TripCountMinusOne =
2357 Builder.CreateSDiv(LHS: Builder.CreateSub(LHS: LoadTaskUB, RHS: LoadTaskLB), RHS: FakeStep);
2358 Value *TripCount = Builder.CreateAdd(LHS: TripCountMinusOne, RHS: One, Name: "trip_cnt");
2359 Value *CastedTripCount = Builder.CreateIntCast(V: TripCount, DestTy: IVTy, isSigned: true);
2360 Value *CastedTaskLB = Builder.CreateIntCast(V: LoadTaskLB, DestTy: IVTy, isSigned: true);
2361 // set the trip count in the CLI
2362 CLI->setTripCount(CastedTripCount);
2363
2364 Builder.SetInsertPoint(TheBB: CLI->getBody(),
2365 IP: CLI->getBody()->getFirstInsertionPt());
2366
2367 if (NumOfCollapseLoops > 1) {
2368 llvm::SmallVector<User *> UsersToReplace;
2369 // When using the collapse clause, the bounds of the loop have to be
2370 // adjusted to properly represent the iterator of the outer loop.
2371 Value *IVPlusTaskLB = Builder.CreateAdd(
2372 LHS: CLI->getIndVar(),
2373 RHS: Builder.CreateSub(LHS: CastedTaskLB, RHS: ConstantInt::get(Ty: IVTy, V: 1)));
2374 // To ensure every Use is correctly captured, we first want to record
2375 // which users to replace the value in, and then replace the value.
2376 for (auto IVUse = CLI->getIndVar()->uses().begin();
2377 IVUse != CLI->getIndVar()->uses().end(); IVUse++) {
2378 User *IVUser = IVUse->getUser();
2379 if (auto *Op = dyn_cast<BinaryOperator>(Val: IVUser)) {
2380 if (Op->getOpcode() == Instruction::URem ||
2381 Op->getOpcode() == Instruction::UDiv) {
2382 UsersToReplace.push_back(Elt: IVUser);
2383 }
2384 }
2385 }
2386 for (User *User : UsersToReplace) {
2387 User->replaceUsesOfWith(From: CLI->getIndVar(), To: IVPlusTaskLB);
2388 }
2389 } else {
2390 // The canonical loop is generated with a fixed lower bound. We need to
2391 // update the index calculation code to use the task's lower bound. The
2392 // generated code looks like this:
2393 // %omp_loop.iv = phi ...
2394 // ...
2395 // %tmp = mul [type] %omp_loop.iv, step
2396 // %user_index = add [type] tmp, lb
2397 // OpenMPIRBuilder constructs canonical loops to have exactly three uses
2398 // of the normalised induction variable:
2399 // 1. This one: converting the normalised IV to the user IV
2400 // 2. The increment (add)
2401 // 3. The comparison against the trip count (icmp)
2402 // (1) is the only use that is a mul followed by an add so this cannot
2403 // match other IR.
2404 assert(CLI->getIndVar()->getNumUses() == 3 &&
2405 "Canonical loop should have exactly three uses of the ind var");
2406 for (User *IVUser : CLI->getIndVar()->users()) {
2407 if (auto *Mul = dyn_cast<BinaryOperator>(Val: IVUser)) {
2408 if (Mul->getOpcode() == Instruction::Mul) {
2409 for (User *MulUser : Mul->users()) {
2410 if (auto *Add = dyn_cast<BinaryOperator>(Val: MulUser)) {
2411 if (Add->getOpcode() == Instruction::Add) {
2412 Add->setOperand(i_nocapture: 1, Val_nocapture: CastedTaskLB);
2413 }
2414 }
2415 }
2416 }
2417 }
2418 }
2419 }
2420
2421 FakeLB->replaceAllUsesWith(V: CastedLBVal);
2422 FakeUB->replaceAllUsesWith(V: CastedUBVal);
2423 FakeStep->replaceAllUsesWith(V: CastedStepVal);
2424 for (Instruction *I : llvm::reverse(C&: ToBeDeleted)) {
2425 I->eraseFromParent();
2426 }
2427 };
2428
2429 addOutlineInfo(OI: std::move(OI));
2430 Builder.SetInsertPoint(TheBB: TaskloopExitBB, IP: TaskloopExitBB->begin());
2431 return Builder.saveIP();
2432}
2433
2434llvm::StructType *OpenMPIRBuilder::getKmpTaskAffinityInfoTy() {
2435 llvm::Type *IntPtrTy = llvm::Type::getIntNTy(
2436 C&: M.getContext(), N: M.getDataLayout().getPointerSizeInBits());
2437 return llvm::StructType::get(elt1: IntPtrTy, elts: IntPtrTy,
2438 elts: llvm::Type::getInt32Ty(C&: M.getContext()));
2439}
2440
2441OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
2442 const LocationDescription &Loc, InsertPointTy AllocaIP,
2443 BodyGenCallbackTy BodyGenCB, bool Tied, Value *Final, Value *IfCondition,
2444 SmallVector<DependData> Dependencies, AffinityData Affinities,
2445 bool Mergeable, Value *EventHandle, Value *Priority) {
2446
2447 if (!updateToLocation(Loc))
2448 return InsertPointTy();
2449
2450 uint32_t SrcLocStrSize;
2451 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2452 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2453 // The current basic block is split into four basic blocks. After outlining,
2454 // they will be mapped as follows:
2455 // ```
2456 // def current_fn() {
2457 // current_basic_block:
2458 // br label %task.exit
2459 // task.exit:
2460 // ; instructions after task
2461 // }
2462 // def outlined_fn() {
2463 // task.alloca:
2464 // br label %task.body
2465 // task.body:
2466 // ret void
2467 // }
2468 // ```
2469 BasicBlock *TaskExitBB = splitBB(Builder, /*CreateBranch=*/true, Name: "task.exit");
2470 BasicBlock *TaskBodyBB = splitBB(Builder, /*CreateBranch=*/true, Name: "task.body");
2471 BasicBlock *TaskAllocaBB =
2472 splitBB(Builder, /*CreateBranch=*/true, Name: "task.alloca");
2473
2474 InsertPointTy TaskAllocaIP =
2475 InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin());
2476 InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin());
2477 if (Error Err = BodyGenCB(TaskAllocaIP, TaskBodyIP))
2478 return Err;
2479
2480 OutlineInfo OI;
2481 OI.EntryBB = TaskAllocaBB;
2482 OI.OuterAllocaBB = AllocaIP.getBlock();
2483 OI.ExitBB = TaskExitBB;
2484
2485 // Add the thread ID argument.
2486 SmallVector<Instruction *, 4> ToBeDeleted;
2487 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
2488 Builder, OuterAllocaIP: AllocaIP, ToBeDeleted, InnerAllocaIP: TaskAllocaIP, Name: "global.tid", AsPtr: false));
2489
2490 OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
2491 Affinities, Mergeable, Priority, EventHandle,
2492 TaskAllocaBB, ToBeDeleted](Function &OutlinedFn) mutable {
2493 // Replace the Stale CI by appropriate RTL function call.
2494 assert(OutlinedFn.hasOneUse() &&
2495 "there must be a single user for the outlined function");
2496 CallInst *StaleCI = cast<CallInst>(Val: OutlinedFn.user_back());
2497
2498 // HasShareds is true if any variables are captured in the outlined region,
2499 // false otherwise.
2500 bool HasShareds = StaleCI->arg_size() > 1;
2501 Builder.SetInsertPoint(StaleCI);
2502
2503 // Gather the arguments for emitting the runtime call for
2504 // @__kmpc_omp_task_alloc
2505 Function *TaskAllocFn =
2506 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_alloc);
2507
2508 // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
2509 // call.
2510 Value *ThreadID = getOrCreateThreadID(Ident);
2511
2512 // Argument - `flags`
2513 // Task is tied iff (Flags & 1) == 1.
2514 // Task is untied iff (Flags & 1) == 0.
2515 // Task is final iff (Flags & 2) == 2.
2516 // Task is not final iff (Flags & 2) == 0.
2517 // Task is mergeable iff (Flags & 4) == 4.
2518 // Task is not mergeable iff (Flags & 4) == 0.
2519 // Task is priority iff (Flags & 32) == 32.
2520 // Task is not priority iff (Flags & 32) == 0.
2521 // TODO: Handle the other flags.
2522 Value *Flags = Builder.getInt32(C: Tied);
2523 if (Final) {
2524 Value *FinalFlag =
2525 Builder.CreateSelect(C: Final, True: Builder.getInt32(C: 2), False: Builder.getInt32(C: 0));
2526 Flags = Builder.CreateOr(LHS: FinalFlag, RHS: Flags);
2527 }
2528
2529 if (Mergeable)
2530 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 4), RHS: Flags);
2531 if (Priority)
2532 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 32), RHS: Flags);
2533
2534 // Argument - `sizeof_kmp_task_t` (TaskSize)
2535 // Tasksize refers to the size in bytes of kmp_task_t data structure
2536 // including private vars accessed in task.
2537 // TODO: add kmp_task_t_with_privates (privates)
2538 Value *TaskSize = Builder.getInt64(
2539 C: divideCeil(Numerator: M.getDataLayout().getTypeSizeInBits(Ty: Task), Denominator: 8));
2540
2541 // Argument - `sizeof_shareds` (SharedsSize)
2542 // SharedsSize refers to the shareds array size in the kmp_task_t data
2543 // structure.
2544 Value *SharedsSize = Builder.getInt64(C: 0);
2545 if (HasShareds) {
2546 AllocaInst *ArgStructAlloca =
2547 dyn_cast<AllocaInst>(Val: StaleCI->getArgOperand(i: 1));
2548 assert(ArgStructAlloca &&
2549 "Unable to find the alloca instruction corresponding to arguments "
2550 "for extracted function");
2551 std::optional<TypeSize> ArgAllocSize =
2552 ArgStructAlloca->getAllocationSize(DL: M.getDataLayout());
2553 assert(ArgAllocSize &&
2554 "Unable to determine size of arguments for extracted function");
2555 SharedsSize = Builder.getInt64(C: ArgAllocSize->getFixedValue());
2556 }
2557 // Emit the @__kmpc_omp_task_alloc runtime call
2558 // The runtime call returns a pointer to an area where the task captured
2559 // variables must be copied before the task is run (TaskData)
2560 CallInst *TaskData = createRuntimeFunctionCall(
2561 Callee: TaskAllocFn, Args: {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
2562 /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
2563 /*task_func=*/&OutlinedFn});
2564
2565 if (Affinities.Count && Affinities.Info) {
2566 Function *RegAffFn = getOrCreateRuntimeFunctionPtr(
2567 FnID: OMPRTL___kmpc_omp_reg_task_with_affinity);
2568
2569 createRuntimeFunctionCall(Callee: RegAffFn, Args: {Ident, ThreadID, TaskData,
2570 Affinities.Count, Affinities.Info});
2571 }
2572
2573 // Emit detach clause initialization.
2574 // evt = (typeof(evt))__kmpc_task_allow_completion_event(loc, tid,
2575 // task_descriptor);
2576 if (EventHandle) {
2577 Function *TaskDetachFn = getOrCreateRuntimeFunctionPtr(
2578 FnID: OMPRTL___kmpc_task_allow_completion_event);
2579 llvm::Value *EventVal =
2580 createRuntimeFunctionCall(Callee: TaskDetachFn, Args: {Ident, ThreadID, TaskData});
2581 llvm::Value *EventHandleAddr =
2582 Builder.CreatePointerBitCastOrAddrSpaceCast(V: EventHandle,
2583 DestTy: Builder.getPtrTy(AddrSpace: 0));
2584 EventVal = Builder.CreatePtrToInt(V: EventVal, DestTy: Builder.getInt64Ty());
2585 Builder.CreateStore(Val: EventVal, Ptr: EventHandleAddr);
2586 }
2587 // Copy the arguments for outlined function
2588 if (HasShareds) {
2589 Value *Shareds = StaleCI->getArgOperand(i: 1);
2590 Align Alignment = TaskData->getPointerAlignment(DL: M.getDataLayout());
2591 Value *TaskShareds = Builder.CreateLoad(Ty: VoidPtr, Ptr: TaskData);
2592 Builder.CreateMemCpy(Dst: TaskShareds, DstAlign: Alignment, Src: Shareds, SrcAlign: Alignment,
2593 Size: SharedsSize);
2594 }
2595
2596 if (Priority) {
2597 //
2598 // The return type of "__kmpc_omp_task_alloc" is "kmp_task_t *",
2599 // we populate the priority information into the "kmp_task_t" here
2600 //
2601 // The struct "kmp_task_t" definition is available in kmp.h
2602 // kmp_task_t = { shareds, routine, part_id, data1, data2 }
2603 // data2 is used for priority
2604 //
2605 Type *Int32Ty = Builder.getInt32Ty();
2606 Constant *Zero = ConstantInt::get(Ty: Int32Ty, V: 0);
2607 // kmp_task_t* => { ptr }
2608 Type *TaskPtr = StructType::get(elt1: VoidPtr);
2609 Value *TaskGEP =
2610 Builder.CreateInBoundsGEP(Ty: TaskPtr, Ptr: TaskData, IdxList: {Zero, Zero});
2611 // kmp_task_t => { ptr, ptr, i32, ptr, ptr }
2612 Type *TaskStructType = StructType::get(
2613 elt1: VoidPtr, elts: VoidPtr, elts: Builder.getInt32Ty(), elts: VoidPtr, elts: VoidPtr);
2614 Value *PriorityData = Builder.CreateInBoundsGEP(
2615 Ty: TaskStructType, Ptr: TaskGEP, IdxList: {Zero, ConstantInt::get(Ty: Int32Ty, V: 4)});
2616 // kmp_cmplrdata_t => { ptr, ptr }
2617 Type *CmplrStructType = StructType::get(elt1: VoidPtr, elts: VoidPtr);
2618 Value *CmplrData = Builder.CreateInBoundsGEP(Ty: CmplrStructType,
2619 Ptr: PriorityData, IdxList: {Zero, Zero});
2620 Builder.CreateStore(Val: Priority, Ptr: CmplrData);
2621 }
2622
2623 Value *DepArray = emitTaskDependencies(OMPBuilder&: *this, Dependencies);
2624
2625 // In the presence of the `if` clause, the following IR is generated:
2626 // ...
2627 // %data = call @__kmpc_omp_task_alloc(...)
2628 // br i1 %if_condition, label %then, label %else
2629 // then:
2630 // call @__kmpc_omp_task(...)
2631 // br label %exit
2632 // else:
2633 // ;; Wait for resolution of dependencies, if any, before
2634 // ;; beginning the task
2635 // call @__kmpc_omp_wait_deps(...)
2636 // call @__kmpc_omp_task_begin_if0(...)
2637 // call @outlined_fn(...)
2638 // call @__kmpc_omp_task_complete_if0(...)
2639 // br label %exit
2640 // exit:
2641 // ...
2642 if (IfCondition) {
2643 // `SplitBlockAndInsertIfThenElse` requires the block to have a
2644 // terminator.
2645 splitBB(Builder, /*CreateBranch=*/true, Name: "if.end");
2646 Instruction *IfTerminator =
2647 Builder.GetInsertPoint()->getParent()->getTerminator();
2648 Instruction *ThenTI = IfTerminator, *ElseTI = nullptr;
2649 Builder.SetInsertPoint(IfTerminator);
2650 SplitBlockAndInsertIfThenElse(Cond: IfCondition, SplitBefore: IfTerminator, ThenTerm: &ThenTI,
2651 ElseTerm: &ElseTI);
2652 Builder.SetInsertPoint(ElseTI);
2653
2654 if (Dependencies.size()) {
2655 Function *TaskWaitFn =
2656 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_wait_deps);
2657 createRuntimeFunctionCall(
2658 Callee: TaskWaitFn,
2659 Args: {Ident, ThreadID, Builder.getInt32(C: Dependencies.size()), DepArray,
2660 ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
2661 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
2662 }
2663 Function *TaskBeginFn =
2664 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_begin_if0);
2665 Function *TaskCompleteFn =
2666 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_complete_if0);
2667 createRuntimeFunctionCall(Callee: TaskBeginFn, Args: {Ident, ThreadID, TaskData});
2668 CallInst *CI = nullptr;
2669 if (HasShareds)
2670 CI = createRuntimeFunctionCall(Callee: &OutlinedFn, Args: {ThreadID, TaskData});
2671 else
2672 CI = createRuntimeFunctionCall(Callee: &OutlinedFn, Args: {ThreadID});
2673 CI->setDebugLoc(StaleCI->getDebugLoc());
2674 createRuntimeFunctionCall(Callee: TaskCompleteFn, Args: {Ident, ThreadID, TaskData});
2675 Builder.SetInsertPoint(ThenTI);
2676 }
2677
2678 if (Dependencies.size()) {
2679 Function *TaskFn =
2680 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_with_deps);
2681 createRuntimeFunctionCall(
2682 Callee: TaskFn,
2683 Args: {Ident, ThreadID, TaskData, Builder.getInt32(C: Dependencies.size()),
2684 DepArray, ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
2685 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
2686
2687 } else {
2688 // Emit the @__kmpc_omp_task runtime call to spawn the task
2689 Function *TaskFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task);
2690 createRuntimeFunctionCall(Callee: TaskFn, Args: {Ident, ThreadID, TaskData});
2691 }
2692
2693 StaleCI->eraseFromParent();
2694
2695 Builder.SetInsertPoint(TheBB: TaskAllocaBB, IP: TaskAllocaBB->begin());
2696 if (HasShareds) {
2697 LoadInst *Shareds = Builder.CreateLoad(Ty: VoidPtr, Ptr: OutlinedFn.getArg(i: 1));
2698 OutlinedFn.getArg(i: 1)->replaceUsesWithIf(
2699 New: Shareds, ShouldReplace: [Shareds](Use &U) { return U.getUser() != Shareds; });
2700 }
2701
2702 for (Instruction *I : llvm::reverse(C&: ToBeDeleted))
2703 I->eraseFromParent();
2704 };
2705
2706 addOutlineInfo(OI: std::move(OI));
2707 Builder.SetInsertPoint(TheBB: TaskExitBB, IP: TaskExitBB->begin());
2708
2709 return Builder.saveIP();
2710}
2711
2712OpenMPIRBuilder::InsertPointOrErrorTy
2713OpenMPIRBuilder::createTaskgroup(const LocationDescription &Loc,
2714 InsertPointTy AllocaIP,
2715 BodyGenCallbackTy BodyGenCB) {
2716 if (!updateToLocation(Loc))
2717 return InsertPointTy();
2718
2719 uint32_t SrcLocStrSize;
2720 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2721 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2722 Value *ThreadID = getOrCreateThreadID(Ident);
2723
2724 // Emit the @__kmpc_taskgroup runtime call to start the taskgroup
2725 Function *TaskgroupFn =
2726 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_taskgroup);
2727 createRuntimeFunctionCall(Callee: TaskgroupFn, Args: {Ident, ThreadID});
2728
2729 BasicBlock *TaskgroupExitBB = splitBB(Builder, CreateBranch: true, Name: "taskgroup.exit");
2730 if (Error Err = BodyGenCB(AllocaIP, Builder.saveIP()))
2731 return Err;
2732
2733 Builder.SetInsertPoint(TaskgroupExitBB);
2734 // Emit the @__kmpc_end_taskgroup runtime call to end the taskgroup
2735 Function *EndTaskgroupFn =
2736 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_taskgroup);
2737 createRuntimeFunctionCall(Callee: EndTaskgroupFn, Args: {Ident, ThreadID});
2738
2739 return Builder.saveIP();
2740}
2741
2742OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createSections(
2743 const LocationDescription &Loc, InsertPointTy AllocaIP,
2744 ArrayRef<StorableBodyGenCallbackTy> SectionCBs, PrivatizeCallbackTy PrivCB,
2745 FinalizeCallbackTy FiniCB, bool IsCancellable, bool IsNowait) {
2746 assert(!isConflictIP(AllocaIP, Loc.IP) && "Dedicated IP allocas required");
2747
2748 if (!updateToLocation(Loc))
2749 return Loc.IP;
2750
2751 FinalizationStack.push_back(Elt: {FiniCB, OMPD_sections, IsCancellable});
2752
2753 // Each section is emitted as a switch case
2754 // Each finalization callback is handled from clang.EmitOMPSectionDirective()
2755 // -> OMP.createSection() which generates the IR for each section
2756 // Iterate through all sections and emit a switch construct:
2757 // switch (IV) {
2758 // case 0:
2759 // <SectionStmt[0]>;
2760 // break;
2761 // ...
2762 // case <NumSection> - 1:
2763 // <SectionStmt[<NumSection> - 1]>;
2764 // break;
2765 // }
2766 // ...
2767 // section_loop.after:
2768 // <FiniCB>;
2769 auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, Value *IndVar) -> Error {
2770 Builder.restoreIP(IP: CodeGenIP);
2771 BasicBlock *Continue =
2772 splitBBWithSuffix(Builder, /*CreateBranch=*/false, Suffix: ".sections.after");
2773 Function *CurFn = Continue->getParent();
2774 SwitchInst *SwitchStmt = Builder.CreateSwitch(V: IndVar, Dest: Continue);
2775
2776 unsigned CaseNumber = 0;
2777 for (auto SectionCB : SectionCBs) {
2778 BasicBlock *CaseBB = BasicBlock::Create(
2779 Context&: M.getContext(), Name: "omp_section_loop.body.case", Parent: CurFn, InsertBefore: Continue);
2780 SwitchStmt->addCase(OnVal: Builder.getInt32(C: CaseNumber), Dest: CaseBB);
2781 Builder.SetInsertPoint(CaseBB);
2782 UncondBrInst *CaseEndBr = Builder.CreateBr(Dest: Continue);
2783 if (Error Err = SectionCB(InsertPointTy(), {CaseEndBr->getParent(),
2784 CaseEndBr->getIterator()}))
2785 return Err;
2786 CaseNumber++;
2787 }
2788 // remove the existing terminator from body BB since there can be no
2789 // terminators after switch/case
2790 return Error::success();
2791 };
2792 // Loop body ends here
2793 // LowerBound, UpperBound, and STride for createCanonicalLoop
2794 Type *I32Ty = Type::getInt32Ty(C&: M.getContext());
2795 Value *LB = ConstantInt::get(Ty: I32Ty, V: 0);
2796 Value *UB = ConstantInt::get(Ty: I32Ty, V: SectionCBs.size());
2797 Value *ST = ConstantInt::get(Ty: I32Ty, V: 1);
2798 Expected<CanonicalLoopInfo *> LoopInfo = createCanonicalLoop(
2799 Loc, BodyGenCB: LoopBodyGenCB, Start: LB, Stop: UB, Step: ST, IsSigned: true, InclusiveStop: false, ComputeIP: AllocaIP, Name: "section_loop");
2800 if (!LoopInfo)
2801 return LoopInfo.takeError();
2802
2803 InsertPointOrErrorTy WsloopIP =
2804 applyStaticWorkshareLoop(DL: Loc.DL, CLI: *LoopInfo, AllocaIP,
2805 LoopType: WorksharingLoopType::ForStaticLoop, NeedsBarrier: !IsNowait);
2806 if (!WsloopIP)
2807 return WsloopIP.takeError();
2808 InsertPointTy AfterIP = *WsloopIP;
2809
2810 BasicBlock *LoopFini = AfterIP.getBlock()->getSinglePredecessor();
2811 assert(LoopFini && "Bad structure of static workshare loop finalization");
2812
2813 // Apply the finalization callback in LoopAfterBB
2814 auto FiniInfo = FinalizationStack.pop_back_val();
2815 assert(FiniInfo.DK == OMPD_sections &&
2816 "Unexpected finalization stack state!");
2817 if (Error Err = FiniInfo.mergeFiniBB(Builder, OtherFiniBB: LoopFini))
2818 return Err;
2819
2820 return AfterIP;
2821}
2822
2823OpenMPIRBuilder::InsertPointOrErrorTy
2824OpenMPIRBuilder::createSection(const LocationDescription &Loc,
2825 BodyGenCallbackTy BodyGenCB,
2826 FinalizeCallbackTy FiniCB) {
2827 if (!updateToLocation(Loc))
2828 return Loc.IP;
2829
2830 auto FiniCBWrapper = [&](InsertPointTy IP) {
2831 if (IP.getBlock()->end() != IP.getPoint())
2832 return FiniCB(IP);
2833 // This must be done otherwise any nested constructs using FinalizeOMPRegion
2834 // will fail because that function requires the Finalization Basic Block to
2835 // have a terminator, which is already removed by EmitOMPRegionBody.
2836 // IP is currently at cancelation block.
2837 // We need to backtrack to the condition block to fetch
2838 // the exit block and create a branch from cancelation
2839 // to exit block.
2840 IRBuilder<>::InsertPointGuard IPG(Builder);
2841 Builder.restoreIP(IP);
2842 auto *CaseBB = Loc.IP.getBlock();
2843 auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
2844 auto *ExitBB = CondBB->getTerminator()->getSuccessor(Idx: 1);
2845 Instruction *I = Builder.CreateBr(Dest: ExitBB);
2846 IP = InsertPointTy(I->getParent(), I->getIterator());
2847 return FiniCB(IP);
2848 };
2849
2850 Directive OMPD = Directive::OMPD_sections;
2851 // Since we are using Finalization Callback here, HasFinalize
2852 // and IsCancellable have to be true
2853 return EmitOMPInlinedRegion(OMPD, EntryCall: nullptr, ExitCall: nullptr, BodyGenCB, FiniCB: FiniCBWrapper,
2854 /*Conditional*/ false, /*hasFinalize*/ HasFinalize: true,
2855 /*IsCancellable*/ true);
2856}
2857
2858static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
2859 BasicBlock::iterator IT(I);
2860 IT++;
2861 return OpenMPIRBuilder::InsertPointTy(I->getParent(), IT);
2862}
2863
2864Value *OpenMPIRBuilder::getGPUThreadID() {
2865 return createRuntimeFunctionCall(
2866 Callee: getOrCreateRuntimeFunction(M,
2867 FnID: OMPRTL___kmpc_get_hardware_thread_id_in_block),
2868 Args: {});
2869}
2870
2871Value *OpenMPIRBuilder::getGPUWarpSize() {
2872 return createRuntimeFunctionCall(
2873 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___kmpc_get_warp_size), Args: {});
2874}
2875
2876Value *OpenMPIRBuilder::getNVPTXWarpID() {
2877 unsigned LaneIDBits = Log2_32(Value: Config.getGridValue().GV_Warp_Size);
2878 return Builder.CreateAShr(LHS: getGPUThreadID(), RHS: LaneIDBits, Name: "nvptx_warp_id");
2879}
2880
2881Value *OpenMPIRBuilder::getNVPTXLaneID() {
2882 unsigned LaneIDBits = Log2_32(Value: Config.getGridValue().GV_Warp_Size);
2883 assert(LaneIDBits < 32 && "Invalid LaneIDBits size in NVPTX device.");
2884 unsigned LaneIDMask = ~0u >> (32u - LaneIDBits);
2885 return Builder.CreateAnd(LHS: getGPUThreadID(), RHS: Builder.getInt32(C: LaneIDMask),
2886 Name: "nvptx_lane_id");
2887}
2888
2889Value *OpenMPIRBuilder::castValueToType(InsertPointTy AllocaIP, Value *From,
2890 Type *ToType) {
2891 Type *FromType = From->getType();
2892 uint64_t FromSize = M.getDataLayout().getTypeStoreSize(Ty: FromType);
2893 uint64_t ToSize = M.getDataLayout().getTypeStoreSize(Ty: ToType);
2894 assert(FromSize > 0 && "From size must be greater than zero");
2895 assert(ToSize > 0 && "To size must be greater than zero");
2896 if (FromType == ToType)
2897 return From;
2898 if (FromSize == ToSize)
2899 return Builder.CreateBitCast(V: From, DestTy: ToType);
2900 if (ToType->isIntegerTy() && FromType->isIntegerTy())
2901 return Builder.CreateIntCast(V: From, DestTy: ToType, /*isSigned*/ true);
2902 InsertPointTy SaveIP = Builder.saveIP();
2903 Builder.restoreIP(IP: AllocaIP);
2904 Value *CastItem = Builder.CreateAlloca(Ty: ToType);
2905 Builder.restoreIP(IP: SaveIP);
2906
2907 Value *ValCastItem = Builder.CreatePointerBitCastOrAddrSpaceCast(
2908 V: CastItem, DestTy: Builder.getPtrTy(AddrSpace: 0));
2909 Builder.CreateStore(Val: From, Ptr: ValCastItem);
2910 return Builder.CreateLoad(Ty: ToType, Ptr: CastItem);
2911}
2912
2913Value *OpenMPIRBuilder::createRuntimeShuffleFunction(InsertPointTy AllocaIP,
2914 Value *Element,
2915 Type *ElementType,
2916 Value *Offset) {
2917 uint64_t Size = M.getDataLayout().getTypeStoreSize(Ty: ElementType);
2918 assert(Size <= 8 && "Unsupported bitwidth in shuffle instruction");
2919
2920 // Cast all types to 32- or 64-bit values before calling shuffle routines.
2921 Type *CastTy = Builder.getIntNTy(N: Size <= 4 ? 32 : 64);
2922 Value *ElemCast = castValueToType(AllocaIP, From: Element, ToType: CastTy);
2923 Value *WarpSize =
2924 Builder.CreateIntCast(V: getGPUWarpSize(), DestTy: Builder.getInt16Ty(), isSigned: true);
2925 Function *ShuffleFunc = getOrCreateRuntimeFunctionPtr(
2926 FnID: Size <= 4 ? RuntimeFunction::OMPRTL___kmpc_shuffle_int32
2927 : RuntimeFunction::OMPRTL___kmpc_shuffle_int64);
2928 Value *WarpSizeCast =
2929 Builder.CreateIntCast(V: WarpSize, DestTy: Builder.getInt16Ty(), /*isSigned=*/true);
2930 Value *ShuffleCall =
2931 createRuntimeFunctionCall(Callee: ShuffleFunc, Args: {ElemCast, Offset, WarpSizeCast});
2932 return castValueToType(AllocaIP, From: ShuffleCall, ToType: CastTy);
2933}
2934
2935void OpenMPIRBuilder::shuffleAndStore(InsertPointTy AllocaIP, Value *SrcAddr,
2936 Value *DstAddr, Type *ElemType,
2937 Value *Offset, Type *ReductionArrayTy,
2938 bool IsByRefElem) {
2939 uint64_t Size = M.getDataLayout().getTypeStoreSize(Ty: ElemType);
2940 // Create the loop over the big sized data.
2941 // ptr = (void*)Elem;
2942 // ptrEnd = (void*) Elem + 1;
2943 // Step = 8;
2944 // while (ptr + Step < ptrEnd)
2945 // shuffle((int64_t)*ptr);
2946 // Step = 4;
2947 // while (ptr + Step < ptrEnd)
2948 // shuffle((int32_t)*ptr);
2949 // ...
2950 Type *IndexTy = Builder.getIndexTy(
2951 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
2952 Value *ElemPtr = DstAddr;
2953 Value *Ptr = SrcAddr;
2954 for (unsigned IntSize = 8; IntSize >= 1; IntSize /= 2) {
2955 if (Size < IntSize)
2956 continue;
2957 Type *IntType = Builder.getIntNTy(N: IntSize * 8);
2958 Ptr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2959 V: Ptr, DestTy: Builder.getPtrTy(AddrSpace: 0), Name: Ptr->getName() + ".ascast");
2960 Value *SrcAddrGEP =
2961 Builder.CreateGEP(Ty: ElemType, Ptr: SrcAddr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
2962 ElemPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2963 V: ElemPtr, DestTy: Builder.getPtrTy(AddrSpace: 0), Name: ElemPtr->getName() + ".ascast");
2964
2965 Function *CurFunc = Builder.GetInsertBlock()->getParent();
2966 if ((Size / IntSize) > 1) {
2967 Value *PtrEnd = Builder.CreatePointerBitCastOrAddrSpaceCast(
2968 V: SrcAddrGEP, DestTy: Builder.getPtrTy());
2969 BasicBlock *PreCondBB =
2970 BasicBlock::Create(Context&: M.getContext(), Name: ".shuffle.pre_cond");
2971 BasicBlock *ThenBB = BasicBlock::Create(Context&: M.getContext(), Name: ".shuffle.then");
2972 BasicBlock *ExitBB = BasicBlock::Create(Context&: M.getContext(), Name: ".shuffle.exit");
2973 BasicBlock *CurrentBB = Builder.GetInsertBlock();
2974 emitBlock(BB: PreCondBB, CurFn: CurFunc);
2975 PHINode *PhiSrc =
2976 Builder.CreatePHI(Ty: Ptr->getType(), /*NumReservedValues=*/2);
2977 PhiSrc->addIncoming(V: Ptr, BB: CurrentBB);
2978 PHINode *PhiDest =
2979 Builder.CreatePHI(Ty: ElemPtr->getType(), /*NumReservedValues=*/2);
2980 PhiDest->addIncoming(V: ElemPtr, BB: CurrentBB);
2981 Ptr = PhiSrc;
2982 ElemPtr = PhiDest;
2983 Value *PtrDiff = Builder.CreatePtrDiff(
2984 ElemTy: Builder.getInt8Ty(), LHS: PtrEnd,
2985 RHS: Builder.CreatePointerBitCastOrAddrSpaceCast(V: Ptr, DestTy: Builder.getPtrTy()));
2986 Builder.CreateCondBr(
2987 Cond: Builder.CreateICmpSGT(LHS: PtrDiff, RHS: Builder.getInt64(C: IntSize - 1)), True: ThenBB,
2988 False: ExitBB);
2989 emitBlock(BB: ThenBB, CurFn: CurFunc);
2990 Value *Res = createRuntimeShuffleFunction(
2991 AllocaIP,
2992 Element: Builder.CreateAlignedLoad(
2993 Ty: IntType, Ptr, Align: M.getDataLayout().getPrefTypeAlign(Ty: ElemType)),
2994 ElementType: IntType, Offset);
2995 Builder.CreateAlignedStore(Val: Res, Ptr: ElemPtr,
2996 Align: M.getDataLayout().getPrefTypeAlign(Ty: ElemType));
2997 Value *LocalPtr =
2998 Builder.CreateGEP(Ty: IntType, Ptr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
2999 Value *LocalElemPtr =
3000 Builder.CreateGEP(Ty: IntType, Ptr: ElemPtr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
3001 PhiSrc->addIncoming(V: LocalPtr, BB: ThenBB);
3002 PhiDest->addIncoming(V: LocalElemPtr, BB: ThenBB);
3003 emitBranch(Target: PreCondBB);
3004 emitBlock(BB: ExitBB, CurFn: CurFunc);
3005 } else {
3006 Value *Res = createRuntimeShuffleFunction(
3007 AllocaIP, Element: Builder.CreateLoad(Ty: IntType, Ptr), ElementType: IntType, Offset);
3008 if (ElemType->isIntegerTy() && ElemType->getScalarSizeInBits() <
3009 Res->getType()->getScalarSizeInBits())
3010 Res = Builder.CreateTrunc(V: Res, DestTy: ElemType);
3011 Builder.CreateStore(Val: Res, Ptr: ElemPtr);
3012 Ptr = Builder.CreateGEP(Ty: IntType, Ptr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
3013 ElemPtr =
3014 Builder.CreateGEP(Ty: IntType, Ptr: ElemPtr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
3015 }
3016 Size = Size % IntSize;
3017 }
3018}
3019
3020Error OpenMPIRBuilder::emitReductionListCopy(
3021 InsertPointTy AllocaIP, CopyAction Action, Type *ReductionArrayTy,
3022 ArrayRef<ReductionInfo> ReductionInfos, Value *SrcBase, Value *DestBase,
3023 ArrayRef<bool> IsByRef, CopyOptionsTy CopyOptions) {
3024 Type *IndexTy = Builder.getIndexTy(
3025 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3026 Value *RemoteLaneOffset = CopyOptions.RemoteLaneOffset;
3027
3028 // Iterates, element-by-element, through the source Reduce list and
3029 // make a copy.
3030 for (auto En : enumerate(First&: ReductionInfos)) {
3031 const ReductionInfo &RI = En.value();
3032 Value *SrcElementAddr = nullptr;
3033 AllocaInst *DestAlloca = nullptr;
3034 Value *DestElementAddr = nullptr;
3035 Value *DestElementPtrAddr = nullptr;
3036 // Should we shuffle in an element from a remote lane?
3037 bool ShuffleInElement = false;
3038 // Set to true to update the pointer in the dest Reduce list to a
3039 // newly created element.
3040 bool UpdateDestListPtr = false;
3041
3042 // Step 1.1: Get the address for the src element in the Reduce list.
3043 Value *SrcElementPtrAddr = Builder.CreateInBoundsGEP(
3044 Ty: ReductionArrayTy, Ptr: SrcBase,
3045 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3046 SrcElementAddr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: SrcElementPtrAddr);
3047
3048 // Step 1.2: Create a temporary to store the element in the destination
3049 // Reduce list.
3050 DestElementPtrAddr = Builder.CreateInBoundsGEP(
3051 Ty: ReductionArrayTy, Ptr: DestBase,
3052 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3053 bool IsByRefElem = (!IsByRef.empty() && IsByRef[En.index()]);
3054 switch (Action) {
3055 case CopyAction::RemoteLaneToThread: {
3056 InsertPointTy CurIP = Builder.saveIP();
3057 Builder.restoreIP(IP: AllocaIP);
3058
3059 Type *DestAllocaType =
3060 IsByRefElem ? RI.ByRefAllocatedType : RI.ElementType;
3061 DestAlloca = Builder.CreateAlloca(Ty: DestAllocaType, ArraySize: nullptr,
3062 Name: ".omp.reduction.element");
3063 DestAlloca->setAlignment(
3064 M.getDataLayout().getPrefTypeAlign(Ty: DestAllocaType));
3065 DestElementAddr = DestAlloca;
3066 DestElementAddr =
3067 Builder.CreateAddrSpaceCast(V: DestElementAddr, DestTy: Builder.getPtrTy(),
3068 Name: DestElementAddr->getName() + ".ascast");
3069 Builder.restoreIP(IP: CurIP);
3070 ShuffleInElement = true;
3071 UpdateDestListPtr = true;
3072 break;
3073 }
3074 case CopyAction::ThreadCopy: {
3075 DestElementAddr =
3076 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: DestElementPtrAddr);
3077 break;
3078 }
3079 }
3080
3081 // Now that all active lanes have read the element in the
3082 // Reduce list, shuffle over the value from the remote lane.
3083 if (ShuffleInElement) {
3084 Type *ShuffleType = RI.ElementType;
3085 Value *ShuffleSrcAddr = SrcElementAddr;
3086 Value *ShuffleDestAddr = DestElementAddr;
3087 AllocaInst *LocalStorage = nullptr;
3088
3089 if (IsByRefElem) {
3090 assert(RI.ByRefElementType && "Expected by-ref element type to be set");
3091 assert(RI.ByRefAllocatedType &&
3092 "Expected by-ref allocated type to be set");
3093 // For by-ref reductions, we need to copy from the remote lane the
3094 // actual value of the partial reduction computed by that remote lane;
3095 // rather than, for example, a pointer to that data or, even worse, a
3096 // pointer to the descriptor of the by-ref reduction element.
3097 ShuffleType = RI.ByRefElementType;
3098
3099 InsertPointOrErrorTy GenResult =
3100 RI.DataPtrPtrGen(Builder.saveIP(), ShuffleSrcAddr, ShuffleSrcAddr);
3101
3102 if (!GenResult)
3103 return GenResult.takeError();
3104
3105 ShuffleSrcAddr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ShuffleSrcAddr);
3106
3107 {
3108 InsertPointTy OldIP = Builder.saveIP();
3109 Builder.restoreIP(IP: AllocaIP);
3110
3111 LocalStorage = Builder.CreateAlloca(Ty: ShuffleType);
3112 Builder.restoreIP(IP: OldIP);
3113 ShuffleDestAddr = LocalStorage;
3114 }
3115 }
3116
3117 shuffleAndStore(AllocaIP, SrcAddr: ShuffleSrcAddr, DstAddr: ShuffleDestAddr, ElemType: ShuffleType,
3118 Offset: RemoteLaneOffset, ReductionArrayTy, IsByRefElem);
3119
3120 if (IsByRefElem) {
3121 // Copy descriptor from source and update base_ptr to shuffled data
3122 Value *DestDescriptorAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3123 V: DestAlloca, DestTy: Builder.getPtrTy(), Name: ".ascast");
3124
3125 InsertPointOrErrorTy GenResult = generateReductionDescriptor(
3126 DescriptorAddr: DestDescriptorAddr, DataPtr: LocalStorage, SrcDescriptorAddr: SrcElementAddr,
3127 DescriptorType: RI.ByRefAllocatedType, DataPtrPtrGen: RI.DataPtrPtrGen);
3128
3129 if (!GenResult)
3130 return GenResult.takeError();
3131 }
3132 } else {
3133 switch (RI.EvaluationKind) {
3134 case EvalKind::Scalar: {
3135 Value *Elem = Builder.CreateLoad(Ty: RI.ElementType, Ptr: SrcElementAddr);
3136 // Store the source element value to the dest element address.
3137 Builder.CreateStore(Val: Elem, Ptr: DestElementAddr);
3138 break;
3139 }
3140 case EvalKind::Complex: {
3141 Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
3142 Ty: RI.ElementType, Ptr: SrcElementAddr, Idx0: 0, Idx1: 0, Name: ".realp");
3143 Value *SrcReal = Builder.CreateLoad(
3144 Ty: RI.ElementType->getStructElementType(N: 0), Ptr: SrcRealPtr, Name: ".real");
3145 Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
3146 Ty: RI.ElementType, Ptr: SrcElementAddr, Idx0: 0, Idx1: 1, Name: ".imagp");
3147 Value *SrcImg = Builder.CreateLoad(
3148 Ty: RI.ElementType->getStructElementType(N: 1), Ptr: SrcImgPtr, Name: ".imag");
3149
3150 Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
3151 Ty: RI.ElementType, Ptr: DestElementAddr, Idx0: 0, Idx1: 0, Name: ".realp");
3152 Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
3153 Ty: RI.ElementType, Ptr: DestElementAddr, Idx0: 0, Idx1: 1, Name: ".imagp");
3154 Builder.CreateStore(Val: SrcReal, Ptr: DestRealPtr);
3155 Builder.CreateStore(Val: SrcImg, Ptr: DestImgPtr);
3156 break;
3157 }
3158 case EvalKind::Aggregate: {
3159 Value *SizeVal = Builder.getInt64(
3160 C: M.getDataLayout().getTypeStoreSize(Ty: RI.ElementType));
3161 Builder.CreateMemCpy(
3162 Dst: DestElementAddr, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
3163 Src: SrcElementAddr, SrcAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
3164 Size: SizeVal, isVolatile: false);
3165 break;
3166 }
3167 };
3168 }
3169
3170 // Step 3.1: Modify reference in dest Reduce list as needed.
3171 // Modifying the reference in Reduce list to point to the newly
3172 // created element. The element is live in the current function
3173 // scope and that of functions it invokes (i.e., reduce_function).
3174 // RemoteReduceData[i] = (void*)&RemoteElem
3175 if (UpdateDestListPtr) {
3176 Value *CastDestAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3177 V: DestElementAddr, DestTy: Builder.getPtrTy(),
3178 Name: DestElementAddr->getName() + ".ascast");
3179 Builder.CreateStore(Val: CastDestAddr, Ptr: DestElementPtrAddr);
3180 }
3181 }
3182
3183 return Error::success();
3184}
3185
3186Expected<Function *> OpenMPIRBuilder::emitInterWarpCopyFunction(
3187 const LocationDescription &Loc, ArrayRef<ReductionInfo> ReductionInfos,
3188 AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
3189 InsertPointTy SavedIP = Builder.saveIP();
3190 LLVMContext &Ctx = M.getContext();
3191 FunctionType *FuncTy = FunctionType::get(
3192 Result: Builder.getVoidTy(), Params: {Builder.getPtrTy(), Builder.getInt32Ty()},
3193 /* IsVarArg */ isVarArg: false);
3194 Function *WcFunc =
3195 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3196 N: "_omp_reduction_inter_warp_copy_func", M: &M);
3197 WcFunc->setAttributes(FuncAttrs);
3198 WcFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3199 WcFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3200 BasicBlock *EntryBB = BasicBlock::Create(Context&: M.getContext(), Name: "entry", Parent: WcFunc);
3201 Builder.SetInsertPoint(EntryBB);
3202
3203 // ReduceList: thread local Reduce list.
3204 // At the stage of the computation when this function is called, partially
3205 // aggregated values reside in the first lane of every active warp.
3206 Argument *ReduceListArg = WcFunc->getArg(i: 0);
3207 // NumWarps: number of warps active in the parallel region. This could
3208 // be smaller than 32 (max warps in a CTA) for partial block reduction.
3209 Argument *NumWarpsArg = WcFunc->getArg(i: 1);
3210
3211 // This array is used as a medium to transfer, one reduce element at a time,
3212 // the data from the first lane of every warp to lanes in the first warp
3213 // in order to perform the final step of a reduction in a parallel region
3214 // (reduction across warps). The array is placed in NVPTX __shared__ memory
3215 // for reduced latency, as well as to have a distinct copy for concurrently
3216 // executing target regions. The array is declared with common linkage so
3217 // as to be shared across compilation units.
3218 StringRef TransferMediumName =
3219 "__openmp_nvptx_data_transfer_temporary_storage";
3220 GlobalVariable *TransferMedium = M.getGlobalVariable(Name: TransferMediumName);
3221 unsigned WarpSize = Config.getGridValue().GV_Warp_Size;
3222 ArrayType *ArrayTy = ArrayType::get(ElementType: Builder.getInt32Ty(), NumElements: WarpSize);
3223 if (!TransferMedium) {
3224 TransferMedium = new GlobalVariable(
3225 M, ArrayTy, /*isConstant=*/false, GlobalVariable::WeakAnyLinkage,
3226 UndefValue::get(T: ArrayTy), TransferMediumName,
3227 /*InsertBefore=*/nullptr, GlobalVariable::NotThreadLocal,
3228 /*AddressSpace=*/3);
3229 }
3230
3231 // Get the CUDA thread id of the current OpenMP thread on the GPU.
3232 Value *GPUThreadID = getGPUThreadID();
3233 // nvptx_lane_id = nvptx_id % warpsize
3234 Value *LaneID = getNVPTXLaneID();
3235 // nvptx_warp_id = nvptx_id / warpsize
3236 Value *WarpID = getNVPTXWarpID();
3237
3238 InsertPointTy AllocaIP =
3239 InsertPointTy(Builder.GetInsertBlock(),
3240 Builder.GetInsertBlock()->getFirstInsertionPt());
3241 Type *Arg0Type = ReduceListArg->getType();
3242 Type *Arg1Type = NumWarpsArg->getType();
3243 Builder.restoreIP(IP: AllocaIP);
3244 AllocaInst *ReduceListAlloca = Builder.CreateAlloca(
3245 Ty: Arg0Type, ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3246 AllocaInst *NumWarpsAlloca =
3247 Builder.CreateAlloca(Ty: Arg1Type, ArraySize: nullptr, Name: NumWarpsArg->getName() + ".addr");
3248 Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3249 V: ReduceListAlloca, DestTy: Arg0Type, Name: ReduceListAlloca->getName() + ".ascast");
3250 Value *NumWarpsAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3251 V: NumWarpsAlloca, DestTy: Builder.getPtrTy(AddrSpace: 0),
3252 Name: NumWarpsAlloca->getName() + ".ascast");
3253 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListAddrCast);
3254 Builder.CreateStore(Val: NumWarpsArg, Ptr: NumWarpsAddrCast);
3255 AllocaIP = getInsertPointAfterInstr(I: NumWarpsAlloca);
3256 InsertPointTy CodeGenIP =
3257 getInsertPointAfterInstr(I: &Builder.GetInsertBlock()->back());
3258 Builder.restoreIP(IP: CodeGenIP);
3259
3260 Value *ReduceList =
3261 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListAddrCast);
3262
3263 for (auto En : enumerate(First&: ReductionInfos)) {
3264 //
3265 // Warp master copies reduce element to transfer medium in __shared__
3266 // memory.
3267 //
3268 const ReductionInfo &RI = En.value();
3269 bool IsByRefElem = !IsByRef.empty() && IsByRef[En.index()];
3270 unsigned RealTySize = M.getDataLayout().getTypeAllocSize(
3271 Ty: IsByRefElem ? RI.ByRefElementType : RI.ElementType);
3272 for (unsigned TySize = 4; TySize > 0 && RealTySize > 0; TySize /= 2) {
3273 Type *CType = Builder.getIntNTy(N: TySize * 8);
3274
3275 unsigned NumIters = RealTySize / TySize;
3276 if (NumIters == 0)
3277 continue;
3278 Value *Cnt = nullptr;
3279 Value *CntAddr = nullptr;
3280 BasicBlock *PrecondBB = nullptr;
3281 BasicBlock *ExitBB = nullptr;
3282 if (NumIters > 1) {
3283 CodeGenIP = Builder.saveIP();
3284 Builder.restoreIP(IP: AllocaIP);
3285 CntAddr =
3286 Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr, Name: ".cnt.addr");
3287
3288 CntAddr = Builder.CreateAddrSpaceCast(V: CntAddr, DestTy: Builder.getPtrTy(),
3289 Name: CntAddr->getName() + ".ascast");
3290 Builder.restoreIP(IP: CodeGenIP);
3291 Builder.CreateStore(Val: Constant::getNullValue(Ty: Builder.getInt32Ty()),
3292 Ptr: CntAddr,
3293 /*Volatile=*/isVolatile: false);
3294 PrecondBB = BasicBlock::Create(Context&: Ctx, Name: "precond");
3295 ExitBB = BasicBlock::Create(Context&: Ctx, Name: "exit");
3296 BasicBlock *BodyBB = BasicBlock::Create(Context&: Ctx, Name: "body");
3297 emitBlock(BB: PrecondBB, CurFn: Builder.GetInsertBlock()->getParent());
3298 Cnt = Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: CntAddr,
3299 /*Volatile=*/isVolatile: false);
3300 Value *Cmp = Builder.CreateICmpULT(
3301 LHS: Cnt, RHS: ConstantInt::get(Ty: Builder.getInt32Ty(), V: NumIters));
3302 Builder.CreateCondBr(Cond: Cmp, True: BodyBB, False: ExitBB);
3303 emitBlock(BB: BodyBB, CurFn: Builder.GetInsertBlock()->getParent());
3304 }
3305
3306 // kmpc_barrier.
3307 InsertPointOrErrorTy BarrierIP1 =
3308 createBarrier(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
3309 Kind: omp::Directive::OMPD_unknown,
3310 /* ForceSimpleCall */ false,
3311 /* CheckCancelFlag */ true);
3312 if (!BarrierIP1)
3313 return BarrierIP1.takeError();
3314 BasicBlock *ThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
3315 BasicBlock *ElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
3316 BasicBlock *MergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
3317
3318 // if (lane_id == 0)
3319 Value *IsWarpMaster = Builder.CreateIsNull(Arg: LaneID, Name: "warp_master");
3320 Builder.CreateCondBr(Cond: IsWarpMaster, True: ThenBB, False: ElseBB);
3321 emitBlock(BB: ThenBB, CurFn: Builder.GetInsertBlock()->getParent());
3322
3323 // Reduce element = LocalReduceList[i]
3324 auto *RedListArrayTy =
3325 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3326 Type *IndexTy = Builder.getIndexTy(
3327 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3328 Value *ElemPtrPtr =
3329 Builder.CreateInBoundsGEP(Ty: RedListArrayTy, Ptr: ReduceList,
3330 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0),
3331 ConstantInt::get(Ty: IndexTy, V: En.index())});
3332 // elemptr = ((CopyType*)(elemptrptr)) + I
3333 Value *ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtrPtr);
3334
3335 if (IsByRefElem) {
3336 InsertPointOrErrorTy GenRes =
3337 RI.DataPtrPtrGen(Builder.saveIP(), ElemPtr, ElemPtr);
3338
3339 if (!GenRes)
3340 return GenRes.takeError();
3341
3342 ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtr);
3343 }
3344
3345 if (NumIters > 1)
3346 ElemPtr = Builder.CreateGEP(Ty: Builder.getInt32Ty(), Ptr: ElemPtr, IdxList: Cnt);
3347
3348 // Get pointer to location in transfer medium.
3349 // MediumPtr = &medium[warp_id]
3350 Value *MediumPtr = Builder.CreateInBoundsGEP(
3351 Ty: ArrayTy, Ptr: TransferMedium, IdxList: {Builder.getInt64(C: 0), WarpID});
3352 // elem = *elemptr
3353 //*MediumPtr = elem
3354 Value *Elem = Builder.CreateLoad(Ty: CType, Ptr: ElemPtr);
3355 // Store the source element value to the dest element address.
3356 Builder.CreateStore(Val: Elem, Ptr: MediumPtr,
3357 /*IsVolatile*/ isVolatile: true);
3358 Builder.CreateBr(Dest: MergeBB);
3359
3360 // else
3361 emitBlock(BB: ElseBB, CurFn: Builder.GetInsertBlock()->getParent());
3362 Builder.CreateBr(Dest: MergeBB);
3363
3364 // endif
3365 emitBlock(BB: MergeBB, CurFn: Builder.GetInsertBlock()->getParent());
3366 InsertPointOrErrorTy BarrierIP2 =
3367 createBarrier(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
3368 Kind: omp::Directive::OMPD_unknown,
3369 /* ForceSimpleCall */ false,
3370 /* CheckCancelFlag */ true);
3371 if (!BarrierIP2)
3372 return BarrierIP2.takeError();
3373
3374 // Warp 0 copies reduce element from transfer medium
3375 BasicBlock *W0ThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
3376 BasicBlock *W0ElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
3377 BasicBlock *W0MergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
3378
3379 Value *NumWarpsVal =
3380 Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: NumWarpsAddrCast);
3381 // Up to 32 threads in warp 0 are active.
3382 Value *IsActiveThread =
3383 Builder.CreateICmpULT(LHS: GPUThreadID, RHS: NumWarpsVal, Name: "is_active_thread");
3384 Builder.CreateCondBr(Cond: IsActiveThread, True: W0ThenBB, False: W0ElseBB);
3385
3386 emitBlock(BB: W0ThenBB, CurFn: Builder.GetInsertBlock()->getParent());
3387
3388 // SecMediumPtr = &medium[tid]
3389 // SrcMediumVal = *SrcMediumPtr
3390 Value *SrcMediumPtrVal = Builder.CreateInBoundsGEP(
3391 Ty: ArrayTy, Ptr: TransferMedium, IdxList: {Builder.getInt64(C: 0), GPUThreadID});
3392 // TargetElemPtr = (CopyType*)(SrcDataAddr[i]) + I
3393 Value *TargetElemPtrPtr =
3394 Builder.CreateInBoundsGEP(Ty: RedListArrayTy, Ptr: ReduceList,
3395 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0),
3396 ConstantInt::get(Ty: IndexTy, V: En.index())});
3397 Value *TargetElemPtrVal =
3398 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: TargetElemPtrPtr);
3399 Value *TargetElemPtr = TargetElemPtrVal;
3400
3401 if (IsByRefElem) {
3402 InsertPointOrErrorTy GenRes =
3403 RI.DataPtrPtrGen(Builder.saveIP(), TargetElemPtr, TargetElemPtr);
3404
3405 if (!GenRes)
3406 return GenRes.takeError();
3407
3408 TargetElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: TargetElemPtr);
3409 }
3410
3411 if (NumIters > 1)
3412 TargetElemPtr =
3413 Builder.CreateGEP(Ty: Builder.getInt32Ty(), Ptr: TargetElemPtr, IdxList: Cnt);
3414
3415 // *TargetElemPtr = SrcMediumVal;
3416 Value *SrcMediumValue =
3417 Builder.CreateLoad(Ty: CType, Ptr: SrcMediumPtrVal, /*IsVolatile*/ isVolatile: true);
3418 Builder.CreateStore(Val: SrcMediumValue, Ptr: TargetElemPtr);
3419 Builder.CreateBr(Dest: W0MergeBB);
3420
3421 emitBlock(BB: W0ElseBB, CurFn: Builder.GetInsertBlock()->getParent());
3422 Builder.CreateBr(Dest: W0MergeBB);
3423
3424 emitBlock(BB: W0MergeBB, CurFn: Builder.GetInsertBlock()->getParent());
3425
3426 if (NumIters > 1) {
3427 Cnt = Builder.CreateNSWAdd(
3428 LHS: Cnt, RHS: ConstantInt::get(Ty: Builder.getInt32Ty(), /*V=*/1));
3429 Builder.CreateStore(Val: Cnt, Ptr: CntAddr, /*Volatile=*/isVolatile: false);
3430
3431 auto *CurFn = Builder.GetInsertBlock()->getParent();
3432 emitBranch(Target: PrecondBB);
3433 emitBlock(BB: ExitBB, CurFn);
3434 }
3435 RealTySize %= TySize;
3436 }
3437 }
3438
3439 Builder.CreateRetVoid();
3440 Builder.restoreIP(IP: SavedIP);
3441
3442 return WcFunc;
3443}
3444
3445Expected<Function *> OpenMPIRBuilder::emitShuffleAndReduceFunction(
3446 ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
3447 AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
3448 LLVMContext &Ctx = M.getContext();
3449 FunctionType *FuncTy =
3450 FunctionType::get(Result: Builder.getVoidTy(),
3451 Params: {Builder.getPtrTy(), Builder.getInt16Ty(),
3452 Builder.getInt16Ty(), Builder.getInt16Ty()},
3453 /* IsVarArg */ isVarArg: false);
3454 Function *SarFunc =
3455 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3456 N: "_omp_reduction_shuffle_and_reduce_func", M: &M);
3457 SarFunc->setAttributes(FuncAttrs);
3458 SarFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3459 SarFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3460 SarFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
3461 SarFunc->addParamAttr(ArgNo: 3, Kind: Attribute::NoUndef);
3462 SarFunc->addParamAttr(ArgNo: 1, Kind: Attribute::SExt);
3463 SarFunc->addParamAttr(ArgNo: 2, Kind: Attribute::SExt);
3464 SarFunc->addParamAttr(ArgNo: 3, Kind: Attribute::SExt);
3465 BasicBlock *EntryBB = BasicBlock::Create(Context&: M.getContext(), Name: "entry", Parent: SarFunc);
3466 Builder.SetInsertPoint(EntryBB);
3467
3468 // Thread local Reduce list used to host the values of data to be reduced.
3469 Argument *ReduceListArg = SarFunc->getArg(i: 0);
3470 // Current lane id; could be logical.
3471 Argument *LaneIDArg = SarFunc->getArg(i: 1);
3472 // Offset of the remote source lane relative to the current lane.
3473 Argument *RemoteLaneOffsetArg = SarFunc->getArg(i: 2);
3474 // Algorithm version. This is expected to be known at compile time.
3475 Argument *AlgoVerArg = SarFunc->getArg(i: 3);
3476
3477 Type *ReduceListArgType = ReduceListArg->getType();
3478 Type *LaneIDArgType = LaneIDArg->getType();
3479 Type *LaneIDArgPtrType = Builder.getPtrTy(AddrSpace: 0);
3480 Value *ReduceListAlloca = Builder.CreateAlloca(
3481 Ty: ReduceListArgType, ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3482 Value *LaneIdAlloca = Builder.CreateAlloca(Ty: LaneIDArgType, ArraySize: nullptr,
3483 Name: LaneIDArg->getName() + ".addr");
3484 Value *RemoteLaneOffsetAlloca = Builder.CreateAlloca(
3485 Ty: LaneIDArgType, ArraySize: nullptr, Name: RemoteLaneOffsetArg->getName() + ".addr");
3486 Value *AlgoVerAlloca = Builder.CreateAlloca(Ty: LaneIDArgType, ArraySize: nullptr,
3487 Name: AlgoVerArg->getName() + ".addr");
3488 ArrayType *RedListArrayTy =
3489 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3490
3491 // Create a local thread-private variable to host the Reduce list
3492 // from a remote lane.
3493 Instruction *RemoteReductionListAlloca = Builder.CreateAlloca(
3494 Ty: RedListArrayTy, ArraySize: nullptr, Name: ".omp.reduction.remote_reduce_list");
3495
3496 Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3497 V: ReduceListAlloca, DestTy: ReduceListArgType,
3498 Name: ReduceListAlloca->getName() + ".ascast");
3499 Value *LaneIdAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3500 V: LaneIdAlloca, DestTy: LaneIDArgPtrType, Name: LaneIdAlloca->getName() + ".ascast");
3501 Value *RemoteLaneOffsetAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3502 V: RemoteLaneOffsetAlloca, DestTy: LaneIDArgPtrType,
3503 Name: RemoteLaneOffsetAlloca->getName() + ".ascast");
3504 Value *AlgoVerAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3505 V: AlgoVerAlloca, DestTy: LaneIDArgPtrType, Name: AlgoVerAlloca->getName() + ".ascast");
3506 Value *RemoteListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3507 V: RemoteReductionListAlloca, DestTy: Builder.getPtrTy(),
3508 Name: RemoteReductionListAlloca->getName() + ".ascast");
3509
3510 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListAddrCast);
3511 Builder.CreateStore(Val: LaneIDArg, Ptr: LaneIdAddrCast);
3512 Builder.CreateStore(Val: RemoteLaneOffsetArg, Ptr: RemoteLaneOffsetAddrCast);
3513 Builder.CreateStore(Val: AlgoVerArg, Ptr: AlgoVerAddrCast);
3514
3515 Value *ReduceList = Builder.CreateLoad(Ty: ReduceListArgType, Ptr: ReduceListAddrCast);
3516 Value *LaneId = Builder.CreateLoad(Ty: LaneIDArgType, Ptr: LaneIdAddrCast);
3517 Value *RemoteLaneOffset =
3518 Builder.CreateLoad(Ty: LaneIDArgType, Ptr: RemoteLaneOffsetAddrCast);
3519 Value *AlgoVer = Builder.CreateLoad(Ty: LaneIDArgType, Ptr: AlgoVerAddrCast);
3520
3521 InsertPointTy AllocaIP = getInsertPointAfterInstr(I: RemoteReductionListAlloca);
3522
3523 // This loop iterates through the list of reduce elements and copies,
3524 // element by element, from a remote lane in the warp to RemoteReduceList,
3525 // hosted on the thread's stack.
3526 Error EmitRedLsCpRes = emitReductionListCopy(
3527 AllocaIP, Action: CopyAction::RemoteLaneToThread, ReductionArrayTy: RedListArrayTy, ReductionInfos,
3528 SrcBase: ReduceList, DestBase: RemoteListAddrCast, IsByRef,
3529 CopyOptions: {.RemoteLaneOffset: RemoteLaneOffset, .ScratchpadIndex: nullptr, .ScratchpadWidth: nullptr});
3530
3531 if (EmitRedLsCpRes)
3532 return EmitRedLsCpRes;
3533
3534 // The actions to be performed on the Remote Reduce list is dependent
3535 // on the algorithm version.
3536 //
3537 // if (AlgoVer==0) || (AlgoVer==1 && (LaneId < Offset)) || (AlgoVer==2 &&
3538 // LaneId % 2 == 0 && Offset > 0):
3539 // do the reduction value aggregation
3540 //
3541 // The thread local variable Reduce list is mutated in place to host the
3542 // reduced data, which is the aggregated value produced from local and
3543 // remote lanes.
3544 //
3545 // Note that AlgoVer is expected to be a constant integer known at compile
3546 // time.
3547 // When AlgoVer==0, the first conjunction evaluates to true, making
3548 // the entire predicate true during compile time.
3549 // When AlgoVer==1, the second conjunction has only the second part to be
3550 // evaluated during runtime. Other conjunctions evaluates to false
3551 // during compile time.
3552 // When AlgoVer==2, the third conjunction has only the second part to be
3553 // evaluated during runtime. Other conjunctions evaluates to false
3554 // during compile time.
3555 Value *CondAlgo0 = Builder.CreateIsNull(Arg: AlgoVer);
3556 Value *Algo1 = Builder.CreateICmpEQ(LHS: AlgoVer, RHS: Builder.getInt16(C: 1));
3557 Value *LaneComp = Builder.CreateICmpULT(LHS: LaneId, RHS: RemoteLaneOffset);
3558 Value *CondAlgo1 = Builder.CreateAnd(LHS: Algo1, RHS: LaneComp);
3559 Value *Algo2 = Builder.CreateICmpEQ(LHS: AlgoVer, RHS: Builder.getInt16(C: 2));
3560 Value *LaneIdAnd1 = Builder.CreateAnd(LHS: LaneId, RHS: Builder.getInt16(C: 1));
3561 Value *LaneIdComp = Builder.CreateIsNull(Arg: LaneIdAnd1);
3562 Value *Algo2AndLaneIdComp = Builder.CreateAnd(LHS: Algo2, RHS: LaneIdComp);
3563 Value *RemoteOffsetComp =
3564 Builder.CreateICmpSGT(LHS: RemoteLaneOffset, RHS: Builder.getInt16(C: 0));
3565 Value *CondAlgo2 = Builder.CreateAnd(LHS: Algo2AndLaneIdComp, RHS: RemoteOffsetComp);
3566 Value *CA0OrCA1 = Builder.CreateOr(LHS: CondAlgo0, RHS: CondAlgo1);
3567 Value *CondReduce = Builder.CreateOr(LHS: CA0OrCA1, RHS: CondAlgo2);
3568
3569 BasicBlock *ThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
3570 BasicBlock *ElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
3571 BasicBlock *MergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
3572
3573 Builder.CreateCondBr(Cond: CondReduce, True: ThenBB, False: ElseBB);
3574 emitBlock(BB: ThenBB, CurFn: Builder.GetInsertBlock()->getParent());
3575 Value *LocalReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3576 V: ReduceList, DestTy: Builder.getPtrTy());
3577 Value *RemoteReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3578 V: RemoteListAddrCast, DestTy: Builder.getPtrTy());
3579 createRuntimeFunctionCall(Callee: ReduceFn, Args: {LocalReduceListPtr, RemoteReduceListPtr})
3580 ->addFnAttr(Kind: Attribute::NoUnwind);
3581 Builder.CreateBr(Dest: MergeBB);
3582
3583 emitBlock(BB: ElseBB, CurFn: Builder.GetInsertBlock()->getParent());
3584 Builder.CreateBr(Dest: MergeBB);
3585
3586 emitBlock(BB: MergeBB, CurFn: Builder.GetInsertBlock()->getParent());
3587
3588 // if (AlgoVer==1 && (LaneId >= Offset)) copy Remote Reduce list to local
3589 // Reduce list.
3590 Algo1 = Builder.CreateICmpEQ(LHS: AlgoVer, RHS: Builder.getInt16(C: 1));
3591 Value *LaneIdGtOffset = Builder.CreateICmpUGE(LHS: LaneId, RHS: RemoteLaneOffset);
3592 Value *CondCopy = Builder.CreateAnd(LHS: Algo1, RHS: LaneIdGtOffset);
3593
3594 BasicBlock *CpyThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
3595 BasicBlock *CpyElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
3596 BasicBlock *CpyMergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
3597 Builder.CreateCondBr(Cond: CondCopy, True: CpyThenBB, False: CpyElseBB);
3598
3599 emitBlock(BB: CpyThenBB, CurFn: Builder.GetInsertBlock()->getParent());
3600
3601 EmitRedLsCpRes = emitReductionListCopy(
3602 AllocaIP, Action: CopyAction::ThreadCopy, ReductionArrayTy: RedListArrayTy, ReductionInfos,
3603 SrcBase: RemoteListAddrCast, DestBase: ReduceList, IsByRef);
3604
3605 if (EmitRedLsCpRes)
3606 return EmitRedLsCpRes;
3607
3608 Builder.CreateBr(Dest: CpyMergeBB);
3609
3610 emitBlock(BB: CpyElseBB, CurFn: Builder.GetInsertBlock()->getParent());
3611 Builder.CreateBr(Dest: CpyMergeBB);
3612
3613 emitBlock(BB: CpyMergeBB, CurFn: Builder.GetInsertBlock()->getParent());
3614
3615 Builder.CreateRetVoid();
3616
3617 return SarFunc;
3618}
3619
3620OpenMPIRBuilder::InsertPointOrErrorTy
3621OpenMPIRBuilder::generateReductionDescriptor(
3622 Value *DescriptorAddr, Value *DataPtr, Value *SrcDescriptorAddr,
3623 Type *DescriptorType,
3624 function_ref<InsertPointOrErrorTy(InsertPointTy, Value *, Value *&)>
3625 DataPtrPtrGen) {
3626
3627 // Copy the source descriptor to preserve all metadata (rank, extents,
3628 // strides, etc.)
3629 Value *DescriptorSize =
3630 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: DescriptorType));
3631 Builder.CreateMemCpy(
3632 Dst: DescriptorAddr, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: DescriptorType),
3633 Src: SrcDescriptorAddr, SrcAlign: M.getDataLayout().getPrefTypeAlign(Ty: DescriptorType),
3634 Size: DescriptorSize);
3635
3636 // Update the base pointer field to point to the local shuffled data
3637 Value *DataPtrField;
3638 InsertPointOrErrorTy GenResult =
3639 DataPtrPtrGen(Builder.saveIP(), DescriptorAddr, DataPtrField);
3640
3641 if (!GenResult)
3642 return GenResult.takeError();
3643
3644 Builder.CreateStore(Val: Builder.CreatePointerBitCastOrAddrSpaceCast(
3645 V: DataPtr, DestTy: Builder.getPtrTy(), Name: ".ascast"),
3646 Ptr: DataPtrField);
3647
3648 return Builder.saveIP();
3649}
3650
3651Expected<Function *> OpenMPIRBuilder::emitListToGlobalCopyFunction(
3652 ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
3653 AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
3654 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
3655 LLVMContext &Ctx = M.getContext();
3656 FunctionType *FuncTy = FunctionType::get(
3657 Result: Builder.getVoidTy(),
3658 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3659 /* IsVarArg */ isVarArg: false);
3660 Function *LtGCFunc =
3661 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3662 N: "_omp_reduction_list_to_global_copy_func", M: &M);
3663 LtGCFunc->setAttributes(FuncAttrs);
3664 LtGCFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3665 LtGCFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3666 LtGCFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
3667
3668 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: LtGCFunc);
3669 Builder.SetInsertPoint(EntryBlock);
3670
3671 // Buffer: global reduction buffer.
3672 Argument *BufferArg = LtGCFunc->getArg(i: 0);
3673 // Idx: index of the buffer.
3674 Argument *IdxArg = LtGCFunc->getArg(i: 1);
3675 // ReduceList: thread local Reduce list.
3676 Argument *ReduceListArg = LtGCFunc->getArg(i: 2);
3677
3678 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
3679 Name: BufferArg->getName() + ".addr");
3680 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
3681 Name: IdxArg->getName() + ".addr");
3682 Value *ReduceListArgAlloca = Builder.CreateAlloca(
3683 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3684 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3685 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
3686 Name: BufferArgAlloca->getName() + ".ascast");
3687 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3688 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
3689 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3690 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
3691 Name: ReduceListArgAlloca->getName() + ".ascast");
3692
3693 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
3694 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
3695 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
3696
3697 Value *LocalReduceList =
3698 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
3699 Value *BufferArgVal =
3700 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
3701 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
3702 Type *IndexTy = Builder.getIndexTy(
3703 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3704 for (auto En : enumerate(First&: ReductionInfos)) {
3705 const ReductionInfo &RI = En.value();
3706 auto *RedListArrayTy =
3707 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3708 // Reduce element = LocalReduceList[i]
3709 Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
3710 Ty: RedListArrayTy, Ptr: LocalReduceList,
3711 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3712 // elemptr = ((CopyType*)(elemptrptr)) + I
3713 Value *ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtrPtr);
3714
3715 // Global = Buffer.VD[Idx];
3716 Value *BufferVD =
3717 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferArgVal, IdxList: Idxs);
3718 Value *GlobVal = Builder.CreateConstInBoundsGEP2_32(
3719 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
3720
3721 switch (RI.EvaluationKind) {
3722 case EvalKind::Scalar: {
3723 Value *TargetElement;
3724
3725 if (IsByRef.empty() || !IsByRef[En.index()]) {
3726 TargetElement = Builder.CreateLoad(Ty: RI.ElementType, Ptr: ElemPtr);
3727 } else {
3728 InsertPointOrErrorTy GenResult =
3729 RI.DataPtrPtrGen(Builder.saveIP(), ElemPtr, ElemPtr);
3730
3731 if (!GenResult)
3732 return GenResult.takeError();
3733
3734 ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtr);
3735 TargetElement = Builder.CreateLoad(Ty: RI.ByRefElementType, Ptr: ElemPtr);
3736 }
3737
3738 Builder.CreateStore(Val: TargetElement, Ptr: GlobVal);
3739 break;
3740 }
3741 case EvalKind::Complex: {
3742 Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
3743 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 0, Name: ".realp");
3744 Value *SrcReal = Builder.CreateLoad(
3745 Ty: RI.ElementType->getStructElementType(N: 0), Ptr: SrcRealPtr, Name: ".real");
3746 Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
3747 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 1, Name: ".imagp");
3748 Value *SrcImg = Builder.CreateLoad(
3749 Ty: RI.ElementType->getStructElementType(N: 1), Ptr: SrcImgPtr, Name: ".imag");
3750
3751 Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
3752 Ty: RI.ElementType, Ptr: GlobVal, Idx0: 0, Idx1: 0, Name: ".realp");
3753 Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
3754 Ty: RI.ElementType, Ptr: GlobVal, Idx0: 0, Idx1: 1, Name: ".imagp");
3755 Builder.CreateStore(Val: SrcReal, Ptr: DestRealPtr);
3756 Builder.CreateStore(Val: SrcImg, Ptr: DestImgPtr);
3757 break;
3758 }
3759 case EvalKind::Aggregate: {
3760 Value *SizeVal =
3761 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: RI.ElementType));
3762 Builder.CreateMemCpy(
3763 Dst: GlobVal, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType), Src: ElemPtr,
3764 SrcAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType), Size: SizeVal, isVolatile: false);
3765 break;
3766 }
3767 }
3768 }
3769
3770 Builder.CreateRetVoid();
3771 Builder.restoreIP(IP: OldIP);
3772 return LtGCFunc;
3773}
3774
3775Expected<Function *> OpenMPIRBuilder::emitListToGlobalReduceFunction(
3776 ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
3777 Type *ReductionsBufferTy, AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
3778 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
3779 LLVMContext &Ctx = M.getContext();
3780 FunctionType *FuncTy = FunctionType::get(
3781 Result: Builder.getVoidTy(),
3782 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3783 /* IsVarArg */ isVarArg: false);
3784 Function *LtGRFunc =
3785 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3786 N: "_omp_reduction_list_to_global_reduce_func", M: &M);
3787 LtGRFunc->setAttributes(FuncAttrs);
3788 LtGRFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3789 LtGRFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3790 LtGRFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
3791
3792 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: LtGRFunc);
3793 Builder.SetInsertPoint(EntryBlock);
3794
3795 // Buffer: global reduction buffer.
3796 Argument *BufferArg = LtGRFunc->getArg(i: 0);
3797 // Idx: index of the buffer.
3798 Argument *IdxArg = LtGRFunc->getArg(i: 1);
3799 // ReduceList: thread local Reduce list.
3800 Argument *ReduceListArg = LtGRFunc->getArg(i: 2);
3801
3802 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
3803 Name: BufferArg->getName() + ".addr");
3804 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
3805 Name: IdxArg->getName() + ".addr");
3806 Value *ReduceListArgAlloca = Builder.CreateAlloca(
3807 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3808 auto *RedListArrayTy =
3809 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3810
3811 // 1. Build a list of reduction variables.
3812 // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
3813 Value *LocalReduceList =
3814 Builder.CreateAlloca(Ty: RedListArrayTy, ArraySize: nullptr, Name: ".omp.reduction.red_list");
3815
3816 InsertPointTy AllocaIP{EntryBlock, EntryBlock->begin()};
3817
3818 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3819 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
3820 Name: BufferArgAlloca->getName() + ".ascast");
3821 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3822 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
3823 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3824 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
3825 Name: ReduceListArgAlloca->getName() + ".ascast");
3826 Value *LocalReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3827 V: LocalReduceList, DestTy: Builder.getPtrTy(),
3828 Name: LocalReduceList->getName() + ".ascast");
3829
3830 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
3831 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
3832 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
3833
3834 Value *BufferVal = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
3835 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
3836 Type *IndexTy = Builder.getIndexTy(
3837 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3838 for (auto En : enumerate(First&: ReductionInfos)) {
3839 const ReductionInfo &RI = En.value();
3840 Value *ByRefAlloc;
3841
3842 if (!IsByRef.empty() && IsByRef[En.index()]) {
3843 InsertPointTy OldIP = Builder.saveIP();
3844 Builder.restoreIP(IP: AllocaIP);
3845
3846 ByRefAlloc = Builder.CreateAlloca(Ty: RI.ByRefAllocatedType);
3847 ByRefAlloc = Builder.CreatePointerBitCastOrAddrSpaceCast(
3848 V: ByRefAlloc, DestTy: Builder.getPtrTy(), Name: ByRefAlloc->getName() + ".ascast");
3849
3850 Builder.restoreIP(IP: OldIP);
3851 }
3852
3853 Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
3854 Ty: RedListArrayTy, Ptr: LocalReduceListAddrCast,
3855 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3856 Value *BufferVD =
3857 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferVal, IdxList: Idxs);
3858 // Global = Buffer.VD[Idx];
3859 Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3860 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
3861
3862 if (!IsByRef.empty() && IsByRef[En.index()]) {
3863 // Get source descriptor from the reduce list argument
3864 Value *ReduceList =
3865 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
3866 Value *SrcElementPtrPtr =
3867 Builder.CreateInBoundsGEP(Ty: RedListArrayTy, Ptr: ReduceList,
3868 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0),
3869 ConstantInt::get(Ty: IndexTy, V: En.index())});
3870 Value *SrcDescriptorAddr =
3871 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: SrcElementPtrPtr);
3872
3873 // Copy descriptor from source and update base_ptr to global buffer data
3874 InsertPointOrErrorTy GenResult =
3875 generateReductionDescriptor(DescriptorAddr: ByRefAlloc, DataPtr: GlobValPtr, SrcDescriptorAddr,
3876 DescriptorType: RI.ByRefAllocatedType, DataPtrPtrGen: RI.DataPtrPtrGen);
3877
3878 if (!GenResult)
3879 return GenResult.takeError();
3880
3881 Builder.CreateStore(Val: ByRefAlloc, Ptr: TargetElementPtrPtr);
3882 } else {
3883 Builder.CreateStore(Val: GlobValPtr, Ptr: TargetElementPtrPtr);
3884 }
3885 }
3886
3887 // Call reduce_function(GlobalReduceList, ReduceList)
3888 Value *ReduceList =
3889 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
3890 createRuntimeFunctionCall(Callee: ReduceFn, Args: {LocalReduceListAddrCast, ReduceList})
3891 ->addFnAttr(Kind: Attribute::NoUnwind);
3892 Builder.CreateRetVoid();
3893 Builder.restoreIP(IP: OldIP);
3894 return LtGRFunc;
3895}
3896
3897Expected<Function *> OpenMPIRBuilder::emitGlobalToListCopyFunction(
3898 ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
3899 AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
3900 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
3901 LLVMContext &Ctx = M.getContext();
3902 FunctionType *FuncTy = FunctionType::get(
3903 Result: Builder.getVoidTy(),
3904 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3905 /* IsVarArg */ isVarArg: false);
3906 Function *GtLCFunc =
3907 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3908 N: "_omp_reduction_global_to_list_copy_func", M: &M);
3909 GtLCFunc->setAttributes(FuncAttrs);
3910 GtLCFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3911 GtLCFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3912 GtLCFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
3913
3914 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: GtLCFunc);
3915 Builder.SetInsertPoint(EntryBlock);
3916
3917 // Buffer: global reduction buffer.
3918 Argument *BufferArg = GtLCFunc->getArg(i: 0);
3919 // Idx: index of the buffer.
3920 Argument *IdxArg = GtLCFunc->getArg(i: 1);
3921 // ReduceList: thread local Reduce list.
3922 Argument *ReduceListArg = GtLCFunc->getArg(i: 2);
3923
3924 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
3925 Name: BufferArg->getName() + ".addr");
3926 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
3927 Name: IdxArg->getName() + ".addr");
3928 Value *ReduceListArgAlloca = Builder.CreateAlloca(
3929 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3930 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3931 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
3932 Name: BufferArgAlloca->getName() + ".ascast");
3933 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3934 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
3935 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3936 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
3937 Name: ReduceListArgAlloca->getName() + ".ascast");
3938 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
3939 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
3940 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
3941
3942 Value *LocalReduceList =
3943 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
3944 Value *BufferVal = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
3945 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
3946 Type *IndexTy = Builder.getIndexTy(
3947 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3948 for (auto En : enumerate(First&: ReductionInfos)) {
3949 const OpenMPIRBuilder::ReductionInfo &RI = En.value();
3950 auto *RedListArrayTy =
3951 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3952 // Reduce element = LocalReduceList[i]
3953 Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
3954 Ty: RedListArrayTy, Ptr: LocalReduceList,
3955 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3956 // elemptr = ((CopyType*)(elemptrptr)) + I
3957 Value *ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtrPtr);
3958 // Global = Buffer.VD[Idx];
3959 Value *BufferVD =
3960 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferVal, IdxList: Idxs);
3961 Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3962 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
3963
3964 switch (RI.EvaluationKind) {
3965 case EvalKind::Scalar: {
3966 Type *ElemType = RI.ElementType;
3967
3968 if (!IsByRef.empty() && IsByRef[En.index()]) {
3969 ElemType = RI.ByRefElementType;
3970 InsertPointOrErrorTy GenResult =
3971 RI.DataPtrPtrGen(Builder.saveIP(), ElemPtr, ElemPtr);
3972
3973 if (!GenResult)
3974 return GenResult.takeError();
3975
3976 ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtr);
3977 }
3978
3979 Value *TargetElement = Builder.CreateLoad(Ty: ElemType, Ptr: GlobValPtr);
3980 Builder.CreateStore(Val: TargetElement, Ptr: ElemPtr);
3981 break;
3982 }
3983 case EvalKind::Complex: {
3984 Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
3985 Ty: RI.ElementType, Ptr: GlobValPtr, Idx0: 0, Idx1: 0, Name: ".realp");
3986 Value *SrcReal = Builder.CreateLoad(
3987 Ty: RI.ElementType->getStructElementType(N: 0), Ptr: SrcRealPtr, Name: ".real");
3988 Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
3989 Ty: RI.ElementType, Ptr: GlobValPtr, Idx0: 0, Idx1: 1, Name: ".imagp");
3990 Value *SrcImg = Builder.CreateLoad(
3991 Ty: RI.ElementType->getStructElementType(N: 1), Ptr: SrcImgPtr, Name: ".imag");
3992
3993 Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
3994 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 0, Name: ".realp");
3995 Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
3996 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 1, Name: ".imagp");
3997 Builder.CreateStore(Val: SrcReal, Ptr: DestRealPtr);
3998 Builder.CreateStore(Val: SrcImg, Ptr: DestImgPtr);
3999 break;
4000 }
4001 case EvalKind::Aggregate: {
4002 Value *SizeVal =
4003 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: RI.ElementType));
4004 Builder.CreateMemCpy(
4005 Dst: ElemPtr, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
4006 Src: GlobValPtr, SrcAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
4007 Size: SizeVal, isVolatile: false);
4008 break;
4009 }
4010 }
4011 }
4012
4013 Builder.CreateRetVoid();
4014 Builder.restoreIP(IP: OldIP);
4015 return GtLCFunc;
4016}
4017
4018Expected<Function *> OpenMPIRBuilder::emitGlobalToListReduceFunction(
4019 ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
4020 Type *ReductionsBufferTy, AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
4021 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
4022 LLVMContext &Ctx = M.getContext();
4023 auto *FuncTy = FunctionType::get(
4024 Result: Builder.getVoidTy(),
4025 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
4026 /* IsVarArg */ isVarArg: false);
4027 Function *GtLRFunc =
4028 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
4029 N: "_omp_reduction_global_to_list_reduce_func", M: &M);
4030 GtLRFunc->setAttributes(FuncAttrs);
4031 GtLRFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
4032 GtLRFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
4033 GtLRFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
4034
4035 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: GtLRFunc);
4036 Builder.SetInsertPoint(EntryBlock);
4037
4038 // Buffer: global reduction buffer.
4039 Argument *BufferArg = GtLRFunc->getArg(i: 0);
4040 // Idx: index of the buffer.
4041 Argument *IdxArg = GtLRFunc->getArg(i: 1);
4042 // ReduceList: thread local Reduce list.
4043 Argument *ReduceListArg = GtLRFunc->getArg(i: 2);
4044
4045 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
4046 Name: BufferArg->getName() + ".addr");
4047 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
4048 Name: IdxArg->getName() + ".addr");
4049 Value *ReduceListArgAlloca = Builder.CreateAlloca(
4050 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
4051 ArrayType *RedListArrayTy =
4052 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
4053
4054 // 1. Build a list of reduction variables.
4055 // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
4056 Value *LocalReduceList =
4057 Builder.CreateAlloca(Ty: RedListArrayTy, ArraySize: nullptr, Name: ".omp.reduction.red_list");
4058
4059 InsertPointTy AllocaIP{EntryBlock, EntryBlock->begin()};
4060
4061 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4062 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
4063 Name: BufferArgAlloca->getName() + ".ascast");
4064 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4065 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
4066 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4067 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
4068 Name: ReduceListArgAlloca->getName() + ".ascast");
4069 Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
4070 V: LocalReduceList, DestTy: Builder.getPtrTy(),
4071 Name: LocalReduceList->getName() + ".ascast");
4072
4073 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
4074 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
4075 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
4076
4077 Value *BufferVal = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
4078 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
4079 Type *IndexTy = Builder.getIndexTy(
4080 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
4081 for (auto En : enumerate(First&: ReductionInfos)) {
4082 const ReductionInfo &RI = En.value();
4083 Value *ByRefAlloc;
4084
4085 if (!IsByRef.empty() && IsByRef[En.index()]) {
4086 InsertPointTy OldIP = Builder.saveIP();
4087 Builder.restoreIP(IP: AllocaIP);
4088
4089 ByRefAlloc = Builder.CreateAlloca(Ty: RI.ByRefAllocatedType);
4090 ByRefAlloc = Builder.CreatePointerBitCastOrAddrSpaceCast(
4091 V: ByRefAlloc, DestTy: Builder.getPtrTy(), Name: ByRefAlloc->getName() + ".ascast");
4092
4093 Builder.restoreIP(IP: OldIP);
4094 }
4095
4096 Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
4097 Ty: RedListArrayTy, Ptr: ReductionList,
4098 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
4099 // Global = Buffer.VD[Idx];
4100 Value *BufferVD =
4101 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferVal, IdxList: Idxs);
4102 Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
4103 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
4104
4105 if (!IsByRef.empty() && IsByRef[En.index()]) {
4106 // Get source descriptor from the reduce list
4107 Value *ReduceListVal =
4108 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
4109 Value *SrcElementPtrPtr =
4110 Builder.CreateInBoundsGEP(Ty: RedListArrayTy, Ptr: ReduceListVal,
4111 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0),
4112 ConstantInt::get(Ty: IndexTy, V: En.index())});
4113 Value *SrcDescriptorAddr =
4114 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: SrcElementPtrPtr);
4115
4116 // Copy descriptor from source and update base_ptr to global buffer data
4117 InsertPointOrErrorTy GenResult =
4118 generateReductionDescriptor(DescriptorAddr: ByRefAlloc, DataPtr: GlobValPtr, SrcDescriptorAddr,
4119 DescriptorType: RI.ByRefAllocatedType, DataPtrPtrGen: RI.DataPtrPtrGen);
4120 if (!GenResult)
4121 return GenResult.takeError();
4122
4123 Builder.CreateStore(Val: ByRefAlloc, Ptr: TargetElementPtrPtr);
4124 } else {
4125 Builder.CreateStore(Val: GlobValPtr, Ptr: TargetElementPtrPtr);
4126 }
4127 }
4128
4129 // Call reduce_function(ReduceList, GlobalReduceList)
4130 Value *ReduceList =
4131 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
4132 createRuntimeFunctionCall(Callee: ReduceFn, Args: {ReduceList, ReductionList})
4133 ->addFnAttr(Kind: Attribute::NoUnwind);
4134 Builder.CreateRetVoid();
4135 Builder.restoreIP(IP: OldIP);
4136 return GtLRFunc;
4137}
4138
4139std::string OpenMPIRBuilder::getReductionFuncName(StringRef Name) const {
4140 std::string Suffix =
4141 createPlatformSpecificName(Parts: {"omp", "reduction", "reduction_func"});
4142 return (Name + Suffix).str();
4143}
4144
4145Expected<Function *> OpenMPIRBuilder::createReductionFunction(
4146 StringRef ReducerName, ArrayRef<ReductionInfo> ReductionInfos,
4147 ArrayRef<bool> IsByRef, ReductionGenCBKind ReductionGenCBKind,
4148 AttributeList FuncAttrs) {
4149 auto *FuncTy = FunctionType::get(Result: Builder.getVoidTy(),
4150 Params: {Builder.getPtrTy(), Builder.getPtrTy()},
4151 /* IsVarArg */ isVarArg: false);
4152 std::string Name = getReductionFuncName(Name: ReducerName);
4153 Function *ReductionFunc =
4154 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage, N: Name, M: &M);
4155 ReductionFunc->setAttributes(FuncAttrs);
4156 ReductionFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
4157 ReductionFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
4158 BasicBlock *EntryBB =
4159 BasicBlock::Create(Context&: M.getContext(), Name: "entry", Parent: ReductionFunc);
4160 Builder.SetInsertPoint(EntryBB);
4161
4162 // Need to alloca memory here and deal with the pointers before getting
4163 // LHS/RHS pointers out
4164 Value *LHSArrayPtr = nullptr;
4165 Value *RHSArrayPtr = nullptr;
4166 Argument *Arg0 = ReductionFunc->getArg(i: 0);
4167 Argument *Arg1 = ReductionFunc->getArg(i: 1);
4168 Type *Arg0Type = Arg0->getType();
4169 Type *Arg1Type = Arg1->getType();
4170
4171 Value *LHSAlloca =
4172 Builder.CreateAlloca(Ty: Arg0Type, ArraySize: nullptr, Name: Arg0->getName() + ".addr");
4173 Value *RHSAlloca =
4174 Builder.CreateAlloca(Ty: Arg1Type, ArraySize: nullptr, Name: Arg1->getName() + ".addr");
4175 Value *LHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4176 V: LHSAlloca, DestTy: Arg0Type, Name: LHSAlloca->getName() + ".ascast");
4177 Value *RHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4178 V: RHSAlloca, DestTy: Arg1Type, Name: RHSAlloca->getName() + ".ascast");
4179 Builder.CreateStore(Val: Arg0, Ptr: LHSAddrCast);
4180 Builder.CreateStore(Val: Arg1, Ptr: RHSAddrCast);
4181 LHSArrayPtr = Builder.CreateLoad(Ty: Arg0Type, Ptr: LHSAddrCast);
4182 RHSArrayPtr = Builder.CreateLoad(Ty: Arg1Type, Ptr: RHSAddrCast);
4183
4184 Type *RedArrayTy = ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
4185 Type *IndexTy = Builder.getIndexTy(
4186 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
4187 SmallVector<Value *> LHSPtrs, RHSPtrs;
4188 for (auto En : enumerate(First&: ReductionInfos)) {
4189 const ReductionInfo &RI = En.value();
4190 Value *RHSI8PtrPtr = Builder.CreateInBoundsGEP(
4191 Ty: RedArrayTy, Ptr: RHSArrayPtr,
4192 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
4193 Value *RHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: RHSI8PtrPtr);
4194 Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
4195 V: RHSI8Ptr, DestTy: RI.PrivateVariable->getType(),
4196 Name: RHSI8Ptr->getName() + ".ascast");
4197
4198 Value *LHSI8PtrPtr = Builder.CreateInBoundsGEP(
4199 Ty: RedArrayTy, Ptr: LHSArrayPtr,
4200 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
4201 Value *LHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: LHSI8PtrPtr);
4202 Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
4203 V: LHSI8Ptr, DestTy: RI.Variable->getType(), Name: LHSI8Ptr->getName() + ".ascast");
4204
4205 if (ReductionGenCBKind == ReductionGenCBKind::Clang) {
4206 LHSPtrs.emplace_back(Args&: LHSPtr);
4207 RHSPtrs.emplace_back(Args&: RHSPtr);
4208 } else {
4209 Value *LHS = LHSPtr;
4210 Value *RHS = RHSPtr;
4211
4212 if (!IsByRef.empty() && !IsByRef[En.index()]) {
4213 LHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: LHSPtr);
4214 RHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: RHSPtr);
4215 }
4216
4217 Value *Reduced;
4218 InsertPointOrErrorTy AfterIP =
4219 RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
4220 if (!AfterIP)
4221 return AfterIP.takeError();
4222 if (!Builder.GetInsertBlock())
4223 return ReductionFunc;
4224
4225 Builder.restoreIP(IP: *AfterIP);
4226
4227 if (!IsByRef.empty() && !IsByRef[En.index()])
4228 Builder.CreateStore(Val: Reduced, Ptr: LHSPtr);
4229 }
4230 }
4231
4232 if (ReductionGenCBKind == ReductionGenCBKind::Clang)
4233 for (auto En : enumerate(First&: ReductionInfos)) {
4234 unsigned Index = En.index();
4235 const ReductionInfo &RI = En.value();
4236 Value *LHSFixupPtr, *RHSFixupPtr;
4237 Builder.restoreIP(IP: RI.ReductionGenClang(
4238 Builder.saveIP(), Index, &LHSFixupPtr, &RHSFixupPtr, ReductionFunc));
4239
4240 // Fix the CallBack code genereated to use the correct Values for the LHS
4241 // and RHS
4242 LHSFixupPtr->replaceUsesWithIf(
4243 New: LHSPtrs[Index], ShouldReplace: [ReductionFunc](const Use &U) {
4244 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
4245 ReductionFunc;
4246 });
4247 RHSFixupPtr->replaceUsesWithIf(
4248 New: RHSPtrs[Index], ShouldReplace: [ReductionFunc](const Use &U) {
4249 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
4250 ReductionFunc;
4251 });
4252 }
4253
4254 Builder.CreateRetVoid();
4255 // Compiling with `-O0`, `alloca`s emitted in non-entry blocks are not hoisted
4256 // to the entry block (this is dones for higher opt levels by later passes in
4257 // the pipeline). This has caused issues because non-entry `alloca`s force the
4258 // function to use dynamic stack allocations and we might run out of scratch
4259 // memory.
4260 hoistNonEntryAllocasToEntryBlock(Func: ReductionFunc);
4261
4262 return ReductionFunc;
4263}
4264
4265static void
4266checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
4267 bool IsGPU) {
4268 for (const OpenMPIRBuilder::ReductionInfo &RI : ReductionInfos) {
4269 (void)RI;
4270 assert(RI.Variable && "expected non-null variable");
4271 assert(RI.PrivateVariable && "expected non-null private variable");
4272 assert((RI.ReductionGen || RI.ReductionGenClang) &&
4273 "expected non-null reduction generator callback");
4274 if (!IsGPU) {
4275 assert(
4276 RI.Variable->getType() == RI.PrivateVariable->getType() &&
4277 "expected variables and their private equivalents to have the same "
4278 "type");
4279 }
4280 assert(RI.Variable->getType()->isPointerTy() &&
4281 "expected variables to be pointers");
4282 }
4283}
4284
4285OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
4286 const LocationDescription &Loc, InsertPointTy AllocaIP,
4287 InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
4288 ArrayRef<bool> IsByRef, bool IsNoWait, bool IsTeamsReduction,
4289 ReductionGenCBKind ReductionGenCBKind, std::optional<omp::GV> GridValue,
4290 unsigned ReductionBufNum, Value *SrcLocInfo) {
4291 if (!updateToLocation(Loc))
4292 return InsertPointTy();
4293 Builder.restoreIP(IP: CodeGenIP);
4294 checkReductionInfos(ReductionInfos, /*IsGPU*/ true);
4295 LLVMContext &Ctx = M.getContext();
4296
4297 // Source location for the ident struct
4298 if (!SrcLocInfo) {
4299 uint32_t SrcLocStrSize;
4300 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4301 SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4302 }
4303
4304 if (ReductionInfos.size() == 0)
4305 return Builder.saveIP();
4306
4307 BasicBlock *ContinuationBlock = nullptr;
4308 if (ReductionGenCBKind != ReductionGenCBKind::Clang) {
4309 // Copied code from createReductions
4310 BasicBlock *InsertBlock = Loc.IP.getBlock();
4311 ContinuationBlock =
4312 InsertBlock->splitBasicBlock(I: Loc.IP.getPoint(), BBName: "reduce.finalize");
4313 InsertBlock->getTerminator()->eraseFromParent();
4314 Builder.SetInsertPoint(TheBB: InsertBlock, IP: InsertBlock->end());
4315 }
4316
4317 Function *CurFunc = Builder.GetInsertBlock()->getParent();
4318 AttributeList FuncAttrs;
4319 AttrBuilder AttrBldr(Ctx);
4320 for (auto Attr : CurFunc->getAttributes().getFnAttrs())
4321 AttrBldr.addAttribute(A: Attr);
4322 AttrBldr.removeAttribute(Val: Attribute::OptimizeNone);
4323 FuncAttrs = FuncAttrs.addFnAttributes(C&: Ctx, B: AttrBldr);
4324
4325 CodeGenIP = Builder.saveIP();
4326 Expected<Function *> ReductionResult = createReductionFunction(
4327 ReducerName: Builder.GetInsertBlock()->getParent()->getName(), ReductionInfos, IsByRef,
4328 ReductionGenCBKind, FuncAttrs);
4329 if (!ReductionResult)
4330 return ReductionResult.takeError();
4331 Function *ReductionFunc = *ReductionResult;
4332 Builder.restoreIP(IP: CodeGenIP);
4333
4334 // Set the grid value in the config needed for lowering later on
4335 if (GridValue.has_value())
4336 Config.setGridValue(GridValue.value());
4337 else
4338 Config.setGridValue(getGridValue(T, Kernel: ReductionFunc));
4339
4340 // Build res = __kmpc_reduce{_nowait}(<gtid>, <n>, sizeof(RedList),
4341 // RedList, shuffle_reduce_func, interwarp_copy_func);
4342 // or
4343 // Build res = __kmpc_reduce_teams_nowait_simple(<loc>, <gtid>, <lck>);
4344 Value *Res;
4345
4346 // 1. Build a list of reduction variables.
4347 // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
4348 auto Size = ReductionInfos.size();
4349 Type *PtrTy = PointerType::get(C&: Ctx, AddressSpace: Config.getDefaultTargetAS());
4350 Type *FuncPtrTy =
4351 Builder.getPtrTy(AddrSpace: M.getDataLayout().getProgramAddressSpace());
4352 Type *RedArrayTy = ArrayType::get(ElementType: PtrTy, NumElements: Size);
4353 CodeGenIP = Builder.saveIP();
4354 Builder.restoreIP(IP: AllocaIP);
4355 Value *ReductionListAlloca =
4356 Builder.CreateAlloca(Ty: RedArrayTy, ArraySize: nullptr, Name: ".omp.reduction.red_list");
4357 Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
4358 V: ReductionListAlloca, DestTy: PtrTy, Name: ReductionListAlloca->getName() + ".ascast");
4359 Builder.restoreIP(IP: CodeGenIP);
4360 Type *IndexTy = Builder.getIndexTy(
4361 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
4362 for (auto En : enumerate(First&: ReductionInfos)) {
4363 const ReductionInfo &RI = En.value();
4364 Value *ElemPtr = Builder.CreateInBoundsGEP(
4365 Ty: RedArrayTy, Ptr: ReductionList,
4366 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
4367
4368 Value *PrivateVar = RI.PrivateVariable;
4369 bool IsByRefElem = !IsByRef.empty() && IsByRef[En.index()];
4370 if (IsByRefElem)
4371 PrivateVar = Builder.CreateLoad(Ty: RI.ElementType, Ptr: PrivateVar);
4372
4373 Value *CastElem =
4374 Builder.CreatePointerBitCastOrAddrSpaceCast(V: PrivateVar, DestTy: PtrTy);
4375 Builder.CreateStore(Val: CastElem, Ptr: ElemPtr);
4376 }
4377 CodeGenIP = Builder.saveIP();
4378 Expected<Function *> SarFunc = emitShuffleAndReduceFunction(
4379 ReductionInfos, ReduceFn: ReductionFunc, FuncAttrs, IsByRef);
4380
4381 if (!SarFunc)
4382 return SarFunc.takeError();
4383
4384 Expected<Function *> CopyResult =
4385 emitInterWarpCopyFunction(Loc, ReductionInfos, FuncAttrs, IsByRef);
4386 if (!CopyResult)
4387 return CopyResult.takeError();
4388 Function *WcFunc = *CopyResult;
4389 Builder.restoreIP(IP: CodeGenIP);
4390
4391 Value *RL = Builder.CreatePointerBitCastOrAddrSpaceCast(V: ReductionList, DestTy: PtrTy);
4392
4393 // NOTE: ReductionDataSize is passed as the reduce_data_size
4394 // argument to __kmpc_nvptx_{parallel,teams}_reduce_nowait_v2, but
4395 // the runtime implementations do not currently use it. The teams
4396 // runtime reads ReductionDataSize from KernelEnvironmentTy instead
4397 // (set separately via TargetKernelDefaultAttrs). It is computed
4398 // here conservatively as max(element sizes) * N rather than the
4399 // exact sum, which over-calculates the size for mixed reduction
4400 // types but is harmless given the argument is unused.
4401 // TODO: Consider dropping this computation if the runtime API is
4402 // ever revised to remove the unused parameter.
4403 unsigned MaxDataSize = 0;
4404 SmallVector<Type *> ReductionTypeArgs;
4405 for (auto En : enumerate(First&: ReductionInfos)) {
4406 // Use ByRefElementType for by-ref reductions so that MaxDataSize matches
4407 // the actual data size stored in the global reduction buffer, consistent
4408 // with the ReductionsBufferTy struct used for GEP offsets below.
4409 Type *RedTypeArg = (!IsByRef.empty() && IsByRef[En.index()])
4410 ? En.value().ByRefElementType
4411 : En.value().ElementType;
4412 auto Size = M.getDataLayout().getTypeStoreSize(Ty: RedTypeArg);
4413 if (Size > MaxDataSize)
4414 MaxDataSize = Size;
4415 ReductionTypeArgs.emplace_back(Args&: RedTypeArg);
4416 }
4417 Value *ReductionDataSize =
4418 Builder.getInt64(C: MaxDataSize * ReductionInfos.size());
4419 if (!IsTeamsReduction) {
4420 Value *SarFuncCast =
4421 Builder.CreatePointerBitCastOrAddrSpaceCast(V: *SarFunc, DestTy: FuncPtrTy);
4422 Value *WcFuncCast =
4423 Builder.CreatePointerBitCastOrAddrSpaceCast(V: WcFunc, DestTy: FuncPtrTy);
4424 Value *Args[] = {SrcLocInfo, ReductionDataSize, RL, SarFuncCast,
4425 WcFuncCast};
4426 Function *Pv2Ptr = getOrCreateRuntimeFunctionPtr(
4427 FnID: RuntimeFunction::OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2);
4428 Res = createRuntimeFunctionCall(Callee: Pv2Ptr, Args);
4429 } else {
4430 CodeGenIP = Builder.saveIP();
4431 StructType *ReductionsBufferTy = StructType::create(
4432 Context&: Ctx, Elements: ReductionTypeArgs, Name: "struct._globalized_locals_ty");
4433 Function *RedFixedBufferFn = getOrCreateRuntimeFunctionPtr(
4434 FnID: RuntimeFunction::OMPRTL___kmpc_reduction_get_fixed_buffer);
4435
4436 Expected<Function *> LtGCFunc = emitListToGlobalCopyFunction(
4437 ReductionInfos, ReductionsBufferTy, FuncAttrs, IsByRef);
4438 if (!LtGCFunc)
4439 return LtGCFunc.takeError();
4440
4441 Expected<Function *> LtGRFunc = emitListToGlobalReduceFunction(
4442 ReductionInfos, ReduceFn: ReductionFunc, ReductionsBufferTy, FuncAttrs, IsByRef);
4443 if (!LtGRFunc)
4444 return LtGRFunc.takeError();
4445
4446 Expected<Function *> GtLCFunc = emitGlobalToListCopyFunction(
4447 ReductionInfos, ReductionsBufferTy, FuncAttrs, IsByRef);
4448 if (!GtLCFunc)
4449 return GtLCFunc.takeError();
4450
4451 Expected<Function *> GtLRFunc = emitGlobalToListReduceFunction(
4452 ReductionInfos, ReduceFn: ReductionFunc, ReductionsBufferTy, FuncAttrs, IsByRef);
4453 if (!GtLRFunc)
4454 return GtLRFunc.takeError();
4455
4456 Builder.restoreIP(IP: CodeGenIP);
4457
4458 Value *KernelTeamsReductionPtr = createRuntimeFunctionCall(
4459 Callee: RedFixedBufferFn, Args: {}, Name: "_openmp_teams_reductions_buffer_$_$ptr");
4460
4461 Value *Args3[] = {SrcLocInfo,
4462 KernelTeamsReductionPtr,
4463 Builder.getInt32(C: ReductionBufNum),
4464 ReductionDataSize,
4465 RL,
4466 *SarFunc,
4467 WcFunc,
4468 *LtGCFunc,
4469 *LtGRFunc,
4470 *GtLCFunc,
4471 *GtLRFunc};
4472
4473 Function *TeamsReduceFn = getOrCreateRuntimeFunctionPtr(
4474 FnID: RuntimeFunction::OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2);
4475 Res = createRuntimeFunctionCall(Callee: TeamsReduceFn, Args: Args3);
4476 }
4477
4478 // 5. Build if (res == 1)
4479 BasicBlock *ExitBB = BasicBlock::Create(Context&: Ctx, Name: ".omp.reduction.done");
4480 BasicBlock *ThenBB = BasicBlock::Create(Context&: Ctx, Name: ".omp.reduction.then");
4481 Value *Cond = Builder.CreateICmpEQ(LHS: Res, RHS: Builder.getInt32(C: 1));
4482 Builder.CreateCondBr(Cond, True: ThenBB, False: ExitBB);
4483
4484 // 6. Build then branch: where we have reduced values in the master
4485 // thread in each team.
4486 // __kmpc_end_reduce{_nowait}(<gtid>);
4487 // break;
4488 emitBlock(BB: ThenBB, CurFn: CurFunc);
4489
4490 // Add emission of __kmpc_end_reduce{_nowait}(<gtid>);
4491 for (auto En : enumerate(First&: ReductionInfos)) {
4492 const ReductionInfo &RI = En.value();
4493 Type *ValueType = RI.ElementType;
4494 Value *RedValue = RI.Variable;
4495 Value *RHS =
4496 Builder.CreatePointerBitCastOrAddrSpaceCast(V: RI.PrivateVariable, DestTy: PtrTy);
4497
4498 if (ReductionGenCBKind == ReductionGenCBKind::Clang) {
4499 Value *LHSPtr, *RHSPtr;
4500 Builder.restoreIP(IP: RI.ReductionGenClang(Builder.saveIP(), En.index(),
4501 &LHSPtr, &RHSPtr, CurFunc));
4502
4503 // Fix the CallBack code genereated to use the correct Values for the LHS
4504 // and RHS
4505 LHSPtr->replaceUsesWithIf(New: RedValue, ShouldReplace: [ReductionFunc](const Use &U) {
4506 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
4507 ReductionFunc;
4508 });
4509 RHSPtr->replaceUsesWithIf(New: RHS, ShouldReplace: [ReductionFunc](const Use &U) {
4510 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
4511 ReductionFunc;
4512 });
4513 } else {
4514 if (IsByRef.empty() || !IsByRef[En.index()]) {
4515 RedValue = Builder.CreateLoad(Ty: ValueType, Ptr: RI.Variable,
4516 Name: "red.value." + Twine(En.index()));
4517 }
4518 Value *PrivateRedValue = Builder.CreateLoad(
4519 Ty: ValueType, Ptr: RHS, Name: "red.private.value" + Twine(En.index()));
4520 Value *Reduced;
4521 InsertPointOrErrorTy AfterIP =
4522 RI.ReductionGen(Builder.saveIP(), RedValue, PrivateRedValue, Reduced);
4523 if (!AfterIP)
4524 return AfterIP.takeError();
4525 Builder.restoreIP(IP: *AfterIP);
4526
4527 if (!IsByRef.empty() && !IsByRef[En.index()])
4528 Builder.CreateStore(Val: Reduced, Ptr: RI.Variable);
4529 }
4530 }
4531 emitBlock(BB: ExitBB, CurFn: CurFunc);
4532 if (ContinuationBlock) {
4533 Builder.CreateBr(Dest: ContinuationBlock);
4534 Builder.SetInsertPoint(ContinuationBlock);
4535 }
4536 Config.setEmitLLVMUsed();
4537
4538 return Builder.saveIP();
4539}
4540
4541static Function *getFreshReductionFunc(Module &M) {
4542 Type *VoidTy = Type::getVoidTy(C&: M.getContext());
4543 Type *Int8PtrTy = PointerType::getUnqual(C&: M.getContext());
4544 auto *FuncTy =
4545 FunctionType::get(Result: VoidTy, Params: {Int8PtrTy, Int8PtrTy}, /* IsVarArg */ isVarArg: false);
4546 return Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
4547 N: ".omp.reduction.func", M: &M);
4548}
4549
4550static Error populateReductionFunction(
4551 Function *ReductionFunc,
4552 ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
4553 IRBuilder<> &Builder, ArrayRef<bool> IsByRef, bool IsGPU) {
4554 Module *Module = ReductionFunc->getParent();
4555 BasicBlock *ReductionFuncBlock =
4556 BasicBlock::Create(Context&: Module->getContext(), Name: "", Parent: ReductionFunc);
4557 Builder.SetInsertPoint(ReductionFuncBlock);
4558 Value *LHSArrayPtr = nullptr;
4559 Value *RHSArrayPtr = nullptr;
4560 if (IsGPU) {
4561 // Need to alloca memory here and deal with the pointers before getting
4562 // LHS/RHS pointers out
4563 //
4564 Argument *Arg0 = ReductionFunc->getArg(i: 0);
4565 Argument *Arg1 = ReductionFunc->getArg(i: 1);
4566 Type *Arg0Type = Arg0->getType();
4567 Type *Arg1Type = Arg1->getType();
4568
4569 Value *LHSAlloca =
4570 Builder.CreateAlloca(Ty: Arg0Type, ArraySize: nullptr, Name: Arg0->getName() + ".addr");
4571 Value *RHSAlloca =
4572 Builder.CreateAlloca(Ty: Arg1Type, ArraySize: nullptr, Name: Arg1->getName() + ".addr");
4573 Value *LHSAddrCast =
4574 Builder.CreatePointerBitCastOrAddrSpaceCast(V: LHSAlloca, DestTy: Arg0Type);
4575 Value *RHSAddrCast =
4576 Builder.CreatePointerBitCastOrAddrSpaceCast(V: RHSAlloca, DestTy: Arg1Type);
4577 Builder.CreateStore(Val: Arg0, Ptr: LHSAddrCast);
4578 Builder.CreateStore(Val: Arg1, Ptr: RHSAddrCast);
4579 LHSArrayPtr = Builder.CreateLoad(Ty: Arg0Type, Ptr: LHSAddrCast);
4580 RHSArrayPtr = Builder.CreateLoad(Ty: Arg1Type, Ptr: RHSAddrCast);
4581 } else {
4582 LHSArrayPtr = ReductionFunc->getArg(i: 0);
4583 RHSArrayPtr = ReductionFunc->getArg(i: 1);
4584 }
4585
4586 unsigned NumReductions = ReductionInfos.size();
4587 Type *RedArrayTy = ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: NumReductions);
4588
4589 for (auto En : enumerate(First&: ReductionInfos)) {
4590 const OpenMPIRBuilder::ReductionInfo &RI = En.value();
4591 Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
4592 Ty: RedArrayTy, Ptr: LHSArrayPtr, Idx0: 0, Idx1: En.index());
4593 Value *LHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: LHSI8PtrPtr);
4594 Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
4595 V: LHSI8Ptr, DestTy: RI.Variable->getType());
4596 Value *LHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: LHSPtr);
4597 Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
4598 Ty: RedArrayTy, Ptr: RHSArrayPtr, Idx0: 0, Idx1: En.index());
4599 Value *RHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: RHSI8PtrPtr);
4600 Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
4601 V: RHSI8Ptr, DestTy: RI.PrivateVariable->getType());
4602 Value *RHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: RHSPtr);
4603 Value *Reduced;
4604 OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
4605 RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
4606 if (!AfterIP)
4607 return AfterIP.takeError();
4608
4609 Builder.restoreIP(IP: *AfterIP);
4610 // TODO: Consider flagging an error.
4611 if (!Builder.GetInsertBlock())
4612 return Error::success();
4613
4614 // store is inside of the reduction region when using by-ref
4615 if (!IsByRef[En.index()])
4616 Builder.CreateStore(Val: Reduced, Ptr: LHSPtr);
4617 }
4618 Builder.CreateRetVoid();
4619 return Error::success();
4620}
4621
4622OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductions(
4623 const LocationDescription &Loc, InsertPointTy AllocaIP,
4624 ArrayRef<ReductionInfo> ReductionInfos, ArrayRef<bool> IsByRef,
4625 bool IsNoWait, bool IsTeamsReduction) {
4626 assert(ReductionInfos.size() == IsByRef.size());
4627 if (Config.isGPU())
4628 return createReductionsGPU(Loc, AllocaIP, CodeGenIP: Builder.saveIP(), ReductionInfos,
4629 IsByRef, IsNoWait, IsTeamsReduction);
4630
4631 checkReductionInfos(ReductionInfos, /*IsGPU*/ false);
4632
4633 if (!updateToLocation(Loc))
4634 return InsertPointTy();
4635
4636 if (ReductionInfos.size() == 0)
4637 return Builder.saveIP();
4638
4639 BasicBlock *InsertBlock = Loc.IP.getBlock();
4640 BasicBlock *ContinuationBlock =
4641 InsertBlock->splitBasicBlock(I: Loc.IP.getPoint(), BBName: "reduce.finalize");
4642 InsertBlock->getTerminator()->eraseFromParent();
4643
4644 // Create and populate array of type-erased pointers to private reduction
4645 // values.
4646 unsigned NumReductions = ReductionInfos.size();
4647 Type *RedArrayTy = ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: NumReductions);
4648 Builder.SetInsertPoint(AllocaIP.getBlock()->getTerminator());
4649 Value *RedArray = Builder.CreateAlloca(Ty: RedArrayTy, ArraySize: nullptr, Name: "red.array");
4650
4651 Builder.SetInsertPoint(TheBB: InsertBlock, IP: InsertBlock->end());
4652
4653 for (auto En : enumerate(First&: ReductionInfos)) {
4654 unsigned Index = En.index();
4655 const ReductionInfo &RI = En.value();
4656 Value *RedArrayElemPtr = Builder.CreateConstInBoundsGEP2_64(
4657 Ty: RedArrayTy, Ptr: RedArray, Idx0: 0, Idx1: Index, Name: "red.array.elem." + Twine(Index));
4658 Builder.CreateStore(Val: RI.PrivateVariable, Ptr: RedArrayElemPtr);
4659 }
4660
4661 // Emit a call to the runtime function that orchestrates the reduction.
4662 // Declare the reduction function in the process.
4663 Type *IndexTy = Builder.getIndexTy(
4664 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
4665 Function *Func = Builder.GetInsertBlock()->getParent();
4666 Module *Module = Func->getParent();
4667 uint32_t SrcLocStrSize;
4668 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4669 bool CanGenerateAtomic = all_of(Range&: ReductionInfos, P: [](const ReductionInfo &RI) {
4670 return RI.AtomicReductionGen;
4671 });
4672 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize,
4673 LocFlags: CanGenerateAtomic
4674 ? IdentFlag::OMP_IDENT_FLAG_ATOMIC_REDUCE
4675 : IdentFlag(0));
4676 Value *ThreadId = getOrCreateThreadID(Ident);
4677 Constant *NumVariables = Builder.getInt32(C: NumReductions);
4678 const DataLayout &DL = Module->getDataLayout();
4679 unsigned RedArrayByteSize = DL.getTypeStoreSize(Ty: RedArrayTy);
4680 Constant *RedArraySize = ConstantInt::get(Ty: IndexTy, V: RedArrayByteSize);
4681 Function *ReductionFunc = getFreshReductionFunc(M&: *Module);
4682 Value *Lock = getOMPCriticalRegionLock(CriticalName: ".reduction");
4683 Function *ReduceFunc = getOrCreateRuntimeFunctionPtr(
4684 FnID: IsNoWait ? RuntimeFunction::OMPRTL___kmpc_reduce_nowait
4685 : RuntimeFunction::OMPRTL___kmpc_reduce);
4686 CallInst *ReduceCall =
4687 createRuntimeFunctionCall(Callee: ReduceFunc,
4688 Args: {Ident, ThreadId, NumVariables, RedArraySize,
4689 RedArray, ReductionFunc, Lock},
4690 Name: "reduce");
4691
4692 // Create final reduction entry blocks for the atomic and non-atomic case.
4693 // Emit IR that dispatches control flow to one of the blocks based on the
4694 // reduction supporting the atomic mode.
4695 BasicBlock *NonAtomicRedBlock =
4696 BasicBlock::Create(Context&: Module->getContext(), Name: "reduce.switch.nonatomic", Parent: Func);
4697 BasicBlock *AtomicRedBlock =
4698 BasicBlock::Create(Context&: Module->getContext(), Name: "reduce.switch.atomic", Parent: Func);
4699 SwitchInst *Switch =
4700 Builder.CreateSwitch(V: ReduceCall, Dest: ContinuationBlock, /* NumCases */ 2);
4701 Switch->addCase(OnVal: Builder.getInt32(C: 1), Dest: NonAtomicRedBlock);
4702 Switch->addCase(OnVal: Builder.getInt32(C: 2), Dest: AtomicRedBlock);
4703
4704 // Populate the non-atomic reduction using the elementwise reduction function.
4705 // This loads the elements from the global and private variables and reduces
4706 // them before storing back the result to the global variable.
4707 Builder.SetInsertPoint(NonAtomicRedBlock);
4708 for (auto En : enumerate(First&: ReductionInfos)) {
4709 const ReductionInfo &RI = En.value();
4710 Type *ValueType = RI.ElementType;
4711 // We have one less load for by-ref case because that load is now inside of
4712 // the reduction region
4713 Value *RedValue = RI.Variable;
4714 if (!IsByRef[En.index()]) {
4715 RedValue = Builder.CreateLoad(Ty: ValueType, Ptr: RI.Variable,
4716 Name: "red.value." + Twine(En.index()));
4717 }
4718 Value *PrivateRedValue =
4719 Builder.CreateLoad(Ty: ValueType, Ptr: RI.PrivateVariable,
4720 Name: "red.private.value." + Twine(En.index()));
4721 Value *Reduced;
4722 InsertPointOrErrorTy AfterIP =
4723 RI.ReductionGen(Builder.saveIP(), RedValue, PrivateRedValue, Reduced);
4724 if (!AfterIP)
4725 return AfterIP.takeError();
4726 Builder.restoreIP(IP: *AfterIP);
4727
4728 if (!Builder.GetInsertBlock())
4729 return InsertPointTy();
4730 // for by-ref case, the load is inside of the reduction region
4731 if (!IsByRef[En.index()])
4732 Builder.CreateStore(Val: Reduced, Ptr: RI.Variable);
4733 }
4734 Function *EndReduceFunc = getOrCreateRuntimeFunctionPtr(
4735 FnID: IsNoWait ? RuntimeFunction::OMPRTL___kmpc_end_reduce_nowait
4736 : RuntimeFunction::OMPRTL___kmpc_end_reduce);
4737 createRuntimeFunctionCall(Callee: EndReduceFunc, Args: {Ident, ThreadId, Lock});
4738 Builder.CreateBr(Dest: ContinuationBlock);
4739
4740 // Populate the atomic reduction using the atomic elementwise reduction
4741 // function. There are no loads/stores here because they will be happening
4742 // inside the atomic elementwise reduction.
4743 Builder.SetInsertPoint(AtomicRedBlock);
4744 if (CanGenerateAtomic && llvm::none_of(Range&: IsByRef, P: [](bool P) { return P; })) {
4745 for (const ReductionInfo &RI : ReductionInfos) {
4746 InsertPointOrErrorTy AfterIP = RI.AtomicReductionGen(
4747 Builder.saveIP(), RI.ElementType, RI.Variable, RI.PrivateVariable);
4748 if (!AfterIP)
4749 return AfterIP.takeError();
4750 Builder.restoreIP(IP: *AfterIP);
4751 if (!Builder.GetInsertBlock())
4752 return InsertPointTy();
4753 }
4754 Builder.CreateBr(Dest: ContinuationBlock);
4755 } else {
4756 Builder.CreateUnreachable();
4757 }
4758
4759 // Populate the outlined reduction function using the elementwise reduction
4760 // function. Partial values are extracted from the type-erased array of
4761 // pointers to private variables.
4762 Error Err = populateReductionFunction(ReductionFunc, ReductionInfos, Builder,
4763 IsByRef, /*isGPU=*/IsGPU: false);
4764 if (Err)
4765 return Err;
4766
4767 if (!Builder.GetInsertBlock())
4768 return InsertPointTy();
4769
4770 Builder.SetInsertPoint(ContinuationBlock);
4771 return Builder.saveIP();
4772}
4773
4774OpenMPIRBuilder::InsertPointOrErrorTy
4775OpenMPIRBuilder::createMaster(const LocationDescription &Loc,
4776 BodyGenCallbackTy BodyGenCB,
4777 FinalizeCallbackTy FiniCB) {
4778 if (!updateToLocation(Loc))
4779 return Loc.IP;
4780
4781 Directive OMPD = Directive::OMPD_master;
4782 uint32_t SrcLocStrSize;
4783 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4784 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4785 Value *ThreadId = getOrCreateThreadID(Ident);
4786 Value *Args[] = {Ident, ThreadId};
4787
4788 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_master);
4789 Instruction *EntryCall = createRuntimeFunctionCall(Callee: EntryRTLFn, Args);
4790
4791 Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_master);
4792 Instruction *ExitCall = createRuntimeFunctionCall(Callee: ExitRTLFn, Args);
4793
4794 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
4795 /*Conditional*/ true, /*hasFinalize*/ HasFinalize: true);
4796}
4797
4798OpenMPIRBuilder::InsertPointOrErrorTy
4799OpenMPIRBuilder::createMasked(const LocationDescription &Loc,
4800 BodyGenCallbackTy BodyGenCB,
4801 FinalizeCallbackTy FiniCB, Value *Filter) {
4802 if (!updateToLocation(Loc))
4803 return Loc.IP;
4804
4805 Directive OMPD = Directive::OMPD_masked;
4806 uint32_t SrcLocStrSize;
4807 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4808 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4809 Value *ThreadId = getOrCreateThreadID(Ident);
4810 Value *Args[] = {Ident, ThreadId, Filter};
4811 Value *ArgsEnd[] = {Ident, ThreadId};
4812
4813 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_masked);
4814 Instruction *EntryCall = createRuntimeFunctionCall(Callee: EntryRTLFn, Args);
4815
4816 Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_masked);
4817 Instruction *ExitCall = createRuntimeFunctionCall(Callee: ExitRTLFn, Args: ArgsEnd);
4818
4819 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
4820 /*Conditional*/ true, /*hasFinalize*/ HasFinalize: true);
4821}
4822
4823static llvm::CallInst *emitNoUnwindRuntimeCall(IRBuilder<> &Builder,
4824 llvm::FunctionCallee Callee,
4825 ArrayRef<llvm::Value *> Args,
4826 const llvm::Twine &Name) {
4827 llvm::CallInst *Call = Builder.CreateCall(
4828 Callee, Args, OpBundles: SmallVector<llvm::OperandBundleDef, 1>(), Name);
4829 Call->setDoesNotThrow();
4830 return Call;
4831}
4832
4833// Expects input basic block is dominated by BeforeScanBB.
4834// Once Scan directive is encountered, the code after scan directive should be
4835// dominated by AfterScanBB. Scan directive splits the code sequence to
4836// scan and input phase. Based on whether inclusive or exclusive
4837// clause is used in the scan directive and whether input loop or scan loop
4838// is lowered, it adds jumps to input and scan phase. First Scan loop is the
4839// input loop and second is the scan loop. The code generated handles only
4840// inclusive scans now.
4841OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createScan(
4842 const LocationDescription &Loc, InsertPointTy AllocaIP,
4843 ArrayRef<llvm::Value *> ScanVars, ArrayRef<llvm::Type *> ScanVarsType,
4844 bool IsInclusive, ScanInfo *ScanRedInfo) {
4845 if (ScanRedInfo->OMPFirstScanLoop) {
4846 llvm::Error Err = emitScanBasedDirectiveDeclsIR(AllocaIP, ScanVars,
4847 ScanVarsType, ScanRedInfo);
4848 if (Err)
4849 return Err;
4850 }
4851 if (!updateToLocation(Loc))
4852 return Loc.IP;
4853
4854 llvm::Value *IV = ScanRedInfo->IV;
4855
4856 if (ScanRedInfo->OMPFirstScanLoop) {
4857 // Emit buffer[i] = red; at the end of the input phase.
4858 for (size_t i = 0; i < ScanVars.size(); i++) {
4859 Value *BuffPtr = (*(ScanRedInfo->ScanBuffPtrs))[ScanVars[i]];
4860 Value *Buff = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BuffPtr);
4861 Type *DestTy = ScanVarsType[i];
4862 Value *Val = Builder.CreateInBoundsGEP(Ty: DestTy, Ptr: Buff, IdxList: IV, Name: "arrayOffset");
4863 Value *Src = Builder.CreateLoad(Ty: DestTy, Ptr: ScanVars[i]);
4864
4865 Builder.CreateStore(Val: Src, Ptr: Val);
4866 }
4867 }
4868 Builder.CreateBr(Dest: ScanRedInfo->OMPScanLoopExit);
4869 emitBlock(BB: ScanRedInfo->OMPScanDispatch,
4870 CurFn: Builder.GetInsertBlock()->getParent());
4871
4872 if (!ScanRedInfo->OMPFirstScanLoop) {
4873 IV = ScanRedInfo->IV;
4874 // Emit red = buffer[i]; at the entrance to the scan phase.
4875 // TODO: if exclusive scan, the red = buffer[i-1] needs to be updated.
4876 for (size_t i = 0; i < ScanVars.size(); i++) {
4877 Value *BuffPtr = (*(ScanRedInfo->ScanBuffPtrs))[ScanVars[i]];
4878 Value *Buff = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BuffPtr);
4879 Type *DestTy = ScanVarsType[i];
4880 Value *SrcPtr =
4881 Builder.CreateInBoundsGEP(Ty: DestTy, Ptr: Buff, IdxList: IV, Name: "arrayOffset");
4882 Value *Src = Builder.CreateLoad(Ty: DestTy, Ptr: SrcPtr);
4883 Builder.CreateStore(Val: Src, Ptr: ScanVars[i]);
4884 }
4885 }
4886
4887 // TODO: Update it to CreateBr and remove dead blocks
4888 llvm::Value *CmpI = Builder.getInt1(V: true);
4889 if (ScanRedInfo->OMPFirstScanLoop == IsInclusive) {
4890 Builder.CreateCondBr(Cond: CmpI, True: ScanRedInfo->OMPBeforeScanBlock,
4891 False: ScanRedInfo->OMPAfterScanBlock);
4892 } else {
4893 Builder.CreateCondBr(Cond: CmpI, True: ScanRedInfo->OMPAfterScanBlock,
4894 False: ScanRedInfo->OMPBeforeScanBlock);
4895 }
4896 emitBlock(BB: ScanRedInfo->OMPAfterScanBlock,
4897 CurFn: Builder.GetInsertBlock()->getParent());
4898 Builder.SetInsertPoint(ScanRedInfo->OMPAfterScanBlock);
4899 return Builder.saveIP();
4900}
4901
4902Error OpenMPIRBuilder::emitScanBasedDirectiveDeclsIR(
4903 InsertPointTy AllocaIP, ArrayRef<Value *> ScanVars,
4904 ArrayRef<Type *> ScanVarsType, ScanInfo *ScanRedInfo) {
4905
4906 Builder.restoreIP(IP: AllocaIP);
4907 // Create the shared pointer at alloca IP.
4908 for (size_t i = 0; i < ScanVars.size(); i++) {
4909 llvm::Value *BuffPtr =
4910 Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: "vla");
4911 (*(ScanRedInfo->ScanBuffPtrs))[ScanVars[i]] = BuffPtr;
4912 }
4913
4914 // Allocate temporary buffer by master thread
4915 auto BodyGenCB = [&](InsertPointTy AllocaIP,
4916 InsertPointTy CodeGenIP) -> Error {
4917 Builder.restoreIP(IP: CodeGenIP);
4918 Value *AllocSpan =
4919 Builder.CreateAdd(LHS: ScanRedInfo->Span, RHS: Builder.getInt32(C: 1));
4920 for (size_t i = 0; i < ScanVars.size(); i++) {
4921 Type *IntPtrTy = Builder.getInt32Ty();
4922 Constant *Allocsize = ConstantExpr::getSizeOf(Ty: ScanVarsType[i]);
4923 Allocsize = ConstantExpr::getTruncOrBitCast(C: Allocsize, Ty: IntPtrTy);
4924 Value *Buff = Builder.CreateMalloc(IntPtrTy, AllocTy: ScanVarsType[i], AllocSize: Allocsize,
4925 ArraySize: AllocSpan, MallocF: nullptr, Name: "arr");
4926 Builder.CreateStore(Val: Buff, Ptr: (*(ScanRedInfo->ScanBuffPtrs))[ScanVars[i]]);
4927 }
4928 return Error::success();
4929 };
4930 // TODO: Perform finalization actions for variables. This has to be
4931 // called for variables which have destructors/finalizers.
4932 auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); };
4933
4934 Builder.SetInsertPoint(ScanRedInfo->OMPScanInit->getTerminator());
4935 llvm::Value *FilterVal = Builder.getInt32(C: 0);
4936 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
4937 createMasked(Loc: Builder.saveIP(), BodyGenCB, FiniCB, Filter: FilterVal);
4938
4939 if (!AfterIP)
4940 return AfterIP.takeError();
4941 Builder.restoreIP(IP: *AfterIP);
4942 BasicBlock *InputBB = Builder.GetInsertBlock();
4943 if (InputBB->hasTerminator())
4944 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
4945 AfterIP = createBarrier(Loc: Builder.saveIP(), Kind: llvm::omp::OMPD_barrier);
4946 if (!AfterIP)
4947 return AfterIP.takeError();
4948 Builder.restoreIP(IP: *AfterIP);
4949
4950 return Error::success();
4951}
4952
4953Error OpenMPIRBuilder::emitScanBasedDirectiveFinalsIR(
4954 ArrayRef<ReductionInfo> ReductionInfos, ScanInfo *ScanRedInfo) {
4955 auto BodyGenCB = [&](InsertPointTy AllocaIP,
4956 InsertPointTy CodeGenIP) -> Error {
4957 Builder.restoreIP(IP: CodeGenIP);
4958 for (ReductionInfo RedInfo : ReductionInfos) {
4959 Value *PrivateVar = RedInfo.PrivateVariable;
4960 Value *OrigVar = RedInfo.Variable;
4961 Value *BuffPtr = (*(ScanRedInfo->ScanBuffPtrs))[PrivateVar];
4962 Value *Buff = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BuffPtr);
4963
4964 Type *SrcTy = RedInfo.ElementType;
4965 Value *Val = Builder.CreateInBoundsGEP(Ty: SrcTy, Ptr: Buff, IdxList: ScanRedInfo->Span,
4966 Name: "arrayOffset");
4967 Value *Src = Builder.CreateLoad(Ty: SrcTy, Ptr: Val);
4968
4969 Builder.CreateStore(Val: Src, Ptr: OrigVar);
4970 Builder.CreateFree(Source: Buff);
4971 }
4972 return Error::success();
4973 };
4974 // TODO: Perform finalization actions for variables. This has to be
4975 // called for variables which have destructors/finalizers.
4976 auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); };
4977
4978 if (Instruction *TI = ScanRedInfo->OMPScanFinish->getTerminatorOrNull())
4979 Builder.SetInsertPoint(TI);
4980 else
4981 Builder.SetInsertPoint(ScanRedInfo->OMPScanFinish);
4982
4983 llvm::Value *FilterVal = Builder.getInt32(C: 0);
4984 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
4985 createMasked(Loc: Builder.saveIP(), BodyGenCB, FiniCB, Filter: FilterVal);
4986
4987 if (!AfterIP)
4988 return AfterIP.takeError();
4989 Builder.restoreIP(IP: *AfterIP);
4990 BasicBlock *InputBB = Builder.GetInsertBlock();
4991 if (InputBB->hasTerminator())
4992 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
4993 AfterIP = createBarrier(Loc: Builder.saveIP(), Kind: llvm::omp::OMPD_barrier);
4994 if (!AfterIP)
4995 return AfterIP.takeError();
4996 Builder.restoreIP(IP: *AfterIP);
4997 return Error::success();
4998}
4999
5000OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitScanReduction(
5001 const LocationDescription &Loc,
5002 ArrayRef<llvm::OpenMPIRBuilder::ReductionInfo> ReductionInfos,
5003 ScanInfo *ScanRedInfo) {
5004
5005 if (!updateToLocation(Loc))
5006 return Loc.IP;
5007 auto BodyGenCB = [&](InsertPointTy AllocaIP,
5008 InsertPointTy CodeGenIP) -> Error {
5009 Builder.restoreIP(IP: CodeGenIP);
5010 Function *CurFn = Builder.GetInsertBlock()->getParent();
5011 // for (int k = 0; k <= ceil(log2(n)); ++k)
5012 llvm::BasicBlock *LoopBB =
5013 BasicBlock::Create(Context&: CurFn->getContext(), Name: "omp.outer.log.scan.body");
5014 llvm::BasicBlock *ExitBB =
5015 splitBB(Builder, CreateBranch: false, Name: "omp.outer.log.scan.exit");
5016 llvm::Function *F = llvm::Intrinsic::getOrInsertDeclaration(
5017 M: Builder.GetInsertBlock()->getModule(),
5018 id: (llvm::Intrinsic::ID)llvm::Intrinsic::log2, OverloadTys: Builder.getDoubleTy());
5019 llvm::BasicBlock *InputBB = Builder.GetInsertBlock();
5020 llvm::Value *Arg =
5021 Builder.CreateUIToFP(V: ScanRedInfo->Span, DestTy: Builder.getDoubleTy());
5022 llvm::Value *LogVal = emitNoUnwindRuntimeCall(Builder, Callee: F, Args: Arg, Name: "");
5023 F = llvm::Intrinsic::getOrInsertDeclaration(
5024 M: Builder.GetInsertBlock()->getModule(),
5025 id: (llvm::Intrinsic::ID)llvm::Intrinsic::ceil, OverloadTys: Builder.getDoubleTy());
5026 LogVal = emitNoUnwindRuntimeCall(Builder, Callee: F, Args: LogVal, Name: "");
5027 LogVal = Builder.CreateFPToUI(V: LogVal, DestTy: Builder.getInt32Ty());
5028 llvm::Value *NMin1 = Builder.CreateNUWSub(
5029 LHS: ScanRedInfo->Span,
5030 RHS: llvm::ConstantInt::get(Ty: ScanRedInfo->Span->getType(), V: 1));
5031 Builder.SetInsertPoint(InputBB);
5032 Builder.CreateBr(Dest: LoopBB);
5033 emitBlock(BB: LoopBB, CurFn);
5034 Builder.SetInsertPoint(LoopBB);
5035
5036 PHINode *Counter = Builder.CreatePHI(Ty: Builder.getInt32Ty(), NumReservedValues: 2);
5037 // size pow2k = 1;
5038 PHINode *Pow2K = Builder.CreatePHI(Ty: Builder.getInt32Ty(), NumReservedValues: 2);
5039 Counter->addIncoming(V: llvm::ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
5040 BB: InputBB);
5041 Pow2K->addIncoming(V: llvm::ConstantInt::get(Ty: Builder.getInt32Ty(), V: 1),
5042 BB: InputBB);
5043 // for (size i = n - 1; i >= 2 ^ k; --i)
5044 // tmp[i] op= tmp[i-pow2k];
5045 llvm::BasicBlock *InnerLoopBB =
5046 BasicBlock::Create(Context&: CurFn->getContext(), Name: "omp.inner.log.scan.body");
5047 llvm::BasicBlock *InnerExitBB =
5048 BasicBlock::Create(Context&: CurFn->getContext(), Name: "omp.inner.log.scan.exit");
5049 llvm::Value *CmpI = Builder.CreateICmpUGE(LHS: NMin1, RHS: Pow2K);
5050 Builder.CreateCondBr(Cond: CmpI, True: InnerLoopBB, False: InnerExitBB);
5051 emitBlock(BB: InnerLoopBB, CurFn);
5052 Builder.SetInsertPoint(InnerLoopBB);
5053 PHINode *IVal = Builder.CreatePHI(Ty: Builder.getInt32Ty(), NumReservedValues: 2);
5054 IVal->addIncoming(V: NMin1, BB: LoopBB);
5055 for (ReductionInfo RedInfo : ReductionInfos) {
5056 Value *ReductionVal = RedInfo.PrivateVariable;
5057 Value *BuffPtr = (*(ScanRedInfo->ScanBuffPtrs))[ReductionVal];
5058 Value *Buff = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BuffPtr);
5059 Type *DestTy = RedInfo.ElementType;
5060 Value *IV = Builder.CreateAdd(LHS: IVal, RHS: Builder.getInt32(C: 1));
5061 Value *LHSPtr =
5062 Builder.CreateInBoundsGEP(Ty: DestTy, Ptr: Buff, IdxList: IV, Name: "arrayOffset");
5063 Value *OffsetIval = Builder.CreateNUWSub(LHS: IV, RHS: Pow2K);
5064 Value *RHSPtr =
5065 Builder.CreateInBoundsGEP(Ty: DestTy, Ptr: Buff, IdxList: OffsetIval, Name: "arrayOffset");
5066 Value *LHS = Builder.CreateLoad(Ty: DestTy, Ptr: LHSPtr);
5067 Value *RHS = Builder.CreateLoad(Ty: DestTy, Ptr: RHSPtr);
5068 llvm::Value *Result;
5069 InsertPointOrErrorTy AfterIP =
5070 RedInfo.ReductionGen(Builder.saveIP(), LHS, RHS, Result);
5071 if (!AfterIP)
5072 return AfterIP.takeError();
5073 Builder.CreateStore(Val: Result, Ptr: LHSPtr);
5074 }
5075 llvm::Value *NextIVal = Builder.CreateNUWSub(
5076 LHS: IVal, RHS: llvm::ConstantInt::get(Ty: Builder.getInt32Ty(), V: 1));
5077 IVal->addIncoming(V: NextIVal, BB: Builder.GetInsertBlock());
5078 CmpI = Builder.CreateICmpUGE(LHS: NextIVal, RHS: Pow2K);
5079 Builder.CreateCondBr(Cond: CmpI, True: InnerLoopBB, False: InnerExitBB);
5080 emitBlock(BB: InnerExitBB, CurFn);
5081 llvm::Value *Next = Builder.CreateNUWAdd(
5082 LHS: Counter, RHS: llvm::ConstantInt::get(Ty: Counter->getType(), V: 1));
5083 Counter->addIncoming(V: Next, BB: Builder.GetInsertBlock());
5084 // pow2k <<= 1;
5085 llvm::Value *NextPow2K = Builder.CreateShl(LHS: Pow2K, RHS: 1, Name: "", /*HasNUW=*/true);
5086 Pow2K->addIncoming(V: NextPow2K, BB: Builder.GetInsertBlock());
5087 llvm::Value *Cmp = Builder.CreateICmpNE(LHS: Next, RHS: LogVal);
5088 Builder.CreateCondBr(Cond: Cmp, True: LoopBB, False: ExitBB);
5089 Builder.SetInsertPoint(ExitBB->getFirstInsertionPt());
5090 return Error::success();
5091 };
5092
5093 // TODO: Perform finalization actions for variables. This has to be
5094 // called for variables which have destructors/finalizers.
5095 auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); };
5096
5097 llvm::Value *FilterVal = Builder.getInt32(C: 0);
5098 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
5099 createMasked(Loc: Builder.saveIP(), BodyGenCB, FiniCB, Filter: FilterVal);
5100
5101 if (!AfterIP)
5102 return AfterIP.takeError();
5103 Builder.restoreIP(IP: *AfterIP);
5104 AfterIP = createBarrier(Loc: Builder.saveIP(), Kind: llvm::omp::OMPD_barrier);
5105
5106 if (!AfterIP)
5107 return AfterIP.takeError();
5108 Builder.restoreIP(IP: *AfterIP);
5109 Error Err = emitScanBasedDirectiveFinalsIR(ReductionInfos, ScanRedInfo);
5110 if (Err)
5111 return Err;
5112
5113 return AfterIP;
5114}
5115
5116Error OpenMPIRBuilder::emitScanBasedDirectiveIR(
5117 llvm::function_ref<Error()> InputLoopGen,
5118 llvm::function_ref<Error(LocationDescription Loc)> ScanLoopGen,
5119 ScanInfo *ScanRedInfo) {
5120
5121 {
5122 // Emit loop with input phase:
5123 // for (i: 0..<num_iters>) {
5124 // <input phase>;
5125 // buffer[i] = red;
5126 // }
5127 ScanRedInfo->OMPFirstScanLoop = true;
5128 Error Err = InputLoopGen();
5129 if (Err)
5130 return Err;
5131 }
5132 {
5133 // Emit loop with scan phase:
5134 // for (i: 0..<num_iters>) {
5135 // red = buffer[i];
5136 // <scan phase>;
5137 // }
5138 ScanRedInfo->OMPFirstScanLoop = false;
5139 Error Err = ScanLoopGen(Builder.saveIP());
5140 if (Err)
5141 return Err;
5142 }
5143 return Error::success();
5144}
5145
5146void OpenMPIRBuilder::createScanBBs(ScanInfo *ScanRedInfo) {
5147 Function *Fun = Builder.GetInsertBlock()->getParent();
5148 ScanRedInfo->OMPScanDispatch =
5149 BasicBlock::Create(Context&: Fun->getContext(), Name: "omp.inscan.dispatch");
5150 ScanRedInfo->OMPAfterScanBlock =
5151 BasicBlock::Create(Context&: Fun->getContext(), Name: "omp.after.scan.bb");
5152 ScanRedInfo->OMPBeforeScanBlock =
5153 BasicBlock::Create(Context&: Fun->getContext(), Name: "omp.before.scan.bb");
5154 ScanRedInfo->OMPScanLoopExit =
5155 BasicBlock::Create(Context&: Fun->getContext(), Name: "omp.scan.loop.exit");
5156}
5157CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton(
5158 DebugLoc DL, Value *TripCount, Function *F, BasicBlock *PreInsertBefore,
5159 BasicBlock *PostInsertBefore, const Twine &Name) {
5160 Module *M = F->getParent();
5161 LLVMContext &Ctx = M->getContext();
5162 Type *IndVarTy = TripCount->getType();
5163
5164 // Create the basic block structure.
5165 BasicBlock *Preheader =
5166 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".preheader", Parent: F, InsertBefore: PreInsertBefore);
5167 BasicBlock *Header =
5168 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".header", Parent: F, InsertBefore: PreInsertBefore);
5169 BasicBlock *Cond =
5170 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".cond", Parent: F, InsertBefore: PreInsertBefore);
5171 BasicBlock *Body =
5172 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".body", Parent: F, InsertBefore: PreInsertBefore);
5173 BasicBlock *Latch =
5174 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".inc", Parent: F, InsertBefore: PostInsertBefore);
5175 BasicBlock *Exit =
5176 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".exit", Parent: F, InsertBefore: PostInsertBefore);
5177 BasicBlock *After =
5178 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".after", Parent: F, InsertBefore: PostInsertBefore);
5179
5180 // Use specified DebugLoc for new instructions.
5181 Builder.SetCurrentDebugLocation(DL);
5182
5183 Builder.SetInsertPoint(Preheader);
5184 Builder.CreateBr(Dest: Header);
5185
5186 Builder.SetInsertPoint(Header);
5187 PHINode *IndVarPHI = Builder.CreatePHI(Ty: IndVarTy, NumReservedValues: 2, Name: "omp_" + Name + ".iv");
5188 IndVarPHI->addIncoming(V: ConstantInt::get(Ty: IndVarTy, V: 0), BB: Preheader);
5189 Builder.CreateBr(Dest: Cond);
5190
5191 Builder.SetInsertPoint(Cond);
5192 Value *Cmp =
5193 Builder.CreateICmpULT(LHS: IndVarPHI, RHS: TripCount, Name: "omp_" + Name + ".cmp");
5194 Builder.CreateCondBr(Cond: Cmp, True: Body, False: Exit);
5195
5196 Builder.SetInsertPoint(Body);
5197 Builder.CreateBr(Dest: Latch);
5198
5199 Builder.SetInsertPoint(Latch);
5200 Value *Next = Builder.CreateAdd(LHS: IndVarPHI, RHS: ConstantInt::get(Ty: IndVarTy, V: 1),
5201 Name: "omp_" + Name + ".next", /*HasNUW=*/true);
5202 Builder.CreateBr(Dest: Header);
5203 IndVarPHI->addIncoming(V: Next, BB: Latch);
5204
5205 Builder.SetInsertPoint(Exit);
5206 Builder.CreateBr(Dest: After);
5207
5208 // Remember and return the canonical control flow.
5209 LoopInfos.emplace_front();
5210 CanonicalLoopInfo *CL = &LoopInfos.front();
5211
5212 CL->Header = Header;
5213 CL->Cond = Cond;
5214 CL->Latch = Latch;
5215 CL->Exit = Exit;
5216
5217#ifndef NDEBUG
5218 CL->assertOK();
5219#endif
5220 return CL;
5221}
5222
5223Expected<CanonicalLoopInfo *>
5224OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
5225 LoopBodyGenCallbackTy BodyGenCB,
5226 Value *TripCount, const Twine &Name) {
5227 BasicBlock *BB = Loc.IP.getBlock();
5228 BasicBlock *NextBB = BB->getNextNode();
5229
5230 CanonicalLoopInfo *CL = createLoopSkeleton(DL: Loc.DL, TripCount, F: BB->getParent(),
5231 PreInsertBefore: NextBB, PostInsertBefore: NextBB, Name);
5232 BasicBlock *After = CL->getAfter();
5233
5234 // If location is not set, don't connect the loop.
5235 if (updateToLocation(Loc)) {
5236 // Split the loop at the insertion point: Branch to the preheader and move
5237 // every following instruction to after the loop (the After BB). Also, the
5238 // new successor is the loop's after block.
5239 spliceBB(Builder, New: After, /*CreateBranch=*/false);
5240 Builder.CreateBr(Dest: CL->getPreheader());
5241 }
5242
5243 // Emit the body content. We do it after connecting the loop to the CFG to
5244 // avoid that the callback encounters degenerate BBs.
5245 if (Error Err = BodyGenCB(CL->getBodyIP(), CL->getIndVar()))
5246 return Err;
5247
5248#ifndef NDEBUG
5249 CL->assertOK();
5250#endif
5251 return CL;
5252}
5253
5254Expected<ScanInfo *> OpenMPIRBuilder::scanInfoInitialize() {
5255 ScanInfos.emplace_front();
5256 ScanInfo *Result = &ScanInfos.front();
5257 return Result;
5258}
5259
5260Expected<SmallVector<llvm::CanonicalLoopInfo *>>
5261OpenMPIRBuilder::createCanonicalScanLoops(
5262 const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
5263 Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
5264 InsertPointTy ComputeIP, const Twine &Name, ScanInfo *ScanRedInfo) {
5265 LocationDescription ComputeLoc =
5266 ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
5267 updateToLocation(Loc: ComputeLoc);
5268
5269 SmallVector<CanonicalLoopInfo *> Result;
5270
5271 Value *TripCount = calculateCanonicalLoopTripCount(
5272 Loc: ComputeLoc, Start, Stop, Step, IsSigned, InclusiveStop, Name);
5273 ScanRedInfo->Span = TripCount;
5274 ScanRedInfo->OMPScanInit = splitBB(Builder, CreateBranch: true, Name: "scan.init");
5275 Builder.SetInsertPoint(ScanRedInfo->OMPScanInit);
5276
5277 auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
5278 Builder.restoreIP(IP: CodeGenIP);
5279 ScanRedInfo->IV = IV;
5280 createScanBBs(ScanRedInfo);
5281 BasicBlock *InputBlock = Builder.GetInsertBlock();
5282 Instruction *Terminator = InputBlock->getTerminator();
5283 assert(Terminator->getNumSuccessors() == 1);
5284 BasicBlock *ContinueBlock = Terminator->getSuccessor(Idx: 0);
5285 Terminator->setSuccessor(Idx: 0, BB: ScanRedInfo->OMPScanDispatch);
5286 emitBlock(BB: ScanRedInfo->OMPBeforeScanBlock,
5287 CurFn: Builder.GetInsertBlock()->getParent());
5288 Builder.CreateBr(Dest: ScanRedInfo->OMPScanLoopExit);
5289 emitBlock(BB: ScanRedInfo->OMPScanLoopExit,
5290 CurFn: Builder.GetInsertBlock()->getParent());
5291 Builder.CreateBr(Dest: ContinueBlock);
5292 Builder.SetInsertPoint(
5293 ScanRedInfo->OMPBeforeScanBlock->getFirstInsertionPt());
5294 return BodyGenCB(Builder.saveIP(), IV);
5295 };
5296
5297 const auto &&InputLoopGen = [&]() -> Error {
5298 Expected<CanonicalLoopInfo *> LoopInfo = createCanonicalLoop(
5299 Loc: Builder.saveIP(), BodyGenCB: BodyGen, Start, Stop, Step, IsSigned, InclusiveStop,
5300 ComputeIP, Name, InScan: true, ScanRedInfo);
5301 if (!LoopInfo)
5302 return LoopInfo.takeError();
5303 Result.push_back(Elt: *LoopInfo);
5304 Builder.restoreIP(IP: (*LoopInfo)->getAfterIP());
5305 return Error::success();
5306 };
5307 const auto &&ScanLoopGen = [&](LocationDescription Loc) -> Error {
5308 Expected<CanonicalLoopInfo *> LoopInfo =
5309 createCanonicalLoop(Loc, BodyGenCB: BodyGen, Start, Stop, Step, IsSigned,
5310 InclusiveStop, ComputeIP, Name, InScan: true, ScanRedInfo);
5311 if (!LoopInfo)
5312 return LoopInfo.takeError();
5313 Result.push_back(Elt: *LoopInfo);
5314 Builder.restoreIP(IP: (*LoopInfo)->getAfterIP());
5315 ScanRedInfo->OMPScanFinish = Builder.GetInsertBlock();
5316 return Error::success();
5317 };
5318 Error Err = emitScanBasedDirectiveIR(InputLoopGen, ScanLoopGen, ScanRedInfo);
5319 if (Err)
5320 return Err;
5321 return Result;
5322}
5323
5324Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
5325 const LocationDescription &Loc, Value *Start, Value *Stop, Value *Step,
5326 bool IsSigned, bool InclusiveStop, const Twine &Name) {
5327
5328 // Consider the following difficulties (assuming 8-bit signed integers):
5329 // * Adding \p Step to the loop counter which passes \p Stop may overflow:
5330 // DO I = 1, 100, 50
5331 /// * A \p Step of INT_MIN cannot not be normalized to a positive direction:
5332 // DO I = 100, 0, -128
5333
5334 // Start, Stop and Step must be of the same integer type.
5335 auto *IndVarTy = cast<IntegerType>(Val: Start->getType());
5336 assert(IndVarTy == Stop->getType() && "Stop type mismatch");
5337 assert(IndVarTy == Step->getType() && "Step type mismatch");
5338
5339 updateToLocation(Loc);
5340
5341 ConstantInt *Zero = ConstantInt::get(Ty: IndVarTy, V: 0);
5342 ConstantInt *One = ConstantInt::get(Ty: IndVarTy, V: 1);
5343
5344 // Like Step, but always positive.
5345 Value *Incr = Step;
5346
5347 // Distance between Start and Stop; always positive.
5348 Value *Span;
5349
5350 // Condition whether there are no iterations are executed at all, e.g. because
5351 // UB < LB.
5352 Value *ZeroCmp;
5353
5354 if (IsSigned) {
5355 // Ensure that increment is positive. If not, negate and invert LB and UB.
5356 Value *IsNeg = Builder.CreateICmpSLT(LHS: Step, RHS: Zero);
5357 Incr = Builder.CreateSelect(C: IsNeg, True: Builder.CreateNeg(V: Step), False: Step);
5358 Value *LB = Builder.CreateSelect(C: IsNeg, True: Stop, False: Start);
5359 Value *UB = Builder.CreateSelect(C: IsNeg, True: Start, False: Stop);
5360 Span = Builder.CreateSub(LHS: UB, RHS: LB, Name: "", HasNUW: false, HasNSW: true);
5361 ZeroCmp = Builder.CreateICmp(
5362 P: InclusiveStop ? CmpInst::ICMP_SLT : CmpInst::ICMP_SLE, LHS: UB, RHS: LB);
5363 } else {
5364 Span = Builder.CreateSub(LHS: Stop, RHS: Start, Name: "", HasNUW: true);
5365 ZeroCmp = Builder.CreateICmp(
5366 P: InclusiveStop ? CmpInst::ICMP_ULT : CmpInst::ICMP_ULE, LHS: Stop, RHS: Start);
5367 }
5368
5369 Value *CountIfLooping;
5370 if (InclusiveStop) {
5371 CountIfLooping = Builder.CreateAdd(LHS: Builder.CreateUDiv(LHS: Span, RHS: Incr), RHS: One);
5372 } else {
5373 // Avoid incrementing past stop since it could overflow.
5374 Value *CountIfTwo = Builder.CreateAdd(
5375 LHS: Builder.CreateUDiv(LHS: Builder.CreateSub(LHS: Span, RHS: One), RHS: Incr), RHS: One);
5376 Value *OneCmp = Builder.CreateICmp(P: CmpInst::ICMP_ULE, LHS: Span, RHS: Incr);
5377 CountIfLooping = Builder.CreateSelect(C: OneCmp, True: One, False: CountIfTwo);
5378 }
5379
5380 return Builder.CreateSelect(C: ZeroCmp, True: Zero, False: CountIfLooping,
5381 Name: "omp_" + Name + ".tripcount");
5382}
5383
5384Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
5385 const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
5386 Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
5387 InsertPointTy ComputeIP, const Twine &Name, bool InScan,
5388 ScanInfo *ScanRedInfo) {
5389 LocationDescription ComputeLoc =
5390 ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
5391
5392 Value *TripCount = calculateCanonicalLoopTripCount(
5393 Loc: ComputeLoc, Start, Stop, Step, IsSigned, InclusiveStop, Name);
5394
5395 auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
5396 Builder.restoreIP(IP: CodeGenIP);
5397 Value *Span = Builder.CreateMul(LHS: IV, RHS: Step);
5398 Value *IndVar = Builder.CreateAdd(LHS: Span, RHS: Start);
5399 if (InScan)
5400 ScanRedInfo->IV = IndVar;
5401 return BodyGenCB(Builder.saveIP(), IndVar);
5402 };
5403 LocationDescription LoopLoc =
5404 ComputeIP.isSet()
5405 ? Loc
5406 : LocationDescription(Builder.saveIP(),
5407 Builder.getCurrentDebugLocation());
5408 return createCanonicalLoop(Loc: LoopLoc, BodyGenCB: BodyGen, TripCount, Name);
5409}
5410
5411// Returns an LLVM function to call for initializing loop bounds using OpenMP
5412// static scheduling for composite `distribute parallel for` depending on
5413// `type`. Only i32 and i64 are supported by the runtime. Always interpret
5414// integers as unsigned similarly to CanonicalLoopInfo.
5415static FunctionCallee
5416getKmpcDistForStaticInitForType(Type *Ty, Module &M,
5417 OpenMPIRBuilder &OMPBuilder) {
5418 unsigned Bitwidth = Ty->getIntegerBitWidth();
5419 if (Bitwidth == 32)
5420 return OMPBuilder.getOrCreateRuntimeFunction(
5421 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dist_for_static_init_4u);
5422 if (Bitwidth == 64)
5423 return OMPBuilder.getOrCreateRuntimeFunction(
5424 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dist_for_static_init_8u);
5425 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
5426}
5427
5428// Returns an LLVM function to call for initializing loop bounds using OpenMP
5429// static scheduling depending on `type`. Only i32 and i64 are supported by the
5430// runtime. Always interpret integers as unsigned similarly to
5431// CanonicalLoopInfo.
5432static FunctionCallee getKmpcForStaticInitForType(Type *Ty, Module &M,
5433 OpenMPIRBuilder &OMPBuilder) {
5434 unsigned Bitwidth = Ty->getIntegerBitWidth();
5435 if (Bitwidth == 32)
5436 return OMPBuilder.getOrCreateRuntimeFunction(
5437 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_4u);
5438 if (Bitwidth == 64)
5439 return OMPBuilder.getOrCreateRuntimeFunction(
5440 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_8u);
5441 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
5442}
5443
5444OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
5445 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
5446 WorksharingLoopType LoopType, bool NeedsBarrier, bool HasDistSchedule,
5447 OMPScheduleType DistScheduleSchedType) {
5448 assert(CLI->isValid() && "Requires a valid canonical loop");
5449 assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
5450 "Require dedicated allocate IP");
5451
5452 // Set up the source location value for OpenMP runtime.
5453 Builder.restoreIP(IP: CLI->getPreheaderIP());
5454 Builder.SetCurrentDebugLocation(DL);
5455
5456 uint32_t SrcLocStrSize;
5457 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
5458 Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5459
5460 // Declare useful OpenMP runtime functions.
5461 Value *IV = CLI->getIndVar();
5462 Type *IVTy = IV->getType();
5463 FunctionCallee StaticInit =
5464 LoopType == WorksharingLoopType::DistributeForStaticLoop
5465 ? getKmpcDistForStaticInitForType(Ty: IVTy, M, OMPBuilder&: *this)
5466 : getKmpcForStaticInitForType(Ty: IVTy, M, OMPBuilder&: *this);
5467 FunctionCallee StaticFini =
5468 getOrCreateRuntimeFunction(M, FnID: omp::OMPRTL___kmpc_for_static_fini);
5469
5470 // Allocate space for computed loop bounds as expected by the "init" function.
5471 Builder.SetInsertPoint(AllocaIP.getBlock()->getFirstNonPHIOrDbgOrAlloca());
5472
5473 Type *I32Type = Type::getInt32Ty(C&: M.getContext());
5474 Value *PLastIter = Builder.CreateAlloca(Ty: I32Type, ArraySize: nullptr, Name: "p.lastiter");
5475 Value *PLowerBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.lowerbound");
5476 Value *PUpperBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.upperbound");
5477 Value *PStride = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.stride");
5478 CLI->setLastIter(PLastIter);
5479
5480 // At the end of the preheader, prepare for calling the "init" function by
5481 // storing the current loop bounds into the allocated space. A canonical loop
5482 // always iterates from 0 to trip-count with step 1. Note that "init" expects
5483 // and produces an inclusive upper bound.
5484 Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
5485 Constant *Zero = ConstantInt::get(Ty: IVTy, V: 0);
5486 Constant *One = ConstantInt::get(Ty: IVTy, V: 1);
5487 Builder.CreateStore(Val: Zero, Ptr: PLowerBound);
5488 Value *UpperBound = Builder.CreateSub(LHS: CLI->getTripCount(), RHS: One);
5489 Builder.CreateStore(Val: UpperBound, Ptr: PUpperBound);
5490 Builder.CreateStore(Val: One, Ptr: PStride);
5491
5492 Value *ThreadNum = getOrCreateThreadID(Ident: SrcLoc);
5493
5494 OMPScheduleType SchedType =
5495 (LoopType == WorksharingLoopType::DistributeStaticLoop)
5496 ? OMPScheduleType::OrderedDistribute
5497 : OMPScheduleType::UnorderedStatic;
5498 Constant *SchedulingType =
5499 ConstantInt::get(Ty: I32Type, V: static_cast<int>(SchedType));
5500
5501 // Call the "init" function and update the trip count of the loop with the
5502 // value it produced.
5503 auto BuildInitCall = [LoopType, SrcLoc, ThreadNum, PLastIter, PLowerBound,
5504 PUpperBound, IVTy, PStride, One, Zero, StaticInit,
5505 this](Value *SchedulingType, auto &Builder) {
5506 SmallVector<Value *, 10> Args({SrcLoc, ThreadNum, SchedulingType, PLastIter,
5507 PLowerBound, PUpperBound});
5508 if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
5509 Value *PDistUpperBound =
5510 Builder.CreateAlloca(IVTy, nullptr, "p.distupperbound");
5511 Args.push_back(Elt: PDistUpperBound);
5512 }
5513 Args.append(IL: {PStride, One, Zero});
5514 createRuntimeFunctionCall(Callee: StaticInit, Args);
5515 };
5516 BuildInitCall(SchedulingType, Builder);
5517 if (HasDistSchedule &&
5518 LoopType != WorksharingLoopType::DistributeStaticLoop) {
5519 Constant *DistScheduleSchedType = ConstantInt::get(
5520 Ty: I32Type, V: static_cast<int>(omp::OMPScheduleType::OrderedDistribute));
5521 // We want to emit a second init function call for the dist_schedule clause
5522 // to the Distribute construct. This should only be done however if a
5523 // Workshare Loop is nested within a Distribute Construct
5524 BuildInitCall(DistScheduleSchedType, Builder);
5525 }
5526 Value *LowerBound = Builder.CreateLoad(Ty: IVTy, Ptr: PLowerBound);
5527 Value *InclusiveUpperBound = Builder.CreateLoad(Ty: IVTy, Ptr: PUpperBound);
5528 Value *TripCountMinusOne = Builder.CreateSub(LHS: InclusiveUpperBound, RHS: LowerBound);
5529 Value *TripCount = Builder.CreateAdd(LHS: TripCountMinusOne, RHS: One);
5530 CLI->setTripCount(TripCount);
5531
5532 // Update all uses of the induction variable except the one in the condition
5533 // block that compares it with the actual upper bound, and the increment in
5534 // the latch block.
5535
5536 CLI->mapIndVar(Updater: [&](Instruction *OldIV) -> Value * {
5537 Builder.SetInsertPoint(TheBB: CLI->getBody(),
5538 IP: CLI->getBody()->getFirstInsertionPt());
5539 Builder.SetCurrentDebugLocation(DL);
5540 return Builder.CreateAdd(LHS: OldIV, RHS: LowerBound);
5541 });
5542
5543 // In the "exit" block, call the "fini" function.
5544 Builder.SetInsertPoint(TheBB: CLI->getExit(),
5545 IP: CLI->getExit()->getTerminator()->getIterator());
5546 createRuntimeFunctionCall(Callee: StaticFini, Args: {SrcLoc, ThreadNum});
5547
5548 // Add the barrier if requested.
5549 if (NeedsBarrier) {
5550 InsertPointOrErrorTy BarrierIP =
5551 createBarrier(Loc: LocationDescription(Builder.saveIP(), DL),
5552 Kind: omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
5553 /* CheckCancelFlag */ false);
5554 if (!BarrierIP)
5555 return BarrierIP.takeError();
5556 }
5557
5558 InsertPointTy AfterIP = CLI->getAfterIP();
5559 CLI->invalidate();
5560
5561 return AfterIP;
5562}
5563
5564static void addAccessGroupMetadata(BasicBlock *Block, MDNode *AccessGroup,
5565 LoopInfo &LI);
5566static void addLoopMetadata(CanonicalLoopInfo *Loop,
5567 ArrayRef<Metadata *> Properties);
5568
5569static void applyParallelAccessesMetadata(CanonicalLoopInfo *CLI,
5570 LLVMContext &Ctx, Loop *Loop,
5571 LoopInfo &LoopInfo,
5572 SmallVector<Metadata *> &LoopMDList) {
5573 SmallSet<BasicBlock *, 8> Reachable;
5574
5575 // Get the basic blocks from the loop in which memref instructions
5576 // can be found.
5577 // TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
5578 // preferably without running any passes.
5579 for (BasicBlock *Block : Loop->getBlocks()) {
5580 if (Block == CLI->getCond() || Block == CLI->getHeader())
5581 continue;
5582 Reachable.insert(Ptr: Block);
5583 }
5584
5585 // Add access group metadata to memory-access instructions.
5586 MDNode *AccessGroup = MDNode::getDistinct(Context&: Ctx, MDs: {});
5587 for (BasicBlock *BB : Reachable)
5588 addAccessGroupMetadata(Block: BB, AccessGroup, LI&: LoopInfo);
5589 // TODO: If the loop has existing parallel access metadata, have
5590 // to combine two lists.
5591 LoopMDList.push_back(Elt: MDNode::get(
5592 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.parallel_accesses"), AccessGroup}));
5593}
5594
5595OpenMPIRBuilder::InsertPointOrErrorTy
5596OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
5597 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
5598 bool NeedsBarrier, Value *ChunkSize, OMPScheduleType SchedType,
5599 Value *DistScheduleChunkSize, OMPScheduleType DistScheduleSchedType) {
5600 assert(CLI->isValid() && "Requires a valid canonical loop");
5601 assert((ChunkSize || DistScheduleChunkSize) && "Chunk size is required");
5602
5603 LLVMContext &Ctx = CLI->getFunction()->getContext();
5604 Value *IV = CLI->getIndVar();
5605 Value *OrigTripCount = CLI->getTripCount();
5606 Type *IVTy = IV->getType();
5607 assert(IVTy->getIntegerBitWidth() <= 64 &&
5608 "Max supported tripcount bitwidth is 64 bits");
5609 Type *InternalIVTy = IVTy->getIntegerBitWidth() <= 32 ? Type::getInt32Ty(C&: Ctx)
5610 : Type::getInt64Ty(C&: Ctx);
5611 Type *I32Type = Type::getInt32Ty(C&: M.getContext());
5612 Constant *Zero = ConstantInt::get(Ty: InternalIVTy, V: 0);
5613 Constant *One = ConstantInt::get(Ty: InternalIVTy, V: 1);
5614
5615 Function *F = CLI->getFunction();
5616 // Blocks must have terminators.
5617 // FIXME: Don't run analyses on incomplete/invalid IR.
5618 SmallVector<Instruction *> UIs;
5619 for (BasicBlock &BB : *F)
5620 if (!BB.hasTerminator())
5621 UIs.push_back(Elt: new UnreachableInst(F->getContext(), &BB));
5622 FunctionAnalysisManager FAM;
5623 FAM.registerPass(PassBuilder: []() { return DominatorTreeAnalysis(); });
5624 FAM.registerPass(PassBuilder: []() { return PassInstrumentationAnalysis(); });
5625 LoopAnalysis LIA;
5626 LoopInfo &&LI = LIA.run(F&: *F, AM&: FAM);
5627 for (Instruction *I : UIs)
5628 I->eraseFromParent();
5629 Loop *L = LI.getLoopFor(BB: CLI->getHeader());
5630 SmallVector<Metadata *> LoopMDList;
5631 if (ChunkSize || DistScheduleChunkSize)
5632 applyParallelAccessesMetadata(CLI, Ctx, Loop: L, LoopInfo&: LI, LoopMDList);
5633 addLoopMetadata(Loop: CLI, Properties: LoopMDList);
5634
5635 // Declare useful OpenMP runtime functions.
5636 FunctionCallee StaticInit =
5637 getKmpcForStaticInitForType(Ty: InternalIVTy, M, OMPBuilder&: *this);
5638 FunctionCallee StaticFini =
5639 getOrCreateRuntimeFunction(M, FnID: omp::OMPRTL___kmpc_for_static_fini);
5640
5641 // Allocate space for computed loop bounds as expected by the "init" function.
5642 Builder.restoreIP(IP: AllocaIP);
5643 Builder.SetCurrentDebugLocation(DL);
5644 Value *PLastIter = Builder.CreateAlloca(Ty: I32Type, ArraySize: nullptr, Name: "p.lastiter");
5645 Value *PLowerBound =
5646 Builder.CreateAlloca(Ty: InternalIVTy, ArraySize: nullptr, Name: "p.lowerbound");
5647 Value *PUpperBound =
5648 Builder.CreateAlloca(Ty: InternalIVTy, ArraySize: nullptr, Name: "p.upperbound");
5649 Value *PStride = Builder.CreateAlloca(Ty: InternalIVTy, ArraySize: nullptr, Name: "p.stride");
5650 CLI->setLastIter(PLastIter);
5651
5652 // Set up the source location value for the OpenMP runtime.
5653 Builder.restoreIP(IP: CLI->getPreheaderIP());
5654 Builder.SetCurrentDebugLocation(DL);
5655
5656 // TODO: Detect overflow in ubsan or max-out with current tripcount.
5657 Value *CastedChunkSize = Builder.CreateZExtOrTrunc(
5658 V: ChunkSize ? ChunkSize : Zero, DestTy: InternalIVTy, Name: "chunksize");
5659 Value *CastedDistScheduleChunkSize = Builder.CreateZExtOrTrunc(
5660 V: DistScheduleChunkSize ? DistScheduleChunkSize : Zero, DestTy: InternalIVTy,
5661 Name: "distschedulechunksize");
5662 Value *CastedTripCount =
5663 Builder.CreateZExt(V: OrigTripCount, DestTy: InternalIVTy, Name: "tripcount");
5664
5665 Constant *SchedulingType =
5666 ConstantInt::get(Ty: I32Type, V: static_cast<int>(SchedType));
5667 Constant *DistSchedulingType =
5668 ConstantInt::get(Ty: I32Type, V: static_cast<int>(DistScheduleSchedType));
5669 Builder.CreateStore(Val: Zero, Ptr: PLowerBound);
5670 Value *OrigUpperBound = Builder.CreateSub(LHS: CastedTripCount, RHS: One);
5671 Value *IsTripCountZero = Builder.CreateICmpEQ(LHS: CastedTripCount, RHS: Zero);
5672 Value *UpperBound =
5673 Builder.CreateSelect(C: IsTripCountZero, True: Zero, False: OrigUpperBound);
5674 Builder.CreateStore(Val: UpperBound, Ptr: PUpperBound);
5675 Builder.CreateStore(Val: One, Ptr: PStride);
5676
5677 // Call the "init" function and update the trip count of the loop with the
5678 // value it produced.
5679 uint32_t SrcLocStrSize;
5680 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
5681 Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5682 Value *ThreadNum = getOrCreateThreadID(Ident: SrcLoc);
5683 auto BuildInitCall = [StaticInit, SrcLoc, ThreadNum, PLastIter, PLowerBound,
5684 PUpperBound, PStride, One,
5685 this](Value *SchedulingType, Value *ChunkSize,
5686 auto &Builder) {
5687 createRuntimeFunctionCall(
5688 Callee: StaticInit, Args: {/*loc=*/SrcLoc, /*global_tid=*/ThreadNum,
5689 /*schedtype=*/SchedulingType, /*plastiter=*/PLastIter,
5690 /*plower=*/PLowerBound, /*pupper=*/PUpperBound,
5691 /*pstride=*/PStride, /*incr=*/One,
5692 /*chunk=*/ChunkSize});
5693 };
5694 BuildInitCall(SchedulingType, CastedChunkSize, Builder);
5695 if (DistScheduleSchedType != OMPScheduleType::None &&
5696 SchedType != OMPScheduleType::OrderedDistributeChunked &&
5697 SchedType != OMPScheduleType::OrderedDistribute) {
5698 // We want to emit a second init function call for the dist_schedule clause
5699 // to the Distribute construct. This should only be done however if a
5700 // Workshare Loop is nested within a Distribute Construct
5701 BuildInitCall(DistSchedulingType, CastedDistScheduleChunkSize, Builder);
5702 }
5703
5704 // Load values written by the "init" function.
5705 Value *FirstChunkStart =
5706 Builder.CreateLoad(Ty: InternalIVTy, Ptr: PLowerBound, Name: "omp_firstchunk.lb");
5707 Value *FirstChunkStop =
5708 Builder.CreateLoad(Ty: InternalIVTy, Ptr: PUpperBound, Name: "omp_firstchunk.ub");
5709 Value *FirstChunkEnd = Builder.CreateAdd(LHS: FirstChunkStop, RHS: One);
5710 Value *ChunkRange =
5711 Builder.CreateSub(LHS: FirstChunkEnd, RHS: FirstChunkStart, Name: "omp_chunk.range");
5712 Value *NextChunkStride =
5713 Builder.CreateLoad(Ty: InternalIVTy, Ptr: PStride, Name: "omp_dispatch.stride");
5714
5715 // Create outer "dispatch" loop for enumerating the chunks.
5716 BasicBlock *DispatchEnter = splitBB(Builder, CreateBranch: true);
5717 Value *DispatchCounter;
5718
5719 // It is safe to assume this didn't return an error because the callback
5720 // passed into createCanonicalLoop is the only possible error source, and it
5721 // always returns success.
5722 CanonicalLoopInfo *DispatchCLI = cantFail(ValOrErr: createCanonicalLoop(
5723 Loc: {Builder.saveIP(), DL},
5724 BodyGenCB: [&](InsertPointTy BodyIP, Value *Counter) {
5725 DispatchCounter = Counter;
5726 return Error::success();
5727 },
5728 Start: FirstChunkStart, Stop: CastedTripCount, Step: NextChunkStride,
5729 /*IsSigned=*/false, /*InclusiveStop=*/false, /*ComputeIP=*/{},
5730 Name: "dispatch"));
5731
5732 // Remember the BasicBlocks of the dispatch loop we need, then invalidate to
5733 // not have to preserve the canonical invariant.
5734 BasicBlock *DispatchBody = DispatchCLI->getBody();
5735 BasicBlock *DispatchLatch = DispatchCLI->getLatch();
5736 BasicBlock *DispatchExit = DispatchCLI->getExit();
5737 BasicBlock *DispatchAfter = DispatchCLI->getAfter();
5738 DispatchCLI->invalidate();
5739
5740 // Rewire the original loop to become the chunk loop inside the dispatch loop.
5741 redirectTo(Source: DispatchAfter, Target: CLI->getAfter(), DL);
5742 redirectTo(Source: CLI->getExit(), Target: DispatchLatch, DL);
5743 redirectTo(Source: DispatchBody, Target: DispatchEnter, DL);
5744
5745 // Prepare the prolog of the chunk loop.
5746 Builder.restoreIP(IP: CLI->getPreheaderIP());
5747 Builder.SetCurrentDebugLocation(DL);
5748
5749 // Compute the number of iterations of the chunk loop.
5750 Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
5751 Value *ChunkEnd = Builder.CreateAdd(LHS: DispatchCounter, RHS: ChunkRange);
5752 Value *IsLastChunk =
5753 Builder.CreateICmpUGE(LHS: ChunkEnd, RHS: CastedTripCount, Name: "omp_chunk.is_last");
5754 Value *CountUntilOrigTripCount =
5755 Builder.CreateSub(LHS: CastedTripCount, RHS: DispatchCounter);
5756 Value *ChunkTripCount = Builder.CreateSelect(
5757 C: IsLastChunk, True: CountUntilOrigTripCount, False: ChunkRange, Name: "omp_chunk.tripcount");
5758 Value *BackcastedChunkTC =
5759 Builder.CreateTrunc(V: ChunkTripCount, DestTy: IVTy, Name: "omp_chunk.tripcount.trunc");
5760 CLI->setTripCount(BackcastedChunkTC);
5761
5762 // Update all uses of the induction variable except the one in the condition
5763 // block that compares it with the actual upper bound, and the increment in
5764 // the latch block.
5765 Value *BackcastedDispatchCounter =
5766 Builder.CreateTrunc(V: DispatchCounter, DestTy: IVTy, Name: "omp_dispatch.iv.trunc");
5767 CLI->mapIndVar(Updater: [&](Instruction *) -> Value * {
5768 Builder.restoreIP(IP: CLI->getBodyIP());
5769 return Builder.CreateAdd(LHS: IV, RHS: BackcastedDispatchCounter);
5770 });
5771
5772 // In the "exit" block, call the "fini" function.
5773 Builder.SetInsertPoint(TheBB: DispatchExit, IP: DispatchExit->getFirstInsertionPt());
5774 createRuntimeFunctionCall(Callee: StaticFini, Args: {SrcLoc, ThreadNum});
5775
5776 // Add the barrier if requested.
5777 if (NeedsBarrier) {
5778 InsertPointOrErrorTy AfterIP =
5779 createBarrier(Loc: LocationDescription(Builder.saveIP(), DL), Kind: OMPD_for,
5780 /*ForceSimpleCall=*/false, /*CheckCancelFlag=*/false);
5781 if (!AfterIP)
5782 return AfterIP.takeError();
5783 }
5784
5785#ifndef NDEBUG
5786 // Even though we currently do not support applying additional methods to it,
5787 // the chunk loop should remain a canonical loop.
5788 CLI->assertOK();
5789#endif
5790
5791 return InsertPointTy(DispatchAfter, DispatchAfter->getFirstInsertionPt());
5792}
5793
5794// Returns an LLVM function to call for executing an OpenMP static worksharing
5795// for loop depending on `type`. Only i32 and i64 are supported by the runtime.
5796// Always interpret integers as unsigned similarly to CanonicalLoopInfo.
5797static FunctionCallee
5798getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
5799 WorksharingLoopType LoopType) {
5800 unsigned Bitwidth = Ty->getIntegerBitWidth();
5801 Module &M = OMPBuilder->M;
5802 switch (LoopType) {
5803 case WorksharingLoopType::ForStaticLoop:
5804 if (Bitwidth == 32)
5805 return OMPBuilder->getOrCreateRuntimeFunction(
5806 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_4u);
5807 if (Bitwidth == 64)
5808 return OMPBuilder->getOrCreateRuntimeFunction(
5809 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_8u);
5810 break;
5811 case WorksharingLoopType::DistributeStaticLoop:
5812 if (Bitwidth == 32)
5813 return OMPBuilder->getOrCreateRuntimeFunction(
5814 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_4u);
5815 if (Bitwidth == 64)
5816 return OMPBuilder->getOrCreateRuntimeFunction(
5817 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_8u);
5818 break;
5819 case WorksharingLoopType::DistributeForStaticLoop:
5820 if (Bitwidth == 32)
5821 return OMPBuilder->getOrCreateRuntimeFunction(
5822 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_4u);
5823 if (Bitwidth == 64)
5824 return OMPBuilder->getOrCreateRuntimeFunction(
5825 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_8u);
5826 break;
5827 }
5828 if (Bitwidth != 32 && Bitwidth != 64) {
5829 llvm_unreachable("Unknown OpenMP loop iterator bitwidth");
5830 }
5831 llvm_unreachable("Unknown type of OpenMP worksharing loop");
5832}
5833
5834// Inserts a call to proper OpenMP Device RTL function which handles
5835// loop worksharing.
5836static void createTargetLoopWorkshareCall(OpenMPIRBuilder *OMPBuilder,
5837 WorksharingLoopType LoopType,
5838 BasicBlock *InsertBlock, Value *Ident,
5839 Value *LoopBodyArg, Value *TripCount,
5840 Function &LoopBodyFn, bool NoLoop) {
5841 Type *TripCountTy = TripCount->getType();
5842 Module &M = OMPBuilder->M;
5843 IRBuilder<> &Builder = OMPBuilder->Builder;
5844 FunctionCallee RTLFn =
5845 getKmpcForStaticLoopForType(Ty: TripCountTy, OMPBuilder, LoopType);
5846 SmallVector<Value *, 8> RealArgs;
5847 RealArgs.push_back(Elt: Ident);
5848 RealArgs.push_back(Elt: &LoopBodyFn);
5849 RealArgs.push_back(Elt: LoopBodyArg);
5850 RealArgs.push_back(Elt: TripCount);
5851 if (LoopType == WorksharingLoopType::DistributeStaticLoop) {
5852 RealArgs.push_back(Elt: ConstantInt::get(Ty: TripCountTy, V: 0));
5853 RealArgs.push_back(Elt: ConstantInt::get(Ty: Builder.getInt8Ty(), V: 0));
5854 Builder.restoreIP(IP: {InsertBlock, std::prev(x: InsertBlock->end())});
5855 OMPBuilder->createRuntimeFunctionCall(Callee: RTLFn, Args: RealArgs);
5856 return;
5857 }
5858 FunctionCallee RTLNumThreads = OMPBuilder->getOrCreateRuntimeFunction(
5859 M, FnID: omp::RuntimeFunction::OMPRTL_omp_get_num_threads);
5860 Builder.restoreIP(IP: {InsertBlock, std::prev(x: InsertBlock->end())});
5861 Value *NumThreads = OMPBuilder->createRuntimeFunctionCall(Callee: RTLNumThreads, Args: {});
5862
5863 RealArgs.push_back(
5864 Elt: Builder.CreateZExtOrTrunc(V: NumThreads, DestTy: TripCountTy, Name: "num.threads.cast"));
5865 RealArgs.push_back(Elt: ConstantInt::get(Ty: TripCountTy, V: 0));
5866 if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
5867 RealArgs.push_back(Elt: ConstantInt::get(Ty: TripCountTy, V: 0));
5868 RealArgs.push_back(Elt: ConstantInt::get(Ty: Builder.getInt8Ty(), V: NoLoop));
5869 } else {
5870 RealArgs.push_back(Elt: ConstantInt::get(Ty: Builder.getInt8Ty(), V: 0));
5871 }
5872
5873 OMPBuilder->createRuntimeFunctionCall(Callee: RTLFn, Args: RealArgs);
5874}
5875
5876static void workshareLoopTargetCallback(
5877 OpenMPIRBuilder *OMPIRBuilder, CanonicalLoopInfo *CLI, Value *Ident,
5878 Function &OutlinedFn, const SmallVector<Instruction *, 4> &ToBeDeleted,
5879 WorksharingLoopType LoopType, bool NoLoop) {
5880 IRBuilder<> &Builder = OMPIRBuilder->Builder;
5881 BasicBlock *Preheader = CLI->getPreheader();
5882 Value *TripCount = CLI->getTripCount();
5883
5884 // After loop body outling, the loop body contains only set up
5885 // of loop body argument structure and the call to the outlined
5886 // loop body function. Firstly, we need to move setup of loop body args
5887 // into loop preheader.
5888 Preheader->splice(ToIt: std::prev(x: Preheader->end()), FromBB: CLI->getBody(),
5889 FromBeginIt: CLI->getBody()->begin(), FromEndIt: std::prev(x: CLI->getBody()->end()));
5890
5891 // The next step is to remove the whole loop. We do not it need anymore.
5892 // That's why make an unconditional branch from loop preheader to loop
5893 // exit block
5894 Builder.restoreIP(IP: {Preheader, Preheader->end()});
5895 Builder.SetCurrentDebugLocation(Preheader->getTerminator()->getDebugLoc());
5896 Preheader->getTerminator()->eraseFromParent();
5897 Builder.CreateBr(Dest: CLI->getExit());
5898
5899 // Delete dead loop blocks
5900 OpenMPIRBuilder::OutlineInfo CleanUpInfo;
5901 SmallPtrSet<BasicBlock *, 32> RegionBlockSet;
5902 SmallVector<BasicBlock *, 32> BlocksToBeRemoved;
5903 CleanUpInfo.EntryBB = CLI->getHeader();
5904 CleanUpInfo.ExitBB = CLI->getExit();
5905 CleanUpInfo.collectBlocks(BlockSet&: RegionBlockSet, BlockVector&: BlocksToBeRemoved);
5906 DeleteDeadBlocks(BBs: BlocksToBeRemoved);
5907
5908 // Find the instruction which corresponds to loop body argument structure
5909 // and remove the call to loop body function instruction.
5910 Value *LoopBodyArg;
5911 User *OutlinedFnUser = OutlinedFn.getUniqueUndroppableUser();
5912 assert(OutlinedFnUser &&
5913 "Expected unique undroppable user of outlined function");
5914 CallInst *OutlinedFnCallInstruction = dyn_cast<CallInst>(Val: OutlinedFnUser);
5915 assert(OutlinedFnCallInstruction && "Expected outlined function call");
5916 assert((OutlinedFnCallInstruction->getParent() == Preheader) &&
5917 "Expected outlined function call to be located in loop preheader");
5918 // Check in case no argument structure has been passed.
5919 if (OutlinedFnCallInstruction->arg_size() > 1)
5920 LoopBodyArg = OutlinedFnCallInstruction->getArgOperand(i: 1);
5921 else
5922 LoopBodyArg = Constant::getNullValue(Ty: Builder.getPtrTy());
5923 OutlinedFnCallInstruction->eraseFromParent();
5924
5925 createTargetLoopWorkshareCall(OMPBuilder: OMPIRBuilder, LoopType, InsertBlock: Preheader, Ident,
5926 LoopBodyArg, TripCount, LoopBodyFn&: OutlinedFn, NoLoop);
5927
5928 for (auto &ToBeDeletedItem : ToBeDeleted)
5929 ToBeDeletedItem->eraseFromParent();
5930 CLI->invalidate();
5931}
5932
5933OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoopTarget(
5934 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
5935 WorksharingLoopType LoopType, bool NoLoop) {
5936 uint32_t SrcLocStrSize;
5937 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
5938 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5939
5940 OutlineInfo OI;
5941 OI.OuterAllocaBB = CLI->getPreheader();
5942 Function *OuterFn = CLI->getPreheader()->getParent();
5943
5944 // Instructions which need to be deleted at the end of code generation
5945 SmallVector<Instruction *, 4> ToBeDeleted;
5946
5947 OI.OuterAllocaBB = AllocaIP.getBlock();
5948
5949 // Mark the body loop as region which needs to be extracted
5950 OI.EntryBB = CLI->getBody();
5951 OI.ExitBB = CLI->getLatch()->splitBasicBlockBefore(I: CLI->getLatch()->begin(),
5952 BBName: "omp.prelatch");
5953
5954 // Prepare loop body for extraction
5955 Builder.restoreIP(IP: {CLI->getPreheader(), CLI->getPreheader()->begin()});
5956
5957 // Insert new loop counter variable which will be used only in loop
5958 // body.
5959 AllocaInst *NewLoopCnt = Builder.CreateAlloca(Ty: CLI->getIndVarType(), ArraySize: 0, Name: "");
5960 Instruction *NewLoopCntLoad =
5961 Builder.CreateLoad(Ty: CLI->getIndVarType(), Ptr: NewLoopCnt);
5962 // New loop counter instructions are redundant in the loop preheader when
5963 // code generation for workshare loop is finshed. That's why mark them as
5964 // ready for deletion.
5965 ToBeDeleted.push_back(Elt: NewLoopCntLoad);
5966 ToBeDeleted.push_back(Elt: NewLoopCnt);
5967
5968 // Analyse loop body region. Find all input variables which are used inside
5969 // loop body region.
5970 SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
5971 SmallVector<BasicBlock *, 32> Blocks;
5972 OI.collectBlocks(BlockSet&: ParallelRegionBlockSet, BlockVector&: Blocks);
5973
5974 CodeExtractorAnalysisCache CEAC(*OuterFn);
5975 CodeExtractor Extractor(Blocks,
5976 /* DominatorTree */ nullptr,
5977 /* AggregateArgs */ true,
5978 /* BlockFrequencyInfo */ nullptr,
5979 /* BranchProbabilityInfo */ nullptr,
5980 /* AssumptionCache */ nullptr,
5981 /* AllowVarArgs */ true,
5982 /* AllowAlloca */ true,
5983 /* AllocationBlock */ CLI->getPreheader(),
5984 /* Suffix */ ".omp_wsloop",
5985 /* AggrArgsIn0AddrSpace */ true);
5986
5987 BasicBlock *CommonExit = nullptr;
5988 SetVector<Value *> SinkingCands, HoistingCands;
5989
5990 // Find allocas outside the loop body region which are used inside loop
5991 // body
5992 Extractor.findAllocas(CEAC, SinkCands&: SinkingCands, HoistCands&: HoistingCands, ExitBlock&: CommonExit);
5993
5994 // We need to model loop body region as the function f(cnt, loop_arg).
5995 // That's why we replace loop induction variable by the new counter
5996 // which will be one of loop body function argument
5997 SmallVector<User *> Users(CLI->getIndVar()->user_begin(),
5998 CLI->getIndVar()->user_end());
5999 for (auto Use : Users) {
6000 if (Instruction *Inst = dyn_cast<Instruction>(Val: Use)) {
6001 if (ParallelRegionBlockSet.count(Ptr: Inst->getParent())) {
6002 Inst->replaceUsesOfWith(From: CLI->getIndVar(), To: NewLoopCntLoad);
6003 }
6004 }
6005 }
6006 // Make sure that loop counter variable is not merged into loop body
6007 // function argument structure and it is passed as separate variable
6008 OI.ExcludeArgsFromAggregate.push_back(Elt: NewLoopCntLoad);
6009
6010 // PostOutline CB is invoked when loop body function is outlined and
6011 // loop body is replaced by call to outlined function. We need to add
6012 // call to OpenMP device rtl inside loop preheader. OpenMP device rtl
6013 // function will handle loop control logic.
6014 //
6015 OI.PostOutlineCB = [=, ToBeDeletedVec =
6016 std::move(ToBeDeleted)](Function &OutlinedFn) {
6017 workshareLoopTargetCallback(OMPIRBuilder: this, CLI, Ident, OutlinedFn, ToBeDeleted: ToBeDeletedVec,
6018 LoopType, NoLoop);
6019 };
6020 addOutlineInfo(OI: std::move(OI));
6021 return CLI->getAfterIP();
6022}
6023
6024OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyWorkshareLoop(
6025 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
6026 bool NeedsBarrier, omp::ScheduleKind SchedKind, Value *ChunkSize,
6027 bool HasSimdModifier, bool HasMonotonicModifier,
6028 bool HasNonmonotonicModifier, bool HasOrderedClause,
6029 WorksharingLoopType LoopType, bool NoLoop, bool HasDistSchedule,
6030 Value *DistScheduleChunkSize) {
6031 if (Config.isTargetDevice())
6032 return applyWorkshareLoopTarget(DL, CLI, AllocaIP, LoopType, NoLoop);
6033 OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType(
6034 ClauseKind: SchedKind, HasChunks: ChunkSize, HasSimdModifier, HasMonotonicModifier,
6035 HasNonmonotonicModifier, HasOrderedClause, HasDistScheduleChunks: DistScheduleChunkSize);
6036
6037 bool IsOrdered = (EffectiveScheduleType & OMPScheduleType::ModifierOrdered) ==
6038 OMPScheduleType::ModifierOrdered;
6039 OMPScheduleType DistScheduleSchedType = OMPScheduleType::None;
6040 if (HasDistSchedule) {
6041 DistScheduleSchedType = DistScheduleChunkSize
6042 ? OMPScheduleType::OrderedDistributeChunked
6043 : OMPScheduleType::OrderedDistribute;
6044 }
6045 switch (EffectiveScheduleType & ~OMPScheduleType::ModifierMask) {
6046 case OMPScheduleType::BaseStatic:
6047 case OMPScheduleType::BaseDistribute:
6048 assert((!ChunkSize || !DistScheduleChunkSize) &&
6049 "No chunk size with static-chunked schedule");
6050 if (IsOrdered && !HasDistSchedule)
6051 return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, SchedType: EffectiveScheduleType,
6052 NeedsBarrier, Chunk: ChunkSize);
6053 // FIXME: Monotonicity ignored?
6054 if (DistScheduleChunkSize)
6055 return applyStaticChunkedWorkshareLoop(
6056 DL, CLI, AllocaIP, NeedsBarrier, ChunkSize, SchedType: EffectiveScheduleType,
6057 DistScheduleChunkSize, DistScheduleSchedType);
6058 return applyStaticWorkshareLoop(DL, CLI, AllocaIP, LoopType, NeedsBarrier,
6059 HasDistSchedule);
6060
6061 case OMPScheduleType::BaseStaticChunked:
6062 case OMPScheduleType::BaseDistributeChunked:
6063 if (IsOrdered && !HasDistSchedule)
6064 return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, SchedType: EffectiveScheduleType,
6065 NeedsBarrier, Chunk: ChunkSize);
6066 // FIXME: Monotonicity ignored?
6067 return applyStaticChunkedWorkshareLoop(
6068 DL, CLI, AllocaIP, NeedsBarrier, ChunkSize, SchedType: EffectiveScheduleType,
6069 DistScheduleChunkSize, DistScheduleSchedType);
6070
6071 case OMPScheduleType::BaseRuntime:
6072 case OMPScheduleType::BaseAuto:
6073 case OMPScheduleType::BaseGreedy:
6074 case OMPScheduleType::BaseBalanced:
6075 case OMPScheduleType::BaseSteal:
6076 case OMPScheduleType::BaseRuntimeSimd:
6077 assert(!ChunkSize &&
6078 "schedule type does not support user-defined chunk sizes");
6079 [[fallthrough]];
6080 case OMPScheduleType::BaseGuidedSimd:
6081 case OMPScheduleType::BaseDynamicChunked:
6082 case OMPScheduleType::BaseGuidedChunked:
6083 case OMPScheduleType::BaseGuidedIterativeChunked:
6084 case OMPScheduleType::BaseGuidedAnalyticalChunked:
6085 case OMPScheduleType::BaseStaticBalancedChunked:
6086 return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, SchedType: EffectiveScheduleType,
6087 NeedsBarrier, Chunk: ChunkSize);
6088
6089 default:
6090 llvm_unreachable("Unknown/unimplemented schedule kind");
6091 }
6092}
6093
6094/// Returns an LLVM function to call for initializing loop bounds using OpenMP
6095/// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
6096/// the runtime. Always interpret integers as unsigned similarly to
6097/// CanonicalLoopInfo.
6098static FunctionCallee
6099getKmpcForDynamicInitForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
6100 unsigned Bitwidth = Ty->getIntegerBitWidth();
6101 if (Bitwidth == 32)
6102 return OMPBuilder.getOrCreateRuntimeFunction(
6103 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_4u);
6104 if (Bitwidth == 64)
6105 return OMPBuilder.getOrCreateRuntimeFunction(
6106 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_8u);
6107 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
6108}
6109
6110/// Returns an LLVM function to call for updating the next loop using OpenMP
6111/// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
6112/// the runtime. Always interpret integers as unsigned similarly to
6113/// CanonicalLoopInfo.
6114static FunctionCallee
6115getKmpcForDynamicNextForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
6116 unsigned Bitwidth = Ty->getIntegerBitWidth();
6117 if (Bitwidth == 32)
6118 return OMPBuilder.getOrCreateRuntimeFunction(
6119 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_4u);
6120 if (Bitwidth == 64)
6121 return OMPBuilder.getOrCreateRuntimeFunction(
6122 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_8u);
6123 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
6124}
6125
6126/// Returns an LLVM function to call for finalizing the dynamic loop using
6127/// depending on `type`. Only i32 and i64 are supported by the runtime. Always
6128/// interpret integers as unsigned similarly to CanonicalLoopInfo.
6129static FunctionCallee
6130getKmpcForDynamicFiniForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
6131 unsigned Bitwidth = Ty->getIntegerBitWidth();
6132 if (Bitwidth == 32)
6133 return OMPBuilder.getOrCreateRuntimeFunction(
6134 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_4u);
6135 if (Bitwidth == 64)
6136 return OMPBuilder.getOrCreateRuntimeFunction(
6137 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_8u);
6138 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
6139}
6140
6141OpenMPIRBuilder::InsertPointOrErrorTy
6142OpenMPIRBuilder::applyDynamicWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
6143 InsertPointTy AllocaIP,
6144 OMPScheduleType SchedType,
6145 bool NeedsBarrier, Value *Chunk) {
6146 assert(CLI->isValid() && "Requires a valid canonical loop");
6147 assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
6148 "Require dedicated allocate IP");
6149 assert(isValidWorkshareLoopScheduleType(SchedType) &&
6150 "Require valid schedule type");
6151
6152 bool Ordered = (SchedType & OMPScheduleType::ModifierOrdered) ==
6153 OMPScheduleType::ModifierOrdered;
6154
6155 // Set up the source location value for OpenMP runtime.
6156 Builder.SetCurrentDebugLocation(DL);
6157
6158 uint32_t SrcLocStrSize;
6159 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
6160 Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6161
6162 // Declare useful OpenMP runtime functions.
6163 Value *IV = CLI->getIndVar();
6164 Type *IVTy = IV->getType();
6165 FunctionCallee DynamicInit = getKmpcForDynamicInitForType(Ty: IVTy, M, OMPBuilder&: *this);
6166 FunctionCallee DynamicNext = getKmpcForDynamicNextForType(Ty: IVTy, M, OMPBuilder&: *this);
6167
6168 // Allocate space for computed loop bounds as expected by the "init" function.
6169 Builder.SetInsertPoint(AllocaIP.getBlock()->getFirstNonPHIOrDbgOrAlloca());
6170 Type *I32Type = Type::getInt32Ty(C&: M.getContext());
6171 Value *PLastIter = Builder.CreateAlloca(Ty: I32Type, ArraySize: nullptr, Name: "p.lastiter");
6172 Value *PLowerBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.lowerbound");
6173 Value *PUpperBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.upperbound");
6174 Value *PStride = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.stride");
6175 CLI->setLastIter(PLastIter);
6176
6177 // At the end of the preheader, prepare for calling the "init" function by
6178 // storing the current loop bounds into the allocated space. A canonical loop
6179 // always iterates from 0 to trip-count with step 1. Note that "init" expects
6180 // and produces an inclusive upper bound.
6181 BasicBlock *PreHeader = CLI->getPreheader();
6182 Builder.SetInsertPoint(PreHeader->getTerminator());
6183 Constant *One = ConstantInt::get(Ty: IVTy, V: 1);
6184 Builder.CreateStore(Val: One, Ptr: PLowerBound);
6185 Value *UpperBound = CLI->getTripCount();
6186 Builder.CreateStore(Val: UpperBound, Ptr: PUpperBound);
6187 Builder.CreateStore(Val: One, Ptr: PStride);
6188
6189 BasicBlock *Header = CLI->getHeader();
6190 BasicBlock *Exit = CLI->getExit();
6191 BasicBlock *Cond = CLI->getCond();
6192 BasicBlock *Latch = CLI->getLatch();
6193 InsertPointTy AfterIP = CLI->getAfterIP();
6194
6195 // The CLI will be "broken" in the code below, as the loop is no longer
6196 // a valid canonical loop.
6197
6198 if (!Chunk)
6199 Chunk = One;
6200
6201 Value *ThreadNum = getOrCreateThreadID(Ident: SrcLoc);
6202
6203 Constant *SchedulingType =
6204 ConstantInt::get(Ty: I32Type, V: static_cast<int>(SchedType));
6205
6206 // Call the "init" function.
6207 createRuntimeFunctionCall(Callee: DynamicInit, Args: {SrcLoc, ThreadNum, SchedulingType,
6208 /* LowerBound */ One, UpperBound,
6209 /* step */ One, Chunk});
6210
6211 // An outer loop around the existing one.
6212 BasicBlock *OuterCond = BasicBlock::Create(
6213 Context&: PreHeader->getContext(), Name: Twine(PreHeader->getName()) + ".outer.cond",
6214 Parent: PreHeader->getParent());
6215 // This needs to be 32-bit always, so can't use the IVTy Zero above.
6216 Builder.SetInsertPoint(TheBB: OuterCond, IP: OuterCond->getFirstInsertionPt());
6217 Value *Res = createRuntimeFunctionCall(
6218 Callee: DynamicNext,
6219 Args: {SrcLoc, ThreadNum, PLastIter, PLowerBound, PUpperBound, PStride});
6220 Constant *Zero32 = ConstantInt::get(Ty: I32Type, V: 0);
6221 Value *MoreWork = Builder.CreateCmp(Pred: CmpInst::ICMP_NE, LHS: Res, RHS: Zero32);
6222 Value *LowerBound =
6223 Builder.CreateSub(LHS: Builder.CreateLoad(Ty: IVTy, Ptr: PLowerBound), RHS: One, Name: "lb");
6224 Builder.CreateCondBr(Cond: MoreWork, True: Header, False: Exit);
6225
6226 // Change PHI-node in loop header to use outer cond rather than preheader,
6227 // and set IV to the LowerBound.
6228 Instruction *Phi = &Header->front();
6229 auto *PI = cast<PHINode>(Val: Phi);
6230 PI->setIncomingBlock(i: 0, BB: OuterCond);
6231 PI->setIncomingValue(i: 0, V: LowerBound);
6232
6233 // Then set the pre-header to jump to the OuterCond
6234 Instruction *Term = PreHeader->getTerminator();
6235 auto *Br = cast<UncondBrInst>(Val: Term);
6236 Br->setSuccessor(OuterCond);
6237
6238 // Modify the inner condition:
6239 // * Use the UpperBound returned from the DynamicNext call.
6240 // * jump to the loop outer loop when done with one of the inner loops.
6241 Builder.SetInsertPoint(TheBB: Cond, IP: Cond->getFirstInsertionPt());
6242 UpperBound = Builder.CreateLoad(Ty: IVTy, Ptr: PUpperBound, Name: "ub");
6243 Instruction *Comp = &*Builder.GetInsertPoint();
6244 auto *CI = cast<CmpInst>(Val: Comp);
6245 CI->setOperand(i_nocapture: 1, Val_nocapture: UpperBound);
6246 // Redirect the inner exit to branch to outer condition.
6247 Instruction *Branch = &Cond->back();
6248 auto *BI = cast<CondBrInst>(Val: Branch);
6249 assert(BI->getSuccessor(1) == Exit);
6250 BI->setSuccessor(idx: 1, NewSucc: OuterCond);
6251
6252 // Call the "fini" function if "ordered" is present in wsloop directive.
6253 if (Ordered) {
6254 Builder.SetInsertPoint(&Latch->back());
6255 FunctionCallee DynamicFini = getKmpcForDynamicFiniForType(Ty: IVTy, M, OMPBuilder&: *this);
6256 createRuntimeFunctionCall(Callee: DynamicFini, Args: {SrcLoc, ThreadNum});
6257 }
6258
6259 // Add the barrier if requested.
6260 if (NeedsBarrier) {
6261 Builder.SetInsertPoint(&Exit->back());
6262 InsertPointOrErrorTy BarrierIP =
6263 createBarrier(Loc: LocationDescription(Builder.saveIP(), DL),
6264 Kind: omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
6265 /* CheckCancelFlag */ false);
6266 if (!BarrierIP)
6267 return BarrierIP.takeError();
6268 }
6269
6270 CLI->invalidate();
6271 return AfterIP;
6272}
6273
6274/// Redirect all edges that branch to \p OldTarget to \p NewTarget. That is,
6275/// after this \p OldTarget will be orphaned.
6276static void redirectAllPredecessorsTo(BasicBlock *OldTarget,
6277 BasicBlock *NewTarget, DebugLoc DL) {
6278 for (BasicBlock *Pred : make_early_inc_range(Range: predecessors(BB: OldTarget)))
6279 redirectTo(Source: Pred, Target: NewTarget, DL);
6280}
6281
6282/// Determine which blocks in \p BBs are reachable from outside and remove the
6283/// ones that are not reachable from the function.
6284static void removeUnusedBlocksFromParent(ArrayRef<BasicBlock *> BBs) {
6285 SmallPtrSet<BasicBlock *, 6> BBsToErase(llvm::from_range, BBs);
6286 auto HasRemainingUses = [&BBsToErase](BasicBlock *BB) {
6287 for (Use &U : BB->uses()) {
6288 auto *UseInst = dyn_cast<Instruction>(Val: U.getUser());
6289 if (!UseInst)
6290 continue;
6291 if (BBsToErase.count(Ptr: UseInst->getParent()))
6292 continue;
6293 return true;
6294 }
6295 return false;
6296 };
6297
6298 while (BBsToErase.remove_if(P: HasRemainingUses)) {
6299 // Try again if anything was removed.
6300 }
6301
6302 SmallVector<BasicBlock *, 7> BBVec(BBsToErase.begin(), BBsToErase.end());
6303 DeleteDeadBlocks(BBs: BBVec);
6304}
6305
6306CanonicalLoopInfo *
6307OpenMPIRBuilder::collapseLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
6308 InsertPointTy ComputeIP) {
6309 assert(Loops.size() >= 1 && "At least one loop required");
6310 size_t NumLoops = Loops.size();
6311
6312 // Nothing to do if there is already just one loop.
6313 if (NumLoops == 1)
6314 return Loops.front();
6315
6316 CanonicalLoopInfo *Outermost = Loops.front();
6317 CanonicalLoopInfo *Innermost = Loops.back();
6318 BasicBlock *OrigPreheader = Outermost->getPreheader();
6319 BasicBlock *OrigAfter = Outermost->getAfter();
6320 Function *F = OrigPreheader->getParent();
6321
6322 // Loop control blocks that may become orphaned later.
6323 SmallVector<BasicBlock *, 12> OldControlBBs;
6324 OldControlBBs.reserve(N: 6 * Loops.size());
6325 for (CanonicalLoopInfo *Loop : Loops)
6326 Loop->collectControlBlocks(BBs&: OldControlBBs);
6327
6328 // Setup the IRBuilder for inserting the trip count computation.
6329 Builder.SetCurrentDebugLocation(DL);
6330 if (ComputeIP.isSet())
6331 Builder.restoreIP(IP: ComputeIP);
6332 else
6333 Builder.restoreIP(IP: Outermost->getPreheaderIP());
6334
6335 // Derive the collapsed' loop trip count.
6336 // TODO: Find common/largest indvar type.
6337 Value *CollapsedTripCount = nullptr;
6338 for (CanonicalLoopInfo *L : Loops) {
6339 assert(L->isValid() &&
6340 "All loops to collapse must be valid canonical loops");
6341 Value *OrigTripCount = L->getTripCount();
6342 if (!CollapsedTripCount) {
6343 CollapsedTripCount = OrigTripCount;
6344 continue;
6345 }
6346
6347 // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
6348 CollapsedTripCount =
6349 Builder.CreateNUWMul(LHS: CollapsedTripCount, RHS: OrigTripCount);
6350 }
6351
6352 // Create the collapsed loop control flow.
6353 CanonicalLoopInfo *Result =
6354 createLoopSkeleton(DL, TripCount: CollapsedTripCount, F,
6355 PreInsertBefore: OrigPreheader->getNextNode(), PostInsertBefore: OrigAfter, Name: "collapsed");
6356
6357 // Build the collapsed loop body code.
6358 // Start with deriving the input loop induction variables from the collapsed
6359 // one, using a divmod scheme. To preserve the original loops' order, the
6360 // innermost loop use the least significant bits.
6361 Builder.restoreIP(IP: Result->getBodyIP());
6362
6363 Value *Leftover = Result->getIndVar();
6364 SmallVector<Value *> NewIndVars;
6365 NewIndVars.resize(N: NumLoops);
6366 for (int i = NumLoops - 1; i >= 1; --i) {
6367 Value *OrigTripCount = Loops[i]->getTripCount();
6368
6369 Value *NewIndVar = Builder.CreateURem(LHS: Leftover, RHS: OrigTripCount);
6370 NewIndVars[i] = NewIndVar;
6371
6372 Leftover = Builder.CreateUDiv(LHS: Leftover, RHS: OrigTripCount);
6373 }
6374 // Outermost loop gets all the remaining bits.
6375 NewIndVars[0] = Leftover;
6376
6377 // Construct the loop body control flow.
6378 // We progressively construct the branch structure following in direction of
6379 // the control flow, from the leading in-between code, the loop nest body, the
6380 // trailing in-between code, and rejoining the collapsed loop's latch.
6381 // ContinueBlock and ContinuePred keep track of the source(s) of next edge. If
6382 // the ContinueBlock is set, continue with that block. If ContinuePred, use
6383 // its predecessors as sources.
6384 BasicBlock *ContinueBlock = Result->getBody();
6385 BasicBlock *ContinuePred = nullptr;
6386 auto ContinueWith = [&ContinueBlock, &ContinuePred, DL](BasicBlock *Dest,
6387 BasicBlock *NextSrc) {
6388 if (ContinueBlock)
6389 redirectTo(Source: ContinueBlock, Target: Dest, DL);
6390 else
6391 redirectAllPredecessorsTo(OldTarget: ContinuePred, NewTarget: Dest, DL);
6392
6393 ContinueBlock = nullptr;
6394 ContinuePred = NextSrc;
6395 };
6396
6397 // The code before the nested loop of each level.
6398 // Because we are sinking it into the nest, it will be executed more often
6399 // that the original loop. More sophisticated schemes could keep track of what
6400 // the in-between code is and instantiate it only once per thread.
6401 for (size_t i = 0; i < NumLoops - 1; ++i)
6402 ContinueWith(Loops[i]->getBody(), Loops[i + 1]->getHeader());
6403
6404 // Connect the loop nest body.
6405 ContinueWith(Innermost->getBody(), Innermost->getLatch());
6406
6407 // The code after the nested loop at each level.
6408 for (size_t i = NumLoops - 1; i > 0; --i)
6409 ContinueWith(Loops[i]->getAfter(), Loops[i - 1]->getLatch());
6410
6411 // Connect the finished loop to the collapsed loop latch.
6412 ContinueWith(Result->getLatch(), nullptr);
6413
6414 // Replace the input loops with the new collapsed loop.
6415 redirectTo(Source: Outermost->getPreheader(), Target: Result->getPreheader(), DL);
6416 redirectTo(Source: Result->getAfter(), Target: Outermost->getAfter(), DL);
6417
6418 // Replace the input loop indvars with the derived ones.
6419 for (size_t i = 0; i < NumLoops; ++i)
6420 Loops[i]->getIndVar()->replaceAllUsesWith(V: NewIndVars[i]);
6421
6422 // Remove unused parts of the input loops.
6423 removeUnusedBlocksFromParent(BBs: OldControlBBs);
6424
6425 for (CanonicalLoopInfo *L : Loops)
6426 L->invalidate();
6427
6428#ifndef NDEBUG
6429 Result->assertOK();
6430#endif
6431 return Result;
6432}
6433
6434std::vector<CanonicalLoopInfo *>
6435OpenMPIRBuilder::tileLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
6436 ArrayRef<Value *> TileSizes) {
6437 assert(TileSizes.size() == Loops.size() &&
6438 "Must pass as many tile sizes as there are loops");
6439 int NumLoops = Loops.size();
6440 assert(NumLoops >= 1 && "At least one loop to tile required");
6441
6442 CanonicalLoopInfo *OutermostLoop = Loops.front();
6443 CanonicalLoopInfo *InnermostLoop = Loops.back();
6444 Function *F = OutermostLoop->getBody()->getParent();
6445 BasicBlock *InnerEnter = InnermostLoop->getBody();
6446 BasicBlock *InnerLatch = InnermostLoop->getLatch();
6447
6448 // Loop control blocks that may become orphaned later.
6449 SmallVector<BasicBlock *, 12> OldControlBBs;
6450 OldControlBBs.reserve(N: 6 * Loops.size());
6451 for (CanonicalLoopInfo *Loop : Loops)
6452 Loop->collectControlBlocks(BBs&: OldControlBBs);
6453
6454 // Collect original trip counts and induction variable to be accessible by
6455 // index. Also, the structure of the original loops is not preserved during
6456 // the construction of the tiled loops, so do it before we scavenge the BBs of
6457 // any original CanonicalLoopInfo.
6458 SmallVector<Value *, 4> OrigTripCounts, OrigIndVars;
6459 for (CanonicalLoopInfo *L : Loops) {
6460 assert(L->isValid() && "All input loops must be valid canonical loops");
6461 OrigTripCounts.push_back(Elt: L->getTripCount());
6462 OrigIndVars.push_back(Elt: L->getIndVar());
6463 }
6464
6465 // Collect the code between loop headers. These may contain SSA definitions
6466 // that are used in the loop nest body. To be usable with in the innermost
6467 // body, these BasicBlocks will be sunk into the loop nest body. That is,
6468 // these instructions may be executed more often than before the tiling.
6469 // TODO: It would be sufficient to only sink them into body of the
6470 // corresponding tile loop.
6471 SmallVector<std::pair<BasicBlock *, BasicBlock *>, 4> InbetweenCode;
6472 for (int i = 0; i < NumLoops - 1; ++i) {
6473 CanonicalLoopInfo *Surrounding = Loops[i];
6474 CanonicalLoopInfo *Nested = Loops[i + 1];
6475
6476 BasicBlock *EnterBB = Surrounding->getBody();
6477 BasicBlock *ExitBB = Nested->getHeader();
6478 InbetweenCode.emplace_back(Args&: EnterBB, Args&: ExitBB);
6479 }
6480
6481 // Compute the trip counts of the floor loops.
6482 Builder.SetCurrentDebugLocation(DL);
6483 Builder.restoreIP(IP: OutermostLoop->getPreheaderIP());
6484 SmallVector<Value *, 4> FloorCompleteCount, FloorCount, FloorRems;
6485 for (int i = 0; i < NumLoops; ++i) {
6486 Value *TileSize = TileSizes[i];
6487 Value *OrigTripCount = OrigTripCounts[i];
6488 Type *IVType = OrigTripCount->getType();
6489
6490 Value *FloorCompleteTripCount = Builder.CreateUDiv(LHS: OrigTripCount, RHS: TileSize);
6491 Value *FloorTripRem = Builder.CreateURem(LHS: OrigTripCount, RHS: TileSize);
6492
6493 // 0 if tripcount divides the tilesize, 1 otherwise.
6494 // 1 means we need an additional iteration for a partial tile.
6495 //
6496 // Unfortunately we cannot just use the roundup-formula
6497 // (tripcount + tilesize - 1)/tilesize
6498 // because the summation might overflow. We do not want introduce undefined
6499 // behavior when the untiled loop nest did not.
6500 Value *FloorTripOverflow =
6501 Builder.CreateICmpNE(LHS: FloorTripRem, RHS: ConstantInt::get(Ty: IVType, V: 0));
6502
6503 FloorTripOverflow = Builder.CreateZExt(V: FloorTripOverflow, DestTy: IVType);
6504 Value *FloorTripCount =
6505 Builder.CreateAdd(LHS: FloorCompleteTripCount, RHS: FloorTripOverflow,
6506 Name: "omp_floor" + Twine(i) + ".tripcount", HasNUW: true);
6507
6508 // Remember some values for later use.
6509 FloorCompleteCount.push_back(Elt: FloorCompleteTripCount);
6510 FloorCount.push_back(Elt: FloorTripCount);
6511 FloorRems.push_back(Elt: FloorTripRem);
6512 }
6513
6514 // Generate the new loop nest, from the outermost to the innermost.
6515 std::vector<CanonicalLoopInfo *> Result;
6516 Result.reserve(n: NumLoops * 2);
6517
6518 // The basic block of the surrounding loop that enters the nest generated
6519 // loop.
6520 BasicBlock *Enter = OutermostLoop->getPreheader();
6521
6522 // The basic block of the surrounding loop where the inner code should
6523 // continue.
6524 BasicBlock *Continue = OutermostLoop->getAfter();
6525
6526 // Where the next loop basic block should be inserted.
6527 BasicBlock *OutroInsertBefore = InnermostLoop->getExit();
6528
6529 auto EmbeddNewLoop =
6530 [this, DL, F, InnerEnter, &Enter, &Continue, &OutroInsertBefore](
6531 Value *TripCount, const Twine &Name) -> CanonicalLoopInfo * {
6532 CanonicalLoopInfo *EmbeddedLoop = createLoopSkeleton(
6533 DL, TripCount, F, PreInsertBefore: InnerEnter, PostInsertBefore: OutroInsertBefore, Name);
6534 redirectTo(Source: Enter, Target: EmbeddedLoop->getPreheader(), DL);
6535 redirectTo(Source: EmbeddedLoop->getAfter(), Target: Continue, DL);
6536
6537 // Setup the position where the next embedded loop connects to this loop.
6538 Enter = EmbeddedLoop->getBody();
6539 Continue = EmbeddedLoop->getLatch();
6540 OutroInsertBefore = EmbeddedLoop->getLatch();
6541 return EmbeddedLoop;
6542 };
6543
6544 auto EmbeddNewLoops = [&Result, &EmbeddNewLoop](ArrayRef<Value *> TripCounts,
6545 const Twine &NameBase) {
6546 for (auto P : enumerate(First&: TripCounts)) {
6547 CanonicalLoopInfo *EmbeddedLoop =
6548 EmbeddNewLoop(P.value(), NameBase + Twine(P.index()));
6549 Result.push_back(x: EmbeddedLoop);
6550 }
6551 };
6552
6553 EmbeddNewLoops(FloorCount, "floor");
6554
6555 // Within the innermost floor loop, emit the code that computes the tile
6556 // sizes.
6557 Builder.SetInsertPoint(Enter->getTerminator());
6558 SmallVector<Value *, 4> TileCounts;
6559 for (int i = 0; i < NumLoops; ++i) {
6560 CanonicalLoopInfo *FloorLoop = Result[i];
6561 Value *TileSize = TileSizes[i];
6562
6563 Value *FloorIsEpilogue =
6564 Builder.CreateICmpEQ(LHS: FloorLoop->getIndVar(), RHS: FloorCompleteCount[i]);
6565 Value *TileTripCount =
6566 Builder.CreateSelect(C: FloorIsEpilogue, True: FloorRems[i], False: TileSize);
6567
6568 TileCounts.push_back(Elt: TileTripCount);
6569 }
6570
6571 // Create the tile loops.
6572 EmbeddNewLoops(TileCounts, "tile");
6573
6574 // Insert the inbetween code into the body.
6575 BasicBlock *BodyEnter = Enter;
6576 BasicBlock *BodyEntered = nullptr;
6577 for (std::pair<BasicBlock *, BasicBlock *> P : InbetweenCode) {
6578 BasicBlock *EnterBB = P.first;
6579 BasicBlock *ExitBB = P.second;
6580
6581 if (BodyEnter)
6582 redirectTo(Source: BodyEnter, Target: EnterBB, DL);
6583 else
6584 redirectAllPredecessorsTo(OldTarget: BodyEntered, NewTarget: EnterBB, DL);
6585
6586 BodyEnter = nullptr;
6587 BodyEntered = ExitBB;
6588 }
6589
6590 // Append the original loop nest body into the generated loop nest body.
6591 if (BodyEnter)
6592 redirectTo(Source: BodyEnter, Target: InnerEnter, DL);
6593 else
6594 redirectAllPredecessorsTo(OldTarget: BodyEntered, NewTarget: InnerEnter, DL);
6595 redirectAllPredecessorsTo(OldTarget: InnerLatch, NewTarget: Continue, DL);
6596
6597 // Replace the original induction variable with an induction variable computed
6598 // from the tile and floor induction variables.
6599 Builder.restoreIP(IP: Result.back()->getBodyIP());
6600 for (int i = 0; i < NumLoops; ++i) {
6601 CanonicalLoopInfo *FloorLoop = Result[i];
6602 CanonicalLoopInfo *TileLoop = Result[NumLoops + i];
6603 Value *OrigIndVar = OrigIndVars[i];
6604 Value *Size = TileSizes[i];
6605
6606 Value *Scale =
6607 Builder.CreateMul(LHS: Size, RHS: FloorLoop->getIndVar(), Name: {}, /*HasNUW=*/true);
6608 Value *Shift =
6609 Builder.CreateAdd(LHS: Scale, RHS: TileLoop->getIndVar(), Name: {}, /*HasNUW=*/true);
6610 OrigIndVar->replaceAllUsesWith(V: Shift);
6611 }
6612
6613 // Remove unused parts of the original loops.
6614 removeUnusedBlocksFromParent(BBs: OldControlBBs);
6615
6616 for (CanonicalLoopInfo *L : Loops)
6617 L->invalidate();
6618
6619#ifndef NDEBUG
6620 for (CanonicalLoopInfo *GenL : Result)
6621 GenL->assertOK();
6622#endif
6623 return Result;
6624}
6625
6626/// Attach metadata \p Properties to the basic block described by \p BB. If the
6627/// basic block already has metadata, the basic block properties are appended.
6628static void addBasicBlockMetadata(BasicBlock *BB,
6629 ArrayRef<Metadata *> Properties) {
6630 // Nothing to do if no property to attach.
6631 if (Properties.empty())
6632 return;
6633
6634 LLVMContext &Ctx = BB->getContext();
6635 SmallVector<Metadata *> NewProperties;
6636 NewProperties.push_back(Elt: nullptr);
6637
6638 // If the basic block already has metadata, prepend it to the new metadata.
6639 MDNode *Existing = BB->getTerminator()->getMetadata(KindID: LLVMContext::MD_loop);
6640 if (Existing)
6641 append_range(C&: NewProperties, R: drop_begin(RangeOrContainer: Existing->operands(), N: 1));
6642
6643 append_range(C&: NewProperties, R&: Properties);
6644 MDNode *BasicBlockID = MDNode::getDistinct(Context&: Ctx, MDs: NewProperties);
6645 BasicBlockID->replaceOperandWith(I: 0, New: BasicBlockID);
6646
6647 BB->getTerminator()->setMetadata(KindID: LLVMContext::MD_loop, Node: BasicBlockID);
6648}
6649
6650/// Attach loop metadata \p Properties to the loop described by \p Loop. If the
6651/// loop already has metadata, the loop properties are appended.
6652static void addLoopMetadata(CanonicalLoopInfo *Loop,
6653 ArrayRef<Metadata *> Properties) {
6654 assert(Loop->isValid() && "Expecting a valid CanonicalLoopInfo");
6655
6656 // Attach metadata to the loop's latch
6657 BasicBlock *Latch = Loop->getLatch();
6658 assert(Latch && "A valid CanonicalLoopInfo must have a unique latch");
6659 addBasicBlockMetadata(BB: Latch, Properties);
6660}
6661
6662/// Attach llvm.access.group metadata to the memref instructions of \p Block
6663static void addAccessGroupMetadata(BasicBlock *Block, MDNode *AccessGroup,
6664 LoopInfo &LI) {
6665 for (Instruction &I : *Block) {
6666 if (I.mayReadOrWriteMemory()) {
6667 // TODO: This instruction may already have access group from
6668 // other pragmas e.g. #pragma clang loop vectorize. Append
6669 // so that the existing metadata is not overwritten.
6670 I.setMetadata(KindID: LLVMContext::MD_access_group, Node: AccessGroup);
6671 }
6672 }
6673}
6674
6675CanonicalLoopInfo *
6676OpenMPIRBuilder::fuseLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops) {
6677 CanonicalLoopInfo *firstLoop = Loops.front();
6678 CanonicalLoopInfo *lastLoop = Loops.back();
6679 Function *F = firstLoop->getPreheader()->getParent();
6680
6681 // Loop control blocks that will become orphaned later
6682 SmallVector<BasicBlock *> oldControlBBs;
6683 for (CanonicalLoopInfo *Loop : Loops)
6684 Loop->collectControlBlocks(BBs&: oldControlBBs);
6685
6686 // Collect original trip counts
6687 SmallVector<Value *> origTripCounts;
6688 for (CanonicalLoopInfo *L : Loops) {
6689 assert(L->isValid() && "All input loops must be valid canonical loops");
6690 origTripCounts.push_back(Elt: L->getTripCount());
6691 }
6692
6693 Builder.SetCurrentDebugLocation(DL);
6694
6695 // Compute max trip count.
6696 // The fused loop will be from 0 to max(origTripCounts)
6697 BasicBlock *TCBlock = BasicBlock::Create(Context&: F->getContext(), Name: "omp.fuse.comp.tc",
6698 Parent: F, InsertBefore: firstLoop->getHeader());
6699 Builder.SetInsertPoint(TCBlock);
6700 Value *fusedTripCount = nullptr;
6701 for (CanonicalLoopInfo *L : Loops) {
6702 assert(L->isValid() && "All loops to fuse must be valid canonical loops");
6703 Value *origTripCount = L->getTripCount();
6704 if (!fusedTripCount) {
6705 fusedTripCount = origTripCount;
6706 continue;
6707 }
6708 Value *condTP = Builder.CreateICmpSGT(LHS: fusedTripCount, RHS: origTripCount);
6709 fusedTripCount = Builder.CreateSelect(C: condTP, True: fusedTripCount, False: origTripCount,
6710 Name: ".omp.fuse.tc");
6711 }
6712
6713 // Generate new loop
6714 CanonicalLoopInfo *fused =
6715 createLoopSkeleton(DL, TripCount: fusedTripCount, F, PreInsertBefore: firstLoop->getBody(),
6716 PostInsertBefore: lastLoop->getLatch(), Name: "fused");
6717
6718 // Replace original loops with the fused loop
6719 // Preheader and After are not considered inside the CLI.
6720 // These are used to compute the individual TCs of the loops
6721 // so they have to be put before the resulting fused loop.
6722 // Moving them up for readability.
6723 for (size_t i = 0; i < Loops.size() - 1; ++i) {
6724 Loops[i]->getPreheader()->moveBefore(MovePos: TCBlock);
6725 Loops[i]->getAfter()->moveBefore(MovePos: TCBlock);
6726 }
6727 lastLoop->getPreheader()->moveBefore(MovePos: TCBlock);
6728
6729 for (size_t i = 0; i < Loops.size() - 1; ++i) {
6730 redirectTo(Source: Loops[i]->getPreheader(), Target: Loops[i]->getAfter(), DL);
6731 redirectTo(Source: Loops[i]->getAfter(), Target: Loops[i + 1]->getPreheader(), DL);
6732 }
6733 redirectTo(Source: lastLoop->getPreheader(), Target: TCBlock, DL);
6734 redirectTo(Source: TCBlock, Target: fused->getPreheader(), DL);
6735 redirectTo(Source: fused->getAfter(), Target: lastLoop->getAfter(), DL);
6736
6737 // Build the fused body
6738 // Create new Blocks with conditions that jump to the original loop bodies
6739 SmallVector<BasicBlock *> condBBs;
6740 SmallVector<Value *> condValues;
6741 for (size_t i = 0; i < Loops.size(); ++i) {
6742 BasicBlock *condBlock = BasicBlock::Create(
6743 Context&: F->getContext(), Name: "omp.fused.inner.cond", Parent: F, InsertBefore: Loops[i]->getBody());
6744 Builder.SetInsertPoint(condBlock);
6745 Value *condValue =
6746 Builder.CreateICmpSLT(LHS: fused->getIndVar(), RHS: origTripCounts[i]);
6747 condBBs.push_back(Elt: condBlock);
6748 condValues.push_back(Elt: condValue);
6749 }
6750 // Join the condition blocks with the bodies of the original loops
6751 redirectTo(Source: fused->getBody(), Target: condBBs[0], DL);
6752 for (size_t i = 0; i < Loops.size() - 1; ++i) {
6753 Builder.SetInsertPoint(condBBs[i]);
6754 Builder.CreateCondBr(Cond: condValues[i], True: Loops[i]->getBody(), False: condBBs[i + 1]);
6755 redirectAllPredecessorsTo(OldTarget: Loops[i]->getLatch(), NewTarget: condBBs[i + 1], DL);
6756 // Replace the IV with the fused IV
6757 Loops[i]->getIndVar()->replaceAllUsesWith(V: fused->getIndVar());
6758 }
6759 // Last body jumps to the created end body block
6760 Builder.SetInsertPoint(condBBs.back());
6761 Builder.CreateCondBr(Cond: condValues.back(), True: lastLoop->getBody(),
6762 False: fused->getLatch());
6763 redirectAllPredecessorsTo(OldTarget: lastLoop->getLatch(), NewTarget: fused->getLatch(), DL);
6764 // Replace the IV with the fused IV
6765 lastLoop->getIndVar()->replaceAllUsesWith(V: fused->getIndVar());
6766
6767 // The loop latch must have only one predecessor. Currently it is branched to
6768 // from both the last condition block and the last loop body
6769 fused->getLatch()->splitBasicBlockBefore(I: fused->getLatch()->begin(),
6770 BBName: "omp.fused.pre_latch");
6771
6772 // Remove unused parts
6773 removeUnusedBlocksFromParent(BBs: oldControlBBs);
6774
6775 // Invalidate old CLIs
6776 for (CanonicalLoopInfo *L : Loops)
6777 L->invalidate();
6778
6779#ifndef NDEBUG
6780 fused->assertOK();
6781#endif
6782 return fused;
6783}
6784
6785void OpenMPIRBuilder::unrollLoopFull(DebugLoc, CanonicalLoopInfo *Loop) {
6786 LLVMContext &Ctx = Builder.getContext();
6787 addLoopMetadata(
6788 Loop, Properties: {MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")),
6789 MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.full"))});
6790}
6791
6792void OpenMPIRBuilder::unrollLoopHeuristic(DebugLoc, CanonicalLoopInfo *Loop) {
6793 LLVMContext &Ctx = Builder.getContext();
6794 addLoopMetadata(
6795 Loop, Properties: {
6796 MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")),
6797 });
6798}
6799
6800void OpenMPIRBuilder::createIfVersion(CanonicalLoopInfo *CanonicalLoop,
6801 Value *IfCond, ValueToValueMapTy &VMap,
6802 LoopAnalysis &LIA, LoopInfo &LI, Loop *L,
6803 const Twine &NamePrefix) {
6804 Function *F = CanonicalLoop->getFunction();
6805
6806 // We can't do
6807 // if (cond) {
6808 // simd_loop;
6809 // } else {
6810 // non_simd_loop;
6811 // }
6812 // because then the CanonicalLoopInfo would only point to one of the loops:
6813 // leading to other constructs operating on the same loop to malfunction.
6814 // Instead generate
6815 // while (...) {
6816 // if (cond) {
6817 // simd_body;
6818 // } else {
6819 // not_simd_body;
6820 // }
6821 // }
6822 // At least for simple loops, LLVM seems able to hoist the if out of the loop
6823 // body at -O3
6824
6825 // Define where if branch should be inserted
6826 auto SplitBeforeIt = CanonicalLoop->getBody()->getFirstNonPHIIt();
6827
6828 // Create additional blocks for the if statement
6829 BasicBlock *Cond = SplitBeforeIt->getParent();
6830 llvm::LLVMContext &C = Cond->getContext();
6831 llvm::BasicBlock *ThenBlock = llvm::BasicBlock::Create(
6832 Context&: C, Name: NamePrefix + ".if.then", Parent: Cond->getParent(), InsertBefore: Cond->getNextNode());
6833 llvm::BasicBlock *ElseBlock = llvm::BasicBlock::Create(
6834 Context&: C, Name: NamePrefix + ".if.else", Parent: Cond->getParent(), InsertBefore: CanonicalLoop->getExit());
6835
6836 // Create if condition branch.
6837 Builder.SetInsertPoint(SplitBeforeIt);
6838 Instruction *BrInstr =
6839 Builder.CreateCondBr(Cond: IfCond, True: ThenBlock, /*ifFalse*/ False: ElseBlock);
6840 InsertPointTy IP{BrInstr->getParent(), ++BrInstr->getIterator()};
6841 // Then block contains branch to omp loop body which needs to be vectorized
6842 spliceBB(IP, New: ThenBlock, CreateBranch: false, DL: Builder.getCurrentDebugLocation());
6843 ThenBlock->replaceSuccessorsPhiUsesWith(Old: Cond, New: ThenBlock);
6844
6845 Builder.SetInsertPoint(ElseBlock);
6846
6847 // Clone loop for the else branch
6848 SmallVector<BasicBlock *, 8> NewBlocks;
6849
6850 SmallVector<BasicBlock *, 8> ExistingBlocks;
6851 ExistingBlocks.reserve(N: L->getNumBlocks() + 1);
6852 ExistingBlocks.push_back(Elt: ThenBlock);
6853 ExistingBlocks.append(in_start: L->block_begin(), in_end: L->block_end());
6854 // Cond is the block that has the if clause condition
6855 // LoopCond is omp_loop.cond
6856 // LoopHeader is omp_loop.header
6857 BasicBlock *LoopCond = Cond->getUniquePredecessor();
6858 BasicBlock *LoopHeader = LoopCond->getUniquePredecessor();
6859 assert(LoopCond && LoopHeader && "Invalid loop structure");
6860 for (BasicBlock *Block : ExistingBlocks) {
6861 if (Block == L->getLoopPreheader() || Block == L->getLoopLatch() ||
6862 Block == LoopHeader || Block == LoopCond || Block == Cond) {
6863 continue;
6864 }
6865 BasicBlock *NewBB = CloneBasicBlock(BB: Block, VMap, NameSuffix: "", F);
6866
6867 // fix name not to be omp.if.then
6868 if (Block == ThenBlock)
6869 NewBB->setName(NamePrefix + ".if.else");
6870
6871 NewBB->moveBefore(MovePos: CanonicalLoop->getExit());
6872 VMap[Block] = NewBB;
6873 NewBlocks.push_back(Elt: NewBB);
6874 }
6875 remapInstructionsInBlocks(Blocks: NewBlocks, VMap);
6876 Builder.CreateBr(Dest: NewBlocks.front());
6877
6878 // The loop latch must have only one predecessor. Currently it is branched to
6879 // from both the 'then' and 'else' branches.
6880 L->getLoopLatch()->splitBasicBlockBefore(I: L->getLoopLatch()->begin(),
6881 BBName: NamePrefix + ".pre_latch");
6882
6883 // Ensure that the then block is added to the loop so we add the attributes in
6884 // the next step
6885 L->addBasicBlockToLoop(NewBB: ThenBlock, LI);
6886}
6887
6888unsigned
6889OpenMPIRBuilder::getOpenMPDefaultSimdAlign(const Triple &TargetTriple,
6890 const StringMap<bool> &Features) {
6891 if (TargetTriple.isX86()) {
6892 if (Features.lookup(Key: "avx512f"))
6893 return 512;
6894 else if (Features.lookup(Key: "avx"))
6895 return 256;
6896 return 128;
6897 }
6898 if (TargetTriple.isPPC())
6899 return 128;
6900 if (TargetTriple.isWasm())
6901 return 128;
6902 return 0;
6903}
6904
6905void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
6906 MapVector<Value *, Value *> AlignedVars,
6907 Value *IfCond, OrderKind Order,
6908 ConstantInt *Simdlen, ConstantInt *Safelen) {
6909 LLVMContext &Ctx = Builder.getContext();
6910
6911 Function *F = CanonicalLoop->getFunction();
6912
6913 // Blocks must have terminators.
6914 // FIXME: Don't run analyses on incomplete/invalid IR.
6915 SmallVector<Instruction *> UIs;
6916 for (BasicBlock &BB : *F)
6917 if (!BB.hasTerminator())
6918 UIs.push_back(Elt: new UnreachableInst(F->getContext(), &BB));
6919
6920 // TODO: We should not rely on pass manager. Currently we use pass manager
6921 // only for getting llvm::Loop which corresponds to given CanonicalLoopInfo
6922 // object. We should have a method which returns all blocks between
6923 // CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter()
6924 FunctionAnalysisManager FAM;
6925 FAM.registerPass(PassBuilder: []() { return DominatorTreeAnalysis(); });
6926 FAM.registerPass(PassBuilder: []() { return LoopAnalysis(); });
6927 FAM.registerPass(PassBuilder: []() { return PassInstrumentationAnalysis(); });
6928
6929 LoopAnalysis LIA;
6930 LoopInfo &&LI = LIA.run(F&: *F, AM&: FAM);
6931
6932 for (Instruction *I : UIs)
6933 I->eraseFromParent();
6934
6935 Loop *L = LI.getLoopFor(BB: CanonicalLoop->getHeader());
6936 if (AlignedVars.size()) {
6937 InsertPointTy IP = Builder.saveIP();
6938 for (auto &AlignedItem : AlignedVars) {
6939 Value *AlignedPtr = AlignedItem.first;
6940 Value *Alignment = AlignedItem.second;
6941 Instruction *loadInst = dyn_cast<Instruction>(Val: AlignedPtr);
6942 Builder.SetInsertPoint(loadInst->getNextNode());
6943 Builder.CreateAlignmentAssumption(DL: F->getDataLayout(), PtrValue: AlignedPtr,
6944 Alignment);
6945 }
6946 Builder.restoreIP(IP);
6947 }
6948
6949 if (IfCond) {
6950 ValueToValueMapTy VMap;
6951 createIfVersion(CanonicalLoop, IfCond, VMap, LIA, LI, L, NamePrefix: "simd");
6952 }
6953
6954 SmallPtrSet<BasicBlock *, 8> Reachable;
6955
6956 // Get the basic blocks from the loop in which memref instructions
6957 // can be found.
6958 // TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
6959 // preferably without running any passes.
6960 for (BasicBlock *Block : L->getBlocks()) {
6961 if (Block == CanonicalLoop->getCond() ||
6962 Block == CanonicalLoop->getHeader())
6963 continue;
6964 Reachable.insert(Ptr: Block);
6965 }
6966
6967 SmallVector<Metadata *> LoopMDList;
6968
6969 // In presence of finite 'safelen', it may be unsafe to mark all
6970 // the memory instructions parallel, because loop-carried
6971 // dependences of 'safelen' iterations are possible.
6972 // If clause order(concurrent) is specified then the memory instructions
6973 // are marked parallel even if 'safelen' is finite.
6974 if ((Safelen == nullptr) || (Order == OrderKind::OMP_ORDER_concurrent))
6975 applyParallelAccessesMetadata(CLI: CanonicalLoop, Ctx, Loop: L, LoopInfo&: LI, LoopMDList);
6976
6977 // FIXME: the IF clause shares a loop backedge for the SIMD and non-SIMD
6978 // versions so we can't add the loop attributes in that case.
6979 if (IfCond) {
6980 // we can still add llvm.loop.parallel_access
6981 addLoopMetadata(Loop: CanonicalLoop, Properties: LoopMDList);
6982 return;
6983 }
6984
6985 // Use the above access group metadata to create loop level
6986 // metadata, which should be distinct for each loop.
6987 ConstantAsMetadata *BoolConst =
6988 ConstantAsMetadata::get(C: ConstantInt::getTrue(Ty: Type::getInt1Ty(C&: Ctx)));
6989 LoopMDList.push_back(Elt: MDNode::get(
6990 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.vectorize.enable"), BoolConst}));
6991
6992 if (Simdlen || Safelen) {
6993 // If both simdlen and safelen clauses are specified, the value of the
6994 // simdlen parameter must be less than or equal to the value of the safelen
6995 // parameter. Therefore, use safelen only in the absence of simdlen.
6996 ConstantInt *VectorizeWidth = Simdlen == nullptr ? Safelen : Simdlen;
6997 LoopMDList.push_back(
6998 Elt: MDNode::get(Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.vectorize.width"),
6999 ConstantAsMetadata::get(C: VectorizeWidth)}));
7000 }
7001
7002 addLoopMetadata(Loop: CanonicalLoop, Properties: LoopMDList);
7003}
7004
7005/// Create the TargetMachine object to query the backend for optimization
7006/// preferences.
7007///
7008/// Ideally, this would be passed from the front-end to the OpenMPBuilder, but
7009/// e.g. Clang does not pass it to its CodeGen layer and creates it only when
7010/// needed for the LLVM pass pipline. We use some default options to avoid
7011/// having to pass too many settings from the frontend that probably do not
7012/// matter.
7013///
7014/// Currently, TargetMachine is only used sometimes by the unrollLoopPartial
7015/// method. If we are going to use TargetMachine for more purposes, especially
7016/// those that are sensitive to TargetOptions, RelocModel and CodeModel, it
7017/// might become be worth requiring front-ends to pass on their TargetMachine,
7018/// or at least cache it between methods. Note that while fontends such as Clang
7019/// have just a single main TargetMachine per translation unit, "target-cpu" and
7020/// "target-features" that determine the TargetMachine are per-function and can
7021/// be overrided using __attribute__((target("OPTIONS"))).
7022static std::unique_ptr<TargetMachine>
7023createTargetMachine(Function *F, CodeGenOptLevel OptLevel) {
7024 Module *M = F->getParent();
7025
7026 StringRef CPU = F->getFnAttribute(Kind: "target-cpu").getValueAsString();
7027 StringRef Features = F->getFnAttribute(Kind: "target-features").getValueAsString();
7028 const llvm::Triple &Triple = M->getTargetTriple();
7029
7030 std::string Error;
7031 const llvm::Target *TheTarget = TargetRegistry::lookupTarget(TheTriple: Triple, Error);
7032 if (!TheTarget)
7033 return {};
7034
7035 llvm::TargetOptions Options;
7036 return std::unique_ptr<TargetMachine>(TheTarget->createTargetMachine(
7037 TT: Triple, CPU, Features, Options, /*RelocModel=*/RM: std::nullopt,
7038 /*CodeModel=*/CM: std::nullopt, OL: OptLevel));
7039}
7040
7041/// Heuristically determine the best-performant unroll factor for \p CLI. This
7042/// depends on the target processor. We are re-using the same heuristics as the
7043/// LoopUnrollPass.
7044static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) {
7045 Function *F = CLI->getFunction();
7046
7047 // Assume the user requests the most aggressive unrolling, even if the rest of
7048 // the code is optimized using a lower setting.
7049 CodeGenOptLevel OptLevel = CodeGenOptLevel::Aggressive;
7050 std::unique_ptr<TargetMachine> TM = createTargetMachine(F, OptLevel);
7051
7052 // Blocks must have terminators.
7053 // FIXME: Don't run analyses on incomplete/invalid IR.
7054 SmallVector<Instruction *> UIs;
7055 for (BasicBlock &BB : *F)
7056 if (!BB.hasTerminator())
7057 UIs.push_back(Elt: new UnreachableInst(F->getContext(), &BB));
7058
7059 FunctionAnalysisManager FAM;
7060 FAM.registerPass(PassBuilder: []() { return TargetLibraryAnalysis(); });
7061 FAM.registerPass(PassBuilder: []() { return AssumptionAnalysis(); });
7062 FAM.registerPass(PassBuilder: []() { return DominatorTreeAnalysis(); });
7063 FAM.registerPass(PassBuilder: []() { return LoopAnalysis(); });
7064 FAM.registerPass(PassBuilder: []() { return ScalarEvolutionAnalysis(); });
7065 FAM.registerPass(PassBuilder: []() { return PassInstrumentationAnalysis(); });
7066 TargetIRAnalysis TIRA;
7067 if (TM)
7068 TIRA = TargetIRAnalysis(
7069 [&](const Function &F) { return TM->getTargetTransformInfo(F); });
7070 FAM.registerPass(PassBuilder: [&]() { return TIRA; });
7071
7072 TargetIRAnalysis::Result &&TTI = TIRA.run(F: *F, FAM);
7073 ScalarEvolutionAnalysis SEA;
7074 ScalarEvolution &&SE = SEA.run(F&: *F, AM&: FAM);
7075 DominatorTreeAnalysis DTA;
7076 DominatorTree &&DT = DTA.run(F&: *F, FAM);
7077 LoopAnalysis LIA;
7078 LoopInfo &&LI = LIA.run(F&: *F, AM&: FAM);
7079 AssumptionAnalysis ACT;
7080 AssumptionCache &&AC = ACT.run(F&: *F, FAM);
7081 OptimizationRemarkEmitter ORE{F};
7082
7083 for (Instruction *I : UIs)
7084 I->eraseFromParent();
7085
7086 Loop *L = LI.getLoopFor(BB: CLI->getHeader());
7087 assert(L && "Expecting CanonicalLoopInfo to be recognized as a loop");
7088
7089 TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences(
7090 L, SE, TTI,
7091 /*BlockFrequencyInfo=*/BFI: nullptr,
7092 /*ProfileSummaryInfo=*/PSI: nullptr, ORE, OptLevel: static_cast<int>(OptLevel),
7093 /*UserThreshold=*/std::nullopt,
7094 /*UserCount=*/std::nullopt,
7095 /*UserAllowPartial=*/true,
7096 /*UserAllowRuntime=*/UserRuntime: true,
7097 /*UserUpperBound=*/std::nullopt,
7098 /*UserFullUnrollMaxCount=*/std::nullopt);
7099
7100 UP.Force = true;
7101
7102 // Account for additional optimizations taking place before the LoopUnrollPass
7103 // would unroll the loop.
7104 UP.Threshold *= UnrollThresholdFactor;
7105 UP.PartialThreshold *= UnrollThresholdFactor;
7106
7107 // Use normal unroll factors even if the rest of the code is optimized for
7108 // size.
7109 UP.OptSizeThreshold = UP.Threshold;
7110 UP.PartialOptSizeThreshold = UP.PartialThreshold;
7111
7112 LLVM_DEBUG(dbgs() << "Unroll heuristic thresholds:\n"
7113 << " Threshold=" << UP.Threshold << "\n"
7114 << " PartialThreshold=" << UP.PartialThreshold << "\n"
7115 << " OptSizeThreshold=" << UP.OptSizeThreshold << "\n"
7116 << " PartialOptSizeThreshold="
7117 << UP.PartialOptSizeThreshold << "\n");
7118
7119 // Disable peeling.
7120 TargetTransformInfo::PeelingPreferences PP =
7121 gatherPeelingPreferences(L, SE, TTI,
7122 /*UserAllowPeeling=*/false,
7123 /*UserAllowProfileBasedPeeling=*/false,
7124 /*UnrollingSpecficValues=*/false);
7125
7126 SmallPtrSet<const Value *, 32> EphValues;
7127 CodeMetrics::collectEphemeralValues(L, AC: &AC, EphValues);
7128
7129 // Assume that reads and writes to stack variables can be eliminated by
7130 // Mem2Reg, SROA or LICM. That is, don't count them towards the loop body's
7131 // size.
7132 for (BasicBlock *BB : L->blocks()) {
7133 for (Instruction &I : *BB) {
7134 Value *Ptr;
7135 if (auto *Load = dyn_cast<LoadInst>(Val: &I)) {
7136 Ptr = Load->getPointerOperand();
7137 } else if (auto *Store = dyn_cast<StoreInst>(Val: &I)) {
7138 Ptr = Store->getPointerOperand();
7139 } else
7140 continue;
7141
7142 Ptr = Ptr->stripPointerCasts();
7143
7144 if (auto *Alloca = dyn_cast<AllocaInst>(Val: Ptr)) {
7145 if (Alloca->getParent() == &F->getEntryBlock())
7146 EphValues.insert(Ptr: &I);
7147 }
7148 }
7149 }
7150
7151 UnrollCostEstimator UCE(L, TTI, EphValues, UP.BEInsns);
7152
7153 // Loop is not unrollable if the loop contains certain instructions.
7154 if (!UCE.canUnroll()) {
7155 LLVM_DEBUG(dbgs() << "Loop not considered unrollable\n");
7156 return 1;
7157 }
7158
7159 LLVM_DEBUG(dbgs() << "Estimated loop size is " << UCE.getRolledLoopSize()
7160 << "\n");
7161
7162 // TODO: Determine trip count of \p CLI if constant, computeUnrollCount might
7163 // be able to use it.
7164 int TripCount = 0;
7165 int MaxTripCount = 0;
7166 bool MaxOrZero = false;
7167 unsigned TripMultiple = 0;
7168
7169 computeUnrollCount(L, TTI, DT, LI: &LI, AC: &AC, SE, EphValues, ORE: &ORE, TripCount,
7170 MaxTripCount, MaxOrZero, TripMultiple, UCE, UP, PP);
7171 unsigned Factor = UP.Count;
7172 LLVM_DEBUG(dbgs() << "Suggesting unroll factor of " << Factor << "\n");
7173
7174 // This function returns 1 to signal to not unroll a loop.
7175 if (Factor == 0)
7176 return 1;
7177 return Factor;
7178}
7179
7180void OpenMPIRBuilder::unrollLoopPartial(DebugLoc DL, CanonicalLoopInfo *Loop,
7181 int32_t Factor,
7182 CanonicalLoopInfo **UnrolledCLI) {
7183 assert(Factor >= 0 && "Unroll factor must not be negative");
7184
7185 Function *F = Loop->getFunction();
7186 LLVMContext &Ctx = F->getContext();
7187
7188 // If the unrolled loop is not used for another loop-associated directive, it
7189 // is sufficient to add metadata for the LoopUnrollPass.
7190 if (!UnrolledCLI) {
7191 SmallVector<Metadata *, 2> LoopMetadata;
7192 LoopMetadata.push_back(
7193 Elt: MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")));
7194
7195 if (Factor >= 1) {
7196 ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
7197 C: ConstantInt::get(Ty: Type::getInt32Ty(C&: Ctx), V: APInt(32, Factor)));
7198 LoopMetadata.push_back(Elt: MDNode::get(
7199 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.count"), FactorConst}));
7200 }
7201
7202 addLoopMetadata(Loop, Properties: LoopMetadata);
7203 return;
7204 }
7205
7206 // Heuristically determine the unroll factor.
7207 if (Factor == 0)
7208 Factor = computeHeuristicUnrollFactor(CLI: Loop);
7209
7210 // No change required with unroll factor 1.
7211 if (Factor == 1) {
7212 *UnrolledCLI = Loop;
7213 return;
7214 }
7215
7216 assert(Factor >= 2 &&
7217 "unrolling only makes sense with a factor of 2 or larger");
7218
7219 Type *IndVarTy = Loop->getIndVarType();
7220
7221 // Apply partial unrolling by tiling the loop by the unroll-factor, then fully
7222 // unroll the inner loop.
7223 Value *FactorVal =
7224 ConstantInt::get(Ty: IndVarTy, V: APInt(IndVarTy->getIntegerBitWidth(), Factor,
7225 /*isSigned=*/false));
7226 std::vector<CanonicalLoopInfo *> LoopNest =
7227 tileLoops(DL, Loops: {Loop}, TileSizes: {FactorVal});
7228 assert(LoopNest.size() == 2 && "Expect 2 loops after tiling");
7229 *UnrolledCLI = LoopNest[0];
7230 CanonicalLoopInfo *InnerLoop = LoopNest[1];
7231
7232 // LoopUnrollPass can only fully unroll loops with constant trip count.
7233 // Unroll by the unroll factor with a fallback epilog for the remainder
7234 // iterations if necessary.
7235 ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
7236 C: ConstantInt::get(Ty: Type::getInt32Ty(C&: Ctx), V: APInt(32, Factor)));
7237 addLoopMetadata(
7238 Loop: InnerLoop,
7239 Properties: {MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")),
7240 MDNode::get(
7241 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.count"), FactorConst})});
7242
7243#ifndef NDEBUG
7244 (*UnrolledCLI)->assertOK();
7245#endif
7246}
7247
7248OpenMPIRBuilder::InsertPointTy
7249OpenMPIRBuilder::createCopyPrivate(const LocationDescription &Loc,
7250 llvm::Value *BufSize, llvm::Value *CpyBuf,
7251 llvm::Value *CpyFn, llvm::Value *DidIt) {
7252 if (!updateToLocation(Loc))
7253 return Loc.IP;
7254
7255 uint32_t SrcLocStrSize;
7256 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7257 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7258 Value *ThreadId = getOrCreateThreadID(Ident);
7259
7260 llvm::Value *DidItLD = Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: DidIt);
7261
7262 Value *Args[] = {Ident, ThreadId, BufSize, CpyBuf, CpyFn, DidItLD};
7263
7264 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_copyprivate);
7265 createRuntimeFunctionCall(Callee: Fn, Args);
7266
7267 return Builder.saveIP();
7268}
7269
7270OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createSingle(
7271 const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
7272 FinalizeCallbackTy FiniCB, bool IsNowait, ArrayRef<llvm::Value *> CPVars,
7273 ArrayRef<llvm::Function *> CPFuncs) {
7274
7275 if (!updateToLocation(Loc))
7276 return Loc.IP;
7277
7278 // If needed allocate and initialize `DidIt` with 0.
7279 // DidIt: flag variable: 1=single thread; 0=not single thread.
7280 llvm::Value *DidIt = nullptr;
7281 if (!CPVars.empty()) {
7282 DidIt = Builder.CreateAlloca(Ty: llvm::Type::getInt32Ty(C&: Builder.getContext()));
7283 Builder.CreateStore(Val: Builder.getInt32(C: 0), Ptr: DidIt);
7284 }
7285
7286 Directive OMPD = Directive::OMPD_single;
7287 uint32_t SrcLocStrSize;
7288 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7289 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7290 Value *ThreadId = getOrCreateThreadID(Ident);
7291 Value *Args[] = {Ident, ThreadId};
7292
7293 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_single);
7294 Instruction *EntryCall = createRuntimeFunctionCall(Callee: EntryRTLFn, Args);
7295
7296 Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_single);
7297 Instruction *ExitCall = createRuntimeFunctionCall(Callee: ExitRTLFn, Args);
7298
7299 auto FiniCBWrapper = [&](InsertPointTy IP) -> Error {
7300 if (Error Err = FiniCB(IP))
7301 return Err;
7302
7303 // The thread that executes the single region must set `DidIt` to 1.
7304 // This is used by __kmpc_copyprivate, to know if the caller is the
7305 // single thread or not.
7306 if (DidIt)
7307 Builder.CreateStore(Val: Builder.getInt32(C: 1), Ptr: DidIt);
7308
7309 return Error::success();
7310 };
7311
7312 // generates the following:
7313 // if (__kmpc_single()) {
7314 // .... single region ...
7315 // __kmpc_end_single
7316 // }
7317 // __kmpc_copyprivate
7318 // __kmpc_barrier
7319
7320 InsertPointOrErrorTy AfterIP =
7321 EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB: FiniCBWrapper,
7322 /*Conditional*/ true,
7323 /*hasFinalize*/ HasFinalize: true);
7324 if (!AfterIP)
7325 return AfterIP.takeError();
7326
7327 if (DidIt) {
7328 for (size_t I = 0, E = CPVars.size(); I < E; ++I)
7329 // NOTE BufSize is currently unused, so just pass 0.
7330 createCopyPrivate(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
7331 /*BufSize=*/ConstantInt::get(Ty: Int64, V: 0), CpyBuf: CPVars[I],
7332 CpyFn: CPFuncs[I], DidIt);
7333 // NOTE __kmpc_copyprivate already inserts a barrier
7334 } else if (!IsNowait) {
7335 InsertPointOrErrorTy AfterIP =
7336 createBarrier(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
7337 Kind: omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
7338 /* CheckCancelFlag */ false);
7339 if (!AfterIP)
7340 return AfterIP.takeError();
7341 }
7342 return Builder.saveIP();
7343}
7344
7345OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createCritical(
7346 const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
7347 FinalizeCallbackTy FiniCB, StringRef CriticalName, Value *HintInst) {
7348
7349 if (!updateToLocation(Loc))
7350 return Loc.IP;
7351
7352 Directive OMPD = Directive::OMPD_critical;
7353 uint32_t SrcLocStrSize;
7354 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7355 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7356 Value *ThreadId = getOrCreateThreadID(Ident);
7357 Value *LockVar = getOMPCriticalRegionLock(CriticalName);
7358 Value *Args[] = {Ident, ThreadId, LockVar};
7359
7360 SmallVector<llvm::Value *, 4> EnterArgs(std::begin(arr&: Args), std::end(arr&: Args));
7361 Function *RTFn = nullptr;
7362 if (HintInst) {
7363 // Add Hint to entry Args and create call
7364 EnterArgs.push_back(Elt: HintInst);
7365 RTFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_critical_with_hint);
7366 } else {
7367 RTFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_critical);
7368 }
7369 Instruction *EntryCall = createRuntimeFunctionCall(Callee: RTFn, Args: EnterArgs);
7370
7371 Function *ExitRTLFn =
7372 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_critical);
7373 Instruction *ExitCall = createRuntimeFunctionCall(Callee: ExitRTLFn, Args);
7374
7375 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
7376 /*Conditional*/ false, /*hasFinalize*/ HasFinalize: true);
7377}
7378
7379OpenMPIRBuilder::InsertPointTy
7380OpenMPIRBuilder::createOrderedDepend(const LocationDescription &Loc,
7381 InsertPointTy AllocaIP, unsigned NumLoops,
7382 ArrayRef<llvm::Value *> StoreValues,
7383 const Twine &Name, bool IsDependSource) {
7384 assert(
7385 llvm::all_of(StoreValues,
7386 [](Value *SV) { return SV->getType()->isIntegerTy(64); }) &&
7387 "OpenMP runtime requires depend vec with i64 type");
7388
7389 if (!updateToLocation(Loc))
7390 return Loc.IP;
7391
7392 // Allocate space for vector and generate alloc instruction.
7393 auto *ArrI64Ty = ArrayType::get(ElementType: Int64, NumElements: NumLoops);
7394 Builder.restoreIP(IP: AllocaIP);
7395 AllocaInst *ArgsBase = Builder.CreateAlloca(Ty: ArrI64Ty, ArraySize: nullptr, Name);
7396 ArgsBase->setAlignment(Align(8));
7397 updateToLocation(Loc);
7398
7399 // Store the index value with offset in depend vector.
7400 for (unsigned I = 0; I < NumLoops; ++I) {
7401 Value *DependAddrGEPIter = Builder.CreateInBoundsGEP(
7402 Ty: ArrI64Ty, Ptr: ArgsBase, IdxList: {Builder.getInt64(C: 0), Builder.getInt64(C: I)});
7403 StoreInst *STInst = Builder.CreateStore(Val: StoreValues[I], Ptr: DependAddrGEPIter);
7404 STInst->setAlignment(Align(8));
7405 }
7406
7407 Value *DependBaseAddrGEP = Builder.CreateInBoundsGEP(
7408 Ty: ArrI64Ty, Ptr: ArgsBase, IdxList: {Builder.getInt64(C: 0), Builder.getInt64(C: 0)});
7409
7410 uint32_t SrcLocStrSize;
7411 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7412 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7413 Value *ThreadId = getOrCreateThreadID(Ident);
7414 Value *Args[] = {Ident, ThreadId, DependBaseAddrGEP};
7415
7416 Function *RTLFn = nullptr;
7417 if (IsDependSource)
7418 RTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_doacross_post);
7419 else
7420 RTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_doacross_wait);
7421 createRuntimeFunctionCall(Callee: RTLFn, Args);
7422
7423 return Builder.saveIP();
7424}
7425
7426OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createOrderedThreadsSimd(
7427 const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
7428 FinalizeCallbackTy FiniCB, bool IsThreads) {
7429 if (!updateToLocation(Loc))
7430 return Loc.IP;
7431
7432 Directive OMPD = Directive::OMPD_ordered;
7433 Instruction *EntryCall = nullptr;
7434 Instruction *ExitCall = nullptr;
7435
7436 if (IsThreads) {
7437 uint32_t SrcLocStrSize;
7438 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7439 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7440 Value *ThreadId = getOrCreateThreadID(Ident);
7441 Value *Args[] = {Ident, ThreadId};
7442
7443 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_ordered);
7444 EntryCall = createRuntimeFunctionCall(Callee: EntryRTLFn, Args);
7445
7446 Function *ExitRTLFn =
7447 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_ordered);
7448 ExitCall = createRuntimeFunctionCall(Callee: ExitRTLFn, Args);
7449 }
7450
7451 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
7452 /*Conditional*/ false, /*hasFinalize*/ HasFinalize: true);
7453}
7454
7455OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::EmitOMPInlinedRegion(
7456 Directive OMPD, Instruction *EntryCall, Instruction *ExitCall,
7457 BodyGenCallbackTy BodyGenCB, FinalizeCallbackTy FiniCB, bool Conditional,
7458 bool HasFinalize, bool IsCancellable) {
7459
7460 if (HasFinalize)
7461 FinalizationStack.push_back(Elt: {FiniCB, OMPD, IsCancellable});
7462
7463 // Create inlined region's entry and body blocks, in preparation
7464 // for conditional creation
7465 BasicBlock *EntryBB = Builder.GetInsertBlock();
7466 Instruction *SplitPos = EntryBB->getTerminatorOrNull();
7467 if (!isa_and_nonnull<UncondBrInst, CondBrInst>(Val: SplitPos))
7468 SplitPos = new UnreachableInst(Builder.getContext(), EntryBB);
7469 BasicBlock *ExitBB = EntryBB->splitBasicBlock(I: SplitPos, BBName: "omp_region.end");
7470 BasicBlock *FiniBB =
7471 EntryBB->splitBasicBlock(I: EntryBB->getTerminator(), BBName: "omp_region.finalize");
7472
7473 Builder.SetInsertPoint(EntryBB->getTerminator());
7474 emitCommonDirectiveEntry(OMPD, EntryCall, ExitBB, Conditional);
7475
7476 // generate body
7477 if (Error Err = BodyGenCB(/* AllocaIP */ InsertPointTy(),
7478 /* CodeGenIP */ Builder.saveIP()))
7479 return Err;
7480
7481 // emit exit call and do any needed finalization.
7482 auto FinIP = InsertPointTy(FiniBB, FiniBB->getFirstInsertionPt());
7483 assert(FiniBB->getTerminator()->getNumSuccessors() == 1 &&
7484 FiniBB->getTerminator()->getSuccessor(0) == ExitBB &&
7485 "Unexpected control flow graph state!!");
7486 InsertPointOrErrorTy AfterIP =
7487 emitCommonDirectiveExit(OMPD, FinIP, ExitCall, HasFinalize);
7488 if (!AfterIP)
7489 return AfterIP.takeError();
7490
7491 // If we are skipping the region of a non conditional, remove the exit
7492 // block, and clear the builder's insertion point.
7493 assert(SplitPos->getParent() == ExitBB &&
7494 "Unexpected Insertion point location!");
7495 auto merged = MergeBlockIntoPredecessor(BB: ExitBB);
7496 BasicBlock *ExitPredBB = SplitPos->getParent();
7497 auto InsertBB = merged ? ExitPredBB : ExitBB;
7498 if (!isa_and_nonnull<UncondBrInst, CondBrInst>(Val: SplitPos))
7499 SplitPos->eraseFromParent();
7500 Builder.SetInsertPoint(InsertBB);
7501
7502 return Builder.saveIP();
7503}
7504
7505OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveEntry(
7506 Directive OMPD, Value *EntryCall, BasicBlock *ExitBB, bool Conditional) {
7507 // if nothing to do, Return current insertion point.
7508 if (!Conditional || !EntryCall)
7509 return Builder.saveIP();
7510
7511 BasicBlock *EntryBB = Builder.GetInsertBlock();
7512 Value *CallBool = Builder.CreateIsNotNull(Arg: EntryCall);
7513 auto *ThenBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp_region.body");
7514 auto *UI = new UnreachableInst(Builder.getContext(), ThenBB);
7515
7516 // Emit thenBB and set the Builder's insertion point there for
7517 // body generation next. Place the block after the current block.
7518 Function *CurFn = EntryBB->getParent();
7519 CurFn->insert(Position: std::next(x: EntryBB->getIterator()), BB: ThenBB);
7520
7521 // Move Entry branch to end of ThenBB, and replace with conditional
7522 // branch (If-stmt)
7523 Instruction *EntryBBTI = EntryBB->getTerminator();
7524 Builder.CreateCondBr(Cond: CallBool, True: ThenBB, False: ExitBB);
7525 EntryBBTI->removeFromParent();
7526 Builder.SetInsertPoint(UI);
7527 Builder.Insert(I: EntryBBTI);
7528 UI->eraseFromParent();
7529 Builder.SetInsertPoint(ThenBB->getTerminator());
7530
7531 // return an insertion point to ExitBB.
7532 return IRBuilder<>::InsertPoint(ExitBB, ExitBB->getFirstInsertionPt());
7533}
7534
7535OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitCommonDirectiveExit(
7536 omp::Directive OMPD, InsertPointTy FinIP, Instruction *ExitCall,
7537 bool HasFinalize) {
7538
7539 Builder.restoreIP(IP: FinIP);
7540
7541 // If there is finalization to do, emit it before the exit call
7542 if (HasFinalize) {
7543 assert(!FinalizationStack.empty() &&
7544 "Unexpected finalization stack state!");
7545
7546 FinalizationInfo Fi = FinalizationStack.pop_back_val();
7547 assert(Fi.DK == OMPD && "Unexpected Directive for Finalization call!");
7548
7549 if (Error Err = Fi.mergeFiniBB(Builder, OtherFiniBB: FinIP.getBlock()))
7550 return std::move(Err);
7551
7552 // Exit condition: insertion point is before the terminator of the new Fini
7553 // block
7554 Builder.SetInsertPoint(FinIP.getBlock()->getTerminator());
7555 }
7556
7557 if (!ExitCall)
7558 return Builder.saveIP();
7559
7560 // place the Exitcall as last instruction before Finalization block terminator
7561 ExitCall->removeFromParent();
7562 Builder.Insert(I: ExitCall);
7563
7564 return IRBuilder<>::InsertPoint(ExitCall->getParent(),
7565 ExitCall->getIterator());
7566}
7567
7568OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCopyinClauseBlocks(
7569 InsertPointTy IP, Value *MasterAddr, Value *PrivateAddr,
7570 llvm::IntegerType *IntPtrTy, bool BranchtoEnd) {
7571 if (!IP.isSet())
7572 return IP;
7573
7574 IRBuilder<>::InsertPointGuard IPG(Builder);
7575
7576 // creates the following CFG structure
7577 // OMP_Entry : (MasterAddr != PrivateAddr)?
7578 // F T
7579 // | \
7580 // | copin.not.master
7581 // | /
7582 // v /
7583 // copyin.not.master.end
7584 // |
7585 // v
7586 // OMP.Entry.Next
7587
7588 BasicBlock *OMP_Entry = IP.getBlock();
7589 Function *CurFn = OMP_Entry->getParent();
7590 BasicBlock *CopyBegin =
7591 BasicBlock::Create(Context&: M.getContext(), Name: "copyin.not.master", Parent: CurFn);
7592 BasicBlock *CopyEnd = nullptr;
7593
7594 // If entry block is terminated, split to preserve the branch to following
7595 // basic block (i.e. OMP.Entry.Next), otherwise, leave everything as is.
7596 if (isa_and_nonnull<CondBrInst>(Val: OMP_Entry->getTerminatorOrNull())) {
7597 CopyEnd = OMP_Entry->splitBasicBlock(I: OMP_Entry->getTerminator(),
7598 BBName: "copyin.not.master.end");
7599 OMP_Entry->getTerminator()->eraseFromParent();
7600 } else {
7601 CopyEnd =
7602 BasicBlock::Create(Context&: M.getContext(), Name: "copyin.not.master.end", Parent: CurFn);
7603 }
7604
7605 Builder.SetInsertPoint(OMP_Entry);
7606 Value *MasterPtr = Builder.CreatePtrToInt(V: MasterAddr, DestTy: IntPtrTy);
7607 Value *PrivatePtr = Builder.CreatePtrToInt(V: PrivateAddr, DestTy: IntPtrTy);
7608 Value *cmp = Builder.CreateICmpNE(LHS: MasterPtr, RHS: PrivatePtr);
7609 Builder.CreateCondBr(Cond: cmp, True: CopyBegin, False: CopyEnd);
7610
7611 Builder.SetInsertPoint(CopyBegin);
7612 if (BranchtoEnd)
7613 Builder.SetInsertPoint(Builder.CreateBr(Dest: CopyEnd));
7614
7615 return Builder.saveIP();
7616}
7617
7618CallInst *OpenMPIRBuilder::createOMPAlloc(const LocationDescription &Loc,
7619 Value *Size, Value *Allocator,
7620 std::string Name) {
7621 IRBuilder<>::InsertPointGuard IPG(Builder);
7622 updateToLocation(Loc);
7623
7624 uint32_t SrcLocStrSize;
7625 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7626 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7627 Value *ThreadId = getOrCreateThreadID(Ident);
7628 Value *Args[] = {ThreadId, Size, Allocator};
7629
7630 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_alloc);
7631
7632 return createRuntimeFunctionCall(Callee: Fn, Args, Name);
7633}
7634
7635CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc,
7636 Value *Addr, Value *Allocator,
7637 std::string Name) {
7638 IRBuilder<>::InsertPointGuard IPG(Builder);
7639 updateToLocation(Loc);
7640
7641 uint32_t SrcLocStrSize;
7642 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7643 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7644 Value *ThreadId = getOrCreateThreadID(Ident);
7645 Value *Args[] = {ThreadId, Addr, Allocator};
7646 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_free);
7647 return createRuntimeFunctionCall(Callee: Fn, Args, Name);
7648}
7649
7650CallInst *OpenMPIRBuilder::createOMPInteropInit(
7651 const LocationDescription &Loc, Value *InteropVar,
7652 omp::OMPInteropType InteropType, Value *Device, Value *NumDependences,
7653 Value *DependenceAddress, bool HaveNowaitClause) {
7654 IRBuilder<>::InsertPointGuard IPG(Builder);
7655 updateToLocation(Loc);
7656
7657 uint32_t SrcLocStrSize;
7658 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7659 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7660 Value *ThreadId = getOrCreateThreadID(Ident);
7661 if (Device == nullptr)
7662 Device = Constant::getAllOnesValue(Ty: Int32);
7663 Constant *InteropTypeVal = ConstantInt::get(Ty: Int32, V: (int)InteropType);
7664 if (NumDependences == nullptr) {
7665 NumDependences = ConstantInt::get(Ty: Int32, V: 0);
7666 PointerType *PointerTypeVar = PointerType::getUnqual(C&: M.getContext());
7667 DependenceAddress = ConstantPointerNull::get(T: PointerTypeVar);
7668 }
7669 Value *HaveNowaitClauseVal = ConstantInt::get(Ty: Int32, V: HaveNowaitClause);
7670 Value *Args[] = {
7671 Ident, ThreadId, InteropVar, InteropTypeVal,
7672 Device, NumDependences, DependenceAddress, HaveNowaitClauseVal};
7673
7674 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___tgt_interop_init);
7675
7676 return createRuntimeFunctionCall(Callee: Fn, Args);
7677}
7678
7679CallInst *OpenMPIRBuilder::createOMPInteropDestroy(
7680 const LocationDescription &Loc, Value *InteropVar, Value *Device,
7681 Value *NumDependences, Value *DependenceAddress, bool HaveNowaitClause) {
7682 IRBuilder<>::InsertPointGuard IPG(Builder);
7683 updateToLocation(Loc);
7684
7685 uint32_t SrcLocStrSize;
7686 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7687 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7688 Value *ThreadId = getOrCreateThreadID(Ident);
7689 if (Device == nullptr)
7690 Device = Constant::getAllOnesValue(Ty: Int32);
7691 if (NumDependences == nullptr) {
7692 NumDependences = ConstantInt::get(Ty: Int32, V: 0);
7693 PointerType *PointerTypeVar = PointerType::getUnqual(C&: M.getContext());
7694 DependenceAddress = ConstantPointerNull::get(T: PointerTypeVar);
7695 }
7696 Value *HaveNowaitClauseVal = ConstantInt::get(Ty: Int32, V: HaveNowaitClause);
7697 Value *Args[] = {
7698 Ident, ThreadId, InteropVar, Device,
7699 NumDependences, DependenceAddress, HaveNowaitClauseVal};
7700
7701 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___tgt_interop_destroy);
7702
7703 return createRuntimeFunctionCall(Callee: Fn, Args);
7704}
7705
7706CallInst *OpenMPIRBuilder::createOMPInteropUse(const LocationDescription &Loc,
7707 Value *InteropVar, Value *Device,
7708 Value *NumDependences,
7709 Value *DependenceAddress,
7710 bool HaveNowaitClause) {
7711 IRBuilder<>::InsertPointGuard IPG(Builder);
7712 updateToLocation(Loc);
7713 uint32_t SrcLocStrSize;
7714 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7715 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7716 Value *ThreadId = getOrCreateThreadID(Ident);
7717 if (Device == nullptr)
7718 Device = Constant::getAllOnesValue(Ty: Int32);
7719 if (NumDependences == nullptr) {
7720 NumDependences = ConstantInt::get(Ty: Int32, V: 0);
7721 PointerType *PointerTypeVar = PointerType::getUnqual(C&: M.getContext());
7722 DependenceAddress = ConstantPointerNull::get(T: PointerTypeVar);
7723 }
7724 Value *HaveNowaitClauseVal = ConstantInt::get(Ty: Int32, V: HaveNowaitClause);
7725 Value *Args[] = {
7726 Ident, ThreadId, InteropVar, Device,
7727 NumDependences, DependenceAddress, HaveNowaitClauseVal};
7728
7729 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___tgt_interop_use);
7730
7731 return createRuntimeFunctionCall(Callee: Fn, Args);
7732}
7733
7734CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
7735 const LocationDescription &Loc, llvm::Value *Pointer,
7736 llvm::ConstantInt *Size, const llvm::Twine &Name) {
7737 IRBuilder<>::InsertPointGuard IPG(Builder);
7738 updateToLocation(Loc);
7739
7740 uint32_t SrcLocStrSize;
7741 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7742 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7743 Value *ThreadId = getOrCreateThreadID(Ident);
7744 Constant *ThreadPrivateCache =
7745 getOrCreateInternalVariable(Ty: Int8PtrPtr, Name: Name.str());
7746 llvm::Value *Args[] = {Ident, ThreadId, Pointer, Size, ThreadPrivateCache};
7747
7748 Function *Fn =
7749 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_threadprivate_cached);
7750
7751 return createRuntimeFunctionCall(Callee: Fn, Args);
7752}
7753
7754OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
7755 const LocationDescription &Loc,
7756 const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
7757 assert(!Attrs.MaxThreads.empty() && !Attrs.MaxTeams.empty() &&
7758 "expected num_threads and num_teams to be specified");
7759
7760 if (!updateToLocation(Loc))
7761 return Loc.IP;
7762
7763 uint32_t SrcLocStrSize;
7764 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7765 Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7766 Constant *IsSPMDVal = ConstantInt::getSigned(Ty: Int8, V: Attrs.ExecFlags);
7767 Constant *UseGenericStateMachineVal = ConstantInt::getSigned(
7768 Ty: Int8, V: Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD);
7769 Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Ty: Int8, V: true);
7770 Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Ty: Int16, V: 0);
7771
7772 Function *DebugKernelWrapper = Builder.GetInsertBlock()->getParent();
7773 Function *Kernel = DebugKernelWrapper;
7774
7775 // We need to strip the debug prefix to get the correct kernel name.
7776 StringRef KernelName = Kernel->getName();
7777 const std::string DebugPrefix = "_debug__";
7778 if (KernelName.ends_with(Suffix: DebugPrefix)) {
7779 KernelName = KernelName.drop_back(N: DebugPrefix.length());
7780 Kernel = M.getFunction(Name: KernelName);
7781 assert(Kernel && "Expected the real kernel to exist");
7782 }
7783
7784 // Manifest the launch configuration in the metadata matching the kernel
7785 // environment.
7786 if (Attrs.MinTeams > 1 || Attrs.MaxTeams.front() > 0)
7787 writeTeamsForKernel(T, Kernel&: *Kernel, LB: Attrs.MinTeams, UB: Attrs.MaxTeams.front());
7788
7789 // If MaxThreads not set, select the maximum between the default workgroup
7790 // size and the MinThreads value.
7791 int32_t MaxThreadsVal = Attrs.MaxThreads.front();
7792 if (MaxThreadsVal < 0) {
7793 if (hasGridValue(T)) {
7794 MaxThreadsVal =
7795 std::max(a: int32_t(getGridValue(T, Kernel).GV_Default_WG_Size),
7796 b: Attrs.MinThreads);
7797 } else {
7798 MaxThreadsVal = Attrs.MinThreads;
7799 }
7800 }
7801
7802 if (MaxThreadsVal > 0)
7803 writeThreadBoundsForKernel(T, Kernel&: *Kernel, LB: Attrs.MinThreads, UB: MaxThreadsVal);
7804
7805 Constant *MinThreads = ConstantInt::getSigned(Ty: Int32, V: Attrs.MinThreads);
7806 Constant *MaxThreads = ConstantInt::getSigned(Ty: Int32, V: MaxThreadsVal);
7807 Constant *MinTeams = ConstantInt::getSigned(Ty: Int32, V: Attrs.MinTeams);
7808 Constant *MaxTeams = ConstantInt::getSigned(Ty: Int32, V: Attrs.MaxTeams.front());
7809 Constant *ReductionDataSize =
7810 ConstantInt::getSigned(Ty: Int32, V: Attrs.ReductionDataSize);
7811 Constant *ReductionBufferLength =
7812 ConstantInt::getSigned(Ty: Int32, V: Attrs.ReductionBufferLength);
7813
7814 Function *Fn = getOrCreateRuntimeFunctionPtr(
7815 FnID: omp::RuntimeFunction::OMPRTL___kmpc_target_init);
7816 const DataLayout &DL = Fn->getDataLayout();
7817
7818 Twine DynamicEnvironmentName = KernelName + "_dynamic_environment";
7819 Constant *DynamicEnvironmentInitializer =
7820 ConstantStruct::get(T: DynamicEnvironment, V: {DebugIndentionLevelVal});
7821 GlobalVariable *DynamicEnvironmentGV = new GlobalVariable(
7822 M, DynamicEnvironment, /*IsConstant=*/false, GlobalValue::WeakODRLinkage,
7823 DynamicEnvironmentInitializer, DynamicEnvironmentName,
7824 /*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
7825 DL.getDefaultGlobalsAddressSpace());
7826 DynamicEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
7827
7828 Constant *DynamicEnvironment =
7829 DynamicEnvironmentGV->getType() == DynamicEnvironmentPtr
7830 ? DynamicEnvironmentGV
7831 : ConstantExpr::getAddrSpaceCast(C: DynamicEnvironmentGV,
7832 Ty: DynamicEnvironmentPtr);
7833
7834 Constant *ConfigurationEnvironmentInitializer = ConstantStruct::get(
7835 T: ConfigurationEnvironment, V: {
7836 UseGenericStateMachineVal,
7837 MayUseNestedParallelismVal,
7838 IsSPMDVal,
7839 MinThreads,
7840 MaxThreads,
7841 MinTeams,
7842 MaxTeams,
7843 ReductionDataSize,
7844 ReductionBufferLength,
7845 });
7846 Constant *KernelEnvironmentInitializer = ConstantStruct::get(
7847 T: KernelEnvironment, V: {
7848 ConfigurationEnvironmentInitializer,
7849 Ident,
7850 DynamicEnvironment,
7851 });
7852 std::string KernelEnvironmentName =
7853 (KernelName + "_kernel_environment").str();
7854 GlobalVariable *KernelEnvironmentGV = new GlobalVariable(
7855 M, KernelEnvironment, /*IsConstant=*/true, GlobalValue::WeakODRLinkage,
7856 KernelEnvironmentInitializer, KernelEnvironmentName,
7857 /*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
7858 DL.getDefaultGlobalsAddressSpace());
7859 KernelEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
7860
7861 Constant *KernelEnvironment =
7862 KernelEnvironmentGV->getType() == KernelEnvironmentPtr
7863 ? KernelEnvironmentGV
7864 : ConstantExpr::getAddrSpaceCast(C: KernelEnvironmentGV,
7865 Ty: KernelEnvironmentPtr);
7866 Value *KernelLaunchEnvironment =
7867 DebugKernelWrapper->getArg(i: DebugKernelWrapper->arg_size() - 1);
7868 Type *KernelLaunchEnvParamTy = Fn->getFunctionType()->getParamType(i: 1);
7869 KernelLaunchEnvironment =
7870 KernelLaunchEnvironment->getType() == KernelLaunchEnvParamTy
7871 ? KernelLaunchEnvironment
7872 : Builder.CreateAddrSpaceCast(V: KernelLaunchEnvironment,
7873 DestTy: KernelLaunchEnvParamTy);
7874 CallInst *ThreadKind = createRuntimeFunctionCall(
7875 Callee: Fn, Args: {KernelEnvironment, KernelLaunchEnvironment});
7876
7877 Value *ExecUserCode = Builder.CreateICmpEQ(
7878 LHS: ThreadKind, RHS: Constant::getAllOnesValue(Ty: ThreadKind->getType()),
7879 Name: "exec_user_code");
7880
7881 // ThreadKind = __kmpc_target_init(...)
7882 // if (ThreadKind == -1)
7883 // user_code
7884 // else
7885 // return;
7886
7887 auto *UI = Builder.CreateUnreachable();
7888 BasicBlock *CheckBB = UI->getParent();
7889 BasicBlock *UserCodeEntryBB = CheckBB->splitBasicBlock(I: UI, BBName: "user_code.entry");
7890
7891 BasicBlock *WorkerExitBB = BasicBlock::Create(
7892 Context&: CheckBB->getContext(), Name: "worker.exit", Parent: CheckBB->getParent());
7893 Builder.SetInsertPoint(WorkerExitBB);
7894 Builder.CreateRetVoid();
7895
7896 auto *CheckBBTI = CheckBB->getTerminator();
7897 Builder.SetInsertPoint(CheckBBTI);
7898 Builder.CreateCondBr(Cond: ExecUserCode, True: UI->getParent(), False: WorkerExitBB);
7899
7900 CheckBBTI->eraseFromParent();
7901 UI->eraseFromParent();
7902
7903 // Continue in the "user_code" block, see diagram above and in
7904 // openmp/libomptarget/deviceRTLs/common/include/target.h .
7905 return InsertPointTy(UserCodeEntryBB, UserCodeEntryBB->getFirstInsertionPt());
7906}
7907
7908void OpenMPIRBuilder::createTargetDeinit(const LocationDescription &Loc,
7909 int32_t TeamsReductionDataSize,
7910 int32_t TeamsReductionBufferLength) {
7911 if (!updateToLocation(Loc))
7912 return;
7913
7914 Function *Fn = getOrCreateRuntimeFunctionPtr(
7915 FnID: omp::RuntimeFunction::OMPRTL___kmpc_target_deinit);
7916
7917 createRuntimeFunctionCall(Callee: Fn, Args: {});
7918
7919 if (!TeamsReductionBufferLength || !TeamsReductionDataSize)
7920 return;
7921
7922 Function *Kernel = Builder.GetInsertBlock()->getParent();
7923 // We need to strip the debug prefix to get the correct kernel name.
7924 StringRef KernelName = Kernel->getName();
7925 const std::string DebugPrefix = "_debug__";
7926 if (KernelName.ends_with(Suffix: DebugPrefix))
7927 KernelName = KernelName.drop_back(N: DebugPrefix.length());
7928 auto *KernelEnvironmentGV =
7929 M.getNamedGlobal(Name: (KernelName + "_kernel_environment").str());
7930 assert(KernelEnvironmentGV && "Expected kernel environment global\n");
7931 auto *KernelEnvironmentInitializer = KernelEnvironmentGV->getInitializer();
7932 auto *NewInitializer = ConstantFoldInsertValueInstruction(
7933 Agg: KernelEnvironmentInitializer,
7934 Val: ConstantInt::get(Ty: Int32, V: TeamsReductionDataSize), Idxs: {0, 7});
7935 NewInitializer = ConstantFoldInsertValueInstruction(
7936 Agg: NewInitializer, Val: ConstantInt::get(Ty: Int32, V: TeamsReductionBufferLength),
7937 Idxs: {0, 8});
7938 KernelEnvironmentGV->setInitializer(NewInitializer);
7939}
7940
7941static void updateNVPTXAttr(Function &Kernel, StringRef Name, int32_t Value,
7942 bool Min) {
7943 if (Kernel.hasFnAttribute(Kind: Name)) {
7944 int32_t OldLimit = Kernel.getFnAttributeAsParsedInteger(Kind: Name);
7945 Value = Min ? std::min(a: OldLimit, b: Value) : std::max(a: OldLimit, b: Value);
7946 }
7947 Kernel.addFnAttr(Kind: Name, Val: llvm::utostr(X: Value));
7948}
7949
7950std::pair<int32_t, int32_t>
7951OpenMPIRBuilder::readThreadBoundsForKernel(const Triple &T, Function &Kernel) {
7952 int32_t ThreadLimit =
7953 Kernel.getFnAttributeAsParsedInteger(Kind: "omp_target_thread_limit");
7954
7955 if (T.isAMDGPU()) {
7956 const auto &Attr = Kernel.getFnAttribute(Kind: "amdgpu-flat-work-group-size");
7957 if (!Attr.isValid() || !Attr.isStringAttribute())
7958 return {0, ThreadLimit};
7959 auto [LBStr, UBStr] = Attr.getValueAsString().split(Separator: ',');
7960 int32_t LB, UB;
7961 if (!llvm::to_integer(S: UBStr, Num&: UB, Base: 10))
7962 return {0, ThreadLimit};
7963 UB = ThreadLimit ? std::min(a: ThreadLimit, b: UB) : UB;
7964 if (!llvm::to_integer(S: LBStr, Num&: LB, Base: 10))
7965 return {0, UB};
7966 return {LB, UB};
7967 }
7968
7969 if (Kernel.hasFnAttribute(Kind: NVVMAttr::MaxNTID)) {
7970 int32_t UB = Kernel.getFnAttributeAsParsedInteger(Kind: NVVMAttr::MaxNTID);
7971 return {0, ThreadLimit ? std::min(a: ThreadLimit, b: UB) : UB};
7972 }
7973 return {0, ThreadLimit};
7974}
7975
7976void OpenMPIRBuilder::writeThreadBoundsForKernel(const Triple &T,
7977 Function &Kernel, int32_t LB,
7978 int32_t UB) {
7979 Kernel.addFnAttr(Kind: "omp_target_thread_limit", Val: std::to_string(val: UB));
7980
7981 if (T.isAMDGPU()) {
7982 Kernel.addFnAttr(Kind: "amdgpu-flat-work-group-size",
7983 Val: llvm::utostr(X: LB) + "," + llvm::utostr(X: UB));
7984 return;
7985 }
7986
7987 updateNVPTXAttr(Kernel, Name: NVVMAttr::MaxNTID, Value: UB, Min: true);
7988}
7989
7990std::pair<int32_t, int32_t>
7991OpenMPIRBuilder::readTeamBoundsForKernel(const Triple &, Function &Kernel) {
7992 // TODO: Read from backend annotations if available.
7993 return {0, Kernel.getFnAttributeAsParsedInteger(Kind: "omp_target_num_teams")};
7994}
7995
7996void OpenMPIRBuilder::writeTeamsForKernel(const Triple &T, Function &Kernel,
7997 int32_t LB, int32_t UB) {
7998 if (T.isNVPTX())
7999 if (UB > 0)
8000 Kernel.addFnAttr(Kind: NVVMAttr::MaxClusterRank, Val: llvm::utostr(X: UB));
8001 if (T.isAMDGPU())
8002 Kernel.addFnAttr(Kind: "amdgpu-max-num-workgroups", Val: llvm::utostr(X: LB) + ",1,1");
8003
8004 Kernel.addFnAttr(Kind: "omp_target_num_teams", Val: std::to_string(val: LB));
8005}
8006
8007void OpenMPIRBuilder::setOutlinedTargetRegionFunctionAttributes(
8008 Function *OutlinedFn) {
8009 if (Config.isTargetDevice()) {
8010 OutlinedFn->setLinkage(GlobalValue::WeakODRLinkage);
8011 // TODO: Determine if DSO local can be set to true.
8012 OutlinedFn->setDSOLocal(false);
8013 OutlinedFn->setVisibility(GlobalValue::ProtectedVisibility);
8014 if (T.isAMDGCN())
8015 OutlinedFn->setCallingConv(CallingConv::AMDGPU_KERNEL);
8016 else if (T.isNVPTX())
8017 OutlinedFn->setCallingConv(CallingConv::PTX_Kernel);
8018 else if (T.isSPIRV())
8019 OutlinedFn->setCallingConv(CallingConv::SPIR_KERNEL);
8020 }
8021}
8022
8023Constant *OpenMPIRBuilder::createOutlinedFunctionID(Function *OutlinedFn,
8024 StringRef EntryFnIDName) {
8025 if (Config.isTargetDevice()) {
8026 assert(OutlinedFn && "The outlined function must exist if embedded");
8027 return OutlinedFn;
8028 }
8029
8030 return new GlobalVariable(
8031 M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
8032 Constant::getNullValue(Ty: Builder.getInt8Ty()), EntryFnIDName);
8033}
8034
8035Constant *OpenMPIRBuilder::createTargetRegionEntryAddr(Function *OutlinedFn,
8036 StringRef EntryFnName) {
8037 if (OutlinedFn)
8038 return OutlinedFn;
8039
8040 assert(!M.getGlobalVariable(EntryFnName, true) &&
8041 "Named kernel already exists?");
8042 return new GlobalVariable(
8043 M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::InternalLinkage,
8044 Constant::getNullValue(Ty: Builder.getInt8Ty()), EntryFnName);
8045}
8046
8047Error OpenMPIRBuilder::emitTargetRegionFunction(
8048 TargetRegionEntryInfo &EntryInfo,
8049 FunctionGenCallback &GenerateFunctionCallback, bool IsOffloadEntry,
8050 Function *&OutlinedFn, Constant *&OutlinedFnID) {
8051
8052 SmallString<64> EntryFnName;
8053 OffloadInfoManager.getTargetRegionEntryFnName(Name&: EntryFnName, EntryInfo);
8054
8055 if (Config.isTargetDevice() || !Config.openMPOffloadMandatory()) {
8056 Expected<Function *> CBResult = GenerateFunctionCallback(EntryFnName);
8057 if (!CBResult)
8058 return CBResult.takeError();
8059 OutlinedFn = *CBResult;
8060 } else {
8061 OutlinedFn = nullptr;
8062 }
8063
8064 // If this target outline function is not an offload entry, we don't need to
8065 // register it. This may be in the case of a false if clause, or if there are
8066 // no OpenMP targets.
8067 if (!IsOffloadEntry)
8068 return Error::success();
8069
8070 std::string EntryFnIDName =
8071 Config.isTargetDevice()
8072 ? std::string(EntryFnName)
8073 : createPlatformSpecificName(Parts: {EntryFnName, "region_id"});
8074
8075 OutlinedFnID = registerTargetRegionFunction(EntryInfo, OutlinedFunction: OutlinedFn,
8076 EntryFnName, EntryFnIDName);
8077 return Error::success();
8078}
8079
8080Constant *OpenMPIRBuilder::registerTargetRegionFunction(
8081 TargetRegionEntryInfo &EntryInfo, Function *OutlinedFn,
8082 StringRef EntryFnName, StringRef EntryFnIDName) {
8083 if (OutlinedFn)
8084 setOutlinedTargetRegionFunctionAttributes(OutlinedFn);
8085 auto OutlinedFnID = createOutlinedFunctionID(OutlinedFn, EntryFnIDName);
8086 auto EntryAddr = createTargetRegionEntryAddr(OutlinedFn, EntryFnName);
8087 OffloadInfoManager.registerTargetRegionEntryInfo(
8088 EntryInfo, Addr: EntryAddr, ID: OutlinedFnID,
8089 Flags: OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion);
8090 return OutlinedFnID;
8091}
8092
8093OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData(
8094 const LocationDescription &Loc, InsertPointTy AllocaIP,
8095 InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
8096 TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
8097 CustomMapperCallbackTy CustomMapperCB, omp::RuntimeFunction *MapperFunc,
8098 function_ref<InsertPointOrErrorTy(InsertPointTy CodeGenIP,
8099 BodyGenTy BodyGenType)>
8100 BodyGenCB,
8101 function_ref<void(unsigned int, Value *)> DeviceAddrCB, Value *SrcLocInfo) {
8102 if (!updateToLocation(Loc))
8103 return InsertPointTy();
8104
8105 Builder.restoreIP(IP: CodeGenIP);
8106
8107 bool IsStandAlone = !BodyGenCB;
8108 MapInfosTy *MapInfo;
8109 // Generate the code for the opening of the data environment. Capture all the
8110 // arguments of the runtime call by reference because they are used in the
8111 // closing of the region.
8112 auto BeginThenGen = [&](InsertPointTy AllocaIP,
8113 InsertPointTy CodeGenIP) -> Error {
8114 MapInfo = &GenMapInfoCB(Builder.saveIP());
8115 if (Error Err = emitOffloadingArrays(
8116 AllocaIP, CodeGenIP: Builder.saveIP(), CombinedInfo&: *MapInfo, Info, CustomMapperCB,
8117 /*IsNonContiguous=*/true, DeviceAddrCB))
8118 return Err;
8119
8120 TargetDataRTArgs RTArgs;
8121 emitOffloadingArraysArgument(Builder, RTArgs, Info);
8122
8123 // Emit the number of elements in the offloading arrays.
8124 Value *PointerNum = Builder.getInt32(C: Info.NumberOfPtrs);
8125
8126 // Source location for the ident struct
8127 if (!SrcLocInfo) {
8128 uint32_t SrcLocStrSize;
8129 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
8130 SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
8131 }
8132
8133 SmallVector<llvm::Value *, 13> OffloadingArgs = {
8134 SrcLocInfo, DeviceID,
8135 PointerNum, RTArgs.BasePointersArray,
8136 RTArgs.PointersArray, RTArgs.SizesArray,
8137 RTArgs.MapTypesArray, RTArgs.MapNamesArray,
8138 RTArgs.MappersArray};
8139
8140 if (IsStandAlone) {
8141 assert(MapperFunc && "MapperFunc missing for standalone target data");
8142
8143 auto TaskBodyCB = [&](Value *, Value *,
8144 IRBuilderBase::InsertPoint) -> Error {
8145 if (Info.HasNoWait) {
8146 OffloadingArgs.append(IL: {llvm::Constant::getNullValue(Ty: Int32),
8147 llvm::Constant::getNullValue(Ty: VoidPtr),
8148 llvm::Constant::getNullValue(Ty: Int32),
8149 llvm::Constant::getNullValue(Ty: VoidPtr)});
8150 }
8151
8152 createRuntimeFunctionCall(Callee: getOrCreateRuntimeFunctionPtr(FnID: *MapperFunc),
8153 Args: OffloadingArgs);
8154
8155 if (Info.HasNoWait) {
8156 BasicBlock *OffloadContBlock =
8157 BasicBlock::Create(Context&: Builder.getContext(), Name: "omp_offload.cont");
8158 Function *CurFn = Builder.GetInsertBlock()->getParent();
8159 emitBlock(BB: OffloadContBlock, CurFn, /*IsFinished=*/true);
8160 Builder.restoreIP(IP: Builder.saveIP());
8161 }
8162 return Error::success();
8163 };
8164
8165 bool RequiresOuterTargetTask = Info.HasNoWait;
8166 if (!RequiresOuterTargetTask)
8167 cantFail(Err: TaskBodyCB(/*DeviceID=*/nullptr, /*RTLoc=*/nullptr,
8168 /*TargetTaskAllocaIP=*/{}));
8169 else
8170 cantFail(ValOrErr: emitTargetTask(TaskBodyCB, DeviceID, RTLoc: SrcLocInfo, AllocaIP,
8171 /*Dependencies=*/{}, RTArgs, HasNoWait: Info.HasNoWait));
8172 } else {
8173 Function *BeginMapperFunc = getOrCreateRuntimeFunctionPtr(
8174 FnID: omp::OMPRTL___tgt_target_data_begin_mapper);
8175
8176 createRuntimeFunctionCall(Callee: BeginMapperFunc, Args: OffloadingArgs);
8177
8178 for (auto DeviceMap : Info.DevicePtrInfoMap) {
8179 if (isa<AllocaInst>(Val: DeviceMap.second.second)) {
8180 auto *LI =
8181 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: DeviceMap.second.first);
8182 Builder.CreateStore(Val: LI, Ptr: DeviceMap.second.second);
8183 }
8184 }
8185
8186 // If device pointer privatization is required, emit the body of the
8187 // region here. It will have to be duplicated: with and without
8188 // privatization.
8189 InsertPointOrErrorTy AfterIP =
8190 BodyGenCB(Builder.saveIP(), BodyGenTy::Priv);
8191 if (!AfterIP)
8192 return AfterIP.takeError();
8193 Builder.restoreIP(IP: *AfterIP);
8194 }
8195 return Error::success();
8196 };
8197
8198 // If we need device pointer privatization, we need to emit the body of the
8199 // region with no privatization in the 'else' branch of the conditional.
8200 // Otherwise, we don't have to do anything.
8201 auto BeginElseGen = [&](InsertPointTy AllocaIP,
8202 InsertPointTy CodeGenIP) -> Error {
8203 InsertPointOrErrorTy AfterIP =
8204 BodyGenCB(Builder.saveIP(), BodyGenTy::DupNoPriv);
8205 if (!AfterIP)
8206 return AfterIP.takeError();
8207 Builder.restoreIP(IP: *AfterIP);
8208 return Error::success();
8209 };
8210
8211 // Generate code for the closing of the data region.
8212 auto EndThenGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
8213 TargetDataRTArgs RTArgs;
8214 Info.EmitDebug = !MapInfo->Names.empty();
8215 emitOffloadingArraysArgument(Builder, RTArgs, Info, /*ForEndCall=*/true);
8216
8217 // Emit the number of elements in the offloading arrays.
8218 Value *PointerNum = Builder.getInt32(C: Info.NumberOfPtrs);
8219
8220 // Source location for the ident struct
8221 if (!SrcLocInfo) {
8222 uint32_t SrcLocStrSize;
8223 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
8224 SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
8225 }
8226
8227 Value *OffloadingArgs[] = {SrcLocInfo, DeviceID,
8228 PointerNum, RTArgs.BasePointersArray,
8229 RTArgs.PointersArray, RTArgs.SizesArray,
8230 RTArgs.MapTypesArray, RTArgs.MapNamesArray,
8231 RTArgs.MappersArray};
8232 Function *EndMapperFunc =
8233 getOrCreateRuntimeFunctionPtr(FnID: omp::OMPRTL___tgt_target_data_end_mapper);
8234
8235 createRuntimeFunctionCall(Callee: EndMapperFunc, Args: OffloadingArgs);
8236 return Error::success();
8237 };
8238
8239 // We don't have to do anything to close the region if the if clause evaluates
8240 // to false.
8241 auto EndElseGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
8242 return Error::success();
8243 };
8244
8245 Error Err = [&]() -> Error {
8246 if (BodyGenCB) {
8247 Error Err = [&]() {
8248 if (IfCond)
8249 return emitIfClause(Cond: IfCond, ThenGen: BeginThenGen, ElseGen: BeginElseGen, AllocaIP);
8250 return BeginThenGen(AllocaIP, Builder.saveIP());
8251 }();
8252
8253 if (Err)
8254 return Err;
8255
8256 // If we don't require privatization of device pointers, we emit the body
8257 // in between the runtime calls. This avoids duplicating the body code.
8258 InsertPointOrErrorTy AfterIP =
8259 BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv);
8260 if (!AfterIP)
8261 return AfterIP.takeError();
8262 restoreIPandDebugLoc(Builder, IP: *AfterIP);
8263
8264 if (IfCond)
8265 return emitIfClause(Cond: IfCond, ThenGen: EndThenGen, ElseGen: EndElseGen, AllocaIP);
8266 return EndThenGen(AllocaIP, Builder.saveIP());
8267 }
8268 if (IfCond)
8269 return emitIfClause(Cond: IfCond, ThenGen: BeginThenGen, ElseGen: EndElseGen, AllocaIP);
8270 return BeginThenGen(AllocaIP, Builder.saveIP());
8271 }();
8272
8273 if (Err)
8274 return Err;
8275
8276 return Builder.saveIP();
8277}
8278
8279FunctionCallee
8280OpenMPIRBuilder::createForStaticInitFunction(unsigned IVSize, bool IVSigned,
8281 bool IsGPUDistribute) {
8282 assert((IVSize == 32 || IVSize == 64) &&
8283 "IV size is not compatible with the omp runtime");
8284 RuntimeFunction Name;
8285 if (IsGPUDistribute)
8286 Name = IVSize == 32
8287 ? (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_4
8288 : omp::OMPRTL___kmpc_distribute_static_init_4u)
8289 : (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_8
8290 : omp::OMPRTL___kmpc_distribute_static_init_8u);
8291 else
8292 Name = IVSize == 32 ? (IVSigned ? omp::OMPRTL___kmpc_for_static_init_4
8293 : omp::OMPRTL___kmpc_for_static_init_4u)
8294 : (IVSigned ? omp::OMPRTL___kmpc_for_static_init_8
8295 : omp::OMPRTL___kmpc_for_static_init_8u);
8296
8297 return getOrCreateRuntimeFunction(M, FnID: Name);
8298}
8299
8300FunctionCallee OpenMPIRBuilder::createDispatchInitFunction(unsigned IVSize,
8301 bool IVSigned) {
8302 assert((IVSize == 32 || IVSize == 64) &&
8303 "IV size is not compatible with the omp runtime");
8304 RuntimeFunction Name = IVSize == 32
8305 ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_4
8306 : omp::OMPRTL___kmpc_dispatch_init_4u)
8307 : (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_8
8308 : omp::OMPRTL___kmpc_dispatch_init_8u);
8309
8310 return getOrCreateRuntimeFunction(M, FnID: Name);
8311}
8312
8313FunctionCallee OpenMPIRBuilder::createDispatchNextFunction(unsigned IVSize,
8314 bool IVSigned) {
8315 assert((IVSize == 32 || IVSize == 64) &&
8316 "IV size is not compatible with the omp runtime");
8317 RuntimeFunction Name = IVSize == 32
8318 ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_4
8319 : omp::OMPRTL___kmpc_dispatch_next_4u)
8320 : (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_8
8321 : omp::OMPRTL___kmpc_dispatch_next_8u);
8322
8323 return getOrCreateRuntimeFunction(M, FnID: Name);
8324}
8325
8326FunctionCallee OpenMPIRBuilder::createDispatchFiniFunction(unsigned IVSize,
8327 bool IVSigned) {
8328 assert((IVSize == 32 || IVSize == 64) &&
8329 "IV size is not compatible with the omp runtime");
8330 RuntimeFunction Name = IVSize == 32
8331 ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_4
8332 : omp::OMPRTL___kmpc_dispatch_fini_4u)
8333 : (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_8
8334 : omp::OMPRTL___kmpc_dispatch_fini_8u);
8335
8336 return getOrCreateRuntimeFunction(M, FnID: Name);
8337}
8338
8339FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
8340 return getOrCreateRuntimeFunction(M, FnID: omp::OMPRTL___kmpc_dispatch_deinit);
8341}
8342
8343static void FixupDebugInfoForOutlinedFunction(
8344 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, Function *Func,
8345 DenseMap<Value *, std::tuple<Value *, unsigned>> &ValueReplacementMap) {
8346
8347 DISubprogram *NewSP = Func->getSubprogram();
8348 if (!NewSP)
8349 return;
8350
8351 SmallDenseMap<DILocalVariable *, DILocalVariable *> RemappedVariables;
8352
8353 auto GetUpdatedDIVariable = [&](DILocalVariable *OldVar, unsigned arg) {
8354 DILocalVariable *&NewVar = RemappedVariables[OldVar];
8355 // Only use cached variable if the arg number matches. This is important
8356 // so that DIVariable created for privatized variables are not discarded.
8357 if (NewVar && (arg == NewVar->getArg()))
8358 return NewVar;
8359
8360 NewVar = llvm::DILocalVariable::get(
8361 Context&: Builder.getContext(), Scope: OldVar->getScope(), Name: OldVar->getName(),
8362 File: OldVar->getFile(), Line: OldVar->getLine(), Type: OldVar->getType(), Arg: arg,
8363 Flags: OldVar->getFlags(), AlignInBits: OldVar->getAlignInBits(), Annotations: OldVar->getAnnotations());
8364 return NewVar;
8365 };
8366
8367 auto UpdateDebugRecord = [&](auto *DR) {
8368 DILocalVariable *OldVar = DR->getVariable();
8369 unsigned ArgNo = 0;
8370 for (auto Loc : DR->location_ops()) {
8371 auto Iter = ValueReplacementMap.find(Loc);
8372 if (Iter != ValueReplacementMap.end()) {
8373 DR->replaceVariableLocationOp(Loc, std::get<0>(Iter->second));
8374 ArgNo = std::get<1>(Iter->second) + 1;
8375 }
8376 }
8377 if (ArgNo != 0)
8378 DR->setVariable(GetUpdatedDIVariable(OldVar, ArgNo));
8379 };
8380
8381 // The location and scope of variable intrinsics and records still point to
8382 // the parent function of the target region. Update them.
8383 for (Instruction &I : instructions(F: Func)) {
8384 assert(!isa<llvm::DbgVariableIntrinsic>(&I) &&
8385 "Unexpected debug intrinsic");
8386 for (DbgVariableRecord &DVR : filterDbgVars(R: I.getDbgRecordRange()))
8387 UpdateDebugRecord(&DVR);
8388 }
8389 // An extra argument is passed to the device. Create the debug data for it.
8390 if (OMPBuilder.Config.isTargetDevice()) {
8391 DICompileUnit *CU = NewSP->getUnit();
8392 Module *M = Func->getParent();
8393 DIBuilder DB(*M, true, CU);
8394 DIType *VoidPtrTy =
8395 DB.createQualifiedType(Tag: dwarf::DW_TAG_pointer_type, FromTy: nullptr);
8396 unsigned ArgNo = Func->arg_size();
8397 DILocalVariable *Var = DB.createParameterVariable(
8398 Scope: NewSP, Name: "dyn_ptr", ArgNo, File: NewSP->getFile(), /*LineNo=*/0, Ty: VoidPtrTy,
8399 /*AlwaysPreserve=*/false, Flags: DINode::DIFlags::FlagArtificial);
8400 auto Loc = DILocation::get(Context&: Func->getContext(), Line: 0, Column: 0, Scope: NewSP, InlinedAt: 0);
8401 Argument *LastArg = Func->getArg(i: Func->arg_size() - 1);
8402 DB.insertDeclare(Storage: LastArg, VarInfo: Var, Expr: DB.createExpression(), DL: Loc,
8403 InsertAtEnd: &(*Func->begin()));
8404 }
8405}
8406
8407static Value *removeASCastIfPresent(Value *V) {
8408 if (Operator::getOpcode(V) == Instruction::AddrSpaceCast)
8409 return cast<Operator>(Val: V)->getOperand(i: 0);
8410 return V;
8411}
8412
8413static Expected<Function *> createOutlinedFunction(
8414 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
8415 const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
8416 StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
8417 OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
8418 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
8419 SmallVector<Type *> ParameterTypes;
8420 if (OMPBuilder.Config.isTargetDevice()) {
8421 // All parameters to target devices are passed as pointers
8422 // or i64. This assumes 64-bit address spaces/pointers.
8423 for (auto &Arg : Inputs)
8424 ParameterTypes.push_back(Elt: Arg->getType()->isPointerTy()
8425 ? Arg->getType()
8426 : Type::getInt64Ty(C&: Builder.getContext()));
8427 } else {
8428 for (auto &Arg : Inputs)
8429 ParameterTypes.push_back(Elt: Arg->getType());
8430 }
8431
8432 // The implicit dyn_ptr argument is always the last parameter on both host
8433 // and device so the argument counts match without runtime manipulation.
8434 auto *PtrTy = PointerType::getUnqual(C&: Builder.getContext());
8435 ParameterTypes.push_back(Elt: PtrTy);
8436
8437 auto BB = Builder.GetInsertBlock();
8438 auto M = BB->getModule();
8439 auto FuncType = FunctionType::get(Result: Builder.getVoidTy(), Params: ParameterTypes,
8440 /*isVarArg*/ false);
8441 auto Func =
8442 Function::Create(Ty: FuncType, Linkage: GlobalValue::InternalLinkage, N: FuncName, M);
8443
8444 // Forward target-cpu and target-features function attributes from the
8445 // original function to the new outlined function.
8446 Function *ParentFn = Builder.GetInsertBlock()->getParent();
8447
8448 auto TargetCpuAttr = ParentFn->getFnAttribute(Kind: "target-cpu");
8449 if (TargetCpuAttr.isStringAttribute())
8450 Func->addFnAttr(Attr: TargetCpuAttr);
8451
8452 auto TargetFeaturesAttr = ParentFn->getFnAttribute(Kind: "target-features");
8453 if (TargetFeaturesAttr.isStringAttribute())
8454 Func->addFnAttr(Attr: TargetFeaturesAttr);
8455
8456 if (OMPBuilder.Config.isTargetDevice()) {
8457 Value *ExecMode =
8458 OMPBuilder.emitKernelExecutionMode(KernelName: FuncName, Mode: DefaultAttrs.ExecFlags);
8459 OMPBuilder.emitUsed(Name: "llvm.compiler.used", List: {ExecMode});
8460 }
8461
8462 // Save insert point.
8463 IRBuilder<>::InsertPointGuard IPG(Builder);
8464 // We will generate the entries in the outlined function but the debug
8465 // location may still be pointing to the parent function. Reset it now.
8466 Builder.SetCurrentDebugLocation(llvm::DebugLoc());
8467
8468 // Generate the region into the function.
8469 BasicBlock *EntryBB = BasicBlock::Create(Context&: Builder.getContext(), Name: "entry", Parent: Func);
8470 Builder.SetInsertPoint(EntryBB);
8471
8472 // Insert target init call in the device compilation pass.
8473 if (OMPBuilder.Config.isTargetDevice())
8474 Builder.restoreIP(IP: OMPBuilder.createTargetInit(Loc: Builder, Attrs: DefaultAttrs));
8475
8476 BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
8477
8478 // As we embed the user code in the middle of our target region after we
8479 // generate entry code, we must move what allocas we can into the entry
8480 // block to avoid possible breaking optimisations for device
8481 if (OMPBuilder.Config.isTargetDevice())
8482 OMPBuilder.ConstantAllocaRaiseCandidates.emplace_back(Args&: Func);
8483
8484 // Insert target deinit call in the device compilation pass.
8485 BasicBlock *OutlinedBodyBB =
8486 splitBB(Builder, /*CreateBranch=*/true, Name: "outlined.body");
8487 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = CBFunc(
8488 Builder.saveIP(),
8489 OpenMPIRBuilder::InsertPointTy(OutlinedBodyBB, OutlinedBodyBB->begin()));
8490 if (!AfterIP)
8491 return AfterIP.takeError();
8492 Builder.restoreIP(IP: *AfterIP);
8493 if (OMPBuilder.Config.isTargetDevice())
8494 OMPBuilder.createTargetDeinit(Loc: Builder);
8495
8496 // Insert return instruction.
8497 Builder.CreateRetVoid();
8498
8499 // New Alloca IP at entry point of created device function.
8500 Builder.SetInsertPoint(EntryBB->getFirstNonPHIIt());
8501 auto AllocaIP = Builder.saveIP();
8502
8503 Builder.SetInsertPoint(UserCodeEntryBB->getFirstNonPHIOrDbg());
8504
8505 // Do not include the artificial dyn_ptr argument.
8506 const auto &ArgRange = make_range(x: Func->arg_begin(), y: Func->arg_end() - 1);
8507
8508 DenseMap<Value *, std::tuple<Value *, unsigned>> ValueReplacementMap;
8509
8510 auto ReplaceValue = [](Value *Input, Value *InputCopy, Function *Func) {
8511 // Things like GEP's can come in the form of Constants. Constants and
8512 // ConstantExpr's do not have access to the knowledge of what they're
8513 // contained in, so we must dig a little to find an instruction so we
8514 // can tell if they're used inside of the function we're outlining. We
8515 // also replace the original constant expression with a new instruction
8516 // equivalent; an instruction as it allows easy modification in the
8517 // following loop, as we can now know the constant (instruction) is
8518 // owned by our target function and replaceUsesOfWith can now be invoked
8519 // on it (cannot do this with constants it seems). A brand new one also
8520 // allows us to be cautious as it is perhaps possible the old expression
8521 // was used inside of the function but exists and is used externally
8522 // (unlikely by the nature of a Constant, but still).
8523 // NOTE: We cannot remove dead constants that have been rewritten to
8524 // instructions at this stage, we run the risk of breaking later lowering
8525 // by doing so as we could still be in the process of lowering the module
8526 // from MLIR to LLVM-IR and the MLIR lowering may still require the original
8527 // constants we have created rewritten versions of.
8528 if (auto *Const = dyn_cast<Constant>(Val: Input))
8529 convertUsersOfConstantsToInstructions(Consts: Const, RestrictToFunc: Func, RemoveDeadConstants: false);
8530
8531 // Collect users before iterating over them to avoid invalidating the
8532 // iteration in case a user uses Input more than once (e.g. a call
8533 // instruction).
8534 SetVector<User *> Users(Input->users().begin(), Input->users().end());
8535 // Collect all the instructions
8536 for (User *User : make_early_inc_range(Range&: Users))
8537 if (auto *Instr = dyn_cast<Instruction>(Val: User))
8538 if (Instr->getFunction() == Func)
8539 Instr->replaceUsesOfWith(From: Input, To: InputCopy);
8540 };
8541
8542 SmallVector<std::pair<Value *, Value *>> DeferredReplacement;
8543
8544 // Rewrite uses of input valus to parameters.
8545 for (auto InArg : zip(t&: Inputs, u: ArgRange)) {
8546 Value *Input = std::get<0>(t&: InArg);
8547 Argument &Arg = std::get<1>(t&: InArg);
8548 Value *InputCopy = nullptr;
8549
8550 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
8551 ArgAccessorFuncCB(Arg, Input, InputCopy, AllocaIP, Builder.saveIP());
8552 if (!AfterIP)
8553 return AfterIP.takeError();
8554 Builder.restoreIP(IP: *AfterIP);
8555 ValueReplacementMap[Input] = std::make_tuple(args&: InputCopy, args: Arg.getArgNo());
8556
8557 // In certain cases a Global may be set up for replacement, however, this
8558 // Global may be used in multiple arguments to the kernel, just segmented
8559 // apart, for example, if we have a global array, that is sectioned into
8560 // multiple mappings (technically not legal in OpenMP, but there is a case
8561 // in Fortran for Common Blocks where this is neccesary), we will end up
8562 // with GEP's into this array inside the kernel, that refer to the Global
8563 // but are technically separate arguments to the kernel for all intents and
8564 // purposes. If we have mapped a segment that requires a GEP into the 0-th
8565 // index, it will fold into an referal to the Global, if we then encounter
8566 // this folded GEP during replacement all of the references to the
8567 // Global in the kernel will be replaced with the argument we have generated
8568 // that corresponds to it, including any other GEP's that refer to the
8569 // Global that may be other arguments. This will invalidate all of the other
8570 // preceding mapped arguments that refer to the same global that may be
8571 // separate segments. To prevent this, we defer global processing until all
8572 // other processing has been performed.
8573 if (llvm::isa<llvm::GlobalValue, llvm::GlobalObject, llvm::GlobalVariable>(
8574 Val: removeASCastIfPresent(V: Input))) {
8575 DeferredReplacement.push_back(Elt: std::make_pair(x&: Input, y&: InputCopy));
8576 continue;
8577 }
8578
8579 if (isa<ConstantData>(Val: Input))
8580 continue;
8581
8582 ReplaceValue(Input, InputCopy, Func);
8583 }
8584
8585 // Replace all of our deferred Input values, currently just Globals.
8586 for (auto Deferred : DeferredReplacement)
8587 ReplaceValue(std::get<0>(in&: Deferred), std::get<1>(in&: Deferred), Func);
8588
8589 FixupDebugInfoForOutlinedFunction(OMPBuilder, Builder, Func,
8590 ValueReplacementMap);
8591 return Func;
8592}
8593/// Given a task descriptor, TaskWithPrivates, return the pointer to the block
8594/// of pointers containing shared data between the parent task and the created
8595/// task.
8596static LoadInst *loadSharedDataFromTaskDescriptor(OpenMPIRBuilder &OMPIRBuilder,
8597 IRBuilderBase &Builder,
8598 Value *TaskWithPrivates,
8599 Type *TaskWithPrivatesTy) {
8600
8601 Type *TaskTy = OMPIRBuilder.Task;
8602 LLVMContext &Ctx = Builder.getContext();
8603 Value *TaskT =
8604 Builder.CreateStructGEP(Ty: TaskWithPrivatesTy, Ptr: TaskWithPrivates, Idx: 0);
8605 Value *Shareds = TaskT;
8606 // TaskWithPrivatesTy can be one of the following
8607 // 1. %struct.task_with_privates = type { %struct.kmp_task_ompbuilder_t,
8608 // %struct.privates }
8609 // 2. %struct.kmp_task_ompbuilder_t ;; This is simply TaskTy
8610 //
8611 // In the former case, that is when TaskWithPrivatesTy != TaskTy,
8612 // its first member has to be the task descriptor. TaskTy is the type of the
8613 // task descriptor. TaskT is the pointer to the task descriptor. Loading the
8614 // first member of TaskT, gives us the pointer to shared data.
8615 if (TaskWithPrivatesTy != TaskTy)
8616 Shareds = Builder.CreateStructGEP(Ty: TaskTy, Ptr: TaskT, Idx: 0);
8617 return Builder.CreateLoad(Ty: PointerType::getUnqual(C&: Ctx), Ptr: Shareds);
8618}
8619/// Create an entry point for a target task with the following.
8620/// It'll have the following signature
8621/// void @.omp_target_task_proxy_func(i32 %thread.id, ptr %task)
8622/// This function is called from emitTargetTask once the
8623/// code to launch the target kernel has been outlined already.
8624/// NumOffloadingArrays is the number of offloading arrays that we need to copy
8625/// into the task structure so that the deferred target task can access this
8626/// data even after the stack frame of the generating task has been rolled
8627/// back. Offloading arrays contain base pointers, pointers, sizes etc
8628/// of the data that the target kernel will access. These in effect are the
8629/// non-empty arrays of pointers held by OpenMPIRBuilder::TargetDataRTArgs.
8630static Function *emitTargetTaskProxyFunction(
8631 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, CallInst *StaleCI,
8632 StructType *PrivatesTy, StructType *TaskWithPrivatesTy,
8633 const size_t NumOffloadingArrays, const int SharedArgsOperandNo) {
8634
8635 // If NumOffloadingArrays is non-zero, PrivatesTy better not be nullptr.
8636 // This is because PrivatesTy is the type of the structure in which
8637 // we pass the offloading arrays to the deferred target task.
8638 assert((!NumOffloadingArrays || PrivatesTy) &&
8639 "PrivatesTy cannot be nullptr when there are offloadingArrays"
8640 "to privatize");
8641
8642 Module &M = OMPBuilder.M;
8643 // KernelLaunchFunction is the target launch function, i.e.
8644 // the function that sets up kernel arguments and calls
8645 // __tgt_target_kernel to launch the kernel on the device.
8646 //
8647 Function *KernelLaunchFunction = StaleCI->getCalledFunction();
8648
8649 // StaleCI is the CallInst which is the call to the outlined
8650 // target kernel launch function. If there are local live-in values
8651 // that the outlined function uses then these are aggregated into a structure
8652 // which is passed as the second argument. If there are no local live-in
8653 // values or if all values used by the outlined kernel are global variables,
8654 // then there's only one argument, the threadID. So, StaleCI can be
8655 //
8656 // %structArg = alloca { ptr, ptr }, align 8
8657 // %gep_ = getelementptr { ptr, ptr }, ptr %structArg, i32 0, i32 0
8658 // store ptr %20, ptr %gep_, align 8
8659 // %gep_8 = getelementptr { ptr, ptr }, ptr %structArg, i32 0, i32 1
8660 // store ptr %21, ptr %gep_8, align 8
8661 // call void @_QQmain..omp_par.1(i32 %global.tid.val6, ptr %structArg)
8662 //
8663 // OR
8664 //
8665 // call void @_QQmain..omp_par.1(i32 %global.tid.val6)
8666 OpenMPIRBuilder::InsertPointTy IP(StaleCI->getParent(),
8667 StaleCI->getIterator());
8668
8669 LLVMContext &Ctx = StaleCI->getParent()->getContext();
8670
8671 Type *ThreadIDTy = Type::getInt32Ty(C&: Ctx);
8672 Type *TaskPtrTy = OMPBuilder.TaskPtr;
8673 [[maybe_unused]] Type *TaskTy = OMPBuilder.Task;
8674
8675 auto ProxyFnTy =
8676 FunctionType::get(Result: Builder.getVoidTy(), Params: {ThreadIDTy, TaskPtrTy},
8677 /* isVarArg */ false);
8678 auto ProxyFn = Function::Create(Ty: ProxyFnTy, Linkage: GlobalValue::InternalLinkage,
8679 N: ".omp_target_task_proxy_func",
8680 M: Builder.GetInsertBlock()->getModule());
8681 Value *ThreadId = ProxyFn->getArg(i: 0);
8682 Value *TaskWithPrivates = ProxyFn->getArg(i: 1);
8683 ThreadId->setName("thread.id");
8684 TaskWithPrivates->setName("task");
8685
8686 bool HasShareds = SharedArgsOperandNo > 0;
8687 bool HasOffloadingArrays = NumOffloadingArrays > 0;
8688 BasicBlock *EntryBB =
8689 BasicBlock::Create(Context&: Builder.getContext(), Name: "entry", Parent: ProxyFn);
8690 Builder.SetInsertPoint(EntryBB);
8691
8692 SmallVector<Value *> KernelLaunchArgs;
8693 KernelLaunchArgs.reserve(N: StaleCI->arg_size());
8694 KernelLaunchArgs.push_back(Elt: ThreadId);
8695
8696 if (HasOffloadingArrays) {
8697 assert(TaskTy != TaskWithPrivatesTy &&
8698 "If there are offloading arrays to pass to the target"
8699 "TaskTy cannot be the same as TaskWithPrivatesTy");
8700 (void)TaskTy;
8701 Value *Privates =
8702 Builder.CreateStructGEP(Ty: TaskWithPrivatesTy, Ptr: TaskWithPrivates, Idx: 1);
8703 for (unsigned int i = 0; i < NumOffloadingArrays; ++i)
8704 KernelLaunchArgs.push_back(
8705 Elt: Builder.CreateStructGEP(Ty: PrivatesTy, Ptr: Privates, Idx: i));
8706 }
8707
8708 if (HasShareds) {
8709 auto *ArgStructAlloca =
8710 dyn_cast<AllocaInst>(Val: StaleCI->getArgOperand(i: SharedArgsOperandNo));
8711 assert(ArgStructAlloca &&
8712 "Unable to find the alloca instruction corresponding to arguments "
8713 "for extracted function");
8714 auto *ArgStructType = cast<StructType>(Val: ArgStructAlloca->getAllocatedType());
8715 std::optional<TypeSize> ArgAllocSize =
8716 ArgStructAlloca->getAllocationSize(DL: M.getDataLayout());
8717 assert(ArgStructType && ArgAllocSize &&
8718 "Unable to determine size of arguments for extracted function");
8719 uint64_t StructSize = ArgAllocSize->getFixedValue();
8720
8721 AllocaInst *NewArgStructAlloca =
8722 Builder.CreateAlloca(Ty: ArgStructType, ArraySize: nullptr, Name: "structArg");
8723
8724 Value *SharedsSize = Builder.getInt64(C: StructSize);
8725
8726 LoadInst *LoadShared = loadSharedDataFromTaskDescriptor(
8727 OMPIRBuilder&: OMPBuilder, Builder, TaskWithPrivates, TaskWithPrivatesTy);
8728
8729 Builder.CreateMemCpy(
8730 Dst: NewArgStructAlloca, DstAlign: NewArgStructAlloca->getAlign(), Src: LoadShared,
8731 SrcAlign: LoadShared->getPointerAlignment(DL: M.getDataLayout()), Size: SharedsSize);
8732 KernelLaunchArgs.push_back(Elt: NewArgStructAlloca);
8733 }
8734 OMPBuilder.createRuntimeFunctionCall(Callee: KernelLaunchFunction, Args: KernelLaunchArgs);
8735 Builder.CreateRetVoid();
8736 return ProxyFn;
8737}
8738static Type *getOffloadingArrayType(Value *V) {
8739
8740 if (auto *GEP = dyn_cast<GetElementPtrInst>(Val: V))
8741 return GEP->getSourceElementType();
8742 if (auto *Alloca = dyn_cast<AllocaInst>(Val: V))
8743 return Alloca->getAllocatedType();
8744
8745 llvm_unreachable("Unhandled Instruction type");
8746 return nullptr;
8747}
8748// This function returns a struct that has at most two members.
8749// The first member is always %struct.kmp_task_ompbuilder_t, that is the task
8750// descriptor. The second member, if needed, is a struct containing arrays
8751// that need to be passed to the offloaded target kernel. For example,
8752// if .offload_baseptrs, .offload_ptrs and .offload_sizes have to be passed to
8753// the target kernel and their types are [3 x ptr], [3 x ptr] and [3 x i64]
8754// respectively, then the types created by this function are
8755//
8756// %struct.privates = type { [3 x ptr], [3 x ptr], [3 x i64] }
8757// %struct.task_with_privates = type { %struct.kmp_task_ompbuilder_t,
8758// %struct.privates }
8759// %struct.task_with_privates is returned by this function.
8760// If there aren't any offloading arrays to pass to the target kernel,
8761// %struct.kmp_task_ompbuilder_t is returned.
8762static StructType *
8763createTaskWithPrivatesTy(OpenMPIRBuilder &OMPIRBuilder,
8764 ArrayRef<Value *> OffloadingArraysToPrivatize) {
8765
8766 if (OffloadingArraysToPrivatize.empty())
8767 return OMPIRBuilder.Task;
8768
8769 SmallVector<Type *, 4> StructFieldTypes;
8770 for (Value *V : OffloadingArraysToPrivatize) {
8771 assert(V->getType()->isPointerTy() &&
8772 "Expected pointer to array to privatize. Got a non-pointer value "
8773 "instead");
8774 Type *ArrayTy = getOffloadingArrayType(V);
8775 assert(ArrayTy && "ArrayType cannot be nullptr");
8776 StructFieldTypes.push_back(Elt: ArrayTy);
8777 }
8778 StructType *PrivatesStructTy =
8779 StructType::create(Elements: StructFieldTypes, Name: "struct.privates");
8780 return StructType::create(Elements: {OMPIRBuilder.Task, PrivatesStructTy},
8781 Name: "struct.task_with_privates");
8782}
8783static Error emitTargetOutlinedFunction(
8784 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
8785 TargetRegionEntryInfo &EntryInfo,
8786 const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
8787 Function *&OutlinedFn, Constant *&OutlinedFnID,
8788 SmallVectorImpl<Value *> &Inputs,
8789 OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
8790 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
8791
8792 OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
8793 [&](StringRef EntryFnName) {
8794 return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
8795 FuncName: EntryFnName, Inputs, CBFunc,
8796 ArgAccessorFuncCB);
8797 };
8798
8799 return OMPBuilder.emitTargetRegionFunction(
8800 EntryInfo, GenerateFunctionCallback&: GenerateOutlinedFunction, IsOffloadEntry, OutlinedFn,
8801 OutlinedFnID);
8802}
8803
8804OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask(
8805 TargetTaskBodyCallbackTy TaskBodyCB, Value *DeviceID, Value *RTLoc,
8806 OpenMPIRBuilder::InsertPointTy AllocaIP,
8807 const SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
8808 const TargetDataRTArgs &RTArgs, bool HasNoWait) {
8809
8810 // The following explains the code-gen scenario for the `target` directive. A
8811 // similar scneario is followed for other device-related directives (e.g.
8812 // `target enter data`) but in similar fashion since we only need to emit task
8813 // that encapsulates the proper runtime call.
8814 //
8815 // When we arrive at this function, the target region itself has been
8816 // outlined into the function OutlinedFn.
8817 // So at ths point, for
8818 // --------------------------------------------------------------
8819 // void user_code_that_offloads(...) {
8820 // omp target depend(..) map(from:a) map(to:b) private(i)
8821 // do i = 1, 10
8822 // a(i) = b(i) + n
8823 // }
8824 //
8825 // --------------------------------------------------------------
8826 //
8827 // we have
8828 //
8829 // --------------------------------------------------------------
8830 //
8831 // void user_code_that_offloads(...) {
8832 // %.offload_baseptrs = alloca [2 x ptr], align 8
8833 // %.offload_ptrs = alloca [2 x ptr], align 8
8834 // %.offload_mappers = alloca [2 x ptr], align 8
8835 // ;; target region has been outlined and now we need to
8836 // ;; offload to it via a target task.
8837 // }
8838 // void outlined_device_function(ptr a, ptr b, ptr n) {
8839 // n = *n_ptr;
8840 // do i = 1, 10
8841 // a(i) = b(i) + n
8842 // }
8843 //
8844 // We have to now do the following
8845 // (i) Make an offloading call to outlined_device_function using the OpenMP
8846 // RTL. See 'kernel_launch_function' in the pseudo code below. This is
8847 // emitted by emitKernelLaunch
8848 // (ii) Create a task entry point function that calls kernel_launch_function
8849 // and is the entry point for the target task. See
8850 // '@.omp_target_task_proxy_func in the pseudocode below.
8851 // (iii) Create a task with the task entry point created in (ii)
8852 //
8853 // That is we create the following
8854 // struct task_with_privates {
8855 // struct kmp_task_ompbuilder_t task_struct;
8856 // struct privates {
8857 // [2 x ptr] ; baseptrs
8858 // [2 x ptr] ; ptrs
8859 // [2 x i64] ; sizes
8860 // }
8861 // }
8862 // void user_code_that_offloads(...) {
8863 // %.offload_baseptrs = alloca [2 x ptr], align 8
8864 // %.offload_ptrs = alloca [2 x ptr], align 8
8865 // %.offload_sizes = alloca [2 x i64], align 8
8866 //
8867 // %structArg = alloca { ptr, ptr, ptr }, align 8
8868 // %strucArg[0] = a
8869 // %strucArg[1] = b
8870 // %strucArg[2] = &n
8871 //
8872 // target_task_with_privates = @__kmpc_omp_target_task_alloc(...,
8873 // sizeof(kmp_task_ompbuilder_t),
8874 // sizeof(structArg),
8875 // @.omp_target_task_proxy_func,
8876 // ...)
8877 // memcpy(target_task_with_privates->task_struct->shareds, %structArg,
8878 // sizeof(structArg))
8879 // memcpy(target_task_with_privates->privates->baseptrs,
8880 // offload_baseptrs, sizeof(offload_baseptrs)
8881 // memcpy(target_task_with_privates->privates->ptrs,
8882 // offload_ptrs, sizeof(offload_ptrs)
8883 // memcpy(target_task_with_privates->privates->sizes,
8884 // offload_sizes, sizeof(offload_sizes)
8885 // dependencies_array = ...
8886 // ;; if nowait not present
8887 // call @__kmpc_omp_wait_deps(..., dependencies_array)
8888 // call @__kmpc_omp_task_begin_if0(...)
8889 // call @ @.omp_target_task_proxy_func(i32 thread_id, ptr
8890 // %target_task_with_privates)
8891 // call @__kmpc_omp_task_complete_if0(...)
8892 // }
8893 //
8894 // define internal void @.omp_target_task_proxy_func(i32 %thread.id,
8895 // ptr %task) {
8896 // %structArg = alloca {ptr, ptr, ptr}
8897 // %task_ptr = getelementptr(%task, 0, 0)
8898 // %shared_data = load (getelementptr %task_ptr, 0, 0)
8899 // mempcy(%structArg, %shared_data, sizeof(%structArg))
8900 //
8901 // %offloading_arrays = getelementptr(%task, 0, 1)
8902 // %offload_baseptrs = getelementptr(%offloading_arrays, 0, 0)
8903 // %offload_ptrs = getelementptr(%offloading_arrays, 0, 1)
8904 // %offload_sizes = getelementptr(%offloading_arrays, 0, 2)
8905 // kernel_launch_function(%thread.id, %offload_baseptrs, %offload_ptrs,
8906 // %offload_sizes, %structArg)
8907 // }
8908 //
8909 // We need the proxy function because the signature of the task entry point
8910 // expected by kmpc_omp_task is always the same and will be different from
8911 // that of the kernel_launch function.
8912 //
8913 // kernel_launch_function is generated by emitKernelLaunch and has the
8914 // always_inline attribute. For this example, it'll look like so:
8915 // void kernel_launch_function(%thread_id, %offload_baseptrs, %offload_ptrs,
8916 // %offload_sizes, %structArg) alwaysinline {
8917 // %kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
8918 // ; load aggregated data from %structArg
8919 // ; setup kernel_args using offload_baseptrs, offload_ptrs and
8920 // ; offload_sizes
8921 // call i32 @__tgt_target_kernel(...,
8922 // outlined_device_function,
8923 // ptr %kernel_args)
8924 // }
8925 // void outlined_device_function(ptr a, ptr b, ptr n) {
8926 // n = *n_ptr;
8927 // do i = 1, 10
8928 // a(i) = b(i) + n
8929 // }
8930 //
8931 BasicBlock *TargetTaskBodyBB =
8932 splitBB(Builder, /*CreateBranch=*/true, Name: "target.task.body");
8933 BasicBlock *TargetTaskAllocaBB =
8934 splitBB(Builder, /*CreateBranch=*/true, Name: "target.task.alloca");
8935
8936 InsertPointTy TargetTaskAllocaIP(TargetTaskAllocaBB,
8937 TargetTaskAllocaBB->begin());
8938 InsertPointTy TargetTaskBodyIP(TargetTaskBodyBB, TargetTaskBodyBB->begin());
8939
8940 OutlineInfo OI;
8941 OI.EntryBB = TargetTaskAllocaBB;
8942 OI.OuterAllocaBB = AllocaIP.getBlock();
8943
8944 // Add the thread ID argument.
8945 SmallVector<Instruction *, 4> ToBeDeleted;
8946 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
8947 Builder, OuterAllocaIP: AllocaIP, ToBeDeleted, InnerAllocaIP: TargetTaskAllocaIP, Name: "global.tid", AsPtr: false));
8948
8949 // Generate the task body which will subsequently be outlined.
8950 Builder.restoreIP(IP: TargetTaskBodyIP);
8951 if (Error Err = TaskBodyCB(DeviceID, RTLoc, TargetTaskAllocaIP))
8952 return Err;
8953
8954 // The outliner (CodeExtractor) extract a sequence or vector of blocks that
8955 // it is given. These blocks are enumerated by
8956 // OpenMPIRBuilder::OutlineInfo::collectBlocks which expects the OI.ExitBlock
8957 // to be outside the region. In other words, OI.ExitBlock is expected to be
8958 // the start of the region after the outlining. We used to set OI.ExitBlock
8959 // to the InsertBlock after TaskBodyCB is done. This is fine in most cases
8960 // except when the task body is a single basic block. In that case,
8961 // OI.ExitBlock is set to the single task body block and will get left out of
8962 // the outlining process. So, simply create a new empty block to which we
8963 // uncoditionally branch from where TaskBodyCB left off
8964 OI.ExitBB = BasicBlock::Create(Context&: Builder.getContext(), Name: "target.task.cont");
8965 emitBlock(BB: OI.ExitBB, CurFn: Builder.GetInsertBlock()->getParent(),
8966 /*IsFinished=*/true);
8967
8968 SmallVector<Value *, 2> OffloadingArraysToPrivatize;
8969 bool NeedsTargetTask = HasNoWait && DeviceID;
8970 if (NeedsTargetTask) {
8971 for (auto *V :
8972 {RTArgs.BasePointersArray, RTArgs.PointersArray, RTArgs.MappersArray,
8973 RTArgs.MapNamesArray, RTArgs.MapTypesArray, RTArgs.MapTypesArrayEnd,
8974 RTArgs.SizesArray}) {
8975 if (V && !isa<ConstantPointerNull, GlobalVariable>(Val: V)) {
8976 OffloadingArraysToPrivatize.push_back(Elt: V);
8977 OI.ExcludeArgsFromAggregate.push_back(Elt: V);
8978 }
8979 }
8980 }
8981 OI.PostOutlineCB = [this, ToBeDeleted, Dependencies, NeedsTargetTask,
8982 DeviceID, OffloadingArraysToPrivatize](
8983 Function &OutlinedFn) mutable {
8984 assert(OutlinedFn.hasOneUse() &&
8985 "there must be a single user for the outlined function");
8986
8987 CallInst *StaleCI = cast<CallInst>(Val: OutlinedFn.user_back());
8988
8989 // The first argument of StaleCI is always the thread id.
8990 // The next few arguments are the pointers to offloading arrays
8991 // if any. (see OffloadingArraysToPrivatize)
8992 // Finally, all other local values that are live-in into the outlined region
8993 // end up in a structure whose pointer is passed as the last argument. This
8994 // piece of data is passed in the "shared" field of the task structure. So,
8995 // we know we have to pass shareds to the task if the number of arguments is
8996 // greater than OffloadingArraysToPrivatize.size() + 1 The 1 is for the
8997 // thread id. Further, for safety, we assert that the number of arguments of
8998 // StaleCI is exactly OffloadingArraysToPrivatize.size() + 2
8999 const unsigned int NumStaleCIArgs = StaleCI->arg_size();
9000 bool HasShareds = NumStaleCIArgs > OffloadingArraysToPrivatize.size() + 1;
9001 assert((!HasShareds ||
9002 NumStaleCIArgs == (OffloadingArraysToPrivatize.size() + 2)) &&
9003 "Wrong number of arguments for StaleCI when shareds are present");
9004 int SharedArgOperandNo =
9005 HasShareds ? OffloadingArraysToPrivatize.size() + 1 : 0;
9006
9007 StructType *TaskWithPrivatesTy =
9008 createTaskWithPrivatesTy(OMPIRBuilder&: *this, OffloadingArraysToPrivatize);
9009 StructType *PrivatesTy = nullptr;
9010
9011 if (!OffloadingArraysToPrivatize.empty())
9012 PrivatesTy =
9013 static_cast<StructType *>(TaskWithPrivatesTy->getElementType(N: 1));
9014
9015 Function *ProxyFn = emitTargetTaskProxyFunction(
9016 OMPBuilder&: *this, Builder, StaleCI, PrivatesTy, TaskWithPrivatesTy,
9017 NumOffloadingArrays: OffloadingArraysToPrivatize.size(), SharedArgsOperandNo: SharedArgOperandNo);
9018
9019 LLVM_DEBUG(dbgs() << "Proxy task entry function created: " << *ProxyFn
9020 << "\n");
9021
9022 Builder.SetInsertPoint(StaleCI);
9023
9024 // Gather the arguments for emitting the runtime call.
9025 uint32_t SrcLocStrSize;
9026 Constant *SrcLocStr =
9027 getOrCreateSrcLocStr(Loc: LocationDescription(Builder), SrcLocStrSize);
9028 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
9029
9030 // @__kmpc_omp_task_alloc or @__kmpc_omp_target_task_alloc
9031 //
9032 // If `HasNoWait == true`, we call @__kmpc_omp_target_task_alloc to provide
9033 // the DeviceID to the deferred task and also since
9034 // @__kmpc_omp_target_task_alloc creates an untied/async task.
9035 Function *TaskAllocFn =
9036 !NeedsTargetTask
9037 ? getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_alloc)
9038 : getOrCreateRuntimeFunctionPtr(
9039 FnID: OMPRTL___kmpc_omp_target_task_alloc);
9040
9041 // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
9042 // call.
9043 Value *ThreadID = getOrCreateThreadID(Ident);
9044
9045 // Argument - `sizeof_kmp_task_t` (TaskSize)
9046 // Tasksize refers to the size in bytes of kmp_task_t data structure
9047 // plus any other data to be passed to the target task, if any, which
9048 // is packed into a struct. kmp_task_t and the struct so created are
9049 // packed into a wrapper struct whose type is TaskWithPrivatesTy.
9050 Value *TaskSize = Builder.getInt64(
9051 C: M.getDataLayout().getTypeStoreSize(Ty: TaskWithPrivatesTy));
9052
9053 // Argument - `sizeof_shareds` (SharedsSize)
9054 // SharedsSize refers to the shareds array size in the kmp_task_t data
9055 // structure.
9056 Value *SharedsSize = Builder.getInt64(C: 0);
9057 if (HasShareds) {
9058 auto *ArgStructAlloca =
9059 dyn_cast<AllocaInst>(Val: StaleCI->getArgOperand(i: SharedArgOperandNo));
9060 assert(ArgStructAlloca &&
9061 "Unable to find the alloca instruction corresponding to arguments "
9062 "for extracted function");
9063 std::optional<TypeSize> ArgAllocSize =
9064 ArgStructAlloca->getAllocationSize(DL: M.getDataLayout());
9065 assert(ArgAllocSize &&
9066 "Unable to determine size of arguments for extracted function");
9067 SharedsSize = Builder.getInt64(C: ArgAllocSize->getFixedValue());
9068 }
9069
9070 // Argument - `flags`
9071 // Task is tied iff (Flags & 1) == 1.
9072 // Task is untied iff (Flags & 1) == 0.
9073 // Task is final iff (Flags & 2) == 2.
9074 // Task is not final iff (Flags & 2) == 0.
9075 // A target task is not final and is untied.
9076 Value *Flags = Builder.getInt32(C: 0);
9077
9078 // Emit the @__kmpc_omp_task_alloc runtime call
9079 // The runtime call returns a pointer to an area where the task captured
9080 // variables must be copied before the task is run (TaskData)
9081 CallInst *TaskData = nullptr;
9082
9083 SmallVector<llvm::Value *> TaskAllocArgs = {
9084 /*loc_ref=*/Ident, /*gtid=*/ThreadID,
9085 /*flags=*/Flags,
9086 /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
9087 /*task_func=*/ProxyFn};
9088
9089 if (NeedsTargetTask) {
9090 assert(DeviceID && "Expected non-empty device ID.");
9091 TaskAllocArgs.push_back(Elt: DeviceID);
9092 }
9093
9094 TaskData = createRuntimeFunctionCall(Callee: TaskAllocFn, Args: TaskAllocArgs);
9095
9096 Align Alignment = TaskData->getPointerAlignment(DL: M.getDataLayout());
9097 if (HasShareds) {
9098 Value *Shareds = StaleCI->getArgOperand(i: SharedArgOperandNo);
9099 Value *TaskShareds = loadSharedDataFromTaskDescriptor(
9100 OMPIRBuilder&: *this, Builder, TaskWithPrivates: TaskData, TaskWithPrivatesTy);
9101 Builder.CreateMemCpy(Dst: TaskShareds, DstAlign: Alignment, Src: Shareds, SrcAlign: Alignment,
9102 Size: SharedsSize);
9103 }
9104 if (!OffloadingArraysToPrivatize.empty()) {
9105 Value *Privates =
9106 Builder.CreateStructGEP(Ty: TaskWithPrivatesTy, Ptr: TaskData, Idx: 1);
9107 for (unsigned int i = 0; i < OffloadingArraysToPrivatize.size(); ++i) {
9108 Value *PtrToPrivatize = OffloadingArraysToPrivatize[i];
9109 [[maybe_unused]] Type *ArrayType =
9110 getOffloadingArrayType(V: PtrToPrivatize);
9111 assert(ArrayType && "ArrayType cannot be nullptr");
9112
9113 Type *ElementType = PrivatesTy->getElementType(N: i);
9114 assert(ElementType == ArrayType &&
9115 "ElementType should match ArrayType");
9116 (void)ArrayType;
9117
9118 Value *Dst = Builder.CreateStructGEP(Ty: PrivatesTy, Ptr: Privates, Idx: i);
9119 Builder.CreateMemCpy(
9120 Dst, DstAlign: Alignment, Src: PtrToPrivatize, SrcAlign: Alignment,
9121 Size: Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: ElementType)));
9122 }
9123 }
9124
9125 Value *DepArray = emitTaskDependencies(OMPBuilder&: *this, Dependencies);
9126
9127 // ---------------------------------------------------------------
9128 // V5.2 13.8 target construct
9129 // If the nowait clause is present, execution of the target task
9130 // may be deferred. If the nowait clause is not present, the target task is
9131 // an included task.
9132 // ---------------------------------------------------------------
9133 // The above means that the lack of a nowait on the target construct
9134 // translates to '#pragma omp task if(0)'
9135 if (!NeedsTargetTask) {
9136 if (DepArray) {
9137 Function *TaskWaitFn =
9138 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_wait_deps);
9139 createRuntimeFunctionCall(
9140 Callee: TaskWaitFn,
9141 Args: {/*loc_ref=*/Ident, /*gtid=*/ThreadID,
9142 /*ndeps=*/Builder.getInt32(C: Dependencies.size()),
9143 /*dep_list=*/DepArray,
9144 /*ndeps_noalias=*/ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
9145 /*noalias_dep_list=*/
9146 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
9147 }
9148 // Included task.
9149 Function *TaskBeginFn =
9150 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_begin_if0);
9151 Function *TaskCompleteFn =
9152 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_complete_if0);
9153 createRuntimeFunctionCall(Callee: TaskBeginFn, Args: {Ident, ThreadID, TaskData});
9154 CallInst *CI = createRuntimeFunctionCall(Callee: ProxyFn, Args: {ThreadID, TaskData});
9155 CI->setDebugLoc(StaleCI->getDebugLoc());
9156 createRuntimeFunctionCall(Callee: TaskCompleteFn, Args: {Ident, ThreadID, TaskData});
9157 } else if (DepArray) {
9158 // HasNoWait - meaning the task may be deferred. Call
9159 // __kmpc_omp_task_with_deps if there are dependencies,
9160 // else call __kmpc_omp_task
9161 Function *TaskFn =
9162 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_with_deps);
9163 createRuntimeFunctionCall(
9164 Callee: TaskFn,
9165 Args: {Ident, ThreadID, TaskData, Builder.getInt32(C: Dependencies.size()),
9166 DepArray, ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
9167 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
9168 } else {
9169 // Emit the @__kmpc_omp_task runtime call to spawn the task
9170 Function *TaskFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task);
9171 createRuntimeFunctionCall(Callee: TaskFn, Args: {Ident, ThreadID, TaskData});
9172 }
9173
9174 StaleCI->eraseFromParent();
9175 for (Instruction *I : llvm::reverse(C&: ToBeDeleted))
9176 I->eraseFromParent();
9177 };
9178 addOutlineInfo(OI: std::move(OI));
9179
9180 LLVM_DEBUG(dbgs() << "Insert block after emitKernelLaunch = \n"
9181 << *(Builder.GetInsertBlock()) << "\n");
9182 LLVM_DEBUG(dbgs() << "Module after emitKernelLaunch = \n"
9183 << *(Builder.GetInsertBlock()->getParent()->getParent())
9184 << "\n");
9185 return Builder.saveIP();
9186}
9187
9188Error OpenMPIRBuilder::emitOffloadingArraysAndArgs(
9189 InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
9190 TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo,
9191 CustomMapperCallbackTy CustomMapperCB, bool IsNonContiguous,
9192 bool ForEndCall, function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
9193 if (Error Err =
9194 emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info,
9195 CustomMapperCB, IsNonContiguous, DeviceAddrCB))
9196 return Err;
9197 emitOffloadingArraysArgument(Builder, RTArgs, Info, ForEndCall);
9198 return Error::success();
9199}
9200
9201static void emitTargetCall(
9202 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
9203 OpenMPIRBuilder::InsertPointTy AllocaIP,
9204 OpenMPIRBuilder::TargetDataInfo &Info,
9205 const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
9206 const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
9207 Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID,
9208 SmallVectorImpl<Value *> &Args,
9209 OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
9210 OpenMPIRBuilder::CustomMapperCallbackTy CustomMapperCB,
9211 const SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
9212 bool HasNoWait, Value *DynCGroupMem,
9213 OMPDynGroupprivateFallbackType DynCGroupMemFallback) {
9214 // Generate a function call to the host fallback implementation of the target
9215 // region. This is called by the host when no offload entry was generated for
9216 // the target region and when the offloading call fails at runtime.
9217 auto &&EmitTargetCallFallbackCB = [&](OpenMPIRBuilder::InsertPointTy IP)
9218 -> OpenMPIRBuilder::InsertPointOrErrorTy {
9219 Builder.restoreIP(IP);
9220 // Ensure the host fallback has the same dyn_ptr ABI as the device.
9221 SmallVector<Value *> FallbackArgs(Args.begin(), Args.end());
9222 FallbackArgs.push_back(
9223 Elt: Constant::getNullValue(Ty: PointerType::getUnqual(C&: Builder.getContext())));
9224 OMPBuilder.createRuntimeFunctionCall(Callee: OutlinedFn, Args: FallbackArgs);
9225 return Builder.saveIP();
9226 };
9227
9228 bool HasDependencies = Dependencies.size() > 0;
9229 bool RequiresOuterTargetTask = HasNoWait || HasDependencies;
9230
9231 OpenMPIRBuilder::TargetKernelArgs KArgs;
9232
9233 auto TaskBodyCB =
9234 [&](Value *DeviceID, Value *RTLoc,
9235 IRBuilderBase::InsertPoint TargetTaskAllocaIP) -> Error {
9236 // Assume no error was returned because EmitTargetCallFallbackCB doesn't
9237 // produce any.
9238 llvm::OpenMPIRBuilder::InsertPointTy AfterIP = cantFail(ValOrErr: [&]() {
9239 // emitKernelLaunch makes the necessary runtime call to offload the
9240 // kernel. We then outline all that code into a separate function
9241 // ('kernel_launch_function' in the pseudo code above). This function is
9242 // then called by the target task proxy function (see
9243 // '@.omp_target_task_proxy_func' in the pseudo code above)
9244 // "@.omp_target_task_proxy_func' is generated by
9245 // emitTargetTaskProxyFunction.
9246 if (OutlinedFnID && DeviceID)
9247 return OMPBuilder.emitKernelLaunch(Loc: Builder, OutlinedFnID,
9248 EmitTargetCallFallbackCB, Args&: KArgs,
9249 DeviceID, RTLoc, AllocaIP: TargetTaskAllocaIP);
9250
9251 // We only need to do the outlining if `DeviceID` is set to avoid calling
9252 // `emitKernelLaunch` if we want to code-gen for the host; e.g. if we are
9253 // generating the `else` branch of an `if` clause.
9254 //
9255 // When OutlinedFnID is set to nullptr, then it's not an offloading call.
9256 // In this case, we execute the host implementation directly.
9257 return EmitTargetCallFallbackCB(OMPBuilder.Builder.saveIP());
9258 }());
9259
9260 OMPBuilder.Builder.restoreIP(IP: AfterIP);
9261 return Error::success();
9262 };
9263
9264 auto &&EmitTargetCallElse =
9265 [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
9266 OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
9267 // Assume no error was returned because EmitTargetCallFallbackCB doesn't
9268 // produce any.
9269 OpenMPIRBuilder::InsertPointTy AfterIP = cantFail(ValOrErr: [&]() {
9270 if (RequiresOuterTargetTask) {
9271 // Arguments that are intended to be directly forwarded to an
9272 // emitKernelLaunch call are pased as nullptr, since
9273 // OutlinedFnID=nullptr results in that call not being done.
9274 OpenMPIRBuilder::TargetDataRTArgs EmptyRTArgs;
9275 return OMPBuilder.emitTargetTask(TaskBodyCB, /*DeviceID=*/nullptr,
9276 /*RTLoc=*/nullptr, AllocaIP,
9277 Dependencies, RTArgs: EmptyRTArgs, HasNoWait);
9278 }
9279 return EmitTargetCallFallbackCB(Builder.saveIP());
9280 }());
9281
9282 Builder.restoreIP(IP: AfterIP);
9283 return Error::success();
9284 };
9285
9286 auto &&EmitTargetCallThen =
9287 [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
9288 OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
9289 Info.HasNoWait = HasNoWait;
9290 OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
9291
9292 OpenMPIRBuilder::TargetDataRTArgs RTArgs;
9293 if (Error Err = OMPBuilder.emitOffloadingArraysAndArgs(
9294 AllocaIP, CodeGenIP: Builder.saveIP(), Info, RTArgs, CombinedInfo&: MapInfo, CustomMapperCB,
9295 /*IsNonContiguous=*/true,
9296 /*ForEndCall=*/false))
9297 return Err;
9298
9299 SmallVector<Value *, 3> NumTeamsC;
9300 for (auto [DefaultVal, RuntimeVal] :
9301 zip_equal(t: DefaultAttrs.MaxTeams, u: RuntimeAttrs.MaxTeams))
9302 NumTeamsC.push_back(Elt: RuntimeVal ? RuntimeVal
9303 : Builder.getInt32(C: DefaultVal));
9304
9305 // Calculate number of threads: 0 if no clauses specified, otherwise it is
9306 // the minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
9307 auto InitMaxThreadsClause = [&Builder](Value *Clause) {
9308 if (Clause)
9309 Clause = Builder.CreateIntCast(V: Clause, DestTy: Builder.getInt32Ty(),
9310 /*isSigned=*/false);
9311 return Clause;
9312 };
9313 auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
9314 if (Clause)
9315 Result =
9316 Result ? Builder.CreateSelect(C: Builder.CreateICmpULT(LHS: Result, RHS: Clause),
9317 True: Result, False: Clause)
9318 : Clause;
9319 };
9320
9321 // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
9322 // the NUM_THREADS clause is overriden by THREAD_LIMIT.
9323 SmallVector<Value *, 3> NumThreadsC;
9324 Value *MaxThreadsClause =
9325 RuntimeAttrs.TeamsThreadLimit.size() == 1
9326 ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
9327 : nullptr;
9328
9329 for (auto [TeamsVal, TargetVal] : zip_equal(
9330 t: RuntimeAttrs.TeamsThreadLimit, u: RuntimeAttrs.TargetThreadLimit)) {
9331 Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
9332 Value *NumThreads = InitMaxThreadsClause(TargetVal);
9333
9334 CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
9335 CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
9336
9337 NumThreadsC.push_back(Elt: NumThreads ? NumThreads : Builder.getInt32(C: 0));
9338 }
9339
9340 unsigned NumTargetItems = Info.NumberOfPtrs;
9341 uint32_t SrcLocStrSize;
9342 Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
9343 Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
9344 LocFlags: llvm::omp::IdentFlag(0), Reserve2Flags: 0);
9345
9346 Value *TripCount = RuntimeAttrs.LoopTripCount
9347 ? Builder.CreateIntCast(V: RuntimeAttrs.LoopTripCount,
9348 DestTy: Builder.getInt64Ty(),
9349 /*isSigned=*/false)
9350 : Builder.getInt64(C: 0);
9351
9352 // Request zero groupprivate bytes by default.
9353 if (!DynCGroupMem)
9354 DynCGroupMem = Builder.getInt32(C: 0);
9355
9356 KArgs = OpenMPIRBuilder::TargetKernelArgs(
9357 NumTargetItems, RTArgs, TripCount, NumTeamsC, NumThreadsC, DynCGroupMem,
9358 HasNoWait, DynCGroupMemFallback);
9359
9360 // Assume no error was returned because TaskBodyCB and
9361 // EmitTargetCallFallbackCB don't produce any.
9362 OpenMPIRBuilder::InsertPointTy AfterIP = cantFail(ValOrErr: [&]() {
9363 // The presence of certain clauses on the target directive require the
9364 // explicit generation of the target task.
9365 if (RequiresOuterTargetTask)
9366 return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID: RuntimeAttrs.DeviceID,
9367 RTLoc, AllocaIP, Dependencies,
9368 RTArgs: KArgs.RTArgs, HasNoWait: Info.HasNoWait);
9369
9370 return OMPBuilder.emitKernelLaunch(
9371 Loc: Builder, OutlinedFnID, EmitTargetCallFallbackCB, Args&: KArgs,
9372 DeviceID: RuntimeAttrs.DeviceID, RTLoc, AllocaIP);
9373 }());
9374
9375 Builder.restoreIP(IP: AfterIP);
9376 return Error::success();
9377 };
9378
9379 // If we don't have an ID for the target region, it means an offload entry
9380 // wasn't created. In this case we just run the host fallback directly and
9381 // ignore any potential 'if' clauses.
9382 if (!OutlinedFnID) {
9383 cantFail(Err: EmitTargetCallElse(AllocaIP, Builder.saveIP()));
9384 return;
9385 }
9386
9387 // If there's no 'if' clause, only generate the kernel launch code path.
9388 if (!IfCond) {
9389 cantFail(Err: EmitTargetCallThen(AllocaIP, Builder.saveIP()));
9390 return;
9391 }
9392
9393 cantFail(Err: OMPBuilder.emitIfClause(Cond: IfCond, ThenGen: EmitTargetCallThen,
9394 ElseGen: EmitTargetCallElse, AllocaIP));
9395}
9396
9397OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
9398 const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
9399 InsertPointTy CodeGenIP, TargetDataInfo &Info,
9400 TargetRegionEntryInfo &EntryInfo,
9401 const TargetKernelDefaultAttrs &DefaultAttrs,
9402 const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
9403 SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
9404 OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
9405 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
9406 CustomMapperCallbackTy CustomMapperCB,
9407 const SmallVector<DependData> &Dependencies, bool HasNowait,
9408 Value *DynCGroupMem, OMPDynGroupprivateFallbackType DynCGroupMemFallback) {
9409
9410 if (!updateToLocation(Loc))
9411 return InsertPointTy();
9412
9413 Builder.restoreIP(IP: CodeGenIP);
9414
9415 Function *OutlinedFn;
9416 Constant *OutlinedFnID = nullptr;
9417 // The target region is outlined into its own function. The LLVM IR for
9418 // the target region itself is generated using the callbacks CBFunc
9419 // and ArgAccessorFuncCB
9420 if (Error Err = emitTargetOutlinedFunction(
9421 OMPBuilder&: *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
9422 OutlinedFnID, Inputs, CBFunc, ArgAccessorFuncCB))
9423 return Err;
9424
9425 // If we are not on the target device, then we need to generate code
9426 // to make a remote call (offload) to the previously outlined function
9427 // that represents the target region. Do that now.
9428 if (!Config.isTargetDevice())
9429 emitTargetCall(OMPBuilder&: *this, Builder, AllocaIP, Info, DefaultAttrs, RuntimeAttrs,
9430 IfCond, OutlinedFn, OutlinedFnID, Args&: Inputs, GenMapInfoCB,
9431 CustomMapperCB, Dependencies, HasNoWait: HasNowait, DynCGroupMem,
9432 DynCGroupMemFallback);
9433 return Builder.saveIP();
9434}
9435
9436std::string OpenMPIRBuilder::getNameWithSeparators(ArrayRef<StringRef> Parts,
9437 StringRef FirstSeparator,
9438 StringRef Separator) {
9439 SmallString<128> Buffer;
9440 llvm::raw_svector_ostream OS(Buffer);
9441 StringRef Sep = FirstSeparator;
9442 for (StringRef Part : Parts) {
9443 OS << Sep << Part;
9444 Sep = Separator;
9445 }
9446 return OS.str().str();
9447}
9448
9449std::string
9450OpenMPIRBuilder::createPlatformSpecificName(ArrayRef<StringRef> Parts) const {
9451 return OpenMPIRBuilder::getNameWithSeparators(Parts, FirstSeparator: Config.firstSeparator(),
9452 Separator: Config.separator());
9453}
9454
9455GlobalVariable *OpenMPIRBuilder::getOrCreateInternalVariable(
9456 Type *Ty, const StringRef &Name, std::optional<unsigned> AddressSpace) {
9457 auto &Elem = *InternalVars.try_emplace(Key: Name, Args: nullptr).first;
9458 if (Elem.second) {
9459 assert(Elem.second->getValueType() == Ty &&
9460 "OMP internal variable has different type than requested");
9461 } else {
9462 // TODO: investigate the appropriate linkage type used for the global
9463 // variable for possibly changing that to internal or private, or maybe
9464 // create different versions of the function for different OMP internal
9465 // variables.
9466 const DataLayout &DL = M.getDataLayout();
9467 // TODO: Investigate why AMDGPU expects AS 0 for globals even though the
9468 // default global AS is 1.
9469 // See double-target-call-with-declare-target.f90 and
9470 // declare-target-vars-in-target-region.f90 libomptarget
9471 // tests.
9472 unsigned AddressSpaceVal = AddressSpace ? *AddressSpace
9473 : M.getTargetTriple().isAMDGPU()
9474 ? 0
9475 : DL.getDefaultGlobalsAddressSpace();
9476 auto Linkage = this->M.getTargetTriple().getArch() == Triple::wasm32
9477 ? GlobalValue::InternalLinkage
9478 : GlobalValue::CommonLinkage;
9479 auto *GV = new GlobalVariable(M, Ty, /*IsConstant=*/false, Linkage,
9480 Constant::getNullValue(Ty), Elem.first(),
9481 /*InsertBefore=*/nullptr,
9482 GlobalValue::NotThreadLocal, AddressSpaceVal);
9483 const llvm::Align TypeAlign = DL.getABITypeAlign(Ty);
9484 const llvm::Align PtrAlign = DL.getPointerABIAlignment(AS: AddressSpaceVal);
9485 GV->setAlignment(std::max(a: TypeAlign, b: PtrAlign));
9486 Elem.second = GV;
9487 }
9488
9489 return Elem.second;
9490}
9491
9492Value *OpenMPIRBuilder::getOMPCriticalRegionLock(StringRef CriticalName) {
9493 std::string Prefix = Twine("gomp_critical_user_", CriticalName).str();
9494 std::string Name = getNameWithSeparators(Parts: {Prefix, "var"}, FirstSeparator: ".", Separator: ".");
9495 return getOrCreateInternalVariable(Ty: KmpCriticalNameTy, Name);
9496}
9497
9498Value *OpenMPIRBuilder::getSizeInBytes(Value *BasePtr) {
9499 LLVMContext &Ctx = Builder.getContext();
9500 Value *Null =
9501 Constant::getNullValue(Ty: PointerType::getUnqual(C&: BasePtr->getContext()));
9502 Value *SizeGep =
9503 Builder.CreateGEP(Ty: BasePtr->getType(), Ptr: Null, IdxList: Builder.getInt32(C: 1));
9504 Value *SizePtrToInt = Builder.CreatePtrToInt(V: SizeGep, DestTy: Type::getInt64Ty(C&: Ctx));
9505 return SizePtrToInt;
9506}
9507
9508GlobalVariable *
9509OpenMPIRBuilder::createOffloadMaptypes(SmallVectorImpl<uint64_t> &Mappings,
9510 std::string VarName) {
9511 llvm::Constant *MaptypesArrayInit =
9512 llvm::ConstantDataArray::get(Context&: M.getContext(), Elts&: Mappings);
9513 auto *MaptypesArrayGlobal = new llvm::GlobalVariable(
9514 M, MaptypesArrayInit->getType(),
9515 /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MaptypesArrayInit,
9516 VarName);
9517 MaptypesArrayGlobal->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
9518 return MaptypesArrayGlobal;
9519}
9520
9521void OpenMPIRBuilder::createMapperAllocas(const LocationDescription &Loc,
9522 InsertPointTy AllocaIP,
9523 unsigned NumOperands,
9524 struct MapperAllocas &MapperAllocas) {
9525 if (!updateToLocation(Loc))
9526 return;
9527
9528 auto *ArrI8PtrTy = ArrayType::get(ElementType: Int8Ptr, NumElements: NumOperands);
9529 auto *ArrI64Ty = ArrayType::get(ElementType: Int64, NumElements: NumOperands);
9530 Builder.restoreIP(IP: AllocaIP);
9531 AllocaInst *ArgsBase = Builder.CreateAlloca(
9532 Ty: ArrI8PtrTy, /* ArraySize = */ nullptr, Name: ".offload_baseptrs");
9533 AllocaInst *Args = Builder.CreateAlloca(Ty: ArrI8PtrTy, /* ArraySize = */ nullptr,
9534 Name: ".offload_ptrs");
9535 AllocaInst *ArgSizes = Builder.CreateAlloca(
9536 Ty: ArrI64Ty, /* ArraySize = */ nullptr, Name: ".offload_sizes");
9537 updateToLocation(Loc);
9538 MapperAllocas.ArgsBase = ArgsBase;
9539 MapperAllocas.Args = Args;
9540 MapperAllocas.ArgSizes = ArgSizes;
9541}
9542
9543void OpenMPIRBuilder::emitMapperCall(const LocationDescription &Loc,
9544 Function *MapperFunc, Value *SrcLocInfo,
9545 Value *MaptypesArg, Value *MapnamesArg,
9546 struct MapperAllocas &MapperAllocas,
9547 int64_t DeviceID, unsigned NumOperands) {
9548 if (!updateToLocation(Loc))
9549 return;
9550
9551 auto *ArrI8PtrTy = ArrayType::get(ElementType: Int8Ptr, NumElements: NumOperands);
9552 auto *ArrI64Ty = ArrayType::get(ElementType: Int64, NumElements: NumOperands);
9553 Value *ArgsBaseGEP =
9554 Builder.CreateInBoundsGEP(Ty: ArrI8PtrTy, Ptr: MapperAllocas.ArgsBase,
9555 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 0)});
9556 Value *ArgsGEP =
9557 Builder.CreateInBoundsGEP(Ty: ArrI8PtrTy, Ptr: MapperAllocas.Args,
9558 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 0)});
9559 Value *ArgSizesGEP =
9560 Builder.CreateInBoundsGEP(Ty: ArrI64Ty, Ptr: MapperAllocas.ArgSizes,
9561 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 0)});
9562 Value *NullPtr =
9563 Constant::getNullValue(Ty: PointerType::getUnqual(C&: Int8Ptr->getContext()));
9564 createRuntimeFunctionCall(Callee: MapperFunc, Args: {SrcLocInfo, Builder.getInt64(C: DeviceID),
9565 Builder.getInt32(C: NumOperands),
9566 ArgsBaseGEP, ArgsGEP, ArgSizesGEP,
9567 MaptypesArg, MapnamesArg, NullPtr});
9568}
9569
9570void OpenMPIRBuilder::emitOffloadingArraysArgument(IRBuilderBase &Builder,
9571 TargetDataRTArgs &RTArgs,
9572 TargetDataInfo &Info,
9573 bool ForEndCall) {
9574 assert((!ForEndCall || Info.separateBeginEndCalls()) &&
9575 "expected region end call to runtime only when end call is separate");
9576 auto UnqualPtrTy = PointerType::getUnqual(C&: M.getContext());
9577 auto VoidPtrTy = UnqualPtrTy;
9578 auto VoidPtrPtrTy = UnqualPtrTy;
9579 auto Int64Ty = Type::getInt64Ty(C&: M.getContext());
9580 auto Int64PtrTy = UnqualPtrTy;
9581
9582 if (!Info.NumberOfPtrs) {
9583 RTArgs.BasePointersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
9584 RTArgs.PointersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
9585 RTArgs.SizesArray = ConstantPointerNull::get(T: Int64PtrTy);
9586 RTArgs.MapTypesArray = ConstantPointerNull::get(T: Int64PtrTy);
9587 RTArgs.MapNamesArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
9588 RTArgs.MappersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
9589 return;
9590 }
9591
9592 RTArgs.BasePointersArray = Builder.CreateConstInBoundsGEP2_32(
9593 Ty: ArrayType::get(ElementType: VoidPtrTy, NumElements: Info.NumberOfPtrs),
9594 Ptr: Info.RTArgs.BasePointersArray,
9595 /*Idx0=*/0, /*Idx1=*/0);
9596 RTArgs.PointersArray = Builder.CreateConstInBoundsGEP2_32(
9597 Ty: ArrayType::get(ElementType: VoidPtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.PointersArray,
9598 /*Idx0=*/0,
9599 /*Idx1=*/0);
9600 RTArgs.SizesArray = Builder.CreateConstInBoundsGEP2_32(
9601 Ty: ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.SizesArray,
9602 /*Idx0=*/0, /*Idx1=*/0);
9603 RTArgs.MapTypesArray = Builder.CreateConstInBoundsGEP2_32(
9604 Ty: ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs),
9605 Ptr: ForEndCall && Info.RTArgs.MapTypesArrayEnd ? Info.RTArgs.MapTypesArrayEnd
9606 : Info.RTArgs.MapTypesArray,
9607 /*Idx0=*/0,
9608 /*Idx1=*/0);
9609
9610 // Only emit the mapper information arrays if debug information is
9611 // requested.
9612 if (!Info.EmitDebug)
9613 RTArgs.MapNamesArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
9614 else
9615 RTArgs.MapNamesArray = Builder.CreateConstInBoundsGEP2_32(
9616 Ty: ArrayType::get(ElementType: VoidPtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.MapNamesArray,
9617 /*Idx0=*/0,
9618 /*Idx1=*/0);
9619 // If there is no user-defined mapper, set the mapper array to nullptr to
9620 // avoid an unnecessary data privatization
9621 if (!Info.HasMapper)
9622 RTArgs.MappersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
9623 else
9624 RTArgs.MappersArray =
9625 Builder.CreatePointerCast(V: Info.RTArgs.MappersArray, DestTy: VoidPtrPtrTy);
9626}
9627
9628void OpenMPIRBuilder::emitNonContiguousDescriptor(InsertPointTy AllocaIP,
9629 InsertPointTy CodeGenIP,
9630 MapInfosTy &CombinedInfo,
9631 TargetDataInfo &Info) {
9632 MapInfosTy::StructNonContiguousInfo &NonContigInfo =
9633 CombinedInfo.NonContigInfo;
9634
9635 // Build an array of struct descriptor_dim and then assign it to
9636 // offload_args.
9637 //
9638 // struct descriptor_dim {
9639 // uint64_t offset;
9640 // uint64_t count;
9641 // uint64_t stride
9642 // };
9643 Type *Int64Ty = Builder.getInt64Ty();
9644 StructType *DimTy = StructType::create(
9645 Context&: M.getContext(), Elements: ArrayRef<Type *>({Int64Ty, Int64Ty, Int64Ty}),
9646 Name: "struct.descriptor_dim");
9647
9648 enum { OffsetFD = 0, CountFD, StrideFD };
9649 // We need two index variable here since the size of "Dims" is the same as
9650 // the size of Components, however, the size of offset, count, and stride is
9651 // equal to the size of base declaration that is non-contiguous.
9652 for (unsigned I = 0, L = 0, E = NonContigInfo.Dims.size(); I < E; ++I) {
9653 // Skip emitting ir if dimension size is 1 since it cannot be
9654 // non-contiguous.
9655 if (NonContigInfo.Dims[I] == 1)
9656 continue;
9657 Builder.restoreIP(IP: AllocaIP);
9658 ArrayType *ArrayTy = ArrayType::get(ElementType: DimTy, NumElements: NonContigInfo.Dims[I]);
9659 AllocaInst *DimsAddr =
9660 Builder.CreateAlloca(Ty: ArrayTy, /* ArraySize = */ nullptr, Name: "dims");
9661 Builder.restoreIP(IP: CodeGenIP);
9662 for (unsigned II = 0, EE = NonContigInfo.Dims[I]; II < EE; ++II) {
9663 unsigned RevIdx = EE - II - 1;
9664 Value *DimsLVal = Builder.CreateInBoundsGEP(
9665 Ty: ArrayTy, Ptr: DimsAddr, IdxList: {Builder.getInt64(C: 0), Builder.getInt64(C: II)});
9666 // Offset
9667 Value *OffsetLVal = Builder.CreateStructGEP(Ty: DimTy, Ptr: DimsLVal, Idx: OffsetFD);
9668 Builder.CreateAlignedStore(
9669 Val: NonContigInfo.Offsets[L][RevIdx], Ptr: OffsetLVal,
9670 Align: M.getDataLayout().getPrefTypeAlign(Ty: OffsetLVal->getType()));
9671 // Count
9672 Value *CountLVal = Builder.CreateStructGEP(Ty: DimTy, Ptr: DimsLVal, Idx: CountFD);
9673 Builder.CreateAlignedStore(
9674 Val: NonContigInfo.Counts[L][RevIdx], Ptr: CountLVal,
9675 Align: M.getDataLayout().getPrefTypeAlign(Ty: CountLVal->getType()));
9676 // Stride
9677 Value *StrideLVal = Builder.CreateStructGEP(Ty: DimTy, Ptr: DimsLVal, Idx: StrideFD);
9678 Builder.CreateAlignedStore(
9679 Val: NonContigInfo.Strides[L][RevIdx], Ptr: StrideLVal,
9680 Align: M.getDataLayout().getPrefTypeAlign(Ty: CountLVal->getType()));
9681 }
9682 // args[I] = &dims
9683 Builder.restoreIP(IP: CodeGenIP);
9684 Value *DAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
9685 V: DimsAddr, DestTy: Builder.getPtrTy());
9686 Value *P = Builder.CreateConstInBoundsGEP2_32(
9687 Ty: ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: Info.NumberOfPtrs),
9688 Ptr: Info.RTArgs.PointersArray, Idx0: 0, Idx1: I);
9689 Builder.CreateAlignedStore(
9690 Val: DAddr, Ptr: P, Align: M.getDataLayout().getPrefTypeAlign(Ty: Builder.getPtrTy()));
9691 ++L;
9692 }
9693}
9694
9695void OpenMPIRBuilder::emitUDMapperArrayInitOrDel(
9696 Function *MapperFn, Value *MapperHandle, Value *Base, Value *Begin,
9697 Value *Size, Value *MapType, Value *MapName, TypeSize ElementSize,
9698 BasicBlock *ExitBB, bool IsInit) {
9699 StringRef Prefix = IsInit ? ".init" : ".del";
9700
9701 // Evaluate if this is an array section.
9702 BasicBlock *BodyBB = BasicBlock::Create(
9703 Context&: M.getContext(), Name: createPlatformSpecificName(Parts: {"omp.array", Prefix}));
9704 Value *IsArray =
9705 Builder.CreateICmpSGT(LHS: Size, RHS: Builder.getInt64(C: 1), Name: "omp.arrayinit.isarray");
9706 Value *DeleteBit = Builder.CreateAnd(
9707 LHS: MapType,
9708 RHS: Builder.getInt64(
9709 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9710 OpenMPOffloadMappingFlags::OMP_MAP_DELETE)));
9711 Value *DeleteCond;
9712 Value *Cond;
9713 if (IsInit) {
9714 // base != begin?
9715 Value *BaseIsBegin = Builder.CreateICmpNE(LHS: Base, RHS: Begin);
9716 Cond = Builder.CreateOr(LHS: IsArray, RHS: BaseIsBegin);
9717 DeleteCond = Builder.CreateIsNull(
9718 Arg: DeleteBit,
9719 Name: createPlatformSpecificName(Parts: {"omp.array", Prefix, ".delete"}));
9720 } else {
9721 Cond = IsArray;
9722 DeleteCond = Builder.CreateIsNotNull(
9723 Arg: DeleteBit,
9724 Name: createPlatformSpecificName(Parts: {"omp.array", Prefix, ".delete"}));
9725 }
9726 Cond = Builder.CreateAnd(LHS: Cond, RHS: DeleteCond);
9727 Builder.CreateCondBr(Cond, True: BodyBB, False: ExitBB);
9728
9729 emitBlock(BB: BodyBB, CurFn: MapperFn);
9730 // Get the array size by multiplying element size and element number (i.e., \p
9731 // Size).
9732 Value *ArraySize = Builder.CreateNUWMul(LHS: Size, RHS: Builder.getInt64(C: ElementSize));
9733 // Remove OMP_MAP_TO and OMP_MAP_FROM from the map type, so that it achieves
9734 // memory allocation/deletion purpose only.
9735 Value *MapTypeArg = Builder.CreateAnd(
9736 LHS: MapType,
9737 RHS: Builder.getInt64(
9738 C: ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9739 OpenMPOffloadMappingFlags::OMP_MAP_TO |
9740 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
9741 MapTypeArg = Builder.CreateOr(
9742 LHS: MapTypeArg,
9743 RHS: Builder.getInt64(
9744 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9745 OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT)));
9746
9747 // Call the runtime API __tgt_push_mapper_component to fill up the runtime
9748 // data structure.
9749 Value *OffloadingArgs[] = {MapperHandle, Base, Begin,
9750 ArraySize, MapTypeArg, MapName};
9751 createRuntimeFunctionCall(
9752 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___tgt_push_mapper_component),
9753 Args: OffloadingArgs);
9754}
9755
9756Expected<Function *> OpenMPIRBuilder::emitUserDefinedMapper(
9757 function_ref<MapInfosOrErrorTy(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
9758 llvm::Value *BeginArg)>
9759 GenMapInfoCB,
9760 Type *ElemTy, StringRef FuncName, CustomMapperCallbackTy CustomMapperCB) {
9761 SmallVector<Type *> Params;
9762 Params.emplace_back(Args: Builder.getPtrTy());
9763 Params.emplace_back(Args: Builder.getPtrTy());
9764 Params.emplace_back(Args: Builder.getPtrTy());
9765 Params.emplace_back(Args: Builder.getInt64Ty());
9766 Params.emplace_back(Args: Builder.getInt64Ty());
9767 Params.emplace_back(Args: Builder.getPtrTy());
9768
9769 auto *FnTy =
9770 FunctionType::get(Result: Builder.getVoidTy(), Params, /* IsVarArg */ isVarArg: false);
9771
9772 SmallString<64> TyStr;
9773 raw_svector_ostream Out(TyStr);
9774 Function *MapperFn =
9775 Function::Create(Ty: FnTy, Linkage: GlobalValue::InternalLinkage, N: FuncName, M);
9776 MapperFn->addFnAttr(Kind: Attribute::NoInline);
9777 MapperFn->addFnAttr(Kind: Attribute::NoUnwind);
9778 MapperFn->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
9779 MapperFn->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
9780 MapperFn->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
9781 MapperFn->addParamAttr(ArgNo: 3, Kind: Attribute::NoUndef);
9782 MapperFn->addParamAttr(ArgNo: 4, Kind: Attribute::NoUndef);
9783 MapperFn->addParamAttr(ArgNo: 5, Kind: Attribute::NoUndef);
9784
9785 // Start the mapper function code generation.
9786 BasicBlock *EntryBB = BasicBlock::Create(Context&: M.getContext(), Name: "entry", Parent: MapperFn);
9787 auto SavedIP = Builder.saveIP();
9788 Builder.SetInsertPoint(EntryBB);
9789
9790 Value *MapperHandle = MapperFn->getArg(i: 0);
9791 Value *BaseIn = MapperFn->getArg(i: 1);
9792 Value *BeginIn = MapperFn->getArg(i: 2);
9793 Value *Size = MapperFn->getArg(i: 3);
9794 Value *MapType = MapperFn->getArg(i: 4);
9795 Value *MapName = MapperFn->getArg(i: 5);
9796
9797 // Compute the starting and end addresses of array elements.
9798 // Prepare common arguments for array initiation and deletion.
9799 // Convert the size in bytes into the number of array elements.
9800 TypeSize ElementSize = M.getDataLayout().getTypeStoreSize(Ty: ElemTy);
9801 Size = Builder.CreateExactUDiv(LHS: Size, RHS: Builder.getInt64(C: ElementSize));
9802 Value *PtrBegin = BeginIn;
9803 Value *PtrEnd = Builder.CreateGEP(Ty: ElemTy, Ptr: PtrBegin, IdxList: Size);
9804
9805 // Emit array initiation if this is an array section and \p MapType indicates
9806 // that memory allocation is required.
9807 BasicBlock *HeadBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.arraymap.head");
9808 emitUDMapperArrayInitOrDel(MapperFn, MapperHandle, Base: BaseIn, Begin: BeginIn, Size,
9809 MapType, MapName, ElementSize, ExitBB: HeadBB,
9810 /*IsInit=*/true);
9811
9812 // Emit a for loop to iterate through SizeArg of elements and map all of them.
9813
9814 // Emit the loop header block.
9815 emitBlock(BB: HeadBB, CurFn: MapperFn);
9816 BasicBlock *BodyBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.arraymap.body");
9817 BasicBlock *DoneBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.done");
9818 // Evaluate whether the initial condition is satisfied.
9819 Value *IsEmpty =
9820 Builder.CreateICmpEQ(LHS: PtrBegin, RHS: PtrEnd, Name: "omp.arraymap.isempty");
9821 Builder.CreateCondBr(Cond: IsEmpty, True: DoneBB, False: BodyBB);
9822
9823 // Emit the loop body block.
9824 emitBlock(BB: BodyBB, CurFn: MapperFn);
9825 BasicBlock *LastBB = BodyBB;
9826 PHINode *PtrPHI =
9827 Builder.CreatePHI(Ty: PtrBegin->getType(), NumReservedValues: 2, Name: "omp.arraymap.ptrcurrent");
9828 PtrPHI->addIncoming(V: PtrBegin, BB: HeadBB);
9829
9830 // Get map clause information. Fill up the arrays with all mapped variables.
9831 MapInfosOrErrorTy Info = GenMapInfoCB(Builder.saveIP(), PtrPHI, BeginIn);
9832 if (!Info)
9833 return Info.takeError();
9834
9835 // Call the runtime API __tgt_mapper_num_components to get the number of
9836 // pre-existing components.
9837 Value *OffloadingArgs[] = {MapperHandle};
9838 Value *PreviousSize = createRuntimeFunctionCall(
9839 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___tgt_mapper_num_components),
9840 Args: OffloadingArgs);
9841 Value *ShiftedPreviousSize =
9842 Builder.CreateShl(LHS: PreviousSize, RHS: Builder.getInt64(C: getFlagMemberOffset()));
9843
9844 // Fill up the runtime mapper handle for all components.
9845 for (unsigned I = 0; I < Info->BasePointers.size(); ++I) {
9846 Value *CurBaseArg = Info->BasePointers[I];
9847 Value *CurBeginArg = Info->Pointers[I];
9848 Value *CurSizeArg = Info->Sizes[I];
9849 Value *CurNameArg = Info->Names.size()
9850 ? Info->Names[I]
9851 : Constant::getNullValue(Ty: Builder.getPtrTy());
9852
9853 // Extract the MEMBER_OF field from the map type.
9854 Value *OriMapType = Builder.getInt64(
9855 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9856 Info->Types[I]));
9857 Value *MemberMapType =
9858 Builder.CreateNUWAdd(LHS: OriMapType, RHS: ShiftedPreviousSize);
9859
9860 // Combine the map type inherited from user-defined mapper with that
9861 // specified in the program. According to the OMP_MAP_TO and OMP_MAP_FROM
9862 // bits of the \a MapType, which is the input argument of the mapper
9863 // function, the following code will set the OMP_MAP_TO and OMP_MAP_FROM
9864 // bits of MemberMapType.
9865 // [OpenMP 5.0], 1.2.6. map-type decay.
9866 // | alloc | to | from | tofrom | release | delete
9867 // ----------------------------------------------------------
9868 // alloc | alloc | alloc | alloc | alloc | release | delete
9869 // to | alloc | to | alloc | to | release | delete
9870 // from | alloc | alloc | from | from | release | delete
9871 // tofrom | alloc | to | from | tofrom | release | delete
9872 Value *LeftToFrom = Builder.CreateAnd(
9873 LHS: MapType,
9874 RHS: Builder.getInt64(
9875 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9876 OpenMPOffloadMappingFlags::OMP_MAP_TO |
9877 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
9878 BasicBlock *AllocBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.alloc");
9879 BasicBlock *AllocElseBB =
9880 BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.alloc.else");
9881 BasicBlock *ToBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.to");
9882 BasicBlock *ToElseBB =
9883 BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.to.else");
9884 BasicBlock *FromBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.from");
9885 BasicBlock *EndBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.end");
9886 Value *IsAlloc = Builder.CreateIsNull(Arg: LeftToFrom);
9887 Builder.CreateCondBr(Cond: IsAlloc, True: AllocBB, False: AllocElseBB);
9888 // In case of alloc, clear OMP_MAP_TO and OMP_MAP_FROM.
9889 emitBlock(BB: AllocBB, CurFn: MapperFn);
9890 Value *AllocMapType = Builder.CreateAnd(
9891 LHS: MemberMapType,
9892 RHS: Builder.getInt64(
9893 C: ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9894 OpenMPOffloadMappingFlags::OMP_MAP_TO |
9895 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
9896 Builder.CreateBr(Dest: EndBB);
9897 emitBlock(BB: AllocElseBB, CurFn: MapperFn);
9898 Value *IsTo = Builder.CreateICmpEQ(
9899 LHS: LeftToFrom,
9900 RHS: Builder.getInt64(
9901 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9902 OpenMPOffloadMappingFlags::OMP_MAP_TO)));
9903 Builder.CreateCondBr(Cond: IsTo, True: ToBB, False: ToElseBB);
9904 // In case of to, clear OMP_MAP_FROM.
9905 emitBlock(BB: ToBB, CurFn: MapperFn);
9906 Value *ToMapType = Builder.CreateAnd(
9907 LHS: MemberMapType,
9908 RHS: Builder.getInt64(
9909 C: ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9910 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
9911 Builder.CreateBr(Dest: EndBB);
9912 emitBlock(BB: ToElseBB, CurFn: MapperFn);
9913 Value *IsFrom = Builder.CreateICmpEQ(
9914 LHS: LeftToFrom,
9915 RHS: Builder.getInt64(
9916 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9917 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
9918 Builder.CreateCondBr(Cond: IsFrom, True: FromBB, False: EndBB);
9919 // In case of from, clear OMP_MAP_TO.
9920 emitBlock(BB: FromBB, CurFn: MapperFn);
9921 Value *FromMapType = Builder.CreateAnd(
9922 LHS: MemberMapType,
9923 RHS: Builder.getInt64(
9924 C: ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9925 OpenMPOffloadMappingFlags::OMP_MAP_TO)));
9926 // In case of tofrom, do nothing.
9927 emitBlock(BB: EndBB, CurFn: MapperFn);
9928 LastBB = EndBB;
9929 PHINode *CurMapType =
9930 Builder.CreatePHI(Ty: Builder.getInt64Ty(), NumReservedValues: 4, Name: "omp.maptype");
9931 CurMapType->addIncoming(V: AllocMapType, BB: AllocBB);
9932 CurMapType->addIncoming(V: ToMapType, BB: ToBB);
9933 CurMapType->addIncoming(V: FromMapType, BB: FromBB);
9934 CurMapType->addIncoming(V: MemberMapType, BB: ToElseBB);
9935
9936 Value *OffloadingArgs[] = {MapperHandle, CurBaseArg, CurBeginArg,
9937 CurSizeArg, CurMapType, CurNameArg};
9938
9939 auto ChildMapperFn = CustomMapperCB(I);
9940 if (!ChildMapperFn)
9941 return ChildMapperFn.takeError();
9942 if (*ChildMapperFn) {
9943 // Call the corresponding mapper function.
9944 createRuntimeFunctionCall(Callee: *ChildMapperFn, Args: OffloadingArgs)
9945 ->setDoesNotThrow();
9946 } else {
9947 // Call the runtime API __tgt_push_mapper_component to fill up the runtime
9948 // data structure.
9949 createRuntimeFunctionCall(
9950 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___tgt_push_mapper_component),
9951 Args: OffloadingArgs);
9952 }
9953 }
9954
9955 // Update the pointer to point to the next element that needs to be mapped,
9956 // and check whether we have mapped all elements.
9957 Value *PtrNext = Builder.CreateConstGEP1_32(Ty: ElemTy, Ptr: PtrPHI, /*Idx0=*/1,
9958 Name: "omp.arraymap.next");
9959 PtrPHI->addIncoming(V: PtrNext, BB: LastBB);
9960 Value *IsDone = Builder.CreateICmpEQ(LHS: PtrNext, RHS: PtrEnd, Name: "omp.arraymap.isdone");
9961 BasicBlock *ExitBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.arraymap.exit");
9962 Builder.CreateCondBr(Cond: IsDone, True: ExitBB, False: BodyBB);
9963
9964 emitBlock(BB: ExitBB, CurFn: MapperFn);
9965 // Emit array deletion if this is an array section and \p MapType indicates
9966 // that deletion is required.
9967 emitUDMapperArrayInitOrDel(MapperFn, MapperHandle, Base: BaseIn, Begin: BeginIn, Size,
9968 MapType, MapName, ElementSize, ExitBB: DoneBB,
9969 /*IsInit=*/false);
9970
9971 // Emit the function exit block.
9972 emitBlock(BB: DoneBB, CurFn: MapperFn, /*IsFinished=*/true);
9973
9974 Builder.CreateRetVoid();
9975 Builder.restoreIP(IP: SavedIP);
9976 return MapperFn;
9977}
9978
9979Error OpenMPIRBuilder::emitOffloadingArrays(
9980 InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
9981 TargetDataInfo &Info, CustomMapperCallbackTy CustomMapperCB,
9982 bool IsNonContiguous,
9983 function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
9984
9985 // Reset the array information.
9986 Info.clearArrayInfo();
9987 Info.NumberOfPtrs = CombinedInfo.BasePointers.size();
9988
9989 if (Info.NumberOfPtrs == 0)
9990 return Error::success();
9991
9992 Builder.restoreIP(IP: AllocaIP);
9993 // Detect if we have any capture size requiring runtime evaluation of the
9994 // size so that a constant array could be eventually used.
9995 ArrayType *PointerArrayType =
9996 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: Info.NumberOfPtrs);
9997
9998 Info.RTArgs.BasePointersArray = Builder.CreateAlloca(
9999 Ty: PointerArrayType, /* ArraySize = */ nullptr, Name: ".offload_baseptrs");
10000
10001 Info.RTArgs.PointersArray = Builder.CreateAlloca(
10002 Ty: PointerArrayType, /* ArraySize = */ nullptr, Name: ".offload_ptrs");
10003 AllocaInst *MappersArray = Builder.CreateAlloca(
10004 Ty: PointerArrayType, /* ArraySize = */ nullptr, Name: ".offload_mappers");
10005 Info.RTArgs.MappersArray = MappersArray;
10006
10007 // If we don't have any VLA types or other types that require runtime
10008 // evaluation, we can use a constant array for the map sizes, otherwise we
10009 // need to fill up the arrays as we do for the pointers.
10010 Type *Int64Ty = Builder.getInt64Ty();
10011 SmallVector<Constant *> ConstSizes(CombinedInfo.Sizes.size(),
10012 ConstantInt::get(Ty: Int64Ty, V: 0));
10013 SmallBitVector RuntimeSizes(CombinedInfo.Sizes.size());
10014 for (unsigned I = 0, E = CombinedInfo.Sizes.size(); I < E; ++I) {
10015 bool IsNonContigEntry =
10016 IsNonContiguous &&
10017 (static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
10018 CombinedInfo.Types[I] &
10019 OpenMPOffloadMappingFlags::OMP_MAP_NON_CONTIG) != 0);
10020 // For NON_CONTIG entries, ArgSizes stores the dimension count (number of
10021 // descriptor_dim records), not the byte size.
10022 if (IsNonContigEntry) {
10023 assert(I < CombinedInfo.NonContigInfo.Dims.size() &&
10024 "Index must be in-bounds for NON_CONTIG Dims array");
10025 const uint64_t DimCount = CombinedInfo.NonContigInfo.Dims[I];
10026 assert(DimCount > 0 && "NON_CONTIG DimCount must be > 0");
10027 ConstSizes[I] = ConstantInt::get(Ty: Int64Ty, V: DimCount);
10028 continue;
10029 }
10030 if (auto *CI = dyn_cast<Constant>(Val: CombinedInfo.Sizes[I])) {
10031 if (!isa<ConstantExpr>(Val: CI) && !isa<GlobalValue>(Val: CI)) {
10032 ConstSizes[I] = CI;
10033 continue;
10034 }
10035 }
10036 RuntimeSizes.set(I);
10037 }
10038
10039 if (RuntimeSizes.all()) {
10040 ArrayType *SizeArrayType = ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs);
10041 Info.RTArgs.SizesArray = Builder.CreateAlloca(
10042 Ty: SizeArrayType, /* ArraySize = */ nullptr, Name: ".offload_sizes");
10043 restoreIPandDebugLoc(Builder, IP: CodeGenIP);
10044 } else {
10045 auto *SizesArrayInit = ConstantArray::get(
10046 T: ArrayType::get(ElementType: Int64Ty, NumElements: ConstSizes.size()), V: ConstSizes);
10047 std::string Name = createPlatformSpecificName(Parts: {"offload_sizes"});
10048 auto *SizesArrayGbl =
10049 new GlobalVariable(M, SizesArrayInit->getType(), /*isConstant=*/true,
10050 GlobalValue::PrivateLinkage, SizesArrayInit, Name);
10051 SizesArrayGbl->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
10052
10053 if (!RuntimeSizes.any()) {
10054 Info.RTArgs.SizesArray = SizesArrayGbl;
10055 } else {
10056 unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(AS: 0);
10057 Align OffloadSizeAlign = M.getDataLayout().getABIIntegerTypeAlignment(BitWidth: 64);
10058 ArrayType *SizeArrayType = ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs);
10059 AllocaInst *Buffer = Builder.CreateAlloca(
10060 Ty: SizeArrayType, /* ArraySize = */ nullptr, Name: ".offload_sizes");
10061 Buffer->setAlignment(OffloadSizeAlign);
10062 restoreIPandDebugLoc(Builder, IP: CodeGenIP);
10063 Builder.CreateMemCpy(
10064 Dst: Buffer, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: Buffer->getType()),
10065 Src: SizesArrayGbl, SrcAlign: OffloadSizeAlign,
10066 Size: Builder.getIntN(
10067 N: IndexSize,
10068 C: Buffer->getAllocationSize(DL: M.getDataLayout())->getFixedValue()));
10069
10070 Info.RTArgs.SizesArray = Buffer;
10071 }
10072 restoreIPandDebugLoc(Builder, IP: CodeGenIP);
10073 }
10074
10075 // The map types are always constant so we don't need to generate code to
10076 // fill arrays. Instead, we create an array constant.
10077 SmallVector<uint64_t, 4> Mapping;
10078 for (auto mapFlag : CombinedInfo.Types)
10079 Mapping.push_back(
10080 Elt: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
10081 mapFlag));
10082 std::string MaptypesName = createPlatformSpecificName(Parts: {"offload_maptypes"});
10083 auto *MapTypesArrayGbl = createOffloadMaptypes(Mappings&: Mapping, VarName: MaptypesName);
10084 Info.RTArgs.MapTypesArray = MapTypesArrayGbl;
10085
10086 // The information types are only built if provided.
10087 if (!CombinedInfo.Names.empty()) {
10088 auto *MapNamesArrayGbl = createOffloadMapnames(
10089 Names&: CombinedInfo.Names, VarName: createPlatformSpecificName(Parts: {"offload_mapnames"}));
10090 Info.RTArgs.MapNamesArray = MapNamesArrayGbl;
10091 Info.EmitDebug = true;
10092 } else {
10093 Info.RTArgs.MapNamesArray =
10094 Constant::getNullValue(Ty: PointerType::getUnqual(C&: Builder.getContext()));
10095 Info.EmitDebug = false;
10096 }
10097
10098 // If there's a present map type modifier, it must not be applied to the end
10099 // of a region, so generate a separate map type array in that case.
10100 if (Info.separateBeginEndCalls()) {
10101 bool EndMapTypesDiffer = false;
10102 for (uint64_t &Type : Mapping) {
10103 if (Type & static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
10104 OpenMPOffloadMappingFlags::OMP_MAP_PRESENT)) {
10105 Type &= ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
10106 OpenMPOffloadMappingFlags::OMP_MAP_PRESENT);
10107 EndMapTypesDiffer = true;
10108 }
10109 }
10110 if (EndMapTypesDiffer) {
10111 MapTypesArrayGbl = createOffloadMaptypes(Mappings&: Mapping, VarName: MaptypesName);
10112 Info.RTArgs.MapTypesArrayEnd = MapTypesArrayGbl;
10113 }
10114 }
10115
10116 PointerType *PtrTy = Builder.getPtrTy();
10117 for (unsigned I = 0; I < Info.NumberOfPtrs; ++I) {
10118 Value *BPVal = CombinedInfo.BasePointers[I];
10119 Value *BP = Builder.CreateConstInBoundsGEP2_32(
10120 Ty: ArrayType::get(ElementType: PtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.BasePointersArray,
10121 Idx0: 0, Idx1: I);
10122 Builder.CreateAlignedStore(Val: BPVal, Ptr: BP,
10123 Align: M.getDataLayout().getPrefTypeAlign(Ty: PtrTy));
10124
10125 if (Info.requiresDevicePointerInfo()) {
10126 if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Pointer) {
10127 CodeGenIP = Builder.saveIP();
10128 Builder.restoreIP(IP: AllocaIP);
10129 Info.DevicePtrInfoMap[BPVal] = {BP, Builder.CreateAlloca(Ty: PtrTy)};
10130 Builder.restoreIP(IP: CodeGenIP);
10131 if (DeviceAddrCB)
10132 DeviceAddrCB(I, Info.DevicePtrInfoMap[BPVal].second);
10133 } else if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Address) {
10134 Info.DevicePtrInfoMap[BPVal] = {BP, BP};
10135 if (DeviceAddrCB)
10136 DeviceAddrCB(I, BP);
10137 }
10138 }
10139
10140 Value *PVal = CombinedInfo.Pointers[I];
10141 Value *P = Builder.CreateConstInBoundsGEP2_32(
10142 Ty: ArrayType::get(ElementType: PtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.PointersArray, Idx0: 0,
10143 Idx1: I);
10144 // TODO: Check alignment correct.
10145 Builder.CreateAlignedStore(Val: PVal, Ptr: P,
10146 Align: M.getDataLayout().getPrefTypeAlign(Ty: PtrTy));
10147
10148 if (RuntimeSizes.test(Idx: I)) {
10149 Value *S = Builder.CreateConstInBoundsGEP2_32(
10150 Ty: ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.SizesArray,
10151 /*Idx0=*/0,
10152 /*Idx1=*/I);
10153 Builder.CreateAlignedStore(Val: Builder.CreateIntCast(V: CombinedInfo.Sizes[I],
10154 DestTy: Int64Ty,
10155 /*isSigned=*/true),
10156 Ptr: S, Align: M.getDataLayout().getPrefTypeAlign(Ty: PtrTy));
10157 }
10158 // Fill up the mapper array.
10159 unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(AS: 0);
10160 Value *MFunc = ConstantPointerNull::get(T: PtrTy);
10161
10162 auto CustomMFunc = CustomMapperCB(I);
10163 if (!CustomMFunc)
10164 return CustomMFunc.takeError();
10165 if (*CustomMFunc)
10166 MFunc = Builder.CreatePointerCast(V: *CustomMFunc, DestTy: PtrTy);
10167
10168 Value *MAddr = Builder.CreateInBoundsGEP(
10169 Ty: PointerArrayType, Ptr: MappersArray,
10170 IdxList: {Builder.getIntN(N: IndexSize, C: 0), Builder.getIntN(N: IndexSize, C: I)});
10171 Builder.CreateAlignedStore(
10172 Val: MFunc, Ptr: MAddr, Align: M.getDataLayout().getPrefTypeAlign(Ty: MAddr->getType()));
10173 }
10174
10175 if (!IsNonContiguous || CombinedInfo.NonContigInfo.Offsets.empty() ||
10176 Info.NumberOfPtrs == 0)
10177 return Error::success();
10178 emitNonContiguousDescriptor(AllocaIP, CodeGenIP, CombinedInfo, Info);
10179 return Error::success();
10180}
10181
10182void OpenMPIRBuilder::emitBranch(BasicBlock *Target) {
10183 BasicBlock *CurBB = Builder.GetInsertBlock();
10184
10185 if (!CurBB || CurBB->hasTerminator()) {
10186 // If there is no insert point or the previous block is already
10187 // terminated, don't touch it.
10188 } else {
10189 // Otherwise, create a fall-through branch.
10190 Builder.CreateBr(Dest: Target);
10191 }
10192
10193 Builder.ClearInsertionPoint();
10194}
10195
10196void OpenMPIRBuilder::emitBlock(BasicBlock *BB, Function *CurFn,
10197 bool IsFinished) {
10198 BasicBlock *CurBB = Builder.GetInsertBlock();
10199
10200 // Fall out of the current block (if necessary).
10201 emitBranch(Target: BB);
10202
10203 if (IsFinished && BB->use_empty()) {
10204 BB->eraseFromParent();
10205 return;
10206 }
10207
10208 // Place the block after the current block, if possible, or else at
10209 // the end of the function.
10210 if (CurBB && CurBB->getParent())
10211 CurFn->insert(Position: std::next(x: CurBB->getIterator()), BB);
10212 else
10213 CurFn->insert(Position: CurFn->end(), BB);
10214 Builder.SetInsertPoint(BB);
10215}
10216
10217Error OpenMPIRBuilder::emitIfClause(Value *Cond, BodyGenCallbackTy ThenGen,
10218 BodyGenCallbackTy ElseGen,
10219 InsertPointTy AllocaIP) {
10220 // If the condition constant folds and can be elided, try to avoid emitting
10221 // the condition and the dead arm of the if/else.
10222 if (auto *CI = dyn_cast<ConstantInt>(Val: Cond)) {
10223 auto CondConstant = CI->getSExtValue();
10224 if (CondConstant)
10225 return ThenGen(AllocaIP, Builder.saveIP());
10226
10227 return ElseGen(AllocaIP, Builder.saveIP());
10228 }
10229
10230 Function *CurFn = Builder.GetInsertBlock()->getParent();
10231
10232 // Otherwise, the condition did not fold, or we couldn't elide it. Just
10233 // emit the conditional branch.
10234 BasicBlock *ThenBlock = BasicBlock::Create(Context&: M.getContext(), Name: "omp_if.then");
10235 BasicBlock *ElseBlock = BasicBlock::Create(Context&: M.getContext(), Name: "omp_if.else");
10236 BasicBlock *ContBlock = BasicBlock::Create(Context&: M.getContext(), Name: "omp_if.end");
10237 Builder.CreateCondBr(Cond, True: ThenBlock, False: ElseBlock);
10238 // Emit the 'then' code.
10239 emitBlock(BB: ThenBlock, CurFn);
10240 if (Error Err = ThenGen(AllocaIP, Builder.saveIP()))
10241 return Err;
10242 emitBranch(Target: ContBlock);
10243 // Emit the 'else' code if present.
10244 // There is no need to emit line number for unconditional branch.
10245 emitBlock(BB: ElseBlock, CurFn);
10246 if (Error Err = ElseGen(AllocaIP, Builder.saveIP()))
10247 return Err;
10248 // There is no need to emit line number for unconditional branch.
10249 emitBranch(Target: ContBlock);
10250 // Emit the continuation block for code after the if.
10251 emitBlock(BB: ContBlock, CurFn, /*IsFinished=*/true);
10252 return Error::success();
10253}
10254
10255bool OpenMPIRBuilder::checkAndEmitFlushAfterAtomic(
10256 const LocationDescription &Loc, llvm::AtomicOrdering AO, AtomicKind AK) {
10257 assert(!(AO == AtomicOrdering::NotAtomic ||
10258 AO == llvm::AtomicOrdering::Unordered) &&
10259 "Unexpected Atomic Ordering.");
10260
10261 bool Flush = false;
10262 llvm::AtomicOrdering FlushAO = AtomicOrdering::Monotonic;
10263
10264 switch (AK) {
10265 case Read:
10266 if (AO == AtomicOrdering::Acquire || AO == AtomicOrdering::AcquireRelease ||
10267 AO == AtomicOrdering::SequentiallyConsistent) {
10268 FlushAO = AtomicOrdering::Acquire;
10269 Flush = true;
10270 }
10271 break;
10272 case Write:
10273 case Compare:
10274 case Update:
10275 if (AO == AtomicOrdering::Release || AO == AtomicOrdering::AcquireRelease ||
10276 AO == AtomicOrdering::SequentiallyConsistent) {
10277 FlushAO = AtomicOrdering::Release;
10278 Flush = true;
10279 }
10280 break;
10281 case Capture:
10282 switch (AO) {
10283 case AtomicOrdering::Acquire:
10284 FlushAO = AtomicOrdering::Acquire;
10285 Flush = true;
10286 break;
10287 case AtomicOrdering::Release:
10288 FlushAO = AtomicOrdering::Release;
10289 Flush = true;
10290 break;
10291 case AtomicOrdering::AcquireRelease:
10292 case AtomicOrdering::SequentiallyConsistent:
10293 FlushAO = AtomicOrdering::AcquireRelease;
10294 Flush = true;
10295 break;
10296 default:
10297 // do nothing - leave silently.
10298 break;
10299 }
10300 }
10301
10302 if (Flush) {
10303 // Currently Flush RT call still doesn't take memory_ordering, so for when
10304 // that happens, this tries to do the resolution of which atomic ordering
10305 // to use with but issue the flush call
10306 // TODO: pass `FlushAO` after memory ordering support is added
10307 (void)FlushAO;
10308 emitFlush(Loc);
10309 }
10310
10311 // for AO == AtomicOrdering::Monotonic and all other case combinations
10312 // do nothing
10313 return Flush;
10314}
10315
10316OpenMPIRBuilder::InsertPointTy
10317OpenMPIRBuilder::createAtomicRead(const LocationDescription &Loc,
10318 AtomicOpValue &X, AtomicOpValue &V,
10319 AtomicOrdering AO, InsertPointTy AllocaIP) {
10320 if (!updateToLocation(Loc))
10321 return Loc.IP;
10322
10323 assert(X.Var->getType()->isPointerTy() &&
10324 "OMP Atomic expects a pointer to target memory");
10325 Type *XElemTy = X.ElemTy;
10326 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
10327 XElemTy->isPointerTy() || XElemTy->isStructTy()) &&
10328 "OMP atomic read expected a scalar type");
10329
10330 Value *XRead = nullptr;
10331
10332 if (XElemTy->isIntegerTy()) {
10333 LoadInst *XLD =
10334 Builder.CreateLoad(Ty: XElemTy, Ptr: X.Var, isVolatile: X.IsVolatile, Name: "omp.atomic.read");
10335 XLD->setAtomic(Ordering: AO);
10336 XRead = cast<Value>(Val: XLD);
10337 } else if (XElemTy->isStructTy()) {
10338 // FIXME: Add checks to ensure __atomic_load is emitted iff the
10339 // target does not support `atomicrmw` of the size of the struct
10340 LoadInst *OldVal = Builder.CreateLoad(Ty: XElemTy, Ptr: X.Var, Name: "omp.atomic.read");
10341 OldVal->setAtomic(Ordering: AO);
10342 const DataLayout &DL = OldVal->getModule()->getDataLayout();
10343 unsigned LoadSize = DL.getTypeStoreSize(Ty: XElemTy);
10344 OpenMPIRBuilder::AtomicInfo atomicInfo(
10345 &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
10346 OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X.Var);
10347 auto AtomicLoadRes = atomicInfo.EmitAtomicLoadLibcall(AO);
10348 XRead = AtomicLoadRes.first;
10349 OldVal->eraseFromParent();
10350 } else {
10351 // We need to perform atomic op as integer
10352 IntegerType *IntCastTy =
10353 IntegerType::get(C&: M.getContext(), NumBits: XElemTy->getScalarSizeInBits());
10354 LoadInst *XLoad =
10355 Builder.CreateLoad(Ty: IntCastTy, Ptr: X.Var, isVolatile: X.IsVolatile, Name: "omp.atomic.load");
10356 XLoad->setAtomic(Ordering: AO);
10357 if (XElemTy->isFloatingPointTy()) {
10358 XRead = Builder.CreateBitCast(V: XLoad, DestTy: XElemTy, Name: "atomic.flt.cast");
10359 } else {
10360 XRead = Builder.CreateIntToPtr(V: XLoad, DestTy: XElemTy, Name: "atomic.ptr.cast");
10361 }
10362 }
10363 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Read);
10364 Builder.CreateStore(Val: XRead, Ptr: V.Var, isVolatile: V.IsVolatile);
10365 return Builder.saveIP();
10366}
10367
10368OpenMPIRBuilder::InsertPointTy
10369OpenMPIRBuilder::createAtomicWrite(const LocationDescription &Loc,
10370 AtomicOpValue &X, Value *Expr,
10371 AtomicOrdering AO, InsertPointTy AllocaIP) {
10372 if (!updateToLocation(Loc))
10373 return Loc.IP;
10374
10375 assert(X.Var->getType()->isPointerTy() &&
10376 "OMP Atomic expects a pointer to target memory");
10377 Type *XElemTy = X.ElemTy;
10378 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
10379 XElemTy->isPointerTy() || XElemTy->isStructTy()) &&
10380 "OMP atomic write expected a scalar type");
10381
10382 if (XElemTy->isIntegerTy()) {
10383 StoreInst *XSt = Builder.CreateStore(Val: Expr, Ptr: X.Var, isVolatile: X.IsVolatile);
10384 XSt->setAtomic(Ordering: AO);
10385 } else if (XElemTy->isStructTy()) {
10386 LoadInst *OldVal = Builder.CreateLoad(Ty: XElemTy, Ptr: X.Var, Name: "omp.atomic.read");
10387 const DataLayout &DL = OldVal->getModule()->getDataLayout();
10388 unsigned LoadSize = DL.getTypeStoreSize(Ty: XElemTy);
10389 OpenMPIRBuilder::AtomicInfo atomicInfo(
10390 &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
10391 OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X.Var);
10392 atomicInfo.EmitAtomicStoreLibcall(AO, Source: Expr);
10393 OldVal->eraseFromParent();
10394 } else {
10395 // We need to bitcast and perform atomic op as integers
10396 IntegerType *IntCastTy =
10397 IntegerType::get(C&: M.getContext(), NumBits: XElemTy->getScalarSizeInBits());
10398 Value *ExprCast =
10399 Builder.CreateBitCast(V: Expr, DestTy: IntCastTy, Name: "atomic.src.int.cast");
10400 StoreInst *XSt = Builder.CreateStore(Val: ExprCast, Ptr: X.Var, isVolatile: X.IsVolatile);
10401 XSt->setAtomic(Ordering: AO);
10402 }
10403
10404 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Write);
10405 return Builder.saveIP();
10406}
10407
10408OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createAtomicUpdate(
10409 const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
10410 Value *Expr, AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
10411 AtomicUpdateCallbackTy &UpdateOp, bool IsXBinopExpr,
10412 bool IsIgnoreDenormalMode, bool IsFineGrainedMemory, bool IsRemoteMemory) {
10413 assert(!isConflictIP(Loc.IP, AllocaIP) && "IPs must not be ambiguous");
10414 if (!updateToLocation(Loc))
10415 return Loc.IP;
10416
10417 LLVM_DEBUG({
10418 Type *XTy = X.Var->getType();
10419 assert(XTy->isPointerTy() &&
10420 "OMP Atomic expects a pointer to target memory");
10421 Type *XElemTy = X.ElemTy;
10422 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
10423 XElemTy->isPointerTy()) &&
10424 "OMP atomic update expected a scalar type");
10425 assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
10426 (RMWOp != AtomicRMWInst::UMax) && (RMWOp != AtomicRMWInst::UMin) &&
10427 "OpenMP atomic does not support LT or GT operations");
10428 });
10429
10430 Expected<std::pair<Value *, Value *>> AtomicResult = emitAtomicUpdate(
10431 AllocaIP, X: X.Var, XElemTy: X.ElemTy, Expr, AO, RMWOp, UpdateOp, VolatileX: X.IsVolatile,
10432 IsXBinopExpr, IsIgnoreDenormalMode, IsFineGrainedMemory, IsRemoteMemory);
10433 if (!AtomicResult)
10434 return AtomicResult.takeError();
10435 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Update);
10436 return Builder.saveIP();
10437}
10438
10439// FIXME: Duplicating AtomicExpand
10440Value *OpenMPIRBuilder::emitRMWOpAsInstruction(Value *Src1, Value *Src2,
10441 AtomicRMWInst::BinOp RMWOp) {
10442 switch (RMWOp) {
10443 case AtomicRMWInst::Add:
10444 return Builder.CreateAdd(LHS: Src1, RHS: Src2);
10445 case AtomicRMWInst::Sub:
10446 return Builder.CreateSub(LHS: Src1, RHS: Src2);
10447 case AtomicRMWInst::And:
10448 return Builder.CreateAnd(LHS: Src1, RHS: Src2);
10449 case AtomicRMWInst::Nand:
10450 return Builder.CreateNeg(V: Builder.CreateAnd(LHS: Src1, RHS: Src2));
10451 case AtomicRMWInst::Or:
10452 return Builder.CreateOr(LHS: Src1, RHS: Src2);
10453 case AtomicRMWInst::Xor:
10454 return Builder.CreateXor(LHS: Src1, RHS: Src2);
10455 case AtomicRMWInst::Xchg:
10456 case AtomicRMWInst::FAdd:
10457 case AtomicRMWInst::FSub:
10458 case AtomicRMWInst::BAD_BINOP:
10459 case AtomicRMWInst::Max:
10460 case AtomicRMWInst::Min:
10461 case AtomicRMWInst::UMax:
10462 case AtomicRMWInst::UMin:
10463 case AtomicRMWInst::FMax:
10464 case AtomicRMWInst::FMin:
10465 case AtomicRMWInst::FMaximum:
10466 case AtomicRMWInst::FMinimum:
10467 case AtomicRMWInst::FMaximumNum:
10468 case AtomicRMWInst::FMinimumNum:
10469 case AtomicRMWInst::UIncWrap:
10470 case AtomicRMWInst::UDecWrap:
10471 case AtomicRMWInst::USubCond:
10472 case AtomicRMWInst::USubSat:
10473 llvm_unreachable("Unsupported atomic update operation");
10474 }
10475 llvm_unreachable("Unsupported atomic update operation");
10476}
10477
10478Expected<std::pair<Value *, Value *>> OpenMPIRBuilder::emitAtomicUpdate(
10479 InsertPointTy AllocaIP, Value *X, Type *XElemTy, Value *Expr,
10480 AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
10481 AtomicUpdateCallbackTy &UpdateOp, bool VolatileX, bool IsXBinopExpr,
10482 bool IsIgnoreDenormalMode, bool IsFineGrainedMemory, bool IsRemoteMemory) {
10483 // TODO: handle the case where XElemTy is not byte-sized or not a power of 2
10484 // or a complex datatype.
10485 bool emitRMWOp = false;
10486 switch (RMWOp) {
10487 case AtomicRMWInst::Add:
10488 case AtomicRMWInst::And:
10489 case AtomicRMWInst::Nand:
10490 case AtomicRMWInst::Or:
10491 case AtomicRMWInst::Xor:
10492 case AtomicRMWInst::Xchg:
10493 emitRMWOp = XElemTy;
10494 break;
10495 case AtomicRMWInst::Sub:
10496 emitRMWOp = (IsXBinopExpr && XElemTy);
10497 break;
10498 default:
10499 emitRMWOp = false;
10500 }
10501 emitRMWOp &= XElemTy->isIntegerTy();
10502
10503 std::pair<Value *, Value *> Res;
10504 if (emitRMWOp) {
10505 AtomicRMWInst *RMWInst =
10506 Builder.CreateAtomicRMW(Op: RMWOp, Ptr: X, Val: Expr, Align: llvm::MaybeAlign(), Ordering: AO);
10507 if (T.isAMDGPU()) {
10508 if (IsIgnoreDenormalMode)
10509 RMWInst->setMetadata(Kind: "amdgpu.ignore.denormal.mode",
10510 Node: llvm::MDNode::get(Context&: Builder.getContext(), MDs: {}));
10511 if (!IsFineGrainedMemory)
10512 RMWInst->setMetadata(Kind: "amdgpu.no.fine.grained.memory",
10513 Node: llvm::MDNode::get(Context&: Builder.getContext(), MDs: {}));
10514 if (!IsRemoteMemory)
10515 RMWInst->setMetadata(Kind: "amdgpu.no.remote.memory",
10516 Node: llvm::MDNode::get(Context&: Builder.getContext(), MDs: {}));
10517 }
10518 Res.first = RMWInst;
10519 // not needed except in case of postfix captures. Generate anyway for
10520 // consistency with the else part. Will be removed with any DCE pass.
10521 // AtomicRMWInst::Xchg does not have a coressponding instruction.
10522 if (RMWOp == AtomicRMWInst::Xchg)
10523 Res.second = Res.first;
10524 else
10525 Res.second = emitRMWOpAsInstruction(Src1: Res.first, Src2: Expr, RMWOp);
10526 } else if (RMWOp == llvm::AtomicRMWInst::BinOp::BAD_BINOP &&
10527 XElemTy->isStructTy()) {
10528 LoadInst *OldVal =
10529 Builder.CreateLoad(Ty: XElemTy, Ptr: X, Name: X->getName() + ".atomic.load");
10530 OldVal->setAtomic(Ordering: AO);
10531 const DataLayout &LoadDL = OldVal->getModule()->getDataLayout();
10532 unsigned LoadSize =
10533 LoadDL.getTypeStoreSize(Ty: OldVal->getPointerOperand()->getType());
10534
10535 OpenMPIRBuilder::AtomicInfo atomicInfo(
10536 &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
10537 OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X);
10538 auto AtomicLoadRes = atomicInfo.EmitAtomicLoadLibcall(AO);
10539 BasicBlock *CurBB = Builder.GetInsertBlock();
10540 Instruction *CurBBTI = CurBB->getTerminatorOrNull();
10541 CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
10542 BasicBlock *ExitBB =
10543 CurBB->splitBasicBlock(I: CurBBTI, BBName: X->getName() + ".atomic.exit");
10544 BasicBlock *ContBB = CurBB->splitBasicBlock(I: CurBB->getTerminator(),
10545 BBName: X->getName() + ".atomic.cont");
10546 ContBB->getTerminator()->eraseFromParent();
10547 Builder.restoreIP(IP: AllocaIP);
10548 AllocaInst *NewAtomicAddr = Builder.CreateAlloca(Ty: XElemTy);
10549 NewAtomicAddr->setName(X->getName() + "x.new.val");
10550 Builder.SetInsertPoint(ContBB);
10551 llvm::PHINode *PHI = Builder.CreatePHI(Ty: OldVal->getType(), NumReservedValues: 2);
10552 PHI->addIncoming(V: AtomicLoadRes.first, BB: CurBB);
10553 Value *OldExprVal = PHI;
10554 Expected<Value *> CBResult = UpdateOp(OldExprVal, Builder);
10555 if (!CBResult)
10556 return CBResult.takeError();
10557 Value *Upd = *CBResult;
10558 Builder.CreateStore(Val: Upd, Ptr: NewAtomicAddr);
10559 AtomicOrdering Failure =
10560 llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(SuccessOrdering: AO);
10561 auto Result = atomicInfo.EmitAtomicCompareExchangeLibcall(
10562 ExpectedVal: AtomicLoadRes.second, DesiredVal: NewAtomicAddr, Success: AO, Failure);
10563 LoadInst *PHILoad = Builder.CreateLoad(Ty: XElemTy, Ptr: Result.first);
10564 PHI->addIncoming(V: PHILoad, BB: Builder.GetInsertBlock());
10565 Builder.CreateCondBr(Cond: Result.second, True: ExitBB, False: ContBB);
10566 OldVal->eraseFromParent();
10567 Res.first = OldExprVal;
10568 Res.second = Upd;
10569
10570 if (UnreachableInst *ExitTI =
10571 dyn_cast<UnreachableInst>(Val: ExitBB->getTerminator())) {
10572 CurBBTI->eraseFromParent();
10573 Builder.SetInsertPoint(ExitBB);
10574 } else {
10575 Builder.SetInsertPoint(ExitTI);
10576 }
10577 } else {
10578 IntegerType *IntCastTy =
10579 IntegerType::get(C&: M.getContext(), NumBits: XElemTy->getScalarSizeInBits());
10580 LoadInst *OldVal =
10581 Builder.CreateLoad(Ty: IntCastTy, Ptr: X, Name: X->getName() + ".atomic.load");
10582 OldVal->setAtomic(Ordering: AO);
10583 // CurBB
10584 // | /---\
10585 // ContBB |
10586 // | \---/
10587 // ExitBB
10588 BasicBlock *CurBB = Builder.GetInsertBlock();
10589 Instruction *CurBBTI = CurBB->getTerminatorOrNull();
10590 CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
10591 BasicBlock *ExitBB =
10592 CurBB->splitBasicBlock(I: CurBBTI, BBName: X->getName() + ".atomic.exit");
10593 BasicBlock *ContBB = CurBB->splitBasicBlock(I: CurBB->getTerminator(),
10594 BBName: X->getName() + ".atomic.cont");
10595 ContBB->getTerminator()->eraseFromParent();
10596 Builder.restoreIP(IP: AllocaIP);
10597 AllocaInst *NewAtomicAddr = Builder.CreateAlloca(Ty: XElemTy);
10598 NewAtomicAddr->setName(X->getName() + "x.new.val");
10599 Builder.SetInsertPoint(ContBB);
10600 llvm::PHINode *PHI = Builder.CreatePHI(Ty: OldVal->getType(), NumReservedValues: 2);
10601 PHI->addIncoming(V: OldVal, BB: CurBB);
10602 bool IsIntTy = XElemTy->isIntegerTy();
10603 Value *OldExprVal = PHI;
10604 if (!IsIntTy) {
10605 if (XElemTy->isFloatingPointTy()) {
10606 OldExprVal = Builder.CreateBitCast(V: PHI, DestTy: XElemTy,
10607 Name: X->getName() + ".atomic.fltCast");
10608 } else {
10609 OldExprVal = Builder.CreateIntToPtr(V: PHI, DestTy: XElemTy,
10610 Name: X->getName() + ".atomic.ptrCast");
10611 }
10612 }
10613
10614 Expected<Value *> CBResult = UpdateOp(OldExprVal, Builder);
10615 if (!CBResult)
10616 return CBResult.takeError();
10617 Value *Upd = *CBResult;
10618 Builder.CreateStore(Val: Upd, Ptr: NewAtomicAddr);
10619 LoadInst *DesiredVal = Builder.CreateLoad(Ty: IntCastTy, Ptr: NewAtomicAddr);
10620 AtomicOrdering Failure =
10621 llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(SuccessOrdering: AO);
10622 AtomicCmpXchgInst *Result = Builder.CreateAtomicCmpXchg(
10623 Ptr: X, Cmp: PHI, New: DesiredVal, Align: llvm::MaybeAlign(), SuccessOrdering: AO, FailureOrdering: Failure);
10624 Result->setVolatile(VolatileX);
10625 Value *PreviousVal = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/0);
10626 Value *SuccessFailureVal = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/1);
10627 PHI->addIncoming(V: PreviousVal, BB: Builder.GetInsertBlock());
10628 Builder.CreateCondBr(Cond: SuccessFailureVal, True: ExitBB, False: ContBB);
10629
10630 Res.first = OldExprVal;
10631 Res.second = Upd;
10632
10633 // set Insertion point in exit block
10634 if (UnreachableInst *ExitTI =
10635 dyn_cast<UnreachableInst>(Val: ExitBB->getTerminator())) {
10636 CurBBTI->eraseFromParent();
10637 Builder.SetInsertPoint(ExitBB);
10638 } else {
10639 Builder.SetInsertPoint(ExitTI);
10640 }
10641 }
10642
10643 return Res;
10644}
10645
10646OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createAtomicCapture(
10647 const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
10648 AtomicOpValue &V, Value *Expr, AtomicOrdering AO,
10649 AtomicRMWInst::BinOp RMWOp, AtomicUpdateCallbackTy &UpdateOp,
10650 bool UpdateExpr, bool IsPostfixUpdate, bool IsXBinopExpr,
10651 bool IsIgnoreDenormalMode, bool IsFineGrainedMemory, bool IsRemoteMemory) {
10652 if (!updateToLocation(Loc))
10653 return Loc.IP;
10654
10655 LLVM_DEBUG({
10656 Type *XTy = X.Var->getType();
10657 assert(XTy->isPointerTy() &&
10658 "OMP Atomic expects a pointer to target memory");
10659 Type *XElemTy = X.ElemTy;
10660 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
10661 XElemTy->isPointerTy()) &&
10662 "OMP atomic capture expected a scalar type");
10663 assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
10664 "OpenMP atomic does not support LT or GT operations");
10665 });
10666
10667 // If UpdateExpr is 'x' updated with some `expr` not based on 'x',
10668 // 'x' is simply atomically rewritten with 'expr'.
10669 AtomicRMWInst::BinOp AtomicOp = (UpdateExpr ? RMWOp : AtomicRMWInst::Xchg);
10670 Expected<std::pair<Value *, Value *>> AtomicResult = emitAtomicUpdate(
10671 AllocaIP, X: X.Var, XElemTy: X.ElemTy, Expr, AO, RMWOp: AtomicOp, UpdateOp, VolatileX: X.IsVolatile,
10672 IsXBinopExpr, IsIgnoreDenormalMode, IsFineGrainedMemory, IsRemoteMemory);
10673 if (!AtomicResult)
10674 return AtomicResult.takeError();
10675 Value *CapturedVal =
10676 (IsPostfixUpdate ? AtomicResult->first : AtomicResult->second);
10677 Builder.CreateStore(Val: CapturedVal, Ptr: V.Var, isVolatile: V.IsVolatile);
10678
10679 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Capture);
10680 return Builder.saveIP();
10681}
10682
10683OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
10684 const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
10685 AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
10686 omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
10687 bool IsFailOnly) {
10688
10689 AtomicOrdering Failure = AtomicCmpXchgInst::getStrongestFailureOrdering(SuccessOrdering: AO);
10690 return createAtomicCompare(Loc, X, V, R, E, D, AO, Op, IsXBinopExpr,
10691 IsPostfixUpdate, IsFailOnly, Failure);
10692}
10693
10694OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
10695 const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
10696 AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
10697 omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
10698 bool IsFailOnly, AtomicOrdering Failure) {
10699
10700 if (!updateToLocation(Loc))
10701 return Loc.IP;
10702
10703 assert(X.Var->getType()->isPointerTy() &&
10704 "OMP atomic expects a pointer to target memory");
10705 // compare capture
10706 if (V.Var) {
10707 assert(V.Var->getType()->isPointerTy() && "v.var must be of pointer type");
10708 assert(V.ElemTy == X.ElemTy && "x and v must be of same type");
10709 }
10710
10711 bool IsInteger = E->getType()->isIntegerTy();
10712
10713 if (Op == OMPAtomicCompareOp::EQ) {
10714 AtomicCmpXchgInst *Result = nullptr;
10715 if (!IsInteger) {
10716 IntegerType *IntCastTy =
10717 IntegerType::get(C&: M.getContext(), NumBits: X.ElemTy->getScalarSizeInBits());
10718 Value *EBCast = Builder.CreateBitCast(V: E, DestTy: IntCastTy);
10719 Value *DBCast = Builder.CreateBitCast(V: D, DestTy: IntCastTy);
10720 Result = Builder.CreateAtomicCmpXchg(Ptr: X.Var, Cmp: EBCast, New: DBCast, Align: MaybeAlign(),
10721 SuccessOrdering: AO, FailureOrdering: Failure);
10722 } else {
10723 Result =
10724 Builder.CreateAtomicCmpXchg(Ptr: X.Var, Cmp: E, New: D, Align: MaybeAlign(), SuccessOrdering: AO, FailureOrdering: Failure);
10725 }
10726
10727 if (V.Var) {
10728 Value *OldValue = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/0);
10729 if (!IsInteger)
10730 OldValue = Builder.CreateBitCast(V: OldValue, DestTy: X.ElemTy);
10731 assert(OldValue->getType() == V.ElemTy &&
10732 "OldValue and V must be of same type");
10733 if (IsPostfixUpdate) {
10734 Builder.CreateStore(Val: OldValue, Ptr: V.Var, isVolatile: V.IsVolatile);
10735 } else {
10736 Value *SuccessOrFail = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/1);
10737 if (IsFailOnly) {
10738 // CurBB----
10739 // | |
10740 // v |
10741 // ContBB |
10742 // | |
10743 // v |
10744 // ExitBB <-
10745 //
10746 // where ContBB only contains the store of old value to 'v'.
10747 BasicBlock *CurBB = Builder.GetInsertBlock();
10748 Instruction *CurBBTI = CurBB->getTerminatorOrNull();
10749 CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
10750 BasicBlock *ExitBB = CurBB->splitBasicBlock(
10751 I: CurBBTI, BBName: X.Var->getName() + ".atomic.exit");
10752 BasicBlock *ContBB = CurBB->splitBasicBlock(
10753 I: CurBB->getTerminator(), BBName: X.Var->getName() + ".atomic.cont");
10754 ContBB->getTerminator()->eraseFromParent();
10755 CurBB->getTerminator()->eraseFromParent();
10756
10757 Builder.CreateCondBr(Cond: SuccessOrFail, True: ExitBB, False: ContBB);
10758
10759 Builder.SetInsertPoint(ContBB);
10760 Builder.CreateStore(Val: OldValue, Ptr: V.Var);
10761 Builder.CreateBr(Dest: ExitBB);
10762
10763 if (UnreachableInst *ExitTI =
10764 dyn_cast<UnreachableInst>(Val: ExitBB->getTerminator())) {
10765 CurBBTI->eraseFromParent();
10766 Builder.SetInsertPoint(ExitBB);
10767 } else {
10768 Builder.SetInsertPoint(ExitTI);
10769 }
10770 } else {
10771 Value *CapturedValue =
10772 Builder.CreateSelect(C: SuccessOrFail, True: E, False: OldValue);
10773 Builder.CreateStore(Val: CapturedValue, Ptr: V.Var, isVolatile: V.IsVolatile);
10774 }
10775 }
10776 }
10777 // The comparison result has to be stored.
10778 if (R.Var) {
10779 assert(R.Var->getType()->isPointerTy() &&
10780 "r.var must be of pointer type");
10781 assert(R.ElemTy->isIntegerTy() && "r must be of integral type");
10782
10783 Value *SuccessFailureVal = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/1);
10784 Value *ResultCast = R.IsSigned
10785 ? Builder.CreateSExt(V: SuccessFailureVal, DestTy: R.ElemTy)
10786 : Builder.CreateZExt(V: SuccessFailureVal, DestTy: R.ElemTy);
10787 Builder.CreateStore(Val: ResultCast, Ptr: R.Var, isVolatile: R.IsVolatile);
10788 }
10789 } else {
10790 assert((Op == OMPAtomicCompareOp::MAX || Op == OMPAtomicCompareOp::MIN) &&
10791 "Op should be either max or min at this point");
10792 assert(!IsFailOnly && "IsFailOnly is only valid when the comparison is ==");
10793
10794 // Reverse the ordop as the OpenMP forms are different from LLVM forms.
10795 // Let's take max as example.
10796 // OpenMP form:
10797 // x = x > expr ? expr : x;
10798 // LLVM form:
10799 // *ptr = *ptr > val ? *ptr : val;
10800 // We need to transform to LLVM form.
10801 // x = x <= expr ? x : expr;
10802 AtomicRMWInst::BinOp NewOp;
10803 if (IsXBinopExpr) {
10804 if (IsInteger) {
10805 if (X.IsSigned)
10806 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Min
10807 : AtomicRMWInst::Max;
10808 else
10809 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMin
10810 : AtomicRMWInst::UMax;
10811 } else {
10812 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMin
10813 : AtomicRMWInst::FMax;
10814 }
10815 } else {
10816 if (IsInteger) {
10817 if (X.IsSigned)
10818 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Max
10819 : AtomicRMWInst::Min;
10820 else
10821 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMax
10822 : AtomicRMWInst::UMin;
10823 } else {
10824 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMax
10825 : AtomicRMWInst::FMin;
10826 }
10827 }
10828
10829 AtomicRMWInst *OldValue =
10830 Builder.CreateAtomicRMW(Op: NewOp, Ptr: X.Var, Val: E, Align: MaybeAlign(), Ordering: AO);
10831 if (V.Var) {
10832 Value *CapturedValue = nullptr;
10833 if (IsPostfixUpdate) {
10834 CapturedValue = OldValue;
10835 } else {
10836 CmpInst::Predicate Pred;
10837 switch (NewOp) {
10838 case AtomicRMWInst::Max:
10839 Pred = CmpInst::ICMP_SGT;
10840 break;
10841 case AtomicRMWInst::UMax:
10842 Pred = CmpInst::ICMP_UGT;
10843 break;
10844 case AtomicRMWInst::FMax:
10845 Pred = CmpInst::FCMP_OGT;
10846 break;
10847 case AtomicRMWInst::Min:
10848 Pred = CmpInst::ICMP_SLT;
10849 break;
10850 case AtomicRMWInst::UMin:
10851 Pred = CmpInst::ICMP_ULT;
10852 break;
10853 case AtomicRMWInst::FMin:
10854 Pred = CmpInst::FCMP_OLT;
10855 break;
10856 default:
10857 llvm_unreachable("unexpected comparison op");
10858 }
10859 Value *NonAtomicCmp = Builder.CreateCmp(Pred, LHS: OldValue, RHS: E);
10860 CapturedValue = Builder.CreateSelect(C: NonAtomicCmp, True: E, False: OldValue);
10861 }
10862 Builder.CreateStore(Val: CapturedValue, Ptr: V.Var, isVolatile: V.IsVolatile);
10863 }
10864 }
10865
10866 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Compare);
10867
10868 return Builder.saveIP();
10869}
10870
10871OpenMPIRBuilder::InsertPointOrErrorTy
10872OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
10873 BodyGenCallbackTy BodyGenCB, Value *NumTeamsLower,
10874 Value *NumTeamsUpper, Value *ThreadLimit,
10875 Value *IfExpr) {
10876 if (!updateToLocation(Loc))
10877 return InsertPointTy();
10878
10879 uint32_t SrcLocStrSize;
10880 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
10881 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
10882 Function *CurrentFunction = Builder.GetInsertBlock()->getParent();
10883
10884 // Outer allocation basicblock is the entry block of the current function.
10885 BasicBlock &OuterAllocaBB = CurrentFunction->getEntryBlock();
10886 if (&OuterAllocaBB == Builder.GetInsertBlock()) {
10887 BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, Name: "teams.entry");
10888 Builder.SetInsertPoint(TheBB: BodyBB, IP: BodyBB->begin());
10889 }
10890
10891 // The current basic block is split into four basic blocks. After outlining,
10892 // they will be mapped as follows:
10893 // ```
10894 // def current_fn() {
10895 // current_basic_block:
10896 // br label %teams.exit
10897 // teams.exit:
10898 // ; instructions after teams
10899 // }
10900 //
10901 // def outlined_fn() {
10902 // teams.alloca:
10903 // br label %teams.body
10904 // teams.body:
10905 // ; instructions within teams body
10906 // }
10907 // ```
10908 BasicBlock *ExitBB = splitBB(Builder, /*CreateBranch=*/true, Name: "teams.exit");
10909 BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, Name: "teams.body");
10910 BasicBlock *AllocaBB =
10911 splitBB(Builder, /*CreateBranch=*/true, Name: "teams.alloca");
10912
10913 bool SubClausesPresent =
10914 (NumTeamsLower || NumTeamsUpper || ThreadLimit || IfExpr);
10915 // Push num_teams
10916 if (!Config.isTargetDevice() && SubClausesPresent) {
10917 assert((NumTeamsLower == nullptr || NumTeamsUpper != nullptr) &&
10918 "if lowerbound is non-null, then upperbound must also be non-null "
10919 "for bounds on num_teams");
10920
10921 if (NumTeamsUpper == nullptr)
10922 NumTeamsUpper = Builder.getInt32(C: 0);
10923
10924 if (NumTeamsLower == nullptr)
10925 NumTeamsLower = NumTeamsUpper;
10926
10927 if (IfExpr) {
10928 assert(IfExpr->getType()->isIntegerTy() &&
10929 "argument to if clause must be an integer value");
10930
10931 // upper = ifexpr ? upper : 1
10932 if (IfExpr->getType() != Int1)
10933 IfExpr = Builder.CreateICmpNE(LHS: IfExpr,
10934 RHS: ConstantInt::get(Ty: IfExpr->getType(), V: 0));
10935 NumTeamsUpper = Builder.CreateSelect(
10936 C: IfExpr, True: NumTeamsUpper, False: Builder.getInt32(C: 1), Name: "numTeamsUpper");
10937
10938 // lower = ifexpr ? lower : 1
10939 NumTeamsLower = Builder.CreateSelect(
10940 C: IfExpr, True: NumTeamsLower, False: Builder.getInt32(C: 1), Name: "numTeamsLower");
10941 }
10942
10943 if (ThreadLimit == nullptr)
10944 ThreadLimit = Builder.getInt32(C: 0);
10945
10946 // The __kmpc_push_num_teams_51 function expects int32 as the arguments. So,
10947 // truncate or sign extend the passed values to match the int32 parameters.
10948 Value *NumTeamsLowerInt32 =
10949 Builder.CreateSExtOrTrunc(V: NumTeamsLower, DestTy: Builder.getInt32Ty());
10950 Value *NumTeamsUpperInt32 =
10951 Builder.CreateSExtOrTrunc(V: NumTeamsUpper, DestTy: Builder.getInt32Ty());
10952 Value *ThreadLimitInt32 =
10953 Builder.CreateSExtOrTrunc(V: ThreadLimit, DestTy: Builder.getInt32Ty());
10954
10955 Value *ThreadNum = getOrCreateThreadID(Ident);
10956
10957 createRuntimeFunctionCall(
10958 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_push_num_teams_51),
10959 Args: {Ident, ThreadNum, NumTeamsLowerInt32, NumTeamsUpperInt32,
10960 ThreadLimitInt32});
10961 }
10962 // Generate the body of teams.
10963 InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
10964 InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
10965 if (Error Err = BodyGenCB(AllocaIP, CodeGenIP))
10966 return Err;
10967
10968 OutlineInfo OI;
10969 OI.EntryBB = AllocaBB;
10970 OI.ExitBB = ExitBB;
10971 OI.OuterAllocaBB = &OuterAllocaBB;
10972
10973 // Insert fake values for global tid and bound tid.
10974 SmallVector<Instruction *, 8> ToBeDeleted;
10975 InsertPointTy OuterAllocaIP(&OuterAllocaBB, OuterAllocaBB.begin());
10976 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
10977 Builder, OuterAllocaIP, ToBeDeleted, InnerAllocaIP: AllocaIP, Name: "gid", AsPtr: true));
10978 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
10979 Builder, OuterAllocaIP, ToBeDeleted, InnerAllocaIP: AllocaIP, Name: "tid", AsPtr: true));
10980
10981 auto HostPostOutlineCB = [this, Ident,
10982 ToBeDeleted](Function &OutlinedFn) mutable {
10983 // The stale call instruction will be replaced with a new call instruction
10984 // for runtime call with the outlined function.
10985
10986 assert(OutlinedFn.hasOneUse() &&
10987 "there must be a single user for the outlined function");
10988 CallInst *StaleCI = cast<CallInst>(Val: OutlinedFn.user_back());
10989 ToBeDeleted.push_back(Elt: StaleCI);
10990
10991 assert((OutlinedFn.arg_size() == 2 || OutlinedFn.arg_size() == 3) &&
10992 "Outlined function must have two or three arguments only");
10993
10994 bool HasShared = OutlinedFn.arg_size() == 3;
10995
10996 OutlinedFn.getArg(i: 0)->setName("global.tid.ptr");
10997 OutlinedFn.getArg(i: 1)->setName("bound.tid.ptr");
10998 if (HasShared)
10999 OutlinedFn.getArg(i: 2)->setName("data");
11000
11001 // Call to the runtime function for teams in the current function.
11002 assert(StaleCI && "Error while outlining - no CallInst user found for the "
11003 "outlined function.");
11004 Builder.SetInsertPoint(StaleCI);
11005 SmallVector<Value *> Args = {
11006 Ident, Builder.getInt32(C: StaleCI->arg_size() - 2), &OutlinedFn};
11007 if (HasShared)
11008 Args.push_back(Elt: StaleCI->getArgOperand(i: 2));
11009 createRuntimeFunctionCall(
11010 Callee: getOrCreateRuntimeFunctionPtr(
11011 FnID: omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
11012 Args);
11013
11014 for (Instruction *I : llvm::reverse(C&: ToBeDeleted))
11015 I->eraseFromParent();
11016 };
11017
11018 if (!Config.isTargetDevice())
11019 OI.PostOutlineCB = HostPostOutlineCB;
11020
11021 addOutlineInfo(OI: std::move(OI));
11022
11023 Builder.SetInsertPoint(TheBB: ExitBB, IP: ExitBB->begin());
11024
11025 return Builder.saveIP();
11026}
11027
11028OpenMPIRBuilder::InsertPointOrErrorTy
11029OpenMPIRBuilder::createDistribute(const LocationDescription &Loc,
11030 InsertPointTy OuterAllocaIP,
11031 BodyGenCallbackTy BodyGenCB) {
11032 if (!updateToLocation(Loc))
11033 return InsertPointTy();
11034
11035 BasicBlock *OuterAllocaBB = OuterAllocaIP.getBlock();
11036
11037 if (OuterAllocaBB == Builder.GetInsertBlock()) {
11038 BasicBlock *BodyBB =
11039 splitBB(Builder, /*CreateBranch=*/true, Name: "distribute.entry");
11040 Builder.SetInsertPoint(TheBB: BodyBB, IP: BodyBB->begin());
11041 }
11042 BasicBlock *ExitBB =
11043 splitBB(Builder, /*CreateBranch=*/true, Name: "distribute.exit");
11044 BasicBlock *BodyBB =
11045 splitBB(Builder, /*CreateBranch=*/true, Name: "distribute.body");
11046 BasicBlock *AllocaBB =
11047 splitBB(Builder, /*CreateBranch=*/true, Name: "distribute.alloca");
11048
11049 // Generate the body of distribute clause
11050 InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
11051 InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
11052 if (Error Err = BodyGenCB(AllocaIP, CodeGenIP))
11053 return Err;
11054
11055 // When using target we use different runtime functions which require a
11056 // callback.
11057 if (Config.isTargetDevice()) {
11058 OutlineInfo OI;
11059 OI.OuterAllocaBB = OuterAllocaIP.getBlock();
11060 OI.EntryBB = AllocaBB;
11061 OI.ExitBB = ExitBB;
11062
11063 addOutlineInfo(OI: std::move(OI));
11064 }
11065 Builder.SetInsertPoint(TheBB: ExitBB, IP: ExitBB->begin());
11066
11067 return Builder.saveIP();
11068}
11069
11070GlobalVariable *
11071OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl<llvm::Constant *> &Names,
11072 std::string VarName) {
11073 llvm::Constant *MapNamesArrayInit = llvm::ConstantArray::get(
11074 T: llvm::ArrayType::get(ElementType: llvm::PointerType::getUnqual(C&: M.getContext()),
11075 NumElements: Names.size()),
11076 V: Names);
11077 auto *MapNamesArrayGlobal = new llvm::GlobalVariable(
11078 M, MapNamesArrayInit->getType(),
11079 /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MapNamesArrayInit,
11080 VarName);
11081 return MapNamesArrayGlobal;
11082}
11083
11084// Create all simple and struct types exposed by the runtime and remember
11085// the llvm::PointerTypes of them for easy access later.
11086void OpenMPIRBuilder::initializeTypes(Module &M) {
11087 LLVMContext &Ctx = M.getContext();
11088 StructType *T;
11089 unsigned DefaultTargetAS = Config.getDefaultTargetAS();
11090 unsigned ProgramAS = M.getDataLayout().getProgramAddressSpace();
11091#define OMP_TYPE(VarName, InitValue) VarName = InitValue;
11092#define OMP_ARRAY_TYPE(VarName, ElemTy, ArraySize) \
11093 VarName##Ty = ArrayType::get(ElemTy, ArraySize); \
11094 VarName##PtrTy = PointerType::get(Ctx, DefaultTargetAS);
11095#define OMP_FUNCTION_TYPE(VarName, IsVarArg, ReturnType, ...) \
11096 VarName = FunctionType::get(ReturnType, {__VA_ARGS__}, IsVarArg); \
11097 VarName##Ptr = PointerType::get(Ctx, ProgramAS);
11098#define OMP_STRUCT_TYPE(VarName, StructName, Packed, ...) \
11099 T = StructType::getTypeByName(Ctx, StructName); \
11100 if (!T) \
11101 T = StructType::create(Ctx, {__VA_ARGS__}, StructName, Packed); \
11102 VarName = T; \
11103 VarName##Ptr = PointerType::get(Ctx, DefaultTargetAS);
11104#include "llvm/Frontend/OpenMP/OMPKinds.def"
11105}
11106
11107void OpenMPIRBuilder::OutlineInfo::collectBlocks(
11108 SmallPtrSetImpl<BasicBlock *> &BlockSet,
11109 SmallVectorImpl<BasicBlock *> &BlockVector) {
11110 SmallVector<BasicBlock *, 32> Worklist;
11111 BlockSet.insert(Ptr: EntryBB);
11112 BlockSet.insert(Ptr: ExitBB);
11113
11114 Worklist.push_back(Elt: EntryBB);
11115 while (!Worklist.empty()) {
11116 BasicBlock *BB = Worklist.pop_back_val();
11117 BlockVector.push_back(Elt: BB);
11118 for (BasicBlock *SuccBB : successors(BB))
11119 if (BlockSet.insert(Ptr: SuccBB).second)
11120 Worklist.push_back(Elt: SuccBB);
11121 }
11122}
11123
11124void OpenMPIRBuilder::createOffloadEntry(Constant *ID, Constant *Addr,
11125 uint64_t Size, int32_t Flags,
11126 GlobalValue::LinkageTypes,
11127 StringRef Name) {
11128 if (!Config.isGPU()) {
11129 llvm::offloading::emitOffloadingEntry(
11130 M, Kind: object::OffloadKind::OFK_OpenMP, Addr: ID,
11131 Name: Name.empty() ? Addr->getName() : Name, Size, Flags, /*Data=*/0);
11132 return;
11133 }
11134 // TODO: Add support for global variables on the device after declare target
11135 // support.
11136 Function *Fn = dyn_cast<Function>(Val: Addr);
11137 if (!Fn)
11138 return;
11139
11140 // Add a function attribute for the kernel.
11141 Fn->addFnAttr(Kind: "kernel");
11142 if (T.isAMDGCN())
11143 Fn->addFnAttr(Kind: "uniform-work-group-size");
11144 Fn->addFnAttr(Kind: Attribute::MustProgress);
11145}
11146
11147// We only generate metadata for function that contain target regions.
11148void OpenMPIRBuilder::createOffloadEntriesAndInfoMetadata(
11149 EmitMetadataErrorReportFunctionTy &ErrorFn) {
11150
11151 // If there are no entries, we don't need to do anything.
11152 if (OffloadInfoManager.empty())
11153 return;
11154
11155 LLVMContext &C = M.getContext();
11156 SmallVector<std::pair<const OffloadEntriesInfoManager::OffloadEntryInfo *,
11157 TargetRegionEntryInfo>,
11158 16>
11159 OrderedEntries(OffloadInfoManager.size());
11160
11161 // Auxiliary methods to create metadata values and strings.
11162 auto &&GetMDInt = [this](unsigned V) {
11163 return ConstantAsMetadata::get(C: ConstantInt::get(Ty: Builder.getInt32Ty(), V));
11164 };
11165
11166 auto &&GetMDString = [&C](StringRef V) { return MDString::get(Context&: C, Str: V); };
11167
11168 // Create the offloading info metadata node.
11169 NamedMDNode *MD = M.getOrInsertNamedMetadata(Name: "omp_offload.info");
11170 auto &&TargetRegionMetadataEmitter =
11171 [&C, MD, &OrderedEntries, &GetMDInt, &GetMDString](
11172 const TargetRegionEntryInfo &EntryInfo,
11173 const OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion &E) {
11174 // Generate metadata for target regions. Each entry of this metadata
11175 // contains:
11176 // - Entry 0 -> Kind of this type of metadata (0).
11177 // - Entry 1 -> Device ID of the file where the entry was identified.
11178 // - Entry 2 -> File ID of the file where the entry was identified.
11179 // - Entry 3 -> Mangled name of the function where the entry was
11180 // identified.
11181 // - Entry 4 -> Line in the file where the entry was identified.
11182 // - Entry 5 -> Count of regions at this DeviceID/FilesID/Line.
11183 // - Entry 6 -> Order the entry was created.
11184 // The first element of the metadata node is the kind.
11185 Metadata *Ops[] = {
11186 GetMDInt(E.getKind()), GetMDInt(EntryInfo.DeviceID),
11187 GetMDInt(EntryInfo.FileID), GetMDString(EntryInfo.ParentName),
11188 GetMDInt(EntryInfo.Line), GetMDInt(EntryInfo.Count),
11189 GetMDInt(E.getOrder())};
11190
11191 // Save this entry in the right position of the ordered entries array.
11192 OrderedEntries[E.getOrder()] = std::make_pair(x: &E, y: EntryInfo);
11193
11194 // Add metadata to the named metadata node.
11195 MD->addOperand(M: MDNode::get(Context&: C, MDs: Ops));
11196 };
11197
11198 OffloadInfoManager.actOnTargetRegionEntriesInfo(Action: TargetRegionMetadataEmitter);
11199
11200 // Create function that emits metadata for each device global variable entry;
11201 auto &&DeviceGlobalVarMetadataEmitter =
11202 [&C, &OrderedEntries, &GetMDInt, &GetMDString, MD](
11203 StringRef MangledName,
11204 const OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar &E) {
11205 // Generate metadata for global variables. Each entry of this metadata
11206 // contains:
11207 // - Entry 0 -> Kind of this type of metadata (1).
11208 // - Entry 1 -> Mangled name of the variable.
11209 // - Entry 2 -> Declare target kind.
11210 // - Entry 3 -> Order the entry was created.
11211 // The first element of the metadata node is the kind.
11212 Metadata *Ops[] = {GetMDInt(E.getKind()), GetMDString(MangledName),
11213 GetMDInt(E.getFlags()), GetMDInt(E.getOrder())};
11214
11215 // Save this entry in the right position of the ordered entries array.
11216 TargetRegionEntryInfo varInfo(MangledName, 0, 0, 0);
11217 OrderedEntries[E.getOrder()] = std::make_pair(x: &E, y&: varInfo);
11218
11219 // Add metadata to the named metadata node.
11220 MD->addOperand(M: MDNode::get(Context&: C, MDs: Ops));
11221 };
11222
11223 OffloadInfoManager.actOnDeviceGlobalVarEntriesInfo(
11224 Action: DeviceGlobalVarMetadataEmitter);
11225
11226 for (const auto &E : OrderedEntries) {
11227 assert(E.first && "All ordered entries must exist!");
11228 if (const auto *CE =
11229 dyn_cast<OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion>(
11230 Val: E.first)) {
11231 if (!CE->getID() || !CE->getAddress()) {
11232 // Do not blame the entry if the parent funtion is not emitted.
11233 TargetRegionEntryInfo EntryInfo = E.second;
11234 StringRef FnName = EntryInfo.ParentName;
11235 if (!M.getNamedValue(Name: FnName))
11236 continue;
11237 ErrorFn(EMIT_MD_TARGET_REGION_ERROR, EntryInfo);
11238 continue;
11239 }
11240 createOffloadEntry(ID: CE->getID(), Addr: CE->getAddress(),
11241 /*Size=*/0, Flags: CE->getFlags(),
11242 GlobalValue::WeakAnyLinkage);
11243 } else if (const auto *CE = dyn_cast<
11244 OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar>(
11245 Val: E.first)) {
11246 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags =
11247 static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
11248 CE->getFlags());
11249 switch (Flags) {
11250 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter:
11251 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo:
11252 if (Config.isTargetDevice() && Config.hasRequiresUnifiedSharedMemory())
11253 continue;
11254 if (!CE->getAddress()) {
11255 ErrorFn(EMIT_MD_DECLARE_TARGET_ERROR, E.second);
11256 continue;
11257 }
11258 // The vaiable has no definition - no need to add the entry.
11259 if (CE->getVarSize() == 0)
11260 continue;
11261 break;
11262 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink:
11263 assert(((Config.isTargetDevice() && !CE->getAddress()) ||
11264 (!Config.isTargetDevice() && CE->getAddress())) &&
11265 "Declaret target link address is set.");
11266 if (Config.isTargetDevice())
11267 continue;
11268 if (!CE->getAddress()) {
11269 ErrorFn(EMIT_MD_GLOBAL_VAR_LINK_ERROR, TargetRegionEntryInfo());
11270 continue;
11271 }
11272 break;
11273 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect:
11274 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirectVTable:
11275 if (!CE->getAddress()) {
11276 ErrorFn(EMIT_MD_GLOBAL_VAR_INDIRECT_ERROR, E.second);
11277 continue;
11278 }
11279 break;
11280 default:
11281 break;
11282 }
11283
11284 // Hidden or internal symbols on the device are not externally visible.
11285 // We should not attempt to register them by creating an offloading
11286 // entry. Indirect variables are handled separately on the device.
11287 if (auto *GV = dyn_cast<GlobalValue>(Val: CE->getAddress()))
11288 if ((GV->hasLocalLinkage() || GV->hasHiddenVisibility()) &&
11289 (Flags !=
11290 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect &&
11291 Flags != OffloadEntriesInfoManager::
11292 OMPTargetGlobalVarEntryIndirectVTable))
11293 continue;
11294
11295 // Indirect globals need to use a special name that doesn't match the name
11296 // of the associated host global.
11297 if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect ||
11298 Flags ==
11299 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirectVTable)
11300 createOffloadEntry(ID: CE->getAddress(), Addr: CE->getAddress(), Size: CE->getVarSize(),
11301 Flags, CE->getLinkage(), Name: CE->getVarName());
11302 else
11303 createOffloadEntry(ID: CE->getAddress(), Addr: CE->getAddress(), Size: CE->getVarSize(),
11304 Flags, CE->getLinkage());
11305
11306 } else {
11307 llvm_unreachable("Unsupported entry kind.");
11308 }
11309 }
11310
11311 // Emit requires directive globals to a special entry so the runtime can
11312 // register them when the device image is loaded.
11313 // TODO: This reduces the offloading entries to a 32-bit integer. Offloading
11314 // entries should be redesigned to better suit this use-case.
11315 if (Config.hasRequiresFlags() && !Config.isTargetDevice())
11316 offloading::emitOffloadingEntry(
11317 M, Kind: object::OffloadKind::OFK_OpenMP,
11318 Addr: Constant::getNullValue(Ty: PointerType::getUnqual(C&: M.getContext())),
11319 Name: ".requires", /*Size=*/0,
11320 Flags: OffloadEntriesInfoManager::OMPTargetGlobalRegisterRequires,
11321 Data: Config.getRequiresFlags());
11322}
11323
11324void TargetRegionEntryInfo::getTargetRegionEntryFnName(
11325 SmallVectorImpl<char> &Name, StringRef ParentName, unsigned DeviceID,
11326 unsigned FileID, unsigned Line, unsigned Count) {
11327 raw_svector_ostream OS(Name);
11328 OS << KernelNamePrefix << llvm::format(Fmt: "%x", Vals: DeviceID)
11329 << llvm::format(Fmt: "_%x_", Vals: FileID) << ParentName << "_l" << Line;
11330 if (Count)
11331 OS << "_" << Count;
11332}
11333
11334void OffloadEntriesInfoManager::getTargetRegionEntryFnName(
11335 SmallVectorImpl<char> &Name, const TargetRegionEntryInfo &EntryInfo) {
11336 unsigned NewCount = getTargetRegionEntryInfoCount(EntryInfo);
11337 TargetRegionEntryInfo::getTargetRegionEntryFnName(
11338 Name, ParentName: EntryInfo.ParentName, DeviceID: EntryInfo.DeviceID, FileID: EntryInfo.FileID,
11339 Line: EntryInfo.Line, Count: NewCount);
11340}
11341
11342TargetRegionEntryInfo
11343OpenMPIRBuilder::getTargetEntryUniqueInfo(FileIdentifierInfoCallbackTy CallBack,
11344 vfs::FileSystem &VFS,
11345 StringRef ParentName) {
11346 sys::fs::UniqueID ID(0xdeadf17e, 0);
11347 auto FileIDInfo = CallBack();
11348 uint64_t FileID = 0;
11349 if (ErrorOr<vfs::Status> Status = VFS.status(Path: std::get<0>(t&: FileIDInfo))) {
11350 ID = Status->getUniqueID();
11351 FileID = Status->getUniqueID().getFile();
11352 } else {
11353 // If the inode ID could not be determined, create a hash value
11354 // the current file name and use that as an ID.
11355 FileID = hash_value(arg: std::get<0>(t&: FileIDInfo));
11356 }
11357
11358 return TargetRegionEntryInfo(ParentName, ID.getDevice(), FileID,
11359 std::get<1>(t&: FileIDInfo));
11360}
11361
11362unsigned OpenMPIRBuilder::getFlagMemberOffset() {
11363 unsigned Offset = 0;
11364 for (uint64_t Remain =
11365 static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
11366 omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF);
11367 !(Remain & 1); Remain = Remain >> 1)
11368 Offset++;
11369 return Offset;
11370}
11371
11372omp::OpenMPOffloadMappingFlags
11373OpenMPIRBuilder::getMemberOfFlag(unsigned Position) {
11374 // Rotate by getFlagMemberOffset() bits.
11375 return static_cast<omp::OpenMPOffloadMappingFlags>(((uint64_t)Position + 1)
11376 << getFlagMemberOffset());
11377}
11378
11379void OpenMPIRBuilder::setCorrectMemberOfFlag(
11380 omp::OpenMPOffloadMappingFlags &Flags,
11381 omp::OpenMPOffloadMappingFlags MemberOfFlag) {
11382 // If the entry is PTR_AND_OBJ but has not been marked with the special
11383 // placeholder value 0xFFFF in the MEMBER_OF field, then it should not be
11384 // marked as MEMBER_OF.
11385 if (static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
11386 Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ) &&
11387 static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
11388 (Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF) !=
11389 omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF))
11390 return;
11391
11392 // Entries with ATTACH are not members-of anything. They are handled
11393 // separately by the runtime after other maps have been handled.
11394 if (static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
11395 Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH))
11396 return;
11397
11398 // Reset the placeholder value to prepare the flag for the assignment of the
11399 // proper MEMBER_OF value.
11400 Flags &= ~omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
11401 Flags |= MemberOfFlag;
11402}
11403
11404Constant *OpenMPIRBuilder::getAddrOfDeclareTargetVar(
11405 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
11406 OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
11407 bool IsDeclaration, bool IsExternallyVisible,
11408 TargetRegionEntryInfo EntryInfo, StringRef MangledName,
11409 std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
11410 std::vector<Triple> TargetTriple, Type *LlvmPtrTy,
11411 std::function<Constant *()> GlobalInitializer,
11412 std::function<GlobalValue::LinkageTypes()> VariableLinkage) {
11413 // TODO: convert this to utilise the IRBuilder Config rather than
11414 // a passed down argument.
11415 if (OpenMPSIMD)
11416 return nullptr;
11417
11418 if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink ||
11419 ((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
11420 CaptureClause ==
11421 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
11422 Config.hasRequiresUnifiedSharedMemory())) {
11423 SmallString<64> PtrName;
11424 {
11425 raw_svector_ostream OS(PtrName);
11426 OS << MangledName;
11427 if (!IsExternallyVisible)
11428 OS << format(Fmt: "_%x", Vals: EntryInfo.FileID);
11429 OS << "_decl_tgt_ref_ptr";
11430 }
11431
11432 Value *Ptr = M.getNamedValue(Name: PtrName);
11433
11434 if (!Ptr) {
11435 GlobalValue *GlobalValue = M.getNamedValue(Name: MangledName);
11436 Ptr = getOrCreateInternalVariable(Ty: LlvmPtrTy, Name: PtrName);
11437
11438 auto *GV = cast<GlobalVariable>(Val: Ptr);
11439 GV->setLinkage(GlobalValue::WeakAnyLinkage);
11440
11441 if (!Config.isTargetDevice()) {
11442 if (GlobalInitializer)
11443 GV->setInitializer(GlobalInitializer());
11444 else
11445 GV->setInitializer(GlobalValue);
11446 }
11447
11448 registerTargetGlobalVariable(
11449 CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
11450 EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
11451 GlobalInitializer, VariableLinkage, LlvmPtrTy, Addr: cast<Constant>(Val: Ptr));
11452 }
11453
11454 return cast<Constant>(Val: Ptr);
11455 }
11456
11457 return nullptr;
11458}
11459
11460void OpenMPIRBuilder::registerTargetGlobalVariable(
11461 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
11462 OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
11463 bool IsDeclaration, bool IsExternallyVisible,
11464 TargetRegionEntryInfo EntryInfo, StringRef MangledName,
11465 std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
11466 std::vector<Triple> TargetTriple,
11467 std::function<Constant *()> GlobalInitializer,
11468 std::function<GlobalValue::LinkageTypes()> VariableLinkage, Type *LlvmPtrTy,
11469 Constant *Addr) {
11470 if (DeviceClause != OffloadEntriesInfoManager::OMPTargetDeviceClauseAny ||
11471 (TargetTriple.empty() && !Config.isTargetDevice()))
11472 return;
11473
11474 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags;
11475 StringRef VarName;
11476 int64_t VarSize;
11477 GlobalValue::LinkageTypes Linkage;
11478
11479 if ((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
11480 CaptureClause ==
11481 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
11482 !Config.hasRequiresUnifiedSharedMemory()) {
11483 Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
11484 VarName = MangledName;
11485 GlobalValue *LlvmVal = M.getNamedValue(Name: VarName);
11486
11487 if (!IsDeclaration)
11488 VarSize = divideCeil(
11489 Numerator: M.getDataLayout().getTypeSizeInBits(Ty: LlvmVal->getValueType()), Denominator: 8);
11490 else
11491 VarSize = 0;
11492 Linkage = (VariableLinkage) ? VariableLinkage() : LlvmVal->getLinkage();
11493
11494 // This is a workaround carried over from Clang which prevents undesired
11495 // optimisation of internal variables.
11496 if (Config.isTargetDevice() &&
11497 (!IsExternallyVisible || Linkage == GlobalValue::LinkOnceODRLinkage)) {
11498 // Do not create a "ref-variable" if the original is not also available
11499 // on the host.
11500 if (!OffloadInfoManager.hasDeviceGlobalVarEntryInfo(VarName))
11501 return;
11502
11503 std::string RefName = createPlatformSpecificName(Parts: {VarName, "ref"});
11504
11505 if (!M.getNamedValue(Name: RefName)) {
11506 Constant *AddrRef =
11507 getOrCreateInternalVariable(Ty: Addr->getType(), Name: RefName);
11508 auto *GvAddrRef = cast<GlobalVariable>(Val: AddrRef);
11509 GvAddrRef->setConstant(true);
11510 GvAddrRef->setLinkage(GlobalValue::InternalLinkage);
11511 GvAddrRef->setInitializer(Addr);
11512 GeneratedRefs.push_back(x: GvAddrRef);
11513 }
11514 }
11515 } else {
11516 if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink)
11517 Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
11518 else
11519 Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
11520
11521 if (Config.isTargetDevice()) {
11522 VarName = (Addr) ? Addr->getName() : "";
11523 Addr = nullptr;
11524 } else {
11525 Addr = getAddrOfDeclareTargetVar(
11526 CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
11527 EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
11528 LlvmPtrTy, GlobalInitializer, VariableLinkage);
11529 VarName = (Addr) ? Addr->getName() : "";
11530 }
11531 VarSize = M.getDataLayout().getPointerSize();
11532 Linkage = GlobalValue::WeakAnyLinkage;
11533 }
11534
11535 OffloadInfoManager.registerDeviceGlobalVarEntryInfo(VarName, Addr, VarSize,
11536 Flags, Linkage);
11537}
11538
11539/// Loads all the offload entries information from the host IR
11540/// metadata.
11541void OpenMPIRBuilder::loadOffloadInfoMetadata(Module &M) {
11542 // If we are in target mode, load the metadata from the host IR. This code has
11543 // to match the metadata creation in createOffloadEntriesAndInfoMetadata().
11544
11545 NamedMDNode *MD = M.getNamedMetadata(Name: ompOffloadInfoName);
11546 if (!MD)
11547 return;
11548
11549 for (MDNode *MN : MD->operands()) {
11550 auto &&GetMDInt = [MN](unsigned Idx) {
11551 auto *V = cast<ConstantAsMetadata>(Val: MN->getOperand(I: Idx));
11552 return cast<ConstantInt>(Val: V->getValue())->getZExtValue();
11553 };
11554
11555 auto &&GetMDString = [MN](unsigned Idx) {
11556 auto *V = cast<MDString>(Val: MN->getOperand(I: Idx));
11557 return V->getString();
11558 };
11559
11560 switch (GetMDInt(0)) {
11561 default:
11562 llvm_unreachable("Unexpected metadata!");
11563 break;
11564 case OffloadEntriesInfoManager::OffloadEntryInfo::
11565 OffloadingEntryInfoTargetRegion: {
11566 TargetRegionEntryInfo EntryInfo(/*ParentName=*/GetMDString(3),
11567 /*DeviceID=*/GetMDInt(1),
11568 /*FileID=*/GetMDInt(2),
11569 /*Line=*/GetMDInt(4),
11570 /*Count=*/GetMDInt(5));
11571 OffloadInfoManager.initializeTargetRegionEntryInfo(EntryInfo,
11572 /*Order=*/GetMDInt(6));
11573 break;
11574 }
11575 case OffloadEntriesInfoManager::OffloadEntryInfo::
11576 OffloadingEntryInfoDeviceGlobalVar:
11577 OffloadInfoManager.initializeDeviceGlobalVarEntryInfo(
11578 /*MangledName=*/Name: GetMDString(1),
11579 Flags: static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
11580 /*Flags=*/GetMDInt(2)),
11581 /*Order=*/GetMDInt(3));
11582 break;
11583 }
11584 }
11585}
11586
11587void OpenMPIRBuilder::loadOffloadInfoMetadata(vfs::FileSystem &VFS,
11588 StringRef HostFilePath) {
11589 if (HostFilePath.empty())
11590 return;
11591
11592 auto Buf = VFS.getBufferForFile(Name: HostFilePath);
11593 if (std::error_code Err = Buf.getError()) {
11594 report_fatal_error(reason: ("error opening host file from host file path inside of "
11595 "OpenMPIRBuilder: " +
11596 Err.message())
11597 .c_str());
11598 }
11599
11600 LLVMContext Ctx;
11601 auto M = expectedToErrorOrAndEmitErrors(
11602 Ctx, Val: parseBitcodeFile(Buffer: Buf.get()->getMemBufferRef(), Context&: Ctx));
11603 if (std::error_code Err = M.getError()) {
11604 report_fatal_error(
11605 reason: ("error parsing host file inside of OpenMPIRBuilder: " + Err.message())
11606 .c_str());
11607 }
11608
11609 loadOffloadInfoMetadata(M&: *M.get());
11610}
11611
11612OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createIteratorLoop(
11613 LocationDescription Loc, llvm::Value *TripCount, IteratorBodyGenTy BodyGen,
11614 llvm::StringRef Name) {
11615 Builder.restoreIP(IP: Loc.IP);
11616
11617 BasicBlock *CurBB = Builder.GetInsertBlock();
11618 assert(CurBB &&
11619 "expected a valid insertion block for creating an iterator loop");
11620 Function *F = CurBB->getParent();
11621
11622 InsertPointTy SplitIP = Builder.saveIP();
11623 if (SplitIP.getPoint() == CurBB->end())
11624 if (Instruction *Terminator = CurBB->getTerminatorOrNull())
11625 SplitIP = InsertPointTy(CurBB, Terminator->getIterator());
11626
11627 BasicBlock *ContBB =
11628 splitBB(IP: SplitIP, /*CreateBranch=*/false,
11629 DL: Builder.getCurrentDebugLocation(), Name: "omp.it.cont");
11630
11631 CanonicalLoopInfo *CLI =
11632 createLoopSkeleton(DL: Builder.getCurrentDebugLocation(), TripCount, F,
11633 /*PreInsertBefore=*/ContBB,
11634 /*PostInsertBefore=*/ContBB, Name);
11635
11636 // Enter loop from original block.
11637 redirectTo(Source: CurBB, Target: CLI->getPreheader(), DL: Builder.getCurrentDebugLocation());
11638
11639 // Remove the unconditional branch inserted by createLoopSkeleton in the body
11640 if (Instruction *T = CLI->getBody()->getTerminatorOrNull())
11641 T->eraseFromParent();
11642
11643 InsertPointTy BodyIP = CLI->getBodyIP();
11644 if (llvm::Error Err = BodyGen(BodyIP, CLI->getIndVar()))
11645 return Err;
11646
11647 // Body must either fallthrough to the latch or branch directly to it.
11648 if (Instruction *BodyTerminator = CLI->getBody()->getTerminatorOrNull()) {
11649 auto *BodyBr = dyn_cast<UncondBrInst>(Val: BodyTerminator);
11650 if (!BodyBr || BodyBr->getSuccessor() != CLI->getLatch()) {
11651 return make_error<StringError>(
11652 Args: "iterator bodygen must terminate the canonical body with an "
11653 "unconditional branch to the loop latch",
11654 Args: inconvertibleErrorCode());
11655 }
11656 } else {
11657 // Ensure we end the loop body by jumping to the latch.
11658 Builder.SetInsertPoint(CLI->getBody());
11659 Builder.CreateBr(Dest: CLI->getLatch());
11660 }
11661
11662 // Link After -> ContBB
11663 Builder.SetInsertPoint(TheBB: CLI->getAfter(), IP: CLI->getAfter()->begin());
11664 if (!CLI->getAfter()->hasTerminator())
11665 Builder.CreateBr(Dest: ContBB);
11666
11667 return InsertPointTy{ContBB, ContBB->begin()};
11668}
11669
11670/// Mangle the parameter part of the vector function name according to
11671/// their OpenMP classification. The mangling function is defined in
11672/// section 4.5 of the AAVFABI(2021Q1).
11673static std::string mangleVectorParameters(
11674 ArrayRef<llvm::OpenMPIRBuilder::DeclareSimdAttrTy> ParamAttrs) {
11675 SmallString<256> Buffer;
11676 llvm::raw_svector_ostream Out(Buffer);
11677 for (const auto &ParamAttr : ParamAttrs) {
11678 switch (ParamAttr.Kind) {
11679 case llvm::OpenMPIRBuilder::DeclareSimdKindTy::Linear:
11680 Out << 'l';
11681 break;
11682 case llvm::OpenMPIRBuilder::DeclareSimdKindTy::LinearRef:
11683 Out << 'R';
11684 break;
11685 case llvm::OpenMPIRBuilder::DeclareSimdKindTy::LinearUVal:
11686 Out << 'U';
11687 break;
11688 case llvm::OpenMPIRBuilder::DeclareSimdKindTy::LinearVal:
11689 Out << 'L';
11690 break;
11691 case llvm::OpenMPIRBuilder::DeclareSimdKindTy::Uniform:
11692 Out << 'u';
11693 break;
11694 case llvm::OpenMPIRBuilder::DeclareSimdKindTy::Vector:
11695 Out << 'v';
11696 break;
11697 }
11698 if (ParamAttr.HasVarStride)
11699 Out << "s" << ParamAttr.StrideOrArg;
11700 else if (ParamAttr.Kind ==
11701 llvm::OpenMPIRBuilder::DeclareSimdKindTy::Linear ||
11702 ParamAttr.Kind ==
11703 llvm::OpenMPIRBuilder::DeclareSimdKindTy::LinearRef ||
11704 ParamAttr.Kind ==
11705 llvm::OpenMPIRBuilder::DeclareSimdKindTy::LinearUVal ||
11706 ParamAttr.Kind ==
11707 llvm::OpenMPIRBuilder::DeclareSimdKindTy::LinearVal) {
11708 // Don't print the step value if it is not present or if it is
11709 // equal to 1.
11710 if (ParamAttr.StrideOrArg < 0)
11711 Out << 'n' << -ParamAttr.StrideOrArg;
11712 else if (ParamAttr.StrideOrArg != 1)
11713 Out << ParamAttr.StrideOrArg;
11714 }
11715
11716 if (!!ParamAttr.Alignment)
11717 Out << 'a' << ParamAttr.Alignment;
11718 }
11719
11720 return std::string(Out.str());
11721}
11722
11723void OpenMPIRBuilder::emitX86DeclareSimdFunction(
11724 llvm::Function *Fn, unsigned NumElts, const llvm::APSInt &VLENVal,
11725 llvm::ArrayRef<DeclareSimdAttrTy> ParamAttrs, DeclareSimdBranch Branch) {
11726 struct ISADataTy {
11727 char ISA;
11728 unsigned VecRegSize;
11729 };
11730 ISADataTy ISAData[] = {
11731 {.ISA: 'b', .VecRegSize: 128}, // SSE
11732 {.ISA: 'c', .VecRegSize: 256}, // AVX
11733 {.ISA: 'd', .VecRegSize: 256}, // AVX2
11734 {.ISA: 'e', .VecRegSize: 512}, // AVX512
11735 };
11736 llvm::SmallVector<char, 2> Masked;
11737 switch (Branch) {
11738 case DeclareSimdBranch::Undefined:
11739 Masked.push_back(Elt: 'N');
11740 Masked.push_back(Elt: 'M');
11741 break;
11742 case DeclareSimdBranch::Notinbranch:
11743 Masked.push_back(Elt: 'N');
11744 break;
11745 case DeclareSimdBranch::Inbranch:
11746 Masked.push_back(Elt: 'M');
11747 break;
11748 }
11749 for (char Mask : Masked) {
11750 for (const ISADataTy &Data : ISAData) {
11751 llvm::SmallString<256> Buffer;
11752 llvm::raw_svector_ostream Out(Buffer);
11753 Out << "_ZGV" << Data.ISA << Mask;
11754 if (!VLENVal) {
11755 assert(NumElts && "Non-zero simdlen/cdtsize expected");
11756 Out << llvm::APSInt::getUnsigned(X: Data.VecRegSize / NumElts);
11757 } else {
11758 Out << VLENVal;
11759 }
11760 Out << mangleVectorParameters(ParamAttrs);
11761 Out << '_' << Fn->getName();
11762 Fn->addFnAttr(Kind: Out.str());
11763 }
11764 }
11765}
11766
11767// Function used to add the attribute. The parameter `VLEN` is templated to
11768// allow the use of `x` when targeting scalable functions for SVE.
11769template <typename T>
11770static void addAArch64VectorName(T VLEN, StringRef LMask, StringRef Prefix,
11771 char ISA, StringRef ParSeq,
11772 StringRef MangledName, bool OutputBecomesInput,
11773 llvm::Function *Fn) {
11774 SmallString<256> Buffer;
11775 llvm::raw_svector_ostream Out(Buffer);
11776 Out << Prefix << ISA << LMask << VLEN;
11777 if (OutputBecomesInput)
11778 Out << 'v';
11779 Out << ParSeq << '_' << MangledName;
11780 Fn->addFnAttr(Kind: Out.str());
11781}
11782
11783// Helper function to generate the Advanced SIMD names depending on the value
11784// of the NDS when simdlen is not present.
11785static void addAArch64AdvSIMDNDSNames(unsigned NDS, StringRef Mask,
11786 StringRef Prefix, char ISA,
11787 StringRef ParSeq, StringRef MangledName,
11788 bool OutputBecomesInput,
11789 llvm::Function *Fn) {
11790 switch (NDS) {
11791 case 8:
11792 addAArch64VectorName(VLEN: 8, LMask: Mask, Prefix, ISA, ParSeq, MangledName,
11793 OutputBecomesInput, Fn);
11794 addAArch64VectorName(VLEN: 16, LMask: Mask, Prefix, ISA, ParSeq, MangledName,
11795 OutputBecomesInput, Fn);
11796 break;
11797 case 16:
11798 addAArch64VectorName(VLEN: 4, LMask: Mask, Prefix, ISA, ParSeq, MangledName,
11799 OutputBecomesInput, Fn);
11800 addAArch64VectorName(VLEN: 8, LMask: Mask, Prefix, ISA, ParSeq, MangledName,
11801 OutputBecomesInput, Fn);
11802 break;
11803 case 32:
11804 addAArch64VectorName(VLEN: 2, LMask: Mask, Prefix, ISA, ParSeq, MangledName,
11805 OutputBecomesInput, Fn);
11806 addAArch64VectorName(VLEN: 4, LMask: Mask, Prefix, ISA, ParSeq, MangledName,
11807 OutputBecomesInput, Fn);
11808 break;
11809 case 64:
11810 case 128:
11811 addAArch64VectorName(VLEN: 2, LMask: Mask, Prefix, ISA, ParSeq, MangledName,
11812 OutputBecomesInput, Fn);
11813 break;
11814 default:
11815 llvm_unreachable("Scalar type is too wide.");
11816 }
11817}
11818
11819/// Emit vector function attributes for AArch64, as defined in the AAVFABI.
11820void OpenMPIRBuilder::emitAArch64DeclareSimdFunction(
11821 llvm::Function *Fn, unsigned UserVLEN,
11822 llvm::ArrayRef<DeclareSimdAttrTy> ParamAttrs, DeclareSimdBranch Branch,
11823 char ISA, unsigned NarrowestDataSize, bool OutputBecomesInput) {
11824 assert((ISA == 'n' || ISA == 's') && "Expected ISA either 's' or 'n'.");
11825
11826 // Sort out parameter sequence.
11827 const std::string ParSeq = mangleVectorParameters(ParamAttrs);
11828 StringRef Prefix = "_ZGV";
11829 StringRef MangledName = Fn->getName();
11830
11831 // Generate simdlen from user input (if any).
11832 if (UserVLEN) {
11833 if (ISA == 's') {
11834 // SVE generates only a masked function.
11835 addAArch64VectorName(VLEN: UserVLEN, LMask: "M", Prefix, ISA, ParSeq, MangledName,
11836 OutputBecomesInput, Fn);
11837 return;
11838 }
11839
11840 switch (Branch) {
11841 case DeclareSimdBranch::Undefined:
11842 addAArch64VectorName(VLEN: UserVLEN, LMask: "N", Prefix, ISA, ParSeq, MangledName,
11843 OutputBecomesInput, Fn);
11844 addAArch64VectorName(VLEN: UserVLEN, LMask: "M", Prefix, ISA, ParSeq, MangledName,
11845 OutputBecomesInput, Fn);
11846 break;
11847 case DeclareSimdBranch::Inbranch:
11848 addAArch64VectorName(VLEN: UserVLEN, LMask: "M", Prefix, ISA, ParSeq, MangledName,
11849 OutputBecomesInput, Fn);
11850 break;
11851 case DeclareSimdBranch::Notinbranch:
11852 addAArch64VectorName(VLEN: UserVLEN, LMask: "N", Prefix, ISA, ParSeq, MangledName,
11853 OutputBecomesInput, Fn);
11854 break;
11855 }
11856 return;
11857 }
11858
11859 if (ISA == 's') {
11860 // SVE, section 3.4.1, item 1.
11861 addAArch64VectorName(VLEN: "x", LMask: "M", Prefix, ISA, ParSeq, MangledName,
11862 OutputBecomesInput, Fn);
11863 return;
11864 }
11865
11866 switch (Branch) {
11867 case DeclareSimdBranch::Undefined:
11868 addAArch64AdvSIMDNDSNames(NDS: NarrowestDataSize, Mask: "N", Prefix, ISA, ParSeq,
11869 MangledName, OutputBecomesInput, Fn);
11870 addAArch64AdvSIMDNDSNames(NDS: NarrowestDataSize, Mask: "M", Prefix, ISA, ParSeq,
11871 MangledName, OutputBecomesInput, Fn);
11872 break;
11873 case DeclareSimdBranch::Inbranch:
11874 addAArch64AdvSIMDNDSNames(NDS: NarrowestDataSize, Mask: "M", Prefix, ISA, ParSeq,
11875 MangledName, OutputBecomesInput, Fn);
11876 break;
11877 case DeclareSimdBranch::Notinbranch:
11878 addAArch64AdvSIMDNDSNames(NDS: NarrowestDataSize, Mask: "N", Prefix, ISA, ParSeq,
11879 MangledName, OutputBecomesInput, Fn);
11880 break;
11881 }
11882}
11883
11884//===----------------------------------------------------------------------===//
11885// OffloadEntriesInfoManager
11886//===----------------------------------------------------------------------===//
11887
11888bool OffloadEntriesInfoManager::empty() const {
11889 return OffloadEntriesTargetRegion.empty() &&
11890 OffloadEntriesDeviceGlobalVar.empty();
11891}
11892
11893unsigned OffloadEntriesInfoManager::getTargetRegionEntryInfoCount(
11894 const TargetRegionEntryInfo &EntryInfo) const {
11895 auto It = OffloadEntriesTargetRegionCount.find(
11896 x: getTargetRegionEntryCountKey(EntryInfo));
11897 if (It == OffloadEntriesTargetRegionCount.end())
11898 return 0;
11899 return It->second;
11900}
11901
11902void OffloadEntriesInfoManager::incrementTargetRegionEntryInfoCount(
11903 const TargetRegionEntryInfo &EntryInfo) {
11904 OffloadEntriesTargetRegionCount[getTargetRegionEntryCountKey(EntryInfo)] =
11905 EntryInfo.Count + 1;
11906}
11907
11908/// Initialize target region entry.
11909void OffloadEntriesInfoManager::initializeTargetRegionEntryInfo(
11910 const TargetRegionEntryInfo &EntryInfo, unsigned Order) {
11911 OffloadEntriesTargetRegion[EntryInfo] =
11912 OffloadEntryInfoTargetRegion(Order, /*Addr=*/nullptr, /*ID=*/nullptr,
11913 OMPTargetRegionEntryTargetRegion);
11914 ++OffloadingEntriesNum;
11915}
11916
11917void OffloadEntriesInfoManager::registerTargetRegionEntryInfo(
11918 TargetRegionEntryInfo EntryInfo, Constant *Addr, Constant *ID,
11919 OMPTargetRegionEntryKind Flags) {
11920 assert(EntryInfo.Count == 0 && "expected default EntryInfo");
11921
11922 // Update the EntryInfo with the next available count for this location.
11923 EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
11924
11925 // If we are emitting code for a target, the entry is already initialized,
11926 // only has to be registered.
11927 if (OMPBuilder->Config.isTargetDevice()) {
11928 // This could happen if the device compilation is invoked standalone.
11929 if (!hasTargetRegionEntryInfo(EntryInfo)) {
11930 return;
11931 }
11932 auto &Entry = OffloadEntriesTargetRegion[EntryInfo];
11933 Entry.setAddress(Addr);
11934 Entry.setID(ID);
11935 Entry.setFlags(Flags);
11936 } else {
11937 if (Flags == OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion &&
11938 hasTargetRegionEntryInfo(EntryInfo, /*IgnoreAddressId*/ true))
11939 return;
11940 assert(!hasTargetRegionEntryInfo(EntryInfo) &&
11941 "Target region entry already registered!");
11942 OffloadEntryInfoTargetRegion Entry(OffloadingEntriesNum, Addr, ID, Flags);
11943 OffloadEntriesTargetRegion[EntryInfo] = Entry;
11944 ++OffloadingEntriesNum;
11945 }
11946 incrementTargetRegionEntryInfoCount(EntryInfo);
11947}
11948
11949bool OffloadEntriesInfoManager::hasTargetRegionEntryInfo(
11950 TargetRegionEntryInfo EntryInfo, bool IgnoreAddressId) const {
11951
11952 // Update the EntryInfo with the next available count for this location.
11953 EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
11954
11955 auto It = OffloadEntriesTargetRegion.find(x: EntryInfo);
11956 if (It == OffloadEntriesTargetRegion.end()) {
11957 return false;
11958 }
11959 // Fail if this entry is already registered.
11960 if (!IgnoreAddressId && (It->second.getAddress() || It->second.getID()))
11961 return false;
11962 return true;
11963}
11964
11965void OffloadEntriesInfoManager::actOnTargetRegionEntriesInfo(
11966 const OffloadTargetRegionEntryInfoActTy &Action) {
11967 // Scan all target region entries and perform the provided action.
11968 for (const auto &It : OffloadEntriesTargetRegion) {
11969 Action(It.first, It.second);
11970 }
11971}
11972
11973void OffloadEntriesInfoManager::initializeDeviceGlobalVarEntryInfo(
11974 StringRef Name, OMPTargetGlobalVarEntryKind Flags, unsigned Order) {
11975 OffloadEntriesDeviceGlobalVar.try_emplace(Key: Name, Args&: Order, Args&: Flags);
11976 ++OffloadingEntriesNum;
11977}
11978
11979void OffloadEntriesInfoManager::registerDeviceGlobalVarEntryInfo(
11980 StringRef VarName, Constant *Addr, int64_t VarSize,
11981 OMPTargetGlobalVarEntryKind Flags, GlobalValue::LinkageTypes Linkage) {
11982 if (OMPBuilder->Config.isTargetDevice()) {
11983 // This could happen if the device compilation is invoked standalone.
11984 if (!hasDeviceGlobalVarEntryInfo(VarName))
11985 return;
11986 auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
11987 if (Entry.getAddress() && hasDeviceGlobalVarEntryInfo(VarName)) {
11988 if (Entry.getVarSize() == 0) {
11989 Entry.setVarSize(VarSize);
11990 Entry.setLinkage(Linkage);
11991 }
11992 return;
11993 }
11994 Entry.setVarSize(VarSize);
11995 Entry.setLinkage(Linkage);
11996 Entry.setAddress(Addr);
11997 } else {
11998 if (hasDeviceGlobalVarEntryInfo(VarName)) {
11999 auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
12000 assert(Entry.isValid() && Entry.getFlags() == Flags &&
12001 "Entry not initialized!");
12002 if (Entry.getVarSize() == 0) {
12003 Entry.setVarSize(VarSize);
12004 Entry.setLinkage(Linkage);
12005 }
12006 return;
12007 }
12008 if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect ||
12009 Flags ==
12010 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirectVTable)
12011 OffloadEntriesDeviceGlobalVar.try_emplace(Key: VarName, Args&: OffloadingEntriesNum,
12012 Args&: Addr, Args&: VarSize, Args&: Flags, Args&: Linkage,
12013 Args: VarName.str());
12014 else
12015 OffloadEntriesDeviceGlobalVar.try_emplace(
12016 Key: VarName, Args&: OffloadingEntriesNum, Args&: Addr, Args&: VarSize, Args&: Flags, Args&: Linkage, Args: "");
12017 ++OffloadingEntriesNum;
12018 }
12019}
12020
12021void OffloadEntriesInfoManager::actOnDeviceGlobalVarEntriesInfo(
12022 const OffloadDeviceGlobalVarEntryInfoActTy &Action) {
12023 // Scan all target region entries and perform the provided action.
12024 for (const auto &E : OffloadEntriesDeviceGlobalVar)
12025 Action(E.getKey(), E.getValue());
12026}
12027
12028//===----------------------------------------------------------------------===//
12029// CanonicalLoopInfo
12030//===----------------------------------------------------------------------===//
12031
12032void CanonicalLoopInfo::collectControlBlocks(
12033 SmallVectorImpl<BasicBlock *> &BBs) {
12034 // We only count those BBs as control block for which we do not need to
12035 // reverse the CFG, i.e. not the loop body which can contain arbitrary control
12036 // flow. For consistency, this also means we do not add the Body block, which
12037 // is just the entry to the body code.
12038 BBs.reserve(N: BBs.size() + 6);
12039 BBs.append(IL: {getPreheader(), Header, Cond, Latch, Exit, getAfter()});
12040}
12041
12042BasicBlock *CanonicalLoopInfo::getPreheader() const {
12043 assert(isValid() && "Requires a valid canonical loop");
12044 for (BasicBlock *Pred : predecessors(BB: Header)) {
12045 if (Pred != Latch)
12046 return Pred;
12047 }
12048 llvm_unreachable("Missing preheader");
12049}
12050
12051void CanonicalLoopInfo::setTripCount(Value *TripCount) {
12052 assert(isValid() && "Requires a valid canonical loop");
12053
12054 Instruction *CmpI = &getCond()->front();
12055 assert(isa<CmpInst>(CmpI) && "First inst must compare IV with TripCount");
12056 CmpI->setOperand(i: 1, Val: TripCount);
12057
12058#ifndef NDEBUG
12059 assertOK();
12060#endif
12061}
12062
12063void CanonicalLoopInfo::mapIndVar(
12064 llvm::function_ref<Value *(Instruction *)> Updater) {
12065 assert(isValid() && "Requires a valid canonical loop");
12066
12067 Instruction *OldIV = getIndVar();
12068
12069 // Record all uses excluding those introduced by the updater. Uses by the
12070 // CanonicalLoopInfo itself to keep track of the number of iterations are
12071 // excluded.
12072 SmallVector<Use *> ReplacableUses;
12073 for (Use &U : OldIV->uses()) {
12074 auto *User = dyn_cast<Instruction>(Val: U.getUser());
12075 if (!User)
12076 continue;
12077 if (User->getParent() == getCond())
12078 continue;
12079 if (User->getParent() == getLatch())
12080 continue;
12081 ReplacableUses.push_back(Elt: &U);
12082 }
12083
12084 // Run the updater that may introduce new uses
12085 Value *NewIV = Updater(OldIV);
12086
12087 // Replace the old uses with the value returned by the updater.
12088 for (Use *U : ReplacableUses)
12089 U->set(NewIV);
12090
12091#ifndef NDEBUG
12092 assertOK();
12093#endif
12094}
12095
12096void CanonicalLoopInfo::assertOK() const {
12097#ifndef NDEBUG
12098 // No constraints if this object currently does not describe a loop.
12099 if (!isValid())
12100 return;
12101
12102 BasicBlock *Preheader = getPreheader();
12103 BasicBlock *Body = getBody();
12104 BasicBlock *After = getAfter();
12105
12106 // Verify standard control-flow we use for OpenMP loops.
12107 assert(Preheader);
12108 assert(isa<UncondBrInst>(Preheader->getTerminator()) &&
12109 "Preheader must terminate with unconditional branch");
12110 assert(Preheader->getSingleSuccessor() == Header &&
12111 "Preheader must jump to header");
12112
12113 assert(Header);
12114 assert(isa<UncondBrInst>(Header->getTerminator()) &&
12115 "Header must terminate with unconditional branch");
12116 assert(Header->getSingleSuccessor() == Cond &&
12117 "Header must jump to exiting block");
12118
12119 assert(Cond);
12120 assert(Cond->getSinglePredecessor() == Header &&
12121 "Exiting block only reachable from header");
12122
12123 assert(isa<CondBrInst>(Cond->getTerminator()) &&
12124 "Exiting block must terminate with conditional branch");
12125 assert(cast<CondBrInst>(Cond->getTerminator())->getSuccessor(0) == Body &&
12126 "Exiting block's first successor jump to the body");
12127 assert(cast<CondBrInst>(Cond->getTerminator())->getSuccessor(1) == Exit &&
12128 "Exiting block's second successor must exit the loop");
12129
12130 assert(Body);
12131 assert(Body->getSinglePredecessor() == Cond &&
12132 "Body only reachable from exiting block");
12133 assert(!isa<PHINode>(Body->front()));
12134
12135 assert(Latch);
12136 assert(isa<UncondBrInst>(Latch->getTerminator()) &&
12137 "Latch must terminate with unconditional branch");
12138 assert(Latch->getSingleSuccessor() == Header && "Latch must jump to header");
12139 // TODO: To support simple redirecting of the end of the body code that has
12140 // multiple; introduce another auxiliary basic block like preheader and after.
12141 assert(Latch->getSinglePredecessor() != nullptr);
12142 assert(!isa<PHINode>(Latch->front()));
12143
12144 assert(Exit);
12145 assert(isa<UncondBrInst>(Exit->getTerminator()) &&
12146 "Exit block must terminate with unconditional branch");
12147 assert(Exit->getSingleSuccessor() == After &&
12148 "Exit block must jump to after block");
12149
12150 assert(After);
12151 assert(After->getSinglePredecessor() == Exit &&
12152 "After block only reachable from exit block");
12153 assert(After->empty() || !isa<PHINode>(After->front()));
12154
12155 Instruction *IndVar = getIndVar();
12156 assert(IndVar && "Canonical induction variable not found?");
12157 assert(isa<IntegerType>(IndVar->getType()) &&
12158 "Induction variable must be an integer");
12159 assert(cast<PHINode>(IndVar)->getParent() == Header &&
12160 "Induction variable must be a PHI in the loop header");
12161 assert(cast<PHINode>(IndVar)->getIncomingBlock(0) == Preheader);
12162 assert(
12163 cast<ConstantInt>(cast<PHINode>(IndVar)->getIncomingValue(0))->isZero());
12164 assert(cast<PHINode>(IndVar)->getIncomingBlock(1) == Latch);
12165
12166 auto *NextIndVar = cast<PHINode>(IndVar)->getIncomingValue(1);
12167 assert(cast<Instruction>(NextIndVar)->getParent() == Latch);
12168 assert(cast<BinaryOperator>(NextIndVar)->getOpcode() == BinaryOperator::Add);
12169 assert(cast<BinaryOperator>(NextIndVar)->getOperand(0) == IndVar);
12170 assert(cast<ConstantInt>(cast<BinaryOperator>(NextIndVar)->getOperand(1))
12171 ->isOne());
12172
12173 Value *TripCount = getTripCount();
12174 assert(TripCount && "Loop trip count not found?");
12175 assert(IndVar->getType() == TripCount->getType() &&
12176 "Trip count and induction variable must have the same type");
12177
12178 auto *CmpI = cast<CmpInst>(&Cond->front());
12179 assert(CmpI->getPredicate() == CmpInst::ICMP_ULT &&
12180 "Exit condition must be a signed less-than comparison");
12181 assert(CmpI->getOperand(0) == IndVar &&
12182 "Exit condition must compare the induction variable");
12183 assert(CmpI->getOperand(1) == TripCount &&
12184 "Exit condition must compare with the trip count");
12185#endif
12186}
12187
12188void CanonicalLoopInfo::invalidate() {
12189 Header = nullptr;
12190 Cond = nullptr;
12191 Latch = nullptr;
12192 Exit = nullptr;
12193}
12194