1//===- OpenMPIRBuilder.cpp - Builder for LLVM-IR for OpenMP directives ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8/// \file
9///
10/// This file implements the OpenMPIRBuilder class, which is used as a
11/// convenient way to create LLVM instructions for OpenMP directives.
12///
13//===----------------------------------------------------------------------===//
14
15#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
16#include "llvm/ADT/SmallBitVector.h"
17#include "llvm/ADT/SmallSet.h"
18#include "llvm/ADT/StringExtras.h"
19#include "llvm/ADT/StringRef.h"
20#include "llvm/Analysis/AssumptionCache.h"
21#include "llvm/Analysis/CodeMetrics.h"
22#include "llvm/Analysis/LoopInfo.h"
23#include "llvm/Analysis/OptimizationRemarkEmitter.h"
24#include "llvm/Analysis/PostDominators.h"
25#include "llvm/Analysis/ScalarEvolution.h"
26#include "llvm/Analysis/TargetLibraryInfo.h"
27#include "llvm/Bitcode/BitcodeReader.h"
28#include "llvm/Frontend/Offloading/Utility.h"
29#include "llvm/Frontend/OpenMP/OMPGridValues.h"
30#include "llvm/IR/Attributes.h"
31#include "llvm/IR/BasicBlock.h"
32#include "llvm/IR/CFG.h"
33#include "llvm/IR/CallingConv.h"
34#include "llvm/IR/Constant.h"
35#include "llvm/IR/Constants.h"
36#include "llvm/IR/DIBuilder.h"
37#include "llvm/IR/DebugInfoMetadata.h"
38#include "llvm/IR/DerivedTypes.h"
39#include "llvm/IR/Function.h"
40#include "llvm/IR/GlobalVariable.h"
41#include "llvm/IR/IRBuilder.h"
42#include "llvm/IR/InstIterator.h"
43#include "llvm/IR/IntrinsicInst.h"
44#include "llvm/IR/LLVMContext.h"
45#include "llvm/IR/MDBuilder.h"
46#include "llvm/IR/Metadata.h"
47#include "llvm/IR/PassInstrumentation.h"
48#include "llvm/IR/PassManager.h"
49#include "llvm/IR/ReplaceConstant.h"
50#include "llvm/IR/Value.h"
51#include "llvm/MC/TargetRegistry.h"
52#include "llvm/Support/CommandLine.h"
53#include "llvm/Support/Error.h"
54#include "llvm/Support/ErrorHandling.h"
55#include "llvm/Support/FileSystem.h"
56#include "llvm/Support/VirtualFileSystem.h"
57#include "llvm/Target/TargetMachine.h"
58#include "llvm/Target/TargetOptions.h"
59#include "llvm/Transforms/Utils/BasicBlockUtils.h"
60#include "llvm/Transforms/Utils/Cloning.h"
61#include "llvm/Transforms/Utils/CodeExtractor.h"
62#include "llvm/Transforms/Utils/LoopPeel.h"
63#include "llvm/Transforms/Utils/UnrollLoop.h"
64
65#include <cstdint>
66#include <optional>
67
68#define DEBUG_TYPE "openmp-ir-builder"
69
70using namespace llvm;
71using namespace omp;
72
73static cl::opt<bool>
74 OptimisticAttributes("openmp-ir-builder-optimistic-attributes", cl::Hidden,
75 cl::desc("Use optimistic attributes describing "
76 "'as-if' properties of runtime calls."),
77 cl::init(Val: false));
78
79static cl::opt<double> UnrollThresholdFactor(
80 "openmp-ir-builder-unroll-threshold-factor", cl::Hidden,
81 cl::desc("Factor for the unroll threshold to account for code "
82 "simplifications still taking place"),
83 cl::init(Val: 1.5));
84
85#ifndef NDEBUG
86/// Return whether IP1 and IP2 are ambiguous, i.e. that inserting instructions
87/// at position IP1 may change the meaning of IP2 or vice-versa. This is because
88/// an InsertPoint stores the instruction before something is inserted. For
89/// instance, if both point to the same instruction, two IRBuilders alternating
90/// creating instruction will cause the instructions to be interleaved.
91static bool isConflictIP(IRBuilder<>::InsertPoint IP1,
92 IRBuilder<>::InsertPoint IP2) {
93 if (!IP1.isSet() || !IP2.isSet())
94 return false;
95 return IP1.getBlock() == IP2.getBlock() && IP1.getPoint() == IP2.getPoint();
96}
97
98static bool isValidWorkshareLoopScheduleType(OMPScheduleType SchedType) {
99 // Valid ordered/unordered and base algorithm combinations.
100 switch (SchedType & ~OMPScheduleType::MonotonicityMask) {
101 case OMPScheduleType::UnorderedStaticChunked:
102 case OMPScheduleType::UnorderedStatic:
103 case OMPScheduleType::UnorderedDynamicChunked:
104 case OMPScheduleType::UnorderedGuidedChunked:
105 case OMPScheduleType::UnorderedRuntime:
106 case OMPScheduleType::UnorderedAuto:
107 case OMPScheduleType::UnorderedTrapezoidal:
108 case OMPScheduleType::UnorderedGreedy:
109 case OMPScheduleType::UnorderedBalanced:
110 case OMPScheduleType::UnorderedGuidedIterativeChunked:
111 case OMPScheduleType::UnorderedGuidedAnalyticalChunked:
112 case OMPScheduleType::UnorderedSteal:
113 case OMPScheduleType::UnorderedStaticBalancedChunked:
114 case OMPScheduleType::UnorderedGuidedSimd:
115 case OMPScheduleType::UnorderedRuntimeSimd:
116 case OMPScheduleType::OrderedStaticChunked:
117 case OMPScheduleType::OrderedStatic:
118 case OMPScheduleType::OrderedDynamicChunked:
119 case OMPScheduleType::OrderedGuidedChunked:
120 case OMPScheduleType::OrderedRuntime:
121 case OMPScheduleType::OrderedAuto:
122 case OMPScheduleType::OrderdTrapezoidal:
123 case OMPScheduleType::NomergeUnorderedStaticChunked:
124 case OMPScheduleType::NomergeUnorderedStatic:
125 case OMPScheduleType::NomergeUnorderedDynamicChunked:
126 case OMPScheduleType::NomergeUnorderedGuidedChunked:
127 case OMPScheduleType::NomergeUnorderedRuntime:
128 case OMPScheduleType::NomergeUnorderedAuto:
129 case OMPScheduleType::NomergeUnorderedTrapezoidal:
130 case OMPScheduleType::NomergeUnorderedGreedy:
131 case OMPScheduleType::NomergeUnorderedBalanced:
132 case OMPScheduleType::NomergeUnorderedGuidedIterativeChunked:
133 case OMPScheduleType::NomergeUnorderedGuidedAnalyticalChunked:
134 case OMPScheduleType::NomergeUnorderedSteal:
135 case OMPScheduleType::NomergeOrderedStaticChunked:
136 case OMPScheduleType::NomergeOrderedStatic:
137 case OMPScheduleType::NomergeOrderedDynamicChunked:
138 case OMPScheduleType::NomergeOrderedGuidedChunked:
139 case OMPScheduleType::NomergeOrderedRuntime:
140 case OMPScheduleType::NomergeOrderedAuto:
141 case OMPScheduleType::NomergeOrderedTrapezoidal:
142 case OMPScheduleType::OrderedDistributeChunked:
143 case OMPScheduleType::OrderedDistribute:
144 break;
145 default:
146 return false;
147 }
148
149 // Must not set both monotonicity modifiers at the same time.
150 OMPScheduleType MonotonicityFlags =
151 SchedType & OMPScheduleType::MonotonicityMask;
152 if (MonotonicityFlags == OMPScheduleType::MonotonicityMask)
153 return false;
154
155 return true;
156}
157#endif
158
159/// This is wrapper over IRBuilderBase::restoreIP that also restores the current
160/// debug location to the last instruction in the specified basic block if the
161/// insert point points to the end of the block.
162static void restoreIPandDebugLoc(llvm::IRBuilderBase &Builder,
163 llvm::IRBuilderBase::InsertPoint IP) {
164 Builder.restoreIP(IP);
165 llvm::BasicBlock *BB = Builder.GetInsertBlock();
166 llvm::BasicBlock::iterator I = Builder.GetInsertPoint();
167 if (!BB->empty() && I == BB->end())
168 Builder.SetCurrentDebugLocation(BB->back().getStableDebugLoc());
169}
170
171static const omp::GV &getGridValue(const Triple &T, Function *Kernel) {
172 if (T.isAMDGPU()) {
173 StringRef Features =
174 Kernel->getFnAttribute(Kind: "target-features").getValueAsString();
175 if (Features.count(Str: "+wavefrontsize64"))
176 return omp::getAMDGPUGridValues<64>();
177 return omp::getAMDGPUGridValues<32>();
178 }
179 if (T.isNVPTX())
180 return omp::NVPTXGridValues;
181 if (T.isSPIRV())
182 return omp::SPIRVGridValues;
183 llvm_unreachable("No grid value available for this architecture!");
184}
185
186/// Determine which scheduling algorithm to use, determined from schedule clause
187/// arguments.
188static OMPScheduleType
189getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind, bool HasChunks,
190 bool HasSimdModifier, bool HasDistScheduleChunks) {
191 // Currently, the default schedule it static.
192 switch (ClauseKind) {
193 case OMP_SCHEDULE_Default:
194 case OMP_SCHEDULE_Static:
195 return HasChunks ? OMPScheduleType::BaseStaticChunked
196 : OMPScheduleType::BaseStatic;
197 case OMP_SCHEDULE_Dynamic:
198 return OMPScheduleType::BaseDynamicChunked;
199 case OMP_SCHEDULE_Guided:
200 return HasSimdModifier ? OMPScheduleType::BaseGuidedSimd
201 : OMPScheduleType::BaseGuidedChunked;
202 case OMP_SCHEDULE_Auto:
203 return llvm::omp::OMPScheduleType::BaseAuto;
204 case OMP_SCHEDULE_Runtime:
205 return HasSimdModifier ? OMPScheduleType::BaseRuntimeSimd
206 : OMPScheduleType::BaseRuntime;
207 case OMP_SCHEDULE_Distribute:
208 return HasDistScheduleChunks ? OMPScheduleType::BaseDistributeChunked
209 : OMPScheduleType::BaseDistribute;
210 }
211 llvm_unreachable("unhandled schedule clause argument");
212}
213
214/// Adds ordering modifier flags to schedule type.
215static OMPScheduleType
216getOpenMPOrderingScheduleType(OMPScheduleType BaseScheduleType,
217 bool HasOrderedClause) {
218 assert((BaseScheduleType & OMPScheduleType::ModifierMask) ==
219 OMPScheduleType::None &&
220 "Must not have ordering nor monotonicity flags already set");
221
222 OMPScheduleType OrderingModifier = HasOrderedClause
223 ? OMPScheduleType::ModifierOrdered
224 : OMPScheduleType::ModifierUnordered;
225 OMPScheduleType OrderingScheduleType = BaseScheduleType | OrderingModifier;
226
227 // Unsupported combinations
228 if (OrderingScheduleType ==
229 (OMPScheduleType::BaseGuidedSimd | OMPScheduleType::ModifierOrdered))
230 return OMPScheduleType::OrderedGuidedChunked;
231 else if (OrderingScheduleType == (OMPScheduleType::BaseRuntimeSimd |
232 OMPScheduleType::ModifierOrdered))
233 return OMPScheduleType::OrderedRuntime;
234
235 return OrderingScheduleType;
236}
237
238/// Adds monotonicity modifier flags to schedule type.
239static OMPScheduleType
240getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,
241 bool HasSimdModifier, bool HasMonotonic,
242 bool HasNonmonotonic, bool HasOrderedClause) {
243 assert((ScheduleType & OMPScheduleType::MonotonicityMask) ==
244 OMPScheduleType::None &&
245 "Must not have monotonicity flags already set");
246 assert((!HasMonotonic || !HasNonmonotonic) &&
247 "Monotonic and Nonmonotonic are contradicting each other");
248
249 if (HasMonotonic) {
250 return ScheduleType | OMPScheduleType::ModifierMonotonic;
251 } else if (HasNonmonotonic) {
252 return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
253 } else {
254 // OpenMP 5.1, 2.11.4 Worksharing-Loop Construct, Description.
255 // If the static schedule kind is specified or if the ordered clause is
256 // specified, and if the nonmonotonic modifier is not specified, the
257 // effect is as if the monotonic modifier is specified. Otherwise, unless
258 // the monotonic modifier is specified, the effect is as if the
259 // nonmonotonic modifier is specified.
260 OMPScheduleType BaseScheduleType =
261 ScheduleType & ~OMPScheduleType::ModifierMask;
262 if ((BaseScheduleType == OMPScheduleType::BaseStatic) ||
263 (BaseScheduleType == OMPScheduleType::BaseStaticChunked) ||
264 HasOrderedClause) {
265 // The monotonic is used by default in openmp runtime library, so no need
266 // to set it.
267 return ScheduleType;
268 } else {
269 return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
270 }
271 }
272}
273
274/// Determine the schedule type using schedule and ordering clause arguments.
275static OMPScheduleType
276computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
277 bool HasSimdModifier, bool HasMonotonicModifier,
278 bool HasNonmonotonicModifier, bool HasOrderedClause,
279 bool HasDistScheduleChunks) {
280 OMPScheduleType BaseSchedule = getOpenMPBaseScheduleType(
281 ClauseKind, HasChunks, HasSimdModifier, HasDistScheduleChunks);
282 OMPScheduleType OrderedSchedule =
283 getOpenMPOrderingScheduleType(BaseScheduleType: BaseSchedule, HasOrderedClause);
284 OMPScheduleType Result = getOpenMPMonotonicityScheduleType(
285 ScheduleType: OrderedSchedule, HasSimdModifier, HasMonotonic: HasMonotonicModifier,
286 HasNonmonotonic: HasNonmonotonicModifier, HasOrderedClause);
287
288 assert(isValidWorkshareLoopScheduleType(Result));
289 return Result;
290}
291
292/// Make \p Source branch to \p Target.
293///
294/// Handles two situations:
295/// * \p Source already has an unconditional branch.
296/// * \p Source is a degenerate block (no terminator because the BB is
297/// the current head of the IR construction).
298static void redirectTo(BasicBlock *Source, BasicBlock *Target, DebugLoc DL) {
299 if (Instruction *Term = Source->getTerminator()) {
300 auto *Br = cast<BranchInst>(Val: Term);
301 assert(!Br->isConditional() &&
302 "BB's terminator must be an unconditional branch (or degenerate)");
303 BasicBlock *Succ = Br->getSuccessor(i: 0);
304 Succ->removePredecessor(Pred: Source, /*KeepOneInputPHIs=*/true);
305 Br->setSuccessor(idx: 0, NewSucc: Target);
306 return;
307 }
308
309 auto *NewBr = BranchInst::Create(IfTrue: Target, InsertBefore: Source);
310 NewBr->setDebugLoc(DL);
311}
312
313void llvm::spliceBB(IRBuilderBase::InsertPoint IP, BasicBlock *New,
314 bool CreateBranch, DebugLoc DL) {
315 assert(New->getFirstInsertionPt() == New->begin() &&
316 "Target BB must not have PHI nodes");
317
318 // Move instructions to new block.
319 BasicBlock *Old = IP.getBlock();
320 // If the `Old` block is empty then there are no instructions to move. But in
321 // the new debug scheme, it could have trailing debug records which will be
322 // moved to `New` in `spliceDebugInfoEmptyBlock`. We dont want that for 2
323 // reasons:
324 // 1. If `New` is also empty, `BasicBlock::splice` crashes.
325 // 2. Even if `New` is not empty, the rationale to move those records to `New`
326 // (in `spliceDebugInfoEmptyBlock`) does not apply here. That function
327 // assumes that `Old` is optimized out and is going away. This is not the case
328 // here. The `Old` block is still being used e.g. a branch instruction is
329 // added to it later in this function.
330 // So we call `BasicBlock::splice` only when `Old` is not empty.
331 if (!Old->empty())
332 New->splice(ToIt: New->begin(), FromBB: Old, FromBeginIt: IP.getPoint(), FromEndIt: Old->end());
333
334 if (CreateBranch) {
335 auto *NewBr = BranchInst::Create(IfTrue: New, InsertBefore: Old);
336 NewBr->setDebugLoc(DL);
337 }
338}
339
340void llvm::spliceBB(IRBuilder<> &Builder, BasicBlock *New, bool CreateBranch) {
341 DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
342 BasicBlock *Old = Builder.GetInsertBlock();
343
344 spliceBB(IP: Builder.saveIP(), New, CreateBranch, DL: DebugLoc);
345 if (CreateBranch)
346 Builder.SetInsertPoint(Old->getTerminator());
347 else
348 Builder.SetInsertPoint(Old);
349
350 // SetInsertPoint also updates the Builder's debug location, but we want to
351 // keep the one the Builder was configured to use.
352 Builder.SetCurrentDebugLocation(DebugLoc);
353}
354
355BasicBlock *llvm::splitBB(IRBuilderBase::InsertPoint IP, bool CreateBranch,
356 DebugLoc DL, llvm::Twine Name) {
357 BasicBlock *Old = IP.getBlock();
358 BasicBlock *New = BasicBlock::Create(
359 Context&: Old->getContext(), Name: Name.isTriviallyEmpty() ? Old->getName() : Name,
360 Parent: Old->getParent(), InsertBefore: Old->getNextNode());
361 spliceBB(IP, New, CreateBranch, DL);
362 New->replaceSuccessorsPhiUsesWith(Old, New);
363 return New;
364}
365
366BasicBlock *llvm::splitBB(IRBuilderBase &Builder, bool CreateBranch,
367 llvm::Twine Name) {
368 DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
369 BasicBlock *New = splitBB(IP: Builder.saveIP(), CreateBranch, DL: DebugLoc, Name);
370 if (CreateBranch)
371 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
372 else
373 Builder.SetInsertPoint(Builder.GetInsertBlock());
374 // SetInsertPoint also updates the Builder's debug location, but we want to
375 // keep the one the Builder was configured to use.
376 Builder.SetCurrentDebugLocation(DebugLoc);
377 return New;
378}
379
380BasicBlock *llvm::splitBB(IRBuilder<> &Builder, bool CreateBranch,
381 llvm::Twine Name) {
382 DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
383 BasicBlock *New = splitBB(IP: Builder.saveIP(), CreateBranch, DL: DebugLoc, Name);
384 if (CreateBranch)
385 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
386 else
387 Builder.SetInsertPoint(Builder.GetInsertBlock());
388 // SetInsertPoint also updates the Builder's debug location, but we want to
389 // keep the one the Builder was configured to use.
390 Builder.SetCurrentDebugLocation(DebugLoc);
391 return New;
392}
393
394BasicBlock *llvm::splitBBWithSuffix(IRBuilderBase &Builder, bool CreateBranch,
395 llvm::Twine Suffix) {
396 BasicBlock *Old = Builder.GetInsertBlock();
397 return splitBB(Builder, CreateBranch, Name: Old->getName() + Suffix);
398}
399
400// This function creates a fake integer value and a fake use for the integer
401// value. It returns the fake value created. This is useful in modeling the
402// extra arguments to the outlined functions.
403Value *createFakeIntVal(IRBuilderBase &Builder,
404 OpenMPIRBuilder::InsertPointTy OuterAllocaIP,
405 llvm::SmallVectorImpl<Instruction *> &ToBeDeleted,
406 OpenMPIRBuilder::InsertPointTy InnerAllocaIP,
407 const Twine &Name = "", bool AsPtr = true,
408 bool Is64Bit = false) {
409 Builder.restoreIP(IP: OuterAllocaIP);
410 IntegerType *IntTy = Is64Bit ? Builder.getInt64Ty() : Builder.getInt32Ty();
411 Instruction *FakeVal;
412 AllocaInst *FakeValAddr =
413 Builder.CreateAlloca(Ty: IntTy, ArraySize: nullptr, Name: Name + ".addr");
414 ToBeDeleted.push_back(Elt: FakeValAddr);
415
416 if (AsPtr) {
417 FakeVal = FakeValAddr;
418 } else {
419 FakeVal = Builder.CreateLoad(Ty: IntTy, Ptr: FakeValAddr, Name: Name + ".val");
420 ToBeDeleted.push_back(Elt: FakeVal);
421 }
422
423 // Generate a fake use of this value
424 Builder.restoreIP(IP: InnerAllocaIP);
425 Instruction *UseFakeVal;
426 if (AsPtr) {
427 UseFakeVal = Builder.CreateLoad(Ty: IntTy, Ptr: FakeVal, Name: Name + ".use");
428 } else {
429 UseFakeVal = cast<BinaryOperator>(Val: Builder.CreateAdd(
430 LHS: FakeVal, RHS: Is64Bit ? Builder.getInt64(C: 10) : Builder.getInt32(C: 10)));
431 }
432 ToBeDeleted.push_back(Elt: UseFakeVal);
433 return FakeVal;
434}
435
436//===----------------------------------------------------------------------===//
437// OpenMPIRBuilderConfig
438//===----------------------------------------------------------------------===//
439
440namespace {
441LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
442/// Values for bit flags for marking which requires clauses have been used.
443enum OpenMPOffloadingRequiresDirFlags {
444 /// flag undefined.
445 OMP_REQ_UNDEFINED = 0x000,
446 /// no requires directive present.
447 OMP_REQ_NONE = 0x001,
448 /// reverse_offload clause.
449 OMP_REQ_REVERSE_OFFLOAD = 0x002,
450 /// unified_address clause.
451 OMP_REQ_UNIFIED_ADDRESS = 0x004,
452 /// unified_shared_memory clause.
453 OMP_REQ_UNIFIED_SHARED_MEMORY = 0x008,
454 /// dynamic_allocators clause.
455 OMP_REQ_DYNAMIC_ALLOCATORS = 0x010,
456 LLVM_MARK_AS_BITMASK_ENUM(/*LargestValue=*/OMP_REQ_DYNAMIC_ALLOCATORS)
457};
458
459} // anonymous namespace
460
461OpenMPIRBuilderConfig::OpenMPIRBuilderConfig()
462 : RequiresFlags(OMP_REQ_UNDEFINED) {}
463
464OpenMPIRBuilderConfig::OpenMPIRBuilderConfig(
465 bool IsTargetDevice, bool IsGPU, bool OpenMPOffloadMandatory,
466 bool HasRequiresReverseOffload, bool HasRequiresUnifiedAddress,
467 bool HasRequiresUnifiedSharedMemory, bool HasRequiresDynamicAllocators)
468 : IsTargetDevice(IsTargetDevice), IsGPU(IsGPU),
469 OpenMPOffloadMandatory(OpenMPOffloadMandatory),
470 RequiresFlags(OMP_REQ_UNDEFINED) {
471 if (HasRequiresReverseOffload)
472 RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
473 if (HasRequiresUnifiedAddress)
474 RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
475 if (HasRequiresUnifiedSharedMemory)
476 RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
477 if (HasRequiresDynamicAllocators)
478 RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
479}
480
481bool OpenMPIRBuilderConfig::hasRequiresReverseOffload() const {
482 return RequiresFlags & OMP_REQ_REVERSE_OFFLOAD;
483}
484
485bool OpenMPIRBuilderConfig::hasRequiresUnifiedAddress() const {
486 return RequiresFlags & OMP_REQ_UNIFIED_ADDRESS;
487}
488
489bool OpenMPIRBuilderConfig::hasRequiresUnifiedSharedMemory() const {
490 return RequiresFlags & OMP_REQ_UNIFIED_SHARED_MEMORY;
491}
492
493bool OpenMPIRBuilderConfig::hasRequiresDynamicAllocators() const {
494 return RequiresFlags & OMP_REQ_DYNAMIC_ALLOCATORS;
495}
496
497int64_t OpenMPIRBuilderConfig::getRequiresFlags() const {
498 return hasRequiresFlags() ? RequiresFlags
499 : static_cast<int64_t>(OMP_REQ_NONE);
500}
501
502void OpenMPIRBuilderConfig::setHasRequiresReverseOffload(bool Value) {
503 if (Value)
504 RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
505 else
506 RequiresFlags &= ~OMP_REQ_REVERSE_OFFLOAD;
507}
508
509void OpenMPIRBuilderConfig::setHasRequiresUnifiedAddress(bool Value) {
510 if (Value)
511 RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
512 else
513 RequiresFlags &= ~OMP_REQ_UNIFIED_ADDRESS;
514}
515
516void OpenMPIRBuilderConfig::setHasRequiresUnifiedSharedMemory(bool Value) {
517 if (Value)
518 RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
519 else
520 RequiresFlags &= ~OMP_REQ_UNIFIED_SHARED_MEMORY;
521}
522
523void OpenMPIRBuilderConfig::setHasRequiresDynamicAllocators(bool Value) {
524 if (Value)
525 RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
526 else
527 RequiresFlags &= ~OMP_REQ_DYNAMIC_ALLOCATORS;
528}
529
530//===----------------------------------------------------------------------===//
531// OpenMPIRBuilder
532//===----------------------------------------------------------------------===//
533
534void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
535 IRBuilderBase &Builder,
536 SmallVector<Value *> &ArgsVector) {
537 Value *Version = Builder.getInt32(OMP_KERNEL_ARG_VERSION);
538 Value *PointerNum = Builder.getInt32(C: KernelArgs.NumTargetItems);
539 auto Int32Ty = Type::getInt32Ty(C&: Builder.getContext());
540 constexpr size_t MaxDim = 3;
541 Value *ZeroArray = Constant::getNullValue(Ty: ArrayType::get(ElementType: Int32Ty, NumElements: MaxDim));
542
543 Value *HasNoWaitFlag = Builder.getInt64(C: KernelArgs.HasNoWait);
544
545 Value *DynCGroupMemFallbackFlag =
546 Builder.getInt64(C: static_cast<uint64_t>(KernelArgs.DynCGroupMemFallback));
547 DynCGroupMemFallbackFlag = Builder.CreateShl(LHS: DynCGroupMemFallbackFlag, RHS: 2);
548 Value *Flags = Builder.CreateOr(LHS: HasNoWaitFlag, RHS: DynCGroupMemFallbackFlag);
549
550 assert(!KernelArgs.NumTeams.empty() && !KernelArgs.NumThreads.empty());
551
552 Value *NumTeams3D =
553 Builder.CreateInsertValue(Agg: ZeroArray, Val: KernelArgs.NumTeams[0], Idxs: {0});
554 Value *NumThreads3D =
555 Builder.CreateInsertValue(Agg: ZeroArray, Val: KernelArgs.NumThreads[0], Idxs: {0});
556 for (unsigned I :
557 seq<unsigned>(Begin: 1, End: std::min(a: KernelArgs.NumTeams.size(), b: MaxDim)))
558 NumTeams3D =
559 Builder.CreateInsertValue(Agg: NumTeams3D, Val: KernelArgs.NumTeams[I], Idxs: {I});
560 for (unsigned I :
561 seq<unsigned>(Begin: 1, End: std::min(a: KernelArgs.NumThreads.size(), b: MaxDim)))
562 NumThreads3D =
563 Builder.CreateInsertValue(Agg: NumThreads3D, Val: KernelArgs.NumThreads[I], Idxs: {I});
564
565 ArgsVector = {Version,
566 PointerNum,
567 KernelArgs.RTArgs.BasePointersArray,
568 KernelArgs.RTArgs.PointersArray,
569 KernelArgs.RTArgs.SizesArray,
570 KernelArgs.RTArgs.MapTypesArray,
571 KernelArgs.RTArgs.MapNamesArray,
572 KernelArgs.RTArgs.MappersArray,
573 KernelArgs.NumIterations,
574 Flags,
575 NumTeams3D,
576 NumThreads3D,
577 KernelArgs.DynCGroupMem};
578}
579
580void OpenMPIRBuilder::addAttributes(omp::RuntimeFunction FnID, Function &Fn) {
581 LLVMContext &Ctx = Fn.getContext();
582
583 // Get the function's current attributes.
584 auto Attrs = Fn.getAttributes();
585 auto FnAttrs = Attrs.getFnAttrs();
586 auto RetAttrs = Attrs.getRetAttrs();
587 SmallVector<AttributeSet, 4> ArgAttrs;
588 for (size_t ArgNo = 0; ArgNo < Fn.arg_size(); ++ArgNo)
589 ArgAttrs.emplace_back(Args: Attrs.getParamAttrs(ArgNo));
590
591 // Add AS to FnAS while taking special care with integer extensions.
592 auto addAttrSet = [&](AttributeSet &FnAS, const AttributeSet &AS,
593 bool Param = true) -> void {
594 bool HasSignExt = AS.hasAttribute(Kind: Attribute::SExt);
595 bool HasZeroExt = AS.hasAttribute(Kind: Attribute::ZExt);
596 if (HasSignExt || HasZeroExt) {
597 assert(AS.getNumAttributes() == 1 &&
598 "Currently not handling extension attr combined with others.");
599 if (Param) {
600 if (auto AK = TargetLibraryInfo::getExtAttrForI32Param(T, Signed: HasSignExt))
601 FnAS = FnAS.addAttribute(C&: Ctx, Kind: AK);
602 } else if (auto AK =
603 TargetLibraryInfo::getExtAttrForI32Return(T, Signed: HasSignExt))
604 FnAS = FnAS.addAttribute(C&: Ctx, Kind: AK);
605 } else {
606 FnAS = FnAS.addAttributes(C&: Ctx, AS);
607 }
608 };
609
610#define OMP_ATTRS_SET(VarName, AttrSet) AttributeSet VarName = AttrSet;
611#include "llvm/Frontend/OpenMP/OMPKinds.def"
612
613 // Add attributes to the function declaration.
614 switch (FnID) {
615#define OMP_RTL_ATTRS(Enum, FnAttrSet, RetAttrSet, ArgAttrSets) \
616 case Enum: \
617 FnAttrs = FnAttrs.addAttributes(Ctx, FnAttrSet); \
618 addAttrSet(RetAttrs, RetAttrSet, /*Param*/ false); \
619 for (size_t ArgNo = 0; ArgNo < ArgAttrSets.size(); ++ArgNo) \
620 addAttrSet(ArgAttrs[ArgNo], ArgAttrSets[ArgNo]); \
621 Fn.setAttributes(AttributeList::get(Ctx, FnAttrs, RetAttrs, ArgAttrs)); \
622 break;
623#include "llvm/Frontend/OpenMP/OMPKinds.def"
624 default:
625 // Attributes are optional.
626 break;
627 }
628}
629
630FunctionCallee
631OpenMPIRBuilder::getOrCreateRuntimeFunction(Module &M, RuntimeFunction FnID) {
632 FunctionType *FnTy = nullptr;
633 Function *Fn = nullptr;
634
635 // Try to find the declation in the module first.
636 switch (FnID) {
637#define OMP_RTL(Enum, Str, IsVarArg, ReturnType, ...) \
638 case Enum: \
639 FnTy = FunctionType::get(ReturnType, ArrayRef<Type *>{__VA_ARGS__}, \
640 IsVarArg); \
641 Fn = M.getFunction(Str); \
642 break;
643#include "llvm/Frontend/OpenMP/OMPKinds.def"
644 }
645
646 if (!Fn) {
647 // Create a new declaration if we need one.
648 switch (FnID) {
649#define OMP_RTL(Enum, Str, ...) \
650 case Enum: \
651 Fn = Function::Create(FnTy, GlobalValue::ExternalLinkage, Str, M); \
652 break;
653#include "llvm/Frontend/OpenMP/OMPKinds.def"
654 }
655 Fn->setCallingConv(Config.getRuntimeCC());
656 // Add information if the runtime function takes a callback function
657 if (FnID == OMPRTL___kmpc_fork_call || FnID == OMPRTL___kmpc_fork_teams) {
658 if (!Fn->hasMetadata(KindID: LLVMContext::MD_callback)) {
659 LLVMContext &Ctx = Fn->getContext();
660 MDBuilder MDB(Ctx);
661 // Annotate the callback behavior of the runtime function:
662 // - The callback callee is argument number 2 (microtask).
663 // - The first two arguments of the callback callee are unknown (-1).
664 // - All variadic arguments to the runtime function are passed to the
665 // callback callee.
666 Fn->addMetadata(
667 KindID: LLVMContext::MD_callback,
668 MD&: *MDNode::get(Context&: Ctx, MDs: {MDB.createCallbackEncoding(
669 CalleeArgNo: 2, Arguments: {-1, -1}, /* VarArgsArePassed */ true)}));
670 }
671 }
672
673 LLVM_DEBUG(dbgs() << "Created OpenMP runtime function " << Fn->getName()
674 << " with type " << *Fn->getFunctionType() << "\n");
675 addAttributes(FnID, Fn&: *Fn);
676
677 } else {
678 LLVM_DEBUG(dbgs() << "Found OpenMP runtime function " << Fn->getName()
679 << " with type " << *Fn->getFunctionType() << "\n");
680 }
681
682 assert(Fn && "Failed to create OpenMP runtime function");
683
684 return {FnTy, Fn};
685}
686
687Expected<BasicBlock *>
688OpenMPIRBuilder::FinalizationInfo::getFiniBB(IRBuilderBase &Builder) {
689 if (!FiniBB) {
690 Function *ParentFunc = Builder.GetInsertBlock()->getParent();
691 IRBuilderBase::InsertPointGuard Guard(Builder);
692 FiniBB = BasicBlock::Create(Context&: Builder.getContext(), Name: ".fini", Parent: ParentFunc);
693 Builder.SetInsertPoint(FiniBB);
694 // FiniCB adds the branch to the exit stub.
695 if (Error Err = FiniCB(Builder.saveIP()))
696 return Err;
697 }
698 return FiniBB;
699}
700
701Error OpenMPIRBuilder::FinalizationInfo::mergeFiniBB(IRBuilderBase &Builder,
702 BasicBlock *OtherFiniBB) {
703 // Simple case: FiniBB does not exist yet: re-use OtherFiniBB.
704 if (!FiniBB) {
705 FiniBB = OtherFiniBB;
706
707 Builder.SetInsertPoint(FiniBB->getFirstNonPHIIt());
708 if (Error Err = FiniCB(Builder.saveIP()))
709 return Err;
710
711 return Error::success();
712 }
713
714 // Move instructions from FiniBB to the start of OtherFiniBB.
715 auto EndIt = FiniBB->end();
716 if (FiniBB->size() >= 1)
717 if (auto Prev = std::prev(x: EndIt); Prev->isTerminator())
718 EndIt = Prev;
719 OtherFiniBB->splice(ToIt: OtherFiniBB->getFirstNonPHIIt(), FromBB: FiniBB, FromBeginIt: FiniBB->begin(),
720 FromEndIt: EndIt);
721
722 FiniBB->replaceAllUsesWith(V: OtherFiniBB);
723 FiniBB->eraseFromParent();
724 FiniBB = OtherFiniBB;
725 return Error::success();
726}
727
728Function *OpenMPIRBuilder::getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID) {
729 FunctionCallee RTLFn = getOrCreateRuntimeFunction(M, FnID);
730 auto *Fn = dyn_cast<llvm::Function>(Val: RTLFn.getCallee());
731 assert(Fn && "Failed to create OpenMP runtime function pointer");
732 return Fn;
733}
734
735CallInst *OpenMPIRBuilder::createRuntimeFunctionCall(FunctionCallee Callee,
736 ArrayRef<Value *> Args,
737 StringRef Name) {
738 CallInst *Call = Builder.CreateCall(Callee, Args, Name);
739 Call->setCallingConv(Config.getRuntimeCC());
740 return Call;
741}
742
743void OpenMPIRBuilder::initialize() { initializeTypes(M); }
744
745static void raiseUserConstantDataAllocasToEntryBlock(IRBuilderBase &Builder,
746 Function *Function) {
747 BasicBlock &EntryBlock = Function->getEntryBlock();
748 BasicBlock::iterator MoveLocInst = EntryBlock.getFirstNonPHIIt();
749
750 // Loop over blocks looking for constant allocas, skipping the entry block
751 // as any allocas there are already in the desired location.
752 for (auto Block = std::next(x: Function->begin(), n: 1); Block != Function->end();
753 Block++) {
754 for (auto Inst = Block->getReverseIterator()->begin();
755 Inst != Block->getReverseIterator()->end();) {
756 if (auto *AllocaInst = dyn_cast_if_present<llvm::AllocaInst>(Val&: Inst)) {
757 Inst++;
758 if (!isa<ConstantData>(Val: AllocaInst->getArraySize()))
759 continue;
760 AllocaInst->moveBeforePreserving(MovePos: MoveLocInst);
761 } else {
762 Inst++;
763 }
764 }
765 }
766}
767
768static void hoistNonEntryAllocasToEntryBlock(llvm::BasicBlock &Block) {
769 llvm::SmallVector<llvm::Instruction *> AllocasToMove;
770
771 auto ShouldHoistAlloca = [](const llvm::AllocaInst &AllocaInst) {
772 // TODO: For now, we support simple static allocations, we might need to
773 // move non-static ones as well. However, this will need further analysis to
774 // move the lenght arguments as well.
775 return !AllocaInst.isArrayAllocation();
776 };
777
778 for (llvm::Instruction &Inst : Block)
779 if (auto *AllocaInst = llvm::dyn_cast<llvm::AllocaInst>(Val: &Inst))
780 if (ShouldHoistAlloca(*AllocaInst))
781 AllocasToMove.push_back(Elt: AllocaInst);
782
783 auto InsertPoint =
784 Block.getParent()->getEntryBlock().getTerminator()->getIterator();
785
786 for (llvm::Instruction *AllocaInst : AllocasToMove)
787 AllocaInst->moveBefore(InsertPos: InsertPoint);
788}
789
790static void hoistNonEntryAllocasToEntryBlock(llvm::Function *Func) {
791 PostDominatorTree PostDomTree(*Func);
792 for (llvm::BasicBlock &BB : *Func)
793 if (PostDomTree.properlyDominates(A: &BB, B: &Func->getEntryBlock()))
794 hoistNonEntryAllocasToEntryBlock(Block&: BB);
795}
796
797void OpenMPIRBuilder::finalize(Function *Fn) {
798 SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
799 SmallVector<BasicBlock *, 32> Blocks;
800 SmallVector<OutlineInfo, 16> DeferredOutlines;
801 for (OutlineInfo &OI : OutlineInfos) {
802 // Skip functions that have not finalized yet; may happen with nested
803 // function generation.
804 if (Fn && OI.getFunction() != Fn) {
805 DeferredOutlines.push_back(Elt: OI);
806 continue;
807 }
808
809 ParallelRegionBlockSet.clear();
810 Blocks.clear();
811 OI.collectBlocks(BlockSet&: ParallelRegionBlockSet, BlockVector&: Blocks);
812
813 Function *OuterFn = OI.getFunction();
814 CodeExtractorAnalysisCache CEAC(*OuterFn);
815 // If we generate code for the target device, we need to allocate
816 // struct for aggregate params in the device default alloca address space.
817 // OpenMP runtime requires that the params of the extracted functions are
818 // passed as zero address space pointers. This flag ensures that
819 // CodeExtractor generates correct code for extracted functions
820 // which are used by OpenMP runtime.
821 bool ArgsInZeroAddressSpace = Config.isTargetDevice();
822 CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
823 /* AggregateArgs */ true,
824 /* BlockFrequencyInfo */ nullptr,
825 /* BranchProbabilityInfo */ nullptr,
826 /* AssumptionCache */ nullptr,
827 /* AllowVarArgs */ true,
828 /* AllowAlloca */ true,
829 /* AllocaBlock*/ OI.OuterAllocaBB,
830 /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
831
832 LLVM_DEBUG(dbgs() << "Before outlining: " << *OuterFn << "\n");
833 LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
834 << " Exit: " << OI.ExitBB->getName() << "\n");
835 assert(Extractor.isEligible() &&
836 "Expected OpenMP outlining to be possible!");
837
838 for (auto *V : OI.ExcludeArgsFromAggregate)
839 Extractor.excludeArgFromAggregate(Arg: V);
840
841 Function *OutlinedFn =
842 Extractor.extractCodeRegion(CEAC, Inputs&: OI.Inputs, Outputs&: OI.Outputs);
843
844 // Forward target-cpu, target-features attributes to the outlined function.
845 auto TargetCpuAttr = OuterFn->getFnAttribute(Kind: "target-cpu");
846 if (TargetCpuAttr.isStringAttribute())
847 OutlinedFn->addFnAttr(Attr: TargetCpuAttr);
848
849 auto TargetFeaturesAttr = OuterFn->getFnAttribute(Kind: "target-features");
850 if (TargetFeaturesAttr.isStringAttribute())
851 OutlinedFn->addFnAttr(Attr: TargetFeaturesAttr);
852
853 LLVM_DEBUG(dbgs() << "After outlining: " << *OuterFn << "\n");
854 LLVM_DEBUG(dbgs() << " Outlined function: " << *OutlinedFn << "\n");
855 assert(OutlinedFn->getReturnType()->isVoidTy() &&
856 "OpenMP outlined functions should not return a value!");
857
858 // For compability with the clang CG we move the outlined function after the
859 // one with the parallel region.
860 OutlinedFn->removeFromParent();
861 M.getFunctionList().insertAfter(where: OuterFn->getIterator(), New: OutlinedFn);
862
863 // Remove the artificial entry introduced by the extractor right away, we
864 // made our own entry block after all.
865 {
866 BasicBlock &ArtificialEntry = OutlinedFn->getEntryBlock();
867 assert(ArtificialEntry.getUniqueSuccessor() == OI.EntryBB);
868 assert(OI.EntryBB->getUniquePredecessor() == &ArtificialEntry);
869 // Move instructions from the to-be-deleted ArtificialEntry to the entry
870 // basic block of the parallel region. CodeExtractor generates
871 // instructions to unwrap the aggregate argument and may sink
872 // allocas/bitcasts for values that are solely used in the outlined region
873 // and do not escape.
874 assert(!ArtificialEntry.empty() &&
875 "Expected instructions to add in the outlined region entry");
876 for (BasicBlock::reverse_iterator It = ArtificialEntry.rbegin(),
877 End = ArtificialEntry.rend();
878 It != End;) {
879 Instruction &I = *It;
880 It++;
881
882 if (I.isTerminator()) {
883 // Absorb any debug value that terminator may have
884 if (OI.EntryBB->getTerminator())
885 OI.EntryBB->getTerminator()->adoptDbgRecords(
886 BB: &ArtificialEntry, It: I.getIterator(), InsertAtHead: false);
887 continue;
888 }
889
890 I.moveBeforePreserving(BB&: *OI.EntryBB, I: OI.EntryBB->getFirstInsertionPt());
891 }
892
893 OI.EntryBB->moveBefore(MovePos: &ArtificialEntry);
894 ArtificialEntry.eraseFromParent();
895 }
896 assert(&OutlinedFn->getEntryBlock() == OI.EntryBB);
897 assert(OutlinedFn && OutlinedFn->hasNUses(1));
898
899 // Run a user callback, e.g. to add attributes.
900 if (OI.PostOutlineCB)
901 OI.PostOutlineCB(*OutlinedFn);
902
903 if (OI.FixUpNonEntryAllocas)
904 hoistNonEntryAllocasToEntryBlock(Func: OutlinedFn);
905 }
906
907 // Remove work items that have been completed.
908 OutlineInfos = std::move(DeferredOutlines);
909
910 // The createTarget functions embeds user written code into
911 // the target region which may inject allocas which need to
912 // be moved to the entry block of our target or risk malformed
913 // optimisations by later passes, this is only relevant for
914 // the device pass which appears to be a little more delicate
915 // when it comes to optimisations (however, we do not block on
916 // that here, it's up to the inserter to the list to do so).
917 // This notbaly has to occur after the OutlinedInfo candidates
918 // have been extracted so we have an end product that will not
919 // be implicitly adversely affected by any raises unless
920 // intentionally appended to the list.
921 // NOTE: This only does so for ConstantData, it could be extended
922 // to ConstantExpr's with further effort, however, they should
923 // largely be folded when they get here. Extending it to runtime
924 // defined/read+writeable allocation sizes would be non-trivial
925 // (need to factor in movement of any stores to variables the
926 // allocation size depends on, as well as the usual loads,
927 // otherwise it'll yield the wrong result after movement) and
928 // likely be more suitable as an LLVM optimisation pass.
929 for (Function *F : ConstantAllocaRaiseCandidates)
930 raiseUserConstantDataAllocasToEntryBlock(Builder, Function: F);
931
932 EmitMetadataErrorReportFunctionTy &&ErrorReportFn =
933 [](EmitMetadataErrorKind Kind,
934 const TargetRegionEntryInfo &EntryInfo) -> void {
935 errs() << "Error of kind: " << Kind
936 << " when emitting offload entries and metadata during "
937 "OMPIRBuilder finalization \n";
938 };
939
940 if (!OffloadInfoManager.empty())
941 createOffloadEntriesAndInfoMetadata(ErrorReportFunction&: ErrorReportFn);
942
943 if (Config.EmitLLVMUsedMetaInfo.value_or(u: false)) {
944 std::vector<WeakTrackingVH> LLVMCompilerUsed = {
945 M.getGlobalVariable(Name: "__openmp_nvptx_data_transfer_temporary_storage")};
946 emitUsed(Name: "llvm.compiler.used", List: LLVMCompilerUsed);
947 }
948
949 IsFinalized = true;
950}
951
952bool OpenMPIRBuilder::isFinalized() { return IsFinalized; }
953
954OpenMPIRBuilder::~OpenMPIRBuilder() {
955 assert(OutlineInfos.empty() && "There must be no outstanding outlinings");
956}
957
958GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
959 IntegerType *I32Ty = Type::getInt32Ty(C&: M.getContext());
960 auto *GV =
961 new GlobalVariable(M, I32Ty,
962 /* isConstant = */ true, GlobalValue::WeakODRLinkage,
963 ConstantInt::get(Ty: I32Ty, V: Value), Name);
964 GV->setVisibility(GlobalValue::HiddenVisibility);
965
966 return GV;
967}
968
969void OpenMPIRBuilder::emitUsed(StringRef Name, ArrayRef<WeakTrackingVH> List) {
970 if (List.empty())
971 return;
972
973 // Convert List to what ConstantArray needs.
974 SmallVector<Constant *, 8> UsedArray;
975 UsedArray.resize(N: List.size());
976 for (unsigned I = 0, E = List.size(); I != E; ++I)
977 UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
978 C: cast<Constant>(Val: &*List[I]), Ty: Builder.getPtrTy());
979
980 if (UsedArray.empty())
981 return;
982 ArrayType *ATy = ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: UsedArray.size());
983
984 auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
985 ConstantArray::get(T: ATy, V: UsedArray), Name);
986
987 GV->setSection("llvm.metadata");
988}
989
990GlobalVariable *
991OpenMPIRBuilder::emitKernelExecutionMode(StringRef KernelName,
992 OMPTgtExecModeFlags Mode) {
993 auto *Int8Ty = Builder.getInt8Ty();
994 auto *GVMode = new GlobalVariable(
995 M, Int8Ty, /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
996 ConstantInt::get(Ty: Int8Ty, V: Mode), Twine(KernelName, "_exec_mode"));
997 GVMode->setVisibility(GlobalVariable::ProtectedVisibility);
998 return GVMode;
999}
1000
1001Constant *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr,
1002 uint32_t SrcLocStrSize,
1003 IdentFlag LocFlags,
1004 unsigned Reserve2Flags) {
1005 // Enable "C-mode".
1006 LocFlags |= OMP_IDENT_FLAG_KMPC;
1007
1008 Constant *&Ident =
1009 IdentMap[{SrcLocStr, uint64_t(LocFlags) << 31 | Reserve2Flags}];
1010 if (!Ident) {
1011 Constant *I32Null = ConstantInt::getNullValue(Ty: Int32);
1012 Constant *IdentData[] = {I32Null,
1013 ConstantInt::get(Ty: Int32, V: uint32_t(LocFlags)),
1014 ConstantInt::get(Ty: Int32, V: Reserve2Flags),
1015 ConstantInt::get(Ty: Int32, V: SrcLocStrSize), SrcLocStr};
1016
1017 size_t SrcLocStrArgIdx = 4;
1018 if (OpenMPIRBuilder::Ident->getElementType(N: SrcLocStrArgIdx)
1019 ->getPointerAddressSpace() !=
1020 IdentData[SrcLocStrArgIdx]->getType()->getPointerAddressSpace())
1021 IdentData[SrcLocStrArgIdx] = ConstantExpr::getAddrSpaceCast(
1022 C: SrcLocStr, Ty: OpenMPIRBuilder::Ident->getElementType(N: SrcLocStrArgIdx));
1023 Constant *Initializer =
1024 ConstantStruct::get(T: OpenMPIRBuilder::Ident, V: IdentData);
1025
1026 // Look for existing encoding of the location + flags, not needed but
1027 // minimizes the difference to the existing solution while we transition.
1028 for (GlobalVariable &GV : M.globals())
1029 if (GV.getValueType() == OpenMPIRBuilder::Ident && GV.hasInitializer())
1030 if (GV.getInitializer() == Initializer)
1031 Ident = &GV;
1032
1033 if (!Ident) {
1034 auto *GV = new GlobalVariable(
1035 M, OpenMPIRBuilder::Ident,
1036 /* isConstant = */ true, GlobalValue::PrivateLinkage, Initializer, "",
1037 nullptr, GlobalValue::NotThreadLocal,
1038 M.getDataLayout().getDefaultGlobalsAddressSpace());
1039 GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
1040 GV->setAlignment(Align(8));
1041 Ident = GV;
1042 }
1043 }
1044
1045 return ConstantExpr::getPointerBitCastOrAddrSpaceCast(C: Ident, Ty: IdentPtr);
1046}
1047
1048Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef LocStr,
1049 uint32_t &SrcLocStrSize) {
1050 SrcLocStrSize = LocStr.size();
1051 Constant *&SrcLocStr = SrcLocStrMap[LocStr];
1052 if (!SrcLocStr) {
1053 Constant *Initializer =
1054 ConstantDataArray::getString(Context&: M.getContext(), Initializer: LocStr);
1055
1056 // Look for existing encoding of the location, not needed but minimizes the
1057 // difference to the existing solution while we transition.
1058 for (GlobalVariable &GV : M.globals())
1059 if (GV.isConstant() && GV.hasInitializer() &&
1060 GV.getInitializer() == Initializer)
1061 return SrcLocStr = ConstantExpr::getPointerCast(C: &GV, Ty: Int8Ptr);
1062
1063 SrcLocStr = Builder.CreateGlobalString(
1064 Str: LocStr, /*Name=*/"", AddressSpace: M.getDataLayout().getDefaultGlobalsAddressSpace(),
1065 M: &M);
1066 }
1067 return SrcLocStr;
1068}
1069
1070Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef FunctionName,
1071 StringRef FileName,
1072 unsigned Line, unsigned Column,
1073 uint32_t &SrcLocStrSize) {
1074 SmallString<128> Buffer;
1075 Buffer.push_back(Elt: ';');
1076 Buffer.append(RHS: FileName);
1077 Buffer.push_back(Elt: ';');
1078 Buffer.append(RHS: FunctionName);
1079 Buffer.push_back(Elt: ';');
1080 Buffer.append(RHS: std::to_string(val: Line));
1081 Buffer.push_back(Elt: ';');
1082 Buffer.append(RHS: std::to_string(val: Column));
1083 Buffer.push_back(Elt: ';');
1084 Buffer.push_back(Elt: ';');
1085 return getOrCreateSrcLocStr(LocStr: Buffer.str(), SrcLocStrSize);
1086}
1087
1088Constant *
1089OpenMPIRBuilder::getOrCreateDefaultSrcLocStr(uint32_t &SrcLocStrSize) {
1090 StringRef UnknownLoc = ";unknown;unknown;0;0;;";
1091 return getOrCreateSrcLocStr(LocStr: UnknownLoc, SrcLocStrSize);
1092}
1093
1094Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(DebugLoc DL,
1095 uint32_t &SrcLocStrSize,
1096 Function *F) {
1097 DILocation *DIL = DL.get();
1098 if (!DIL)
1099 return getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1100 StringRef FileName = M.getName();
1101 if (DIFile *DIF = DIL->getFile())
1102 if (std::optional<StringRef> Source = DIF->getSource())
1103 FileName = *Source;
1104 StringRef Function = DIL->getScope()->getSubprogram()->getName();
1105 if (Function.empty() && F)
1106 Function = F->getName();
1107 return getOrCreateSrcLocStr(FunctionName: Function, FileName, Line: DIL->getLine(),
1108 Column: DIL->getColumn(), SrcLocStrSize);
1109}
1110
1111Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(const LocationDescription &Loc,
1112 uint32_t &SrcLocStrSize) {
1113 return getOrCreateSrcLocStr(DL: Loc.DL, SrcLocStrSize,
1114 F: Loc.IP.getBlock()->getParent());
1115}
1116
1117Value *OpenMPIRBuilder::getOrCreateThreadID(Value *Ident) {
1118 return createRuntimeFunctionCall(
1119 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_global_thread_num), Args: Ident,
1120 Name: "omp_global_thread_num");
1121}
1122
1123OpenMPIRBuilder::InsertPointOrErrorTy
1124OpenMPIRBuilder::createBarrier(const LocationDescription &Loc, Directive Kind,
1125 bool ForceSimpleCall, bool CheckCancelFlag) {
1126 if (!updateToLocation(Loc))
1127 return Loc.IP;
1128
1129 // Build call __kmpc_cancel_barrier(loc, thread_id) or
1130 // __kmpc_barrier(loc, thread_id);
1131
1132 IdentFlag BarrierLocFlags;
1133 switch (Kind) {
1134 case OMPD_for:
1135 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_FOR;
1136 break;
1137 case OMPD_sections:
1138 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SECTIONS;
1139 break;
1140 case OMPD_single:
1141 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SINGLE;
1142 break;
1143 case OMPD_barrier:
1144 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_EXPL;
1145 break;
1146 default:
1147 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL;
1148 break;
1149 }
1150
1151 uint32_t SrcLocStrSize;
1152 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1153 Value *Args[] = {
1154 getOrCreateIdent(SrcLocStr, SrcLocStrSize, LocFlags: BarrierLocFlags),
1155 getOrCreateThreadID(Ident: getOrCreateIdent(SrcLocStr, SrcLocStrSize))};
1156
1157 // If we are in a cancellable parallel region, barriers are cancellation
1158 // points.
1159 // TODO: Check why we would force simple calls or to ignore the cancel flag.
1160 bool UseCancelBarrier =
1161 !ForceSimpleCall && isLastFinalizationInfoCancellable(DK: OMPD_parallel);
1162
1163 Value *Result = createRuntimeFunctionCall(
1164 Callee: getOrCreateRuntimeFunctionPtr(FnID: UseCancelBarrier
1165 ? OMPRTL___kmpc_cancel_barrier
1166 : OMPRTL___kmpc_barrier),
1167 Args);
1168
1169 if (UseCancelBarrier && CheckCancelFlag)
1170 if (Error Err = emitCancelationCheckImpl(CancelFlag: Result, CanceledDirective: OMPD_parallel))
1171 return Err;
1172
1173 return Builder.saveIP();
1174}
1175
1176OpenMPIRBuilder::InsertPointOrErrorTy
1177OpenMPIRBuilder::createCancel(const LocationDescription &Loc,
1178 Value *IfCondition,
1179 omp::Directive CanceledDirective) {
1180 if (!updateToLocation(Loc))
1181 return Loc.IP;
1182
1183 // LLVM utilities like blocks with terminators.
1184 auto *UI = Builder.CreateUnreachable();
1185
1186 Instruction *ThenTI = UI, *ElseTI = nullptr;
1187 if (IfCondition) {
1188 SplitBlockAndInsertIfThenElse(Cond: IfCondition, SplitBefore: UI, ThenTerm: &ThenTI, ElseTerm: &ElseTI);
1189
1190 // Even if the if condition evaluates to false, this should count as a
1191 // cancellation point
1192 Builder.SetInsertPoint(ElseTI);
1193 auto ElseIP = Builder.saveIP();
1194
1195 InsertPointOrErrorTy IPOrErr = createCancellationPoint(
1196 Loc: LocationDescription{ElseIP, Loc.DL}, CanceledDirective);
1197 if (!IPOrErr)
1198 return IPOrErr;
1199 }
1200
1201 Builder.SetInsertPoint(ThenTI);
1202
1203 Value *CancelKind = nullptr;
1204 switch (CanceledDirective) {
1205#define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value) \
1206 case DirectiveEnum: \
1207 CancelKind = Builder.getInt32(Value); \
1208 break;
1209#include "llvm/Frontend/OpenMP/OMPKinds.def"
1210 default:
1211 llvm_unreachable("Unknown cancel kind!");
1212 }
1213
1214 uint32_t SrcLocStrSize;
1215 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1216 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1217 Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
1218 Value *Result = createRuntimeFunctionCall(
1219 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_cancel), Args);
1220
1221 // The actual cancel logic is shared with others, e.g., cancel_barriers.
1222 if (Error Err = emitCancelationCheckImpl(CancelFlag: Result, CanceledDirective))
1223 return Err;
1224
1225 // Update the insertion point and remove the terminator we introduced.
1226 Builder.SetInsertPoint(UI->getParent());
1227 UI->eraseFromParent();
1228
1229 return Builder.saveIP();
1230}
1231
1232OpenMPIRBuilder::InsertPointOrErrorTy
1233OpenMPIRBuilder::createCancellationPoint(const LocationDescription &Loc,
1234 omp::Directive CanceledDirective) {
1235 if (!updateToLocation(Loc))
1236 return Loc.IP;
1237
1238 // LLVM utilities like blocks with terminators.
1239 auto *UI = Builder.CreateUnreachable();
1240 Builder.SetInsertPoint(UI);
1241
1242 Value *CancelKind = nullptr;
1243 switch (CanceledDirective) {
1244#define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value) \
1245 case DirectiveEnum: \
1246 CancelKind = Builder.getInt32(Value); \
1247 break;
1248#include "llvm/Frontend/OpenMP/OMPKinds.def"
1249 default:
1250 llvm_unreachable("Unknown cancel kind!");
1251 }
1252
1253 uint32_t SrcLocStrSize;
1254 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1255 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1256 Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
1257 Value *Result = createRuntimeFunctionCall(
1258 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_cancellationpoint), Args);
1259
1260 // The actual cancel logic is shared with others, e.g., cancel_barriers.
1261 if (Error Err = emitCancelationCheckImpl(CancelFlag: Result, CanceledDirective))
1262 return Err;
1263
1264 // Update the insertion point and remove the terminator we introduced.
1265 Builder.SetInsertPoint(UI->getParent());
1266 UI->eraseFromParent();
1267
1268 return Builder.saveIP();
1269}
1270
1271OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetKernel(
1272 const LocationDescription &Loc, InsertPointTy AllocaIP, Value *&Return,
1273 Value *Ident, Value *DeviceID, Value *NumTeams, Value *NumThreads,
1274 Value *HostPtr, ArrayRef<Value *> KernelArgs) {
1275 if (!updateToLocation(Loc))
1276 return Loc.IP;
1277
1278 Builder.restoreIP(IP: AllocaIP);
1279 auto *KernelArgsPtr =
1280 Builder.CreateAlloca(Ty: OpenMPIRBuilder::KernelArgs, ArraySize: nullptr, Name: "kernel_args");
1281 updateToLocation(Loc);
1282
1283 for (unsigned I = 0, Size = KernelArgs.size(); I != Size; ++I) {
1284 llvm::Value *Arg =
1285 Builder.CreateStructGEP(Ty: OpenMPIRBuilder::KernelArgs, Ptr: KernelArgsPtr, Idx: I);
1286 Builder.CreateAlignedStore(
1287 Val: KernelArgs[I], Ptr: Arg,
1288 Align: M.getDataLayout().getPrefTypeAlign(Ty: KernelArgs[I]->getType()));
1289 }
1290
1291 SmallVector<Value *> OffloadingArgs{Ident, DeviceID, NumTeams,
1292 NumThreads, HostPtr, KernelArgsPtr};
1293
1294 Return = createRuntimeFunctionCall(
1295 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___tgt_target_kernel),
1296 Args: OffloadingArgs);
1297
1298 return Builder.saveIP();
1299}
1300
1301OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitKernelLaunch(
1302 const LocationDescription &Loc, Value *OutlinedFnID,
1303 EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
1304 Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP) {
1305
1306 if (!updateToLocation(Loc))
1307 return Loc.IP;
1308
1309 // On top of the arrays that were filled up, the target offloading call
1310 // takes as arguments the device id as well as the host pointer. The host
1311 // pointer is used by the runtime library to identify the current target
1312 // region, so it only has to be unique and not necessarily point to
1313 // anything. It could be the pointer to the outlined function that
1314 // implements the target region, but we aren't using that so that the
1315 // compiler doesn't need to keep that, and could therefore inline the host
1316 // function if proven worthwhile during optimization.
1317
1318 // From this point on, we need to have an ID of the target region defined.
1319 assert(OutlinedFnID && "Invalid outlined function ID!");
1320 (void)OutlinedFnID;
1321
1322 // Return value of the runtime offloading call.
1323 Value *Return = nullptr;
1324
1325 // Arguments for the target kernel.
1326 SmallVector<Value *> ArgsVector;
1327 getKernelArgsVector(KernelArgs&: Args, Builder, ArgsVector);
1328
1329 // The target region is an outlined function launched by the runtime
1330 // via calls to __tgt_target_kernel().
1331 //
1332 // Note that on the host and CPU targets, the runtime implementation of
1333 // these calls simply call the outlined function without forking threads.
1334 // The outlined functions themselves have runtime calls to
1335 // __kmpc_fork_teams() and __kmpc_fork() for this purpose, codegen'd by
1336 // the compiler in emitTeamsCall() and emitParallelCall().
1337 //
1338 // In contrast, on the NVPTX target, the implementation of
1339 // __tgt_target_teams() launches a GPU kernel with the requested number
1340 // of teams and threads so no additional calls to the runtime are required.
1341 // Check the error code and execute the host version if required.
1342 Builder.restoreIP(IP: emitTargetKernel(
1343 Loc: Builder, AllocaIP, Return, Ident: RTLoc, DeviceID, NumTeams: Args.NumTeams.front(),
1344 NumThreads: Args.NumThreads.front(), HostPtr: OutlinedFnID, KernelArgs: ArgsVector));
1345
1346 BasicBlock *OffloadFailedBlock =
1347 BasicBlock::Create(Context&: Builder.getContext(), Name: "omp_offload.failed");
1348 BasicBlock *OffloadContBlock =
1349 BasicBlock::Create(Context&: Builder.getContext(), Name: "omp_offload.cont");
1350 Value *Failed = Builder.CreateIsNotNull(Arg: Return);
1351 Builder.CreateCondBr(Cond: Failed, True: OffloadFailedBlock, False: OffloadContBlock);
1352
1353 auto CurFn = Builder.GetInsertBlock()->getParent();
1354 emitBlock(BB: OffloadFailedBlock, CurFn);
1355 InsertPointOrErrorTy AfterIP = EmitTargetCallFallbackCB(Builder.saveIP());
1356 if (!AfterIP)
1357 return AfterIP.takeError();
1358 Builder.restoreIP(IP: *AfterIP);
1359 emitBranch(Target: OffloadContBlock);
1360 emitBlock(BB: OffloadContBlock, CurFn, /*IsFinished=*/true);
1361 return Builder.saveIP();
1362}
1363
1364Error OpenMPIRBuilder::emitCancelationCheckImpl(
1365 Value *CancelFlag, omp::Directive CanceledDirective) {
1366 assert(isLastFinalizationInfoCancellable(CanceledDirective) &&
1367 "Unexpected cancellation!");
1368
1369 // For a cancel barrier we create two new blocks.
1370 BasicBlock *BB = Builder.GetInsertBlock();
1371 BasicBlock *NonCancellationBlock;
1372 if (Builder.GetInsertPoint() == BB->end()) {
1373 // TODO: This branch will not be needed once we moved to the
1374 // OpenMPIRBuilder codegen completely.
1375 NonCancellationBlock = BasicBlock::Create(
1376 Context&: BB->getContext(), Name: BB->getName() + ".cont", Parent: BB->getParent());
1377 } else {
1378 NonCancellationBlock = SplitBlock(Old: BB, SplitPt: &*Builder.GetInsertPoint());
1379 BB->getTerminator()->eraseFromParent();
1380 Builder.SetInsertPoint(BB);
1381 }
1382 BasicBlock *CancellationBlock = BasicBlock::Create(
1383 Context&: BB->getContext(), Name: BB->getName() + ".cncl", Parent: BB->getParent());
1384
1385 // Jump to them based on the return value.
1386 Value *Cmp = Builder.CreateIsNull(Arg: CancelFlag);
1387 Builder.CreateCondBr(Cond: Cmp, True: NonCancellationBlock, False: CancellationBlock,
1388 /* TODO weight */ BranchWeights: nullptr, Unpredictable: nullptr);
1389
1390 // From the cancellation block we finalize all variables and go to the
1391 // post finalization block that is known to the FiniCB callback.
1392 auto &FI = FinalizationStack.back();
1393 Expected<BasicBlock *> FiniBBOrErr = FI.getFiniBB(Builder);
1394 if (!FiniBBOrErr)
1395 return FiniBBOrErr.takeError();
1396 Builder.SetInsertPoint(CancellationBlock);
1397 Builder.CreateBr(Dest: *FiniBBOrErr);
1398
1399 // The continuation block is where code generation continues.
1400 Builder.SetInsertPoint(TheBB: NonCancellationBlock, IP: NonCancellationBlock->begin());
1401 return Error::success();
1402}
1403
1404// Callback used to create OpenMP runtime calls to support
1405// omp parallel clause for the device.
1406// We need to use this callback to replace call to the OutlinedFn in OuterFn
1407// by the call to the OpenMP DeviceRTL runtime function (kmpc_parallel_60)
1408static void targetParallelCallback(
1409 OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn, Function *OuterFn,
1410 BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
1411 Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1412 Value *ThreadID, const SmallVector<Instruction *, 4> &ToBeDeleted) {
1413 // Add some known attributes.
1414 IRBuilder<> &Builder = OMPIRBuilder->Builder;
1415 OutlinedFn.addParamAttr(ArgNo: 0, Kind: Attribute::NoAlias);
1416 OutlinedFn.addParamAttr(ArgNo: 1, Kind: Attribute::NoAlias);
1417 OutlinedFn.addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
1418 OutlinedFn.addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
1419 OutlinedFn.addFnAttr(Kind: Attribute::NoUnwind);
1420
1421 assert(OutlinedFn.arg_size() >= 2 &&
1422 "Expected at least tid and bounded tid as arguments");
1423 unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1424
1425 CallInst *CI = cast<CallInst>(Val: OutlinedFn.user_back());
1426 assert(CI && "Expected call instruction to outlined function");
1427 CI->getParent()->setName("omp_parallel");
1428
1429 Builder.SetInsertPoint(CI);
1430 Type *PtrTy = OMPIRBuilder->VoidPtr;
1431 Value *NullPtrValue = Constant::getNullValue(Ty: PtrTy);
1432
1433 // Add alloca for kernel args
1434 OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP();
1435 Builder.SetInsertPoint(TheBB: OuterAllocaBB, IP: OuterAllocaBB->getFirstInsertionPt());
1436 AllocaInst *ArgsAlloca =
1437 Builder.CreateAlloca(Ty: ArrayType::get(ElementType: PtrTy, NumElements: NumCapturedVars));
1438 Value *Args = ArgsAlloca;
1439 // Add address space cast if array for storing arguments is not allocated
1440 // in address space 0
1441 if (ArgsAlloca->getAddressSpace())
1442 Args = Builder.CreatePointerCast(V: ArgsAlloca, DestTy: PtrTy);
1443 Builder.restoreIP(IP: CurrentIP);
1444
1445 // Store captured vars which are used by kmpc_parallel_60
1446 for (unsigned Idx = 0; Idx < NumCapturedVars; Idx++) {
1447 Value *V = *(CI->arg_begin() + 2 + Idx);
1448 Value *StoreAddress = Builder.CreateConstInBoundsGEP2_64(
1449 Ty: ArrayType::get(ElementType: PtrTy, NumElements: NumCapturedVars), Ptr: Args, Idx0: 0, Idx1: Idx);
1450 Builder.CreateStore(Val: V, Ptr: StoreAddress);
1451 }
1452
1453 Value *Cond =
1454 IfCondition ? Builder.CreateSExtOrTrunc(V: IfCondition, DestTy: OMPIRBuilder->Int32)
1455 : Builder.getInt32(C: 1);
1456
1457 // Build kmpc_parallel_60 call
1458 Value *Parallel60CallArgs[] = {
1459 /* identifier*/ Ident,
1460 /* global thread num*/ ThreadID,
1461 /* if expression */ Cond,
1462 /* number of threads */ NumThreads ? NumThreads : Builder.getInt32(C: -1),
1463 /* Proc bind */ Builder.getInt32(C: -1),
1464 /* outlined function */ &OutlinedFn,
1465 /* wrapper function */ NullPtrValue,
1466 /* arguments of the outlined funciton*/ Args,
1467 /* number of arguments */ Builder.getInt64(C: NumCapturedVars),
1468 /* strict for number of threads */ Builder.getInt32(C: 0)};
1469
1470 FunctionCallee RTLFn =
1471 OMPIRBuilder->getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_parallel_60);
1472
1473 OMPIRBuilder->createRuntimeFunctionCall(Callee: RTLFn, Args: Parallel60CallArgs);
1474
1475 LLVM_DEBUG(dbgs() << "With kmpc_parallel_60 placed: "
1476 << *Builder.GetInsertBlock()->getParent() << "\n");
1477
1478 // Initialize the local TID stack location with the argument value.
1479 Builder.SetInsertPoint(PrivTID);
1480 Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1481 Builder.CreateStore(Val: Builder.CreateLoad(Ty: OMPIRBuilder->Int32, Ptr: OutlinedAI),
1482 Ptr: PrivTIDAddr);
1483
1484 // Remove redundant call to the outlined function.
1485 CI->eraseFromParent();
1486
1487 for (Instruction *I : ToBeDeleted) {
1488 I->eraseFromParent();
1489 }
1490}
1491
1492// Callback used to create OpenMP runtime calls to support
1493// omp parallel clause for the host.
1494// We need to use this callback to replace call to the OutlinedFn in OuterFn
1495// by the call to the OpenMP host runtime function ( __kmpc_fork_call[_if])
1496static void
1497hostParallelCallback(OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn,
1498 Function *OuterFn, Value *Ident, Value *IfCondition,
1499 Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1500 const SmallVector<Instruction *, 4> &ToBeDeleted) {
1501 IRBuilder<> &Builder = OMPIRBuilder->Builder;
1502 FunctionCallee RTLFn;
1503 if (IfCondition) {
1504 RTLFn =
1505 OMPIRBuilder->getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_fork_call_if);
1506 } else {
1507 RTLFn =
1508 OMPIRBuilder->getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_fork_call);
1509 }
1510 if (auto *F = dyn_cast<Function>(Val: RTLFn.getCallee())) {
1511 if (!F->hasMetadata(KindID: LLVMContext::MD_callback)) {
1512 LLVMContext &Ctx = F->getContext();
1513 MDBuilder MDB(Ctx);
1514 // Annotate the callback behavior of the __kmpc_fork_call:
1515 // - The callback callee is argument number 2 (microtask).
1516 // - The first two arguments of the callback callee are unknown (-1).
1517 // - All variadic arguments to the __kmpc_fork_call are passed to the
1518 // callback callee.
1519 F->addMetadata(KindID: LLVMContext::MD_callback,
1520 MD&: *MDNode::get(Context&: Ctx, MDs: {MDB.createCallbackEncoding(
1521 CalleeArgNo: 2, Arguments: {-1, -1},
1522 /* VarArgsArePassed */ true)}));
1523 }
1524 }
1525 // Add some known attributes.
1526 OutlinedFn.addParamAttr(ArgNo: 0, Kind: Attribute::NoAlias);
1527 OutlinedFn.addParamAttr(ArgNo: 1, Kind: Attribute::NoAlias);
1528 OutlinedFn.addFnAttr(Kind: Attribute::NoUnwind);
1529
1530 assert(OutlinedFn.arg_size() >= 2 &&
1531 "Expected at least tid and bounded tid as arguments");
1532 unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1533
1534 CallInst *CI = cast<CallInst>(Val: OutlinedFn.user_back());
1535 CI->getParent()->setName("omp_parallel");
1536 Builder.SetInsertPoint(CI);
1537
1538 // Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn);
1539 Value *ForkCallArgs[] = {Ident, Builder.getInt32(C: NumCapturedVars),
1540 &OutlinedFn};
1541
1542 SmallVector<Value *, 16> RealArgs;
1543 RealArgs.append(in_start: std::begin(arr&: ForkCallArgs), in_end: std::end(arr&: ForkCallArgs));
1544 if (IfCondition) {
1545 Value *Cond = Builder.CreateSExtOrTrunc(V: IfCondition, DestTy: OMPIRBuilder->Int32);
1546 RealArgs.push_back(Elt: Cond);
1547 }
1548 RealArgs.append(in_start: CI->arg_begin() + /* tid & bound tid */ 2, in_end: CI->arg_end());
1549
1550 // __kmpc_fork_call_if always expects a void ptr as the last argument
1551 // If there are no arguments, pass a null pointer.
1552 auto PtrTy = OMPIRBuilder->VoidPtr;
1553 if (IfCondition && NumCapturedVars == 0) {
1554 Value *NullPtrValue = Constant::getNullValue(Ty: PtrTy);
1555 RealArgs.push_back(Elt: NullPtrValue);
1556 }
1557
1558 OMPIRBuilder->createRuntimeFunctionCall(Callee: RTLFn, Args: RealArgs);
1559
1560 LLVM_DEBUG(dbgs() << "With fork_call placed: "
1561 << *Builder.GetInsertBlock()->getParent() << "\n");
1562
1563 // Initialize the local TID stack location with the argument value.
1564 Builder.SetInsertPoint(PrivTID);
1565 Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1566 Builder.CreateStore(Val: Builder.CreateLoad(Ty: OMPIRBuilder->Int32, Ptr: OutlinedAI),
1567 Ptr: PrivTIDAddr);
1568
1569 // Remove redundant call to the outlined function.
1570 CI->eraseFromParent();
1571
1572 for (Instruction *I : ToBeDeleted) {
1573 I->eraseFromParent();
1574 }
1575}
1576
1577OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel(
1578 const LocationDescription &Loc, InsertPointTy OuterAllocaIP,
1579 BodyGenCallbackTy BodyGenCB, PrivatizeCallbackTy PrivCB,
1580 FinalizeCallbackTy FiniCB, Value *IfCondition, Value *NumThreads,
1581 omp::ProcBindKind ProcBind, bool IsCancellable) {
1582 assert(!isConflictIP(Loc.IP, OuterAllocaIP) && "IPs must not be ambiguous");
1583
1584 if (!updateToLocation(Loc))
1585 return Loc.IP;
1586
1587 uint32_t SrcLocStrSize;
1588 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1589 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1590 Value *ThreadID = getOrCreateThreadID(Ident);
1591 // If we generate code for the target device, we need to allocate
1592 // struct for aggregate params in the device default alloca address space.
1593 // OpenMP runtime requires that the params of the extracted functions are
1594 // passed as zero address space pointers. This flag ensures that extracted
1595 // function arguments are declared in zero address space
1596 bool ArgsInZeroAddressSpace = Config.isTargetDevice();
1597
1598 // Build call __kmpc_push_num_threads(&Ident, global_tid, num_threads)
1599 // only if we compile for host side.
1600 if (NumThreads && !Config.isTargetDevice()) {
1601 Value *Args[] = {
1602 Ident, ThreadID,
1603 Builder.CreateIntCast(V: NumThreads, DestTy: Int32, /*isSigned*/ false)};
1604 createRuntimeFunctionCall(
1605 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_push_num_threads), Args);
1606 }
1607
1608 if (ProcBind != OMP_PROC_BIND_default) {
1609 // Build call __kmpc_push_proc_bind(&Ident, global_tid, proc_bind)
1610 Value *Args[] = {
1611 Ident, ThreadID,
1612 ConstantInt::get(Ty: Int32, V: unsigned(ProcBind), /*isSigned=*/IsSigned: true)};
1613 createRuntimeFunctionCall(
1614 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_push_proc_bind), Args);
1615 }
1616
1617 BasicBlock *InsertBB = Builder.GetInsertBlock();
1618 Function *OuterFn = InsertBB->getParent();
1619
1620 // Save the outer alloca block because the insertion iterator may get
1621 // invalidated and we still need this later.
1622 BasicBlock *OuterAllocaBlock = OuterAllocaIP.getBlock();
1623
1624 // Vector to remember instructions we used only during the modeling but which
1625 // we want to delete at the end.
1626 SmallVector<Instruction *, 4> ToBeDeleted;
1627
1628 // Change the location to the outer alloca insertion point to create and
1629 // initialize the allocas we pass into the parallel region.
1630 InsertPointTy NewOuter(OuterAllocaBlock, OuterAllocaBlock->begin());
1631 Builder.restoreIP(IP: NewOuter);
1632 AllocaInst *TIDAddrAlloca = Builder.CreateAlloca(Ty: Int32, ArraySize: nullptr, Name: "tid.addr");
1633 AllocaInst *ZeroAddrAlloca =
1634 Builder.CreateAlloca(Ty: Int32, ArraySize: nullptr, Name: "zero.addr");
1635 Instruction *TIDAddr = TIDAddrAlloca;
1636 Instruction *ZeroAddr = ZeroAddrAlloca;
1637 if (ArgsInZeroAddressSpace && M.getDataLayout().getAllocaAddrSpace() != 0) {
1638 // Add additional casts to enforce pointers in zero address space
1639 TIDAddr = new AddrSpaceCastInst(
1640 TIDAddrAlloca, PointerType ::get(C&: M.getContext(), AddressSpace: 0), "tid.addr.ascast");
1641 TIDAddr->insertAfter(InsertPos: TIDAddrAlloca->getIterator());
1642 ToBeDeleted.push_back(Elt: TIDAddr);
1643 ZeroAddr = new AddrSpaceCastInst(ZeroAddrAlloca,
1644 PointerType ::get(C&: M.getContext(), AddressSpace: 0),
1645 "zero.addr.ascast");
1646 ZeroAddr->insertAfter(InsertPos: ZeroAddrAlloca->getIterator());
1647 ToBeDeleted.push_back(Elt: ZeroAddr);
1648 }
1649
1650 // We only need TIDAddr and ZeroAddr for modeling purposes to get the
1651 // associated arguments in the outlined function, so we delete them later.
1652 ToBeDeleted.push_back(Elt: TIDAddrAlloca);
1653 ToBeDeleted.push_back(Elt: ZeroAddrAlloca);
1654
1655 // Create an artificial insertion point that will also ensure the blocks we
1656 // are about to split are not degenerated.
1657 auto *UI = new UnreachableInst(Builder.getContext(), InsertBB);
1658
1659 BasicBlock *EntryBB = UI->getParent();
1660 BasicBlock *PRegEntryBB = EntryBB->splitBasicBlock(I: UI, BBName: "omp.par.entry");
1661 BasicBlock *PRegBodyBB = PRegEntryBB->splitBasicBlock(I: UI, BBName: "omp.par.region");
1662 BasicBlock *PRegPreFiniBB =
1663 PRegBodyBB->splitBasicBlock(I: UI, BBName: "omp.par.pre_finalize");
1664 BasicBlock *PRegExitBB = PRegPreFiniBB->splitBasicBlock(I: UI, BBName: "omp.par.exit");
1665
1666 auto FiniCBWrapper = [&](InsertPointTy IP) {
1667 // Hide "open-ended" blocks from the given FiniCB by setting the right jump
1668 // target to the region exit block.
1669 if (IP.getBlock()->end() == IP.getPoint()) {
1670 IRBuilder<>::InsertPointGuard IPG(Builder);
1671 Builder.restoreIP(IP);
1672 Instruction *I = Builder.CreateBr(Dest: PRegExitBB);
1673 IP = InsertPointTy(I->getParent(), I->getIterator());
1674 }
1675 assert(IP.getBlock()->getTerminator()->getNumSuccessors() == 1 &&
1676 IP.getBlock()->getTerminator()->getSuccessor(0) == PRegExitBB &&
1677 "Unexpected insertion point for finalization call!");
1678 return FiniCB(IP);
1679 };
1680
1681 FinalizationStack.push_back(Elt: {FiniCBWrapper, OMPD_parallel, IsCancellable});
1682
1683 // Generate the privatization allocas in the block that will become the entry
1684 // of the outlined function.
1685 Builder.SetInsertPoint(PRegEntryBB->getTerminator());
1686 InsertPointTy InnerAllocaIP = Builder.saveIP();
1687
1688 AllocaInst *PrivTIDAddr =
1689 Builder.CreateAlloca(Ty: Int32, ArraySize: nullptr, Name: "tid.addr.local");
1690 Instruction *PrivTID = Builder.CreateLoad(Ty: Int32, Ptr: PrivTIDAddr, Name: "tid");
1691
1692 // Add some fake uses for OpenMP provided arguments.
1693 ToBeDeleted.push_back(Elt: Builder.CreateLoad(Ty: Int32, Ptr: TIDAddr, Name: "tid.addr.use"));
1694 Instruction *ZeroAddrUse =
1695 Builder.CreateLoad(Ty: Int32, Ptr: ZeroAddr, Name: "zero.addr.use");
1696 ToBeDeleted.push_back(Elt: ZeroAddrUse);
1697
1698 // EntryBB
1699 // |
1700 // V
1701 // PRegionEntryBB <- Privatization allocas are placed here.
1702 // |
1703 // V
1704 // PRegionBodyBB <- BodeGen is invoked here.
1705 // |
1706 // V
1707 // PRegPreFiniBB <- The block we will start finalization from.
1708 // |
1709 // V
1710 // PRegionExitBB <- A common exit to simplify block collection.
1711 //
1712
1713 LLVM_DEBUG(dbgs() << "Before body codegen: " << *OuterFn << "\n");
1714
1715 // Let the caller create the body.
1716 assert(BodyGenCB && "Expected body generation callback!");
1717 InsertPointTy CodeGenIP(PRegBodyBB, PRegBodyBB->begin());
1718 if (Error Err = BodyGenCB(InnerAllocaIP, CodeGenIP))
1719 return Err;
1720
1721 LLVM_DEBUG(dbgs() << "After body codegen: " << *OuterFn << "\n");
1722
1723 OutlineInfo OI;
1724 if (Config.isTargetDevice()) {
1725 // Generate OpenMP target specific runtime call
1726 OI.PostOutlineCB = [=, ToBeDeletedVec =
1727 std::move(ToBeDeleted)](Function &OutlinedFn) {
1728 targetParallelCallback(OMPIRBuilder: this, OutlinedFn, OuterFn, OuterAllocaBB: OuterAllocaBlock, Ident,
1729 IfCondition, NumThreads, PrivTID, PrivTIDAddr,
1730 ThreadID, ToBeDeleted: ToBeDeletedVec);
1731 };
1732 OI.FixUpNonEntryAllocas = true;
1733 } else {
1734 // Generate OpenMP host runtime call
1735 OI.PostOutlineCB = [=, ToBeDeletedVec =
1736 std::move(ToBeDeleted)](Function &OutlinedFn) {
1737 hostParallelCallback(OMPIRBuilder: this, OutlinedFn, OuterFn, Ident, IfCondition,
1738 PrivTID, PrivTIDAddr, ToBeDeleted: ToBeDeletedVec);
1739 };
1740 OI.FixUpNonEntryAllocas = true;
1741 }
1742
1743 OI.OuterAllocaBB = OuterAllocaBlock;
1744 OI.EntryBB = PRegEntryBB;
1745 OI.ExitBB = PRegExitBB;
1746
1747 SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
1748 SmallVector<BasicBlock *, 32> Blocks;
1749 OI.collectBlocks(BlockSet&: ParallelRegionBlockSet, BlockVector&: Blocks);
1750
1751 CodeExtractorAnalysisCache CEAC(*OuterFn);
1752 CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
1753 /* AggregateArgs */ false,
1754 /* BlockFrequencyInfo */ nullptr,
1755 /* BranchProbabilityInfo */ nullptr,
1756 /* AssumptionCache */ nullptr,
1757 /* AllowVarArgs */ true,
1758 /* AllowAlloca */ true,
1759 /* AllocationBlock */ OuterAllocaBlock,
1760 /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
1761
1762 // Find inputs to, outputs from the code region.
1763 BasicBlock *CommonExit = nullptr;
1764 SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
1765 Extractor.findAllocas(CEAC, SinkCands&: SinkingCands, HoistCands&: HoistingCands, ExitBlock&: CommonExit);
1766
1767 Extractor.findInputsOutputs(Inputs, Outputs, Allocas: SinkingCands,
1768 /*CollectGlobalInputs=*/true);
1769
1770 Inputs.remove_if(P: [&](Value *I) {
1771 if (auto *GV = dyn_cast_if_present<GlobalVariable>(Val: I))
1772 return GV->getValueType() == OpenMPIRBuilder::Ident;
1773
1774 return false;
1775 });
1776
1777 LLVM_DEBUG(dbgs() << "Before privatization: " << *OuterFn << "\n");
1778
1779 FunctionCallee TIDRTLFn =
1780 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_global_thread_num);
1781
1782 auto PrivHelper = [&](Value &V) -> Error {
1783 if (&V == TIDAddr || &V == ZeroAddr) {
1784 OI.ExcludeArgsFromAggregate.push_back(Elt: &V);
1785 return Error::success();
1786 }
1787
1788 SetVector<Use *> Uses;
1789 for (Use &U : V.uses())
1790 if (auto *UserI = dyn_cast<Instruction>(Val: U.getUser()))
1791 if (ParallelRegionBlockSet.count(Ptr: UserI->getParent()))
1792 Uses.insert(X: &U);
1793
1794 // __kmpc_fork_call expects extra arguments as pointers. If the input
1795 // already has a pointer type, everything is fine. Otherwise, store the
1796 // value onto stack and load it back inside the to-be-outlined region. This
1797 // will ensure only the pointer will be passed to the function.
1798 // FIXME: if there are more than 15 trailing arguments, they must be
1799 // additionally packed in a struct.
1800 Value *Inner = &V;
1801 if (!V.getType()->isPointerTy()) {
1802 IRBuilder<>::InsertPointGuard Guard(Builder);
1803 LLVM_DEBUG(llvm::dbgs() << "Forwarding input as pointer: " << V << "\n");
1804
1805 Builder.restoreIP(IP: OuterAllocaIP);
1806 Value *Ptr =
1807 Builder.CreateAlloca(Ty: V.getType(), ArraySize: nullptr, Name: V.getName() + ".reloaded");
1808
1809 // Store to stack at end of the block that currently branches to the entry
1810 // block of the to-be-outlined region.
1811 Builder.SetInsertPoint(TheBB: InsertBB,
1812 IP: InsertBB->getTerminator()->getIterator());
1813 Builder.CreateStore(Val: &V, Ptr);
1814
1815 // Load back next to allocations in the to-be-outlined region.
1816 Builder.restoreIP(IP: InnerAllocaIP);
1817 Inner = Builder.CreateLoad(Ty: V.getType(), Ptr);
1818 }
1819
1820 Value *ReplacementValue = nullptr;
1821 CallInst *CI = dyn_cast<CallInst>(Val: &V);
1822 if (CI && CI->getCalledFunction() == TIDRTLFn.getCallee()) {
1823 ReplacementValue = PrivTID;
1824 } else {
1825 InsertPointOrErrorTy AfterIP =
1826 PrivCB(InnerAllocaIP, Builder.saveIP(), V, *Inner, ReplacementValue);
1827 if (!AfterIP)
1828 return AfterIP.takeError();
1829 Builder.restoreIP(IP: *AfterIP);
1830 InnerAllocaIP = {
1831 InnerAllocaIP.getBlock(),
1832 InnerAllocaIP.getBlock()->getTerminator()->getIterator()};
1833
1834 assert(ReplacementValue &&
1835 "Expected copy/create callback to set replacement value!");
1836 if (ReplacementValue == &V)
1837 return Error::success();
1838 }
1839
1840 for (Use *UPtr : Uses)
1841 UPtr->set(ReplacementValue);
1842
1843 return Error::success();
1844 };
1845
1846 // Reset the inner alloca insertion as it will be used for loading the values
1847 // wrapped into pointers before passing them into the to-be-outlined region.
1848 // Configure it to insert immediately after the fake use of zero address so
1849 // that they are available in the generated body and so that the
1850 // OpenMP-related values (thread ID and zero address pointers) remain leading
1851 // in the argument list.
1852 InnerAllocaIP = IRBuilder<>::InsertPoint(
1853 ZeroAddrUse->getParent(), ZeroAddrUse->getNextNode()->getIterator());
1854
1855 // Reset the outer alloca insertion point to the entry of the relevant block
1856 // in case it was invalidated.
1857 OuterAllocaIP = IRBuilder<>::InsertPoint(
1858 OuterAllocaBlock, OuterAllocaBlock->getFirstInsertionPt());
1859
1860 for (Value *Input : Inputs) {
1861 LLVM_DEBUG(dbgs() << "Captured input: " << *Input << "\n");
1862 if (Error Err = PrivHelper(*Input))
1863 return Err;
1864 }
1865 LLVM_DEBUG({
1866 for (Value *Output : Outputs)
1867 LLVM_DEBUG(dbgs() << "Captured output: " << *Output << "\n");
1868 });
1869 assert(Outputs.empty() &&
1870 "OpenMP outlining should not produce live-out values!");
1871
1872 LLVM_DEBUG(dbgs() << "After privatization: " << *OuterFn << "\n");
1873 LLVM_DEBUG({
1874 for (auto *BB : Blocks)
1875 dbgs() << " PBR: " << BB->getName() << "\n";
1876 });
1877
1878 // Adjust the finalization stack, verify the adjustment, and call the
1879 // finalize function a last time to finalize values between the pre-fini
1880 // block and the exit block if we left the parallel "the normal way".
1881 auto FiniInfo = FinalizationStack.pop_back_val();
1882 (void)FiniInfo;
1883 assert(FiniInfo.DK == OMPD_parallel &&
1884 "Unexpected finalization stack state!");
1885
1886 Instruction *PRegPreFiniTI = PRegPreFiniBB->getTerminator();
1887
1888 InsertPointTy PreFiniIP(PRegPreFiniBB, PRegPreFiniTI->getIterator());
1889 Expected<BasicBlock *> FiniBBOrErr = FiniInfo.getFiniBB(Builder);
1890 if (!FiniBBOrErr)
1891 return FiniBBOrErr.takeError();
1892 {
1893 IRBuilderBase::InsertPointGuard Guard(Builder);
1894 Builder.restoreIP(IP: PreFiniIP);
1895 Builder.CreateBr(Dest: *FiniBBOrErr);
1896 // There's currently a branch to omp.par.exit. Delete it. We will get there
1897 // via the fini block
1898 if (Instruction *Term = Builder.GetInsertBlock()->getTerminator())
1899 Term->eraseFromParent();
1900 }
1901
1902 // Register the outlined info.
1903 addOutlineInfo(OI: std::move(OI));
1904
1905 InsertPointTy AfterIP(UI->getParent(), UI->getParent()->end());
1906 UI->eraseFromParent();
1907
1908 return AfterIP;
1909}
1910
1911void OpenMPIRBuilder::emitFlush(const LocationDescription &Loc) {
1912 // Build call void __kmpc_flush(ident_t *loc)
1913 uint32_t SrcLocStrSize;
1914 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1915 Value *Args[] = {getOrCreateIdent(SrcLocStr, SrcLocStrSize)};
1916
1917 createRuntimeFunctionCall(Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_flush),
1918 Args);
1919}
1920
1921void OpenMPIRBuilder::createFlush(const LocationDescription &Loc) {
1922 if (!updateToLocation(Loc))
1923 return;
1924 emitFlush(Loc);
1925}
1926
1927void OpenMPIRBuilder::emitTaskwaitImpl(const LocationDescription &Loc) {
1928 // Build call kmp_int32 __kmpc_omp_taskwait(ident_t *loc, kmp_int32
1929 // global_tid);
1930 uint32_t SrcLocStrSize;
1931 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1932 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1933 Value *Args[] = {Ident, getOrCreateThreadID(Ident)};
1934
1935 // Ignore return result until untied tasks are supported.
1936 createRuntimeFunctionCall(
1937 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_taskwait), Args);
1938}
1939
1940void OpenMPIRBuilder::createTaskwait(const LocationDescription &Loc) {
1941 if (!updateToLocation(Loc))
1942 return;
1943 emitTaskwaitImpl(Loc);
1944}
1945
1946void OpenMPIRBuilder::emitTaskyieldImpl(const LocationDescription &Loc) {
1947 // Build call __kmpc_omp_taskyield(loc, thread_id, 0);
1948 uint32_t SrcLocStrSize;
1949 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1950 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1951 Constant *I32Null = ConstantInt::getNullValue(Ty: Int32);
1952 Value *Args[] = {Ident, getOrCreateThreadID(Ident), I32Null};
1953
1954 createRuntimeFunctionCall(
1955 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_taskyield), Args);
1956}
1957
1958void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
1959 if (!updateToLocation(Loc))
1960 return;
1961 emitTaskyieldImpl(Loc);
1962}
1963
1964// Processes the dependencies in Dependencies and does the following
1965// - Allocates space on the stack of an array of DependInfo objects
1966// - Populates each DependInfo object with relevant information of
1967// the corresponding dependence.
1968// - All code is inserted in the entry block of the current function.
1969static Value *emitTaskDependencies(
1970 OpenMPIRBuilder &OMPBuilder,
1971 const SmallVectorImpl<OpenMPIRBuilder::DependData> &Dependencies) {
1972 // Early return if we have no dependencies to process
1973 if (Dependencies.empty())
1974 return nullptr;
1975
1976 // Given a vector of DependData objects, in this function we create an
1977 // array on the stack that holds kmp_dep_info objects corresponding
1978 // to each dependency. This is then passed to the OpenMP runtime.
1979 // For example, if there are 'n' dependencies then the following psedo
1980 // code is generated. Assume the first dependence is on a variable 'a'
1981 //
1982 // \code{c}
1983 // DepArray = alloc(n x sizeof(kmp_depend_info);
1984 // idx = 0;
1985 // DepArray[idx].base_addr = ptrtoint(&a);
1986 // DepArray[idx].len = 8;
1987 // DepArray[idx].flags = Dep.DepKind; /*(See OMPContants.h for DepKind)*/
1988 // ++idx;
1989 // DepArray[idx].base_addr = ...;
1990 // \endcode
1991
1992 IRBuilderBase &Builder = OMPBuilder.Builder;
1993 Type *DependInfo = OMPBuilder.DependInfo;
1994 Module &M = OMPBuilder.M;
1995
1996 Value *DepArray = nullptr;
1997 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
1998 Builder.SetInsertPoint(
1999 OldIP.getBlock()->getParent()->getEntryBlock().getTerminator());
2000
2001 Type *DepArrayTy = ArrayType::get(ElementType: DependInfo, NumElements: Dependencies.size());
2002 DepArray = Builder.CreateAlloca(Ty: DepArrayTy, ArraySize: nullptr, Name: ".dep.arr.addr");
2003
2004 Builder.restoreIP(IP: OldIP);
2005
2006 for (const auto &[DepIdx, Dep] : enumerate(First: Dependencies)) {
2007 Value *Base =
2008 Builder.CreateConstInBoundsGEP2_64(Ty: DepArrayTy, Ptr: DepArray, Idx0: 0, Idx1: DepIdx);
2009 // Store the pointer to the variable
2010 Value *Addr = Builder.CreateStructGEP(
2011 Ty: DependInfo, Ptr: Base,
2012 Idx: static_cast<unsigned int>(RTLDependInfoFields::BaseAddr));
2013 Value *DepValPtr = Builder.CreatePtrToInt(V: Dep.DepVal, DestTy: Builder.getInt64Ty());
2014 Builder.CreateStore(Val: DepValPtr, Ptr: Addr);
2015 // Store the size of the variable
2016 Value *Size = Builder.CreateStructGEP(
2017 Ty: DependInfo, Ptr: Base, Idx: static_cast<unsigned int>(RTLDependInfoFields::Len));
2018 Builder.CreateStore(
2019 Val: Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: Dep.DepValueType)),
2020 Ptr: Size);
2021 // Store the dependency kind
2022 Value *Flags = Builder.CreateStructGEP(
2023 Ty: DependInfo, Ptr: Base,
2024 Idx: static_cast<unsigned int>(RTLDependInfoFields::Flags));
2025 Builder.CreateStore(
2026 Val: ConstantInt::get(Ty: Builder.getInt8Ty(),
2027 V: static_cast<unsigned int>(Dep.DepKind)),
2028 Ptr: Flags);
2029 }
2030 return DepArray;
2031}
2032
2033/// Create the task duplication function passed to kmpc_taskloop.
2034Expected<Value *> OpenMPIRBuilder::createTaskDuplicationFunction(
2035 Type *PrivatesTy, int32_t PrivatesIndex, TaskDupCallbackTy DupCB) {
2036 unsigned ProgramAddressSpace = M.getDataLayout().getProgramAddressSpace();
2037 if (!DupCB)
2038 return Constant::getNullValue(
2039 Ty: PointerType::get(C&: Builder.getContext(), AddressSpace: ProgramAddressSpace));
2040
2041 // From OpenMP Runtime p_task_dup_t:
2042 // Routine optionally generated by the compiler for setting the lastprivate
2043 // flag and calling needed constructors for private/firstprivate objects (used
2044 // to form taskloop tasks from pattern task) Parameters: dest task, src task,
2045 // lastprivate flag.
2046 // typedef void (*p_task_dup_t)(kmp_task_t *, kmp_task_t *, kmp_int32);
2047
2048 auto *VoidPtrTy = PointerType::get(C&: Builder.getContext(), AddressSpace: ProgramAddressSpace);
2049
2050 FunctionType *DupFuncTy = FunctionType::get(
2051 Result: Builder.getVoidTy(), Params: {VoidPtrTy, VoidPtrTy, Builder.getInt32Ty()},
2052 /*isVarArg=*/false);
2053
2054 Function *DupFunction = Function::Create(Ty: DupFuncTy, Linkage: Function::InternalLinkage,
2055 N: "omp_taskloop_dup", M);
2056 Value *DestTaskArg = DupFunction->getArg(i: 0);
2057 Value *SrcTaskArg = DupFunction->getArg(i: 1);
2058 Value *LastprivateFlagArg = DupFunction->getArg(i: 2);
2059 DestTaskArg->setName("dest_task");
2060 SrcTaskArg->setName("src_task");
2061 LastprivateFlagArg->setName("lastprivate_flag");
2062
2063 IRBuilderBase::InsertPointGuard Guard(Builder);
2064 Builder.SetInsertPoint(
2065 BasicBlock::Create(Context&: Builder.getContext(), Name: "entry", Parent: DupFunction));
2066
2067 auto GetTaskContextPtrFromArg = [&](Value *Arg) -> Value * {
2068 Type *TaskWithPrivatesTy =
2069 StructType::get(Context&: Builder.getContext(), Elements: {Task, PrivatesTy});
2070 Value *TaskPrivates = Builder.CreateGEP(
2071 Ty: TaskWithPrivatesTy, Ptr: Arg, IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 1)});
2072 Value *ContextPtr = Builder.CreateGEP(
2073 Ty: PrivatesTy, Ptr: TaskPrivates,
2074 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: PrivatesIndex)});
2075 return ContextPtr;
2076 };
2077
2078 Value *DestTaskContextPtr = GetTaskContextPtrFromArg(DestTaskArg);
2079 Value *SrcTaskContextPtr = GetTaskContextPtrFromArg(SrcTaskArg);
2080
2081 DestTaskContextPtr->setName("destPtr");
2082 SrcTaskContextPtr->setName("srcPtr");
2083
2084 InsertPointTy AllocaIP(&DupFunction->getEntryBlock(),
2085 DupFunction->getEntryBlock().begin());
2086 InsertPointTy CodeGenIP = Builder.saveIP();
2087 Expected<IRBuilderBase::InsertPoint> AfterIPOrError =
2088 DupCB(AllocaIP, CodeGenIP, DestTaskContextPtr, SrcTaskContextPtr);
2089 if (!AfterIPOrError)
2090 return AfterIPOrError.takeError();
2091 Builder.restoreIP(IP: *AfterIPOrError);
2092
2093 Builder.CreateRetVoid();
2094
2095 return DupFunction;
2096}
2097
2098OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTaskloop(
2099 const LocationDescription &Loc, InsertPointTy AllocaIP,
2100 BodyGenCallbackTy BodyGenCB,
2101 llvm::function_ref<llvm::Expected<llvm::CanonicalLoopInfo *>()> LoopInfo,
2102 Value *LBVal, Value *UBVal, Value *StepVal, bool Untied, Value *IfCond,
2103 Value *GrainSize, bool NoGroup, int Sched, Value *Final, bool Mergeable,
2104 Value *Priority, uint64_t NumOfCollapseLoops, TaskDupCallbackTy DupCB,
2105 Value *TaskContextStructPtrVal) {
2106
2107 if (!updateToLocation(Loc))
2108 return InsertPointTy();
2109
2110 uint32_t SrcLocStrSize;
2111 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2112 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2113
2114 BasicBlock *TaskloopExitBB =
2115 splitBB(Builder, /*CreateBranch=*/true, Name: "taskloop.exit");
2116 BasicBlock *TaskloopBodyBB =
2117 splitBB(Builder, /*CreateBranch=*/true, Name: "taskloop.body");
2118 BasicBlock *TaskloopAllocaBB =
2119 splitBB(Builder, /*CreateBranch=*/true, Name: "taskloop.alloca");
2120
2121 InsertPointTy TaskloopAllocaIP =
2122 InsertPointTy(TaskloopAllocaBB, TaskloopAllocaBB->begin());
2123 InsertPointTy TaskloopBodyIP =
2124 InsertPointTy(TaskloopBodyBB, TaskloopBodyBB->begin());
2125
2126 if (Error Err = BodyGenCB(TaskloopAllocaIP, TaskloopBodyIP))
2127 return Err;
2128
2129 llvm::Expected<llvm::CanonicalLoopInfo *> result = LoopInfo();
2130 if (!result) {
2131 return result.takeError();
2132 }
2133
2134 llvm::CanonicalLoopInfo *CLI = result.get();
2135 OutlineInfo OI;
2136 OI.EntryBB = TaskloopAllocaBB;
2137 OI.OuterAllocaBB = AllocaIP.getBlock();
2138 OI.ExitBB = TaskloopExitBB;
2139
2140 // Add the thread ID argument.
2141 SmallVector<Instruction *> ToBeDeleted;
2142 // dummy instruction to be used as a fake argument
2143 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
2144 Builder, OuterAllocaIP: AllocaIP, ToBeDeleted, InnerAllocaIP: TaskloopAllocaIP, Name: "global.tid", AsPtr: false));
2145 Value *FakeLB = createFakeIntVal(Builder, OuterAllocaIP: AllocaIP, ToBeDeleted,
2146 InnerAllocaIP: TaskloopAllocaIP, Name: "lb", AsPtr: false, Is64Bit: true);
2147 Value *FakeUB = createFakeIntVal(Builder, OuterAllocaIP: AllocaIP, ToBeDeleted,
2148 InnerAllocaIP: TaskloopAllocaIP, Name: "ub", AsPtr: false, Is64Bit: true);
2149 Value *FakeStep = createFakeIntVal(Builder, OuterAllocaIP: AllocaIP, ToBeDeleted,
2150 InnerAllocaIP: TaskloopAllocaIP, Name: "step", AsPtr: false, Is64Bit: true);
2151 // For Taskloop, we want to force the bounds being the first 3 inputs in the
2152 // aggregate struct
2153 OI.Inputs.insert(X: FakeLB);
2154 OI.Inputs.insert(X: FakeUB);
2155 OI.Inputs.insert(X: FakeStep);
2156 if (TaskContextStructPtrVal)
2157 OI.Inputs.insert(X: TaskContextStructPtrVal);
2158 assert(((TaskContextStructPtrVal && DupCB) ||
2159 (!TaskContextStructPtrVal && !DupCB)) &&
2160 "Task context struct ptr and duplication callback must be both set "
2161 "or both null");
2162
2163 // It isn't safe to run the duplication bodygen callback inside the post
2164 // outlining callback so this has to be run now before we know the real task
2165 // shareds structure type.
2166 unsigned ProgramAddressSpace = M.getDataLayout().getProgramAddressSpace();
2167 Type *PointerTy = PointerType::get(C&: Builder.getContext(), AddressSpace: ProgramAddressSpace);
2168 Type *FakeSharedsTy = StructType::get(
2169 Context&: Builder.getContext(),
2170 Elements: {FakeLB->getType(), FakeUB->getType(), FakeStep->getType(), PointerTy});
2171 Expected<Value *> TaskDupFnOrErr = createTaskDuplicationFunction(
2172 PrivatesTy: FakeSharedsTy,
2173 /*PrivatesIndex: the pointer after the three indices above*/ PrivatesIndex: 3, DupCB);
2174 if (!TaskDupFnOrErr) {
2175 return TaskDupFnOrErr.takeError();
2176 }
2177 Value *TaskDupFn = *TaskDupFnOrErr;
2178
2179 OI.PostOutlineCB = [this, Ident, LBVal, UBVal, StepVal, Untied,
2180 TaskloopAllocaBB, CLI, Loc, TaskDupFn, ToBeDeleted,
2181 IfCond, GrainSize, NoGroup, Sched, FakeLB, FakeUB,
2182 FakeStep, Final, Mergeable, Priority,
2183 NumOfCollapseLoops](Function &OutlinedFn) mutable {
2184 // Replace the Stale CI by appropriate RTL function call.
2185 assert(OutlinedFn.hasOneUse() &&
2186 "there must be a single user for the outlined function");
2187 CallInst *StaleCI = cast<CallInst>(Val: OutlinedFn.user_back());
2188
2189 /* Create the casting for the Bounds Values that can be used when outlining
2190 * to replace the uses of the fakes with real values */
2191 BasicBlock *CodeReplBB = StaleCI->getParent();
2192 IRBuilderBase::InsertPoint CurrentIp = Builder.saveIP();
2193 Builder.SetInsertPoint(CodeReplBB->getFirstInsertionPt());
2194 Value *CastedLBVal =
2195 Builder.CreateIntCast(V: LBVal, DestTy: Builder.getInt64Ty(), isSigned: true, Name: "lb64");
2196 Value *CastedUBVal =
2197 Builder.CreateIntCast(V: UBVal, DestTy: Builder.getInt64Ty(), isSigned: true, Name: "ub64");
2198 Value *CastedStepVal =
2199 Builder.CreateIntCast(V: StepVal, DestTy: Builder.getInt64Ty(), isSigned: true, Name: "step64");
2200 Builder.restoreIP(IP: CurrentIp);
2201
2202 Builder.SetInsertPoint(StaleCI);
2203
2204 // Gather the arguments for emitting the runtime call for
2205 // @__kmpc_omp_task_alloc
2206 Function *TaskAllocFn =
2207 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_alloc);
2208
2209 Value *ThreadID = getOrCreateThreadID(Ident);
2210
2211 if (!NoGroup) {
2212 // Emit runtime call for @__kmpc_taskgroup
2213 Function *TaskgroupFn =
2214 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_taskgroup);
2215 Builder.CreateCall(Callee: TaskgroupFn, Args: {Ident, ThreadID});
2216 }
2217
2218 // `flags` Argument Configuration
2219 // Task is tied if (Flags & 1) == 1.
2220 // Task is untied if (Flags & 1) == 0.
2221 // Task is final if (Flags & 2) == 2.
2222 // Task is not final if (Flags & 2) == 0.
2223 // Task is mergeable if (Flags & 4) == 4.
2224 // Task is not mergeable if (Flags & 4) == 0.
2225 // Task is priority if (Flags & 32) == 32.
2226 // Task is not priority if (Flags & 32) == 0.
2227 Value *Flags = Builder.getInt32(C: Untied ? 0 : 1);
2228 if (Final)
2229 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 2), RHS: Flags);
2230 if (Mergeable)
2231 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 4), RHS: Flags);
2232 if (Priority)
2233 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 32), RHS: Flags);
2234
2235 Value *TaskSize = Builder.getInt64(
2236 C: divideCeil(Numerator: M.getDataLayout().getTypeSizeInBits(Ty: Task), Denominator: 8));
2237
2238 AllocaInst *ArgStructAlloca =
2239 dyn_cast<AllocaInst>(Val: StaleCI->getArgOperand(i: 1));
2240 assert(ArgStructAlloca &&
2241 "Unable to find the alloca instruction corresponding to arguments "
2242 "for extracted function");
2243 StructType *ArgStructType =
2244 dyn_cast<StructType>(Val: ArgStructAlloca->getAllocatedType());
2245 assert(ArgStructType && "Unable to find struct type corresponding to "
2246 "arguments for extracted function");
2247 Value *SharedsSize =
2248 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: ArgStructType));
2249
2250 // Emit the @__kmpc_omp_task_alloc runtime call
2251 // The runtime call returns a pointer to an area where the task captured
2252 // variables must be copied before the task is run (TaskData)
2253 CallInst *TaskData = Builder.CreateCall(
2254 Callee: TaskAllocFn, Args: {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
2255 /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
2256 /*task_func=*/&OutlinedFn});
2257
2258 Value *Shareds = StaleCI->getArgOperand(i: 1);
2259 Align Alignment = TaskData->getPointerAlignment(DL: M.getDataLayout());
2260 Value *TaskShareds = Builder.CreateLoad(Ty: VoidPtr, Ptr: TaskData);
2261 Builder.CreateMemCpy(Dst: TaskShareds, DstAlign: Alignment, Src: Shareds, SrcAlign: Alignment,
2262 Size: SharedsSize);
2263 // Get the pointer to loop lb, ub, step from task ptr
2264 // and set up the lowerbound,upperbound and step values
2265 llvm::Value *Lb = Builder.CreateGEP(
2266 Ty: ArgStructType, Ptr: TaskShareds, IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 0)});
2267
2268 llvm::Value *Ub = Builder.CreateGEP(
2269 Ty: ArgStructType, Ptr: TaskShareds, IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 1)});
2270
2271 llvm::Value *Step = Builder.CreateGEP(
2272 Ty: ArgStructType, Ptr: TaskShareds, IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 2)});
2273 llvm::Value *Loadstep = Builder.CreateLoad(Ty: Builder.getInt64Ty(), Ptr: Step);
2274
2275 // set up the arguments for emitting kmpc_taskloop runtime call
2276 // setting values for ifval, nogroup, sched, grainsize, task_dup
2277 Value *IfCondVal =
2278 IfCond ? Builder.CreateIntCast(V: IfCond, DestTy: Builder.getInt32Ty(), isSigned: true)
2279 : Builder.getInt32(C: 1);
2280 // As __kmpc_taskgroup is called manually in OMPIRBuilder, NoGroupVal should
2281 // always be 1 when calling __kmpc_taskloop to ensure it is not called again
2282 Value *NoGroupVal = Builder.getInt32(C: 1);
2283 Value *SchedVal = Builder.getInt32(C: Sched);
2284 Value *GrainSizeVal =
2285 GrainSize ? Builder.CreateIntCast(V: GrainSize, DestTy: Builder.getInt64Ty(), isSigned: true)
2286 : Builder.getInt64(C: 0);
2287 Value *TaskDup = TaskDupFn;
2288
2289 Value *Args[] = {Ident, ThreadID, TaskData, IfCondVal, Lb, Ub,
2290 Loadstep, NoGroupVal, SchedVal, GrainSizeVal, TaskDup};
2291
2292 // taskloop runtime call
2293 Function *TaskloopFn =
2294 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_taskloop);
2295 Builder.CreateCall(Callee: TaskloopFn, Args);
2296
2297 // Emit the @__kmpc_end_taskgroup runtime call to end the taskgroup if
2298 // nogroup is not defined
2299 if (!NoGroup) {
2300 Function *EndTaskgroupFn =
2301 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_taskgroup);
2302 Builder.CreateCall(Callee: EndTaskgroupFn, Args: {Ident, ThreadID});
2303 }
2304
2305 StaleCI->eraseFromParent();
2306
2307 Builder.SetInsertPoint(TheBB: TaskloopAllocaBB, IP: TaskloopAllocaBB->begin());
2308
2309 LoadInst *SharedsOutlined =
2310 Builder.CreateLoad(Ty: VoidPtr, Ptr: OutlinedFn.getArg(i: 1));
2311 OutlinedFn.getArg(i: 1)->replaceUsesWithIf(
2312 New: SharedsOutlined,
2313 ShouldReplace: [SharedsOutlined](Use &U) { return U.getUser() != SharedsOutlined; });
2314
2315 Value *IV = CLI->getIndVar();
2316 Type *IVTy = IV->getType();
2317 Constant *One = ConstantInt::get(Ty: Builder.getInt64Ty(), V: 1);
2318
2319 // When outlining, CodeExtractor will create GEP's to the LowerBound and
2320 // UpperBound. These GEP's can be reused for loading the tasks respective
2321 // bounds.
2322 Value *TaskLB = nullptr;
2323 Value *TaskUB = nullptr;
2324 Value *LoadTaskLB = nullptr;
2325 Value *LoadTaskUB = nullptr;
2326 for (Instruction &I : *TaskloopAllocaBB) {
2327 if (I.getOpcode() == Instruction::GetElementPtr) {
2328 GetElementPtrInst &Gep = cast<GetElementPtrInst>(Val&: I);
2329 if (ConstantInt *CI = dyn_cast<ConstantInt>(Val: Gep.getOperand(i_nocapture: 2))) {
2330 switch (CI->getZExtValue()) {
2331 case 0:
2332 TaskLB = &I;
2333 break;
2334 case 1:
2335 TaskUB = &I;
2336 break;
2337 }
2338 }
2339 } else if (I.getOpcode() == Instruction::Load) {
2340 LoadInst &Load = cast<LoadInst>(Val&: I);
2341 if (Load.getPointerOperand() == TaskLB) {
2342 assert(TaskLB != nullptr && "Expected value for TaskLB");
2343 LoadTaskLB = &I;
2344 } else if (Load.getPointerOperand() == TaskUB) {
2345 assert(TaskUB != nullptr && "Expected value for TaskUB");
2346 LoadTaskUB = &I;
2347 }
2348 }
2349 }
2350
2351 Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
2352
2353 assert(LoadTaskLB != nullptr && "Expected value for LoadTaskLB");
2354 assert(LoadTaskUB != nullptr && "Expected value for LoadTaskUB");
2355 Value *TripCountMinusOne =
2356 Builder.CreateSDiv(LHS: Builder.CreateSub(LHS: LoadTaskUB, RHS: LoadTaskLB), RHS: FakeStep);
2357 Value *TripCount = Builder.CreateAdd(LHS: TripCountMinusOne, RHS: One, Name: "trip_cnt");
2358 Value *CastedTripCount = Builder.CreateIntCast(V: TripCount, DestTy: IVTy, isSigned: true);
2359 Value *CastedTaskLB = Builder.CreateIntCast(V: LoadTaskLB, DestTy: IVTy, isSigned: true);
2360 // set the trip count in the CLI
2361 CLI->setTripCount(CastedTripCount);
2362
2363 Builder.SetInsertPoint(TheBB: CLI->getBody(),
2364 IP: CLI->getBody()->getFirstInsertionPt());
2365
2366 if (NumOfCollapseLoops > 1) {
2367 llvm::SmallVector<User *> UsersToReplace;
2368 // When using the collapse clause, the bounds of the loop have to be
2369 // adjusted to properly represent the iterator of the outer loop.
2370 Value *IVPlusTaskLB = Builder.CreateAdd(
2371 LHS: CLI->getIndVar(),
2372 RHS: Builder.CreateSub(LHS: CastedTaskLB, RHS: ConstantInt::get(Ty: IVTy, V: 1)));
2373 // To ensure every Use is correctly captured, we first want to record
2374 // which users to replace the value in, and then replace the value.
2375 for (auto IVUse = CLI->getIndVar()->uses().begin();
2376 IVUse != CLI->getIndVar()->uses().end(); IVUse++) {
2377 User *IVUser = IVUse->getUser();
2378 if (auto *Op = dyn_cast<BinaryOperator>(Val: IVUser)) {
2379 if (Op->getOpcode() == Instruction::URem ||
2380 Op->getOpcode() == Instruction::UDiv) {
2381 UsersToReplace.push_back(Elt: IVUser);
2382 }
2383 }
2384 }
2385 for (User *User : UsersToReplace) {
2386 User->replaceUsesOfWith(From: CLI->getIndVar(), To: IVPlusTaskLB);
2387 }
2388 } else {
2389 // The canonical loop is generated with a fixed lower bound. We need to
2390 // update the index calculation code to use the task's lower bound. The
2391 // generated code looks like this:
2392 // %omp_loop.iv = phi ...
2393 // ...
2394 // %tmp = mul [type] %omp_loop.iv, step
2395 // %user_index = add [type] tmp, lb
2396 // OpenMPIRBuilder constructs canonical loops to have exactly three uses
2397 // of the normalised induction variable:
2398 // 1. This one: converting the normalised IV to the user IV
2399 // 2. The increment (add)
2400 // 3. The comparison against the trip count (icmp)
2401 // (1) is the only use that is a mul followed by an add so this cannot
2402 // match other IR.
2403 assert(CLI->getIndVar()->getNumUses() == 3 &&
2404 "Canonical loop should have exactly three uses of the ind var");
2405 for (User *IVUser : CLI->getIndVar()->users()) {
2406 if (auto *Mul = dyn_cast<BinaryOperator>(Val: IVUser)) {
2407 if (Mul->getOpcode() == Instruction::Mul) {
2408 for (User *MulUser : Mul->users()) {
2409 if (auto *Add = dyn_cast<BinaryOperator>(Val: MulUser)) {
2410 if (Add->getOpcode() == Instruction::Add) {
2411 Add->setOperand(i_nocapture: 1, Val_nocapture: CastedTaskLB);
2412 }
2413 }
2414 }
2415 }
2416 }
2417 }
2418 }
2419
2420 FakeLB->replaceAllUsesWith(V: CastedLBVal);
2421 FakeUB->replaceAllUsesWith(V: CastedUBVal);
2422 FakeStep->replaceAllUsesWith(V: CastedStepVal);
2423 for (Instruction *I : llvm::reverse(C&: ToBeDeleted)) {
2424 I->eraseFromParent();
2425 }
2426 };
2427
2428 addOutlineInfo(OI: std::move(OI));
2429 Builder.SetInsertPoint(TheBB: TaskloopExitBB, IP: TaskloopExitBB->begin());
2430 return Builder.saveIP();
2431}
2432
2433OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
2434 const LocationDescription &Loc, InsertPointTy AllocaIP,
2435 BodyGenCallbackTy BodyGenCB, bool Tied, Value *Final, Value *IfCondition,
2436 SmallVector<DependData> Dependencies, bool Mergeable, Value *EventHandle,
2437 Value *Priority) {
2438
2439 if (!updateToLocation(Loc))
2440 return InsertPointTy();
2441
2442 uint32_t SrcLocStrSize;
2443 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2444 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2445 // The current basic block is split into four basic blocks. After outlining,
2446 // they will be mapped as follows:
2447 // ```
2448 // def current_fn() {
2449 // current_basic_block:
2450 // br label %task.exit
2451 // task.exit:
2452 // ; instructions after task
2453 // }
2454 // def outlined_fn() {
2455 // task.alloca:
2456 // br label %task.body
2457 // task.body:
2458 // ret void
2459 // }
2460 // ```
2461 BasicBlock *TaskExitBB = splitBB(Builder, /*CreateBranch=*/true, Name: "task.exit");
2462 BasicBlock *TaskBodyBB = splitBB(Builder, /*CreateBranch=*/true, Name: "task.body");
2463 BasicBlock *TaskAllocaBB =
2464 splitBB(Builder, /*CreateBranch=*/true, Name: "task.alloca");
2465
2466 InsertPointTy TaskAllocaIP =
2467 InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin());
2468 InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin());
2469 if (Error Err = BodyGenCB(TaskAllocaIP, TaskBodyIP))
2470 return Err;
2471
2472 OutlineInfo OI;
2473 OI.EntryBB = TaskAllocaBB;
2474 OI.OuterAllocaBB = AllocaIP.getBlock();
2475 OI.ExitBB = TaskExitBB;
2476
2477 // Add the thread ID argument.
2478 SmallVector<Instruction *, 4> ToBeDeleted;
2479 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
2480 Builder, OuterAllocaIP: AllocaIP, ToBeDeleted, InnerAllocaIP: TaskAllocaIP, Name: "global.tid", AsPtr: false));
2481
2482 OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
2483 Mergeable, Priority, EventHandle, TaskAllocaBB,
2484 ToBeDeleted](Function &OutlinedFn) mutable {
2485 // Replace the Stale CI by appropriate RTL function call.
2486 assert(OutlinedFn.hasOneUse() &&
2487 "there must be a single user for the outlined function");
2488 CallInst *StaleCI = cast<CallInst>(Val: OutlinedFn.user_back());
2489
2490 // HasShareds is true if any variables are captured in the outlined region,
2491 // false otherwise.
2492 bool HasShareds = StaleCI->arg_size() > 1;
2493 Builder.SetInsertPoint(StaleCI);
2494
2495 // Gather the arguments for emitting the runtime call for
2496 // @__kmpc_omp_task_alloc
2497 Function *TaskAllocFn =
2498 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_alloc);
2499
2500 // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
2501 // call.
2502 Value *ThreadID = getOrCreateThreadID(Ident);
2503
2504 // Argument - `flags`
2505 // Task is tied iff (Flags & 1) == 1.
2506 // Task is untied iff (Flags & 1) == 0.
2507 // Task is final iff (Flags & 2) == 2.
2508 // Task is not final iff (Flags & 2) == 0.
2509 // Task is mergeable iff (Flags & 4) == 4.
2510 // Task is not mergeable iff (Flags & 4) == 0.
2511 // Task is priority iff (Flags & 32) == 32.
2512 // Task is not priority iff (Flags & 32) == 0.
2513 // TODO: Handle the other flags.
2514 Value *Flags = Builder.getInt32(C: Tied);
2515 if (Final) {
2516 Value *FinalFlag =
2517 Builder.CreateSelect(C: Final, True: Builder.getInt32(C: 2), False: Builder.getInt32(C: 0));
2518 Flags = Builder.CreateOr(LHS: FinalFlag, RHS: Flags);
2519 }
2520
2521 if (Mergeable)
2522 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 4), RHS: Flags);
2523 if (Priority)
2524 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 32), RHS: Flags);
2525
2526 // Argument - `sizeof_kmp_task_t` (TaskSize)
2527 // Tasksize refers to the size in bytes of kmp_task_t data structure
2528 // including private vars accessed in task.
2529 // TODO: add kmp_task_t_with_privates (privates)
2530 Value *TaskSize = Builder.getInt64(
2531 C: divideCeil(Numerator: M.getDataLayout().getTypeSizeInBits(Ty: Task), Denominator: 8));
2532
2533 // Argument - `sizeof_shareds` (SharedsSize)
2534 // SharedsSize refers to the shareds array size in the kmp_task_t data
2535 // structure.
2536 Value *SharedsSize = Builder.getInt64(C: 0);
2537 if (HasShareds) {
2538 AllocaInst *ArgStructAlloca =
2539 dyn_cast<AllocaInst>(Val: StaleCI->getArgOperand(i: 1));
2540 assert(ArgStructAlloca &&
2541 "Unable to find the alloca instruction corresponding to arguments "
2542 "for extracted function");
2543 StructType *ArgStructType =
2544 dyn_cast<StructType>(Val: ArgStructAlloca->getAllocatedType());
2545 assert(ArgStructType && "Unable to find struct type corresponding to "
2546 "arguments for extracted function");
2547 SharedsSize =
2548 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: ArgStructType));
2549 }
2550 // Emit the @__kmpc_omp_task_alloc runtime call
2551 // The runtime call returns a pointer to an area where the task captured
2552 // variables must be copied before the task is run (TaskData)
2553 CallInst *TaskData = createRuntimeFunctionCall(
2554 Callee: TaskAllocFn, Args: {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
2555 /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
2556 /*task_func=*/&OutlinedFn});
2557
2558 // Emit detach clause initialization.
2559 // evt = (typeof(evt))__kmpc_task_allow_completion_event(loc, tid,
2560 // task_descriptor);
2561 if (EventHandle) {
2562 Function *TaskDetachFn = getOrCreateRuntimeFunctionPtr(
2563 FnID: OMPRTL___kmpc_task_allow_completion_event);
2564 llvm::Value *EventVal =
2565 createRuntimeFunctionCall(Callee: TaskDetachFn, Args: {Ident, ThreadID, TaskData});
2566 llvm::Value *EventHandleAddr =
2567 Builder.CreatePointerBitCastOrAddrSpaceCast(V: EventHandle,
2568 DestTy: Builder.getPtrTy(AddrSpace: 0));
2569 EventVal = Builder.CreatePtrToInt(V: EventVal, DestTy: Builder.getInt64Ty());
2570 Builder.CreateStore(Val: EventVal, Ptr: EventHandleAddr);
2571 }
2572 // Copy the arguments for outlined function
2573 if (HasShareds) {
2574 Value *Shareds = StaleCI->getArgOperand(i: 1);
2575 Align Alignment = TaskData->getPointerAlignment(DL: M.getDataLayout());
2576 Value *TaskShareds = Builder.CreateLoad(Ty: VoidPtr, Ptr: TaskData);
2577 Builder.CreateMemCpy(Dst: TaskShareds, DstAlign: Alignment, Src: Shareds, SrcAlign: Alignment,
2578 Size: SharedsSize);
2579 }
2580
2581 if (Priority) {
2582 //
2583 // The return type of "__kmpc_omp_task_alloc" is "kmp_task_t *",
2584 // we populate the priority information into the "kmp_task_t" here
2585 //
2586 // The struct "kmp_task_t" definition is available in kmp.h
2587 // kmp_task_t = { shareds, routine, part_id, data1, data2 }
2588 // data2 is used for priority
2589 //
2590 Type *Int32Ty = Builder.getInt32Ty();
2591 Constant *Zero = ConstantInt::get(Ty: Int32Ty, V: 0);
2592 // kmp_task_t* => { ptr }
2593 Type *TaskPtr = StructType::get(elt1: VoidPtr);
2594 Value *TaskGEP =
2595 Builder.CreateInBoundsGEP(Ty: TaskPtr, Ptr: TaskData, IdxList: {Zero, Zero});
2596 // kmp_task_t => { ptr, ptr, i32, ptr, ptr }
2597 Type *TaskStructType = StructType::get(
2598 elt1: VoidPtr, elts: VoidPtr, elts: Builder.getInt32Ty(), elts: VoidPtr, elts: VoidPtr);
2599 Value *PriorityData = Builder.CreateInBoundsGEP(
2600 Ty: TaskStructType, Ptr: TaskGEP, IdxList: {Zero, ConstantInt::get(Ty: Int32Ty, V: 4)});
2601 // kmp_cmplrdata_t => { ptr, ptr }
2602 Type *CmplrStructType = StructType::get(elt1: VoidPtr, elts: VoidPtr);
2603 Value *CmplrData = Builder.CreateInBoundsGEP(Ty: CmplrStructType,
2604 Ptr: PriorityData, IdxList: {Zero, Zero});
2605 Builder.CreateStore(Val: Priority, Ptr: CmplrData);
2606 }
2607
2608 Value *DepArray = emitTaskDependencies(OMPBuilder&: *this, Dependencies);
2609
2610 // In the presence of the `if` clause, the following IR is generated:
2611 // ...
2612 // %data = call @__kmpc_omp_task_alloc(...)
2613 // br i1 %if_condition, label %then, label %else
2614 // then:
2615 // call @__kmpc_omp_task(...)
2616 // br label %exit
2617 // else:
2618 // ;; Wait for resolution of dependencies, if any, before
2619 // ;; beginning the task
2620 // call @__kmpc_omp_wait_deps(...)
2621 // call @__kmpc_omp_task_begin_if0(...)
2622 // call @outlined_fn(...)
2623 // call @__kmpc_omp_task_complete_if0(...)
2624 // br label %exit
2625 // exit:
2626 // ...
2627 if (IfCondition) {
2628 // `SplitBlockAndInsertIfThenElse` requires the block to have a
2629 // terminator.
2630 splitBB(Builder, /*CreateBranch=*/true, Name: "if.end");
2631 Instruction *IfTerminator =
2632 Builder.GetInsertPoint()->getParent()->getTerminator();
2633 Instruction *ThenTI = IfTerminator, *ElseTI = nullptr;
2634 Builder.SetInsertPoint(IfTerminator);
2635 SplitBlockAndInsertIfThenElse(Cond: IfCondition, SplitBefore: IfTerminator, ThenTerm: &ThenTI,
2636 ElseTerm: &ElseTI);
2637 Builder.SetInsertPoint(ElseTI);
2638
2639 if (Dependencies.size()) {
2640 Function *TaskWaitFn =
2641 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_wait_deps);
2642 createRuntimeFunctionCall(
2643 Callee: TaskWaitFn,
2644 Args: {Ident, ThreadID, Builder.getInt32(C: Dependencies.size()), DepArray,
2645 ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
2646 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
2647 }
2648 Function *TaskBeginFn =
2649 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_begin_if0);
2650 Function *TaskCompleteFn =
2651 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_complete_if0);
2652 createRuntimeFunctionCall(Callee: TaskBeginFn, Args: {Ident, ThreadID, TaskData});
2653 CallInst *CI = nullptr;
2654 if (HasShareds)
2655 CI = createRuntimeFunctionCall(Callee: &OutlinedFn, Args: {ThreadID, TaskData});
2656 else
2657 CI = createRuntimeFunctionCall(Callee: &OutlinedFn, Args: {ThreadID});
2658 CI->setDebugLoc(StaleCI->getDebugLoc());
2659 createRuntimeFunctionCall(Callee: TaskCompleteFn, Args: {Ident, ThreadID, TaskData});
2660 Builder.SetInsertPoint(ThenTI);
2661 }
2662
2663 if (Dependencies.size()) {
2664 Function *TaskFn =
2665 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_with_deps);
2666 createRuntimeFunctionCall(
2667 Callee: TaskFn,
2668 Args: {Ident, ThreadID, TaskData, Builder.getInt32(C: Dependencies.size()),
2669 DepArray, ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
2670 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
2671
2672 } else {
2673 // Emit the @__kmpc_omp_task runtime call to spawn the task
2674 Function *TaskFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task);
2675 createRuntimeFunctionCall(Callee: TaskFn, Args: {Ident, ThreadID, TaskData});
2676 }
2677
2678 StaleCI->eraseFromParent();
2679
2680 Builder.SetInsertPoint(TheBB: TaskAllocaBB, IP: TaskAllocaBB->begin());
2681 if (HasShareds) {
2682 LoadInst *Shareds = Builder.CreateLoad(Ty: VoidPtr, Ptr: OutlinedFn.getArg(i: 1));
2683 OutlinedFn.getArg(i: 1)->replaceUsesWithIf(
2684 New: Shareds, ShouldReplace: [Shareds](Use &U) { return U.getUser() != Shareds; });
2685 }
2686
2687 for (Instruction *I : llvm::reverse(C&: ToBeDeleted))
2688 I->eraseFromParent();
2689 };
2690
2691 addOutlineInfo(OI: std::move(OI));
2692 Builder.SetInsertPoint(TheBB: TaskExitBB, IP: TaskExitBB->begin());
2693
2694 return Builder.saveIP();
2695}
2696
2697OpenMPIRBuilder::InsertPointOrErrorTy
2698OpenMPIRBuilder::createTaskgroup(const LocationDescription &Loc,
2699 InsertPointTy AllocaIP,
2700 BodyGenCallbackTy BodyGenCB) {
2701 if (!updateToLocation(Loc))
2702 return InsertPointTy();
2703
2704 uint32_t SrcLocStrSize;
2705 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2706 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2707 Value *ThreadID = getOrCreateThreadID(Ident);
2708
2709 // Emit the @__kmpc_taskgroup runtime call to start the taskgroup
2710 Function *TaskgroupFn =
2711 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_taskgroup);
2712 createRuntimeFunctionCall(Callee: TaskgroupFn, Args: {Ident, ThreadID});
2713
2714 BasicBlock *TaskgroupExitBB = splitBB(Builder, CreateBranch: true, Name: "taskgroup.exit");
2715 if (Error Err = BodyGenCB(AllocaIP, Builder.saveIP()))
2716 return Err;
2717
2718 Builder.SetInsertPoint(TaskgroupExitBB);
2719 // Emit the @__kmpc_end_taskgroup runtime call to end the taskgroup
2720 Function *EndTaskgroupFn =
2721 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_taskgroup);
2722 createRuntimeFunctionCall(Callee: EndTaskgroupFn, Args: {Ident, ThreadID});
2723
2724 return Builder.saveIP();
2725}
2726
2727OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createSections(
2728 const LocationDescription &Loc, InsertPointTy AllocaIP,
2729 ArrayRef<StorableBodyGenCallbackTy> SectionCBs, PrivatizeCallbackTy PrivCB,
2730 FinalizeCallbackTy FiniCB, bool IsCancellable, bool IsNowait) {
2731 assert(!isConflictIP(AllocaIP, Loc.IP) && "Dedicated IP allocas required");
2732
2733 if (!updateToLocation(Loc))
2734 return Loc.IP;
2735
2736 FinalizationStack.push_back(Elt: {FiniCB, OMPD_sections, IsCancellable});
2737
2738 // Each section is emitted as a switch case
2739 // Each finalization callback is handled from clang.EmitOMPSectionDirective()
2740 // -> OMP.createSection() which generates the IR for each section
2741 // Iterate through all sections and emit a switch construct:
2742 // switch (IV) {
2743 // case 0:
2744 // <SectionStmt[0]>;
2745 // break;
2746 // ...
2747 // case <NumSection> - 1:
2748 // <SectionStmt[<NumSection> - 1]>;
2749 // break;
2750 // }
2751 // ...
2752 // section_loop.after:
2753 // <FiniCB>;
2754 auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, Value *IndVar) -> Error {
2755 Builder.restoreIP(IP: CodeGenIP);
2756 BasicBlock *Continue =
2757 splitBBWithSuffix(Builder, /*CreateBranch=*/false, Suffix: ".sections.after");
2758 Function *CurFn = Continue->getParent();
2759 SwitchInst *SwitchStmt = Builder.CreateSwitch(V: IndVar, Dest: Continue);
2760
2761 unsigned CaseNumber = 0;
2762 for (auto SectionCB : SectionCBs) {
2763 BasicBlock *CaseBB = BasicBlock::Create(
2764 Context&: M.getContext(), Name: "omp_section_loop.body.case", Parent: CurFn, InsertBefore: Continue);
2765 SwitchStmt->addCase(OnVal: Builder.getInt32(C: CaseNumber), Dest: CaseBB);
2766 Builder.SetInsertPoint(CaseBB);
2767 BranchInst *CaseEndBr = Builder.CreateBr(Dest: Continue);
2768 if (Error Err = SectionCB(InsertPointTy(), {CaseEndBr->getParent(),
2769 CaseEndBr->getIterator()}))
2770 return Err;
2771 CaseNumber++;
2772 }
2773 // remove the existing terminator from body BB since there can be no
2774 // terminators after switch/case
2775 return Error::success();
2776 };
2777 // Loop body ends here
2778 // LowerBound, UpperBound, and STride for createCanonicalLoop
2779 Type *I32Ty = Type::getInt32Ty(C&: M.getContext());
2780 Value *LB = ConstantInt::get(Ty: I32Ty, V: 0);
2781 Value *UB = ConstantInt::get(Ty: I32Ty, V: SectionCBs.size());
2782 Value *ST = ConstantInt::get(Ty: I32Ty, V: 1);
2783 Expected<CanonicalLoopInfo *> LoopInfo = createCanonicalLoop(
2784 Loc, BodyGenCB: LoopBodyGenCB, Start: LB, Stop: UB, Step: ST, IsSigned: true, InclusiveStop: false, ComputeIP: AllocaIP, Name: "section_loop");
2785 if (!LoopInfo)
2786 return LoopInfo.takeError();
2787
2788 InsertPointOrErrorTy WsloopIP =
2789 applyStaticWorkshareLoop(DL: Loc.DL, CLI: *LoopInfo, AllocaIP,
2790 LoopType: WorksharingLoopType::ForStaticLoop, NeedsBarrier: !IsNowait);
2791 if (!WsloopIP)
2792 return WsloopIP.takeError();
2793 InsertPointTy AfterIP = *WsloopIP;
2794
2795 BasicBlock *LoopFini = AfterIP.getBlock()->getSinglePredecessor();
2796 assert(LoopFini && "Bad structure of static workshare loop finalization");
2797
2798 // Apply the finalization callback in LoopAfterBB
2799 auto FiniInfo = FinalizationStack.pop_back_val();
2800 assert(FiniInfo.DK == OMPD_sections &&
2801 "Unexpected finalization stack state!");
2802 if (Error Err = FiniInfo.mergeFiniBB(Builder, OtherFiniBB: LoopFini))
2803 return Err;
2804
2805 return AfterIP;
2806}
2807
2808OpenMPIRBuilder::InsertPointOrErrorTy
2809OpenMPIRBuilder::createSection(const LocationDescription &Loc,
2810 BodyGenCallbackTy BodyGenCB,
2811 FinalizeCallbackTy FiniCB) {
2812 if (!updateToLocation(Loc))
2813 return Loc.IP;
2814
2815 auto FiniCBWrapper = [&](InsertPointTy IP) {
2816 if (IP.getBlock()->end() != IP.getPoint())
2817 return FiniCB(IP);
2818 // This must be done otherwise any nested constructs using FinalizeOMPRegion
2819 // will fail because that function requires the Finalization Basic Block to
2820 // have a terminator, which is already removed by EmitOMPRegionBody.
2821 // IP is currently at cancelation block.
2822 // We need to backtrack to the condition block to fetch
2823 // the exit block and create a branch from cancelation
2824 // to exit block.
2825 IRBuilder<>::InsertPointGuard IPG(Builder);
2826 Builder.restoreIP(IP);
2827 auto *CaseBB = Loc.IP.getBlock();
2828 auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
2829 auto *ExitBB = CondBB->getTerminator()->getSuccessor(Idx: 1);
2830 Instruction *I = Builder.CreateBr(Dest: ExitBB);
2831 IP = InsertPointTy(I->getParent(), I->getIterator());
2832 return FiniCB(IP);
2833 };
2834
2835 Directive OMPD = Directive::OMPD_sections;
2836 // Since we are using Finalization Callback here, HasFinalize
2837 // and IsCancellable have to be true
2838 return EmitOMPInlinedRegion(OMPD, EntryCall: nullptr, ExitCall: nullptr, BodyGenCB, FiniCB: FiniCBWrapper,
2839 /*Conditional*/ false, /*hasFinalize*/ HasFinalize: true,
2840 /*IsCancellable*/ true);
2841}
2842
2843static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
2844 BasicBlock::iterator IT(I);
2845 IT++;
2846 return OpenMPIRBuilder::InsertPointTy(I->getParent(), IT);
2847}
2848
2849Value *OpenMPIRBuilder::getGPUThreadID() {
2850 return createRuntimeFunctionCall(
2851 Callee: getOrCreateRuntimeFunction(M,
2852 FnID: OMPRTL___kmpc_get_hardware_thread_id_in_block),
2853 Args: {});
2854}
2855
2856Value *OpenMPIRBuilder::getGPUWarpSize() {
2857 return createRuntimeFunctionCall(
2858 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___kmpc_get_warp_size), Args: {});
2859}
2860
2861Value *OpenMPIRBuilder::getNVPTXWarpID() {
2862 unsigned LaneIDBits = Log2_32(Value: Config.getGridValue().GV_Warp_Size);
2863 return Builder.CreateAShr(LHS: getGPUThreadID(), RHS: LaneIDBits, Name: "nvptx_warp_id");
2864}
2865
2866Value *OpenMPIRBuilder::getNVPTXLaneID() {
2867 unsigned LaneIDBits = Log2_32(Value: Config.getGridValue().GV_Warp_Size);
2868 assert(LaneIDBits < 32 && "Invalid LaneIDBits size in NVPTX device.");
2869 unsigned LaneIDMask = ~0u >> (32u - LaneIDBits);
2870 return Builder.CreateAnd(LHS: getGPUThreadID(), RHS: Builder.getInt32(C: LaneIDMask),
2871 Name: "nvptx_lane_id");
2872}
2873
2874Value *OpenMPIRBuilder::castValueToType(InsertPointTy AllocaIP, Value *From,
2875 Type *ToType) {
2876 Type *FromType = From->getType();
2877 uint64_t FromSize = M.getDataLayout().getTypeStoreSize(Ty: FromType);
2878 uint64_t ToSize = M.getDataLayout().getTypeStoreSize(Ty: ToType);
2879 assert(FromSize > 0 && "From size must be greater than zero");
2880 assert(ToSize > 0 && "To size must be greater than zero");
2881 if (FromType == ToType)
2882 return From;
2883 if (FromSize == ToSize)
2884 return Builder.CreateBitCast(V: From, DestTy: ToType);
2885 if (ToType->isIntegerTy() && FromType->isIntegerTy())
2886 return Builder.CreateIntCast(V: From, DestTy: ToType, /*isSigned*/ true);
2887 InsertPointTy SaveIP = Builder.saveIP();
2888 Builder.restoreIP(IP: AllocaIP);
2889 Value *CastItem = Builder.CreateAlloca(Ty: ToType);
2890 Builder.restoreIP(IP: SaveIP);
2891
2892 Value *ValCastItem = Builder.CreatePointerBitCastOrAddrSpaceCast(
2893 V: CastItem, DestTy: Builder.getPtrTy(AddrSpace: 0));
2894 Builder.CreateStore(Val: From, Ptr: ValCastItem);
2895 return Builder.CreateLoad(Ty: ToType, Ptr: CastItem);
2896}
2897
2898Value *OpenMPIRBuilder::createRuntimeShuffleFunction(InsertPointTy AllocaIP,
2899 Value *Element,
2900 Type *ElementType,
2901 Value *Offset) {
2902 uint64_t Size = M.getDataLayout().getTypeStoreSize(Ty: ElementType);
2903 assert(Size <= 8 && "Unsupported bitwidth in shuffle instruction");
2904
2905 // Cast all types to 32- or 64-bit values before calling shuffle routines.
2906 Type *CastTy = Builder.getIntNTy(N: Size <= 4 ? 32 : 64);
2907 Value *ElemCast = castValueToType(AllocaIP, From: Element, ToType: CastTy);
2908 Value *WarpSize =
2909 Builder.CreateIntCast(V: getGPUWarpSize(), DestTy: Builder.getInt16Ty(), isSigned: true);
2910 Function *ShuffleFunc = getOrCreateRuntimeFunctionPtr(
2911 FnID: Size <= 4 ? RuntimeFunction::OMPRTL___kmpc_shuffle_int32
2912 : RuntimeFunction::OMPRTL___kmpc_shuffle_int64);
2913 Value *WarpSizeCast =
2914 Builder.CreateIntCast(V: WarpSize, DestTy: Builder.getInt16Ty(), /*isSigned=*/true);
2915 Value *ShuffleCall =
2916 createRuntimeFunctionCall(Callee: ShuffleFunc, Args: {ElemCast, Offset, WarpSizeCast});
2917 return castValueToType(AllocaIP, From: ShuffleCall, ToType: CastTy);
2918}
2919
2920void OpenMPIRBuilder::shuffleAndStore(InsertPointTy AllocaIP, Value *SrcAddr,
2921 Value *DstAddr, Type *ElemType,
2922 Value *Offset, Type *ReductionArrayTy,
2923 bool IsByRefElem) {
2924 uint64_t Size = M.getDataLayout().getTypeStoreSize(Ty: ElemType);
2925 // Create the loop over the big sized data.
2926 // ptr = (void*)Elem;
2927 // ptrEnd = (void*) Elem + 1;
2928 // Step = 8;
2929 // while (ptr + Step < ptrEnd)
2930 // shuffle((int64_t)*ptr);
2931 // Step = 4;
2932 // while (ptr + Step < ptrEnd)
2933 // shuffle((int32_t)*ptr);
2934 // ...
2935 Type *IndexTy = Builder.getIndexTy(
2936 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
2937 Value *ElemPtr = DstAddr;
2938 Value *Ptr = SrcAddr;
2939 for (unsigned IntSize = 8; IntSize >= 1; IntSize /= 2) {
2940 if (Size < IntSize)
2941 continue;
2942 Type *IntType = Builder.getIntNTy(N: IntSize * 8);
2943 Ptr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2944 V: Ptr, DestTy: Builder.getPtrTy(AddrSpace: 0), Name: Ptr->getName() + ".ascast");
2945 Value *SrcAddrGEP =
2946 Builder.CreateGEP(Ty: ElemType, Ptr: SrcAddr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
2947 ElemPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2948 V: ElemPtr, DestTy: Builder.getPtrTy(AddrSpace: 0), Name: ElemPtr->getName() + ".ascast");
2949
2950 Function *CurFunc = Builder.GetInsertBlock()->getParent();
2951 if ((Size / IntSize) > 1) {
2952 Value *PtrEnd = Builder.CreatePointerBitCastOrAddrSpaceCast(
2953 V: SrcAddrGEP, DestTy: Builder.getPtrTy());
2954 BasicBlock *PreCondBB =
2955 BasicBlock::Create(Context&: M.getContext(), Name: ".shuffle.pre_cond");
2956 BasicBlock *ThenBB = BasicBlock::Create(Context&: M.getContext(), Name: ".shuffle.then");
2957 BasicBlock *ExitBB = BasicBlock::Create(Context&: M.getContext(), Name: ".shuffle.exit");
2958 BasicBlock *CurrentBB = Builder.GetInsertBlock();
2959 emitBlock(BB: PreCondBB, CurFn: CurFunc);
2960 PHINode *PhiSrc =
2961 Builder.CreatePHI(Ty: Ptr->getType(), /*NumReservedValues=*/2);
2962 PhiSrc->addIncoming(V: Ptr, BB: CurrentBB);
2963 PHINode *PhiDest =
2964 Builder.CreatePHI(Ty: ElemPtr->getType(), /*NumReservedValues=*/2);
2965 PhiDest->addIncoming(V: ElemPtr, BB: CurrentBB);
2966 Ptr = PhiSrc;
2967 ElemPtr = PhiDest;
2968 Value *PtrDiff = Builder.CreatePtrDiff(
2969 ElemTy: Builder.getInt8Ty(), LHS: PtrEnd,
2970 RHS: Builder.CreatePointerBitCastOrAddrSpaceCast(V: Ptr, DestTy: Builder.getPtrTy()));
2971 Builder.CreateCondBr(
2972 Cond: Builder.CreateICmpSGT(LHS: PtrDiff, RHS: Builder.getInt64(C: IntSize - 1)), True: ThenBB,
2973 False: ExitBB);
2974 emitBlock(BB: ThenBB, CurFn: CurFunc);
2975 Value *Res = createRuntimeShuffleFunction(
2976 AllocaIP,
2977 Element: Builder.CreateAlignedLoad(
2978 Ty: IntType, Ptr, Align: M.getDataLayout().getPrefTypeAlign(Ty: ElemType)),
2979 ElementType: IntType, Offset);
2980 Builder.CreateAlignedStore(Val: Res, Ptr: ElemPtr,
2981 Align: M.getDataLayout().getPrefTypeAlign(Ty: ElemType));
2982 Value *LocalPtr =
2983 Builder.CreateGEP(Ty: IntType, Ptr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
2984 Value *LocalElemPtr =
2985 Builder.CreateGEP(Ty: IntType, Ptr: ElemPtr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
2986 PhiSrc->addIncoming(V: LocalPtr, BB: ThenBB);
2987 PhiDest->addIncoming(V: LocalElemPtr, BB: ThenBB);
2988 emitBranch(Target: PreCondBB);
2989 emitBlock(BB: ExitBB, CurFn: CurFunc);
2990 } else {
2991 Value *Res = createRuntimeShuffleFunction(
2992 AllocaIP, Element: Builder.CreateLoad(Ty: IntType, Ptr), ElementType: IntType, Offset);
2993 if (ElemType->isIntegerTy() && ElemType->getScalarSizeInBits() <
2994 Res->getType()->getScalarSizeInBits())
2995 Res = Builder.CreateTrunc(V: Res, DestTy: ElemType);
2996 Builder.CreateStore(Val: Res, Ptr: ElemPtr);
2997 Ptr = Builder.CreateGEP(Ty: IntType, Ptr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
2998 ElemPtr =
2999 Builder.CreateGEP(Ty: IntType, Ptr: ElemPtr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
3000 }
3001 Size = Size % IntSize;
3002 }
3003}
3004
3005Error OpenMPIRBuilder::emitReductionListCopy(
3006 InsertPointTy AllocaIP, CopyAction Action, Type *ReductionArrayTy,
3007 ArrayRef<ReductionInfo> ReductionInfos, Value *SrcBase, Value *DestBase,
3008 ArrayRef<bool> IsByRef, CopyOptionsTy CopyOptions) {
3009 Type *IndexTy = Builder.getIndexTy(
3010 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3011 Value *RemoteLaneOffset = CopyOptions.RemoteLaneOffset;
3012
3013 // Iterates, element-by-element, through the source Reduce list and
3014 // make a copy.
3015 for (auto En : enumerate(First&: ReductionInfos)) {
3016 const ReductionInfo &RI = En.value();
3017 Value *SrcElementAddr = nullptr;
3018 AllocaInst *DestAlloca = nullptr;
3019 Value *DestElementAddr = nullptr;
3020 Value *DestElementPtrAddr = nullptr;
3021 // Should we shuffle in an element from a remote lane?
3022 bool ShuffleInElement = false;
3023 // Set to true to update the pointer in the dest Reduce list to a
3024 // newly created element.
3025 bool UpdateDestListPtr = false;
3026
3027 // Step 1.1: Get the address for the src element in the Reduce list.
3028 Value *SrcElementPtrAddr = Builder.CreateInBoundsGEP(
3029 Ty: ReductionArrayTy, Ptr: SrcBase,
3030 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3031 SrcElementAddr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: SrcElementPtrAddr);
3032
3033 // Step 1.2: Create a temporary to store the element in the destination
3034 // Reduce list.
3035 DestElementPtrAddr = Builder.CreateInBoundsGEP(
3036 Ty: ReductionArrayTy, Ptr: DestBase,
3037 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3038 bool IsByRefElem = (!IsByRef.empty() && IsByRef[En.index()]);
3039 switch (Action) {
3040 case CopyAction::RemoteLaneToThread: {
3041 InsertPointTy CurIP = Builder.saveIP();
3042 Builder.restoreIP(IP: AllocaIP);
3043
3044 Type *DestAllocaType =
3045 IsByRefElem ? RI.ByRefAllocatedType : RI.ElementType;
3046 DestAlloca = Builder.CreateAlloca(Ty: DestAllocaType, ArraySize: nullptr,
3047 Name: ".omp.reduction.element");
3048 DestAlloca->setAlignment(
3049 M.getDataLayout().getPrefTypeAlign(Ty: DestAllocaType));
3050 DestElementAddr = DestAlloca;
3051 DestElementAddr =
3052 Builder.CreateAddrSpaceCast(V: DestElementAddr, DestTy: Builder.getPtrTy(),
3053 Name: DestElementAddr->getName() + ".ascast");
3054 Builder.restoreIP(IP: CurIP);
3055 ShuffleInElement = true;
3056 UpdateDestListPtr = true;
3057 break;
3058 }
3059 case CopyAction::ThreadCopy: {
3060 DestElementAddr =
3061 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: DestElementPtrAddr);
3062 break;
3063 }
3064 }
3065
3066 // Now that all active lanes have read the element in the
3067 // Reduce list, shuffle over the value from the remote lane.
3068 if (ShuffleInElement) {
3069 Type *ShuffleType = RI.ElementType;
3070 Value *ShuffleSrcAddr = SrcElementAddr;
3071 Value *ShuffleDestAddr = DestElementAddr;
3072 AllocaInst *LocalStorage = nullptr;
3073
3074 if (IsByRefElem) {
3075 assert(RI.ByRefElementType && "Expected by-ref element type to be set");
3076 assert(RI.ByRefAllocatedType &&
3077 "Expected by-ref allocated type to be set");
3078 // For by-ref reductions, we need to copy from the remote lane the
3079 // actual value of the partial reduction computed by that remote lane;
3080 // rather than, for example, a pointer to that data or, even worse, a
3081 // pointer to the descriptor of the by-ref reduction element.
3082 ShuffleType = RI.ByRefElementType;
3083
3084 InsertPointOrErrorTy GenResult =
3085 RI.DataPtrPtrGen(Builder.saveIP(), ShuffleSrcAddr, ShuffleSrcAddr);
3086
3087 if (!GenResult)
3088 return GenResult.takeError();
3089
3090 ShuffleSrcAddr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ShuffleSrcAddr);
3091
3092 {
3093 InsertPointTy OldIP = Builder.saveIP();
3094 Builder.restoreIP(IP: AllocaIP);
3095
3096 LocalStorage = Builder.CreateAlloca(Ty: ShuffleType);
3097 Builder.restoreIP(IP: OldIP);
3098 ShuffleDestAddr = LocalStorage;
3099 }
3100 }
3101
3102 shuffleAndStore(AllocaIP, SrcAddr: ShuffleSrcAddr, DstAddr: ShuffleDestAddr, ElemType: ShuffleType,
3103 Offset: RemoteLaneOffset, ReductionArrayTy, IsByRefElem);
3104
3105 if (IsByRefElem) {
3106 // Copy descriptor from source and update base_ptr to shuffled data
3107 Value *DestDescriptorAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3108 V: DestAlloca, DestTy: Builder.getPtrTy(), Name: ".ascast");
3109
3110 InsertPointOrErrorTy GenResult = generateReductionDescriptor(
3111 DescriptorAddr: DestDescriptorAddr, DataPtr: LocalStorage, SrcDescriptorAddr: SrcElementAddr,
3112 DescriptorType: RI.ByRefAllocatedType, DataPtrPtrGen: RI.DataPtrPtrGen);
3113
3114 if (!GenResult)
3115 return GenResult.takeError();
3116 }
3117 } else {
3118 switch (RI.EvaluationKind) {
3119 case EvalKind::Scalar: {
3120 Value *Elem = Builder.CreateLoad(Ty: RI.ElementType, Ptr: SrcElementAddr);
3121 // Store the source element value to the dest element address.
3122 Builder.CreateStore(Val: Elem, Ptr: DestElementAddr);
3123 break;
3124 }
3125 case EvalKind::Complex: {
3126 Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
3127 Ty: RI.ElementType, Ptr: SrcElementAddr, Idx0: 0, Idx1: 0, Name: ".realp");
3128 Value *SrcReal = Builder.CreateLoad(
3129 Ty: RI.ElementType->getStructElementType(N: 0), Ptr: SrcRealPtr, Name: ".real");
3130 Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
3131 Ty: RI.ElementType, Ptr: SrcElementAddr, Idx0: 0, Idx1: 1, Name: ".imagp");
3132 Value *SrcImg = Builder.CreateLoad(
3133 Ty: RI.ElementType->getStructElementType(N: 1), Ptr: SrcImgPtr, Name: ".imag");
3134
3135 Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
3136 Ty: RI.ElementType, Ptr: DestElementAddr, Idx0: 0, Idx1: 0, Name: ".realp");
3137 Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
3138 Ty: RI.ElementType, Ptr: DestElementAddr, Idx0: 0, Idx1: 1, Name: ".imagp");
3139 Builder.CreateStore(Val: SrcReal, Ptr: DestRealPtr);
3140 Builder.CreateStore(Val: SrcImg, Ptr: DestImgPtr);
3141 break;
3142 }
3143 case EvalKind::Aggregate: {
3144 Value *SizeVal = Builder.getInt64(
3145 C: M.getDataLayout().getTypeStoreSize(Ty: RI.ElementType));
3146 Builder.CreateMemCpy(
3147 Dst: DestElementAddr, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
3148 Src: SrcElementAddr, SrcAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
3149 Size: SizeVal, isVolatile: false);
3150 break;
3151 }
3152 };
3153 }
3154
3155 // Step 3.1: Modify reference in dest Reduce list as needed.
3156 // Modifying the reference in Reduce list to point to the newly
3157 // created element. The element is live in the current function
3158 // scope and that of functions it invokes (i.e., reduce_function).
3159 // RemoteReduceData[i] = (void*)&RemoteElem
3160 if (UpdateDestListPtr) {
3161 Value *CastDestAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3162 V: DestElementAddr, DestTy: Builder.getPtrTy(),
3163 Name: DestElementAddr->getName() + ".ascast");
3164 Builder.CreateStore(Val: CastDestAddr, Ptr: DestElementPtrAddr);
3165 }
3166 }
3167
3168 return Error::success();
3169}
3170
3171Expected<Function *> OpenMPIRBuilder::emitInterWarpCopyFunction(
3172 const LocationDescription &Loc, ArrayRef<ReductionInfo> ReductionInfos,
3173 AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
3174 InsertPointTy SavedIP = Builder.saveIP();
3175 LLVMContext &Ctx = M.getContext();
3176 FunctionType *FuncTy = FunctionType::get(
3177 Result: Builder.getVoidTy(), Params: {Builder.getPtrTy(), Builder.getInt32Ty()},
3178 /* IsVarArg */ isVarArg: false);
3179 Function *WcFunc =
3180 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3181 N: "_omp_reduction_inter_warp_copy_func", M: &M);
3182 WcFunc->setAttributes(FuncAttrs);
3183 WcFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3184 WcFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3185 BasicBlock *EntryBB = BasicBlock::Create(Context&: M.getContext(), Name: "entry", Parent: WcFunc);
3186 Builder.SetInsertPoint(EntryBB);
3187
3188 // ReduceList: thread local Reduce list.
3189 // At the stage of the computation when this function is called, partially
3190 // aggregated values reside in the first lane of every active warp.
3191 Argument *ReduceListArg = WcFunc->getArg(i: 0);
3192 // NumWarps: number of warps active in the parallel region. This could
3193 // be smaller than 32 (max warps in a CTA) for partial block reduction.
3194 Argument *NumWarpsArg = WcFunc->getArg(i: 1);
3195
3196 // This array is used as a medium to transfer, one reduce element at a time,
3197 // the data from the first lane of every warp to lanes in the first warp
3198 // in order to perform the final step of a reduction in a parallel region
3199 // (reduction across warps). The array is placed in NVPTX __shared__ memory
3200 // for reduced latency, as well as to have a distinct copy for concurrently
3201 // executing target regions. The array is declared with common linkage so
3202 // as to be shared across compilation units.
3203 StringRef TransferMediumName =
3204 "__openmp_nvptx_data_transfer_temporary_storage";
3205 GlobalVariable *TransferMedium = M.getGlobalVariable(Name: TransferMediumName);
3206 unsigned WarpSize = Config.getGridValue().GV_Warp_Size;
3207 ArrayType *ArrayTy = ArrayType::get(ElementType: Builder.getInt32Ty(), NumElements: WarpSize);
3208 if (!TransferMedium) {
3209 TransferMedium = new GlobalVariable(
3210 M, ArrayTy, /*isConstant=*/false, GlobalVariable::WeakAnyLinkage,
3211 UndefValue::get(T: ArrayTy), TransferMediumName,
3212 /*InsertBefore=*/nullptr, GlobalVariable::NotThreadLocal,
3213 /*AddressSpace=*/3);
3214 }
3215
3216 // Get the CUDA thread id of the current OpenMP thread on the GPU.
3217 Value *GPUThreadID = getGPUThreadID();
3218 // nvptx_lane_id = nvptx_id % warpsize
3219 Value *LaneID = getNVPTXLaneID();
3220 // nvptx_warp_id = nvptx_id / warpsize
3221 Value *WarpID = getNVPTXWarpID();
3222
3223 InsertPointTy AllocaIP =
3224 InsertPointTy(Builder.GetInsertBlock(),
3225 Builder.GetInsertBlock()->getFirstInsertionPt());
3226 Type *Arg0Type = ReduceListArg->getType();
3227 Type *Arg1Type = NumWarpsArg->getType();
3228 Builder.restoreIP(IP: AllocaIP);
3229 AllocaInst *ReduceListAlloca = Builder.CreateAlloca(
3230 Ty: Arg0Type, ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3231 AllocaInst *NumWarpsAlloca =
3232 Builder.CreateAlloca(Ty: Arg1Type, ArraySize: nullptr, Name: NumWarpsArg->getName() + ".addr");
3233 Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3234 V: ReduceListAlloca, DestTy: Arg0Type, Name: ReduceListAlloca->getName() + ".ascast");
3235 Value *NumWarpsAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3236 V: NumWarpsAlloca, DestTy: Builder.getPtrTy(AddrSpace: 0),
3237 Name: NumWarpsAlloca->getName() + ".ascast");
3238 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListAddrCast);
3239 Builder.CreateStore(Val: NumWarpsArg, Ptr: NumWarpsAddrCast);
3240 AllocaIP = getInsertPointAfterInstr(I: NumWarpsAlloca);
3241 InsertPointTy CodeGenIP =
3242 getInsertPointAfterInstr(I: &Builder.GetInsertBlock()->back());
3243 Builder.restoreIP(IP: CodeGenIP);
3244
3245 Value *ReduceList =
3246 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListAddrCast);
3247
3248 for (auto En : enumerate(First&: ReductionInfos)) {
3249 //
3250 // Warp master copies reduce element to transfer medium in __shared__
3251 // memory.
3252 //
3253 const ReductionInfo &RI = En.value();
3254 bool IsByRefElem = !IsByRef.empty() && IsByRef[En.index()];
3255 unsigned RealTySize = M.getDataLayout().getTypeAllocSize(
3256 Ty: IsByRefElem ? RI.ByRefElementType : RI.ElementType);
3257 for (unsigned TySize = 4; TySize > 0 && RealTySize > 0; TySize /= 2) {
3258 Type *CType = Builder.getIntNTy(N: TySize * 8);
3259
3260 unsigned NumIters = RealTySize / TySize;
3261 if (NumIters == 0)
3262 continue;
3263 Value *Cnt = nullptr;
3264 Value *CntAddr = nullptr;
3265 BasicBlock *PrecondBB = nullptr;
3266 BasicBlock *ExitBB = nullptr;
3267 if (NumIters > 1) {
3268 CodeGenIP = Builder.saveIP();
3269 Builder.restoreIP(IP: AllocaIP);
3270 CntAddr =
3271 Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr, Name: ".cnt.addr");
3272
3273 CntAddr = Builder.CreateAddrSpaceCast(V: CntAddr, DestTy: Builder.getPtrTy(),
3274 Name: CntAddr->getName() + ".ascast");
3275 Builder.restoreIP(IP: CodeGenIP);
3276 Builder.CreateStore(Val: Constant::getNullValue(Ty: Builder.getInt32Ty()),
3277 Ptr: CntAddr,
3278 /*Volatile=*/isVolatile: false);
3279 PrecondBB = BasicBlock::Create(Context&: Ctx, Name: "precond");
3280 ExitBB = BasicBlock::Create(Context&: Ctx, Name: "exit");
3281 BasicBlock *BodyBB = BasicBlock::Create(Context&: Ctx, Name: "body");
3282 emitBlock(BB: PrecondBB, CurFn: Builder.GetInsertBlock()->getParent());
3283 Cnt = Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: CntAddr,
3284 /*Volatile=*/isVolatile: false);
3285 Value *Cmp = Builder.CreateICmpULT(
3286 LHS: Cnt, RHS: ConstantInt::get(Ty: Builder.getInt32Ty(), V: NumIters));
3287 Builder.CreateCondBr(Cond: Cmp, True: BodyBB, False: ExitBB);
3288 emitBlock(BB: BodyBB, CurFn: Builder.GetInsertBlock()->getParent());
3289 }
3290
3291 // kmpc_barrier.
3292 InsertPointOrErrorTy BarrierIP1 =
3293 createBarrier(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
3294 Kind: omp::Directive::OMPD_unknown,
3295 /* ForceSimpleCall */ false,
3296 /* CheckCancelFlag */ true);
3297 if (!BarrierIP1)
3298 return BarrierIP1.takeError();
3299 BasicBlock *ThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
3300 BasicBlock *ElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
3301 BasicBlock *MergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
3302
3303 // if (lane_id == 0)
3304 Value *IsWarpMaster = Builder.CreateIsNull(Arg: LaneID, Name: "warp_master");
3305 Builder.CreateCondBr(Cond: IsWarpMaster, True: ThenBB, False: ElseBB);
3306 emitBlock(BB: ThenBB, CurFn: Builder.GetInsertBlock()->getParent());
3307
3308 // Reduce element = LocalReduceList[i]
3309 auto *RedListArrayTy =
3310 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3311 Type *IndexTy = Builder.getIndexTy(
3312 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3313 Value *ElemPtrPtr =
3314 Builder.CreateInBoundsGEP(Ty: RedListArrayTy, Ptr: ReduceList,
3315 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0),
3316 ConstantInt::get(Ty: IndexTy, V: En.index())});
3317 // elemptr = ((CopyType*)(elemptrptr)) + I
3318 Value *ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtrPtr);
3319
3320 if (IsByRefElem) {
3321 InsertPointOrErrorTy GenRes =
3322 RI.DataPtrPtrGen(Builder.saveIP(), ElemPtr, ElemPtr);
3323
3324 if (!GenRes)
3325 return GenRes.takeError();
3326
3327 ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtr);
3328 }
3329
3330 if (NumIters > 1)
3331 ElemPtr = Builder.CreateGEP(Ty: Builder.getInt32Ty(), Ptr: ElemPtr, IdxList: Cnt);
3332
3333 // Get pointer to location in transfer medium.
3334 // MediumPtr = &medium[warp_id]
3335 Value *MediumPtr = Builder.CreateInBoundsGEP(
3336 Ty: ArrayTy, Ptr: TransferMedium, IdxList: {Builder.getInt64(C: 0), WarpID});
3337 // elem = *elemptr
3338 //*MediumPtr = elem
3339 Value *Elem = Builder.CreateLoad(Ty: CType, Ptr: ElemPtr);
3340 // Store the source element value to the dest element address.
3341 Builder.CreateStore(Val: Elem, Ptr: MediumPtr,
3342 /*IsVolatile*/ isVolatile: true);
3343 Builder.CreateBr(Dest: MergeBB);
3344
3345 // else
3346 emitBlock(BB: ElseBB, CurFn: Builder.GetInsertBlock()->getParent());
3347 Builder.CreateBr(Dest: MergeBB);
3348
3349 // endif
3350 emitBlock(BB: MergeBB, CurFn: Builder.GetInsertBlock()->getParent());
3351 InsertPointOrErrorTy BarrierIP2 =
3352 createBarrier(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
3353 Kind: omp::Directive::OMPD_unknown,
3354 /* ForceSimpleCall */ false,
3355 /* CheckCancelFlag */ true);
3356 if (!BarrierIP2)
3357 return BarrierIP2.takeError();
3358
3359 // Warp 0 copies reduce element from transfer medium
3360 BasicBlock *W0ThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
3361 BasicBlock *W0ElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
3362 BasicBlock *W0MergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
3363
3364 Value *NumWarpsVal =
3365 Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: NumWarpsAddrCast);
3366 // Up to 32 threads in warp 0 are active.
3367 Value *IsActiveThread =
3368 Builder.CreateICmpULT(LHS: GPUThreadID, RHS: NumWarpsVal, Name: "is_active_thread");
3369 Builder.CreateCondBr(Cond: IsActiveThread, True: W0ThenBB, False: W0ElseBB);
3370
3371 emitBlock(BB: W0ThenBB, CurFn: Builder.GetInsertBlock()->getParent());
3372
3373 // SecMediumPtr = &medium[tid]
3374 // SrcMediumVal = *SrcMediumPtr
3375 Value *SrcMediumPtrVal = Builder.CreateInBoundsGEP(
3376 Ty: ArrayTy, Ptr: TransferMedium, IdxList: {Builder.getInt64(C: 0), GPUThreadID});
3377 // TargetElemPtr = (CopyType*)(SrcDataAddr[i]) + I
3378 Value *TargetElemPtrPtr =
3379 Builder.CreateInBoundsGEP(Ty: RedListArrayTy, Ptr: ReduceList,
3380 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0),
3381 ConstantInt::get(Ty: IndexTy, V: En.index())});
3382 Value *TargetElemPtrVal =
3383 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: TargetElemPtrPtr);
3384 Value *TargetElemPtr = TargetElemPtrVal;
3385
3386 if (IsByRefElem) {
3387 InsertPointOrErrorTy GenRes =
3388 RI.DataPtrPtrGen(Builder.saveIP(), TargetElemPtr, TargetElemPtr);
3389
3390 if (!GenRes)
3391 return GenRes.takeError();
3392
3393 TargetElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: TargetElemPtr);
3394 }
3395
3396 if (NumIters > 1)
3397 TargetElemPtr =
3398 Builder.CreateGEP(Ty: Builder.getInt32Ty(), Ptr: TargetElemPtr, IdxList: Cnt);
3399
3400 // *TargetElemPtr = SrcMediumVal;
3401 Value *SrcMediumValue =
3402 Builder.CreateLoad(Ty: CType, Ptr: SrcMediumPtrVal, /*IsVolatile*/ isVolatile: true);
3403 Builder.CreateStore(Val: SrcMediumValue, Ptr: TargetElemPtr);
3404 Builder.CreateBr(Dest: W0MergeBB);
3405
3406 emitBlock(BB: W0ElseBB, CurFn: Builder.GetInsertBlock()->getParent());
3407 Builder.CreateBr(Dest: W0MergeBB);
3408
3409 emitBlock(BB: W0MergeBB, CurFn: Builder.GetInsertBlock()->getParent());
3410
3411 if (NumIters > 1) {
3412 Cnt = Builder.CreateNSWAdd(
3413 LHS: Cnt, RHS: ConstantInt::get(Ty: Builder.getInt32Ty(), /*V=*/1));
3414 Builder.CreateStore(Val: Cnt, Ptr: CntAddr, /*Volatile=*/isVolatile: false);
3415
3416 auto *CurFn = Builder.GetInsertBlock()->getParent();
3417 emitBranch(Target: PrecondBB);
3418 emitBlock(BB: ExitBB, CurFn);
3419 }
3420 RealTySize %= TySize;
3421 }
3422 }
3423
3424 Builder.CreateRetVoid();
3425 Builder.restoreIP(IP: SavedIP);
3426
3427 return WcFunc;
3428}
3429
3430Expected<Function *> OpenMPIRBuilder::emitShuffleAndReduceFunction(
3431 ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
3432 AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
3433 LLVMContext &Ctx = M.getContext();
3434 FunctionType *FuncTy =
3435 FunctionType::get(Result: Builder.getVoidTy(),
3436 Params: {Builder.getPtrTy(), Builder.getInt16Ty(),
3437 Builder.getInt16Ty(), Builder.getInt16Ty()},
3438 /* IsVarArg */ isVarArg: false);
3439 Function *SarFunc =
3440 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3441 N: "_omp_reduction_shuffle_and_reduce_func", M: &M);
3442 SarFunc->setAttributes(FuncAttrs);
3443 SarFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3444 SarFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3445 SarFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
3446 SarFunc->addParamAttr(ArgNo: 3, Kind: Attribute::NoUndef);
3447 SarFunc->addParamAttr(ArgNo: 1, Kind: Attribute::SExt);
3448 SarFunc->addParamAttr(ArgNo: 2, Kind: Attribute::SExt);
3449 SarFunc->addParamAttr(ArgNo: 3, Kind: Attribute::SExt);
3450 BasicBlock *EntryBB = BasicBlock::Create(Context&: M.getContext(), Name: "entry", Parent: SarFunc);
3451 Builder.SetInsertPoint(EntryBB);
3452
3453 // Thread local Reduce list used to host the values of data to be reduced.
3454 Argument *ReduceListArg = SarFunc->getArg(i: 0);
3455 // Current lane id; could be logical.
3456 Argument *LaneIDArg = SarFunc->getArg(i: 1);
3457 // Offset of the remote source lane relative to the current lane.
3458 Argument *RemoteLaneOffsetArg = SarFunc->getArg(i: 2);
3459 // Algorithm version. This is expected to be known at compile time.
3460 Argument *AlgoVerArg = SarFunc->getArg(i: 3);
3461
3462 Type *ReduceListArgType = ReduceListArg->getType();
3463 Type *LaneIDArgType = LaneIDArg->getType();
3464 Type *LaneIDArgPtrType = Builder.getPtrTy(AddrSpace: 0);
3465 Value *ReduceListAlloca = Builder.CreateAlloca(
3466 Ty: ReduceListArgType, ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3467 Value *LaneIdAlloca = Builder.CreateAlloca(Ty: LaneIDArgType, ArraySize: nullptr,
3468 Name: LaneIDArg->getName() + ".addr");
3469 Value *RemoteLaneOffsetAlloca = Builder.CreateAlloca(
3470 Ty: LaneIDArgType, ArraySize: nullptr, Name: RemoteLaneOffsetArg->getName() + ".addr");
3471 Value *AlgoVerAlloca = Builder.CreateAlloca(Ty: LaneIDArgType, ArraySize: nullptr,
3472 Name: AlgoVerArg->getName() + ".addr");
3473 ArrayType *RedListArrayTy =
3474 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3475
3476 // Create a local thread-private variable to host the Reduce list
3477 // from a remote lane.
3478 Instruction *RemoteReductionListAlloca = Builder.CreateAlloca(
3479 Ty: RedListArrayTy, ArraySize: nullptr, Name: ".omp.reduction.remote_reduce_list");
3480
3481 Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3482 V: ReduceListAlloca, DestTy: ReduceListArgType,
3483 Name: ReduceListAlloca->getName() + ".ascast");
3484 Value *LaneIdAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3485 V: LaneIdAlloca, DestTy: LaneIDArgPtrType, Name: LaneIdAlloca->getName() + ".ascast");
3486 Value *RemoteLaneOffsetAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3487 V: RemoteLaneOffsetAlloca, DestTy: LaneIDArgPtrType,
3488 Name: RemoteLaneOffsetAlloca->getName() + ".ascast");
3489 Value *AlgoVerAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3490 V: AlgoVerAlloca, DestTy: LaneIDArgPtrType, Name: AlgoVerAlloca->getName() + ".ascast");
3491 Value *RemoteListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3492 V: RemoteReductionListAlloca, DestTy: Builder.getPtrTy(),
3493 Name: RemoteReductionListAlloca->getName() + ".ascast");
3494
3495 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListAddrCast);
3496 Builder.CreateStore(Val: LaneIDArg, Ptr: LaneIdAddrCast);
3497 Builder.CreateStore(Val: RemoteLaneOffsetArg, Ptr: RemoteLaneOffsetAddrCast);
3498 Builder.CreateStore(Val: AlgoVerArg, Ptr: AlgoVerAddrCast);
3499
3500 Value *ReduceList = Builder.CreateLoad(Ty: ReduceListArgType, Ptr: ReduceListAddrCast);
3501 Value *LaneId = Builder.CreateLoad(Ty: LaneIDArgType, Ptr: LaneIdAddrCast);
3502 Value *RemoteLaneOffset =
3503 Builder.CreateLoad(Ty: LaneIDArgType, Ptr: RemoteLaneOffsetAddrCast);
3504 Value *AlgoVer = Builder.CreateLoad(Ty: LaneIDArgType, Ptr: AlgoVerAddrCast);
3505
3506 InsertPointTy AllocaIP = getInsertPointAfterInstr(I: RemoteReductionListAlloca);
3507
3508 // This loop iterates through the list of reduce elements and copies,
3509 // element by element, from a remote lane in the warp to RemoteReduceList,
3510 // hosted on the thread's stack.
3511 Error EmitRedLsCpRes = emitReductionListCopy(
3512 AllocaIP, Action: CopyAction::RemoteLaneToThread, ReductionArrayTy: RedListArrayTy, ReductionInfos,
3513 SrcBase: ReduceList, DestBase: RemoteListAddrCast, IsByRef,
3514 CopyOptions: {.RemoteLaneOffset: RemoteLaneOffset, .ScratchpadIndex: nullptr, .ScratchpadWidth: nullptr});
3515
3516 if (EmitRedLsCpRes)
3517 return EmitRedLsCpRes;
3518
3519 // The actions to be performed on the Remote Reduce list is dependent
3520 // on the algorithm version.
3521 //
3522 // if (AlgoVer==0) || (AlgoVer==1 && (LaneId < Offset)) || (AlgoVer==2 &&
3523 // LaneId % 2 == 0 && Offset > 0):
3524 // do the reduction value aggregation
3525 //
3526 // The thread local variable Reduce list is mutated in place to host the
3527 // reduced data, which is the aggregated value produced from local and
3528 // remote lanes.
3529 //
3530 // Note that AlgoVer is expected to be a constant integer known at compile
3531 // time.
3532 // When AlgoVer==0, the first conjunction evaluates to true, making
3533 // the entire predicate true during compile time.
3534 // When AlgoVer==1, the second conjunction has only the second part to be
3535 // evaluated during runtime. Other conjunctions evaluates to false
3536 // during compile time.
3537 // When AlgoVer==2, the third conjunction has only the second part to be
3538 // evaluated during runtime. Other conjunctions evaluates to false
3539 // during compile time.
3540 Value *CondAlgo0 = Builder.CreateIsNull(Arg: AlgoVer);
3541 Value *Algo1 = Builder.CreateICmpEQ(LHS: AlgoVer, RHS: Builder.getInt16(C: 1));
3542 Value *LaneComp = Builder.CreateICmpULT(LHS: LaneId, RHS: RemoteLaneOffset);
3543 Value *CondAlgo1 = Builder.CreateAnd(LHS: Algo1, RHS: LaneComp);
3544 Value *Algo2 = Builder.CreateICmpEQ(LHS: AlgoVer, RHS: Builder.getInt16(C: 2));
3545 Value *LaneIdAnd1 = Builder.CreateAnd(LHS: LaneId, RHS: Builder.getInt16(C: 1));
3546 Value *LaneIdComp = Builder.CreateIsNull(Arg: LaneIdAnd1);
3547 Value *Algo2AndLaneIdComp = Builder.CreateAnd(LHS: Algo2, RHS: LaneIdComp);
3548 Value *RemoteOffsetComp =
3549 Builder.CreateICmpSGT(LHS: RemoteLaneOffset, RHS: Builder.getInt16(C: 0));
3550 Value *CondAlgo2 = Builder.CreateAnd(LHS: Algo2AndLaneIdComp, RHS: RemoteOffsetComp);
3551 Value *CA0OrCA1 = Builder.CreateOr(LHS: CondAlgo0, RHS: CondAlgo1);
3552 Value *CondReduce = Builder.CreateOr(LHS: CA0OrCA1, RHS: CondAlgo2);
3553
3554 BasicBlock *ThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
3555 BasicBlock *ElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
3556 BasicBlock *MergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
3557
3558 Builder.CreateCondBr(Cond: CondReduce, True: ThenBB, False: ElseBB);
3559 emitBlock(BB: ThenBB, CurFn: Builder.GetInsertBlock()->getParent());
3560 Value *LocalReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3561 V: ReduceList, DestTy: Builder.getPtrTy());
3562 Value *RemoteReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3563 V: RemoteListAddrCast, DestTy: Builder.getPtrTy());
3564 createRuntimeFunctionCall(Callee: ReduceFn, Args: {LocalReduceListPtr, RemoteReduceListPtr})
3565 ->addFnAttr(Kind: Attribute::NoUnwind);
3566 Builder.CreateBr(Dest: MergeBB);
3567
3568 emitBlock(BB: ElseBB, CurFn: Builder.GetInsertBlock()->getParent());
3569 Builder.CreateBr(Dest: MergeBB);
3570
3571 emitBlock(BB: MergeBB, CurFn: Builder.GetInsertBlock()->getParent());
3572
3573 // if (AlgoVer==1 && (LaneId >= Offset)) copy Remote Reduce list to local
3574 // Reduce list.
3575 Algo1 = Builder.CreateICmpEQ(LHS: AlgoVer, RHS: Builder.getInt16(C: 1));
3576 Value *LaneIdGtOffset = Builder.CreateICmpUGE(LHS: LaneId, RHS: RemoteLaneOffset);
3577 Value *CondCopy = Builder.CreateAnd(LHS: Algo1, RHS: LaneIdGtOffset);
3578
3579 BasicBlock *CpyThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
3580 BasicBlock *CpyElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
3581 BasicBlock *CpyMergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
3582 Builder.CreateCondBr(Cond: CondCopy, True: CpyThenBB, False: CpyElseBB);
3583
3584 emitBlock(BB: CpyThenBB, CurFn: Builder.GetInsertBlock()->getParent());
3585
3586 EmitRedLsCpRes = emitReductionListCopy(
3587 AllocaIP, Action: CopyAction::ThreadCopy, ReductionArrayTy: RedListArrayTy, ReductionInfos,
3588 SrcBase: RemoteListAddrCast, DestBase: ReduceList, IsByRef);
3589
3590 if (EmitRedLsCpRes)
3591 return EmitRedLsCpRes;
3592
3593 Builder.CreateBr(Dest: CpyMergeBB);
3594
3595 emitBlock(BB: CpyElseBB, CurFn: Builder.GetInsertBlock()->getParent());
3596 Builder.CreateBr(Dest: CpyMergeBB);
3597
3598 emitBlock(BB: CpyMergeBB, CurFn: Builder.GetInsertBlock()->getParent());
3599
3600 Builder.CreateRetVoid();
3601
3602 return SarFunc;
3603}
3604
3605OpenMPIRBuilder::InsertPointOrErrorTy
3606OpenMPIRBuilder::generateReductionDescriptor(
3607 Value *DescriptorAddr, Value *DataPtr, Value *SrcDescriptorAddr,
3608 Type *DescriptorType,
3609 function_ref<InsertPointOrErrorTy(InsertPointTy, Value *, Value *&)>
3610 DataPtrPtrGen) {
3611
3612 // Copy the source descriptor to preserve all metadata (rank, extents,
3613 // strides, etc.)
3614 Value *DescriptorSize =
3615 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: DescriptorType));
3616 Builder.CreateMemCpy(
3617 Dst: DescriptorAddr, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: DescriptorType),
3618 Src: SrcDescriptorAddr, SrcAlign: M.getDataLayout().getPrefTypeAlign(Ty: DescriptorType),
3619 Size: DescriptorSize);
3620
3621 // Update the base pointer field to point to the local shuffled data
3622 Value *DataPtrField;
3623 InsertPointOrErrorTy GenResult =
3624 DataPtrPtrGen(Builder.saveIP(), DescriptorAddr, DataPtrField);
3625
3626 if (!GenResult)
3627 return GenResult.takeError();
3628
3629 Builder.CreateStore(Val: Builder.CreatePointerBitCastOrAddrSpaceCast(
3630 V: DataPtr, DestTy: Builder.getPtrTy(), Name: ".ascast"),
3631 Ptr: DataPtrField);
3632
3633 return Builder.saveIP();
3634}
3635
3636Expected<Function *> OpenMPIRBuilder::emitListToGlobalCopyFunction(
3637 ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
3638 AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
3639 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
3640 LLVMContext &Ctx = M.getContext();
3641 FunctionType *FuncTy = FunctionType::get(
3642 Result: Builder.getVoidTy(),
3643 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3644 /* IsVarArg */ isVarArg: false);
3645 Function *LtGCFunc =
3646 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3647 N: "_omp_reduction_list_to_global_copy_func", M: &M);
3648 LtGCFunc->setAttributes(FuncAttrs);
3649 LtGCFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3650 LtGCFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3651 LtGCFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
3652
3653 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: LtGCFunc);
3654 Builder.SetInsertPoint(EntryBlock);
3655
3656 // Buffer: global reduction buffer.
3657 Argument *BufferArg = LtGCFunc->getArg(i: 0);
3658 // Idx: index of the buffer.
3659 Argument *IdxArg = LtGCFunc->getArg(i: 1);
3660 // ReduceList: thread local Reduce list.
3661 Argument *ReduceListArg = LtGCFunc->getArg(i: 2);
3662
3663 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
3664 Name: BufferArg->getName() + ".addr");
3665 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
3666 Name: IdxArg->getName() + ".addr");
3667 Value *ReduceListArgAlloca = Builder.CreateAlloca(
3668 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3669 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3670 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
3671 Name: BufferArgAlloca->getName() + ".ascast");
3672 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3673 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
3674 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3675 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
3676 Name: ReduceListArgAlloca->getName() + ".ascast");
3677
3678 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
3679 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
3680 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
3681
3682 Value *LocalReduceList =
3683 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
3684 Value *BufferArgVal =
3685 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
3686 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
3687 Type *IndexTy = Builder.getIndexTy(
3688 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3689 for (auto En : enumerate(First&: ReductionInfos)) {
3690 const ReductionInfo &RI = En.value();
3691 auto *RedListArrayTy =
3692 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3693 // Reduce element = LocalReduceList[i]
3694 Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
3695 Ty: RedListArrayTy, Ptr: LocalReduceList,
3696 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3697 // elemptr = ((CopyType*)(elemptrptr)) + I
3698 Value *ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtrPtr);
3699
3700 // Global = Buffer.VD[Idx];
3701 Value *BufferVD =
3702 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferArgVal, IdxList: Idxs);
3703 Value *GlobVal = Builder.CreateConstInBoundsGEP2_32(
3704 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
3705
3706 switch (RI.EvaluationKind) {
3707 case EvalKind::Scalar: {
3708 Value *TargetElement;
3709
3710 if (IsByRef.empty() || !IsByRef[En.index()]) {
3711 TargetElement = Builder.CreateLoad(Ty: RI.ElementType, Ptr: ElemPtr);
3712 } else {
3713 InsertPointOrErrorTy GenResult =
3714 RI.DataPtrPtrGen(Builder.saveIP(), ElemPtr, ElemPtr);
3715
3716 if (!GenResult)
3717 return GenResult.takeError();
3718
3719 ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtr);
3720 TargetElement = Builder.CreateLoad(Ty: RI.ByRefElementType, Ptr: ElemPtr);
3721 }
3722
3723 Builder.CreateStore(Val: TargetElement, Ptr: GlobVal);
3724 break;
3725 }
3726 case EvalKind::Complex: {
3727 Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
3728 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 0, Name: ".realp");
3729 Value *SrcReal = Builder.CreateLoad(
3730 Ty: RI.ElementType->getStructElementType(N: 0), Ptr: SrcRealPtr, Name: ".real");
3731 Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
3732 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 1, Name: ".imagp");
3733 Value *SrcImg = Builder.CreateLoad(
3734 Ty: RI.ElementType->getStructElementType(N: 1), Ptr: SrcImgPtr, Name: ".imag");
3735
3736 Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
3737 Ty: RI.ElementType, Ptr: GlobVal, Idx0: 0, Idx1: 0, Name: ".realp");
3738 Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
3739 Ty: RI.ElementType, Ptr: GlobVal, Idx0: 0, Idx1: 1, Name: ".imagp");
3740 Builder.CreateStore(Val: SrcReal, Ptr: DestRealPtr);
3741 Builder.CreateStore(Val: SrcImg, Ptr: DestImgPtr);
3742 break;
3743 }
3744 case EvalKind::Aggregate: {
3745 Value *SizeVal =
3746 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: RI.ElementType));
3747 Builder.CreateMemCpy(
3748 Dst: GlobVal, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType), Src: ElemPtr,
3749 SrcAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType), Size: SizeVal, isVolatile: false);
3750 break;
3751 }
3752 }
3753 }
3754
3755 Builder.CreateRetVoid();
3756 Builder.restoreIP(IP: OldIP);
3757 return LtGCFunc;
3758}
3759
3760Expected<Function *> OpenMPIRBuilder::emitListToGlobalReduceFunction(
3761 ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
3762 Type *ReductionsBufferTy, AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
3763 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
3764 LLVMContext &Ctx = M.getContext();
3765 FunctionType *FuncTy = FunctionType::get(
3766 Result: Builder.getVoidTy(),
3767 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3768 /* IsVarArg */ isVarArg: false);
3769 Function *LtGRFunc =
3770 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3771 N: "_omp_reduction_list_to_global_reduce_func", M: &M);
3772 LtGRFunc->setAttributes(FuncAttrs);
3773 LtGRFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3774 LtGRFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3775 LtGRFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
3776
3777 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: LtGRFunc);
3778 Builder.SetInsertPoint(EntryBlock);
3779
3780 // Buffer: global reduction buffer.
3781 Argument *BufferArg = LtGRFunc->getArg(i: 0);
3782 // Idx: index of the buffer.
3783 Argument *IdxArg = LtGRFunc->getArg(i: 1);
3784 // ReduceList: thread local Reduce list.
3785 Argument *ReduceListArg = LtGRFunc->getArg(i: 2);
3786
3787 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
3788 Name: BufferArg->getName() + ".addr");
3789 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
3790 Name: IdxArg->getName() + ".addr");
3791 Value *ReduceListArgAlloca = Builder.CreateAlloca(
3792 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3793 auto *RedListArrayTy =
3794 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3795
3796 // 1. Build a list of reduction variables.
3797 // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
3798 Value *LocalReduceList =
3799 Builder.CreateAlloca(Ty: RedListArrayTy, ArraySize: nullptr, Name: ".omp.reduction.red_list");
3800
3801 InsertPointTy AllocaIP{EntryBlock, EntryBlock->begin()};
3802
3803 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3804 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
3805 Name: BufferArgAlloca->getName() + ".ascast");
3806 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3807 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
3808 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3809 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
3810 Name: ReduceListArgAlloca->getName() + ".ascast");
3811 Value *LocalReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3812 V: LocalReduceList, DestTy: Builder.getPtrTy(),
3813 Name: LocalReduceList->getName() + ".ascast");
3814
3815 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
3816 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
3817 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
3818
3819 Value *BufferVal = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
3820 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
3821 Type *IndexTy = Builder.getIndexTy(
3822 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3823 for (auto En : enumerate(First&: ReductionInfos)) {
3824 const ReductionInfo &RI = En.value();
3825 Value *ByRefAlloc;
3826
3827 if (!IsByRef.empty() && IsByRef[En.index()]) {
3828 InsertPointTy OldIP = Builder.saveIP();
3829 Builder.restoreIP(IP: AllocaIP);
3830
3831 ByRefAlloc = Builder.CreateAlloca(Ty: RI.ByRefAllocatedType);
3832 ByRefAlloc = Builder.CreatePointerBitCastOrAddrSpaceCast(
3833 V: ByRefAlloc, DestTy: Builder.getPtrTy(), Name: ByRefAlloc->getName() + ".ascast");
3834
3835 Builder.restoreIP(IP: OldIP);
3836 }
3837
3838 Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
3839 Ty: RedListArrayTy, Ptr: LocalReduceListAddrCast,
3840 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3841 Value *BufferVD =
3842 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferVal, IdxList: Idxs);
3843 // Global = Buffer.VD[Idx];
3844 Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3845 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
3846
3847 if (!IsByRef.empty() && IsByRef[En.index()]) {
3848 // Get source descriptor from the reduce list argument
3849 Value *ReduceList =
3850 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
3851 Value *SrcElementPtrPtr =
3852 Builder.CreateInBoundsGEP(Ty: RedListArrayTy, Ptr: ReduceList,
3853 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0),
3854 ConstantInt::get(Ty: IndexTy, V: En.index())});
3855 Value *SrcDescriptorAddr =
3856 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: SrcElementPtrPtr);
3857
3858 // Copy descriptor from source and update base_ptr to global buffer data
3859 InsertPointOrErrorTy GenResult =
3860 generateReductionDescriptor(DescriptorAddr: ByRefAlloc, DataPtr: GlobValPtr, SrcDescriptorAddr,
3861 DescriptorType: RI.ByRefAllocatedType, DataPtrPtrGen: RI.DataPtrPtrGen);
3862
3863 if (!GenResult)
3864 return GenResult.takeError();
3865
3866 Builder.CreateStore(Val: ByRefAlloc, Ptr: TargetElementPtrPtr);
3867 } else {
3868 Builder.CreateStore(Val: GlobValPtr, Ptr: TargetElementPtrPtr);
3869 }
3870 }
3871
3872 // Call reduce_function(GlobalReduceList, ReduceList)
3873 Value *ReduceList =
3874 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
3875 createRuntimeFunctionCall(Callee: ReduceFn, Args: {LocalReduceListAddrCast, ReduceList})
3876 ->addFnAttr(Kind: Attribute::NoUnwind);
3877 Builder.CreateRetVoid();
3878 Builder.restoreIP(IP: OldIP);
3879 return LtGRFunc;
3880}
3881
3882Expected<Function *> OpenMPIRBuilder::emitGlobalToListCopyFunction(
3883 ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
3884 AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
3885 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
3886 LLVMContext &Ctx = M.getContext();
3887 FunctionType *FuncTy = FunctionType::get(
3888 Result: Builder.getVoidTy(),
3889 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3890 /* IsVarArg */ isVarArg: false);
3891 Function *GtLCFunc =
3892 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3893 N: "_omp_reduction_global_to_list_copy_func", M: &M);
3894 GtLCFunc->setAttributes(FuncAttrs);
3895 GtLCFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3896 GtLCFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3897 GtLCFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
3898
3899 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: GtLCFunc);
3900 Builder.SetInsertPoint(EntryBlock);
3901
3902 // Buffer: global reduction buffer.
3903 Argument *BufferArg = GtLCFunc->getArg(i: 0);
3904 // Idx: index of the buffer.
3905 Argument *IdxArg = GtLCFunc->getArg(i: 1);
3906 // ReduceList: thread local Reduce list.
3907 Argument *ReduceListArg = GtLCFunc->getArg(i: 2);
3908
3909 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
3910 Name: BufferArg->getName() + ".addr");
3911 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
3912 Name: IdxArg->getName() + ".addr");
3913 Value *ReduceListArgAlloca = Builder.CreateAlloca(
3914 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3915 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3916 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
3917 Name: BufferArgAlloca->getName() + ".ascast");
3918 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3919 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
3920 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3921 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
3922 Name: ReduceListArgAlloca->getName() + ".ascast");
3923 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
3924 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
3925 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
3926
3927 Value *LocalReduceList =
3928 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
3929 Value *BufferVal = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
3930 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
3931 Type *IndexTy = Builder.getIndexTy(
3932 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3933 for (auto En : enumerate(First&: ReductionInfos)) {
3934 const OpenMPIRBuilder::ReductionInfo &RI = En.value();
3935 auto *RedListArrayTy =
3936 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3937 // Reduce element = LocalReduceList[i]
3938 Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
3939 Ty: RedListArrayTy, Ptr: LocalReduceList,
3940 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3941 // elemptr = ((CopyType*)(elemptrptr)) + I
3942 Value *ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtrPtr);
3943 // Global = Buffer.VD[Idx];
3944 Value *BufferVD =
3945 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferVal, IdxList: Idxs);
3946 Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3947 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
3948
3949 switch (RI.EvaluationKind) {
3950 case EvalKind::Scalar: {
3951 Type *ElemType = RI.ElementType;
3952
3953 if (!IsByRef.empty() && IsByRef[En.index()]) {
3954 ElemType = RI.ByRefElementType;
3955 InsertPointOrErrorTy GenResult =
3956 RI.DataPtrPtrGen(Builder.saveIP(), ElemPtr, ElemPtr);
3957
3958 if (!GenResult)
3959 return GenResult.takeError();
3960
3961 ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtr);
3962 }
3963
3964 Value *TargetElement = Builder.CreateLoad(Ty: ElemType, Ptr: GlobValPtr);
3965 Builder.CreateStore(Val: TargetElement, Ptr: ElemPtr);
3966 break;
3967 }
3968 case EvalKind::Complex: {
3969 Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
3970 Ty: RI.ElementType, Ptr: GlobValPtr, Idx0: 0, Idx1: 0, Name: ".realp");
3971 Value *SrcReal = Builder.CreateLoad(
3972 Ty: RI.ElementType->getStructElementType(N: 0), Ptr: SrcRealPtr, Name: ".real");
3973 Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
3974 Ty: RI.ElementType, Ptr: GlobValPtr, Idx0: 0, Idx1: 1, Name: ".imagp");
3975 Value *SrcImg = Builder.CreateLoad(
3976 Ty: RI.ElementType->getStructElementType(N: 1), Ptr: SrcImgPtr, Name: ".imag");
3977
3978 Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
3979 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 0, Name: ".realp");
3980 Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
3981 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 1, Name: ".imagp");
3982 Builder.CreateStore(Val: SrcReal, Ptr: DestRealPtr);
3983 Builder.CreateStore(Val: SrcImg, Ptr: DestImgPtr);
3984 break;
3985 }
3986 case EvalKind::Aggregate: {
3987 Value *SizeVal =
3988 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: RI.ElementType));
3989 Builder.CreateMemCpy(
3990 Dst: ElemPtr, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
3991 Src: GlobValPtr, SrcAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
3992 Size: SizeVal, isVolatile: false);
3993 break;
3994 }
3995 }
3996 }
3997
3998 Builder.CreateRetVoid();
3999 Builder.restoreIP(IP: OldIP);
4000 return GtLCFunc;
4001}
4002
4003Expected<Function *> OpenMPIRBuilder::emitGlobalToListReduceFunction(
4004 ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
4005 Type *ReductionsBufferTy, AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
4006 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
4007 LLVMContext &Ctx = M.getContext();
4008 auto *FuncTy = FunctionType::get(
4009 Result: Builder.getVoidTy(),
4010 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
4011 /* IsVarArg */ isVarArg: false);
4012 Function *GtLRFunc =
4013 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
4014 N: "_omp_reduction_global_to_list_reduce_func", M: &M);
4015 GtLRFunc->setAttributes(FuncAttrs);
4016 GtLRFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
4017 GtLRFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
4018 GtLRFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
4019
4020 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: GtLRFunc);
4021 Builder.SetInsertPoint(EntryBlock);
4022
4023 // Buffer: global reduction buffer.
4024 Argument *BufferArg = GtLRFunc->getArg(i: 0);
4025 // Idx: index of the buffer.
4026 Argument *IdxArg = GtLRFunc->getArg(i: 1);
4027 // ReduceList: thread local Reduce list.
4028 Argument *ReduceListArg = GtLRFunc->getArg(i: 2);
4029
4030 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
4031 Name: BufferArg->getName() + ".addr");
4032 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
4033 Name: IdxArg->getName() + ".addr");
4034 Value *ReduceListArgAlloca = Builder.CreateAlloca(
4035 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
4036 ArrayType *RedListArrayTy =
4037 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
4038
4039 // 1. Build a list of reduction variables.
4040 // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
4041 Value *LocalReduceList =
4042 Builder.CreateAlloca(Ty: RedListArrayTy, ArraySize: nullptr, Name: ".omp.reduction.red_list");
4043
4044 InsertPointTy AllocaIP{EntryBlock, EntryBlock->begin()};
4045
4046 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4047 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
4048 Name: BufferArgAlloca->getName() + ".ascast");
4049 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4050 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
4051 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4052 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
4053 Name: ReduceListArgAlloca->getName() + ".ascast");
4054 Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
4055 V: LocalReduceList, DestTy: Builder.getPtrTy(),
4056 Name: LocalReduceList->getName() + ".ascast");
4057
4058 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
4059 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
4060 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
4061
4062 Value *BufferVal = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
4063 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
4064 Type *IndexTy = Builder.getIndexTy(
4065 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
4066 for (auto En : enumerate(First&: ReductionInfos)) {
4067 const ReductionInfo &RI = En.value();
4068 Value *ByRefAlloc;
4069
4070 if (!IsByRef.empty() && IsByRef[En.index()]) {
4071 InsertPointTy OldIP = Builder.saveIP();
4072 Builder.restoreIP(IP: AllocaIP);
4073
4074 ByRefAlloc = Builder.CreateAlloca(Ty: RI.ByRefAllocatedType);
4075 ByRefAlloc = Builder.CreatePointerBitCastOrAddrSpaceCast(
4076 V: ByRefAlloc, DestTy: Builder.getPtrTy(), Name: ByRefAlloc->getName() + ".ascast");
4077
4078 Builder.restoreIP(IP: OldIP);
4079 }
4080
4081 Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
4082 Ty: RedListArrayTy, Ptr: ReductionList,
4083 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
4084 // Global = Buffer.VD[Idx];
4085 Value *BufferVD =
4086 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferVal, IdxList: Idxs);
4087 Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
4088 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
4089
4090 if (!IsByRef.empty() && IsByRef[En.index()]) {
4091 // Get source descriptor from the reduce list
4092 Value *ReduceListVal =
4093 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
4094 Value *SrcElementPtrPtr =
4095 Builder.CreateInBoundsGEP(Ty: RedListArrayTy, Ptr: ReduceListVal,
4096 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0),
4097 ConstantInt::get(Ty: IndexTy, V: En.index())});
4098 Value *SrcDescriptorAddr =
4099 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: SrcElementPtrPtr);
4100
4101 // Copy descriptor from source and update base_ptr to global buffer data
4102 InsertPointOrErrorTy GenResult =
4103 generateReductionDescriptor(DescriptorAddr: ByRefAlloc, DataPtr: GlobValPtr, SrcDescriptorAddr,
4104 DescriptorType: RI.ByRefAllocatedType, DataPtrPtrGen: RI.DataPtrPtrGen);
4105 if (!GenResult)
4106 return GenResult.takeError();
4107
4108 Builder.CreateStore(Val: ByRefAlloc, Ptr: TargetElementPtrPtr);
4109 } else {
4110 Builder.CreateStore(Val: GlobValPtr, Ptr: TargetElementPtrPtr);
4111 }
4112 }
4113
4114 // Call reduce_function(ReduceList, GlobalReduceList)
4115 Value *ReduceList =
4116 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
4117 createRuntimeFunctionCall(Callee: ReduceFn, Args: {ReduceList, ReductionList})
4118 ->addFnAttr(Kind: Attribute::NoUnwind);
4119 Builder.CreateRetVoid();
4120 Builder.restoreIP(IP: OldIP);
4121 return GtLRFunc;
4122}
4123
4124std::string OpenMPIRBuilder::getReductionFuncName(StringRef Name) const {
4125 std::string Suffix =
4126 createPlatformSpecificName(Parts: {"omp", "reduction", "reduction_func"});
4127 return (Name + Suffix).str();
4128}
4129
4130Expected<Function *> OpenMPIRBuilder::createReductionFunction(
4131 StringRef ReducerName, ArrayRef<ReductionInfo> ReductionInfos,
4132 ArrayRef<bool> IsByRef, ReductionGenCBKind ReductionGenCBKind,
4133 AttributeList FuncAttrs) {
4134 auto *FuncTy = FunctionType::get(Result: Builder.getVoidTy(),
4135 Params: {Builder.getPtrTy(), Builder.getPtrTy()},
4136 /* IsVarArg */ isVarArg: false);
4137 std::string Name = getReductionFuncName(Name: ReducerName);
4138 Function *ReductionFunc =
4139 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage, N: Name, M: &M);
4140 ReductionFunc->setAttributes(FuncAttrs);
4141 ReductionFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
4142 ReductionFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
4143 BasicBlock *EntryBB =
4144 BasicBlock::Create(Context&: M.getContext(), Name: "entry", Parent: ReductionFunc);
4145 Builder.SetInsertPoint(EntryBB);
4146
4147 // Need to alloca memory here and deal with the pointers before getting
4148 // LHS/RHS pointers out
4149 Value *LHSArrayPtr = nullptr;
4150 Value *RHSArrayPtr = nullptr;
4151 Argument *Arg0 = ReductionFunc->getArg(i: 0);
4152 Argument *Arg1 = ReductionFunc->getArg(i: 1);
4153 Type *Arg0Type = Arg0->getType();
4154 Type *Arg1Type = Arg1->getType();
4155
4156 Value *LHSAlloca =
4157 Builder.CreateAlloca(Ty: Arg0Type, ArraySize: nullptr, Name: Arg0->getName() + ".addr");
4158 Value *RHSAlloca =
4159 Builder.CreateAlloca(Ty: Arg1Type, ArraySize: nullptr, Name: Arg1->getName() + ".addr");
4160 Value *LHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4161 V: LHSAlloca, DestTy: Arg0Type, Name: LHSAlloca->getName() + ".ascast");
4162 Value *RHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4163 V: RHSAlloca, DestTy: Arg1Type, Name: RHSAlloca->getName() + ".ascast");
4164 Builder.CreateStore(Val: Arg0, Ptr: LHSAddrCast);
4165 Builder.CreateStore(Val: Arg1, Ptr: RHSAddrCast);
4166 LHSArrayPtr = Builder.CreateLoad(Ty: Arg0Type, Ptr: LHSAddrCast);
4167 RHSArrayPtr = Builder.CreateLoad(Ty: Arg1Type, Ptr: RHSAddrCast);
4168
4169 Type *RedArrayTy = ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
4170 Type *IndexTy = Builder.getIndexTy(
4171 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
4172 SmallVector<Value *> LHSPtrs, RHSPtrs;
4173 for (auto En : enumerate(First&: ReductionInfos)) {
4174 const ReductionInfo &RI = En.value();
4175 Value *RHSI8PtrPtr = Builder.CreateInBoundsGEP(
4176 Ty: RedArrayTy, Ptr: RHSArrayPtr,
4177 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
4178 Value *RHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: RHSI8PtrPtr);
4179 Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
4180 V: RHSI8Ptr, DestTy: RI.PrivateVariable->getType(),
4181 Name: RHSI8Ptr->getName() + ".ascast");
4182
4183 Value *LHSI8PtrPtr = Builder.CreateInBoundsGEP(
4184 Ty: RedArrayTy, Ptr: LHSArrayPtr,
4185 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
4186 Value *LHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: LHSI8PtrPtr);
4187 Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
4188 V: LHSI8Ptr, DestTy: RI.Variable->getType(), Name: LHSI8Ptr->getName() + ".ascast");
4189
4190 if (ReductionGenCBKind == ReductionGenCBKind::Clang) {
4191 LHSPtrs.emplace_back(Args&: LHSPtr);
4192 RHSPtrs.emplace_back(Args&: RHSPtr);
4193 } else {
4194 Value *LHS = LHSPtr;
4195 Value *RHS = RHSPtr;
4196
4197 if (!IsByRef.empty() && !IsByRef[En.index()]) {
4198 LHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: LHSPtr);
4199 RHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: RHSPtr);
4200 }
4201
4202 Value *Reduced;
4203 InsertPointOrErrorTy AfterIP =
4204 RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
4205 if (!AfterIP)
4206 return AfterIP.takeError();
4207 if (!Builder.GetInsertBlock())
4208 return ReductionFunc;
4209
4210 Builder.restoreIP(IP: *AfterIP);
4211
4212 if (!IsByRef.empty() && !IsByRef[En.index()])
4213 Builder.CreateStore(Val: Reduced, Ptr: LHSPtr);
4214 }
4215 }
4216
4217 if (ReductionGenCBKind == ReductionGenCBKind::Clang)
4218 for (auto En : enumerate(First&: ReductionInfos)) {
4219 unsigned Index = En.index();
4220 const ReductionInfo &RI = En.value();
4221 Value *LHSFixupPtr, *RHSFixupPtr;
4222 Builder.restoreIP(IP: RI.ReductionGenClang(
4223 Builder.saveIP(), Index, &LHSFixupPtr, &RHSFixupPtr, ReductionFunc));
4224
4225 // Fix the CallBack code genereated to use the correct Values for the LHS
4226 // and RHS
4227 LHSFixupPtr->replaceUsesWithIf(
4228 New: LHSPtrs[Index], ShouldReplace: [ReductionFunc](const Use &U) {
4229 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
4230 ReductionFunc;
4231 });
4232 RHSFixupPtr->replaceUsesWithIf(
4233 New: RHSPtrs[Index], ShouldReplace: [ReductionFunc](const Use &U) {
4234 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
4235 ReductionFunc;
4236 });
4237 }
4238
4239 Builder.CreateRetVoid();
4240 // Compiling with `-O0`, `alloca`s emitted in non-entry blocks are not hoisted
4241 // to the entry block (this is dones for higher opt levels by later passes in
4242 // the pipeline). This has caused issues because non-entry `alloca`s force the
4243 // function to use dynamic stack allocations and we might run out of scratch
4244 // memory.
4245 hoistNonEntryAllocasToEntryBlock(Func: ReductionFunc);
4246
4247 return ReductionFunc;
4248}
4249
4250static void
4251checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
4252 bool IsGPU) {
4253 for (const OpenMPIRBuilder::ReductionInfo &RI : ReductionInfos) {
4254 (void)RI;
4255 assert(RI.Variable && "expected non-null variable");
4256 assert(RI.PrivateVariable && "expected non-null private variable");
4257 assert((RI.ReductionGen || RI.ReductionGenClang) &&
4258 "expected non-null reduction generator callback");
4259 if (!IsGPU) {
4260 assert(
4261 RI.Variable->getType() == RI.PrivateVariable->getType() &&
4262 "expected variables and their private equivalents to have the same "
4263 "type");
4264 }
4265 assert(RI.Variable->getType()->isPointerTy() &&
4266 "expected variables to be pointers");
4267 }
4268}
4269
4270OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
4271 const LocationDescription &Loc, InsertPointTy AllocaIP,
4272 InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
4273 ArrayRef<bool> IsByRef, bool IsNoWait, bool IsTeamsReduction,
4274 ReductionGenCBKind ReductionGenCBKind, std::optional<omp::GV> GridValue,
4275 unsigned ReductionBufNum, Value *SrcLocInfo) {
4276 if (!updateToLocation(Loc))
4277 return InsertPointTy();
4278 Builder.restoreIP(IP: CodeGenIP);
4279 checkReductionInfos(ReductionInfos, /*IsGPU*/ true);
4280 LLVMContext &Ctx = M.getContext();
4281
4282 // Source location for the ident struct
4283 if (!SrcLocInfo) {
4284 uint32_t SrcLocStrSize;
4285 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4286 SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4287 }
4288
4289 if (ReductionInfos.size() == 0)
4290 return Builder.saveIP();
4291
4292 BasicBlock *ContinuationBlock = nullptr;
4293 if (ReductionGenCBKind != ReductionGenCBKind::Clang) {
4294 // Copied code from createReductions
4295 BasicBlock *InsertBlock = Loc.IP.getBlock();
4296 ContinuationBlock =
4297 InsertBlock->splitBasicBlock(I: Loc.IP.getPoint(), BBName: "reduce.finalize");
4298 InsertBlock->getTerminator()->eraseFromParent();
4299 Builder.SetInsertPoint(TheBB: InsertBlock, IP: InsertBlock->end());
4300 }
4301
4302 Function *CurFunc = Builder.GetInsertBlock()->getParent();
4303 AttributeList FuncAttrs;
4304 AttrBuilder AttrBldr(Ctx);
4305 for (auto Attr : CurFunc->getAttributes().getFnAttrs())
4306 AttrBldr.addAttribute(A: Attr);
4307 AttrBldr.removeAttribute(Val: Attribute::OptimizeNone);
4308 FuncAttrs = FuncAttrs.addFnAttributes(C&: Ctx, B: AttrBldr);
4309
4310 CodeGenIP = Builder.saveIP();
4311 Expected<Function *> ReductionResult = createReductionFunction(
4312 ReducerName: Builder.GetInsertBlock()->getParent()->getName(), ReductionInfos, IsByRef,
4313 ReductionGenCBKind, FuncAttrs);
4314 if (!ReductionResult)
4315 return ReductionResult.takeError();
4316 Function *ReductionFunc = *ReductionResult;
4317 Builder.restoreIP(IP: CodeGenIP);
4318
4319 // Set the grid value in the config needed for lowering later on
4320 if (GridValue.has_value())
4321 Config.setGridValue(GridValue.value());
4322 else
4323 Config.setGridValue(getGridValue(T, Kernel: ReductionFunc));
4324
4325 // Build res = __kmpc_reduce{_nowait}(<gtid>, <n>, sizeof(RedList),
4326 // RedList, shuffle_reduce_func, interwarp_copy_func);
4327 // or
4328 // Build res = __kmpc_reduce_teams_nowait_simple(<loc>, <gtid>, <lck>);
4329 Value *Res;
4330
4331 // 1. Build a list of reduction variables.
4332 // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
4333 auto Size = ReductionInfos.size();
4334 Type *PtrTy = PointerType::get(C&: Ctx, AddressSpace: Config.getDefaultTargetAS());
4335 Type *FuncPtrTy =
4336 Builder.getPtrTy(AddrSpace: M.getDataLayout().getProgramAddressSpace());
4337 Type *RedArrayTy = ArrayType::get(ElementType: PtrTy, NumElements: Size);
4338 CodeGenIP = Builder.saveIP();
4339 Builder.restoreIP(IP: AllocaIP);
4340 Value *ReductionListAlloca =
4341 Builder.CreateAlloca(Ty: RedArrayTy, ArraySize: nullptr, Name: ".omp.reduction.red_list");
4342 Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
4343 V: ReductionListAlloca, DestTy: PtrTy, Name: ReductionListAlloca->getName() + ".ascast");
4344 Builder.restoreIP(IP: CodeGenIP);
4345 Type *IndexTy = Builder.getIndexTy(
4346 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
4347 for (auto En : enumerate(First&: ReductionInfos)) {
4348 const ReductionInfo &RI = En.value();
4349 Value *ElemPtr = Builder.CreateInBoundsGEP(
4350 Ty: RedArrayTy, Ptr: ReductionList,
4351 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
4352
4353 Value *PrivateVar = RI.PrivateVariable;
4354 bool IsByRefElem = !IsByRef.empty() && IsByRef[En.index()];
4355 if (IsByRefElem)
4356 PrivateVar = Builder.CreateLoad(Ty: RI.ElementType, Ptr: PrivateVar);
4357
4358 Value *CastElem =
4359 Builder.CreatePointerBitCastOrAddrSpaceCast(V: PrivateVar, DestTy: PtrTy);
4360 Builder.CreateStore(Val: CastElem, Ptr: ElemPtr);
4361 }
4362 CodeGenIP = Builder.saveIP();
4363 Expected<Function *> SarFunc = emitShuffleAndReduceFunction(
4364 ReductionInfos, ReduceFn: ReductionFunc, FuncAttrs, IsByRef);
4365
4366 if (!SarFunc)
4367 return SarFunc.takeError();
4368
4369 Expected<Function *> CopyResult =
4370 emitInterWarpCopyFunction(Loc, ReductionInfos, FuncAttrs, IsByRef);
4371 if (!CopyResult)
4372 return CopyResult.takeError();
4373 Function *WcFunc = *CopyResult;
4374 Builder.restoreIP(IP: CodeGenIP);
4375
4376 Value *RL = Builder.CreatePointerBitCastOrAddrSpaceCast(V: ReductionList, DestTy: PtrTy);
4377
4378 unsigned MaxDataSize = 0;
4379 SmallVector<Type *> ReductionTypeArgs;
4380 for (auto En : enumerate(First&: ReductionInfos)) {
4381 auto Size = M.getDataLayout().getTypeStoreSize(Ty: En.value().ElementType);
4382 if (Size > MaxDataSize)
4383 MaxDataSize = Size;
4384 Type *RedTypeArg = (!IsByRef.empty() && IsByRef[En.index()])
4385 ? En.value().ByRefElementType
4386 : En.value().ElementType;
4387 ReductionTypeArgs.emplace_back(Args&: RedTypeArg);
4388 }
4389 Value *ReductionDataSize =
4390 Builder.getInt64(C: MaxDataSize * ReductionInfos.size());
4391 if (!IsTeamsReduction) {
4392 Value *SarFuncCast =
4393 Builder.CreatePointerBitCastOrAddrSpaceCast(V: *SarFunc, DestTy: FuncPtrTy);
4394 Value *WcFuncCast =
4395 Builder.CreatePointerBitCastOrAddrSpaceCast(V: WcFunc, DestTy: FuncPtrTy);
4396 Value *Args[] = {SrcLocInfo, ReductionDataSize, RL, SarFuncCast,
4397 WcFuncCast};
4398 Function *Pv2Ptr = getOrCreateRuntimeFunctionPtr(
4399 FnID: RuntimeFunction::OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2);
4400 Res = createRuntimeFunctionCall(Callee: Pv2Ptr, Args);
4401 } else {
4402 CodeGenIP = Builder.saveIP();
4403 StructType *ReductionsBufferTy = StructType::create(
4404 Context&: Ctx, Elements: ReductionTypeArgs, Name: "struct._globalized_locals_ty");
4405 Function *RedFixedBufferFn = getOrCreateRuntimeFunctionPtr(
4406 FnID: RuntimeFunction::OMPRTL___kmpc_reduction_get_fixed_buffer);
4407
4408 Expected<Function *> LtGCFunc = emitListToGlobalCopyFunction(
4409 ReductionInfos, ReductionsBufferTy, FuncAttrs, IsByRef);
4410 if (!LtGCFunc)
4411 return LtGCFunc.takeError();
4412
4413 Expected<Function *> LtGRFunc = emitListToGlobalReduceFunction(
4414 ReductionInfos, ReduceFn: ReductionFunc, ReductionsBufferTy, FuncAttrs, IsByRef);
4415 if (!LtGRFunc)
4416 return LtGRFunc.takeError();
4417
4418 Expected<Function *> GtLCFunc = emitGlobalToListCopyFunction(
4419 ReductionInfos, ReductionsBufferTy, FuncAttrs, IsByRef);
4420 if (!GtLCFunc)
4421 return GtLCFunc.takeError();
4422
4423 Expected<Function *> GtLRFunc = emitGlobalToListReduceFunction(
4424 ReductionInfos, ReduceFn: ReductionFunc, ReductionsBufferTy, FuncAttrs, IsByRef);
4425 if (!GtLRFunc)
4426 return GtLRFunc.takeError();
4427
4428 Builder.restoreIP(IP: CodeGenIP);
4429
4430 Value *KernelTeamsReductionPtr = createRuntimeFunctionCall(
4431 Callee: RedFixedBufferFn, Args: {}, Name: "_openmp_teams_reductions_buffer_$_$ptr");
4432
4433 Value *Args3[] = {SrcLocInfo,
4434 KernelTeamsReductionPtr,
4435 Builder.getInt32(C: ReductionBufNum),
4436 ReductionDataSize,
4437 RL,
4438 *SarFunc,
4439 WcFunc,
4440 *LtGCFunc,
4441 *LtGRFunc,
4442 *GtLCFunc,
4443 *GtLRFunc};
4444
4445 Function *TeamsReduceFn = getOrCreateRuntimeFunctionPtr(
4446 FnID: RuntimeFunction::OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2);
4447 Res = createRuntimeFunctionCall(Callee: TeamsReduceFn, Args: Args3);
4448 }
4449
4450 // 5. Build if (res == 1)
4451 BasicBlock *ExitBB = BasicBlock::Create(Context&: Ctx, Name: ".omp.reduction.done");
4452 BasicBlock *ThenBB = BasicBlock::Create(Context&: Ctx, Name: ".omp.reduction.then");
4453 Value *Cond = Builder.CreateICmpEQ(LHS: Res, RHS: Builder.getInt32(C: 1));
4454 Builder.CreateCondBr(Cond, True: ThenBB, False: ExitBB);
4455
4456 // 6. Build then branch: where we have reduced values in the master
4457 // thread in each team.
4458 // __kmpc_end_reduce{_nowait}(<gtid>);
4459 // break;
4460 emitBlock(BB: ThenBB, CurFn: CurFunc);
4461
4462 // Add emission of __kmpc_end_reduce{_nowait}(<gtid>);
4463 for (auto En : enumerate(First&: ReductionInfos)) {
4464 const ReductionInfo &RI = En.value();
4465 Type *ValueType = RI.ElementType;
4466 Value *RedValue = RI.Variable;
4467 Value *RHS =
4468 Builder.CreatePointerBitCastOrAddrSpaceCast(V: RI.PrivateVariable, DestTy: PtrTy);
4469
4470 if (ReductionGenCBKind == ReductionGenCBKind::Clang) {
4471 Value *LHSPtr, *RHSPtr;
4472 Builder.restoreIP(IP: RI.ReductionGenClang(Builder.saveIP(), En.index(),
4473 &LHSPtr, &RHSPtr, CurFunc));
4474
4475 // Fix the CallBack code genereated to use the correct Values for the LHS
4476 // and RHS
4477 LHSPtr->replaceUsesWithIf(New: RedValue, ShouldReplace: [ReductionFunc](const Use &U) {
4478 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
4479 ReductionFunc;
4480 });
4481 RHSPtr->replaceUsesWithIf(New: RHS, ShouldReplace: [ReductionFunc](const Use &U) {
4482 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
4483 ReductionFunc;
4484 });
4485 } else {
4486 if (IsByRef.empty() || !IsByRef[En.index()]) {
4487 RedValue = Builder.CreateLoad(Ty: ValueType, Ptr: RI.Variable,
4488 Name: "red.value." + Twine(En.index()));
4489 }
4490 Value *PrivateRedValue = Builder.CreateLoad(
4491 Ty: ValueType, Ptr: RHS, Name: "red.private.value" + Twine(En.index()));
4492 Value *Reduced;
4493 InsertPointOrErrorTy AfterIP =
4494 RI.ReductionGen(Builder.saveIP(), RedValue, PrivateRedValue, Reduced);
4495 if (!AfterIP)
4496 return AfterIP.takeError();
4497 Builder.restoreIP(IP: *AfterIP);
4498
4499 if (!IsByRef.empty() && !IsByRef[En.index()])
4500 Builder.CreateStore(Val: Reduced, Ptr: RI.Variable);
4501 }
4502 }
4503 emitBlock(BB: ExitBB, CurFn: CurFunc);
4504 if (ContinuationBlock) {
4505 Builder.CreateBr(Dest: ContinuationBlock);
4506 Builder.SetInsertPoint(ContinuationBlock);
4507 }
4508 Config.setEmitLLVMUsed();
4509
4510 return Builder.saveIP();
4511}
4512
4513static Function *getFreshReductionFunc(Module &M) {
4514 Type *VoidTy = Type::getVoidTy(C&: M.getContext());
4515 Type *Int8PtrTy = PointerType::getUnqual(C&: M.getContext());
4516 auto *FuncTy =
4517 FunctionType::get(Result: VoidTy, Params: {Int8PtrTy, Int8PtrTy}, /* IsVarArg */ isVarArg: false);
4518 return Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
4519 N: ".omp.reduction.func", M: &M);
4520}
4521
4522static Error populateReductionFunction(
4523 Function *ReductionFunc,
4524 ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
4525 IRBuilder<> &Builder, ArrayRef<bool> IsByRef, bool IsGPU) {
4526 Module *Module = ReductionFunc->getParent();
4527 BasicBlock *ReductionFuncBlock =
4528 BasicBlock::Create(Context&: Module->getContext(), Name: "", Parent: ReductionFunc);
4529 Builder.SetInsertPoint(ReductionFuncBlock);
4530 Value *LHSArrayPtr = nullptr;
4531 Value *RHSArrayPtr = nullptr;
4532 if (IsGPU) {
4533 // Need to alloca memory here and deal with the pointers before getting
4534 // LHS/RHS pointers out
4535 //
4536 Argument *Arg0 = ReductionFunc->getArg(i: 0);
4537 Argument *Arg1 = ReductionFunc->getArg(i: 1);
4538 Type *Arg0Type = Arg0->getType();
4539 Type *Arg1Type = Arg1->getType();
4540
4541 Value *LHSAlloca =
4542 Builder.CreateAlloca(Ty: Arg0Type, ArraySize: nullptr, Name: Arg0->getName() + ".addr");
4543 Value *RHSAlloca =
4544 Builder.CreateAlloca(Ty: Arg1Type, ArraySize: nullptr, Name: Arg1->getName() + ".addr");
4545 Value *LHSAddrCast =
4546 Builder.CreatePointerBitCastOrAddrSpaceCast(V: LHSAlloca, DestTy: Arg0Type);
4547 Value *RHSAddrCast =
4548 Builder.CreatePointerBitCastOrAddrSpaceCast(V: RHSAlloca, DestTy: Arg1Type);
4549 Builder.CreateStore(Val: Arg0, Ptr: LHSAddrCast);
4550 Builder.CreateStore(Val: Arg1, Ptr: RHSAddrCast);
4551 LHSArrayPtr = Builder.CreateLoad(Ty: Arg0Type, Ptr: LHSAddrCast);
4552 RHSArrayPtr = Builder.CreateLoad(Ty: Arg1Type, Ptr: RHSAddrCast);
4553 } else {
4554 LHSArrayPtr = ReductionFunc->getArg(i: 0);
4555 RHSArrayPtr = ReductionFunc->getArg(i: 1);
4556 }
4557
4558 unsigned NumReductions = ReductionInfos.size();
4559 Type *RedArrayTy = ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: NumReductions);
4560
4561 for (auto En : enumerate(First&: ReductionInfos)) {
4562 const OpenMPIRBuilder::ReductionInfo &RI = En.value();
4563 Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
4564 Ty: RedArrayTy, Ptr: LHSArrayPtr, Idx0: 0, Idx1: En.index());
4565 Value *LHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: LHSI8PtrPtr);
4566 Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
4567 V: LHSI8Ptr, DestTy: RI.Variable->getType());
4568 Value *LHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: LHSPtr);
4569 Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
4570 Ty: RedArrayTy, Ptr: RHSArrayPtr, Idx0: 0, Idx1: En.index());
4571 Value *RHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: RHSI8PtrPtr);
4572 Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
4573 V: RHSI8Ptr, DestTy: RI.PrivateVariable->getType());
4574 Value *RHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: RHSPtr);
4575 Value *Reduced;
4576 OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
4577 RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
4578 if (!AfterIP)
4579 return AfterIP.takeError();
4580
4581 Builder.restoreIP(IP: *AfterIP);
4582 // TODO: Consider flagging an error.
4583 if (!Builder.GetInsertBlock())
4584 return Error::success();
4585
4586 // store is inside of the reduction region when using by-ref
4587 if (!IsByRef[En.index()])
4588 Builder.CreateStore(Val: Reduced, Ptr: LHSPtr);
4589 }
4590 Builder.CreateRetVoid();
4591 return Error::success();
4592}
4593
4594OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductions(
4595 const LocationDescription &Loc, InsertPointTy AllocaIP,
4596 ArrayRef<ReductionInfo> ReductionInfos, ArrayRef<bool> IsByRef,
4597 bool IsNoWait, bool IsTeamsReduction) {
4598 assert(ReductionInfos.size() == IsByRef.size());
4599 if (Config.isGPU())
4600 return createReductionsGPU(Loc, AllocaIP, CodeGenIP: Builder.saveIP(), ReductionInfos,
4601 IsByRef, IsNoWait, IsTeamsReduction);
4602
4603 checkReductionInfos(ReductionInfos, /*IsGPU*/ false);
4604
4605 if (!updateToLocation(Loc))
4606 return InsertPointTy();
4607
4608 if (ReductionInfos.size() == 0)
4609 return Builder.saveIP();
4610
4611 BasicBlock *InsertBlock = Loc.IP.getBlock();
4612 BasicBlock *ContinuationBlock =
4613 InsertBlock->splitBasicBlock(I: Loc.IP.getPoint(), BBName: "reduce.finalize");
4614 InsertBlock->getTerminator()->eraseFromParent();
4615
4616 // Create and populate array of type-erased pointers to private reduction
4617 // values.
4618 unsigned NumReductions = ReductionInfos.size();
4619 Type *RedArrayTy = ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: NumReductions);
4620 Builder.SetInsertPoint(AllocaIP.getBlock()->getTerminator());
4621 Value *RedArray = Builder.CreateAlloca(Ty: RedArrayTy, ArraySize: nullptr, Name: "red.array");
4622
4623 Builder.SetInsertPoint(TheBB: InsertBlock, IP: InsertBlock->end());
4624
4625 for (auto En : enumerate(First&: ReductionInfos)) {
4626 unsigned Index = En.index();
4627 const ReductionInfo &RI = En.value();
4628 Value *RedArrayElemPtr = Builder.CreateConstInBoundsGEP2_64(
4629 Ty: RedArrayTy, Ptr: RedArray, Idx0: 0, Idx1: Index, Name: "red.array.elem." + Twine(Index));
4630 Builder.CreateStore(Val: RI.PrivateVariable, Ptr: RedArrayElemPtr);
4631 }
4632
4633 // Emit a call to the runtime function that orchestrates the reduction.
4634 // Declare the reduction function in the process.
4635 Type *IndexTy = Builder.getIndexTy(
4636 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
4637 Function *Func = Builder.GetInsertBlock()->getParent();
4638 Module *Module = Func->getParent();
4639 uint32_t SrcLocStrSize;
4640 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4641 bool CanGenerateAtomic = all_of(Range&: ReductionInfos, P: [](const ReductionInfo &RI) {
4642 return RI.AtomicReductionGen;
4643 });
4644 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize,
4645 LocFlags: CanGenerateAtomic
4646 ? IdentFlag::OMP_IDENT_FLAG_ATOMIC_REDUCE
4647 : IdentFlag(0));
4648 Value *ThreadId = getOrCreateThreadID(Ident);
4649 Constant *NumVariables = Builder.getInt32(C: NumReductions);
4650 const DataLayout &DL = Module->getDataLayout();
4651 unsigned RedArrayByteSize = DL.getTypeStoreSize(Ty: RedArrayTy);
4652 Constant *RedArraySize = ConstantInt::get(Ty: IndexTy, V: RedArrayByteSize);
4653 Function *ReductionFunc = getFreshReductionFunc(M&: *Module);
4654 Value *Lock = getOMPCriticalRegionLock(CriticalName: ".reduction");
4655 Function *ReduceFunc = getOrCreateRuntimeFunctionPtr(
4656 FnID: IsNoWait ? RuntimeFunction::OMPRTL___kmpc_reduce_nowait
4657 : RuntimeFunction::OMPRTL___kmpc_reduce);
4658 CallInst *ReduceCall =
4659 createRuntimeFunctionCall(Callee: ReduceFunc,
4660 Args: {Ident, ThreadId, NumVariables, RedArraySize,
4661 RedArray, ReductionFunc, Lock},
4662 Name: "reduce");
4663
4664 // Create final reduction entry blocks for the atomic and non-atomic case.
4665 // Emit IR that dispatches control flow to one of the blocks based on the
4666 // reduction supporting the atomic mode.
4667 BasicBlock *NonAtomicRedBlock =
4668 BasicBlock::Create(Context&: Module->getContext(), Name: "reduce.switch.nonatomic", Parent: Func);
4669 BasicBlock *AtomicRedBlock =
4670 BasicBlock::Create(Context&: Module->getContext(), Name: "reduce.switch.atomic", Parent: Func);
4671 SwitchInst *Switch =
4672 Builder.CreateSwitch(V: ReduceCall, Dest: ContinuationBlock, /* NumCases */ 2);
4673 Switch->addCase(OnVal: Builder.getInt32(C: 1), Dest: NonAtomicRedBlock);
4674 Switch->addCase(OnVal: Builder.getInt32(C: 2), Dest: AtomicRedBlock);
4675
4676 // Populate the non-atomic reduction using the elementwise reduction function.
4677 // This loads the elements from the global and private variables and reduces
4678 // them before storing back the result to the global variable.
4679 Builder.SetInsertPoint(NonAtomicRedBlock);
4680 for (auto En : enumerate(First&: ReductionInfos)) {
4681 const ReductionInfo &RI = En.value();
4682 Type *ValueType = RI.ElementType;
4683 // We have one less load for by-ref case because that load is now inside of
4684 // the reduction region
4685 Value *RedValue = RI.Variable;
4686 if (!IsByRef[En.index()]) {
4687 RedValue = Builder.CreateLoad(Ty: ValueType, Ptr: RI.Variable,
4688 Name: "red.value." + Twine(En.index()));
4689 }
4690 Value *PrivateRedValue =
4691 Builder.CreateLoad(Ty: ValueType, Ptr: RI.PrivateVariable,
4692 Name: "red.private.value." + Twine(En.index()));
4693 Value *Reduced;
4694 InsertPointOrErrorTy AfterIP =
4695 RI.ReductionGen(Builder.saveIP(), RedValue, PrivateRedValue, Reduced);
4696 if (!AfterIP)
4697 return AfterIP.takeError();
4698 Builder.restoreIP(IP: *AfterIP);
4699
4700 if (!Builder.GetInsertBlock())
4701 return InsertPointTy();
4702 // for by-ref case, the load is inside of the reduction region
4703 if (!IsByRef[En.index()])
4704 Builder.CreateStore(Val: Reduced, Ptr: RI.Variable);
4705 }
4706 Function *EndReduceFunc = getOrCreateRuntimeFunctionPtr(
4707 FnID: IsNoWait ? RuntimeFunction::OMPRTL___kmpc_end_reduce_nowait
4708 : RuntimeFunction::OMPRTL___kmpc_end_reduce);
4709 createRuntimeFunctionCall(Callee: EndReduceFunc, Args: {Ident, ThreadId, Lock});
4710 Builder.CreateBr(Dest: ContinuationBlock);
4711
4712 // Populate the atomic reduction using the atomic elementwise reduction
4713 // function. There are no loads/stores here because they will be happening
4714 // inside the atomic elementwise reduction.
4715 Builder.SetInsertPoint(AtomicRedBlock);
4716 if (CanGenerateAtomic && llvm::none_of(Range&: IsByRef, P: [](bool P) { return P; })) {
4717 for (const ReductionInfo &RI : ReductionInfos) {
4718 InsertPointOrErrorTy AfterIP = RI.AtomicReductionGen(
4719 Builder.saveIP(), RI.ElementType, RI.Variable, RI.PrivateVariable);
4720 if (!AfterIP)
4721 return AfterIP.takeError();
4722 Builder.restoreIP(IP: *AfterIP);
4723 if (!Builder.GetInsertBlock())
4724 return InsertPointTy();
4725 }
4726 Builder.CreateBr(Dest: ContinuationBlock);
4727 } else {
4728 Builder.CreateUnreachable();
4729 }
4730
4731 // Populate the outlined reduction function using the elementwise reduction
4732 // function. Partial values are extracted from the type-erased array of
4733 // pointers to private variables.
4734 Error Err = populateReductionFunction(ReductionFunc, ReductionInfos, Builder,
4735 IsByRef, /*isGPU=*/IsGPU: false);
4736 if (Err)
4737 return Err;
4738
4739 if (!Builder.GetInsertBlock())
4740 return InsertPointTy();
4741
4742 Builder.SetInsertPoint(ContinuationBlock);
4743 return Builder.saveIP();
4744}
4745
4746OpenMPIRBuilder::InsertPointOrErrorTy
4747OpenMPIRBuilder::createMaster(const LocationDescription &Loc,
4748 BodyGenCallbackTy BodyGenCB,
4749 FinalizeCallbackTy FiniCB) {
4750 if (!updateToLocation(Loc))
4751 return Loc.IP;
4752
4753 Directive OMPD = Directive::OMPD_master;
4754 uint32_t SrcLocStrSize;
4755 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4756 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4757 Value *ThreadId = getOrCreateThreadID(Ident);
4758 Value *Args[] = {Ident, ThreadId};
4759
4760 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_master);
4761 Instruction *EntryCall = createRuntimeFunctionCall(Callee: EntryRTLFn, Args);
4762
4763 Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_master);
4764 Instruction *ExitCall = createRuntimeFunctionCall(Callee: ExitRTLFn, Args);
4765
4766 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
4767 /*Conditional*/ true, /*hasFinalize*/ HasFinalize: true);
4768}
4769
4770OpenMPIRBuilder::InsertPointOrErrorTy
4771OpenMPIRBuilder::createMasked(const LocationDescription &Loc,
4772 BodyGenCallbackTy BodyGenCB,
4773 FinalizeCallbackTy FiniCB, Value *Filter) {
4774 if (!updateToLocation(Loc))
4775 return Loc.IP;
4776
4777 Directive OMPD = Directive::OMPD_masked;
4778 uint32_t SrcLocStrSize;
4779 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4780 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4781 Value *ThreadId = getOrCreateThreadID(Ident);
4782 Value *Args[] = {Ident, ThreadId, Filter};
4783 Value *ArgsEnd[] = {Ident, ThreadId};
4784
4785 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_masked);
4786 Instruction *EntryCall = createRuntimeFunctionCall(Callee: EntryRTLFn, Args);
4787
4788 Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_masked);
4789 Instruction *ExitCall = createRuntimeFunctionCall(Callee: ExitRTLFn, Args: ArgsEnd);
4790
4791 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
4792 /*Conditional*/ true, /*hasFinalize*/ HasFinalize: true);
4793}
4794
4795static llvm::CallInst *emitNoUnwindRuntimeCall(IRBuilder<> &Builder,
4796 llvm::FunctionCallee Callee,
4797 ArrayRef<llvm::Value *> Args,
4798 const llvm::Twine &Name) {
4799 llvm::CallInst *Call = Builder.CreateCall(
4800 Callee, Args, OpBundles: SmallVector<llvm::OperandBundleDef, 1>(), Name);
4801 Call->setDoesNotThrow();
4802 return Call;
4803}
4804
4805// Expects input basic block is dominated by BeforeScanBB.
4806// Once Scan directive is encountered, the code after scan directive should be
4807// dominated by AfterScanBB. Scan directive splits the code sequence to
4808// scan and input phase. Based on whether inclusive or exclusive
4809// clause is used in the scan directive and whether input loop or scan loop
4810// is lowered, it adds jumps to input and scan phase. First Scan loop is the
4811// input loop and second is the scan loop. The code generated handles only
4812// inclusive scans now.
4813OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createScan(
4814 const LocationDescription &Loc, InsertPointTy AllocaIP,
4815 ArrayRef<llvm::Value *> ScanVars, ArrayRef<llvm::Type *> ScanVarsType,
4816 bool IsInclusive, ScanInfo *ScanRedInfo) {
4817 if (ScanRedInfo->OMPFirstScanLoop) {
4818 llvm::Error Err = emitScanBasedDirectiveDeclsIR(AllocaIP, ScanVars,
4819 ScanVarsType, ScanRedInfo);
4820 if (Err)
4821 return Err;
4822 }
4823 if (!updateToLocation(Loc))
4824 return Loc.IP;
4825
4826 llvm::Value *IV = ScanRedInfo->IV;
4827
4828 if (ScanRedInfo->OMPFirstScanLoop) {
4829 // Emit buffer[i] = red; at the end of the input phase.
4830 for (size_t i = 0; i < ScanVars.size(); i++) {
4831 Value *BuffPtr = (*(ScanRedInfo->ScanBuffPtrs))[ScanVars[i]];
4832 Value *Buff = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BuffPtr);
4833 Type *DestTy = ScanVarsType[i];
4834 Value *Val = Builder.CreateInBoundsGEP(Ty: DestTy, Ptr: Buff, IdxList: IV, Name: "arrayOffset");
4835 Value *Src = Builder.CreateLoad(Ty: DestTy, Ptr: ScanVars[i]);
4836
4837 Builder.CreateStore(Val: Src, Ptr: Val);
4838 }
4839 }
4840 Builder.CreateBr(Dest: ScanRedInfo->OMPScanLoopExit);
4841 emitBlock(BB: ScanRedInfo->OMPScanDispatch,
4842 CurFn: Builder.GetInsertBlock()->getParent());
4843
4844 if (!ScanRedInfo->OMPFirstScanLoop) {
4845 IV = ScanRedInfo->IV;
4846 // Emit red = buffer[i]; at the entrance to the scan phase.
4847 // TODO: if exclusive scan, the red = buffer[i-1] needs to be updated.
4848 for (size_t i = 0; i < ScanVars.size(); i++) {
4849 Value *BuffPtr = (*(ScanRedInfo->ScanBuffPtrs))[ScanVars[i]];
4850 Value *Buff = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BuffPtr);
4851 Type *DestTy = ScanVarsType[i];
4852 Value *SrcPtr =
4853 Builder.CreateInBoundsGEP(Ty: DestTy, Ptr: Buff, IdxList: IV, Name: "arrayOffset");
4854 Value *Src = Builder.CreateLoad(Ty: DestTy, Ptr: SrcPtr);
4855 Builder.CreateStore(Val: Src, Ptr: ScanVars[i]);
4856 }
4857 }
4858
4859 // TODO: Update it to CreateBr and remove dead blocks
4860 llvm::Value *CmpI = Builder.getInt1(V: true);
4861 if (ScanRedInfo->OMPFirstScanLoop == IsInclusive) {
4862 Builder.CreateCondBr(Cond: CmpI, True: ScanRedInfo->OMPBeforeScanBlock,
4863 False: ScanRedInfo->OMPAfterScanBlock);
4864 } else {
4865 Builder.CreateCondBr(Cond: CmpI, True: ScanRedInfo->OMPAfterScanBlock,
4866 False: ScanRedInfo->OMPBeforeScanBlock);
4867 }
4868 emitBlock(BB: ScanRedInfo->OMPAfterScanBlock,
4869 CurFn: Builder.GetInsertBlock()->getParent());
4870 Builder.SetInsertPoint(ScanRedInfo->OMPAfterScanBlock);
4871 return Builder.saveIP();
4872}
4873
4874Error OpenMPIRBuilder::emitScanBasedDirectiveDeclsIR(
4875 InsertPointTy AllocaIP, ArrayRef<Value *> ScanVars,
4876 ArrayRef<Type *> ScanVarsType, ScanInfo *ScanRedInfo) {
4877
4878 Builder.restoreIP(IP: AllocaIP);
4879 // Create the shared pointer at alloca IP.
4880 for (size_t i = 0; i < ScanVars.size(); i++) {
4881 llvm::Value *BuffPtr =
4882 Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: "vla");
4883 (*(ScanRedInfo->ScanBuffPtrs))[ScanVars[i]] = BuffPtr;
4884 }
4885
4886 // Allocate temporary buffer by master thread
4887 auto BodyGenCB = [&](InsertPointTy AllocaIP,
4888 InsertPointTy CodeGenIP) -> Error {
4889 Builder.restoreIP(IP: CodeGenIP);
4890 Value *AllocSpan =
4891 Builder.CreateAdd(LHS: ScanRedInfo->Span, RHS: Builder.getInt32(C: 1));
4892 for (size_t i = 0; i < ScanVars.size(); i++) {
4893 Type *IntPtrTy = Builder.getInt32Ty();
4894 Constant *Allocsize = ConstantExpr::getSizeOf(Ty: ScanVarsType[i]);
4895 Allocsize = ConstantExpr::getTruncOrBitCast(C: Allocsize, Ty: IntPtrTy);
4896 Value *Buff = Builder.CreateMalloc(IntPtrTy, AllocTy: ScanVarsType[i], AllocSize: Allocsize,
4897 ArraySize: AllocSpan, MallocF: nullptr, Name: "arr");
4898 Builder.CreateStore(Val: Buff, Ptr: (*(ScanRedInfo->ScanBuffPtrs))[ScanVars[i]]);
4899 }
4900 return Error::success();
4901 };
4902 // TODO: Perform finalization actions for variables. This has to be
4903 // called for variables which have destructors/finalizers.
4904 auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); };
4905
4906 Builder.SetInsertPoint(ScanRedInfo->OMPScanInit->getTerminator());
4907 llvm::Value *FilterVal = Builder.getInt32(C: 0);
4908 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
4909 createMasked(Loc: Builder.saveIP(), BodyGenCB, FiniCB, Filter: FilterVal);
4910
4911 if (!AfterIP)
4912 return AfterIP.takeError();
4913 Builder.restoreIP(IP: *AfterIP);
4914 BasicBlock *InputBB = Builder.GetInsertBlock();
4915 if (InputBB->getTerminator())
4916 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
4917 AfterIP = createBarrier(Loc: Builder.saveIP(), Kind: llvm::omp::OMPD_barrier);
4918 if (!AfterIP)
4919 return AfterIP.takeError();
4920 Builder.restoreIP(IP: *AfterIP);
4921
4922 return Error::success();
4923}
4924
4925Error OpenMPIRBuilder::emitScanBasedDirectiveFinalsIR(
4926 ArrayRef<ReductionInfo> ReductionInfos, ScanInfo *ScanRedInfo) {
4927 auto BodyGenCB = [&](InsertPointTy AllocaIP,
4928 InsertPointTy CodeGenIP) -> Error {
4929 Builder.restoreIP(IP: CodeGenIP);
4930 for (ReductionInfo RedInfo : ReductionInfos) {
4931 Value *PrivateVar = RedInfo.PrivateVariable;
4932 Value *OrigVar = RedInfo.Variable;
4933 Value *BuffPtr = (*(ScanRedInfo->ScanBuffPtrs))[PrivateVar];
4934 Value *Buff = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BuffPtr);
4935
4936 Type *SrcTy = RedInfo.ElementType;
4937 Value *Val = Builder.CreateInBoundsGEP(Ty: SrcTy, Ptr: Buff, IdxList: ScanRedInfo->Span,
4938 Name: "arrayOffset");
4939 Value *Src = Builder.CreateLoad(Ty: SrcTy, Ptr: Val);
4940
4941 Builder.CreateStore(Val: Src, Ptr: OrigVar);
4942 Builder.CreateFree(Source: Buff);
4943 }
4944 return Error::success();
4945 };
4946 // TODO: Perform finalization actions for variables. This has to be
4947 // called for variables which have destructors/finalizers.
4948 auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); };
4949
4950 if (ScanRedInfo->OMPScanFinish->getTerminator())
4951 Builder.SetInsertPoint(ScanRedInfo->OMPScanFinish->getTerminator());
4952 else
4953 Builder.SetInsertPoint(ScanRedInfo->OMPScanFinish);
4954
4955 llvm::Value *FilterVal = Builder.getInt32(C: 0);
4956 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
4957 createMasked(Loc: Builder.saveIP(), BodyGenCB, FiniCB, Filter: FilterVal);
4958
4959 if (!AfterIP)
4960 return AfterIP.takeError();
4961 Builder.restoreIP(IP: *AfterIP);
4962 BasicBlock *InputBB = Builder.GetInsertBlock();
4963 if (InputBB->getTerminator())
4964 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
4965 AfterIP = createBarrier(Loc: Builder.saveIP(), Kind: llvm::omp::OMPD_barrier);
4966 if (!AfterIP)
4967 return AfterIP.takeError();
4968 Builder.restoreIP(IP: *AfterIP);
4969 return Error::success();
4970}
4971
4972OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitScanReduction(
4973 const LocationDescription &Loc,
4974 ArrayRef<llvm::OpenMPIRBuilder::ReductionInfo> ReductionInfos,
4975 ScanInfo *ScanRedInfo) {
4976
4977 if (!updateToLocation(Loc))
4978 return Loc.IP;
4979 auto BodyGenCB = [&](InsertPointTy AllocaIP,
4980 InsertPointTy CodeGenIP) -> Error {
4981 Builder.restoreIP(IP: CodeGenIP);
4982 Function *CurFn = Builder.GetInsertBlock()->getParent();
4983 // for (int k = 0; k <= ceil(log2(n)); ++k)
4984 llvm::BasicBlock *LoopBB =
4985 BasicBlock::Create(Context&: CurFn->getContext(), Name: "omp.outer.log.scan.body");
4986 llvm::BasicBlock *ExitBB =
4987 splitBB(Builder, CreateBranch: false, Name: "omp.outer.log.scan.exit");
4988 llvm::Function *F = llvm::Intrinsic::getOrInsertDeclaration(
4989 M: Builder.GetInsertBlock()->getModule(),
4990 id: (llvm::Intrinsic::ID)llvm::Intrinsic::log2, Tys: Builder.getDoubleTy());
4991 llvm::BasicBlock *InputBB = Builder.GetInsertBlock();
4992 llvm::Value *Arg =
4993 Builder.CreateUIToFP(V: ScanRedInfo->Span, DestTy: Builder.getDoubleTy());
4994 llvm::Value *LogVal = emitNoUnwindRuntimeCall(Builder, Callee: F, Args: Arg, Name: "");
4995 F = llvm::Intrinsic::getOrInsertDeclaration(
4996 M: Builder.GetInsertBlock()->getModule(),
4997 id: (llvm::Intrinsic::ID)llvm::Intrinsic::ceil, Tys: Builder.getDoubleTy());
4998 LogVal = emitNoUnwindRuntimeCall(Builder, Callee: F, Args: LogVal, Name: "");
4999 LogVal = Builder.CreateFPToUI(V: LogVal, DestTy: Builder.getInt32Ty());
5000 llvm::Value *NMin1 = Builder.CreateNUWSub(
5001 LHS: ScanRedInfo->Span,
5002 RHS: llvm::ConstantInt::get(Ty: ScanRedInfo->Span->getType(), V: 1));
5003 Builder.SetInsertPoint(InputBB);
5004 Builder.CreateBr(Dest: LoopBB);
5005 emitBlock(BB: LoopBB, CurFn);
5006 Builder.SetInsertPoint(LoopBB);
5007
5008 PHINode *Counter = Builder.CreatePHI(Ty: Builder.getInt32Ty(), NumReservedValues: 2);
5009 // size pow2k = 1;
5010 PHINode *Pow2K = Builder.CreatePHI(Ty: Builder.getInt32Ty(), NumReservedValues: 2);
5011 Counter->addIncoming(V: llvm::ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
5012 BB: InputBB);
5013 Pow2K->addIncoming(V: llvm::ConstantInt::get(Ty: Builder.getInt32Ty(), V: 1),
5014 BB: InputBB);
5015 // for (size i = n - 1; i >= 2 ^ k; --i)
5016 // tmp[i] op= tmp[i-pow2k];
5017 llvm::BasicBlock *InnerLoopBB =
5018 BasicBlock::Create(Context&: CurFn->getContext(), Name: "omp.inner.log.scan.body");
5019 llvm::BasicBlock *InnerExitBB =
5020 BasicBlock::Create(Context&: CurFn->getContext(), Name: "omp.inner.log.scan.exit");
5021 llvm::Value *CmpI = Builder.CreateICmpUGE(LHS: NMin1, RHS: Pow2K);
5022 Builder.CreateCondBr(Cond: CmpI, True: InnerLoopBB, False: InnerExitBB);
5023 emitBlock(BB: InnerLoopBB, CurFn);
5024 Builder.SetInsertPoint(InnerLoopBB);
5025 PHINode *IVal = Builder.CreatePHI(Ty: Builder.getInt32Ty(), NumReservedValues: 2);
5026 IVal->addIncoming(V: NMin1, BB: LoopBB);
5027 for (ReductionInfo RedInfo : ReductionInfos) {
5028 Value *ReductionVal = RedInfo.PrivateVariable;
5029 Value *BuffPtr = (*(ScanRedInfo->ScanBuffPtrs))[ReductionVal];
5030 Value *Buff = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BuffPtr);
5031 Type *DestTy = RedInfo.ElementType;
5032 Value *IV = Builder.CreateAdd(LHS: IVal, RHS: Builder.getInt32(C: 1));
5033 Value *LHSPtr =
5034 Builder.CreateInBoundsGEP(Ty: DestTy, Ptr: Buff, IdxList: IV, Name: "arrayOffset");
5035 Value *OffsetIval = Builder.CreateNUWSub(LHS: IV, RHS: Pow2K);
5036 Value *RHSPtr =
5037 Builder.CreateInBoundsGEP(Ty: DestTy, Ptr: Buff, IdxList: OffsetIval, Name: "arrayOffset");
5038 Value *LHS = Builder.CreateLoad(Ty: DestTy, Ptr: LHSPtr);
5039 Value *RHS = Builder.CreateLoad(Ty: DestTy, Ptr: RHSPtr);
5040 llvm::Value *Result;
5041 InsertPointOrErrorTy AfterIP =
5042 RedInfo.ReductionGen(Builder.saveIP(), LHS, RHS, Result);
5043 if (!AfterIP)
5044 return AfterIP.takeError();
5045 Builder.CreateStore(Val: Result, Ptr: LHSPtr);
5046 }
5047 llvm::Value *NextIVal = Builder.CreateNUWSub(
5048 LHS: IVal, RHS: llvm::ConstantInt::get(Ty: Builder.getInt32Ty(), V: 1));
5049 IVal->addIncoming(V: NextIVal, BB: Builder.GetInsertBlock());
5050 CmpI = Builder.CreateICmpUGE(LHS: NextIVal, RHS: Pow2K);
5051 Builder.CreateCondBr(Cond: CmpI, True: InnerLoopBB, False: InnerExitBB);
5052 emitBlock(BB: InnerExitBB, CurFn);
5053 llvm::Value *Next = Builder.CreateNUWAdd(
5054 LHS: Counter, RHS: llvm::ConstantInt::get(Ty: Counter->getType(), V: 1));
5055 Counter->addIncoming(V: Next, BB: Builder.GetInsertBlock());
5056 // pow2k <<= 1;
5057 llvm::Value *NextPow2K = Builder.CreateShl(LHS: Pow2K, RHS: 1, Name: "", /*HasNUW=*/true);
5058 Pow2K->addIncoming(V: NextPow2K, BB: Builder.GetInsertBlock());
5059 llvm::Value *Cmp = Builder.CreateICmpNE(LHS: Next, RHS: LogVal);
5060 Builder.CreateCondBr(Cond: Cmp, True: LoopBB, False: ExitBB);
5061 Builder.SetInsertPoint(ExitBB->getFirstInsertionPt());
5062 return Error::success();
5063 };
5064
5065 // TODO: Perform finalization actions for variables. This has to be
5066 // called for variables which have destructors/finalizers.
5067 auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); };
5068
5069 llvm::Value *FilterVal = Builder.getInt32(C: 0);
5070 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
5071 createMasked(Loc: Builder.saveIP(), BodyGenCB, FiniCB, Filter: FilterVal);
5072
5073 if (!AfterIP)
5074 return AfterIP.takeError();
5075 Builder.restoreIP(IP: *AfterIP);
5076 AfterIP = createBarrier(Loc: Builder.saveIP(), Kind: llvm::omp::OMPD_barrier);
5077
5078 if (!AfterIP)
5079 return AfterIP.takeError();
5080 Builder.restoreIP(IP: *AfterIP);
5081 Error Err = emitScanBasedDirectiveFinalsIR(ReductionInfos, ScanRedInfo);
5082 if (Err)
5083 return Err;
5084
5085 return AfterIP;
5086}
5087
5088Error OpenMPIRBuilder::emitScanBasedDirectiveIR(
5089 llvm::function_ref<Error()> InputLoopGen,
5090 llvm::function_ref<Error(LocationDescription Loc)> ScanLoopGen,
5091 ScanInfo *ScanRedInfo) {
5092
5093 {
5094 // Emit loop with input phase:
5095 // for (i: 0..<num_iters>) {
5096 // <input phase>;
5097 // buffer[i] = red;
5098 // }
5099 ScanRedInfo->OMPFirstScanLoop = true;
5100 Error Err = InputLoopGen();
5101 if (Err)
5102 return Err;
5103 }
5104 {
5105 // Emit loop with scan phase:
5106 // for (i: 0..<num_iters>) {
5107 // red = buffer[i];
5108 // <scan phase>;
5109 // }
5110 ScanRedInfo->OMPFirstScanLoop = false;
5111 Error Err = ScanLoopGen(Builder.saveIP());
5112 if (Err)
5113 return Err;
5114 }
5115 return Error::success();
5116}
5117
5118void OpenMPIRBuilder::createScanBBs(ScanInfo *ScanRedInfo) {
5119 Function *Fun = Builder.GetInsertBlock()->getParent();
5120 ScanRedInfo->OMPScanDispatch =
5121 BasicBlock::Create(Context&: Fun->getContext(), Name: "omp.inscan.dispatch");
5122 ScanRedInfo->OMPAfterScanBlock =
5123 BasicBlock::Create(Context&: Fun->getContext(), Name: "omp.after.scan.bb");
5124 ScanRedInfo->OMPBeforeScanBlock =
5125 BasicBlock::Create(Context&: Fun->getContext(), Name: "omp.before.scan.bb");
5126 ScanRedInfo->OMPScanLoopExit =
5127 BasicBlock::Create(Context&: Fun->getContext(), Name: "omp.scan.loop.exit");
5128}
5129CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton(
5130 DebugLoc DL, Value *TripCount, Function *F, BasicBlock *PreInsertBefore,
5131 BasicBlock *PostInsertBefore, const Twine &Name) {
5132 Module *M = F->getParent();
5133 LLVMContext &Ctx = M->getContext();
5134 Type *IndVarTy = TripCount->getType();
5135
5136 // Create the basic block structure.
5137 BasicBlock *Preheader =
5138 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".preheader", Parent: F, InsertBefore: PreInsertBefore);
5139 BasicBlock *Header =
5140 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".header", Parent: F, InsertBefore: PreInsertBefore);
5141 BasicBlock *Cond =
5142 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".cond", Parent: F, InsertBefore: PreInsertBefore);
5143 BasicBlock *Body =
5144 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".body", Parent: F, InsertBefore: PreInsertBefore);
5145 BasicBlock *Latch =
5146 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".inc", Parent: F, InsertBefore: PostInsertBefore);
5147 BasicBlock *Exit =
5148 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".exit", Parent: F, InsertBefore: PostInsertBefore);
5149 BasicBlock *After =
5150 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".after", Parent: F, InsertBefore: PostInsertBefore);
5151
5152 // Use specified DebugLoc for new instructions.
5153 Builder.SetCurrentDebugLocation(DL);
5154
5155 Builder.SetInsertPoint(Preheader);
5156 Builder.CreateBr(Dest: Header);
5157
5158 Builder.SetInsertPoint(Header);
5159 PHINode *IndVarPHI = Builder.CreatePHI(Ty: IndVarTy, NumReservedValues: 2, Name: "omp_" + Name + ".iv");
5160 IndVarPHI->addIncoming(V: ConstantInt::get(Ty: IndVarTy, V: 0), BB: Preheader);
5161 Builder.CreateBr(Dest: Cond);
5162
5163 Builder.SetInsertPoint(Cond);
5164 Value *Cmp =
5165 Builder.CreateICmpULT(LHS: IndVarPHI, RHS: TripCount, Name: "omp_" + Name + ".cmp");
5166 Builder.CreateCondBr(Cond: Cmp, True: Body, False: Exit);
5167
5168 Builder.SetInsertPoint(Body);
5169 Builder.CreateBr(Dest: Latch);
5170
5171 Builder.SetInsertPoint(Latch);
5172 Value *Next = Builder.CreateAdd(LHS: IndVarPHI, RHS: ConstantInt::get(Ty: IndVarTy, V: 1),
5173 Name: "omp_" + Name + ".next", /*HasNUW=*/true);
5174 Builder.CreateBr(Dest: Header);
5175 IndVarPHI->addIncoming(V: Next, BB: Latch);
5176
5177 Builder.SetInsertPoint(Exit);
5178 Builder.CreateBr(Dest: After);
5179
5180 // Remember and return the canonical control flow.
5181 LoopInfos.emplace_front();
5182 CanonicalLoopInfo *CL = &LoopInfos.front();
5183
5184 CL->Header = Header;
5185 CL->Cond = Cond;
5186 CL->Latch = Latch;
5187 CL->Exit = Exit;
5188
5189#ifndef NDEBUG
5190 CL->assertOK();
5191#endif
5192 return CL;
5193}
5194
5195Expected<CanonicalLoopInfo *>
5196OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
5197 LoopBodyGenCallbackTy BodyGenCB,
5198 Value *TripCount, const Twine &Name) {
5199 BasicBlock *BB = Loc.IP.getBlock();
5200 BasicBlock *NextBB = BB->getNextNode();
5201
5202 CanonicalLoopInfo *CL = createLoopSkeleton(DL: Loc.DL, TripCount, F: BB->getParent(),
5203 PreInsertBefore: NextBB, PostInsertBefore: NextBB, Name);
5204 BasicBlock *After = CL->getAfter();
5205
5206 // If location is not set, don't connect the loop.
5207 if (updateToLocation(Loc)) {
5208 // Split the loop at the insertion point: Branch to the preheader and move
5209 // every following instruction to after the loop (the After BB). Also, the
5210 // new successor is the loop's after block.
5211 spliceBB(Builder, New: After, /*CreateBranch=*/false);
5212 Builder.CreateBr(Dest: CL->getPreheader());
5213 }
5214
5215 // Emit the body content. We do it after connecting the loop to the CFG to
5216 // avoid that the callback encounters degenerate BBs.
5217 if (Error Err = BodyGenCB(CL->getBodyIP(), CL->getIndVar()))
5218 return Err;
5219
5220#ifndef NDEBUG
5221 CL->assertOK();
5222#endif
5223 return CL;
5224}
5225
5226Expected<ScanInfo *> OpenMPIRBuilder::scanInfoInitialize() {
5227 ScanInfos.emplace_front();
5228 ScanInfo *Result = &ScanInfos.front();
5229 return Result;
5230}
5231
5232Expected<SmallVector<llvm::CanonicalLoopInfo *>>
5233OpenMPIRBuilder::createCanonicalScanLoops(
5234 const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
5235 Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
5236 InsertPointTy ComputeIP, const Twine &Name, ScanInfo *ScanRedInfo) {
5237 LocationDescription ComputeLoc =
5238 ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
5239 updateToLocation(Loc: ComputeLoc);
5240
5241 SmallVector<CanonicalLoopInfo *> Result;
5242
5243 Value *TripCount = calculateCanonicalLoopTripCount(
5244 Loc: ComputeLoc, Start, Stop, Step, IsSigned, InclusiveStop, Name);
5245 ScanRedInfo->Span = TripCount;
5246 ScanRedInfo->OMPScanInit = splitBB(Builder, CreateBranch: true, Name: "scan.init");
5247 Builder.SetInsertPoint(ScanRedInfo->OMPScanInit);
5248
5249 auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
5250 Builder.restoreIP(IP: CodeGenIP);
5251 ScanRedInfo->IV = IV;
5252 createScanBBs(ScanRedInfo);
5253 BasicBlock *InputBlock = Builder.GetInsertBlock();
5254 Instruction *Terminator = InputBlock->getTerminator();
5255 assert(Terminator->getNumSuccessors() == 1);
5256 BasicBlock *ContinueBlock = Terminator->getSuccessor(Idx: 0);
5257 Terminator->setSuccessor(Idx: 0, BB: ScanRedInfo->OMPScanDispatch);
5258 emitBlock(BB: ScanRedInfo->OMPBeforeScanBlock,
5259 CurFn: Builder.GetInsertBlock()->getParent());
5260 Builder.CreateBr(Dest: ScanRedInfo->OMPScanLoopExit);
5261 emitBlock(BB: ScanRedInfo->OMPScanLoopExit,
5262 CurFn: Builder.GetInsertBlock()->getParent());
5263 Builder.CreateBr(Dest: ContinueBlock);
5264 Builder.SetInsertPoint(
5265 ScanRedInfo->OMPBeforeScanBlock->getFirstInsertionPt());
5266 return BodyGenCB(Builder.saveIP(), IV);
5267 };
5268
5269 const auto &&InputLoopGen = [&]() -> Error {
5270 Expected<CanonicalLoopInfo *> LoopInfo = createCanonicalLoop(
5271 Loc: Builder.saveIP(), BodyGenCB: BodyGen, Start, Stop, Step, IsSigned, InclusiveStop,
5272 ComputeIP, Name, InScan: true, ScanRedInfo);
5273 if (!LoopInfo)
5274 return LoopInfo.takeError();
5275 Result.push_back(Elt: *LoopInfo);
5276 Builder.restoreIP(IP: (*LoopInfo)->getAfterIP());
5277 return Error::success();
5278 };
5279 const auto &&ScanLoopGen = [&](LocationDescription Loc) -> Error {
5280 Expected<CanonicalLoopInfo *> LoopInfo =
5281 createCanonicalLoop(Loc, BodyGenCB: BodyGen, Start, Stop, Step, IsSigned,
5282 InclusiveStop, ComputeIP, Name, InScan: true, ScanRedInfo);
5283 if (!LoopInfo)
5284 return LoopInfo.takeError();
5285 Result.push_back(Elt: *LoopInfo);
5286 Builder.restoreIP(IP: (*LoopInfo)->getAfterIP());
5287 ScanRedInfo->OMPScanFinish = Builder.GetInsertBlock();
5288 return Error::success();
5289 };
5290 Error Err = emitScanBasedDirectiveIR(InputLoopGen, ScanLoopGen, ScanRedInfo);
5291 if (Err)
5292 return Err;
5293 return Result;
5294}
5295
5296Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
5297 const LocationDescription &Loc, Value *Start, Value *Stop, Value *Step,
5298 bool IsSigned, bool InclusiveStop, const Twine &Name) {
5299
5300 // Consider the following difficulties (assuming 8-bit signed integers):
5301 // * Adding \p Step to the loop counter which passes \p Stop may overflow:
5302 // DO I = 1, 100, 50
5303 /// * A \p Step of INT_MIN cannot not be normalized to a positive direction:
5304 // DO I = 100, 0, -128
5305
5306 // Start, Stop and Step must be of the same integer type.
5307 auto *IndVarTy = cast<IntegerType>(Val: Start->getType());
5308 assert(IndVarTy == Stop->getType() && "Stop type mismatch");
5309 assert(IndVarTy == Step->getType() && "Step type mismatch");
5310
5311 updateToLocation(Loc);
5312
5313 ConstantInt *Zero = ConstantInt::get(Ty: IndVarTy, V: 0);
5314 ConstantInt *One = ConstantInt::get(Ty: IndVarTy, V: 1);
5315
5316 // Like Step, but always positive.
5317 Value *Incr = Step;
5318
5319 // Distance between Start and Stop; always positive.
5320 Value *Span;
5321
5322 // Condition whether there are no iterations are executed at all, e.g. because
5323 // UB < LB.
5324 Value *ZeroCmp;
5325
5326 if (IsSigned) {
5327 // Ensure that increment is positive. If not, negate and invert LB and UB.
5328 Value *IsNeg = Builder.CreateICmpSLT(LHS: Step, RHS: Zero);
5329 Incr = Builder.CreateSelect(C: IsNeg, True: Builder.CreateNeg(V: Step), False: Step);
5330 Value *LB = Builder.CreateSelect(C: IsNeg, True: Stop, False: Start);
5331 Value *UB = Builder.CreateSelect(C: IsNeg, True: Start, False: Stop);
5332 Span = Builder.CreateSub(LHS: UB, RHS: LB, Name: "", HasNUW: false, HasNSW: true);
5333 ZeroCmp = Builder.CreateICmp(
5334 P: InclusiveStop ? CmpInst::ICMP_SLT : CmpInst::ICMP_SLE, LHS: UB, RHS: LB);
5335 } else {
5336 Span = Builder.CreateSub(LHS: Stop, RHS: Start, Name: "", HasNUW: true);
5337 ZeroCmp = Builder.CreateICmp(
5338 P: InclusiveStop ? CmpInst::ICMP_ULT : CmpInst::ICMP_ULE, LHS: Stop, RHS: Start);
5339 }
5340
5341 Value *CountIfLooping;
5342 if (InclusiveStop) {
5343 CountIfLooping = Builder.CreateAdd(LHS: Builder.CreateUDiv(LHS: Span, RHS: Incr), RHS: One);
5344 } else {
5345 // Avoid incrementing past stop since it could overflow.
5346 Value *CountIfTwo = Builder.CreateAdd(
5347 LHS: Builder.CreateUDiv(LHS: Builder.CreateSub(LHS: Span, RHS: One), RHS: Incr), RHS: One);
5348 Value *OneCmp = Builder.CreateICmp(P: CmpInst::ICMP_ULE, LHS: Span, RHS: Incr);
5349 CountIfLooping = Builder.CreateSelect(C: OneCmp, True: One, False: CountIfTwo);
5350 }
5351
5352 return Builder.CreateSelect(C: ZeroCmp, True: Zero, False: CountIfLooping,
5353 Name: "omp_" + Name + ".tripcount");
5354}
5355
5356Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
5357 const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
5358 Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
5359 InsertPointTy ComputeIP, const Twine &Name, bool InScan,
5360 ScanInfo *ScanRedInfo) {
5361 LocationDescription ComputeLoc =
5362 ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
5363
5364 Value *TripCount = calculateCanonicalLoopTripCount(
5365 Loc: ComputeLoc, Start, Stop, Step, IsSigned, InclusiveStop, Name);
5366
5367 auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
5368 Builder.restoreIP(IP: CodeGenIP);
5369 Value *Span = Builder.CreateMul(LHS: IV, RHS: Step);
5370 Value *IndVar = Builder.CreateAdd(LHS: Span, RHS: Start);
5371 if (InScan)
5372 ScanRedInfo->IV = IndVar;
5373 return BodyGenCB(Builder.saveIP(), IndVar);
5374 };
5375 LocationDescription LoopLoc =
5376 ComputeIP.isSet()
5377 ? Loc
5378 : LocationDescription(Builder.saveIP(),
5379 Builder.getCurrentDebugLocation());
5380 return createCanonicalLoop(Loc: LoopLoc, BodyGenCB: BodyGen, TripCount, Name);
5381}
5382
5383// Returns an LLVM function to call for initializing loop bounds using OpenMP
5384// static scheduling for composite `distribute parallel for` depending on
5385// `type`. Only i32 and i64 are supported by the runtime. Always interpret
5386// integers as unsigned similarly to CanonicalLoopInfo.
5387static FunctionCallee
5388getKmpcDistForStaticInitForType(Type *Ty, Module &M,
5389 OpenMPIRBuilder &OMPBuilder) {
5390 unsigned Bitwidth = Ty->getIntegerBitWidth();
5391 if (Bitwidth == 32)
5392 return OMPBuilder.getOrCreateRuntimeFunction(
5393 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dist_for_static_init_4u);
5394 if (Bitwidth == 64)
5395 return OMPBuilder.getOrCreateRuntimeFunction(
5396 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dist_for_static_init_8u);
5397 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
5398}
5399
5400// Returns an LLVM function to call for initializing loop bounds using OpenMP
5401// static scheduling depending on `type`. Only i32 and i64 are supported by the
5402// runtime. Always interpret integers as unsigned similarly to
5403// CanonicalLoopInfo.
5404static FunctionCallee getKmpcForStaticInitForType(Type *Ty, Module &M,
5405 OpenMPIRBuilder &OMPBuilder) {
5406 unsigned Bitwidth = Ty->getIntegerBitWidth();
5407 if (Bitwidth == 32)
5408 return OMPBuilder.getOrCreateRuntimeFunction(
5409 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_4u);
5410 if (Bitwidth == 64)
5411 return OMPBuilder.getOrCreateRuntimeFunction(
5412 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_8u);
5413 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
5414}
5415
5416OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
5417 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
5418 WorksharingLoopType LoopType, bool NeedsBarrier, bool HasDistSchedule,
5419 OMPScheduleType DistScheduleSchedType) {
5420 assert(CLI->isValid() && "Requires a valid canonical loop");
5421 assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
5422 "Require dedicated allocate IP");
5423
5424 // Set up the source location value for OpenMP runtime.
5425 Builder.restoreIP(IP: CLI->getPreheaderIP());
5426 Builder.SetCurrentDebugLocation(DL);
5427
5428 uint32_t SrcLocStrSize;
5429 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
5430 Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5431
5432 // Declare useful OpenMP runtime functions.
5433 Value *IV = CLI->getIndVar();
5434 Type *IVTy = IV->getType();
5435 FunctionCallee StaticInit =
5436 LoopType == WorksharingLoopType::DistributeForStaticLoop
5437 ? getKmpcDistForStaticInitForType(Ty: IVTy, M, OMPBuilder&: *this)
5438 : getKmpcForStaticInitForType(Ty: IVTy, M, OMPBuilder&: *this);
5439 FunctionCallee StaticFini =
5440 getOrCreateRuntimeFunction(M, FnID: omp::OMPRTL___kmpc_for_static_fini);
5441
5442 // Allocate space for computed loop bounds as expected by the "init" function.
5443 Builder.SetInsertPoint(AllocaIP.getBlock()->getFirstNonPHIOrDbgOrAlloca());
5444
5445 Type *I32Type = Type::getInt32Ty(C&: M.getContext());
5446 Value *PLastIter = Builder.CreateAlloca(Ty: I32Type, ArraySize: nullptr, Name: "p.lastiter");
5447 Value *PLowerBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.lowerbound");
5448 Value *PUpperBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.upperbound");
5449 Value *PStride = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.stride");
5450 CLI->setLastIter(PLastIter);
5451
5452 // At the end of the preheader, prepare for calling the "init" function by
5453 // storing the current loop bounds into the allocated space. A canonical loop
5454 // always iterates from 0 to trip-count with step 1. Note that "init" expects
5455 // and produces an inclusive upper bound.
5456 Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
5457 Constant *Zero = ConstantInt::get(Ty: IVTy, V: 0);
5458 Constant *One = ConstantInt::get(Ty: IVTy, V: 1);
5459 Builder.CreateStore(Val: Zero, Ptr: PLowerBound);
5460 Value *UpperBound = Builder.CreateSub(LHS: CLI->getTripCount(), RHS: One);
5461 Builder.CreateStore(Val: UpperBound, Ptr: PUpperBound);
5462 Builder.CreateStore(Val: One, Ptr: PStride);
5463
5464 Value *ThreadNum = getOrCreateThreadID(Ident: SrcLoc);
5465
5466 OMPScheduleType SchedType =
5467 (LoopType == WorksharingLoopType::DistributeStaticLoop)
5468 ? OMPScheduleType::OrderedDistribute
5469 : OMPScheduleType::UnorderedStatic;
5470 Constant *SchedulingType =
5471 ConstantInt::get(Ty: I32Type, V: static_cast<int>(SchedType));
5472
5473 // Call the "init" function and update the trip count of the loop with the
5474 // value it produced.
5475 auto BuildInitCall = [LoopType, SrcLoc, ThreadNum, PLastIter, PLowerBound,
5476 PUpperBound, IVTy, PStride, One, Zero, StaticInit,
5477 this](Value *SchedulingType, auto &Builder) {
5478 SmallVector<Value *, 10> Args({SrcLoc, ThreadNum, SchedulingType, PLastIter,
5479 PLowerBound, PUpperBound});
5480 if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
5481 Value *PDistUpperBound =
5482 Builder.CreateAlloca(IVTy, nullptr, "p.distupperbound");
5483 Args.push_back(Elt: PDistUpperBound);
5484 }
5485 Args.append(IL: {PStride, One, Zero});
5486 createRuntimeFunctionCall(Callee: StaticInit, Args);
5487 };
5488 BuildInitCall(SchedulingType, Builder);
5489 if (HasDistSchedule &&
5490 LoopType != WorksharingLoopType::DistributeStaticLoop) {
5491 Constant *DistScheduleSchedType = ConstantInt::get(
5492 Ty: I32Type, V: static_cast<int>(omp::OMPScheduleType::OrderedDistribute));
5493 // We want to emit a second init function call for the dist_schedule clause
5494 // to the Distribute construct. This should only be done however if a
5495 // Workshare Loop is nested within a Distribute Construct
5496 BuildInitCall(DistScheduleSchedType, Builder);
5497 }
5498 Value *LowerBound = Builder.CreateLoad(Ty: IVTy, Ptr: PLowerBound);
5499 Value *InclusiveUpperBound = Builder.CreateLoad(Ty: IVTy, Ptr: PUpperBound);
5500 Value *TripCountMinusOne = Builder.CreateSub(LHS: InclusiveUpperBound, RHS: LowerBound);
5501 Value *TripCount = Builder.CreateAdd(LHS: TripCountMinusOne, RHS: One);
5502 CLI->setTripCount(TripCount);
5503
5504 // Update all uses of the induction variable except the one in the condition
5505 // block that compares it with the actual upper bound, and the increment in
5506 // the latch block.
5507
5508 CLI->mapIndVar(Updater: [&](Instruction *OldIV) -> Value * {
5509 Builder.SetInsertPoint(TheBB: CLI->getBody(),
5510 IP: CLI->getBody()->getFirstInsertionPt());
5511 Builder.SetCurrentDebugLocation(DL);
5512 return Builder.CreateAdd(LHS: OldIV, RHS: LowerBound);
5513 });
5514
5515 // In the "exit" block, call the "fini" function.
5516 Builder.SetInsertPoint(TheBB: CLI->getExit(),
5517 IP: CLI->getExit()->getTerminator()->getIterator());
5518 createRuntimeFunctionCall(Callee: StaticFini, Args: {SrcLoc, ThreadNum});
5519
5520 // Add the barrier if requested.
5521 if (NeedsBarrier) {
5522 InsertPointOrErrorTy BarrierIP =
5523 createBarrier(Loc: LocationDescription(Builder.saveIP(), DL),
5524 Kind: omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
5525 /* CheckCancelFlag */ false);
5526 if (!BarrierIP)
5527 return BarrierIP.takeError();
5528 }
5529
5530 InsertPointTy AfterIP = CLI->getAfterIP();
5531 CLI->invalidate();
5532
5533 return AfterIP;
5534}
5535
5536static void addAccessGroupMetadata(BasicBlock *Block, MDNode *AccessGroup,
5537 LoopInfo &LI);
5538static void addLoopMetadata(CanonicalLoopInfo *Loop,
5539 ArrayRef<Metadata *> Properties);
5540
5541static void applyParallelAccessesMetadata(CanonicalLoopInfo *CLI,
5542 LLVMContext &Ctx, Loop *Loop,
5543 LoopInfo &LoopInfo,
5544 SmallVector<Metadata *> &LoopMDList) {
5545 SmallSet<BasicBlock *, 8> Reachable;
5546
5547 // Get the basic blocks from the loop in which memref instructions
5548 // can be found.
5549 // TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
5550 // preferably without running any passes.
5551 for (BasicBlock *Block : Loop->getBlocks()) {
5552 if (Block == CLI->getCond() || Block == CLI->getHeader())
5553 continue;
5554 Reachable.insert(Ptr: Block);
5555 }
5556
5557 // Add access group metadata to memory-access instructions.
5558 MDNode *AccessGroup = MDNode::getDistinct(Context&: Ctx, MDs: {});
5559 for (BasicBlock *BB : Reachable)
5560 addAccessGroupMetadata(Block: BB, AccessGroup, LI&: LoopInfo);
5561 // TODO: If the loop has existing parallel access metadata, have
5562 // to combine two lists.
5563 LoopMDList.push_back(Elt: MDNode::get(
5564 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.parallel_accesses"), AccessGroup}));
5565}
5566
5567OpenMPIRBuilder::InsertPointOrErrorTy
5568OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
5569 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
5570 bool NeedsBarrier, Value *ChunkSize, OMPScheduleType SchedType,
5571 Value *DistScheduleChunkSize, OMPScheduleType DistScheduleSchedType) {
5572 assert(CLI->isValid() && "Requires a valid canonical loop");
5573 assert((ChunkSize || DistScheduleChunkSize) && "Chunk size is required");
5574
5575 LLVMContext &Ctx = CLI->getFunction()->getContext();
5576 Value *IV = CLI->getIndVar();
5577 Value *OrigTripCount = CLI->getTripCount();
5578 Type *IVTy = IV->getType();
5579 assert(IVTy->getIntegerBitWidth() <= 64 &&
5580 "Max supported tripcount bitwidth is 64 bits");
5581 Type *InternalIVTy = IVTy->getIntegerBitWidth() <= 32 ? Type::getInt32Ty(C&: Ctx)
5582 : Type::getInt64Ty(C&: Ctx);
5583 Type *I32Type = Type::getInt32Ty(C&: M.getContext());
5584 Constant *Zero = ConstantInt::get(Ty: InternalIVTy, V: 0);
5585 Constant *One = ConstantInt::get(Ty: InternalIVTy, V: 1);
5586
5587 Function *F = CLI->getFunction();
5588 FunctionAnalysisManager FAM;
5589 FAM.registerPass(PassBuilder: []() { return DominatorTreeAnalysis(); });
5590 FAM.registerPass(PassBuilder: []() { return PassInstrumentationAnalysis(); });
5591 LoopAnalysis LIA;
5592 LoopInfo &&LI = LIA.run(F&: *F, AM&: FAM);
5593 Loop *L = LI.getLoopFor(BB: CLI->getHeader());
5594 SmallVector<Metadata *> LoopMDList;
5595 if (ChunkSize || DistScheduleChunkSize)
5596 applyParallelAccessesMetadata(CLI, Ctx, Loop: L, LoopInfo&: LI, LoopMDList);
5597 addLoopMetadata(Loop: CLI, Properties: LoopMDList);
5598
5599 // Declare useful OpenMP runtime functions.
5600 FunctionCallee StaticInit =
5601 getKmpcForStaticInitForType(Ty: InternalIVTy, M, OMPBuilder&: *this);
5602 FunctionCallee StaticFini =
5603 getOrCreateRuntimeFunction(M, FnID: omp::OMPRTL___kmpc_for_static_fini);
5604
5605 // Allocate space for computed loop bounds as expected by the "init" function.
5606 Builder.restoreIP(IP: AllocaIP);
5607 Builder.SetCurrentDebugLocation(DL);
5608 Value *PLastIter = Builder.CreateAlloca(Ty: I32Type, ArraySize: nullptr, Name: "p.lastiter");
5609 Value *PLowerBound =
5610 Builder.CreateAlloca(Ty: InternalIVTy, ArraySize: nullptr, Name: "p.lowerbound");
5611 Value *PUpperBound =
5612 Builder.CreateAlloca(Ty: InternalIVTy, ArraySize: nullptr, Name: "p.upperbound");
5613 Value *PStride = Builder.CreateAlloca(Ty: InternalIVTy, ArraySize: nullptr, Name: "p.stride");
5614 CLI->setLastIter(PLastIter);
5615
5616 // Set up the source location value for the OpenMP runtime.
5617 Builder.restoreIP(IP: CLI->getPreheaderIP());
5618 Builder.SetCurrentDebugLocation(DL);
5619
5620 // TODO: Detect overflow in ubsan or max-out with current tripcount.
5621 Value *CastedChunkSize = Builder.CreateZExtOrTrunc(
5622 V: ChunkSize ? ChunkSize : Zero, DestTy: InternalIVTy, Name: "chunksize");
5623 Value *CastedDistScheduleChunkSize = Builder.CreateZExtOrTrunc(
5624 V: DistScheduleChunkSize ? DistScheduleChunkSize : Zero, DestTy: InternalIVTy,
5625 Name: "distschedulechunksize");
5626 Value *CastedTripCount =
5627 Builder.CreateZExt(V: OrigTripCount, DestTy: InternalIVTy, Name: "tripcount");
5628
5629 Constant *SchedulingType =
5630 ConstantInt::get(Ty: I32Type, V: static_cast<int>(SchedType));
5631 Constant *DistSchedulingType =
5632 ConstantInt::get(Ty: I32Type, V: static_cast<int>(DistScheduleSchedType));
5633 Builder.CreateStore(Val: Zero, Ptr: PLowerBound);
5634 Value *OrigUpperBound = Builder.CreateSub(LHS: CastedTripCount, RHS: One);
5635 Value *IsTripCountZero = Builder.CreateICmpEQ(LHS: CastedTripCount, RHS: Zero);
5636 Value *UpperBound =
5637 Builder.CreateSelect(C: IsTripCountZero, True: Zero, False: OrigUpperBound);
5638 Builder.CreateStore(Val: UpperBound, Ptr: PUpperBound);
5639 Builder.CreateStore(Val: One, Ptr: PStride);
5640
5641 // Call the "init" function and update the trip count of the loop with the
5642 // value it produced.
5643 uint32_t SrcLocStrSize;
5644 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
5645 Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5646 Value *ThreadNum = getOrCreateThreadID(Ident: SrcLoc);
5647 auto BuildInitCall = [StaticInit, SrcLoc, ThreadNum, PLastIter, PLowerBound,
5648 PUpperBound, PStride, One,
5649 this](Value *SchedulingType, Value *ChunkSize,
5650 auto &Builder) {
5651 createRuntimeFunctionCall(
5652 Callee: StaticInit, Args: {/*loc=*/SrcLoc, /*global_tid=*/ThreadNum,
5653 /*schedtype=*/SchedulingType, /*plastiter=*/PLastIter,
5654 /*plower=*/PLowerBound, /*pupper=*/PUpperBound,
5655 /*pstride=*/PStride, /*incr=*/One,
5656 /*chunk=*/ChunkSize});
5657 };
5658 BuildInitCall(SchedulingType, CastedChunkSize, Builder);
5659 if (DistScheduleSchedType != OMPScheduleType::None &&
5660 SchedType != OMPScheduleType::OrderedDistributeChunked &&
5661 SchedType != OMPScheduleType::OrderedDistribute) {
5662 // We want to emit a second init function call for the dist_schedule clause
5663 // to the Distribute construct. This should only be done however if a
5664 // Workshare Loop is nested within a Distribute Construct
5665 BuildInitCall(DistSchedulingType, CastedDistScheduleChunkSize, Builder);
5666 }
5667
5668 // Load values written by the "init" function.
5669 Value *FirstChunkStart =
5670 Builder.CreateLoad(Ty: InternalIVTy, Ptr: PLowerBound, Name: "omp_firstchunk.lb");
5671 Value *FirstChunkStop =
5672 Builder.CreateLoad(Ty: InternalIVTy, Ptr: PUpperBound, Name: "omp_firstchunk.ub");
5673 Value *FirstChunkEnd = Builder.CreateAdd(LHS: FirstChunkStop, RHS: One);
5674 Value *ChunkRange =
5675 Builder.CreateSub(LHS: FirstChunkEnd, RHS: FirstChunkStart, Name: "omp_chunk.range");
5676 Value *NextChunkStride =
5677 Builder.CreateLoad(Ty: InternalIVTy, Ptr: PStride, Name: "omp_dispatch.stride");
5678
5679 // Create outer "dispatch" loop for enumerating the chunks.
5680 BasicBlock *DispatchEnter = splitBB(Builder, CreateBranch: true);
5681 Value *DispatchCounter;
5682
5683 // It is safe to assume this didn't return an error because the callback
5684 // passed into createCanonicalLoop is the only possible error source, and it
5685 // always returns success.
5686 CanonicalLoopInfo *DispatchCLI = cantFail(ValOrErr: createCanonicalLoop(
5687 Loc: {Builder.saveIP(), DL},
5688 BodyGenCB: [&](InsertPointTy BodyIP, Value *Counter) {
5689 DispatchCounter = Counter;
5690 return Error::success();
5691 },
5692 Start: FirstChunkStart, Stop: CastedTripCount, Step: NextChunkStride,
5693 /*IsSigned=*/false, /*InclusiveStop=*/false, /*ComputeIP=*/{},
5694 Name: "dispatch"));
5695
5696 // Remember the BasicBlocks of the dispatch loop we need, then invalidate to
5697 // not have to preserve the canonical invariant.
5698 BasicBlock *DispatchBody = DispatchCLI->getBody();
5699 BasicBlock *DispatchLatch = DispatchCLI->getLatch();
5700 BasicBlock *DispatchExit = DispatchCLI->getExit();
5701 BasicBlock *DispatchAfter = DispatchCLI->getAfter();
5702 DispatchCLI->invalidate();
5703
5704 // Rewire the original loop to become the chunk loop inside the dispatch loop.
5705 redirectTo(Source: DispatchAfter, Target: CLI->getAfter(), DL);
5706 redirectTo(Source: CLI->getExit(), Target: DispatchLatch, DL);
5707 redirectTo(Source: DispatchBody, Target: DispatchEnter, DL);
5708
5709 // Prepare the prolog of the chunk loop.
5710 Builder.restoreIP(IP: CLI->getPreheaderIP());
5711 Builder.SetCurrentDebugLocation(DL);
5712
5713 // Compute the number of iterations of the chunk loop.
5714 Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
5715 Value *ChunkEnd = Builder.CreateAdd(LHS: DispatchCounter, RHS: ChunkRange);
5716 Value *IsLastChunk =
5717 Builder.CreateICmpUGE(LHS: ChunkEnd, RHS: CastedTripCount, Name: "omp_chunk.is_last");
5718 Value *CountUntilOrigTripCount =
5719 Builder.CreateSub(LHS: CastedTripCount, RHS: DispatchCounter);
5720 Value *ChunkTripCount = Builder.CreateSelect(
5721 C: IsLastChunk, True: CountUntilOrigTripCount, False: ChunkRange, Name: "omp_chunk.tripcount");
5722 Value *BackcastedChunkTC =
5723 Builder.CreateTrunc(V: ChunkTripCount, DestTy: IVTy, Name: "omp_chunk.tripcount.trunc");
5724 CLI->setTripCount(BackcastedChunkTC);
5725
5726 // Update all uses of the induction variable except the one in the condition
5727 // block that compares it with the actual upper bound, and the increment in
5728 // the latch block.
5729 Value *BackcastedDispatchCounter =
5730 Builder.CreateTrunc(V: DispatchCounter, DestTy: IVTy, Name: "omp_dispatch.iv.trunc");
5731 CLI->mapIndVar(Updater: [&](Instruction *) -> Value * {
5732 Builder.restoreIP(IP: CLI->getBodyIP());
5733 return Builder.CreateAdd(LHS: IV, RHS: BackcastedDispatchCounter);
5734 });
5735
5736 // In the "exit" block, call the "fini" function.
5737 Builder.SetInsertPoint(TheBB: DispatchExit, IP: DispatchExit->getFirstInsertionPt());
5738 createRuntimeFunctionCall(Callee: StaticFini, Args: {SrcLoc, ThreadNum});
5739
5740 // Add the barrier if requested.
5741 if (NeedsBarrier) {
5742 InsertPointOrErrorTy AfterIP =
5743 createBarrier(Loc: LocationDescription(Builder.saveIP(), DL), Kind: OMPD_for,
5744 /*ForceSimpleCall=*/false, /*CheckCancelFlag=*/false);
5745 if (!AfterIP)
5746 return AfterIP.takeError();
5747 }
5748
5749#ifndef NDEBUG
5750 // Even though we currently do not support applying additional methods to it,
5751 // the chunk loop should remain a canonical loop.
5752 CLI->assertOK();
5753#endif
5754
5755 return InsertPointTy(DispatchAfter, DispatchAfter->getFirstInsertionPt());
5756}
5757
5758// Returns an LLVM function to call for executing an OpenMP static worksharing
5759// for loop depending on `type`. Only i32 and i64 are supported by the runtime.
5760// Always interpret integers as unsigned similarly to CanonicalLoopInfo.
5761static FunctionCallee
5762getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
5763 WorksharingLoopType LoopType) {
5764 unsigned Bitwidth = Ty->getIntegerBitWidth();
5765 Module &M = OMPBuilder->M;
5766 switch (LoopType) {
5767 case WorksharingLoopType::ForStaticLoop:
5768 if (Bitwidth == 32)
5769 return OMPBuilder->getOrCreateRuntimeFunction(
5770 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_4u);
5771 if (Bitwidth == 64)
5772 return OMPBuilder->getOrCreateRuntimeFunction(
5773 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_8u);
5774 break;
5775 case WorksharingLoopType::DistributeStaticLoop:
5776 if (Bitwidth == 32)
5777 return OMPBuilder->getOrCreateRuntimeFunction(
5778 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_4u);
5779 if (Bitwidth == 64)
5780 return OMPBuilder->getOrCreateRuntimeFunction(
5781 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_8u);
5782 break;
5783 case WorksharingLoopType::DistributeForStaticLoop:
5784 if (Bitwidth == 32)
5785 return OMPBuilder->getOrCreateRuntimeFunction(
5786 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_4u);
5787 if (Bitwidth == 64)
5788 return OMPBuilder->getOrCreateRuntimeFunction(
5789 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_8u);
5790 break;
5791 }
5792 if (Bitwidth != 32 && Bitwidth != 64) {
5793 llvm_unreachable("Unknown OpenMP loop iterator bitwidth");
5794 }
5795 llvm_unreachable("Unknown type of OpenMP worksharing loop");
5796}
5797
5798// Inserts a call to proper OpenMP Device RTL function which handles
5799// loop worksharing.
5800static void createTargetLoopWorkshareCall(OpenMPIRBuilder *OMPBuilder,
5801 WorksharingLoopType LoopType,
5802 BasicBlock *InsertBlock, Value *Ident,
5803 Value *LoopBodyArg, Value *TripCount,
5804 Function &LoopBodyFn, bool NoLoop) {
5805 Type *TripCountTy = TripCount->getType();
5806 Module &M = OMPBuilder->M;
5807 IRBuilder<> &Builder = OMPBuilder->Builder;
5808 FunctionCallee RTLFn =
5809 getKmpcForStaticLoopForType(Ty: TripCountTy, OMPBuilder, LoopType);
5810 SmallVector<Value *, 8> RealArgs;
5811 RealArgs.push_back(Elt: Ident);
5812 RealArgs.push_back(Elt: &LoopBodyFn);
5813 RealArgs.push_back(Elt: LoopBodyArg);
5814 RealArgs.push_back(Elt: TripCount);
5815 if (LoopType == WorksharingLoopType::DistributeStaticLoop) {
5816 RealArgs.push_back(Elt: ConstantInt::get(Ty: TripCountTy, V: 0));
5817 RealArgs.push_back(Elt: ConstantInt::get(Ty: Builder.getInt8Ty(), V: 0));
5818 Builder.restoreIP(IP: {InsertBlock, std::prev(x: InsertBlock->end())});
5819 OMPBuilder->createRuntimeFunctionCall(Callee: RTLFn, Args: RealArgs);
5820 return;
5821 }
5822 FunctionCallee RTLNumThreads = OMPBuilder->getOrCreateRuntimeFunction(
5823 M, FnID: omp::RuntimeFunction::OMPRTL_omp_get_num_threads);
5824 Builder.restoreIP(IP: {InsertBlock, std::prev(x: InsertBlock->end())});
5825 Value *NumThreads = OMPBuilder->createRuntimeFunctionCall(Callee: RTLNumThreads, Args: {});
5826
5827 RealArgs.push_back(
5828 Elt: Builder.CreateZExtOrTrunc(V: NumThreads, DestTy: TripCountTy, Name: "num.threads.cast"));
5829 RealArgs.push_back(Elt: ConstantInt::get(Ty: TripCountTy, V: 0));
5830 if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
5831 RealArgs.push_back(Elt: ConstantInt::get(Ty: TripCountTy, V: 0));
5832 RealArgs.push_back(Elt: ConstantInt::get(Ty: Builder.getInt8Ty(), V: NoLoop));
5833 } else {
5834 RealArgs.push_back(Elt: ConstantInt::get(Ty: Builder.getInt8Ty(), V: 0));
5835 }
5836
5837 OMPBuilder->createRuntimeFunctionCall(Callee: RTLFn, Args: RealArgs);
5838}
5839
5840static void workshareLoopTargetCallback(
5841 OpenMPIRBuilder *OMPIRBuilder, CanonicalLoopInfo *CLI, Value *Ident,
5842 Function &OutlinedFn, const SmallVector<Instruction *, 4> &ToBeDeleted,
5843 WorksharingLoopType LoopType, bool NoLoop) {
5844 IRBuilder<> &Builder = OMPIRBuilder->Builder;
5845 BasicBlock *Preheader = CLI->getPreheader();
5846 Value *TripCount = CLI->getTripCount();
5847
5848 // After loop body outling, the loop body contains only set up
5849 // of loop body argument structure and the call to the outlined
5850 // loop body function. Firstly, we need to move setup of loop body args
5851 // into loop preheader.
5852 Preheader->splice(ToIt: std::prev(x: Preheader->end()), FromBB: CLI->getBody(),
5853 FromBeginIt: CLI->getBody()->begin(), FromEndIt: std::prev(x: CLI->getBody()->end()));
5854
5855 // The next step is to remove the whole loop. We do not it need anymore.
5856 // That's why make an unconditional branch from loop preheader to loop
5857 // exit block
5858 Builder.restoreIP(IP: {Preheader, Preheader->end()});
5859 Builder.SetCurrentDebugLocation(Preheader->getTerminator()->getDebugLoc());
5860 Preheader->getTerminator()->eraseFromParent();
5861 Builder.CreateBr(Dest: CLI->getExit());
5862
5863 // Delete dead loop blocks
5864 OpenMPIRBuilder::OutlineInfo CleanUpInfo;
5865 SmallPtrSet<BasicBlock *, 32> RegionBlockSet;
5866 SmallVector<BasicBlock *, 32> BlocksToBeRemoved;
5867 CleanUpInfo.EntryBB = CLI->getHeader();
5868 CleanUpInfo.ExitBB = CLI->getExit();
5869 CleanUpInfo.collectBlocks(BlockSet&: RegionBlockSet, BlockVector&: BlocksToBeRemoved);
5870 DeleteDeadBlocks(BBs: BlocksToBeRemoved);
5871
5872 // Find the instruction which corresponds to loop body argument structure
5873 // and remove the call to loop body function instruction.
5874 Value *LoopBodyArg;
5875 User *OutlinedFnUser = OutlinedFn.getUniqueUndroppableUser();
5876 assert(OutlinedFnUser &&
5877 "Expected unique undroppable user of outlined function");
5878 CallInst *OutlinedFnCallInstruction = dyn_cast<CallInst>(Val: OutlinedFnUser);
5879 assert(OutlinedFnCallInstruction && "Expected outlined function call");
5880 assert((OutlinedFnCallInstruction->getParent() == Preheader) &&
5881 "Expected outlined function call to be located in loop preheader");
5882 // Check in case no argument structure has been passed.
5883 if (OutlinedFnCallInstruction->arg_size() > 1)
5884 LoopBodyArg = OutlinedFnCallInstruction->getArgOperand(i: 1);
5885 else
5886 LoopBodyArg = Constant::getNullValue(Ty: Builder.getPtrTy());
5887 OutlinedFnCallInstruction->eraseFromParent();
5888
5889 createTargetLoopWorkshareCall(OMPBuilder: OMPIRBuilder, LoopType, InsertBlock: Preheader, Ident,
5890 LoopBodyArg, TripCount, LoopBodyFn&: OutlinedFn, NoLoop);
5891
5892 for (auto &ToBeDeletedItem : ToBeDeleted)
5893 ToBeDeletedItem->eraseFromParent();
5894 CLI->invalidate();
5895}
5896
5897OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoopTarget(
5898 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
5899 WorksharingLoopType LoopType, bool NoLoop) {
5900 uint32_t SrcLocStrSize;
5901 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
5902 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5903
5904 OutlineInfo OI;
5905 OI.OuterAllocaBB = CLI->getPreheader();
5906 Function *OuterFn = CLI->getPreheader()->getParent();
5907
5908 // Instructions which need to be deleted at the end of code generation
5909 SmallVector<Instruction *, 4> ToBeDeleted;
5910
5911 OI.OuterAllocaBB = AllocaIP.getBlock();
5912
5913 // Mark the body loop as region which needs to be extracted
5914 OI.EntryBB = CLI->getBody();
5915 OI.ExitBB = CLI->getLatch()->splitBasicBlockBefore(I: CLI->getLatch()->begin(),
5916 BBName: "omp.prelatch");
5917
5918 // Prepare loop body for extraction
5919 Builder.restoreIP(IP: {CLI->getPreheader(), CLI->getPreheader()->begin()});
5920
5921 // Insert new loop counter variable which will be used only in loop
5922 // body.
5923 AllocaInst *NewLoopCnt = Builder.CreateAlloca(Ty: CLI->getIndVarType(), ArraySize: 0, Name: "");
5924 Instruction *NewLoopCntLoad =
5925 Builder.CreateLoad(Ty: CLI->getIndVarType(), Ptr: NewLoopCnt);
5926 // New loop counter instructions are redundant in the loop preheader when
5927 // code generation for workshare loop is finshed. That's why mark them as
5928 // ready for deletion.
5929 ToBeDeleted.push_back(Elt: NewLoopCntLoad);
5930 ToBeDeleted.push_back(Elt: NewLoopCnt);
5931
5932 // Analyse loop body region. Find all input variables which are used inside
5933 // loop body region.
5934 SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
5935 SmallVector<BasicBlock *, 32> Blocks;
5936 OI.collectBlocks(BlockSet&: ParallelRegionBlockSet, BlockVector&: Blocks);
5937
5938 CodeExtractorAnalysisCache CEAC(*OuterFn);
5939 CodeExtractor Extractor(Blocks,
5940 /* DominatorTree */ nullptr,
5941 /* AggregateArgs */ true,
5942 /* BlockFrequencyInfo */ nullptr,
5943 /* BranchProbabilityInfo */ nullptr,
5944 /* AssumptionCache */ nullptr,
5945 /* AllowVarArgs */ true,
5946 /* AllowAlloca */ true,
5947 /* AllocationBlock */ CLI->getPreheader(),
5948 /* Suffix */ ".omp_wsloop",
5949 /* AggrArgsIn0AddrSpace */ true);
5950
5951 BasicBlock *CommonExit = nullptr;
5952 SetVector<Value *> SinkingCands, HoistingCands;
5953
5954 // Find allocas outside the loop body region which are used inside loop
5955 // body
5956 Extractor.findAllocas(CEAC, SinkCands&: SinkingCands, HoistCands&: HoistingCands, ExitBlock&: CommonExit);
5957
5958 // We need to model loop body region as the function f(cnt, loop_arg).
5959 // That's why we replace loop induction variable by the new counter
5960 // which will be one of loop body function argument
5961 SmallVector<User *> Users(CLI->getIndVar()->user_begin(),
5962 CLI->getIndVar()->user_end());
5963 for (auto Use : Users) {
5964 if (Instruction *Inst = dyn_cast<Instruction>(Val: Use)) {
5965 if (ParallelRegionBlockSet.count(Ptr: Inst->getParent())) {
5966 Inst->replaceUsesOfWith(From: CLI->getIndVar(), To: NewLoopCntLoad);
5967 }
5968 }
5969 }
5970 // Make sure that loop counter variable is not merged into loop body
5971 // function argument structure and it is passed as separate variable
5972 OI.ExcludeArgsFromAggregate.push_back(Elt: NewLoopCntLoad);
5973
5974 // PostOutline CB is invoked when loop body function is outlined and
5975 // loop body is replaced by call to outlined function. We need to add
5976 // call to OpenMP device rtl inside loop preheader. OpenMP device rtl
5977 // function will handle loop control logic.
5978 //
5979 OI.PostOutlineCB = [=, ToBeDeletedVec =
5980 std::move(ToBeDeleted)](Function &OutlinedFn) {
5981 workshareLoopTargetCallback(OMPIRBuilder: this, CLI, Ident, OutlinedFn, ToBeDeleted: ToBeDeletedVec,
5982 LoopType, NoLoop);
5983 };
5984 addOutlineInfo(OI: std::move(OI));
5985 return CLI->getAfterIP();
5986}
5987
5988OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyWorkshareLoop(
5989 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
5990 bool NeedsBarrier, omp::ScheduleKind SchedKind, Value *ChunkSize,
5991 bool HasSimdModifier, bool HasMonotonicModifier,
5992 bool HasNonmonotonicModifier, bool HasOrderedClause,
5993 WorksharingLoopType LoopType, bool NoLoop, bool HasDistSchedule,
5994 Value *DistScheduleChunkSize) {
5995 if (Config.isTargetDevice())
5996 return applyWorkshareLoopTarget(DL, CLI, AllocaIP, LoopType, NoLoop);
5997 OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType(
5998 ClauseKind: SchedKind, HasChunks: ChunkSize, HasSimdModifier, HasMonotonicModifier,
5999 HasNonmonotonicModifier, HasOrderedClause, HasDistScheduleChunks: DistScheduleChunkSize);
6000
6001 bool IsOrdered = (EffectiveScheduleType & OMPScheduleType::ModifierOrdered) ==
6002 OMPScheduleType::ModifierOrdered;
6003 OMPScheduleType DistScheduleSchedType = OMPScheduleType::None;
6004 if (HasDistSchedule) {
6005 DistScheduleSchedType = DistScheduleChunkSize
6006 ? OMPScheduleType::OrderedDistributeChunked
6007 : OMPScheduleType::OrderedDistribute;
6008 }
6009 switch (EffectiveScheduleType & ~OMPScheduleType::ModifierMask) {
6010 case OMPScheduleType::BaseStatic:
6011 case OMPScheduleType::BaseDistribute:
6012 assert((!ChunkSize || !DistScheduleChunkSize) &&
6013 "No chunk size with static-chunked schedule");
6014 if (IsOrdered && !HasDistSchedule)
6015 return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, SchedType: EffectiveScheduleType,
6016 NeedsBarrier, Chunk: ChunkSize);
6017 // FIXME: Monotonicity ignored?
6018 if (DistScheduleChunkSize)
6019 return applyStaticChunkedWorkshareLoop(
6020 DL, CLI, AllocaIP, NeedsBarrier, ChunkSize, SchedType: EffectiveScheduleType,
6021 DistScheduleChunkSize, DistScheduleSchedType);
6022 return applyStaticWorkshareLoop(DL, CLI, AllocaIP, LoopType, NeedsBarrier,
6023 HasDistSchedule);
6024
6025 case OMPScheduleType::BaseStaticChunked:
6026 case OMPScheduleType::BaseDistributeChunked:
6027 if (IsOrdered && !HasDistSchedule)
6028 return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, SchedType: EffectiveScheduleType,
6029 NeedsBarrier, Chunk: ChunkSize);
6030 // FIXME: Monotonicity ignored?
6031 return applyStaticChunkedWorkshareLoop(
6032 DL, CLI, AllocaIP, NeedsBarrier, ChunkSize, SchedType: EffectiveScheduleType,
6033 DistScheduleChunkSize, DistScheduleSchedType);
6034
6035 case OMPScheduleType::BaseRuntime:
6036 case OMPScheduleType::BaseAuto:
6037 case OMPScheduleType::BaseGreedy:
6038 case OMPScheduleType::BaseBalanced:
6039 case OMPScheduleType::BaseSteal:
6040 case OMPScheduleType::BaseRuntimeSimd:
6041 assert(!ChunkSize &&
6042 "schedule type does not support user-defined chunk sizes");
6043 [[fallthrough]];
6044 case OMPScheduleType::BaseGuidedSimd:
6045 case OMPScheduleType::BaseDynamicChunked:
6046 case OMPScheduleType::BaseGuidedChunked:
6047 case OMPScheduleType::BaseGuidedIterativeChunked:
6048 case OMPScheduleType::BaseGuidedAnalyticalChunked:
6049 case OMPScheduleType::BaseStaticBalancedChunked:
6050 return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, SchedType: EffectiveScheduleType,
6051 NeedsBarrier, Chunk: ChunkSize);
6052
6053 default:
6054 llvm_unreachable("Unknown/unimplemented schedule kind");
6055 }
6056}
6057
6058/// Returns an LLVM function to call for initializing loop bounds using OpenMP
6059/// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
6060/// the runtime. Always interpret integers as unsigned similarly to
6061/// CanonicalLoopInfo.
6062static FunctionCallee
6063getKmpcForDynamicInitForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
6064 unsigned Bitwidth = Ty->getIntegerBitWidth();
6065 if (Bitwidth == 32)
6066 return OMPBuilder.getOrCreateRuntimeFunction(
6067 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_4u);
6068 if (Bitwidth == 64)
6069 return OMPBuilder.getOrCreateRuntimeFunction(
6070 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_8u);
6071 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
6072}
6073
6074/// Returns an LLVM function to call for updating the next loop using OpenMP
6075/// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
6076/// the runtime. Always interpret integers as unsigned similarly to
6077/// CanonicalLoopInfo.
6078static FunctionCallee
6079getKmpcForDynamicNextForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
6080 unsigned Bitwidth = Ty->getIntegerBitWidth();
6081 if (Bitwidth == 32)
6082 return OMPBuilder.getOrCreateRuntimeFunction(
6083 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_4u);
6084 if (Bitwidth == 64)
6085 return OMPBuilder.getOrCreateRuntimeFunction(
6086 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_8u);
6087 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
6088}
6089
6090/// Returns an LLVM function to call for finalizing the dynamic loop using
6091/// depending on `type`. Only i32 and i64 are supported by the runtime. Always
6092/// interpret integers as unsigned similarly to CanonicalLoopInfo.
6093static FunctionCallee
6094getKmpcForDynamicFiniForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
6095 unsigned Bitwidth = Ty->getIntegerBitWidth();
6096 if (Bitwidth == 32)
6097 return OMPBuilder.getOrCreateRuntimeFunction(
6098 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_4u);
6099 if (Bitwidth == 64)
6100 return OMPBuilder.getOrCreateRuntimeFunction(
6101 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_8u);
6102 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
6103}
6104
6105OpenMPIRBuilder::InsertPointOrErrorTy
6106OpenMPIRBuilder::applyDynamicWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
6107 InsertPointTy AllocaIP,
6108 OMPScheduleType SchedType,
6109 bool NeedsBarrier, Value *Chunk) {
6110 assert(CLI->isValid() && "Requires a valid canonical loop");
6111 assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
6112 "Require dedicated allocate IP");
6113 assert(isValidWorkshareLoopScheduleType(SchedType) &&
6114 "Require valid schedule type");
6115
6116 bool Ordered = (SchedType & OMPScheduleType::ModifierOrdered) ==
6117 OMPScheduleType::ModifierOrdered;
6118
6119 // Set up the source location value for OpenMP runtime.
6120 Builder.SetCurrentDebugLocation(DL);
6121
6122 uint32_t SrcLocStrSize;
6123 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
6124 Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6125
6126 // Declare useful OpenMP runtime functions.
6127 Value *IV = CLI->getIndVar();
6128 Type *IVTy = IV->getType();
6129 FunctionCallee DynamicInit = getKmpcForDynamicInitForType(Ty: IVTy, M, OMPBuilder&: *this);
6130 FunctionCallee DynamicNext = getKmpcForDynamicNextForType(Ty: IVTy, M, OMPBuilder&: *this);
6131
6132 // Allocate space for computed loop bounds as expected by the "init" function.
6133 Builder.SetInsertPoint(AllocaIP.getBlock()->getFirstNonPHIOrDbgOrAlloca());
6134 Type *I32Type = Type::getInt32Ty(C&: M.getContext());
6135 Value *PLastIter = Builder.CreateAlloca(Ty: I32Type, ArraySize: nullptr, Name: "p.lastiter");
6136 Value *PLowerBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.lowerbound");
6137 Value *PUpperBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.upperbound");
6138 Value *PStride = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.stride");
6139 CLI->setLastIter(PLastIter);
6140
6141 // At the end of the preheader, prepare for calling the "init" function by
6142 // storing the current loop bounds into the allocated space. A canonical loop
6143 // always iterates from 0 to trip-count with step 1. Note that "init" expects
6144 // and produces an inclusive upper bound.
6145 BasicBlock *PreHeader = CLI->getPreheader();
6146 Builder.SetInsertPoint(PreHeader->getTerminator());
6147 Constant *One = ConstantInt::get(Ty: IVTy, V: 1);
6148 Builder.CreateStore(Val: One, Ptr: PLowerBound);
6149 Value *UpperBound = CLI->getTripCount();
6150 Builder.CreateStore(Val: UpperBound, Ptr: PUpperBound);
6151 Builder.CreateStore(Val: One, Ptr: PStride);
6152
6153 BasicBlock *Header = CLI->getHeader();
6154 BasicBlock *Exit = CLI->getExit();
6155 BasicBlock *Cond = CLI->getCond();
6156 BasicBlock *Latch = CLI->getLatch();
6157 InsertPointTy AfterIP = CLI->getAfterIP();
6158
6159 // The CLI will be "broken" in the code below, as the loop is no longer
6160 // a valid canonical loop.
6161
6162 if (!Chunk)
6163 Chunk = One;
6164
6165 Value *ThreadNum = getOrCreateThreadID(Ident: SrcLoc);
6166
6167 Constant *SchedulingType =
6168 ConstantInt::get(Ty: I32Type, V: static_cast<int>(SchedType));
6169
6170 // Call the "init" function.
6171 createRuntimeFunctionCall(Callee: DynamicInit, Args: {SrcLoc, ThreadNum, SchedulingType,
6172 /* LowerBound */ One, UpperBound,
6173 /* step */ One, Chunk});
6174
6175 // An outer loop around the existing one.
6176 BasicBlock *OuterCond = BasicBlock::Create(
6177 Context&: PreHeader->getContext(), Name: Twine(PreHeader->getName()) + ".outer.cond",
6178 Parent: PreHeader->getParent());
6179 // This needs to be 32-bit always, so can't use the IVTy Zero above.
6180 Builder.SetInsertPoint(TheBB: OuterCond, IP: OuterCond->getFirstInsertionPt());
6181 Value *Res = createRuntimeFunctionCall(
6182 Callee: DynamicNext,
6183 Args: {SrcLoc, ThreadNum, PLastIter, PLowerBound, PUpperBound, PStride});
6184 Constant *Zero32 = ConstantInt::get(Ty: I32Type, V: 0);
6185 Value *MoreWork = Builder.CreateCmp(Pred: CmpInst::ICMP_NE, LHS: Res, RHS: Zero32);
6186 Value *LowerBound =
6187 Builder.CreateSub(LHS: Builder.CreateLoad(Ty: IVTy, Ptr: PLowerBound), RHS: One, Name: "lb");
6188 Builder.CreateCondBr(Cond: MoreWork, True: Header, False: Exit);
6189
6190 // Change PHI-node in loop header to use outer cond rather than preheader,
6191 // and set IV to the LowerBound.
6192 Instruction *Phi = &Header->front();
6193 auto *PI = cast<PHINode>(Val: Phi);
6194 PI->setIncomingBlock(i: 0, BB: OuterCond);
6195 PI->setIncomingValue(i: 0, V: LowerBound);
6196
6197 // Then set the pre-header to jump to the OuterCond
6198 Instruction *Term = PreHeader->getTerminator();
6199 auto *Br = cast<BranchInst>(Val: Term);
6200 Br->setSuccessor(idx: 0, NewSucc: OuterCond);
6201
6202 // Modify the inner condition:
6203 // * Use the UpperBound returned from the DynamicNext call.
6204 // * jump to the loop outer loop when done with one of the inner loops.
6205 Builder.SetInsertPoint(TheBB: Cond, IP: Cond->getFirstInsertionPt());
6206 UpperBound = Builder.CreateLoad(Ty: IVTy, Ptr: PUpperBound, Name: "ub");
6207 Instruction *Comp = &*Builder.GetInsertPoint();
6208 auto *CI = cast<CmpInst>(Val: Comp);
6209 CI->setOperand(i_nocapture: 1, Val_nocapture: UpperBound);
6210 // Redirect the inner exit to branch to outer condition.
6211 Instruction *Branch = &Cond->back();
6212 auto *BI = cast<BranchInst>(Val: Branch);
6213 assert(BI->getSuccessor(1) == Exit);
6214 BI->setSuccessor(idx: 1, NewSucc: OuterCond);
6215
6216 // Call the "fini" function if "ordered" is present in wsloop directive.
6217 if (Ordered) {
6218 Builder.SetInsertPoint(&Latch->back());
6219 FunctionCallee DynamicFini = getKmpcForDynamicFiniForType(Ty: IVTy, M, OMPBuilder&: *this);
6220 createRuntimeFunctionCall(Callee: DynamicFini, Args: {SrcLoc, ThreadNum});
6221 }
6222
6223 // Add the barrier if requested.
6224 if (NeedsBarrier) {
6225 Builder.SetInsertPoint(&Exit->back());
6226 InsertPointOrErrorTy BarrierIP =
6227 createBarrier(Loc: LocationDescription(Builder.saveIP(), DL),
6228 Kind: omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
6229 /* CheckCancelFlag */ false);
6230 if (!BarrierIP)
6231 return BarrierIP.takeError();
6232 }
6233
6234 CLI->invalidate();
6235 return AfterIP;
6236}
6237
6238/// Redirect all edges that branch to \p OldTarget to \p NewTarget. That is,
6239/// after this \p OldTarget will be orphaned.
6240static void redirectAllPredecessorsTo(BasicBlock *OldTarget,
6241 BasicBlock *NewTarget, DebugLoc DL) {
6242 for (BasicBlock *Pred : make_early_inc_range(Range: predecessors(BB: OldTarget)))
6243 redirectTo(Source: Pred, Target: NewTarget, DL);
6244}
6245
6246/// Determine which blocks in \p BBs are reachable from outside and remove the
6247/// ones that are not reachable from the function.
6248static void removeUnusedBlocksFromParent(ArrayRef<BasicBlock *> BBs) {
6249 SmallPtrSet<BasicBlock *, 6> BBsToErase(llvm::from_range, BBs);
6250 auto HasRemainingUses = [&BBsToErase](BasicBlock *BB) {
6251 for (Use &U : BB->uses()) {
6252 auto *UseInst = dyn_cast<Instruction>(Val: U.getUser());
6253 if (!UseInst)
6254 continue;
6255 if (BBsToErase.count(Ptr: UseInst->getParent()))
6256 continue;
6257 return true;
6258 }
6259 return false;
6260 };
6261
6262 while (BBsToErase.remove_if(P: HasRemainingUses)) {
6263 // Try again if anything was removed.
6264 }
6265
6266 SmallVector<BasicBlock *, 7> BBVec(BBsToErase.begin(), BBsToErase.end());
6267 DeleteDeadBlocks(BBs: BBVec);
6268}
6269
6270CanonicalLoopInfo *
6271OpenMPIRBuilder::collapseLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
6272 InsertPointTy ComputeIP) {
6273 assert(Loops.size() >= 1 && "At least one loop required");
6274 size_t NumLoops = Loops.size();
6275
6276 // Nothing to do if there is already just one loop.
6277 if (NumLoops == 1)
6278 return Loops.front();
6279
6280 CanonicalLoopInfo *Outermost = Loops.front();
6281 CanonicalLoopInfo *Innermost = Loops.back();
6282 BasicBlock *OrigPreheader = Outermost->getPreheader();
6283 BasicBlock *OrigAfter = Outermost->getAfter();
6284 Function *F = OrigPreheader->getParent();
6285
6286 // Loop control blocks that may become orphaned later.
6287 SmallVector<BasicBlock *, 12> OldControlBBs;
6288 OldControlBBs.reserve(N: 6 * Loops.size());
6289 for (CanonicalLoopInfo *Loop : Loops)
6290 Loop->collectControlBlocks(BBs&: OldControlBBs);
6291
6292 // Setup the IRBuilder for inserting the trip count computation.
6293 Builder.SetCurrentDebugLocation(DL);
6294 if (ComputeIP.isSet())
6295 Builder.restoreIP(IP: ComputeIP);
6296 else
6297 Builder.restoreIP(IP: Outermost->getPreheaderIP());
6298
6299 // Derive the collapsed' loop trip count.
6300 // TODO: Find common/largest indvar type.
6301 Value *CollapsedTripCount = nullptr;
6302 for (CanonicalLoopInfo *L : Loops) {
6303 assert(L->isValid() &&
6304 "All loops to collapse must be valid canonical loops");
6305 Value *OrigTripCount = L->getTripCount();
6306 if (!CollapsedTripCount) {
6307 CollapsedTripCount = OrigTripCount;
6308 continue;
6309 }
6310
6311 // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
6312 CollapsedTripCount =
6313 Builder.CreateNUWMul(LHS: CollapsedTripCount, RHS: OrigTripCount);
6314 }
6315
6316 // Create the collapsed loop control flow.
6317 CanonicalLoopInfo *Result =
6318 createLoopSkeleton(DL, TripCount: CollapsedTripCount, F,
6319 PreInsertBefore: OrigPreheader->getNextNode(), PostInsertBefore: OrigAfter, Name: "collapsed");
6320
6321 // Build the collapsed loop body code.
6322 // Start with deriving the input loop induction variables from the collapsed
6323 // one, using a divmod scheme. To preserve the original loops' order, the
6324 // innermost loop use the least significant bits.
6325 Builder.restoreIP(IP: Result->getBodyIP());
6326
6327 Value *Leftover = Result->getIndVar();
6328 SmallVector<Value *> NewIndVars;
6329 NewIndVars.resize(N: NumLoops);
6330 for (int i = NumLoops - 1; i >= 1; --i) {
6331 Value *OrigTripCount = Loops[i]->getTripCount();
6332
6333 Value *NewIndVar = Builder.CreateURem(LHS: Leftover, RHS: OrigTripCount);
6334 NewIndVars[i] = NewIndVar;
6335
6336 Leftover = Builder.CreateUDiv(LHS: Leftover, RHS: OrigTripCount);
6337 }
6338 // Outermost loop gets all the remaining bits.
6339 NewIndVars[0] = Leftover;
6340
6341 // Construct the loop body control flow.
6342 // We progressively construct the branch structure following in direction of
6343 // the control flow, from the leading in-between code, the loop nest body, the
6344 // trailing in-between code, and rejoining the collapsed loop's latch.
6345 // ContinueBlock and ContinuePred keep track of the source(s) of next edge. If
6346 // the ContinueBlock is set, continue with that block. If ContinuePred, use
6347 // its predecessors as sources.
6348 BasicBlock *ContinueBlock = Result->getBody();
6349 BasicBlock *ContinuePred = nullptr;
6350 auto ContinueWith = [&ContinueBlock, &ContinuePred, DL](BasicBlock *Dest,
6351 BasicBlock *NextSrc) {
6352 if (ContinueBlock)
6353 redirectTo(Source: ContinueBlock, Target: Dest, DL);
6354 else
6355 redirectAllPredecessorsTo(OldTarget: ContinuePred, NewTarget: Dest, DL);
6356
6357 ContinueBlock = nullptr;
6358 ContinuePred = NextSrc;
6359 };
6360
6361 // The code before the nested loop of each level.
6362 // Because we are sinking it into the nest, it will be executed more often
6363 // that the original loop. More sophisticated schemes could keep track of what
6364 // the in-between code is and instantiate it only once per thread.
6365 for (size_t i = 0; i < NumLoops - 1; ++i)
6366 ContinueWith(Loops[i]->getBody(), Loops[i + 1]->getHeader());
6367
6368 // Connect the loop nest body.
6369 ContinueWith(Innermost->getBody(), Innermost->getLatch());
6370
6371 // The code after the nested loop at each level.
6372 for (size_t i = NumLoops - 1; i > 0; --i)
6373 ContinueWith(Loops[i]->getAfter(), Loops[i - 1]->getLatch());
6374
6375 // Connect the finished loop to the collapsed loop latch.
6376 ContinueWith(Result->getLatch(), nullptr);
6377
6378 // Replace the input loops with the new collapsed loop.
6379 redirectTo(Source: Outermost->getPreheader(), Target: Result->getPreheader(), DL);
6380 redirectTo(Source: Result->getAfter(), Target: Outermost->getAfter(), DL);
6381
6382 // Replace the input loop indvars with the derived ones.
6383 for (size_t i = 0; i < NumLoops; ++i)
6384 Loops[i]->getIndVar()->replaceAllUsesWith(V: NewIndVars[i]);
6385
6386 // Remove unused parts of the input loops.
6387 removeUnusedBlocksFromParent(BBs: OldControlBBs);
6388
6389 for (CanonicalLoopInfo *L : Loops)
6390 L->invalidate();
6391
6392#ifndef NDEBUG
6393 Result->assertOK();
6394#endif
6395 return Result;
6396}
6397
6398std::vector<CanonicalLoopInfo *>
6399OpenMPIRBuilder::tileLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
6400 ArrayRef<Value *> TileSizes) {
6401 assert(TileSizes.size() == Loops.size() &&
6402 "Must pass as many tile sizes as there are loops");
6403 int NumLoops = Loops.size();
6404 assert(NumLoops >= 1 && "At least one loop to tile required");
6405
6406 CanonicalLoopInfo *OutermostLoop = Loops.front();
6407 CanonicalLoopInfo *InnermostLoop = Loops.back();
6408 Function *F = OutermostLoop->getBody()->getParent();
6409 BasicBlock *InnerEnter = InnermostLoop->getBody();
6410 BasicBlock *InnerLatch = InnermostLoop->getLatch();
6411
6412 // Loop control blocks that may become orphaned later.
6413 SmallVector<BasicBlock *, 12> OldControlBBs;
6414 OldControlBBs.reserve(N: 6 * Loops.size());
6415 for (CanonicalLoopInfo *Loop : Loops)
6416 Loop->collectControlBlocks(BBs&: OldControlBBs);
6417
6418 // Collect original trip counts and induction variable to be accessible by
6419 // index. Also, the structure of the original loops is not preserved during
6420 // the construction of the tiled loops, so do it before we scavenge the BBs of
6421 // any original CanonicalLoopInfo.
6422 SmallVector<Value *, 4> OrigTripCounts, OrigIndVars;
6423 for (CanonicalLoopInfo *L : Loops) {
6424 assert(L->isValid() && "All input loops must be valid canonical loops");
6425 OrigTripCounts.push_back(Elt: L->getTripCount());
6426 OrigIndVars.push_back(Elt: L->getIndVar());
6427 }
6428
6429 // Collect the code between loop headers. These may contain SSA definitions
6430 // that are used in the loop nest body. To be usable with in the innermost
6431 // body, these BasicBlocks will be sunk into the loop nest body. That is,
6432 // these instructions may be executed more often than before the tiling.
6433 // TODO: It would be sufficient to only sink them into body of the
6434 // corresponding tile loop.
6435 SmallVector<std::pair<BasicBlock *, BasicBlock *>, 4> InbetweenCode;
6436 for (int i = 0; i < NumLoops - 1; ++i) {
6437 CanonicalLoopInfo *Surrounding = Loops[i];
6438 CanonicalLoopInfo *Nested = Loops[i + 1];
6439
6440 BasicBlock *EnterBB = Surrounding->getBody();
6441 BasicBlock *ExitBB = Nested->getHeader();
6442 InbetweenCode.emplace_back(Args&: EnterBB, Args&: ExitBB);
6443 }
6444
6445 // Compute the trip counts of the floor loops.
6446 Builder.SetCurrentDebugLocation(DL);
6447 Builder.restoreIP(IP: OutermostLoop->getPreheaderIP());
6448 SmallVector<Value *, 4> FloorCompleteCount, FloorCount, FloorRems;
6449 for (int i = 0; i < NumLoops; ++i) {
6450 Value *TileSize = TileSizes[i];
6451 Value *OrigTripCount = OrigTripCounts[i];
6452 Type *IVType = OrigTripCount->getType();
6453
6454 Value *FloorCompleteTripCount = Builder.CreateUDiv(LHS: OrigTripCount, RHS: TileSize);
6455 Value *FloorTripRem = Builder.CreateURem(LHS: OrigTripCount, RHS: TileSize);
6456
6457 // 0 if tripcount divides the tilesize, 1 otherwise.
6458 // 1 means we need an additional iteration for a partial tile.
6459 //
6460 // Unfortunately we cannot just use the roundup-formula
6461 // (tripcount + tilesize - 1)/tilesize
6462 // because the summation might overflow. We do not want introduce undefined
6463 // behavior when the untiled loop nest did not.
6464 Value *FloorTripOverflow =
6465 Builder.CreateICmpNE(LHS: FloorTripRem, RHS: ConstantInt::get(Ty: IVType, V: 0));
6466
6467 FloorTripOverflow = Builder.CreateZExt(V: FloorTripOverflow, DestTy: IVType);
6468 Value *FloorTripCount =
6469 Builder.CreateAdd(LHS: FloorCompleteTripCount, RHS: FloorTripOverflow,
6470 Name: "omp_floor" + Twine(i) + ".tripcount", HasNUW: true);
6471
6472 // Remember some values for later use.
6473 FloorCompleteCount.push_back(Elt: FloorCompleteTripCount);
6474 FloorCount.push_back(Elt: FloorTripCount);
6475 FloorRems.push_back(Elt: FloorTripRem);
6476 }
6477
6478 // Generate the new loop nest, from the outermost to the innermost.
6479 std::vector<CanonicalLoopInfo *> Result;
6480 Result.reserve(n: NumLoops * 2);
6481
6482 // The basic block of the surrounding loop that enters the nest generated
6483 // loop.
6484 BasicBlock *Enter = OutermostLoop->getPreheader();
6485
6486 // The basic block of the surrounding loop where the inner code should
6487 // continue.
6488 BasicBlock *Continue = OutermostLoop->getAfter();
6489
6490 // Where the next loop basic block should be inserted.
6491 BasicBlock *OutroInsertBefore = InnermostLoop->getExit();
6492
6493 auto EmbeddNewLoop =
6494 [this, DL, F, InnerEnter, &Enter, &Continue, &OutroInsertBefore](
6495 Value *TripCount, const Twine &Name) -> CanonicalLoopInfo * {
6496 CanonicalLoopInfo *EmbeddedLoop = createLoopSkeleton(
6497 DL, TripCount, F, PreInsertBefore: InnerEnter, PostInsertBefore: OutroInsertBefore, Name);
6498 redirectTo(Source: Enter, Target: EmbeddedLoop->getPreheader(), DL);
6499 redirectTo(Source: EmbeddedLoop->getAfter(), Target: Continue, DL);
6500
6501 // Setup the position where the next embedded loop connects to this loop.
6502 Enter = EmbeddedLoop->getBody();
6503 Continue = EmbeddedLoop->getLatch();
6504 OutroInsertBefore = EmbeddedLoop->getLatch();
6505 return EmbeddedLoop;
6506 };
6507
6508 auto EmbeddNewLoops = [&Result, &EmbeddNewLoop](ArrayRef<Value *> TripCounts,
6509 const Twine &NameBase) {
6510 for (auto P : enumerate(First&: TripCounts)) {
6511 CanonicalLoopInfo *EmbeddedLoop =
6512 EmbeddNewLoop(P.value(), NameBase + Twine(P.index()));
6513 Result.push_back(x: EmbeddedLoop);
6514 }
6515 };
6516
6517 EmbeddNewLoops(FloorCount, "floor");
6518
6519 // Within the innermost floor loop, emit the code that computes the tile
6520 // sizes.
6521 Builder.SetInsertPoint(Enter->getTerminator());
6522 SmallVector<Value *, 4> TileCounts;
6523 for (int i = 0; i < NumLoops; ++i) {
6524 CanonicalLoopInfo *FloorLoop = Result[i];
6525 Value *TileSize = TileSizes[i];
6526
6527 Value *FloorIsEpilogue =
6528 Builder.CreateICmpEQ(LHS: FloorLoop->getIndVar(), RHS: FloorCompleteCount[i]);
6529 Value *TileTripCount =
6530 Builder.CreateSelect(C: FloorIsEpilogue, True: FloorRems[i], False: TileSize);
6531
6532 TileCounts.push_back(Elt: TileTripCount);
6533 }
6534
6535 // Create the tile loops.
6536 EmbeddNewLoops(TileCounts, "tile");
6537
6538 // Insert the inbetween code into the body.
6539 BasicBlock *BodyEnter = Enter;
6540 BasicBlock *BodyEntered = nullptr;
6541 for (std::pair<BasicBlock *, BasicBlock *> P : InbetweenCode) {
6542 BasicBlock *EnterBB = P.first;
6543 BasicBlock *ExitBB = P.second;
6544
6545 if (BodyEnter)
6546 redirectTo(Source: BodyEnter, Target: EnterBB, DL);
6547 else
6548 redirectAllPredecessorsTo(OldTarget: BodyEntered, NewTarget: EnterBB, DL);
6549
6550 BodyEnter = nullptr;
6551 BodyEntered = ExitBB;
6552 }
6553
6554 // Append the original loop nest body into the generated loop nest body.
6555 if (BodyEnter)
6556 redirectTo(Source: BodyEnter, Target: InnerEnter, DL);
6557 else
6558 redirectAllPredecessorsTo(OldTarget: BodyEntered, NewTarget: InnerEnter, DL);
6559 redirectAllPredecessorsTo(OldTarget: InnerLatch, NewTarget: Continue, DL);
6560
6561 // Replace the original induction variable with an induction variable computed
6562 // from the tile and floor induction variables.
6563 Builder.restoreIP(IP: Result.back()->getBodyIP());
6564 for (int i = 0; i < NumLoops; ++i) {
6565 CanonicalLoopInfo *FloorLoop = Result[i];
6566 CanonicalLoopInfo *TileLoop = Result[NumLoops + i];
6567 Value *OrigIndVar = OrigIndVars[i];
6568 Value *Size = TileSizes[i];
6569
6570 Value *Scale =
6571 Builder.CreateMul(LHS: Size, RHS: FloorLoop->getIndVar(), Name: {}, /*HasNUW=*/true);
6572 Value *Shift =
6573 Builder.CreateAdd(LHS: Scale, RHS: TileLoop->getIndVar(), Name: {}, /*HasNUW=*/true);
6574 OrigIndVar->replaceAllUsesWith(V: Shift);
6575 }
6576
6577 // Remove unused parts of the original loops.
6578 removeUnusedBlocksFromParent(BBs: OldControlBBs);
6579
6580 for (CanonicalLoopInfo *L : Loops)
6581 L->invalidate();
6582
6583#ifndef NDEBUG
6584 for (CanonicalLoopInfo *GenL : Result)
6585 GenL->assertOK();
6586#endif
6587 return Result;
6588}
6589
6590/// Attach metadata \p Properties to the basic block described by \p BB. If the
6591/// basic block already has metadata, the basic block properties are appended.
6592static void addBasicBlockMetadata(BasicBlock *BB,
6593 ArrayRef<Metadata *> Properties) {
6594 // Nothing to do if no property to attach.
6595 if (Properties.empty())
6596 return;
6597
6598 LLVMContext &Ctx = BB->getContext();
6599 SmallVector<Metadata *> NewProperties;
6600 NewProperties.push_back(Elt: nullptr);
6601
6602 // If the basic block already has metadata, prepend it to the new metadata.
6603 MDNode *Existing = BB->getTerminator()->getMetadata(KindID: LLVMContext::MD_loop);
6604 if (Existing)
6605 append_range(C&: NewProperties, R: drop_begin(RangeOrContainer: Existing->operands(), N: 1));
6606
6607 append_range(C&: NewProperties, R&: Properties);
6608 MDNode *BasicBlockID = MDNode::getDistinct(Context&: Ctx, MDs: NewProperties);
6609 BasicBlockID->replaceOperandWith(I: 0, New: BasicBlockID);
6610
6611 BB->getTerminator()->setMetadata(KindID: LLVMContext::MD_loop, Node: BasicBlockID);
6612}
6613
6614/// Attach loop metadata \p Properties to the loop described by \p Loop. If the
6615/// loop already has metadata, the loop properties are appended.
6616static void addLoopMetadata(CanonicalLoopInfo *Loop,
6617 ArrayRef<Metadata *> Properties) {
6618 assert(Loop->isValid() && "Expecting a valid CanonicalLoopInfo");
6619
6620 // Attach metadata to the loop's latch
6621 BasicBlock *Latch = Loop->getLatch();
6622 assert(Latch && "A valid CanonicalLoopInfo must have a unique latch");
6623 addBasicBlockMetadata(BB: Latch, Properties);
6624}
6625
6626/// Attach llvm.access.group metadata to the memref instructions of \p Block
6627static void addAccessGroupMetadata(BasicBlock *Block, MDNode *AccessGroup,
6628 LoopInfo &LI) {
6629 for (Instruction &I : *Block) {
6630 if (I.mayReadOrWriteMemory()) {
6631 // TODO: This instruction may already have access group from
6632 // other pragmas e.g. #pragma clang loop vectorize. Append
6633 // so that the existing metadata is not overwritten.
6634 I.setMetadata(KindID: LLVMContext::MD_access_group, Node: AccessGroup);
6635 }
6636 }
6637}
6638
6639CanonicalLoopInfo *
6640OpenMPIRBuilder::fuseLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops) {
6641 CanonicalLoopInfo *firstLoop = Loops.front();
6642 CanonicalLoopInfo *lastLoop = Loops.back();
6643 Function *F = firstLoop->getPreheader()->getParent();
6644
6645 // Loop control blocks that will become orphaned later
6646 SmallVector<BasicBlock *> oldControlBBs;
6647 for (CanonicalLoopInfo *Loop : Loops)
6648 Loop->collectControlBlocks(BBs&: oldControlBBs);
6649
6650 // Collect original trip counts
6651 SmallVector<Value *> origTripCounts;
6652 for (CanonicalLoopInfo *L : Loops) {
6653 assert(L->isValid() && "All input loops must be valid canonical loops");
6654 origTripCounts.push_back(Elt: L->getTripCount());
6655 }
6656
6657 Builder.SetCurrentDebugLocation(DL);
6658
6659 // Compute max trip count.
6660 // The fused loop will be from 0 to max(origTripCounts)
6661 BasicBlock *TCBlock = BasicBlock::Create(Context&: F->getContext(), Name: "omp.fuse.comp.tc",
6662 Parent: F, InsertBefore: firstLoop->getHeader());
6663 Builder.SetInsertPoint(TCBlock);
6664 Value *fusedTripCount = nullptr;
6665 for (CanonicalLoopInfo *L : Loops) {
6666 assert(L->isValid() && "All loops to fuse must be valid canonical loops");
6667 Value *origTripCount = L->getTripCount();
6668 if (!fusedTripCount) {
6669 fusedTripCount = origTripCount;
6670 continue;
6671 }
6672 Value *condTP = Builder.CreateICmpSGT(LHS: fusedTripCount, RHS: origTripCount);
6673 fusedTripCount = Builder.CreateSelect(C: condTP, True: fusedTripCount, False: origTripCount,
6674 Name: ".omp.fuse.tc");
6675 }
6676
6677 // Generate new loop
6678 CanonicalLoopInfo *fused =
6679 createLoopSkeleton(DL, TripCount: fusedTripCount, F, PreInsertBefore: firstLoop->getBody(),
6680 PostInsertBefore: lastLoop->getLatch(), Name: "fused");
6681
6682 // Replace original loops with the fused loop
6683 // Preheader and After are not considered inside the CLI.
6684 // These are used to compute the individual TCs of the loops
6685 // so they have to be put before the resulting fused loop.
6686 // Moving them up for readability.
6687 for (size_t i = 0; i < Loops.size() - 1; ++i) {
6688 Loops[i]->getPreheader()->moveBefore(MovePos: TCBlock);
6689 Loops[i]->getAfter()->moveBefore(MovePos: TCBlock);
6690 }
6691 lastLoop->getPreheader()->moveBefore(MovePos: TCBlock);
6692
6693 for (size_t i = 0; i < Loops.size() - 1; ++i) {
6694 redirectTo(Source: Loops[i]->getPreheader(), Target: Loops[i]->getAfter(), DL);
6695 redirectTo(Source: Loops[i]->getAfter(), Target: Loops[i + 1]->getPreheader(), DL);
6696 }
6697 redirectTo(Source: lastLoop->getPreheader(), Target: TCBlock, DL);
6698 redirectTo(Source: TCBlock, Target: fused->getPreheader(), DL);
6699 redirectTo(Source: fused->getAfter(), Target: lastLoop->getAfter(), DL);
6700
6701 // Build the fused body
6702 // Create new Blocks with conditions that jump to the original loop bodies
6703 SmallVector<BasicBlock *> condBBs;
6704 SmallVector<Value *> condValues;
6705 for (size_t i = 0; i < Loops.size(); ++i) {
6706 BasicBlock *condBlock = BasicBlock::Create(
6707 Context&: F->getContext(), Name: "omp.fused.inner.cond", Parent: F, InsertBefore: Loops[i]->getBody());
6708 Builder.SetInsertPoint(condBlock);
6709 Value *condValue =
6710 Builder.CreateICmpSLT(LHS: fused->getIndVar(), RHS: origTripCounts[i]);
6711 condBBs.push_back(Elt: condBlock);
6712 condValues.push_back(Elt: condValue);
6713 }
6714 // Join the condition blocks with the bodies of the original loops
6715 redirectTo(Source: fused->getBody(), Target: condBBs[0], DL);
6716 for (size_t i = 0; i < Loops.size() - 1; ++i) {
6717 Builder.SetInsertPoint(condBBs[i]);
6718 Builder.CreateCondBr(Cond: condValues[i], True: Loops[i]->getBody(), False: condBBs[i + 1]);
6719 redirectAllPredecessorsTo(OldTarget: Loops[i]->getLatch(), NewTarget: condBBs[i + 1], DL);
6720 // Replace the IV with the fused IV
6721 Loops[i]->getIndVar()->replaceAllUsesWith(V: fused->getIndVar());
6722 }
6723 // Last body jumps to the created end body block
6724 Builder.SetInsertPoint(condBBs.back());
6725 Builder.CreateCondBr(Cond: condValues.back(), True: lastLoop->getBody(),
6726 False: fused->getLatch());
6727 redirectAllPredecessorsTo(OldTarget: lastLoop->getLatch(), NewTarget: fused->getLatch(), DL);
6728 // Replace the IV with the fused IV
6729 lastLoop->getIndVar()->replaceAllUsesWith(V: fused->getIndVar());
6730
6731 // The loop latch must have only one predecessor. Currently it is branched to
6732 // from both the last condition block and the last loop body
6733 fused->getLatch()->splitBasicBlockBefore(I: fused->getLatch()->begin(),
6734 BBName: "omp.fused.pre_latch");
6735
6736 // Remove unused parts
6737 removeUnusedBlocksFromParent(BBs: oldControlBBs);
6738
6739 // Invalidate old CLIs
6740 for (CanonicalLoopInfo *L : Loops)
6741 L->invalidate();
6742
6743#ifndef NDEBUG
6744 fused->assertOK();
6745#endif
6746 return fused;
6747}
6748
6749void OpenMPIRBuilder::unrollLoopFull(DebugLoc, CanonicalLoopInfo *Loop) {
6750 LLVMContext &Ctx = Builder.getContext();
6751 addLoopMetadata(
6752 Loop, Properties: {MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")),
6753 MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.full"))});
6754}
6755
6756void OpenMPIRBuilder::unrollLoopHeuristic(DebugLoc, CanonicalLoopInfo *Loop) {
6757 LLVMContext &Ctx = Builder.getContext();
6758 addLoopMetadata(
6759 Loop, Properties: {
6760 MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")),
6761 });
6762}
6763
6764void OpenMPIRBuilder::createIfVersion(CanonicalLoopInfo *CanonicalLoop,
6765 Value *IfCond, ValueToValueMapTy &VMap,
6766 LoopAnalysis &LIA, LoopInfo &LI, Loop *L,
6767 const Twine &NamePrefix) {
6768 Function *F = CanonicalLoop->getFunction();
6769
6770 // We can't do
6771 // if (cond) {
6772 // simd_loop;
6773 // } else {
6774 // non_simd_loop;
6775 // }
6776 // because then the CanonicalLoopInfo would only point to one of the loops:
6777 // leading to other constructs operating on the same loop to malfunction.
6778 // Instead generate
6779 // while (...) {
6780 // if (cond) {
6781 // simd_body;
6782 // } else {
6783 // not_simd_body;
6784 // }
6785 // }
6786 // At least for simple loops, LLVM seems able to hoist the if out of the loop
6787 // body at -O3
6788
6789 // Define where if branch should be inserted
6790 auto SplitBeforeIt = CanonicalLoop->getBody()->getFirstNonPHIIt();
6791
6792 // Create additional blocks for the if statement
6793 BasicBlock *Cond = SplitBeforeIt->getParent();
6794 llvm::LLVMContext &C = Cond->getContext();
6795 llvm::BasicBlock *ThenBlock = llvm::BasicBlock::Create(
6796 Context&: C, Name: NamePrefix + ".if.then", Parent: Cond->getParent(), InsertBefore: Cond->getNextNode());
6797 llvm::BasicBlock *ElseBlock = llvm::BasicBlock::Create(
6798 Context&: C, Name: NamePrefix + ".if.else", Parent: Cond->getParent(), InsertBefore: CanonicalLoop->getExit());
6799
6800 // Create if condition branch.
6801 Builder.SetInsertPoint(SplitBeforeIt);
6802 Instruction *BrInstr =
6803 Builder.CreateCondBr(Cond: IfCond, True: ThenBlock, /*ifFalse*/ False: ElseBlock);
6804 InsertPointTy IP{BrInstr->getParent(), ++BrInstr->getIterator()};
6805 // Then block contains branch to omp loop body which needs to be vectorized
6806 spliceBB(IP, New: ThenBlock, CreateBranch: false, DL: Builder.getCurrentDebugLocation());
6807 ThenBlock->replaceSuccessorsPhiUsesWith(Old: Cond, New: ThenBlock);
6808
6809 Builder.SetInsertPoint(ElseBlock);
6810
6811 // Clone loop for the else branch
6812 SmallVector<BasicBlock *, 8> NewBlocks;
6813
6814 SmallVector<BasicBlock *, 8> ExistingBlocks;
6815 ExistingBlocks.reserve(N: L->getNumBlocks() + 1);
6816 ExistingBlocks.push_back(Elt: ThenBlock);
6817 ExistingBlocks.append(in_start: L->block_begin(), in_end: L->block_end());
6818 // Cond is the block that has the if clause condition
6819 // LoopCond is omp_loop.cond
6820 // LoopHeader is omp_loop.header
6821 BasicBlock *LoopCond = Cond->getUniquePredecessor();
6822 BasicBlock *LoopHeader = LoopCond->getUniquePredecessor();
6823 assert(LoopCond && LoopHeader && "Invalid loop structure");
6824 for (BasicBlock *Block : ExistingBlocks) {
6825 if (Block == L->getLoopPreheader() || Block == L->getLoopLatch() ||
6826 Block == LoopHeader || Block == LoopCond || Block == Cond) {
6827 continue;
6828 }
6829 BasicBlock *NewBB = CloneBasicBlock(BB: Block, VMap, NameSuffix: "", F);
6830
6831 // fix name not to be omp.if.then
6832 if (Block == ThenBlock)
6833 NewBB->setName(NamePrefix + ".if.else");
6834
6835 NewBB->moveBefore(MovePos: CanonicalLoop->getExit());
6836 VMap[Block] = NewBB;
6837 NewBlocks.push_back(Elt: NewBB);
6838 }
6839 remapInstructionsInBlocks(Blocks: NewBlocks, VMap);
6840 Builder.CreateBr(Dest: NewBlocks.front());
6841
6842 // The loop latch must have only one predecessor. Currently it is branched to
6843 // from both the 'then' and 'else' branches.
6844 L->getLoopLatch()->splitBasicBlockBefore(I: L->getLoopLatch()->begin(),
6845 BBName: NamePrefix + ".pre_latch");
6846
6847 // Ensure that the then block is added to the loop so we add the attributes in
6848 // the next step
6849 L->addBasicBlockToLoop(NewBB: ThenBlock, LI);
6850}
6851
6852unsigned
6853OpenMPIRBuilder::getOpenMPDefaultSimdAlign(const Triple &TargetTriple,
6854 const StringMap<bool> &Features) {
6855 if (TargetTriple.isX86()) {
6856 if (Features.lookup(Key: "avx512f"))
6857 return 512;
6858 else if (Features.lookup(Key: "avx"))
6859 return 256;
6860 return 128;
6861 }
6862 if (TargetTriple.isPPC())
6863 return 128;
6864 if (TargetTriple.isWasm())
6865 return 128;
6866 return 0;
6867}
6868
6869void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
6870 MapVector<Value *, Value *> AlignedVars,
6871 Value *IfCond, OrderKind Order,
6872 ConstantInt *Simdlen, ConstantInt *Safelen) {
6873 LLVMContext &Ctx = Builder.getContext();
6874
6875 Function *F = CanonicalLoop->getFunction();
6876
6877 // TODO: We should not rely on pass manager. Currently we use pass manager
6878 // only for getting llvm::Loop which corresponds to given CanonicalLoopInfo
6879 // object. We should have a method which returns all blocks between
6880 // CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter()
6881 FunctionAnalysisManager FAM;
6882 FAM.registerPass(PassBuilder: []() { return DominatorTreeAnalysis(); });
6883 FAM.registerPass(PassBuilder: []() { return LoopAnalysis(); });
6884 FAM.registerPass(PassBuilder: []() { return PassInstrumentationAnalysis(); });
6885
6886 LoopAnalysis LIA;
6887 LoopInfo &&LI = LIA.run(F&: *F, AM&: FAM);
6888
6889 Loop *L = LI.getLoopFor(BB: CanonicalLoop->getHeader());
6890 if (AlignedVars.size()) {
6891 InsertPointTy IP = Builder.saveIP();
6892 for (auto &AlignedItem : AlignedVars) {
6893 Value *AlignedPtr = AlignedItem.first;
6894 Value *Alignment = AlignedItem.second;
6895 Instruction *loadInst = dyn_cast<Instruction>(Val: AlignedPtr);
6896 Builder.SetInsertPoint(loadInst->getNextNode());
6897 Builder.CreateAlignmentAssumption(DL: F->getDataLayout(), PtrValue: AlignedPtr,
6898 Alignment);
6899 }
6900 Builder.restoreIP(IP);
6901 }
6902
6903 if (IfCond) {
6904 ValueToValueMapTy VMap;
6905 createIfVersion(CanonicalLoop, IfCond, VMap, LIA, LI, L, NamePrefix: "simd");
6906 }
6907
6908 SmallPtrSet<BasicBlock *, 8> Reachable;
6909
6910 // Get the basic blocks from the loop in which memref instructions
6911 // can be found.
6912 // TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
6913 // preferably without running any passes.
6914 for (BasicBlock *Block : L->getBlocks()) {
6915 if (Block == CanonicalLoop->getCond() ||
6916 Block == CanonicalLoop->getHeader())
6917 continue;
6918 Reachable.insert(Ptr: Block);
6919 }
6920
6921 SmallVector<Metadata *> LoopMDList;
6922
6923 // In presence of finite 'safelen', it may be unsafe to mark all
6924 // the memory instructions parallel, because loop-carried
6925 // dependences of 'safelen' iterations are possible.
6926 // If clause order(concurrent) is specified then the memory instructions
6927 // are marked parallel even if 'safelen' is finite.
6928 if ((Safelen == nullptr) || (Order == OrderKind::OMP_ORDER_concurrent))
6929 applyParallelAccessesMetadata(CLI: CanonicalLoop, Ctx, Loop: L, LoopInfo&: LI, LoopMDList);
6930
6931 // FIXME: the IF clause shares a loop backedge for the SIMD and non-SIMD
6932 // versions so we can't add the loop attributes in that case.
6933 if (IfCond) {
6934 // we can still add llvm.loop.parallel_access
6935 addLoopMetadata(Loop: CanonicalLoop, Properties: LoopMDList);
6936 return;
6937 }
6938
6939 // Use the above access group metadata to create loop level
6940 // metadata, which should be distinct for each loop.
6941 ConstantAsMetadata *BoolConst =
6942 ConstantAsMetadata::get(C: ConstantInt::getTrue(Ty: Type::getInt1Ty(C&: Ctx)));
6943 LoopMDList.push_back(Elt: MDNode::get(
6944 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.vectorize.enable"), BoolConst}));
6945
6946 if (Simdlen || Safelen) {
6947 // If both simdlen and safelen clauses are specified, the value of the
6948 // simdlen parameter must be less than or equal to the value of the safelen
6949 // parameter. Therefore, use safelen only in the absence of simdlen.
6950 ConstantInt *VectorizeWidth = Simdlen == nullptr ? Safelen : Simdlen;
6951 LoopMDList.push_back(
6952 Elt: MDNode::get(Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.vectorize.width"),
6953 ConstantAsMetadata::get(C: VectorizeWidth)}));
6954 }
6955
6956 addLoopMetadata(Loop: CanonicalLoop, Properties: LoopMDList);
6957}
6958
6959/// Create the TargetMachine object to query the backend for optimization
6960/// preferences.
6961///
6962/// Ideally, this would be passed from the front-end to the OpenMPBuilder, but
6963/// e.g. Clang does not pass it to its CodeGen layer and creates it only when
6964/// needed for the LLVM pass pipline. We use some default options to avoid
6965/// having to pass too many settings from the frontend that probably do not
6966/// matter.
6967///
6968/// Currently, TargetMachine is only used sometimes by the unrollLoopPartial
6969/// method. If we are going to use TargetMachine for more purposes, especially
6970/// those that are sensitive to TargetOptions, RelocModel and CodeModel, it
6971/// might become be worth requiring front-ends to pass on their TargetMachine,
6972/// or at least cache it between methods. Note that while fontends such as Clang
6973/// have just a single main TargetMachine per translation unit, "target-cpu" and
6974/// "target-features" that determine the TargetMachine are per-function and can
6975/// be overrided using __attribute__((target("OPTIONS"))).
6976static std::unique_ptr<TargetMachine>
6977createTargetMachine(Function *F, CodeGenOptLevel OptLevel) {
6978 Module *M = F->getParent();
6979
6980 StringRef CPU = F->getFnAttribute(Kind: "target-cpu").getValueAsString();
6981 StringRef Features = F->getFnAttribute(Kind: "target-features").getValueAsString();
6982 const llvm::Triple &Triple = M->getTargetTriple();
6983
6984 std::string Error;
6985 const llvm::Target *TheTarget = TargetRegistry::lookupTarget(TheTriple: Triple, Error);
6986 if (!TheTarget)
6987 return {};
6988
6989 llvm::TargetOptions Options;
6990 return std::unique_ptr<TargetMachine>(TheTarget->createTargetMachine(
6991 TT: Triple, CPU, Features, Options, /*RelocModel=*/RM: std::nullopt,
6992 /*CodeModel=*/CM: std::nullopt, OL: OptLevel));
6993}
6994
6995/// Heuristically determine the best-performant unroll factor for \p CLI. This
6996/// depends on the target processor. We are re-using the same heuristics as the
6997/// LoopUnrollPass.
6998static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) {
6999 Function *F = CLI->getFunction();
7000
7001 // Assume the user requests the most aggressive unrolling, even if the rest of
7002 // the code is optimized using a lower setting.
7003 CodeGenOptLevel OptLevel = CodeGenOptLevel::Aggressive;
7004 std::unique_ptr<TargetMachine> TM = createTargetMachine(F, OptLevel);
7005
7006 FunctionAnalysisManager FAM;
7007 FAM.registerPass(PassBuilder: []() { return TargetLibraryAnalysis(); });
7008 FAM.registerPass(PassBuilder: []() { return AssumptionAnalysis(); });
7009 FAM.registerPass(PassBuilder: []() { return DominatorTreeAnalysis(); });
7010 FAM.registerPass(PassBuilder: []() { return LoopAnalysis(); });
7011 FAM.registerPass(PassBuilder: []() { return ScalarEvolutionAnalysis(); });
7012 FAM.registerPass(PassBuilder: []() { return PassInstrumentationAnalysis(); });
7013 TargetIRAnalysis TIRA;
7014 if (TM)
7015 TIRA = TargetIRAnalysis(
7016 [&](const Function &F) { return TM->getTargetTransformInfo(F); });
7017 FAM.registerPass(PassBuilder: [&]() { return TIRA; });
7018
7019 TargetIRAnalysis::Result &&TTI = TIRA.run(F: *F, FAM);
7020 ScalarEvolutionAnalysis SEA;
7021 ScalarEvolution &&SE = SEA.run(F&: *F, AM&: FAM);
7022 DominatorTreeAnalysis DTA;
7023 DominatorTree &&DT = DTA.run(F&: *F, FAM);
7024 LoopAnalysis LIA;
7025 LoopInfo &&LI = LIA.run(F&: *F, AM&: FAM);
7026 AssumptionAnalysis ACT;
7027 AssumptionCache &&AC = ACT.run(F&: *F, FAM);
7028 OptimizationRemarkEmitter ORE{F};
7029
7030 Loop *L = LI.getLoopFor(BB: CLI->getHeader());
7031 assert(L && "Expecting CanonicalLoopInfo to be recognized as a loop");
7032
7033 TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences(
7034 L, SE, TTI,
7035 /*BlockFrequencyInfo=*/BFI: nullptr,
7036 /*ProfileSummaryInfo=*/PSI: nullptr, ORE, OptLevel: static_cast<int>(OptLevel),
7037 /*UserThreshold=*/std::nullopt,
7038 /*UserCount=*/std::nullopt,
7039 /*UserAllowPartial=*/true,
7040 /*UserAllowRuntime=*/UserRuntime: true,
7041 /*UserUpperBound=*/std::nullopt,
7042 /*UserFullUnrollMaxCount=*/std::nullopt);
7043
7044 UP.Force = true;
7045
7046 // Account for additional optimizations taking place before the LoopUnrollPass
7047 // would unroll the loop.
7048 UP.Threshold *= UnrollThresholdFactor;
7049 UP.PartialThreshold *= UnrollThresholdFactor;
7050
7051 // Use normal unroll factors even if the rest of the code is optimized for
7052 // size.
7053 UP.OptSizeThreshold = UP.Threshold;
7054 UP.PartialOptSizeThreshold = UP.PartialThreshold;
7055
7056 LLVM_DEBUG(dbgs() << "Unroll heuristic thresholds:\n"
7057 << " Threshold=" << UP.Threshold << "\n"
7058 << " PartialThreshold=" << UP.PartialThreshold << "\n"
7059 << " OptSizeThreshold=" << UP.OptSizeThreshold << "\n"
7060 << " PartialOptSizeThreshold="
7061 << UP.PartialOptSizeThreshold << "\n");
7062
7063 // Disable peeling.
7064 TargetTransformInfo::PeelingPreferences PP =
7065 gatherPeelingPreferences(L, SE, TTI,
7066 /*UserAllowPeeling=*/false,
7067 /*UserAllowProfileBasedPeeling=*/false,
7068 /*UnrollingSpecficValues=*/false);
7069
7070 SmallPtrSet<const Value *, 32> EphValues;
7071 CodeMetrics::collectEphemeralValues(L, AC: &AC, EphValues);
7072
7073 // Assume that reads and writes to stack variables can be eliminated by
7074 // Mem2Reg, SROA or LICM. That is, don't count them towards the loop body's
7075 // size.
7076 for (BasicBlock *BB : L->blocks()) {
7077 for (Instruction &I : *BB) {
7078 Value *Ptr;
7079 if (auto *Load = dyn_cast<LoadInst>(Val: &I)) {
7080 Ptr = Load->getPointerOperand();
7081 } else if (auto *Store = dyn_cast<StoreInst>(Val: &I)) {
7082 Ptr = Store->getPointerOperand();
7083 } else
7084 continue;
7085
7086 Ptr = Ptr->stripPointerCasts();
7087
7088 if (auto *Alloca = dyn_cast<AllocaInst>(Val: Ptr)) {
7089 if (Alloca->getParent() == &F->getEntryBlock())
7090 EphValues.insert(Ptr: &I);
7091 }
7092 }
7093 }
7094
7095 UnrollCostEstimator UCE(L, TTI, EphValues, UP.BEInsns);
7096
7097 // Loop is not unrollable if the loop contains certain instructions.
7098 if (!UCE.canUnroll()) {
7099 LLVM_DEBUG(dbgs() << "Loop not considered unrollable\n");
7100 return 1;
7101 }
7102
7103 LLVM_DEBUG(dbgs() << "Estimated loop size is " << UCE.getRolledLoopSize()
7104 << "\n");
7105
7106 // TODO: Determine trip count of \p CLI if constant, computeUnrollCount might
7107 // be able to use it.
7108 int TripCount = 0;
7109 int MaxTripCount = 0;
7110 bool MaxOrZero = false;
7111 unsigned TripMultiple = 0;
7112
7113 bool UseUpperBound = false;
7114 computeUnrollCount(L, TTI, DT, LI: &LI, AC: &AC, SE, EphValues, ORE: &ORE, TripCount,
7115 MaxTripCount, MaxOrZero, TripMultiple, UCE, UP, PP,
7116 UseUpperBound);
7117 unsigned Factor = UP.Count;
7118 LLVM_DEBUG(dbgs() << "Suggesting unroll factor of " << Factor << "\n");
7119
7120 // This function returns 1 to signal to not unroll a loop.
7121 if (Factor == 0)
7122 return 1;
7123 return Factor;
7124}
7125
7126void OpenMPIRBuilder::unrollLoopPartial(DebugLoc DL, CanonicalLoopInfo *Loop,
7127 int32_t Factor,
7128 CanonicalLoopInfo **UnrolledCLI) {
7129 assert(Factor >= 0 && "Unroll factor must not be negative");
7130
7131 Function *F = Loop->getFunction();
7132 LLVMContext &Ctx = F->getContext();
7133
7134 // If the unrolled loop is not used for another loop-associated directive, it
7135 // is sufficient to add metadata for the LoopUnrollPass.
7136 if (!UnrolledCLI) {
7137 SmallVector<Metadata *, 2> LoopMetadata;
7138 LoopMetadata.push_back(
7139 Elt: MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")));
7140
7141 if (Factor >= 1) {
7142 ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
7143 C: ConstantInt::get(Ty: Type::getInt32Ty(C&: Ctx), V: APInt(32, Factor)));
7144 LoopMetadata.push_back(Elt: MDNode::get(
7145 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.count"), FactorConst}));
7146 }
7147
7148 addLoopMetadata(Loop, Properties: LoopMetadata);
7149 return;
7150 }
7151
7152 // Heuristically determine the unroll factor.
7153 if (Factor == 0)
7154 Factor = computeHeuristicUnrollFactor(CLI: Loop);
7155
7156 // No change required with unroll factor 1.
7157 if (Factor == 1) {
7158 *UnrolledCLI = Loop;
7159 return;
7160 }
7161
7162 assert(Factor >= 2 &&
7163 "unrolling only makes sense with a factor of 2 or larger");
7164
7165 Type *IndVarTy = Loop->getIndVarType();
7166
7167 // Apply partial unrolling by tiling the loop by the unroll-factor, then fully
7168 // unroll the inner loop.
7169 Value *FactorVal =
7170 ConstantInt::get(Ty: IndVarTy, V: APInt(IndVarTy->getIntegerBitWidth(), Factor,
7171 /*isSigned=*/false));
7172 std::vector<CanonicalLoopInfo *> LoopNest =
7173 tileLoops(DL, Loops: {Loop}, TileSizes: {FactorVal});
7174 assert(LoopNest.size() == 2 && "Expect 2 loops after tiling");
7175 *UnrolledCLI = LoopNest[0];
7176 CanonicalLoopInfo *InnerLoop = LoopNest[1];
7177
7178 // LoopUnrollPass can only fully unroll loops with constant trip count.
7179 // Unroll by the unroll factor with a fallback epilog for the remainder
7180 // iterations if necessary.
7181 ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
7182 C: ConstantInt::get(Ty: Type::getInt32Ty(C&: Ctx), V: APInt(32, Factor)));
7183 addLoopMetadata(
7184 Loop: InnerLoop,
7185 Properties: {MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")),
7186 MDNode::get(
7187 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.count"), FactorConst})});
7188
7189#ifndef NDEBUG
7190 (*UnrolledCLI)->assertOK();
7191#endif
7192}
7193
7194OpenMPIRBuilder::InsertPointTy
7195OpenMPIRBuilder::createCopyPrivate(const LocationDescription &Loc,
7196 llvm::Value *BufSize, llvm::Value *CpyBuf,
7197 llvm::Value *CpyFn, llvm::Value *DidIt) {
7198 if (!updateToLocation(Loc))
7199 return Loc.IP;
7200
7201 uint32_t SrcLocStrSize;
7202 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7203 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7204 Value *ThreadId = getOrCreateThreadID(Ident);
7205
7206 llvm::Value *DidItLD = Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: DidIt);
7207
7208 Value *Args[] = {Ident, ThreadId, BufSize, CpyBuf, CpyFn, DidItLD};
7209
7210 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_copyprivate);
7211 createRuntimeFunctionCall(Callee: Fn, Args);
7212
7213 return Builder.saveIP();
7214}
7215
7216OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createSingle(
7217 const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
7218 FinalizeCallbackTy FiniCB, bool IsNowait, ArrayRef<llvm::Value *> CPVars,
7219 ArrayRef<llvm::Function *> CPFuncs) {
7220
7221 if (!updateToLocation(Loc))
7222 return Loc.IP;
7223
7224 // If needed allocate and initialize `DidIt` with 0.
7225 // DidIt: flag variable: 1=single thread; 0=not single thread.
7226 llvm::Value *DidIt = nullptr;
7227 if (!CPVars.empty()) {
7228 DidIt = Builder.CreateAlloca(Ty: llvm::Type::getInt32Ty(C&: Builder.getContext()));
7229 Builder.CreateStore(Val: Builder.getInt32(C: 0), Ptr: DidIt);
7230 }
7231
7232 Directive OMPD = Directive::OMPD_single;
7233 uint32_t SrcLocStrSize;
7234 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7235 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7236 Value *ThreadId = getOrCreateThreadID(Ident);
7237 Value *Args[] = {Ident, ThreadId};
7238
7239 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_single);
7240 Instruction *EntryCall = createRuntimeFunctionCall(Callee: EntryRTLFn, Args);
7241
7242 Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_single);
7243 Instruction *ExitCall = createRuntimeFunctionCall(Callee: ExitRTLFn, Args);
7244
7245 auto FiniCBWrapper = [&](InsertPointTy IP) -> Error {
7246 if (Error Err = FiniCB(IP))
7247 return Err;
7248
7249 // The thread that executes the single region must set `DidIt` to 1.
7250 // This is used by __kmpc_copyprivate, to know if the caller is the
7251 // single thread or not.
7252 if (DidIt)
7253 Builder.CreateStore(Val: Builder.getInt32(C: 1), Ptr: DidIt);
7254
7255 return Error::success();
7256 };
7257
7258 // generates the following:
7259 // if (__kmpc_single()) {
7260 // .... single region ...
7261 // __kmpc_end_single
7262 // }
7263 // __kmpc_copyprivate
7264 // __kmpc_barrier
7265
7266 InsertPointOrErrorTy AfterIP =
7267 EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB: FiniCBWrapper,
7268 /*Conditional*/ true,
7269 /*hasFinalize*/ HasFinalize: true);
7270 if (!AfterIP)
7271 return AfterIP.takeError();
7272
7273 if (DidIt) {
7274 for (size_t I = 0, E = CPVars.size(); I < E; ++I)
7275 // NOTE BufSize is currently unused, so just pass 0.
7276 createCopyPrivate(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
7277 /*BufSize=*/ConstantInt::get(Ty: Int64, V: 0), CpyBuf: CPVars[I],
7278 CpyFn: CPFuncs[I], DidIt);
7279 // NOTE __kmpc_copyprivate already inserts a barrier
7280 } else if (!IsNowait) {
7281 InsertPointOrErrorTy AfterIP =
7282 createBarrier(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
7283 Kind: omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
7284 /* CheckCancelFlag */ false);
7285 if (!AfterIP)
7286 return AfterIP.takeError();
7287 }
7288 return Builder.saveIP();
7289}
7290
7291OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createCritical(
7292 const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
7293 FinalizeCallbackTy FiniCB, StringRef CriticalName, Value *HintInst) {
7294
7295 if (!updateToLocation(Loc))
7296 return Loc.IP;
7297
7298 Directive OMPD = Directive::OMPD_critical;
7299 uint32_t SrcLocStrSize;
7300 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7301 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7302 Value *ThreadId = getOrCreateThreadID(Ident);
7303 Value *LockVar = getOMPCriticalRegionLock(CriticalName);
7304 Value *Args[] = {Ident, ThreadId, LockVar};
7305
7306 SmallVector<llvm::Value *, 4> EnterArgs(std::begin(arr&: Args), std::end(arr&: Args));
7307 Function *RTFn = nullptr;
7308 if (HintInst) {
7309 // Add Hint to entry Args and create call
7310 EnterArgs.push_back(Elt: HintInst);
7311 RTFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_critical_with_hint);
7312 } else {
7313 RTFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_critical);
7314 }
7315 Instruction *EntryCall = createRuntimeFunctionCall(Callee: RTFn, Args: EnterArgs);
7316
7317 Function *ExitRTLFn =
7318 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_critical);
7319 Instruction *ExitCall = createRuntimeFunctionCall(Callee: ExitRTLFn, Args);
7320
7321 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
7322 /*Conditional*/ false, /*hasFinalize*/ HasFinalize: true);
7323}
7324
7325OpenMPIRBuilder::InsertPointTy
7326OpenMPIRBuilder::createOrderedDepend(const LocationDescription &Loc,
7327 InsertPointTy AllocaIP, unsigned NumLoops,
7328 ArrayRef<llvm::Value *> StoreValues,
7329 const Twine &Name, bool IsDependSource) {
7330 assert(
7331 llvm::all_of(StoreValues,
7332 [](Value *SV) { return SV->getType()->isIntegerTy(64); }) &&
7333 "OpenMP runtime requires depend vec with i64 type");
7334
7335 if (!updateToLocation(Loc))
7336 return Loc.IP;
7337
7338 // Allocate space for vector and generate alloc instruction.
7339 auto *ArrI64Ty = ArrayType::get(ElementType: Int64, NumElements: NumLoops);
7340 Builder.restoreIP(IP: AllocaIP);
7341 AllocaInst *ArgsBase = Builder.CreateAlloca(Ty: ArrI64Ty, ArraySize: nullptr, Name);
7342 ArgsBase->setAlignment(Align(8));
7343 updateToLocation(Loc);
7344
7345 // Store the index value with offset in depend vector.
7346 for (unsigned I = 0; I < NumLoops; ++I) {
7347 Value *DependAddrGEPIter = Builder.CreateInBoundsGEP(
7348 Ty: ArrI64Ty, Ptr: ArgsBase, IdxList: {Builder.getInt64(C: 0), Builder.getInt64(C: I)});
7349 StoreInst *STInst = Builder.CreateStore(Val: StoreValues[I], Ptr: DependAddrGEPIter);
7350 STInst->setAlignment(Align(8));
7351 }
7352
7353 Value *DependBaseAddrGEP = Builder.CreateInBoundsGEP(
7354 Ty: ArrI64Ty, Ptr: ArgsBase, IdxList: {Builder.getInt64(C: 0), Builder.getInt64(C: 0)});
7355
7356 uint32_t SrcLocStrSize;
7357 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7358 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7359 Value *ThreadId = getOrCreateThreadID(Ident);
7360 Value *Args[] = {Ident, ThreadId, DependBaseAddrGEP};
7361
7362 Function *RTLFn = nullptr;
7363 if (IsDependSource)
7364 RTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_doacross_post);
7365 else
7366 RTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_doacross_wait);
7367 createRuntimeFunctionCall(Callee: RTLFn, Args);
7368
7369 return Builder.saveIP();
7370}
7371
7372OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createOrderedThreadsSimd(
7373 const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
7374 FinalizeCallbackTy FiniCB, bool IsThreads) {
7375 if (!updateToLocation(Loc))
7376 return Loc.IP;
7377
7378 Directive OMPD = Directive::OMPD_ordered;
7379 Instruction *EntryCall = nullptr;
7380 Instruction *ExitCall = nullptr;
7381
7382 if (IsThreads) {
7383 uint32_t SrcLocStrSize;
7384 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7385 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7386 Value *ThreadId = getOrCreateThreadID(Ident);
7387 Value *Args[] = {Ident, ThreadId};
7388
7389 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_ordered);
7390 EntryCall = createRuntimeFunctionCall(Callee: EntryRTLFn, Args);
7391
7392 Function *ExitRTLFn =
7393 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_ordered);
7394 ExitCall = createRuntimeFunctionCall(Callee: ExitRTLFn, Args);
7395 }
7396
7397 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
7398 /*Conditional*/ false, /*hasFinalize*/ HasFinalize: true);
7399}
7400
7401OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::EmitOMPInlinedRegion(
7402 Directive OMPD, Instruction *EntryCall, Instruction *ExitCall,
7403 BodyGenCallbackTy BodyGenCB, FinalizeCallbackTy FiniCB, bool Conditional,
7404 bool HasFinalize, bool IsCancellable) {
7405
7406 if (HasFinalize)
7407 FinalizationStack.push_back(Elt: {FiniCB, OMPD, IsCancellable});
7408
7409 // Create inlined region's entry and body blocks, in preparation
7410 // for conditional creation
7411 BasicBlock *EntryBB = Builder.GetInsertBlock();
7412 Instruction *SplitPos = EntryBB->getTerminator();
7413 if (!isa_and_nonnull<BranchInst>(Val: SplitPos))
7414 SplitPos = new UnreachableInst(Builder.getContext(), EntryBB);
7415 BasicBlock *ExitBB = EntryBB->splitBasicBlock(I: SplitPos, BBName: "omp_region.end");
7416 BasicBlock *FiniBB =
7417 EntryBB->splitBasicBlock(I: EntryBB->getTerminator(), BBName: "omp_region.finalize");
7418
7419 Builder.SetInsertPoint(EntryBB->getTerminator());
7420 emitCommonDirectiveEntry(OMPD, EntryCall, ExitBB, Conditional);
7421
7422 // generate body
7423 if (Error Err = BodyGenCB(/* AllocaIP */ InsertPointTy(),
7424 /* CodeGenIP */ Builder.saveIP()))
7425 return Err;
7426
7427 // emit exit call and do any needed finalization.
7428 auto FinIP = InsertPointTy(FiniBB, FiniBB->getFirstInsertionPt());
7429 assert(FiniBB->getTerminator()->getNumSuccessors() == 1 &&
7430 FiniBB->getTerminator()->getSuccessor(0) == ExitBB &&
7431 "Unexpected control flow graph state!!");
7432 InsertPointOrErrorTy AfterIP =
7433 emitCommonDirectiveExit(OMPD, FinIP, ExitCall, HasFinalize);
7434 if (!AfterIP)
7435 return AfterIP.takeError();
7436
7437 // If we are skipping the region of a non conditional, remove the exit
7438 // block, and clear the builder's insertion point.
7439 assert(SplitPos->getParent() == ExitBB &&
7440 "Unexpected Insertion point location!");
7441 auto merged = MergeBlockIntoPredecessor(BB: ExitBB);
7442 BasicBlock *ExitPredBB = SplitPos->getParent();
7443 auto InsertBB = merged ? ExitPredBB : ExitBB;
7444 if (!isa_and_nonnull<BranchInst>(Val: SplitPos))
7445 SplitPos->eraseFromParent();
7446 Builder.SetInsertPoint(InsertBB);
7447
7448 return Builder.saveIP();
7449}
7450
7451OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveEntry(
7452 Directive OMPD, Value *EntryCall, BasicBlock *ExitBB, bool Conditional) {
7453 // if nothing to do, Return current insertion point.
7454 if (!Conditional || !EntryCall)
7455 return Builder.saveIP();
7456
7457 BasicBlock *EntryBB = Builder.GetInsertBlock();
7458 Value *CallBool = Builder.CreateIsNotNull(Arg: EntryCall);
7459 auto *ThenBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp_region.body");
7460 auto *UI = new UnreachableInst(Builder.getContext(), ThenBB);
7461
7462 // Emit thenBB and set the Builder's insertion point there for
7463 // body generation next. Place the block after the current block.
7464 Function *CurFn = EntryBB->getParent();
7465 CurFn->insert(Position: std::next(x: EntryBB->getIterator()), BB: ThenBB);
7466
7467 // Move Entry branch to end of ThenBB, and replace with conditional
7468 // branch (If-stmt)
7469 Instruction *EntryBBTI = EntryBB->getTerminator();
7470 Builder.CreateCondBr(Cond: CallBool, True: ThenBB, False: ExitBB);
7471 EntryBBTI->removeFromParent();
7472 Builder.SetInsertPoint(UI);
7473 Builder.Insert(I: EntryBBTI);
7474 UI->eraseFromParent();
7475 Builder.SetInsertPoint(ThenBB->getTerminator());
7476
7477 // return an insertion point to ExitBB.
7478 return IRBuilder<>::InsertPoint(ExitBB, ExitBB->getFirstInsertionPt());
7479}
7480
7481OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitCommonDirectiveExit(
7482 omp::Directive OMPD, InsertPointTy FinIP, Instruction *ExitCall,
7483 bool HasFinalize) {
7484
7485 Builder.restoreIP(IP: FinIP);
7486
7487 // If there is finalization to do, emit it before the exit call
7488 if (HasFinalize) {
7489 assert(!FinalizationStack.empty() &&
7490 "Unexpected finalization stack state!");
7491
7492 FinalizationInfo Fi = FinalizationStack.pop_back_val();
7493 assert(Fi.DK == OMPD && "Unexpected Directive for Finalization call!");
7494
7495 if (Error Err = Fi.mergeFiniBB(Builder, OtherFiniBB: FinIP.getBlock()))
7496 return std::move(Err);
7497
7498 // Exit condition: insertion point is before the terminator of the new Fini
7499 // block
7500 Builder.SetInsertPoint(FinIP.getBlock()->getTerminator());
7501 }
7502
7503 if (!ExitCall)
7504 return Builder.saveIP();
7505
7506 // place the Exitcall as last instruction before Finalization block terminator
7507 ExitCall->removeFromParent();
7508 Builder.Insert(I: ExitCall);
7509
7510 return IRBuilder<>::InsertPoint(ExitCall->getParent(),
7511 ExitCall->getIterator());
7512}
7513
7514OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCopyinClauseBlocks(
7515 InsertPointTy IP, Value *MasterAddr, Value *PrivateAddr,
7516 llvm::IntegerType *IntPtrTy, bool BranchtoEnd) {
7517 if (!IP.isSet())
7518 return IP;
7519
7520 IRBuilder<>::InsertPointGuard IPG(Builder);
7521
7522 // creates the following CFG structure
7523 // OMP_Entry : (MasterAddr != PrivateAddr)?
7524 // F T
7525 // | \
7526 // | copin.not.master
7527 // | /
7528 // v /
7529 // copyin.not.master.end
7530 // |
7531 // v
7532 // OMP.Entry.Next
7533
7534 BasicBlock *OMP_Entry = IP.getBlock();
7535 Function *CurFn = OMP_Entry->getParent();
7536 BasicBlock *CopyBegin =
7537 BasicBlock::Create(Context&: M.getContext(), Name: "copyin.not.master", Parent: CurFn);
7538 BasicBlock *CopyEnd = nullptr;
7539
7540 // If entry block is terminated, split to preserve the branch to following
7541 // basic block (i.e. OMP.Entry.Next), otherwise, leave everything as is.
7542 if (isa_and_nonnull<BranchInst>(Val: OMP_Entry->getTerminator())) {
7543 CopyEnd = OMP_Entry->splitBasicBlock(I: OMP_Entry->getTerminator(),
7544 BBName: "copyin.not.master.end");
7545 OMP_Entry->getTerminator()->eraseFromParent();
7546 } else {
7547 CopyEnd =
7548 BasicBlock::Create(Context&: M.getContext(), Name: "copyin.not.master.end", Parent: CurFn);
7549 }
7550
7551 Builder.SetInsertPoint(OMP_Entry);
7552 Value *MasterPtr = Builder.CreatePtrToInt(V: MasterAddr, DestTy: IntPtrTy);
7553 Value *PrivatePtr = Builder.CreatePtrToInt(V: PrivateAddr, DestTy: IntPtrTy);
7554 Value *cmp = Builder.CreateICmpNE(LHS: MasterPtr, RHS: PrivatePtr);
7555 Builder.CreateCondBr(Cond: cmp, True: CopyBegin, False: CopyEnd);
7556
7557 Builder.SetInsertPoint(CopyBegin);
7558 if (BranchtoEnd)
7559 Builder.SetInsertPoint(Builder.CreateBr(Dest: CopyEnd));
7560
7561 return Builder.saveIP();
7562}
7563
7564CallInst *OpenMPIRBuilder::createOMPAlloc(const LocationDescription &Loc,
7565 Value *Size, Value *Allocator,
7566 std::string Name) {
7567 IRBuilder<>::InsertPointGuard IPG(Builder);
7568 updateToLocation(Loc);
7569
7570 uint32_t SrcLocStrSize;
7571 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7572 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7573 Value *ThreadId = getOrCreateThreadID(Ident);
7574 Value *Args[] = {ThreadId, Size, Allocator};
7575
7576 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_alloc);
7577
7578 return createRuntimeFunctionCall(Callee: Fn, Args, Name);
7579}
7580
7581CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc,
7582 Value *Addr, Value *Allocator,
7583 std::string Name) {
7584 IRBuilder<>::InsertPointGuard IPG(Builder);
7585 updateToLocation(Loc);
7586
7587 uint32_t SrcLocStrSize;
7588 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7589 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7590 Value *ThreadId = getOrCreateThreadID(Ident);
7591 Value *Args[] = {ThreadId, Addr, Allocator};
7592 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_free);
7593 return createRuntimeFunctionCall(Callee: Fn, Args, Name);
7594}
7595
7596CallInst *OpenMPIRBuilder::createOMPInteropInit(
7597 const LocationDescription &Loc, Value *InteropVar,
7598 omp::OMPInteropType InteropType, Value *Device, Value *NumDependences,
7599 Value *DependenceAddress, bool HaveNowaitClause) {
7600 IRBuilder<>::InsertPointGuard IPG(Builder);
7601 updateToLocation(Loc);
7602
7603 uint32_t SrcLocStrSize;
7604 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7605 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7606 Value *ThreadId = getOrCreateThreadID(Ident);
7607 if (Device == nullptr)
7608 Device = Constant::getAllOnesValue(Ty: Int32);
7609 Constant *InteropTypeVal = ConstantInt::get(Ty: Int32, V: (int)InteropType);
7610 if (NumDependences == nullptr) {
7611 NumDependences = ConstantInt::get(Ty: Int32, V: 0);
7612 PointerType *PointerTypeVar = PointerType::getUnqual(C&: M.getContext());
7613 DependenceAddress = ConstantPointerNull::get(T: PointerTypeVar);
7614 }
7615 Value *HaveNowaitClauseVal = ConstantInt::get(Ty: Int32, V: HaveNowaitClause);
7616 Value *Args[] = {
7617 Ident, ThreadId, InteropVar, InteropTypeVal,
7618 Device, NumDependences, DependenceAddress, HaveNowaitClauseVal};
7619
7620 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___tgt_interop_init);
7621
7622 return createRuntimeFunctionCall(Callee: Fn, Args);
7623}
7624
7625CallInst *OpenMPIRBuilder::createOMPInteropDestroy(
7626 const LocationDescription &Loc, Value *InteropVar, Value *Device,
7627 Value *NumDependences, Value *DependenceAddress, bool HaveNowaitClause) {
7628 IRBuilder<>::InsertPointGuard IPG(Builder);
7629 updateToLocation(Loc);
7630
7631 uint32_t SrcLocStrSize;
7632 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7633 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7634 Value *ThreadId = getOrCreateThreadID(Ident);
7635 if (Device == nullptr)
7636 Device = Constant::getAllOnesValue(Ty: Int32);
7637 if (NumDependences == nullptr) {
7638 NumDependences = ConstantInt::get(Ty: Int32, V: 0);
7639 PointerType *PointerTypeVar = PointerType::getUnqual(C&: M.getContext());
7640 DependenceAddress = ConstantPointerNull::get(T: PointerTypeVar);
7641 }
7642 Value *HaveNowaitClauseVal = ConstantInt::get(Ty: Int32, V: HaveNowaitClause);
7643 Value *Args[] = {
7644 Ident, ThreadId, InteropVar, Device,
7645 NumDependences, DependenceAddress, HaveNowaitClauseVal};
7646
7647 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___tgt_interop_destroy);
7648
7649 return createRuntimeFunctionCall(Callee: Fn, Args);
7650}
7651
7652CallInst *OpenMPIRBuilder::createOMPInteropUse(const LocationDescription &Loc,
7653 Value *InteropVar, Value *Device,
7654 Value *NumDependences,
7655 Value *DependenceAddress,
7656 bool HaveNowaitClause) {
7657 IRBuilder<>::InsertPointGuard IPG(Builder);
7658 updateToLocation(Loc);
7659 uint32_t SrcLocStrSize;
7660 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7661 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7662 Value *ThreadId = getOrCreateThreadID(Ident);
7663 if (Device == nullptr)
7664 Device = Constant::getAllOnesValue(Ty: Int32);
7665 if (NumDependences == nullptr) {
7666 NumDependences = ConstantInt::get(Ty: Int32, V: 0);
7667 PointerType *PointerTypeVar = PointerType::getUnqual(C&: M.getContext());
7668 DependenceAddress = ConstantPointerNull::get(T: PointerTypeVar);
7669 }
7670 Value *HaveNowaitClauseVal = ConstantInt::get(Ty: Int32, V: HaveNowaitClause);
7671 Value *Args[] = {
7672 Ident, ThreadId, InteropVar, Device,
7673 NumDependences, DependenceAddress, HaveNowaitClauseVal};
7674
7675 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___tgt_interop_use);
7676
7677 return createRuntimeFunctionCall(Callee: Fn, Args);
7678}
7679
7680CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
7681 const LocationDescription &Loc, llvm::Value *Pointer,
7682 llvm::ConstantInt *Size, const llvm::Twine &Name) {
7683 IRBuilder<>::InsertPointGuard IPG(Builder);
7684 updateToLocation(Loc);
7685
7686 uint32_t SrcLocStrSize;
7687 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7688 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7689 Value *ThreadId = getOrCreateThreadID(Ident);
7690 Constant *ThreadPrivateCache =
7691 getOrCreateInternalVariable(Ty: Int8PtrPtr, Name: Name.str());
7692 llvm::Value *Args[] = {Ident, ThreadId, Pointer, Size, ThreadPrivateCache};
7693
7694 Function *Fn =
7695 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_threadprivate_cached);
7696
7697 return createRuntimeFunctionCall(Callee: Fn, Args);
7698}
7699
7700OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
7701 const LocationDescription &Loc,
7702 const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
7703 assert(!Attrs.MaxThreads.empty() && !Attrs.MaxTeams.empty() &&
7704 "expected num_threads and num_teams to be specified");
7705
7706 if (!updateToLocation(Loc))
7707 return Loc.IP;
7708
7709 uint32_t SrcLocStrSize;
7710 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7711 Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7712 Constant *IsSPMDVal = ConstantInt::getSigned(Ty: Int8, V: Attrs.ExecFlags);
7713 Constant *UseGenericStateMachineVal = ConstantInt::getSigned(
7714 Ty: Int8, V: Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD);
7715 Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Ty: Int8, V: true);
7716 Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Ty: Int16, V: 0);
7717
7718 Function *DebugKernelWrapper = Builder.GetInsertBlock()->getParent();
7719 Function *Kernel = DebugKernelWrapper;
7720
7721 // We need to strip the debug prefix to get the correct kernel name.
7722 StringRef KernelName = Kernel->getName();
7723 const std::string DebugPrefix = "_debug__";
7724 if (KernelName.ends_with(Suffix: DebugPrefix)) {
7725 KernelName = KernelName.drop_back(N: DebugPrefix.length());
7726 Kernel = M.getFunction(Name: KernelName);
7727 assert(Kernel && "Expected the real kernel to exist");
7728 }
7729
7730 // Manifest the launch configuration in the metadata matching the kernel
7731 // environment.
7732 if (Attrs.MinTeams > 1 || Attrs.MaxTeams.front() > 0)
7733 writeTeamsForKernel(T, Kernel&: *Kernel, LB: Attrs.MinTeams, UB: Attrs.MaxTeams.front());
7734
7735 // If MaxThreads not set, select the maximum between the default workgroup
7736 // size and the MinThreads value.
7737 int32_t MaxThreadsVal = Attrs.MaxThreads.front();
7738 if (MaxThreadsVal < 0)
7739 MaxThreadsVal = std::max(
7740 a: int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), b: Attrs.MinThreads);
7741
7742 if (MaxThreadsVal > 0)
7743 writeThreadBoundsForKernel(T, Kernel&: *Kernel, LB: Attrs.MinThreads, UB: MaxThreadsVal);
7744
7745 Constant *MinThreads = ConstantInt::getSigned(Ty: Int32, V: Attrs.MinThreads);
7746 Constant *MaxThreads = ConstantInt::getSigned(Ty: Int32, V: MaxThreadsVal);
7747 Constant *MinTeams = ConstantInt::getSigned(Ty: Int32, V: Attrs.MinTeams);
7748 Constant *MaxTeams = ConstantInt::getSigned(Ty: Int32, V: Attrs.MaxTeams.front());
7749 Constant *ReductionDataSize =
7750 ConstantInt::getSigned(Ty: Int32, V: Attrs.ReductionDataSize);
7751 Constant *ReductionBufferLength =
7752 ConstantInt::getSigned(Ty: Int32, V: Attrs.ReductionBufferLength);
7753
7754 Function *Fn = getOrCreateRuntimeFunctionPtr(
7755 FnID: omp::RuntimeFunction::OMPRTL___kmpc_target_init);
7756 const DataLayout &DL = Fn->getDataLayout();
7757
7758 Twine DynamicEnvironmentName = KernelName + "_dynamic_environment";
7759 Constant *DynamicEnvironmentInitializer =
7760 ConstantStruct::get(T: DynamicEnvironment, V: {DebugIndentionLevelVal});
7761 GlobalVariable *DynamicEnvironmentGV = new GlobalVariable(
7762 M, DynamicEnvironment, /*IsConstant=*/false, GlobalValue::WeakODRLinkage,
7763 DynamicEnvironmentInitializer, DynamicEnvironmentName,
7764 /*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
7765 DL.getDefaultGlobalsAddressSpace());
7766 DynamicEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
7767
7768 Constant *DynamicEnvironment =
7769 DynamicEnvironmentGV->getType() == DynamicEnvironmentPtr
7770 ? DynamicEnvironmentGV
7771 : ConstantExpr::getAddrSpaceCast(C: DynamicEnvironmentGV,
7772 Ty: DynamicEnvironmentPtr);
7773
7774 Constant *ConfigurationEnvironmentInitializer = ConstantStruct::get(
7775 T: ConfigurationEnvironment, V: {
7776 UseGenericStateMachineVal,
7777 MayUseNestedParallelismVal,
7778 IsSPMDVal,
7779 MinThreads,
7780 MaxThreads,
7781 MinTeams,
7782 MaxTeams,
7783 ReductionDataSize,
7784 ReductionBufferLength,
7785 });
7786 Constant *KernelEnvironmentInitializer = ConstantStruct::get(
7787 T: KernelEnvironment, V: {
7788 ConfigurationEnvironmentInitializer,
7789 Ident,
7790 DynamicEnvironment,
7791 });
7792 std::string KernelEnvironmentName =
7793 (KernelName + "_kernel_environment").str();
7794 GlobalVariable *KernelEnvironmentGV = new GlobalVariable(
7795 M, KernelEnvironment, /*IsConstant=*/true, GlobalValue::WeakODRLinkage,
7796 KernelEnvironmentInitializer, KernelEnvironmentName,
7797 /*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
7798 DL.getDefaultGlobalsAddressSpace());
7799 KernelEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
7800
7801 Constant *KernelEnvironment =
7802 KernelEnvironmentGV->getType() == KernelEnvironmentPtr
7803 ? KernelEnvironmentGV
7804 : ConstantExpr::getAddrSpaceCast(C: KernelEnvironmentGV,
7805 Ty: KernelEnvironmentPtr);
7806 Value *KernelLaunchEnvironment = DebugKernelWrapper->getArg(i: 0);
7807 Type *KernelLaunchEnvParamTy = Fn->getFunctionType()->getParamType(i: 1);
7808 KernelLaunchEnvironment =
7809 KernelLaunchEnvironment->getType() == KernelLaunchEnvParamTy
7810 ? KernelLaunchEnvironment
7811 : Builder.CreateAddrSpaceCast(V: KernelLaunchEnvironment,
7812 DestTy: KernelLaunchEnvParamTy);
7813 CallInst *ThreadKind = createRuntimeFunctionCall(
7814 Callee: Fn, Args: {KernelEnvironment, KernelLaunchEnvironment});
7815
7816 Value *ExecUserCode = Builder.CreateICmpEQ(
7817 LHS: ThreadKind, RHS: Constant::getAllOnesValue(Ty: ThreadKind->getType()),
7818 Name: "exec_user_code");
7819
7820 // ThreadKind = __kmpc_target_init(...)
7821 // if (ThreadKind == -1)
7822 // user_code
7823 // else
7824 // return;
7825
7826 auto *UI = Builder.CreateUnreachable();
7827 BasicBlock *CheckBB = UI->getParent();
7828 BasicBlock *UserCodeEntryBB = CheckBB->splitBasicBlock(I: UI, BBName: "user_code.entry");
7829
7830 BasicBlock *WorkerExitBB = BasicBlock::Create(
7831 Context&: CheckBB->getContext(), Name: "worker.exit", Parent: CheckBB->getParent());
7832 Builder.SetInsertPoint(WorkerExitBB);
7833 Builder.CreateRetVoid();
7834
7835 auto *CheckBBTI = CheckBB->getTerminator();
7836 Builder.SetInsertPoint(CheckBBTI);
7837 Builder.CreateCondBr(Cond: ExecUserCode, True: UI->getParent(), False: WorkerExitBB);
7838
7839 CheckBBTI->eraseFromParent();
7840 UI->eraseFromParent();
7841
7842 // Continue in the "user_code" block, see diagram above and in
7843 // openmp/libomptarget/deviceRTLs/common/include/target.h .
7844 return InsertPointTy(UserCodeEntryBB, UserCodeEntryBB->getFirstInsertionPt());
7845}
7846
7847void OpenMPIRBuilder::createTargetDeinit(const LocationDescription &Loc,
7848 int32_t TeamsReductionDataSize,
7849 int32_t TeamsReductionBufferLength) {
7850 if (!updateToLocation(Loc))
7851 return;
7852
7853 Function *Fn = getOrCreateRuntimeFunctionPtr(
7854 FnID: omp::RuntimeFunction::OMPRTL___kmpc_target_deinit);
7855
7856 createRuntimeFunctionCall(Callee: Fn, Args: {});
7857
7858 if (!TeamsReductionBufferLength || !TeamsReductionDataSize)
7859 return;
7860
7861 Function *Kernel = Builder.GetInsertBlock()->getParent();
7862 // We need to strip the debug prefix to get the correct kernel name.
7863 StringRef KernelName = Kernel->getName();
7864 const std::string DebugPrefix = "_debug__";
7865 if (KernelName.ends_with(Suffix: DebugPrefix))
7866 KernelName = KernelName.drop_back(N: DebugPrefix.length());
7867 auto *KernelEnvironmentGV =
7868 M.getNamedGlobal(Name: (KernelName + "_kernel_environment").str());
7869 assert(KernelEnvironmentGV && "Expected kernel environment global\n");
7870 auto *KernelEnvironmentInitializer = KernelEnvironmentGV->getInitializer();
7871 auto *NewInitializer = ConstantFoldInsertValueInstruction(
7872 Agg: KernelEnvironmentInitializer,
7873 Val: ConstantInt::get(Ty: Int32, V: TeamsReductionDataSize), Idxs: {0, 7});
7874 NewInitializer = ConstantFoldInsertValueInstruction(
7875 Agg: NewInitializer, Val: ConstantInt::get(Ty: Int32, V: TeamsReductionBufferLength),
7876 Idxs: {0, 8});
7877 KernelEnvironmentGV->setInitializer(NewInitializer);
7878}
7879
7880static void updateNVPTXAttr(Function &Kernel, StringRef Name, int32_t Value,
7881 bool Min) {
7882 if (Kernel.hasFnAttribute(Kind: Name)) {
7883 int32_t OldLimit = Kernel.getFnAttributeAsParsedInteger(Kind: Name);
7884 Value = Min ? std::min(a: OldLimit, b: Value) : std::max(a: OldLimit, b: Value);
7885 }
7886 Kernel.addFnAttr(Kind: Name, Val: llvm::utostr(X: Value));
7887}
7888
7889std::pair<int32_t, int32_t>
7890OpenMPIRBuilder::readThreadBoundsForKernel(const Triple &T, Function &Kernel) {
7891 int32_t ThreadLimit =
7892 Kernel.getFnAttributeAsParsedInteger(Kind: "omp_target_thread_limit");
7893
7894 if (T.isAMDGPU()) {
7895 const auto &Attr = Kernel.getFnAttribute(Kind: "amdgpu-flat-work-group-size");
7896 if (!Attr.isValid() || !Attr.isStringAttribute())
7897 return {0, ThreadLimit};
7898 auto [LBStr, UBStr] = Attr.getValueAsString().split(Separator: ',');
7899 int32_t LB, UB;
7900 if (!llvm::to_integer(S: UBStr, Num&: UB, Base: 10))
7901 return {0, ThreadLimit};
7902 UB = ThreadLimit ? std::min(a: ThreadLimit, b: UB) : UB;
7903 if (!llvm::to_integer(S: LBStr, Num&: LB, Base: 10))
7904 return {0, UB};
7905 return {LB, UB};
7906 }
7907
7908 if (Kernel.hasFnAttribute(Kind: "nvvm.maxntid")) {
7909 int32_t UB = Kernel.getFnAttributeAsParsedInteger(Kind: "nvvm.maxntid");
7910 return {0, ThreadLimit ? std::min(a: ThreadLimit, b: UB) : UB};
7911 }
7912 return {0, ThreadLimit};
7913}
7914
7915void OpenMPIRBuilder::writeThreadBoundsForKernel(const Triple &T,
7916 Function &Kernel, int32_t LB,
7917 int32_t UB) {
7918 Kernel.addFnAttr(Kind: "omp_target_thread_limit", Val: std::to_string(val: UB));
7919
7920 if (T.isAMDGPU()) {
7921 Kernel.addFnAttr(Kind: "amdgpu-flat-work-group-size",
7922 Val: llvm::utostr(X: LB) + "," + llvm::utostr(X: UB));
7923 return;
7924 }
7925
7926 updateNVPTXAttr(Kernel, Name: "nvvm.maxntid", Value: UB, Min: true);
7927}
7928
7929std::pair<int32_t, int32_t>
7930OpenMPIRBuilder::readTeamBoundsForKernel(const Triple &, Function &Kernel) {
7931 // TODO: Read from backend annotations if available.
7932 return {0, Kernel.getFnAttributeAsParsedInteger(Kind: "omp_target_num_teams")};
7933}
7934
7935void OpenMPIRBuilder::writeTeamsForKernel(const Triple &T, Function &Kernel,
7936 int32_t LB, int32_t UB) {
7937 if (T.isNVPTX())
7938 if (UB > 0)
7939 Kernel.addFnAttr(Kind: "nvvm.maxclusterrank", Val: llvm::utostr(X: UB));
7940 if (T.isAMDGPU())
7941 Kernel.addFnAttr(Kind: "amdgpu-max-num-workgroups", Val: llvm::utostr(X: LB) + ",1,1");
7942
7943 Kernel.addFnAttr(Kind: "omp_target_num_teams", Val: std::to_string(val: LB));
7944}
7945
7946void OpenMPIRBuilder::setOutlinedTargetRegionFunctionAttributes(
7947 Function *OutlinedFn) {
7948 if (Config.isTargetDevice()) {
7949 OutlinedFn->setLinkage(GlobalValue::WeakODRLinkage);
7950 // TODO: Determine if DSO local can be set to true.
7951 OutlinedFn->setDSOLocal(false);
7952 OutlinedFn->setVisibility(GlobalValue::ProtectedVisibility);
7953 if (T.isAMDGCN())
7954 OutlinedFn->setCallingConv(CallingConv::AMDGPU_KERNEL);
7955 else if (T.isNVPTX())
7956 OutlinedFn->setCallingConv(CallingConv::PTX_Kernel);
7957 else if (T.isSPIRV())
7958 OutlinedFn->setCallingConv(CallingConv::SPIR_KERNEL);
7959 }
7960}
7961
7962Constant *OpenMPIRBuilder::createOutlinedFunctionID(Function *OutlinedFn,
7963 StringRef EntryFnIDName) {
7964 if (Config.isTargetDevice()) {
7965 assert(OutlinedFn && "The outlined function must exist if embedded");
7966 return OutlinedFn;
7967 }
7968
7969 return new GlobalVariable(
7970 M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
7971 Constant::getNullValue(Ty: Builder.getInt8Ty()), EntryFnIDName);
7972}
7973
7974Constant *OpenMPIRBuilder::createTargetRegionEntryAddr(Function *OutlinedFn,
7975 StringRef EntryFnName) {
7976 if (OutlinedFn)
7977 return OutlinedFn;
7978
7979 assert(!M.getGlobalVariable(EntryFnName, true) &&
7980 "Named kernel already exists?");
7981 return new GlobalVariable(
7982 M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::InternalLinkage,
7983 Constant::getNullValue(Ty: Builder.getInt8Ty()), EntryFnName);
7984}
7985
7986Error OpenMPIRBuilder::emitTargetRegionFunction(
7987 TargetRegionEntryInfo &EntryInfo,
7988 FunctionGenCallback &GenerateFunctionCallback, bool IsOffloadEntry,
7989 Function *&OutlinedFn, Constant *&OutlinedFnID) {
7990
7991 SmallString<64> EntryFnName;
7992 OffloadInfoManager.getTargetRegionEntryFnName(Name&: EntryFnName, EntryInfo);
7993
7994 if (Config.isTargetDevice() || !Config.openMPOffloadMandatory()) {
7995 Expected<Function *> CBResult = GenerateFunctionCallback(EntryFnName);
7996 if (!CBResult)
7997 return CBResult.takeError();
7998 OutlinedFn = *CBResult;
7999 } else {
8000 OutlinedFn = nullptr;
8001 }
8002
8003 // If this target outline function is not an offload entry, we don't need to
8004 // register it. This may be in the case of a false if clause, or if there are
8005 // no OpenMP targets.
8006 if (!IsOffloadEntry)
8007 return Error::success();
8008
8009 std::string EntryFnIDName =
8010 Config.isTargetDevice()
8011 ? std::string(EntryFnName)
8012 : createPlatformSpecificName(Parts: {EntryFnName, "region_id"});
8013
8014 OutlinedFnID = registerTargetRegionFunction(EntryInfo, OutlinedFunction: OutlinedFn,
8015 EntryFnName, EntryFnIDName);
8016 return Error::success();
8017}
8018
8019Constant *OpenMPIRBuilder::registerTargetRegionFunction(
8020 TargetRegionEntryInfo &EntryInfo, Function *OutlinedFn,
8021 StringRef EntryFnName, StringRef EntryFnIDName) {
8022 if (OutlinedFn)
8023 setOutlinedTargetRegionFunctionAttributes(OutlinedFn);
8024 auto OutlinedFnID = createOutlinedFunctionID(OutlinedFn, EntryFnIDName);
8025 auto EntryAddr = createTargetRegionEntryAddr(OutlinedFn, EntryFnName);
8026 OffloadInfoManager.registerTargetRegionEntryInfo(
8027 EntryInfo, Addr: EntryAddr, ID: OutlinedFnID,
8028 Flags: OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion);
8029 return OutlinedFnID;
8030}
8031
8032OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData(
8033 const LocationDescription &Loc, InsertPointTy AllocaIP,
8034 InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
8035 TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
8036 CustomMapperCallbackTy CustomMapperCB, omp::RuntimeFunction *MapperFunc,
8037 function_ref<InsertPointOrErrorTy(InsertPointTy CodeGenIP,
8038 BodyGenTy BodyGenType)>
8039 BodyGenCB,
8040 function_ref<void(unsigned int, Value *)> DeviceAddrCB, Value *SrcLocInfo) {
8041 if (!updateToLocation(Loc))
8042 return InsertPointTy();
8043
8044 Builder.restoreIP(IP: CodeGenIP);
8045
8046 bool IsStandAlone = !BodyGenCB;
8047 MapInfosTy *MapInfo;
8048 // Generate the code for the opening of the data environment. Capture all the
8049 // arguments of the runtime call by reference because they are used in the
8050 // closing of the region.
8051 auto BeginThenGen = [&](InsertPointTy AllocaIP,
8052 InsertPointTy CodeGenIP) -> Error {
8053 MapInfo = &GenMapInfoCB(Builder.saveIP());
8054 if (Error Err = emitOffloadingArrays(
8055 AllocaIP, CodeGenIP: Builder.saveIP(), CombinedInfo&: *MapInfo, Info, CustomMapperCB,
8056 /*IsNonContiguous=*/true, DeviceAddrCB))
8057 return Err;
8058
8059 TargetDataRTArgs RTArgs;
8060 emitOffloadingArraysArgument(Builder, RTArgs, Info);
8061
8062 // Emit the number of elements in the offloading arrays.
8063 Value *PointerNum = Builder.getInt32(C: Info.NumberOfPtrs);
8064
8065 // Source location for the ident struct
8066 if (!SrcLocInfo) {
8067 uint32_t SrcLocStrSize;
8068 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
8069 SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
8070 }
8071
8072 SmallVector<llvm::Value *, 13> OffloadingArgs = {
8073 SrcLocInfo, DeviceID,
8074 PointerNum, RTArgs.BasePointersArray,
8075 RTArgs.PointersArray, RTArgs.SizesArray,
8076 RTArgs.MapTypesArray, RTArgs.MapNamesArray,
8077 RTArgs.MappersArray};
8078
8079 if (IsStandAlone) {
8080 assert(MapperFunc && "MapperFunc missing for standalone target data");
8081
8082 auto TaskBodyCB = [&](Value *, Value *,
8083 IRBuilderBase::InsertPoint) -> Error {
8084 if (Info.HasNoWait) {
8085 OffloadingArgs.append(IL: {llvm::Constant::getNullValue(Ty: Int32),
8086 llvm::Constant::getNullValue(Ty: VoidPtr),
8087 llvm::Constant::getNullValue(Ty: Int32),
8088 llvm::Constant::getNullValue(Ty: VoidPtr)});
8089 }
8090
8091 createRuntimeFunctionCall(Callee: getOrCreateRuntimeFunctionPtr(FnID: *MapperFunc),
8092 Args: OffloadingArgs);
8093
8094 if (Info.HasNoWait) {
8095 BasicBlock *OffloadContBlock =
8096 BasicBlock::Create(Context&: Builder.getContext(), Name: "omp_offload.cont");
8097 Function *CurFn = Builder.GetInsertBlock()->getParent();
8098 emitBlock(BB: OffloadContBlock, CurFn, /*IsFinished=*/true);
8099 Builder.restoreIP(IP: Builder.saveIP());
8100 }
8101 return Error::success();
8102 };
8103
8104 bool RequiresOuterTargetTask = Info.HasNoWait;
8105 if (!RequiresOuterTargetTask)
8106 cantFail(Err: TaskBodyCB(/*DeviceID=*/nullptr, /*RTLoc=*/nullptr,
8107 /*TargetTaskAllocaIP=*/{}));
8108 else
8109 cantFail(ValOrErr: emitTargetTask(TaskBodyCB, DeviceID, RTLoc: SrcLocInfo, AllocaIP,
8110 /*Dependencies=*/{}, RTArgs, HasNoWait: Info.HasNoWait));
8111 } else {
8112 Function *BeginMapperFunc = getOrCreateRuntimeFunctionPtr(
8113 FnID: omp::OMPRTL___tgt_target_data_begin_mapper);
8114
8115 createRuntimeFunctionCall(Callee: BeginMapperFunc, Args: OffloadingArgs);
8116
8117 for (auto DeviceMap : Info.DevicePtrInfoMap) {
8118 if (isa<AllocaInst>(Val: DeviceMap.second.second)) {
8119 auto *LI =
8120 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: DeviceMap.second.first);
8121 Builder.CreateStore(Val: LI, Ptr: DeviceMap.second.second);
8122 }
8123 }
8124
8125 // If device pointer privatization is required, emit the body of the
8126 // region here. It will have to be duplicated: with and without
8127 // privatization.
8128 InsertPointOrErrorTy AfterIP =
8129 BodyGenCB(Builder.saveIP(), BodyGenTy::Priv);
8130 if (!AfterIP)
8131 return AfterIP.takeError();
8132 Builder.restoreIP(IP: *AfterIP);
8133 }
8134 return Error::success();
8135 };
8136
8137 // If we need device pointer privatization, we need to emit the body of the
8138 // region with no privatization in the 'else' branch of the conditional.
8139 // Otherwise, we don't have to do anything.
8140 auto BeginElseGen = [&](InsertPointTy AllocaIP,
8141 InsertPointTy CodeGenIP) -> Error {
8142 InsertPointOrErrorTy AfterIP =
8143 BodyGenCB(Builder.saveIP(), BodyGenTy::DupNoPriv);
8144 if (!AfterIP)
8145 return AfterIP.takeError();
8146 Builder.restoreIP(IP: *AfterIP);
8147 return Error::success();
8148 };
8149
8150 // Generate code for the closing of the data region.
8151 auto EndThenGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
8152 TargetDataRTArgs RTArgs;
8153 Info.EmitDebug = !MapInfo->Names.empty();
8154 emitOffloadingArraysArgument(Builder, RTArgs, Info, /*ForEndCall=*/true);
8155
8156 // Emit the number of elements in the offloading arrays.
8157 Value *PointerNum = Builder.getInt32(C: Info.NumberOfPtrs);
8158
8159 // Source location for the ident struct
8160 if (!SrcLocInfo) {
8161 uint32_t SrcLocStrSize;
8162 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
8163 SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
8164 }
8165
8166 Value *OffloadingArgs[] = {SrcLocInfo, DeviceID,
8167 PointerNum, RTArgs.BasePointersArray,
8168 RTArgs.PointersArray, RTArgs.SizesArray,
8169 RTArgs.MapTypesArray, RTArgs.MapNamesArray,
8170 RTArgs.MappersArray};
8171 Function *EndMapperFunc =
8172 getOrCreateRuntimeFunctionPtr(FnID: omp::OMPRTL___tgt_target_data_end_mapper);
8173
8174 createRuntimeFunctionCall(Callee: EndMapperFunc, Args: OffloadingArgs);
8175 return Error::success();
8176 };
8177
8178 // We don't have to do anything to close the region if the if clause evaluates
8179 // to false.
8180 auto EndElseGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
8181 return Error::success();
8182 };
8183
8184 Error Err = [&]() -> Error {
8185 if (BodyGenCB) {
8186 Error Err = [&]() {
8187 if (IfCond)
8188 return emitIfClause(Cond: IfCond, ThenGen: BeginThenGen, ElseGen: BeginElseGen, AllocaIP);
8189 return BeginThenGen(AllocaIP, Builder.saveIP());
8190 }();
8191
8192 if (Err)
8193 return Err;
8194
8195 // If we don't require privatization of device pointers, we emit the body
8196 // in between the runtime calls. This avoids duplicating the body code.
8197 InsertPointOrErrorTy AfterIP =
8198 BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv);
8199 if (!AfterIP)
8200 return AfterIP.takeError();
8201 restoreIPandDebugLoc(Builder, IP: *AfterIP);
8202
8203 if (IfCond)
8204 return emitIfClause(Cond: IfCond, ThenGen: EndThenGen, ElseGen: EndElseGen, AllocaIP);
8205 return EndThenGen(AllocaIP, Builder.saveIP());
8206 }
8207 if (IfCond)
8208 return emitIfClause(Cond: IfCond, ThenGen: BeginThenGen, ElseGen: EndElseGen, AllocaIP);
8209 return BeginThenGen(AllocaIP, Builder.saveIP());
8210 }();
8211
8212 if (Err)
8213 return Err;
8214
8215 return Builder.saveIP();
8216}
8217
8218FunctionCallee
8219OpenMPIRBuilder::createForStaticInitFunction(unsigned IVSize, bool IVSigned,
8220 bool IsGPUDistribute) {
8221 assert((IVSize == 32 || IVSize == 64) &&
8222 "IV size is not compatible with the omp runtime");
8223 RuntimeFunction Name;
8224 if (IsGPUDistribute)
8225 Name = IVSize == 32
8226 ? (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_4
8227 : omp::OMPRTL___kmpc_distribute_static_init_4u)
8228 : (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_8
8229 : omp::OMPRTL___kmpc_distribute_static_init_8u);
8230 else
8231 Name = IVSize == 32 ? (IVSigned ? omp::OMPRTL___kmpc_for_static_init_4
8232 : omp::OMPRTL___kmpc_for_static_init_4u)
8233 : (IVSigned ? omp::OMPRTL___kmpc_for_static_init_8
8234 : omp::OMPRTL___kmpc_for_static_init_8u);
8235
8236 return getOrCreateRuntimeFunction(M, FnID: Name);
8237}
8238
8239FunctionCallee OpenMPIRBuilder::createDispatchInitFunction(unsigned IVSize,
8240 bool IVSigned) {
8241 assert((IVSize == 32 || IVSize == 64) &&
8242 "IV size is not compatible with the omp runtime");
8243 RuntimeFunction Name = IVSize == 32
8244 ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_4
8245 : omp::OMPRTL___kmpc_dispatch_init_4u)
8246 : (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_8
8247 : omp::OMPRTL___kmpc_dispatch_init_8u);
8248
8249 return getOrCreateRuntimeFunction(M, FnID: Name);
8250}
8251
8252FunctionCallee OpenMPIRBuilder::createDispatchNextFunction(unsigned IVSize,
8253 bool IVSigned) {
8254 assert((IVSize == 32 || IVSize == 64) &&
8255 "IV size is not compatible with the omp runtime");
8256 RuntimeFunction Name = IVSize == 32
8257 ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_4
8258 : omp::OMPRTL___kmpc_dispatch_next_4u)
8259 : (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_8
8260 : omp::OMPRTL___kmpc_dispatch_next_8u);
8261
8262 return getOrCreateRuntimeFunction(M, FnID: Name);
8263}
8264
8265FunctionCallee OpenMPIRBuilder::createDispatchFiniFunction(unsigned IVSize,
8266 bool IVSigned) {
8267 assert((IVSize == 32 || IVSize == 64) &&
8268 "IV size is not compatible with the omp runtime");
8269 RuntimeFunction Name = IVSize == 32
8270 ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_4
8271 : omp::OMPRTL___kmpc_dispatch_fini_4u)
8272 : (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_8
8273 : omp::OMPRTL___kmpc_dispatch_fini_8u);
8274
8275 return getOrCreateRuntimeFunction(M, FnID: Name);
8276}
8277
8278FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
8279 return getOrCreateRuntimeFunction(M, FnID: omp::OMPRTL___kmpc_dispatch_deinit);
8280}
8281
8282static void FixupDebugInfoForOutlinedFunction(
8283 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, Function *Func,
8284 DenseMap<Value *, std::tuple<Value *, unsigned>> &ValueReplacementMap) {
8285
8286 DISubprogram *NewSP = Func->getSubprogram();
8287 if (!NewSP)
8288 return;
8289
8290 SmallDenseMap<DILocalVariable *, DILocalVariable *> RemappedVariables;
8291
8292 auto GetUpdatedDIVariable = [&](DILocalVariable *OldVar, unsigned arg) {
8293 DILocalVariable *&NewVar = RemappedVariables[OldVar];
8294 // Only use cached variable if the arg number matches. This is important
8295 // so that DIVariable created for privatized variables are not discarded.
8296 if (NewVar && (arg == NewVar->getArg()))
8297 return NewVar;
8298
8299 NewVar = llvm::DILocalVariable::get(
8300 Context&: Builder.getContext(), Scope: OldVar->getScope(), Name: OldVar->getName(),
8301 File: OldVar->getFile(), Line: OldVar->getLine(), Type: OldVar->getType(), Arg: arg,
8302 Flags: OldVar->getFlags(), AlignInBits: OldVar->getAlignInBits(), Annotations: OldVar->getAnnotations());
8303 return NewVar;
8304 };
8305
8306 auto UpdateDebugRecord = [&](auto *DR) {
8307 DILocalVariable *OldVar = DR->getVariable();
8308 unsigned ArgNo = 0;
8309 for (auto Loc : DR->location_ops()) {
8310 auto Iter = ValueReplacementMap.find(Loc);
8311 if (Iter != ValueReplacementMap.end()) {
8312 DR->replaceVariableLocationOp(Loc, std::get<0>(Iter->second));
8313 ArgNo = std::get<1>(Iter->second) + 1;
8314 }
8315 }
8316 if (ArgNo != 0)
8317 DR->setVariable(GetUpdatedDIVariable(OldVar, ArgNo));
8318 };
8319
8320 // The location and scope of variable intrinsics and records still point to
8321 // the parent function of the target region. Update them.
8322 for (Instruction &I : instructions(F: Func)) {
8323 assert(!isa<llvm::DbgVariableIntrinsic>(&I) &&
8324 "Unexpected debug intrinsic");
8325 for (DbgVariableRecord &DVR : filterDbgVars(R: I.getDbgRecordRange()))
8326 UpdateDebugRecord(&DVR);
8327 }
8328 // An extra argument is passed to the device. Create the debug data for it.
8329 if (OMPBuilder.Config.isTargetDevice()) {
8330 DICompileUnit *CU = NewSP->getUnit();
8331 Module *M = Func->getParent();
8332 DIBuilder DB(*M, true, CU);
8333 DIType *VoidPtrTy =
8334 DB.createQualifiedType(Tag: dwarf::DW_TAG_pointer_type, FromTy: nullptr);
8335 DILocalVariable *Var = DB.createParameterVariable(
8336 Scope: NewSP, Name: "dyn_ptr", /*ArgNo*/ 1, File: NewSP->getFile(), /*LineNo=*/0,
8337 Ty: VoidPtrTy, /*AlwaysPreserve=*/false, Flags: DINode::DIFlags::FlagArtificial);
8338 auto Loc = DILocation::get(Context&: Func->getContext(), Line: 0, Column: 0, Scope: NewSP, InlinedAt: 0);
8339 DB.insertDeclare(Storage: &(*Func->arg_begin()), VarInfo: Var, Expr: DB.createExpression(), DL: Loc,
8340 InsertAtEnd: &(*Func->begin()));
8341 }
8342}
8343
8344static Value *removeASCastIfPresent(Value *V) {
8345 if (Operator::getOpcode(V) == Instruction::AddrSpaceCast)
8346 return cast<Operator>(Val: V)->getOperand(i: 0);
8347 return V;
8348}
8349
8350static Expected<Function *> createOutlinedFunction(
8351 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
8352 const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
8353 StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
8354 OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
8355 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
8356 SmallVector<Type *> ParameterTypes;
8357 if (OMPBuilder.Config.isTargetDevice()) {
8358 // Add the "implicit" runtime argument we use to provide launch specific
8359 // information for target devices.
8360 auto *Int8PtrTy = PointerType::getUnqual(C&: Builder.getContext());
8361 ParameterTypes.push_back(Elt: Int8PtrTy);
8362
8363 // All parameters to target devices are passed as pointers
8364 // or i64. This assumes 64-bit address spaces/pointers.
8365 for (auto &Arg : Inputs)
8366 ParameterTypes.push_back(Elt: Arg->getType()->isPointerTy()
8367 ? Arg->getType()
8368 : Type::getInt64Ty(C&: Builder.getContext()));
8369 } else {
8370 for (auto &Arg : Inputs)
8371 ParameterTypes.push_back(Elt: Arg->getType());
8372 }
8373
8374 auto BB = Builder.GetInsertBlock();
8375 auto M = BB->getModule();
8376 auto FuncType = FunctionType::get(Result: Builder.getVoidTy(), Params: ParameterTypes,
8377 /*isVarArg*/ false);
8378 auto Func =
8379 Function::Create(Ty: FuncType, Linkage: GlobalValue::InternalLinkage, N: FuncName, M);
8380
8381 // Forward target-cpu and target-features function attributes from the
8382 // original function to the new outlined function.
8383 Function *ParentFn = Builder.GetInsertBlock()->getParent();
8384
8385 auto TargetCpuAttr = ParentFn->getFnAttribute(Kind: "target-cpu");
8386 if (TargetCpuAttr.isStringAttribute())
8387 Func->addFnAttr(Attr: TargetCpuAttr);
8388
8389 auto TargetFeaturesAttr = ParentFn->getFnAttribute(Kind: "target-features");
8390 if (TargetFeaturesAttr.isStringAttribute())
8391 Func->addFnAttr(Attr: TargetFeaturesAttr);
8392
8393 if (OMPBuilder.Config.isTargetDevice()) {
8394 Value *ExecMode =
8395 OMPBuilder.emitKernelExecutionMode(KernelName: FuncName, Mode: DefaultAttrs.ExecFlags);
8396 OMPBuilder.emitUsed(Name: "llvm.compiler.used", List: {ExecMode});
8397 }
8398
8399 // Save insert point.
8400 IRBuilder<>::InsertPointGuard IPG(Builder);
8401 // We will generate the entries in the outlined function but the debug
8402 // location may still be pointing to the parent function. Reset it now.
8403 Builder.SetCurrentDebugLocation(llvm::DebugLoc());
8404
8405 // Generate the region into the function.
8406 BasicBlock *EntryBB = BasicBlock::Create(Context&: Builder.getContext(), Name: "entry", Parent: Func);
8407 Builder.SetInsertPoint(EntryBB);
8408
8409 // Insert target init call in the device compilation pass.
8410 if (OMPBuilder.Config.isTargetDevice())
8411 Builder.restoreIP(IP: OMPBuilder.createTargetInit(Loc: Builder, Attrs: DefaultAttrs));
8412
8413 BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
8414
8415 // As we embed the user code in the middle of our target region after we
8416 // generate entry code, we must move what allocas we can into the entry
8417 // block to avoid possible breaking optimisations for device
8418 if (OMPBuilder.Config.isTargetDevice())
8419 OMPBuilder.ConstantAllocaRaiseCandidates.emplace_back(Args&: Func);
8420
8421 // Insert target deinit call in the device compilation pass.
8422 BasicBlock *OutlinedBodyBB =
8423 splitBB(Builder, /*CreateBranch=*/true, Name: "outlined.body");
8424 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = CBFunc(
8425 Builder.saveIP(),
8426 OpenMPIRBuilder::InsertPointTy(OutlinedBodyBB, OutlinedBodyBB->begin()));
8427 if (!AfterIP)
8428 return AfterIP.takeError();
8429 Builder.restoreIP(IP: *AfterIP);
8430 if (OMPBuilder.Config.isTargetDevice())
8431 OMPBuilder.createTargetDeinit(Loc: Builder);
8432
8433 // Insert return instruction.
8434 Builder.CreateRetVoid();
8435
8436 // New Alloca IP at entry point of created device function.
8437 Builder.SetInsertPoint(EntryBB->getFirstNonPHIIt());
8438 auto AllocaIP = Builder.saveIP();
8439
8440 Builder.SetInsertPoint(UserCodeEntryBB->getFirstNonPHIOrDbg());
8441
8442 // Skip the artificial dyn_ptr on the device.
8443 const auto &ArgRange =
8444 OMPBuilder.Config.isTargetDevice()
8445 ? make_range(x: Func->arg_begin() + 1, y: Func->arg_end())
8446 : Func->args();
8447
8448 DenseMap<Value *, std::tuple<Value *, unsigned>> ValueReplacementMap;
8449
8450 auto ReplaceValue = [](Value *Input, Value *InputCopy, Function *Func) {
8451 // Things like GEP's can come in the form of Constants. Constants and
8452 // ConstantExpr's do not have access to the knowledge of what they're
8453 // contained in, so we must dig a little to find an instruction so we
8454 // can tell if they're used inside of the function we're outlining. We
8455 // also replace the original constant expression with a new instruction
8456 // equivalent; an instruction as it allows easy modification in the
8457 // following loop, as we can now know the constant (instruction) is
8458 // owned by our target function and replaceUsesOfWith can now be invoked
8459 // on it (cannot do this with constants it seems). A brand new one also
8460 // allows us to be cautious as it is perhaps possible the old expression
8461 // was used inside of the function but exists and is used externally
8462 // (unlikely by the nature of a Constant, but still).
8463 // NOTE: We cannot remove dead constants that have been rewritten to
8464 // instructions at this stage, we run the risk of breaking later lowering
8465 // by doing so as we could still be in the process of lowering the module
8466 // from MLIR to LLVM-IR and the MLIR lowering may still require the original
8467 // constants we have created rewritten versions of.
8468 if (auto *Const = dyn_cast<Constant>(Val: Input))
8469 convertUsersOfConstantsToInstructions(Consts: Const, RestrictToFunc: Func, RemoveDeadConstants: false);
8470
8471 // Collect users before iterating over them to avoid invalidating the
8472 // iteration in case a user uses Input more than once (e.g. a call
8473 // instruction).
8474 SetVector<User *> Users(Input->users().begin(), Input->users().end());
8475 // Collect all the instructions
8476 for (User *User : make_early_inc_range(Range&: Users))
8477 if (auto *Instr = dyn_cast<Instruction>(Val: User))
8478 if (Instr->getFunction() == Func)
8479 Instr->replaceUsesOfWith(From: Input, To: InputCopy);
8480 };
8481
8482 SmallVector<std::pair<Value *, Value *>> DeferredReplacement;
8483
8484 // Rewrite uses of input valus to parameters.
8485 for (auto InArg : zip(t&: Inputs, u: ArgRange)) {
8486 Value *Input = std::get<0>(t&: InArg);
8487 Argument &Arg = std::get<1>(t&: InArg);
8488 Value *InputCopy = nullptr;
8489
8490 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
8491 ArgAccessorFuncCB(Arg, Input, InputCopy, AllocaIP, Builder.saveIP());
8492 if (!AfterIP)
8493 return AfterIP.takeError();
8494 Builder.restoreIP(IP: *AfterIP);
8495 ValueReplacementMap[Input] = std::make_tuple(args&: InputCopy, args: Arg.getArgNo());
8496
8497 // In certain cases a Global may be set up for replacement, however, this
8498 // Global may be used in multiple arguments to the kernel, just segmented
8499 // apart, for example, if we have a global array, that is sectioned into
8500 // multiple mappings (technically not legal in OpenMP, but there is a case
8501 // in Fortran for Common Blocks where this is neccesary), we will end up
8502 // with GEP's into this array inside the kernel, that refer to the Global
8503 // but are technically separate arguments to the kernel for all intents and
8504 // purposes. If we have mapped a segment that requires a GEP into the 0-th
8505 // index, it will fold into an referal to the Global, if we then encounter
8506 // this folded GEP during replacement all of the references to the
8507 // Global in the kernel will be replaced with the argument we have generated
8508 // that corresponds to it, including any other GEP's that refer to the
8509 // Global that may be other arguments. This will invalidate all of the other
8510 // preceding mapped arguments that refer to the same global that may be
8511 // separate segments. To prevent this, we defer global processing until all
8512 // other processing has been performed.
8513 if (llvm::isa<llvm::GlobalValue, llvm::GlobalObject, llvm::GlobalVariable>(
8514 Val: removeASCastIfPresent(V: Input))) {
8515 DeferredReplacement.push_back(Elt: std::make_pair(x&: Input, y&: InputCopy));
8516 continue;
8517 }
8518
8519 if (isa<ConstantData>(Val: Input))
8520 continue;
8521
8522 ReplaceValue(Input, InputCopy, Func);
8523 }
8524
8525 // Replace all of our deferred Input values, currently just Globals.
8526 for (auto Deferred : DeferredReplacement)
8527 ReplaceValue(std::get<0>(in&: Deferred), std::get<1>(in&: Deferred), Func);
8528
8529 FixupDebugInfoForOutlinedFunction(OMPBuilder, Builder, Func,
8530 ValueReplacementMap);
8531 return Func;
8532}
8533/// Given a task descriptor, TaskWithPrivates, return the pointer to the block
8534/// of pointers containing shared data between the parent task and the created
8535/// task.
8536static LoadInst *loadSharedDataFromTaskDescriptor(OpenMPIRBuilder &OMPIRBuilder,
8537 IRBuilderBase &Builder,
8538 Value *TaskWithPrivates,
8539 Type *TaskWithPrivatesTy) {
8540
8541 Type *TaskTy = OMPIRBuilder.Task;
8542 LLVMContext &Ctx = Builder.getContext();
8543 Value *TaskT =
8544 Builder.CreateStructGEP(Ty: TaskWithPrivatesTy, Ptr: TaskWithPrivates, Idx: 0);
8545 Value *Shareds = TaskT;
8546 // TaskWithPrivatesTy can be one of the following
8547 // 1. %struct.task_with_privates = type { %struct.kmp_task_ompbuilder_t,
8548 // %struct.privates }
8549 // 2. %struct.kmp_task_ompbuilder_t ;; This is simply TaskTy
8550 //
8551 // In the former case, that is when TaskWithPrivatesTy != TaskTy,
8552 // its first member has to be the task descriptor. TaskTy is the type of the
8553 // task descriptor. TaskT is the pointer to the task descriptor. Loading the
8554 // first member of TaskT, gives us the pointer to shared data.
8555 if (TaskWithPrivatesTy != TaskTy)
8556 Shareds = Builder.CreateStructGEP(Ty: TaskTy, Ptr: TaskT, Idx: 0);
8557 return Builder.CreateLoad(Ty: PointerType::getUnqual(C&: Ctx), Ptr: Shareds);
8558}
8559/// Create an entry point for a target task with the following.
8560/// It'll have the following signature
8561/// void @.omp_target_task_proxy_func(i32 %thread.id, ptr %task)
8562/// This function is called from emitTargetTask once the
8563/// code to launch the target kernel has been outlined already.
8564/// NumOffloadingArrays is the number of offloading arrays that we need to copy
8565/// into the task structure so that the deferred target task can access this
8566/// data even after the stack frame of the generating task has been rolled
8567/// back. Offloading arrays contain base pointers, pointers, sizes etc
8568/// of the data that the target kernel will access. These in effect are the
8569/// non-empty arrays of pointers held by OpenMPIRBuilder::TargetDataRTArgs.
8570static Function *emitTargetTaskProxyFunction(
8571 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, CallInst *StaleCI,
8572 StructType *PrivatesTy, StructType *TaskWithPrivatesTy,
8573 const size_t NumOffloadingArrays, const int SharedArgsOperandNo) {
8574
8575 // If NumOffloadingArrays is non-zero, PrivatesTy better not be nullptr.
8576 // This is because PrivatesTy is the type of the structure in which
8577 // we pass the offloading arrays to the deferred target task.
8578 assert((!NumOffloadingArrays || PrivatesTy) &&
8579 "PrivatesTy cannot be nullptr when there are offloadingArrays"
8580 "to privatize");
8581
8582 Module &M = OMPBuilder.M;
8583 // KernelLaunchFunction is the target launch function, i.e.
8584 // the function that sets up kernel arguments and calls
8585 // __tgt_target_kernel to launch the kernel on the device.
8586 //
8587 Function *KernelLaunchFunction = StaleCI->getCalledFunction();
8588
8589 // StaleCI is the CallInst which is the call to the outlined
8590 // target kernel launch function. If there are local live-in values
8591 // that the outlined function uses then these are aggregated into a structure
8592 // which is passed as the second argument. If there are no local live-in
8593 // values or if all values used by the outlined kernel are global variables,
8594 // then there's only one argument, the threadID. So, StaleCI can be
8595 //
8596 // %structArg = alloca { ptr, ptr }, align 8
8597 // %gep_ = getelementptr { ptr, ptr }, ptr %structArg, i32 0, i32 0
8598 // store ptr %20, ptr %gep_, align 8
8599 // %gep_8 = getelementptr { ptr, ptr }, ptr %structArg, i32 0, i32 1
8600 // store ptr %21, ptr %gep_8, align 8
8601 // call void @_QQmain..omp_par.1(i32 %global.tid.val6, ptr %structArg)
8602 //
8603 // OR
8604 //
8605 // call void @_QQmain..omp_par.1(i32 %global.tid.val6)
8606 OpenMPIRBuilder::InsertPointTy IP(StaleCI->getParent(),
8607 StaleCI->getIterator());
8608
8609 LLVMContext &Ctx = StaleCI->getParent()->getContext();
8610
8611 Type *ThreadIDTy = Type::getInt32Ty(C&: Ctx);
8612 Type *TaskPtrTy = OMPBuilder.TaskPtr;
8613 [[maybe_unused]] Type *TaskTy = OMPBuilder.Task;
8614
8615 auto ProxyFnTy =
8616 FunctionType::get(Result: Builder.getVoidTy(), Params: {ThreadIDTy, TaskPtrTy},
8617 /* isVarArg */ false);
8618 auto ProxyFn = Function::Create(Ty: ProxyFnTy, Linkage: GlobalValue::InternalLinkage,
8619 N: ".omp_target_task_proxy_func",
8620 M: Builder.GetInsertBlock()->getModule());
8621 Value *ThreadId = ProxyFn->getArg(i: 0);
8622 Value *TaskWithPrivates = ProxyFn->getArg(i: 1);
8623 ThreadId->setName("thread.id");
8624 TaskWithPrivates->setName("task");
8625
8626 bool HasShareds = SharedArgsOperandNo > 0;
8627 bool HasOffloadingArrays = NumOffloadingArrays > 0;
8628 BasicBlock *EntryBB =
8629 BasicBlock::Create(Context&: Builder.getContext(), Name: "entry", Parent: ProxyFn);
8630 Builder.SetInsertPoint(EntryBB);
8631
8632 SmallVector<Value *> KernelLaunchArgs;
8633 KernelLaunchArgs.reserve(N: StaleCI->arg_size());
8634 KernelLaunchArgs.push_back(Elt: ThreadId);
8635
8636 if (HasOffloadingArrays) {
8637 assert(TaskTy != TaskWithPrivatesTy &&
8638 "If there are offloading arrays to pass to the target"
8639 "TaskTy cannot be the same as TaskWithPrivatesTy");
8640 (void)TaskTy;
8641 Value *Privates =
8642 Builder.CreateStructGEP(Ty: TaskWithPrivatesTy, Ptr: TaskWithPrivates, Idx: 1);
8643 for (unsigned int i = 0; i < NumOffloadingArrays; ++i)
8644 KernelLaunchArgs.push_back(
8645 Elt: Builder.CreateStructGEP(Ty: PrivatesTy, Ptr: Privates, Idx: i));
8646 }
8647
8648 if (HasShareds) {
8649 auto *ArgStructAlloca =
8650 dyn_cast<AllocaInst>(Val: StaleCI->getArgOperand(i: SharedArgsOperandNo));
8651 assert(ArgStructAlloca &&
8652 "Unable to find the alloca instruction corresponding to arguments "
8653 "for extracted function");
8654 auto *ArgStructType = cast<StructType>(Val: ArgStructAlloca->getAllocatedType());
8655
8656 AllocaInst *NewArgStructAlloca =
8657 Builder.CreateAlloca(Ty: ArgStructType, ArraySize: nullptr, Name: "structArg");
8658
8659 Value *SharedsSize =
8660 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: ArgStructType));
8661
8662 LoadInst *LoadShared = loadSharedDataFromTaskDescriptor(
8663 OMPIRBuilder&: OMPBuilder, Builder, TaskWithPrivates, TaskWithPrivatesTy);
8664
8665 Builder.CreateMemCpy(
8666 Dst: NewArgStructAlloca, DstAlign: NewArgStructAlloca->getAlign(), Src: LoadShared,
8667 SrcAlign: LoadShared->getPointerAlignment(DL: M.getDataLayout()), Size: SharedsSize);
8668 KernelLaunchArgs.push_back(Elt: NewArgStructAlloca);
8669 }
8670 OMPBuilder.createRuntimeFunctionCall(Callee: KernelLaunchFunction, Args: KernelLaunchArgs);
8671 Builder.CreateRetVoid();
8672 return ProxyFn;
8673}
8674static Type *getOffloadingArrayType(Value *V) {
8675
8676 if (auto *GEP = dyn_cast<GetElementPtrInst>(Val: V))
8677 return GEP->getSourceElementType();
8678 if (auto *Alloca = dyn_cast<AllocaInst>(Val: V))
8679 return Alloca->getAllocatedType();
8680
8681 llvm_unreachable("Unhandled Instruction type");
8682 return nullptr;
8683}
8684// This function returns a struct that has at most two members.
8685// The first member is always %struct.kmp_task_ompbuilder_t, that is the task
8686// descriptor. The second member, if needed, is a struct containing arrays
8687// that need to be passed to the offloaded target kernel. For example,
8688// if .offload_baseptrs, .offload_ptrs and .offload_sizes have to be passed to
8689// the target kernel and their types are [3 x ptr], [3 x ptr] and [3 x i64]
8690// respectively, then the types created by this function are
8691//
8692// %struct.privates = type { [3 x ptr], [3 x ptr], [3 x i64] }
8693// %struct.task_with_privates = type { %struct.kmp_task_ompbuilder_t,
8694// %struct.privates }
8695// %struct.task_with_privates is returned by this function.
8696// If there aren't any offloading arrays to pass to the target kernel,
8697// %struct.kmp_task_ompbuilder_t is returned.
8698static StructType *
8699createTaskWithPrivatesTy(OpenMPIRBuilder &OMPIRBuilder,
8700 ArrayRef<Value *> OffloadingArraysToPrivatize) {
8701
8702 if (OffloadingArraysToPrivatize.empty())
8703 return OMPIRBuilder.Task;
8704
8705 SmallVector<Type *, 4> StructFieldTypes;
8706 for (Value *V : OffloadingArraysToPrivatize) {
8707 assert(V->getType()->isPointerTy() &&
8708 "Expected pointer to array to privatize. Got a non-pointer value "
8709 "instead");
8710 Type *ArrayTy = getOffloadingArrayType(V);
8711 assert(ArrayTy && "ArrayType cannot be nullptr");
8712 StructFieldTypes.push_back(Elt: ArrayTy);
8713 }
8714 StructType *PrivatesStructTy =
8715 StructType::create(Elements: StructFieldTypes, Name: "struct.privates");
8716 return StructType::create(Elements: {OMPIRBuilder.Task, PrivatesStructTy},
8717 Name: "struct.task_with_privates");
8718}
8719static Error emitTargetOutlinedFunction(
8720 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
8721 TargetRegionEntryInfo &EntryInfo,
8722 const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
8723 Function *&OutlinedFn, Constant *&OutlinedFnID,
8724 SmallVectorImpl<Value *> &Inputs,
8725 OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
8726 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
8727
8728 OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
8729 [&](StringRef EntryFnName) {
8730 return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
8731 FuncName: EntryFnName, Inputs, CBFunc,
8732 ArgAccessorFuncCB);
8733 };
8734
8735 return OMPBuilder.emitTargetRegionFunction(
8736 EntryInfo, GenerateFunctionCallback&: GenerateOutlinedFunction, IsOffloadEntry, OutlinedFn,
8737 OutlinedFnID);
8738}
8739
8740OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask(
8741 TargetTaskBodyCallbackTy TaskBodyCB, Value *DeviceID, Value *RTLoc,
8742 OpenMPIRBuilder::InsertPointTy AllocaIP,
8743 const SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
8744 const TargetDataRTArgs &RTArgs, bool HasNoWait) {
8745
8746 // The following explains the code-gen scenario for the `target` directive. A
8747 // similar scneario is followed for other device-related directives (e.g.
8748 // `target enter data`) but in similar fashion since we only need to emit task
8749 // that encapsulates the proper runtime call.
8750 //
8751 // When we arrive at this function, the target region itself has been
8752 // outlined into the function OutlinedFn.
8753 // So at ths point, for
8754 // --------------------------------------------------------------
8755 // void user_code_that_offloads(...) {
8756 // omp target depend(..) map(from:a) map(to:b) private(i)
8757 // do i = 1, 10
8758 // a(i) = b(i) + n
8759 // }
8760 //
8761 // --------------------------------------------------------------
8762 //
8763 // we have
8764 //
8765 // --------------------------------------------------------------
8766 //
8767 // void user_code_that_offloads(...) {
8768 // %.offload_baseptrs = alloca [2 x ptr], align 8
8769 // %.offload_ptrs = alloca [2 x ptr], align 8
8770 // %.offload_mappers = alloca [2 x ptr], align 8
8771 // ;; target region has been outlined and now we need to
8772 // ;; offload to it via a target task.
8773 // }
8774 // void outlined_device_function(ptr a, ptr b, ptr n) {
8775 // n = *n_ptr;
8776 // do i = 1, 10
8777 // a(i) = b(i) + n
8778 // }
8779 //
8780 // We have to now do the following
8781 // (i) Make an offloading call to outlined_device_function using the OpenMP
8782 // RTL. See 'kernel_launch_function' in the pseudo code below. This is
8783 // emitted by emitKernelLaunch
8784 // (ii) Create a task entry point function that calls kernel_launch_function
8785 // and is the entry point for the target task. See
8786 // '@.omp_target_task_proxy_func in the pseudocode below.
8787 // (iii) Create a task with the task entry point created in (ii)
8788 //
8789 // That is we create the following
8790 // struct task_with_privates {
8791 // struct kmp_task_ompbuilder_t task_struct;
8792 // struct privates {
8793 // [2 x ptr] ; baseptrs
8794 // [2 x ptr] ; ptrs
8795 // [2 x i64] ; sizes
8796 // }
8797 // }
8798 // void user_code_that_offloads(...) {
8799 // %.offload_baseptrs = alloca [2 x ptr], align 8
8800 // %.offload_ptrs = alloca [2 x ptr], align 8
8801 // %.offload_sizes = alloca [2 x i64], align 8
8802 //
8803 // %structArg = alloca { ptr, ptr, ptr }, align 8
8804 // %strucArg[0] = a
8805 // %strucArg[1] = b
8806 // %strucArg[2] = &n
8807 //
8808 // target_task_with_privates = @__kmpc_omp_target_task_alloc(...,
8809 // sizeof(kmp_task_ompbuilder_t),
8810 // sizeof(structArg),
8811 // @.omp_target_task_proxy_func,
8812 // ...)
8813 // memcpy(target_task_with_privates->task_struct->shareds, %structArg,
8814 // sizeof(structArg))
8815 // memcpy(target_task_with_privates->privates->baseptrs,
8816 // offload_baseptrs, sizeof(offload_baseptrs)
8817 // memcpy(target_task_with_privates->privates->ptrs,
8818 // offload_ptrs, sizeof(offload_ptrs)
8819 // memcpy(target_task_with_privates->privates->sizes,
8820 // offload_sizes, sizeof(offload_sizes)
8821 // dependencies_array = ...
8822 // ;; if nowait not present
8823 // call @__kmpc_omp_wait_deps(..., dependencies_array)
8824 // call @__kmpc_omp_task_begin_if0(...)
8825 // call @ @.omp_target_task_proxy_func(i32 thread_id, ptr
8826 // %target_task_with_privates)
8827 // call @__kmpc_omp_task_complete_if0(...)
8828 // }
8829 //
8830 // define internal void @.omp_target_task_proxy_func(i32 %thread.id,
8831 // ptr %task) {
8832 // %structArg = alloca {ptr, ptr, ptr}
8833 // %task_ptr = getelementptr(%task, 0, 0)
8834 // %shared_data = load (getelementptr %task_ptr, 0, 0)
8835 // mempcy(%structArg, %shared_data, sizeof(%structArg))
8836 //
8837 // %offloading_arrays = getelementptr(%task, 0, 1)
8838 // %offload_baseptrs = getelementptr(%offloading_arrays, 0, 0)
8839 // %offload_ptrs = getelementptr(%offloading_arrays, 0, 1)
8840 // %offload_sizes = getelementptr(%offloading_arrays, 0, 2)
8841 // kernel_launch_function(%thread.id, %offload_baseptrs, %offload_ptrs,
8842 // %offload_sizes, %structArg)
8843 // }
8844 //
8845 // We need the proxy function because the signature of the task entry point
8846 // expected by kmpc_omp_task is always the same and will be different from
8847 // that of the kernel_launch function.
8848 //
8849 // kernel_launch_function is generated by emitKernelLaunch and has the
8850 // always_inline attribute. For this example, it'll look like so:
8851 // void kernel_launch_function(%thread_id, %offload_baseptrs, %offload_ptrs,
8852 // %offload_sizes, %structArg) alwaysinline {
8853 // %kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
8854 // ; load aggregated data from %structArg
8855 // ; setup kernel_args using offload_baseptrs, offload_ptrs and
8856 // ; offload_sizes
8857 // call i32 @__tgt_target_kernel(...,
8858 // outlined_device_function,
8859 // ptr %kernel_args)
8860 // }
8861 // void outlined_device_function(ptr a, ptr b, ptr n) {
8862 // n = *n_ptr;
8863 // do i = 1, 10
8864 // a(i) = b(i) + n
8865 // }
8866 //
8867 BasicBlock *TargetTaskBodyBB =
8868 splitBB(Builder, /*CreateBranch=*/true, Name: "target.task.body");
8869 BasicBlock *TargetTaskAllocaBB =
8870 splitBB(Builder, /*CreateBranch=*/true, Name: "target.task.alloca");
8871
8872 InsertPointTy TargetTaskAllocaIP(TargetTaskAllocaBB,
8873 TargetTaskAllocaBB->begin());
8874 InsertPointTy TargetTaskBodyIP(TargetTaskBodyBB, TargetTaskBodyBB->begin());
8875
8876 OutlineInfo OI;
8877 OI.EntryBB = TargetTaskAllocaBB;
8878 OI.OuterAllocaBB = AllocaIP.getBlock();
8879
8880 // Add the thread ID argument.
8881 SmallVector<Instruction *, 4> ToBeDeleted;
8882 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
8883 Builder, OuterAllocaIP: AllocaIP, ToBeDeleted, InnerAllocaIP: TargetTaskAllocaIP, Name: "global.tid", AsPtr: false));
8884
8885 // Generate the task body which will subsequently be outlined.
8886 Builder.restoreIP(IP: TargetTaskBodyIP);
8887 if (Error Err = TaskBodyCB(DeviceID, RTLoc, TargetTaskAllocaIP))
8888 return Err;
8889
8890 // The outliner (CodeExtractor) extract a sequence or vector of blocks that
8891 // it is given. These blocks are enumerated by
8892 // OpenMPIRBuilder::OutlineInfo::collectBlocks which expects the OI.ExitBlock
8893 // to be outside the region. In other words, OI.ExitBlock is expected to be
8894 // the start of the region after the outlining. We used to set OI.ExitBlock
8895 // to the InsertBlock after TaskBodyCB is done. This is fine in most cases
8896 // except when the task body is a single basic block. In that case,
8897 // OI.ExitBlock is set to the single task body block and will get left out of
8898 // the outlining process. So, simply create a new empty block to which we
8899 // uncoditionally branch from where TaskBodyCB left off
8900 OI.ExitBB = BasicBlock::Create(Context&: Builder.getContext(), Name: "target.task.cont");
8901 emitBlock(BB: OI.ExitBB, CurFn: Builder.GetInsertBlock()->getParent(),
8902 /*IsFinished=*/true);
8903
8904 SmallVector<Value *, 2> OffloadingArraysToPrivatize;
8905 bool NeedsTargetTask = HasNoWait && DeviceID;
8906 if (NeedsTargetTask) {
8907 for (auto *V :
8908 {RTArgs.BasePointersArray, RTArgs.PointersArray, RTArgs.MappersArray,
8909 RTArgs.MapNamesArray, RTArgs.MapTypesArray, RTArgs.MapTypesArrayEnd,
8910 RTArgs.SizesArray}) {
8911 if (V && !isa<ConstantPointerNull, GlobalVariable>(Val: V)) {
8912 OffloadingArraysToPrivatize.push_back(Elt: V);
8913 OI.ExcludeArgsFromAggregate.push_back(Elt: V);
8914 }
8915 }
8916 }
8917 OI.PostOutlineCB = [this, ToBeDeleted, Dependencies, NeedsTargetTask,
8918 DeviceID, OffloadingArraysToPrivatize](
8919 Function &OutlinedFn) mutable {
8920 assert(OutlinedFn.hasOneUse() &&
8921 "there must be a single user for the outlined function");
8922
8923 CallInst *StaleCI = cast<CallInst>(Val: OutlinedFn.user_back());
8924
8925 // The first argument of StaleCI is always the thread id.
8926 // The next few arguments are the pointers to offloading arrays
8927 // if any. (see OffloadingArraysToPrivatize)
8928 // Finally, all other local values that are live-in into the outlined region
8929 // end up in a structure whose pointer is passed as the last argument. This
8930 // piece of data is passed in the "shared" field of the task structure. So,
8931 // we know we have to pass shareds to the task if the number of arguments is
8932 // greater than OffloadingArraysToPrivatize.size() + 1 The 1 is for the
8933 // thread id. Further, for safety, we assert that the number of arguments of
8934 // StaleCI is exactly OffloadingArraysToPrivatize.size() + 2
8935 const unsigned int NumStaleCIArgs = StaleCI->arg_size();
8936 bool HasShareds = NumStaleCIArgs > OffloadingArraysToPrivatize.size() + 1;
8937 assert((!HasShareds ||
8938 NumStaleCIArgs == (OffloadingArraysToPrivatize.size() + 2)) &&
8939 "Wrong number of arguments for StaleCI when shareds are present");
8940 int SharedArgOperandNo =
8941 HasShareds ? OffloadingArraysToPrivatize.size() + 1 : 0;
8942
8943 StructType *TaskWithPrivatesTy =
8944 createTaskWithPrivatesTy(OMPIRBuilder&: *this, OffloadingArraysToPrivatize);
8945 StructType *PrivatesTy = nullptr;
8946
8947 if (!OffloadingArraysToPrivatize.empty())
8948 PrivatesTy =
8949 static_cast<StructType *>(TaskWithPrivatesTy->getElementType(N: 1));
8950
8951 Function *ProxyFn = emitTargetTaskProxyFunction(
8952 OMPBuilder&: *this, Builder, StaleCI, PrivatesTy, TaskWithPrivatesTy,
8953 NumOffloadingArrays: OffloadingArraysToPrivatize.size(), SharedArgsOperandNo: SharedArgOperandNo);
8954
8955 LLVM_DEBUG(dbgs() << "Proxy task entry function created: " << *ProxyFn
8956 << "\n");
8957
8958 Builder.SetInsertPoint(StaleCI);
8959
8960 // Gather the arguments for emitting the runtime call.
8961 uint32_t SrcLocStrSize;
8962 Constant *SrcLocStr =
8963 getOrCreateSrcLocStr(Loc: LocationDescription(Builder), SrcLocStrSize);
8964 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
8965
8966 // @__kmpc_omp_task_alloc or @__kmpc_omp_target_task_alloc
8967 //
8968 // If `HasNoWait == true`, we call @__kmpc_omp_target_task_alloc to provide
8969 // the DeviceID to the deferred task and also since
8970 // @__kmpc_omp_target_task_alloc creates an untied/async task.
8971 Function *TaskAllocFn =
8972 !NeedsTargetTask
8973 ? getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_alloc)
8974 : getOrCreateRuntimeFunctionPtr(
8975 FnID: OMPRTL___kmpc_omp_target_task_alloc);
8976
8977 // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
8978 // call.
8979 Value *ThreadID = getOrCreateThreadID(Ident);
8980
8981 // Argument - `sizeof_kmp_task_t` (TaskSize)
8982 // Tasksize refers to the size in bytes of kmp_task_t data structure
8983 // plus any other data to be passed to the target task, if any, which
8984 // is packed into a struct. kmp_task_t and the struct so created are
8985 // packed into a wrapper struct whose type is TaskWithPrivatesTy.
8986 Value *TaskSize = Builder.getInt64(
8987 C: M.getDataLayout().getTypeStoreSize(Ty: TaskWithPrivatesTy));
8988
8989 // Argument - `sizeof_shareds` (SharedsSize)
8990 // SharedsSize refers to the shareds array size in the kmp_task_t data
8991 // structure.
8992 Value *SharedsSize = Builder.getInt64(C: 0);
8993 if (HasShareds) {
8994 auto *ArgStructAlloca =
8995 dyn_cast<AllocaInst>(Val: StaleCI->getArgOperand(i: SharedArgOperandNo));
8996 assert(ArgStructAlloca &&
8997 "Unable to find the alloca instruction corresponding to arguments "
8998 "for extracted function");
8999 auto *ArgStructType =
9000 dyn_cast<StructType>(Val: ArgStructAlloca->getAllocatedType());
9001 assert(ArgStructType && "Unable to find struct type corresponding to "
9002 "arguments for extracted function");
9003 SharedsSize =
9004 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: ArgStructType));
9005 }
9006
9007 // Argument - `flags`
9008 // Task is tied iff (Flags & 1) == 1.
9009 // Task is untied iff (Flags & 1) == 0.
9010 // Task is final iff (Flags & 2) == 2.
9011 // Task is not final iff (Flags & 2) == 0.
9012 // A target task is not final and is untied.
9013 Value *Flags = Builder.getInt32(C: 0);
9014
9015 // Emit the @__kmpc_omp_task_alloc runtime call
9016 // The runtime call returns a pointer to an area where the task captured
9017 // variables must be copied before the task is run (TaskData)
9018 CallInst *TaskData = nullptr;
9019
9020 SmallVector<llvm::Value *> TaskAllocArgs = {
9021 /*loc_ref=*/Ident, /*gtid=*/ThreadID,
9022 /*flags=*/Flags,
9023 /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
9024 /*task_func=*/ProxyFn};
9025
9026 if (NeedsTargetTask) {
9027 assert(DeviceID && "Expected non-empty device ID.");
9028 TaskAllocArgs.push_back(Elt: DeviceID);
9029 }
9030
9031 TaskData = createRuntimeFunctionCall(Callee: TaskAllocFn, Args: TaskAllocArgs);
9032
9033 Align Alignment = TaskData->getPointerAlignment(DL: M.getDataLayout());
9034 if (HasShareds) {
9035 Value *Shareds = StaleCI->getArgOperand(i: SharedArgOperandNo);
9036 Value *TaskShareds = loadSharedDataFromTaskDescriptor(
9037 OMPIRBuilder&: *this, Builder, TaskWithPrivates: TaskData, TaskWithPrivatesTy);
9038 Builder.CreateMemCpy(Dst: TaskShareds, DstAlign: Alignment, Src: Shareds, SrcAlign: Alignment,
9039 Size: SharedsSize);
9040 }
9041 if (!OffloadingArraysToPrivatize.empty()) {
9042 Value *Privates =
9043 Builder.CreateStructGEP(Ty: TaskWithPrivatesTy, Ptr: TaskData, Idx: 1);
9044 for (unsigned int i = 0; i < OffloadingArraysToPrivatize.size(); ++i) {
9045 Value *PtrToPrivatize = OffloadingArraysToPrivatize[i];
9046 [[maybe_unused]] Type *ArrayType =
9047 getOffloadingArrayType(V: PtrToPrivatize);
9048 assert(ArrayType && "ArrayType cannot be nullptr");
9049
9050 Type *ElementType = PrivatesTy->getElementType(N: i);
9051 assert(ElementType == ArrayType &&
9052 "ElementType should match ArrayType");
9053 (void)ArrayType;
9054
9055 Value *Dst = Builder.CreateStructGEP(Ty: PrivatesTy, Ptr: Privates, Idx: i);
9056 Builder.CreateMemCpy(
9057 Dst, DstAlign: Alignment, Src: PtrToPrivatize, SrcAlign: Alignment,
9058 Size: Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: ElementType)));
9059 }
9060 }
9061
9062 Value *DepArray = emitTaskDependencies(OMPBuilder&: *this, Dependencies);
9063
9064 // ---------------------------------------------------------------
9065 // V5.2 13.8 target construct
9066 // If the nowait clause is present, execution of the target task
9067 // may be deferred. If the nowait clause is not present, the target task is
9068 // an included task.
9069 // ---------------------------------------------------------------
9070 // The above means that the lack of a nowait on the target construct
9071 // translates to '#pragma omp task if(0)'
9072 if (!NeedsTargetTask) {
9073 if (DepArray) {
9074 Function *TaskWaitFn =
9075 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_wait_deps);
9076 createRuntimeFunctionCall(
9077 Callee: TaskWaitFn,
9078 Args: {/*loc_ref=*/Ident, /*gtid=*/ThreadID,
9079 /*ndeps=*/Builder.getInt32(C: Dependencies.size()),
9080 /*dep_list=*/DepArray,
9081 /*ndeps_noalias=*/ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
9082 /*noalias_dep_list=*/
9083 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
9084 }
9085 // Included task.
9086 Function *TaskBeginFn =
9087 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_begin_if0);
9088 Function *TaskCompleteFn =
9089 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_complete_if0);
9090 createRuntimeFunctionCall(Callee: TaskBeginFn, Args: {Ident, ThreadID, TaskData});
9091 CallInst *CI = createRuntimeFunctionCall(Callee: ProxyFn, Args: {ThreadID, TaskData});
9092 CI->setDebugLoc(StaleCI->getDebugLoc());
9093 createRuntimeFunctionCall(Callee: TaskCompleteFn, Args: {Ident, ThreadID, TaskData});
9094 } else if (DepArray) {
9095 // HasNoWait - meaning the task may be deferred. Call
9096 // __kmpc_omp_task_with_deps if there are dependencies,
9097 // else call __kmpc_omp_task
9098 Function *TaskFn =
9099 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_with_deps);
9100 createRuntimeFunctionCall(
9101 Callee: TaskFn,
9102 Args: {Ident, ThreadID, TaskData, Builder.getInt32(C: Dependencies.size()),
9103 DepArray, ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
9104 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
9105 } else {
9106 // Emit the @__kmpc_omp_task runtime call to spawn the task
9107 Function *TaskFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task);
9108 createRuntimeFunctionCall(Callee: TaskFn, Args: {Ident, ThreadID, TaskData});
9109 }
9110
9111 StaleCI->eraseFromParent();
9112 for (Instruction *I : llvm::reverse(C&: ToBeDeleted))
9113 I->eraseFromParent();
9114 };
9115 addOutlineInfo(OI: std::move(OI));
9116
9117 LLVM_DEBUG(dbgs() << "Insert block after emitKernelLaunch = \n"
9118 << *(Builder.GetInsertBlock()) << "\n");
9119 LLVM_DEBUG(dbgs() << "Module after emitKernelLaunch = \n"
9120 << *(Builder.GetInsertBlock()->getParent()->getParent())
9121 << "\n");
9122 return Builder.saveIP();
9123}
9124
9125Error OpenMPIRBuilder::emitOffloadingArraysAndArgs(
9126 InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
9127 TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo,
9128 CustomMapperCallbackTy CustomMapperCB, bool IsNonContiguous,
9129 bool ForEndCall, function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
9130 if (Error Err =
9131 emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info,
9132 CustomMapperCB, IsNonContiguous, DeviceAddrCB))
9133 return Err;
9134 emitOffloadingArraysArgument(Builder, RTArgs, Info, ForEndCall);
9135 return Error::success();
9136}
9137
9138static void emitTargetCall(
9139 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
9140 OpenMPIRBuilder::InsertPointTy AllocaIP,
9141 OpenMPIRBuilder::TargetDataInfo &Info,
9142 const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
9143 const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
9144 Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID,
9145 SmallVectorImpl<Value *> &Args,
9146 OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
9147 OpenMPIRBuilder::CustomMapperCallbackTy CustomMapperCB,
9148 const SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
9149 bool HasNoWait, Value *DynCGroupMem,
9150 OMPDynGroupprivateFallbackType DynCGroupMemFallback) {
9151 // Generate a function call to the host fallback implementation of the target
9152 // region. This is called by the host when no offload entry was generated for
9153 // the target region and when the offloading call fails at runtime.
9154 auto &&EmitTargetCallFallbackCB = [&](OpenMPIRBuilder::InsertPointTy IP)
9155 -> OpenMPIRBuilder::InsertPointOrErrorTy {
9156 Builder.restoreIP(IP);
9157 OMPBuilder.createRuntimeFunctionCall(Callee: OutlinedFn, Args);
9158 return Builder.saveIP();
9159 };
9160
9161 bool HasDependencies = Dependencies.size() > 0;
9162 bool RequiresOuterTargetTask = HasNoWait || HasDependencies;
9163
9164 OpenMPIRBuilder::TargetKernelArgs KArgs;
9165
9166 auto TaskBodyCB =
9167 [&](Value *DeviceID, Value *RTLoc,
9168 IRBuilderBase::InsertPoint TargetTaskAllocaIP) -> Error {
9169 // Assume no error was returned because EmitTargetCallFallbackCB doesn't
9170 // produce any.
9171 llvm::OpenMPIRBuilder::InsertPointTy AfterIP = cantFail(ValOrErr: [&]() {
9172 // emitKernelLaunch makes the necessary runtime call to offload the
9173 // kernel. We then outline all that code into a separate function
9174 // ('kernel_launch_function' in the pseudo code above). This function is
9175 // then called by the target task proxy function (see
9176 // '@.omp_target_task_proxy_func' in the pseudo code above)
9177 // "@.omp_target_task_proxy_func' is generated by
9178 // emitTargetTaskProxyFunction.
9179 if (OutlinedFnID && DeviceID)
9180 return OMPBuilder.emitKernelLaunch(Loc: Builder, OutlinedFnID,
9181 EmitTargetCallFallbackCB, Args&: KArgs,
9182 DeviceID, RTLoc, AllocaIP: TargetTaskAllocaIP);
9183
9184 // We only need to do the outlining if `DeviceID` is set to avoid calling
9185 // `emitKernelLaunch` if we want to code-gen for the host; e.g. if we are
9186 // generating the `else` branch of an `if` clause.
9187 //
9188 // When OutlinedFnID is set to nullptr, then it's not an offloading call.
9189 // In this case, we execute the host implementation directly.
9190 return EmitTargetCallFallbackCB(OMPBuilder.Builder.saveIP());
9191 }());
9192
9193 OMPBuilder.Builder.restoreIP(IP: AfterIP);
9194 return Error::success();
9195 };
9196
9197 auto &&EmitTargetCallElse =
9198 [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
9199 OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
9200 // Assume no error was returned because EmitTargetCallFallbackCB doesn't
9201 // produce any.
9202 OpenMPIRBuilder::InsertPointTy AfterIP = cantFail(ValOrErr: [&]() {
9203 if (RequiresOuterTargetTask) {
9204 // Arguments that are intended to be directly forwarded to an
9205 // emitKernelLaunch call are pased as nullptr, since
9206 // OutlinedFnID=nullptr results in that call not being done.
9207 OpenMPIRBuilder::TargetDataRTArgs EmptyRTArgs;
9208 return OMPBuilder.emitTargetTask(TaskBodyCB, /*DeviceID=*/nullptr,
9209 /*RTLoc=*/nullptr, AllocaIP,
9210 Dependencies, RTArgs: EmptyRTArgs, HasNoWait);
9211 }
9212 return EmitTargetCallFallbackCB(Builder.saveIP());
9213 }());
9214
9215 Builder.restoreIP(IP: AfterIP);
9216 return Error::success();
9217 };
9218
9219 auto &&EmitTargetCallThen =
9220 [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
9221 OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
9222 Info.HasNoWait = HasNoWait;
9223 OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
9224 OpenMPIRBuilder::TargetDataRTArgs RTArgs;
9225 if (Error Err = OMPBuilder.emitOffloadingArraysAndArgs(
9226 AllocaIP, CodeGenIP: Builder.saveIP(), Info, RTArgs, CombinedInfo&: MapInfo, CustomMapperCB,
9227 /*IsNonContiguous=*/true,
9228 /*ForEndCall=*/false))
9229 return Err;
9230
9231 SmallVector<Value *, 3> NumTeamsC;
9232 for (auto [DefaultVal, RuntimeVal] :
9233 zip_equal(t: DefaultAttrs.MaxTeams, u: RuntimeAttrs.MaxTeams))
9234 NumTeamsC.push_back(Elt: RuntimeVal ? RuntimeVal
9235 : Builder.getInt32(C: DefaultVal));
9236
9237 // Calculate number of threads: 0 if no clauses specified, otherwise it is
9238 // the minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
9239 auto InitMaxThreadsClause = [&Builder](Value *Clause) {
9240 if (Clause)
9241 Clause = Builder.CreateIntCast(V: Clause, DestTy: Builder.getInt32Ty(),
9242 /*isSigned=*/false);
9243 return Clause;
9244 };
9245 auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
9246 if (Clause)
9247 Result =
9248 Result ? Builder.CreateSelect(C: Builder.CreateICmpULT(LHS: Result, RHS: Clause),
9249 True: Result, False: Clause)
9250 : Clause;
9251 };
9252
9253 // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
9254 // the NUM_THREADS clause is overriden by THREAD_LIMIT.
9255 SmallVector<Value *, 3> NumThreadsC;
9256 Value *MaxThreadsClause =
9257 RuntimeAttrs.TeamsThreadLimit.size() == 1
9258 ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
9259 : nullptr;
9260
9261 for (auto [TeamsVal, TargetVal] : zip_equal(
9262 t: RuntimeAttrs.TeamsThreadLimit, u: RuntimeAttrs.TargetThreadLimit)) {
9263 Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
9264 Value *NumThreads = InitMaxThreadsClause(TargetVal);
9265
9266 CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
9267 CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
9268
9269 NumThreadsC.push_back(Elt: NumThreads ? NumThreads : Builder.getInt32(C: 0));
9270 }
9271
9272 unsigned NumTargetItems = Info.NumberOfPtrs;
9273 uint32_t SrcLocStrSize;
9274 Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
9275 Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
9276 LocFlags: llvm::omp::IdentFlag(0), Reserve2Flags: 0);
9277
9278 Value *TripCount = RuntimeAttrs.LoopTripCount
9279 ? Builder.CreateIntCast(V: RuntimeAttrs.LoopTripCount,
9280 DestTy: Builder.getInt64Ty(),
9281 /*isSigned=*/false)
9282 : Builder.getInt64(C: 0);
9283
9284 // Request zero groupprivate bytes by default.
9285 if (!DynCGroupMem)
9286 DynCGroupMem = Builder.getInt32(C: 0);
9287
9288 KArgs = OpenMPIRBuilder::TargetKernelArgs(
9289 NumTargetItems, RTArgs, TripCount, NumTeamsC, NumThreadsC, DynCGroupMem,
9290 HasNoWait, DynCGroupMemFallback);
9291
9292 // Assume no error was returned because TaskBodyCB and
9293 // EmitTargetCallFallbackCB don't produce any.
9294 OpenMPIRBuilder::InsertPointTy AfterIP = cantFail(ValOrErr: [&]() {
9295 // The presence of certain clauses on the target directive require the
9296 // explicit generation of the target task.
9297 if (RequiresOuterTargetTask)
9298 return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID: RuntimeAttrs.DeviceID,
9299 RTLoc, AllocaIP, Dependencies,
9300 RTArgs: KArgs.RTArgs, HasNoWait: Info.HasNoWait);
9301
9302 return OMPBuilder.emitKernelLaunch(
9303 Loc: Builder, OutlinedFnID, EmitTargetCallFallbackCB, Args&: KArgs,
9304 DeviceID: RuntimeAttrs.DeviceID, RTLoc, AllocaIP);
9305 }());
9306
9307 Builder.restoreIP(IP: AfterIP);
9308 return Error::success();
9309 };
9310
9311 // If we don't have an ID for the target region, it means an offload entry
9312 // wasn't created. In this case we just run the host fallback directly and
9313 // ignore any potential 'if' clauses.
9314 if (!OutlinedFnID) {
9315 cantFail(Err: EmitTargetCallElse(AllocaIP, Builder.saveIP()));
9316 return;
9317 }
9318
9319 // If there's no 'if' clause, only generate the kernel launch code path.
9320 if (!IfCond) {
9321 cantFail(Err: EmitTargetCallThen(AllocaIP, Builder.saveIP()));
9322 return;
9323 }
9324
9325 cantFail(Err: OMPBuilder.emitIfClause(Cond: IfCond, ThenGen: EmitTargetCallThen,
9326 ElseGen: EmitTargetCallElse, AllocaIP));
9327}
9328
9329OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
9330 const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
9331 InsertPointTy CodeGenIP, TargetDataInfo &Info,
9332 TargetRegionEntryInfo &EntryInfo,
9333 const TargetKernelDefaultAttrs &DefaultAttrs,
9334 const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
9335 SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
9336 OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
9337 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
9338 CustomMapperCallbackTy CustomMapperCB,
9339 const SmallVector<DependData> &Dependencies, bool HasNowait,
9340 Value *DynCGroupMem, OMPDynGroupprivateFallbackType DynCGroupMemFallback) {
9341
9342 if (!updateToLocation(Loc))
9343 return InsertPointTy();
9344
9345 Builder.restoreIP(IP: CodeGenIP);
9346
9347 Function *OutlinedFn;
9348 Constant *OutlinedFnID = nullptr;
9349 // The target region is outlined into its own function. The LLVM IR for
9350 // the target region itself is generated using the callbacks CBFunc
9351 // and ArgAccessorFuncCB
9352 if (Error Err = emitTargetOutlinedFunction(
9353 OMPBuilder&: *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
9354 OutlinedFnID, Inputs, CBFunc, ArgAccessorFuncCB))
9355 return Err;
9356
9357 // If we are not on the target device, then we need to generate code
9358 // to make a remote call (offload) to the previously outlined function
9359 // that represents the target region. Do that now.
9360 if (!Config.isTargetDevice())
9361 emitTargetCall(OMPBuilder&: *this, Builder, AllocaIP, Info, DefaultAttrs, RuntimeAttrs,
9362 IfCond, OutlinedFn, OutlinedFnID, Args&: Inputs, GenMapInfoCB,
9363 CustomMapperCB, Dependencies, HasNoWait: HasNowait, DynCGroupMem,
9364 DynCGroupMemFallback);
9365 return Builder.saveIP();
9366}
9367
9368std::string OpenMPIRBuilder::getNameWithSeparators(ArrayRef<StringRef> Parts,
9369 StringRef FirstSeparator,
9370 StringRef Separator) {
9371 SmallString<128> Buffer;
9372 llvm::raw_svector_ostream OS(Buffer);
9373 StringRef Sep = FirstSeparator;
9374 for (StringRef Part : Parts) {
9375 OS << Sep << Part;
9376 Sep = Separator;
9377 }
9378 return OS.str().str();
9379}
9380
9381std::string
9382OpenMPIRBuilder::createPlatformSpecificName(ArrayRef<StringRef> Parts) const {
9383 return OpenMPIRBuilder::getNameWithSeparators(Parts, FirstSeparator: Config.firstSeparator(),
9384 Separator: Config.separator());
9385}
9386
9387GlobalVariable *OpenMPIRBuilder::getOrCreateInternalVariable(
9388 Type *Ty, const StringRef &Name, std::optional<unsigned> AddressSpace) {
9389 auto &Elem = *InternalVars.try_emplace(Key: Name, Args: nullptr).first;
9390 if (Elem.second) {
9391 assert(Elem.second->getValueType() == Ty &&
9392 "OMP internal variable has different type than requested");
9393 } else {
9394 // TODO: investigate the appropriate linkage type used for the global
9395 // variable for possibly changing that to internal or private, or maybe
9396 // create different versions of the function for different OMP internal
9397 // variables.
9398 const DataLayout &DL = M.getDataLayout();
9399 // TODO: Investigate why AMDGPU expects AS 0 for globals even though the
9400 // default global AS is 1.
9401 // See double-target-call-with-declare-target.f90 and
9402 // declare-target-vars-in-target-region.f90 libomptarget
9403 // tests.
9404 unsigned AddressSpaceVal = AddressSpace ? *AddressSpace
9405 : M.getTargetTriple().isAMDGPU()
9406 ? 0
9407 : DL.getDefaultGlobalsAddressSpace();
9408 auto Linkage = this->M.getTargetTriple().getArch() == Triple::wasm32
9409 ? GlobalValue::InternalLinkage
9410 : GlobalValue::CommonLinkage;
9411 auto *GV = new GlobalVariable(M, Ty, /*IsConstant=*/false, Linkage,
9412 Constant::getNullValue(Ty), Elem.first(),
9413 /*InsertBefore=*/nullptr,
9414 GlobalValue::NotThreadLocal, AddressSpaceVal);
9415 const llvm::Align TypeAlign = DL.getABITypeAlign(Ty);
9416 const llvm::Align PtrAlign = DL.getPointerABIAlignment(AS: AddressSpaceVal);
9417 GV->setAlignment(std::max(a: TypeAlign, b: PtrAlign));
9418 Elem.second = GV;
9419 }
9420
9421 return Elem.second;
9422}
9423
9424Value *OpenMPIRBuilder::getOMPCriticalRegionLock(StringRef CriticalName) {
9425 std::string Prefix = Twine("gomp_critical_user_", CriticalName).str();
9426 std::string Name = getNameWithSeparators(Parts: {Prefix, "var"}, FirstSeparator: ".", Separator: ".");
9427 return getOrCreateInternalVariable(Ty: KmpCriticalNameTy, Name);
9428}
9429
9430Value *OpenMPIRBuilder::getSizeInBytes(Value *BasePtr) {
9431 LLVMContext &Ctx = Builder.getContext();
9432 Value *Null =
9433 Constant::getNullValue(Ty: PointerType::getUnqual(C&: BasePtr->getContext()));
9434 Value *SizeGep =
9435 Builder.CreateGEP(Ty: BasePtr->getType(), Ptr: Null, IdxList: Builder.getInt32(C: 1));
9436 Value *SizePtrToInt = Builder.CreatePtrToInt(V: SizeGep, DestTy: Type::getInt64Ty(C&: Ctx));
9437 return SizePtrToInt;
9438}
9439
9440GlobalVariable *
9441OpenMPIRBuilder::createOffloadMaptypes(SmallVectorImpl<uint64_t> &Mappings,
9442 std::string VarName) {
9443 llvm::Constant *MaptypesArrayInit =
9444 llvm::ConstantDataArray::get(Context&: M.getContext(), Elts&: Mappings);
9445 auto *MaptypesArrayGlobal = new llvm::GlobalVariable(
9446 M, MaptypesArrayInit->getType(),
9447 /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MaptypesArrayInit,
9448 VarName);
9449 MaptypesArrayGlobal->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
9450 return MaptypesArrayGlobal;
9451}
9452
9453void OpenMPIRBuilder::createMapperAllocas(const LocationDescription &Loc,
9454 InsertPointTy AllocaIP,
9455 unsigned NumOperands,
9456 struct MapperAllocas &MapperAllocas) {
9457 if (!updateToLocation(Loc))
9458 return;
9459
9460 auto *ArrI8PtrTy = ArrayType::get(ElementType: Int8Ptr, NumElements: NumOperands);
9461 auto *ArrI64Ty = ArrayType::get(ElementType: Int64, NumElements: NumOperands);
9462 Builder.restoreIP(IP: AllocaIP);
9463 AllocaInst *ArgsBase = Builder.CreateAlloca(
9464 Ty: ArrI8PtrTy, /* ArraySize = */ nullptr, Name: ".offload_baseptrs");
9465 AllocaInst *Args = Builder.CreateAlloca(Ty: ArrI8PtrTy, /* ArraySize = */ nullptr,
9466 Name: ".offload_ptrs");
9467 AllocaInst *ArgSizes = Builder.CreateAlloca(
9468 Ty: ArrI64Ty, /* ArraySize = */ nullptr, Name: ".offload_sizes");
9469 updateToLocation(Loc);
9470 MapperAllocas.ArgsBase = ArgsBase;
9471 MapperAllocas.Args = Args;
9472 MapperAllocas.ArgSizes = ArgSizes;
9473}
9474
9475void OpenMPIRBuilder::emitMapperCall(const LocationDescription &Loc,
9476 Function *MapperFunc, Value *SrcLocInfo,
9477 Value *MaptypesArg, Value *MapnamesArg,
9478 struct MapperAllocas &MapperAllocas,
9479 int64_t DeviceID, unsigned NumOperands) {
9480 if (!updateToLocation(Loc))
9481 return;
9482
9483 auto *ArrI8PtrTy = ArrayType::get(ElementType: Int8Ptr, NumElements: NumOperands);
9484 auto *ArrI64Ty = ArrayType::get(ElementType: Int64, NumElements: NumOperands);
9485 Value *ArgsBaseGEP =
9486 Builder.CreateInBoundsGEP(Ty: ArrI8PtrTy, Ptr: MapperAllocas.ArgsBase,
9487 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 0)});
9488 Value *ArgsGEP =
9489 Builder.CreateInBoundsGEP(Ty: ArrI8PtrTy, Ptr: MapperAllocas.Args,
9490 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 0)});
9491 Value *ArgSizesGEP =
9492 Builder.CreateInBoundsGEP(Ty: ArrI64Ty, Ptr: MapperAllocas.ArgSizes,
9493 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 0)});
9494 Value *NullPtr =
9495 Constant::getNullValue(Ty: PointerType::getUnqual(C&: Int8Ptr->getContext()));
9496 createRuntimeFunctionCall(Callee: MapperFunc, Args: {SrcLocInfo, Builder.getInt64(C: DeviceID),
9497 Builder.getInt32(C: NumOperands),
9498 ArgsBaseGEP, ArgsGEP, ArgSizesGEP,
9499 MaptypesArg, MapnamesArg, NullPtr});
9500}
9501
9502void OpenMPIRBuilder::emitOffloadingArraysArgument(IRBuilderBase &Builder,
9503 TargetDataRTArgs &RTArgs,
9504 TargetDataInfo &Info,
9505 bool ForEndCall) {
9506 assert((!ForEndCall || Info.separateBeginEndCalls()) &&
9507 "expected region end call to runtime only when end call is separate");
9508 auto UnqualPtrTy = PointerType::getUnqual(C&: M.getContext());
9509 auto VoidPtrTy = UnqualPtrTy;
9510 auto VoidPtrPtrTy = UnqualPtrTy;
9511 auto Int64Ty = Type::getInt64Ty(C&: M.getContext());
9512 auto Int64PtrTy = UnqualPtrTy;
9513
9514 if (!Info.NumberOfPtrs) {
9515 RTArgs.BasePointersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
9516 RTArgs.PointersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
9517 RTArgs.SizesArray = ConstantPointerNull::get(T: Int64PtrTy);
9518 RTArgs.MapTypesArray = ConstantPointerNull::get(T: Int64PtrTy);
9519 RTArgs.MapNamesArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
9520 RTArgs.MappersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
9521 return;
9522 }
9523
9524 RTArgs.BasePointersArray = Builder.CreateConstInBoundsGEP2_32(
9525 Ty: ArrayType::get(ElementType: VoidPtrTy, NumElements: Info.NumberOfPtrs),
9526 Ptr: Info.RTArgs.BasePointersArray,
9527 /*Idx0=*/0, /*Idx1=*/0);
9528 RTArgs.PointersArray = Builder.CreateConstInBoundsGEP2_32(
9529 Ty: ArrayType::get(ElementType: VoidPtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.PointersArray,
9530 /*Idx0=*/0,
9531 /*Idx1=*/0);
9532 RTArgs.SizesArray = Builder.CreateConstInBoundsGEP2_32(
9533 Ty: ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.SizesArray,
9534 /*Idx0=*/0, /*Idx1=*/0);
9535 RTArgs.MapTypesArray = Builder.CreateConstInBoundsGEP2_32(
9536 Ty: ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs),
9537 Ptr: ForEndCall && Info.RTArgs.MapTypesArrayEnd ? Info.RTArgs.MapTypesArrayEnd
9538 : Info.RTArgs.MapTypesArray,
9539 /*Idx0=*/0,
9540 /*Idx1=*/0);
9541
9542 // Only emit the mapper information arrays if debug information is
9543 // requested.
9544 if (!Info.EmitDebug)
9545 RTArgs.MapNamesArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
9546 else
9547 RTArgs.MapNamesArray = Builder.CreateConstInBoundsGEP2_32(
9548 Ty: ArrayType::get(ElementType: VoidPtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.MapNamesArray,
9549 /*Idx0=*/0,
9550 /*Idx1=*/0);
9551 // If there is no user-defined mapper, set the mapper array to nullptr to
9552 // avoid an unnecessary data privatization
9553 if (!Info.HasMapper)
9554 RTArgs.MappersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
9555 else
9556 RTArgs.MappersArray =
9557 Builder.CreatePointerCast(V: Info.RTArgs.MappersArray, DestTy: VoidPtrPtrTy);
9558}
9559
9560void OpenMPIRBuilder::emitNonContiguousDescriptor(InsertPointTy AllocaIP,
9561 InsertPointTy CodeGenIP,
9562 MapInfosTy &CombinedInfo,
9563 TargetDataInfo &Info) {
9564 MapInfosTy::StructNonContiguousInfo &NonContigInfo =
9565 CombinedInfo.NonContigInfo;
9566
9567 // Build an array of struct descriptor_dim and then assign it to
9568 // offload_args.
9569 //
9570 // struct descriptor_dim {
9571 // uint64_t offset;
9572 // uint64_t count;
9573 // uint64_t stride
9574 // };
9575 Type *Int64Ty = Builder.getInt64Ty();
9576 StructType *DimTy = StructType::create(
9577 Context&: M.getContext(), Elements: ArrayRef<Type *>({Int64Ty, Int64Ty, Int64Ty}),
9578 Name: "struct.descriptor_dim");
9579
9580 enum { OffsetFD = 0, CountFD, StrideFD };
9581 // We need two index variable here since the size of "Dims" is the same as
9582 // the size of Components, however, the size of offset, count, and stride is
9583 // equal to the size of base declaration that is non-contiguous.
9584 for (unsigned I = 0, L = 0, E = NonContigInfo.Dims.size(); I < E; ++I) {
9585 // Skip emitting ir if dimension size is 1 since it cannot be
9586 // non-contiguous.
9587 if (NonContigInfo.Dims[I] == 1)
9588 continue;
9589 Builder.restoreIP(IP: AllocaIP);
9590 ArrayType *ArrayTy = ArrayType::get(ElementType: DimTy, NumElements: NonContigInfo.Dims[I]);
9591 AllocaInst *DimsAddr =
9592 Builder.CreateAlloca(Ty: ArrayTy, /* ArraySize = */ nullptr, Name: "dims");
9593 Builder.restoreIP(IP: CodeGenIP);
9594 for (unsigned II = 0, EE = NonContigInfo.Dims[I]; II < EE; ++II) {
9595 unsigned RevIdx = EE - II - 1;
9596 Value *DimsLVal = Builder.CreateInBoundsGEP(
9597 Ty: DimsAddr->getAllocatedType(), Ptr: DimsAddr,
9598 IdxList: {Builder.getInt64(C: 0), Builder.getInt64(C: II)});
9599 // Offset
9600 Value *OffsetLVal = Builder.CreateStructGEP(Ty: DimTy, Ptr: DimsLVal, Idx: OffsetFD);
9601 Builder.CreateAlignedStore(
9602 Val: NonContigInfo.Offsets[L][RevIdx], Ptr: OffsetLVal,
9603 Align: M.getDataLayout().getPrefTypeAlign(Ty: OffsetLVal->getType()));
9604 // Count
9605 Value *CountLVal = Builder.CreateStructGEP(Ty: DimTy, Ptr: DimsLVal, Idx: CountFD);
9606 Builder.CreateAlignedStore(
9607 Val: NonContigInfo.Counts[L][RevIdx], Ptr: CountLVal,
9608 Align: M.getDataLayout().getPrefTypeAlign(Ty: CountLVal->getType()));
9609 // Stride
9610 Value *StrideLVal = Builder.CreateStructGEP(Ty: DimTy, Ptr: DimsLVal, Idx: StrideFD);
9611 Builder.CreateAlignedStore(
9612 Val: NonContigInfo.Strides[L][RevIdx], Ptr: StrideLVal,
9613 Align: M.getDataLayout().getPrefTypeAlign(Ty: CountLVal->getType()));
9614 }
9615 // args[I] = &dims
9616 Builder.restoreIP(IP: CodeGenIP);
9617 Value *DAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
9618 V: DimsAddr, DestTy: Builder.getPtrTy());
9619 Value *P = Builder.CreateConstInBoundsGEP2_32(
9620 Ty: ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: Info.NumberOfPtrs),
9621 Ptr: Info.RTArgs.PointersArray, Idx0: 0, Idx1: I);
9622 Builder.CreateAlignedStore(
9623 Val: DAddr, Ptr: P, Align: M.getDataLayout().getPrefTypeAlign(Ty: Builder.getPtrTy()));
9624 ++L;
9625 }
9626}
9627
9628void OpenMPIRBuilder::emitUDMapperArrayInitOrDel(
9629 Function *MapperFn, Value *MapperHandle, Value *Base, Value *Begin,
9630 Value *Size, Value *MapType, Value *MapName, TypeSize ElementSize,
9631 BasicBlock *ExitBB, bool IsInit) {
9632 StringRef Prefix = IsInit ? ".init" : ".del";
9633
9634 // Evaluate if this is an array section.
9635 BasicBlock *BodyBB = BasicBlock::Create(
9636 Context&: M.getContext(), Name: createPlatformSpecificName(Parts: {"omp.array", Prefix}));
9637 Value *IsArray =
9638 Builder.CreateICmpSGT(LHS: Size, RHS: Builder.getInt64(C: 1), Name: "omp.arrayinit.isarray");
9639 Value *DeleteBit = Builder.CreateAnd(
9640 LHS: MapType,
9641 RHS: Builder.getInt64(
9642 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9643 OpenMPOffloadMappingFlags::OMP_MAP_DELETE)));
9644 Value *DeleteCond;
9645 Value *Cond;
9646 if (IsInit) {
9647 // base != begin?
9648 Value *BaseIsBegin = Builder.CreateICmpNE(LHS: Base, RHS: Begin);
9649 Cond = Builder.CreateOr(LHS: IsArray, RHS: BaseIsBegin);
9650 DeleteCond = Builder.CreateIsNull(
9651 Arg: DeleteBit,
9652 Name: createPlatformSpecificName(Parts: {"omp.array", Prefix, ".delete"}));
9653 } else {
9654 Cond = IsArray;
9655 DeleteCond = Builder.CreateIsNotNull(
9656 Arg: DeleteBit,
9657 Name: createPlatformSpecificName(Parts: {"omp.array", Prefix, ".delete"}));
9658 }
9659 Cond = Builder.CreateAnd(LHS: Cond, RHS: DeleteCond);
9660 Builder.CreateCondBr(Cond, True: BodyBB, False: ExitBB);
9661
9662 emitBlock(BB: BodyBB, CurFn: MapperFn);
9663 // Get the array size by multiplying element size and element number (i.e., \p
9664 // Size).
9665 Value *ArraySize = Builder.CreateNUWMul(LHS: Size, RHS: Builder.getInt64(C: ElementSize));
9666 // Remove OMP_MAP_TO and OMP_MAP_FROM from the map type, so that it achieves
9667 // memory allocation/deletion purpose only.
9668 Value *MapTypeArg = Builder.CreateAnd(
9669 LHS: MapType,
9670 RHS: Builder.getInt64(
9671 C: ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9672 OpenMPOffloadMappingFlags::OMP_MAP_TO |
9673 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
9674 MapTypeArg = Builder.CreateOr(
9675 LHS: MapTypeArg,
9676 RHS: Builder.getInt64(
9677 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9678 OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT)));
9679
9680 // Call the runtime API __tgt_push_mapper_component to fill up the runtime
9681 // data structure.
9682 Value *OffloadingArgs[] = {MapperHandle, Base, Begin,
9683 ArraySize, MapTypeArg, MapName};
9684 createRuntimeFunctionCall(
9685 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___tgt_push_mapper_component),
9686 Args: OffloadingArgs);
9687}
9688
9689Expected<Function *> OpenMPIRBuilder::emitUserDefinedMapper(
9690 function_ref<MapInfosOrErrorTy(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
9691 llvm::Value *BeginArg)>
9692 GenMapInfoCB,
9693 Type *ElemTy, StringRef FuncName, CustomMapperCallbackTy CustomMapperCB) {
9694 SmallVector<Type *> Params;
9695 Params.emplace_back(Args: Builder.getPtrTy());
9696 Params.emplace_back(Args: Builder.getPtrTy());
9697 Params.emplace_back(Args: Builder.getPtrTy());
9698 Params.emplace_back(Args: Builder.getInt64Ty());
9699 Params.emplace_back(Args: Builder.getInt64Ty());
9700 Params.emplace_back(Args: Builder.getPtrTy());
9701
9702 auto *FnTy =
9703 FunctionType::get(Result: Builder.getVoidTy(), Params, /* IsVarArg */ isVarArg: false);
9704
9705 SmallString<64> TyStr;
9706 raw_svector_ostream Out(TyStr);
9707 Function *MapperFn =
9708 Function::Create(Ty: FnTy, Linkage: GlobalValue::InternalLinkage, N: FuncName, M);
9709 MapperFn->addFnAttr(Kind: Attribute::NoInline);
9710 MapperFn->addFnAttr(Kind: Attribute::NoUnwind);
9711 MapperFn->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
9712 MapperFn->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
9713 MapperFn->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
9714 MapperFn->addParamAttr(ArgNo: 3, Kind: Attribute::NoUndef);
9715 MapperFn->addParamAttr(ArgNo: 4, Kind: Attribute::NoUndef);
9716 MapperFn->addParamAttr(ArgNo: 5, Kind: Attribute::NoUndef);
9717
9718 // Start the mapper function code generation.
9719 BasicBlock *EntryBB = BasicBlock::Create(Context&: M.getContext(), Name: "entry", Parent: MapperFn);
9720 auto SavedIP = Builder.saveIP();
9721 Builder.SetInsertPoint(EntryBB);
9722
9723 Value *MapperHandle = MapperFn->getArg(i: 0);
9724 Value *BaseIn = MapperFn->getArg(i: 1);
9725 Value *BeginIn = MapperFn->getArg(i: 2);
9726 Value *Size = MapperFn->getArg(i: 3);
9727 Value *MapType = MapperFn->getArg(i: 4);
9728 Value *MapName = MapperFn->getArg(i: 5);
9729
9730 // Compute the starting and end addresses of array elements.
9731 // Prepare common arguments for array initiation and deletion.
9732 // Convert the size in bytes into the number of array elements.
9733 TypeSize ElementSize = M.getDataLayout().getTypeStoreSize(Ty: ElemTy);
9734 Size = Builder.CreateExactUDiv(LHS: Size, RHS: Builder.getInt64(C: ElementSize));
9735 Value *PtrBegin = BeginIn;
9736 Value *PtrEnd = Builder.CreateGEP(Ty: ElemTy, Ptr: PtrBegin, IdxList: Size);
9737
9738 // Emit array initiation if this is an array section and \p MapType indicates
9739 // that memory allocation is required.
9740 BasicBlock *HeadBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.arraymap.head");
9741 emitUDMapperArrayInitOrDel(MapperFn, MapperHandle, Base: BaseIn, Begin: BeginIn, Size,
9742 MapType, MapName, ElementSize, ExitBB: HeadBB,
9743 /*IsInit=*/true);
9744
9745 // Emit a for loop to iterate through SizeArg of elements and map all of them.
9746
9747 // Emit the loop header block.
9748 emitBlock(BB: HeadBB, CurFn: MapperFn);
9749 BasicBlock *BodyBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.arraymap.body");
9750 BasicBlock *DoneBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.done");
9751 // Evaluate whether the initial condition is satisfied.
9752 Value *IsEmpty =
9753 Builder.CreateICmpEQ(LHS: PtrBegin, RHS: PtrEnd, Name: "omp.arraymap.isempty");
9754 Builder.CreateCondBr(Cond: IsEmpty, True: DoneBB, False: BodyBB);
9755
9756 // Emit the loop body block.
9757 emitBlock(BB: BodyBB, CurFn: MapperFn);
9758 BasicBlock *LastBB = BodyBB;
9759 PHINode *PtrPHI =
9760 Builder.CreatePHI(Ty: PtrBegin->getType(), NumReservedValues: 2, Name: "omp.arraymap.ptrcurrent");
9761 PtrPHI->addIncoming(V: PtrBegin, BB: HeadBB);
9762
9763 // Get map clause information. Fill up the arrays with all mapped variables.
9764 MapInfosOrErrorTy Info = GenMapInfoCB(Builder.saveIP(), PtrPHI, BeginIn);
9765 if (!Info)
9766 return Info.takeError();
9767
9768 // Call the runtime API __tgt_mapper_num_components to get the number of
9769 // pre-existing components.
9770 Value *OffloadingArgs[] = {MapperHandle};
9771 Value *PreviousSize = createRuntimeFunctionCall(
9772 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___tgt_mapper_num_components),
9773 Args: OffloadingArgs);
9774 Value *ShiftedPreviousSize =
9775 Builder.CreateShl(LHS: PreviousSize, RHS: Builder.getInt64(C: getFlagMemberOffset()));
9776
9777 // Fill up the runtime mapper handle for all components.
9778 for (unsigned I = 0; I < Info->BasePointers.size(); ++I) {
9779 Value *CurBaseArg = Info->BasePointers[I];
9780 Value *CurBeginArg = Info->Pointers[I];
9781 Value *CurSizeArg = Info->Sizes[I];
9782 Value *CurNameArg = Info->Names.size()
9783 ? Info->Names[I]
9784 : Constant::getNullValue(Ty: Builder.getPtrTy());
9785
9786 // Extract the MEMBER_OF field from the map type.
9787 Value *OriMapType = Builder.getInt64(
9788 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9789 Info->Types[I]));
9790 Value *MemberMapType =
9791 Builder.CreateNUWAdd(LHS: OriMapType, RHS: ShiftedPreviousSize);
9792
9793 // Combine the map type inherited from user-defined mapper with that
9794 // specified in the program. According to the OMP_MAP_TO and OMP_MAP_FROM
9795 // bits of the \a MapType, which is the input argument of the mapper
9796 // function, the following code will set the OMP_MAP_TO and OMP_MAP_FROM
9797 // bits of MemberMapType.
9798 // [OpenMP 5.0], 1.2.6. map-type decay.
9799 // | alloc | to | from | tofrom | release | delete
9800 // ----------------------------------------------------------
9801 // alloc | alloc | alloc | alloc | alloc | release | delete
9802 // to | alloc | to | alloc | to | release | delete
9803 // from | alloc | alloc | from | from | release | delete
9804 // tofrom | alloc | to | from | tofrom | release | delete
9805 Value *LeftToFrom = Builder.CreateAnd(
9806 LHS: MapType,
9807 RHS: Builder.getInt64(
9808 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9809 OpenMPOffloadMappingFlags::OMP_MAP_TO |
9810 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
9811 BasicBlock *AllocBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.alloc");
9812 BasicBlock *AllocElseBB =
9813 BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.alloc.else");
9814 BasicBlock *ToBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.to");
9815 BasicBlock *ToElseBB =
9816 BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.to.else");
9817 BasicBlock *FromBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.from");
9818 BasicBlock *EndBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.end");
9819 Value *IsAlloc = Builder.CreateIsNull(Arg: LeftToFrom);
9820 Builder.CreateCondBr(Cond: IsAlloc, True: AllocBB, False: AllocElseBB);
9821 // In case of alloc, clear OMP_MAP_TO and OMP_MAP_FROM.
9822 emitBlock(BB: AllocBB, CurFn: MapperFn);
9823 Value *AllocMapType = Builder.CreateAnd(
9824 LHS: MemberMapType,
9825 RHS: Builder.getInt64(
9826 C: ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9827 OpenMPOffloadMappingFlags::OMP_MAP_TO |
9828 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
9829 Builder.CreateBr(Dest: EndBB);
9830 emitBlock(BB: AllocElseBB, CurFn: MapperFn);
9831 Value *IsTo = Builder.CreateICmpEQ(
9832 LHS: LeftToFrom,
9833 RHS: Builder.getInt64(
9834 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9835 OpenMPOffloadMappingFlags::OMP_MAP_TO)));
9836 Builder.CreateCondBr(Cond: IsTo, True: ToBB, False: ToElseBB);
9837 // In case of to, clear OMP_MAP_FROM.
9838 emitBlock(BB: ToBB, CurFn: MapperFn);
9839 Value *ToMapType = Builder.CreateAnd(
9840 LHS: MemberMapType,
9841 RHS: Builder.getInt64(
9842 C: ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9843 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
9844 Builder.CreateBr(Dest: EndBB);
9845 emitBlock(BB: ToElseBB, CurFn: MapperFn);
9846 Value *IsFrom = Builder.CreateICmpEQ(
9847 LHS: LeftToFrom,
9848 RHS: Builder.getInt64(
9849 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9850 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
9851 Builder.CreateCondBr(Cond: IsFrom, True: FromBB, False: EndBB);
9852 // In case of from, clear OMP_MAP_TO.
9853 emitBlock(BB: FromBB, CurFn: MapperFn);
9854 Value *FromMapType = Builder.CreateAnd(
9855 LHS: MemberMapType,
9856 RHS: Builder.getInt64(
9857 C: ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9858 OpenMPOffloadMappingFlags::OMP_MAP_TO)));
9859 // In case of tofrom, do nothing.
9860 emitBlock(BB: EndBB, CurFn: MapperFn);
9861 LastBB = EndBB;
9862 PHINode *CurMapType =
9863 Builder.CreatePHI(Ty: Builder.getInt64Ty(), NumReservedValues: 4, Name: "omp.maptype");
9864 CurMapType->addIncoming(V: AllocMapType, BB: AllocBB);
9865 CurMapType->addIncoming(V: ToMapType, BB: ToBB);
9866 CurMapType->addIncoming(V: FromMapType, BB: FromBB);
9867 CurMapType->addIncoming(V: MemberMapType, BB: ToElseBB);
9868
9869 Value *OffloadingArgs[] = {MapperHandle, CurBaseArg, CurBeginArg,
9870 CurSizeArg, CurMapType, CurNameArg};
9871
9872 auto ChildMapperFn = CustomMapperCB(I);
9873 if (!ChildMapperFn)
9874 return ChildMapperFn.takeError();
9875 if (*ChildMapperFn) {
9876 // Call the corresponding mapper function.
9877 createRuntimeFunctionCall(Callee: *ChildMapperFn, Args: OffloadingArgs)
9878 ->setDoesNotThrow();
9879 } else {
9880 // Call the runtime API __tgt_push_mapper_component to fill up the runtime
9881 // data structure.
9882 createRuntimeFunctionCall(
9883 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___tgt_push_mapper_component),
9884 Args: OffloadingArgs);
9885 }
9886 }
9887
9888 // Update the pointer to point to the next element that needs to be mapped,
9889 // and check whether we have mapped all elements.
9890 Value *PtrNext = Builder.CreateConstGEP1_32(Ty: ElemTy, Ptr: PtrPHI, /*Idx0=*/1,
9891 Name: "omp.arraymap.next");
9892 PtrPHI->addIncoming(V: PtrNext, BB: LastBB);
9893 Value *IsDone = Builder.CreateICmpEQ(LHS: PtrNext, RHS: PtrEnd, Name: "omp.arraymap.isdone");
9894 BasicBlock *ExitBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.arraymap.exit");
9895 Builder.CreateCondBr(Cond: IsDone, True: ExitBB, False: BodyBB);
9896
9897 emitBlock(BB: ExitBB, CurFn: MapperFn);
9898 // Emit array deletion if this is an array section and \p MapType indicates
9899 // that deletion is required.
9900 emitUDMapperArrayInitOrDel(MapperFn, MapperHandle, Base: BaseIn, Begin: BeginIn, Size,
9901 MapType, MapName, ElementSize, ExitBB: DoneBB,
9902 /*IsInit=*/false);
9903
9904 // Emit the function exit block.
9905 emitBlock(BB: DoneBB, CurFn: MapperFn, /*IsFinished=*/true);
9906
9907 Builder.CreateRetVoid();
9908 Builder.restoreIP(IP: SavedIP);
9909 return MapperFn;
9910}
9911
9912Error OpenMPIRBuilder::emitOffloadingArrays(
9913 InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
9914 TargetDataInfo &Info, CustomMapperCallbackTy CustomMapperCB,
9915 bool IsNonContiguous,
9916 function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
9917
9918 // Reset the array information.
9919 Info.clearArrayInfo();
9920 Info.NumberOfPtrs = CombinedInfo.BasePointers.size();
9921
9922 if (Info.NumberOfPtrs == 0)
9923 return Error::success();
9924
9925 Builder.restoreIP(IP: AllocaIP);
9926 // Detect if we have any capture size requiring runtime evaluation of the
9927 // size so that a constant array could be eventually used.
9928 ArrayType *PointerArrayType =
9929 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: Info.NumberOfPtrs);
9930
9931 Info.RTArgs.BasePointersArray = Builder.CreateAlloca(
9932 Ty: PointerArrayType, /* ArraySize = */ nullptr, Name: ".offload_baseptrs");
9933
9934 Info.RTArgs.PointersArray = Builder.CreateAlloca(
9935 Ty: PointerArrayType, /* ArraySize = */ nullptr, Name: ".offload_ptrs");
9936 AllocaInst *MappersArray = Builder.CreateAlloca(
9937 Ty: PointerArrayType, /* ArraySize = */ nullptr, Name: ".offload_mappers");
9938 Info.RTArgs.MappersArray = MappersArray;
9939
9940 // If we don't have any VLA types or other types that require runtime
9941 // evaluation, we can use a constant array for the map sizes, otherwise we
9942 // need to fill up the arrays as we do for the pointers.
9943 Type *Int64Ty = Builder.getInt64Ty();
9944 SmallVector<Constant *> ConstSizes(CombinedInfo.Sizes.size(),
9945 ConstantInt::get(Ty: Int64Ty, V: 0));
9946 SmallBitVector RuntimeSizes(CombinedInfo.Sizes.size());
9947 for (unsigned I = 0, E = CombinedInfo.Sizes.size(); I < E; ++I) {
9948 if (auto *CI = dyn_cast<Constant>(Val: CombinedInfo.Sizes[I])) {
9949 if (!isa<ConstantExpr>(Val: CI) && !isa<GlobalValue>(Val: CI)) {
9950 if (IsNonContiguous &&
9951 static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
9952 CombinedInfo.Types[I] &
9953 OpenMPOffloadMappingFlags::OMP_MAP_NON_CONTIG))
9954 ConstSizes[I] =
9955 ConstantInt::get(Ty: Int64Ty, V: CombinedInfo.NonContigInfo.Dims[I]);
9956 else
9957 ConstSizes[I] = CI;
9958 continue;
9959 }
9960 }
9961 RuntimeSizes.set(I);
9962 }
9963
9964 if (RuntimeSizes.all()) {
9965 ArrayType *SizeArrayType = ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs);
9966 Info.RTArgs.SizesArray = Builder.CreateAlloca(
9967 Ty: SizeArrayType, /* ArraySize = */ nullptr, Name: ".offload_sizes");
9968 restoreIPandDebugLoc(Builder, IP: CodeGenIP);
9969 } else {
9970 auto *SizesArrayInit = ConstantArray::get(
9971 T: ArrayType::get(ElementType: Int64Ty, NumElements: ConstSizes.size()), V: ConstSizes);
9972 std::string Name = createPlatformSpecificName(Parts: {"offload_sizes"});
9973 auto *SizesArrayGbl =
9974 new GlobalVariable(M, SizesArrayInit->getType(), /*isConstant=*/true,
9975 GlobalValue::PrivateLinkage, SizesArrayInit, Name);
9976 SizesArrayGbl->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
9977
9978 if (!RuntimeSizes.any()) {
9979 Info.RTArgs.SizesArray = SizesArrayGbl;
9980 } else {
9981 unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(AS: 0);
9982 Align OffloadSizeAlign = M.getDataLayout().getABIIntegerTypeAlignment(BitWidth: 64);
9983 ArrayType *SizeArrayType = ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs);
9984 AllocaInst *Buffer = Builder.CreateAlloca(
9985 Ty: SizeArrayType, /* ArraySize = */ nullptr, Name: ".offload_sizes");
9986 Buffer->setAlignment(OffloadSizeAlign);
9987 restoreIPandDebugLoc(Builder, IP: CodeGenIP);
9988 Builder.CreateMemCpy(
9989 Dst: Buffer, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: Buffer->getType()),
9990 Src: SizesArrayGbl, SrcAlign: OffloadSizeAlign,
9991 Size: Builder.getIntN(
9992 N: IndexSize,
9993 C: Buffer->getAllocationSize(DL: M.getDataLayout())->getFixedValue()));
9994
9995 Info.RTArgs.SizesArray = Buffer;
9996 }
9997 restoreIPandDebugLoc(Builder, IP: CodeGenIP);
9998 }
9999
10000 // The map types are always constant so we don't need to generate code to
10001 // fill arrays. Instead, we create an array constant.
10002 SmallVector<uint64_t, 4> Mapping;
10003 for (auto mapFlag : CombinedInfo.Types)
10004 Mapping.push_back(
10005 Elt: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
10006 mapFlag));
10007 std::string MaptypesName = createPlatformSpecificName(Parts: {"offload_maptypes"});
10008 auto *MapTypesArrayGbl = createOffloadMaptypes(Mappings&: Mapping, VarName: MaptypesName);
10009 Info.RTArgs.MapTypesArray = MapTypesArrayGbl;
10010
10011 // The information types are only built if provided.
10012 if (!CombinedInfo.Names.empty()) {
10013 auto *MapNamesArrayGbl = createOffloadMapnames(
10014 Names&: CombinedInfo.Names, VarName: createPlatformSpecificName(Parts: {"offload_mapnames"}));
10015 Info.RTArgs.MapNamesArray = MapNamesArrayGbl;
10016 Info.EmitDebug = true;
10017 } else {
10018 Info.RTArgs.MapNamesArray =
10019 Constant::getNullValue(Ty: PointerType::getUnqual(C&: Builder.getContext()));
10020 Info.EmitDebug = false;
10021 }
10022
10023 // If there's a present map type modifier, it must not be applied to the end
10024 // of a region, so generate a separate map type array in that case.
10025 if (Info.separateBeginEndCalls()) {
10026 bool EndMapTypesDiffer = false;
10027 for (uint64_t &Type : Mapping) {
10028 if (Type & static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
10029 OpenMPOffloadMappingFlags::OMP_MAP_PRESENT)) {
10030 Type &= ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
10031 OpenMPOffloadMappingFlags::OMP_MAP_PRESENT);
10032 EndMapTypesDiffer = true;
10033 }
10034 }
10035 if (EndMapTypesDiffer) {
10036 MapTypesArrayGbl = createOffloadMaptypes(Mappings&: Mapping, VarName: MaptypesName);
10037 Info.RTArgs.MapTypesArrayEnd = MapTypesArrayGbl;
10038 }
10039 }
10040
10041 PointerType *PtrTy = Builder.getPtrTy();
10042 for (unsigned I = 0; I < Info.NumberOfPtrs; ++I) {
10043 Value *BPVal = CombinedInfo.BasePointers[I];
10044 Value *BP = Builder.CreateConstInBoundsGEP2_32(
10045 Ty: ArrayType::get(ElementType: PtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.BasePointersArray,
10046 Idx0: 0, Idx1: I);
10047 Builder.CreateAlignedStore(Val: BPVal, Ptr: BP,
10048 Align: M.getDataLayout().getPrefTypeAlign(Ty: PtrTy));
10049
10050 if (Info.requiresDevicePointerInfo()) {
10051 if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Pointer) {
10052 CodeGenIP = Builder.saveIP();
10053 Builder.restoreIP(IP: AllocaIP);
10054 Info.DevicePtrInfoMap[BPVal] = {BP, Builder.CreateAlloca(Ty: PtrTy)};
10055 Builder.restoreIP(IP: CodeGenIP);
10056 if (DeviceAddrCB)
10057 DeviceAddrCB(I, Info.DevicePtrInfoMap[BPVal].second);
10058 } else if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Address) {
10059 Info.DevicePtrInfoMap[BPVal] = {BP, BP};
10060 if (DeviceAddrCB)
10061 DeviceAddrCB(I, BP);
10062 }
10063 }
10064
10065 Value *PVal = CombinedInfo.Pointers[I];
10066 Value *P = Builder.CreateConstInBoundsGEP2_32(
10067 Ty: ArrayType::get(ElementType: PtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.PointersArray, Idx0: 0,
10068 Idx1: I);
10069 // TODO: Check alignment correct.
10070 Builder.CreateAlignedStore(Val: PVal, Ptr: P,
10071 Align: M.getDataLayout().getPrefTypeAlign(Ty: PtrTy));
10072
10073 if (RuntimeSizes.test(Idx: I)) {
10074 Value *S = Builder.CreateConstInBoundsGEP2_32(
10075 Ty: ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.SizesArray,
10076 /*Idx0=*/0,
10077 /*Idx1=*/I);
10078 Builder.CreateAlignedStore(Val: Builder.CreateIntCast(V: CombinedInfo.Sizes[I],
10079 DestTy: Int64Ty,
10080 /*isSigned=*/true),
10081 Ptr: S, Align: M.getDataLayout().getPrefTypeAlign(Ty: PtrTy));
10082 }
10083 // Fill up the mapper array.
10084 unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(AS: 0);
10085 Value *MFunc = ConstantPointerNull::get(T: PtrTy);
10086
10087 auto CustomMFunc = CustomMapperCB(I);
10088 if (!CustomMFunc)
10089 return CustomMFunc.takeError();
10090 if (*CustomMFunc)
10091 MFunc = Builder.CreatePointerCast(V: *CustomMFunc, DestTy: PtrTy);
10092
10093 Value *MAddr = Builder.CreateInBoundsGEP(
10094 Ty: MappersArray->getAllocatedType(), Ptr: MappersArray,
10095 IdxList: {Builder.getIntN(N: IndexSize, C: 0), Builder.getIntN(N: IndexSize, C: I)});
10096 Builder.CreateAlignedStore(
10097 Val: MFunc, Ptr: MAddr, Align: M.getDataLayout().getPrefTypeAlign(Ty: MAddr->getType()));
10098 }
10099
10100 if (!IsNonContiguous || CombinedInfo.NonContigInfo.Offsets.empty() ||
10101 Info.NumberOfPtrs == 0)
10102 return Error::success();
10103 emitNonContiguousDescriptor(AllocaIP, CodeGenIP, CombinedInfo, Info);
10104 return Error::success();
10105}
10106
10107void OpenMPIRBuilder::emitBranch(BasicBlock *Target) {
10108 BasicBlock *CurBB = Builder.GetInsertBlock();
10109
10110 if (!CurBB || CurBB->getTerminator()) {
10111 // If there is no insert point or the previous block is already
10112 // terminated, don't touch it.
10113 } else {
10114 // Otherwise, create a fall-through branch.
10115 Builder.CreateBr(Dest: Target);
10116 }
10117
10118 Builder.ClearInsertionPoint();
10119}
10120
10121void OpenMPIRBuilder::emitBlock(BasicBlock *BB, Function *CurFn,
10122 bool IsFinished) {
10123 BasicBlock *CurBB = Builder.GetInsertBlock();
10124
10125 // Fall out of the current block (if necessary).
10126 emitBranch(Target: BB);
10127
10128 if (IsFinished && BB->use_empty()) {
10129 BB->eraseFromParent();
10130 return;
10131 }
10132
10133 // Place the block after the current block, if possible, or else at
10134 // the end of the function.
10135 if (CurBB && CurBB->getParent())
10136 CurFn->insert(Position: std::next(x: CurBB->getIterator()), BB);
10137 else
10138 CurFn->insert(Position: CurFn->end(), BB);
10139 Builder.SetInsertPoint(BB);
10140}
10141
10142Error OpenMPIRBuilder::emitIfClause(Value *Cond, BodyGenCallbackTy ThenGen,
10143 BodyGenCallbackTy ElseGen,
10144 InsertPointTy AllocaIP) {
10145 // If the condition constant folds and can be elided, try to avoid emitting
10146 // the condition and the dead arm of the if/else.
10147 if (auto *CI = dyn_cast<ConstantInt>(Val: Cond)) {
10148 auto CondConstant = CI->getSExtValue();
10149 if (CondConstant)
10150 return ThenGen(AllocaIP, Builder.saveIP());
10151
10152 return ElseGen(AllocaIP, Builder.saveIP());
10153 }
10154
10155 Function *CurFn = Builder.GetInsertBlock()->getParent();
10156
10157 // Otherwise, the condition did not fold, or we couldn't elide it. Just
10158 // emit the conditional branch.
10159 BasicBlock *ThenBlock = BasicBlock::Create(Context&: M.getContext(), Name: "omp_if.then");
10160 BasicBlock *ElseBlock = BasicBlock::Create(Context&: M.getContext(), Name: "omp_if.else");
10161 BasicBlock *ContBlock = BasicBlock::Create(Context&: M.getContext(), Name: "omp_if.end");
10162 Builder.CreateCondBr(Cond, True: ThenBlock, False: ElseBlock);
10163 // Emit the 'then' code.
10164 emitBlock(BB: ThenBlock, CurFn);
10165 if (Error Err = ThenGen(AllocaIP, Builder.saveIP()))
10166 return Err;
10167 emitBranch(Target: ContBlock);
10168 // Emit the 'else' code if present.
10169 // There is no need to emit line number for unconditional branch.
10170 emitBlock(BB: ElseBlock, CurFn);
10171 if (Error Err = ElseGen(AllocaIP, Builder.saveIP()))
10172 return Err;
10173 // There is no need to emit line number for unconditional branch.
10174 emitBranch(Target: ContBlock);
10175 // Emit the continuation block for code after the if.
10176 emitBlock(BB: ContBlock, CurFn, /*IsFinished=*/true);
10177 return Error::success();
10178}
10179
10180bool OpenMPIRBuilder::checkAndEmitFlushAfterAtomic(
10181 const LocationDescription &Loc, llvm::AtomicOrdering AO, AtomicKind AK) {
10182 assert(!(AO == AtomicOrdering::NotAtomic ||
10183 AO == llvm::AtomicOrdering::Unordered) &&
10184 "Unexpected Atomic Ordering.");
10185
10186 bool Flush = false;
10187 llvm::AtomicOrdering FlushAO = AtomicOrdering::Monotonic;
10188
10189 switch (AK) {
10190 case Read:
10191 if (AO == AtomicOrdering::Acquire || AO == AtomicOrdering::AcquireRelease ||
10192 AO == AtomicOrdering::SequentiallyConsistent) {
10193 FlushAO = AtomicOrdering::Acquire;
10194 Flush = true;
10195 }
10196 break;
10197 case Write:
10198 case Compare:
10199 case Update:
10200 if (AO == AtomicOrdering::Release || AO == AtomicOrdering::AcquireRelease ||
10201 AO == AtomicOrdering::SequentiallyConsistent) {
10202 FlushAO = AtomicOrdering::Release;
10203 Flush = true;
10204 }
10205 break;
10206 case Capture:
10207 switch (AO) {
10208 case AtomicOrdering::Acquire:
10209 FlushAO = AtomicOrdering::Acquire;
10210 Flush = true;
10211 break;
10212 case AtomicOrdering::Release:
10213 FlushAO = AtomicOrdering::Release;
10214 Flush = true;
10215 break;
10216 case AtomicOrdering::AcquireRelease:
10217 case AtomicOrdering::SequentiallyConsistent:
10218 FlushAO = AtomicOrdering::AcquireRelease;
10219 Flush = true;
10220 break;
10221 default:
10222 // do nothing - leave silently.
10223 break;
10224 }
10225 }
10226
10227 if (Flush) {
10228 // Currently Flush RT call still doesn't take memory_ordering, so for when
10229 // that happens, this tries to do the resolution of which atomic ordering
10230 // to use with but issue the flush call
10231 // TODO: pass `FlushAO` after memory ordering support is added
10232 (void)FlushAO;
10233 emitFlush(Loc);
10234 }
10235
10236 // for AO == AtomicOrdering::Monotonic and all other case combinations
10237 // do nothing
10238 return Flush;
10239}
10240
10241OpenMPIRBuilder::InsertPointTy
10242OpenMPIRBuilder::createAtomicRead(const LocationDescription &Loc,
10243 AtomicOpValue &X, AtomicOpValue &V,
10244 AtomicOrdering AO, InsertPointTy AllocaIP) {
10245 if (!updateToLocation(Loc))
10246 return Loc.IP;
10247
10248 assert(X.Var->getType()->isPointerTy() &&
10249 "OMP Atomic expects a pointer to target memory");
10250 Type *XElemTy = X.ElemTy;
10251 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
10252 XElemTy->isPointerTy() || XElemTy->isStructTy()) &&
10253 "OMP atomic read expected a scalar type");
10254
10255 Value *XRead = nullptr;
10256
10257 if (XElemTy->isIntegerTy()) {
10258 LoadInst *XLD =
10259 Builder.CreateLoad(Ty: XElemTy, Ptr: X.Var, isVolatile: X.IsVolatile, Name: "omp.atomic.read");
10260 XLD->setAtomic(Ordering: AO);
10261 XRead = cast<Value>(Val: XLD);
10262 } else if (XElemTy->isStructTy()) {
10263 // FIXME: Add checks to ensure __atomic_load is emitted iff the
10264 // target does not support `atomicrmw` of the size of the struct
10265 LoadInst *OldVal = Builder.CreateLoad(Ty: XElemTy, Ptr: X.Var, Name: "omp.atomic.read");
10266 OldVal->setAtomic(Ordering: AO);
10267 const DataLayout &DL = OldVal->getModule()->getDataLayout();
10268 unsigned LoadSize = DL.getTypeStoreSize(Ty: XElemTy);
10269 OpenMPIRBuilder::AtomicInfo atomicInfo(
10270 &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
10271 OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X.Var);
10272 auto AtomicLoadRes = atomicInfo.EmitAtomicLoadLibcall(AO);
10273 XRead = AtomicLoadRes.first;
10274 OldVal->eraseFromParent();
10275 } else {
10276 // We need to perform atomic op as integer
10277 IntegerType *IntCastTy =
10278 IntegerType::get(C&: M.getContext(), NumBits: XElemTy->getScalarSizeInBits());
10279 LoadInst *XLoad =
10280 Builder.CreateLoad(Ty: IntCastTy, Ptr: X.Var, isVolatile: X.IsVolatile, Name: "omp.atomic.load");
10281 XLoad->setAtomic(Ordering: AO);
10282 if (XElemTy->isFloatingPointTy()) {
10283 XRead = Builder.CreateBitCast(V: XLoad, DestTy: XElemTy, Name: "atomic.flt.cast");
10284 } else {
10285 XRead = Builder.CreateIntToPtr(V: XLoad, DestTy: XElemTy, Name: "atomic.ptr.cast");
10286 }
10287 }
10288 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Read);
10289 Builder.CreateStore(Val: XRead, Ptr: V.Var, isVolatile: V.IsVolatile);
10290 return Builder.saveIP();
10291}
10292
10293OpenMPIRBuilder::InsertPointTy
10294OpenMPIRBuilder::createAtomicWrite(const LocationDescription &Loc,
10295 AtomicOpValue &X, Value *Expr,
10296 AtomicOrdering AO, InsertPointTy AllocaIP) {
10297 if (!updateToLocation(Loc))
10298 return Loc.IP;
10299
10300 assert(X.Var->getType()->isPointerTy() &&
10301 "OMP Atomic expects a pointer to target memory");
10302 Type *XElemTy = X.ElemTy;
10303 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
10304 XElemTy->isPointerTy() || XElemTy->isStructTy()) &&
10305 "OMP atomic write expected a scalar type");
10306
10307 if (XElemTy->isIntegerTy()) {
10308 StoreInst *XSt = Builder.CreateStore(Val: Expr, Ptr: X.Var, isVolatile: X.IsVolatile);
10309 XSt->setAtomic(Ordering: AO);
10310 } else if (XElemTy->isStructTy()) {
10311 LoadInst *OldVal = Builder.CreateLoad(Ty: XElemTy, Ptr: X.Var, Name: "omp.atomic.read");
10312 const DataLayout &DL = OldVal->getModule()->getDataLayout();
10313 unsigned LoadSize = DL.getTypeStoreSize(Ty: XElemTy);
10314 OpenMPIRBuilder::AtomicInfo atomicInfo(
10315 &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
10316 OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X.Var);
10317 atomicInfo.EmitAtomicStoreLibcall(AO, Source: Expr);
10318 OldVal->eraseFromParent();
10319 } else {
10320 // We need to bitcast and perform atomic op as integers
10321 IntegerType *IntCastTy =
10322 IntegerType::get(C&: M.getContext(), NumBits: XElemTy->getScalarSizeInBits());
10323 Value *ExprCast =
10324 Builder.CreateBitCast(V: Expr, DestTy: IntCastTy, Name: "atomic.src.int.cast");
10325 StoreInst *XSt = Builder.CreateStore(Val: ExprCast, Ptr: X.Var, isVolatile: X.IsVolatile);
10326 XSt->setAtomic(Ordering: AO);
10327 }
10328
10329 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Write);
10330 return Builder.saveIP();
10331}
10332
10333OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createAtomicUpdate(
10334 const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
10335 Value *Expr, AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
10336 AtomicUpdateCallbackTy &UpdateOp, bool IsXBinopExpr,
10337 bool IsIgnoreDenormalMode, bool IsFineGrainedMemory, bool IsRemoteMemory) {
10338 assert(!isConflictIP(Loc.IP, AllocaIP) && "IPs must not be ambiguous");
10339 if (!updateToLocation(Loc))
10340 return Loc.IP;
10341
10342 LLVM_DEBUG({
10343 Type *XTy = X.Var->getType();
10344 assert(XTy->isPointerTy() &&
10345 "OMP Atomic expects a pointer to target memory");
10346 Type *XElemTy = X.ElemTy;
10347 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
10348 XElemTy->isPointerTy()) &&
10349 "OMP atomic update expected a scalar type");
10350 assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
10351 (RMWOp != AtomicRMWInst::UMax) && (RMWOp != AtomicRMWInst::UMin) &&
10352 "OpenMP atomic does not support LT or GT operations");
10353 });
10354
10355 Expected<std::pair<Value *, Value *>> AtomicResult = emitAtomicUpdate(
10356 AllocaIP, X: X.Var, XElemTy: X.ElemTy, Expr, AO, RMWOp, UpdateOp, VolatileX: X.IsVolatile,
10357 IsXBinopExpr, IsIgnoreDenormalMode, IsFineGrainedMemory, IsRemoteMemory);
10358 if (!AtomicResult)
10359 return AtomicResult.takeError();
10360 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Update);
10361 return Builder.saveIP();
10362}
10363
10364// FIXME: Duplicating AtomicExpand
10365Value *OpenMPIRBuilder::emitRMWOpAsInstruction(Value *Src1, Value *Src2,
10366 AtomicRMWInst::BinOp RMWOp) {
10367 switch (RMWOp) {
10368 case AtomicRMWInst::Add:
10369 return Builder.CreateAdd(LHS: Src1, RHS: Src2);
10370 case AtomicRMWInst::Sub:
10371 return Builder.CreateSub(LHS: Src1, RHS: Src2);
10372 case AtomicRMWInst::And:
10373 return Builder.CreateAnd(LHS: Src1, RHS: Src2);
10374 case AtomicRMWInst::Nand:
10375 return Builder.CreateNeg(V: Builder.CreateAnd(LHS: Src1, RHS: Src2));
10376 case AtomicRMWInst::Or:
10377 return Builder.CreateOr(LHS: Src1, RHS: Src2);
10378 case AtomicRMWInst::Xor:
10379 return Builder.CreateXor(LHS: Src1, RHS: Src2);
10380 case AtomicRMWInst::Xchg:
10381 case AtomicRMWInst::FAdd:
10382 case AtomicRMWInst::FSub:
10383 case AtomicRMWInst::BAD_BINOP:
10384 case AtomicRMWInst::Max:
10385 case AtomicRMWInst::Min:
10386 case AtomicRMWInst::UMax:
10387 case AtomicRMWInst::UMin:
10388 case AtomicRMWInst::FMax:
10389 case AtomicRMWInst::FMin:
10390 case AtomicRMWInst::FMaximum:
10391 case AtomicRMWInst::FMinimum:
10392 case AtomicRMWInst::UIncWrap:
10393 case AtomicRMWInst::UDecWrap:
10394 case AtomicRMWInst::USubCond:
10395 case AtomicRMWInst::USubSat:
10396 llvm_unreachable("Unsupported atomic update operation");
10397 }
10398 llvm_unreachable("Unsupported atomic update operation");
10399}
10400
10401Expected<std::pair<Value *, Value *>> OpenMPIRBuilder::emitAtomicUpdate(
10402 InsertPointTy AllocaIP, Value *X, Type *XElemTy, Value *Expr,
10403 AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
10404 AtomicUpdateCallbackTy &UpdateOp, bool VolatileX, bool IsXBinopExpr,
10405 bool IsIgnoreDenormalMode, bool IsFineGrainedMemory, bool IsRemoteMemory) {
10406 // TODO: handle the case where XElemTy is not byte-sized or not a power of 2
10407 // or a complex datatype.
10408 bool emitRMWOp = false;
10409 switch (RMWOp) {
10410 case AtomicRMWInst::Add:
10411 case AtomicRMWInst::And:
10412 case AtomicRMWInst::Nand:
10413 case AtomicRMWInst::Or:
10414 case AtomicRMWInst::Xor:
10415 case AtomicRMWInst::Xchg:
10416 emitRMWOp = XElemTy;
10417 break;
10418 case AtomicRMWInst::Sub:
10419 emitRMWOp = (IsXBinopExpr && XElemTy);
10420 break;
10421 default:
10422 emitRMWOp = false;
10423 }
10424 emitRMWOp &= XElemTy->isIntegerTy();
10425
10426 std::pair<Value *, Value *> Res;
10427 if (emitRMWOp) {
10428 AtomicRMWInst *RMWInst =
10429 Builder.CreateAtomicRMW(Op: RMWOp, Ptr: X, Val: Expr, Align: llvm::MaybeAlign(), Ordering: AO);
10430 if (T.isAMDGPU()) {
10431 if (IsIgnoreDenormalMode)
10432 RMWInst->setMetadata(Kind: "amdgpu.ignore.denormal.mode",
10433 Node: llvm::MDNode::get(Context&: Builder.getContext(), MDs: {}));
10434 if (!IsFineGrainedMemory)
10435 RMWInst->setMetadata(Kind: "amdgpu.no.fine.grained.memory",
10436 Node: llvm::MDNode::get(Context&: Builder.getContext(), MDs: {}));
10437 if (!IsRemoteMemory)
10438 RMWInst->setMetadata(Kind: "amdgpu.no.remote.memory",
10439 Node: llvm::MDNode::get(Context&: Builder.getContext(), MDs: {}));
10440 }
10441 Res.first = RMWInst;
10442 // not needed except in case of postfix captures. Generate anyway for
10443 // consistency with the else part. Will be removed with any DCE pass.
10444 // AtomicRMWInst::Xchg does not have a coressponding instruction.
10445 if (RMWOp == AtomicRMWInst::Xchg)
10446 Res.second = Res.first;
10447 else
10448 Res.second = emitRMWOpAsInstruction(Src1: Res.first, Src2: Expr, RMWOp);
10449 } else if (RMWOp == llvm::AtomicRMWInst::BinOp::BAD_BINOP &&
10450 XElemTy->isStructTy()) {
10451 LoadInst *OldVal =
10452 Builder.CreateLoad(Ty: XElemTy, Ptr: X, Name: X->getName() + ".atomic.load");
10453 OldVal->setAtomic(Ordering: AO);
10454 const DataLayout &LoadDL = OldVal->getModule()->getDataLayout();
10455 unsigned LoadSize =
10456 LoadDL.getTypeStoreSize(Ty: OldVal->getPointerOperand()->getType());
10457
10458 OpenMPIRBuilder::AtomicInfo atomicInfo(
10459 &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
10460 OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X);
10461 auto AtomicLoadRes = atomicInfo.EmitAtomicLoadLibcall(AO);
10462 BasicBlock *CurBB = Builder.GetInsertBlock();
10463 Instruction *CurBBTI = CurBB->getTerminator();
10464 CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
10465 BasicBlock *ExitBB =
10466 CurBB->splitBasicBlock(I: CurBBTI, BBName: X->getName() + ".atomic.exit");
10467 BasicBlock *ContBB = CurBB->splitBasicBlock(I: CurBB->getTerminator(),
10468 BBName: X->getName() + ".atomic.cont");
10469 ContBB->getTerminator()->eraseFromParent();
10470 Builder.restoreIP(IP: AllocaIP);
10471 AllocaInst *NewAtomicAddr = Builder.CreateAlloca(Ty: XElemTy);
10472 NewAtomicAddr->setName(X->getName() + "x.new.val");
10473 Builder.SetInsertPoint(ContBB);
10474 llvm::PHINode *PHI = Builder.CreatePHI(Ty: OldVal->getType(), NumReservedValues: 2);
10475 PHI->addIncoming(V: AtomicLoadRes.first, BB: CurBB);
10476 Value *OldExprVal = PHI;
10477 Expected<Value *> CBResult = UpdateOp(OldExprVal, Builder);
10478 if (!CBResult)
10479 return CBResult.takeError();
10480 Value *Upd = *CBResult;
10481 Builder.CreateStore(Val: Upd, Ptr: NewAtomicAddr);
10482 AtomicOrdering Failure =
10483 llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(SuccessOrdering: AO);
10484 auto Result = atomicInfo.EmitAtomicCompareExchangeLibcall(
10485 ExpectedVal: AtomicLoadRes.second, DesiredVal: NewAtomicAddr, Success: AO, Failure);
10486 LoadInst *PHILoad = Builder.CreateLoad(Ty: XElemTy, Ptr: Result.first);
10487 PHI->addIncoming(V: PHILoad, BB: Builder.GetInsertBlock());
10488 Builder.CreateCondBr(Cond: Result.second, True: ExitBB, False: ContBB);
10489 OldVal->eraseFromParent();
10490 Res.first = OldExprVal;
10491 Res.second = Upd;
10492
10493 if (UnreachableInst *ExitTI =
10494 dyn_cast<UnreachableInst>(Val: ExitBB->getTerminator())) {
10495 CurBBTI->eraseFromParent();
10496 Builder.SetInsertPoint(ExitBB);
10497 } else {
10498 Builder.SetInsertPoint(ExitTI);
10499 }
10500 } else {
10501 IntegerType *IntCastTy =
10502 IntegerType::get(C&: M.getContext(), NumBits: XElemTy->getScalarSizeInBits());
10503 LoadInst *OldVal =
10504 Builder.CreateLoad(Ty: IntCastTy, Ptr: X, Name: X->getName() + ".atomic.load");
10505 OldVal->setAtomic(Ordering: AO);
10506 // CurBB
10507 // | /---\
10508 // ContBB |
10509 // | \---/
10510 // ExitBB
10511 BasicBlock *CurBB = Builder.GetInsertBlock();
10512 Instruction *CurBBTI = CurBB->getTerminator();
10513 CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
10514 BasicBlock *ExitBB =
10515 CurBB->splitBasicBlock(I: CurBBTI, BBName: X->getName() + ".atomic.exit");
10516 BasicBlock *ContBB = CurBB->splitBasicBlock(I: CurBB->getTerminator(),
10517 BBName: X->getName() + ".atomic.cont");
10518 ContBB->getTerminator()->eraseFromParent();
10519 Builder.restoreIP(IP: AllocaIP);
10520 AllocaInst *NewAtomicAddr = Builder.CreateAlloca(Ty: XElemTy);
10521 NewAtomicAddr->setName(X->getName() + "x.new.val");
10522 Builder.SetInsertPoint(ContBB);
10523 llvm::PHINode *PHI = Builder.CreatePHI(Ty: OldVal->getType(), NumReservedValues: 2);
10524 PHI->addIncoming(V: OldVal, BB: CurBB);
10525 bool IsIntTy = XElemTy->isIntegerTy();
10526 Value *OldExprVal = PHI;
10527 if (!IsIntTy) {
10528 if (XElemTy->isFloatingPointTy()) {
10529 OldExprVal = Builder.CreateBitCast(V: PHI, DestTy: XElemTy,
10530 Name: X->getName() + ".atomic.fltCast");
10531 } else {
10532 OldExprVal = Builder.CreateIntToPtr(V: PHI, DestTy: XElemTy,
10533 Name: X->getName() + ".atomic.ptrCast");
10534 }
10535 }
10536
10537 Expected<Value *> CBResult = UpdateOp(OldExprVal, Builder);
10538 if (!CBResult)
10539 return CBResult.takeError();
10540 Value *Upd = *CBResult;
10541 Builder.CreateStore(Val: Upd, Ptr: NewAtomicAddr);
10542 LoadInst *DesiredVal = Builder.CreateLoad(Ty: IntCastTy, Ptr: NewAtomicAddr);
10543 AtomicOrdering Failure =
10544 llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(SuccessOrdering: AO);
10545 AtomicCmpXchgInst *Result = Builder.CreateAtomicCmpXchg(
10546 Ptr: X, Cmp: PHI, New: DesiredVal, Align: llvm::MaybeAlign(), SuccessOrdering: AO, FailureOrdering: Failure);
10547 Result->setVolatile(VolatileX);
10548 Value *PreviousVal = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/0);
10549 Value *SuccessFailureVal = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/1);
10550 PHI->addIncoming(V: PreviousVal, BB: Builder.GetInsertBlock());
10551 Builder.CreateCondBr(Cond: SuccessFailureVal, True: ExitBB, False: ContBB);
10552
10553 Res.first = OldExprVal;
10554 Res.second = Upd;
10555
10556 // set Insertion point in exit block
10557 if (UnreachableInst *ExitTI =
10558 dyn_cast<UnreachableInst>(Val: ExitBB->getTerminator())) {
10559 CurBBTI->eraseFromParent();
10560 Builder.SetInsertPoint(ExitBB);
10561 } else {
10562 Builder.SetInsertPoint(ExitTI);
10563 }
10564 }
10565
10566 return Res;
10567}
10568
10569OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createAtomicCapture(
10570 const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
10571 AtomicOpValue &V, Value *Expr, AtomicOrdering AO,
10572 AtomicRMWInst::BinOp RMWOp, AtomicUpdateCallbackTy &UpdateOp,
10573 bool UpdateExpr, bool IsPostfixUpdate, bool IsXBinopExpr,
10574 bool IsIgnoreDenormalMode, bool IsFineGrainedMemory, bool IsRemoteMemory) {
10575 if (!updateToLocation(Loc))
10576 return Loc.IP;
10577
10578 LLVM_DEBUG({
10579 Type *XTy = X.Var->getType();
10580 assert(XTy->isPointerTy() &&
10581 "OMP Atomic expects a pointer to target memory");
10582 Type *XElemTy = X.ElemTy;
10583 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
10584 XElemTy->isPointerTy()) &&
10585 "OMP atomic capture expected a scalar type");
10586 assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
10587 "OpenMP atomic does not support LT or GT operations");
10588 });
10589
10590 // If UpdateExpr is 'x' updated with some `expr` not based on 'x',
10591 // 'x' is simply atomically rewritten with 'expr'.
10592 AtomicRMWInst::BinOp AtomicOp = (UpdateExpr ? RMWOp : AtomicRMWInst::Xchg);
10593 Expected<std::pair<Value *, Value *>> AtomicResult = emitAtomicUpdate(
10594 AllocaIP, X: X.Var, XElemTy: X.ElemTy, Expr, AO, RMWOp: AtomicOp, UpdateOp, VolatileX: X.IsVolatile,
10595 IsXBinopExpr, IsIgnoreDenormalMode, IsFineGrainedMemory, IsRemoteMemory);
10596 if (!AtomicResult)
10597 return AtomicResult.takeError();
10598 Value *CapturedVal =
10599 (IsPostfixUpdate ? AtomicResult->first : AtomicResult->second);
10600 Builder.CreateStore(Val: CapturedVal, Ptr: V.Var, isVolatile: V.IsVolatile);
10601
10602 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Capture);
10603 return Builder.saveIP();
10604}
10605
10606OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
10607 const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
10608 AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
10609 omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
10610 bool IsFailOnly) {
10611
10612 AtomicOrdering Failure = AtomicCmpXchgInst::getStrongestFailureOrdering(SuccessOrdering: AO);
10613 return createAtomicCompare(Loc, X, V, R, E, D, AO, Op, IsXBinopExpr,
10614 IsPostfixUpdate, IsFailOnly, Failure);
10615}
10616
10617OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
10618 const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
10619 AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
10620 omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
10621 bool IsFailOnly, AtomicOrdering Failure) {
10622
10623 if (!updateToLocation(Loc))
10624 return Loc.IP;
10625
10626 assert(X.Var->getType()->isPointerTy() &&
10627 "OMP atomic expects a pointer to target memory");
10628 // compare capture
10629 if (V.Var) {
10630 assert(V.Var->getType()->isPointerTy() && "v.var must be of pointer type");
10631 assert(V.ElemTy == X.ElemTy && "x and v must be of same type");
10632 }
10633
10634 bool IsInteger = E->getType()->isIntegerTy();
10635
10636 if (Op == OMPAtomicCompareOp::EQ) {
10637 AtomicCmpXchgInst *Result = nullptr;
10638 if (!IsInteger) {
10639 IntegerType *IntCastTy =
10640 IntegerType::get(C&: M.getContext(), NumBits: X.ElemTy->getScalarSizeInBits());
10641 Value *EBCast = Builder.CreateBitCast(V: E, DestTy: IntCastTy);
10642 Value *DBCast = Builder.CreateBitCast(V: D, DestTy: IntCastTy);
10643 Result = Builder.CreateAtomicCmpXchg(Ptr: X.Var, Cmp: EBCast, New: DBCast, Align: MaybeAlign(),
10644 SuccessOrdering: AO, FailureOrdering: Failure);
10645 } else {
10646 Result =
10647 Builder.CreateAtomicCmpXchg(Ptr: X.Var, Cmp: E, New: D, Align: MaybeAlign(), SuccessOrdering: AO, FailureOrdering: Failure);
10648 }
10649
10650 if (V.Var) {
10651 Value *OldValue = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/0);
10652 if (!IsInteger)
10653 OldValue = Builder.CreateBitCast(V: OldValue, DestTy: X.ElemTy);
10654 assert(OldValue->getType() == V.ElemTy &&
10655 "OldValue and V must be of same type");
10656 if (IsPostfixUpdate) {
10657 Builder.CreateStore(Val: OldValue, Ptr: V.Var, isVolatile: V.IsVolatile);
10658 } else {
10659 Value *SuccessOrFail = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/1);
10660 if (IsFailOnly) {
10661 // CurBB----
10662 // | |
10663 // v |
10664 // ContBB |
10665 // | |
10666 // v |
10667 // ExitBB <-
10668 //
10669 // where ContBB only contains the store of old value to 'v'.
10670 BasicBlock *CurBB = Builder.GetInsertBlock();
10671 Instruction *CurBBTI = CurBB->getTerminator();
10672 CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
10673 BasicBlock *ExitBB = CurBB->splitBasicBlock(
10674 I: CurBBTI, BBName: X.Var->getName() + ".atomic.exit");
10675 BasicBlock *ContBB = CurBB->splitBasicBlock(
10676 I: CurBB->getTerminator(), BBName: X.Var->getName() + ".atomic.cont");
10677 ContBB->getTerminator()->eraseFromParent();
10678 CurBB->getTerminator()->eraseFromParent();
10679
10680 Builder.CreateCondBr(Cond: SuccessOrFail, True: ExitBB, False: ContBB);
10681
10682 Builder.SetInsertPoint(ContBB);
10683 Builder.CreateStore(Val: OldValue, Ptr: V.Var);
10684 Builder.CreateBr(Dest: ExitBB);
10685
10686 if (UnreachableInst *ExitTI =
10687 dyn_cast<UnreachableInst>(Val: ExitBB->getTerminator())) {
10688 CurBBTI->eraseFromParent();
10689 Builder.SetInsertPoint(ExitBB);
10690 } else {
10691 Builder.SetInsertPoint(ExitTI);
10692 }
10693 } else {
10694 Value *CapturedValue =
10695 Builder.CreateSelect(C: SuccessOrFail, True: E, False: OldValue);
10696 Builder.CreateStore(Val: CapturedValue, Ptr: V.Var, isVolatile: V.IsVolatile);
10697 }
10698 }
10699 }
10700 // The comparison result has to be stored.
10701 if (R.Var) {
10702 assert(R.Var->getType()->isPointerTy() &&
10703 "r.var must be of pointer type");
10704 assert(R.ElemTy->isIntegerTy() && "r must be of integral type");
10705
10706 Value *SuccessFailureVal = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/1);
10707 Value *ResultCast = R.IsSigned
10708 ? Builder.CreateSExt(V: SuccessFailureVal, DestTy: R.ElemTy)
10709 : Builder.CreateZExt(V: SuccessFailureVal, DestTy: R.ElemTy);
10710 Builder.CreateStore(Val: ResultCast, Ptr: R.Var, isVolatile: R.IsVolatile);
10711 }
10712 } else {
10713 assert((Op == OMPAtomicCompareOp::MAX || Op == OMPAtomicCompareOp::MIN) &&
10714 "Op should be either max or min at this point");
10715 assert(!IsFailOnly && "IsFailOnly is only valid when the comparison is ==");
10716
10717 // Reverse the ordop as the OpenMP forms are different from LLVM forms.
10718 // Let's take max as example.
10719 // OpenMP form:
10720 // x = x > expr ? expr : x;
10721 // LLVM form:
10722 // *ptr = *ptr > val ? *ptr : val;
10723 // We need to transform to LLVM form.
10724 // x = x <= expr ? x : expr;
10725 AtomicRMWInst::BinOp NewOp;
10726 if (IsXBinopExpr) {
10727 if (IsInteger) {
10728 if (X.IsSigned)
10729 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Min
10730 : AtomicRMWInst::Max;
10731 else
10732 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMin
10733 : AtomicRMWInst::UMax;
10734 } else {
10735 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMin
10736 : AtomicRMWInst::FMax;
10737 }
10738 } else {
10739 if (IsInteger) {
10740 if (X.IsSigned)
10741 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Max
10742 : AtomicRMWInst::Min;
10743 else
10744 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMax
10745 : AtomicRMWInst::UMin;
10746 } else {
10747 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMax
10748 : AtomicRMWInst::FMin;
10749 }
10750 }
10751
10752 AtomicRMWInst *OldValue =
10753 Builder.CreateAtomicRMW(Op: NewOp, Ptr: X.Var, Val: E, Align: MaybeAlign(), Ordering: AO);
10754 if (V.Var) {
10755 Value *CapturedValue = nullptr;
10756 if (IsPostfixUpdate) {
10757 CapturedValue = OldValue;
10758 } else {
10759 CmpInst::Predicate Pred;
10760 switch (NewOp) {
10761 case AtomicRMWInst::Max:
10762 Pred = CmpInst::ICMP_SGT;
10763 break;
10764 case AtomicRMWInst::UMax:
10765 Pred = CmpInst::ICMP_UGT;
10766 break;
10767 case AtomicRMWInst::FMax:
10768 Pred = CmpInst::FCMP_OGT;
10769 break;
10770 case AtomicRMWInst::Min:
10771 Pred = CmpInst::ICMP_SLT;
10772 break;
10773 case AtomicRMWInst::UMin:
10774 Pred = CmpInst::ICMP_ULT;
10775 break;
10776 case AtomicRMWInst::FMin:
10777 Pred = CmpInst::FCMP_OLT;
10778 break;
10779 default:
10780 llvm_unreachable("unexpected comparison op");
10781 }
10782 Value *NonAtomicCmp = Builder.CreateCmp(Pred, LHS: OldValue, RHS: E);
10783 CapturedValue = Builder.CreateSelect(C: NonAtomicCmp, True: E, False: OldValue);
10784 }
10785 Builder.CreateStore(Val: CapturedValue, Ptr: V.Var, isVolatile: V.IsVolatile);
10786 }
10787 }
10788
10789 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Compare);
10790
10791 return Builder.saveIP();
10792}
10793
10794OpenMPIRBuilder::InsertPointOrErrorTy
10795OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
10796 BodyGenCallbackTy BodyGenCB, Value *NumTeamsLower,
10797 Value *NumTeamsUpper, Value *ThreadLimit,
10798 Value *IfExpr) {
10799 if (!updateToLocation(Loc))
10800 return InsertPointTy();
10801
10802 uint32_t SrcLocStrSize;
10803 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
10804 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
10805 Function *CurrentFunction = Builder.GetInsertBlock()->getParent();
10806
10807 // Outer allocation basicblock is the entry block of the current function.
10808 BasicBlock &OuterAllocaBB = CurrentFunction->getEntryBlock();
10809 if (&OuterAllocaBB == Builder.GetInsertBlock()) {
10810 BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, Name: "teams.entry");
10811 Builder.SetInsertPoint(TheBB: BodyBB, IP: BodyBB->begin());
10812 }
10813
10814 // The current basic block is split into four basic blocks. After outlining,
10815 // they will be mapped as follows:
10816 // ```
10817 // def current_fn() {
10818 // current_basic_block:
10819 // br label %teams.exit
10820 // teams.exit:
10821 // ; instructions after teams
10822 // }
10823 //
10824 // def outlined_fn() {
10825 // teams.alloca:
10826 // br label %teams.body
10827 // teams.body:
10828 // ; instructions within teams body
10829 // }
10830 // ```
10831 BasicBlock *ExitBB = splitBB(Builder, /*CreateBranch=*/true, Name: "teams.exit");
10832 BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, Name: "teams.body");
10833 BasicBlock *AllocaBB =
10834 splitBB(Builder, /*CreateBranch=*/true, Name: "teams.alloca");
10835
10836 bool SubClausesPresent =
10837 (NumTeamsLower || NumTeamsUpper || ThreadLimit || IfExpr);
10838 // Push num_teams
10839 if (!Config.isTargetDevice() && SubClausesPresent) {
10840 assert((NumTeamsLower == nullptr || NumTeamsUpper != nullptr) &&
10841 "if lowerbound is non-null, then upperbound must also be non-null "
10842 "for bounds on num_teams");
10843
10844 if (NumTeamsUpper == nullptr)
10845 NumTeamsUpper = Builder.getInt32(C: 0);
10846
10847 if (NumTeamsLower == nullptr)
10848 NumTeamsLower = NumTeamsUpper;
10849
10850 if (IfExpr) {
10851 assert(IfExpr->getType()->isIntegerTy() &&
10852 "argument to if clause must be an integer value");
10853
10854 // upper = ifexpr ? upper : 1
10855 if (IfExpr->getType() != Int1)
10856 IfExpr = Builder.CreateICmpNE(LHS: IfExpr,
10857 RHS: ConstantInt::get(Ty: IfExpr->getType(), V: 0));
10858 NumTeamsUpper = Builder.CreateSelect(
10859 C: IfExpr, True: NumTeamsUpper, False: Builder.getInt32(C: 1), Name: "numTeamsUpper");
10860
10861 // lower = ifexpr ? lower : 1
10862 NumTeamsLower = Builder.CreateSelect(
10863 C: IfExpr, True: NumTeamsLower, False: Builder.getInt32(C: 1), Name: "numTeamsLower");
10864 }
10865
10866 if (ThreadLimit == nullptr)
10867 ThreadLimit = Builder.getInt32(C: 0);
10868
10869 // The __kmpc_push_num_teams_51 function expects int32 as the arguments. So,
10870 // truncate or sign extend the passed values to match the int32 parameters.
10871 Value *NumTeamsLowerInt32 =
10872 Builder.CreateSExtOrTrunc(V: NumTeamsLower, DestTy: Builder.getInt32Ty());
10873 Value *NumTeamsUpperInt32 =
10874 Builder.CreateSExtOrTrunc(V: NumTeamsUpper, DestTy: Builder.getInt32Ty());
10875 Value *ThreadLimitInt32 =
10876 Builder.CreateSExtOrTrunc(V: ThreadLimit, DestTy: Builder.getInt32Ty());
10877
10878 Value *ThreadNum = getOrCreateThreadID(Ident);
10879
10880 createRuntimeFunctionCall(
10881 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_push_num_teams_51),
10882 Args: {Ident, ThreadNum, NumTeamsLowerInt32, NumTeamsUpperInt32,
10883 ThreadLimitInt32});
10884 }
10885 // Generate the body of teams.
10886 InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
10887 InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
10888 if (Error Err = BodyGenCB(AllocaIP, CodeGenIP))
10889 return Err;
10890
10891 OutlineInfo OI;
10892 OI.EntryBB = AllocaBB;
10893 OI.ExitBB = ExitBB;
10894 OI.OuterAllocaBB = &OuterAllocaBB;
10895
10896 // Insert fake values for global tid and bound tid.
10897 SmallVector<Instruction *, 8> ToBeDeleted;
10898 InsertPointTy OuterAllocaIP(&OuterAllocaBB, OuterAllocaBB.begin());
10899 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
10900 Builder, OuterAllocaIP, ToBeDeleted, InnerAllocaIP: AllocaIP, Name: "gid", AsPtr: true));
10901 OI.ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
10902 Builder, OuterAllocaIP, ToBeDeleted, InnerAllocaIP: AllocaIP, Name: "tid", AsPtr: true));
10903
10904 auto HostPostOutlineCB = [this, Ident,
10905 ToBeDeleted](Function &OutlinedFn) mutable {
10906 // The stale call instruction will be replaced with a new call instruction
10907 // for runtime call with the outlined function.
10908
10909 assert(OutlinedFn.hasOneUse() &&
10910 "there must be a single user for the outlined function");
10911 CallInst *StaleCI = cast<CallInst>(Val: OutlinedFn.user_back());
10912 ToBeDeleted.push_back(Elt: StaleCI);
10913
10914 assert((OutlinedFn.arg_size() == 2 || OutlinedFn.arg_size() == 3) &&
10915 "Outlined function must have two or three arguments only");
10916
10917 bool HasShared = OutlinedFn.arg_size() == 3;
10918
10919 OutlinedFn.getArg(i: 0)->setName("global.tid.ptr");
10920 OutlinedFn.getArg(i: 1)->setName("bound.tid.ptr");
10921 if (HasShared)
10922 OutlinedFn.getArg(i: 2)->setName("data");
10923
10924 // Call to the runtime function for teams in the current function.
10925 assert(StaleCI && "Error while outlining - no CallInst user found for the "
10926 "outlined function.");
10927 Builder.SetInsertPoint(StaleCI);
10928 SmallVector<Value *> Args = {
10929 Ident, Builder.getInt32(C: StaleCI->arg_size() - 2), &OutlinedFn};
10930 if (HasShared)
10931 Args.push_back(Elt: StaleCI->getArgOperand(i: 2));
10932 createRuntimeFunctionCall(
10933 Callee: getOrCreateRuntimeFunctionPtr(
10934 FnID: omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
10935 Args);
10936
10937 for (Instruction *I : llvm::reverse(C&: ToBeDeleted))
10938 I->eraseFromParent();
10939 };
10940
10941 if (!Config.isTargetDevice())
10942 OI.PostOutlineCB = HostPostOutlineCB;
10943
10944 addOutlineInfo(OI: std::move(OI));
10945
10946 Builder.SetInsertPoint(TheBB: ExitBB, IP: ExitBB->begin());
10947
10948 return Builder.saveIP();
10949}
10950
10951OpenMPIRBuilder::InsertPointOrErrorTy
10952OpenMPIRBuilder::createDistribute(const LocationDescription &Loc,
10953 InsertPointTy OuterAllocaIP,
10954 BodyGenCallbackTy BodyGenCB) {
10955 if (!updateToLocation(Loc))
10956 return InsertPointTy();
10957
10958 BasicBlock *OuterAllocaBB = OuterAllocaIP.getBlock();
10959
10960 if (OuterAllocaBB == Builder.GetInsertBlock()) {
10961 BasicBlock *BodyBB =
10962 splitBB(Builder, /*CreateBranch=*/true, Name: "distribute.entry");
10963 Builder.SetInsertPoint(TheBB: BodyBB, IP: BodyBB->begin());
10964 }
10965 BasicBlock *ExitBB =
10966 splitBB(Builder, /*CreateBranch=*/true, Name: "distribute.exit");
10967 BasicBlock *BodyBB =
10968 splitBB(Builder, /*CreateBranch=*/true, Name: "distribute.body");
10969 BasicBlock *AllocaBB =
10970 splitBB(Builder, /*CreateBranch=*/true, Name: "distribute.alloca");
10971
10972 // Generate the body of distribute clause
10973 InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
10974 InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
10975 if (Error Err = BodyGenCB(AllocaIP, CodeGenIP))
10976 return Err;
10977
10978 // When using target we use different runtime functions which require a
10979 // callback.
10980 if (Config.isTargetDevice()) {
10981 OutlineInfo OI;
10982 OI.OuterAllocaBB = OuterAllocaIP.getBlock();
10983 OI.EntryBB = AllocaBB;
10984 OI.ExitBB = ExitBB;
10985
10986 addOutlineInfo(OI: std::move(OI));
10987 }
10988 Builder.SetInsertPoint(TheBB: ExitBB, IP: ExitBB->begin());
10989
10990 return Builder.saveIP();
10991}
10992
10993GlobalVariable *
10994OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl<llvm::Constant *> &Names,
10995 std::string VarName) {
10996 llvm::Constant *MapNamesArrayInit = llvm::ConstantArray::get(
10997 T: llvm::ArrayType::get(ElementType: llvm::PointerType::getUnqual(C&: M.getContext()),
10998 NumElements: Names.size()),
10999 V: Names);
11000 auto *MapNamesArrayGlobal = new llvm::GlobalVariable(
11001 M, MapNamesArrayInit->getType(),
11002 /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MapNamesArrayInit,
11003 VarName);
11004 return MapNamesArrayGlobal;
11005}
11006
11007// Create all simple and struct types exposed by the runtime and remember
11008// the llvm::PointerTypes of them for easy access later.
11009void OpenMPIRBuilder::initializeTypes(Module &M) {
11010 LLVMContext &Ctx = M.getContext();
11011 StructType *T;
11012 unsigned DefaultTargetAS = Config.getDefaultTargetAS();
11013 unsigned ProgramAS = M.getDataLayout().getProgramAddressSpace();
11014#define OMP_TYPE(VarName, InitValue) VarName = InitValue;
11015#define OMP_ARRAY_TYPE(VarName, ElemTy, ArraySize) \
11016 VarName##Ty = ArrayType::get(ElemTy, ArraySize); \
11017 VarName##PtrTy = PointerType::get(Ctx, DefaultTargetAS);
11018#define OMP_FUNCTION_TYPE(VarName, IsVarArg, ReturnType, ...) \
11019 VarName = FunctionType::get(ReturnType, {__VA_ARGS__}, IsVarArg); \
11020 VarName##Ptr = PointerType::get(Ctx, ProgramAS);
11021#define OMP_STRUCT_TYPE(VarName, StructName, Packed, ...) \
11022 T = StructType::getTypeByName(Ctx, StructName); \
11023 if (!T) \
11024 T = StructType::create(Ctx, {__VA_ARGS__}, StructName, Packed); \
11025 VarName = T; \
11026 VarName##Ptr = PointerType::get(Ctx, DefaultTargetAS);
11027#include "llvm/Frontend/OpenMP/OMPKinds.def"
11028}
11029
11030void OpenMPIRBuilder::OutlineInfo::collectBlocks(
11031 SmallPtrSetImpl<BasicBlock *> &BlockSet,
11032 SmallVectorImpl<BasicBlock *> &BlockVector) {
11033 SmallVector<BasicBlock *, 32> Worklist;
11034 BlockSet.insert(Ptr: EntryBB);
11035 BlockSet.insert(Ptr: ExitBB);
11036
11037 Worklist.push_back(Elt: EntryBB);
11038 while (!Worklist.empty()) {
11039 BasicBlock *BB = Worklist.pop_back_val();
11040 BlockVector.push_back(Elt: BB);
11041 for (BasicBlock *SuccBB : successors(BB))
11042 if (BlockSet.insert(Ptr: SuccBB).second)
11043 Worklist.push_back(Elt: SuccBB);
11044 }
11045}
11046
11047void OpenMPIRBuilder::createOffloadEntry(Constant *ID, Constant *Addr,
11048 uint64_t Size, int32_t Flags,
11049 GlobalValue::LinkageTypes,
11050 StringRef Name) {
11051 if (!Config.isGPU()) {
11052 llvm::offloading::emitOffloadingEntry(
11053 M, Kind: object::OffloadKind::OFK_OpenMP, Addr: ID,
11054 Name: Name.empty() ? Addr->getName() : Name, Size, Flags, /*Data=*/0);
11055 return;
11056 }
11057 // TODO: Add support for global variables on the device after declare target
11058 // support.
11059 Function *Fn = dyn_cast<Function>(Val: Addr);
11060 if (!Fn)
11061 return;
11062
11063 // Add a function attribute for the kernel.
11064 Fn->addFnAttr(Kind: "kernel");
11065 if (T.isAMDGCN())
11066 Fn->addFnAttr(Kind: "uniform-work-group-size", Val: "true");
11067 Fn->addFnAttr(Kind: Attribute::MustProgress);
11068}
11069
11070// We only generate metadata for function that contain target regions.
11071void OpenMPIRBuilder::createOffloadEntriesAndInfoMetadata(
11072 EmitMetadataErrorReportFunctionTy &ErrorFn) {
11073
11074 // If there are no entries, we don't need to do anything.
11075 if (OffloadInfoManager.empty())
11076 return;
11077
11078 LLVMContext &C = M.getContext();
11079 SmallVector<std::pair<const OffloadEntriesInfoManager::OffloadEntryInfo *,
11080 TargetRegionEntryInfo>,
11081 16>
11082 OrderedEntries(OffloadInfoManager.size());
11083
11084 // Auxiliary methods to create metadata values and strings.
11085 auto &&GetMDInt = [this](unsigned V) {
11086 return ConstantAsMetadata::get(C: ConstantInt::get(Ty: Builder.getInt32Ty(), V));
11087 };
11088
11089 auto &&GetMDString = [&C](StringRef V) { return MDString::get(Context&: C, Str: V); };
11090
11091 // Create the offloading info metadata node.
11092 NamedMDNode *MD = M.getOrInsertNamedMetadata(Name: "omp_offload.info");
11093 auto &&TargetRegionMetadataEmitter =
11094 [&C, MD, &OrderedEntries, &GetMDInt, &GetMDString](
11095 const TargetRegionEntryInfo &EntryInfo,
11096 const OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion &E) {
11097 // Generate metadata for target regions. Each entry of this metadata
11098 // contains:
11099 // - Entry 0 -> Kind of this type of metadata (0).
11100 // - Entry 1 -> Device ID of the file where the entry was identified.
11101 // - Entry 2 -> File ID of the file where the entry was identified.
11102 // - Entry 3 -> Mangled name of the function where the entry was
11103 // identified.
11104 // - Entry 4 -> Line in the file where the entry was identified.
11105 // - Entry 5 -> Count of regions at this DeviceID/FilesID/Line.
11106 // - Entry 6 -> Order the entry was created.
11107 // The first element of the metadata node is the kind.
11108 Metadata *Ops[] = {
11109 GetMDInt(E.getKind()), GetMDInt(EntryInfo.DeviceID),
11110 GetMDInt(EntryInfo.FileID), GetMDString(EntryInfo.ParentName),
11111 GetMDInt(EntryInfo.Line), GetMDInt(EntryInfo.Count),
11112 GetMDInt(E.getOrder())};
11113
11114 // Save this entry in the right position of the ordered entries array.
11115 OrderedEntries[E.getOrder()] = std::make_pair(x: &E, y: EntryInfo);
11116
11117 // Add metadata to the named metadata node.
11118 MD->addOperand(M: MDNode::get(Context&: C, MDs: Ops));
11119 };
11120
11121 OffloadInfoManager.actOnTargetRegionEntriesInfo(Action: TargetRegionMetadataEmitter);
11122
11123 // Create function that emits metadata for each device global variable entry;
11124 auto &&DeviceGlobalVarMetadataEmitter =
11125 [&C, &OrderedEntries, &GetMDInt, &GetMDString, MD](
11126 StringRef MangledName,
11127 const OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar &E) {
11128 // Generate metadata for global variables. Each entry of this metadata
11129 // contains:
11130 // - Entry 0 -> Kind of this type of metadata (1).
11131 // - Entry 1 -> Mangled name of the variable.
11132 // - Entry 2 -> Declare target kind.
11133 // - Entry 3 -> Order the entry was created.
11134 // The first element of the metadata node is the kind.
11135 Metadata *Ops[] = {GetMDInt(E.getKind()), GetMDString(MangledName),
11136 GetMDInt(E.getFlags()), GetMDInt(E.getOrder())};
11137
11138 // Save this entry in the right position of the ordered entries array.
11139 TargetRegionEntryInfo varInfo(MangledName, 0, 0, 0);
11140 OrderedEntries[E.getOrder()] = std::make_pair(x: &E, y&: varInfo);
11141
11142 // Add metadata to the named metadata node.
11143 MD->addOperand(M: MDNode::get(Context&: C, MDs: Ops));
11144 };
11145
11146 OffloadInfoManager.actOnDeviceGlobalVarEntriesInfo(
11147 Action: DeviceGlobalVarMetadataEmitter);
11148
11149 for (const auto &E : OrderedEntries) {
11150 assert(E.first && "All ordered entries must exist!");
11151 if (const auto *CE =
11152 dyn_cast<OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion>(
11153 Val: E.first)) {
11154 if (!CE->getID() || !CE->getAddress()) {
11155 // Do not blame the entry if the parent funtion is not emitted.
11156 TargetRegionEntryInfo EntryInfo = E.second;
11157 StringRef FnName = EntryInfo.ParentName;
11158 if (!M.getNamedValue(Name: FnName))
11159 continue;
11160 ErrorFn(EMIT_MD_TARGET_REGION_ERROR, EntryInfo);
11161 continue;
11162 }
11163 createOffloadEntry(ID: CE->getID(), Addr: CE->getAddress(),
11164 /*Size=*/0, Flags: CE->getFlags(),
11165 GlobalValue::WeakAnyLinkage);
11166 } else if (const auto *CE = dyn_cast<
11167 OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar>(
11168 Val: E.first)) {
11169 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags =
11170 static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
11171 CE->getFlags());
11172 switch (Flags) {
11173 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter:
11174 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo:
11175 if (Config.isTargetDevice() && Config.hasRequiresUnifiedSharedMemory())
11176 continue;
11177 if (!CE->getAddress()) {
11178 ErrorFn(EMIT_MD_DECLARE_TARGET_ERROR, E.second);
11179 continue;
11180 }
11181 // The vaiable has no definition - no need to add the entry.
11182 if (CE->getVarSize() == 0)
11183 continue;
11184 break;
11185 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink:
11186 assert(((Config.isTargetDevice() && !CE->getAddress()) ||
11187 (!Config.isTargetDevice() && CE->getAddress())) &&
11188 "Declaret target link address is set.");
11189 if (Config.isTargetDevice())
11190 continue;
11191 if (!CE->getAddress()) {
11192 ErrorFn(EMIT_MD_GLOBAL_VAR_LINK_ERROR, TargetRegionEntryInfo());
11193 continue;
11194 }
11195 break;
11196 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect:
11197 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirectVTable:
11198 if (!CE->getAddress()) {
11199 ErrorFn(EMIT_MD_GLOBAL_VAR_INDIRECT_ERROR, E.second);
11200 continue;
11201 }
11202 break;
11203 default:
11204 break;
11205 }
11206
11207 // Hidden or internal symbols on the device are not externally visible.
11208 // We should not attempt to register them by creating an offloading
11209 // entry. Indirect variables are handled separately on the device.
11210 if (auto *GV = dyn_cast<GlobalValue>(Val: CE->getAddress()))
11211 if ((GV->hasLocalLinkage() || GV->hasHiddenVisibility()) &&
11212 (Flags !=
11213 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect &&
11214 Flags != OffloadEntriesInfoManager::
11215 OMPTargetGlobalVarEntryIndirectVTable))
11216 continue;
11217
11218 // Indirect globals need to use a special name that doesn't match the name
11219 // of the associated host global.
11220 if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect ||
11221 Flags ==
11222 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirectVTable)
11223 createOffloadEntry(ID: CE->getAddress(), Addr: CE->getAddress(), Size: CE->getVarSize(),
11224 Flags, CE->getLinkage(), Name: CE->getVarName());
11225 else
11226 createOffloadEntry(ID: CE->getAddress(), Addr: CE->getAddress(), Size: CE->getVarSize(),
11227 Flags, CE->getLinkage());
11228
11229 } else {
11230 llvm_unreachable("Unsupported entry kind.");
11231 }
11232 }
11233
11234 // Emit requires directive globals to a special entry so the runtime can
11235 // register them when the device image is loaded.
11236 // TODO: This reduces the offloading entries to a 32-bit integer. Offloading
11237 // entries should be redesigned to better suit this use-case.
11238 if (Config.hasRequiresFlags() && !Config.isTargetDevice())
11239 offloading::emitOffloadingEntry(
11240 M, Kind: object::OffloadKind::OFK_OpenMP,
11241 Addr: Constant::getNullValue(Ty: PointerType::getUnqual(C&: M.getContext())),
11242 Name: ".requires", /*Size=*/0,
11243 Flags: OffloadEntriesInfoManager::OMPTargetGlobalRegisterRequires,
11244 Data: Config.getRequiresFlags());
11245}
11246
11247void TargetRegionEntryInfo::getTargetRegionEntryFnName(
11248 SmallVectorImpl<char> &Name, StringRef ParentName, unsigned DeviceID,
11249 unsigned FileID, unsigned Line, unsigned Count) {
11250 raw_svector_ostream OS(Name);
11251 OS << KernelNamePrefix << llvm::format(Fmt: "%x", Vals: DeviceID)
11252 << llvm::format(Fmt: "_%x_", Vals: FileID) << ParentName << "_l" << Line;
11253 if (Count)
11254 OS << "_" << Count;
11255}
11256
11257void OffloadEntriesInfoManager::getTargetRegionEntryFnName(
11258 SmallVectorImpl<char> &Name, const TargetRegionEntryInfo &EntryInfo) {
11259 unsigned NewCount = getTargetRegionEntryInfoCount(EntryInfo);
11260 TargetRegionEntryInfo::getTargetRegionEntryFnName(
11261 Name, ParentName: EntryInfo.ParentName, DeviceID: EntryInfo.DeviceID, FileID: EntryInfo.FileID,
11262 Line: EntryInfo.Line, Count: NewCount);
11263}
11264
11265TargetRegionEntryInfo
11266OpenMPIRBuilder::getTargetEntryUniqueInfo(FileIdentifierInfoCallbackTy CallBack,
11267 vfs::FileSystem &VFS,
11268 StringRef ParentName) {
11269 sys::fs::UniqueID ID(0xdeadf17e, 0);
11270 auto FileIDInfo = CallBack();
11271 uint64_t FileID = 0;
11272 if (ErrorOr<vfs::Status> Status = VFS.status(Path: std::get<0>(t&: FileIDInfo))) {
11273 ID = Status->getUniqueID();
11274 FileID = Status->getUniqueID().getFile();
11275 } else {
11276 // If the inode ID could not be determined, create a hash value
11277 // the current file name and use that as an ID.
11278 FileID = hash_value(arg: std::get<0>(t&: FileIDInfo));
11279 }
11280
11281 return TargetRegionEntryInfo(ParentName, ID.getDevice(), FileID,
11282 std::get<1>(t&: FileIDInfo));
11283}
11284
11285unsigned OpenMPIRBuilder::getFlagMemberOffset() {
11286 unsigned Offset = 0;
11287 for (uint64_t Remain =
11288 static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
11289 omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF);
11290 !(Remain & 1); Remain = Remain >> 1)
11291 Offset++;
11292 return Offset;
11293}
11294
11295omp::OpenMPOffloadMappingFlags
11296OpenMPIRBuilder::getMemberOfFlag(unsigned Position) {
11297 // Rotate by getFlagMemberOffset() bits.
11298 return static_cast<omp::OpenMPOffloadMappingFlags>(((uint64_t)Position + 1)
11299 << getFlagMemberOffset());
11300}
11301
11302void OpenMPIRBuilder::setCorrectMemberOfFlag(
11303 omp::OpenMPOffloadMappingFlags &Flags,
11304 omp::OpenMPOffloadMappingFlags MemberOfFlag) {
11305 // If the entry is PTR_AND_OBJ but has not been marked with the special
11306 // placeholder value 0xFFFF in the MEMBER_OF field, then it should not be
11307 // marked as MEMBER_OF.
11308 if (static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
11309 Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ) &&
11310 static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
11311 (Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF) !=
11312 omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF))
11313 return;
11314
11315 // Entries with ATTACH are not members-of anything. They are handled
11316 // separately by the runtime after other maps have been handled.
11317 if (static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
11318 Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH))
11319 return;
11320
11321 // Reset the placeholder value to prepare the flag for the assignment of the
11322 // proper MEMBER_OF value.
11323 Flags &= ~omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
11324 Flags |= MemberOfFlag;
11325}
11326
11327Constant *OpenMPIRBuilder::getAddrOfDeclareTargetVar(
11328 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
11329 OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
11330 bool IsDeclaration, bool IsExternallyVisible,
11331 TargetRegionEntryInfo EntryInfo, StringRef MangledName,
11332 std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
11333 std::vector<Triple> TargetTriple, Type *LlvmPtrTy,
11334 std::function<Constant *()> GlobalInitializer,
11335 std::function<GlobalValue::LinkageTypes()> VariableLinkage) {
11336 // TODO: convert this to utilise the IRBuilder Config rather than
11337 // a passed down argument.
11338 if (OpenMPSIMD)
11339 return nullptr;
11340
11341 if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink ||
11342 ((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
11343 CaptureClause ==
11344 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
11345 Config.hasRequiresUnifiedSharedMemory())) {
11346 SmallString<64> PtrName;
11347 {
11348 raw_svector_ostream OS(PtrName);
11349 OS << MangledName;
11350 if (!IsExternallyVisible)
11351 OS << format(Fmt: "_%x", Vals: EntryInfo.FileID);
11352 OS << "_decl_tgt_ref_ptr";
11353 }
11354
11355 Value *Ptr = M.getNamedValue(Name: PtrName);
11356
11357 if (!Ptr) {
11358 GlobalValue *GlobalValue = M.getNamedValue(Name: MangledName);
11359 Ptr = getOrCreateInternalVariable(Ty: LlvmPtrTy, Name: PtrName);
11360
11361 auto *GV = cast<GlobalVariable>(Val: Ptr);
11362 GV->setLinkage(GlobalValue::WeakAnyLinkage);
11363
11364 if (!Config.isTargetDevice()) {
11365 if (GlobalInitializer)
11366 GV->setInitializer(GlobalInitializer());
11367 else
11368 GV->setInitializer(GlobalValue);
11369 }
11370
11371 registerTargetGlobalVariable(
11372 CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
11373 EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
11374 GlobalInitializer, VariableLinkage, LlvmPtrTy, Addr: cast<Constant>(Val: Ptr));
11375 }
11376
11377 return cast<Constant>(Val: Ptr);
11378 }
11379
11380 return nullptr;
11381}
11382
11383void OpenMPIRBuilder::registerTargetGlobalVariable(
11384 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
11385 OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
11386 bool IsDeclaration, bool IsExternallyVisible,
11387 TargetRegionEntryInfo EntryInfo, StringRef MangledName,
11388 std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
11389 std::vector<Triple> TargetTriple,
11390 std::function<Constant *()> GlobalInitializer,
11391 std::function<GlobalValue::LinkageTypes()> VariableLinkage, Type *LlvmPtrTy,
11392 Constant *Addr) {
11393 if (DeviceClause != OffloadEntriesInfoManager::OMPTargetDeviceClauseAny ||
11394 (TargetTriple.empty() && !Config.isTargetDevice()))
11395 return;
11396
11397 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags;
11398 StringRef VarName;
11399 int64_t VarSize;
11400 GlobalValue::LinkageTypes Linkage;
11401
11402 if ((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
11403 CaptureClause ==
11404 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
11405 !Config.hasRequiresUnifiedSharedMemory()) {
11406 Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
11407 VarName = MangledName;
11408 GlobalValue *LlvmVal = M.getNamedValue(Name: VarName);
11409
11410 if (!IsDeclaration)
11411 VarSize = divideCeil(
11412 Numerator: M.getDataLayout().getTypeSizeInBits(Ty: LlvmVal->getValueType()), Denominator: 8);
11413 else
11414 VarSize = 0;
11415 Linkage = (VariableLinkage) ? VariableLinkage() : LlvmVal->getLinkage();
11416
11417 // This is a workaround carried over from Clang which prevents undesired
11418 // optimisation of internal variables.
11419 if (Config.isTargetDevice() &&
11420 (!IsExternallyVisible || Linkage == GlobalValue::LinkOnceODRLinkage)) {
11421 // Do not create a "ref-variable" if the original is not also available
11422 // on the host.
11423 if (!OffloadInfoManager.hasDeviceGlobalVarEntryInfo(VarName))
11424 return;
11425
11426 std::string RefName = createPlatformSpecificName(Parts: {VarName, "ref"});
11427
11428 if (!M.getNamedValue(Name: RefName)) {
11429 Constant *AddrRef =
11430 getOrCreateInternalVariable(Ty: Addr->getType(), Name: RefName);
11431 auto *GvAddrRef = cast<GlobalVariable>(Val: AddrRef);
11432 GvAddrRef->setConstant(true);
11433 GvAddrRef->setLinkage(GlobalValue::InternalLinkage);
11434 GvAddrRef->setInitializer(Addr);
11435 GeneratedRefs.push_back(x: GvAddrRef);
11436 }
11437 }
11438 } else {
11439 if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink)
11440 Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
11441 else
11442 Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
11443
11444 if (Config.isTargetDevice()) {
11445 VarName = (Addr) ? Addr->getName() : "";
11446 Addr = nullptr;
11447 } else {
11448 Addr = getAddrOfDeclareTargetVar(
11449 CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
11450 EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
11451 LlvmPtrTy, GlobalInitializer, VariableLinkage);
11452 VarName = (Addr) ? Addr->getName() : "";
11453 }
11454 VarSize = M.getDataLayout().getPointerSize();
11455 Linkage = GlobalValue::WeakAnyLinkage;
11456 }
11457
11458 OffloadInfoManager.registerDeviceGlobalVarEntryInfo(VarName, Addr, VarSize,
11459 Flags, Linkage);
11460}
11461
11462/// Loads all the offload entries information from the host IR
11463/// metadata.
11464void OpenMPIRBuilder::loadOffloadInfoMetadata(Module &M) {
11465 // If we are in target mode, load the metadata from the host IR. This code has
11466 // to match the metadata creation in createOffloadEntriesAndInfoMetadata().
11467
11468 NamedMDNode *MD = M.getNamedMetadata(Name: ompOffloadInfoName);
11469 if (!MD)
11470 return;
11471
11472 for (MDNode *MN : MD->operands()) {
11473 auto &&GetMDInt = [MN](unsigned Idx) {
11474 auto *V = cast<ConstantAsMetadata>(Val: MN->getOperand(I: Idx));
11475 return cast<ConstantInt>(Val: V->getValue())->getZExtValue();
11476 };
11477
11478 auto &&GetMDString = [MN](unsigned Idx) {
11479 auto *V = cast<MDString>(Val: MN->getOperand(I: Idx));
11480 return V->getString();
11481 };
11482
11483 switch (GetMDInt(0)) {
11484 default:
11485 llvm_unreachable("Unexpected metadata!");
11486 break;
11487 case OffloadEntriesInfoManager::OffloadEntryInfo::
11488 OffloadingEntryInfoTargetRegion: {
11489 TargetRegionEntryInfo EntryInfo(/*ParentName=*/GetMDString(3),
11490 /*DeviceID=*/GetMDInt(1),
11491 /*FileID=*/GetMDInt(2),
11492 /*Line=*/GetMDInt(4),
11493 /*Count=*/GetMDInt(5));
11494 OffloadInfoManager.initializeTargetRegionEntryInfo(EntryInfo,
11495 /*Order=*/GetMDInt(6));
11496 break;
11497 }
11498 case OffloadEntriesInfoManager::OffloadEntryInfo::
11499 OffloadingEntryInfoDeviceGlobalVar:
11500 OffloadInfoManager.initializeDeviceGlobalVarEntryInfo(
11501 /*MangledName=*/Name: GetMDString(1),
11502 Flags: static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
11503 /*Flags=*/GetMDInt(2)),
11504 /*Order=*/GetMDInt(3));
11505 break;
11506 }
11507 }
11508}
11509
11510void OpenMPIRBuilder::loadOffloadInfoMetadata(vfs::FileSystem &VFS,
11511 StringRef HostFilePath) {
11512 if (HostFilePath.empty())
11513 return;
11514
11515 auto Buf = VFS.getBufferForFile(Name: HostFilePath);
11516 if (std::error_code Err = Buf.getError()) {
11517 report_fatal_error(reason: ("error opening host file from host file path inside of "
11518 "OpenMPIRBuilder: " +
11519 Err.message())
11520 .c_str());
11521 }
11522
11523 LLVMContext Ctx;
11524 auto M = expectedToErrorOrAndEmitErrors(
11525 Ctx, Val: parseBitcodeFile(Buffer: Buf.get()->getMemBufferRef(), Context&: Ctx));
11526 if (std::error_code Err = M.getError()) {
11527 report_fatal_error(
11528 reason: ("error parsing host file inside of OpenMPIRBuilder: " + Err.message())
11529 .c_str());
11530 }
11531
11532 loadOffloadInfoMetadata(M&: *M.get());
11533}
11534
11535//===----------------------------------------------------------------------===//
11536// OffloadEntriesInfoManager
11537//===----------------------------------------------------------------------===//
11538
11539bool OffloadEntriesInfoManager::empty() const {
11540 return OffloadEntriesTargetRegion.empty() &&
11541 OffloadEntriesDeviceGlobalVar.empty();
11542}
11543
11544unsigned OffloadEntriesInfoManager::getTargetRegionEntryInfoCount(
11545 const TargetRegionEntryInfo &EntryInfo) const {
11546 auto It = OffloadEntriesTargetRegionCount.find(
11547 x: getTargetRegionEntryCountKey(EntryInfo));
11548 if (It == OffloadEntriesTargetRegionCount.end())
11549 return 0;
11550 return It->second;
11551}
11552
11553void OffloadEntriesInfoManager::incrementTargetRegionEntryInfoCount(
11554 const TargetRegionEntryInfo &EntryInfo) {
11555 OffloadEntriesTargetRegionCount[getTargetRegionEntryCountKey(EntryInfo)] =
11556 EntryInfo.Count + 1;
11557}
11558
11559/// Initialize target region entry.
11560void OffloadEntriesInfoManager::initializeTargetRegionEntryInfo(
11561 const TargetRegionEntryInfo &EntryInfo, unsigned Order) {
11562 OffloadEntriesTargetRegion[EntryInfo] =
11563 OffloadEntryInfoTargetRegion(Order, /*Addr=*/nullptr, /*ID=*/nullptr,
11564 OMPTargetRegionEntryTargetRegion);
11565 ++OffloadingEntriesNum;
11566}
11567
11568void OffloadEntriesInfoManager::registerTargetRegionEntryInfo(
11569 TargetRegionEntryInfo EntryInfo, Constant *Addr, Constant *ID,
11570 OMPTargetRegionEntryKind Flags) {
11571 assert(EntryInfo.Count == 0 && "expected default EntryInfo");
11572
11573 // Update the EntryInfo with the next available count for this location.
11574 EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
11575
11576 // If we are emitting code for a target, the entry is already initialized,
11577 // only has to be registered.
11578 if (OMPBuilder->Config.isTargetDevice()) {
11579 // This could happen if the device compilation is invoked standalone.
11580 if (!hasTargetRegionEntryInfo(EntryInfo)) {
11581 return;
11582 }
11583 auto &Entry = OffloadEntriesTargetRegion[EntryInfo];
11584 Entry.setAddress(Addr);
11585 Entry.setID(ID);
11586 Entry.setFlags(Flags);
11587 } else {
11588 if (Flags == OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion &&
11589 hasTargetRegionEntryInfo(EntryInfo, /*IgnoreAddressId*/ true))
11590 return;
11591 assert(!hasTargetRegionEntryInfo(EntryInfo) &&
11592 "Target region entry already registered!");
11593 OffloadEntryInfoTargetRegion Entry(OffloadingEntriesNum, Addr, ID, Flags);
11594 OffloadEntriesTargetRegion[EntryInfo] = Entry;
11595 ++OffloadingEntriesNum;
11596 }
11597 incrementTargetRegionEntryInfoCount(EntryInfo);
11598}
11599
11600bool OffloadEntriesInfoManager::hasTargetRegionEntryInfo(
11601 TargetRegionEntryInfo EntryInfo, bool IgnoreAddressId) const {
11602
11603 // Update the EntryInfo with the next available count for this location.
11604 EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
11605
11606 auto It = OffloadEntriesTargetRegion.find(x: EntryInfo);
11607 if (It == OffloadEntriesTargetRegion.end()) {
11608 return false;
11609 }
11610 // Fail if this entry is already registered.
11611 if (!IgnoreAddressId && (It->second.getAddress() || It->second.getID()))
11612 return false;
11613 return true;
11614}
11615
11616void OffloadEntriesInfoManager::actOnTargetRegionEntriesInfo(
11617 const OffloadTargetRegionEntryInfoActTy &Action) {
11618 // Scan all target region entries and perform the provided action.
11619 for (const auto &It : OffloadEntriesTargetRegion) {
11620 Action(It.first, It.second);
11621 }
11622}
11623
11624void OffloadEntriesInfoManager::initializeDeviceGlobalVarEntryInfo(
11625 StringRef Name, OMPTargetGlobalVarEntryKind Flags, unsigned Order) {
11626 OffloadEntriesDeviceGlobalVar.try_emplace(Key: Name, Args&: Order, Args&: Flags);
11627 ++OffloadingEntriesNum;
11628}
11629
11630void OffloadEntriesInfoManager::registerDeviceGlobalVarEntryInfo(
11631 StringRef VarName, Constant *Addr, int64_t VarSize,
11632 OMPTargetGlobalVarEntryKind Flags, GlobalValue::LinkageTypes Linkage) {
11633 if (OMPBuilder->Config.isTargetDevice()) {
11634 // This could happen if the device compilation is invoked standalone.
11635 if (!hasDeviceGlobalVarEntryInfo(VarName))
11636 return;
11637 auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
11638 if (Entry.getAddress() && hasDeviceGlobalVarEntryInfo(VarName)) {
11639 if (Entry.getVarSize() == 0) {
11640 Entry.setVarSize(VarSize);
11641 Entry.setLinkage(Linkage);
11642 }
11643 return;
11644 }
11645 Entry.setVarSize(VarSize);
11646 Entry.setLinkage(Linkage);
11647 Entry.setAddress(Addr);
11648 } else {
11649 if (hasDeviceGlobalVarEntryInfo(VarName)) {
11650 auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
11651 assert(Entry.isValid() && Entry.getFlags() == Flags &&
11652 "Entry not initialized!");
11653 if (Entry.getVarSize() == 0) {
11654 Entry.setVarSize(VarSize);
11655 Entry.setLinkage(Linkage);
11656 }
11657 return;
11658 }
11659 if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect ||
11660 Flags ==
11661 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirectVTable)
11662 OffloadEntriesDeviceGlobalVar.try_emplace(Key: VarName, Args&: OffloadingEntriesNum,
11663 Args&: Addr, Args&: VarSize, Args&: Flags, Args&: Linkage,
11664 Args: VarName.str());
11665 else
11666 OffloadEntriesDeviceGlobalVar.try_emplace(
11667 Key: VarName, Args&: OffloadingEntriesNum, Args&: Addr, Args&: VarSize, Args&: Flags, Args&: Linkage, Args: "");
11668 ++OffloadingEntriesNum;
11669 }
11670}
11671
11672void OffloadEntriesInfoManager::actOnDeviceGlobalVarEntriesInfo(
11673 const OffloadDeviceGlobalVarEntryInfoActTy &Action) {
11674 // Scan all target region entries and perform the provided action.
11675 for (const auto &E : OffloadEntriesDeviceGlobalVar)
11676 Action(E.getKey(), E.getValue());
11677}
11678
11679//===----------------------------------------------------------------------===//
11680// CanonicalLoopInfo
11681//===----------------------------------------------------------------------===//
11682
11683void CanonicalLoopInfo::collectControlBlocks(
11684 SmallVectorImpl<BasicBlock *> &BBs) {
11685 // We only count those BBs as control block for which we do not need to
11686 // reverse the CFG, i.e. not the loop body which can contain arbitrary control
11687 // flow. For consistency, this also means we do not add the Body block, which
11688 // is just the entry to the body code.
11689 BBs.reserve(N: BBs.size() + 6);
11690 BBs.append(IL: {getPreheader(), Header, Cond, Latch, Exit, getAfter()});
11691}
11692
11693BasicBlock *CanonicalLoopInfo::getPreheader() const {
11694 assert(isValid() && "Requires a valid canonical loop");
11695 for (BasicBlock *Pred : predecessors(BB: Header)) {
11696 if (Pred != Latch)
11697 return Pred;
11698 }
11699 llvm_unreachable("Missing preheader");
11700}
11701
11702void CanonicalLoopInfo::setTripCount(Value *TripCount) {
11703 assert(isValid() && "Requires a valid canonical loop");
11704
11705 Instruction *CmpI = &getCond()->front();
11706 assert(isa<CmpInst>(CmpI) && "First inst must compare IV with TripCount");
11707 CmpI->setOperand(i: 1, Val: TripCount);
11708
11709#ifndef NDEBUG
11710 assertOK();
11711#endif
11712}
11713
11714void CanonicalLoopInfo::mapIndVar(
11715 llvm::function_ref<Value *(Instruction *)> Updater) {
11716 assert(isValid() && "Requires a valid canonical loop");
11717
11718 Instruction *OldIV = getIndVar();
11719
11720 // Record all uses excluding those introduced by the updater. Uses by the
11721 // CanonicalLoopInfo itself to keep track of the number of iterations are
11722 // excluded.
11723 SmallVector<Use *> ReplacableUses;
11724 for (Use &U : OldIV->uses()) {
11725 auto *User = dyn_cast<Instruction>(Val: U.getUser());
11726 if (!User)
11727 continue;
11728 if (User->getParent() == getCond())
11729 continue;
11730 if (User->getParent() == getLatch())
11731 continue;
11732 ReplacableUses.push_back(Elt: &U);
11733 }
11734
11735 // Run the updater that may introduce new uses
11736 Value *NewIV = Updater(OldIV);
11737
11738 // Replace the old uses with the value returned by the updater.
11739 for (Use *U : ReplacableUses)
11740 U->set(NewIV);
11741
11742#ifndef NDEBUG
11743 assertOK();
11744#endif
11745}
11746
11747void CanonicalLoopInfo::assertOK() const {
11748#ifndef NDEBUG
11749 // No constraints if this object currently does not describe a loop.
11750 if (!isValid())
11751 return;
11752
11753 BasicBlock *Preheader = getPreheader();
11754 BasicBlock *Body = getBody();
11755 BasicBlock *After = getAfter();
11756
11757 // Verify standard control-flow we use for OpenMP loops.
11758 assert(Preheader);
11759 assert(isa<BranchInst>(Preheader->getTerminator()) &&
11760 "Preheader must terminate with unconditional branch");
11761 assert(Preheader->getSingleSuccessor() == Header &&
11762 "Preheader must jump to header");
11763
11764 assert(Header);
11765 assert(isa<BranchInst>(Header->getTerminator()) &&
11766 "Header must terminate with unconditional branch");
11767 assert(Header->getSingleSuccessor() == Cond &&
11768 "Header must jump to exiting block");
11769
11770 assert(Cond);
11771 assert(Cond->getSinglePredecessor() == Header &&
11772 "Exiting block only reachable from header");
11773
11774 assert(isa<BranchInst>(Cond->getTerminator()) &&
11775 "Exiting block must terminate with conditional branch");
11776 assert(size(successors(Cond)) == 2 &&
11777 "Exiting block must have two successors");
11778 assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(0) == Body &&
11779 "Exiting block's first successor jump to the body");
11780 assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(1) == Exit &&
11781 "Exiting block's second successor must exit the loop");
11782
11783 assert(Body);
11784 assert(Body->getSinglePredecessor() == Cond &&
11785 "Body only reachable from exiting block");
11786 assert(!isa<PHINode>(Body->front()));
11787
11788 assert(Latch);
11789 assert(isa<BranchInst>(Latch->getTerminator()) &&
11790 "Latch must terminate with unconditional branch");
11791 assert(Latch->getSingleSuccessor() == Header && "Latch must jump to header");
11792 // TODO: To support simple redirecting of the end of the body code that has
11793 // multiple; introduce another auxiliary basic block like preheader and after.
11794 assert(Latch->getSinglePredecessor() != nullptr);
11795 assert(!isa<PHINode>(Latch->front()));
11796
11797 assert(Exit);
11798 assert(isa<BranchInst>(Exit->getTerminator()) &&
11799 "Exit block must terminate with unconditional branch");
11800 assert(Exit->getSingleSuccessor() == After &&
11801 "Exit block must jump to after block");
11802
11803 assert(After);
11804 assert(After->getSinglePredecessor() == Exit &&
11805 "After block only reachable from exit block");
11806 assert(After->empty() || !isa<PHINode>(After->front()));
11807
11808 Instruction *IndVar = getIndVar();
11809 assert(IndVar && "Canonical induction variable not found?");
11810 assert(isa<IntegerType>(IndVar->getType()) &&
11811 "Induction variable must be an integer");
11812 assert(cast<PHINode>(IndVar)->getParent() == Header &&
11813 "Induction variable must be a PHI in the loop header");
11814 assert(cast<PHINode>(IndVar)->getIncomingBlock(0) == Preheader);
11815 assert(
11816 cast<ConstantInt>(cast<PHINode>(IndVar)->getIncomingValue(0))->isZero());
11817 assert(cast<PHINode>(IndVar)->getIncomingBlock(1) == Latch);
11818
11819 auto *NextIndVar = cast<PHINode>(IndVar)->getIncomingValue(1);
11820 assert(cast<Instruction>(NextIndVar)->getParent() == Latch);
11821 assert(cast<BinaryOperator>(NextIndVar)->getOpcode() == BinaryOperator::Add);
11822 assert(cast<BinaryOperator>(NextIndVar)->getOperand(0) == IndVar);
11823 assert(cast<ConstantInt>(cast<BinaryOperator>(NextIndVar)->getOperand(1))
11824 ->isOne());
11825
11826 Value *TripCount = getTripCount();
11827 assert(TripCount && "Loop trip count not found?");
11828 assert(IndVar->getType() == TripCount->getType() &&
11829 "Trip count and induction variable must have the same type");
11830
11831 auto *CmpI = cast<CmpInst>(&Cond->front());
11832 assert(CmpI->getPredicate() == CmpInst::ICMP_ULT &&
11833 "Exit condition must be a signed less-than comparison");
11834 assert(CmpI->getOperand(0) == IndVar &&
11835 "Exit condition must compare the induction variable");
11836 assert(CmpI->getOperand(1) == TripCount &&
11837 "Exit condition must compare with the trip count");
11838#endif
11839}
11840
11841void CanonicalLoopInfo::invalidate() {
11842 Header = nullptr;
11843 Cond = nullptr;
11844 Latch = nullptr;
11845 Exit = nullptr;
11846}
11847