1//===- OpenMPIRBuilder.cpp - Builder for LLVM-IR for OpenMP directives ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8/// \file
9///
10/// This file implements the OpenMPIRBuilder class, which is used as a
11/// convenient way to create LLVM instructions for OpenMP directives.
12///
13//===----------------------------------------------------------------------===//
14
15#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
16#include "llvm/ADT/SmallBitVector.h"
17#include "llvm/ADT/SmallSet.h"
18#include "llvm/ADT/SmallVectorExtras.h"
19#include "llvm/ADT/StringExtras.h"
20#include "llvm/ADT/StringRef.h"
21#include "llvm/Analysis/AssumptionCache.h"
22#include "llvm/Analysis/CodeMetrics.h"
23#include "llvm/Analysis/LoopInfo.h"
24#include "llvm/Analysis/OptimizationRemarkEmitter.h"
25#include "llvm/Analysis/PostDominators.h"
26#include "llvm/Analysis/ScalarEvolution.h"
27#include "llvm/Analysis/TargetLibraryInfo.h"
28#include "llvm/Bitcode/BitcodeReader.h"
29#include "llvm/Frontend/Offloading/Utility.h"
30#include "llvm/Frontend/OpenMP/OMPGridValues.h"
31#include "llvm/IR/Attributes.h"
32#include "llvm/IR/BasicBlock.h"
33#include "llvm/IR/CFG.h"
34#include "llvm/IR/CallingConv.h"
35#include "llvm/IR/Constant.h"
36#include "llvm/IR/Constants.h"
37#include "llvm/IR/DIBuilder.h"
38#include "llvm/IR/DebugInfoMetadata.h"
39#include "llvm/IR/DerivedTypes.h"
40#include "llvm/IR/Function.h"
41#include "llvm/IR/GlobalVariable.h"
42#include "llvm/IR/IRBuilder.h"
43#include "llvm/IR/InstIterator.h"
44#include "llvm/IR/IntrinsicInst.h"
45#include "llvm/IR/LLVMContext.h"
46#include "llvm/IR/MDBuilder.h"
47#include "llvm/IR/Metadata.h"
48#include "llvm/IR/PassInstrumentation.h"
49#include "llvm/IR/PassManager.h"
50#include "llvm/IR/ReplaceConstant.h"
51#include "llvm/IR/Value.h"
52#include "llvm/MC/TargetRegistry.h"
53#include "llvm/Support/CommandLine.h"
54#include "llvm/Support/Error.h"
55#include "llvm/Support/ErrorHandling.h"
56#include "llvm/Support/FileSystem.h"
57#include "llvm/Support/NVVMAttributes.h"
58#include "llvm/Support/VirtualFileSystem.h"
59#include "llvm/Target/TargetMachine.h"
60#include "llvm/Target/TargetOptions.h"
61#include "llvm/Transforms/Utils/BasicBlockUtils.h"
62#include "llvm/Transforms/Utils/Cloning.h"
63#include "llvm/Transforms/Utils/CodeExtractor.h"
64#include "llvm/Transforms/Utils/LoopPeel.h"
65#include "llvm/Transforms/Utils/UnrollLoop.h"
66
67#include <cstdint>
68#include <optional>
69
70#define DEBUG_TYPE "openmp-ir-builder"
71
72using namespace llvm;
73using namespace omp;
74
75static cl::opt<bool>
76 OptimisticAttributes("openmp-ir-builder-optimistic-attributes", cl::Hidden,
77 cl::desc("Use optimistic attributes describing "
78 "'as-if' properties of runtime calls."),
79 cl::init(Val: false));
80
81static cl::opt<double> UnrollThresholdFactor(
82 "openmp-ir-builder-unroll-threshold-factor", cl::Hidden,
83 cl::desc("Factor for the unroll threshold to account for code "
84 "simplifications still taking place"),
85 cl::init(Val: 1.5));
86
87static cl::opt<bool> UseDefaultMaxThreads(
88 "openmp-ir-builder-use-default-max-threads", cl::Hidden,
89 cl::desc("Use a default max threads if none is provided."), cl::init(Val: true));
90
91#ifndef NDEBUG
92/// Return whether IP1 and IP2 are ambiguous, i.e. that inserting instructions
93/// at position IP1 may change the meaning of IP2 or vice-versa. This is because
94/// an InsertPoint stores the instruction before something is inserted. For
95/// instance, if both point to the same instruction, two IRBuilders alternating
96/// creating instruction will cause the instructions to be interleaved.
97static bool isConflictIP(IRBuilder<>::InsertPoint IP1,
98 IRBuilder<>::InsertPoint IP2) {
99 if (!IP1.isSet() || !IP2.isSet())
100 return false;
101 return IP1.getBlock() == IP2.getBlock() && IP1.getPoint() == IP2.getPoint();
102}
103
104static bool isValidWorkshareLoopScheduleType(OMPScheduleType SchedType) {
105 // Valid ordered/unordered and base algorithm combinations.
106 switch (SchedType & ~OMPScheduleType::MonotonicityMask) {
107 case OMPScheduleType::UnorderedStaticChunked:
108 case OMPScheduleType::UnorderedStatic:
109 case OMPScheduleType::UnorderedDynamicChunked:
110 case OMPScheduleType::UnorderedGuidedChunked:
111 case OMPScheduleType::UnorderedRuntime:
112 case OMPScheduleType::UnorderedAuto:
113 case OMPScheduleType::UnorderedTrapezoidal:
114 case OMPScheduleType::UnorderedGreedy:
115 case OMPScheduleType::UnorderedBalanced:
116 case OMPScheduleType::UnorderedGuidedIterativeChunked:
117 case OMPScheduleType::UnorderedGuidedAnalyticalChunked:
118 case OMPScheduleType::UnorderedSteal:
119 case OMPScheduleType::UnorderedStaticBalancedChunked:
120 case OMPScheduleType::UnorderedGuidedSimd:
121 case OMPScheduleType::UnorderedRuntimeSimd:
122 case OMPScheduleType::OrderedStaticChunked:
123 case OMPScheduleType::OrderedStatic:
124 case OMPScheduleType::OrderedDynamicChunked:
125 case OMPScheduleType::OrderedGuidedChunked:
126 case OMPScheduleType::OrderedRuntime:
127 case OMPScheduleType::OrderedAuto:
128 case OMPScheduleType::OrderdTrapezoidal:
129 case OMPScheduleType::NomergeUnorderedStaticChunked:
130 case OMPScheduleType::NomergeUnorderedStatic:
131 case OMPScheduleType::NomergeUnorderedDynamicChunked:
132 case OMPScheduleType::NomergeUnorderedGuidedChunked:
133 case OMPScheduleType::NomergeUnorderedRuntime:
134 case OMPScheduleType::NomergeUnorderedAuto:
135 case OMPScheduleType::NomergeUnorderedTrapezoidal:
136 case OMPScheduleType::NomergeUnorderedGreedy:
137 case OMPScheduleType::NomergeUnorderedBalanced:
138 case OMPScheduleType::NomergeUnorderedGuidedIterativeChunked:
139 case OMPScheduleType::NomergeUnorderedGuidedAnalyticalChunked:
140 case OMPScheduleType::NomergeUnorderedSteal:
141 case OMPScheduleType::NomergeOrderedStaticChunked:
142 case OMPScheduleType::NomergeOrderedStatic:
143 case OMPScheduleType::NomergeOrderedDynamicChunked:
144 case OMPScheduleType::NomergeOrderedGuidedChunked:
145 case OMPScheduleType::NomergeOrderedRuntime:
146 case OMPScheduleType::NomergeOrderedAuto:
147 case OMPScheduleType::NomergeOrderedTrapezoidal:
148 case OMPScheduleType::OrderedDistributeChunked:
149 case OMPScheduleType::OrderedDistribute:
150 break;
151 default:
152 return false;
153 }
154
155 // Must not set both monotonicity modifiers at the same time.
156 OMPScheduleType MonotonicityFlags =
157 SchedType & OMPScheduleType::MonotonicityMask;
158 if (MonotonicityFlags == OMPScheduleType::MonotonicityMask)
159 return false;
160
161 return true;
162}
163#endif
164
165/// This is wrapper over IRBuilderBase::restoreIP that also restores the current
166/// debug location to the last instruction in the specified basic block if the
167/// insert point points to the end of the block.
168static void restoreIPandDebugLoc(llvm::IRBuilderBase &Builder,
169 llvm::IRBuilderBase::InsertPoint IP) {
170 Builder.restoreIP(IP);
171 llvm::BasicBlock *BB = Builder.GetInsertBlock();
172 llvm::BasicBlock::iterator I = Builder.GetInsertPoint();
173 if (!BB->empty() && I == BB->end())
174 Builder.SetCurrentDebugLocation(BB->back().getStableDebugLoc());
175}
176
177static bool hasGridValue(const Triple &T) {
178 return T.isAMDGPU() || T.isNVPTX() || T.isSPIRV();
179}
180
181static const omp::GV &getGridValue(const Triple &T, Function *Kernel) {
182 if (T.isAMDGPU()) {
183 StringRef Features =
184 Kernel->getFnAttribute(Kind: "target-features").getValueAsString();
185 if (Features.count(Str: "+wavefrontsize64"))
186 return omp::getAMDGPUGridValues<64>();
187 return omp::getAMDGPUGridValues<32>();
188 }
189 if (T.isNVPTX())
190 return omp::NVPTXGridValues;
191 if (T.isSPIRV())
192 return omp::SPIRVGridValues;
193 llvm_unreachable("No grid value available for this architecture!");
194}
195
196/// Determine which scheduling algorithm to use, determined from schedule clause
197/// arguments.
198static OMPScheduleType
199getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind, bool HasChunks,
200 bool HasSimdModifier, bool HasDistScheduleChunks) {
201 // Currently, the default schedule it static.
202 switch (ClauseKind) {
203 case OMP_SCHEDULE_Default:
204 case OMP_SCHEDULE_Static:
205 return HasChunks ? OMPScheduleType::BaseStaticChunked
206 : OMPScheduleType::BaseStatic;
207 case OMP_SCHEDULE_Dynamic:
208 return OMPScheduleType::BaseDynamicChunked;
209 case OMP_SCHEDULE_Guided:
210 return HasSimdModifier ? OMPScheduleType::BaseGuidedSimd
211 : OMPScheduleType::BaseGuidedChunked;
212 case OMP_SCHEDULE_Auto:
213 return llvm::omp::OMPScheduleType::BaseAuto;
214 case OMP_SCHEDULE_Runtime:
215 return HasSimdModifier ? OMPScheduleType::BaseRuntimeSimd
216 : OMPScheduleType::BaseRuntime;
217 case OMP_SCHEDULE_Distribute:
218 return HasDistScheduleChunks ? OMPScheduleType::BaseDistributeChunked
219 : OMPScheduleType::BaseDistribute;
220 }
221 llvm_unreachable("unhandled schedule clause argument");
222}
223
224/// Adds ordering modifier flags to schedule type.
225static OMPScheduleType
226getOpenMPOrderingScheduleType(OMPScheduleType BaseScheduleType,
227 bool HasOrderedClause) {
228 assert((BaseScheduleType & OMPScheduleType::ModifierMask) ==
229 OMPScheduleType::None &&
230 "Must not have ordering nor monotonicity flags already set");
231
232 OMPScheduleType OrderingModifier = HasOrderedClause
233 ? OMPScheduleType::ModifierOrdered
234 : OMPScheduleType::ModifierUnordered;
235 OMPScheduleType OrderingScheduleType = BaseScheduleType | OrderingModifier;
236
237 // Unsupported combinations
238 if (OrderingScheduleType ==
239 (OMPScheduleType::BaseGuidedSimd | OMPScheduleType::ModifierOrdered))
240 return OMPScheduleType::OrderedGuidedChunked;
241 else if (OrderingScheduleType == (OMPScheduleType::BaseRuntimeSimd |
242 OMPScheduleType::ModifierOrdered))
243 return OMPScheduleType::OrderedRuntime;
244
245 return OrderingScheduleType;
246}
247
248/// Adds monotonicity modifier flags to schedule type.
249static OMPScheduleType
250getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,
251 bool HasSimdModifier, bool HasMonotonic,
252 bool HasNonmonotonic, bool HasOrderedClause) {
253 assert((ScheduleType & OMPScheduleType::MonotonicityMask) ==
254 OMPScheduleType::None &&
255 "Must not have monotonicity flags already set");
256 assert((!HasMonotonic || !HasNonmonotonic) &&
257 "Monotonic and Nonmonotonic are contradicting each other");
258
259 if (HasMonotonic) {
260 return ScheduleType | OMPScheduleType::ModifierMonotonic;
261 } else if (HasNonmonotonic) {
262 return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
263 } else {
264 // OpenMP 5.1, 2.11.4 Worksharing-Loop Construct, Description.
265 // If the static schedule kind is specified or if the ordered clause is
266 // specified, and if the nonmonotonic modifier is not specified, the
267 // effect is as if the monotonic modifier is specified. Otherwise, unless
268 // the monotonic modifier is specified, the effect is as if the
269 // nonmonotonic modifier is specified.
270 OMPScheduleType BaseScheduleType =
271 ScheduleType & ~OMPScheduleType::ModifierMask;
272 if ((BaseScheduleType == OMPScheduleType::BaseStatic) ||
273 (BaseScheduleType == OMPScheduleType::BaseStaticChunked) ||
274 HasOrderedClause) {
275 // The monotonic is used by default in openmp runtime library, so no need
276 // to set it.
277 return ScheduleType;
278 } else {
279 return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
280 }
281 }
282}
283
284/// Determine the schedule type using schedule and ordering clause arguments.
285static OMPScheduleType
286computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
287 bool HasSimdModifier, bool HasMonotonicModifier,
288 bool HasNonmonotonicModifier, bool HasOrderedClause,
289 bool HasDistScheduleChunks) {
290 OMPScheduleType BaseSchedule = getOpenMPBaseScheduleType(
291 ClauseKind, HasChunks, HasSimdModifier, HasDistScheduleChunks);
292 OMPScheduleType OrderedSchedule =
293 getOpenMPOrderingScheduleType(BaseScheduleType: BaseSchedule, HasOrderedClause);
294 OMPScheduleType Result = getOpenMPMonotonicityScheduleType(
295 ScheduleType: OrderedSchedule, HasSimdModifier, HasMonotonic: HasMonotonicModifier,
296 HasNonmonotonic: HasNonmonotonicModifier, HasOrderedClause);
297
298 assert(isValidWorkshareLoopScheduleType(Result));
299 return Result;
300}
301
302/// Given a function, if it represents the entry point of a target kernel, this
303/// returns the execution mode flags associated with that kernel.
304static std::optional<omp::OMPTgtExecModeFlags>
305getTargetKernelExecMode(Function &Kernel) {
306 CallInst *TargetInitCall = nullptr;
307 for (Instruction &Inst : Kernel.getEntryBlock()) {
308 if (auto *Call = dyn_cast<CallInst>(Val: &Inst)) {
309 if (Call->getCalledFunction()->getName() == "__kmpc_target_init") {
310 TargetInitCall = Call;
311 break;
312 }
313 }
314 }
315
316 if (!TargetInitCall)
317 return std::nullopt;
318
319 // Get the kernel mode information from the global variable associated to the
320 // first argument to the call to __kmpc_target_init. Refer to
321 // createTargetInit() to see how this is initialized.
322 Value *InitOperand = TargetInitCall->getArgOperand(i: 0);
323 GlobalVariable *KernelEnv = nullptr;
324 if (auto *Cast = dyn_cast<ConstantExpr>(Val: InitOperand))
325 KernelEnv = cast<GlobalVariable>(Val: Cast->getOperand(i_nocapture: 0));
326 else
327 KernelEnv = cast<GlobalVariable>(Val: InitOperand);
328 auto *KernelEnvInit = cast<ConstantStruct>(Val: KernelEnv->getInitializer());
329 auto *ConfigEnv = cast<ConstantStruct>(Val: KernelEnvInit->getOperand(i_nocapture: 0));
330 auto *KernelMode = cast<ConstantInt>(Val: ConfigEnv->getOperand(i_nocapture: 2));
331 return static_cast<OMPTgtExecModeFlags>(KernelMode->getZExtValue());
332}
333
334static bool isGenericKernel(Function &Fn) {
335 std::optional<omp::OMPTgtExecModeFlags> ExecMode =
336 getTargetKernelExecMode(Kernel&: Fn);
337 return !ExecMode || (*ExecMode & OMP_TGT_EXEC_MODE_GENERIC);
338}
339
340/// Make \p Source branch to \p Target.
341///
342/// Handles two situations:
343/// * \p Source already has an unconditional branch.
344/// * \p Source is a degenerate block (no terminator because the BB is
345/// the current head of the IR construction).
346static void redirectTo(BasicBlock *Source, BasicBlock *Target, DebugLoc DL) {
347 if (Instruction *Term = Source->getTerminatorOrNull()) {
348 auto *Br = cast<UncondBrInst>(Val: Term);
349 BasicBlock *Succ = Br->getSuccessor();
350 Succ->removePredecessor(Pred: Source, /*KeepOneInputPHIs=*/true);
351 Br->setSuccessor(Target);
352 return;
353 }
354
355 auto *NewBr = UncondBrInst::Create(Target, InsertBefore: Source);
356 NewBr->setDebugLoc(DL);
357}
358
359void llvm::spliceBB(IRBuilderBase::InsertPoint IP, BasicBlock *New,
360 bool CreateBranch, DebugLoc DL) {
361 assert(New->getFirstInsertionPt() == New->begin() &&
362 "Target BB must not have PHI nodes");
363
364 // Move instructions to new block.
365 BasicBlock *Old = IP.getBlock();
366 // If the `Old` block is empty then there are no instructions to move. But in
367 // the new debug scheme, it could have trailing debug records which will be
368 // moved to `New` in `spliceDebugInfoEmptyBlock`. We dont want that for 2
369 // reasons:
370 // 1. If `New` is also empty, `BasicBlock::splice` crashes.
371 // 2. Even if `New` is not empty, the rationale to move those records to `New`
372 // (in `spliceDebugInfoEmptyBlock`) does not apply here. That function
373 // assumes that `Old` is optimized out and is going away. This is not the case
374 // here. The `Old` block is still being used e.g. a branch instruction is
375 // added to it later in this function.
376 // So we call `BasicBlock::splice` only when `Old` is not empty.
377 if (!Old->empty())
378 New->splice(ToIt: New->begin(), FromBB: Old, FromBeginIt: IP.getPoint(), FromEndIt: Old->end());
379
380 if (CreateBranch) {
381 auto *NewBr = UncondBrInst::Create(Target: New, InsertBefore: Old);
382 NewBr->setDebugLoc(DL);
383 }
384}
385
386void llvm::spliceBB(IRBuilder<> &Builder, BasicBlock *New, bool CreateBranch) {
387 DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
388 BasicBlock *Old = Builder.GetInsertBlock();
389
390 spliceBB(IP: Builder.saveIP(), New, CreateBranch, DL: DebugLoc);
391 if (CreateBranch)
392 Builder.SetInsertPoint(Old->getTerminator());
393 else
394 Builder.SetInsertPoint(Old);
395
396 // SetInsertPoint also updates the Builder's debug location, but we want to
397 // keep the one the Builder was configured to use.
398 Builder.SetCurrentDebugLocation(DebugLoc);
399}
400
401BasicBlock *llvm::splitBB(IRBuilderBase::InsertPoint IP, bool CreateBranch,
402 DebugLoc DL, llvm::Twine Name) {
403 BasicBlock *Old = IP.getBlock();
404 BasicBlock *New = BasicBlock::Create(
405 Context&: Old->getContext(), Name: Name.isTriviallyEmpty() ? Old->getName() : Name,
406 Parent: Old->getParent(), InsertBefore: Old->getNextNode());
407 spliceBB(IP, New, CreateBranch, DL);
408 New->replaceSuccessorsPhiUsesWith(Old, New);
409 return New;
410}
411
412BasicBlock *llvm::splitBB(IRBuilderBase &Builder, bool CreateBranch,
413 llvm::Twine Name) {
414 DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
415 BasicBlock *New = splitBB(IP: Builder.saveIP(), CreateBranch, DL: DebugLoc, Name);
416 if (CreateBranch)
417 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
418 else
419 Builder.SetInsertPoint(Builder.GetInsertBlock());
420 // SetInsertPoint also updates the Builder's debug location, but we want to
421 // keep the one the Builder was configured to use.
422 Builder.SetCurrentDebugLocation(DebugLoc);
423 return New;
424}
425
426BasicBlock *llvm::splitBB(IRBuilder<> &Builder, bool CreateBranch,
427 llvm::Twine Name) {
428 DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
429 BasicBlock *New = splitBB(IP: Builder.saveIP(), CreateBranch, DL: DebugLoc, Name);
430 if (CreateBranch)
431 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
432 else
433 Builder.SetInsertPoint(Builder.GetInsertBlock());
434 // SetInsertPoint also updates the Builder's debug location, but we want to
435 // keep the one the Builder was configured to use.
436 Builder.SetCurrentDebugLocation(DebugLoc);
437 return New;
438}
439
440BasicBlock *llvm::splitBBWithSuffix(IRBuilderBase &Builder, bool CreateBranch,
441 llvm::Twine Suffix) {
442 BasicBlock *Old = Builder.GetInsertBlock();
443 return splitBB(Builder, CreateBranch, Name: Old->getName() + Suffix);
444}
445
446// This function creates a fake integer value and a fake use for the integer
447// value. It returns the fake value created. This is useful in modeling the
448// extra arguments to the outlined functions.
449Value *createFakeIntVal(IRBuilderBase &Builder,
450 OpenMPIRBuilder::InsertPointTy OuterAllocaIP,
451 llvm::SmallVectorImpl<Instruction *> &ToBeDeleted,
452 OpenMPIRBuilder::InsertPointTy InnerAllocaIP,
453 const Twine &Name = "", bool AsPtr = true,
454 bool Is64Bit = false) {
455 Builder.restoreIP(IP: OuterAllocaIP);
456 IntegerType *IntTy = Is64Bit ? Builder.getInt64Ty() : Builder.getInt32Ty();
457 Instruction *FakeVal;
458 AllocaInst *FakeValAddr =
459 Builder.CreateAlloca(Ty: IntTy, ArraySize: nullptr, Name: Name + ".addr");
460 ToBeDeleted.push_back(Elt: FakeValAddr);
461
462 if (AsPtr) {
463 FakeVal = FakeValAddr;
464 } else {
465 FakeVal = Builder.CreateLoad(Ty: IntTy, Ptr: FakeValAddr, Name: Name + ".val");
466 ToBeDeleted.push_back(Elt: FakeVal);
467 }
468
469 // Generate a fake use of this value
470 Builder.restoreIP(IP: InnerAllocaIP);
471 Instruction *UseFakeVal;
472 if (AsPtr) {
473 UseFakeVal = Builder.CreateLoad(Ty: IntTy, Ptr: FakeVal, Name: Name + ".use");
474 } else {
475 UseFakeVal = cast<BinaryOperator>(Val: Builder.CreateAdd(
476 LHS: FakeVal, RHS: Is64Bit ? Builder.getInt64(C: 10) : Builder.getInt32(C: 10)));
477 }
478 ToBeDeleted.push_back(Elt: UseFakeVal);
479 return FakeVal;
480}
481
482//===----------------------------------------------------------------------===//
483// OpenMPIRBuilderConfig
484//===----------------------------------------------------------------------===//
485
486namespace {
487LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
488/// Values for bit flags for marking which requires clauses have been used.
489enum OpenMPOffloadingRequiresDirFlags {
490 /// flag undefined.
491 OMP_REQ_UNDEFINED = 0x000,
492 /// no requires directive present.
493 OMP_REQ_NONE = 0x001,
494 /// reverse_offload clause.
495 OMP_REQ_REVERSE_OFFLOAD = 0x002,
496 /// unified_address clause.
497 OMP_REQ_UNIFIED_ADDRESS = 0x004,
498 /// unified_shared_memory clause.
499 OMP_REQ_UNIFIED_SHARED_MEMORY = 0x008,
500 /// dynamic_allocators clause.
501 OMP_REQ_DYNAMIC_ALLOCATORS = 0x010,
502 LLVM_MARK_AS_BITMASK_ENUM(/*LargestValue=*/OMP_REQ_DYNAMIC_ALLOCATORS)
503};
504
505class OMPCodeExtractor : public CodeExtractor {
506public:
507 OMPCodeExtractor(OpenMPIRBuilder &OMPBuilder, ArrayRef<BasicBlock *> BBs,
508 DominatorTree *DT = nullptr, bool AggregateArgs = false,
509 BlockFrequencyInfo *BFI = nullptr,
510 BranchProbabilityInfo *BPI = nullptr,
511 AssumptionCache *AC = nullptr, bool AllowVarArgs = false,
512 bool AllowAlloca = false,
513 BasicBlock *AllocationBlock = nullptr,
514 ArrayRef<BasicBlock *> DeallocationBlocks = {},
515 std::string Suffix = "", bool ArgsInZeroAddressSpace = false)
516 : CodeExtractor(BBs, DT, AggregateArgs, BFI, BPI, AC, AllowVarArgs,
517 AllowAlloca, AllocationBlock, DeallocationBlocks, Suffix,
518 ArgsInZeroAddressSpace),
519 OMPBuilder(OMPBuilder) {}
520
521 virtual ~OMPCodeExtractor() = default;
522
523protected:
524 OpenMPIRBuilder &OMPBuilder;
525};
526
527class DeviceSharedMemCodeExtractor : public OMPCodeExtractor {
528public:
529 using OMPCodeExtractor::OMPCodeExtractor;
530 virtual ~DeviceSharedMemCodeExtractor() = default;
531
532protected:
533 virtual Instruction *
534 allocateVar(IRBuilder<>::InsertPoint AllocaIP, Type *VarType,
535 const Twine &Name = Twine(""),
536 AddrSpaceCastInst **CastedAlloc = nullptr) override {
537 return OMPBuilder.createOMPAllocShared(Loc: AllocaIP, VarType, Name);
538 }
539
540 virtual Instruction *deallocateVar(IRBuilder<>::InsertPoint DeallocIP,
541 Value *Var, Type *VarType) override {
542 return OMPBuilder.createOMPFreeShared(Loc: DeallocIP, Addr: Var, VarType);
543 }
544};
545
546/// Helper storing information about regions to outline using device shared
547/// memory for intermediate allocations.
548struct DeviceSharedMemOutlineInfo : public OpenMPIRBuilder::OutlineInfo {
549 OpenMPIRBuilder &OMPBuilder;
550
551 DeviceSharedMemOutlineInfo(OpenMPIRBuilder &OMPBuilder)
552 : OMPBuilder(OMPBuilder) {}
553 virtual ~DeviceSharedMemOutlineInfo() = default;
554
555 virtual std::unique_ptr<CodeExtractor>
556 createCodeExtractor(ArrayRef<BasicBlock *> Blocks,
557 bool ArgsInZeroAddressSpace,
558 Twine Suffix = Twine("")) override;
559};
560
561} // anonymous namespace
562
563OpenMPIRBuilderConfig::OpenMPIRBuilderConfig()
564 : RequiresFlags(OMP_REQ_UNDEFINED) {}
565
566OpenMPIRBuilderConfig::OpenMPIRBuilderConfig(
567 bool IsTargetDevice, bool IsGPU, bool OpenMPOffloadMandatory,
568 bool HasRequiresReverseOffload, bool HasRequiresUnifiedAddress,
569 bool HasRequiresUnifiedSharedMemory, bool HasRequiresDynamicAllocators)
570 : IsTargetDevice(IsTargetDevice), IsGPU(IsGPU),
571 OpenMPOffloadMandatory(OpenMPOffloadMandatory),
572 RequiresFlags(OMP_REQ_UNDEFINED) {
573 if (HasRequiresReverseOffload)
574 RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
575 if (HasRequiresUnifiedAddress)
576 RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
577 if (HasRequiresUnifiedSharedMemory)
578 RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
579 if (HasRequiresDynamicAllocators)
580 RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
581}
582
583bool OpenMPIRBuilderConfig::hasRequiresReverseOffload() const {
584 return RequiresFlags & OMP_REQ_REVERSE_OFFLOAD;
585}
586
587bool OpenMPIRBuilderConfig::hasRequiresUnifiedAddress() const {
588 return RequiresFlags & OMP_REQ_UNIFIED_ADDRESS;
589}
590
591bool OpenMPIRBuilderConfig::hasRequiresUnifiedSharedMemory() const {
592 return RequiresFlags & OMP_REQ_UNIFIED_SHARED_MEMORY;
593}
594
595bool OpenMPIRBuilderConfig::hasRequiresDynamicAllocators() const {
596 return RequiresFlags & OMP_REQ_DYNAMIC_ALLOCATORS;
597}
598
599int64_t OpenMPIRBuilderConfig::getRequiresFlags() const {
600 return hasRequiresFlags() ? RequiresFlags
601 : static_cast<int64_t>(OMP_REQ_NONE);
602}
603
604void OpenMPIRBuilderConfig::setHasRequiresReverseOffload(bool Value) {
605 if (Value)
606 RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
607 else
608 RequiresFlags &= ~OMP_REQ_REVERSE_OFFLOAD;
609}
610
611void OpenMPIRBuilderConfig::setHasRequiresUnifiedAddress(bool Value) {
612 if (Value)
613 RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
614 else
615 RequiresFlags &= ~OMP_REQ_UNIFIED_ADDRESS;
616}
617
618void OpenMPIRBuilderConfig::setHasRequiresUnifiedSharedMemory(bool Value) {
619 if (Value)
620 RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
621 else
622 RequiresFlags &= ~OMP_REQ_UNIFIED_SHARED_MEMORY;
623}
624
625void OpenMPIRBuilderConfig::setHasRequiresDynamicAllocators(bool Value) {
626 if (Value)
627 RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
628 else
629 RequiresFlags &= ~OMP_REQ_DYNAMIC_ALLOCATORS;
630}
631
632//===----------------------------------------------------------------------===//
633// OpenMPIRBuilder
634//===----------------------------------------------------------------------===//
635
636void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
637 IRBuilderBase &Builder,
638 SmallVector<Value *> &ArgsVector) {
639 Value *Version = Builder.getInt32(OMP_KERNEL_ARG_VERSION);
640 Value *PointerNum = Builder.getInt32(C: KernelArgs.NumTargetItems);
641 auto Int32Ty = Type::getInt32Ty(C&: Builder.getContext());
642 constexpr size_t MaxDim = 3;
643 Value *ZeroArray = Constant::getNullValue(Ty: ArrayType::get(ElementType: Int32Ty, NumElements: MaxDim));
644
645 Value *HasNoWaitFlag = Builder.getInt64(C: KernelArgs.HasNoWait);
646
647 Value *DynCGroupMemFallbackFlag =
648 Builder.getInt64(C: static_cast<uint64_t>(KernelArgs.DynCGroupMemFallback));
649 DynCGroupMemFallbackFlag = Builder.CreateShl(LHS: DynCGroupMemFallbackFlag, RHS: 2);
650
651 Value *StrictFlag = Builder.getInt64(C: KernelArgs.StrictBlocksAndThreads);
652 StrictFlag = Builder.CreateShl(LHS: StrictFlag, RHS: 6);
653
654 Value *Flags = Builder.CreateOr(LHS: HasNoWaitFlag, RHS: DynCGroupMemFallbackFlag);
655 Flags = Builder.CreateOr(LHS: Flags, RHS: StrictFlag);
656
657 assert(!KernelArgs.NumTeams.empty() && !KernelArgs.NumThreads.empty());
658
659 Value *NumTeams3D =
660 Builder.CreateInsertValue(Agg: ZeroArray, Val: KernelArgs.NumTeams[0], Idxs: {0});
661 Value *NumThreads3D =
662 Builder.CreateInsertValue(Agg: ZeroArray, Val: KernelArgs.NumThreads[0], Idxs: {0});
663 for (unsigned I :
664 seq<unsigned>(Begin: 1, End: std::min(a: KernelArgs.NumTeams.size(), b: MaxDim)))
665 NumTeams3D =
666 Builder.CreateInsertValue(Agg: NumTeams3D, Val: KernelArgs.NumTeams[I], Idxs: {I});
667 for (unsigned I :
668 seq<unsigned>(Begin: 1, End: std::min(a: KernelArgs.NumThreads.size(), b: MaxDim)))
669 NumThreads3D =
670 Builder.CreateInsertValue(Agg: NumThreads3D, Val: KernelArgs.NumThreads[I], Idxs: {I});
671
672 ArgsVector = {Version,
673 PointerNum,
674 KernelArgs.RTArgs.BasePointersArray,
675 KernelArgs.RTArgs.PointersArray,
676 KernelArgs.RTArgs.SizesArray,
677 KernelArgs.RTArgs.MapTypesArray,
678 KernelArgs.RTArgs.MapNamesArray,
679 KernelArgs.RTArgs.MappersArray,
680 KernelArgs.NumIterations,
681 Flags,
682 NumTeams3D,
683 NumThreads3D,
684 KernelArgs.DynCGroupMem};
685}
686
687void OpenMPIRBuilder::addAttributes(omp::RuntimeFunction FnID, Function &Fn) {
688 LLVMContext &Ctx = Fn.getContext();
689
690 // Get the function's current attributes.
691 auto Attrs = Fn.getAttributes();
692 auto FnAttrs = Attrs.getFnAttrs();
693 auto RetAttrs = Attrs.getRetAttrs();
694 SmallVector<AttributeSet, 4> ArgAttrs;
695 for (size_t ArgNo = 0; ArgNo < Fn.arg_size(); ++ArgNo)
696 ArgAttrs.emplace_back(Args: Attrs.getParamAttrs(ArgNo));
697
698 // Add AS to FnAS while taking special care with integer extensions.
699 auto addAttrSet = [&](AttributeSet &FnAS, const AttributeSet &AS,
700 bool Param = true) -> void {
701 bool HasSignExt = AS.hasAttribute(Kind: Attribute::SExt);
702 bool HasZeroExt = AS.hasAttribute(Kind: Attribute::ZExt);
703 if (HasSignExt || HasZeroExt) {
704 assert(AS.getNumAttributes() == 1 &&
705 "Currently not handling extension attr combined with others.");
706 if (Param) {
707 if (auto AK = TargetLibraryInfo::getExtAttrForI32Param(T, Signed: HasSignExt))
708 FnAS = FnAS.addAttribute(C&: Ctx, Kind: AK);
709 } else if (auto AK =
710 TargetLibraryInfo::getExtAttrForI32Return(T, Signed: HasSignExt))
711 FnAS = FnAS.addAttribute(C&: Ctx, Kind: AK);
712 } else {
713 FnAS = FnAS.addAttributes(C&: Ctx, AS);
714 }
715 };
716
717#define OMP_ATTRS_SET(VarName, AttrSet) AttributeSet VarName = AttrSet;
718#include "llvm/Frontend/OpenMP/OMPKinds.def"
719
720 // Add attributes to the function declaration.
721 switch (FnID) {
722#define OMP_RTL_ATTRS(Enum, FnAttrSet, RetAttrSet, ArgAttrSets) \
723 case Enum: \
724 FnAttrs = FnAttrs.addAttributes(Ctx, FnAttrSet); \
725 addAttrSet(RetAttrs, RetAttrSet, /*Param*/ false); \
726 for (size_t ArgNo = 0; ArgNo < ArgAttrSets.size(); ++ArgNo) \
727 addAttrSet(ArgAttrs[ArgNo], ArgAttrSets[ArgNo]); \
728 Fn.setAttributes(AttributeList::get(Ctx, FnAttrs, RetAttrs, ArgAttrs)); \
729 break;
730#include "llvm/Frontend/OpenMP/OMPKinds.def"
731 default:
732 // Attributes are optional.
733 break;
734 }
735}
736
737FunctionCallee
738OpenMPIRBuilder::getOrCreateRuntimeFunction(Module &M, RuntimeFunction FnID) {
739 FunctionType *FnTy = nullptr;
740 Function *Fn = nullptr;
741
742 // Try to find the declation in the module first.
743 switch (FnID) {
744#define OMP_RTL(Enum, Str, IsVarArg, ReturnType, ...) \
745 case Enum: \
746 FnTy = FunctionType::get(ReturnType, ArrayRef<Type *>{__VA_ARGS__}, \
747 IsVarArg); \
748 Fn = M.getFunction(Str); \
749 break;
750#include "llvm/Frontend/OpenMP/OMPKinds.def"
751 }
752
753 if (!Fn) {
754 // Create a new declaration if we need one.
755 switch (FnID) {
756#define OMP_RTL(Enum, Str, ...) \
757 case Enum: \
758 Fn = Function::Create(FnTy, GlobalValue::ExternalLinkage, Str, M); \
759 break;
760#include "llvm/Frontend/OpenMP/OMPKinds.def"
761 }
762 Fn->setCallingConv(Config.getRuntimeCC());
763 // Add information if the runtime function takes a callback function
764 if (FnID == OMPRTL___kmpc_fork_call || FnID == OMPRTL___kmpc_fork_teams) {
765 if (!Fn->hasMetadata(KindID: LLVMContext::MD_callback)) {
766 LLVMContext &Ctx = Fn->getContext();
767 MDBuilder MDB(Ctx);
768 // Annotate the callback behavior of the runtime function:
769 // - The callback callee is argument number 2 (microtask).
770 // - The first two arguments of the callback callee are unknown (-1).
771 // - All variadic arguments to the runtime function are passed to the
772 // callback callee.
773 Fn->addMetadata(
774 KindID: LLVMContext::MD_callback,
775 MD&: *MDNode::get(Context&: Ctx, MDs: {MDB.createCallbackEncoding(
776 CalleeArgNo: 2, Arguments: {-1, -1}, /* VarArgsArePassed */ true)}));
777 }
778 }
779
780 LLVM_DEBUG(dbgs() << "Created OpenMP runtime function " << Fn->getName()
781 << " with type " << *Fn->getFunctionType() << "\n");
782 addAttributes(FnID, Fn&: *Fn);
783
784 } else {
785 LLVM_DEBUG(dbgs() << "Found OpenMP runtime function " << Fn->getName()
786 << " with type " << *Fn->getFunctionType() << "\n");
787 }
788
789 assert(Fn && "Failed to create OpenMP runtime function");
790
791 return {FnTy, Fn};
792}
793
794Expected<BasicBlock *>
795OpenMPIRBuilder::FinalizationInfo::getFiniBB(IRBuilderBase &Builder) {
796 if (!FiniBB) {
797 Function *ParentFunc = Builder.GetInsertBlock()->getParent();
798 IRBuilderBase::InsertPointGuard Guard(Builder);
799 FiniBB = BasicBlock::Create(Context&: Builder.getContext(), Name: ".fini", Parent: ParentFunc);
800 Builder.SetInsertPoint(FiniBB);
801 // FiniCB adds the branch to the exit stub.
802 if (Error Err = FiniCB(Builder.saveIP()))
803 return Err;
804 }
805 return FiniBB;
806}
807
808Error OpenMPIRBuilder::FinalizationInfo::mergeFiniBB(IRBuilderBase &Builder,
809 BasicBlock *OtherFiniBB) {
810 // Simple case: FiniBB does not exist yet: re-use OtherFiniBB.
811 if (!FiniBB) {
812 FiniBB = OtherFiniBB;
813
814 Builder.SetInsertPoint(FiniBB->getFirstNonPHIIt());
815 if (Error Err = FiniCB(Builder.saveIP()))
816 return Err;
817
818 return Error::success();
819 }
820
821 // Move instructions from FiniBB to the start of OtherFiniBB.
822 auto EndIt = FiniBB->end();
823 if (FiniBB->size() >= 1)
824 if (auto Prev = std::prev(x: EndIt); Prev->isTerminator())
825 EndIt = Prev;
826 OtherFiniBB->splice(ToIt: OtherFiniBB->getFirstNonPHIIt(), FromBB: FiniBB, FromBeginIt: FiniBB->begin(),
827 FromEndIt: EndIt);
828
829 FiniBB->replaceAllUsesWith(V: OtherFiniBB);
830 FiniBB->eraseFromParent();
831 FiniBB = OtherFiniBB;
832 return Error::success();
833}
834
835Function *OpenMPIRBuilder::getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID) {
836 FunctionCallee RTLFn = getOrCreateRuntimeFunction(M, FnID);
837 auto *Fn = dyn_cast<llvm::Function>(Val: RTLFn.getCallee());
838 assert(Fn && "Failed to create OpenMP runtime function pointer");
839 return Fn;
840}
841
842CallInst *OpenMPIRBuilder::createRuntimeFunctionCall(FunctionCallee Callee,
843 ArrayRef<Value *> Args,
844 StringRef Name) {
845 CallInst *Call = Builder.CreateCall(Callee, Args, Name);
846 Call->setCallingConv(Config.getRuntimeCC());
847 return Call;
848}
849
850void OpenMPIRBuilder::initialize() { initializeTypes(M); }
851
852static void raiseUserConstantDataAllocasToEntryBlock(IRBuilderBase &Builder,
853 Function *Function) {
854 BasicBlock &EntryBlock = Function->getEntryBlock();
855 BasicBlock::iterator MoveLocInst = EntryBlock.getFirstNonPHIIt();
856
857 // Loop over blocks looking for constant allocas, skipping the entry block
858 // as any allocas there are already in the desired location.
859 for (auto Block = std::next(x: Function->begin(), n: 1); Block != Function->end();
860 Block++) {
861 for (auto Inst = Block->getReverseIterator()->begin();
862 Inst != Block->getReverseIterator()->end();) {
863 if (auto *AllocaInst = dyn_cast_if_present<llvm::AllocaInst>(Val&: Inst)) {
864 Inst++;
865 if (!isa<ConstantData>(Val: AllocaInst->getArraySize()))
866 continue;
867 AllocaInst->moveBeforePreserving(MovePos: MoveLocInst);
868 } else {
869 Inst++;
870 }
871 }
872 }
873}
874
875static void hoistNonEntryAllocasToEntryBlock(llvm::BasicBlock &Block) {
876 llvm::SmallVector<llvm::Instruction *> AllocasToMove;
877
878 auto ShouldHoistAlloca = [](const llvm::AllocaInst &AllocaInst) {
879 // TODO: For now, we support simple static allocations, we might need to
880 // move non-static ones as well. However, this will need further analysis to
881 // move the lenght arguments as well.
882 return !AllocaInst.isArrayAllocation();
883 };
884
885 for (llvm::Instruction &Inst : Block)
886 if (auto *AllocaInst = llvm::dyn_cast<llvm::AllocaInst>(Val: &Inst))
887 if (ShouldHoistAlloca(*AllocaInst))
888 AllocasToMove.push_back(Elt: AllocaInst);
889
890 auto InsertPoint =
891 Block.getParent()->getEntryBlock().getTerminator()->getIterator();
892
893 for (llvm::Instruction *AllocaInst : AllocasToMove)
894 AllocaInst->moveBefore(InsertPos: InsertPoint);
895}
896
897static void hoistNonEntryAllocasToEntryBlock(llvm::Function *Func) {
898 PostDominatorTree PostDomTree(*Func);
899 for (llvm::BasicBlock &BB : *Func)
900 if (PostDomTree.properlyDominates(A: &BB, B: &Func->getEntryBlock()))
901 hoistNonEntryAllocasToEntryBlock(Block&: BB);
902}
903
904void OpenMPIRBuilder::finalize(Function *Fn) {
905 SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
906 SmallVector<BasicBlock *, 32> Blocks;
907 SmallVector<std::unique_ptr<OutlineInfo>, 16> DeferredOutlines;
908 for (std::unique_ptr<OutlineInfo> &OI : OutlineInfos) {
909 // Skip functions that have not finalized yet; may happen with nested
910 // function generation.
911 if (Fn && OI->getFunction() != Fn) {
912 DeferredOutlines.push_back(Elt: std::move(OI));
913 continue;
914 }
915
916 ParallelRegionBlockSet.clear();
917 Blocks.clear();
918 OI->collectBlocks(BlockSet&: ParallelRegionBlockSet, BlockVector&: Blocks);
919
920 Function *OuterFn = OI->getFunction();
921 CodeExtractorAnalysisCache CEAC(*OuterFn);
922 // If we generate code for the target device, we need to allocate
923 // struct for aggregate params in the device default alloca address space.
924 // OpenMP runtime requires that the params of the extracted functions are
925 // passed as zero address space pointers. This flag ensures that
926 // CodeExtractor generates correct code for extracted functions
927 // which are used by OpenMP runtime.
928 bool ArgsInZeroAddressSpace = Config.isTargetDevice();
929 std::unique_ptr<CodeExtractor> Extractor =
930 OI->createCodeExtractor(Blocks, ArgsInZeroAddressSpace, Suffix: ".omp_par");
931
932 LLVM_DEBUG(dbgs() << "Before outlining: " << *OuterFn << "\n");
933 LLVM_DEBUG(dbgs() << "Entry " << OI->EntryBB->getName()
934 << " Exit: " << OI->ExitBB->getName() << "\n");
935 assert(Extractor->isEligible() &&
936 "Expected OpenMP outlining to be possible!");
937
938 for (auto *V : OI->ExcludeArgsFromAggregate)
939 Extractor->excludeArgFromAggregate(Arg: V);
940
941 Function *OutlinedFn =
942 Extractor->extractCodeRegion(CEAC, Inputs&: OI->Inputs, Outputs&: OI->Outputs);
943
944 // Forward target-cpu, target-features attributes to the outlined function.
945 auto TargetCpuAttr = OuterFn->getFnAttribute(Kind: "target-cpu");
946 if (TargetCpuAttr.isStringAttribute())
947 OutlinedFn->addFnAttr(Attr: TargetCpuAttr);
948
949 auto TargetFeaturesAttr = OuterFn->getFnAttribute(Kind: "target-features");
950 if (TargetFeaturesAttr.isStringAttribute())
951 OutlinedFn->addFnAttr(Attr: TargetFeaturesAttr);
952
953 LLVM_DEBUG(dbgs() << "After outlining: " << *OuterFn << "\n");
954 LLVM_DEBUG(dbgs() << " Outlined function: " << *OutlinedFn << "\n");
955 assert(OutlinedFn->getReturnType()->isVoidTy() &&
956 "OpenMP outlined functions should not return a value!");
957
958 // For compability with the clang CG we move the outlined function after the
959 // one with the parallel region.
960 OutlinedFn->removeFromParent();
961 M.getFunctionList().insertAfter(where: OuterFn->getIterator(), New: OutlinedFn);
962
963 // Remove the artificial entry introduced by the extractor right away, we
964 // made our own entry block after all.
965 {
966 BasicBlock &ArtificialEntry = OutlinedFn->getEntryBlock();
967 assert(ArtificialEntry.getUniqueSuccessor() == OI->EntryBB);
968 assert(OI->EntryBB->getUniquePredecessor() == &ArtificialEntry);
969 // Move instructions from the to-be-deleted ArtificialEntry to the entry
970 // basic block of the parallel region. CodeExtractor generates
971 // instructions to unwrap the aggregate argument and may sink
972 // allocas/bitcasts for values that are solely used in the outlined region
973 // and do not escape.
974 assert(!ArtificialEntry.empty() &&
975 "Expected instructions to add in the outlined region entry");
976 for (BasicBlock::reverse_iterator It = ArtificialEntry.rbegin(),
977 End = ArtificialEntry.rend();
978 It != End;) {
979 Instruction &I = *It;
980 It++;
981
982 if (I.isTerminator()) {
983 // Absorb any debug value that terminator may have
984 if (Instruction *TI = OI->EntryBB->getTerminatorOrNull())
985 TI->adoptDbgRecords(BB: &ArtificialEntry, It: I.getIterator(), InsertAtHead: false);
986 continue;
987 }
988
989 I.moveBeforePreserving(BB&: *OI->EntryBB,
990 I: OI->EntryBB->getFirstInsertionPt());
991 }
992
993 OI->EntryBB->moveBefore(MovePos: &ArtificialEntry);
994 ArtificialEntry.eraseFromParent();
995 }
996 assert(&OutlinedFn->getEntryBlock() == OI->EntryBB);
997 assert(OutlinedFn && OutlinedFn->hasNUses(1));
998
999 // Run a user callback, e.g. to add attributes.
1000 if (OI->PostOutlineCB)
1001 OI->PostOutlineCB(*OutlinedFn);
1002
1003 if (OI->FixUpNonEntryAllocas)
1004 hoistNonEntryAllocasToEntryBlock(Func: OutlinedFn);
1005 }
1006
1007 // Remove work items that have been completed.
1008 OutlineInfos = std::move(DeferredOutlines);
1009
1010 // The createTarget functions embeds user written code into
1011 // the target region which may inject allocas which need to
1012 // be moved to the entry block of our target or risk malformed
1013 // optimisations by later passes, this is only relevant for
1014 // the device pass which appears to be a little more delicate
1015 // when it comes to optimisations (however, we do not block on
1016 // that here, it's up to the inserter to the list to do so).
1017 // This notbaly has to occur after the OutlinedInfo candidates
1018 // have been extracted so we have an end product that will not
1019 // be implicitly adversely affected by any raises unless
1020 // intentionally appended to the list.
1021 // NOTE: This only does so for ConstantData, it could be extended
1022 // to ConstantExpr's with further effort, however, they should
1023 // largely be folded when they get here. Extending it to runtime
1024 // defined/read+writeable allocation sizes would be non-trivial
1025 // (need to factor in movement of any stores to variables the
1026 // allocation size depends on, as well as the usual loads,
1027 // otherwise it'll yield the wrong result after movement) and
1028 // likely be more suitable as an LLVM optimisation pass.
1029 for (Function *F : ConstantAllocaRaiseCandidates)
1030 raiseUserConstantDataAllocasToEntryBlock(Builder, Function: F);
1031
1032 EmitMetadataErrorReportFunctionTy &&ErrorReportFn =
1033 [](EmitMetadataErrorKind Kind,
1034 const TargetRegionEntryInfo &EntryInfo) -> void {
1035 errs() << "Error of kind: " << Kind
1036 << " when emitting offload entries and metadata during "
1037 "OMPIRBuilder finalization \n";
1038 };
1039
1040 if (!OffloadInfoManager.empty())
1041 createOffloadEntriesAndInfoMetadata(ErrorReportFunction&: ErrorReportFn);
1042
1043 if (Config.EmitLLVMUsedMetaInfo.value_or(u: false)) {
1044 std::vector<WeakTrackingVH> LLVMCompilerUsed = {
1045 M.getGlobalVariable(Name: "__openmp_nvptx_data_transfer_temporary_storage")};
1046 emitUsed(Name: "llvm.compiler.used", List: LLVMCompilerUsed);
1047 }
1048
1049 IsFinalized = true;
1050}
1051
1052bool OpenMPIRBuilder::isFinalized() { return IsFinalized; }
1053
1054OpenMPIRBuilder::~OpenMPIRBuilder() {
1055 assert(OutlineInfos.empty() && "There must be no outstanding outlinings");
1056}
1057
1058GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
1059 IntegerType *I32Ty = Type::getInt32Ty(C&: M.getContext());
1060 auto *GV =
1061 new GlobalVariable(M, I32Ty,
1062 /* isConstant = */ true, GlobalValue::WeakODRLinkage,
1063 ConstantInt::get(Ty: I32Ty, V: Value), Name);
1064 GV->setVisibility(GlobalValue::HiddenVisibility);
1065
1066 return GV;
1067}
1068
1069void OpenMPIRBuilder::emitUsed(StringRef Name, ArrayRef<WeakTrackingVH> List) {
1070 if (List.empty())
1071 return;
1072
1073 // Convert List to what ConstantArray needs.
1074 SmallVector<Constant *, 8> UsedArray;
1075 UsedArray.resize(N: List.size());
1076 for (unsigned I = 0, E = List.size(); I != E; ++I)
1077 UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
1078 C: cast<Constant>(Val: &*List[I]), Ty: Builder.getPtrTy());
1079
1080 if (UsedArray.empty())
1081 return;
1082 ArrayType *ATy = ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: UsedArray.size());
1083
1084 auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
1085 ConstantArray::get(T: ATy, V: UsedArray), Name);
1086
1087 GV->setSection("llvm.metadata");
1088}
1089
1090GlobalVariable *
1091OpenMPIRBuilder::emitKernelExecutionMode(StringRef KernelName,
1092 OMPTgtExecModeFlags Mode) {
1093 auto *Int8Ty = Builder.getInt8Ty();
1094 auto *GVMode = new GlobalVariable(
1095 M, Int8Ty, /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
1096 ConstantInt::get(Ty: Int8Ty, V: Mode), Twine(KernelName, "_exec_mode"));
1097 GVMode->setVisibility(GlobalVariable::ProtectedVisibility);
1098 return GVMode;
1099}
1100
1101Constant *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr,
1102 uint32_t SrcLocStrSize,
1103 IdentFlag LocFlags,
1104 unsigned Reserve2Flags) {
1105 // Enable "C-mode".
1106 LocFlags |= OMP_IDENT_FLAG_KMPC;
1107
1108 Constant *&Ident =
1109 IdentMap[{SrcLocStr, uint64_t(LocFlags) << 31 | Reserve2Flags}];
1110 if (!Ident) {
1111 Constant *I32Null = ConstantInt::getNullValue(Ty: Int32);
1112 Constant *IdentData[] = {I32Null,
1113 ConstantInt::get(Ty: Int32, V: uint32_t(LocFlags)),
1114 ConstantInt::get(Ty: Int32, V: Reserve2Flags),
1115 ConstantInt::get(Ty: Int32, V: SrcLocStrSize), SrcLocStr};
1116
1117 size_t SrcLocStrArgIdx = 4;
1118 if (OpenMPIRBuilder::Ident->getElementType(N: SrcLocStrArgIdx)
1119 ->getPointerAddressSpace() !=
1120 IdentData[SrcLocStrArgIdx]->getType()->getPointerAddressSpace())
1121 IdentData[SrcLocStrArgIdx] = ConstantExpr::getAddrSpaceCast(
1122 C: SrcLocStr, Ty: OpenMPIRBuilder::Ident->getElementType(N: SrcLocStrArgIdx));
1123 Constant *Initializer =
1124 ConstantStruct::get(T: OpenMPIRBuilder::Ident, V: IdentData);
1125
1126 // Look for existing encoding of the location + flags, not needed but
1127 // minimizes the difference to the existing solution while we transition.
1128 for (GlobalVariable &GV : M.globals())
1129 if (GV.getValueType() == OpenMPIRBuilder::Ident && GV.hasInitializer())
1130 if (GV.getInitializer() == Initializer)
1131 Ident = &GV;
1132
1133 if (!Ident) {
1134 auto *GV = new GlobalVariable(
1135 M, OpenMPIRBuilder::Ident,
1136 /* isConstant = */ true, GlobalValue::PrivateLinkage, Initializer, "",
1137 nullptr, GlobalValue::NotThreadLocal,
1138 M.getDataLayout().getDefaultGlobalsAddressSpace());
1139 GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
1140 GV->setAlignment(Align(8));
1141 Ident = GV;
1142 }
1143 }
1144
1145 return ConstantExpr::getPointerBitCastOrAddrSpaceCast(C: Ident, Ty: IdentPtr);
1146}
1147
1148Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef LocStr,
1149 uint32_t &SrcLocStrSize) {
1150 SrcLocStrSize = LocStr.size();
1151 Constant *&SrcLocStr = SrcLocStrMap[LocStr];
1152 if (!SrcLocStr) {
1153 Constant *Initializer =
1154 ConstantDataArray::getString(Context&: M.getContext(), Initializer: LocStr);
1155
1156 // Look for existing encoding of the location, not needed but minimizes the
1157 // difference to the existing solution while we transition.
1158 for (GlobalVariable &GV : M.globals())
1159 if (GV.isConstant() && GV.hasInitializer() &&
1160 GV.getInitializer() == Initializer)
1161 return SrcLocStr = ConstantExpr::getPointerCast(C: &GV, Ty: Int8Ptr);
1162
1163 SrcLocStr = Builder.CreateGlobalString(
1164 Str: LocStr, /*Name=*/"", AddressSpace: M.getDataLayout().getDefaultGlobalsAddressSpace(),
1165 M: &M);
1166 }
1167 return SrcLocStr;
1168}
1169
1170Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef FunctionName,
1171 StringRef FileName,
1172 unsigned Line, unsigned Column,
1173 uint32_t &SrcLocStrSize) {
1174 SmallString<128> Buffer;
1175 Buffer.push_back(Elt: ';');
1176 Buffer.append(RHS: FileName);
1177 Buffer.push_back(Elt: ';');
1178 Buffer.append(RHS: FunctionName);
1179 Buffer.push_back(Elt: ';');
1180 Buffer.append(RHS: std::to_string(val: Line));
1181 Buffer.push_back(Elt: ';');
1182 Buffer.append(RHS: std::to_string(val: Column));
1183 Buffer.push_back(Elt: ';');
1184 Buffer.push_back(Elt: ';');
1185 return getOrCreateSrcLocStr(LocStr: Buffer.str(), SrcLocStrSize);
1186}
1187
1188Constant *
1189OpenMPIRBuilder::getOrCreateDefaultSrcLocStr(uint32_t &SrcLocStrSize) {
1190 StringRef UnknownLoc = ";unknown;unknown;0;0;;";
1191 return getOrCreateSrcLocStr(LocStr: UnknownLoc, SrcLocStrSize);
1192}
1193
1194Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(DebugLoc DL,
1195 uint32_t &SrcLocStrSize,
1196 Function *F) {
1197 DILocation *DIL = DL.get();
1198 if (!DIL)
1199 return getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1200 StringRef FileName =
1201 !DIL->getFilename().empty() ? DIL->getFilename() : M.getName();
1202 StringRef Function = DIL->getScope()->getSubprogram()->getName();
1203 if (Function.empty() && F)
1204 Function = F->getName();
1205 return getOrCreateSrcLocStr(FunctionName: Function, FileName, Line: DIL->getLine(),
1206 Column: DIL->getColumn(), SrcLocStrSize);
1207}
1208
1209Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(const LocationDescription &Loc,
1210 uint32_t &SrcLocStrSize) {
1211 return getOrCreateSrcLocStr(DL: Loc.DL, SrcLocStrSize,
1212 F: Loc.IP.getBlock()->getParent());
1213}
1214
1215Value *OpenMPIRBuilder::getOrCreateThreadID(Value *Ident) {
1216 return createRuntimeFunctionCall(
1217 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_global_thread_num), Args: Ident,
1218 Name: "omp_global_thread_num");
1219}
1220
1221OpenMPIRBuilder::InsertPointOrErrorTy
1222OpenMPIRBuilder::createBarrier(const LocationDescription &Loc, Directive Kind,
1223 bool ForceSimpleCall, bool CheckCancelFlag) {
1224 if (!updateToLocation(Loc))
1225 return Loc.IP;
1226
1227 // Build call __kmpc_cancel_barrier(loc, thread_id) or
1228 // __kmpc_barrier(loc, thread_id);
1229
1230 IdentFlag BarrierLocFlags;
1231 switch (Kind) {
1232 case OMPD_for:
1233 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_FOR;
1234 break;
1235 case OMPD_sections:
1236 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SECTIONS;
1237 break;
1238 case OMPD_single:
1239 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SINGLE;
1240 break;
1241 case OMPD_barrier:
1242 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_EXPL;
1243 break;
1244 default:
1245 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL;
1246 break;
1247 }
1248
1249 uint32_t SrcLocStrSize;
1250 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1251 Value *Args[] = {
1252 getOrCreateIdent(SrcLocStr, SrcLocStrSize, LocFlags: BarrierLocFlags),
1253 getOrCreateThreadID(Ident: getOrCreateIdent(SrcLocStr, SrcLocStrSize))};
1254
1255 // If we are in a cancellable parallel region, barriers are cancellation
1256 // points.
1257 // TODO: Check why we would force simple calls or to ignore the cancel flag.
1258 bool UseCancelBarrier =
1259 !ForceSimpleCall && isLastFinalizationInfoCancellable(DK: OMPD_parallel);
1260
1261 Value *Result = createRuntimeFunctionCall(
1262 Callee: getOrCreateRuntimeFunctionPtr(FnID: UseCancelBarrier
1263 ? OMPRTL___kmpc_cancel_barrier
1264 : OMPRTL___kmpc_barrier),
1265 Args);
1266
1267 if (UseCancelBarrier && CheckCancelFlag)
1268 if (Error Err = emitCancelationCheckImpl(CancelFlag: Result, CanceledDirective: OMPD_parallel))
1269 return Err;
1270
1271 return Builder.saveIP();
1272}
1273
1274OpenMPIRBuilder::InsertPointOrErrorTy
1275OpenMPIRBuilder::createCancel(const LocationDescription &Loc,
1276 Value *IfCondition,
1277 omp::Directive CanceledDirective) {
1278 if (!updateToLocation(Loc))
1279 return Loc.IP;
1280
1281 // LLVM utilities like blocks with terminators.
1282 auto *UI = Builder.CreateUnreachable();
1283
1284 Instruction *ThenTI = UI, *ElseTI = nullptr;
1285 if (IfCondition) {
1286 SplitBlockAndInsertIfThenElse(Cond: IfCondition, SplitBefore: UI, ThenTerm: &ThenTI, ElseTerm: &ElseTI);
1287
1288 // Even if the if condition evaluates to false, this should count as a
1289 // cancellation point
1290 Builder.SetInsertPoint(ElseTI);
1291 auto ElseIP = Builder.saveIP();
1292
1293 InsertPointOrErrorTy IPOrErr = createCancellationPoint(
1294 Loc: LocationDescription{ElseIP, Loc.DL}, CanceledDirective);
1295 if (!IPOrErr)
1296 return IPOrErr;
1297 }
1298
1299 Builder.SetInsertPoint(ThenTI);
1300
1301 Value *CancelKind = nullptr;
1302 switch (CanceledDirective) {
1303#define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value) \
1304 case DirectiveEnum: \
1305 CancelKind = Builder.getInt32(Value); \
1306 break;
1307#include "llvm/Frontend/OpenMP/OMPKinds.def"
1308 default:
1309 llvm_unreachable("Unknown cancel kind!");
1310 }
1311
1312 uint32_t SrcLocStrSize;
1313 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1314 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1315 Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
1316 Value *Result = createRuntimeFunctionCall(
1317 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_cancel), Args);
1318
1319 // The actual cancel logic is shared with others, e.g., cancel_barriers.
1320 if (Error Err = emitCancelationCheckImpl(CancelFlag: Result, CanceledDirective))
1321 return Err;
1322
1323 // Update the insertion point and remove the terminator we introduced.
1324 Builder.SetInsertPoint(UI->getParent());
1325 UI->eraseFromParent();
1326
1327 return Builder.saveIP();
1328}
1329
1330OpenMPIRBuilder::InsertPointOrErrorTy
1331OpenMPIRBuilder::createCancellationPoint(const LocationDescription &Loc,
1332 omp::Directive CanceledDirective) {
1333 if (!updateToLocation(Loc))
1334 return Loc.IP;
1335
1336 // LLVM utilities like blocks with terminators.
1337 auto *UI = Builder.CreateUnreachable();
1338 Builder.SetInsertPoint(UI);
1339
1340 Value *CancelKind = nullptr;
1341 switch (CanceledDirective) {
1342#define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value) \
1343 case DirectiveEnum: \
1344 CancelKind = Builder.getInt32(Value); \
1345 break;
1346#include "llvm/Frontend/OpenMP/OMPKinds.def"
1347 default:
1348 llvm_unreachable("Unknown cancel kind!");
1349 }
1350
1351 uint32_t SrcLocStrSize;
1352 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1353 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1354 Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
1355 Value *Result = createRuntimeFunctionCall(
1356 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_cancellationpoint), Args);
1357
1358 // The actual cancel logic is shared with others, e.g., cancel_barriers.
1359 if (Error Err = emitCancelationCheckImpl(CancelFlag: Result, CanceledDirective))
1360 return Err;
1361
1362 // Update the insertion point and remove the terminator we introduced.
1363 Builder.SetInsertPoint(UI->getParent());
1364 UI->eraseFromParent();
1365
1366 return Builder.saveIP();
1367}
1368
1369OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetKernel(
1370 const LocationDescription &Loc, InsertPointTy AllocaIP, Value *&Return,
1371 Value *Ident, Value *DeviceID, Value *NumTeams, Value *NumThreads,
1372 Value *HostPtr, ArrayRef<Value *> KernelArgs) {
1373 if (!updateToLocation(Loc))
1374 return Loc.IP;
1375
1376 Builder.restoreIP(IP: AllocaIP);
1377 auto *KernelArgsPtr =
1378 Builder.CreateAlloca(Ty: OpenMPIRBuilder::KernelArgs, ArraySize: nullptr, Name: "kernel_args");
1379 updateToLocation(Loc);
1380
1381 for (unsigned I = 0, Size = KernelArgs.size(); I != Size; ++I) {
1382 llvm::Value *Arg =
1383 Builder.CreateStructGEP(Ty: OpenMPIRBuilder::KernelArgs, Ptr: KernelArgsPtr, Idx: I);
1384 Builder.CreateAlignedStore(
1385 Val: KernelArgs[I], Ptr: Arg,
1386 Align: M.getDataLayout().getPrefTypeAlign(Ty: KernelArgs[I]->getType()));
1387 }
1388
1389 SmallVector<Value *> OffloadingArgs{Ident, DeviceID, NumTeams,
1390 NumThreads, HostPtr, KernelArgsPtr};
1391
1392 Return = createRuntimeFunctionCall(
1393 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___tgt_target_kernel),
1394 Args: OffloadingArgs);
1395
1396 return Builder.saveIP();
1397}
1398
1399OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitKernelLaunch(
1400 const LocationDescription &Loc, Value *OutlinedFnID,
1401 EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
1402 Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP) {
1403
1404 if (!updateToLocation(Loc))
1405 return Loc.IP;
1406
1407 // On top of the arrays that were filled up, the target offloading call
1408 // takes as arguments the device id as well as the host pointer. The host
1409 // pointer is used by the runtime library to identify the current target
1410 // region, so it only has to be unique and not necessarily point to
1411 // anything. It could be the pointer to the outlined function that
1412 // implements the target region, but we aren't using that so that the
1413 // compiler doesn't need to keep that, and could therefore inline the host
1414 // function if proven worthwhile during optimization.
1415
1416 // From this point on, we need to have an ID of the target region defined.
1417 assert(OutlinedFnID && "Invalid outlined function ID!");
1418 (void)OutlinedFnID;
1419
1420 // Return value of the runtime offloading call.
1421 Value *Return = nullptr;
1422
1423 // Arguments for the target kernel.
1424 SmallVector<Value *> ArgsVector;
1425 getKernelArgsVector(KernelArgs&: Args, Builder, ArgsVector);
1426
1427 // The target region is an outlined function launched by the runtime
1428 // via calls to __tgt_target_kernel().
1429 //
1430 // Note that on the host and CPU targets, the runtime implementation of
1431 // these calls simply call the outlined function without forking threads.
1432 // The outlined functions themselves have runtime calls to
1433 // __kmpc_fork_teams() and __kmpc_fork() for this purpose, codegen'd by
1434 // the compiler in emitTeamsCall() and emitParallelCall().
1435 //
1436 // In contrast, on the NVPTX target, the implementation of
1437 // __tgt_target_teams() launches a GPU kernel with the requested number
1438 // of teams and threads so no additional calls to the runtime are required.
1439 // Check the error code and execute the host version if required.
1440 Builder.restoreIP(IP: emitTargetKernel(
1441 Loc: Builder, AllocaIP, Return, Ident: RTLoc, DeviceID, NumTeams: Args.NumTeams.front(),
1442 NumThreads: Args.NumThreads.front(), HostPtr: OutlinedFnID, KernelArgs: ArgsVector));
1443
1444 BasicBlock *OffloadFailedBlock =
1445 BasicBlock::Create(Context&: Builder.getContext(), Name: "omp_offload.failed");
1446 BasicBlock *OffloadContBlock =
1447 BasicBlock::Create(Context&: Builder.getContext(), Name: "omp_offload.cont");
1448 Value *Failed = Builder.CreateIsNotNull(Arg: Return);
1449 Builder.CreateCondBr(Cond: Failed, True: OffloadFailedBlock, False: OffloadContBlock);
1450
1451 auto CurFn = Builder.GetInsertBlock()->getParent();
1452 emitBlock(BB: OffloadFailedBlock, CurFn);
1453 InsertPointOrErrorTy AfterIP = EmitTargetCallFallbackCB(Builder.saveIP());
1454 if (!AfterIP)
1455 return AfterIP.takeError();
1456 Builder.restoreIP(IP: *AfterIP);
1457 emitBranch(Target: OffloadContBlock);
1458 emitBlock(BB: OffloadContBlock, CurFn, /*IsFinished=*/true);
1459 return Builder.saveIP();
1460}
1461
1462Error OpenMPIRBuilder::emitCancelationCheckImpl(
1463 Value *CancelFlag, omp::Directive CanceledDirective) {
1464 assert(isLastFinalizationInfoCancellable(CanceledDirective) &&
1465 "Unexpected cancellation!");
1466
1467 // For a cancel barrier we create two new blocks.
1468 BasicBlock *BB = Builder.GetInsertBlock();
1469 BasicBlock *NonCancellationBlock;
1470 if (Builder.GetInsertPoint() == BB->end()) {
1471 // TODO: This branch will not be needed once we moved to the
1472 // OpenMPIRBuilder codegen completely.
1473 NonCancellationBlock = BasicBlock::Create(
1474 Context&: BB->getContext(), Name: BB->getName() + ".cont", Parent: BB->getParent());
1475 } else {
1476 NonCancellationBlock = SplitBlock(Old: BB, SplitPt: &*Builder.GetInsertPoint());
1477 BB->getTerminator()->eraseFromParent();
1478 Builder.SetInsertPoint(BB);
1479 }
1480 BasicBlock *CancellationBlock = BasicBlock::Create(
1481 Context&: BB->getContext(), Name: BB->getName() + ".cncl", Parent: BB->getParent());
1482
1483 // Jump to them based on the return value.
1484 Value *Cmp = Builder.CreateIsNull(Arg: CancelFlag);
1485 Builder.CreateCondBr(Cond: Cmp, True: NonCancellationBlock, False: CancellationBlock,
1486 /* TODO weight */ BranchWeights: nullptr, Unpredictable: nullptr);
1487
1488 // From the cancellation block we finalize all variables and go to the
1489 // post finalization block that is known to the FiniCB callback.
1490 auto &FI = FinalizationStack.back();
1491 Expected<BasicBlock *> FiniBBOrErr = FI.getFiniBB(Builder);
1492 if (!FiniBBOrErr)
1493 return FiniBBOrErr.takeError();
1494 Builder.SetInsertPoint(CancellationBlock);
1495 Builder.CreateBr(Dest: *FiniBBOrErr);
1496
1497 // The continuation block is where code generation continues.
1498 Builder.SetInsertPoint(TheBB: NonCancellationBlock, IP: NonCancellationBlock->begin());
1499 return Error::success();
1500}
1501
1502/// Create wrapper function used to gather the outlined function's argument
1503/// structure from a shared buffer and to forward them to it when running in
1504/// Generic mode.
1505///
1506/// The outlined function is expected to receive 2 integer arguments followed by
1507/// an optional pointer argument to an argument structure holding the rest.
1508static Function *createTargetParallelWrapper(OpenMPIRBuilder *OMPIRBuilder,
1509 Function &OutlinedFn) {
1510 size_t NumArgs = OutlinedFn.arg_size();
1511 assert((NumArgs == 2 || NumArgs == 3) &&
1512 "expected a 2-3 argument parallel outlined function");
1513 bool UseArgStruct = NumArgs == 3;
1514
1515 IRBuilder<> &Builder = OMPIRBuilder->Builder;
1516 IRBuilder<>::InsertPointGuard IPG(Builder);
1517 auto *FnTy = FunctionType::get(Result: Builder.getVoidTy(),
1518 Params: {Builder.getInt16Ty(), Builder.getInt32Ty()},
1519 /*isVarArg=*/false);
1520 auto *WrapperFn =
1521 Function::Create(Ty: FnTy, Linkage: GlobalValue::InternalLinkage,
1522 N: OutlinedFn.getName() + ".wrapper", M&: OMPIRBuilder->M);
1523
1524 WrapperFn->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
1525 WrapperFn->addParamAttr(ArgNo: 0, Kind: Attribute::ZExt);
1526 WrapperFn->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
1527
1528 BasicBlock *EntryBB =
1529 BasicBlock::Create(Context&: OMPIRBuilder->M.getContext(), Name: "entry", Parent: WrapperFn);
1530 Builder.SetInsertPoint(EntryBB);
1531
1532 // Allocation.
1533 Value *AddrAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(),
1534 /*ArraySize=*/nullptr, Name: "addr");
1535 AddrAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
1536 V: AddrAlloca, DestTy: Builder.getPtrTy(/*AddrSpace=*/0),
1537 Name: AddrAlloca->getName() + ".ascast");
1538
1539 Value *ZeroAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(),
1540 /*ArraySize=*/nullptr, Name: "zero");
1541 ZeroAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
1542 V: ZeroAlloca, DestTy: Builder.getPtrTy(/*AddrSpace=*/0),
1543 Name: ZeroAlloca->getName() + ".ascast");
1544
1545 Value *ArgsAlloca = nullptr;
1546 if (UseArgStruct) {
1547 ArgsAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(),
1548 /*ArraySize=*/nullptr, Name: "global_args");
1549 ArgsAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
1550 V: ArgsAlloca, DestTy: Builder.getPtrTy(/*AddrSpace=*/0),
1551 Name: ArgsAlloca->getName() + ".ascast");
1552 }
1553
1554 // Initialization.
1555 Builder.CreateStore(Val: WrapperFn->getArg(i: 1), Ptr: AddrAlloca);
1556 Builder.CreateStore(Val: Builder.getInt32(C: 0), Ptr: ZeroAlloca);
1557 if (UseArgStruct) {
1558 Builder.CreateCall(
1559 Callee: OMPIRBuilder->getOrCreateRuntimeFunctionPtr(
1560 FnID: llvm::omp::RuntimeFunction::OMPRTL___kmpc_get_shared_variables),
1561 Args: {ArgsAlloca});
1562 }
1563
1564 SmallVector<Value *, 3> Args{AddrAlloca, ZeroAlloca};
1565
1566 // Load structArg from global_args.
1567 if (UseArgStruct) {
1568 Value *StructArg = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ArgsAlloca);
1569 StructArg = Builder.CreateInBoundsGEP(Ty: Builder.getPtrTy(), Ptr: StructArg,
1570 IdxList: {Builder.getInt64(C: 0)});
1571 StructArg = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: StructArg, Name: "structArg");
1572 Args.push_back(Elt: StructArg);
1573 }
1574
1575 // Call the outlined function holding the parallel body.
1576 Builder.CreateCall(Callee: &OutlinedFn, Args);
1577 Builder.CreateRetVoid();
1578
1579 return WrapperFn;
1580}
1581
1582// Callback used to create OpenMP runtime calls to support
1583// omp parallel clause for the device.
1584// We need to use this callback to replace call to the OutlinedFn in OuterFn
1585// by the call to the OpenMP DeviceRTL runtime function (kmpc_parallel_60)
1586static void targetParallelCallback(
1587 OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn, Function *OuterFn,
1588 BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
1589 Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1590 Value *ThreadID, const SmallVector<Instruction *, 4> &ToBeDeleted) {
1591 assert(OutlinedFn.arg_size() >= 2 &&
1592 "Expected at least tid and bounded tid as arguments");
1593 unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1594
1595 // Add some known attributes.
1596 IRBuilder<> &Builder = OMPIRBuilder->Builder;
1597 OutlinedFn.addParamAttr(ArgNo: 0, Kind: Attribute::NoAlias);
1598 OutlinedFn.addParamAttr(ArgNo: 1, Kind: Attribute::NoAlias);
1599 OutlinedFn.addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
1600 OutlinedFn.addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
1601 OutlinedFn.addFnAttr(Kind: Attribute::NoUnwind);
1602
1603 CallInst *CI = cast<CallInst>(Val: OutlinedFn.user_back());
1604 assert(CI && "Expected call instruction to outlined function");
1605 CI->getParent()->setName("omp_parallel");
1606
1607 Builder.SetInsertPoint(CI);
1608 Type *PtrTy = OMPIRBuilder->VoidPtr;
1609
1610 // Add alloca for kernel args
1611 OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP();
1612 Builder.SetInsertPoint(TheBB: OuterAllocaBB, IP: OuterAllocaBB->getFirstInsertionPt());
1613 AllocaInst *ArgsAlloca =
1614 Builder.CreateAlloca(Ty: ArrayType::get(ElementType: PtrTy, NumElements: NumCapturedVars));
1615 Value *Args = ArgsAlloca;
1616 // Add address space cast if array for storing arguments is not allocated
1617 // in address space 0
1618 if (ArgsAlloca->getAddressSpace())
1619 Args = Builder.CreatePointerCast(V: ArgsAlloca, DestTy: PtrTy);
1620 Builder.restoreIP(IP: CurrentIP);
1621
1622 // Store captured vars which are used by kmpc_parallel_60
1623 for (unsigned Idx = 0; Idx < NumCapturedVars; Idx++) {
1624 Value *V = *(CI->arg_begin() + 2 + Idx);
1625 Value *StoreAddress = Builder.CreateConstInBoundsGEP2_64(
1626 Ty: ArrayType::get(ElementType: PtrTy, NumElements: NumCapturedVars), Ptr: Args, Idx0: 0, Idx1: Idx);
1627 Builder.CreateStore(Val: V, Ptr: StoreAddress);
1628 }
1629
1630 Value *Cond =
1631 IfCondition ? Builder.CreateSExtOrTrunc(V: IfCondition, DestTy: OMPIRBuilder->Int32)
1632 : Builder.getInt32(C: 1);
1633 Value *NumThreadsArg =
1634 NumThreads ? Builder.CreateZExtOrTrunc(V: NumThreads, DestTy: OMPIRBuilder->Int32)
1635 : Builder.getInt32(C: -1);
1636
1637 // If this is not a Generic kernel, we can skip generating the wrapper.
1638 Value *WrapperFn;
1639 if (isGenericKernel(Fn&: *OuterFn))
1640 WrapperFn = createTargetParallelWrapper(OMPIRBuilder, OutlinedFn);
1641 else
1642 WrapperFn = Constant::getNullValue(Ty: PtrTy);
1643
1644 // Build kmpc_parallel_60 call
1645 Value *Parallel60CallArgs[] = {
1646 /* identifier*/ Ident,
1647 /* global thread num*/ ThreadID,
1648 /* if expression */ Cond,
1649 /* number of threads */ NumThreadsArg,
1650 /* Proc bind */ Builder.getInt32(C: -1),
1651 /* outlined function */ &OutlinedFn,
1652 /* wrapper function */ WrapperFn,
1653 /* arguments of the outlined funciton*/ Args,
1654 /* number of arguments */ Builder.getInt64(C: NumCapturedVars),
1655 /* strict for number of threads */ Builder.getInt32(C: 0)};
1656
1657 FunctionCallee RTLFn =
1658 OMPIRBuilder->getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_parallel_60);
1659
1660 OMPIRBuilder->createRuntimeFunctionCall(Callee: RTLFn, Args: Parallel60CallArgs);
1661
1662 LLVM_DEBUG(dbgs() << "With kmpc_parallel_60 placed: "
1663 << *Builder.GetInsertBlock()->getParent() << "\n");
1664
1665 // Initialize the local TID stack location with the argument value.
1666 Builder.SetInsertPoint(PrivTID);
1667 Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1668 Builder.CreateStore(Val: Builder.CreateLoad(Ty: OMPIRBuilder->Int32, Ptr: OutlinedAI),
1669 Ptr: PrivTIDAddr);
1670
1671 // Remove redundant call to the outlined function.
1672 CI->eraseFromParent();
1673
1674 for (Instruction *I : ToBeDeleted) {
1675 I->eraseFromParent();
1676 }
1677}
1678
1679// Callback used to create OpenMP runtime calls to support
1680// omp parallel clause for the host.
1681// We need to use this callback to replace call to the OutlinedFn in OuterFn
1682// by the call to the OpenMP host runtime function ( __kmpc_fork_call[_if])
1683static void
1684hostParallelCallback(OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn,
1685 Function *OuterFn, Value *Ident, Value *IfCondition,
1686 Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1687 const SmallVector<Instruction *, 4> &ToBeDeleted) {
1688 IRBuilder<> &Builder = OMPIRBuilder->Builder;
1689 FunctionCallee RTLFn;
1690 if (IfCondition) {
1691 RTLFn =
1692 OMPIRBuilder->getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_fork_call_if);
1693 } else {
1694 RTLFn =
1695 OMPIRBuilder->getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_fork_call);
1696 }
1697 if (auto *F = dyn_cast<Function>(Val: RTLFn.getCallee())) {
1698 if (!F->hasMetadata(KindID: LLVMContext::MD_callback)) {
1699 LLVMContext &Ctx = F->getContext();
1700 MDBuilder MDB(Ctx);
1701 // Annotate the callback behavior of the __kmpc_fork_call:
1702 // - The callback callee is argument number 2 (microtask).
1703 // - The first two arguments of the callback callee are unknown (-1).
1704 // - All variadic arguments to the __kmpc_fork_call are passed to the
1705 // callback callee.
1706 F->addMetadata(KindID: LLVMContext::MD_callback,
1707 MD&: *MDNode::get(Context&: Ctx, MDs: {MDB.createCallbackEncoding(
1708 CalleeArgNo: 2, Arguments: {-1, -1},
1709 /* VarArgsArePassed */ true)}));
1710 }
1711 }
1712 // Add some known attributes.
1713 OutlinedFn.addParamAttr(ArgNo: 0, Kind: Attribute::NoAlias);
1714 OutlinedFn.addParamAttr(ArgNo: 1, Kind: Attribute::NoAlias);
1715 OutlinedFn.addFnAttr(Kind: Attribute::NoUnwind);
1716
1717 assert(OutlinedFn.arg_size() >= 2 &&
1718 "Expected at least tid and bounded tid as arguments");
1719 unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1720
1721 CallInst *CI = cast<CallInst>(Val: OutlinedFn.user_back());
1722 CI->getParent()->setName("omp_parallel");
1723 Builder.SetInsertPoint(CI);
1724
1725 // Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn);
1726 Value *ForkCallArgs[] = {Ident, Builder.getInt32(C: NumCapturedVars),
1727 &OutlinedFn};
1728
1729 SmallVector<Value *, 16> RealArgs;
1730 RealArgs.append(in_start: std::begin(arr&: ForkCallArgs), in_end: std::end(arr&: ForkCallArgs));
1731 if (IfCondition) {
1732 Value *Cond = Builder.CreateSExtOrTrunc(V: IfCondition, DestTy: OMPIRBuilder->Int32);
1733 RealArgs.push_back(Elt: Cond);
1734 }
1735 RealArgs.append(in_start: CI->arg_begin() + /* tid & bound tid */ 2, in_end: CI->arg_end());
1736
1737 // __kmpc_fork_call_if always expects a void ptr as the last argument
1738 // If there are no arguments, pass a null pointer.
1739 auto PtrTy = OMPIRBuilder->VoidPtr;
1740 if (IfCondition && NumCapturedVars == 0) {
1741 Value *NullPtrValue = Constant::getNullValue(Ty: PtrTy);
1742 RealArgs.push_back(Elt: NullPtrValue);
1743 }
1744
1745 OMPIRBuilder->createRuntimeFunctionCall(Callee: RTLFn, Args: RealArgs);
1746
1747 LLVM_DEBUG(dbgs() << "With fork_call placed: "
1748 << *Builder.GetInsertBlock()->getParent() << "\n");
1749
1750 // Initialize the local TID stack location with the argument value.
1751 Builder.SetInsertPoint(PrivTID);
1752 Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1753 Builder.CreateStore(Val: Builder.CreateLoad(Ty: OMPIRBuilder->Int32, Ptr: OutlinedAI),
1754 Ptr: PrivTIDAddr);
1755
1756 // Remove redundant call to the outlined function.
1757 CI->eraseFromParent();
1758
1759 for (Instruction *I : ToBeDeleted) {
1760 I->eraseFromParent();
1761 }
1762}
1763
1764OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel(
1765 const LocationDescription &Loc, InsertPointTy OuterAllocIP,
1766 ArrayRef<BasicBlock *> OuterDeallocBlocks, BodyGenCallbackTy BodyGenCB,
1767 PrivatizeCallbackTy PrivCB, FinalizeCallbackTy FiniCB, Value *IfCondition,
1768 Value *NumThreads, omp::ProcBindKind ProcBind, bool IsCancellable) {
1769 assert(!isConflictIP(Loc.IP, OuterAllocIP) && "IPs must not be ambiguous");
1770
1771 if (!updateToLocation(Loc))
1772 return Loc.IP;
1773
1774 uint32_t SrcLocStrSize;
1775 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1776 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1777 const bool NeedThreadID = NumThreads || Config.isTargetDevice() ||
1778 (ProcBind != OMP_PROC_BIND_default);
1779 Value *ThreadID = NeedThreadID ? getOrCreateThreadID(Ident) : nullptr;
1780 // If we generate code for the target device, we need to allocate
1781 // struct for aggregate params in the device default alloca address space.
1782 // OpenMP runtime requires that the params of the extracted functions are
1783 // passed as zero address space pointers. This flag ensures that extracted
1784 // function arguments are declared in zero address space
1785 bool ArgsInZeroAddressSpace = Config.isTargetDevice();
1786
1787 // Build call __kmpc_push_num_threads(&Ident, global_tid, num_threads)
1788 // only if we compile for host side.
1789 if (NumThreads && !Config.isTargetDevice()) {
1790 Value *Args[] = {
1791 Ident, ThreadID,
1792 Builder.CreateIntCast(V: NumThreads, DestTy: Int32, /*isSigned*/ false)};
1793 createRuntimeFunctionCall(
1794 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_push_num_threads), Args);
1795 }
1796
1797 if (ProcBind != OMP_PROC_BIND_default) {
1798 // Build call __kmpc_push_proc_bind(&Ident, global_tid, proc_bind)
1799 Value *Args[] = {
1800 Ident, ThreadID,
1801 ConstantInt::get(Ty: Int32, V: unsigned(ProcBind), /*isSigned=*/IsSigned: true)};
1802 createRuntimeFunctionCall(
1803 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_push_proc_bind), Args);
1804 }
1805
1806 BasicBlock *InsertBB = Builder.GetInsertBlock();
1807 Function *OuterFn = InsertBB->getParent();
1808
1809 // Save the outer alloca block because the insertion iterator may get
1810 // invalidated and we still need this later.
1811 BasicBlock *OuterAllocaBlock = OuterAllocIP.getBlock();
1812
1813 // Vector to remember instructions we used only during the modeling but which
1814 // we want to delete at the end.
1815 SmallVector<Instruction *, 4> ToBeDeleted;
1816
1817 // Change the location to the outer alloca insertion point to create and
1818 // initialize the allocas we pass into the parallel region.
1819 InsertPointTy NewOuter(OuterAllocaBlock, OuterAllocaBlock->begin());
1820 Builder.restoreIP(IP: NewOuter);
1821 AllocaInst *TIDAddrAlloca = Builder.CreateAlloca(Ty: Int32, ArraySize: nullptr, Name: "tid.addr");
1822 AllocaInst *ZeroAddrAlloca =
1823 Builder.CreateAlloca(Ty: Int32, ArraySize: nullptr, Name: "zero.addr");
1824 Instruction *TIDAddr = TIDAddrAlloca;
1825 Instruction *ZeroAddr = ZeroAddrAlloca;
1826 if (ArgsInZeroAddressSpace && M.getDataLayout().getAllocaAddrSpace() != 0) {
1827 // Add additional casts to enforce pointers in zero address space
1828 TIDAddr = new AddrSpaceCastInst(
1829 TIDAddrAlloca, PointerType ::get(C&: M.getContext(), AddressSpace: 0), "tid.addr.ascast");
1830 TIDAddr->insertAfter(InsertPos: TIDAddrAlloca->getIterator());
1831 ToBeDeleted.push_back(Elt: TIDAddr);
1832 ZeroAddr = new AddrSpaceCastInst(ZeroAddrAlloca,
1833 PointerType ::get(C&: M.getContext(), AddressSpace: 0),
1834 "zero.addr.ascast");
1835 ZeroAddr->insertAfter(InsertPos: ZeroAddrAlloca->getIterator());
1836 ToBeDeleted.push_back(Elt: ZeroAddr);
1837 }
1838
1839 // We only need TIDAddr and ZeroAddr for modeling purposes to get the
1840 // associated arguments in the outlined function, so we delete them later.
1841 ToBeDeleted.push_back(Elt: TIDAddrAlloca);
1842 ToBeDeleted.push_back(Elt: ZeroAddrAlloca);
1843
1844 // Create an artificial insertion point that will also ensure the blocks we
1845 // are about to split are not degenerated.
1846 auto *UI = new UnreachableInst(Builder.getContext(), InsertBB);
1847
1848 BasicBlock *EntryBB = UI->getParent();
1849 BasicBlock *PRegEntryBB = EntryBB->splitBasicBlock(I: UI, BBName: "omp.par.entry");
1850 BasicBlock *PRegBodyBB = PRegEntryBB->splitBasicBlock(I: UI, BBName: "omp.par.region");
1851 BasicBlock *PRegPreFiniBB =
1852 PRegBodyBB->splitBasicBlock(I: UI, BBName: "omp.par.pre_finalize");
1853 BasicBlock *PRegExitBB = PRegPreFiniBB->splitBasicBlock(I: UI, BBName: "omp.par.exit");
1854
1855 auto FiniCBWrapper = [&](InsertPointTy IP) {
1856 // Hide "open-ended" blocks from the given FiniCB by setting the right jump
1857 // target to the region exit block.
1858 if (IP.getBlock()->end() == IP.getPoint()) {
1859 IRBuilder<>::InsertPointGuard IPG(Builder);
1860 Builder.restoreIP(IP);
1861 Instruction *I = Builder.CreateBr(Dest: PRegExitBB);
1862 IP = InsertPointTy(I->getParent(), I->getIterator());
1863 }
1864 assert(IP.getBlock()->getTerminator()->getNumSuccessors() == 1 &&
1865 IP.getBlock()->getTerminator()->getSuccessor(0) == PRegExitBB &&
1866 "Unexpected insertion point for finalization call!");
1867 return FiniCB(IP);
1868 };
1869
1870 FinalizationStack.push_back(Elt: {FiniCBWrapper, OMPD_parallel, IsCancellable});
1871
1872 // Generate the privatization allocas in the block that will become the entry
1873 // of the outlined function.
1874 Builder.SetInsertPoint(PRegEntryBB->getTerminator());
1875 InsertPointTy InnerAllocaIP = Builder.saveIP();
1876
1877 AllocaInst *PrivTIDAddr =
1878 Builder.CreateAlloca(Ty: Int32, ArraySize: nullptr, Name: "tid.addr.local");
1879 Instruction *PrivTID = Builder.CreateLoad(Ty: Int32, Ptr: PrivTIDAddr, Name: "tid");
1880
1881 // Add some fake uses for OpenMP provided arguments.
1882 ToBeDeleted.push_back(Elt: Builder.CreateLoad(Ty: Int32, Ptr: TIDAddr, Name: "tid.addr.use"));
1883 Instruction *ZeroAddrUse =
1884 Builder.CreateLoad(Ty: Int32, Ptr: ZeroAddr, Name: "zero.addr.use");
1885 ToBeDeleted.push_back(Elt: ZeroAddrUse);
1886
1887 // EntryBB
1888 // |
1889 // V
1890 // PRegionEntryBB <- Privatization allocas are placed here.
1891 // |
1892 // V
1893 // PRegionBodyBB <- BodeGen is invoked here.
1894 // |
1895 // V
1896 // PRegPreFiniBB <- The block we will start finalization from.
1897 // |
1898 // V
1899 // PRegionExitBB <- A common exit to simplify block collection.
1900 //
1901
1902 LLVM_DEBUG(dbgs() << "Before body codegen: " << *OuterFn << "\n");
1903
1904 // Let the caller create the body.
1905 assert(BodyGenCB && "Expected body generation callback!");
1906 InsertPointTy CodeGenIP(PRegBodyBB, PRegBodyBB->begin());
1907 if (Error Err = BodyGenCB(InnerAllocaIP, CodeGenIP, PRegExitBB))
1908 return Err;
1909
1910 LLVM_DEBUG(dbgs() << "After body codegen: " << *OuterFn << "\n");
1911
1912 // If OuterFn is a Generic kernel, we need to use device shared memory to
1913 // allocate argument structures. Otherwise, we use stack allocations as usual.
1914 bool UsesDeviceSharedMemory =
1915 Config.isTargetDevice() && isGenericKernel(Fn&: *OuterFn);
1916 std::unique_ptr<OutlineInfo> OI =
1917 UsesDeviceSharedMemory
1918 ? std::make_unique<DeviceSharedMemOutlineInfo>(args&: *this)
1919 : std::make_unique<OutlineInfo>();
1920
1921 if (Config.isTargetDevice()) {
1922 // Generate OpenMP target specific runtime call
1923 OI->PostOutlineCB = [=, ToBeDeletedVec =
1924 std::move(ToBeDeleted)](Function &OutlinedFn) {
1925 targetParallelCallback(OMPIRBuilder: this, OutlinedFn, OuterFn, OuterAllocaBB: OuterAllocaBlock, Ident,
1926 IfCondition, NumThreads, PrivTID, PrivTIDAddr,
1927 ThreadID, ToBeDeleted: ToBeDeletedVec);
1928 };
1929 } else {
1930 // Generate OpenMP host runtime call
1931 OI->PostOutlineCB = [=, ToBeDeletedVec =
1932 std::move(ToBeDeleted)](Function &OutlinedFn) {
1933 hostParallelCallback(OMPIRBuilder: this, OutlinedFn, OuterFn, Ident, IfCondition,
1934 PrivTID, PrivTIDAddr, ToBeDeleted: ToBeDeletedVec);
1935 };
1936 }
1937
1938 OI->FixUpNonEntryAllocas = true;
1939 OI->OuterAllocBB = OuterAllocaBlock;
1940 OI->EntryBB = PRegEntryBB;
1941 OI->ExitBB = PRegExitBB;
1942 OI->OuterDeallocBBs.reserve(N: OuterDeallocBlocks.size());
1943 copy(Range&: OuterDeallocBlocks, Out: OI->OuterDeallocBBs.end());
1944
1945 SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
1946 SmallVector<BasicBlock *, 32> Blocks;
1947 OI->collectBlocks(BlockSet&: ParallelRegionBlockSet, BlockVector&: Blocks);
1948
1949 CodeExtractorAnalysisCache CEAC(*OuterFn);
1950 CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
1951 /* AggregateArgs */ false,
1952 /* BlockFrequencyInfo */ nullptr,
1953 /* BranchProbabilityInfo */ nullptr,
1954 /* AssumptionCache */ nullptr,
1955 /* AllowVarArgs */ true,
1956 /* AllowAlloca */ true,
1957 /* AllocationBlock */ OuterAllocaBlock,
1958 /* DeallocationBlocks */ {},
1959 /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
1960
1961 // Find inputs to, outputs from the code region.
1962 BasicBlock *CommonExit = nullptr;
1963 SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
1964 Extractor.findAllocas(CEAC, SinkCands&: SinkingCands, HoistCands&: HoistingCands, ExitBlock&: CommonExit);
1965
1966 Extractor.findInputsOutputs(Inputs, Outputs, Allocas: SinkingCands,
1967 /*CollectGlobalInputs=*/true);
1968
1969 Inputs.remove_if(P: [&](Value *I) {
1970 if (auto *GV = dyn_cast_if_present<GlobalVariable>(Val: I))
1971 return GV->getValueType() == OpenMPIRBuilder::Ident;
1972
1973 return false;
1974 });
1975
1976 LLVM_DEBUG(dbgs() << "Before privatization: " << *OuterFn << "\n");
1977
1978 FunctionCallee TIDRTLFn =
1979 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_global_thread_num);
1980
1981 auto PrivHelper = [&](Value &V) -> Error {
1982 if (&V == TIDAddr || &V == ZeroAddr) {
1983 OI->ExcludeArgsFromAggregate.push_back(Elt: &V);
1984 return Error::success();
1985 }
1986
1987 SetVector<Use *> Uses;
1988 for (Use &U : V.uses())
1989 if (auto *UserI = dyn_cast<Instruction>(Val: U.getUser()))
1990 if (ParallelRegionBlockSet.count(Ptr: UserI->getParent()))
1991 Uses.insert(X: &U);
1992
1993 // __kmpc_fork_call expects extra arguments as pointers. If the input
1994 // already has a pointer type, everything is fine. Otherwise, store the
1995 // value onto stack and load it back inside the to-be-outlined region. This
1996 // will ensure only the pointer will be passed to the function.
1997 // FIXME: if there are more than 15 trailing arguments, they must be
1998 // additionally packed in a struct.
1999 Value *Inner = &V;
2000 if (!V.getType()->isPointerTy()) {
2001 IRBuilder<>::InsertPointGuard Guard(Builder);
2002 LLVM_DEBUG(llvm::dbgs() << "Forwarding input as pointer: " << V << "\n");
2003
2004 Builder.restoreIP(IP: OuterAllocIP);
2005 Value *Ptr;
2006 if (UsesDeviceSharedMemory) {
2007 // Use device shared memory instead, if needed.
2008 Ptr = createOMPAllocShared(Loc: OuterAllocIP, VarType: V.getType(),
2009 Name: V.getName() + ".reloaded");
2010 for (BasicBlock *DeallocBlock : OuterDeallocBlocks)
2011 createOMPFreeShared(
2012 Loc: InsertPointTy(DeallocBlock, DeallocBlock->getFirstInsertionPt()),
2013 Addr: Ptr, VarType: V.getType());
2014 } else {
2015 Ptr = Builder.CreateAlloca(Ty: V.getType(), ArraySize: nullptr,
2016 Name: V.getName() + ".reloaded");
2017 }
2018
2019 // Store to stack at end of the block that currently branches to the entry
2020 // block of the to-be-outlined region.
2021 Builder.SetInsertPoint(TheBB: InsertBB,
2022 IP: InsertBB->getTerminator()->getIterator());
2023 Builder.CreateStore(Val: &V, Ptr);
2024
2025 // Load back next to allocations in the to-be-outlined region.
2026 Builder.restoreIP(IP: InnerAllocaIP);
2027 Inner = Builder.CreateLoad(Ty: V.getType(), Ptr);
2028 }
2029
2030 Value *ReplacementValue = nullptr;
2031 CallInst *CI = dyn_cast<CallInst>(Val: &V);
2032 if (CI && CI->getCalledFunction() == TIDRTLFn.getCallee()) {
2033 ReplacementValue = PrivTID;
2034 } else {
2035 InsertPointOrErrorTy AfterIP =
2036 PrivCB(InnerAllocaIP, Builder.saveIP(), V, *Inner, ReplacementValue);
2037 if (!AfterIP)
2038 return AfterIP.takeError();
2039 Builder.restoreIP(IP: *AfterIP);
2040 InnerAllocaIP = {
2041 InnerAllocaIP.getBlock(),
2042 InnerAllocaIP.getBlock()->getTerminator()->getIterator()};
2043
2044 assert(ReplacementValue &&
2045 "Expected copy/create callback to set replacement value!");
2046 if (ReplacementValue == &V)
2047 return Error::success();
2048 }
2049
2050 for (Use *UPtr : Uses)
2051 UPtr->set(ReplacementValue);
2052
2053 return Error::success();
2054 };
2055
2056 // Reset the inner alloca insertion as it will be used for loading the values
2057 // wrapped into pointers before passing them into the to-be-outlined region.
2058 // Configure it to insert immediately after the fake use of zero address so
2059 // that they are available in the generated body and so that the
2060 // OpenMP-related values (thread ID and zero address pointers) remain leading
2061 // in the argument list.
2062 InnerAllocaIP = IRBuilder<>::InsertPoint(
2063 ZeroAddrUse->getParent(), ZeroAddrUse->getNextNode()->getIterator());
2064
2065 // Reset the outer alloca insertion point to the entry of the relevant block
2066 // in case it was invalidated.
2067 OuterAllocIP = IRBuilder<>::InsertPoint(
2068 OuterAllocaBlock, OuterAllocaBlock->getFirstInsertionPt());
2069
2070 for (Value *Input : Inputs) {
2071 LLVM_DEBUG(dbgs() << "Captured input: " << *Input << "\n");
2072 if (Error Err = PrivHelper(*Input))
2073 return Err;
2074 }
2075 LLVM_DEBUG({
2076 for (Value *Output : Outputs)
2077 LLVM_DEBUG(dbgs() << "Captured output: " << *Output << "\n");
2078 });
2079 assert(Outputs.empty() &&
2080 "OpenMP outlining should not produce live-out values!");
2081
2082 LLVM_DEBUG(dbgs() << "After privatization: " << *OuterFn << "\n");
2083 LLVM_DEBUG({
2084 for (auto *BB : Blocks)
2085 dbgs() << " PBR: " << BB->getName() << "\n";
2086 });
2087
2088 // Adjust the finalization stack, verify the adjustment, and call the
2089 // finalize function a last time to finalize values between the pre-fini
2090 // block and the exit block if we left the parallel "the normal way".
2091 auto FiniInfo = FinalizationStack.pop_back_val();
2092 (void)FiniInfo;
2093 assert(FiniInfo.DK == OMPD_parallel &&
2094 "Unexpected finalization stack state!");
2095
2096 Instruction *PRegPreFiniTI = PRegPreFiniBB->getTerminator();
2097
2098 InsertPointTy PreFiniIP(PRegPreFiniBB, PRegPreFiniTI->getIterator());
2099 Expected<BasicBlock *> FiniBBOrErr = FiniInfo.getFiniBB(Builder);
2100 if (!FiniBBOrErr)
2101 return FiniBBOrErr.takeError();
2102 {
2103 IRBuilderBase::InsertPointGuard Guard(Builder);
2104 Builder.restoreIP(IP: PreFiniIP);
2105 Builder.CreateBr(Dest: *FiniBBOrErr);
2106 // There's currently a branch to omp.par.exit. Delete it. We will get there
2107 // via the fini block
2108 if (Instruction *Term = Builder.GetInsertBlock()->getTerminator())
2109 Term->eraseFromParent();
2110 }
2111
2112 // Register the outlined info.
2113 addOutlineInfo(OI: std::move(OI));
2114
2115 InsertPointTy AfterIP(UI->getParent(), UI->getParent()->end());
2116 UI->eraseFromParent();
2117
2118 return AfterIP;
2119}
2120
2121void OpenMPIRBuilder::emitFlush(const LocationDescription &Loc) {
2122 // Build call void __kmpc_flush(ident_t *loc)
2123 uint32_t SrcLocStrSize;
2124 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2125 Value *Args[] = {getOrCreateIdent(SrcLocStr, SrcLocStrSize)};
2126
2127 createRuntimeFunctionCall(Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_flush),
2128 Args);
2129}
2130
2131void OpenMPIRBuilder::createFlush(const LocationDescription &Loc) {
2132 if (!updateToLocation(Loc))
2133 return;
2134 emitFlush(Loc);
2135}
2136
2137void OpenMPIRBuilder::emitTaskwaitImpl(const LocationDescription &Loc) {
2138 // Build call kmp_int32 __kmpc_omp_taskwait(ident_t *loc, kmp_int32
2139 // global_tid);
2140 uint32_t SrcLocStrSize;
2141 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2142 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2143 Value *Args[] = {Ident, getOrCreateThreadID(Ident)};
2144
2145 // Ignore return result until untied tasks are supported.
2146 createRuntimeFunctionCall(
2147 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_taskwait), Args);
2148}
2149
2150void OpenMPIRBuilder::createTaskwait(const LocationDescription &Loc) {
2151 if (!updateToLocation(Loc))
2152 return;
2153 emitTaskwaitImpl(Loc);
2154}
2155
2156void OpenMPIRBuilder::emitTaskyieldImpl(const LocationDescription &Loc) {
2157 // Build call __kmpc_omp_taskyield(loc, thread_id, 0);
2158 uint32_t SrcLocStrSize;
2159 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2160 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2161 Constant *I32Null = ConstantInt::getNullValue(Ty: Int32);
2162 Value *Args[] = {Ident, getOrCreateThreadID(Ident), I32Null};
2163
2164 createRuntimeFunctionCall(
2165 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_taskyield), Args);
2166}
2167
2168void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
2169 if (!updateToLocation(Loc))
2170 return;
2171 emitTaskyieldImpl(Loc);
2172}
2173
2174void OpenMPIRBuilder::emitTaskDependency(IRBuilderBase &Builder, Value *Entry,
2175 const DependData &Dep) {
2176 // Store the pointer to the variable
2177 Value *Addr = Builder.CreateStructGEP(
2178 Ty: DependInfo, Ptr: Entry,
2179 Idx: static_cast<unsigned int>(RTLDependInfoFields::BaseAddr));
2180 Value *DepValPtr = Builder.CreatePtrToInt(V: Dep.DepVal, DestTy: SizeTy);
2181 Builder.CreateStore(Val: DepValPtr, Ptr: Addr);
2182 // Store the size of the variable
2183 Value *Size = Builder.CreateStructGEP(
2184 Ty: DependInfo, Ptr: Entry, Idx: static_cast<unsigned int>(RTLDependInfoFields::Len));
2185 Builder.CreateStore(
2186 Val: ConstantInt::get(Ty: SizeTy,
2187 V: M.getDataLayout().getTypeStoreSize(Ty: Dep.DepValueType)),
2188 Ptr: Size);
2189 // Store the dependency kind
2190 Value *Flags = Builder.CreateStructGEP(
2191 Ty: DependInfo, Ptr: Entry, Idx: static_cast<unsigned int>(RTLDependInfoFields::Flags));
2192 Builder.CreateStore(Val: ConstantInt::get(Ty: Builder.getInt8Ty(),
2193 V: static_cast<unsigned int>(Dep.DepKind)),
2194 Ptr: Flags);
2195}
2196
2197// Processes the dependencies in Dependencies and does the following
2198// - Allocates space on the stack of an array of DependInfo objects
2199// - Populates each DependInfo object with relevant information of
2200// the corresponding dependence.
2201// - All code is inserted in the entry block of the current function.
2202static Value *emitTaskDependencies(
2203 OpenMPIRBuilder &OMPBuilder,
2204 const SmallVectorImpl<OpenMPIRBuilder::DependData> &Dependencies) {
2205 // Early return if we have no dependencies to process
2206 if (Dependencies.empty())
2207 return nullptr;
2208
2209 // Given a vector of DependData objects, in this function we create an
2210 // array on the stack that holds kmp_depend_info objects corresponding
2211 // to each dependency. This is then passed to the OpenMP runtime.
2212 // For example, if there are 'n' dependencies then the following psedo
2213 // code is generated. Assume the first dependence is on a variable 'a'
2214 //
2215 // \code{c}
2216 // DepArray = alloc(n x sizeof(kmp_depend_info);
2217 // idx = 0;
2218 // DepArray[idx].base_addr = ptrtoint(&a);
2219 // DepArray[idx].len = 8;
2220 // DepArray[idx].flags = Dep.DepKind; /*(See OMPContants.h for DepKind)*/
2221 // ++idx;
2222 // DepArray[idx].base_addr = ...;
2223 // \endcode
2224
2225 IRBuilderBase &Builder = OMPBuilder.Builder;
2226 Type *DependInfo = OMPBuilder.DependInfo;
2227
2228 Value *DepArray = nullptr;
2229 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
2230 Builder.SetInsertPoint(
2231 OldIP.getBlock()->getParent()->getEntryBlock().getTerminator());
2232
2233 Type *DepArrayTy = ArrayType::get(ElementType: DependInfo, NumElements: Dependencies.size());
2234 DepArray = Builder.CreateAlloca(Ty: DepArrayTy, ArraySize: nullptr, Name: ".dep.arr.addr");
2235
2236 Builder.restoreIP(IP: OldIP);
2237
2238 for (const auto &[DepIdx, Dep] : enumerate(First: Dependencies)) {
2239 Value *Base =
2240 Builder.CreateConstInBoundsGEP2_64(Ty: DepArrayTy, Ptr: DepArray, Idx0: 0, Idx1: DepIdx);
2241 OMPBuilder.emitTaskDependency(Builder, Entry: Base, Dep);
2242 }
2243 return DepArray;
2244}
2245
2246/// Create the task duplication function passed to kmpc_taskloop.
2247Expected<Value *> OpenMPIRBuilder::createTaskDuplicationFunction(
2248 Type *PrivatesTy, int32_t PrivatesIndex, TaskDupCallbackTy DupCB) {
2249 unsigned ProgramAddressSpace = M.getDataLayout().getProgramAddressSpace();
2250 if (!DupCB)
2251 return Constant::getNullValue(
2252 Ty: PointerType::get(C&: Builder.getContext(), AddressSpace: ProgramAddressSpace));
2253
2254 // From OpenMP Runtime p_task_dup_t:
2255 // Routine optionally generated by the compiler for setting the lastprivate
2256 // flag and calling needed constructors for private/firstprivate objects (used
2257 // to form taskloop tasks from pattern task) Parameters: dest task, src task,
2258 // lastprivate flag.
2259 // typedef void (*p_task_dup_t)(kmp_task_t *, kmp_task_t *, kmp_int32);
2260
2261 auto *VoidPtrTy = PointerType::get(C&: Builder.getContext(), AddressSpace: ProgramAddressSpace);
2262
2263 FunctionType *DupFuncTy = FunctionType::get(
2264 Result: Builder.getVoidTy(), Params: {VoidPtrTy, VoidPtrTy, Builder.getInt32Ty()},
2265 /*isVarArg=*/false);
2266
2267 Function *DupFunction = Function::Create(Ty: DupFuncTy, Linkage: Function::InternalLinkage,
2268 N: "omp_taskloop_dup", M);
2269 Value *DestTaskArg = DupFunction->getArg(i: 0);
2270 Value *SrcTaskArg = DupFunction->getArg(i: 1);
2271 Value *LastprivateFlagArg = DupFunction->getArg(i: 2);
2272 DestTaskArg->setName("dest_task");
2273 SrcTaskArg->setName("src_task");
2274 LastprivateFlagArg->setName("lastprivate_flag");
2275
2276 IRBuilderBase::InsertPointGuard Guard(Builder);
2277 Builder.SetInsertPoint(
2278 BasicBlock::Create(Context&: Builder.getContext(), Name: "entry", Parent: DupFunction));
2279
2280 auto GetTaskContextPtrFromArg = [&](Value *Arg) -> Value * {
2281 Type *TaskWithPrivatesTy =
2282 StructType::get(Context&: Builder.getContext(), Elements: {Task, PrivatesTy});
2283 Value *TaskPrivates = Builder.CreateGEP(
2284 Ty: TaskWithPrivatesTy, Ptr: Arg, IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 1)});
2285 Value *ContextPtr = Builder.CreateGEP(
2286 Ty: PrivatesTy, Ptr: TaskPrivates,
2287 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: PrivatesIndex)});
2288 return ContextPtr;
2289 };
2290
2291 Value *DestTaskContextPtr = GetTaskContextPtrFromArg(DestTaskArg);
2292 Value *SrcTaskContextPtr = GetTaskContextPtrFromArg(SrcTaskArg);
2293
2294 DestTaskContextPtr->setName("destPtr");
2295 SrcTaskContextPtr->setName("srcPtr");
2296
2297 InsertPointTy AllocaIP(&DupFunction->getEntryBlock(),
2298 DupFunction->getEntryBlock().begin());
2299 InsertPointTy CodeGenIP = Builder.saveIP();
2300 Expected<IRBuilderBase::InsertPoint> AfterIPOrError =
2301 DupCB(AllocaIP, CodeGenIP, DestTaskContextPtr, SrcTaskContextPtr);
2302 if (!AfterIPOrError)
2303 return AfterIPOrError.takeError();
2304 Builder.restoreIP(IP: *AfterIPOrError);
2305
2306 Builder.CreateRetVoid();
2307
2308 return DupFunction;
2309}
2310
2311OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTaskloop(
2312 const LocationDescription &Loc, InsertPointTy AllocaIP,
2313 ArrayRef<BasicBlock *> DeallocBlocks, BodyGenCallbackTy BodyGenCB,
2314 llvm::function_ref<llvm::Expected<llvm::CanonicalLoopInfo *>()> LoopInfo,
2315 Value *LBVal, Value *UBVal, Value *StepVal, bool Untied, Value *IfCond,
2316 Value *GrainSize, bool NoGroup, int Sched, Value *Final, bool Mergeable,
2317 Value *Priority, uint64_t NumOfCollapseLoops, TaskDupCallbackTy DupCB,
2318 Value *TaskContextStructPtrVal) {
2319
2320 if (!updateToLocation(Loc))
2321 return InsertPointTy();
2322
2323 uint32_t SrcLocStrSize;
2324 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2325 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2326
2327 BasicBlock *TaskloopExitBB =
2328 splitBB(Builder, /*CreateBranch=*/true, Name: "taskloop.exit");
2329 BasicBlock *TaskloopBodyBB =
2330 splitBB(Builder, /*CreateBranch=*/true, Name: "taskloop.body");
2331 BasicBlock *TaskloopAllocaBB =
2332 splitBB(Builder, /*CreateBranch=*/true, Name: "taskloop.alloca");
2333
2334 InsertPointTy TaskloopAllocaIP =
2335 InsertPointTy(TaskloopAllocaBB, TaskloopAllocaBB->begin());
2336 InsertPointTy TaskloopBodyIP =
2337 InsertPointTy(TaskloopBodyBB, TaskloopBodyBB->begin());
2338
2339 if (Error Err = BodyGenCB(TaskloopAllocaIP, TaskloopBodyIP, TaskloopExitBB))
2340 return Err;
2341
2342 llvm::Expected<llvm::CanonicalLoopInfo *> result = LoopInfo();
2343 if (!result) {
2344 return result.takeError();
2345 }
2346
2347 llvm::CanonicalLoopInfo *CLI = result.get();
2348 auto OI = std::make_unique<OutlineInfo>();
2349 OI->EntryBB = TaskloopAllocaBB;
2350 OI->OuterAllocBB = AllocaIP.getBlock();
2351 OI->ExitBB = TaskloopExitBB;
2352 OI->OuterDeallocBBs.reserve(N: DeallocBlocks.size());
2353 copy(Range&: DeallocBlocks, Out: OI->OuterDeallocBBs.end());
2354
2355 // Add the thread ID argument.
2356 SmallVector<Instruction *> ToBeDeleted;
2357 // dummy instruction to be used as a fake argument
2358 OI->ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
2359 Builder, OuterAllocaIP: AllocaIP, ToBeDeleted, InnerAllocaIP: TaskloopAllocaIP, Name: "global.tid", AsPtr: false));
2360 Value *FakeLB = createFakeIntVal(Builder, OuterAllocaIP: AllocaIP, ToBeDeleted,
2361 InnerAllocaIP: TaskloopAllocaIP, Name: "lb", AsPtr: false, Is64Bit: true);
2362 Value *FakeUB = createFakeIntVal(Builder, OuterAllocaIP: AllocaIP, ToBeDeleted,
2363 InnerAllocaIP: TaskloopAllocaIP, Name: "ub", AsPtr: false, Is64Bit: true);
2364 Value *FakeStep = createFakeIntVal(Builder, OuterAllocaIP: AllocaIP, ToBeDeleted,
2365 InnerAllocaIP: TaskloopAllocaIP, Name: "step", AsPtr: false, Is64Bit: true);
2366 // For Taskloop, we want to force the bounds being the first 3 inputs in the
2367 // aggregate struct
2368 OI->Inputs.insert(X: FakeLB);
2369 OI->Inputs.insert(X: FakeUB);
2370 OI->Inputs.insert(X: FakeStep);
2371 if (TaskContextStructPtrVal)
2372 OI->Inputs.insert(X: TaskContextStructPtrVal);
2373 assert(((TaskContextStructPtrVal && DupCB) ||
2374 (!TaskContextStructPtrVal && !DupCB)) &&
2375 "Task context struct ptr and duplication callback must be both set "
2376 "or both null");
2377
2378 // It isn't safe to run the duplication bodygen callback inside the post
2379 // outlining callback so this has to be run now before we know the real task
2380 // shareds structure type.
2381 unsigned ProgramAddressSpace = M.getDataLayout().getProgramAddressSpace();
2382 Type *PointerTy = PointerType::get(C&: Builder.getContext(), AddressSpace: ProgramAddressSpace);
2383 Type *FakeSharedsTy = StructType::get(
2384 Context&: Builder.getContext(),
2385 Elements: {FakeLB->getType(), FakeUB->getType(), FakeStep->getType(), PointerTy});
2386 Expected<Value *> TaskDupFnOrErr = createTaskDuplicationFunction(
2387 PrivatesTy: FakeSharedsTy,
2388 /*PrivatesIndex: the pointer after the three indices above*/ PrivatesIndex: 3, DupCB);
2389 if (!TaskDupFnOrErr) {
2390 return TaskDupFnOrErr.takeError();
2391 }
2392 Value *TaskDupFn = *TaskDupFnOrErr;
2393
2394 OI->PostOutlineCB = [this, Ident, LBVal, UBVal, StepVal, Untied,
2395 TaskloopAllocaBB, CLI, TaskDupFn, ToBeDeleted, IfCond,
2396 GrainSize, NoGroup, Sched, FakeLB, FakeUB, FakeStep,
2397 FakeSharedsTy, Final, Mergeable, Priority,
2398 NumOfCollapseLoops](Function &OutlinedFn) mutable {
2399 // Replace the Stale CI by appropriate RTL function call.
2400 assert(OutlinedFn.hasOneUse() &&
2401 "there must be a single user for the outlined function");
2402 CallInst *StaleCI = cast<CallInst>(Val: OutlinedFn.user_back());
2403
2404 /* Create the casting for the Bounds Values that can be used when outlining
2405 * to replace the uses of the fakes with real values */
2406 BasicBlock *CodeReplBB = StaleCI->getParent();
2407 Builder.SetInsertPoint(CodeReplBB->getFirstInsertionPt());
2408 Value *CastedLBVal =
2409 Builder.CreateIntCast(V: LBVal, DestTy: Builder.getInt64Ty(), isSigned: true, Name: "lb64");
2410 Value *CastedUBVal =
2411 Builder.CreateIntCast(V: UBVal, DestTy: Builder.getInt64Ty(), isSigned: true, Name: "ub64");
2412 Value *CastedStepVal =
2413 Builder.CreateIntCast(V: StepVal, DestTy: Builder.getInt64Ty(), isSigned: true, Name: "step64");
2414
2415 Builder.SetInsertPoint(StaleCI);
2416
2417 // Gather the arguments for emitting the runtime call for
2418 // @__kmpc_omp_task_alloc
2419 Function *TaskAllocFn =
2420 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_alloc);
2421
2422 Value *ThreadID = getOrCreateThreadID(Ident);
2423
2424 if (!NoGroup) {
2425 // Emit runtime call for @__kmpc_taskgroup
2426 Function *TaskgroupFn =
2427 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_taskgroup);
2428 Builder.CreateCall(Callee: TaskgroupFn, Args: {Ident, ThreadID});
2429 }
2430
2431 // `flags` Argument Configuration
2432 // Task is tied if (Flags & 1) == 1.
2433 // Task is untied if (Flags & 1) == 0.
2434 // Task is final if (Flags & 2) == 2.
2435 // Task is not final if (Flags & 2) == 0.
2436 // Task is mergeable if (Flags & 4) == 4.
2437 // Task is not mergeable if (Flags & 4) == 0.
2438 // Task is priority if (Flags & 32) == 32.
2439 // Task is not priority if (Flags & 32) == 0.
2440 Value *Flags = Builder.getInt32(C: Untied ? 0 : 1);
2441 if (Final)
2442 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 2), RHS: Flags);
2443 if (Mergeable)
2444 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 4), RHS: Flags);
2445 if (Priority)
2446 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 32), RHS: Flags);
2447
2448 Value *TaskSize = Builder.getInt64(
2449 C: divideCeil(Numerator: M.getDataLayout().getTypeSizeInBits(Ty: Task), Denominator: 8));
2450
2451 AllocaInst *ArgStructAlloca =
2452 dyn_cast<AllocaInst>(Val: StaleCI->getArgOperand(i: 1));
2453 assert(ArgStructAlloca &&
2454 "Unable to find the alloca instruction corresponding to arguments "
2455 "for extracted function");
2456 std::optional<TypeSize> ArgAllocSize =
2457 ArgStructAlloca->getAllocationSize(DL: M.getDataLayout());
2458 assert(ArgAllocSize &&
2459 "Unable to determine size of arguments for extracted function");
2460 Value *SharedsSize = Builder.getInt64(C: ArgAllocSize->getFixedValue());
2461
2462 // Emit the @__kmpc_omp_task_alloc runtime call
2463 // The runtime call returns a pointer to an area where the task captured
2464 // variables must be copied before the task is run (TaskData)
2465 CallInst *TaskData = Builder.CreateCall(
2466 Callee: TaskAllocFn, Args: {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
2467 /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
2468 /*task_func=*/&OutlinedFn});
2469
2470 Value *Shareds = StaleCI->getArgOperand(i: 1);
2471 Align Alignment = TaskData->getPointerAlignment(DL: M.getDataLayout());
2472 Value *TaskShareds = Builder.CreateLoad(Ty: VoidPtr, Ptr: TaskData);
2473 Builder.CreateMemCpy(Dst: TaskShareds, DstAlign: Alignment, Src: Shareds, SrcAlign: Alignment,
2474 Size: SharedsSize);
2475 // Get the pointer to loop lb, ub, step from task ptr
2476 // and set up the lowerbound,upperbound and step values
2477 llvm::Value *Lb = Builder.CreateGEP(
2478 Ty: FakeSharedsTy, Ptr: TaskShareds, IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 0)});
2479
2480 llvm::Value *Ub = Builder.CreateGEP(
2481 Ty: FakeSharedsTy, Ptr: TaskShareds, IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 1)});
2482
2483 llvm::Value *Step = Builder.CreateGEP(
2484 Ty: FakeSharedsTy, Ptr: TaskShareds, IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 2)});
2485 llvm::Value *Loadstep = Builder.CreateLoad(Ty: Builder.getInt64Ty(), Ptr: Step);
2486
2487 // set up the arguments for emitting kmpc_taskloop runtime call
2488 // setting values for ifval, nogroup, sched, grainsize, task_dup
2489 Value *IfCondVal =
2490 IfCond ? Builder.CreateIntCast(V: IfCond, DestTy: Builder.getInt32Ty(), isSigned: true)
2491 : Builder.getInt32(C: 1);
2492 // As __kmpc_taskgroup is called manually in OMPIRBuilder, NoGroupVal should
2493 // always be 1 when calling __kmpc_taskloop to ensure it is not called again
2494 Value *NoGroupVal = Builder.getInt32(C: 1);
2495 Value *SchedVal = Builder.getInt32(C: Sched);
2496 Value *GrainSizeVal =
2497 GrainSize ? Builder.CreateIntCast(V: GrainSize, DestTy: Builder.getInt64Ty(), isSigned: true)
2498 : Builder.getInt64(C: 0);
2499 Value *TaskDup = TaskDupFn;
2500
2501 Value *Args[] = {Ident, ThreadID, TaskData, IfCondVal, Lb, Ub,
2502 Loadstep, NoGroupVal, SchedVal, GrainSizeVal, TaskDup};
2503
2504 // taskloop runtime call
2505 Function *TaskloopFn =
2506 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_taskloop);
2507 Builder.CreateCall(Callee: TaskloopFn, Args);
2508
2509 // Emit the @__kmpc_end_taskgroup runtime call to end the taskgroup if
2510 // nogroup is not defined
2511 if (!NoGroup) {
2512 Function *EndTaskgroupFn =
2513 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_taskgroup);
2514 Builder.CreateCall(Callee: EndTaskgroupFn, Args: {Ident, ThreadID});
2515 }
2516
2517 StaleCI->eraseFromParent();
2518
2519 Builder.SetInsertPoint(TheBB: TaskloopAllocaBB, IP: TaskloopAllocaBB->begin());
2520
2521 LoadInst *SharedsOutlined =
2522 Builder.CreateLoad(Ty: VoidPtr, Ptr: OutlinedFn.getArg(i: 1));
2523 OutlinedFn.getArg(i: 1)->replaceUsesWithIf(
2524 New: SharedsOutlined,
2525 ShouldReplace: [SharedsOutlined](Use &U) { return U.getUser() != SharedsOutlined; });
2526
2527 Value *IV = CLI->getIndVar();
2528 Type *IVTy = IV->getType();
2529 Constant *One = ConstantInt::get(Ty: Builder.getInt64Ty(), V: 1);
2530
2531 // When outlining, CodeExtractor will create GEP's to the LowerBound and
2532 // UpperBound. These GEP's can be reused for loading the tasks respective
2533 // bounds.
2534 Value *TaskLB = nullptr;
2535 Value *TaskUB = nullptr;
2536 Value *TaskStep = nullptr;
2537 Value *LoadTaskLB = nullptr;
2538 Value *LoadTaskUB = nullptr;
2539 Value *LoadTaskStep = nullptr;
2540 for (Instruction &I : *TaskloopAllocaBB) {
2541 if (I.getOpcode() == Instruction::GetElementPtr) {
2542 GetElementPtrInst &Gep = cast<GetElementPtrInst>(Val&: I);
2543 if (ConstantInt *CI = dyn_cast<ConstantInt>(Val: Gep.getOperand(i_nocapture: 2))) {
2544 switch (CI->getZExtValue()) {
2545 case 0:
2546 TaskLB = &I;
2547 break;
2548 case 1:
2549 TaskUB = &I;
2550 break;
2551 case 2:
2552 TaskStep = &I;
2553 break;
2554 }
2555 }
2556 } else if (I.getOpcode() == Instruction::Load) {
2557 LoadInst &Load = cast<LoadInst>(Val&: I);
2558 if (Load.getPointerOperand() == TaskLB) {
2559 assert(TaskLB != nullptr && "Expected value for TaskLB");
2560 LoadTaskLB = &I;
2561 } else if (Load.getPointerOperand() == TaskUB) {
2562 assert(TaskUB != nullptr && "Expected value for TaskUB");
2563 LoadTaskUB = &I;
2564 } else if (Load.getPointerOperand() == TaskStep) {
2565 assert(TaskStep != nullptr && "Expected value for TaskStep");
2566 LoadTaskStep = &I;
2567 }
2568 }
2569 }
2570
2571 Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
2572
2573 assert(LoadTaskLB != nullptr && "Expected value for LoadTaskLB");
2574 assert(LoadTaskUB != nullptr && "Expected value for LoadTaskUB");
2575 assert(LoadTaskStep != nullptr && "Expected value for LoadTaskStep");
2576 Value *TripCountMinusOne = Builder.CreateSDiv(
2577 LHS: Builder.CreateSub(LHS: LoadTaskUB, RHS: LoadTaskLB), RHS: LoadTaskStep);
2578 Value *TripCount = Builder.CreateAdd(LHS: TripCountMinusOne, RHS: One, Name: "trip_cnt");
2579 Value *CastedTripCount = Builder.CreateIntCast(V: TripCount, DestTy: IVTy, isSigned: true);
2580 Value *CastedTaskLB = Builder.CreateIntCast(V: LoadTaskLB, DestTy: IVTy, isSigned: true);
2581 // set the trip count in the CLI
2582 CLI->setTripCount(CastedTripCount);
2583
2584 Builder.SetInsertPoint(TheBB: CLI->getBody(),
2585 IP: CLI->getBody()->getFirstInsertionPt());
2586
2587 if (NumOfCollapseLoops > 1) {
2588 llvm::SmallVector<User *> UsersToReplace;
2589 // When using the collapse clause, the bounds of the loop have to be
2590 // adjusted to properly represent the iterator of the outer loop.
2591 Value *IVPlusTaskLB = Builder.CreateAdd(
2592 LHS: CLI->getIndVar(),
2593 RHS: Builder.CreateSub(LHS: CastedTaskLB, RHS: ConstantInt::get(Ty: IVTy, V: 1)));
2594 // To ensure every Use is correctly captured, we first want to record
2595 // which users to replace the value in, and then replace the value.
2596 for (auto IVUse = CLI->getIndVar()->uses().begin();
2597 IVUse != CLI->getIndVar()->uses().end(); IVUse++) {
2598 User *IVUser = IVUse->getUser();
2599 if (auto *Op = dyn_cast<BinaryOperator>(Val: IVUser)) {
2600 if (Op->getOpcode() == Instruction::URem ||
2601 Op->getOpcode() == Instruction::UDiv) {
2602 UsersToReplace.push_back(Elt: IVUser);
2603 }
2604 }
2605 }
2606 for (User *User : UsersToReplace) {
2607 User->replaceUsesOfWith(From: CLI->getIndVar(), To: IVPlusTaskLB);
2608 }
2609 } else {
2610 // The canonical loop is generated with a fixed lower bound. We need to
2611 // update the index calculation code to use the task's lower bound. The
2612 // generated code looks like this:
2613 // %omp_loop.iv = phi ...
2614 // ...
2615 // %tmp = mul [type] %omp_loop.iv, step
2616 // %user_index = add [type] tmp, lb
2617 // OpenMPIRBuilder constructs canonical loops to have exactly three uses
2618 // of the normalised induction variable:
2619 // 1. This one: converting the normalised IV to the user IV
2620 // 2. The increment (add)
2621 // 3. The comparison against the trip count (icmp)
2622 // (1) is the only use that is a mul followed by an add so this cannot
2623 // match other IR.
2624 assert(CLI->getIndVar()->getNumUses() == 3 &&
2625 "Canonical loop should have exactly three uses of the ind var");
2626 for (User *IVUser : CLI->getIndVar()->users()) {
2627 if (auto *Mul = dyn_cast<BinaryOperator>(Val: IVUser)) {
2628 if (Mul->getOpcode() == Instruction::Mul) {
2629 for (User *MulUser : Mul->users()) {
2630 if (auto *Add = dyn_cast<BinaryOperator>(Val: MulUser)) {
2631 if (Add->getOpcode() == Instruction::Add) {
2632 Add->setOperand(i_nocapture: 1, Val_nocapture: CastedTaskLB);
2633 }
2634 }
2635 }
2636 }
2637 }
2638 }
2639 }
2640
2641 FakeLB->replaceAllUsesWith(V: CastedLBVal);
2642 FakeUB->replaceAllUsesWith(V: CastedUBVal);
2643 FakeStep->replaceAllUsesWith(V: CastedStepVal);
2644 for (Instruction *I : llvm::reverse(C&: ToBeDeleted)) {
2645 I->eraseFromParent();
2646 }
2647 };
2648
2649 addOutlineInfo(OI: std::move(OI));
2650 Builder.SetInsertPoint(TheBB: TaskloopExitBB, IP: TaskloopExitBB->begin());
2651 return Builder.saveIP();
2652}
2653
2654llvm::StructType *OpenMPIRBuilder::getKmpTaskAffinityInfoTy() {
2655 llvm::Type *IntPtrTy = llvm::Type::getIntNTy(
2656 C&: M.getContext(), N: M.getDataLayout().getPointerSizeInBits());
2657 return llvm::StructType::get(elt1: IntPtrTy, elts: IntPtrTy,
2658 elts: llvm::Type::getInt32Ty(C&: M.getContext()));
2659}
2660
2661OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
2662 const LocationDescription &Loc, InsertPointTy AllocaIP,
2663 ArrayRef<BasicBlock *> DeallocBlocks, BodyGenCallbackTy BodyGenCB,
2664 bool Tied, Value *Final, Value *IfCondition,
2665 const DependenciesInfo &Dependencies, const AffinityData &Affinities,
2666 bool Mergeable, Value *EventHandle, Value *Priority) {
2667
2668 if (!updateToLocation(Loc))
2669 return InsertPointTy();
2670
2671 uint32_t SrcLocStrSize;
2672 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2673 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2674 // The current basic block is split into four basic blocks. After outlining,
2675 // they will be mapped as follows:
2676 // ```
2677 // def current_fn() {
2678 // current_basic_block:
2679 // br label %task.exit
2680 // task.exit:
2681 // ; instructions after task
2682 // }
2683 // def outlined_fn() {
2684 // task.alloca:
2685 // br label %task.body
2686 // task.body:
2687 // ret void
2688 // }
2689 // ```
2690 BasicBlock *TaskExitBB = splitBB(Builder, /*CreateBranch=*/true, Name: "task.exit");
2691 BasicBlock *TaskBodyBB = splitBB(Builder, /*CreateBranch=*/true, Name: "task.body");
2692 BasicBlock *TaskAllocaBB =
2693 splitBB(Builder, /*CreateBranch=*/true, Name: "task.alloca");
2694
2695 InsertPointTy TaskAllocaIP =
2696 InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin());
2697 InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin());
2698 if (Error Err = BodyGenCB(TaskAllocaIP, TaskBodyIP, TaskExitBB))
2699 return Err;
2700
2701 auto OI = std::make_unique<OutlineInfo>();
2702 OI->EntryBB = TaskAllocaBB;
2703 OI->OuterAllocBB = AllocaIP.getBlock();
2704 OI->ExitBB = TaskExitBB;
2705 OI->OuterDeallocBBs.reserve(N: DeallocBlocks.size());
2706 copy(Range&: DeallocBlocks, Out: OI->OuterDeallocBBs.end());
2707
2708 // Add the thread ID argument.
2709 SmallVector<Instruction *, 4> ToBeDeleted;
2710 OI->ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
2711 Builder, OuterAllocaIP: AllocaIP, ToBeDeleted, InnerAllocaIP: TaskAllocaIP, Name: "global.tid", AsPtr: false));
2712
2713 OI->PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
2714 Affinities, Mergeable, Priority, EventHandle,
2715 TaskAllocaBB,
2716 ToBeDeleted](Function &OutlinedFn) mutable {
2717 // Replace the Stale CI by appropriate RTL function call.
2718 assert(OutlinedFn.hasOneUse() &&
2719 "there must be a single user for the outlined function");
2720 CallInst *StaleCI = cast<CallInst>(Val: OutlinedFn.user_back());
2721
2722 // HasShareds is true if any variables are captured in the outlined region,
2723 // false otherwise.
2724 bool HasShareds = StaleCI->arg_size() > 1;
2725 Builder.SetInsertPoint(StaleCI);
2726
2727 // Gather the arguments for emitting the runtime call for
2728 // @__kmpc_omp_task_alloc
2729 Function *TaskAllocFn =
2730 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_alloc);
2731
2732 // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
2733 // call.
2734 Value *ThreadID = getOrCreateThreadID(Ident);
2735
2736 // Argument - `flags`
2737 // Task is tied iff (Flags & 1) == 1.
2738 // Task is untied iff (Flags & 1) == 0.
2739 // Task is final iff (Flags & 2) == 2.
2740 // Task is not final iff (Flags & 2) == 0.
2741 // Task is mergeable or merged-if0 iff (Flags & 4) == 4.
2742 // Task is neither mergeable nor merged-if0 iff (Flags & 4) == 0.
2743 // Task is detachable iff (Flags & 64) == 64.
2744 // Task is not detachable iff (Flags & 64) == 0.
2745 // Task is priority iff (Flags & 32) == 32.
2746 // Task is not priority iff (Flags & 32) == 0.
2747 // TODO: Handle the other flags.
2748 Value *Flags = Builder.getInt32(C: Tied);
2749 auto *ConstIfCondition = dyn_cast_or_null<ConstantInt>(Val: IfCondition);
2750 bool UseMergedIf0Path = ConstIfCondition && ConstIfCondition->isZero();
2751 if (Final) {
2752 Value *FinalFlag =
2753 Builder.CreateSelect(C: Final, True: Builder.getInt32(C: 2), False: Builder.getInt32(C: 0));
2754 Flags = Builder.CreateOr(LHS: FinalFlag, RHS: Flags);
2755 }
2756
2757 if (Mergeable || UseMergedIf0Path)
2758 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 4), RHS: Flags);
2759 if (EventHandle)
2760 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 64), RHS: Flags);
2761 if (Priority)
2762 Flags = Builder.CreateOr(LHS: Builder.getInt32(C: 32), RHS: Flags);
2763
2764 // Argument - `sizeof_kmp_task_t` (TaskSize)
2765 // Tasksize refers to the size in bytes of kmp_task_t data structure
2766 // including private vars accessed in task.
2767 // TODO: add kmp_task_t_with_privates (privates)
2768 Value *TaskSize = Builder.getInt64(
2769 C: divideCeil(Numerator: M.getDataLayout().getTypeSizeInBits(Ty: Task), Denominator: 8));
2770
2771 // Argument - `sizeof_shareds` (SharedsSize)
2772 // SharedsSize refers to the shareds array size in the kmp_task_t data
2773 // structure.
2774 Value *SharedsSize = Builder.getInt64(C: 0);
2775 if (HasShareds) {
2776 AllocaInst *ArgStructAlloca =
2777 dyn_cast<AllocaInst>(Val: StaleCI->getArgOperand(i: 1));
2778 assert(ArgStructAlloca &&
2779 "Unable to find the alloca instruction corresponding to arguments "
2780 "for extracted function");
2781 std::optional<TypeSize> ArgAllocSize =
2782 ArgStructAlloca->getAllocationSize(DL: M.getDataLayout());
2783 assert(ArgAllocSize &&
2784 "Unable to determine size of arguments for extracted function");
2785 SharedsSize = Builder.getInt64(C: ArgAllocSize->getFixedValue());
2786 }
2787 // Emit the @__kmpc_omp_task_alloc runtime call
2788 // The runtime call returns a pointer to an area where the task captured
2789 // variables must be copied before the task is run (TaskData)
2790 CallInst *TaskData = createRuntimeFunctionCall(
2791 Callee: TaskAllocFn, Args: {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
2792 /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
2793 /*task_func=*/&OutlinedFn});
2794
2795 if (Affinities.Count && Affinities.Info) {
2796 Function *RegAffFn = getOrCreateRuntimeFunctionPtr(
2797 FnID: OMPRTL___kmpc_omp_reg_task_with_affinity);
2798
2799 createRuntimeFunctionCall(Callee: RegAffFn, Args: {Ident, ThreadID, TaskData,
2800 Affinities.Count, Affinities.Info});
2801 }
2802
2803 // Emit detach clause initialization.
2804 // evt = (typeof(evt))__kmpc_task_allow_completion_event(loc, tid,
2805 // task_descriptor);
2806 if (EventHandle) {
2807 Function *TaskDetachFn = getOrCreateRuntimeFunctionPtr(
2808 FnID: OMPRTL___kmpc_task_allow_completion_event);
2809 llvm::Value *EventVal =
2810 createRuntimeFunctionCall(Callee: TaskDetachFn, Args: {Ident, ThreadID, TaskData});
2811 llvm::Value *EventHandleAddr =
2812 Builder.CreatePointerBitCastOrAddrSpaceCast(V: EventHandle,
2813 DestTy: Builder.getPtrTy(AddrSpace: 0));
2814 EventVal = Builder.CreatePtrToInt(V: EventVal, DestTy: Builder.getInt64Ty());
2815 Builder.CreateStore(Val: EventVal, Ptr: EventHandleAddr);
2816 }
2817 // Copy the arguments for outlined function
2818 if (HasShareds) {
2819 Value *Shareds = StaleCI->getArgOperand(i: 1);
2820 Align Alignment = TaskData->getPointerAlignment(DL: M.getDataLayout());
2821 Value *TaskShareds = Builder.CreateLoad(Ty: VoidPtr, Ptr: TaskData);
2822 Builder.CreateMemCpy(Dst: TaskShareds, DstAlign: Alignment, Src: Shareds, SrcAlign: Alignment,
2823 Size: SharedsSize);
2824 }
2825
2826 if (Priority) {
2827 //
2828 // The return type of "__kmpc_omp_task_alloc" is "kmp_task_t *",
2829 // we populate the priority information into the "kmp_task_t" here
2830 //
2831 // The struct "kmp_task_t" definition is available in kmp.h
2832 // kmp_task_t = { shareds, routine, part_id, data1, data2 }
2833 // data2 is used for priority
2834 //
2835 Type *Int32Ty = Builder.getInt32Ty();
2836 Constant *Zero = ConstantInt::get(Ty: Int32Ty, V: 0);
2837 // kmp_task_t* => { ptr }
2838 Type *TaskPtr = StructType::get(elt1: VoidPtr);
2839 Value *TaskGEP =
2840 Builder.CreateInBoundsGEP(Ty: TaskPtr, Ptr: TaskData, IdxList: {Zero, Zero});
2841 // kmp_task_t => { ptr, ptr, i32, ptr, ptr }
2842 Type *TaskStructType = StructType::get(
2843 elt1: VoidPtr, elts: VoidPtr, elts: Builder.getInt32Ty(), elts: VoidPtr, elts: VoidPtr);
2844 Value *PriorityData = Builder.CreateInBoundsGEP(
2845 Ty: TaskStructType, Ptr: TaskGEP, IdxList: {Zero, ConstantInt::get(Ty: Int32Ty, V: 4)});
2846 // kmp_cmplrdata_t => { ptr, ptr }
2847 Type *CmplrStructType = StructType::get(elt1: VoidPtr, elts: VoidPtr);
2848 Value *CmplrData = Builder.CreateInBoundsGEP(Ty: CmplrStructType,
2849 Ptr: PriorityData, IdxList: {Zero, Zero});
2850 Builder.CreateStore(Val: Priority, Ptr: CmplrData);
2851 }
2852
2853 Value *DepArray = nullptr;
2854 Value *NumDeps = nullptr;
2855 if (Dependencies.DepArray) {
2856 DepArray = Dependencies.DepArray;
2857 NumDeps = Dependencies.NumDeps;
2858 } else if (!Dependencies.Deps.empty()) {
2859 DepArray = emitTaskDependencies(OMPBuilder&: *this, Dependencies: Dependencies.Deps);
2860 NumDeps = Builder.getInt32(C: Dependencies.Deps.size());
2861 }
2862
2863 // In the presence of the `if` clause, the following IR is generated:
2864 // ...
2865 // %data = call @__kmpc_omp_task_alloc(...)
2866 // br i1 %if_condition, label %then, label %else
2867 // then:
2868 // call @__kmpc_omp_task(...)
2869 // br label %exit
2870 // else:
2871 // ;; Wait for resolution of dependencies, if any, before
2872 // ;; beginning the task
2873 // call @__kmpc_omp_wait_deps(...)
2874 // call @__kmpc_omp_task_begin_if0(...)
2875 // call @outlined_fn(...)
2876 // call @__kmpc_omp_task_complete_if0(...)
2877 // br label %exit
2878 // exit:
2879 // ...
2880 if (IfCondition && !UseMergedIf0Path) {
2881 // `SplitBlockAndInsertIfThenElse` requires the block to have a
2882 // terminator.
2883 splitBB(Builder, /*CreateBranch=*/true, Name: "if.end");
2884 Instruction *IfTerminator =
2885 Builder.GetInsertPoint()->getParent()->getTerminator();
2886 Instruction *ThenTI = IfTerminator, *ElseTI = nullptr;
2887 Builder.SetInsertPoint(IfTerminator);
2888 SplitBlockAndInsertIfThenElse(Cond: IfCondition, SplitBefore: IfTerminator, ThenTerm: &ThenTI,
2889 ElseTerm: &ElseTI);
2890 Builder.SetInsertPoint(ElseTI);
2891
2892 if (DepArray) {
2893 Function *TaskWaitFn =
2894 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_wait_deps);
2895 createRuntimeFunctionCall(
2896 Callee: TaskWaitFn,
2897 Args: {Ident, ThreadID, NumDeps, DepArray,
2898 ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
2899 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
2900 }
2901 Function *TaskBeginFn =
2902 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_begin_if0);
2903 Function *TaskCompleteFn =
2904 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_complete_if0);
2905 createRuntimeFunctionCall(Callee: TaskBeginFn, Args: {Ident, ThreadID, TaskData});
2906 CallInst *CI = nullptr;
2907 if (HasShareds)
2908 CI = createRuntimeFunctionCall(Callee: &OutlinedFn, Args: {ThreadID, TaskData});
2909 else
2910 CI = createRuntimeFunctionCall(Callee: &OutlinedFn, Args: {ThreadID});
2911 CI->setDebugLoc(StaleCI->getDebugLoc());
2912 createRuntimeFunctionCall(Callee: TaskCompleteFn, Args: {Ident, ThreadID, TaskData});
2913 Builder.SetInsertPoint(ThenTI);
2914 }
2915
2916 if (DepArray) {
2917 Function *TaskFn =
2918 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_with_deps);
2919 createRuntimeFunctionCall(
2920 Callee: TaskFn,
2921 Args: {Ident, ThreadID, TaskData, NumDeps, DepArray,
2922 ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
2923 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
2924
2925 } else {
2926 // Emit the @__kmpc_omp_task runtime call to spawn the task
2927 Function *TaskFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task);
2928 createRuntimeFunctionCall(Callee: TaskFn, Args: {Ident, ThreadID, TaskData});
2929 }
2930
2931 StaleCI->eraseFromParent();
2932
2933 Builder.SetInsertPoint(TheBB: TaskAllocaBB, IP: TaskAllocaBB->begin());
2934 if (HasShareds) {
2935 LoadInst *Shareds = Builder.CreateLoad(Ty: VoidPtr, Ptr: OutlinedFn.getArg(i: 1));
2936 OutlinedFn.getArg(i: 1)->replaceUsesWithIf(
2937 New: Shareds, ShouldReplace: [Shareds](Use &U) { return U.getUser() != Shareds; });
2938 }
2939
2940 for (Instruction *I : llvm::reverse(C&: ToBeDeleted))
2941 I->eraseFromParent();
2942 };
2943
2944 addOutlineInfo(OI: std::move(OI));
2945 Builder.SetInsertPoint(TheBB: TaskExitBB, IP: TaskExitBB->begin());
2946
2947 return Builder.saveIP();
2948}
2949
2950OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTaskgroup(
2951 const LocationDescription &Loc, InsertPointTy AllocaIP,
2952 ArrayRef<BasicBlock *> DeallocBlocks, BodyGenCallbackTy BodyGenCB) {
2953 if (!updateToLocation(Loc))
2954 return InsertPointTy();
2955
2956 uint32_t SrcLocStrSize;
2957 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2958 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2959 Value *ThreadID = getOrCreateThreadID(Ident);
2960
2961 // Emit the @__kmpc_taskgroup runtime call to start the taskgroup
2962 Function *TaskgroupFn =
2963 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_taskgroup);
2964 createRuntimeFunctionCall(Callee: TaskgroupFn, Args: {Ident, ThreadID});
2965
2966 BasicBlock *TaskgroupExitBB = splitBB(Builder, CreateBranch: true, Name: "taskgroup.exit");
2967 if (Error Err = BodyGenCB(AllocaIP, Builder.saveIP(), DeallocBlocks))
2968 return Err;
2969
2970 Builder.SetInsertPoint(TaskgroupExitBB);
2971 // Emit the @__kmpc_end_taskgroup runtime call to end the taskgroup
2972 Function *EndTaskgroupFn =
2973 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_taskgroup);
2974 createRuntimeFunctionCall(Callee: EndTaskgroupFn, Args: {Ident, ThreadID});
2975
2976 return Builder.saveIP();
2977}
2978
2979OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createSections(
2980 const LocationDescription &Loc, InsertPointTy AllocaIP,
2981 ArrayRef<StorableBodyGenCallbackTy> SectionCBs, PrivatizeCallbackTy PrivCB,
2982 FinalizeCallbackTy FiniCB, bool IsCancellable, bool IsNowait) {
2983 assert(!isConflictIP(AllocaIP, Loc.IP) && "Dedicated IP allocas required");
2984
2985 if (!updateToLocation(Loc))
2986 return Loc.IP;
2987
2988 FinalizationStack.push_back(Elt: {FiniCB, OMPD_sections, IsCancellable});
2989
2990 // Each section is emitted as a switch case
2991 // Each finalization callback is handled from clang.EmitOMPSectionDirective()
2992 // -> OMP.createSection() which generates the IR for each section
2993 // Iterate through all sections and emit a switch construct:
2994 // switch (IV) {
2995 // case 0:
2996 // <SectionStmt[0]>;
2997 // break;
2998 // ...
2999 // case <NumSection> - 1:
3000 // <SectionStmt[<NumSection> - 1]>;
3001 // break;
3002 // }
3003 // ...
3004 // section_loop.after:
3005 // <FiniCB>;
3006 auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, Value *IndVar) -> Error {
3007 Builder.restoreIP(IP: CodeGenIP);
3008 BasicBlock *Continue =
3009 splitBBWithSuffix(Builder, /*CreateBranch=*/false, Suffix: ".sections.after");
3010 Function *CurFn = Continue->getParent();
3011 SwitchInst *SwitchStmt = Builder.CreateSwitch(V: IndVar, Dest: Continue);
3012
3013 unsigned CaseNumber = 0;
3014 for (auto SectionCB : SectionCBs) {
3015 BasicBlock *CaseBB = BasicBlock::Create(
3016 Context&: M.getContext(), Name: "omp_section_loop.body.case", Parent: CurFn, InsertBefore: Continue);
3017 SwitchStmt->addCase(OnVal: Builder.getInt32(C: CaseNumber), Dest: CaseBB);
3018 Builder.SetInsertPoint(CaseBB);
3019 UncondBrInst *CaseEndBr = Builder.CreateBr(Dest: Continue);
3020 if (Error Err =
3021 SectionCB(InsertPointTy(),
3022 {CaseEndBr->getParent(), CaseEndBr->getIterator()}, {}))
3023 return Err;
3024 CaseNumber++;
3025 }
3026 // remove the existing terminator from body BB since there can be no
3027 // terminators after switch/case
3028 return Error::success();
3029 };
3030 // Loop body ends here
3031 // LowerBound, UpperBound, and STride for createCanonicalLoop
3032 Type *I32Ty = Type::getInt32Ty(C&: M.getContext());
3033 Value *LB = ConstantInt::get(Ty: I32Ty, V: 0);
3034 Value *UB = ConstantInt::get(Ty: I32Ty, V: SectionCBs.size());
3035 Value *ST = ConstantInt::get(Ty: I32Ty, V: 1);
3036 Expected<CanonicalLoopInfo *> LoopInfo = createCanonicalLoop(
3037 Loc, BodyGenCB: LoopBodyGenCB, Start: LB, Stop: UB, Step: ST, IsSigned: true, InclusiveStop: false, ComputeIP: AllocaIP, Name: "section_loop");
3038 if (!LoopInfo)
3039 return LoopInfo.takeError();
3040
3041 InsertPointOrErrorTy WsloopIP =
3042 applyStaticWorkshareLoop(DL: Loc.DL, CLI: *LoopInfo, AllocaIP,
3043 LoopType: WorksharingLoopType::ForStaticLoop, NeedsBarrier: !IsNowait);
3044 if (!WsloopIP)
3045 return WsloopIP.takeError();
3046 InsertPointTy AfterIP = *WsloopIP;
3047
3048 BasicBlock *LoopFini = AfterIP.getBlock()->getSinglePredecessor();
3049 assert(LoopFini && "Bad structure of static workshare loop finalization");
3050
3051 // Apply the finalization callback in LoopAfterBB
3052 auto FiniInfo = FinalizationStack.pop_back_val();
3053 assert(FiniInfo.DK == OMPD_sections &&
3054 "Unexpected finalization stack state!");
3055 if (Error Err = FiniInfo.mergeFiniBB(Builder, OtherFiniBB: LoopFini))
3056 return Err;
3057
3058 return AfterIP;
3059}
3060
3061OpenMPIRBuilder::InsertPointOrErrorTy
3062OpenMPIRBuilder::createSection(const LocationDescription &Loc,
3063 BodyGenCallbackTy BodyGenCB,
3064 FinalizeCallbackTy FiniCB) {
3065 if (!updateToLocation(Loc))
3066 return Loc.IP;
3067
3068 auto FiniCBWrapper = [&](InsertPointTy IP) {
3069 if (IP.getBlock()->end() != IP.getPoint())
3070 return FiniCB(IP);
3071 // This must be done otherwise any nested constructs using FinalizeOMPRegion
3072 // will fail because that function requires the Finalization Basic Block to
3073 // have a terminator, which is already removed by EmitOMPRegionBody.
3074 // IP is currently at cancelation block.
3075 // We need to backtrack to the condition block to fetch
3076 // the exit block and create a branch from cancelation
3077 // to exit block.
3078 IRBuilder<>::InsertPointGuard IPG(Builder);
3079 Builder.restoreIP(IP);
3080 auto *CaseBB = Loc.IP.getBlock();
3081 auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
3082 auto *ExitBB = CondBB->getTerminator()->getSuccessor(Idx: 1);
3083 Instruction *I = Builder.CreateBr(Dest: ExitBB);
3084 IP = InsertPointTy(I->getParent(), I->getIterator());
3085 return FiniCB(IP);
3086 };
3087
3088 Directive OMPD = Directive::OMPD_sections;
3089 // Since we are using Finalization Callback here, HasFinalize
3090 // and IsCancellable have to be true
3091 return EmitOMPInlinedRegion(OMPD, EntryCall: nullptr, ExitCall: nullptr, BodyGenCB, FiniCB: FiniCBWrapper,
3092 /*Conditional*/ false, /*hasFinalize*/ HasFinalize: true,
3093 /*IsCancellable*/ true);
3094}
3095
3096static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
3097 BasicBlock::iterator IT(I);
3098 IT++;
3099 return OpenMPIRBuilder::InsertPointTy(I->getParent(), IT);
3100}
3101
3102Value *OpenMPIRBuilder::getGPUThreadID() {
3103 return createRuntimeFunctionCall(
3104 Callee: getOrCreateRuntimeFunction(M,
3105 FnID: OMPRTL___kmpc_get_hardware_thread_id_in_block),
3106 Args: {});
3107}
3108
3109Value *OpenMPIRBuilder::getGPUWarpSize() {
3110 return createRuntimeFunctionCall(
3111 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___kmpc_get_warp_size), Args: {});
3112}
3113
3114Value *OpenMPIRBuilder::getNVPTXWarpID() {
3115 unsigned LaneIDBits = Log2_32(Value: Config.getGridValue().GV_Warp_Size);
3116 return Builder.CreateAShr(LHS: getGPUThreadID(), RHS: LaneIDBits, Name: "nvptx_warp_id");
3117}
3118
3119Value *OpenMPIRBuilder::getNVPTXLaneID() {
3120 unsigned LaneIDBits = Log2_32(Value: Config.getGridValue().GV_Warp_Size);
3121 assert(LaneIDBits < 32 && "Invalid LaneIDBits size in NVPTX device.");
3122 unsigned LaneIDMask = ~0u >> (32u - LaneIDBits);
3123 return Builder.CreateAnd(LHS: getGPUThreadID(), RHS: Builder.getInt32(C: LaneIDMask),
3124 Name: "nvptx_lane_id");
3125}
3126
3127Value *OpenMPIRBuilder::castValueToType(InsertPointTy AllocaIP, Value *From,
3128 Type *ToType) {
3129 Type *FromType = From->getType();
3130 uint64_t FromSize = M.getDataLayout().getTypeStoreSize(Ty: FromType);
3131 uint64_t ToSize = M.getDataLayout().getTypeStoreSize(Ty: ToType);
3132 assert(FromSize > 0 && "From size must be greater than zero");
3133 assert(ToSize > 0 && "To size must be greater than zero");
3134 if (FromType == ToType)
3135 return From;
3136 if (FromSize == ToSize)
3137 return Builder.CreateBitCast(V: From, DestTy: ToType);
3138 if (ToType->isIntegerTy() && FromType->isIntegerTy())
3139 return Builder.CreateIntCast(V: From, DestTy: ToType, /*isSigned*/ true);
3140 InsertPointTy SaveIP = Builder.saveIP();
3141 Builder.restoreIP(IP: AllocaIP);
3142 Value *CastItem = Builder.CreateAlloca(Ty: ToType);
3143 Builder.restoreIP(IP: SaveIP);
3144
3145 Value *ValCastItem = Builder.CreatePointerBitCastOrAddrSpaceCast(
3146 V: CastItem, DestTy: Builder.getPtrTy(AddrSpace: 0));
3147 Builder.CreateStore(Val: From, Ptr: ValCastItem);
3148 return Builder.CreateLoad(Ty: ToType, Ptr: CastItem);
3149}
3150
3151Value *OpenMPIRBuilder::createRuntimeShuffleFunction(InsertPointTy AllocaIP,
3152 Value *Element,
3153 Type *ElementType,
3154 Value *Offset) {
3155 uint64_t Size = M.getDataLayout().getTypeStoreSize(Ty: ElementType);
3156 assert(Size <= 8 && "Unsupported bitwidth in shuffle instruction");
3157
3158 // Cast all types to 32- or 64-bit values before calling shuffle routines.
3159 Type *CastTy = Builder.getIntNTy(N: Size <= 4 ? 32 : 64);
3160 Value *ElemCast = castValueToType(AllocaIP, From: Element, ToType: CastTy);
3161 Value *WarpSize =
3162 Builder.CreateIntCast(V: getGPUWarpSize(), DestTy: Builder.getInt16Ty(), isSigned: true);
3163 Function *ShuffleFunc = getOrCreateRuntimeFunctionPtr(
3164 FnID: Size <= 4 ? RuntimeFunction::OMPRTL___kmpc_shuffle_int32
3165 : RuntimeFunction::OMPRTL___kmpc_shuffle_int64);
3166 Value *WarpSizeCast =
3167 Builder.CreateIntCast(V: WarpSize, DestTy: Builder.getInt16Ty(), /*isSigned=*/true);
3168 Value *ShuffleCall =
3169 createRuntimeFunctionCall(Callee: ShuffleFunc, Args: {ElemCast, Offset, WarpSizeCast});
3170 return castValueToType(AllocaIP, From: ShuffleCall, ToType: CastTy);
3171}
3172
3173void OpenMPIRBuilder::shuffleAndStore(InsertPointTy AllocaIP, Value *SrcAddr,
3174 Value *DstAddr, Type *ElemType,
3175 Value *Offset, Type *ReductionArrayTy,
3176 bool IsByRefElem) {
3177 uint64_t Size = M.getDataLayout().getTypeStoreSize(Ty: ElemType);
3178 // Create the loop over the big sized data.
3179 // ptr = (void*)Elem;
3180 // ptrEnd = (void*) Elem + 1;
3181 // Step = 8;
3182 // while (ptr + Step < ptrEnd)
3183 // shuffle((int64_t)*ptr);
3184 // Step = 4;
3185 // while (ptr + Step < ptrEnd)
3186 // shuffle((int32_t)*ptr);
3187 // ...
3188 Type *IndexTy = Builder.getIndexTy(
3189 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3190 Value *ElemPtr = DstAddr;
3191 Value *Ptr = SrcAddr;
3192 for (unsigned IntSize = 8; IntSize >= 1; IntSize /= 2) {
3193 if (Size < IntSize)
3194 continue;
3195 Type *IntType = Builder.getIntNTy(N: IntSize * 8);
3196 Ptr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3197 V: Ptr, DestTy: Builder.getPtrTy(AddrSpace: 0), Name: Ptr->getName() + ".ascast");
3198 Value *SrcAddrGEP =
3199 Builder.CreateGEP(Ty: ElemType, Ptr: SrcAddr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
3200 ElemPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3201 V: ElemPtr, DestTy: Builder.getPtrTy(AddrSpace: 0), Name: ElemPtr->getName() + ".ascast");
3202
3203 Function *CurFunc = Builder.GetInsertBlock()->getParent();
3204 if ((Size / IntSize) > 1) {
3205 Value *PtrEnd = Builder.CreatePointerBitCastOrAddrSpaceCast(
3206 V: SrcAddrGEP, DestTy: Builder.getPtrTy());
3207 BasicBlock *PreCondBB =
3208 BasicBlock::Create(Context&: M.getContext(), Name: ".shuffle.pre_cond");
3209 BasicBlock *ThenBB = BasicBlock::Create(Context&: M.getContext(), Name: ".shuffle.then");
3210 BasicBlock *ExitBB = BasicBlock::Create(Context&: M.getContext(), Name: ".shuffle.exit");
3211 BasicBlock *CurrentBB = Builder.GetInsertBlock();
3212 emitBlock(BB: PreCondBB, CurFn: CurFunc);
3213 PHINode *PhiSrc =
3214 Builder.CreatePHI(Ty: Ptr->getType(), /*NumReservedValues=*/2);
3215 PhiSrc->addIncoming(V: Ptr, BB: CurrentBB);
3216 PHINode *PhiDest =
3217 Builder.CreatePHI(Ty: ElemPtr->getType(), /*NumReservedValues=*/2);
3218 PhiDest->addIncoming(V: ElemPtr, BB: CurrentBB);
3219 Ptr = PhiSrc;
3220 ElemPtr = PhiDest;
3221 Value *PtrDiff = Builder.CreatePtrDiff(
3222 ElemTy: Builder.getInt8Ty(), LHS: PtrEnd,
3223 RHS: Builder.CreatePointerBitCastOrAddrSpaceCast(V: Ptr, DestTy: Builder.getPtrTy()));
3224 Builder.CreateCondBr(
3225 Cond: Builder.CreateICmpSGT(LHS: PtrDiff, RHS: Builder.getInt64(C: IntSize - 1)), True: ThenBB,
3226 False: ExitBB);
3227 emitBlock(BB: ThenBB, CurFn: CurFunc);
3228 Value *Res = createRuntimeShuffleFunction(
3229 AllocaIP,
3230 Element: Builder.CreateAlignedLoad(
3231 Ty: IntType, Ptr, Align: M.getDataLayout().getPrefTypeAlign(Ty: ElemType)),
3232 ElementType: IntType, Offset);
3233 Builder.CreateAlignedStore(Val: Res, Ptr: ElemPtr,
3234 Align: M.getDataLayout().getPrefTypeAlign(Ty: ElemType));
3235 Value *LocalPtr =
3236 Builder.CreateGEP(Ty: IntType, Ptr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
3237 Value *LocalElemPtr =
3238 Builder.CreateGEP(Ty: IntType, Ptr: ElemPtr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
3239 PhiSrc->addIncoming(V: LocalPtr, BB: ThenBB);
3240 PhiDest->addIncoming(V: LocalElemPtr, BB: ThenBB);
3241 emitBranch(Target: PreCondBB);
3242 emitBlock(BB: ExitBB, CurFn: CurFunc);
3243 } else {
3244 Value *Res = createRuntimeShuffleFunction(
3245 AllocaIP, Element: Builder.CreateLoad(Ty: IntType, Ptr), ElementType: IntType, Offset);
3246 if (ElemType->isIntegerTy() && ElemType->getScalarSizeInBits() <
3247 Res->getType()->getScalarSizeInBits())
3248 Res = Builder.CreateTrunc(V: Res, DestTy: ElemType);
3249 Builder.CreateStore(Val: Res, Ptr: ElemPtr);
3250 Ptr = Builder.CreateGEP(Ty: IntType, Ptr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
3251 ElemPtr =
3252 Builder.CreateGEP(Ty: IntType, Ptr: ElemPtr, IdxList: {ConstantInt::get(Ty: IndexTy, V: 1)});
3253 }
3254 Size = Size % IntSize;
3255 }
3256}
3257
3258Error OpenMPIRBuilder::emitReductionListCopy(
3259 InsertPointTy AllocaIP, CopyAction Action, Type *ReductionArrayTy,
3260 ArrayRef<ReductionInfo> ReductionInfos, Value *SrcBase, Value *DestBase,
3261 ArrayRef<bool> IsByRef, CopyOptionsTy CopyOptions) {
3262 Type *IndexTy = Builder.getIndexTy(
3263 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3264 Value *RemoteLaneOffset = CopyOptions.RemoteLaneOffset;
3265
3266 // Iterates, element-by-element, through the source Reduce list and
3267 // make a copy.
3268 for (auto En : enumerate(First&: ReductionInfos)) {
3269 const ReductionInfo &RI = En.value();
3270 Value *SrcElementAddr = nullptr;
3271 AllocaInst *DestAlloca = nullptr;
3272 Value *DestElementAddr = nullptr;
3273 Value *DestElementPtrAddr = nullptr;
3274 // Should we shuffle in an element from a remote lane?
3275 bool ShuffleInElement = false;
3276 // Set to true to update the pointer in the dest Reduce list to a
3277 // newly created element.
3278 bool UpdateDestListPtr = false;
3279
3280 // Step 1.1: Get the address for the src element in the Reduce list.
3281 Value *SrcElementPtrAddr = Builder.CreateInBoundsGEP(
3282 Ty: ReductionArrayTy, Ptr: SrcBase,
3283 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3284 SrcElementAddr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: SrcElementPtrAddr);
3285
3286 // Step 1.2: Create a temporary to store the element in the destination
3287 // Reduce list.
3288 DestElementPtrAddr = Builder.CreateInBoundsGEP(
3289 Ty: ReductionArrayTy, Ptr: DestBase,
3290 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3291 bool IsByRefElem = (!IsByRef.empty() && IsByRef[En.index()]);
3292 switch (Action) {
3293 case CopyAction::RemoteLaneToThread: {
3294 InsertPointTy CurIP = Builder.saveIP();
3295 Builder.restoreIP(IP: AllocaIP);
3296
3297 Type *DestAllocaType =
3298 IsByRefElem ? RI.ByRefAllocatedType : RI.ElementType;
3299 DestAlloca = Builder.CreateAlloca(Ty: DestAllocaType, ArraySize: nullptr,
3300 Name: ".omp.reduction.element");
3301 DestAlloca->setAlignment(
3302 M.getDataLayout().getPrefTypeAlign(Ty: DestAllocaType));
3303 DestElementAddr = DestAlloca;
3304 DestElementAddr =
3305 Builder.CreateAddrSpaceCast(V: DestElementAddr, DestTy: Builder.getPtrTy(),
3306 Name: DestElementAddr->getName() + ".ascast");
3307 Builder.restoreIP(IP: CurIP);
3308 ShuffleInElement = true;
3309 UpdateDestListPtr = true;
3310 break;
3311 }
3312 case CopyAction::ThreadCopy: {
3313 DestElementAddr =
3314 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: DestElementPtrAddr);
3315 break;
3316 }
3317 }
3318
3319 // Now that all active lanes have read the element in the
3320 // Reduce list, shuffle over the value from the remote lane.
3321 if (ShuffleInElement) {
3322 Type *ShuffleType = RI.ElementType;
3323 Value *ShuffleSrcAddr = SrcElementAddr;
3324 Value *ShuffleDestAddr = DestElementAddr;
3325 AllocaInst *LocalStorage = nullptr;
3326
3327 if (IsByRefElem) {
3328 assert(RI.ByRefElementType && "Expected by-ref element type to be set");
3329 assert(RI.ByRefAllocatedType &&
3330 "Expected by-ref allocated type to be set");
3331 // For by-ref reductions, we need to copy from the remote lane the
3332 // actual value of the partial reduction computed by that remote lane;
3333 // rather than, for example, a pointer to that data or, even worse, a
3334 // pointer to the descriptor of the by-ref reduction element.
3335 ShuffleType = RI.ByRefElementType;
3336
3337 if (RI.DataPtrPtrGen) {
3338 // Descriptor-based by-ref: extract data pointer from descriptor.
3339 InsertPointOrErrorTy GenResult = RI.DataPtrPtrGen(
3340 Builder.saveIP(), ShuffleSrcAddr, ShuffleSrcAddr);
3341
3342 if (!GenResult)
3343 return GenResult.takeError();
3344
3345 ShuffleSrcAddr =
3346 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ShuffleSrcAddr);
3347
3348 {
3349 InsertPointTy OldIP = Builder.saveIP();
3350 Builder.restoreIP(IP: AllocaIP);
3351
3352 LocalStorage = Builder.CreateAlloca(Ty: ShuffleType);
3353 Builder.restoreIP(IP: OldIP);
3354 ShuffleDestAddr = LocalStorage;
3355 }
3356 } else {
3357 // Non-descriptor by-ref: the pointer already references data
3358 // directly. Shuffle into the destination alloca.
3359 ShuffleDestAddr = DestElementAddr;
3360 }
3361 }
3362
3363 shuffleAndStore(AllocaIP, SrcAddr: ShuffleSrcAddr, DstAddr: ShuffleDestAddr, ElemType: ShuffleType,
3364 Offset: RemoteLaneOffset, ReductionArrayTy, IsByRefElem);
3365
3366 if (IsByRefElem && RI.DataPtrPtrGen) {
3367 // Copy descriptor from source and update base_ptr to shuffled data
3368 Value *DestDescriptorAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3369 V: DestAlloca, DestTy: Builder.getPtrTy(), Name: ".ascast");
3370
3371 InsertPointOrErrorTy GenResult = generateReductionDescriptor(
3372 DescriptorAddr: DestDescriptorAddr, DataPtr: LocalStorage, SrcDescriptorAddr: SrcElementAddr,
3373 DescriptorType: RI.ByRefAllocatedType, DataPtrPtrGen: RI.DataPtrPtrGen);
3374
3375 if (!GenResult)
3376 return GenResult.takeError();
3377 }
3378 } else {
3379 switch (RI.EvaluationKind) {
3380 case EvalKind::Scalar: {
3381 Value *Elem = Builder.CreateLoad(Ty: RI.ElementType, Ptr: SrcElementAddr);
3382 // Store the source element value to the dest element address.
3383 Builder.CreateStore(Val: Elem, Ptr: DestElementAddr);
3384 break;
3385 }
3386 case EvalKind::Complex: {
3387 Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
3388 Ty: RI.ElementType, Ptr: SrcElementAddr, Idx0: 0, Idx1: 0, Name: ".realp");
3389 Value *SrcReal = Builder.CreateLoad(
3390 Ty: RI.ElementType->getStructElementType(N: 0), Ptr: SrcRealPtr, Name: ".real");
3391 Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
3392 Ty: RI.ElementType, Ptr: SrcElementAddr, Idx0: 0, Idx1: 1, Name: ".imagp");
3393 Value *SrcImg = Builder.CreateLoad(
3394 Ty: RI.ElementType->getStructElementType(N: 1), Ptr: SrcImgPtr, Name: ".imag");
3395
3396 Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
3397 Ty: RI.ElementType, Ptr: DestElementAddr, Idx0: 0, Idx1: 0, Name: ".realp");
3398 Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
3399 Ty: RI.ElementType, Ptr: DestElementAddr, Idx0: 0, Idx1: 1, Name: ".imagp");
3400 Builder.CreateStore(Val: SrcReal, Ptr: DestRealPtr);
3401 Builder.CreateStore(Val: SrcImg, Ptr: DestImgPtr);
3402 break;
3403 }
3404 case EvalKind::Aggregate: {
3405 Value *SizeVal = Builder.getInt64(
3406 C: M.getDataLayout().getTypeStoreSize(Ty: RI.ElementType));
3407 Builder.CreateMemCpy(
3408 Dst: DestElementAddr, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
3409 Src: SrcElementAddr, SrcAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
3410 Size: SizeVal, isVolatile: false);
3411 break;
3412 }
3413 };
3414 }
3415
3416 // Step 3.1: Modify reference in dest Reduce list as needed.
3417 // Modifying the reference in Reduce list to point to the newly
3418 // created element. The element is live in the current function
3419 // scope and that of functions it invokes (i.e., reduce_function).
3420 // RemoteReduceData[i] = (void*)&RemoteElem
3421 if (UpdateDestListPtr) {
3422 Value *CastDestAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3423 V: DestElementAddr, DestTy: Builder.getPtrTy(),
3424 Name: DestElementAddr->getName() + ".ascast");
3425 Builder.CreateStore(Val: CastDestAddr, Ptr: DestElementPtrAddr);
3426 }
3427 }
3428
3429 return Error::success();
3430}
3431
3432Expected<Function *> OpenMPIRBuilder::emitInterWarpCopyFunction(
3433 const LocationDescription &Loc, ArrayRef<ReductionInfo> ReductionInfos,
3434 AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
3435 InsertPointTy SavedIP = Builder.saveIP();
3436 LLVMContext &Ctx = M.getContext();
3437 FunctionType *FuncTy = FunctionType::get(
3438 Result: Builder.getVoidTy(), Params: {Builder.getPtrTy(), Builder.getInt32Ty()},
3439 /* IsVarArg */ isVarArg: false);
3440 Function *WcFunc =
3441 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3442 N: "_omp_reduction_inter_warp_copy_func", M: &M);
3443 WcFunc->setCallingConv(Config.getRuntimeCC());
3444 WcFunc->setAttributes(FuncAttrs);
3445 WcFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3446 WcFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3447 BasicBlock *EntryBB = BasicBlock::Create(Context&: M.getContext(), Name: "entry", Parent: WcFunc);
3448 Builder.SetInsertPoint(EntryBB);
3449
3450 // ReduceList: thread local Reduce list.
3451 // At the stage of the computation when this function is called, partially
3452 // aggregated values reside in the first lane of every active warp.
3453 Argument *ReduceListArg = WcFunc->getArg(i: 0);
3454 // NumWarps: number of warps active in the parallel region. This could
3455 // be smaller than 32 (max warps in a CTA) for partial block reduction.
3456 Argument *NumWarpsArg = WcFunc->getArg(i: 1);
3457
3458 // This array is used as a medium to transfer, one reduce element at a time,
3459 // the data from the first lane of every warp to lanes in the first warp
3460 // in order to perform the final step of a reduction in a parallel region
3461 // (reduction across warps). The array is placed in NVPTX __shared__ memory
3462 // for reduced latency, as well as to have a distinct copy for concurrently
3463 // executing target regions. The array is declared with common linkage so
3464 // as to be shared across compilation units.
3465 StringRef TransferMediumName =
3466 "__openmp_nvptx_data_transfer_temporary_storage";
3467 GlobalVariable *TransferMedium = M.getGlobalVariable(Name: TransferMediumName);
3468 unsigned WarpSize = Config.getGridValue().GV_Warp_Size;
3469 ArrayType *ArrayTy = ArrayType::get(ElementType: Builder.getInt32Ty(), NumElements: WarpSize);
3470 if (!TransferMedium) {
3471 TransferMedium = new GlobalVariable(
3472 M, ArrayTy, /*isConstant=*/false, GlobalVariable::WeakAnyLinkage,
3473 UndefValue::get(T: ArrayTy), TransferMediumName,
3474 /*InsertBefore=*/nullptr, GlobalVariable::NotThreadLocal,
3475 /*AddressSpace=*/3);
3476 }
3477
3478 // Get the CUDA thread id of the current OpenMP thread on the GPU.
3479 Value *GPUThreadID = getGPUThreadID();
3480 // nvptx_lane_id = nvptx_id % warpsize
3481 Value *LaneID = getNVPTXLaneID();
3482 // nvptx_warp_id = nvptx_id / warpsize
3483 Value *WarpID = getNVPTXWarpID();
3484
3485 InsertPointTy AllocaIP =
3486 InsertPointTy(Builder.GetInsertBlock(),
3487 Builder.GetInsertBlock()->getFirstInsertionPt());
3488 Type *Arg0Type = ReduceListArg->getType();
3489 Type *Arg1Type = NumWarpsArg->getType();
3490 Builder.restoreIP(IP: AllocaIP);
3491 AllocaInst *ReduceListAlloca = Builder.CreateAlloca(
3492 Ty: Arg0Type, ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3493 AllocaInst *NumWarpsAlloca =
3494 Builder.CreateAlloca(Ty: Arg1Type, ArraySize: nullptr, Name: NumWarpsArg->getName() + ".addr");
3495 Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3496 V: ReduceListAlloca, DestTy: Arg0Type, Name: ReduceListAlloca->getName() + ".ascast");
3497 Value *NumWarpsAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3498 V: NumWarpsAlloca, DestTy: Builder.getPtrTy(AddrSpace: 0),
3499 Name: NumWarpsAlloca->getName() + ".ascast");
3500 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListAddrCast);
3501 Builder.CreateStore(Val: NumWarpsArg, Ptr: NumWarpsAddrCast);
3502 AllocaIP = getInsertPointAfterInstr(I: NumWarpsAlloca);
3503 InsertPointTy CodeGenIP =
3504 getInsertPointAfterInstr(I: &Builder.GetInsertBlock()->back());
3505 Builder.restoreIP(IP: CodeGenIP);
3506
3507 Value *ReduceList =
3508 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListAddrCast);
3509
3510 for (auto En : enumerate(First&: ReductionInfos)) {
3511 //
3512 // Warp master copies reduce element to transfer medium in __shared__
3513 // memory.
3514 //
3515 const ReductionInfo &RI = En.value();
3516 bool IsByRefElem = !IsByRef.empty() && IsByRef[En.index()];
3517 unsigned RealTySize = M.getDataLayout().getTypeAllocSize(
3518 Ty: IsByRefElem ? RI.ByRefElementType : RI.ElementType);
3519 for (unsigned TySize = 4; TySize > 0 && RealTySize > 0; TySize /= 2) {
3520 Type *CType = Builder.getIntNTy(N: TySize * 8);
3521
3522 unsigned NumIters = RealTySize / TySize;
3523 if (NumIters == 0)
3524 continue;
3525 Value *Cnt = nullptr;
3526 Value *CntAddr = nullptr;
3527 BasicBlock *PrecondBB = nullptr;
3528 BasicBlock *ExitBB = nullptr;
3529 if (NumIters > 1) {
3530 CodeGenIP = Builder.saveIP();
3531 Builder.restoreIP(IP: AllocaIP);
3532 CntAddr =
3533 Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr, Name: ".cnt.addr");
3534
3535 CntAddr = Builder.CreateAddrSpaceCast(V: CntAddr, DestTy: Builder.getPtrTy(),
3536 Name: CntAddr->getName() + ".ascast");
3537 Builder.restoreIP(IP: CodeGenIP);
3538 Builder.CreateStore(Val: Constant::getNullValue(Ty: Builder.getInt32Ty()),
3539 Ptr: CntAddr,
3540 /*Volatile=*/isVolatile: false);
3541 PrecondBB = BasicBlock::Create(Context&: Ctx, Name: "precond");
3542 ExitBB = BasicBlock::Create(Context&: Ctx, Name: "exit");
3543 BasicBlock *BodyBB = BasicBlock::Create(Context&: Ctx, Name: "body");
3544 emitBlock(BB: PrecondBB, CurFn: Builder.GetInsertBlock()->getParent());
3545 Cnt = Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: CntAddr,
3546 /*Volatile=*/isVolatile: false);
3547 Value *Cmp = Builder.CreateICmpULT(
3548 LHS: Cnt, RHS: ConstantInt::get(Ty: Builder.getInt32Ty(), V: NumIters));
3549 Builder.CreateCondBr(Cond: Cmp, True: BodyBB, False: ExitBB);
3550 emitBlock(BB: BodyBB, CurFn: Builder.GetInsertBlock()->getParent());
3551 }
3552
3553 // kmpc_barrier.
3554 InsertPointOrErrorTy BarrierIP1 =
3555 createBarrier(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
3556 Kind: omp::Directive::OMPD_unknown,
3557 /* ForceSimpleCall */ false,
3558 /* CheckCancelFlag */ true);
3559 if (!BarrierIP1)
3560 return BarrierIP1.takeError();
3561 BasicBlock *ThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
3562 BasicBlock *ElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
3563 BasicBlock *MergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
3564
3565 // if (lane_id == 0)
3566 Value *IsWarpMaster = Builder.CreateIsNull(Arg: LaneID, Name: "warp_master");
3567 Builder.CreateCondBr(Cond: IsWarpMaster, True: ThenBB, False: ElseBB);
3568 emitBlock(BB: ThenBB, CurFn: Builder.GetInsertBlock()->getParent());
3569
3570 // Reduce element = LocalReduceList[i]
3571 auto *RedListArrayTy =
3572 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3573 Type *IndexTy = Builder.getIndexTy(
3574 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3575 Value *ElemPtrPtr =
3576 Builder.CreateInBoundsGEP(Ty: RedListArrayTy, Ptr: ReduceList,
3577 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0),
3578 ConstantInt::get(Ty: IndexTy, V: En.index())});
3579 // elemptr = ((CopyType*)(elemptrptr)) + I
3580 Value *ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtrPtr);
3581
3582 if (IsByRefElem && RI.DataPtrPtrGen) {
3583 InsertPointOrErrorTy GenRes =
3584 RI.DataPtrPtrGen(Builder.saveIP(), ElemPtr, ElemPtr);
3585
3586 if (!GenRes)
3587 return GenRes.takeError();
3588
3589 ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtr);
3590 }
3591
3592 if (NumIters > 1)
3593 ElemPtr = Builder.CreateGEP(Ty: Builder.getInt32Ty(), Ptr: ElemPtr, IdxList: Cnt);
3594
3595 // Get pointer to location in transfer medium.
3596 // MediumPtr = &medium[warp_id]
3597 Value *MediumPtr = Builder.CreateInBoundsGEP(
3598 Ty: ArrayTy, Ptr: TransferMedium, IdxList: {Builder.getInt64(C: 0), WarpID});
3599 // elem = *elemptr
3600 //*MediumPtr = elem
3601 Value *Elem = Builder.CreateLoad(Ty: CType, Ptr: ElemPtr);
3602 // Store the source element value to the dest element address.
3603 Builder.CreateStore(Val: Elem, Ptr: MediumPtr,
3604 /*IsVolatile*/ isVolatile: true);
3605 Builder.CreateBr(Dest: MergeBB);
3606
3607 // else
3608 emitBlock(BB: ElseBB, CurFn: Builder.GetInsertBlock()->getParent());
3609 Builder.CreateBr(Dest: MergeBB);
3610
3611 // endif
3612 emitBlock(BB: MergeBB, CurFn: Builder.GetInsertBlock()->getParent());
3613 InsertPointOrErrorTy BarrierIP2 =
3614 createBarrier(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
3615 Kind: omp::Directive::OMPD_unknown,
3616 /* ForceSimpleCall */ false,
3617 /* CheckCancelFlag */ true);
3618 if (!BarrierIP2)
3619 return BarrierIP2.takeError();
3620
3621 // Warp 0 copies reduce element from transfer medium
3622 BasicBlock *W0ThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
3623 BasicBlock *W0ElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
3624 BasicBlock *W0MergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
3625
3626 Value *NumWarpsVal =
3627 Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: NumWarpsAddrCast);
3628 // Up to 32 threads in warp 0 are active.
3629 Value *IsActiveThread =
3630 Builder.CreateICmpULT(LHS: GPUThreadID, RHS: NumWarpsVal, Name: "is_active_thread");
3631 Builder.CreateCondBr(Cond: IsActiveThread, True: W0ThenBB, False: W0ElseBB);
3632
3633 emitBlock(BB: W0ThenBB, CurFn: Builder.GetInsertBlock()->getParent());
3634
3635 // SecMediumPtr = &medium[tid]
3636 // SrcMediumVal = *SrcMediumPtr
3637 Value *SrcMediumPtrVal = Builder.CreateInBoundsGEP(
3638 Ty: ArrayTy, Ptr: TransferMedium, IdxList: {Builder.getInt64(C: 0), GPUThreadID});
3639 // TargetElemPtr = (CopyType*)(SrcDataAddr[i]) + I
3640 Value *TargetElemPtrPtr =
3641 Builder.CreateInBoundsGEP(Ty: RedListArrayTy, Ptr: ReduceList,
3642 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0),
3643 ConstantInt::get(Ty: IndexTy, V: En.index())});
3644 Value *TargetElemPtrVal =
3645 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: TargetElemPtrPtr);
3646 Value *TargetElemPtr = TargetElemPtrVal;
3647
3648 if (IsByRefElem && RI.DataPtrPtrGen) {
3649 InsertPointOrErrorTy GenRes =
3650 RI.DataPtrPtrGen(Builder.saveIP(), TargetElemPtr, TargetElemPtr);
3651
3652 if (!GenRes)
3653 return GenRes.takeError();
3654
3655 TargetElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: TargetElemPtr);
3656 }
3657
3658 if (NumIters > 1)
3659 TargetElemPtr =
3660 Builder.CreateGEP(Ty: Builder.getInt32Ty(), Ptr: TargetElemPtr, IdxList: Cnt);
3661
3662 // *TargetElemPtr = SrcMediumVal;
3663 Value *SrcMediumValue =
3664 Builder.CreateLoad(Ty: CType, Ptr: SrcMediumPtrVal, /*IsVolatile*/ isVolatile: true);
3665 Builder.CreateStore(Val: SrcMediumValue, Ptr: TargetElemPtr);
3666 Builder.CreateBr(Dest: W0MergeBB);
3667
3668 emitBlock(BB: W0ElseBB, CurFn: Builder.GetInsertBlock()->getParent());
3669 Builder.CreateBr(Dest: W0MergeBB);
3670
3671 emitBlock(BB: W0MergeBB, CurFn: Builder.GetInsertBlock()->getParent());
3672
3673 if (NumIters > 1) {
3674 Cnt = Builder.CreateNSWAdd(
3675 LHS: Cnt, RHS: ConstantInt::get(Ty: Builder.getInt32Ty(), /*V=*/1));
3676 Builder.CreateStore(Val: Cnt, Ptr: CntAddr, /*Volatile=*/isVolatile: false);
3677
3678 auto *CurFn = Builder.GetInsertBlock()->getParent();
3679 emitBranch(Target: PrecondBB);
3680 emitBlock(BB: ExitBB, CurFn);
3681 }
3682 RealTySize %= TySize;
3683 }
3684 }
3685
3686 Builder.CreateRetVoid();
3687 Builder.restoreIP(IP: SavedIP);
3688
3689 return WcFunc;
3690}
3691
3692Expected<Function *> OpenMPIRBuilder::emitShuffleAndReduceFunction(
3693 ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
3694 AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
3695 LLVMContext &Ctx = M.getContext();
3696 FunctionType *FuncTy =
3697 FunctionType::get(Result: Builder.getVoidTy(),
3698 Params: {Builder.getPtrTy(), Builder.getInt16Ty(),
3699 Builder.getInt16Ty(), Builder.getInt16Ty()},
3700 /* IsVarArg */ isVarArg: false);
3701 Function *SarFunc =
3702 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3703 N: "_omp_reduction_shuffle_and_reduce_func", M: &M);
3704 SarFunc->setCallingConv(Config.getRuntimeCC());
3705 SarFunc->setAttributes(FuncAttrs);
3706 SarFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3707 SarFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3708 SarFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
3709 SarFunc->addParamAttr(ArgNo: 3, Kind: Attribute::NoUndef);
3710 SarFunc->addParamAttr(ArgNo: 1, Kind: Attribute::SExt);
3711 SarFunc->addParamAttr(ArgNo: 2, Kind: Attribute::SExt);
3712 SarFunc->addParamAttr(ArgNo: 3, Kind: Attribute::SExt);
3713 BasicBlock *EntryBB = BasicBlock::Create(Context&: M.getContext(), Name: "entry", Parent: SarFunc);
3714 Builder.SetInsertPoint(EntryBB);
3715
3716 // Thread local Reduce list used to host the values of data to be reduced.
3717 Argument *ReduceListArg = SarFunc->getArg(i: 0);
3718 // Current lane id; could be logical.
3719 Argument *LaneIDArg = SarFunc->getArg(i: 1);
3720 // Offset of the remote source lane relative to the current lane.
3721 Argument *RemoteLaneOffsetArg = SarFunc->getArg(i: 2);
3722 // Algorithm version. This is expected to be known at compile time.
3723 Argument *AlgoVerArg = SarFunc->getArg(i: 3);
3724
3725 Type *ReduceListArgType = ReduceListArg->getType();
3726 Type *LaneIDArgType = LaneIDArg->getType();
3727 Type *LaneIDArgPtrType = Builder.getPtrTy(AddrSpace: 0);
3728 Value *ReduceListAlloca = Builder.CreateAlloca(
3729 Ty: ReduceListArgType, ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3730 Value *LaneIdAlloca = Builder.CreateAlloca(Ty: LaneIDArgType, ArraySize: nullptr,
3731 Name: LaneIDArg->getName() + ".addr");
3732 Value *RemoteLaneOffsetAlloca = Builder.CreateAlloca(
3733 Ty: LaneIDArgType, ArraySize: nullptr, Name: RemoteLaneOffsetArg->getName() + ".addr");
3734 Value *AlgoVerAlloca = Builder.CreateAlloca(Ty: LaneIDArgType, ArraySize: nullptr,
3735 Name: AlgoVerArg->getName() + ".addr");
3736 ArrayType *RedListArrayTy =
3737 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3738
3739 // Create a local thread-private variable to host the Reduce list
3740 // from a remote lane.
3741 Instruction *RemoteReductionListAlloca = Builder.CreateAlloca(
3742 Ty: RedListArrayTy, ArraySize: nullptr, Name: ".omp.reduction.remote_reduce_list");
3743
3744 Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3745 V: ReduceListAlloca, DestTy: ReduceListArgType,
3746 Name: ReduceListAlloca->getName() + ".ascast");
3747 Value *LaneIdAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3748 V: LaneIdAlloca, DestTy: LaneIDArgPtrType, Name: LaneIdAlloca->getName() + ".ascast");
3749 Value *RemoteLaneOffsetAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3750 V: RemoteLaneOffsetAlloca, DestTy: LaneIDArgPtrType,
3751 Name: RemoteLaneOffsetAlloca->getName() + ".ascast");
3752 Value *AlgoVerAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3753 V: AlgoVerAlloca, DestTy: LaneIDArgPtrType, Name: AlgoVerAlloca->getName() + ".ascast");
3754 Value *RemoteListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3755 V: RemoteReductionListAlloca, DestTy: Builder.getPtrTy(),
3756 Name: RemoteReductionListAlloca->getName() + ".ascast");
3757
3758 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListAddrCast);
3759 Builder.CreateStore(Val: LaneIDArg, Ptr: LaneIdAddrCast);
3760 Builder.CreateStore(Val: RemoteLaneOffsetArg, Ptr: RemoteLaneOffsetAddrCast);
3761 Builder.CreateStore(Val: AlgoVerArg, Ptr: AlgoVerAddrCast);
3762
3763 Value *ReduceList = Builder.CreateLoad(Ty: ReduceListArgType, Ptr: ReduceListAddrCast);
3764 Value *LaneId = Builder.CreateLoad(Ty: LaneIDArgType, Ptr: LaneIdAddrCast);
3765 Value *RemoteLaneOffset =
3766 Builder.CreateLoad(Ty: LaneIDArgType, Ptr: RemoteLaneOffsetAddrCast);
3767 Value *AlgoVer = Builder.CreateLoad(Ty: LaneIDArgType, Ptr: AlgoVerAddrCast);
3768
3769 InsertPointTy AllocaIP = getInsertPointAfterInstr(I: RemoteReductionListAlloca);
3770
3771 // This loop iterates through the list of reduce elements and copies,
3772 // element by element, from a remote lane in the warp to RemoteReduceList,
3773 // hosted on the thread's stack.
3774 Error EmitRedLsCpRes = emitReductionListCopy(
3775 AllocaIP, Action: CopyAction::RemoteLaneToThread, ReductionArrayTy: RedListArrayTy, ReductionInfos,
3776 SrcBase: ReduceList, DestBase: RemoteListAddrCast, IsByRef,
3777 CopyOptions: {.RemoteLaneOffset: RemoteLaneOffset, .ScratchpadIndex: nullptr, .ScratchpadWidth: nullptr});
3778
3779 if (EmitRedLsCpRes)
3780 return EmitRedLsCpRes;
3781
3782 // The actions to be performed on the Remote Reduce list is dependent
3783 // on the algorithm version.
3784 //
3785 // if (AlgoVer==0) || (AlgoVer==1 && (LaneId < Offset)) || (AlgoVer==2 &&
3786 // LaneId % 2 == 0 && Offset > 0):
3787 // do the reduction value aggregation
3788 //
3789 // The thread local variable Reduce list is mutated in place to host the
3790 // reduced data, which is the aggregated value produced from local and
3791 // remote lanes.
3792 //
3793 // Note that AlgoVer is expected to be a constant integer known at compile
3794 // time.
3795 // When AlgoVer==0, the first conjunction evaluates to true, making
3796 // the entire predicate true during compile time.
3797 // When AlgoVer==1, the second conjunction has only the second part to be
3798 // evaluated during runtime. Other conjunctions evaluates to false
3799 // during compile time.
3800 // When AlgoVer==2, the third conjunction has only the second part to be
3801 // evaluated during runtime. Other conjunctions evaluates to false
3802 // during compile time.
3803 Value *CondAlgo0 = Builder.CreateIsNull(Arg: AlgoVer);
3804 Value *Algo1 = Builder.CreateICmpEQ(LHS: AlgoVer, RHS: Builder.getInt16(C: 1));
3805 Value *LaneComp = Builder.CreateICmpULT(LHS: LaneId, RHS: RemoteLaneOffset);
3806 Value *CondAlgo1 = Builder.CreateAnd(LHS: Algo1, RHS: LaneComp);
3807 Value *Algo2 = Builder.CreateICmpEQ(LHS: AlgoVer, RHS: Builder.getInt16(C: 2));
3808 Value *LaneIdAnd1 = Builder.CreateAnd(LHS: LaneId, RHS: Builder.getInt16(C: 1));
3809 Value *LaneIdComp = Builder.CreateIsNull(Arg: LaneIdAnd1);
3810 Value *Algo2AndLaneIdComp = Builder.CreateAnd(LHS: Algo2, RHS: LaneIdComp);
3811 Value *RemoteOffsetComp =
3812 Builder.CreateICmpSGT(LHS: RemoteLaneOffset, RHS: Builder.getInt16(C: 0));
3813 Value *CondAlgo2 = Builder.CreateAnd(LHS: Algo2AndLaneIdComp, RHS: RemoteOffsetComp);
3814 Value *CA0OrCA1 = Builder.CreateOr(LHS: CondAlgo0, RHS: CondAlgo1);
3815 Value *CondReduce = Builder.CreateOr(LHS: CA0OrCA1, RHS: CondAlgo2);
3816
3817 BasicBlock *ThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
3818 BasicBlock *ElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
3819 BasicBlock *MergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
3820
3821 Builder.CreateCondBr(Cond: CondReduce, True: ThenBB, False: ElseBB);
3822 emitBlock(BB: ThenBB, CurFn: Builder.GetInsertBlock()->getParent());
3823 Value *LocalReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3824 V: ReduceList, DestTy: Builder.getPtrTy());
3825 Value *RemoteReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3826 V: RemoteListAddrCast, DestTy: Builder.getPtrTy());
3827 createRuntimeFunctionCall(Callee: ReduceFn, Args: {LocalReduceListPtr, RemoteReduceListPtr})
3828 ->addFnAttr(Kind: Attribute::NoUnwind);
3829 Builder.CreateBr(Dest: MergeBB);
3830
3831 emitBlock(BB: ElseBB, CurFn: Builder.GetInsertBlock()->getParent());
3832 Builder.CreateBr(Dest: MergeBB);
3833
3834 emitBlock(BB: MergeBB, CurFn: Builder.GetInsertBlock()->getParent());
3835
3836 // if (AlgoVer==1 && (LaneId >= Offset)) copy Remote Reduce list to local
3837 // Reduce list.
3838 Algo1 = Builder.CreateICmpEQ(LHS: AlgoVer, RHS: Builder.getInt16(C: 1));
3839 Value *LaneIdGtOffset = Builder.CreateICmpUGE(LHS: LaneId, RHS: RemoteLaneOffset);
3840 Value *CondCopy = Builder.CreateAnd(LHS: Algo1, RHS: LaneIdGtOffset);
3841
3842 BasicBlock *CpyThenBB = BasicBlock::Create(Context&: Ctx, Name: "then");
3843 BasicBlock *CpyElseBB = BasicBlock::Create(Context&: Ctx, Name: "else");
3844 BasicBlock *CpyMergeBB = BasicBlock::Create(Context&: Ctx, Name: "ifcont");
3845 Builder.CreateCondBr(Cond: CondCopy, True: CpyThenBB, False: CpyElseBB);
3846
3847 emitBlock(BB: CpyThenBB, CurFn: Builder.GetInsertBlock()->getParent());
3848
3849 EmitRedLsCpRes = emitReductionListCopy(
3850 AllocaIP, Action: CopyAction::ThreadCopy, ReductionArrayTy: RedListArrayTy, ReductionInfos,
3851 SrcBase: RemoteListAddrCast, DestBase: ReduceList, IsByRef);
3852
3853 if (EmitRedLsCpRes)
3854 return EmitRedLsCpRes;
3855
3856 Builder.CreateBr(Dest: CpyMergeBB);
3857
3858 emitBlock(BB: CpyElseBB, CurFn: Builder.GetInsertBlock()->getParent());
3859 Builder.CreateBr(Dest: CpyMergeBB);
3860
3861 emitBlock(BB: CpyMergeBB, CurFn: Builder.GetInsertBlock()->getParent());
3862
3863 Builder.CreateRetVoid();
3864
3865 return SarFunc;
3866}
3867
3868OpenMPIRBuilder::InsertPointOrErrorTy
3869OpenMPIRBuilder::generateReductionDescriptor(
3870 Value *DescriptorAddr, Value *DataPtr, Value *SrcDescriptorAddr,
3871 Type *DescriptorType,
3872 function_ref<InsertPointOrErrorTy(InsertPointTy, Value *, Value *&)>
3873 DataPtrPtrGen) {
3874
3875 // Copy the source descriptor to preserve all metadata (rank, extents,
3876 // strides, etc.)
3877 Value *DescriptorSize =
3878 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: DescriptorType));
3879 Builder.CreateMemCpy(
3880 Dst: DescriptorAddr, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: DescriptorType),
3881 Src: SrcDescriptorAddr, SrcAlign: M.getDataLayout().getPrefTypeAlign(Ty: DescriptorType),
3882 Size: DescriptorSize);
3883
3884 // Update the base pointer field to point to the local shuffled data
3885 Value *DataPtrField;
3886 InsertPointOrErrorTy GenResult =
3887 DataPtrPtrGen(Builder.saveIP(), DescriptorAddr, DataPtrField);
3888
3889 if (!GenResult)
3890 return GenResult.takeError();
3891
3892 Builder.CreateStore(Val: Builder.CreatePointerBitCastOrAddrSpaceCast(
3893 V: DataPtr, DestTy: Builder.getPtrTy(), Name: ".ascast"),
3894 Ptr: DataPtrField);
3895
3896 return Builder.saveIP();
3897}
3898
3899Expected<Value *> OpenMPIRBuilder::createReductionDescriptorCopy(
3900 InsertPointTy AllocaIP, const ReductionInfo &RI, Value *DataPtr,
3901 Value *SrcDescriptorAddr, Type *DescriptorPtrTy, const Twine &Name) {
3902 InsertPointTy OldIP = Builder.saveIP();
3903 Builder.restoreIP(IP: AllocaIP);
3904
3905 AllocaInst *DescriptorAlloca =
3906 Builder.CreateAlloca(Ty: RI.ByRefAllocatedType, ArraySize: nullptr, Name);
3907 DescriptorAlloca->setAlignment(
3908 M.getDataLayout().getPrefTypeAlign(Ty: RI.ByRefAllocatedType));
3909 Value *DescriptorAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3910 V: DescriptorAlloca, DestTy: DescriptorPtrTy,
3911 Name: DescriptorAlloca->getName() + ".ascast");
3912
3913 Builder.restoreIP(IP: OldIP);
3914
3915 InsertPointOrErrorTy GenResult =
3916 generateReductionDescriptor(DescriptorAddr, DataPtr, SrcDescriptorAddr,
3917 DescriptorType: RI.ByRefAllocatedType, DataPtrPtrGen: RI.DataPtrPtrGen);
3918 if (!GenResult)
3919 return GenResult.takeError();
3920
3921 return DescriptorAddr;
3922}
3923
3924Expected<Function *> OpenMPIRBuilder::emitListToGlobalCopyFunction(
3925 ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
3926 AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
3927 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
3928 LLVMContext &Ctx = M.getContext();
3929 FunctionType *FuncTy = FunctionType::get(
3930 Result: Builder.getVoidTy(),
3931 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3932 /* IsVarArg */ isVarArg: false);
3933 Function *LtGCFunc =
3934 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
3935 N: "_omp_reduction_list_to_global_copy_func", M: &M);
3936 LtGCFunc->setAttributes(FuncAttrs);
3937 LtGCFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
3938 LtGCFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
3939 LtGCFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
3940
3941 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: LtGCFunc);
3942 Builder.SetInsertPoint(EntryBlock);
3943
3944 // Buffer: global reduction buffer.
3945 Argument *BufferArg = LtGCFunc->getArg(i: 0);
3946 // Idx: index of the buffer.
3947 Argument *IdxArg = LtGCFunc->getArg(i: 1);
3948 // ReduceList: thread local Reduce list.
3949 Argument *ReduceListArg = LtGCFunc->getArg(i: 2);
3950
3951 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
3952 Name: BufferArg->getName() + ".addr");
3953 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
3954 Name: IdxArg->getName() + ".addr");
3955 Value *ReduceListArgAlloca = Builder.CreateAlloca(
3956 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
3957 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3958 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
3959 Name: BufferArgAlloca->getName() + ".ascast");
3960 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3961 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
3962 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3963 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
3964 Name: ReduceListArgAlloca->getName() + ".ascast");
3965
3966 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
3967 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
3968 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
3969
3970 Value *LocalReduceList =
3971 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
3972 Value *BufferArgVal =
3973 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
3974 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
3975 Type *IndexTy = Builder.getIndexTy(
3976 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
3977 for (auto En : enumerate(First&: ReductionInfos)) {
3978 const ReductionInfo &RI = En.value();
3979 auto *RedListArrayTy =
3980 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
3981 // Reduce element = LocalReduceList[i]
3982 Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
3983 Ty: RedListArrayTy, Ptr: LocalReduceList,
3984 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
3985 // elemptr = ((CopyType*)(elemptrptr)) + I
3986 Value *ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtrPtr);
3987
3988 // Global = Buffer.VD[Idx];
3989 Value *BufferVD =
3990 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferArgVal, IdxList: Idxs);
3991 Value *GlobVal = Builder.CreateConstInBoundsGEP2_32(
3992 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
3993
3994 switch (RI.EvaluationKind) {
3995 case EvalKind::Scalar: {
3996 Value *TargetElement;
3997
3998 if (IsByRef.empty() || !IsByRef[En.index()]) {
3999 TargetElement = Builder.CreateLoad(Ty: RI.ElementType, Ptr: ElemPtr);
4000 } else {
4001 if (RI.DataPtrPtrGen) {
4002 InsertPointOrErrorTy GenResult =
4003 RI.DataPtrPtrGen(Builder.saveIP(), ElemPtr, ElemPtr);
4004
4005 if (!GenResult)
4006 return GenResult.takeError();
4007
4008 ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtr);
4009 }
4010 TargetElement = Builder.CreateLoad(Ty: RI.ByRefElementType, Ptr: ElemPtr);
4011 }
4012
4013 Builder.CreateStore(Val: TargetElement, Ptr: GlobVal);
4014 break;
4015 }
4016 case EvalKind::Complex: {
4017 Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
4018 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 0, Name: ".realp");
4019 Value *SrcReal = Builder.CreateLoad(
4020 Ty: RI.ElementType->getStructElementType(N: 0), Ptr: SrcRealPtr, Name: ".real");
4021 Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
4022 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 1, Name: ".imagp");
4023 Value *SrcImg = Builder.CreateLoad(
4024 Ty: RI.ElementType->getStructElementType(N: 1), Ptr: SrcImgPtr, Name: ".imag");
4025
4026 Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
4027 Ty: RI.ElementType, Ptr: GlobVal, Idx0: 0, Idx1: 0, Name: ".realp");
4028 Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
4029 Ty: RI.ElementType, Ptr: GlobVal, Idx0: 0, Idx1: 1, Name: ".imagp");
4030 Builder.CreateStore(Val: SrcReal, Ptr: DestRealPtr);
4031 Builder.CreateStore(Val: SrcImg, Ptr: DestImgPtr);
4032 break;
4033 }
4034 case EvalKind::Aggregate: {
4035 Value *SizeVal =
4036 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: RI.ElementType));
4037 Builder.CreateMemCpy(
4038 Dst: GlobVal, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType), Src: ElemPtr,
4039 SrcAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType), Size: SizeVal, isVolatile: false);
4040 break;
4041 }
4042 }
4043 }
4044
4045 Builder.CreateRetVoid();
4046 Builder.restoreIP(IP: OldIP);
4047 return LtGCFunc;
4048}
4049
4050Expected<Function *> OpenMPIRBuilder::emitListToGlobalReduceFunction(
4051 ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
4052 Type *ReductionsBufferTy, AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
4053 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
4054 LLVMContext &Ctx = M.getContext();
4055 FunctionType *FuncTy = FunctionType::get(
4056 Result: Builder.getVoidTy(),
4057 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
4058 /* IsVarArg */ isVarArg: false);
4059 Function *LtGRFunc =
4060 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
4061 N: "_omp_reduction_list_to_global_reduce_func", M: &M);
4062 LtGRFunc->setAttributes(FuncAttrs);
4063 LtGRFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
4064 LtGRFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
4065 LtGRFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
4066
4067 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: LtGRFunc);
4068 Builder.SetInsertPoint(EntryBlock);
4069
4070 // Buffer: global reduction buffer.
4071 Argument *BufferArg = LtGRFunc->getArg(i: 0);
4072 // Idx: index of the buffer.
4073 Argument *IdxArg = LtGRFunc->getArg(i: 1);
4074 // ReduceList: thread local Reduce list.
4075 Argument *ReduceListArg = LtGRFunc->getArg(i: 2);
4076
4077 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
4078 Name: BufferArg->getName() + ".addr");
4079 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
4080 Name: IdxArg->getName() + ".addr");
4081 Value *ReduceListArgAlloca = Builder.CreateAlloca(
4082 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
4083 auto *RedListArrayTy =
4084 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
4085
4086 // 1. Build a list of reduction variables.
4087 // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
4088 Value *LocalReduceList =
4089 Builder.CreateAlloca(Ty: RedListArrayTy, ArraySize: nullptr, Name: ".omp.reduction.red_list");
4090
4091 InsertPointTy AllocaIP{EntryBlock, EntryBlock->begin()};
4092
4093 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4094 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
4095 Name: BufferArgAlloca->getName() + ".ascast");
4096 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4097 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
4098 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4099 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
4100 Name: ReduceListArgAlloca->getName() + ".ascast");
4101 Value *LocalReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4102 V: LocalReduceList, DestTy: Builder.getPtrTy(),
4103 Name: LocalReduceList->getName() + ".ascast");
4104
4105 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
4106 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
4107 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
4108
4109 Value *BufferVal = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
4110 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
4111 Type *IndexTy = Builder.getIndexTy(
4112 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
4113 for (auto En : enumerate(First&: ReductionInfos)) {
4114 const ReductionInfo &RI = En.value();
4115
4116 Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
4117 Ty: RedListArrayTy, Ptr: LocalReduceListAddrCast,
4118 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
4119 Value *BufferVD =
4120 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferVal, IdxList: Idxs);
4121 // Global = Buffer.VD[Idx];
4122 Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
4123 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
4124
4125 if (!IsByRef.empty() && IsByRef[En.index()] && RI.DataPtrPtrGen) {
4126 // Get source descriptor from the reduce list argument
4127 Value *ReduceList =
4128 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
4129 Value *SrcElementPtrPtr =
4130 Builder.CreateInBoundsGEP(Ty: RedListArrayTy, Ptr: ReduceList,
4131 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0),
4132 ConstantInt::get(Ty: IndexTy, V: En.index())});
4133 Value *SrcDescriptorAddr =
4134 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: SrcElementPtrPtr);
4135
4136 // Copy descriptor from source and update base_ptr to global buffer data
4137 Expected<Value *> ByRefAlloc = createReductionDescriptorCopy(
4138 AllocaIP, RI, DataPtr: GlobValPtr, SrcDescriptorAddr, DescriptorPtrTy: Builder.getPtrTy());
4139 if (!ByRefAlloc)
4140 return ByRefAlloc.takeError();
4141
4142 Builder.CreateStore(Val: *ByRefAlloc, Ptr: TargetElementPtrPtr);
4143 } else {
4144 Builder.CreateStore(Val: GlobValPtr, Ptr: TargetElementPtrPtr);
4145 }
4146 }
4147
4148 // Call reduce_function(GlobalReduceList, ReduceList)
4149 Value *ReduceList =
4150 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
4151 createRuntimeFunctionCall(Callee: ReduceFn, Args: {LocalReduceListAddrCast, ReduceList})
4152 ->addFnAttr(Kind: Attribute::NoUnwind);
4153 Builder.CreateRetVoid();
4154 Builder.restoreIP(IP: OldIP);
4155 return LtGRFunc;
4156}
4157
4158Expected<Function *> OpenMPIRBuilder::emitGlobalToListCopyFunction(
4159 ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
4160 AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
4161 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
4162 LLVMContext &Ctx = M.getContext();
4163 FunctionType *FuncTy = FunctionType::get(
4164 Result: Builder.getVoidTy(),
4165 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
4166 /* IsVarArg */ isVarArg: false);
4167 Function *GtLCFunc =
4168 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
4169 N: "_omp_reduction_global_to_list_copy_func", M: &M);
4170 GtLCFunc->setAttributes(FuncAttrs);
4171 GtLCFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
4172 GtLCFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
4173 GtLCFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
4174
4175 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: GtLCFunc);
4176 Builder.SetInsertPoint(EntryBlock);
4177
4178 // Buffer: global reduction buffer.
4179 Argument *BufferArg = GtLCFunc->getArg(i: 0);
4180 // Idx: index of the buffer.
4181 Argument *IdxArg = GtLCFunc->getArg(i: 1);
4182 // ReduceList: thread local Reduce list.
4183 Argument *ReduceListArg = GtLCFunc->getArg(i: 2);
4184
4185 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
4186 Name: BufferArg->getName() + ".addr");
4187 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
4188 Name: IdxArg->getName() + ".addr");
4189 Value *ReduceListArgAlloca = Builder.CreateAlloca(
4190 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
4191 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4192 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
4193 Name: BufferArgAlloca->getName() + ".ascast");
4194 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4195 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
4196 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4197 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
4198 Name: ReduceListArgAlloca->getName() + ".ascast");
4199 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
4200 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
4201 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
4202
4203 Value *LocalReduceList =
4204 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
4205 Value *BufferVal = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
4206 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
4207 Type *IndexTy = Builder.getIndexTy(
4208 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
4209 for (auto En : enumerate(First&: ReductionInfos)) {
4210 const OpenMPIRBuilder::ReductionInfo &RI = En.value();
4211 auto *RedListArrayTy =
4212 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
4213 // Reduce element = LocalReduceList[i]
4214 Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
4215 Ty: RedListArrayTy, Ptr: LocalReduceList,
4216 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
4217 // elemptr = ((CopyType*)(elemptrptr)) + I
4218 Value *ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtrPtr);
4219 // Global = Buffer.VD[Idx];
4220 Value *BufferVD =
4221 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferVal, IdxList: Idxs);
4222 Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
4223 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
4224
4225 switch (RI.EvaluationKind) {
4226 case EvalKind::Scalar: {
4227 Type *ElemType = RI.ElementType;
4228
4229 if (!IsByRef.empty() && IsByRef[En.index()]) {
4230 ElemType = RI.ByRefElementType;
4231 if (RI.DataPtrPtrGen) {
4232 InsertPointOrErrorTy GenResult =
4233 RI.DataPtrPtrGen(Builder.saveIP(), ElemPtr, ElemPtr);
4234
4235 if (!GenResult)
4236 return GenResult.takeError();
4237
4238 ElemPtr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ElemPtr);
4239 }
4240 }
4241
4242 Value *TargetElement = Builder.CreateLoad(Ty: ElemType, Ptr: GlobValPtr);
4243 Builder.CreateStore(Val: TargetElement, Ptr: ElemPtr);
4244 break;
4245 }
4246 case EvalKind::Complex: {
4247 Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
4248 Ty: RI.ElementType, Ptr: GlobValPtr, Idx0: 0, Idx1: 0, Name: ".realp");
4249 Value *SrcReal = Builder.CreateLoad(
4250 Ty: RI.ElementType->getStructElementType(N: 0), Ptr: SrcRealPtr, Name: ".real");
4251 Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
4252 Ty: RI.ElementType, Ptr: GlobValPtr, Idx0: 0, Idx1: 1, Name: ".imagp");
4253 Value *SrcImg = Builder.CreateLoad(
4254 Ty: RI.ElementType->getStructElementType(N: 1), Ptr: SrcImgPtr, Name: ".imag");
4255
4256 Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
4257 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 0, Name: ".realp");
4258 Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
4259 Ty: RI.ElementType, Ptr: ElemPtr, Idx0: 0, Idx1: 1, Name: ".imagp");
4260 Builder.CreateStore(Val: SrcReal, Ptr: DestRealPtr);
4261 Builder.CreateStore(Val: SrcImg, Ptr: DestImgPtr);
4262 break;
4263 }
4264 case EvalKind::Aggregate: {
4265 Value *SizeVal =
4266 Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: RI.ElementType));
4267 Builder.CreateMemCpy(
4268 Dst: ElemPtr, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
4269 Src: GlobValPtr, SrcAlign: M.getDataLayout().getPrefTypeAlign(Ty: RI.ElementType),
4270 Size: SizeVal, isVolatile: false);
4271 break;
4272 }
4273 }
4274 }
4275
4276 Builder.CreateRetVoid();
4277 Builder.restoreIP(IP: OldIP);
4278 return GtLCFunc;
4279}
4280
4281Expected<Function *> OpenMPIRBuilder::emitGlobalToListReduceFunction(
4282 ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
4283 Type *ReductionsBufferTy, AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
4284 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
4285 LLVMContext &Ctx = M.getContext();
4286 auto *FuncTy = FunctionType::get(
4287 Result: Builder.getVoidTy(),
4288 Params: {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
4289 /* IsVarArg */ isVarArg: false);
4290 Function *GtLRFunc =
4291 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
4292 N: "_omp_reduction_global_to_list_reduce_func", M: &M);
4293 GtLRFunc->setAttributes(FuncAttrs);
4294 GtLRFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
4295 GtLRFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
4296 GtLRFunc->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
4297
4298 BasicBlock *EntryBlock = BasicBlock::Create(Context&: Ctx, Name: "entry", Parent: GtLRFunc);
4299 Builder.SetInsertPoint(EntryBlock);
4300
4301 // Buffer: global reduction buffer.
4302 Argument *BufferArg = GtLRFunc->getArg(i: 0);
4303 // Idx: index of the buffer.
4304 Argument *IdxArg = GtLRFunc->getArg(i: 1);
4305 // ReduceList: thread local Reduce list.
4306 Argument *ReduceListArg = GtLRFunc->getArg(i: 2);
4307
4308 Value *BufferArgAlloca = Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr,
4309 Name: BufferArg->getName() + ".addr");
4310 Value *IdxArgAlloca = Builder.CreateAlloca(Ty: Builder.getInt32Ty(), ArraySize: nullptr,
4311 Name: IdxArg->getName() + ".addr");
4312 Value *ReduceListArgAlloca = Builder.CreateAlloca(
4313 Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: ReduceListArg->getName() + ".addr");
4314 ArrayType *RedListArrayTy =
4315 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
4316
4317 // 1. Build a list of reduction variables.
4318 // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
4319 Value *LocalReduceList =
4320 Builder.CreateAlloca(Ty: RedListArrayTy, ArraySize: nullptr, Name: ".omp.reduction.red_list");
4321
4322 InsertPointTy AllocaIP{EntryBlock, EntryBlock->begin()};
4323
4324 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4325 V: BufferArgAlloca, DestTy: Builder.getPtrTy(),
4326 Name: BufferArgAlloca->getName() + ".ascast");
4327 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4328 V: IdxArgAlloca, DestTy: Builder.getPtrTy(), Name: IdxArgAlloca->getName() + ".ascast");
4329 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4330 V: ReduceListArgAlloca, DestTy: Builder.getPtrTy(),
4331 Name: ReduceListArgAlloca->getName() + ".ascast");
4332 Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
4333 V: LocalReduceList, DestTy: Builder.getPtrTy(),
4334 Name: LocalReduceList->getName() + ".ascast");
4335
4336 Builder.CreateStore(Val: BufferArg, Ptr: BufferArgAddrCast);
4337 Builder.CreateStore(Val: IdxArg, Ptr: IdxArgAddrCast);
4338 Builder.CreateStore(Val: ReduceListArg, Ptr: ReduceListArgAddrCast);
4339
4340 Value *BufferVal = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BufferArgAddrCast);
4341 Value *Idxs[] = {Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: IdxArgAddrCast)};
4342 Type *IndexTy = Builder.getIndexTy(
4343 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
4344 for (auto En : enumerate(First&: ReductionInfos)) {
4345 const ReductionInfo &RI = En.value();
4346
4347 Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
4348 Ty: RedListArrayTy, Ptr: ReductionList,
4349 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
4350 // Global = Buffer.VD[Idx];
4351 Value *BufferVD =
4352 Builder.CreateInBoundsGEP(Ty: ReductionsBufferTy, Ptr: BufferVal, IdxList: Idxs);
4353 Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
4354 Ty: ReductionsBufferTy, Ptr: BufferVD, Idx0: 0, Idx1: En.index());
4355
4356 if (!IsByRef.empty() && IsByRef[En.index()] && RI.DataPtrPtrGen) {
4357 // Get source descriptor from the reduce list
4358 Value *ReduceListVal =
4359 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
4360 Value *SrcElementPtrPtr =
4361 Builder.CreateInBoundsGEP(Ty: RedListArrayTy, Ptr: ReduceListVal,
4362 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0),
4363 ConstantInt::get(Ty: IndexTy, V: En.index())});
4364 Value *SrcDescriptorAddr =
4365 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: SrcElementPtrPtr);
4366
4367 // Copy descriptor from source and update base_ptr to global buffer data
4368 Expected<Value *> ByRefAlloc = createReductionDescriptorCopy(
4369 AllocaIP, RI, DataPtr: GlobValPtr, SrcDescriptorAddr, DescriptorPtrTy: Builder.getPtrTy());
4370 if (!ByRefAlloc)
4371 return ByRefAlloc.takeError();
4372
4373 Builder.CreateStore(Val: *ByRefAlloc, Ptr: TargetElementPtrPtr);
4374 } else {
4375 Builder.CreateStore(Val: GlobValPtr, Ptr: TargetElementPtrPtr);
4376 }
4377 }
4378
4379 // Call reduce_function(ReduceList, GlobalReduceList)
4380 Value *ReduceList =
4381 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: ReduceListArgAddrCast);
4382 createRuntimeFunctionCall(Callee: ReduceFn, Args: {ReduceList, ReductionList})
4383 ->addFnAttr(Kind: Attribute::NoUnwind);
4384 Builder.CreateRetVoid();
4385 Builder.restoreIP(IP: OldIP);
4386 return GtLRFunc;
4387}
4388
4389std::string OpenMPIRBuilder::getReductionFuncName(StringRef Name) const {
4390 std::string Suffix =
4391 createPlatformSpecificName(Parts: {"omp", "reduction", "reduction_func"});
4392 return (Name + Suffix).str();
4393}
4394
4395Expected<Function *> OpenMPIRBuilder::createReductionFunction(
4396 StringRef ReducerName, ArrayRef<ReductionInfo> ReductionInfos,
4397 ArrayRef<bool> IsByRef, ReductionGenCBKind ReductionGenCBKind,
4398 AttributeList FuncAttrs) {
4399 auto *FuncTy = FunctionType::get(Result: Builder.getVoidTy(),
4400 Params: {Builder.getPtrTy(), Builder.getPtrTy()},
4401 /* IsVarArg */ isVarArg: false);
4402 std::string Name = getReductionFuncName(Name: ReducerName);
4403 Function *ReductionFunc =
4404 Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage, N: Name, M: &M);
4405 ReductionFunc->setCallingConv(Config.getRuntimeCC());
4406 ReductionFunc->setAttributes(FuncAttrs);
4407 ReductionFunc->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
4408 ReductionFunc->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
4409 BasicBlock *EntryBB =
4410 BasicBlock::Create(Context&: M.getContext(), Name: "entry", Parent: ReductionFunc);
4411 Builder.SetInsertPoint(EntryBB);
4412
4413 // Need to alloca memory here and deal with the pointers before getting
4414 // LHS/RHS pointers out
4415 Value *LHSArrayPtr = nullptr;
4416 Value *RHSArrayPtr = nullptr;
4417 Argument *Arg0 = ReductionFunc->getArg(i: 0);
4418 Argument *Arg1 = ReductionFunc->getArg(i: 1);
4419 Type *Arg0Type = Arg0->getType();
4420 Type *Arg1Type = Arg1->getType();
4421
4422 Value *LHSAlloca =
4423 Builder.CreateAlloca(Ty: Arg0Type, ArraySize: nullptr, Name: Arg0->getName() + ".addr");
4424 Value *RHSAlloca =
4425 Builder.CreateAlloca(Ty: Arg1Type, ArraySize: nullptr, Name: Arg1->getName() + ".addr");
4426 Value *LHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4427 V: LHSAlloca, DestTy: Arg0Type, Name: LHSAlloca->getName() + ".ascast");
4428 Value *RHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
4429 V: RHSAlloca, DestTy: Arg1Type, Name: RHSAlloca->getName() + ".ascast");
4430 Builder.CreateStore(Val: Arg0, Ptr: LHSAddrCast);
4431 Builder.CreateStore(Val: Arg1, Ptr: RHSAddrCast);
4432 LHSArrayPtr = Builder.CreateLoad(Ty: Arg0Type, Ptr: LHSAddrCast);
4433 RHSArrayPtr = Builder.CreateLoad(Ty: Arg1Type, Ptr: RHSAddrCast);
4434
4435 Type *RedArrayTy = ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: ReductionInfos.size());
4436 Type *IndexTy = Builder.getIndexTy(
4437 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
4438 SmallVector<Value *> LHSPtrs, RHSPtrs;
4439 for (auto En : enumerate(First&: ReductionInfos)) {
4440 const ReductionInfo &RI = En.value();
4441 Value *RHSI8PtrPtr = Builder.CreateInBoundsGEP(
4442 Ty: RedArrayTy, Ptr: RHSArrayPtr,
4443 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
4444 Value *RHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: RHSI8PtrPtr);
4445 Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
4446 V: RHSI8Ptr, DestTy: RI.PrivateVariable->getType(),
4447 Name: RHSI8Ptr->getName() + ".ascast");
4448
4449 Value *LHSI8PtrPtr = Builder.CreateInBoundsGEP(
4450 Ty: RedArrayTy, Ptr: LHSArrayPtr,
4451 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
4452 Value *LHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: LHSI8PtrPtr);
4453 Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
4454 V: LHSI8Ptr, DestTy: RI.Variable->getType(), Name: LHSI8Ptr->getName() + ".ascast");
4455
4456 if (ReductionGenCBKind == ReductionGenCBKind::Clang) {
4457 LHSPtrs.emplace_back(Args&: LHSPtr);
4458 RHSPtrs.emplace_back(Args&: RHSPtr);
4459 } else {
4460 Value *LHS = LHSPtr;
4461 Value *RHS = RHSPtr;
4462
4463 if (!IsByRef.empty() && !IsByRef[En.index()]) {
4464 LHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: LHSPtr);
4465 RHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: RHSPtr);
4466 }
4467
4468 Value *Reduced;
4469 InsertPointOrErrorTy AfterIP =
4470 RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
4471 if (!AfterIP)
4472 return AfterIP.takeError();
4473 if (!Builder.GetInsertBlock())
4474 return ReductionFunc;
4475
4476 Builder.restoreIP(IP: *AfterIP);
4477
4478 if (!IsByRef.empty() && !IsByRef[En.index()])
4479 Builder.CreateStore(Val: Reduced, Ptr: LHSPtr);
4480 }
4481 }
4482
4483 if (ReductionGenCBKind == ReductionGenCBKind::Clang)
4484 for (auto En : enumerate(First&: ReductionInfos)) {
4485 unsigned Index = En.index();
4486 const ReductionInfo &RI = En.value();
4487 Value *LHSFixupPtr, *RHSFixupPtr;
4488 Builder.restoreIP(IP: RI.ReductionGenClang(
4489 Builder.saveIP(), Index, &LHSFixupPtr, &RHSFixupPtr, ReductionFunc));
4490
4491 // Fix the CallBack code genereated to use the correct Values for the LHS
4492 // and RHS
4493 LHSFixupPtr->replaceUsesWithIf(
4494 New: LHSPtrs[Index], ShouldReplace: [ReductionFunc](const Use &U) {
4495 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
4496 ReductionFunc;
4497 });
4498 RHSFixupPtr->replaceUsesWithIf(
4499 New: RHSPtrs[Index], ShouldReplace: [ReductionFunc](const Use &U) {
4500 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
4501 ReductionFunc;
4502 });
4503 }
4504
4505 Builder.CreateRetVoid();
4506 // Compiling with `-O0`, `alloca`s emitted in non-entry blocks are not hoisted
4507 // to the entry block (this is dones for higher opt levels by later passes in
4508 // the pipeline). This has caused issues because non-entry `alloca`s force the
4509 // function to use dynamic stack allocations and we might run out of scratch
4510 // memory.
4511 hoistNonEntryAllocasToEntryBlock(Func: ReductionFunc);
4512
4513 return ReductionFunc;
4514}
4515
4516static void
4517checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
4518 bool IsGPU) {
4519 for (const OpenMPIRBuilder::ReductionInfo &RI : ReductionInfos) {
4520 (void)RI;
4521 assert(RI.Variable && "expected non-null variable");
4522 assert(RI.PrivateVariable && "expected non-null private variable");
4523 assert((RI.ReductionGen || RI.ReductionGenClang) &&
4524 "expected non-null reduction generator callback");
4525 if (!IsGPU) {
4526 assert(
4527 RI.Variable->getType() == RI.PrivateVariable->getType() &&
4528 "expected variables and their private equivalents to have the same "
4529 "type");
4530 }
4531 assert(RI.Variable->getType()->isPointerTy() &&
4532 "expected variables to be pointers");
4533 }
4534}
4535
4536OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
4537 const LocationDescription &Loc, InsertPointTy AllocaIP,
4538 InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
4539 ArrayRef<bool> IsByRef, bool IsNoWait, bool IsTeamsReduction, bool IsSPMD,
4540 ReductionGenCBKind ReductionGenCBKind, std::optional<omp::GV> GridValue,
4541 Value *SrcLocInfo) {
4542 if (!updateToLocation(Loc))
4543 return InsertPointTy();
4544 Builder.restoreIP(IP: CodeGenIP);
4545 checkReductionInfos(ReductionInfos, /*IsGPU*/ true);
4546 LLVMContext &Ctx = M.getContext();
4547
4548 // Source location for the ident struct
4549 if (!SrcLocInfo) {
4550 uint32_t SrcLocStrSize;
4551 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4552 SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4553 }
4554
4555 if (ReductionInfos.size() == 0)
4556 return Builder.saveIP();
4557
4558 BasicBlock *ContinuationBlock = nullptr;
4559 if (ReductionGenCBKind != ReductionGenCBKind::Clang) {
4560 // Copied code from createReductions
4561 BasicBlock *InsertBlock = Loc.IP.getBlock();
4562 ContinuationBlock =
4563 InsertBlock->splitBasicBlock(I: Loc.IP.getPoint(), BBName: "reduce.finalize");
4564 InsertBlock->getTerminator()->eraseFromParent();
4565 Builder.SetInsertPoint(TheBB: InsertBlock, IP: InsertBlock->end());
4566 }
4567
4568 Function *CurFunc = Builder.GetInsertBlock()->getParent();
4569 AttributeList FuncAttrs;
4570 AttrBuilder AttrBldr(Ctx);
4571 for (auto Attr : CurFunc->getAttributes().getFnAttrs())
4572 AttrBldr.addAttribute(A: Attr);
4573 AttrBldr.removeAttribute(Val: Attribute::OptimizeNone);
4574 FuncAttrs = FuncAttrs.addFnAttributes(C&: Ctx, B: AttrBldr);
4575
4576 CodeGenIP = Builder.saveIP();
4577 Expected<Function *> ReductionResult = createReductionFunction(
4578 ReducerName: Builder.GetInsertBlock()->getParent()->getName(), ReductionInfos, IsByRef,
4579 ReductionGenCBKind, FuncAttrs);
4580 if (!ReductionResult)
4581 return ReductionResult.takeError();
4582 Function *ReductionFunc = *ReductionResult;
4583 Builder.restoreIP(IP: CodeGenIP);
4584
4585 // Set the grid value in the config needed for lowering later on
4586 if (GridValue.has_value())
4587 Config.setGridValue(GridValue.value());
4588 else
4589 Config.setGridValue(getGridValue(T, Kernel: ReductionFunc));
4590
4591 // Build res = __kmpc_reduce{_nowait}(<gtid>, <n>, sizeof(RedList),
4592 // RedList, shuffle_reduce_func, interwarp_copy_func);
4593 // or
4594 // Build res = __kmpc_reduce_teams_nowait_simple(<loc>, <gtid>, <lck>);
4595 Value *Res;
4596
4597 // 1. Build a list of reduction variables.
4598 // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
4599 auto Size = ReductionInfos.size();
4600 Type *PtrTy = PointerType::get(C&: Ctx, AddressSpace: Config.getDefaultTargetAS());
4601 Type *FuncPtrTy =
4602 Builder.getPtrTy(AddrSpace: M.getDataLayout().getProgramAddressSpace());
4603 Type *RedArrayTy = ArrayType::get(ElementType: PtrTy, NumElements: Size);
4604 CodeGenIP = Builder.saveIP();
4605 Builder.restoreIP(IP: AllocaIP);
4606 Value *ReductionListAlloca =
4607 Builder.CreateAlloca(Ty: RedArrayTy, ArraySize: nullptr, Name: ".omp.reduction.red_list");
4608 Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
4609 V: ReductionListAlloca, DestTy: PtrTy, Name: ReductionListAlloca->getName() + ".ascast");
4610 Builder.restoreIP(IP: CodeGenIP);
4611 Type *IndexTy = Builder.getIndexTy(
4612 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
4613 for (auto En : enumerate(First&: ReductionInfos)) {
4614 const ReductionInfo &RI = En.value();
4615 Value *ElemPtr = Builder.CreateInBoundsGEP(
4616 Ty: RedArrayTy, Ptr: ReductionList,
4617 IdxList: {ConstantInt::get(Ty: IndexTy, V: 0), ConstantInt::get(Ty: IndexTy, V: En.index())});
4618
4619 Value *PrivateVar = RI.PrivateVariable;
4620 bool IsByRefElem = !IsByRef.empty() && IsByRef[En.index()];
4621 if (IsByRefElem)
4622 PrivateVar = Builder.CreateLoad(Ty: RI.ElementType, Ptr: PrivateVar);
4623
4624 Value *CastElem =
4625 Builder.CreatePointerBitCastOrAddrSpaceCast(V: PrivateVar, DestTy: PtrTy);
4626 Builder.CreateStore(Val: CastElem, Ptr: ElemPtr);
4627 }
4628 CodeGenIP = Builder.saveIP();
4629 Expected<Function *> SarFunc = emitShuffleAndReduceFunction(
4630 ReductionInfos, ReduceFn: ReductionFunc, FuncAttrs, IsByRef);
4631
4632 if (!SarFunc)
4633 return SarFunc.takeError();
4634
4635 Expected<Function *> CopyResult =
4636 emitInterWarpCopyFunction(Loc, ReductionInfos, FuncAttrs, IsByRef);
4637 if (!CopyResult)
4638 return CopyResult.takeError();
4639 Function *WcFunc = *CopyResult;
4640 Builder.restoreIP(IP: CodeGenIP);
4641
4642 Value *RL = Builder.CreatePointerBitCastOrAddrSpaceCast(V: ReductionList, DestTy: PtrTy);
4643
4644 // NOTE: ReductionDataSize is passed as the reduce_data_size argument to
4645 // __kmpc_nvptx_parallel_reduce_nowait_v2, but the runtime implementations do
4646 // not currently use it. It is computed here conservatively as max(element
4647 // sizes) * N rather than the exact sum, which over-calculates the size for
4648 // mixed reduction types but is harmless given the argument is unused.
4649 // TODO: Consider dropping this computation if the runtime API is ever revised
4650 // to remove the unused parameter.
4651 unsigned MaxDataSize = 0;
4652 SmallVector<Type *> ReductionTypeArgs;
4653 for (auto En : enumerate(First&: ReductionInfos)) {
4654 // Use ByRefElementType for by-ref reductions so that MaxDataSize matches
4655 // the actual data size stored in the global reduction buffer, consistent
4656 // with the ReductionsBufferTy struct used for GEP offsets below.
4657 Type *RedTypeArg = (!IsByRef.empty() && IsByRef[En.index()])
4658 ? En.value().ByRefElementType
4659 : En.value().ElementType;
4660 auto Size = M.getDataLayout().getTypeStoreSize(Ty: RedTypeArg);
4661 if (Size > MaxDataSize)
4662 MaxDataSize = Size;
4663 ReductionTypeArgs.emplace_back(Args&: RedTypeArg);
4664 }
4665 Value *ReductionDataSize =
4666 Builder.getInt64(C: MaxDataSize * ReductionInfos.size());
4667
4668 // Helper function to copy thread-local data back to the original reduction
4669 // list.
4670 Function *CopyScratchToListFunc = nullptr;
4671 // Thread-local storage for the reduction variables.
4672 Value *ScratchForCopyBack = nullptr;
4673 // RL pointer to which the final value from the per-thread scratch should be
4674 // copied back. (Basically RL, appropriately casted if necessary.)
4675 Value *RLForCopyBack = RL;
4676
4677 if (!IsTeamsReduction) {
4678 Value *SarFuncCast =
4679 Builder.CreatePointerBitCastOrAddrSpaceCast(V: *SarFunc, DestTy: FuncPtrTy);
4680 Value *WcFuncCast =
4681 Builder.CreatePointerBitCastOrAddrSpaceCast(V: WcFunc, DestTy: FuncPtrTy);
4682 Value *Args[] = {SrcLocInfo, ReductionDataSize, RL, SarFuncCast,
4683 WcFuncCast};
4684 Function *Pv2Ptr = getOrCreateRuntimeFunctionPtr(
4685 FnID: RuntimeFunction::OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2);
4686 Res = createRuntimeFunctionCall(Callee: Pv2Ptr, Args);
4687 } else {
4688 CodeGenIP = Builder.saveIP();
4689 StructType *ReductionsBufferTy = StructType::create(
4690 Context&: Ctx, Elements: ReductionTypeArgs, Name: "struct._globalized_locals_ty");
4691
4692 Expected<Function *> LtGCFunc = emitListToGlobalCopyFunction(
4693 ReductionInfos, ReductionsBufferTy, FuncAttrs, IsByRef);
4694 if (!LtGCFunc)
4695 return LtGCFunc.takeError();
4696
4697 Expected<Function *> GtLCFunc = emitGlobalToListCopyFunction(
4698 ReductionInfos, ReductionsBufferTy, FuncAttrs, IsByRef);
4699 if (!GtLCFunc)
4700 return GtLCFunc.takeError();
4701
4702 Expected<Function *> GtLRFunc = emitGlobalToListReduceFunction(
4703 ReductionInfos, ReduceFn: ReductionFunc, ReductionsBufferTy, FuncAttrs, IsByRef);
4704 if (!GtLRFunc)
4705 return GtLRFunc.takeError();
4706
4707 Builder.restoreIP(IP: CodeGenIP);
4708
4709 // The runtime's cross-team final aggregate uses the storage pointed at by
4710 // its reduce-list argument as per-thread scratch. When the surrounding
4711 // kernel is already in SPMD execution mode, clang emitted each reduction
4712 // private as a per-thread `alloca addrspace(5)`, so the original red_list
4713 // (RL) is already per-thread and nothing else is needed.
4714 //
4715 // When the kernel is in Non-SPMD execution mode at codegen time, clang's
4716 // Generic-mode globalization put the reduction private into team-shared
4717 // LDS. OpenMPOpt may later upgrade the kernel to Generic-SPMD, at which
4718 // point all threads of the last team would race on the shared LDS slot.
4719 // Emit a per-thread scratch buffer and a per-thread RL, copy the team-local
4720 // value in, and hand the per-thread RL to the runtime instead. The writer
4721 // thread copies the final value from that per-thread scratch back to RL
4722 // before running the existing combine path below.
4723
4724 // Thread-local RL (might need localization below before being passed to the
4725 // runtime).
4726 Value *RuntimeRL = RL;
4727
4728 if (!IsSPMD) {
4729 CodeGenIP = Builder.saveIP();
4730 Builder.restoreIP(IP: AllocaIP);
4731 // Allocate thread-local buffer for the reduction variables.
4732 Value *PerThreadScratchAlloca = Builder.CreateAlloca(
4733 Ty: ReductionsBufferTy, /*ArraySize=*/nullptr, Name: ".omp.reduction.scratch");
4734 Value *PerThreadScratch = Builder.CreatePointerBitCastOrAddrSpaceCast(
4735 V: PerThreadScratchAlloca, DestTy: PtrTy,
4736 Name: PerThreadScratchAlloca->getName() + ".ascast");
4737 // Allocate thread-local buffer for the pointers to the reduction
4738 // variables.
4739 Value *PerThreadRedListAlloca =
4740 Builder.CreateAlloca(Ty: RedArrayTy, /*ArraySize=*/nullptr,
4741 Name: ".omp.reduction.per_thread_red_list");
4742 RuntimeRL = Builder.CreatePointerBitCastOrAddrSpaceCast(
4743 V: PerThreadRedListAlloca, DestTy: PtrTy,
4744 Name: PerThreadRedListAlloca->getName() + ".ascast");
4745 Builder.restoreIP(IP: CodeGenIP);
4746
4747 // Iterate over the reduction variables and copy the team-local value to
4748 // the thread-local buffer.
4749 for (auto En : enumerate(First&: ReductionInfos)) {
4750 const ReductionInfo &RI = En.value();
4751 bool IsByRefElem = !IsByRef.empty() && IsByRef[En.index()];
4752
4753 Value *FieldPtr = Builder.CreateConstInBoundsGEP2_32(
4754 Ty: ReductionsBufferTy, Ptr: PerThreadScratch, Idx0: 0, Idx1: En.index());
4755 Value *Slot = Builder.CreateConstInBoundsGEP2_32(Ty: RedArrayTy, Ptr: RuntimeRL,
4756 Idx0: 0, Idx1: En.index());
4757
4758 Value *RuntimeListEntry = FieldPtr;
4759 if (IsByRefElem && RI.DataPtrPtrGen) {
4760 Value *SrcDescriptor =
4761 Builder.CreateLoad(Ty: RI.ElementType, Ptr: RI.PrivateVariable);
4762 Expected<Value *> Descriptor = createReductionDescriptorCopy(
4763 AllocaIP, RI, DataPtr: FieldPtr, SrcDescriptorAddr: SrcDescriptor, DescriptorPtrTy: PtrTy);
4764 if (!Descriptor)
4765 return Descriptor.takeError();
4766 RuntimeListEntry = *Descriptor;
4767 }
4768 Builder.CreateStore(Val: RuntimeListEntry, Ptr: Slot);
4769 }
4770 // The copy helpers were emitted with default-AS (AS 0) pointer params
4771 // (see emitListToGlobalCopyFunction / emitGlobalToListCopyFunction),
4772 // but PerThreadScratch and RL live in the target's default AS, which
4773 // is non-zero on e.g. SPIRV. (See Config.getDefaultTargetAS().)
4774 Type *CopyArg0Ty = (*LtGCFunc)->getFunctionType()->getParamType(i: 0);
4775 Type *CopyArg2Ty = (*LtGCFunc)->getFunctionType()->getParamType(i: 2);
4776 ScratchForCopyBack = Builder.CreatePointerBitCastOrAddrSpaceCast(
4777 V: PerThreadScratch, DestTy: CopyArg0Ty);
4778 RLForCopyBack =
4779 Builder.CreatePointerBitCastOrAddrSpaceCast(V: RL, DestTy: CopyArg2Ty);
4780 // Use index 0 because there is no array of target values to index into,
4781 // there is only one thread-local memory slot.
4782 // restoreIP above left a stale/empty debug location; this inlinable call
4783 // to a debug-info-bearing helper needs one or the verifier rejects the
4784 // module ("!dbg attachment points at wrong subprogram") after inlining.
4785 Builder.SetCurrentDebugLocation(Loc.DL);
4786 Builder.CreateCall(
4787 Callee: *LtGCFunc, Args: {ScratchForCopyBack, Builder.getInt32(C: 0), RLForCopyBack});
4788 CopyScratchToListFunc = *GtLCFunc;
4789 }
4790
4791 Value *Args3[] = {SrcLocInfo, RuntimeRL, *SarFunc, WcFunc,
4792 *LtGCFunc, *GtLCFunc, *GtLRFunc};
4793
4794 Function *TeamsReduceFn = getOrCreateRuntimeFunctionPtr(
4795 FnID: RuntimeFunction::OMPRTL___kmpc_gpu_xteam_reduce_nowait);
4796 Res = createRuntimeFunctionCall(Callee: TeamsReduceFn, Args: Args3);
4797 }
4798
4799 // 5. Build if (res == 1)
4800 BasicBlock *ExitBB = BasicBlock::Create(Context&: Ctx, Name: ".omp.reduction.done");
4801 BasicBlock *ThenBB = BasicBlock::Create(Context&: Ctx, Name: ".omp.reduction.then");
4802 Value *Cond = Builder.CreateICmpEQ(LHS: Res, RHS: Builder.getInt32(C: 1));
4803 Builder.CreateCondBr(Cond, True: ThenBB, False: ExitBB);
4804
4805 // 6. Build then branch: where we have reduced values in the master
4806 // thread in each team.
4807 // __kmpc_end_reduce{_nowait}(<gtid>);
4808 // break;
4809 emitBlock(BB: ThenBB, CurFn: CurFunc);
4810
4811 // Copy the writer thread's per-thread scratch result back into the original
4812 // red-list storage before the existing combine path reads RI.PrivateVariable.
4813 // Set a debug location: this inlinable call to a debug-info-bearing helper
4814 // needs one or the verifier rejects the module after inlining.
4815 if (ScratchForCopyBack) {
4816 Builder.SetCurrentDebugLocation(Loc.DL);
4817 Builder.CreateCall(
4818 Callee: CopyScratchToListFunc,
4819 Args: {ScratchForCopyBack, Builder.getInt32(C: 0), RLForCopyBack});
4820 }
4821
4822 // Add emission of __kmpc_end_reduce{_nowait}(<gtid>);
4823 for (auto En : enumerate(First&: ReductionInfos)) {
4824 const ReductionInfo &RI = En.value();
4825 Type *ValueType = RI.ElementType;
4826 Value *RedValue = RI.Variable;
4827
4828 Value *RHS =
4829 Builder.CreatePointerBitCastOrAddrSpaceCast(V: RI.PrivateVariable, DestTy: PtrTy);
4830
4831 if (ReductionGenCBKind == ReductionGenCBKind::Clang) {
4832 Value *LHSPtr, *RHSPtr;
4833 Builder.restoreIP(IP: RI.ReductionGenClang(Builder.saveIP(), En.index(),
4834 &LHSPtr, &RHSPtr, CurFunc));
4835
4836 // Fix the CallBack code genereated to use the correct Values for the LHS
4837 // and RHS. Cast to match types before replacing (necessary to handle
4838 // different address spaces).
4839 if (LHSPtr->getType() != RedValue->getType())
4840 RedValue = Builder.CreatePointerBitCastOrAddrSpaceCast(
4841 V: RedValue, DestTy: LHSPtr->getType());
4842 if (RHSPtr->getType() != RHS->getType())
4843 RHS =
4844 Builder.CreatePointerBitCastOrAddrSpaceCast(V: RHS, DestTy: RHSPtr->getType());
4845
4846 LHSPtr->replaceUsesWithIf(New: RedValue, ShouldReplace: [ReductionFunc](const Use &U) {
4847 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
4848 ReductionFunc;
4849 });
4850 RHSPtr->replaceUsesWithIf(New: RHS, ShouldReplace: [ReductionFunc](const Use &U) {
4851 return cast<Instruction>(Val: U.getUser())->getParent()->getParent() ==
4852 ReductionFunc;
4853 });
4854 } else {
4855 if (IsByRef.empty() || !IsByRef[En.index()]) {
4856 RedValue = Builder.CreateLoad(Ty: ValueType, Ptr: RI.Variable,
4857 Name: "red.value." + Twine(En.index()));
4858 }
4859 Value *PrivateRedValue = Builder.CreateLoad(
4860 Ty: ValueType, Ptr: RHS, Name: "red.private.value" + Twine(En.index()));
4861 Value *Reduced;
4862 InsertPointOrErrorTy AfterIP =
4863 RI.ReductionGen(Builder.saveIP(), RedValue, PrivateRedValue, Reduced);
4864 if (!AfterIP)
4865 return AfterIP.takeError();
4866 Builder.restoreIP(IP: *AfterIP);
4867
4868 if (!IsByRef.empty() && !IsByRef[En.index()])
4869 Builder.CreateStore(Val: Reduced, Ptr: RI.Variable);
4870 }
4871 }
4872 emitBlock(BB: ExitBB, CurFn: CurFunc);
4873 if (ContinuationBlock) {
4874 Builder.CreateBr(Dest: ContinuationBlock);
4875 Builder.SetInsertPoint(ContinuationBlock);
4876 }
4877 Config.setEmitLLVMUsed();
4878
4879 return Builder.saveIP();
4880}
4881
4882static Function *getFreshReductionFunc(Module &M) {
4883 Type *VoidTy = Type::getVoidTy(C&: M.getContext());
4884 Type *Int8PtrTy = PointerType::getUnqual(C&: M.getContext());
4885 auto *FuncTy =
4886 FunctionType::get(Result: VoidTy, Params: {Int8PtrTy, Int8PtrTy}, /* IsVarArg */ isVarArg: false);
4887 return Function::Create(Ty: FuncTy, Linkage: GlobalVariable::InternalLinkage,
4888 N: ".omp.reduction.func", M: &M);
4889}
4890
4891static Error populateReductionFunction(
4892 Function *ReductionFunc,
4893 ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
4894 IRBuilder<> &Builder, ArrayRef<bool> IsByRef, bool IsGPU) {
4895 Module *Module = ReductionFunc->getParent();
4896 BasicBlock *ReductionFuncBlock =
4897 BasicBlock::Create(Context&: Module->getContext(), Name: "", Parent: ReductionFunc);
4898 Builder.SetInsertPoint(ReductionFuncBlock);
4899 Value *LHSArrayPtr = nullptr;
4900 Value *RHSArrayPtr = nullptr;
4901 if (IsGPU) {
4902 // Need to alloca memory here and deal with the pointers before getting
4903 // LHS/RHS pointers out
4904 //
4905 Argument *Arg0 = ReductionFunc->getArg(i: 0);
4906 Argument *Arg1 = ReductionFunc->getArg(i: 1);
4907 Type *Arg0Type = Arg0->getType();
4908 Type *Arg1Type = Arg1->getType();
4909
4910 Value *LHSAlloca =
4911 Builder.CreateAlloca(Ty: Arg0Type, ArraySize: nullptr, Name: Arg0->getName() + ".addr");
4912 Value *RHSAlloca =
4913 Builder.CreateAlloca(Ty: Arg1Type, ArraySize: nullptr, Name: Arg1->getName() + ".addr");
4914 Value *LHSAddrCast =
4915 Builder.CreatePointerBitCastOrAddrSpaceCast(V: LHSAlloca, DestTy: Arg0Type);
4916 Value *RHSAddrCast =
4917 Builder.CreatePointerBitCastOrAddrSpaceCast(V: RHSAlloca, DestTy: Arg1Type);
4918 Builder.CreateStore(Val: Arg0, Ptr: LHSAddrCast);
4919 Builder.CreateStore(Val: Arg1, Ptr: RHSAddrCast);
4920 LHSArrayPtr = Builder.CreateLoad(Ty: Arg0Type, Ptr: LHSAddrCast);
4921 RHSArrayPtr = Builder.CreateLoad(Ty: Arg1Type, Ptr: RHSAddrCast);
4922 } else {
4923 LHSArrayPtr = ReductionFunc->getArg(i: 0);
4924 RHSArrayPtr = ReductionFunc->getArg(i: 1);
4925 }
4926
4927 unsigned NumReductions = ReductionInfos.size();
4928 Type *RedArrayTy = ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: NumReductions);
4929
4930 for (auto En : enumerate(First&: ReductionInfos)) {
4931 const OpenMPIRBuilder::ReductionInfo &RI = En.value();
4932 Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
4933 Ty: RedArrayTy, Ptr: LHSArrayPtr, Idx0: 0, Idx1: En.index());
4934 Value *LHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: LHSI8PtrPtr);
4935 Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
4936 V: LHSI8Ptr, DestTy: RI.Variable->getType());
4937 Value *LHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: LHSPtr);
4938 Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
4939 Ty: RedArrayTy, Ptr: RHSArrayPtr, Idx0: 0, Idx1: En.index());
4940 Value *RHSI8Ptr = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: RHSI8PtrPtr);
4941 Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
4942 V: RHSI8Ptr, DestTy: RI.PrivateVariable->getType());
4943 Value *RHS = Builder.CreateLoad(Ty: RI.ElementType, Ptr: RHSPtr);
4944 Value *Reduced;
4945 OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
4946 RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
4947 if (!AfterIP)
4948 return AfterIP.takeError();
4949
4950 Builder.restoreIP(IP: *AfterIP);
4951 // TODO: Consider flagging an error.
4952 if (!Builder.GetInsertBlock())
4953 return Error::success();
4954
4955 // store is inside of the reduction region when using by-ref
4956 if (!IsByRef[En.index()])
4957 Builder.CreateStore(Val: Reduced, Ptr: LHSPtr);
4958 }
4959 Builder.CreateRetVoid();
4960 return Error::success();
4961}
4962
4963OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductions(
4964 const LocationDescription &Loc, InsertPointTy AllocaIP,
4965 ArrayRef<ReductionInfo> ReductionInfos, ArrayRef<bool> IsByRef,
4966 bool IsNoWait, bool IsTeamsReduction) {
4967 assert(ReductionInfos.size() == IsByRef.size());
4968 if (Config.isGPU())
4969 return createReductionsGPU(Loc, AllocaIP, CodeGenIP: Builder.saveIP(), ReductionInfos,
4970 IsByRef, IsNoWait, IsTeamsReduction);
4971
4972 checkReductionInfos(ReductionInfos, /*IsGPU*/ false);
4973
4974 if (!updateToLocation(Loc))
4975 return InsertPointTy();
4976
4977 if (ReductionInfos.size() == 0)
4978 return Builder.saveIP();
4979
4980 BasicBlock *InsertBlock = Loc.IP.getBlock();
4981 BasicBlock *ContinuationBlock =
4982 InsertBlock->splitBasicBlock(I: Loc.IP.getPoint(), BBName: "reduce.finalize");
4983 InsertBlock->getTerminator()->eraseFromParent();
4984
4985 // Create and populate array of type-erased pointers to private reduction
4986 // values.
4987 unsigned NumReductions = ReductionInfos.size();
4988 Type *RedArrayTy = ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: NumReductions);
4989 Builder.SetInsertPoint(AllocaIP.getBlock()->getTerminator());
4990 Value *RedArray = Builder.CreateAlloca(Ty: RedArrayTy, ArraySize: nullptr, Name: "red.array");
4991
4992 Builder.SetInsertPoint(TheBB: InsertBlock, IP: InsertBlock->end());
4993
4994 for (auto En : enumerate(First&: ReductionInfos)) {
4995 unsigned Index = En.index();
4996 const ReductionInfo &RI = En.value();
4997 Value *RedArrayElemPtr = Builder.CreateConstInBoundsGEP2_64(
4998 Ty: RedArrayTy, Ptr: RedArray, Idx0: 0, Idx1: Index, Name: "red.array.elem." + Twine(Index));
4999 Builder.CreateStore(Val: RI.PrivateVariable, Ptr: RedArrayElemPtr);
5000 }
5001
5002 // Emit a call to the runtime function that orchestrates the reduction.
5003 // Declare the reduction function in the process.
5004 Type *IndexTy = Builder.getIndexTy(
5005 DL: M.getDataLayout(), AddrSpace: M.getDataLayout().getDefaultGlobalsAddressSpace());
5006 Function *Func = Builder.GetInsertBlock()->getParent();
5007 Module *Module = Func->getParent();
5008 uint32_t SrcLocStrSize;
5009 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5010 bool CanGenerateAtomic = all_of(Range&: ReductionInfos, P: [](const ReductionInfo &RI) {
5011 return RI.AtomicReductionGen;
5012 });
5013 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize,
5014 LocFlags: CanGenerateAtomic
5015 ? IdentFlag::OMP_IDENT_FLAG_ATOMIC_REDUCE
5016 : IdentFlag(0));
5017 Value *ThreadId = getOrCreateThreadID(Ident);
5018 Constant *NumVariables = Builder.getInt32(C: NumReductions);
5019 const DataLayout &DL = Module->getDataLayout();
5020 unsigned RedArrayByteSize = DL.getTypeStoreSize(Ty: RedArrayTy);
5021 Constant *RedArraySize = ConstantInt::get(Ty: IndexTy, V: RedArrayByteSize);
5022 Function *ReductionFunc = getFreshReductionFunc(M&: *Module);
5023 Value *Lock = getOMPCriticalRegionLock(CriticalName: ".reduction");
5024 Function *ReduceFunc = getOrCreateRuntimeFunctionPtr(
5025 FnID: IsNoWait ? RuntimeFunction::OMPRTL___kmpc_reduce_nowait
5026 : RuntimeFunction::OMPRTL___kmpc_reduce);
5027 CallInst *ReduceCall =
5028 createRuntimeFunctionCall(Callee: ReduceFunc,
5029 Args: {Ident, ThreadId, NumVariables, RedArraySize,
5030 RedArray, ReductionFunc, Lock},
5031 Name: "reduce");
5032
5033 // Create final reduction entry blocks for the atomic and non-atomic case.
5034 // Emit IR that dispatches control flow to one of the blocks based on the
5035 // reduction supporting the atomic mode.
5036 BasicBlock *NonAtomicRedBlock =
5037 BasicBlock::Create(Context&: Module->getContext(), Name: "reduce.switch.nonatomic", Parent: Func);
5038 BasicBlock *AtomicRedBlock =
5039 BasicBlock::Create(Context&: Module->getContext(), Name: "reduce.switch.atomic", Parent: Func);
5040 SwitchInst *Switch =
5041 Builder.CreateSwitch(V: ReduceCall, Dest: ContinuationBlock, /* NumCases */ 2);
5042 Switch->addCase(OnVal: Builder.getInt32(C: 1), Dest: NonAtomicRedBlock);
5043 Switch->addCase(OnVal: Builder.getInt32(C: 2), Dest: AtomicRedBlock);
5044
5045 // Populate the non-atomic reduction using the elementwise reduction function.
5046 // This loads the elements from the global and private variables and reduces
5047 // them before storing back the result to the global variable.
5048 Builder.SetInsertPoint(NonAtomicRedBlock);
5049 for (auto En : enumerate(First&: ReductionInfos)) {
5050 const ReductionInfo &RI = En.value();
5051 Type *ValueType = RI.ElementType;
5052 // We have one less load for by-ref case because that load is now inside of
5053 // the reduction region
5054 Value *RedValue = RI.Variable;
5055 if (!IsByRef[En.index()]) {
5056 RedValue = Builder.CreateLoad(Ty: ValueType, Ptr: RI.Variable,
5057 Name: "red.value." + Twine(En.index()));
5058 }
5059 Value *PrivateRedValue =
5060 Builder.CreateLoad(Ty: ValueType, Ptr: RI.PrivateVariable,
5061 Name: "red.private.value." + Twine(En.index()));
5062 Value *Reduced;
5063 InsertPointOrErrorTy AfterIP =
5064 RI.ReductionGen(Builder.saveIP(), RedValue, PrivateRedValue, Reduced);
5065 if (!AfterIP)
5066 return AfterIP.takeError();
5067 Builder.restoreIP(IP: *AfterIP);
5068
5069 if (!Builder.GetInsertBlock())
5070 return InsertPointTy();
5071 // for by-ref case, the load is inside of the reduction region
5072 if (!IsByRef[En.index()])
5073 Builder.CreateStore(Val: Reduced, Ptr: RI.Variable);
5074 }
5075 Function *EndReduceFunc = getOrCreateRuntimeFunctionPtr(
5076 FnID: IsNoWait ? RuntimeFunction::OMPRTL___kmpc_end_reduce_nowait
5077 : RuntimeFunction::OMPRTL___kmpc_end_reduce);
5078 createRuntimeFunctionCall(Callee: EndReduceFunc, Args: {Ident, ThreadId, Lock});
5079 Builder.CreateBr(Dest: ContinuationBlock);
5080
5081 // Populate the atomic reduction using the atomic elementwise reduction
5082 // function. There are no loads/stores here because they will be happening
5083 // inside the atomic elementwise reduction.
5084 Builder.SetInsertPoint(AtomicRedBlock);
5085 if (CanGenerateAtomic && llvm::none_of(Range&: IsByRef, P: [](bool P) { return P; })) {
5086 for (const ReductionInfo &RI : ReductionInfos) {
5087 InsertPointOrErrorTy AfterIP = RI.AtomicReductionGen(
5088 Builder.saveIP(), RI.ElementType, RI.Variable, RI.PrivateVariable);
5089 if (!AfterIP)
5090 return AfterIP.takeError();
5091 Builder.restoreIP(IP: *AfterIP);
5092 if (!Builder.GetInsertBlock())
5093 return InsertPointTy();
5094 }
5095 Builder.CreateBr(Dest: ContinuationBlock);
5096 } else {
5097 Builder.CreateUnreachable();
5098 }
5099
5100 // Populate the outlined reduction function using the elementwise reduction
5101 // function. Partial values are extracted from the type-erased array of
5102 // pointers to private variables.
5103 Error Err = populateReductionFunction(ReductionFunc, ReductionInfos, Builder,
5104 IsByRef, /*isGPU=*/IsGPU: false);
5105 if (Err)
5106 return Err;
5107
5108 if (!Builder.GetInsertBlock())
5109 return InsertPointTy();
5110
5111 Builder.SetInsertPoint(ContinuationBlock);
5112 return Builder.saveIP();
5113}
5114
5115OpenMPIRBuilder::InsertPointOrErrorTy
5116OpenMPIRBuilder::createMaster(const LocationDescription &Loc,
5117 BodyGenCallbackTy BodyGenCB,
5118 FinalizeCallbackTy FiniCB) {
5119 if (!updateToLocation(Loc))
5120 return Loc.IP;
5121
5122 Directive OMPD = Directive::OMPD_master;
5123 uint32_t SrcLocStrSize;
5124 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5125 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5126 Value *ThreadId = getOrCreateThreadID(Ident);
5127 Value *Args[] = {Ident, ThreadId};
5128
5129 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_master);
5130 Instruction *EntryCall = createRuntimeFunctionCall(Callee: EntryRTLFn, Args);
5131
5132 Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_master);
5133 Instruction *ExitCall = createRuntimeFunctionCall(Callee: ExitRTLFn, Args);
5134
5135 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
5136 /*Conditional*/ true, /*hasFinalize*/ HasFinalize: true);
5137}
5138
5139OpenMPIRBuilder::InsertPointOrErrorTy
5140OpenMPIRBuilder::createMasked(const LocationDescription &Loc,
5141 BodyGenCallbackTy BodyGenCB,
5142 FinalizeCallbackTy FiniCB, Value *Filter) {
5143 if (!updateToLocation(Loc))
5144 return Loc.IP;
5145
5146 Directive OMPD = Directive::OMPD_masked;
5147 uint32_t SrcLocStrSize;
5148 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5149 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5150 Value *ThreadId = getOrCreateThreadID(Ident);
5151 Value *Args[] = {Ident, ThreadId, Filter};
5152 Value *ArgsEnd[] = {Ident, ThreadId};
5153
5154 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_masked);
5155 Instruction *EntryCall = createRuntimeFunctionCall(Callee: EntryRTLFn, Args);
5156
5157 Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_masked);
5158 Instruction *ExitCall = createRuntimeFunctionCall(Callee: ExitRTLFn, Args: ArgsEnd);
5159
5160 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
5161 /*Conditional*/ true, /*hasFinalize*/ HasFinalize: true);
5162}
5163
5164static llvm::CallInst *emitNoUnwindRuntimeCall(IRBuilder<> &Builder,
5165 llvm::FunctionCallee Callee,
5166 ArrayRef<llvm::Value *> Args,
5167 const llvm::Twine &Name) {
5168 llvm::CallInst *Call = Builder.CreateCall(
5169 Callee, Args, OpBundles: SmallVector<llvm::OperandBundleDef, 1>(), Name);
5170 Call->setDoesNotThrow();
5171 return Call;
5172}
5173
5174// Expects input basic block is dominated by BeforeScanBB.
5175// Once Scan directive is encountered, the code after scan directive should be
5176// dominated by AfterScanBB. Scan directive splits the code sequence to
5177// scan and input phase. Based on whether inclusive or exclusive
5178// clause is used in the scan directive and whether input loop or scan loop
5179// is lowered, it adds jumps to input and scan phase. First Scan loop is the
5180// input loop and second is the scan loop. The code generated handles only
5181// inclusive scans now.
5182OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createScan(
5183 const LocationDescription &Loc, InsertPointTy AllocaIP,
5184 ArrayRef<llvm::Value *> ScanVars, ArrayRef<llvm::Type *> ScanVarsType,
5185 bool IsInclusive, ScanInfo *ScanRedInfo) {
5186 if (ScanRedInfo->OMPFirstScanLoop) {
5187 llvm::Error Err = emitScanBasedDirectiveDeclsIR(AllocaIP, ScanVars,
5188 ScanVarsType, ScanRedInfo);
5189 if (Err)
5190 return Err;
5191 }
5192 if (!updateToLocation(Loc))
5193 return Loc.IP;
5194
5195 llvm::Value *IV = ScanRedInfo->IV;
5196
5197 if (ScanRedInfo->OMPFirstScanLoop) {
5198 // Emit buffer[i] = red; at the end of the input phase.
5199 for (size_t i = 0; i < ScanVars.size(); i++) {
5200 Value *BuffPtr = (*(ScanRedInfo->ScanBuffPtrs))[ScanVars[i]];
5201 Value *Buff = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BuffPtr);
5202 Type *DestTy = ScanVarsType[i];
5203 Value *Val = Builder.CreateInBoundsGEP(Ty: DestTy, Ptr: Buff, IdxList: IV, Name: "arrayOffset");
5204 Value *Src = Builder.CreateLoad(Ty: DestTy, Ptr: ScanVars[i]);
5205
5206 Builder.CreateStore(Val: Src, Ptr: Val);
5207 }
5208 }
5209 Builder.CreateBr(Dest: ScanRedInfo->OMPScanLoopExit);
5210 emitBlock(BB: ScanRedInfo->OMPScanDispatch,
5211 CurFn: Builder.GetInsertBlock()->getParent());
5212
5213 if (!ScanRedInfo->OMPFirstScanLoop) {
5214 IV = ScanRedInfo->IV;
5215 // Emit red = buffer[i]; at the entrance to the scan phase.
5216 // TODO: if exclusive scan, the red = buffer[i-1] needs to be updated.
5217 for (size_t i = 0; i < ScanVars.size(); i++) {
5218 Value *BuffPtr = (*(ScanRedInfo->ScanBuffPtrs))[ScanVars[i]];
5219 Value *Buff = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BuffPtr);
5220 Type *DestTy = ScanVarsType[i];
5221 Value *SrcPtr =
5222 Builder.CreateInBoundsGEP(Ty: DestTy, Ptr: Buff, IdxList: IV, Name: "arrayOffset");
5223 Value *Src = Builder.CreateLoad(Ty: DestTy, Ptr: SrcPtr);
5224 Builder.CreateStore(Val: Src, Ptr: ScanVars[i]);
5225 }
5226 }
5227
5228 // TODO: Update it to CreateBr and remove dead blocks
5229 llvm::Value *CmpI = Builder.getInt1(V: true);
5230 if (ScanRedInfo->OMPFirstScanLoop == IsInclusive) {
5231 Builder.CreateCondBr(Cond: CmpI, True: ScanRedInfo->OMPBeforeScanBlock,
5232 False: ScanRedInfo->OMPAfterScanBlock);
5233 } else {
5234 Builder.CreateCondBr(Cond: CmpI, True: ScanRedInfo->OMPAfterScanBlock,
5235 False: ScanRedInfo->OMPBeforeScanBlock);
5236 }
5237 emitBlock(BB: ScanRedInfo->OMPAfterScanBlock,
5238 CurFn: Builder.GetInsertBlock()->getParent());
5239 Builder.SetInsertPoint(ScanRedInfo->OMPAfterScanBlock);
5240 return Builder.saveIP();
5241}
5242
5243Error OpenMPIRBuilder::emitScanBasedDirectiveDeclsIR(
5244 InsertPointTy AllocaIP, ArrayRef<Value *> ScanVars,
5245 ArrayRef<Type *> ScanVarsType, ScanInfo *ScanRedInfo) {
5246
5247 Builder.restoreIP(IP: AllocaIP);
5248 // Create the shared pointer at alloca IP.
5249 for (size_t i = 0; i < ScanVars.size(); i++) {
5250 llvm::Value *BuffPtr =
5251 Builder.CreateAlloca(Ty: Builder.getPtrTy(), ArraySize: nullptr, Name: "vla");
5252 (*(ScanRedInfo->ScanBuffPtrs))[ScanVars[i]] = BuffPtr;
5253 }
5254
5255 // Allocate temporary buffer by master thread
5256 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
5257 ArrayRef<BasicBlock *> DeallocBlocks) -> Error {
5258 Builder.restoreIP(IP: CodeGenIP);
5259 Value *AllocSpan =
5260 Builder.CreateAdd(LHS: ScanRedInfo->Span, RHS: Builder.getInt32(C: 1));
5261 for (size_t i = 0; i < ScanVars.size(); i++) {
5262 Type *IntPtrTy = Builder.getInt32Ty();
5263 Constant *Allocsize = ConstantExpr::getSizeOf(Ty: ScanVarsType[i]);
5264 Allocsize = ConstantExpr::getTruncOrBitCast(C: Allocsize, Ty: IntPtrTy);
5265 Value *Buff = Builder.CreateMalloc(IntPtrTy, AllocTy: ScanVarsType[i], AllocSize: Allocsize,
5266 ArraySize: AllocSpan, MallocF: nullptr, Name: "arr");
5267 Builder.CreateStore(Val: Buff, Ptr: (*(ScanRedInfo->ScanBuffPtrs))[ScanVars[i]]);
5268 }
5269 return Error::success();
5270 };
5271 // TODO: Perform finalization actions for variables. This has to be
5272 // called for variables which have destructors/finalizers.
5273 auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); };
5274
5275 Builder.SetInsertPoint(ScanRedInfo->OMPScanInit->getTerminator());
5276 llvm::Value *FilterVal = Builder.getInt32(C: 0);
5277 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
5278 createMasked(Loc: Builder.saveIP(), BodyGenCB, FiniCB, Filter: FilterVal);
5279
5280 if (!AfterIP)
5281 return AfterIP.takeError();
5282 Builder.restoreIP(IP: *AfterIP);
5283 BasicBlock *InputBB = Builder.GetInsertBlock();
5284 if (InputBB->hasTerminator())
5285 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
5286 AfterIP = createBarrier(Loc: Builder.saveIP(), Kind: llvm::omp::OMPD_barrier);
5287 if (!AfterIP)
5288 return AfterIP.takeError();
5289 Builder.restoreIP(IP: *AfterIP);
5290
5291 return Error::success();
5292}
5293
5294Error OpenMPIRBuilder::emitScanBasedDirectiveFinalsIR(
5295 ArrayRef<ReductionInfo> ReductionInfos, ScanInfo *ScanRedInfo) {
5296 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
5297 ArrayRef<BasicBlock *> DeallocBlocks) -> Error {
5298 Builder.restoreIP(IP: CodeGenIP);
5299 for (ReductionInfo RedInfo : ReductionInfos) {
5300 Value *PrivateVar = RedInfo.PrivateVariable;
5301 Value *OrigVar = RedInfo.Variable;
5302 Value *BuffPtr = (*(ScanRedInfo->ScanBuffPtrs))[PrivateVar];
5303 Value *Buff = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BuffPtr);
5304
5305 Type *SrcTy = RedInfo.ElementType;
5306 Value *Val = Builder.CreateInBoundsGEP(Ty: SrcTy, Ptr: Buff, IdxList: ScanRedInfo->Span,
5307 Name: "arrayOffset");
5308 Value *Src = Builder.CreateLoad(Ty: SrcTy, Ptr: Val);
5309
5310 Builder.CreateStore(Val: Src, Ptr: OrigVar);
5311 Builder.CreateFree(Source: Buff);
5312 }
5313 return Error::success();
5314 };
5315 // TODO: Perform finalization actions for variables. This has to be
5316 // called for variables which have destructors/finalizers.
5317 auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); };
5318
5319 if (Instruction *TI = ScanRedInfo->OMPScanFinish->getTerminatorOrNull())
5320 Builder.SetInsertPoint(TI);
5321 else
5322 Builder.SetInsertPoint(ScanRedInfo->OMPScanFinish);
5323
5324 llvm::Value *FilterVal = Builder.getInt32(C: 0);
5325 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
5326 createMasked(Loc: Builder.saveIP(), BodyGenCB, FiniCB, Filter: FilterVal);
5327
5328 if (!AfterIP)
5329 return AfterIP.takeError();
5330 Builder.restoreIP(IP: *AfterIP);
5331 BasicBlock *InputBB = Builder.GetInsertBlock();
5332 if (InputBB->hasTerminator())
5333 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
5334 AfterIP = createBarrier(Loc: Builder.saveIP(), Kind: llvm::omp::OMPD_barrier);
5335 if (!AfterIP)
5336 return AfterIP.takeError();
5337 Builder.restoreIP(IP: *AfterIP);
5338 return Error::success();
5339}
5340
5341OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitScanReduction(
5342 const LocationDescription &Loc,
5343 ArrayRef<llvm::OpenMPIRBuilder::ReductionInfo> ReductionInfos,
5344 ScanInfo *ScanRedInfo) {
5345
5346 if (!updateToLocation(Loc))
5347 return Loc.IP;
5348 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
5349 ArrayRef<BasicBlock *> DeallocBlocks) -> Error {
5350 Builder.restoreIP(IP: CodeGenIP);
5351 Function *CurFn = Builder.GetInsertBlock()->getParent();
5352 // for (int k = 0; k <= ceil(log2(n)); ++k)
5353 llvm::BasicBlock *LoopBB =
5354 BasicBlock::Create(Context&: CurFn->getContext(), Name: "omp.outer.log.scan.body");
5355 llvm::BasicBlock *ExitBB =
5356 splitBB(Builder, CreateBranch: false, Name: "omp.outer.log.scan.exit");
5357 llvm::Function *F = llvm::Intrinsic::getOrInsertDeclaration(
5358 M: Builder.GetInsertBlock()->getModule(),
5359 id: (llvm::Intrinsic::ID)llvm::Intrinsic::log2, OverloadTys: Builder.getDoubleTy());
5360 llvm::BasicBlock *InputBB = Builder.GetInsertBlock();
5361 llvm::Value *Arg =
5362 Builder.CreateUIToFP(V: ScanRedInfo->Span, DestTy: Builder.getDoubleTy());
5363 llvm::Value *LogVal = emitNoUnwindRuntimeCall(Builder, Callee: F, Args: Arg, Name: "");
5364 F = llvm::Intrinsic::getOrInsertDeclaration(
5365 M: Builder.GetInsertBlock()->getModule(),
5366 id: (llvm::Intrinsic::ID)llvm::Intrinsic::ceil, OverloadTys: Builder.getDoubleTy());
5367 LogVal = emitNoUnwindRuntimeCall(Builder, Callee: F, Args: LogVal, Name: "");
5368 LogVal = Builder.CreateFPToUI(V: LogVal, DestTy: Builder.getInt32Ty());
5369 llvm::Value *NMin1 = Builder.CreateNUWSub(
5370 LHS: ScanRedInfo->Span,
5371 RHS: llvm::ConstantInt::get(Ty: ScanRedInfo->Span->getType(), V: 1));
5372 Builder.SetInsertPoint(InputBB);
5373 Builder.CreateBr(Dest: LoopBB);
5374 emitBlock(BB: LoopBB, CurFn);
5375 Builder.SetInsertPoint(LoopBB);
5376
5377 PHINode *Counter = Builder.CreatePHI(Ty: Builder.getInt32Ty(), NumReservedValues: 2);
5378 // size pow2k = 1;
5379 PHINode *Pow2K = Builder.CreatePHI(Ty: Builder.getInt32Ty(), NumReservedValues: 2);
5380 Counter->addIncoming(V: llvm::ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
5381 BB: InputBB);
5382 Pow2K->addIncoming(V: llvm::ConstantInt::get(Ty: Builder.getInt32Ty(), V: 1),
5383 BB: InputBB);
5384 // for (size i = n - 1; i >= 2 ^ k; --i)
5385 // tmp[i] op= tmp[i-pow2k];
5386 llvm::BasicBlock *InnerLoopBB =
5387 BasicBlock::Create(Context&: CurFn->getContext(), Name: "omp.inner.log.scan.body");
5388 llvm::BasicBlock *InnerExitBB =
5389 BasicBlock::Create(Context&: CurFn->getContext(), Name: "omp.inner.log.scan.exit");
5390 llvm::Value *CmpI = Builder.CreateICmpUGE(LHS: NMin1, RHS: Pow2K);
5391 Builder.CreateCondBr(Cond: CmpI, True: InnerLoopBB, False: InnerExitBB);
5392 emitBlock(BB: InnerLoopBB, CurFn);
5393 Builder.SetInsertPoint(InnerLoopBB);
5394 PHINode *IVal = Builder.CreatePHI(Ty: Builder.getInt32Ty(), NumReservedValues: 2);
5395 IVal->addIncoming(V: NMin1, BB: LoopBB);
5396 for (ReductionInfo RedInfo : ReductionInfos) {
5397 Value *ReductionVal = RedInfo.PrivateVariable;
5398 Value *BuffPtr = (*(ScanRedInfo->ScanBuffPtrs))[ReductionVal];
5399 Value *Buff = Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: BuffPtr);
5400 Type *DestTy = RedInfo.ElementType;
5401 Value *IV = Builder.CreateAdd(LHS: IVal, RHS: Builder.getInt32(C: 1));
5402 Value *LHSPtr =
5403 Builder.CreateInBoundsGEP(Ty: DestTy, Ptr: Buff, IdxList: IV, Name: "arrayOffset");
5404 Value *OffsetIval = Builder.CreateNUWSub(LHS: IV, RHS: Pow2K);
5405 Value *RHSPtr =
5406 Builder.CreateInBoundsGEP(Ty: DestTy, Ptr: Buff, IdxList: OffsetIval, Name: "arrayOffset");
5407 Value *LHS = Builder.CreateLoad(Ty: DestTy, Ptr: LHSPtr);
5408 Value *RHS = Builder.CreateLoad(Ty: DestTy, Ptr: RHSPtr);
5409 llvm::Value *Result;
5410 InsertPointOrErrorTy AfterIP =
5411 RedInfo.ReductionGen(Builder.saveIP(), LHS, RHS, Result);
5412 if (!AfterIP)
5413 return AfterIP.takeError();
5414 Builder.CreateStore(Val: Result, Ptr: LHSPtr);
5415 }
5416 llvm::Value *NextIVal = Builder.CreateNUWSub(
5417 LHS: IVal, RHS: llvm::ConstantInt::get(Ty: Builder.getInt32Ty(), V: 1));
5418 IVal->addIncoming(V: NextIVal, BB: Builder.GetInsertBlock());
5419 CmpI = Builder.CreateICmpUGE(LHS: NextIVal, RHS: Pow2K);
5420 Builder.CreateCondBr(Cond: CmpI, True: InnerLoopBB, False: InnerExitBB);
5421 emitBlock(BB: InnerExitBB, CurFn);
5422 llvm::Value *Next = Builder.CreateNUWAdd(
5423 LHS: Counter, RHS: llvm::ConstantInt::get(Ty: Counter->getType(), V: 1));
5424 Counter->addIncoming(V: Next, BB: Builder.GetInsertBlock());
5425 // pow2k <<= 1;
5426 llvm::Value *NextPow2K = Builder.CreateShl(LHS: Pow2K, RHS: 1, Name: "", /*HasNUW=*/true);
5427 Pow2K->addIncoming(V: NextPow2K, BB: Builder.GetInsertBlock());
5428 llvm::Value *Cmp = Builder.CreateICmpNE(LHS: Next, RHS: LogVal);
5429 Builder.CreateCondBr(Cond: Cmp, True: LoopBB, False: ExitBB);
5430 Builder.SetInsertPoint(ExitBB->getFirstInsertionPt());
5431 return Error::success();
5432 };
5433
5434 // TODO: Perform finalization actions for variables. This has to be
5435 // called for variables which have destructors/finalizers.
5436 auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); };
5437
5438 llvm::Value *FilterVal = Builder.getInt32(C: 0);
5439 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
5440 createMasked(Loc: Builder.saveIP(), BodyGenCB, FiniCB, Filter: FilterVal);
5441
5442 if (!AfterIP)
5443 return AfterIP.takeError();
5444 Builder.restoreIP(IP: *AfterIP);
5445 AfterIP = createBarrier(Loc: Builder.saveIP(), Kind: llvm::omp::OMPD_barrier);
5446
5447 if (!AfterIP)
5448 return AfterIP.takeError();
5449 Builder.restoreIP(IP: *AfterIP);
5450 Error Err = emitScanBasedDirectiveFinalsIR(ReductionInfos, ScanRedInfo);
5451 if (Err)
5452 return Err;
5453
5454 return AfterIP;
5455}
5456
5457Error OpenMPIRBuilder::emitScanBasedDirectiveIR(
5458 llvm::function_ref<Error()> InputLoopGen,
5459 llvm::function_ref<Error(LocationDescription Loc)> ScanLoopGen,
5460 ScanInfo *ScanRedInfo) {
5461
5462 {
5463 // Emit loop with input phase:
5464 // for (i: 0..<num_iters>) {
5465 // <input phase>;
5466 // buffer[i] = red;
5467 // }
5468 ScanRedInfo->OMPFirstScanLoop = true;
5469 Error Err = InputLoopGen();
5470 if (Err)
5471 return Err;
5472 }
5473 {
5474 // Emit loop with scan phase:
5475 // for (i: 0..<num_iters>) {
5476 // red = buffer[i];
5477 // <scan phase>;
5478 // }
5479 ScanRedInfo->OMPFirstScanLoop = false;
5480 Error Err = ScanLoopGen(Builder.saveIP());
5481 if (Err)
5482 return Err;
5483 }
5484 return Error::success();
5485}
5486
5487void OpenMPIRBuilder::createScanBBs(ScanInfo *ScanRedInfo) {
5488 Function *Fun = Builder.GetInsertBlock()->getParent();
5489 ScanRedInfo->OMPScanDispatch =
5490 BasicBlock::Create(Context&: Fun->getContext(), Name: "omp.inscan.dispatch");
5491 ScanRedInfo->OMPAfterScanBlock =
5492 BasicBlock::Create(Context&: Fun->getContext(), Name: "omp.after.scan.bb");
5493 ScanRedInfo->OMPBeforeScanBlock =
5494 BasicBlock::Create(Context&: Fun->getContext(), Name: "omp.before.scan.bb");
5495 ScanRedInfo->OMPScanLoopExit =
5496 BasicBlock::Create(Context&: Fun->getContext(), Name: "omp.scan.loop.exit");
5497}
5498CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton(
5499 DebugLoc DL, Value *TripCount, Function *F, BasicBlock *PreInsertBefore,
5500 BasicBlock *PostInsertBefore, const Twine &Name) {
5501 Module *M = F->getParent();
5502 LLVMContext &Ctx = M->getContext();
5503 Type *IndVarTy = TripCount->getType();
5504
5505 // Create the basic block structure.
5506 BasicBlock *Preheader =
5507 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".preheader", Parent: F, InsertBefore: PreInsertBefore);
5508 BasicBlock *Header =
5509 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".header", Parent: F, InsertBefore: PreInsertBefore);
5510 BasicBlock *Cond =
5511 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".cond", Parent: F, InsertBefore: PreInsertBefore);
5512 BasicBlock *Body =
5513 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".body", Parent: F, InsertBefore: PreInsertBefore);
5514 BasicBlock *Latch =
5515 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".inc", Parent: F, InsertBefore: PostInsertBefore);
5516 BasicBlock *Exit =
5517 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".exit", Parent: F, InsertBefore: PostInsertBefore);
5518 BasicBlock *After =
5519 BasicBlock::Create(Context&: Ctx, Name: "omp_" + Name + ".after", Parent: F, InsertBefore: PostInsertBefore);
5520
5521 // Use specified DebugLoc for new instructions.
5522 Builder.SetCurrentDebugLocation(DL);
5523
5524 Builder.SetInsertPoint(Preheader);
5525 Builder.CreateBr(Dest: Header);
5526
5527 Builder.SetInsertPoint(Header);
5528 PHINode *IndVarPHI = Builder.CreatePHI(Ty: IndVarTy, NumReservedValues: 2, Name: "omp_" + Name + ".iv");
5529 IndVarPHI->addIncoming(V: ConstantInt::get(Ty: IndVarTy, V: 0), BB: Preheader);
5530 Builder.CreateBr(Dest: Cond);
5531
5532 Builder.SetInsertPoint(Cond);
5533 Value *Cmp =
5534 Builder.CreateICmpULT(LHS: IndVarPHI, RHS: TripCount, Name: "omp_" + Name + ".cmp");
5535 Builder.CreateCondBr(Cond: Cmp, True: Body, False: Exit);
5536
5537 Builder.SetInsertPoint(Body);
5538 Builder.CreateBr(Dest: Latch);
5539
5540 Builder.SetInsertPoint(Latch);
5541 Value *Next = Builder.CreateAdd(LHS: IndVarPHI, RHS: ConstantInt::get(Ty: IndVarTy, V: 1),
5542 Name: "omp_" + Name + ".next", /*HasNUW=*/true);
5543 Builder.CreateBr(Dest: Header);
5544 IndVarPHI->addIncoming(V: Next, BB: Latch);
5545
5546 Builder.SetInsertPoint(Exit);
5547 Builder.CreateBr(Dest: After);
5548
5549 // Remember and return the canonical control flow.
5550 LoopInfos.emplace_front();
5551 CanonicalLoopInfo *CL = &LoopInfos.front();
5552
5553 CL->Header = Header;
5554 CL->Cond = Cond;
5555 CL->Latch = Latch;
5556 CL->Exit = Exit;
5557
5558#ifndef NDEBUG
5559 CL->assertOK();
5560#endif
5561 return CL;
5562}
5563
5564Expected<CanonicalLoopInfo *>
5565OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
5566 LoopBodyGenCallbackTy BodyGenCB,
5567 Value *TripCount, const Twine &Name) {
5568 BasicBlock *BB = Loc.IP.getBlock();
5569 BasicBlock *NextBB = BB->getNextNode();
5570
5571 CanonicalLoopInfo *CL = createLoopSkeleton(DL: Loc.DL, TripCount, F: BB->getParent(),
5572 PreInsertBefore: NextBB, PostInsertBefore: NextBB, Name);
5573 BasicBlock *After = CL->getAfter();
5574
5575 // If location is not set, don't connect the loop.
5576 if (updateToLocation(Loc)) {
5577 // Split the loop at the insertion point: Branch to the preheader and move
5578 // every following instruction to after the loop (the After BB). Also, the
5579 // new successor is the loop's after block.
5580 spliceBB(Builder, New: After, /*CreateBranch=*/false);
5581 Builder.CreateBr(Dest: CL->getPreheader());
5582 }
5583
5584 // Emit the body content. We do it after connecting the loop to the CFG to
5585 // avoid that the callback encounters degenerate BBs.
5586 if (Error Err = BodyGenCB(CL->getBodyIP(), CL->getIndVar()))
5587 return Err;
5588
5589#ifndef NDEBUG
5590 CL->assertOK();
5591#endif
5592 return CL;
5593}
5594
5595Expected<ScanInfo *> OpenMPIRBuilder::scanInfoInitialize() {
5596 ScanInfos.emplace_front();
5597 ScanInfo *Result = &ScanInfos.front();
5598 return Result;
5599}
5600
5601Expected<SmallVector<llvm::CanonicalLoopInfo *>>
5602OpenMPIRBuilder::createCanonicalScanLoops(
5603 const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
5604 Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
5605 InsertPointTy ComputeIP, const Twine &Name, ScanInfo *ScanRedInfo) {
5606 LocationDescription ComputeLoc =
5607 ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
5608 updateToLocation(Loc: ComputeLoc);
5609
5610 SmallVector<CanonicalLoopInfo *> Result;
5611
5612 Value *TripCount = calculateCanonicalLoopTripCount(
5613 Loc: ComputeLoc, Start, Stop, Step, IsSigned, InclusiveStop, Name);
5614 ScanRedInfo->Span = TripCount;
5615 ScanRedInfo->OMPScanInit = splitBB(Builder, CreateBranch: true, Name: "scan.init");
5616 Builder.SetInsertPoint(ScanRedInfo->OMPScanInit);
5617
5618 auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
5619 Builder.restoreIP(IP: CodeGenIP);
5620 ScanRedInfo->IV = IV;
5621 createScanBBs(ScanRedInfo);
5622 BasicBlock *InputBlock = Builder.GetInsertBlock();
5623 Instruction *Terminator = InputBlock->getTerminator();
5624 assert(Terminator->getNumSuccessors() == 1);
5625 BasicBlock *ContinueBlock = Terminator->getSuccessor(Idx: 0);
5626 Terminator->setSuccessor(Idx: 0, BB: ScanRedInfo->OMPScanDispatch);
5627 emitBlock(BB: ScanRedInfo->OMPBeforeScanBlock,
5628 CurFn: Builder.GetInsertBlock()->getParent());
5629 Builder.CreateBr(Dest: ScanRedInfo->OMPScanLoopExit);
5630 emitBlock(BB: ScanRedInfo->OMPScanLoopExit,
5631 CurFn: Builder.GetInsertBlock()->getParent());
5632 Builder.CreateBr(Dest: ContinueBlock);
5633 Builder.SetInsertPoint(
5634 ScanRedInfo->OMPBeforeScanBlock->getFirstInsertionPt());
5635 return BodyGenCB(Builder.saveIP(), IV);
5636 };
5637
5638 const auto &&InputLoopGen = [&]() -> Error {
5639 Expected<CanonicalLoopInfo *> LoopInfo = createCanonicalLoop(
5640 Loc: Builder.saveIP(), BodyGenCB: BodyGen, Start, Stop, Step, IsSigned, InclusiveStop,
5641 ComputeIP, Name, InScan: true, ScanRedInfo);
5642 if (!LoopInfo)
5643 return LoopInfo.takeError();
5644 Result.push_back(Elt: *LoopInfo);
5645 Builder.restoreIP(IP: (*LoopInfo)->getAfterIP());
5646 return Error::success();
5647 };
5648 const auto &&ScanLoopGen = [&](LocationDescription Loc) -> Error {
5649 Expected<CanonicalLoopInfo *> LoopInfo =
5650 createCanonicalLoop(Loc, BodyGenCB: BodyGen, Start, Stop, Step, IsSigned,
5651 InclusiveStop, ComputeIP, Name, InScan: true, ScanRedInfo);
5652 if (!LoopInfo)
5653 return LoopInfo.takeError();
5654 Result.push_back(Elt: *LoopInfo);
5655 Builder.restoreIP(IP: (*LoopInfo)->getAfterIP());
5656 ScanRedInfo->OMPScanFinish = Builder.GetInsertBlock();
5657 return Error::success();
5658 };
5659 Error Err = emitScanBasedDirectiveIR(InputLoopGen, ScanLoopGen, ScanRedInfo);
5660 if (Err)
5661 return Err;
5662 return Result;
5663}
5664
5665Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
5666 const LocationDescription &Loc, Value *Start, Value *Stop, Value *Step,
5667 bool IsSigned, bool InclusiveStop, const Twine &Name) {
5668
5669 // Consider the following difficulties (assuming 8-bit signed integers):
5670 // * Adding \p Step to the loop counter which passes \p Stop may overflow:
5671 // DO I = 1, 100, 50
5672 /// * A \p Step of INT_MIN cannot not be normalized to a positive direction:
5673 // DO I = 100, 0, -128
5674
5675 // Start, Stop and Step must be of the same integer type.
5676 auto *IndVarTy = cast<IntegerType>(Val: Start->getType());
5677 assert(IndVarTy == Stop->getType() && "Stop type mismatch");
5678 assert(IndVarTy == Step->getType() && "Step type mismatch");
5679
5680 updateToLocation(Loc);
5681
5682 ConstantInt *Zero = ConstantInt::get(Ty: IndVarTy, V: 0);
5683 ConstantInt *One = ConstantInt::get(Ty: IndVarTy, V: 1);
5684
5685 // Like Step, but always positive.
5686 Value *Incr = Step;
5687
5688 // Distance between Start and Stop; always positive.
5689 Value *Span;
5690
5691 // Condition whether there are no iterations are executed at all, e.g. because
5692 // UB < LB.
5693 Value *ZeroCmp;
5694
5695 if (IsSigned) {
5696 // Ensure that increment is positive. If not, negate and invert LB and UB.
5697 Value *IsNeg = Builder.CreateICmpSLT(LHS: Step, RHS: Zero);
5698 Incr = Builder.CreateSelect(C: IsNeg, True: Builder.CreateNeg(V: Step), False: Step);
5699 Value *LB = Builder.CreateSelect(C: IsNeg, True: Stop, False: Start);
5700 Value *UB = Builder.CreateSelect(C: IsNeg, True: Start, False: Stop);
5701 Span = Builder.CreateSub(LHS: UB, RHS: LB, Name: "", HasNUW: false, HasNSW: true);
5702 ZeroCmp = Builder.CreateICmp(
5703 P: InclusiveStop ? CmpInst::ICMP_SLT : CmpInst::ICMP_SLE, LHS: UB, RHS: LB);
5704 } else {
5705 Span = Builder.CreateSub(LHS: Stop, RHS: Start, Name: "", HasNUW: true);
5706 ZeroCmp = Builder.CreateICmp(
5707 P: InclusiveStop ? CmpInst::ICMP_ULT : CmpInst::ICMP_ULE, LHS: Stop, RHS: Start);
5708 }
5709
5710 Value *CountIfLooping;
5711 if (InclusiveStop) {
5712 CountIfLooping = Builder.CreateAdd(LHS: Builder.CreateUDiv(LHS: Span, RHS: Incr), RHS: One);
5713 } else {
5714 // Avoid incrementing past stop since it could overflow.
5715 Value *CountIfTwo = Builder.CreateAdd(
5716 LHS: Builder.CreateUDiv(LHS: Builder.CreateSub(LHS: Span, RHS: One), RHS: Incr), RHS: One);
5717 Value *OneCmp = Builder.CreateICmp(P: CmpInst::ICMP_ULE, LHS: Span, RHS: Incr);
5718 CountIfLooping = Builder.CreateSelect(C: OneCmp, True: One, False: CountIfTwo);
5719 }
5720
5721 return Builder.CreateSelect(C: ZeroCmp, True: Zero, False: CountIfLooping,
5722 Name: "omp_" + Name + ".tripcount");
5723}
5724
5725Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
5726 const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
5727 Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
5728 InsertPointTy ComputeIP, const Twine &Name, bool InScan,
5729 ScanInfo *ScanRedInfo) {
5730 LocationDescription ComputeLoc =
5731 ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
5732
5733 Value *TripCount = calculateCanonicalLoopTripCount(
5734 Loc: ComputeLoc, Start, Stop, Step, IsSigned, InclusiveStop, Name);
5735
5736 auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
5737 Builder.restoreIP(IP: CodeGenIP);
5738 Value *Span = Builder.CreateMul(LHS: IV, RHS: Step);
5739 Value *IndVar = Builder.CreateAdd(LHS: Span, RHS: Start);
5740 if (InScan)
5741 ScanRedInfo->IV = IndVar;
5742 return BodyGenCB(Builder.saveIP(), IndVar);
5743 };
5744 LocationDescription LoopLoc =
5745 ComputeIP.isSet()
5746 ? Loc
5747 : LocationDescription(Builder.saveIP(),
5748 Builder.getCurrentDebugLocation());
5749 return createCanonicalLoop(Loc: LoopLoc, BodyGenCB: BodyGen, TripCount, Name);
5750}
5751
5752// Returns an LLVM function to call for initializing loop bounds using OpenMP
5753// static scheduling for composite `distribute parallel for` depending on
5754// `type`. Only i32 and i64 are supported by the runtime. Always interpret
5755// integers as unsigned similarly to CanonicalLoopInfo.
5756static FunctionCallee
5757getKmpcDistForStaticInitForType(Type *Ty, Module &M,
5758 OpenMPIRBuilder &OMPBuilder) {
5759 unsigned Bitwidth = Ty->getIntegerBitWidth();
5760 if (Bitwidth == 32)
5761 return OMPBuilder.getOrCreateRuntimeFunction(
5762 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dist_for_static_init_4u);
5763 if (Bitwidth == 64)
5764 return OMPBuilder.getOrCreateRuntimeFunction(
5765 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dist_for_static_init_8u);
5766 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
5767}
5768
5769// Returns an LLVM function to call for initializing loop bounds using OpenMP
5770// static scheduling depending on `type`. Only i32 and i64 are supported by the
5771// runtime. Always interpret integers as unsigned similarly to
5772// CanonicalLoopInfo.
5773static FunctionCallee getKmpcForStaticInitForType(Type *Ty, Module &M,
5774 OpenMPIRBuilder &OMPBuilder) {
5775 unsigned Bitwidth = Ty->getIntegerBitWidth();
5776 if (Bitwidth == 32)
5777 return OMPBuilder.getOrCreateRuntimeFunction(
5778 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_4u);
5779 if (Bitwidth == 64)
5780 return OMPBuilder.getOrCreateRuntimeFunction(
5781 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_8u);
5782 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
5783}
5784
5785OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
5786 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
5787 WorksharingLoopType LoopType, bool NeedsBarrier, bool HasDistSchedule,
5788 OMPScheduleType DistScheduleSchedType) {
5789 assert(CLI->isValid() && "Requires a valid canonical loop");
5790 assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
5791 "Require dedicated allocate IP");
5792
5793 // Set up the source location value for OpenMP runtime.
5794 Builder.restoreIP(IP: CLI->getPreheaderIP());
5795 Builder.SetCurrentDebugLocation(DL);
5796
5797 uint32_t SrcLocStrSize;
5798 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
5799 IdentFlag Flag = IdentFlag(0);
5800 switch (LoopType) {
5801 case WorksharingLoopType::ForStaticLoop:
5802 Flag = OMP_IDENT_FLAG_WORK_LOOP;
5803 break;
5804 case WorksharingLoopType::DistributeStaticLoop:
5805 Flag = OMP_IDENT_FLAG_WORK_DISTRIBUTE;
5806 break;
5807 case WorksharingLoopType::DistributeForStaticLoop:
5808 Flag = OMP_IDENT_FLAG_WORK_DISTRIBUTE | OMP_IDENT_FLAG_WORK_LOOP;
5809 break;
5810 }
5811 Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize, LocFlags: Flag);
5812
5813 // Declare useful OpenMP runtime functions.
5814 Value *IV = CLI->getIndVar();
5815 Type *IVTy = IV->getType();
5816 FunctionCallee StaticInit =
5817 LoopType == WorksharingLoopType::DistributeForStaticLoop
5818 ? getKmpcDistForStaticInitForType(Ty: IVTy, M, OMPBuilder&: *this)
5819 : getKmpcForStaticInitForType(Ty: IVTy, M, OMPBuilder&: *this);
5820 FunctionCallee StaticFini =
5821 getOrCreateRuntimeFunction(M, FnID: omp::OMPRTL___kmpc_for_static_fini);
5822
5823 // Allocate space for computed loop bounds as expected by the "init" function.
5824 Builder.SetInsertPoint(AllocaIP.getBlock()->getFirstNonPHIOrDbgOrAlloca());
5825
5826 Type *I32Type = Type::getInt32Ty(C&: M.getContext());
5827 Value *PLastIter = Builder.CreateAlloca(Ty: I32Type, ArraySize: nullptr, Name: "p.lastiter");
5828 Value *PLowerBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.lowerbound");
5829 Value *PUpperBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.upperbound");
5830 Value *PStride = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.stride");
5831 CLI->setLastIter(PLastIter);
5832
5833 // At the end of the preheader, prepare for calling the "init" function by
5834 // storing the current loop bounds into the allocated space. A canonical loop
5835 // always iterates from 0 to trip-count with step 1. Note that "init" expects
5836 // and produces an inclusive upper bound.
5837 Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
5838 Constant *Zero = ConstantInt::get(Ty: IVTy, V: 0);
5839 Constant *One = ConstantInt::get(Ty: IVTy, V: 1);
5840 Builder.CreateStore(Val: Zero, Ptr: PLowerBound);
5841 Value *UpperBound = Builder.CreateSub(LHS: CLI->getTripCount(), RHS: One);
5842 Builder.CreateStore(Val: UpperBound, Ptr: PUpperBound);
5843 Builder.CreateStore(Val: One, Ptr: PStride);
5844
5845 Value *ThreadNum =
5846 getOrCreateThreadID(Ident: getOrCreateIdent(SrcLocStr, SrcLocStrSize));
5847
5848 OMPScheduleType SchedType =
5849 (LoopType == WorksharingLoopType::DistributeStaticLoop)
5850 ? OMPScheduleType::OrderedDistribute
5851 : OMPScheduleType::UnorderedStatic;
5852 Constant *SchedulingType =
5853 ConstantInt::get(Ty: I32Type, V: static_cast<int>(SchedType));
5854
5855 // Call the "init" function and update the trip count of the loop with the
5856 // value it produced.
5857 auto BuildInitCall = [LoopType, SrcLoc, ThreadNum, PLastIter, PLowerBound,
5858 PUpperBound, IVTy, PStride, One, Zero, StaticInit,
5859 this](Value *SchedulingType, auto &Builder) {
5860 SmallVector<Value *, 10> Args({SrcLoc, ThreadNum, SchedulingType, PLastIter,
5861 PLowerBound, PUpperBound});
5862 if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
5863 Value *PDistUpperBound =
5864 Builder.CreateAlloca(IVTy, nullptr, "p.distupperbound");
5865 Args.push_back(Elt: PDistUpperBound);
5866 }
5867 Args.append(IL: {PStride, One, Zero});
5868 createRuntimeFunctionCall(Callee: StaticInit, Args);
5869 };
5870 BuildInitCall(SchedulingType, Builder);
5871 if (HasDistSchedule &&
5872 LoopType != WorksharingLoopType::DistributeStaticLoop) {
5873 Constant *DistScheduleSchedType = ConstantInt::get(
5874 Ty: I32Type, V: static_cast<int>(omp::OMPScheduleType::OrderedDistribute));
5875 // We want to emit a second init function call for the dist_schedule clause
5876 // to the Distribute construct. This should only be done however if a
5877 // Workshare Loop is nested within a Distribute Construct
5878 BuildInitCall(DistScheduleSchedType, Builder);
5879 }
5880 Value *LowerBound = Builder.CreateLoad(Ty: IVTy, Ptr: PLowerBound);
5881 Value *InclusiveUpperBound = Builder.CreateLoad(Ty: IVTy, Ptr: PUpperBound);
5882 Value *TripCountMinusOne = Builder.CreateSub(LHS: InclusiveUpperBound, RHS: LowerBound);
5883 Value *TripCount = Builder.CreateAdd(LHS: TripCountMinusOne, RHS: One);
5884 CLI->setTripCount(TripCount);
5885
5886 // Update all uses of the induction variable except the one in the condition
5887 // block that compares it with the actual upper bound, and the increment in
5888 // the latch block.
5889
5890 CLI->mapIndVar(Updater: [&](Instruction *OldIV) -> Value * {
5891 Builder.SetInsertPoint(TheBB: CLI->getBody(),
5892 IP: CLI->getBody()->getFirstInsertionPt());
5893 Builder.SetCurrentDebugLocation(DL);
5894 return Builder.CreateAdd(LHS: OldIV, RHS: LowerBound);
5895 });
5896
5897 // In the "exit" block, call the "fini" function.
5898 Builder.SetInsertPoint(TheBB: CLI->getExit(),
5899 IP: CLI->getExit()->getTerminator()->getIterator());
5900 createRuntimeFunctionCall(Callee: StaticFini, Args: {SrcLoc, ThreadNum});
5901
5902 // Add the barrier if requested.
5903 if (NeedsBarrier) {
5904 InsertPointOrErrorTy BarrierIP =
5905 createBarrier(Loc: LocationDescription(Builder.saveIP(), DL),
5906 Kind: omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
5907 /* CheckCancelFlag */ false);
5908 if (!BarrierIP)
5909 return BarrierIP.takeError();
5910 }
5911
5912 InsertPointTy AfterIP = CLI->getAfterIP();
5913 CLI->invalidate();
5914
5915 return AfterIP;
5916}
5917
5918static void addAccessGroupMetadata(BasicBlock *Block, MDNode *AccessGroup,
5919 LoopInfo &LI);
5920static void addLoopMetadata(CanonicalLoopInfo *Loop,
5921 ArrayRef<Metadata *> Properties);
5922
5923static void applyParallelAccessesMetadata(CanonicalLoopInfo *CLI,
5924 LLVMContext &Ctx, Loop *Loop,
5925 LoopInfo &LoopInfo,
5926 SmallVector<Metadata *> &LoopMDList) {
5927 SmallSet<BasicBlock *, 8> Reachable;
5928
5929 // Get the basic blocks from the loop in which memref instructions
5930 // can be found.
5931 // TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
5932 // preferably without running any passes.
5933 for (BasicBlock *Block : Loop->getBlocks()) {
5934 if (Block == CLI->getCond() || Block == CLI->getHeader())
5935 continue;
5936 Reachable.insert(Ptr: Block);
5937 }
5938
5939 // Add access group metadata to memory-access instructions.
5940 MDNode *AccessGroup = MDNode::getDistinct(Context&: Ctx, MDs: {});
5941 for (BasicBlock *BB : Reachable)
5942 addAccessGroupMetadata(Block: BB, AccessGroup, LI&: LoopInfo);
5943 // TODO: If the loop has existing parallel access metadata, have
5944 // to combine two lists.
5945 LoopMDList.push_back(Elt: MDNode::get(
5946 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.parallel_accesses"), AccessGroup}));
5947}
5948
5949OpenMPIRBuilder::InsertPointOrErrorTy
5950OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
5951 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
5952 bool NeedsBarrier, Value *ChunkSize, OMPScheduleType SchedType,
5953 Value *DistScheduleChunkSize, OMPScheduleType DistScheduleSchedType) {
5954 assert(CLI->isValid() && "Requires a valid canonical loop");
5955 assert((ChunkSize || DistScheduleChunkSize) && "Chunk size is required");
5956
5957 LLVMContext &Ctx = CLI->getFunction()->getContext();
5958 Value *IV = CLI->getIndVar();
5959 Value *OrigTripCount = CLI->getTripCount();
5960 Type *IVTy = IV->getType();
5961 assert(IVTy->getIntegerBitWidth() <= 64 &&
5962 "Max supported tripcount bitwidth is 64 bits");
5963 Type *InternalIVTy = IVTy->getIntegerBitWidth() <= 32 ? Type::getInt32Ty(C&: Ctx)
5964 : Type::getInt64Ty(C&: Ctx);
5965 Type *I32Type = Type::getInt32Ty(C&: M.getContext());
5966 Constant *Zero = ConstantInt::get(Ty: InternalIVTy, V: 0);
5967 Constant *One = ConstantInt::get(Ty: InternalIVTy, V: 1);
5968
5969 Function *F = CLI->getFunction();
5970 // Blocks must have terminators.
5971 // FIXME: Don't run analyses on incomplete/invalid IR.
5972 SmallVector<Instruction *> UIs;
5973 for (BasicBlock &BB : *F)
5974 if (!BB.hasTerminator())
5975 UIs.push_back(Elt: new UnreachableInst(F->getContext(), &BB));
5976 FunctionAnalysisManager FAM;
5977 FAM.registerPass(PassBuilder: []() { return DominatorTreeAnalysis(); });
5978 FAM.registerPass(PassBuilder: []() { return PassInstrumentationAnalysis(); });
5979 LoopAnalysis LIA;
5980 LoopInfo &&LI = LIA.run(F&: *F, AM&: FAM);
5981 for (Instruction *I : UIs)
5982 I->eraseFromParent();
5983 Loop *L = LI.getLoopFor(BB: CLI->getHeader());
5984 SmallVector<Metadata *> LoopMDList;
5985 if (ChunkSize || DistScheduleChunkSize)
5986 applyParallelAccessesMetadata(CLI, Ctx, Loop: L, LoopInfo&: LI, LoopMDList);
5987 addLoopMetadata(Loop: CLI, Properties: LoopMDList);
5988
5989 // Declare useful OpenMP runtime functions.
5990 FunctionCallee StaticInit =
5991 getKmpcForStaticInitForType(Ty: InternalIVTy, M, OMPBuilder&: *this);
5992 FunctionCallee StaticFini =
5993 getOrCreateRuntimeFunction(M, FnID: omp::OMPRTL___kmpc_for_static_fini);
5994
5995 // Allocate space for computed loop bounds as expected by the "init" function.
5996 Builder.restoreIP(IP: AllocaIP);
5997 Builder.SetCurrentDebugLocation(DL);
5998 Value *PLastIter = Builder.CreateAlloca(Ty: I32Type, ArraySize: nullptr, Name: "p.lastiter");
5999 Value *PLowerBound =
6000 Builder.CreateAlloca(Ty: InternalIVTy, ArraySize: nullptr, Name: "p.lowerbound");
6001 Value *PUpperBound =
6002 Builder.CreateAlloca(Ty: InternalIVTy, ArraySize: nullptr, Name: "p.upperbound");
6003 Value *PStride = Builder.CreateAlloca(Ty: InternalIVTy, ArraySize: nullptr, Name: "p.stride");
6004 CLI->setLastIter(PLastIter);
6005
6006 // Set up the source location value for the OpenMP runtime.
6007 Builder.restoreIP(IP: CLI->getPreheaderIP());
6008 Builder.SetCurrentDebugLocation(DL);
6009
6010 // TODO: Detect overflow in ubsan or max-out with current tripcount.
6011 Value *CastedChunkSize = Builder.CreateZExtOrTrunc(
6012 V: ChunkSize ? ChunkSize : Zero, DestTy: InternalIVTy, Name: "chunksize");
6013 Value *CastedDistScheduleChunkSize = Builder.CreateZExtOrTrunc(
6014 V: DistScheduleChunkSize ? DistScheduleChunkSize : Zero, DestTy: InternalIVTy,
6015 Name: "distschedulechunksize");
6016 Value *CastedTripCount =
6017 Builder.CreateZExt(V: OrigTripCount, DestTy: InternalIVTy, Name: "tripcount");
6018
6019 Constant *SchedulingType =
6020 ConstantInt::get(Ty: I32Type, V: static_cast<int>(SchedType));
6021 Constant *DistSchedulingType =
6022 ConstantInt::get(Ty: I32Type, V: static_cast<int>(DistScheduleSchedType));
6023 Builder.CreateStore(Val: Zero, Ptr: PLowerBound);
6024 Value *OrigUpperBound = Builder.CreateSub(LHS: CastedTripCount, RHS: One);
6025 Value *IsTripCountZero = Builder.CreateICmpEQ(LHS: CastedTripCount, RHS: Zero);
6026 Value *UpperBound =
6027 Builder.CreateSelect(C: IsTripCountZero, True: Zero, False: OrigUpperBound);
6028 Builder.CreateStore(Val: UpperBound, Ptr: PUpperBound);
6029 Builder.CreateStore(Val: One, Ptr: PStride);
6030
6031 // Call the "init" function and update the trip count of the loop with the
6032 // value it produced.
6033 uint32_t SrcLocStrSize;
6034 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
6035 IdentFlag Flag = OMP_IDENT_FLAG_WORK_LOOP;
6036 if (DistScheduleSchedType != OMPScheduleType::None) {
6037 Flag |= OMP_IDENT_FLAG_WORK_DISTRIBUTE;
6038 }
6039 Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize, LocFlags: Flag);
6040 Value *ThreadNum =
6041 getOrCreateThreadID(Ident: getOrCreateIdent(SrcLocStr, SrcLocStrSize));
6042 auto BuildInitCall = [StaticInit, SrcLoc, ThreadNum, PLastIter, PLowerBound,
6043 PUpperBound, PStride, One,
6044 this](Value *SchedulingType, Value *ChunkSize,
6045 auto &Builder) {
6046 createRuntimeFunctionCall(
6047 Callee: StaticInit, Args: {/*loc=*/SrcLoc, /*global_tid=*/ThreadNum,
6048 /*schedtype=*/SchedulingType, /*plastiter=*/PLastIter,
6049 /*plower=*/PLowerBound, /*pupper=*/PUpperBound,
6050 /*pstride=*/PStride, /*incr=*/One,
6051 /*chunk=*/ChunkSize});
6052 };
6053 BuildInitCall(SchedulingType, CastedChunkSize, Builder);
6054 if (DistScheduleSchedType != OMPScheduleType::None &&
6055 SchedType != OMPScheduleType::OrderedDistributeChunked &&
6056 SchedType != OMPScheduleType::OrderedDistribute) {
6057 // We want to emit a second init function call for the dist_schedule clause
6058 // to the Distribute construct. This should only be done however if a
6059 // Workshare Loop is nested within a Distribute Construct
6060 BuildInitCall(DistSchedulingType, CastedDistScheduleChunkSize, Builder);
6061 }
6062
6063 // Load values written by the "init" function.
6064 Value *FirstChunkStart =
6065 Builder.CreateLoad(Ty: InternalIVTy, Ptr: PLowerBound, Name: "omp_firstchunk.lb");
6066 Value *FirstChunkStop =
6067 Builder.CreateLoad(Ty: InternalIVTy, Ptr: PUpperBound, Name: "omp_firstchunk.ub");
6068 Value *FirstChunkEnd = Builder.CreateAdd(LHS: FirstChunkStop, RHS: One);
6069 Value *ChunkRange =
6070 Builder.CreateSub(LHS: FirstChunkEnd, RHS: FirstChunkStart, Name: "omp_chunk.range");
6071 Value *NextChunkStride =
6072 Builder.CreateLoad(Ty: InternalIVTy, Ptr: PStride, Name: "omp_dispatch.stride");
6073
6074 // Create outer "dispatch" loop for enumerating the chunks.
6075 BasicBlock *DispatchEnter = splitBB(Builder, CreateBranch: true);
6076 Value *DispatchCounter;
6077
6078 // It is safe to assume this didn't return an error because the callback
6079 // passed into createCanonicalLoop is the only possible error source, and it
6080 // always returns success.
6081 CanonicalLoopInfo *DispatchCLI = cantFail(ValOrErr: createCanonicalLoop(
6082 Loc: {Builder.saveIP(), DL},
6083 BodyGenCB: [&](InsertPointTy BodyIP, Value *Counter) {
6084 DispatchCounter = Counter;
6085 return Error::success();
6086 },
6087 Start: FirstChunkStart, Stop: CastedTripCount, Step: NextChunkStride,
6088 /*IsSigned=*/false, /*InclusiveStop=*/false, /*ComputeIP=*/{},
6089 Name: "dispatch"));
6090
6091 // Remember the BasicBlocks of the dispatch loop we need, then invalidate to
6092 // not have to preserve the canonical invariant.
6093 BasicBlock *DispatchBody = DispatchCLI->getBody();
6094 BasicBlock *DispatchLatch = DispatchCLI->getLatch();
6095 BasicBlock *DispatchExit = DispatchCLI->getExit();
6096 BasicBlock *DispatchAfter = DispatchCLI->getAfter();
6097 DispatchCLI->invalidate();
6098
6099 // Rewire the original loop to become the chunk loop inside the dispatch loop.
6100 redirectTo(Source: DispatchAfter, Target: CLI->getAfter(), DL);
6101 redirectTo(Source: CLI->getExit(), Target: DispatchLatch, DL);
6102 redirectTo(Source: DispatchBody, Target: DispatchEnter, DL);
6103
6104 // Prepare the prolog of the chunk loop.
6105 Builder.restoreIP(IP: CLI->getPreheaderIP());
6106 Builder.SetCurrentDebugLocation(DL);
6107
6108 // Compute the number of iterations of the chunk loop.
6109 Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
6110 Value *ChunkEnd = Builder.CreateAdd(LHS: DispatchCounter, RHS: ChunkRange);
6111 Value *IsLastChunk =
6112 Builder.CreateICmpUGE(LHS: ChunkEnd, RHS: CastedTripCount, Name: "omp_chunk.is_last");
6113 Value *CountUntilOrigTripCount =
6114 Builder.CreateSub(LHS: CastedTripCount, RHS: DispatchCounter);
6115 Value *ChunkTripCount = Builder.CreateSelect(
6116 C: IsLastChunk, True: CountUntilOrigTripCount, False: ChunkRange, Name: "omp_chunk.tripcount");
6117 Value *BackcastedChunkTC =
6118 Builder.CreateTrunc(V: ChunkTripCount, DestTy: IVTy, Name: "omp_chunk.tripcount.trunc");
6119 CLI->setTripCount(BackcastedChunkTC);
6120
6121 // Update all uses of the induction variable except the one in the condition
6122 // block that compares it with the actual upper bound, and the increment in
6123 // the latch block.
6124 Value *BackcastedDispatchCounter =
6125 Builder.CreateTrunc(V: DispatchCounter, DestTy: IVTy, Name: "omp_dispatch.iv.trunc");
6126 CLI->mapIndVar(Updater: [&](Instruction *) -> Value * {
6127 Builder.restoreIP(IP: CLI->getBodyIP());
6128 return Builder.CreateAdd(LHS: IV, RHS: BackcastedDispatchCounter);
6129 });
6130
6131 // In the "exit" block, call the "fini" function.
6132 Builder.SetInsertPoint(TheBB: DispatchExit, IP: DispatchExit->getFirstInsertionPt());
6133 createRuntimeFunctionCall(Callee: StaticFini, Args: {SrcLoc, ThreadNum});
6134
6135 // Add the barrier if requested.
6136 if (NeedsBarrier) {
6137 InsertPointOrErrorTy AfterIP =
6138 createBarrier(Loc: LocationDescription(Builder.saveIP(), DL), Kind: OMPD_for,
6139 /*ForceSimpleCall=*/false, /*CheckCancelFlag=*/false);
6140 if (!AfterIP)
6141 return AfterIP.takeError();
6142 }
6143
6144#ifndef NDEBUG
6145 // Even though we currently do not support applying additional methods to it,
6146 // the chunk loop should remain a canonical loop.
6147 CLI->assertOK();
6148#endif
6149
6150 return InsertPointTy(DispatchAfter, DispatchAfter->getFirstInsertionPt());
6151}
6152
6153// Returns an LLVM function to call for executing an OpenMP static worksharing
6154// for loop depending on `type`. Only i32 and i64 are supported by the runtime.
6155// Always interpret integers as unsigned similarly to CanonicalLoopInfo.
6156static FunctionCallee
6157getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
6158 WorksharingLoopType LoopType) {
6159 unsigned Bitwidth = Ty->getIntegerBitWidth();
6160 Module &M = OMPBuilder->M;
6161 switch (LoopType) {
6162 case WorksharingLoopType::ForStaticLoop:
6163 if (Bitwidth == 32)
6164 return OMPBuilder->getOrCreateRuntimeFunction(
6165 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_4u);
6166 if (Bitwidth == 64)
6167 return OMPBuilder->getOrCreateRuntimeFunction(
6168 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_8u);
6169 break;
6170 case WorksharingLoopType::DistributeStaticLoop:
6171 if (Bitwidth == 32)
6172 return OMPBuilder->getOrCreateRuntimeFunction(
6173 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_4u);
6174 if (Bitwidth == 64)
6175 return OMPBuilder->getOrCreateRuntimeFunction(
6176 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_8u);
6177 break;
6178 case WorksharingLoopType::DistributeForStaticLoop:
6179 if (Bitwidth == 32)
6180 return OMPBuilder->getOrCreateRuntimeFunction(
6181 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_4u);
6182 if (Bitwidth == 64)
6183 return OMPBuilder->getOrCreateRuntimeFunction(
6184 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_8u);
6185 break;
6186 }
6187 if (Bitwidth != 32 && Bitwidth != 64) {
6188 llvm_unreachable("Unknown OpenMP loop iterator bitwidth");
6189 }
6190 llvm_unreachable("Unknown type of OpenMP worksharing loop");
6191}
6192
6193// Inserts a call to proper OpenMP Device RTL function which handles
6194// loop worksharing.
6195static void createTargetLoopWorkshareCall(OpenMPIRBuilder *OMPBuilder,
6196 WorksharingLoopType LoopType,
6197 BasicBlock *InsertBlock, Value *Ident,
6198 Value *LoopBodyArg, Value *TripCount,
6199 Function &LoopBodyFn, bool NoLoop) {
6200 Type *TripCountTy = TripCount->getType();
6201 Module &M = OMPBuilder->M;
6202 IRBuilder<> &Builder = OMPBuilder->Builder;
6203 FunctionCallee RTLFn =
6204 getKmpcForStaticLoopForType(Ty: TripCountTy, OMPBuilder, LoopType);
6205 SmallVector<Value *, 8> RealArgs;
6206 RealArgs.push_back(Elt: Ident);
6207 RealArgs.push_back(Elt: &LoopBodyFn);
6208 RealArgs.push_back(Elt: LoopBodyArg);
6209 RealArgs.push_back(Elt: TripCount);
6210 if (LoopType == WorksharingLoopType::DistributeStaticLoop) {
6211 RealArgs.push_back(Elt: ConstantInt::get(Ty: TripCountTy, V: 0));
6212 RealArgs.push_back(Elt: ConstantInt::get(Ty: Builder.getInt8Ty(), V: 0));
6213 Builder.restoreIP(IP: {InsertBlock, std::prev(x: InsertBlock->end())});
6214 OMPBuilder->createRuntimeFunctionCall(Callee: RTLFn, Args: RealArgs);
6215 return;
6216 }
6217 FunctionCallee RTLNumThreads = OMPBuilder->getOrCreateRuntimeFunction(
6218 M, FnID: omp::RuntimeFunction::OMPRTL_omp_get_num_threads);
6219 Builder.restoreIP(IP: {InsertBlock, std::prev(x: InsertBlock->end())});
6220 Value *NumThreads = OMPBuilder->createRuntimeFunctionCall(Callee: RTLNumThreads, Args: {});
6221
6222 RealArgs.push_back(
6223 Elt: Builder.CreateZExtOrTrunc(V: NumThreads, DestTy: TripCountTy, Name: "num.threads.cast"));
6224 RealArgs.push_back(Elt: ConstantInt::get(Ty: TripCountTy, V: 0));
6225 if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
6226 RealArgs.push_back(Elt: ConstantInt::get(Ty: TripCountTy, V: 0));
6227 RealArgs.push_back(Elt: ConstantInt::get(Ty: Builder.getInt8Ty(), V: NoLoop));
6228 } else {
6229 RealArgs.push_back(Elt: ConstantInt::get(Ty: Builder.getInt8Ty(), V: 0));
6230 }
6231
6232 OMPBuilder->createRuntimeFunctionCall(Callee: RTLFn, Args: RealArgs);
6233}
6234
6235static void workshareLoopTargetCallback(
6236 OpenMPIRBuilder *OMPIRBuilder, CanonicalLoopInfo *CLI, Value *Ident,
6237 Function &OutlinedFn, const SmallVector<Instruction *, 4> &ToBeDeleted,
6238 WorksharingLoopType LoopType, bool NoLoop) {
6239 IRBuilder<> &Builder = OMPIRBuilder->Builder;
6240 BasicBlock *Preheader = CLI->getPreheader();
6241 Value *TripCount = CLI->getTripCount();
6242
6243 // After loop body outling, the loop body contains only set up
6244 // of loop body argument structure and the call to the outlined
6245 // loop body function. Firstly, we need to move setup of loop body args
6246 // into loop preheader.
6247 Preheader->splice(ToIt: std::prev(x: Preheader->end()), FromBB: CLI->getBody(),
6248 FromBeginIt: CLI->getBody()->begin(), FromEndIt: std::prev(x: CLI->getBody()->end()));
6249
6250 // The next step is to remove the whole loop. We do not it need anymore.
6251 // That's why make an unconditional branch from loop preheader to loop
6252 // exit block
6253 Builder.restoreIP(IP: {Preheader, Preheader->end()});
6254 Builder.SetCurrentDebugLocation(Preheader->getTerminator()->getDebugLoc());
6255 Preheader->getTerminator()->eraseFromParent();
6256 Builder.CreateBr(Dest: CLI->getExit());
6257
6258 // Delete dead loop blocks
6259 OpenMPIRBuilder::OutlineInfo CleanUpInfo;
6260 SmallPtrSet<BasicBlock *, 32> RegionBlockSet;
6261 SmallVector<BasicBlock *, 32> BlocksToBeRemoved;
6262 CleanUpInfo.EntryBB = CLI->getHeader();
6263 CleanUpInfo.ExitBB = CLI->getExit();
6264 CleanUpInfo.collectBlocks(BlockSet&: RegionBlockSet, BlockVector&: BlocksToBeRemoved);
6265 DeleteDeadBlocks(BBs: BlocksToBeRemoved);
6266
6267 // Find the instruction which corresponds to loop body argument structure
6268 // and remove the call to loop body function instruction.
6269 Value *LoopBodyArg;
6270 User *OutlinedFnUser = OutlinedFn.getUniqueUndroppableUser();
6271 assert(OutlinedFnUser &&
6272 "Expected unique undroppable user of outlined function");
6273 CallInst *OutlinedFnCallInstruction = dyn_cast<CallInst>(Val: OutlinedFnUser);
6274 assert(OutlinedFnCallInstruction && "Expected outlined function call");
6275 assert((OutlinedFnCallInstruction->getParent() == Preheader) &&
6276 "Expected outlined function call to be located in loop preheader");
6277 // Check in case no argument structure has been passed.
6278 if (OutlinedFnCallInstruction->arg_size() > 1)
6279 LoopBodyArg = OutlinedFnCallInstruction->getArgOperand(i: 1);
6280 else
6281 LoopBodyArg = Constant::getNullValue(Ty: Builder.getPtrTy());
6282 OutlinedFnCallInstruction->eraseFromParent();
6283
6284 createTargetLoopWorkshareCall(OMPBuilder: OMPIRBuilder, LoopType, InsertBlock: Preheader, Ident,
6285 LoopBodyArg, TripCount, LoopBodyFn&: OutlinedFn, NoLoop);
6286
6287 for (auto &ToBeDeletedItem : ToBeDeleted)
6288 ToBeDeletedItem->eraseFromParent();
6289 CLI->invalidate();
6290}
6291
6292OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoopTarget(
6293 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
6294 WorksharingLoopType LoopType, bool NoLoop) {
6295 uint32_t SrcLocStrSize;
6296 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
6297 IdentFlag Flag = IdentFlag(0);
6298 switch (LoopType) {
6299 case WorksharingLoopType::ForStaticLoop:
6300 Flag = OMP_IDENT_FLAG_WORK_LOOP;
6301 break;
6302 case WorksharingLoopType::DistributeStaticLoop:
6303 Flag = OMP_IDENT_FLAG_WORK_DISTRIBUTE;
6304 break;
6305 case WorksharingLoopType::DistributeForStaticLoop:
6306 Flag = OMP_IDENT_FLAG_WORK_DISTRIBUTE | OMP_IDENT_FLAG_WORK_LOOP;
6307 break;
6308 }
6309 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize, LocFlags: Flag);
6310
6311 auto OI = std::make_unique<OutlineInfo>();
6312 OI->OuterAllocBB = CLI->getPreheader();
6313 Function *OuterFn = CLI->getPreheader()->getParent();
6314
6315 // Instructions which need to be deleted at the end of code generation
6316 SmallVector<Instruction *, 4> ToBeDeleted;
6317
6318 OI->OuterAllocBB = AllocaIP.getBlock();
6319
6320 // Mark the body loop as region which needs to be extracted
6321 OI->EntryBB = CLI->getBody();
6322 OI->ExitBB = CLI->getLatch()->splitBasicBlockBefore(I: CLI->getLatch()->begin(),
6323 BBName: "omp.prelatch");
6324
6325 // Prepare loop body for extraction
6326 Builder.restoreIP(IP: {CLI->getPreheader(), CLI->getPreheader()->begin()});
6327
6328 // Insert new loop counter variable which will be used only in loop
6329 // body.
6330 AllocaInst *NewLoopCnt = Builder.CreateAlloca(Ty: CLI->getIndVarType(), ArraySize: 0, Name: "");
6331 Instruction *NewLoopCntLoad =
6332 Builder.CreateLoad(Ty: CLI->getIndVarType(), Ptr: NewLoopCnt);
6333 // New loop counter instructions are redundant in the loop preheader when
6334 // code generation for workshare loop is finshed. That's why mark them as
6335 // ready for deletion.
6336 ToBeDeleted.push_back(Elt: NewLoopCntLoad);
6337 ToBeDeleted.push_back(Elt: NewLoopCnt);
6338
6339 // Analyse loop body region. Find all input variables which are used inside
6340 // loop body region.
6341 SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
6342 SmallVector<BasicBlock *, 32> Blocks;
6343 OI->collectBlocks(BlockSet&: ParallelRegionBlockSet, BlockVector&: Blocks);
6344
6345 CodeExtractorAnalysisCache CEAC(*OuterFn);
6346 CodeExtractor Extractor(Blocks,
6347 /* DominatorTree */ nullptr,
6348 /* AggregateArgs */ true,
6349 /* BlockFrequencyInfo */ nullptr,
6350 /* BranchProbabilityInfo */ nullptr,
6351 /* AssumptionCache */ nullptr,
6352 /* AllowVarArgs */ true,
6353 /* AllowAlloca */ true,
6354 /* AllocationBlock */ CLI->getPreheader(),
6355 /* DeallocationBlocks */ {},
6356 /* Suffix */ ".omp_wsloop",
6357 /* AggrArgsIn0AddrSpace */ true);
6358
6359 BasicBlock *CommonExit = nullptr;
6360 SetVector<Value *> SinkingCands, HoistingCands;
6361
6362 // Find allocas outside the loop body region which are used inside loop
6363 // body
6364 Extractor.findAllocas(CEAC, SinkCands&: SinkingCands, HoistCands&: HoistingCands, ExitBlock&: CommonExit);
6365
6366 // We need to model loop body region as the function f(cnt, loop_arg).
6367 // That's why we replace loop induction variable by the new counter
6368 // which will be one of loop body function argument
6369 SmallVector<User *> Users(CLI->getIndVar()->user_begin(),
6370 CLI->getIndVar()->user_end());
6371 for (auto Use : Users) {
6372 if (Instruction *Inst = dyn_cast<Instruction>(Val: Use)) {
6373 if (ParallelRegionBlockSet.count(Ptr: Inst->getParent())) {
6374 Inst->replaceUsesOfWith(From: CLI->getIndVar(), To: NewLoopCntLoad);
6375 }
6376 }
6377 }
6378 // Make sure that loop counter variable is not merged into loop body
6379 // function argument structure and it is passed as separate variable
6380 OI->ExcludeArgsFromAggregate.push_back(Elt: NewLoopCntLoad);
6381
6382 // PostOutline CB is invoked when loop body function is outlined and
6383 // loop body is replaced by call to outlined function. We need to add
6384 // call to OpenMP device rtl inside loop preheader. OpenMP device rtl
6385 // function will handle loop control logic.
6386 //
6387 OI->PostOutlineCB = [=, ToBeDeletedVec =
6388 std::move(ToBeDeleted)](Function &OutlinedFn) {
6389 workshareLoopTargetCallback(OMPIRBuilder: this, CLI, Ident, OutlinedFn, ToBeDeleted: ToBeDeletedVec,
6390 LoopType, NoLoop);
6391 };
6392 addOutlineInfo(OI: std::move(OI));
6393 return CLI->getAfterIP();
6394}
6395
6396OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyWorkshareLoop(
6397 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
6398 bool NeedsBarrier, omp::ScheduleKind SchedKind, Value *ChunkSize,
6399 bool HasSimdModifier, bool HasMonotonicModifier,
6400 bool HasNonmonotonicModifier, bool HasOrderedClause,
6401 WorksharingLoopType LoopType, bool NoLoop, bool HasDistSchedule,
6402 Value *DistScheduleChunkSize) {
6403 if (Config.isTargetDevice())
6404 return applyWorkshareLoopTarget(DL, CLI, AllocaIP, LoopType, NoLoop);
6405 OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType(
6406 ClauseKind: SchedKind, HasChunks: ChunkSize, HasSimdModifier, HasMonotonicModifier,
6407 HasNonmonotonicModifier, HasOrderedClause, HasDistScheduleChunks: DistScheduleChunkSize);
6408
6409 bool IsOrdered = (EffectiveScheduleType & OMPScheduleType::ModifierOrdered) ==
6410 OMPScheduleType::ModifierOrdered;
6411 OMPScheduleType DistScheduleSchedType = OMPScheduleType::None;
6412 if (HasDistSchedule) {
6413 DistScheduleSchedType = DistScheduleChunkSize
6414 ? OMPScheduleType::OrderedDistributeChunked
6415 : OMPScheduleType::OrderedDistribute;
6416 }
6417 switch (EffectiveScheduleType & ~OMPScheduleType::ModifierMask) {
6418 case OMPScheduleType::BaseStatic:
6419 case OMPScheduleType::BaseDistribute:
6420 assert((!ChunkSize || !DistScheduleChunkSize) &&
6421 "No chunk size with static-chunked schedule");
6422 if (IsOrdered && !HasDistSchedule)
6423 return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, SchedType: EffectiveScheduleType,
6424 NeedsBarrier, Chunk: ChunkSize);
6425 // FIXME: Monotonicity ignored?
6426 if (DistScheduleChunkSize)
6427 return applyStaticChunkedWorkshareLoop(
6428 DL, CLI, AllocaIP, NeedsBarrier, ChunkSize, SchedType: EffectiveScheduleType,
6429 DistScheduleChunkSize, DistScheduleSchedType);
6430 return applyStaticWorkshareLoop(DL, CLI, AllocaIP, LoopType, NeedsBarrier,
6431 HasDistSchedule);
6432
6433 case OMPScheduleType::BaseStaticChunked:
6434 case OMPScheduleType::BaseDistributeChunked:
6435 if (IsOrdered && !HasDistSchedule)
6436 return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, SchedType: EffectiveScheduleType,
6437 NeedsBarrier, Chunk: ChunkSize);
6438 // FIXME: Monotonicity ignored?
6439 return applyStaticChunkedWorkshareLoop(
6440 DL, CLI, AllocaIP, NeedsBarrier, ChunkSize, SchedType: EffectiveScheduleType,
6441 DistScheduleChunkSize, DistScheduleSchedType);
6442
6443 case OMPScheduleType::BaseRuntime:
6444 case OMPScheduleType::BaseAuto:
6445 case OMPScheduleType::BaseGreedy:
6446 case OMPScheduleType::BaseBalanced:
6447 case OMPScheduleType::BaseSteal:
6448 case OMPScheduleType::BaseRuntimeSimd:
6449 assert(!ChunkSize &&
6450 "schedule type does not support user-defined chunk sizes");
6451 [[fallthrough]];
6452 case OMPScheduleType::BaseGuidedSimd:
6453 case OMPScheduleType::BaseDynamicChunked:
6454 case OMPScheduleType::BaseGuidedChunked:
6455 case OMPScheduleType::BaseGuidedIterativeChunked:
6456 case OMPScheduleType::BaseGuidedAnalyticalChunked:
6457 case OMPScheduleType::BaseStaticBalancedChunked:
6458 return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, SchedType: EffectiveScheduleType,
6459 NeedsBarrier, Chunk: ChunkSize);
6460
6461 default:
6462 llvm_unreachable("Unknown/unimplemented schedule kind");
6463 }
6464}
6465
6466/// Returns an LLVM function to call for initializing loop bounds using OpenMP
6467/// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
6468/// the runtime. Always interpret integers as unsigned similarly to
6469/// CanonicalLoopInfo.
6470static FunctionCallee
6471getKmpcForDynamicInitForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
6472 unsigned Bitwidth = Ty->getIntegerBitWidth();
6473 if (Bitwidth == 32)
6474 return OMPBuilder.getOrCreateRuntimeFunction(
6475 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_4u);
6476 if (Bitwidth == 64)
6477 return OMPBuilder.getOrCreateRuntimeFunction(
6478 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_8u);
6479 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
6480}
6481
6482/// Returns an LLVM function to call for updating the next loop using OpenMP
6483/// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
6484/// the runtime. Always interpret integers as unsigned similarly to
6485/// CanonicalLoopInfo.
6486static FunctionCallee
6487getKmpcForDynamicNextForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
6488 unsigned Bitwidth = Ty->getIntegerBitWidth();
6489 if (Bitwidth == 32)
6490 return OMPBuilder.getOrCreateRuntimeFunction(
6491 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_4u);
6492 if (Bitwidth == 64)
6493 return OMPBuilder.getOrCreateRuntimeFunction(
6494 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_8u);
6495 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
6496}
6497
6498/// Returns an LLVM function to call for finalizing the dynamic loop using
6499/// depending on `type`. Only i32 and i64 are supported by the runtime. Always
6500/// interpret integers as unsigned similarly to CanonicalLoopInfo.
6501static FunctionCallee
6502getKmpcForDynamicFiniForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
6503 unsigned Bitwidth = Ty->getIntegerBitWidth();
6504 if (Bitwidth == 32)
6505 return OMPBuilder.getOrCreateRuntimeFunction(
6506 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_4u);
6507 if (Bitwidth == 64)
6508 return OMPBuilder.getOrCreateRuntimeFunction(
6509 M, FnID: omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_8u);
6510 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
6511}
6512
6513OpenMPIRBuilder::InsertPointOrErrorTy
6514OpenMPIRBuilder::applyDynamicWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
6515 InsertPointTy AllocaIP,
6516 OMPScheduleType SchedType,
6517 bool NeedsBarrier, Value *Chunk) {
6518 assert(CLI->isValid() && "Requires a valid canonical loop");
6519 assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
6520 "Require dedicated allocate IP");
6521 assert(isValidWorkshareLoopScheduleType(SchedType) &&
6522 "Require valid schedule type");
6523
6524 bool Ordered = (SchedType & OMPScheduleType::ModifierOrdered) ==
6525 OMPScheduleType::ModifierOrdered;
6526
6527 // Set up the source location value for OpenMP runtime.
6528 Builder.SetCurrentDebugLocation(DL);
6529
6530 uint32_t SrcLocStrSize;
6531 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
6532 Value *SrcLoc =
6533 getOrCreateIdent(SrcLocStr, SrcLocStrSize, LocFlags: OMP_IDENT_FLAG_WORK_LOOP);
6534
6535 // Declare useful OpenMP runtime functions.
6536 Value *IV = CLI->getIndVar();
6537 Type *IVTy = IV->getType();
6538 FunctionCallee DynamicInit = getKmpcForDynamicInitForType(Ty: IVTy, M, OMPBuilder&: *this);
6539 FunctionCallee DynamicNext = getKmpcForDynamicNextForType(Ty: IVTy, M, OMPBuilder&: *this);
6540
6541 // Allocate space for computed loop bounds as expected by the "init" function.
6542 Builder.SetInsertPoint(AllocaIP.getBlock()->getFirstNonPHIOrDbgOrAlloca());
6543 Type *I32Type = Type::getInt32Ty(C&: M.getContext());
6544 Value *PLastIter = Builder.CreateAlloca(Ty: I32Type, ArraySize: nullptr, Name: "p.lastiter");
6545 Value *PLowerBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.lowerbound");
6546 Value *PUpperBound = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.upperbound");
6547 Value *PStride = Builder.CreateAlloca(Ty: IVTy, ArraySize: nullptr, Name: "p.stride");
6548 CLI->setLastIter(PLastIter);
6549
6550 // At the end of the preheader, prepare for calling the "init" function by
6551 // storing the current loop bounds into the allocated space. A canonical loop
6552 // always iterates from 0 to trip-count with step 1. Note that "init" expects
6553 // and produces an inclusive upper bound.
6554 BasicBlock *PreHeader = CLI->getPreheader();
6555 Builder.SetInsertPoint(PreHeader->getTerminator());
6556 Constant *One = ConstantInt::get(Ty: IVTy, V: 1);
6557 Builder.CreateStore(Val: One, Ptr: PLowerBound);
6558 Value *UpperBound = CLI->getTripCount();
6559 Builder.CreateStore(Val: UpperBound, Ptr: PUpperBound);
6560 Builder.CreateStore(Val: One, Ptr: PStride);
6561
6562 BasicBlock *Header = CLI->getHeader();
6563 BasicBlock *Exit = CLI->getExit();
6564 BasicBlock *Cond = CLI->getCond();
6565 BasicBlock *Latch = CLI->getLatch();
6566 InsertPointTy AfterIP = CLI->getAfterIP();
6567
6568 // The CLI will be "broken" in the code below, as the loop is no longer
6569 // a valid canonical loop.
6570
6571 if (!Chunk)
6572 Chunk = One;
6573
6574 Value *ThreadNum =
6575 getOrCreateThreadID(Ident: getOrCreateIdent(SrcLocStr, SrcLocStrSize));
6576
6577 Constant *SchedulingType =
6578 ConstantInt::get(Ty: I32Type, V: static_cast<int>(SchedType));
6579
6580 // Call the "init" function.
6581 createRuntimeFunctionCall(Callee: DynamicInit, Args: {SrcLoc, ThreadNum, SchedulingType,
6582 /* LowerBound */ One, UpperBound,
6583 /* step */ One, Chunk});
6584
6585 // An outer loop around the existing one.
6586 BasicBlock *OuterCond = BasicBlock::Create(
6587 Context&: PreHeader->getContext(), Name: Twine(PreHeader->getName()) + ".outer.cond",
6588 Parent: PreHeader->getParent());
6589 // This needs to be 32-bit always, so can't use the IVTy Zero above.
6590 Builder.SetInsertPoint(TheBB: OuterCond, IP: OuterCond->getFirstInsertionPt());
6591 Value *Res = createRuntimeFunctionCall(
6592 Callee: DynamicNext,
6593 Args: {SrcLoc, ThreadNum, PLastIter, PLowerBound, PUpperBound, PStride});
6594 Constant *Zero32 = ConstantInt::get(Ty: I32Type, V: 0);
6595 Value *MoreWork = Builder.CreateCmp(Pred: CmpInst::ICMP_NE, LHS: Res, RHS: Zero32);
6596 Value *LowerBound =
6597 Builder.CreateSub(LHS: Builder.CreateLoad(Ty: IVTy, Ptr: PLowerBound), RHS: One, Name: "lb");
6598 Builder.CreateCondBr(Cond: MoreWork, True: Header, False: Exit);
6599
6600 // Change PHI-node in loop header to use outer cond rather than preheader,
6601 // and set IV to the LowerBound.
6602 Instruction *Phi = &Header->front();
6603 auto *PI = cast<PHINode>(Val: Phi);
6604 PI->setIncomingBlock(i: 0, BB: OuterCond);
6605 PI->setIncomingValue(i: 0, V: LowerBound);
6606
6607 // Then set the pre-header to jump to the OuterCond
6608 Instruction *Term = PreHeader->getTerminator();
6609 auto *Br = cast<UncondBrInst>(Val: Term);
6610 Br->setSuccessor(OuterCond);
6611
6612 // Modify the inner condition:
6613 // * Use the UpperBound returned from the DynamicNext call.
6614 // * jump to the loop outer loop when done with one of the inner loops.
6615 Builder.SetInsertPoint(TheBB: Cond, IP: Cond->getFirstInsertionPt());
6616 UpperBound = Builder.CreateLoad(Ty: IVTy, Ptr: PUpperBound, Name: "ub");
6617 Instruction *Comp = &*Builder.GetInsertPoint();
6618 auto *CI = cast<CmpInst>(Val: Comp);
6619 CI->setOperand(i_nocapture: 1, Val_nocapture: UpperBound);
6620 // Redirect the inner exit to branch to outer condition.
6621 Instruction *Branch = &Cond->back();
6622 auto *BI = cast<CondBrInst>(Val: Branch);
6623 assert(BI->getSuccessor(1) == Exit);
6624 BI->setSuccessor(idx: 1, NewSucc: OuterCond);
6625
6626 // Call the "fini" function if "ordered" is present in wsloop directive.
6627 if (Ordered) {
6628 Builder.SetInsertPoint(&Latch->back());
6629 FunctionCallee DynamicFini = getKmpcForDynamicFiniForType(Ty: IVTy, M, OMPBuilder&: *this);
6630 createRuntimeFunctionCall(Callee: DynamicFini, Args: {SrcLoc, ThreadNum});
6631 }
6632
6633 // Add the barrier if requested.
6634 if (NeedsBarrier) {
6635 Builder.SetInsertPoint(&Exit->back());
6636 InsertPointOrErrorTy BarrierIP =
6637 createBarrier(Loc: LocationDescription(Builder.saveIP(), DL),
6638 Kind: omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
6639 /* CheckCancelFlag */ false);
6640 if (!BarrierIP)
6641 return BarrierIP.takeError();
6642 }
6643
6644 CLI->invalidate();
6645 return AfterIP;
6646}
6647
6648/// Redirect all edges that branch to \p OldTarget to \p NewTarget. That is,
6649/// after this \p OldTarget will be orphaned.
6650static void redirectAllPredecessorsTo(BasicBlock *OldTarget,
6651 BasicBlock *NewTarget, DebugLoc DL) {
6652 for (BasicBlock *Pred : make_early_inc_range(Range: predecessors(BB: OldTarget)))
6653 redirectTo(Source: Pred, Target: NewTarget, DL);
6654}
6655
6656static void removeUnusedBlocksFromParent(ArrayRef<BasicBlock *> BBs) {
6657 SmallPtrSet<BasicBlock *, 8> InternalBBs(from_range, BBs);
6658 // We add a block to BBsToKeep iff we have proven it has an external use.
6659 SmallPtrSet<BasicBlock *, 8> BBsToKeep;
6660
6661 while (true) {
6662 bool Changed = false;
6663
6664 for (BasicBlock *BB : BBs) {
6665 if (BBsToKeep.contains(Ptr: BB))
6666 continue;
6667
6668 for (Use &U : BB->uses()) {
6669 auto *UseInst = dyn_cast<Instruction>(Val: U.getUser());
6670 if (!UseInst)
6671 continue;
6672 BasicBlock *UseBB = UseInst->getParent();
6673 if (!InternalBBs.contains(Ptr: UseBB) || BBsToKeep.contains(Ptr: UseBB)) {
6674 BBsToKeep.insert(Ptr: BB);
6675 Changed = true;
6676 break;
6677 }
6678 }
6679 }
6680
6681 if (!Changed)
6682 break;
6683 }
6684
6685 SmallVector<BasicBlock *> BBsToDelete = filter_to_vector(
6686 C&: BBs, Pred: [&BBsToKeep](BasicBlock *BB) { return !BBsToKeep.contains(Ptr: BB); });
6687 DeleteDeadBlocks(BBs: BBsToDelete);
6688}
6689
6690CanonicalLoopInfo *
6691OpenMPIRBuilder::collapseLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
6692 InsertPointTy ComputeIP) {
6693 assert(Loops.size() >= 1 && "At least one loop required");
6694 size_t NumLoops = Loops.size();
6695
6696 // Nothing to do if there is already just one loop.
6697 if (NumLoops == 1)
6698 return Loops.front();
6699
6700 CanonicalLoopInfo *Outermost = Loops.front();
6701 CanonicalLoopInfo *Innermost = Loops.back();
6702 BasicBlock *OrigPreheader = Outermost->getPreheader();
6703 BasicBlock *OrigAfter = Outermost->getAfter();
6704 Function *F = OrigPreheader->getParent();
6705
6706 // Loop control blocks that may become orphaned later.
6707 SmallVector<BasicBlock *, 12> OldControlBBs;
6708 OldControlBBs.reserve(N: 6 * Loops.size());
6709 for (CanonicalLoopInfo *Loop : Loops)
6710 Loop->collectControlBlocks(BBs&: OldControlBBs);
6711
6712 // Setup the IRBuilder for inserting the trip count computation.
6713 Builder.SetCurrentDebugLocation(DL);
6714 if (ComputeIP.isSet())
6715 Builder.restoreIP(IP: ComputeIP);
6716 else
6717 Builder.restoreIP(IP: Outermost->getPreheaderIP());
6718
6719 // Derive the collapsed' loop trip count.
6720 // TODO: Find common/largest indvar type.
6721 Value *CollapsedTripCount = nullptr;
6722 for (CanonicalLoopInfo *L : Loops) {
6723 assert(L->isValid() &&
6724 "All loops to collapse must be valid canonical loops");
6725 Value *OrigTripCount = L->getTripCount();
6726 if (!CollapsedTripCount) {
6727 CollapsedTripCount = OrigTripCount;
6728 continue;
6729 }
6730
6731 // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
6732 CollapsedTripCount =
6733 Builder.CreateNUWMul(LHS: CollapsedTripCount, RHS: OrigTripCount);
6734 }
6735
6736 // Create the collapsed loop control flow.
6737 CanonicalLoopInfo *Result =
6738 createLoopSkeleton(DL, TripCount: CollapsedTripCount, F,
6739 PreInsertBefore: OrigPreheader->getNextNode(), PostInsertBefore: OrigAfter, Name: "collapsed");
6740
6741 // Build the collapsed loop body code.
6742 // Start with deriving the input loop induction variables from the collapsed
6743 // one, using a divmod scheme. To preserve the original loops' order, the
6744 // innermost loop use the least significant bits.
6745 Builder.restoreIP(IP: Result->getBodyIP());
6746
6747 Value *Leftover = Result->getIndVar();
6748 SmallVector<Value *> NewIndVars;
6749 NewIndVars.resize(N: NumLoops);
6750 for (int i = NumLoops - 1; i >= 1; --i) {
6751 Value *OrigTripCount = Loops[i]->getTripCount();
6752
6753 Value *NewIndVar = Builder.CreateURem(LHS: Leftover, RHS: OrigTripCount);
6754 NewIndVars[i] = NewIndVar;
6755
6756 Leftover = Builder.CreateUDiv(LHS: Leftover, RHS: OrigTripCount);
6757 }
6758 // Outermost loop gets all the remaining bits.
6759 NewIndVars[0] = Leftover;
6760
6761 // Construct the loop body control flow.
6762 // We progressively construct the branch structure following in direction of
6763 // the control flow, from the leading in-between code, the loop nest body, the
6764 // trailing in-between code, and rejoining the collapsed loop's latch.
6765 // ContinueBlock and ContinuePred keep track of the source(s) of next edge. If
6766 // the ContinueBlock is set, continue with that block. If ContinuePred, use
6767 // its predecessors as sources.
6768 BasicBlock *ContinueBlock = Result->getBody();
6769 BasicBlock *ContinuePred = nullptr;
6770 auto ContinueWith = [&ContinueBlock, &ContinuePred, DL](BasicBlock *Dest,
6771 BasicBlock *NextSrc) {
6772 if (ContinueBlock)
6773 redirectTo(Source: ContinueBlock, Target: Dest, DL);
6774 else
6775 redirectAllPredecessorsTo(OldTarget: ContinuePred, NewTarget: Dest, DL);
6776
6777 ContinueBlock = nullptr;
6778 ContinuePred = NextSrc;
6779 };
6780
6781 // The code before the nested loop of each level.
6782 // Because we are sinking it into the nest, it will be executed more often
6783 // that the original loop. More sophisticated schemes could keep track of what
6784 // the in-between code is and instantiate it only once per thread.
6785 for (size_t i = 0; i < NumLoops - 1; ++i)
6786 ContinueWith(Loops[i]->getBody(), Loops[i + 1]->getHeader());
6787
6788 // Connect the loop nest body.
6789 ContinueWith(Innermost->getBody(), Innermost->getLatch());
6790
6791 // The code after the nested loop at each level.
6792 for (size_t i = NumLoops - 1; i > 0; --i)
6793 ContinueWith(Loops[i]->getAfter(), Loops[i - 1]->getLatch());
6794
6795 // Connect the finished loop to the collapsed loop latch.
6796 ContinueWith(Result->getLatch(), nullptr);
6797
6798 // Replace the input loops with the new collapsed loop.
6799 redirectTo(Source: Outermost->getPreheader(), Target: Result->getPreheader(), DL);
6800 redirectTo(Source: Result->getAfter(), Target: Outermost->getAfter(), DL);
6801
6802 // Replace the input loop indvars with the derived ones.
6803 for (size_t i = 0; i < NumLoops; ++i)
6804 Loops[i]->getIndVar()->replaceAllUsesWith(V: NewIndVars[i]);
6805
6806 // Remove unused parts of the input loops.
6807 removeUnusedBlocksFromParent(BBs: OldControlBBs);
6808
6809 for (CanonicalLoopInfo *L : Loops)
6810 L->invalidate();
6811
6812#ifndef NDEBUG
6813 Result->assertOK();
6814#endif
6815 return Result;
6816}
6817
6818std::vector<CanonicalLoopInfo *>
6819OpenMPIRBuilder::tileLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
6820 ArrayRef<Value *> TileSizes) {
6821 assert(TileSizes.size() == Loops.size() &&
6822 "Must pass as many tile sizes as there are loops");
6823 int NumLoops = Loops.size();
6824 assert(NumLoops >= 1 && "At least one loop to tile required");
6825
6826 CanonicalLoopInfo *OutermostLoop = Loops.front();
6827 CanonicalLoopInfo *InnermostLoop = Loops.back();
6828 Function *F = OutermostLoop->getBody()->getParent();
6829 BasicBlock *InnerEnter = InnermostLoop->getBody();
6830 BasicBlock *InnerLatch = InnermostLoop->getLatch();
6831
6832 // Loop control blocks that may become orphaned later.
6833 SmallVector<BasicBlock *, 12> OldControlBBs;
6834 OldControlBBs.reserve(N: 6 * Loops.size());
6835 for (CanonicalLoopInfo *Loop : Loops)
6836 Loop->collectControlBlocks(BBs&: OldControlBBs);
6837
6838 // Collect original trip counts and induction variable to be accessible by
6839 // index. Also, the structure of the original loops is not preserved during
6840 // the construction of the tiled loops, so do it before we scavenge the BBs of
6841 // any original CanonicalLoopInfo.
6842 SmallVector<Value *, 4> OrigTripCounts, OrigIndVars;
6843 for (CanonicalLoopInfo *L : Loops) {
6844 assert(L->isValid() && "All input loops must be valid canonical loops");
6845 OrigTripCounts.push_back(Elt: L->getTripCount());
6846 OrigIndVars.push_back(Elt: L->getIndVar());
6847 }
6848
6849 // Collect the code between loop headers. These may contain SSA definitions
6850 // that are used in the loop nest body. To be usable with in the innermost
6851 // body, these BasicBlocks will be sunk into the loop nest body. That is,
6852 // these instructions may be executed more often than before the tiling.
6853 // TODO: It would be sufficient to only sink them into body of the
6854 // corresponding tile loop.
6855 SmallVector<std::pair<BasicBlock *, BasicBlock *>, 4> InbetweenCode;
6856 for (int i = 0; i < NumLoops - 1; ++i) {
6857 CanonicalLoopInfo *Surrounding = Loops[i];
6858 CanonicalLoopInfo *Nested = Loops[i + 1];
6859
6860 BasicBlock *EnterBB = Surrounding->getBody();
6861 BasicBlock *ExitBB = Nested->getHeader();
6862 InbetweenCode.emplace_back(Args&: EnterBB, Args&: ExitBB);
6863 }
6864
6865 // Compute the trip counts of the floor loops.
6866 Builder.SetCurrentDebugLocation(DL);
6867 Builder.restoreIP(IP: OutermostLoop->getPreheaderIP());
6868 SmallVector<Value *, 4> FloorCompleteCount, FloorCount, FloorRems;
6869 for (int i = 0; i < NumLoops; ++i) {
6870 Value *TileSize = TileSizes[i];
6871 Value *OrigTripCount = OrigTripCounts[i];
6872 Type *IVType = OrigTripCount->getType();
6873
6874 Value *FloorCompleteTripCount = Builder.CreateUDiv(LHS: OrigTripCount, RHS: TileSize);
6875 Value *FloorTripRem = Builder.CreateURem(LHS: OrigTripCount, RHS: TileSize);
6876
6877 // 0 if tripcount divides the tilesize, 1 otherwise.
6878 // 1 means we need an additional iteration for a partial tile.
6879 //
6880 // Unfortunately we cannot just use the roundup-formula
6881 // (tripcount + tilesize - 1)/tilesize
6882 // because the summation might overflow. We do not want introduce undefined
6883 // behavior when the untiled loop nest did not.
6884 Value *FloorTripOverflow =
6885 Builder.CreateICmpNE(LHS: FloorTripRem, RHS: ConstantInt::get(Ty: IVType, V: 0));
6886
6887 FloorTripOverflow = Builder.CreateZExt(V: FloorTripOverflow, DestTy: IVType);
6888 Value *FloorTripCount =
6889 Builder.CreateAdd(LHS: FloorCompleteTripCount, RHS: FloorTripOverflow,
6890 Name: "omp_floor" + Twine(i) + ".tripcount", HasNUW: true);
6891
6892 // Remember some values for later use.
6893 FloorCompleteCount.push_back(Elt: FloorCompleteTripCount);
6894 FloorCount.push_back(Elt: FloorTripCount);
6895 FloorRems.push_back(Elt: FloorTripRem);
6896 }
6897
6898 // Generate the new loop nest, from the outermost to the innermost.
6899 std::vector<CanonicalLoopInfo *> Result;
6900 Result.reserve(n: NumLoops * 2);
6901
6902 // The basic block of the surrounding loop that enters the nest generated
6903 // loop.
6904 BasicBlock *Enter = OutermostLoop->getPreheader();
6905
6906 // The basic block of the surrounding loop where the inner code should
6907 // continue.
6908 BasicBlock *Continue = OutermostLoop->getAfter();
6909
6910 // Where the next loop basic block should be inserted.
6911 BasicBlock *OutroInsertBefore = InnermostLoop->getExit();
6912
6913 auto EmbeddNewLoop =
6914 [this, DL, F, InnerEnter, &Enter, &Continue, &OutroInsertBefore](
6915 Value *TripCount, const Twine &Name) -> CanonicalLoopInfo * {
6916 CanonicalLoopInfo *EmbeddedLoop = createLoopSkeleton(
6917 DL, TripCount, F, PreInsertBefore: InnerEnter, PostInsertBefore: OutroInsertBefore, Name);
6918 redirectTo(Source: Enter, Target: EmbeddedLoop->getPreheader(), DL);
6919 redirectTo(Source: EmbeddedLoop->getAfter(), Target: Continue, DL);
6920
6921 // Setup the position where the next embedded loop connects to this loop.
6922 Enter = EmbeddedLoop->getBody();
6923 Continue = EmbeddedLoop->getLatch();
6924 OutroInsertBefore = EmbeddedLoop->getLatch();
6925 return EmbeddedLoop;
6926 };
6927
6928 auto EmbeddNewLoops = [&Result, &EmbeddNewLoop](ArrayRef<Value *> TripCounts,
6929 const Twine &NameBase) {
6930 for (auto P : enumerate(First&: TripCounts)) {
6931 CanonicalLoopInfo *EmbeddedLoop =
6932 EmbeddNewLoop(P.value(), NameBase + Twine(P.index()));
6933 Result.push_back(x: EmbeddedLoop);
6934 }
6935 };
6936
6937 EmbeddNewLoops(FloorCount, "floor");
6938
6939 // Within the innermost floor loop, emit the code that computes the tile
6940 // sizes.
6941 Builder.SetInsertPoint(Enter->getTerminator());
6942 SmallVector<Value *, 4> TileCounts;
6943 for (int i = 0; i < NumLoops; ++i) {
6944 CanonicalLoopInfo *FloorLoop = Result[i];
6945 Value *TileSize = TileSizes[i];
6946
6947 Value *FloorIsEpilogue =
6948 Builder.CreateICmpEQ(LHS: FloorLoop->getIndVar(), RHS: FloorCompleteCount[i]);
6949 Value *TileTripCount =
6950 Builder.CreateSelect(C: FloorIsEpilogue, True: FloorRems[i], False: TileSize);
6951
6952 TileCounts.push_back(Elt: TileTripCount);
6953 }
6954
6955 // Create the tile loops.
6956 EmbeddNewLoops(TileCounts, "tile");
6957
6958 // Insert the inbetween code into the body.
6959 BasicBlock *BodyEnter = Enter;
6960 BasicBlock *BodyEntered = nullptr;
6961 for (std::pair<BasicBlock *, BasicBlock *> P : InbetweenCode) {
6962 BasicBlock *EnterBB = P.first;
6963 BasicBlock *ExitBB = P.second;
6964
6965 if (BodyEnter)
6966 redirectTo(Source: BodyEnter, Target: EnterBB, DL);
6967 else
6968 redirectAllPredecessorsTo(OldTarget: BodyEntered, NewTarget: EnterBB, DL);
6969
6970 BodyEnter = nullptr;
6971 BodyEntered = ExitBB;
6972 }
6973
6974 // Append the original loop nest body into the generated loop nest body.
6975 if (BodyEnter)
6976 redirectTo(Source: BodyEnter, Target: InnerEnter, DL);
6977 else
6978 redirectAllPredecessorsTo(OldTarget: BodyEntered, NewTarget: InnerEnter, DL);
6979 redirectAllPredecessorsTo(OldTarget: InnerLatch, NewTarget: Continue, DL);
6980
6981 // Replace the original induction variable with an induction variable computed
6982 // from the tile and floor induction variables.
6983 Builder.restoreIP(IP: Result.back()->getBodyIP());
6984 for (int i = 0; i < NumLoops; ++i) {
6985 CanonicalLoopInfo *FloorLoop = Result[i];
6986 CanonicalLoopInfo *TileLoop = Result[NumLoops + i];
6987 Value *OrigIndVar = OrigIndVars[i];
6988 Value *Size = TileSizes[i];
6989
6990 Value *Scale =
6991 Builder.CreateMul(LHS: Size, RHS: FloorLoop->getIndVar(), Name: {}, /*HasNUW=*/true);
6992 Value *Shift =
6993 Builder.CreateAdd(LHS: Scale, RHS: TileLoop->getIndVar(), Name: {}, /*HasNUW=*/true);
6994 OrigIndVar->replaceAllUsesWith(V: Shift);
6995 }
6996
6997 // Remove unused parts of the original loops.
6998 removeUnusedBlocksFromParent(BBs: OldControlBBs);
6999
7000 for (CanonicalLoopInfo *L : Loops)
7001 L->invalidate();
7002
7003#ifndef NDEBUG
7004 for (CanonicalLoopInfo *GenL : Result)
7005 GenL->assertOK();
7006#endif
7007 return Result;
7008}
7009
7010/// Attach metadata \p Properties to the basic block described by \p BB. If the
7011/// basic block already has metadata, the basic block properties are appended.
7012static void addBasicBlockMetadata(BasicBlock *BB,
7013 ArrayRef<Metadata *> Properties) {
7014 // Nothing to do if no property to attach.
7015 if (Properties.empty())
7016 return;
7017
7018 LLVMContext &Ctx = BB->getContext();
7019 SmallVector<Metadata *> NewProperties;
7020 NewProperties.push_back(Elt: nullptr);
7021
7022 // If the basic block already has metadata, prepend it to the new metadata.
7023 MDNode *Existing = BB->getTerminator()->getMetadata(KindID: LLVMContext::MD_loop);
7024 if (Existing)
7025 append_range(C&: NewProperties, R: drop_begin(RangeOrContainer: Existing->operands(), N: 1));
7026
7027 append_range(C&: NewProperties, R&: Properties);
7028 MDNode *BasicBlockID = MDNode::getDistinct(Context&: Ctx, MDs: NewProperties);
7029 BasicBlockID->replaceOperandWith(I: 0, New: BasicBlockID);
7030
7031 BB->getTerminator()->setMetadata(KindID: LLVMContext::MD_loop, Node: BasicBlockID);
7032}
7033
7034/// Attach loop metadata \p Properties to the loop described by \p Loop. If the
7035/// loop already has metadata, the loop properties are appended.
7036static void addLoopMetadata(CanonicalLoopInfo *Loop,
7037 ArrayRef<Metadata *> Properties) {
7038 assert(Loop->isValid() && "Expecting a valid CanonicalLoopInfo");
7039
7040 // Attach metadata to the loop's latch
7041 BasicBlock *Latch = Loop->getLatch();
7042 assert(Latch && "A valid CanonicalLoopInfo must have a unique latch");
7043 addBasicBlockMetadata(BB: Latch, Properties);
7044}
7045
7046/// Attach llvm.access.group metadata to the memref instructions of \p Block
7047static void addAccessGroupMetadata(BasicBlock *Block, MDNode *AccessGroup,
7048 LoopInfo &LI) {
7049 for (Instruction &I : *Block) {
7050 if (I.mayReadOrWriteMemory()) {
7051 // TODO: This instruction may already have access group from
7052 // other pragmas e.g. #pragma clang loop vectorize. Append
7053 // so that the existing metadata is not overwritten.
7054 I.setMetadata(KindID: LLVMContext::MD_access_group, Node: AccessGroup);
7055 }
7056 }
7057}
7058
7059CanonicalLoopInfo *
7060OpenMPIRBuilder::fuseLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops) {
7061 CanonicalLoopInfo *firstLoop = Loops.front();
7062 CanonicalLoopInfo *lastLoop = Loops.back();
7063 Function *F = firstLoop->getPreheader()->getParent();
7064
7065 // Loop control blocks that will become orphaned later
7066 SmallVector<BasicBlock *> oldControlBBs;
7067 for (CanonicalLoopInfo *Loop : Loops)
7068 Loop->collectControlBlocks(BBs&: oldControlBBs);
7069
7070 // Collect original trip counts
7071 SmallVector<Value *> origTripCounts;
7072 for (CanonicalLoopInfo *L : Loops) {
7073 assert(L->isValid() && "All input loops must be valid canonical loops");
7074 origTripCounts.push_back(Elt: L->getTripCount());
7075 }
7076
7077 Builder.SetCurrentDebugLocation(DL);
7078
7079 // Compute max trip count.
7080 // The fused loop will be from 0 to max(origTripCounts)
7081 BasicBlock *TCBlock = BasicBlock::Create(Context&: F->getContext(), Name: "omp.fuse.comp.tc",
7082 Parent: F, InsertBefore: firstLoop->getHeader());
7083 Builder.SetInsertPoint(TCBlock);
7084 Value *fusedTripCount = nullptr;
7085 for (CanonicalLoopInfo *L : Loops) {
7086 assert(L->isValid() && "All loops to fuse must be valid canonical loops");
7087 Value *origTripCount = L->getTripCount();
7088 if (!fusedTripCount) {
7089 fusedTripCount = origTripCount;
7090 continue;
7091 }
7092 Value *condTP = Builder.CreateICmpSGT(LHS: fusedTripCount, RHS: origTripCount);
7093 fusedTripCount = Builder.CreateSelect(C: condTP, True: fusedTripCount, False: origTripCount,
7094 Name: ".omp.fuse.tc");
7095 }
7096
7097 // Generate new loop
7098 CanonicalLoopInfo *fused =
7099 createLoopSkeleton(DL, TripCount: fusedTripCount, F, PreInsertBefore: firstLoop->getBody(),
7100 PostInsertBefore: lastLoop->getLatch(), Name: "fused");
7101
7102 // Replace original loops with the fused loop
7103 // Preheader and After are not considered inside the CLI.
7104 // These are used to compute the individual TCs of the loops
7105 // so they have to be put before the resulting fused loop.
7106 // Moving them up for readability.
7107 for (size_t i = 0; i < Loops.size() - 1; ++i) {
7108 Loops[i]->getPreheader()->moveBefore(MovePos: TCBlock);
7109 Loops[i]->getAfter()->moveBefore(MovePos: TCBlock);
7110 }
7111 lastLoop->getPreheader()->moveBefore(MovePos: TCBlock);
7112
7113 for (size_t i = 0; i < Loops.size() - 1; ++i) {
7114 redirectTo(Source: Loops[i]->getPreheader(), Target: Loops[i]->getAfter(), DL);
7115 redirectTo(Source: Loops[i]->getAfter(), Target: Loops[i + 1]->getPreheader(), DL);
7116 }
7117 redirectTo(Source: lastLoop->getPreheader(), Target: TCBlock, DL);
7118 redirectTo(Source: TCBlock, Target: fused->getPreheader(), DL);
7119 redirectTo(Source: fused->getAfter(), Target: lastLoop->getAfter(), DL);
7120
7121 // Build the fused body
7122 // Create new Blocks with conditions that jump to the original loop bodies
7123 SmallVector<BasicBlock *> condBBs;
7124 SmallVector<Value *> condValues;
7125 for (size_t i = 0; i < Loops.size(); ++i) {
7126 BasicBlock *condBlock = BasicBlock::Create(
7127 Context&: F->getContext(), Name: "omp.fused.inner.cond", Parent: F, InsertBefore: Loops[i]->getBody());
7128 Builder.SetInsertPoint(condBlock);
7129 Value *condValue =
7130 Builder.CreateICmpSLT(LHS: fused->getIndVar(), RHS: origTripCounts[i]);
7131 condBBs.push_back(Elt: condBlock);
7132 condValues.push_back(Elt: condValue);
7133 }
7134 // Join the condition blocks with the bodies of the original loops
7135 redirectTo(Source: fused->getBody(), Target: condBBs[0], DL);
7136 for (size_t i = 0; i < Loops.size() - 1; ++i) {
7137 Builder.SetInsertPoint(condBBs[i]);
7138 Builder.CreateCondBr(Cond: condValues[i], True: Loops[i]->getBody(), False: condBBs[i + 1]);
7139 redirectAllPredecessorsTo(OldTarget: Loops[i]->getLatch(), NewTarget: condBBs[i + 1], DL);
7140 // Replace the IV with the fused IV
7141 Loops[i]->getIndVar()->replaceAllUsesWith(V: fused->getIndVar());
7142 }
7143 // Last body jumps to the created end body block
7144 Builder.SetInsertPoint(condBBs.back());
7145 Builder.CreateCondBr(Cond: condValues.back(), True: lastLoop->getBody(),
7146 False: fused->getLatch());
7147 redirectAllPredecessorsTo(OldTarget: lastLoop->getLatch(), NewTarget: fused->getLatch(), DL);
7148 // Replace the IV with the fused IV
7149 lastLoop->getIndVar()->replaceAllUsesWith(V: fused->getIndVar());
7150
7151 // The loop latch must have only one predecessor. Currently it is branched to
7152 // from both the last condition block and the last loop body
7153 fused->getLatch()->splitBasicBlockBefore(I: fused->getLatch()->begin(),
7154 BBName: "omp.fused.pre_latch");
7155
7156 // Remove unused parts
7157 removeUnusedBlocksFromParent(BBs: oldControlBBs);
7158
7159 // Invalidate old CLIs
7160 for (CanonicalLoopInfo *L : Loops)
7161 L->invalidate();
7162
7163#ifndef NDEBUG
7164 fused->assertOK();
7165#endif
7166 return fused;
7167}
7168
7169void OpenMPIRBuilder::unrollLoopFull(DebugLoc, CanonicalLoopInfo *Loop) {
7170 LLVMContext &Ctx = Builder.getContext();
7171 addLoopMetadata(
7172 Loop, Properties: {MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")),
7173 MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.full"))});
7174}
7175
7176void OpenMPIRBuilder::unrollLoopHeuristic(DebugLoc, CanonicalLoopInfo *Loop) {
7177 LLVMContext &Ctx = Builder.getContext();
7178 addLoopMetadata(
7179 Loop, Properties: {
7180 MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")),
7181 });
7182}
7183
7184void OpenMPIRBuilder::createIfVersion(CanonicalLoopInfo *CanonicalLoop,
7185 Value *IfCond, ValueToValueMapTy &VMap,
7186 LoopAnalysis &LIA, LoopInfo &LI, Loop *L,
7187 const Twine &NamePrefix) {
7188 Function *F = CanonicalLoop->getFunction();
7189
7190 // We can't do
7191 // if (cond) {
7192 // simd_loop;
7193 // } else {
7194 // non_simd_loop;
7195 // }
7196 // because then the CanonicalLoopInfo would only point to one of the loops:
7197 // leading to other constructs operating on the same loop to malfunction.
7198 // Instead generate
7199 // while (...) {
7200 // if (cond) {
7201 // simd_body;
7202 // } else {
7203 // not_simd_body;
7204 // }
7205 // }
7206 // At least for simple loops, LLVM seems able to hoist the if out of the loop
7207 // body at -O3
7208
7209 // Define where if branch should be inserted
7210 auto SplitBeforeIt = CanonicalLoop->getBody()->getFirstNonPHIIt();
7211
7212 // Create additional blocks for the if statement
7213 BasicBlock *Cond = SplitBeforeIt->getParent();
7214 llvm::LLVMContext &C = Cond->getContext();
7215 llvm::BasicBlock *ThenBlock = llvm::BasicBlock::Create(
7216 Context&: C, Name: NamePrefix + ".if.then", Parent: Cond->getParent(), InsertBefore: Cond->getNextNode());
7217 llvm::BasicBlock *ElseBlock = llvm::BasicBlock::Create(
7218 Context&: C, Name: NamePrefix + ".if.else", Parent: Cond->getParent(), InsertBefore: CanonicalLoop->getExit());
7219
7220 // Create if condition branch.
7221 Builder.SetInsertPoint(SplitBeforeIt);
7222 Instruction *BrInstr =
7223 Builder.CreateCondBr(Cond: IfCond, True: ThenBlock, /*ifFalse*/ False: ElseBlock);
7224 InsertPointTy IP{BrInstr->getParent(), ++BrInstr->getIterator()};
7225 // Then block contains branch to omp loop body which needs to be vectorized
7226 spliceBB(IP, New: ThenBlock, CreateBranch: false, DL: Builder.getCurrentDebugLocation());
7227 ThenBlock->replaceSuccessorsPhiUsesWith(Old: Cond, New: ThenBlock);
7228
7229 Builder.SetInsertPoint(ElseBlock);
7230
7231 // Clone loop for the else branch
7232 SmallVector<BasicBlock *, 8> NewBlocks;
7233
7234 SmallVector<BasicBlock *, 8> ExistingBlocks;
7235 ExistingBlocks.reserve(N: L->getNumBlocks() + 1);
7236 ExistingBlocks.push_back(Elt: ThenBlock);
7237 ExistingBlocks.append(in_start: L->block_begin(), in_end: L->block_end());
7238 // Cond is the block that has the if clause condition
7239 // LoopCond is omp_loop.cond
7240 // LoopHeader is omp_loop.header
7241 BasicBlock *LoopCond = Cond->getUniquePredecessor();
7242 BasicBlock *LoopHeader = LoopCond->getUniquePredecessor();
7243 assert(LoopCond && LoopHeader && "Invalid loop structure");
7244 for (BasicBlock *Block : ExistingBlocks) {
7245 if (Block == L->getLoopPreheader() || Block == L->getLoopLatch() ||
7246 Block == LoopHeader || Block == LoopCond || Block == Cond) {
7247 continue;
7248 }
7249 BasicBlock *NewBB = CloneBasicBlock(BB: Block, VMap, NameSuffix: "", F);
7250
7251 // fix name not to be omp.if.then
7252 if (Block == ThenBlock)
7253 NewBB->setName(NamePrefix + ".if.else");
7254
7255 NewBB->moveBefore(MovePos: CanonicalLoop->getExit());
7256 VMap[Block] = NewBB;
7257 NewBlocks.push_back(Elt: NewBB);
7258 }
7259 remapInstructionsInBlocks(Blocks: NewBlocks, VMap);
7260 Builder.CreateBr(Dest: NewBlocks.front());
7261
7262 // The loop latch must have only one predecessor. Currently it is branched to
7263 // from both the 'then' and 'else' branches.
7264 L->getLoopLatch()->splitBasicBlockBefore(I: L->getLoopLatch()->begin(),
7265 BBName: NamePrefix + ".pre_latch");
7266
7267 // Ensure that the then block is added to the loop so we add the attributes in
7268 // the next step
7269 L->addBasicBlockToLoop(NewBB: ThenBlock, LI);
7270}
7271
7272unsigned
7273OpenMPIRBuilder::getOpenMPDefaultSimdAlign(const Triple &TargetTriple,
7274 const StringMap<bool> &Features) {
7275 if (TargetTriple.isX86()) {
7276 if (Features.lookup(Key: "avx512f"))
7277 return 512;
7278 else if (Features.lookup(Key: "avx"))
7279 return 256;
7280 return 128;
7281 }
7282 if (TargetTriple.isPPC())
7283 return 128;
7284 if (TargetTriple.isWasm())
7285 return 128;
7286 return 0;
7287}
7288
7289void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
7290 MapVector<Value *, Value *> AlignedVars,
7291 Value *IfCond, OrderKind Order,
7292 ConstantInt *Simdlen, ConstantInt *Safelen) {
7293 LLVMContext &Ctx = Builder.getContext();
7294
7295 Function *F = CanonicalLoop->getFunction();
7296
7297 // Blocks must have terminators.
7298 // FIXME: Don't run analyses on incomplete/invalid IR.
7299 SmallVector<Instruction *> UIs;
7300 for (BasicBlock &BB : *F)
7301 if (!BB.hasTerminator())
7302 UIs.push_back(Elt: new UnreachableInst(F->getContext(), &BB));
7303
7304 // TODO: We should not rely on pass manager. Currently we use pass manager
7305 // only for getting llvm::Loop which corresponds to given CanonicalLoopInfo
7306 // object. We should have a method which returns all blocks between
7307 // CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter()
7308 FunctionAnalysisManager FAM;
7309 FAM.registerPass(PassBuilder: []() { return DominatorTreeAnalysis(); });
7310 FAM.registerPass(PassBuilder: []() { return LoopAnalysis(); });
7311 FAM.registerPass(PassBuilder: []() { return PassInstrumentationAnalysis(); });
7312
7313 LoopAnalysis LIA;
7314 LoopInfo &&LI = LIA.run(F&: *F, AM&: FAM);
7315
7316 for (Instruction *I : UIs)
7317 I->eraseFromParent();
7318
7319 Loop *L = LI.getLoopFor(BB: CanonicalLoop->getHeader());
7320 if (AlignedVars.size()) {
7321 InsertPointTy IP = Builder.saveIP();
7322 for (auto &AlignedItem : AlignedVars) {
7323 Value *AlignedPtr = AlignedItem.first;
7324 Value *Alignment = AlignedItem.second;
7325 Instruction *loadInst = dyn_cast<Instruction>(Val: AlignedPtr);
7326 Builder.SetInsertPoint(loadInst->getNextNode());
7327 Builder.CreateAlignmentAssumption(DL: F->getDataLayout(), PtrValue: AlignedPtr,
7328 Alignment);
7329 }
7330 Builder.restoreIP(IP);
7331 }
7332
7333 if (IfCond) {
7334 ValueToValueMapTy VMap;
7335 createIfVersion(CanonicalLoop, IfCond, VMap, LIA, LI, L, NamePrefix: "simd");
7336 }
7337
7338 SmallPtrSet<BasicBlock *, 8> Reachable;
7339
7340 // Get the basic blocks from the loop in which memref instructions
7341 // can be found.
7342 // TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
7343 // preferably without running any passes.
7344 for (BasicBlock *Block : L->getBlocks()) {
7345 if (Block == CanonicalLoop->getCond() ||
7346 Block == CanonicalLoop->getHeader())
7347 continue;
7348 Reachable.insert(Ptr: Block);
7349 }
7350
7351 SmallVector<Metadata *> LoopMDList;
7352
7353 // In presence of finite 'safelen', it may be unsafe to mark all
7354 // the memory instructions parallel, because loop-carried
7355 // dependences of 'safelen' iterations are possible.
7356 // If clause order(concurrent) is specified then the memory instructions
7357 // are marked parallel even if 'safelen' is finite.
7358 if ((Safelen == nullptr) || (Order == OrderKind::OMP_ORDER_concurrent))
7359 applyParallelAccessesMetadata(CLI: CanonicalLoop, Ctx, Loop: L, LoopInfo&: LI, LoopMDList);
7360
7361 // FIXME: the IF clause shares a loop backedge for the SIMD and non-SIMD
7362 // versions so we can't add the loop attributes in that case.
7363 if (IfCond) {
7364 // we can still add llvm.loop.parallel_access
7365 addLoopMetadata(Loop: CanonicalLoop, Properties: LoopMDList);
7366 return;
7367 }
7368
7369 // Use the above access group metadata to create loop level
7370 // metadata, which should be distinct for each loop.
7371 ConstantAsMetadata *BoolConst =
7372 ConstantAsMetadata::get(C: ConstantInt::getTrue(Ty: Type::getInt1Ty(C&: Ctx)));
7373 LoopMDList.push_back(Elt: MDNode::get(
7374 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.vectorize.enable"), BoolConst}));
7375
7376 if (Simdlen || Safelen) {
7377 // If both simdlen and safelen clauses are specified, the value of the
7378 // simdlen parameter must be less than or equal to the value of the safelen
7379 // parameter. Therefore, use safelen only in the absence of simdlen.
7380 ConstantInt *VectorizeWidth = Simdlen == nullptr ? Safelen : Simdlen;
7381 LoopMDList.push_back(
7382 Elt: MDNode::get(Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.vectorize.width"),
7383 ConstantAsMetadata::get(C: VectorizeWidth)}));
7384 }
7385
7386 addLoopMetadata(Loop: CanonicalLoop, Properties: LoopMDList);
7387}
7388
7389/// Create the TargetMachine object to query the backend for optimization
7390/// preferences.
7391///
7392/// Ideally, this would be passed from the front-end to the OpenMPBuilder, but
7393/// e.g. Clang does not pass it to its CodeGen layer and creates it only when
7394/// needed for the LLVM pass pipline. We use some default options to avoid
7395/// having to pass too many settings from the frontend that probably do not
7396/// matter.
7397///
7398/// Currently, TargetMachine is only used sometimes by the unrollLoopPartial
7399/// method. If we are going to use TargetMachine for more purposes, especially
7400/// those that are sensitive to TargetOptions, RelocModel and CodeModel, it
7401/// might become be worth requiring front-ends to pass on their TargetMachine,
7402/// or at least cache it between methods. Note that while fontends such as Clang
7403/// have just a single main TargetMachine per translation unit, "target-cpu" and
7404/// "target-features" that determine the TargetMachine are per-function and can
7405/// be overrided using __attribute__((target("OPTIONS"))).
7406static std::unique_ptr<TargetMachine>
7407createTargetMachine(Function *F, CodeGenOptLevel OptLevel) {
7408 Module *M = F->getParent();
7409
7410 StringRef CPU = F->getFnAttribute(Kind: "target-cpu").getValueAsString();
7411 StringRef Features = F->getFnAttribute(Kind: "target-features").getValueAsString();
7412 const llvm::Triple &Triple = M->getTargetTriple();
7413
7414 std::string Error;
7415 const llvm::Target *TheTarget = TargetRegistry::lookupTarget(TheTriple: Triple, Error);
7416 if (!TheTarget)
7417 return {};
7418
7419 llvm::TargetOptions Options;
7420 return std::unique_ptr<TargetMachine>(TheTarget->createTargetMachine(
7421 TT: Triple, CPU, Features, Options, /*RelocModel=*/RM: std::nullopt,
7422 /*CodeModel=*/CM: std::nullopt, OL: OptLevel));
7423}
7424
7425/// Heuristically determine the best-performant unroll factor for \p CLI. This
7426/// depends on the target processor. We are re-using the same heuristics as the
7427/// LoopUnrollPass.
7428static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) {
7429 Function *F = CLI->getFunction();
7430
7431 // Assume the user requests the most aggressive unrolling, even if the rest of
7432 // the code is optimized using a lower setting.
7433 CodeGenOptLevel OptLevel = CodeGenOptLevel::Aggressive;
7434 std::unique_ptr<TargetMachine> TM = createTargetMachine(F, OptLevel);
7435
7436 // Blocks must have terminators.
7437 // FIXME: Don't run analyses on incomplete/invalid IR.
7438 SmallVector<Instruction *> UIs;
7439 for (BasicBlock &BB : *F)
7440 if (!BB.hasTerminator())
7441 UIs.push_back(Elt: new UnreachableInst(F->getContext(), &BB));
7442
7443 FunctionAnalysisManager FAM;
7444 FAM.registerPass(PassBuilder: []() { return TargetLibraryAnalysis(); });
7445 FAM.registerPass(PassBuilder: []() { return AssumptionAnalysis(); });
7446 FAM.registerPass(PassBuilder: []() { return DominatorTreeAnalysis(); });
7447 FAM.registerPass(PassBuilder: []() { return LoopAnalysis(); });
7448 FAM.registerPass(PassBuilder: []() { return ScalarEvolutionAnalysis(); });
7449 FAM.registerPass(PassBuilder: []() { return PassInstrumentationAnalysis(); });
7450 TargetIRAnalysis TIRA;
7451 if (TM)
7452 TIRA = TargetIRAnalysis(
7453 [&](const Function &F) { return TM->getTargetTransformInfo(F); });
7454 FAM.registerPass(PassBuilder: [&]() { return TIRA; });
7455
7456 TargetIRAnalysis::Result &&TTI = TIRA.run(F: *F, FAM);
7457 ScalarEvolutionAnalysis SEA;
7458 ScalarEvolution &&SE = SEA.run(F&: *F, AM&: FAM);
7459 DominatorTreeAnalysis DTA;
7460 DominatorTree &&DT = DTA.run(F&: *F, FAM);
7461 LoopAnalysis LIA;
7462 LoopInfo &&LI = LIA.run(F&: *F, AM&: FAM);
7463 AssumptionAnalysis ACT;
7464 AssumptionCache &&AC = ACT.run(F&: *F, FAM);
7465 OptimizationRemarkEmitter ORE{F};
7466
7467 for (Instruction *I : UIs)
7468 I->eraseFromParent();
7469
7470 Loop *L = LI.getLoopFor(BB: CLI->getHeader());
7471 assert(L && "Expecting CanonicalLoopInfo to be recognized as a loop");
7472
7473 TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences(
7474 L, SE, TTI,
7475 /*BlockFrequencyInfo=*/BFI: nullptr,
7476 /*ProfileSummaryInfo=*/PSI: nullptr, ORE, OptLevel: static_cast<int>(OptLevel),
7477 /*UserThreshold=*/std::nullopt,
7478 /*UserCount=*/std::nullopt,
7479 /*UserAllowPartial=*/true,
7480 /*UserAllowRuntime=*/UserRuntime: true,
7481 /*UserUpperBound=*/std::nullopt,
7482 /*UserFullUnrollMaxCount=*/std::nullopt);
7483
7484 UP.Force = true;
7485
7486 // Account for additional optimizations taking place before the LoopUnrollPass
7487 // would unroll the loop.
7488 UP.Threshold *= UnrollThresholdFactor;
7489 UP.PartialThreshold *= UnrollThresholdFactor;
7490
7491 // Use normal unroll factors even if the rest of the code is optimized for
7492 // size.
7493 UP.OptSizeThreshold = UP.Threshold;
7494 UP.PartialOptSizeThreshold = UP.PartialThreshold;
7495
7496 LLVM_DEBUG(dbgs() << "Unroll heuristic thresholds:\n"
7497 << " Threshold=" << UP.Threshold << "\n"
7498 << " PartialThreshold=" << UP.PartialThreshold << "\n"
7499 << " OptSizeThreshold=" << UP.OptSizeThreshold << "\n"
7500 << " PartialOptSizeThreshold="
7501 << UP.PartialOptSizeThreshold << "\n");
7502
7503 // Disable peeling.
7504 TargetTransformInfo::PeelingPreferences PP =
7505 gatherPeelingPreferences(L, SE, TTI,
7506 /*UserAllowPeeling=*/false,
7507 /*UserAllowProfileBasedPeeling=*/false,
7508 /*UnrollingSpecficValues=*/false);
7509
7510 SmallPtrSet<const Value *, 32> EphValues;
7511 CodeMetrics::collectEphemeralValues(L, AC: &AC, EphValues);
7512
7513 // Assume that reads and writes to stack variables can be eliminated by
7514 // Mem2Reg, SROA or LICM. That is, don't count them towards the loop body's
7515 // size.
7516 for (BasicBlock *BB : L->blocks()) {
7517 for (Instruction &I : *BB) {
7518 Value *Ptr;
7519 if (auto *Load = dyn_cast<LoadInst>(Val: &I)) {
7520 Ptr = Load->getPointerOperand();
7521 } else if (auto *Store = dyn_cast<StoreInst>(Val: &I)) {
7522 Ptr = Store->getPointerOperand();
7523 } else
7524 continue;
7525
7526 Ptr = Ptr->stripPointerCasts();
7527
7528 if (auto *Alloca = dyn_cast<AllocaInst>(Val: Ptr)) {
7529 if (Alloca->getParent() == &F->getEntryBlock())
7530 EphValues.insert(Ptr: &I);
7531 }
7532 }
7533 }
7534
7535 UnrollCostEstimator UCE(L, TTI, EphValues, UP.BEInsns);
7536
7537 // Loop is not unrollable if the loop contains certain instructions.
7538 if (!UCE.canUnroll()) {
7539 LLVM_DEBUG(dbgs() << "Loop not considered unrollable\n");
7540 return 1;
7541 }
7542
7543 LLVM_DEBUG(dbgs() << "Estimated loop size is " << UCE.getRolledLoopSize()
7544 << "\n");
7545
7546 // TODO: Determine trip count of \p CLI if constant, computeUnrollCount might
7547 // be able to use it.
7548 int TripCount = 0;
7549 int MaxTripCount = 0;
7550 bool MaxOrZero = false;
7551 unsigned TripMultiple = 0;
7552
7553 computeUnrollCount(L, TTI, DT, LI: &LI, AC: &AC, SE, EphValues, ORE: &ORE, TripCount,
7554 MaxTripCount, MaxOrZero, TripMultiple, UCE, UP, PP);
7555 unsigned Factor = UP.Count;
7556 LLVM_DEBUG(dbgs() << "Suggesting unroll factor of " << Factor << "\n");
7557
7558 // This function returns 1 to signal to not unroll a loop.
7559 if (Factor == 0)
7560 return 1;
7561 return Factor;
7562}
7563
7564void OpenMPIRBuilder::unrollLoopPartial(DebugLoc DL, CanonicalLoopInfo *Loop,
7565 int32_t Factor,
7566 CanonicalLoopInfo **UnrolledCLI) {
7567 assert(Factor >= 0 && "Unroll factor must not be negative");
7568
7569 Function *F = Loop->getFunction();
7570 LLVMContext &Ctx = F->getContext();
7571
7572 // If the unrolled loop is not used for another loop-associated directive, it
7573 // is sufficient to add metadata for the LoopUnrollPass.
7574 if (!UnrolledCLI) {
7575 SmallVector<Metadata *, 2> LoopMetadata;
7576 LoopMetadata.push_back(
7577 Elt: MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")));
7578
7579 if (Factor >= 1) {
7580 ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
7581 C: ConstantInt::get(Ty: Type::getInt32Ty(C&: Ctx), V: APInt(32, Factor)));
7582 LoopMetadata.push_back(Elt: MDNode::get(
7583 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.count"), FactorConst}));
7584 }
7585
7586 addLoopMetadata(Loop, Properties: LoopMetadata);
7587 return;
7588 }
7589
7590 // Heuristically determine the unroll factor.
7591 if (Factor == 0)
7592 Factor = computeHeuristicUnrollFactor(CLI: Loop);
7593
7594 // No change required with unroll factor 1.
7595 if (Factor == 1) {
7596 *UnrolledCLI = Loop;
7597 return;
7598 }
7599
7600 assert(Factor >= 2 &&
7601 "unrolling only makes sense with a factor of 2 or larger");
7602
7603 Type *IndVarTy = Loop->getIndVarType();
7604
7605 // Apply partial unrolling by tiling the loop by the unroll-factor, then fully
7606 // unroll the inner loop.
7607 Value *FactorVal =
7608 ConstantInt::get(Ty: IndVarTy, V: APInt(IndVarTy->getIntegerBitWidth(), Factor,
7609 /*isSigned=*/false));
7610 std::vector<CanonicalLoopInfo *> LoopNest =
7611 tileLoops(DL, Loops: {Loop}, TileSizes: {FactorVal});
7612 assert(LoopNest.size() == 2 && "Expect 2 loops after tiling");
7613 *UnrolledCLI = LoopNest[0];
7614 CanonicalLoopInfo *InnerLoop = LoopNest[1];
7615
7616 // LoopUnrollPass can only fully unroll loops with constant trip count.
7617 // Unroll by the unroll factor with a fallback epilog for the remainder
7618 // iterations if necessary.
7619 ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
7620 C: ConstantInt::get(Ty: Type::getInt32Ty(C&: Ctx), V: APInt(32, Factor)));
7621 addLoopMetadata(
7622 Loop: InnerLoop,
7623 Properties: {MDNode::get(Context&: Ctx, MDs: MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.enable")),
7624 MDNode::get(
7625 Context&: Ctx, MDs: {MDString::get(Context&: Ctx, Str: "llvm.loop.unroll.count"), FactorConst})});
7626
7627#ifndef NDEBUG
7628 (*UnrolledCLI)->assertOK();
7629#endif
7630}
7631
7632OpenMPIRBuilder::InsertPointTy
7633OpenMPIRBuilder::createCopyPrivate(const LocationDescription &Loc,
7634 llvm::Value *BufSize, llvm::Value *CpyBuf,
7635 llvm::Value *CpyFn, llvm::Value *DidIt) {
7636 if (!updateToLocation(Loc))
7637 return Loc.IP;
7638
7639 uint32_t SrcLocStrSize;
7640 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7641 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7642 Value *ThreadId = getOrCreateThreadID(Ident);
7643
7644 llvm::Value *DidItLD = Builder.CreateLoad(Ty: Builder.getInt32Ty(), Ptr: DidIt);
7645
7646 Value *Args[] = {Ident, ThreadId, BufSize, CpyBuf, CpyFn, DidItLD};
7647
7648 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_copyprivate);
7649 createRuntimeFunctionCall(Callee: Fn, Args);
7650
7651 return Builder.saveIP();
7652}
7653
7654OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createSingle(
7655 const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
7656 FinalizeCallbackTy FiniCB, bool IsNowait, ArrayRef<llvm::Value *> CPVars,
7657 ArrayRef<llvm::Function *> CPFuncs) {
7658
7659 if (!updateToLocation(Loc))
7660 return Loc.IP;
7661
7662 // If needed allocate and initialize `DidIt` with 0.
7663 // DidIt: flag variable: 1=single thread; 0=not single thread.
7664 llvm::Value *DidIt = nullptr;
7665 if (!CPVars.empty()) {
7666 DidIt = Builder.CreateAlloca(Ty: llvm::Type::getInt32Ty(C&: Builder.getContext()));
7667 Builder.CreateStore(Val: Builder.getInt32(C: 0), Ptr: DidIt);
7668 }
7669
7670 Directive OMPD = Directive::OMPD_single;
7671 uint32_t SrcLocStrSize;
7672 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7673 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7674 Value *ThreadId = getOrCreateThreadID(Ident);
7675 Value *Args[] = {Ident, ThreadId};
7676
7677 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_single);
7678 Instruction *EntryCall = createRuntimeFunctionCall(Callee: EntryRTLFn, Args);
7679
7680 Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_single);
7681 Instruction *ExitCall = createRuntimeFunctionCall(Callee: ExitRTLFn, Args);
7682
7683 auto FiniCBWrapper = [&](InsertPointTy IP) -> Error {
7684 if (Error Err = FiniCB(IP))
7685 return Err;
7686
7687 // The thread that executes the single region must set `DidIt` to 1.
7688 // This is used by __kmpc_copyprivate, to know if the caller is the
7689 // single thread or not.
7690 if (DidIt)
7691 Builder.CreateStore(Val: Builder.getInt32(C: 1), Ptr: DidIt);
7692
7693 return Error::success();
7694 };
7695
7696 // generates the following:
7697 // if (__kmpc_single()) {
7698 // .... single region ...
7699 // __kmpc_end_single
7700 // }
7701 // __kmpc_copyprivate
7702 // __kmpc_barrier
7703
7704 InsertPointOrErrorTy AfterIP =
7705 EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB: FiniCBWrapper,
7706 /*Conditional*/ true,
7707 /*hasFinalize*/ HasFinalize: true);
7708 if (!AfterIP)
7709 return AfterIP.takeError();
7710
7711 if (DidIt) {
7712 for (size_t I = 0, E = CPVars.size(); I < E; ++I)
7713 // NOTE BufSize is currently unused, so just pass 0.
7714 createCopyPrivate(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
7715 /*BufSize=*/ConstantInt::get(Ty: Int64, V: 0), CpyBuf: CPVars[I],
7716 CpyFn: CPFuncs[I], DidIt);
7717 // NOTE __kmpc_copyprivate already inserts a barrier
7718 } else if (!IsNowait) {
7719 InsertPointOrErrorTy AfterIP =
7720 createBarrier(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
7721 Kind: omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
7722 /* CheckCancelFlag */ false);
7723 if (!AfterIP)
7724 return AfterIP.takeError();
7725 }
7726 return Builder.saveIP();
7727}
7728
7729OpenMPIRBuilder::InsertPointOrErrorTy
7730OpenMPIRBuilder::createScope(const LocationDescription &Loc,
7731 BodyGenCallbackTy BodyGenCB,
7732 FinalizeCallbackTy FiniCB, bool IsNowait) {
7733
7734 if (!updateToLocation(Loc))
7735 return Loc.IP;
7736
7737 // All threads execute the scope body — no conditional entry.
7738 InsertPointOrErrorTy AfterIP = EmitOMPInlinedRegion(
7739 OMPD: Directive::OMPD_scope, /*EntryCall=*/nullptr, /*ExitCall=*/nullptr,
7740 BodyGenCB, FiniCB, /*Conditional=*/false, /*HasFinalize=*/true,
7741 /*IsCancellable=*/false);
7742 if (!AfterIP)
7743 return AfterIP.takeError();
7744
7745 Builder.restoreIP(IP: *AfterIP);
7746 if (!IsNowait) {
7747 AfterIP = createBarrier(Loc: LocationDescription(Builder.saveIP(), Loc.DL),
7748 Kind: omp::Directive::OMPD_unknown,
7749 /*ForceSimpleCall=*/false,
7750 /*CheckCancelFlag=*/false);
7751 if (!AfterIP)
7752 return AfterIP.takeError();
7753 }
7754 return Builder.saveIP();
7755}
7756
7757OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createCritical(
7758 const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
7759 FinalizeCallbackTy FiniCB, StringRef CriticalName, Value *HintInst) {
7760
7761 if (!updateToLocation(Loc))
7762 return Loc.IP;
7763
7764 Directive OMPD = Directive::OMPD_critical;
7765 uint32_t SrcLocStrSize;
7766 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7767 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7768 Value *ThreadId = getOrCreateThreadID(Ident);
7769 Value *LockVar = getOMPCriticalRegionLock(CriticalName);
7770 Value *Args[] = {Ident, ThreadId, LockVar};
7771
7772 SmallVector<llvm::Value *, 4> EnterArgs(std::begin(arr&: Args), std::end(arr&: Args));
7773 Function *RTFn = nullptr;
7774 if (HintInst) {
7775 // Add Hint to entry Args and create call
7776 EnterArgs.push_back(Elt: HintInst);
7777 RTFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_critical_with_hint);
7778 } else {
7779 RTFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_critical);
7780 }
7781 Instruction *EntryCall = createRuntimeFunctionCall(Callee: RTFn, Args: EnterArgs);
7782
7783 Function *ExitRTLFn =
7784 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_critical);
7785 Instruction *ExitCall = createRuntimeFunctionCall(Callee: ExitRTLFn, Args);
7786
7787 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
7788 /*Conditional*/ false, /*hasFinalize*/ HasFinalize: true);
7789}
7790
7791OpenMPIRBuilder::InsertPointTy
7792OpenMPIRBuilder::createOrderedDepend(const LocationDescription &Loc,
7793 InsertPointTy AllocaIP, unsigned NumLoops,
7794 ArrayRef<llvm::Value *> StoreValues,
7795 const Twine &Name, bool IsDependSource) {
7796 assert(
7797 llvm::all_of(StoreValues,
7798 [](Value *SV) { return SV->getType()->isIntegerTy(64); }) &&
7799 "OpenMP runtime requires depend vec with i64 type");
7800
7801 if (!updateToLocation(Loc))
7802 return Loc.IP;
7803
7804 // Allocate space for vector and generate alloc instruction.
7805 auto *ArrI64Ty = ArrayType::get(ElementType: Int64, NumElements: NumLoops);
7806 Builder.restoreIP(IP: AllocaIP);
7807 AllocaInst *ArgsBase = Builder.CreateAlloca(Ty: ArrI64Ty, ArraySize: nullptr, Name);
7808 ArgsBase->setAlignment(Align(8));
7809 updateToLocation(Loc);
7810
7811 // Store the index value with offset in depend vector.
7812 for (unsigned I = 0; I < NumLoops; ++I) {
7813 Value *DependAddrGEPIter = Builder.CreateInBoundsGEP(
7814 Ty: ArrI64Ty, Ptr: ArgsBase, IdxList: {Builder.getInt64(C: 0), Builder.getInt64(C: I)});
7815 StoreInst *STInst = Builder.CreateStore(Val: StoreValues[I], Ptr: DependAddrGEPIter);
7816 STInst->setAlignment(Align(8));
7817 }
7818
7819 Value *DependBaseAddrGEP = Builder.CreateInBoundsGEP(
7820 Ty: ArrI64Ty, Ptr: ArgsBase, IdxList: {Builder.getInt64(C: 0), Builder.getInt64(C: 0)});
7821
7822 uint32_t SrcLocStrSize;
7823 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7824 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7825 Value *ThreadId = getOrCreateThreadID(Ident);
7826 Value *Args[] = {Ident, ThreadId, DependBaseAddrGEP};
7827
7828 Function *RTLFn = nullptr;
7829 if (IsDependSource)
7830 RTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_doacross_post);
7831 else
7832 RTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_doacross_wait);
7833 createRuntimeFunctionCall(Callee: RTLFn, Args);
7834
7835 return Builder.saveIP();
7836}
7837
7838OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createOrderedThreadsSimd(
7839 const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
7840 FinalizeCallbackTy FiniCB, bool IsThreads) {
7841 if (!updateToLocation(Loc))
7842 return Loc.IP;
7843
7844 Directive OMPD = Directive::OMPD_ordered;
7845 Instruction *EntryCall = nullptr;
7846 Instruction *ExitCall = nullptr;
7847
7848 if (IsThreads) {
7849 uint32_t SrcLocStrSize;
7850 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
7851 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7852 Value *ThreadId = getOrCreateThreadID(Ident);
7853 Value *Args[] = {Ident, ThreadId};
7854
7855 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_ordered);
7856 EntryCall = createRuntimeFunctionCall(Callee: EntryRTLFn, Args);
7857
7858 Function *ExitRTLFn =
7859 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_end_ordered);
7860 ExitCall = createRuntimeFunctionCall(Callee: ExitRTLFn, Args);
7861 }
7862
7863 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
7864 /*Conditional*/ false, /*hasFinalize*/ HasFinalize: true);
7865}
7866
7867OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::EmitOMPInlinedRegion(
7868 Directive OMPD, Instruction *EntryCall, Instruction *ExitCall,
7869 BodyGenCallbackTy BodyGenCB, FinalizeCallbackTy FiniCB, bool Conditional,
7870 bool HasFinalize, bool IsCancellable) {
7871
7872 if (HasFinalize)
7873 FinalizationStack.push_back(Elt: {FiniCB, OMPD, IsCancellable});
7874
7875 // Create inlined region's entry and body blocks, in preparation
7876 // for conditional creation
7877 BasicBlock *EntryBB = Builder.GetInsertBlock();
7878 Instruction *SplitPos = EntryBB->getTerminatorOrNull();
7879 if (!isa_and_nonnull<UncondBrInst, CondBrInst>(Val: SplitPos))
7880 SplitPos = new UnreachableInst(Builder.getContext(), EntryBB);
7881 BasicBlock *ExitBB = EntryBB->splitBasicBlock(I: SplitPos, BBName: "omp_region.end");
7882 BasicBlock *FiniBB =
7883 EntryBB->splitBasicBlock(I: EntryBB->getTerminator(), BBName: "omp_region.finalize");
7884
7885 Builder.SetInsertPoint(EntryBB->getTerminator());
7886 emitCommonDirectiveEntry(OMPD, EntryCall, ExitBB, Conditional);
7887
7888 // generate body
7889 if (Error Err =
7890 BodyGenCB(/* AllocaIP */ InsertPointTy(),
7891 /* CodeGenIP */ Builder.saveIP(), /* DeallocBlocks */ {}))
7892 return Err;
7893
7894 // emit exit call and do any needed finalization.
7895 auto FinIP = InsertPointTy(FiniBB, FiniBB->getFirstInsertionPt());
7896 assert(FiniBB->getTerminator()->getNumSuccessors() == 1 &&
7897 FiniBB->getTerminator()->getSuccessor(0) == ExitBB &&
7898 "Unexpected control flow graph state!!");
7899 InsertPointOrErrorTy AfterIP =
7900 emitCommonDirectiveExit(OMPD, FinIP, ExitCall, HasFinalize);
7901 if (!AfterIP)
7902 return AfterIP.takeError();
7903
7904 // If we are skipping the region of a non conditional, remove the exit
7905 // block, and clear the builder's insertion point.
7906 assert(SplitPos->getParent() == ExitBB &&
7907 "Unexpected Insertion point location!");
7908 auto merged = MergeBlockIntoPredecessor(BB: ExitBB);
7909 BasicBlock *ExitPredBB = SplitPos->getParent();
7910 auto InsertBB = merged ? ExitPredBB : ExitBB;
7911 if (!isa_and_nonnull<UncondBrInst, CondBrInst>(Val: SplitPos))
7912 SplitPos->eraseFromParent();
7913 Builder.SetInsertPoint(InsertBB);
7914
7915 return Builder.saveIP();
7916}
7917
7918OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveEntry(
7919 Directive OMPD, Value *EntryCall, BasicBlock *ExitBB, bool Conditional) {
7920 // if nothing to do, Return current insertion point.
7921 if (!Conditional || !EntryCall)
7922 return Builder.saveIP();
7923
7924 BasicBlock *EntryBB = Builder.GetInsertBlock();
7925 Value *CallBool = Builder.CreateIsNotNull(Arg: EntryCall);
7926 auto *ThenBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp_region.body");
7927 auto *UI = new UnreachableInst(Builder.getContext(), ThenBB);
7928
7929 // Emit thenBB and set the Builder's insertion point there for
7930 // body generation next. Place the block after the current block.
7931 Function *CurFn = EntryBB->getParent();
7932 CurFn->insert(Position: std::next(x: EntryBB->getIterator()), BB: ThenBB);
7933
7934 // Move Entry branch to end of ThenBB, and replace with conditional
7935 // branch (If-stmt)
7936 Instruction *EntryBBTI = EntryBB->getTerminator();
7937 Builder.CreateCondBr(Cond: CallBool, True: ThenBB, False: ExitBB);
7938 EntryBBTI->removeFromParent();
7939 Builder.SetInsertPoint(UI);
7940 Builder.Insert(I: EntryBBTI);
7941 UI->eraseFromParent();
7942 Builder.SetInsertPoint(ThenBB->getTerminator());
7943
7944 // return an insertion point to ExitBB.
7945 return IRBuilder<>::InsertPoint(ExitBB, ExitBB->getFirstInsertionPt());
7946}
7947
7948OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitCommonDirectiveExit(
7949 omp::Directive OMPD, InsertPointTy FinIP, Instruction *ExitCall,
7950 bool HasFinalize) {
7951
7952 Builder.restoreIP(IP: FinIP);
7953
7954 // If there is finalization to do, emit it before the exit call
7955 if (HasFinalize) {
7956 assert(!FinalizationStack.empty() &&
7957 "Unexpected finalization stack state!");
7958
7959 FinalizationInfo Fi = FinalizationStack.pop_back_val();
7960 assert(Fi.DK == OMPD && "Unexpected Directive for Finalization call!");
7961
7962 if (Error Err = Fi.mergeFiniBB(Builder, OtherFiniBB: FinIP.getBlock()))
7963 return std::move(Err);
7964
7965 // Exit condition: insertion point is before the terminator of the new Fini
7966 // block
7967 Builder.SetInsertPoint(FinIP.getBlock()->getTerminator());
7968 }
7969
7970 if (!ExitCall)
7971 return Builder.saveIP();
7972
7973 // place the Exitcall as last instruction before Finalization block terminator
7974 ExitCall->removeFromParent();
7975 Builder.Insert(I: ExitCall);
7976
7977 return IRBuilder<>::InsertPoint(ExitCall->getParent(),
7978 ExitCall->getIterator());
7979}
7980
7981OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCopyinClauseBlocks(
7982 InsertPointTy IP, Value *MasterAddr, Value *PrivateAddr,
7983 llvm::IntegerType *IntPtrTy, bool BranchtoEnd) {
7984 if (!IP.isSet())
7985 return IP;
7986
7987 IRBuilder<>::InsertPointGuard IPG(Builder);
7988
7989 // creates the following CFG structure
7990 // OMP_Entry : (MasterAddr != PrivateAddr)?
7991 // F T
7992 // | \
7993 // | copin.not.master
7994 // | /
7995 // v /
7996 // copyin.not.master.end
7997 // |
7998 // v
7999 // OMP.Entry.Next
8000
8001 BasicBlock *OMP_Entry = IP.getBlock();
8002 Function *CurFn = OMP_Entry->getParent();
8003 BasicBlock *CopyBegin =
8004 BasicBlock::Create(Context&: M.getContext(), Name: "copyin.not.master", Parent: CurFn);
8005 BasicBlock *CopyEnd = nullptr;
8006
8007 // If entry block is terminated, split to preserve the branch to following
8008 // basic block (i.e. OMP.Entry.Next), otherwise, leave everything as is.
8009 if (isa_and_nonnull<CondBrInst>(Val: OMP_Entry->getTerminatorOrNull())) {
8010 CopyEnd = OMP_Entry->splitBasicBlock(I: OMP_Entry->getTerminator(),
8011 BBName: "copyin.not.master.end");
8012 OMP_Entry->getTerminator()->eraseFromParent();
8013 } else {
8014 CopyEnd =
8015 BasicBlock::Create(Context&: M.getContext(), Name: "copyin.not.master.end", Parent: CurFn);
8016 }
8017
8018 Builder.SetInsertPoint(OMP_Entry);
8019 Value *MasterPtr = Builder.CreatePtrToInt(V: MasterAddr, DestTy: IntPtrTy);
8020 Value *PrivatePtr = Builder.CreatePtrToInt(V: PrivateAddr, DestTy: IntPtrTy);
8021 Value *cmp = Builder.CreateICmpNE(LHS: MasterPtr, RHS: PrivatePtr);
8022 Builder.CreateCondBr(Cond: cmp, True: CopyBegin, False: CopyEnd);
8023
8024 Builder.SetInsertPoint(CopyBegin);
8025 if (BranchtoEnd)
8026 Builder.SetInsertPoint(Builder.CreateBr(Dest: CopyEnd));
8027
8028 return Builder.saveIP();
8029}
8030
8031CallInst *OpenMPIRBuilder::createOMPAlloc(const LocationDescription &Loc,
8032 Value *Size, Value *Allocator,
8033 std::string Name) {
8034 IRBuilder<>::InsertPointGuard IPG(Builder);
8035 if (!updateToLocation(Loc))
8036 return nullptr;
8037
8038 uint32_t SrcLocStrSize;
8039 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
8040 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
8041 Value *ThreadId = getOrCreateThreadID(Ident);
8042 Value *Args[] = {ThreadId, Size, Allocator};
8043
8044 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_alloc);
8045
8046 return createRuntimeFunctionCall(Callee: Fn, Args, Name);
8047}
8048
8049CallInst *OpenMPIRBuilder::createOMPAlignedAlloc(const LocationDescription &Loc,
8050 Value *Align, Value *Size,
8051 Value *Allocator,
8052 std::string Name) {
8053 IRBuilder<>::InsertPointGuard IPG(Builder);
8054 if (!updateToLocation(Loc))
8055 return nullptr;
8056
8057 uint32_t SrcLocStrSize;
8058 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
8059 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
8060 Value *ThreadId = getOrCreateThreadID(Ident);
8061 Value *Args[] = {ThreadId, Align, Size, Allocator};
8062
8063 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_aligned_alloc);
8064
8065 return Builder.CreateCall(Callee: Fn, Args, Name);
8066}
8067
8068CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc,
8069 Value *Addr, Value *Allocator,
8070 std::string Name) {
8071 IRBuilder<>::InsertPointGuard IPG(Builder);
8072 if (!updateToLocation(Loc))
8073 return nullptr;
8074
8075 uint32_t SrcLocStrSize;
8076 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
8077 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
8078 Value *ThreadId = getOrCreateThreadID(Ident);
8079 Value *Args[] = {ThreadId, Addr, Allocator};
8080 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_free);
8081 return createRuntimeFunctionCall(Callee: Fn, Args, Name);
8082}
8083
8084CallInst *OpenMPIRBuilder::createOMPAllocShared(const LocationDescription &Loc,
8085 Value *Size,
8086 const Twine &Name) {
8087 IRBuilder<>::InsertPointGuard IPG(Builder);
8088 updateToLocation(Loc);
8089
8090 Value *Args[] = {Size};
8091 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_alloc_shared);
8092 CallInst *Call = Builder.CreateCall(Callee: Fn, Args, Name);
8093 Call->addRetAttr(Attr: Attribute::getWithAlignment(
8094 Context&: M.getContext(), Alignment: M.getDataLayout().getPrefTypeAlign(Ty: Int64)));
8095 return Call;
8096}
8097
8098CallInst *OpenMPIRBuilder::createOMPAllocShared(const LocationDescription &Loc,
8099 Type *VarType,
8100 const Twine &Name) {
8101 return createOMPAllocShared(
8102 Loc, Size: Builder.getInt64(C: M.getDataLayout().getTypeAllocSize(Ty: VarType)), Name);
8103}
8104
8105CallInst *OpenMPIRBuilder::createOMPFreeShared(const LocationDescription &Loc,
8106 Value *Addr, Value *Size,
8107 const Twine &Name) {
8108 IRBuilder<>::InsertPointGuard IPG(Builder);
8109 updateToLocation(Loc);
8110
8111 Value *Args[] = {Addr, Size};
8112 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_free_shared);
8113 return Builder.CreateCall(Callee: Fn, Args, Name);
8114}
8115
8116CallInst *OpenMPIRBuilder::createOMPFreeShared(const LocationDescription &Loc,
8117 Value *Addr, Type *VarType,
8118 const Twine &Name) {
8119 return createOMPFreeShared(
8120 Loc, Addr, Size: Builder.getInt64(C: M.getDataLayout().getTypeAllocSize(Ty: VarType)),
8121 Name);
8122}
8123
8124CallInst *OpenMPIRBuilder::createOMPInteropInit(
8125 const LocationDescription &Loc, Value *InteropVar,
8126 omp::OMPInteropType InteropType, Value *Device, Value *NumDependences,
8127 Value *DependenceAddress, bool HaveNowaitClause) {
8128 IRBuilder<>::InsertPointGuard IPG(Builder);
8129 updateToLocation(Loc);
8130
8131 uint32_t SrcLocStrSize;
8132 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
8133 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
8134 Value *ThreadId = getOrCreateThreadID(Ident);
8135 if (Device == nullptr)
8136 Device = Constant::getAllOnesValue(Ty: Int32);
8137 Constant *InteropTypeVal = ConstantInt::get(Ty: Int32, V: (int)InteropType);
8138 if (NumDependences == nullptr) {
8139 NumDependences = ConstantInt::get(Ty: Int32, V: 0);
8140 PointerType *PointerTypeVar = PointerType::getUnqual(C&: M.getContext());
8141 DependenceAddress = ConstantPointerNull::get(T: PointerTypeVar);
8142 }
8143 Value *HaveNowaitClauseVal = ConstantInt::get(Ty: Int32, V: HaveNowaitClause);
8144 Value *Args[] = {
8145 Ident, ThreadId, InteropVar, InteropTypeVal,
8146 Device, NumDependences, DependenceAddress, HaveNowaitClauseVal};
8147
8148 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___tgt_interop_init);
8149
8150 return createRuntimeFunctionCall(Callee: Fn, Args);
8151}
8152
8153CallInst *OpenMPIRBuilder::createOMPInteropDestroy(
8154 const LocationDescription &Loc, Value *InteropVar, Value *Device,
8155 Value *NumDependences, Value *DependenceAddress, bool HaveNowaitClause) {
8156 IRBuilder<>::InsertPointGuard IPG(Builder);
8157 updateToLocation(Loc);
8158
8159 uint32_t SrcLocStrSize;
8160 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
8161 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
8162 Value *ThreadId = getOrCreateThreadID(Ident);
8163 if (Device == nullptr)
8164 Device = Constant::getAllOnesValue(Ty: Int32);
8165 if (NumDependences == nullptr) {
8166 NumDependences = ConstantInt::get(Ty: Int32, V: 0);
8167 PointerType *PointerTypeVar = PointerType::getUnqual(C&: M.getContext());
8168 DependenceAddress = ConstantPointerNull::get(T: PointerTypeVar);
8169 }
8170 Value *HaveNowaitClauseVal = ConstantInt::get(Ty: Int32, V: HaveNowaitClause);
8171 Value *Args[] = {
8172 Ident, ThreadId, InteropVar, Device,
8173 NumDependences, DependenceAddress, HaveNowaitClauseVal};
8174
8175 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___tgt_interop_destroy);
8176
8177 return createRuntimeFunctionCall(Callee: Fn, Args);
8178}
8179
8180CallInst *OpenMPIRBuilder::createOMPInteropUse(const LocationDescription &Loc,
8181 Value *InteropVar, Value *Device,
8182 Value *NumDependences,
8183 Value *DependenceAddress,
8184 bool HaveNowaitClause) {
8185 IRBuilder<>::InsertPointGuard IPG(Builder);
8186 updateToLocation(Loc);
8187 uint32_t SrcLocStrSize;
8188 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
8189 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
8190 Value *ThreadId = getOrCreateThreadID(Ident);
8191 if (Device == nullptr)
8192 Device = Constant::getAllOnesValue(Ty: Int32);
8193 if (NumDependences == nullptr) {
8194 NumDependences = ConstantInt::get(Ty: Int32, V: 0);
8195 PointerType *PointerTypeVar = PointerType::getUnqual(C&: M.getContext());
8196 DependenceAddress = ConstantPointerNull::get(T: PointerTypeVar);
8197 }
8198 Value *HaveNowaitClauseVal = ConstantInt::get(Ty: Int32, V: HaveNowaitClause);
8199 Value *Args[] = {
8200 Ident, ThreadId, InteropVar, Device,
8201 NumDependences, DependenceAddress, HaveNowaitClauseVal};
8202
8203 Function *Fn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___tgt_interop_use);
8204
8205 return createRuntimeFunctionCall(Callee: Fn, Args);
8206}
8207
8208CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
8209 const LocationDescription &Loc, llvm::Value *Pointer,
8210 llvm::ConstantInt *Size, const llvm::Twine &Name) {
8211 IRBuilder<>::InsertPointGuard IPG(Builder);
8212 updateToLocation(Loc);
8213
8214 uint32_t SrcLocStrSize;
8215 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
8216 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
8217 Value *ThreadId = getOrCreateThreadID(Ident);
8218 Constant *ThreadPrivateCache =
8219 getOrCreateInternalVariable(Ty: Int8PtrPtr, Name: Name.str());
8220 llvm::Value *Args[] = {Ident, ThreadId, Pointer, Size, ThreadPrivateCache};
8221
8222 Function *Fn =
8223 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_threadprivate_cached);
8224
8225 return createRuntimeFunctionCall(Callee: Fn, Args);
8226}
8227
8228OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
8229 const LocationDescription &Loc,
8230 const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
8231 assert(!Attrs.MaxThreads.empty() && !Attrs.MaxTeams.empty() &&
8232 "expected num_threads and num_teams to be specified");
8233
8234 if (!updateToLocation(Loc))
8235 return Loc.IP;
8236
8237 uint32_t SrcLocStrSize;
8238 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
8239 Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
8240 Constant *IsSPMDVal = ConstantInt::getSigned(Ty: Int8, V: Attrs.ExecFlags);
8241 Constant *UseGenericStateMachineVal = ConstantInt::getSigned(
8242 Ty: Int8, V: Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD &&
8243 Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP);
8244 Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Ty: Int8, V: true);
8245 Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Ty: Int16, V: 0);
8246
8247 Function *DebugKernelWrapper = Builder.GetInsertBlock()->getParent();
8248 Function *Kernel = DebugKernelWrapper;
8249
8250 // We need to strip the debug prefix to get the correct kernel name.
8251 StringRef KernelName = Kernel->getName();
8252 const std::string DebugPrefix = "_debug__";
8253 if (KernelName.ends_with(Suffix: DebugPrefix)) {
8254 KernelName = KernelName.drop_back(N: DebugPrefix.length());
8255 Kernel = M.getFunction(Name: KernelName);
8256 assert(Kernel && "Expected the real kernel to exist");
8257 }
8258
8259 // Manifest the launch configuration in the metadata matching the kernel
8260 // environment.
8261 if (Attrs.MinTeams > 1 || Attrs.MaxTeams.front() > 0)
8262 writeTeamsForKernel(T, Kernel&: *Kernel, LB: Attrs.MinTeams, UB: Attrs.MaxTeams.front());
8263
8264 // If MaxThreads is not set and needs adjustment, select the maximum between
8265 // the default workgroup size and the MinThreads value.
8266 int32_t MaxThreadsVal = Attrs.MaxThreads.front();
8267 if (MaxThreadsVal < 0 && UseDefaultMaxThreads) {
8268 if (hasGridValue(T)) {
8269 MaxThreadsVal =
8270 std::max(a: int32_t(getGridValue(T, Kernel).GV_Default_WG_Size),
8271 b: Attrs.MinThreads);
8272 } else {
8273 MaxThreadsVal = Attrs.MinThreads;
8274 }
8275 }
8276
8277 if (MaxThreadsVal > 0)
8278 writeThreadBoundsForKernel(T, Kernel&: *Kernel, LB: Attrs.MinThreads, UB: MaxThreadsVal);
8279
8280 Constant *MinThreads = ConstantInt::getSigned(Ty: Int32, V: Attrs.MinThreads);
8281 Constant *MaxThreads = ConstantInt::getSigned(Ty: Int32, V: MaxThreadsVal);
8282 Constant *MinTeams = ConstantInt::getSigned(Ty: Int32, V: Attrs.MinTeams);
8283 Constant *MaxTeams = ConstantInt::getSigned(Ty: Int32, V: Attrs.MaxTeams.front());
8284 Constant *ReductionDataSize =
8285 ConstantInt::getSigned(Ty: Int32, V: Attrs.ReductionDataSize);
8286
8287 Function *Fn = getOrCreateRuntimeFunctionPtr(
8288 FnID: omp::RuntimeFunction::OMPRTL___kmpc_target_init);
8289 const DataLayout &DL = Fn->getDataLayout();
8290
8291 Twine DynamicEnvironmentName = KernelName + "_dynamic_environment";
8292 Constant *DynamicEnvironmentInitializer =
8293 ConstantStruct::get(T: DynamicEnvironment, V: {DebugIndentionLevelVal});
8294 GlobalVariable *DynamicEnvironmentGV = new GlobalVariable(
8295 M, DynamicEnvironment, /*IsConstant=*/false, GlobalValue::WeakODRLinkage,
8296 DynamicEnvironmentInitializer, DynamicEnvironmentName,
8297 /*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
8298 DL.getDefaultGlobalsAddressSpace());
8299 DynamicEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
8300
8301 Constant *DynamicEnvironment =
8302 DynamicEnvironmentGV->getType() == DynamicEnvironmentPtr
8303 ? DynamicEnvironmentGV
8304 : ConstantExpr::getAddrSpaceCast(C: DynamicEnvironmentGV,
8305 Ty: DynamicEnvironmentPtr);
8306
8307 Constant *ConfigurationEnvironmentInitializer = ConstantStruct::get(
8308 T: ConfigurationEnvironment, V: {
8309 UseGenericStateMachineVal,
8310 MayUseNestedParallelismVal,
8311 IsSPMDVal,
8312 MinThreads,
8313 MaxThreads,
8314 MinTeams,
8315 MaxTeams,
8316 ReductionDataSize,
8317 });
8318 Constant *KernelEnvironmentInitializer = ConstantStruct::get(
8319 T: KernelEnvironment, V: {
8320 ConfigurationEnvironmentInitializer,
8321 Ident,
8322 DynamicEnvironment,
8323 });
8324 std::string KernelEnvironmentName =
8325 (KernelName + "_kernel_environment").str();
8326 GlobalVariable *KernelEnvironmentGV = new GlobalVariable(
8327 M, KernelEnvironment, /*IsConstant=*/true, GlobalValue::WeakODRLinkage,
8328 KernelEnvironmentInitializer, KernelEnvironmentName,
8329 /*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
8330 DL.getDefaultGlobalsAddressSpace());
8331 KernelEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
8332
8333 Constant *KernelEnvironment =
8334 KernelEnvironmentGV->getType() == KernelEnvironmentPtr
8335 ? KernelEnvironmentGV
8336 : ConstantExpr::getAddrSpaceCast(C: KernelEnvironmentGV,
8337 Ty: KernelEnvironmentPtr);
8338 Value *KernelLaunchEnvironment =
8339 DebugKernelWrapper->getArg(i: DebugKernelWrapper->arg_size() - 1);
8340 Type *KernelLaunchEnvParamTy = Fn->getFunctionType()->getParamType(i: 1);
8341 KernelLaunchEnvironment =
8342 KernelLaunchEnvironment->getType() == KernelLaunchEnvParamTy
8343 ? KernelLaunchEnvironment
8344 : Builder.CreateAddrSpaceCast(V: KernelLaunchEnvironment,
8345 DestTy: KernelLaunchEnvParamTy);
8346 CallInst *ThreadKind = createRuntimeFunctionCall(
8347 Callee: Fn, Args: {KernelEnvironment, KernelLaunchEnvironment});
8348
8349 Value *ExecUserCode = Builder.CreateICmpEQ(
8350 LHS: ThreadKind, RHS: Constant::getAllOnesValue(Ty: ThreadKind->getType()),
8351 Name: "exec_user_code");
8352
8353 // ThreadKind = __kmpc_target_init(...)
8354 // if (ThreadKind == -1)
8355 // user_code
8356 // else
8357 // return;
8358
8359 auto *UI = Builder.CreateUnreachable();
8360 BasicBlock *CheckBB = UI->getParent();
8361 BasicBlock *UserCodeEntryBB = CheckBB->splitBasicBlock(I: UI, BBName: "user_code.entry");
8362
8363 BasicBlock *WorkerExitBB = BasicBlock::Create(
8364 Context&: CheckBB->getContext(), Name: "worker.exit", Parent: CheckBB->getParent());
8365 Builder.SetInsertPoint(WorkerExitBB);
8366 Builder.CreateRetVoid();
8367
8368 auto *CheckBBTI = CheckBB->getTerminator();
8369 Builder.SetInsertPoint(CheckBBTI);
8370 Builder.CreateCondBr(Cond: ExecUserCode, True: UI->getParent(), False: WorkerExitBB);
8371
8372 CheckBBTI->eraseFromParent();
8373 UI->eraseFromParent();
8374
8375 // Continue in the "user_code" block, see diagram above and in
8376 // openmp/libomptarget/deviceRTLs/common/include/target.h .
8377 return InsertPointTy(UserCodeEntryBB, UserCodeEntryBB->getFirstInsertionPt());
8378}
8379
8380void OpenMPIRBuilder::createTargetDeinit(const LocationDescription &Loc,
8381 int32_t TeamsReductionDataSize) {
8382 if (!updateToLocation(Loc))
8383 return;
8384
8385 Function *Fn = getOrCreateRuntimeFunctionPtr(
8386 FnID: omp::RuntimeFunction::OMPRTL___kmpc_target_deinit);
8387
8388 createRuntimeFunctionCall(Callee: Fn, Args: {});
8389
8390 if (!TeamsReductionDataSize)
8391 return;
8392
8393 Function *Kernel = Builder.GetInsertBlock()->getParent();
8394 // We need to strip the debug prefix to get the correct kernel name.
8395 StringRef KernelName = Kernel->getName();
8396 const std::string DebugPrefix = "_debug__";
8397 if (KernelName.ends_with(Suffix: DebugPrefix))
8398 KernelName = KernelName.drop_back(N: DebugPrefix.length());
8399 auto *KernelEnvironmentGV =
8400 M.getNamedGlobal(Name: (KernelName + "_kernel_environment").str());
8401 assert(KernelEnvironmentGV && "Expected kernel environment global\n");
8402 auto *KernelEnvironmentInitializer = KernelEnvironmentGV->getInitializer();
8403 auto *NewInitializer = ConstantFoldInsertValueInstruction(
8404 Agg: KernelEnvironmentInitializer,
8405 Val: ConstantInt::get(Ty: Int32, V: TeamsReductionDataSize), Idxs: {0, 7});
8406 KernelEnvironmentGV->setInitializer(NewInitializer);
8407}
8408
8409static void updateNVPTXAttr(Function &Kernel, StringRef Name, int32_t Value,
8410 bool Min) {
8411 if (Kernel.hasFnAttribute(Kind: Name)) {
8412 int32_t OldLimit = Kernel.getFnAttributeAsParsedInteger(Kind: Name);
8413 Value = Min ? std::min(a: OldLimit, b: Value) : std::max(a: OldLimit, b: Value);
8414 }
8415 Kernel.addFnAttr(Kind: Name, Val: llvm::utostr(X: Value));
8416}
8417
8418std::pair<int32_t, int32_t>
8419OpenMPIRBuilder::readThreadBoundsForKernel(const Triple &T, Function &Kernel) {
8420 int32_t ThreadLimit =
8421 Kernel.getFnAttributeAsParsedInteger(Kind: "omp_target_thread_limit");
8422
8423 if (T.isAMDGPU()) {
8424 const auto &Attr = Kernel.getFnAttribute(Kind: "amdgpu-flat-work-group-size");
8425 if (!Attr.isValid() || !Attr.isStringAttribute())
8426 return {0, ThreadLimit};
8427 auto [LBStr, UBStr] = Attr.getValueAsString().split(Separator: ',');
8428 int32_t LB, UB;
8429 if (!llvm::to_integer(S: UBStr, Num&: UB, Base: 10))
8430 return {0, ThreadLimit};
8431 UB = ThreadLimit ? std::min(a: ThreadLimit, b: UB) : UB;
8432 if (!llvm::to_integer(S: LBStr, Num&: LB, Base: 10))
8433 return {0, UB};
8434 return {LB, UB};
8435 }
8436
8437 if (Kernel.hasFnAttribute(Kind: NVVMAttr::MaxNTID)) {
8438 int32_t UB = Kernel.getFnAttributeAsParsedInteger(Kind: NVVMAttr::MaxNTID);
8439 return {0, ThreadLimit ? std::min(a: ThreadLimit, b: UB) : UB};
8440 }
8441 return {0, ThreadLimit};
8442}
8443
8444void OpenMPIRBuilder::writeThreadBoundsForKernel(const Triple &T,
8445 Function &Kernel, int32_t LB,
8446 int32_t UB) {
8447 Kernel.addFnAttr(Kind: "omp_target_thread_limit", Val: std::to_string(val: UB));
8448
8449 if (T.isAMDGPU()) {
8450 Kernel.addFnAttr(Kind: "amdgpu-flat-work-group-size",
8451 Val: llvm::utostr(X: LB) + "," + llvm::utostr(X: UB));
8452 return;
8453 }
8454
8455 updateNVPTXAttr(Kernel, Name: NVVMAttr::MaxNTID, Value: UB, Min: true);
8456}
8457
8458std::pair<int32_t, int32_t>
8459OpenMPIRBuilder::readTeamBoundsForKernel(const Triple &, Function &Kernel) {
8460 // TODO: Read from backend annotations if available.
8461 return {0, Kernel.getFnAttributeAsParsedInteger(Kind: "omp_target_num_teams")};
8462}
8463
8464void OpenMPIRBuilder::writeTeamsForKernel(const Triple &T, Function &Kernel,
8465 int32_t LB, int32_t UB) {
8466 if (UB > 0) {
8467 if (T.isNVPTX())
8468 Kernel.addFnAttr(Kind: NVVMAttr::MaxClusterRank, Val: llvm::utostr(X: UB));
8469 if (T.isAMDGPU())
8470 Kernel.addFnAttr(Kind: "amdgpu-max-num-workgroups", Val: llvm::utostr(X: UB) + ",1,1");
8471 }
8472
8473 Kernel.addFnAttr(Kind: "omp_target_num_teams", Val: std::to_string(val: LB));
8474}
8475
8476void OpenMPIRBuilder::setOutlinedTargetRegionFunctionAttributes(
8477 Function *OutlinedFn) {
8478 if (Config.isTargetDevice()) {
8479 OutlinedFn->setLinkage(GlobalValue::WeakODRLinkage);
8480 // TODO: Determine if DSO local can be set to true.
8481 OutlinedFn->setDSOLocal(false);
8482 OutlinedFn->setVisibility(GlobalValue::ProtectedVisibility);
8483 if (T.isAMDGCN())
8484 OutlinedFn->setCallingConv(CallingConv::AMDGPU_KERNEL);
8485 else if (T.isNVPTX())
8486 OutlinedFn->setCallingConv(CallingConv::PTX_Kernel);
8487 else if (T.isSPIRV())
8488 OutlinedFn->setCallingConv(CallingConv::SPIR_KERNEL);
8489 }
8490}
8491
8492Constant *OpenMPIRBuilder::createOutlinedFunctionID(Function *OutlinedFn,
8493 StringRef EntryFnIDName) {
8494 if (Config.isTargetDevice()) {
8495 assert(OutlinedFn && "The outlined function must exist if embedded");
8496 return OutlinedFn;
8497 }
8498
8499 return new GlobalVariable(
8500 M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
8501 Constant::getNullValue(Ty: Builder.getInt8Ty()), EntryFnIDName);
8502}
8503
8504Constant *OpenMPIRBuilder::createTargetRegionEntryAddr(Function *OutlinedFn,
8505 StringRef EntryFnName) {
8506 if (OutlinedFn)
8507 return OutlinedFn;
8508
8509 assert(!M.getGlobalVariable(EntryFnName, true) &&
8510 "Named kernel already exists?");
8511 return new GlobalVariable(
8512 M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::InternalLinkage,
8513 Constant::getNullValue(Ty: Builder.getInt8Ty()), EntryFnName);
8514}
8515
8516Error OpenMPIRBuilder::emitTargetRegionFunction(
8517 TargetRegionEntryInfo &EntryInfo,
8518 FunctionGenCallback &GenerateFunctionCallback, bool IsOffloadEntry,
8519 Function *&OutlinedFn, Constant *&OutlinedFnID) {
8520
8521 SmallString<64> EntryFnName;
8522 OffloadInfoManager.getTargetRegionEntryFnName(Name&: EntryFnName, EntryInfo);
8523
8524 if (Config.isTargetDevice() || !Config.openMPOffloadMandatory()) {
8525 Expected<Function *> CBResult = GenerateFunctionCallback(EntryFnName);
8526 if (!CBResult)
8527 return CBResult.takeError();
8528 OutlinedFn = *CBResult;
8529 } else {
8530 OutlinedFn = nullptr;
8531 }
8532
8533 // If this target outline function is not an offload entry, we don't need to
8534 // register it. This may be in the case of a false if clause, or if there are
8535 // no OpenMP targets.
8536 if (!IsOffloadEntry)
8537 return Error::success();
8538
8539 std::string EntryFnIDName =
8540 Config.isTargetDevice()
8541 ? std::string(EntryFnName)
8542 : createPlatformSpecificName(Parts: {EntryFnName, "region_id"});
8543
8544 OutlinedFnID = registerTargetRegionFunction(EntryInfo, OutlinedFunction: OutlinedFn,
8545 EntryFnName, EntryFnIDName);
8546 return Error::success();
8547}
8548
8549Constant *OpenMPIRBuilder::registerTargetRegionFunction(
8550 TargetRegionEntryInfo &EntryInfo, Function *OutlinedFn,
8551 StringRef EntryFnName, StringRef EntryFnIDName) {
8552 if (OutlinedFn)
8553 setOutlinedTargetRegionFunctionAttributes(OutlinedFn);
8554 auto OutlinedFnID = createOutlinedFunctionID(OutlinedFn, EntryFnIDName);
8555 auto EntryAddr = createTargetRegionEntryAddr(OutlinedFn, EntryFnName);
8556 OffloadInfoManager.registerTargetRegionEntryInfo(
8557 EntryInfo, Addr: EntryAddr, ID: OutlinedFnID,
8558 Flags: OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion);
8559 return OutlinedFnID;
8560}
8561
8562OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData(
8563 const LocationDescription &Loc, InsertPointTy AllocaIP,
8564 InsertPointTy CodeGenIP, ArrayRef<BasicBlock *> DeallocBlocks,
8565 Value *DeviceID, Value *IfCond, TargetDataInfo &Info,
8566 GenMapInfoCallbackTy GenMapInfoCB, CustomMapperCallbackTy CustomMapperCB,
8567 omp::RuntimeFunction *MapperFunc,
8568 function_ref<InsertPointOrErrorTy(InsertPointTy CodeGenIP,
8569 BodyGenTy BodyGenType)>
8570 BodyGenCB,
8571 function_ref<void(unsigned int, Value *)> DeviceAddrCB, Value *SrcLocInfo) {
8572 if (!updateToLocation(Loc))
8573 return InsertPointTy();
8574
8575 Builder.restoreIP(IP: CodeGenIP);
8576
8577 bool IsStandAlone = !BodyGenCB;
8578 MapInfosTy *MapInfo;
8579 // Generate the code for the opening of the data environment. Capture all the
8580 // arguments of the runtime call by reference because they are used in the
8581 // closing of the region.
8582 auto BeginThenGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
8583 ArrayRef<BasicBlock *> DeallocBlocks) -> Error {
8584 MapInfo = &GenMapInfoCB(Builder.saveIP());
8585 if (Error Err = emitOffloadingArrays(
8586 AllocaIP, CodeGenIP: Builder.saveIP(), CombinedInfo&: *MapInfo, Info, CustomMapperCB,
8587 /*IsNonContiguous=*/true, DeviceAddrCB))
8588 return Err;
8589
8590 TargetDataRTArgs RTArgs;
8591 emitOffloadingArraysArgument(Builder, RTArgs, Info);
8592
8593 // Emit the number of elements in the offloading arrays.
8594 Value *PointerNum = Builder.getInt32(C: Info.NumberOfPtrs);
8595
8596 // Source location for the ident struct
8597 if (!SrcLocInfo) {
8598 uint32_t SrcLocStrSize;
8599 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
8600 SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
8601 }
8602
8603 SmallVector<llvm::Value *, 13> OffloadingArgs = {
8604 SrcLocInfo, DeviceID,
8605 PointerNum, RTArgs.BasePointersArray,
8606 RTArgs.PointersArray, RTArgs.SizesArray,
8607 RTArgs.MapTypesArray, RTArgs.MapNamesArray,
8608 RTArgs.MappersArray};
8609
8610 if (IsStandAlone) {
8611 assert(MapperFunc && "MapperFunc missing for standalone target data");
8612
8613 auto TaskBodyCB = [&](Value *, Value *,
8614 IRBuilderBase::InsertPoint) -> Error {
8615 if (Info.HasNoWait) {
8616 OffloadingArgs.append(IL: {llvm::Constant::getNullValue(Ty: Int32),
8617 llvm::Constant::getNullValue(Ty: VoidPtr),
8618 llvm::Constant::getNullValue(Ty: Int32),
8619 llvm::Constant::getNullValue(Ty: VoidPtr)});
8620 }
8621
8622 createRuntimeFunctionCall(Callee: getOrCreateRuntimeFunctionPtr(FnID: *MapperFunc),
8623 Args: OffloadingArgs);
8624
8625 if (Info.HasNoWait) {
8626 BasicBlock *OffloadContBlock =
8627 BasicBlock::Create(Context&: Builder.getContext(), Name: "omp_offload.cont");
8628 Function *CurFn = Builder.GetInsertBlock()->getParent();
8629 emitBlock(BB: OffloadContBlock, CurFn, /*IsFinished=*/true);
8630 Builder.restoreIP(IP: Builder.saveIP());
8631 }
8632 return Error::success();
8633 };
8634
8635 bool RequiresOuterTargetTask = Info.HasNoWait;
8636 if (!RequiresOuterTargetTask)
8637 cantFail(Err: TaskBodyCB(/*DeviceID=*/nullptr, /*RTLoc=*/nullptr,
8638 /*TargetTaskAllocaIP=*/{}));
8639 else
8640 cantFail(ValOrErr: emitTargetTask(TaskBodyCB, DeviceID, RTLoc: SrcLocInfo, AllocaIP,
8641 /*Dependencies=*/{}, RTArgs, HasNoWait: Info.HasNoWait));
8642 } else {
8643 Function *BeginMapperFunc = getOrCreateRuntimeFunctionPtr(
8644 FnID: omp::OMPRTL___tgt_target_data_begin_mapper);
8645
8646 createRuntimeFunctionCall(Callee: BeginMapperFunc, Args: OffloadingArgs);
8647
8648 for (auto DeviceMap : Info.DevicePtrInfoMap) {
8649 if (isa<AllocaInst>(Val: DeviceMap.second.second)) {
8650 auto *LI =
8651 Builder.CreateLoad(Ty: Builder.getPtrTy(), Ptr: DeviceMap.second.first);
8652 Builder.CreateStore(Val: LI, Ptr: DeviceMap.second.second);
8653 }
8654 }
8655
8656 // If device pointer privatization is required, emit the body of the
8657 // region here. It will have to be duplicated: with and without
8658 // privatization.
8659 InsertPointOrErrorTy AfterIP =
8660 BodyGenCB(Builder.saveIP(), BodyGenTy::Priv);
8661 if (!AfterIP)
8662 return AfterIP.takeError();
8663 Builder.restoreIP(IP: *AfterIP);
8664 }
8665 return Error::success();
8666 };
8667
8668 // If we need device pointer privatization, we need to emit the body of the
8669 // region with no privatization in the 'else' branch of the conditional.
8670 // Otherwise, we don't have to do anything.
8671 auto BeginElseGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
8672 ArrayRef<BasicBlock *> DeallocBlocks) -> Error {
8673 InsertPointOrErrorTy AfterIP =
8674 BodyGenCB(Builder.saveIP(), BodyGenTy::DupNoPriv);
8675 if (!AfterIP)
8676 return AfterIP.takeError();
8677 Builder.restoreIP(IP: *AfterIP);
8678 return Error::success();
8679 };
8680
8681 // Generate code for the closing of the data region.
8682 auto EndThenGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
8683 ArrayRef<BasicBlock *> DeallocBlocks) {
8684 TargetDataRTArgs RTArgs;
8685 Info.EmitDebug = !MapInfo->Names.empty();
8686 emitOffloadingArraysArgument(Builder, RTArgs, Info, /*ForEndCall=*/true);
8687
8688 // Emit the number of elements in the offloading arrays.
8689 Value *PointerNum = Builder.getInt32(C: Info.NumberOfPtrs);
8690
8691 // Source location for the ident struct
8692 if (!SrcLocInfo) {
8693 uint32_t SrcLocStrSize;
8694 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
8695 SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
8696 }
8697
8698 Value *OffloadingArgs[] = {SrcLocInfo, DeviceID,
8699 PointerNum, RTArgs.BasePointersArray,
8700 RTArgs.PointersArray, RTArgs.SizesArray,
8701 RTArgs.MapTypesArray, RTArgs.MapNamesArray,
8702 RTArgs.MappersArray};
8703 Function *EndMapperFunc =
8704 getOrCreateRuntimeFunctionPtr(FnID: omp::OMPRTL___tgt_target_data_end_mapper);
8705
8706 createRuntimeFunctionCall(Callee: EndMapperFunc, Args: OffloadingArgs);
8707 return Error::success();
8708 };
8709
8710 // We don't have to do anything to close the region if the if clause evaluates
8711 // to false.
8712 auto EndElseGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
8713 ArrayRef<BasicBlock *> DeallocBlocks) {
8714 return Error::success();
8715 };
8716
8717 Error Err = [&]() -> Error {
8718 if (BodyGenCB) {
8719 Error Err = [&]() {
8720 if (IfCond)
8721 return emitIfClause(Cond: IfCond, ThenGen: BeginThenGen, ElseGen: BeginElseGen, AllocaIP);
8722 return BeginThenGen(AllocaIP, Builder.saveIP(), DeallocBlocks);
8723 }();
8724
8725 if (Err)
8726 return Err;
8727
8728 // If we don't require privatization of device pointers, we emit the body
8729 // in between the runtime calls. This avoids duplicating the body code.
8730 InsertPointOrErrorTy AfterIP =
8731 BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv);
8732 if (!AfterIP)
8733 return AfterIP.takeError();
8734 restoreIPandDebugLoc(Builder, IP: *AfterIP);
8735
8736 if (IfCond)
8737 return emitIfClause(Cond: IfCond, ThenGen: EndThenGen, ElseGen: EndElseGen, AllocaIP);
8738 return EndThenGen(AllocaIP, Builder.saveIP(), DeallocBlocks);
8739 }
8740 if (IfCond)
8741 return emitIfClause(Cond: IfCond, ThenGen: BeginThenGen, ElseGen: EndElseGen, AllocaIP);
8742 return BeginThenGen(AllocaIP, Builder.saveIP(), DeallocBlocks);
8743 }();
8744
8745 if (Err)
8746 return Err;
8747
8748 return Builder.saveIP();
8749}
8750
8751FunctionCallee
8752OpenMPIRBuilder::createForStaticInitFunction(unsigned IVSize, bool IVSigned,
8753 bool IsGPUDistribute) {
8754 assert((IVSize == 32 || IVSize == 64) &&
8755 "IV size is not compatible with the omp runtime");
8756 RuntimeFunction Name;
8757 if (IsGPUDistribute)
8758 Name = IVSize == 32
8759 ? (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_4
8760 : omp::OMPRTL___kmpc_distribute_static_init_4u)
8761 : (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_8
8762 : omp::OMPRTL___kmpc_distribute_static_init_8u);
8763 else
8764 Name = IVSize == 32 ? (IVSigned ? omp::OMPRTL___kmpc_for_static_init_4
8765 : omp::OMPRTL___kmpc_for_static_init_4u)
8766 : (IVSigned ? omp::OMPRTL___kmpc_for_static_init_8
8767 : omp::OMPRTL___kmpc_for_static_init_8u);
8768
8769 return getOrCreateRuntimeFunction(M, FnID: Name);
8770}
8771
8772FunctionCallee OpenMPIRBuilder::createDispatchInitFunction(unsigned IVSize,
8773 bool IVSigned) {
8774 assert((IVSize == 32 || IVSize == 64) &&
8775 "IV size is not compatible with the omp runtime");
8776 RuntimeFunction Name = IVSize == 32
8777 ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_4
8778 : omp::OMPRTL___kmpc_dispatch_init_4u)
8779 : (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_8
8780 : omp::OMPRTL___kmpc_dispatch_init_8u);
8781
8782 return getOrCreateRuntimeFunction(M, FnID: Name);
8783}
8784
8785FunctionCallee OpenMPIRBuilder::createDispatchNextFunction(unsigned IVSize,
8786 bool IVSigned) {
8787 assert((IVSize == 32 || IVSize == 64) &&
8788 "IV size is not compatible with the omp runtime");
8789 RuntimeFunction Name = IVSize == 32
8790 ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_4
8791 : omp::OMPRTL___kmpc_dispatch_next_4u)
8792 : (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_8
8793 : omp::OMPRTL___kmpc_dispatch_next_8u);
8794
8795 return getOrCreateRuntimeFunction(M, FnID: Name);
8796}
8797
8798FunctionCallee OpenMPIRBuilder::createDispatchFiniFunction(unsigned IVSize,
8799 bool IVSigned) {
8800 assert((IVSize == 32 || IVSize == 64) &&
8801 "IV size is not compatible with the omp runtime");
8802 RuntimeFunction Name = IVSize == 32
8803 ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_4
8804 : omp::OMPRTL___kmpc_dispatch_fini_4u)
8805 : (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_8
8806 : omp::OMPRTL___kmpc_dispatch_fini_8u);
8807
8808 return getOrCreateRuntimeFunction(M, FnID: Name);
8809}
8810
8811FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
8812 return getOrCreateRuntimeFunction(M, FnID: omp::OMPRTL___kmpc_dispatch_deinit);
8813}
8814
8815static void FixupDebugInfoForOutlinedFunction(
8816 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, Function *Func,
8817 DenseMap<Value *, std::tuple<Value *, unsigned>> &ValueReplacementMap) {
8818
8819 DISubprogram *NewSP = Func->getSubprogram();
8820 if (!NewSP)
8821 return;
8822
8823 SmallDenseMap<DILocalVariable *, DILocalVariable *> RemappedVariables;
8824
8825 auto GetUpdatedDIVariable = [&](DILocalVariable *OldVar, unsigned arg) {
8826 DILocalVariable *&NewVar = RemappedVariables[OldVar];
8827 // Only use cached variable if the arg number matches. This is important
8828 // so that DIVariable created for privatized variables are not discarded.
8829 if (NewVar && (arg == NewVar->getArg()))
8830 return NewVar;
8831
8832 NewVar = llvm::DILocalVariable::get(
8833 Context&: Builder.getContext(), Scope: OldVar->getScope(), Name: OldVar->getName(),
8834 File: OldVar->getFile(), Line: OldVar->getLine(), Type: OldVar->getType(), Arg: arg,
8835 Flags: OldVar->getFlags(), AlignInBits: OldVar->getAlignInBits(), Annotations: OldVar->getAnnotations());
8836 return NewVar;
8837 };
8838
8839 auto UpdateDebugRecord = [&](auto *DR) {
8840 DILocalVariable *OldVar = DR->getVariable();
8841 unsigned ArgNo = 0;
8842 for (auto Loc : DR->location_ops()) {
8843 auto Iter = ValueReplacementMap.find(Loc);
8844 if (Iter != ValueReplacementMap.end()) {
8845 DR->replaceVariableLocationOp(Loc, std::get<0>(Iter->second));
8846 ArgNo = std::get<1>(Iter->second) + 1;
8847 }
8848 }
8849 if (ArgNo != 0)
8850 DR->setVariable(GetUpdatedDIVariable(OldVar, ArgNo));
8851 };
8852
8853 SmallVector<DbgVariableRecord *, 4> DVRsToDelete;
8854 auto MoveDebugRecordToCorrectBlock = [&](DbgVariableRecord *DVR) {
8855 if (DVR->getNumVariableLocationOps() != 1u) {
8856 DVR->setKillLocation();
8857 return;
8858 }
8859 Value *Loc = DVR->getVariableLocationOp(OpIdx: 0u);
8860 BasicBlock *CurBB = DVR->getParent();
8861 BasicBlock *RequiredBB = nullptr;
8862
8863 if (Instruction *LocInst = dyn_cast<Instruction>(Val: Loc))
8864 RequiredBB = LocInst->getParent();
8865 else if (isa<llvm::Argument>(Val: Loc))
8866 RequiredBB = &DVR->getFunction()->getEntryBlock();
8867
8868 if (RequiredBB && RequiredBB != CurBB) {
8869 assert(!RequiredBB->empty());
8870 RequiredBB->insertDbgRecordBefore(DR: DVR->clone(),
8871 Here: RequiredBB->back().getIterator());
8872 DVRsToDelete.push_back(Elt: DVR);
8873 }
8874 };
8875
8876 // The location and scope of variable intrinsics and records still point to
8877 // the parent function of the target region. Update them.
8878 for (Instruction &I : instructions(F: Func)) {
8879 assert(!isa<llvm::DbgVariableIntrinsic>(&I) &&
8880 "Unexpected debug intrinsic");
8881 for (DbgVariableRecord &DVR : filterDbgVars(R: I.getDbgRecordRange())) {
8882 UpdateDebugRecord(&DVR);
8883 MoveDebugRecordToCorrectBlock(&DVR);
8884 }
8885 }
8886 for (auto *DVR : DVRsToDelete)
8887 DVR->getMarker()->MarkedInstr->dropOneDbgRecord(I: DVR);
8888 // An extra argument is passed to the device. Create the debug data for it.
8889 if (OMPBuilder.Config.isTargetDevice()) {
8890 DICompileUnit *CU = NewSP->getUnit();
8891 Module *M = Func->getParent();
8892 DIBuilder DB(*M, true, CU);
8893 DIType *VoidPtrTy =
8894 DB.createQualifiedType(Tag: dwarf::DW_TAG_pointer_type, FromTy: nullptr);
8895 unsigned ArgNo = Func->arg_size();
8896 DILocalVariable *Var = DB.createParameterVariable(
8897 Scope: NewSP, Name: "dyn_ptr", ArgNo, File: NewSP->getFile(), /*LineNo=*/0, Ty: VoidPtrTy,
8898 /*AlwaysPreserve=*/false, Flags: DINode::DIFlags::FlagArtificial);
8899 auto Loc = DILocation::get(Context&: Func->getContext(), Line: 0, Column: 0, Scope: NewSP, InlinedAt: 0);
8900 Argument *LastArg = Func->getArg(i: Func->arg_size() - 1);
8901 DB.insertDeclare(Storage: LastArg, VarInfo: Var, Expr: DB.createExpression(), DL: Loc,
8902 InsertAtEnd: &(*Func->begin()));
8903 }
8904}
8905
8906static Value *removeASCastIfPresent(Value *V) {
8907 if (Operator::getOpcode(V) == Instruction::AddrSpaceCast)
8908 return cast<Operator>(Val: V)->getOperand(i: 0);
8909 return V;
8910}
8911
8912static Expected<Function *> createOutlinedFunction(
8913 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
8914 const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
8915 StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
8916 OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
8917 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
8918 SmallVector<Type *> ParameterTypes;
8919 if (OMPBuilder.Config.isTargetDevice()) {
8920 // All parameters to target devices are passed as pointers
8921 // or i64. This assumes 64-bit address spaces/pointers.
8922 for (auto &Arg : Inputs)
8923 ParameterTypes.push_back(Elt: Arg->getType()->isPointerTy()
8924 ? Arg->getType()
8925 : Type::getInt64Ty(C&: Builder.getContext()));
8926 } else {
8927 for (auto &Arg : Inputs)
8928 ParameterTypes.push_back(Elt: Arg->getType());
8929 }
8930
8931 // The implicit dyn_ptr argument is always the last parameter on both host
8932 // and device so the argument counts match without runtime manipulation.
8933 auto *PtrTy = PointerType::getUnqual(C&: Builder.getContext());
8934 ParameterTypes.push_back(Elt: PtrTy);
8935
8936 auto BB = Builder.GetInsertBlock();
8937 auto M = BB->getModule();
8938 auto FuncType = FunctionType::get(Result: Builder.getVoidTy(), Params: ParameterTypes,
8939 /*isVarArg*/ false);
8940 auto Func =
8941 Function::Create(Ty: FuncType, Linkage: GlobalValue::InternalLinkage, N: FuncName, M);
8942
8943 // Forward target-cpu and target-features function attributes from the
8944 // original function to the new outlined function.
8945 Function *ParentFn = Builder.GetInsertBlock()->getParent();
8946
8947 auto TargetCpuAttr = ParentFn->getFnAttribute(Kind: "target-cpu");
8948 if (TargetCpuAttr.isStringAttribute())
8949 Func->addFnAttr(Attr: TargetCpuAttr);
8950
8951 auto TargetFeaturesAttr = ParentFn->getFnAttribute(Kind: "target-features");
8952 if (TargetFeaturesAttr.isStringAttribute())
8953 Func->addFnAttr(Attr: TargetFeaturesAttr);
8954
8955 if (OMPBuilder.Config.isTargetDevice()) {
8956 Value *ExecMode =
8957 OMPBuilder.emitKernelExecutionMode(KernelName: FuncName, Mode: DefaultAttrs.ExecFlags);
8958 OMPBuilder.emitUsed(Name: "llvm.compiler.used", List: {ExecMode});
8959 }
8960
8961 // Save insert point.
8962 IRBuilder<>::InsertPointGuard IPG(Builder);
8963 // We will generate the entries in the outlined function but the debug
8964 // location may still be pointing to the parent function. Reset it now.
8965 Builder.SetCurrentDebugLocation(llvm::DebugLoc());
8966
8967 // Generate the region into the function.
8968 BasicBlock *EntryBB = BasicBlock::Create(Context&: Builder.getContext(), Name: "entry", Parent: Func);
8969 Builder.SetInsertPoint(EntryBB);
8970
8971 // Insert target init call in the device compilation pass.
8972 if (OMPBuilder.Config.isTargetDevice())
8973 Builder.restoreIP(IP: OMPBuilder.createTargetInit(Loc: Builder, Attrs: DefaultAttrs));
8974
8975 BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
8976
8977 // As we embed the user code in the middle of our target region after we
8978 // generate entry code, we must move what allocas we can into the entry
8979 // block to avoid possible breaking optimisations for device
8980 if (OMPBuilder.Config.isTargetDevice())
8981 OMPBuilder.ConstantAllocaRaiseCandidates.emplace_back(Args&: Func);
8982
8983 BasicBlock *ExitBB = splitBB(Builder, /*CreateBranch=*/true, Name: "target.exit");
8984 BasicBlock *OutlinedBodyBB =
8985 splitBB(Builder, /*CreateBranch=*/true, Name: "outlined.body");
8986 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = CBFunc(
8987 Builder.saveIP(),
8988 OpenMPIRBuilder::InsertPointTy(OutlinedBodyBB, OutlinedBodyBB->begin()),
8989 ExitBB);
8990 if (!AfterIP)
8991 return AfterIP.takeError();
8992 Builder.SetInsertPoint(ExitBB);
8993
8994 // Insert target deinit call in the device compilation pass.
8995 if (OMPBuilder.Config.isTargetDevice())
8996 OMPBuilder.createTargetDeinit(Loc: Builder);
8997
8998 // Insert return instruction.
8999 Builder.CreateRetVoid();
9000
9001 // New Alloca IP at entry point of created device function.
9002 Builder.SetInsertPoint(EntryBB->getFirstNonPHIIt());
9003 auto AllocaIP = Builder.saveIP();
9004
9005 Builder.SetInsertPoint(UserCodeEntryBB->getFirstNonPHIOrDbg());
9006
9007 // Do not include the artificial dyn_ptr argument.
9008 const auto &ArgRange = make_range(x: Func->arg_begin(), y: Func->arg_end() - 1);
9009
9010 DenseMap<Value *, std::tuple<Value *, unsigned>> ValueReplacementMap;
9011
9012 auto ReplaceValue = [](Value *Input, Value *InputCopy, Function *Func) {
9013 // Things like GEP's can come in the form of Constants. Constants and
9014 // ConstantExpr's do not have access to the knowledge of what they're
9015 // contained in, so we must dig a little to find an instruction so we
9016 // can tell if they're used inside of the function we're outlining. We
9017 // also replace the original constant expression with a new instruction
9018 // equivalent; an instruction as it allows easy modification in the
9019 // following loop, as we can now know the constant (instruction) is
9020 // owned by our target function and replaceUsesOfWith can now be invoked
9021 // on it (cannot do this with constants it seems). A brand new one also
9022 // allows us to be cautious as it is perhaps possible the old expression
9023 // was used inside of the function but exists and is used externally
9024 // (unlikely by the nature of a Constant, but still).
9025 // NOTE: We cannot remove dead constants that have been rewritten to
9026 // instructions at this stage, we run the risk of breaking later lowering
9027 // by doing so as we could still be in the process of lowering the module
9028 // from MLIR to LLVM-IR and the MLIR lowering may still require the original
9029 // constants we have created rewritten versions of.
9030 if (auto *Const = dyn_cast<Constant>(Val: Input))
9031 convertUsersOfConstantsToInstructions(Consts: Const, RestrictToFunc: Func, RemoveDeadConstants: false);
9032
9033 // Collect users before iterating over them to avoid invalidating the
9034 // iteration in case a user uses Input more than once (e.g. a call
9035 // instruction).
9036 SetVector<User *> Users(Input->users().begin(), Input->users().end());
9037 // Collect all the instructions
9038 for (User *User : make_early_inc_range(Range&: Users))
9039 if (auto *Instr = dyn_cast<Instruction>(Val: User))
9040 if (Instr->getFunction() == Func)
9041 Instr->replaceUsesOfWith(From: Input, To: InputCopy);
9042 };
9043
9044 SmallVector<std::pair<Value *, Value *>> DeferredReplacement;
9045
9046 // Rewrite uses of input valus to parameters.
9047 for (auto InArg : zip(t&: Inputs, u: ArgRange)) {
9048 Value *Input = std::get<0>(t&: InArg);
9049 Argument &Arg = std::get<1>(t&: InArg);
9050 Value *InputCopy = nullptr;
9051
9052 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = ArgAccessorFuncCB(
9053 Arg, Input, InputCopy, AllocaIP, Builder.saveIP(),
9054 OpenMPIRBuilder::InsertPointTy(ExitBB, ExitBB->begin()));
9055 if (!AfterIP)
9056 return AfterIP.takeError();
9057 Builder.restoreIP(IP: *AfterIP);
9058 ValueReplacementMap[Input] = std::make_tuple(args&: InputCopy, args: Arg.getArgNo());
9059
9060 // In certain cases a Global may be set up for replacement, however, this
9061 // Global may be used in multiple arguments to the kernel, just segmented
9062 // apart, for example, if we have a global array, that is sectioned into
9063 // multiple mappings (technically not legal in OpenMP, but there is a case
9064 // in Fortran for Common Blocks where this is neccesary), we will end up
9065 // with GEP's into this array inside the kernel, that refer to the Global
9066 // but are technically separate arguments to the kernel for all intents and
9067 // purposes. If we have mapped a segment that requires a GEP into the 0-th
9068 // index, it will fold into an referal to the Global, if we then encounter
9069 // this folded GEP during replacement all of the references to the
9070 // Global in the kernel will be replaced with the argument we have generated
9071 // that corresponds to it, including any other GEP's that refer to the
9072 // Global that may be other arguments. This will invalidate all of the other
9073 // preceding mapped arguments that refer to the same global that may be
9074 // separate segments. To prevent this, we defer global processing until all
9075 // other processing has been performed.
9076 if (llvm::isa<llvm::GlobalValue, llvm::GlobalObject, llvm::GlobalVariable>(
9077 Val: removeASCastIfPresent(V: Input))) {
9078 DeferredReplacement.push_back(Elt: std::make_pair(x&: Input, y&: InputCopy));
9079 continue;
9080 }
9081
9082 if (isa<ConstantData>(Val: Input))
9083 continue;
9084
9085 ReplaceValue(Input, InputCopy, Func);
9086 }
9087
9088 // Replace all of our deferred Input values, currently just Globals.
9089 for (auto Deferred : DeferredReplacement)
9090 ReplaceValue(std::get<0>(in&: Deferred), std::get<1>(in&: Deferred), Func);
9091
9092 FixupDebugInfoForOutlinedFunction(OMPBuilder, Builder, Func,
9093 ValueReplacementMap);
9094 return Func;
9095}
9096/// Given a task descriptor, TaskWithPrivates, return the pointer to the block
9097/// of pointers containing shared data between the parent task and the created
9098/// task.
9099static LoadInst *loadSharedDataFromTaskDescriptor(OpenMPIRBuilder &OMPIRBuilder,
9100 IRBuilderBase &Builder,
9101 Value *TaskWithPrivates,
9102 Type *TaskWithPrivatesTy) {
9103
9104 Type *TaskTy = OMPIRBuilder.Task;
9105 LLVMContext &Ctx = Builder.getContext();
9106 Value *TaskT =
9107 Builder.CreateStructGEP(Ty: TaskWithPrivatesTy, Ptr: TaskWithPrivates, Idx: 0);
9108 Value *Shareds = TaskT;
9109 // TaskWithPrivatesTy can be one of the following
9110 // 1. %struct.task_with_privates = type { %struct.kmp_task_ompbuilder_t,
9111 // %struct.privates }
9112 // 2. %struct.kmp_task_ompbuilder_t ;; This is simply TaskTy
9113 //
9114 // In the former case, that is when TaskWithPrivatesTy != TaskTy,
9115 // its first member has to be the task descriptor. TaskTy is the type of the
9116 // task descriptor. TaskT is the pointer to the task descriptor. Loading the
9117 // first member of TaskT, gives us the pointer to shared data.
9118 if (TaskWithPrivatesTy != TaskTy)
9119 Shareds = Builder.CreateStructGEP(Ty: TaskTy, Ptr: TaskT, Idx: 0);
9120 return Builder.CreateLoad(Ty: PointerType::getUnqual(C&: Ctx), Ptr: Shareds);
9121}
9122/// Create an entry point for a target task with the following.
9123/// It'll have the following signature
9124/// void @.omp_target_task_proxy_func(i32 %thread.id, ptr %task)
9125/// This function is called from emitTargetTask once the
9126/// code to launch the target kernel has been outlined already.
9127/// NumOffloadingArrays is the number of offloading arrays that we need to copy
9128/// into the task structure so that the deferred target task can access this
9129/// data even after the stack frame of the generating task has been rolled
9130/// back. Offloading arrays contain base pointers, pointers, sizes etc
9131/// of the data that the target kernel will access. These in effect are the
9132/// non-empty arrays of pointers held by OpenMPIRBuilder::TargetDataRTArgs.
9133static Function *emitTargetTaskProxyFunction(
9134 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, CallInst *StaleCI,
9135 StructType *PrivatesTy, StructType *TaskWithPrivatesTy,
9136 const size_t NumOffloadingArrays, const int SharedArgsOperandNo) {
9137
9138 // If NumOffloadingArrays is non-zero, PrivatesTy better not be nullptr.
9139 // This is because PrivatesTy is the type of the structure in which
9140 // we pass the offloading arrays to the deferred target task.
9141 assert((!NumOffloadingArrays || PrivatesTy) &&
9142 "PrivatesTy cannot be nullptr when there are offloadingArrays"
9143 "to privatize");
9144
9145 Module &M = OMPBuilder.M;
9146 // KernelLaunchFunction is the target launch function, i.e.
9147 // the function that sets up kernel arguments and calls
9148 // __tgt_target_kernel to launch the kernel on the device.
9149 //
9150 Function *KernelLaunchFunction = StaleCI->getCalledFunction();
9151
9152 // StaleCI is the CallInst which is the call to the outlined
9153 // target kernel launch function. If there are local live-in values
9154 // that the outlined function uses then these are aggregated into a structure
9155 // which is passed as the second argument. If there are no local live-in
9156 // values or if all values used by the outlined kernel are global variables,
9157 // then there's only one argument, the threadID. So, StaleCI can be
9158 //
9159 // %structArg = alloca { ptr, ptr }, align 8
9160 // %gep_ = getelementptr { ptr, ptr }, ptr %structArg, i32 0, i32 0
9161 // store ptr %20, ptr %gep_, align 8
9162 // %gep_8 = getelementptr { ptr, ptr }, ptr %structArg, i32 0, i32 1
9163 // store ptr %21, ptr %gep_8, align 8
9164 // call void @_QQmain..omp_par.1(i32 %global.tid.val6, ptr %structArg)
9165 //
9166 // OR
9167 //
9168 // call void @_QQmain..omp_par.1(i32 %global.tid.val6)
9169 OpenMPIRBuilder::InsertPointTy IP(StaleCI->getParent(),
9170 StaleCI->getIterator());
9171
9172 LLVMContext &Ctx = StaleCI->getParent()->getContext();
9173
9174 Type *ThreadIDTy = Type::getInt32Ty(C&: Ctx);
9175 Type *TaskPtrTy = OMPBuilder.TaskPtr;
9176 [[maybe_unused]] Type *TaskTy = OMPBuilder.Task;
9177
9178 auto ProxyFnTy =
9179 FunctionType::get(Result: Builder.getVoidTy(), Params: {ThreadIDTy, TaskPtrTy},
9180 /* isVarArg */ false);
9181 auto ProxyFn = Function::Create(Ty: ProxyFnTy, Linkage: GlobalValue::InternalLinkage,
9182 N: ".omp_target_task_proxy_func",
9183 M: Builder.GetInsertBlock()->getModule());
9184 Value *ThreadId = ProxyFn->getArg(i: 0);
9185 Value *TaskWithPrivates = ProxyFn->getArg(i: 1);
9186 ThreadId->setName("thread.id");
9187 TaskWithPrivates->setName("task");
9188
9189 bool HasShareds = SharedArgsOperandNo > 0;
9190 bool HasOffloadingArrays = NumOffloadingArrays > 0;
9191 BasicBlock *EntryBB =
9192 BasicBlock::Create(Context&: Builder.getContext(), Name: "entry", Parent: ProxyFn);
9193 Builder.SetInsertPoint(EntryBB);
9194
9195 SmallVector<Value *> KernelLaunchArgs;
9196 KernelLaunchArgs.reserve(N: StaleCI->arg_size());
9197 KernelLaunchArgs.push_back(Elt: ThreadId);
9198
9199 if (HasOffloadingArrays) {
9200 assert(TaskTy != TaskWithPrivatesTy &&
9201 "If there are offloading arrays to pass to the target"
9202 "TaskTy cannot be the same as TaskWithPrivatesTy");
9203 (void)TaskTy;
9204 Value *Privates =
9205 Builder.CreateStructGEP(Ty: TaskWithPrivatesTy, Ptr: TaskWithPrivates, Idx: 1);
9206 for (unsigned int i = 0; i < NumOffloadingArrays; ++i)
9207 KernelLaunchArgs.push_back(
9208 Elt: Builder.CreateStructGEP(Ty: PrivatesTy, Ptr: Privates, Idx: i));
9209 }
9210
9211 if (HasShareds) {
9212 auto *ArgStructAlloca =
9213 dyn_cast<AllocaInst>(Val: StaleCI->getArgOperand(i: SharedArgsOperandNo));
9214 assert(ArgStructAlloca &&
9215 "Unable to find the alloca instruction corresponding to arguments "
9216 "for extracted function");
9217 auto *ArgStructType = cast<StructType>(Val: ArgStructAlloca->getAllocatedType());
9218 std::optional<TypeSize> ArgAllocSize =
9219 ArgStructAlloca->getAllocationSize(DL: M.getDataLayout());
9220 assert(ArgStructType && ArgAllocSize &&
9221 "Unable to determine size of arguments for extracted function");
9222 uint64_t StructSize = ArgAllocSize->getFixedValue();
9223
9224 AllocaInst *NewArgStructAlloca =
9225 Builder.CreateAlloca(Ty: ArgStructType, ArraySize: nullptr, Name: "structArg");
9226
9227 Value *SharedsSize = Builder.getInt64(C: StructSize);
9228
9229 LoadInst *LoadShared = loadSharedDataFromTaskDescriptor(
9230 OMPIRBuilder&: OMPBuilder, Builder, TaskWithPrivates, TaskWithPrivatesTy);
9231
9232 Builder.CreateMemCpy(
9233 Dst: NewArgStructAlloca, DstAlign: NewArgStructAlloca->getAlign(), Src: LoadShared,
9234 SrcAlign: LoadShared->getPointerAlignment(DL: M.getDataLayout()), Size: SharedsSize);
9235 KernelLaunchArgs.push_back(Elt: NewArgStructAlloca);
9236 }
9237 OMPBuilder.createRuntimeFunctionCall(Callee: KernelLaunchFunction, Args: KernelLaunchArgs);
9238 Builder.CreateRetVoid();
9239 return ProxyFn;
9240}
9241static Type *getOffloadingArrayType(Value *V) {
9242
9243 if (auto *GEP = dyn_cast<GetElementPtrInst>(Val: V))
9244 return GEP->getSourceElementType();
9245 if (auto *Alloca = dyn_cast<AllocaInst>(Val: V))
9246 return Alloca->getAllocatedType();
9247
9248 llvm_unreachable("Unhandled Instruction type");
9249 return nullptr;
9250}
9251// This function returns a struct that has at most two members.
9252// The first member is always %struct.kmp_task_ompbuilder_t, that is the task
9253// descriptor. The second member, if needed, is a struct containing arrays
9254// that need to be passed to the offloaded target kernel. For example,
9255// if .offload_baseptrs, .offload_ptrs and .offload_sizes have to be passed to
9256// the target kernel and their types are [3 x ptr], [3 x ptr] and [3 x i64]
9257// respectively, then the types created by this function are
9258//
9259// %struct.privates = type { [3 x ptr], [3 x ptr], [3 x i64] }
9260// %struct.task_with_privates = type { %struct.kmp_task_ompbuilder_t,
9261// %struct.privates }
9262// %struct.task_with_privates is returned by this function.
9263// If there aren't any offloading arrays to pass to the target kernel,
9264// %struct.kmp_task_ompbuilder_t is returned.
9265static StructType *
9266createTaskWithPrivatesTy(OpenMPIRBuilder &OMPIRBuilder,
9267 ArrayRef<Value *> OffloadingArraysToPrivatize) {
9268
9269 if (OffloadingArraysToPrivatize.empty())
9270 return OMPIRBuilder.Task;
9271
9272 SmallVector<Type *, 4> StructFieldTypes;
9273 for (Value *V : OffloadingArraysToPrivatize) {
9274 assert(V->getType()->isPointerTy() &&
9275 "Expected pointer to array to privatize. Got a non-pointer value "
9276 "instead");
9277 Type *ArrayTy = getOffloadingArrayType(V);
9278 assert(ArrayTy && "ArrayType cannot be nullptr");
9279 StructFieldTypes.push_back(Elt: ArrayTy);
9280 }
9281 StructType *PrivatesStructTy =
9282 StructType::create(Elements: StructFieldTypes, Name: "struct.privates");
9283 return StructType::create(Elements: {OMPIRBuilder.Task, PrivatesStructTy},
9284 Name: "struct.task_with_privates");
9285}
9286static Error emitTargetOutlinedFunction(
9287 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
9288 TargetRegionEntryInfo &EntryInfo,
9289 const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
9290 Function *&OutlinedFn, Constant *&OutlinedFnID,
9291 SmallVectorImpl<Value *> &Inputs,
9292 OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
9293 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
9294
9295 OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
9296 [&](StringRef EntryFnName) {
9297 return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
9298 FuncName: EntryFnName, Inputs, CBFunc,
9299 ArgAccessorFuncCB);
9300 };
9301
9302 return OMPBuilder.emitTargetRegionFunction(
9303 EntryInfo, GenerateFunctionCallback&: GenerateOutlinedFunction, IsOffloadEntry, OutlinedFn,
9304 OutlinedFnID);
9305}
9306
9307OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask(
9308 TargetTaskBodyCallbackTy TaskBodyCB, Value *DeviceID, Value *RTLoc,
9309 OpenMPIRBuilder::InsertPointTy AllocaIP,
9310 const DependenciesInfo &Dependencies, const TargetDataRTArgs &RTArgs,
9311 bool HasNoWait) {
9312
9313 // The following explains the code-gen scenario for the `target` directive. A
9314 // similar scneario is followed for other device-related directives (e.g.
9315 // `target enter data`) but in similar fashion since we only need to emit task
9316 // that encapsulates the proper runtime call.
9317 //
9318 // When we arrive at this function, the target region itself has been
9319 // outlined into the function OutlinedFn.
9320 // So at ths point, for
9321 // --------------------------------------------------------------
9322 // void user_code_that_offloads(...) {
9323 // omp target depend(..) map(from:a) map(to:b) private(i)
9324 // do i = 1, 10
9325 // a(i) = b(i) + n
9326 // }
9327 //
9328 // --------------------------------------------------------------
9329 //
9330 // we have
9331 //
9332 // --------------------------------------------------------------
9333 //
9334 // void user_code_that_offloads(...) {
9335 // %.offload_baseptrs = alloca [2 x ptr], align 8
9336 // %.offload_ptrs = alloca [2 x ptr], align 8
9337 // %.offload_mappers = alloca [2 x ptr], align 8
9338 // ;; target region has been outlined and now we need to
9339 // ;; offload to it via a target task.
9340 // }
9341 // void outlined_device_function(ptr a, ptr b, ptr n) {
9342 // n = *n_ptr;
9343 // do i = 1, 10
9344 // a(i) = b(i) + n
9345 // }
9346 //
9347 // We have to now do the following
9348 // (i) Make an offloading call to outlined_device_function using the OpenMP
9349 // RTL. See 'kernel_launch_function' in the pseudo code below. This is
9350 // emitted by emitKernelLaunch
9351 // (ii) Create a task entry point function that calls kernel_launch_function
9352 // and is the entry point for the target task. See
9353 // '@.omp_target_task_proxy_func in the pseudocode below.
9354 // (iii) Create a task with the task entry point created in (ii)
9355 //
9356 // That is we create the following
9357 // struct task_with_privates {
9358 // struct kmp_task_ompbuilder_t task_struct;
9359 // struct privates {
9360 // [2 x ptr] ; baseptrs
9361 // [2 x ptr] ; ptrs
9362 // [2 x i64] ; sizes
9363 // }
9364 // }
9365 // void user_code_that_offloads(...) {
9366 // %.offload_baseptrs = alloca [2 x ptr], align 8
9367 // %.offload_ptrs = alloca [2 x ptr], align 8
9368 // %.offload_sizes = alloca [2 x i64], align 8
9369 //
9370 // %structArg = alloca { ptr, ptr, ptr }, align 8
9371 // %strucArg[0] = a
9372 // %strucArg[1] = b
9373 // %strucArg[2] = &n
9374 //
9375 // target_task_with_privates = @__kmpc_omp_target_task_alloc(...,
9376 // sizeof(kmp_task_ompbuilder_t),
9377 // sizeof(structArg),
9378 // @.omp_target_task_proxy_func,
9379 // ...)
9380 // memcpy(target_task_with_privates->task_struct->shareds, %structArg,
9381 // sizeof(structArg))
9382 // memcpy(target_task_with_privates->privates->baseptrs,
9383 // offload_baseptrs, sizeof(offload_baseptrs)
9384 // memcpy(target_task_with_privates->privates->ptrs,
9385 // offload_ptrs, sizeof(offload_ptrs)
9386 // memcpy(target_task_with_privates->privates->sizes,
9387 // offload_sizes, sizeof(offload_sizes)
9388 // dependencies_array = ...
9389 // ;; if nowait not present
9390 // call @__kmpc_omp_wait_deps(..., dependencies_array)
9391 // call @__kmpc_omp_task_begin_if0(...)
9392 // call @ @.omp_target_task_proxy_func(i32 thread_id, ptr
9393 // %target_task_with_privates)
9394 // call @__kmpc_omp_task_complete_if0(...)
9395 // }
9396 //
9397 // define internal void @.omp_target_task_proxy_func(i32 %thread.id,
9398 // ptr %task) {
9399 // %structArg = alloca {ptr, ptr, ptr}
9400 // %task_ptr = getelementptr(%task, 0, 0)
9401 // %shared_data = load (getelementptr %task_ptr, 0, 0)
9402 // mempcy(%structArg, %shared_data, sizeof(%structArg))
9403 //
9404 // %offloading_arrays = getelementptr(%task, 0, 1)
9405 // %offload_baseptrs = getelementptr(%offloading_arrays, 0, 0)
9406 // %offload_ptrs = getelementptr(%offloading_arrays, 0, 1)
9407 // %offload_sizes = getelementptr(%offloading_arrays, 0, 2)
9408 // kernel_launch_function(%thread.id, %offload_baseptrs, %offload_ptrs,
9409 // %offload_sizes, %structArg)
9410 // }
9411 //
9412 // We need the proxy function because the signature of the task entry point
9413 // expected by kmpc_omp_task is always the same and will be different from
9414 // that of the kernel_launch function.
9415 //
9416 // kernel_launch_function is generated by emitKernelLaunch and has the
9417 // always_inline attribute. For this example, it'll look like so:
9418 // void kernel_launch_function(%thread_id, %offload_baseptrs, %offload_ptrs,
9419 // %offload_sizes, %structArg) alwaysinline {
9420 // %kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
9421 // ; load aggregated data from %structArg
9422 // ; setup kernel_args using offload_baseptrs, offload_ptrs and
9423 // ; offload_sizes
9424 // call i32 @__tgt_target_kernel(...,
9425 // outlined_device_function,
9426 // ptr %kernel_args)
9427 // }
9428 // void outlined_device_function(ptr a, ptr b, ptr n) {
9429 // n = *n_ptr;
9430 // do i = 1, 10
9431 // a(i) = b(i) + n
9432 // }
9433 //
9434 BasicBlock *TargetTaskBodyBB =
9435 splitBB(Builder, /*CreateBranch=*/true, Name: "target.task.body");
9436 BasicBlock *TargetTaskAllocaBB =
9437 splitBB(Builder, /*CreateBranch=*/true, Name: "target.task.alloca");
9438
9439 InsertPointTy TargetTaskAllocaIP(TargetTaskAllocaBB,
9440 TargetTaskAllocaBB->begin());
9441 InsertPointTy TargetTaskBodyIP(TargetTaskBodyBB, TargetTaskBodyBB->begin());
9442
9443 auto OI = std::make_unique<OutlineInfo>();
9444 OI->EntryBB = TargetTaskAllocaBB;
9445 OI->OuterAllocBB = AllocaIP.getBlock();
9446
9447 // Add the thread ID argument.
9448 SmallVector<Instruction *, 4> ToBeDeleted;
9449 OI->ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
9450 Builder, OuterAllocaIP: AllocaIP, ToBeDeleted, InnerAllocaIP: TargetTaskAllocaIP, Name: "global.tid", AsPtr: false));
9451
9452 // Generate the task body which will subsequently be outlined.
9453 Builder.restoreIP(IP: TargetTaskBodyIP);
9454 if (Error Err = TaskBodyCB(DeviceID, RTLoc, TargetTaskAllocaIP))
9455 return Err;
9456
9457 // The outliner (CodeExtractor) extract a sequence or vector of blocks that
9458 // it is given. These blocks are enumerated by
9459 // OpenMPIRBuilder::OutlineInfo::collectBlocks which expects the OI.ExitBlock
9460 // to be outside the region. In other words, OI.ExitBlock is expected to be
9461 // the start of the region after the outlining. We used to set OI.ExitBlock
9462 // to the InsertBlock after TaskBodyCB is done. This is fine in most cases
9463 // except when the task body is a single basic block. In that case,
9464 // OI.ExitBlock is set to the single task body block and will get left out of
9465 // the outlining process. So, simply create a new empty block to which we
9466 // uncoditionally branch from where TaskBodyCB left off
9467 OI->ExitBB = BasicBlock::Create(Context&: Builder.getContext(), Name: "target.task.cont");
9468 emitBlock(BB: OI->ExitBB, CurFn: Builder.GetInsertBlock()->getParent(),
9469 /*IsFinished=*/true);
9470
9471 SmallVector<Value *, 2> OffloadingArraysToPrivatize;
9472 bool NeedsTargetTask = HasNoWait && DeviceID;
9473 if (NeedsTargetTask) {
9474 for (auto *V :
9475 {RTArgs.BasePointersArray, RTArgs.PointersArray, RTArgs.MappersArray,
9476 RTArgs.MapNamesArray, RTArgs.MapTypesArray, RTArgs.MapTypesArrayEnd,
9477 RTArgs.SizesArray}) {
9478 if (V && !isa<ConstantPointerNull, GlobalVariable>(Val: V)) {
9479 OffloadingArraysToPrivatize.push_back(Elt: V);
9480 OI->ExcludeArgsFromAggregate.push_back(Elt: V);
9481 }
9482 }
9483 }
9484 OI->PostOutlineCB = [this, ToBeDeleted, Dependencies, NeedsTargetTask,
9485 DeviceID, OffloadingArraysToPrivatize](
9486 Function &OutlinedFn) mutable {
9487 assert(OutlinedFn.hasOneUse() &&
9488 "there must be a single user for the outlined function");
9489
9490 CallInst *StaleCI = cast<CallInst>(Val: OutlinedFn.user_back());
9491
9492 // The first argument of StaleCI is always the thread id.
9493 // The next few arguments are the pointers to offloading arrays
9494 // if any. (see OffloadingArraysToPrivatize)
9495 // Finally, all other local values that are live-in into the outlined region
9496 // end up in a structure whose pointer is passed as the last argument. This
9497 // piece of data is passed in the "shared" field of the task structure. So,
9498 // we know we have to pass shareds to the task if the number of arguments is
9499 // greater than OffloadingArraysToPrivatize.size() + 1 The 1 is for the
9500 // thread id. Further, for safety, we assert that the number of arguments of
9501 // StaleCI is exactly OffloadingArraysToPrivatize.size() + 2
9502 const unsigned int NumStaleCIArgs = StaleCI->arg_size();
9503 bool HasShareds = NumStaleCIArgs > OffloadingArraysToPrivatize.size() + 1;
9504 assert((!HasShareds ||
9505 NumStaleCIArgs == (OffloadingArraysToPrivatize.size() + 2)) &&
9506 "Wrong number of arguments for StaleCI when shareds are present");
9507 int SharedArgOperandNo =
9508 HasShareds ? OffloadingArraysToPrivatize.size() + 1 : 0;
9509
9510 StructType *TaskWithPrivatesTy =
9511 createTaskWithPrivatesTy(OMPIRBuilder&: *this, OffloadingArraysToPrivatize);
9512 StructType *PrivatesTy = nullptr;
9513
9514 if (!OffloadingArraysToPrivatize.empty())
9515 PrivatesTy =
9516 static_cast<StructType *>(TaskWithPrivatesTy->getElementType(N: 1));
9517
9518 Function *ProxyFn = emitTargetTaskProxyFunction(
9519 OMPBuilder&: *this, Builder, StaleCI, PrivatesTy, TaskWithPrivatesTy,
9520 NumOffloadingArrays: OffloadingArraysToPrivatize.size(), SharedArgsOperandNo: SharedArgOperandNo);
9521
9522 LLVM_DEBUG(dbgs() << "Proxy task entry function created: " << *ProxyFn
9523 << "\n");
9524
9525 Builder.SetInsertPoint(StaleCI);
9526
9527 // Gather the arguments for emitting the runtime call.
9528 uint32_t SrcLocStrSize;
9529 Constant *SrcLocStr =
9530 getOrCreateSrcLocStr(Loc: LocationDescription(Builder), SrcLocStrSize);
9531 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
9532
9533 // @__kmpc_omp_task_alloc or @__kmpc_omp_target_task_alloc
9534 //
9535 // If `HasNoWait == true`, we call @__kmpc_omp_target_task_alloc to provide
9536 // the DeviceID to the deferred task and also since
9537 // @__kmpc_omp_target_task_alloc creates an untied/async task.
9538 Function *TaskAllocFn =
9539 !NeedsTargetTask
9540 ? getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_alloc)
9541 : getOrCreateRuntimeFunctionPtr(
9542 FnID: OMPRTL___kmpc_omp_target_task_alloc);
9543
9544 // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
9545 // call.
9546 Value *ThreadID = getOrCreateThreadID(Ident);
9547
9548 // Argument - `sizeof_kmp_task_t` (TaskSize)
9549 // Tasksize refers to the size in bytes of kmp_task_t data structure
9550 // plus any other data to be passed to the target task, if any, which
9551 // is packed into a struct. kmp_task_t and the struct so created are
9552 // packed into a wrapper struct whose type is TaskWithPrivatesTy.
9553 Value *TaskSize = Builder.getInt64(
9554 C: M.getDataLayout().getTypeStoreSize(Ty: TaskWithPrivatesTy));
9555
9556 // Argument - `sizeof_shareds` (SharedsSize)
9557 // SharedsSize refers to the shareds array size in the kmp_task_t data
9558 // structure.
9559 Value *SharedsSize = Builder.getInt64(C: 0);
9560 if (HasShareds) {
9561 auto *ArgStructAlloca =
9562 dyn_cast<AllocaInst>(Val: StaleCI->getArgOperand(i: SharedArgOperandNo));
9563 assert(ArgStructAlloca &&
9564 "Unable to find the alloca instruction corresponding to arguments "
9565 "for extracted function");
9566 std::optional<TypeSize> ArgAllocSize =
9567 ArgStructAlloca->getAllocationSize(DL: M.getDataLayout());
9568 assert(ArgAllocSize &&
9569 "Unable to determine size of arguments for extracted function");
9570 SharedsSize = Builder.getInt64(C: ArgAllocSize->getFixedValue());
9571 }
9572
9573 // Argument - `flags`
9574 // Task is tied iff (Flags & 1) == 1.
9575 // Task is untied iff (Flags & 1) == 0.
9576 // Task is final iff (Flags & 2) == 2.
9577 // Task is not final iff (Flags & 2) == 0.
9578 // A target task is not final and is untied.
9579 Value *Flags = Builder.getInt32(C: 0);
9580
9581 // Emit the @__kmpc_omp_task_alloc runtime call
9582 // The runtime call returns a pointer to an area where the task captured
9583 // variables must be copied before the task is run (TaskData)
9584 CallInst *TaskData = nullptr;
9585
9586 SmallVector<llvm::Value *> TaskAllocArgs = {
9587 /*loc_ref=*/Ident, /*gtid=*/ThreadID,
9588 /*flags=*/Flags,
9589 /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
9590 /*task_func=*/ProxyFn};
9591
9592 if (NeedsTargetTask) {
9593 assert(DeviceID && "Expected non-empty device ID.");
9594 TaskAllocArgs.push_back(Elt: DeviceID);
9595 }
9596
9597 TaskData = createRuntimeFunctionCall(Callee: TaskAllocFn, Args: TaskAllocArgs);
9598
9599 Align Alignment = TaskData->getPointerAlignment(DL: M.getDataLayout());
9600 if (HasShareds) {
9601 Value *Shareds = StaleCI->getArgOperand(i: SharedArgOperandNo);
9602 Value *TaskShareds = loadSharedDataFromTaskDescriptor(
9603 OMPIRBuilder&: *this, Builder, TaskWithPrivates: TaskData, TaskWithPrivatesTy);
9604 Builder.CreateMemCpy(Dst: TaskShareds, DstAlign: Alignment, Src: Shareds, SrcAlign: Alignment,
9605 Size: SharedsSize);
9606 }
9607 if (!OffloadingArraysToPrivatize.empty()) {
9608 Value *Privates =
9609 Builder.CreateStructGEP(Ty: TaskWithPrivatesTy, Ptr: TaskData, Idx: 1);
9610 for (unsigned int i = 0; i < OffloadingArraysToPrivatize.size(); ++i) {
9611 Value *PtrToPrivatize = OffloadingArraysToPrivatize[i];
9612 [[maybe_unused]] Type *ArrayType =
9613 getOffloadingArrayType(V: PtrToPrivatize);
9614 assert(ArrayType && "ArrayType cannot be nullptr");
9615
9616 Type *ElementType = PrivatesTy->getElementType(N: i);
9617 assert(ElementType == ArrayType &&
9618 "ElementType should match ArrayType");
9619 (void)ArrayType;
9620
9621 Value *Dst = Builder.CreateStructGEP(Ty: PrivatesTy, Ptr: Privates, Idx: i);
9622 Builder.CreateMemCpy(
9623 Dst, DstAlign: Alignment, Src: PtrToPrivatize, SrcAlign: Alignment,
9624 Size: Builder.getInt64(C: M.getDataLayout().getTypeStoreSize(Ty: ElementType)));
9625 }
9626 }
9627
9628 Value *DepArray = nullptr;
9629 Value *NumDeps = nullptr;
9630 if (Dependencies.DepArray) {
9631 DepArray = Dependencies.DepArray;
9632 NumDeps = Dependencies.NumDeps;
9633 } else if (!Dependencies.Deps.empty()) {
9634 DepArray = emitTaskDependencies(OMPBuilder&: *this, Dependencies: Dependencies.Deps);
9635 NumDeps = Builder.getInt32(C: Dependencies.Deps.size());
9636 }
9637
9638 // ---------------------------------------------------------------
9639 // V5.2 13.8 target construct
9640 // If the nowait clause is present, execution of the target task
9641 // may be deferred. If the nowait clause is not present, the target task is
9642 // an included task.
9643 // ---------------------------------------------------------------
9644 // The above means that the lack of a nowait on the target construct
9645 // translates to '#pragma omp task if(0)'
9646 if (!NeedsTargetTask) {
9647 if (DepArray) {
9648 Function *TaskWaitFn =
9649 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_wait_deps);
9650 createRuntimeFunctionCall(
9651 Callee: TaskWaitFn,
9652 Args: {/*loc_ref=*/Ident, /*gtid=*/ThreadID,
9653 /*ndeps=*/NumDeps,
9654 /*dep_list=*/DepArray,
9655 /*ndeps_noalias=*/ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
9656 /*noalias_dep_list=*/
9657 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
9658 }
9659 // Included task.
9660 Function *TaskBeginFn =
9661 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_begin_if0);
9662 Function *TaskCompleteFn =
9663 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_complete_if0);
9664 createRuntimeFunctionCall(Callee: TaskBeginFn, Args: {Ident, ThreadID, TaskData});
9665 CallInst *CI = createRuntimeFunctionCall(Callee: ProxyFn, Args: {ThreadID, TaskData});
9666 CI->setDebugLoc(StaleCI->getDebugLoc());
9667 createRuntimeFunctionCall(Callee: TaskCompleteFn, Args: {Ident, ThreadID, TaskData});
9668 } else if (DepArray) {
9669 // HasNoWait - meaning the task may be deferred. Call
9670 // __kmpc_omp_task_with_deps if there are dependencies,
9671 // else call __kmpc_omp_task
9672 Function *TaskFn =
9673 getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task_with_deps);
9674 createRuntimeFunctionCall(
9675 Callee: TaskFn,
9676 Args: {Ident, ThreadID, TaskData, NumDeps, DepArray,
9677 ConstantInt::get(Ty: Builder.getInt32Ty(), V: 0),
9678 ConstantPointerNull::get(T: PointerType::getUnqual(C&: M.getContext()))});
9679 } else {
9680 // Emit the @__kmpc_omp_task runtime call to spawn the task
9681 Function *TaskFn = getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_omp_task);
9682 createRuntimeFunctionCall(Callee: TaskFn, Args: {Ident, ThreadID, TaskData});
9683 }
9684
9685 StaleCI->eraseFromParent();
9686 for (Instruction *I : llvm::reverse(C&: ToBeDeleted))
9687 I->eraseFromParent();
9688 };
9689 addOutlineInfo(OI: std::move(OI));
9690
9691 LLVM_DEBUG(dbgs() << "Insert block after emitKernelLaunch = \n"
9692 << *(Builder.GetInsertBlock()) << "\n");
9693 LLVM_DEBUG(dbgs() << "Module after emitKernelLaunch = \n"
9694 << *(Builder.GetInsertBlock()->getParent()->getParent())
9695 << "\n");
9696 return Builder.saveIP();
9697}
9698
9699Error OpenMPIRBuilder::emitOffloadingArraysAndArgs(
9700 InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
9701 TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo,
9702 CustomMapperCallbackTy CustomMapperCB, bool IsNonContiguous,
9703 bool ForEndCall, function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
9704 if (Error Err =
9705 emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info,
9706 CustomMapperCB, IsNonContiguous, DeviceAddrCB))
9707 return Err;
9708 emitOffloadingArraysArgument(Builder, RTArgs, Info, ForEndCall);
9709 return Error::success();
9710}
9711
9712static void emitTargetCall(
9713 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
9714 OpenMPIRBuilder::InsertPointTy AllocaIP,
9715 ArrayRef<BasicBlock *> DeallocBlocks, OpenMPIRBuilder::TargetDataInfo &Info,
9716 const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
9717 const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
9718 Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID,
9719 SmallVectorImpl<Value *> &Args,
9720 OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
9721 OpenMPIRBuilder::CustomMapperCallbackTy CustomMapperCB,
9722 const OpenMPIRBuilder::DependenciesInfo &Dependencies, bool HasNoWait,
9723 Value *DynCGroupMem, OMPDynGroupprivateFallbackType DynCGroupMemFallback) {
9724 // Generate a function call to the host fallback implementation of the target
9725 // region. This is called by the host when no offload entry was generated for
9726 // the target region and when the offloading call fails at runtime.
9727 auto &&EmitTargetCallFallbackCB = [&](OpenMPIRBuilder::InsertPointTy IP)
9728 -> OpenMPIRBuilder::InsertPointOrErrorTy {
9729 Builder.restoreIP(IP);
9730 // Ensure the host fallback has the same dyn_ptr ABI as the device.
9731 SmallVector<Value *> FallbackArgs(Args.begin(), Args.end());
9732 FallbackArgs.push_back(
9733 Elt: Constant::getNullValue(Ty: PointerType::getUnqual(C&: Builder.getContext())));
9734 OMPBuilder.createRuntimeFunctionCall(Callee: OutlinedFn, Args: FallbackArgs);
9735 return Builder.saveIP();
9736 };
9737
9738 bool HasDependencies = !Dependencies.empty();
9739 bool RequiresOuterTargetTask = HasNoWait || HasDependencies;
9740
9741 OpenMPIRBuilder::TargetKernelArgs KArgs;
9742
9743 auto TaskBodyCB =
9744 [&](Value *DeviceID, Value *RTLoc,
9745 IRBuilderBase::InsertPoint TargetTaskAllocaIP) -> Error {
9746 // Assume no error was returned because EmitTargetCallFallbackCB doesn't
9747 // produce any.
9748 llvm::OpenMPIRBuilder::InsertPointTy AfterIP = cantFail(ValOrErr: [&]() {
9749 // emitKernelLaunch makes the necessary runtime call to offload the
9750 // kernel. We then outline all that code into a separate function
9751 // ('kernel_launch_function' in the pseudo code above). This function is
9752 // then called by the target task proxy function (see
9753 // '@.omp_target_task_proxy_func' in the pseudo code above)
9754 // "@.omp_target_task_proxy_func' is generated by
9755 // emitTargetTaskProxyFunction.
9756 if (OutlinedFnID && DeviceID)
9757 return OMPBuilder.emitKernelLaunch(Loc: Builder, OutlinedFnID,
9758 EmitTargetCallFallbackCB, Args&: KArgs,
9759 DeviceID, RTLoc, AllocaIP: TargetTaskAllocaIP);
9760
9761 // We only need to do the outlining if `DeviceID` is set to avoid calling
9762 // `emitKernelLaunch` if we want to code-gen for the host; e.g. if we are
9763 // generating the `else` branch of an `if` clause.
9764 //
9765 // When OutlinedFnID is set to nullptr, then it's not an offloading call.
9766 // In this case, we execute the host implementation directly.
9767 return EmitTargetCallFallbackCB(OMPBuilder.Builder.saveIP());
9768 }());
9769
9770 OMPBuilder.Builder.restoreIP(IP: AfterIP);
9771 return Error::success();
9772 };
9773
9774 auto &&EmitTargetCallElse =
9775 [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
9776 OpenMPIRBuilder::InsertPointTy CodeGenIP,
9777 ArrayRef<BasicBlock *> DeallocBlocks) -> Error {
9778 // Assume no error was returned because EmitTargetCallFallbackCB doesn't
9779 // produce any.
9780 OpenMPIRBuilder::InsertPointTy AfterIP = cantFail(ValOrErr: [&]() {
9781 if (RequiresOuterTargetTask) {
9782 // Arguments that are intended to be directly forwarded to an
9783 // emitKernelLaunch call are pased as nullptr, since
9784 // OutlinedFnID=nullptr results in that call not being done.
9785 OpenMPIRBuilder::TargetDataRTArgs EmptyRTArgs;
9786 return OMPBuilder.emitTargetTask(TaskBodyCB, /*DeviceID=*/nullptr,
9787 /*RTLoc=*/nullptr, AllocaIP,
9788 Dependencies, RTArgs: EmptyRTArgs, HasNoWait);
9789 }
9790 return EmitTargetCallFallbackCB(Builder.saveIP());
9791 }());
9792
9793 Builder.restoreIP(IP: AfterIP);
9794 return Error::success();
9795 };
9796
9797 auto &&EmitTargetCallThen =
9798 [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
9799 OpenMPIRBuilder::InsertPointTy CodeGenIP,
9800 ArrayRef<BasicBlock *> DeallocBlocks) -> Error {
9801 Info.HasNoWait = HasNoWait;
9802 OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
9803
9804 OpenMPIRBuilder::TargetDataRTArgs RTArgs;
9805 if (Error Err = OMPBuilder.emitOffloadingArraysAndArgs(
9806 AllocaIP, CodeGenIP: Builder.saveIP(), Info, RTArgs, CombinedInfo&: MapInfo, CustomMapperCB,
9807 /*IsNonContiguous=*/true,
9808 /*ForEndCall=*/false))
9809 return Err;
9810
9811 SmallVector<Value *, 3> NumTeamsC;
9812 for (auto [DefaultVal, RuntimeVal] :
9813 zip_equal(t: DefaultAttrs.MaxTeams, u: RuntimeAttrs.MaxTeams))
9814 NumTeamsC.push_back(Elt: RuntimeVal ? RuntimeVal
9815 : Builder.getInt32(C: DefaultVal));
9816
9817 // Calculate number of threads: 0 if no clauses specified, otherwise it is
9818 // the minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
9819 auto InitMaxThreadsClause = [&Builder](Value *Clause) {
9820 if (Clause)
9821 Clause = Builder.CreateIntCast(V: Clause, DestTy: Builder.getInt32Ty(),
9822 /*isSigned=*/false);
9823 return Clause;
9824 };
9825 auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
9826 if (Clause)
9827 Result =
9828 Result ? Builder.CreateSelect(C: Builder.CreateICmpULT(LHS: Result, RHS: Clause),
9829 True: Result, False: Clause)
9830 : Clause;
9831 };
9832
9833 // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
9834 // the NUM_THREADS clause is overriden by THREAD_LIMIT.
9835 SmallVector<Value *, 3> NumThreadsC;
9836 Value *MaxThreadsClause =
9837 RuntimeAttrs.TeamsThreadLimit.size() == 1
9838 ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
9839 : nullptr;
9840
9841 for (auto [TeamsVal, TargetVal] : zip_equal(
9842 t: RuntimeAttrs.TeamsThreadLimit, u: RuntimeAttrs.TargetThreadLimit)) {
9843 Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
9844 Value *NumThreads = InitMaxThreadsClause(TargetVal);
9845
9846 CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
9847 CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
9848
9849 NumThreadsC.push_back(Elt: NumThreads ? NumThreads : Builder.getInt32(C: 0));
9850 }
9851
9852 unsigned NumTargetItems = Info.NumberOfPtrs;
9853 uint32_t SrcLocStrSize;
9854 Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
9855 Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
9856 LocFlags: llvm::omp::IdentFlag(0), Reserve2Flags: 0);
9857
9858 Value *TripCount = RuntimeAttrs.LoopTripCount
9859 ? Builder.CreateIntCast(V: RuntimeAttrs.LoopTripCount,
9860 DestTy: Builder.getInt64Ty(),
9861 /*isSigned=*/false)
9862 : Builder.getInt64(C: 0);
9863
9864 // Request zero groupprivate bytes by default.
9865 if (!DynCGroupMem)
9866 DynCGroupMem = Builder.getInt32(C: 0);
9867
9868 KArgs = OpenMPIRBuilder::TargetKernelArgs(
9869 NumTargetItems, RTArgs, TripCount, NumTeamsC, NumThreadsC, DynCGroupMem,
9870 HasNoWait, /*StrictBlocksAndThreads=*/false, DynCGroupMemFallback);
9871
9872 // Assume no error was returned because TaskBodyCB and
9873 // EmitTargetCallFallbackCB don't produce any.
9874 OpenMPIRBuilder::InsertPointTy AfterIP = cantFail(ValOrErr: [&]() {
9875 // The presence of certain clauses on the target directive require the
9876 // explicit generation of the target task.
9877 if (RequiresOuterTargetTask)
9878 return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID: RuntimeAttrs.DeviceID,
9879 RTLoc, AllocaIP, Dependencies,
9880 RTArgs: KArgs.RTArgs, HasNoWait: Info.HasNoWait);
9881
9882 return OMPBuilder.emitKernelLaunch(
9883 Loc: Builder, OutlinedFnID, EmitTargetCallFallbackCB, Args&: KArgs,
9884 DeviceID: RuntimeAttrs.DeviceID, RTLoc, AllocaIP);
9885 }());
9886
9887 Builder.restoreIP(IP: AfterIP);
9888 return Error::success();
9889 };
9890
9891 // If we don't have an ID for the target region, it means an offload entry
9892 // wasn't created. In this case we just run the host fallback directly and
9893 // ignore any potential 'if' clauses.
9894 if (!OutlinedFnID) {
9895 cantFail(Err: EmitTargetCallElse(AllocaIP, Builder.saveIP(), DeallocBlocks));
9896 return;
9897 }
9898
9899 // If there's no 'if' clause, only generate the kernel launch code path.
9900 if (!IfCond) {
9901 cantFail(Err: EmitTargetCallThen(AllocaIP, Builder.saveIP(), DeallocBlocks));
9902 return;
9903 }
9904
9905 cantFail(Err: OMPBuilder.emitIfClause(Cond: IfCond, ThenGen: EmitTargetCallThen,
9906 ElseGen: EmitTargetCallElse, AllocaIP));
9907}
9908
9909OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
9910 const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
9911 InsertPointTy CodeGenIP, ArrayRef<BasicBlock *> DeallocBlocks,
9912 TargetDataInfo &Info, TargetRegionEntryInfo &EntryInfo,
9913 const TargetKernelDefaultAttrs &DefaultAttrs,
9914 const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
9915 SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
9916 OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
9917 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
9918 CustomMapperCallbackTy CustomMapperCB, const DependenciesInfo &Dependencies,
9919 bool HasNowait, Value *DynCGroupMem,
9920 OMPDynGroupprivateFallbackType DynCGroupMemFallback) {
9921
9922 if (!updateToLocation(Loc))
9923 return InsertPointTy();
9924
9925 Builder.restoreIP(IP: CodeGenIP);
9926
9927 Function *OutlinedFn;
9928 Constant *OutlinedFnID = nullptr;
9929 // The target region is outlined into its own function. The LLVM IR for
9930 // the target region itself is generated using the callbacks CBFunc
9931 // and ArgAccessorFuncCB
9932 if (Error Err = emitTargetOutlinedFunction(
9933 OMPBuilder&: *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
9934 OutlinedFnID, Inputs, CBFunc, ArgAccessorFuncCB))
9935 return Err;
9936
9937 // If we are not on the target device, then we need to generate code
9938 // to make a remote call (offload) to the previously outlined function
9939 // that represents the target region. Do that now.
9940 if (!Config.isTargetDevice())
9941 emitTargetCall(OMPBuilder&: *this, Builder, AllocaIP, DeallocBlocks, Info, DefaultAttrs,
9942 RuntimeAttrs, IfCond, OutlinedFn, OutlinedFnID, Args&: Inputs,
9943 GenMapInfoCB, CustomMapperCB, Dependencies, HasNoWait: HasNowait,
9944 DynCGroupMem, DynCGroupMemFallback);
9945 return Builder.saveIP();
9946}
9947
9948std::string OpenMPIRBuilder::getNameWithSeparators(ArrayRef<StringRef> Parts,
9949 StringRef FirstSeparator,
9950 StringRef Separator) {
9951 SmallString<128> Buffer;
9952 llvm::raw_svector_ostream OS(Buffer);
9953 StringRef Sep = FirstSeparator;
9954 for (StringRef Part : Parts) {
9955 OS << Sep << Part;
9956 Sep = Separator;
9957 }
9958 return OS.str().str();
9959}
9960
9961std::string
9962OpenMPIRBuilder::createPlatformSpecificName(ArrayRef<StringRef> Parts) const {
9963 return OpenMPIRBuilder::getNameWithSeparators(Parts, FirstSeparator: Config.firstSeparator(),
9964 Separator: Config.separator());
9965}
9966
9967GlobalVariable *OpenMPIRBuilder::getOrCreateInternalVariable(
9968 Type *Ty, const StringRef &Name, std::optional<unsigned> AddressSpace) {
9969 auto &Elem = *InternalVars.try_emplace(Key: Name, Args: nullptr).first;
9970 if (Elem.second) {
9971 assert(Elem.second->getValueType() == Ty &&
9972 "OMP internal variable has different type than requested");
9973 } else {
9974 // TODO: investigate the appropriate linkage type used for the global
9975 // variable for possibly changing that to internal or private, or maybe
9976 // create different versions of the function for different OMP internal
9977 // variables.
9978 const DataLayout &DL = M.getDataLayout();
9979 // TODO: Investigate why AMDGPU expects AS 0 for globals even though the
9980 // default global AS is 1.
9981 // See double-target-call-with-declare-target.f90 and
9982 // declare-target-vars-in-target-region.f90 libomptarget
9983 // tests.
9984 unsigned AddressSpaceVal = AddressSpace ? *AddressSpace
9985 : M.getTargetTriple().isAMDGPU()
9986 ? 0
9987 : DL.getDefaultGlobalsAddressSpace();
9988 auto Linkage = this->M.getTargetTriple().getArch() == Triple::wasm32
9989 ? GlobalValue::InternalLinkage
9990 : GlobalValue::CommonLinkage;
9991 auto *GV = new GlobalVariable(M, Ty, /*IsConstant=*/false, Linkage,
9992 Constant::getNullValue(Ty), Elem.first(),
9993 /*InsertBefore=*/nullptr,
9994 GlobalValue::NotThreadLocal, AddressSpaceVal);
9995 const llvm::Align TypeAlign = DL.getABITypeAlign(Ty);
9996 const llvm::Align PtrAlign = DL.getPointerABIAlignment(AS: AddressSpaceVal);
9997 GV->setAlignment(std::max(a: TypeAlign, b: PtrAlign));
9998 Elem.second = GV;
9999 }
10000
10001 return Elem.second;
10002}
10003
10004Value *OpenMPIRBuilder::getOMPCriticalRegionLock(StringRef CriticalName) {
10005 std::string Prefix = Twine("gomp_critical_user_", CriticalName).str();
10006 std::string Name = getNameWithSeparators(Parts: {Prefix, "var"}, FirstSeparator: ".", Separator: ".");
10007 return getOrCreateInternalVariable(Ty: KmpCriticalNameTy, Name);
10008}
10009
10010Value *OpenMPIRBuilder::getSizeInBytes(Value *BasePtr) {
10011 LLVMContext &Ctx = Builder.getContext();
10012 Value *Null =
10013 Constant::getNullValue(Ty: PointerType::getUnqual(C&: BasePtr->getContext()));
10014 Value *SizeGep =
10015 Builder.CreateGEP(Ty: BasePtr->getType(), Ptr: Null, IdxList: Builder.getInt32(C: 1));
10016 Value *SizePtrToInt = Builder.CreatePtrToInt(V: SizeGep, DestTy: Type::getInt64Ty(C&: Ctx));
10017 return SizePtrToInt;
10018}
10019
10020GlobalVariable *
10021OpenMPIRBuilder::createOffloadMaptypes(SmallVectorImpl<uint64_t> &Mappings,
10022 std::string VarName) {
10023 llvm::Constant *MaptypesArrayInit =
10024 llvm::ConstantDataArray::get(Context&: M.getContext(), Elts&: Mappings);
10025 auto *MaptypesArrayGlobal = new llvm::GlobalVariable(
10026 M, MaptypesArrayInit->getType(),
10027 /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MaptypesArrayInit,
10028 VarName);
10029 MaptypesArrayGlobal->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
10030 return MaptypesArrayGlobal;
10031}
10032
10033void OpenMPIRBuilder::createMapperAllocas(const LocationDescription &Loc,
10034 InsertPointTy AllocaIP,
10035 unsigned NumOperands,
10036 struct MapperAllocas &MapperAllocas) {
10037 if (!updateToLocation(Loc))
10038 return;
10039
10040 auto *ArrI8PtrTy = ArrayType::get(ElementType: Int8Ptr, NumElements: NumOperands);
10041 auto *ArrI64Ty = ArrayType::get(ElementType: Int64, NumElements: NumOperands);
10042 Builder.restoreIP(IP: AllocaIP);
10043 AllocaInst *ArgsBase = Builder.CreateAlloca(
10044 Ty: ArrI8PtrTy, /* ArraySize = */ nullptr, Name: ".offload_baseptrs");
10045 AllocaInst *Args = Builder.CreateAlloca(Ty: ArrI8PtrTy, /* ArraySize = */ nullptr,
10046 Name: ".offload_ptrs");
10047 AllocaInst *ArgSizes = Builder.CreateAlloca(
10048 Ty: ArrI64Ty, /* ArraySize = */ nullptr, Name: ".offload_sizes");
10049 updateToLocation(Loc);
10050 MapperAllocas.ArgsBase = ArgsBase;
10051 MapperAllocas.Args = Args;
10052 MapperAllocas.ArgSizes = ArgSizes;
10053}
10054
10055void OpenMPIRBuilder::emitMapperCall(const LocationDescription &Loc,
10056 Function *MapperFunc, Value *SrcLocInfo,
10057 Value *MaptypesArg, Value *MapnamesArg,
10058 struct MapperAllocas &MapperAllocas,
10059 int64_t DeviceID, unsigned NumOperands) {
10060 if (!updateToLocation(Loc))
10061 return;
10062
10063 auto *ArrI8PtrTy = ArrayType::get(ElementType: Int8Ptr, NumElements: NumOperands);
10064 auto *ArrI64Ty = ArrayType::get(ElementType: Int64, NumElements: NumOperands);
10065 Value *ArgsBaseGEP =
10066 Builder.CreateInBoundsGEP(Ty: ArrI8PtrTy, Ptr: MapperAllocas.ArgsBase,
10067 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 0)});
10068 Value *ArgsGEP =
10069 Builder.CreateInBoundsGEP(Ty: ArrI8PtrTy, Ptr: MapperAllocas.Args,
10070 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 0)});
10071 Value *ArgSizesGEP =
10072 Builder.CreateInBoundsGEP(Ty: ArrI64Ty, Ptr: MapperAllocas.ArgSizes,
10073 IdxList: {Builder.getInt32(C: 0), Builder.getInt32(C: 0)});
10074 Value *NullPtr =
10075 Constant::getNullValue(Ty: PointerType::getUnqual(C&: Int8Ptr->getContext()));
10076 createRuntimeFunctionCall(Callee: MapperFunc, Args: {SrcLocInfo, Builder.getInt64(C: DeviceID),
10077 Builder.getInt32(C: NumOperands),
10078 ArgsBaseGEP, ArgsGEP, ArgSizesGEP,
10079 MaptypesArg, MapnamesArg, NullPtr});
10080}
10081
10082void OpenMPIRBuilder::emitOffloadingArraysArgument(IRBuilderBase &Builder,
10083 TargetDataRTArgs &RTArgs,
10084 TargetDataInfo &Info,
10085 bool ForEndCall) {
10086 assert((!ForEndCall || Info.separateBeginEndCalls()) &&
10087 "expected region end call to runtime only when end call is separate");
10088 auto UnqualPtrTy = PointerType::getUnqual(C&: M.getContext());
10089 auto VoidPtrTy = UnqualPtrTy;
10090 auto VoidPtrPtrTy = UnqualPtrTy;
10091 auto Int64Ty = Type::getInt64Ty(C&: M.getContext());
10092 auto Int64PtrTy = UnqualPtrTy;
10093
10094 if (!Info.NumberOfPtrs) {
10095 RTArgs.BasePointersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
10096 RTArgs.PointersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
10097 RTArgs.SizesArray = ConstantPointerNull::get(T: Int64PtrTy);
10098 RTArgs.MapTypesArray = ConstantPointerNull::get(T: Int64PtrTy);
10099 RTArgs.MapNamesArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
10100 RTArgs.MappersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
10101 return;
10102 }
10103
10104 RTArgs.BasePointersArray = Builder.CreateConstInBoundsGEP2_32(
10105 Ty: ArrayType::get(ElementType: VoidPtrTy, NumElements: Info.NumberOfPtrs),
10106 Ptr: Info.RTArgs.BasePointersArray,
10107 /*Idx0=*/0, /*Idx1=*/0);
10108 RTArgs.PointersArray = Builder.CreateConstInBoundsGEP2_32(
10109 Ty: ArrayType::get(ElementType: VoidPtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.PointersArray,
10110 /*Idx0=*/0,
10111 /*Idx1=*/0);
10112 RTArgs.SizesArray = Builder.CreateConstInBoundsGEP2_32(
10113 Ty: ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.SizesArray,
10114 /*Idx0=*/0, /*Idx1=*/0);
10115 RTArgs.MapTypesArray = Builder.CreateConstInBoundsGEP2_32(
10116 Ty: ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs),
10117 Ptr: ForEndCall && Info.RTArgs.MapTypesArrayEnd ? Info.RTArgs.MapTypesArrayEnd
10118 : Info.RTArgs.MapTypesArray,
10119 /*Idx0=*/0,
10120 /*Idx1=*/0);
10121
10122 // Only emit the mapper information arrays if debug information is
10123 // requested.
10124 if (!Info.EmitDebug)
10125 RTArgs.MapNamesArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
10126 else
10127 RTArgs.MapNamesArray = Builder.CreateConstInBoundsGEP2_32(
10128 Ty: ArrayType::get(ElementType: VoidPtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.MapNamesArray,
10129 /*Idx0=*/0,
10130 /*Idx1=*/0);
10131 // If there is no user-defined mapper, set the mapper array to nullptr to
10132 // avoid an unnecessary data privatization
10133 if (!Info.HasMapper)
10134 RTArgs.MappersArray = ConstantPointerNull::get(T: VoidPtrPtrTy);
10135 else
10136 RTArgs.MappersArray =
10137 Builder.CreatePointerCast(V: Info.RTArgs.MappersArray, DestTy: VoidPtrPtrTy);
10138}
10139
10140void OpenMPIRBuilder::emitNonContiguousDescriptor(InsertPointTy AllocaIP,
10141 InsertPointTy CodeGenIP,
10142 MapInfosTy &CombinedInfo,
10143 TargetDataInfo &Info) {
10144 MapInfosTy::StructNonContiguousInfo &NonContigInfo =
10145 CombinedInfo.NonContigInfo;
10146
10147 // Build an array of struct descriptor_dim and then assign it to
10148 // offload_args.
10149 //
10150 // struct descriptor_dim {
10151 // uint64_t offset;
10152 // uint64_t count;
10153 // uint64_t stride
10154 // };
10155 Type *Int64Ty = Builder.getInt64Ty();
10156 StructType *DimTy = StructType::create(
10157 Context&: M.getContext(), Elements: ArrayRef<Type *>({Int64Ty, Int64Ty, Int64Ty}),
10158 Name: "struct.descriptor_dim");
10159
10160 enum { OffsetFD = 0, CountFD, StrideFD };
10161 // We need two index variable here since the size of "Dims" is the same as
10162 // the size of Components, however, the size of offset, count, and stride is
10163 // equal to the size of base declaration that is non-contiguous.
10164 for (unsigned I = 0, L = 0, E = NonContigInfo.Dims.size(); I < E; ++I) {
10165 // Skip emitting ir if dimension size is 1 since it cannot be
10166 // non-contiguous.
10167 if (NonContigInfo.Dims[I] == 1)
10168 continue;
10169 Builder.restoreIP(IP: AllocaIP);
10170 ArrayType *ArrayTy = ArrayType::get(ElementType: DimTy, NumElements: NonContigInfo.Dims[I]);
10171 AllocaInst *DimsAddr =
10172 Builder.CreateAlloca(Ty: ArrayTy, /* ArraySize = */ nullptr, Name: "dims");
10173 Builder.restoreIP(IP: CodeGenIP);
10174 for (unsigned II = 0, EE = NonContigInfo.Dims[I]; II < EE; ++II) {
10175 unsigned RevIdx = EE - II - 1;
10176 Value *DimsLVal = Builder.CreateInBoundsGEP(
10177 Ty: ArrayTy, Ptr: DimsAddr, IdxList: {Builder.getInt64(C: 0), Builder.getInt64(C: II)});
10178 // Offset
10179 Value *OffsetLVal = Builder.CreateStructGEP(Ty: DimTy, Ptr: DimsLVal, Idx: OffsetFD);
10180 Builder.CreateAlignedStore(
10181 Val: NonContigInfo.Offsets[L][RevIdx], Ptr: OffsetLVal,
10182 Align: M.getDataLayout().getPrefTypeAlign(Ty: OffsetLVal->getType()));
10183 // Count
10184 Value *CountLVal = Builder.CreateStructGEP(Ty: DimTy, Ptr: DimsLVal, Idx: CountFD);
10185 Builder.CreateAlignedStore(
10186 Val: NonContigInfo.Counts[L][RevIdx], Ptr: CountLVal,
10187 Align: M.getDataLayout().getPrefTypeAlign(Ty: CountLVal->getType()));
10188 // Stride
10189 Value *StrideLVal = Builder.CreateStructGEP(Ty: DimTy, Ptr: DimsLVal, Idx: StrideFD);
10190 Builder.CreateAlignedStore(
10191 Val: NonContigInfo.Strides[L][RevIdx], Ptr: StrideLVal,
10192 Align: M.getDataLayout().getPrefTypeAlign(Ty: CountLVal->getType()));
10193 }
10194 // args[I] = &dims
10195 Builder.restoreIP(IP: CodeGenIP);
10196 Value *DAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
10197 V: DimsAddr, DestTy: Builder.getPtrTy());
10198 Value *P = Builder.CreateConstInBoundsGEP2_32(
10199 Ty: ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: Info.NumberOfPtrs),
10200 Ptr: Info.RTArgs.PointersArray, Idx0: 0, Idx1: I);
10201 Builder.CreateAlignedStore(
10202 Val: DAddr, Ptr: P, Align: M.getDataLayout().getPrefTypeAlign(Ty: Builder.getPtrTy()));
10203 ++L;
10204 }
10205}
10206
10207void OpenMPIRBuilder::emitUDMapperArrayInitOrDel(
10208 Function *MapperFn, Value *MapperHandle, Value *Base, Value *Begin,
10209 Value *Size, Value *MapType, Value *MapName, TypeSize ElementSize,
10210 BasicBlock *ExitBB, bool IsInit) {
10211 StringRef Prefix = IsInit ? ".init" : ".del";
10212
10213 // Evaluate if this is an array section.
10214 BasicBlock *BodyBB = BasicBlock::Create(
10215 Context&: M.getContext(), Name: createPlatformSpecificName(Parts: {"omp.array", Prefix}));
10216 Value *IsArray =
10217 Builder.CreateICmpSGT(LHS: Size, RHS: Builder.getInt64(C: 1), Name: "omp.arrayinit.isarray");
10218 Value *DeleteBit = Builder.CreateAnd(
10219 LHS: MapType,
10220 RHS: Builder.getInt64(
10221 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
10222 OpenMPOffloadMappingFlags::OMP_MAP_DELETE)));
10223 Value *DeleteCond;
10224 Value *Cond;
10225 if (IsInit) {
10226 // base != begin?
10227 Value *BaseIsBegin = Builder.CreateICmpNE(LHS: Base, RHS: Begin);
10228 Cond = Builder.CreateOr(LHS: IsArray, RHS: BaseIsBegin);
10229 DeleteCond = Builder.CreateIsNull(
10230 Arg: DeleteBit,
10231 Name: createPlatformSpecificName(Parts: {"omp.array", Prefix, ".delete"}));
10232 } else {
10233 Cond = IsArray;
10234 DeleteCond = Builder.CreateIsNotNull(
10235 Arg: DeleteBit,
10236 Name: createPlatformSpecificName(Parts: {"omp.array", Prefix, ".delete"}));
10237 }
10238 Cond = Builder.CreateAnd(LHS: Cond, RHS: DeleteCond);
10239 Builder.CreateCondBr(Cond, True: BodyBB, False: ExitBB);
10240
10241 emitBlock(BB: BodyBB, CurFn: MapperFn);
10242 // Get the array size by multiplying element size and element number (i.e., \p
10243 // Size).
10244 Value *ArraySize = Builder.CreateNUWMul(LHS: Size, RHS: Builder.getInt64(C: ElementSize));
10245 // Remove OMP_MAP_TO and OMP_MAP_FROM from the map type, so that it achieves
10246 // memory allocation/deletion purpose only.
10247 Value *MapTypeArg = Builder.CreateAnd(
10248 LHS: MapType,
10249 RHS: Builder.getInt64(
10250 C: ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
10251 OpenMPOffloadMappingFlags::OMP_MAP_TO |
10252 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
10253 MapTypeArg = Builder.CreateOr(
10254 LHS: MapTypeArg,
10255 RHS: Builder.getInt64(
10256 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
10257 OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT)));
10258
10259 // Call the runtime API __tgt_push_mapper_component to fill up the runtime
10260 // data structure.
10261 Value *OffloadingArgs[] = {MapperHandle, Base, Begin,
10262 ArraySize, MapTypeArg, MapName};
10263 createRuntimeFunctionCall(
10264 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___tgt_push_mapper_component),
10265 Args: OffloadingArgs);
10266}
10267
10268Expected<Function *> OpenMPIRBuilder::emitUserDefinedMapper(
10269 function_ref<MapInfosOrErrorTy(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
10270 llvm::Value *BeginArg)>
10271 GenMapInfoCB,
10272 Type *ElemTy, StringRef FuncName, CustomMapperCallbackTy CustomMapperCB,
10273 bool PreserveMemberOfFlags) {
10274 SmallVector<Type *> Params;
10275 Params.emplace_back(Args: Builder.getPtrTy());
10276 Params.emplace_back(Args: Builder.getPtrTy());
10277 Params.emplace_back(Args: Builder.getPtrTy());
10278 Params.emplace_back(Args: Builder.getInt64Ty());
10279 Params.emplace_back(Args: Builder.getInt64Ty());
10280 Params.emplace_back(Args: Builder.getPtrTy());
10281
10282 auto *FnTy =
10283 FunctionType::get(Result: Builder.getVoidTy(), Params, /* IsVarArg */ isVarArg: false);
10284
10285 SmallString<64> TyStr;
10286 raw_svector_ostream Out(TyStr);
10287 Function *MapperFn =
10288 Function::Create(Ty: FnTy, Linkage: GlobalValue::InternalLinkage, N: FuncName, M);
10289 MapperFn->addFnAttr(Kind: Attribute::NoInline);
10290 MapperFn->addFnAttr(Kind: Attribute::NoUnwind);
10291 MapperFn->addParamAttr(ArgNo: 0, Kind: Attribute::NoUndef);
10292 MapperFn->addParamAttr(ArgNo: 1, Kind: Attribute::NoUndef);
10293 MapperFn->addParamAttr(ArgNo: 2, Kind: Attribute::NoUndef);
10294 MapperFn->addParamAttr(ArgNo: 3, Kind: Attribute::NoUndef);
10295 MapperFn->addParamAttr(ArgNo: 4, Kind: Attribute::NoUndef);
10296 MapperFn->addParamAttr(ArgNo: 5, Kind: Attribute::NoUndef);
10297
10298 // Start the mapper function code generation.
10299 BasicBlock *EntryBB = BasicBlock::Create(Context&: M.getContext(), Name: "entry", Parent: MapperFn);
10300 auto SavedIP = Builder.saveIP();
10301 Builder.SetInsertPoint(EntryBB);
10302
10303 Value *MapperHandle = MapperFn->getArg(i: 0);
10304 Value *BaseIn = MapperFn->getArg(i: 1);
10305 Value *BeginIn = MapperFn->getArg(i: 2);
10306 Value *Size = MapperFn->getArg(i: 3);
10307 Value *MapType = MapperFn->getArg(i: 4);
10308 Value *MapName = MapperFn->getArg(i: 5);
10309
10310 // Compute the starting and end addresses of array elements.
10311 // Prepare common arguments for array initiation and deletion.
10312 // Convert the size in bytes into the number of array elements.
10313 TypeSize ElementSize = M.getDataLayout().getTypeStoreSize(Ty: ElemTy);
10314 Size = Builder.CreateExactUDiv(LHS: Size, RHS: Builder.getInt64(C: ElementSize));
10315 Value *PtrBegin = BeginIn;
10316 Value *PtrEnd = Builder.CreateGEP(Ty: ElemTy, Ptr: PtrBegin, IdxList: Size);
10317
10318 // Emit array initiation if this is an array section and \p MapType indicates
10319 // that memory allocation is required.
10320 BasicBlock *HeadBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.arraymap.head");
10321 emitUDMapperArrayInitOrDel(MapperFn, MapperHandle, Base: BaseIn, Begin: BeginIn, Size,
10322 MapType, MapName, ElementSize, ExitBB: HeadBB,
10323 /*IsInit=*/true);
10324
10325 // Emit a for loop to iterate through SizeArg of elements and map all of them.
10326
10327 // Emit the loop header block.
10328 emitBlock(BB: HeadBB, CurFn: MapperFn);
10329 BasicBlock *BodyBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.arraymap.body");
10330 BasicBlock *DoneBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.done");
10331 // Evaluate whether the initial condition is satisfied.
10332 Value *IsEmpty =
10333 Builder.CreateICmpEQ(LHS: PtrBegin, RHS: PtrEnd, Name: "omp.arraymap.isempty");
10334 Builder.CreateCondBr(Cond: IsEmpty, True: DoneBB, False: BodyBB);
10335
10336 // Emit the loop body block.
10337 emitBlock(BB: BodyBB, CurFn: MapperFn);
10338 BasicBlock *LastBB = BodyBB;
10339 PHINode *PtrPHI =
10340 Builder.CreatePHI(Ty: PtrBegin->getType(), NumReservedValues: 2, Name: "omp.arraymap.ptrcurrent");
10341 PtrPHI->addIncoming(V: PtrBegin, BB: HeadBB);
10342
10343 // Get map clause information. Fill up the arrays with all mapped variables.
10344 MapInfosOrErrorTy Info = GenMapInfoCB(Builder.saveIP(), PtrPHI, BeginIn);
10345 if (!Info)
10346 return Info.takeError();
10347
10348 // Call the runtime API __tgt_mapper_num_components to get the number of
10349 // pre-existing components.
10350 Value *OffloadingArgs[] = {MapperHandle};
10351 Value *PreviousSize = createRuntimeFunctionCall(
10352 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___tgt_mapper_num_components),
10353 Args: OffloadingArgs);
10354 Value *ShiftedPreviousSize =
10355 Builder.CreateShl(LHS: PreviousSize, RHS: Builder.getInt64(C: getFlagMemberOffset()));
10356
10357 // Fill up the runtime mapper handle for all components.
10358 for (unsigned I = 0; I < Info->BasePointers.size(); ++I) {
10359 Value *CurBaseArg = Info->BasePointers[I];
10360 Value *CurBeginArg = Info->Pointers[I];
10361 Value *CurSizeArg = Info->Sizes[I];
10362 Value *CurNameArg = Info->Names.size()
10363 ? Info->Names[I]
10364 : Constant::getNullValue(Ty: Builder.getPtrTy());
10365
10366 // Extract the MEMBER_OF field from the map type.
10367 Value *OriMapType = Builder.getInt64(
10368 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
10369 Info->Types[I]));
10370 Value *MemberMapType;
10371 if (PreserveMemberOfFlags) {
10372 constexpr uint64_t MemberOfMask =
10373 static_cast<uint64_t>(OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF);
10374 uint64_t OrigFlags =
10375 static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
10376 Info->Types[I]);
10377 bool HasMemberOf = (OrigFlags & MemberOfMask) != 0;
10378 if (HasMemberOf)
10379 MemberMapType = Builder.CreateNUWAdd(LHS: OriMapType, RHS: ShiftedPreviousSize);
10380 else
10381 MemberMapType = OriMapType;
10382 } else {
10383 MemberMapType = Builder.CreateNUWAdd(LHS: OriMapType, RHS: ShiftedPreviousSize);
10384 }
10385
10386 // Combine the map type inherited from user-defined mapper with that
10387 // specified in the program. According to the OMP_MAP_TO and OMP_MAP_FROM
10388 // bits of the \a MapType, which is the input argument of the mapper
10389 // function, the following code will set the OMP_MAP_TO and OMP_MAP_FROM
10390 // bits of MemberMapType.
10391 // [OpenMP 5.0], 1.2.6. map-type decay.
10392 // | alloc | to | from | tofrom | release | delete
10393 // ----------------------------------------------------------
10394 // alloc | alloc | alloc | alloc | alloc | release | delete
10395 // to | alloc | to | alloc | to | release | delete
10396 // from | alloc | alloc | from | from | release | delete
10397 // tofrom | alloc | to | from | tofrom | release | delete
10398 Value *LeftToFrom = Builder.CreateAnd(
10399 LHS: MapType,
10400 RHS: Builder.getInt64(
10401 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
10402 OpenMPOffloadMappingFlags::OMP_MAP_TO |
10403 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
10404 BasicBlock *AllocBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.alloc");
10405 BasicBlock *AllocElseBB =
10406 BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.alloc.else");
10407 BasicBlock *ToBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.to");
10408 BasicBlock *ToElseBB =
10409 BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.to.else");
10410 BasicBlock *FromBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.from");
10411 BasicBlock *EndBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.type.end");
10412 Value *IsAlloc = Builder.CreateIsNull(Arg: LeftToFrom);
10413 Builder.CreateCondBr(Cond: IsAlloc, True: AllocBB, False: AllocElseBB);
10414 // In case of alloc, clear OMP_MAP_TO and OMP_MAP_FROM.
10415 emitBlock(BB: AllocBB, CurFn: MapperFn);
10416 Value *AllocMapType = Builder.CreateAnd(
10417 LHS: MemberMapType,
10418 RHS: Builder.getInt64(
10419 C: ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
10420 OpenMPOffloadMappingFlags::OMP_MAP_TO |
10421 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
10422 Builder.CreateBr(Dest: EndBB);
10423 emitBlock(BB: AllocElseBB, CurFn: MapperFn);
10424 Value *IsTo = Builder.CreateICmpEQ(
10425 LHS: LeftToFrom,
10426 RHS: Builder.getInt64(
10427 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
10428 OpenMPOffloadMappingFlags::OMP_MAP_TO)));
10429 Builder.CreateCondBr(Cond: IsTo, True: ToBB, False: ToElseBB);
10430 // In case of to, clear OMP_MAP_FROM.
10431 emitBlock(BB: ToBB, CurFn: MapperFn);
10432 Value *ToMapType = Builder.CreateAnd(
10433 LHS: MemberMapType,
10434 RHS: Builder.getInt64(
10435 C: ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
10436 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
10437 Builder.CreateBr(Dest: EndBB);
10438 emitBlock(BB: ToElseBB, CurFn: MapperFn);
10439 Value *IsFrom = Builder.CreateICmpEQ(
10440 LHS: LeftToFrom,
10441 RHS: Builder.getInt64(
10442 C: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
10443 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
10444 Builder.CreateCondBr(Cond: IsFrom, True: FromBB, False: EndBB);
10445 // In case of from, clear OMP_MAP_TO.
10446 emitBlock(BB: FromBB, CurFn: MapperFn);
10447 Value *FromMapType = Builder.CreateAnd(
10448 LHS: MemberMapType,
10449 RHS: Builder.getInt64(
10450 C: ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
10451 OpenMPOffloadMappingFlags::OMP_MAP_TO)));
10452 // In case of tofrom, do nothing.
10453 emitBlock(BB: EndBB, CurFn: MapperFn);
10454 LastBB = EndBB;
10455 PHINode *CurMapType =
10456 Builder.CreatePHI(Ty: Builder.getInt64Ty(), NumReservedValues: 4, Name: "omp.maptype");
10457 CurMapType->addIncoming(V: AllocMapType, BB: AllocBB);
10458 CurMapType->addIncoming(V: ToMapType, BB: ToBB);
10459 CurMapType->addIncoming(V: FromMapType, BB: FromBB);
10460 CurMapType->addIncoming(V: MemberMapType, BB: ToElseBB);
10461
10462 Value *OffloadingArgs[] = {MapperHandle, CurBaseArg, CurBeginArg,
10463 CurSizeArg, CurMapType, CurNameArg};
10464
10465 auto ChildMapperFn = CustomMapperCB(I);
10466 if (!ChildMapperFn)
10467 return ChildMapperFn.takeError();
10468 if (*ChildMapperFn) {
10469 // Call the corresponding mapper function.
10470 createRuntimeFunctionCall(Callee: *ChildMapperFn, Args: OffloadingArgs)
10471 ->setDoesNotThrow();
10472 } else {
10473 // Call the runtime API __tgt_push_mapper_component to fill up the runtime
10474 // data structure.
10475 createRuntimeFunctionCall(
10476 Callee: getOrCreateRuntimeFunction(M, FnID: OMPRTL___tgt_push_mapper_component),
10477 Args: OffloadingArgs);
10478 }
10479 }
10480
10481 // Update the pointer to point to the next element that needs to be mapped,
10482 // and check whether we have mapped all elements.
10483 Value *PtrNext = Builder.CreateConstGEP1_32(Ty: ElemTy, Ptr: PtrPHI, /*Idx0=*/1,
10484 Name: "omp.arraymap.next");
10485 PtrPHI->addIncoming(V: PtrNext, BB: LastBB);
10486 Value *IsDone = Builder.CreateICmpEQ(LHS: PtrNext, RHS: PtrEnd, Name: "omp.arraymap.isdone");
10487 BasicBlock *ExitBB = BasicBlock::Create(Context&: M.getContext(), Name: "omp.arraymap.exit");
10488 Builder.CreateCondBr(Cond: IsDone, True: ExitBB, False: BodyBB);
10489
10490 emitBlock(BB: ExitBB, CurFn: MapperFn);
10491 // Emit array deletion if this is an array section and \p MapType indicates
10492 // that deletion is required.
10493 emitUDMapperArrayInitOrDel(MapperFn, MapperHandle, Base: BaseIn, Begin: BeginIn, Size,
10494 MapType, MapName, ElementSize, ExitBB: DoneBB,
10495 /*IsInit=*/false);
10496
10497 // Emit the function exit block.
10498 emitBlock(BB: DoneBB, CurFn: MapperFn, /*IsFinished=*/true);
10499
10500 Builder.CreateRetVoid();
10501 Builder.restoreIP(IP: SavedIP);
10502 return MapperFn;
10503}
10504
10505Error OpenMPIRBuilder::emitOffloadingArrays(
10506 InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
10507 TargetDataInfo &Info, CustomMapperCallbackTy CustomMapperCB,
10508 bool IsNonContiguous,
10509 function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
10510
10511 // Reset the array information.
10512 Info.clearArrayInfo();
10513 Info.NumberOfPtrs = CombinedInfo.BasePointers.size();
10514
10515 if (Info.NumberOfPtrs == 0)
10516 return Error::success();
10517
10518 Builder.restoreIP(IP: AllocaIP);
10519 // Detect if we have any capture size requiring runtime evaluation of the
10520 // size so that a constant array could be eventually used.
10521 ArrayType *PointerArrayType =
10522 ArrayType::get(ElementType: Builder.getPtrTy(), NumElements: Info.NumberOfPtrs);
10523
10524 Info.RTArgs.BasePointersArray = Builder.CreateAlloca(
10525 Ty: PointerArrayType, /* ArraySize = */ nullptr, Name: ".offload_baseptrs");
10526
10527 Info.RTArgs.PointersArray = Builder.CreateAlloca(
10528 Ty: PointerArrayType, /* ArraySize = */ nullptr, Name: ".offload_ptrs");
10529 AllocaInst *MappersArray = Builder.CreateAlloca(
10530 Ty: PointerArrayType, /* ArraySize = */ nullptr, Name: ".offload_mappers");
10531 Info.RTArgs.MappersArray = MappersArray;
10532
10533 // If we don't have any VLA types or other types that require runtime
10534 // evaluation, we can use a constant array for the map sizes, otherwise we
10535 // need to fill up the arrays as we do for the pointers.
10536 Type *Int64Ty = Builder.getInt64Ty();
10537 SmallVector<Constant *> ConstSizes(CombinedInfo.Sizes.size(),
10538 ConstantInt::get(Ty: Int64Ty, V: 0));
10539 SmallBitVector RuntimeSizes(CombinedInfo.Sizes.size());
10540 for (unsigned I = 0, E = CombinedInfo.Sizes.size(); I < E; ++I) {
10541 bool IsNonContigEntry =
10542 IsNonContiguous &&
10543 (static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
10544 CombinedInfo.Types[I] &
10545 OpenMPOffloadMappingFlags::OMP_MAP_NON_CONTIG) != 0);
10546 // For NON_CONTIG entries, ArgSizes stores the dimension count (number of
10547 // descriptor_dim records), not the byte size.
10548 if (IsNonContigEntry) {
10549 assert(I < CombinedInfo.NonContigInfo.Dims.size() &&
10550 "Index must be in-bounds for NON_CONTIG Dims array");
10551 const uint64_t DimCount = CombinedInfo.NonContigInfo.Dims[I];
10552 assert(DimCount > 0 && "NON_CONTIG DimCount must be > 0");
10553 ConstSizes[I] = ConstantInt::get(Ty: Int64Ty, V: DimCount);
10554 continue;
10555 }
10556 if (auto *CI = dyn_cast<Constant>(Val: CombinedInfo.Sizes[I])) {
10557 if (!isa<ConstantExpr>(Val: CI) && !isa<GlobalValue>(Val: CI)) {
10558 ConstSizes[I] = CI;
10559 continue;
10560 }
10561 }
10562 RuntimeSizes.set(I);
10563 }
10564
10565 if (RuntimeSizes.all()) {
10566 ArrayType *SizeArrayType = ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs);
10567 Info.RTArgs.SizesArray = Builder.CreateAlloca(
10568 Ty: SizeArrayType, /* ArraySize = */ nullptr, Name: ".offload_sizes");
10569 restoreIPandDebugLoc(Builder, IP: CodeGenIP);
10570 } else {
10571 auto *SizesArrayInit = ConstantArray::get(
10572 T: ArrayType::get(ElementType: Int64Ty, NumElements: ConstSizes.size()), V: ConstSizes);
10573 std::string Name = createPlatformSpecificName(Parts: {"offload_sizes"});
10574 auto *SizesArrayGbl =
10575 new GlobalVariable(M, SizesArrayInit->getType(), /*isConstant=*/true,
10576 GlobalValue::PrivateLinkage, SizesArrayInit, Name);
10577 SizesArrayGbl->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
10578
10579 if (!RuntimeSizes.any()) {
10580 Info.RTArgs.SizesArray = SizesArrayGbl;
10581 } else {
10582 unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(AS: 0);
10583 Align OffloadSizeAlign = M.getDataLayout().getABIIntegerTypeAlignment(BitWidth: 64);
10584 ArrayType *SizeArrayType = ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs);
10585 AllocaInst *Buffer = Builder.CreateAlloca(
10586 Ty: SizeArrayType, /* ArraySize = */ nullptr, Name: ".offload_sizes");
10587 Buffer->setAlignment(OffloadSizeAlign);
10588 restoreIPandDebugLoc(Builder, IP: CodeGenIP);
10589 Builder.CreateMemCpy(
10590 Dst: Buffer, DstAlign: M.getDataLayout().getPrefTypeAlign(Ty: Buffer->getType()),
10591 Src: SizesArrayGbl, SrcAlign: OffloadSizeAlign,
10592 Size: Builder.getIntN(
10593 N: IndexSize,
10594 C: Buffer->getAllocationSize(DL: M.getDataLayout())->getFixedValue()));
10595
10596 Info.RTArgs.SizesArray = Buffer;
10597 }
10598 restoreIPandDebugLoc(Builder, IP: CodeGenIP);
10599 }
10600
10601 // The map types are always constant so we don't need to generate code to
10602 // fill arrays. Instead, we create an array constant.
10603 SmallVector<uint64_t, 4> Mapping;
10604 for (auto mapFlag : CombinedInfo.Types)
10605 Mapping.push_back(
10606 Elt: static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
10607 mapFlag));
10608 std::string MaptypesName = createPlatformSpecificName(Parts: {"offload_maptypes"});
10609 auto *MapTypesArrayGbl = createOffloadMaptypes(Mappings&: Mapping, VarName: MaptypesName);
10610 Info.RTArgs.MapTypesArray = MapTypesArrayGbl;
10611
10612 // The information types are only built if provided.
10613 if (!CombinedInfo.Names.empty()) {
10614 auto *MapNamesArrayGbl = createOffloadMapnames(
10615 Names&: CombinedInfo.Names, VarName: createPlatformSpecificName(Parts: {"offload_mapnames"}));
10616 Info.RTArgs.MapNamesArray = MapNamesArrayGbl;
10617 Info.EmitDebug = true;
10618 } else {
10619 Info.RTArgs.MapNamesArray =
10620 Constant::getNullValue(Ty: PointerType::getUnqual(C&: Builder.getContext()));
10621 Info.EmitDebug = false;
10622 }
10623
10624 // If there's a present map type modifier, it must not be applied to the end
10625 // of a region, so generate a separate map type array in that case.
10626 if (Info.separateBeginEndCalls()) {
10627 bool EndMapTypesDiffer = false;
10628 for (uint64_t &Type : Mapping) {
10629 if (Type & static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
10630 OpenMPOffloadMappingFlags::OMP_MAP_PRESENT)) {
10631 Type &= ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
10632 OpenMPOffloadMappingFlags::OMP_MAP_PRESENT);
10633 EndMapTypesDiffer = true;
10634 }
10635 }
10636 if (EndMapTypesDiffer) {
10637 MapTypesArrayGbl = createOffloadMaptypes(Mappings&: Mapping, VarName: MaptypesName);
10638 Info.RTArgs.MapTypesArrayEnd = MapTypesArrayGbl;
10639 }
10640 }
10641
10642 PointerType *PtrTy = Builder.getPtrTy();
10643 for (unsigned I = 0; I < Info.NumberOfPtrs; ++I) {
10644 Value *BPVal = CombinedInfo.BasePointers[I];
10645 Value *BP = Builder.CreateConstInBoundsGEP2_32(
10646 Ty: ArrayType::get(ElementType: PtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.BasePointersArray,
10647 Idx0: 0, Idx1: I);
10648 Builder.CreateAlignedStore(Val: BPVal, Ptr: BP,
10649 Align: M.getDataLayout().getPrefTypeAlign(Ty: PtrTy));
10650
10651 if (Info.requiresDevicePointerInfo()) {
10652 if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Pointer) {
10653 CodeGenIP = Builder.saveIP();
10654 Builder.restoreIP(IP: AllocaIP);
10655 Info.DevicePtrInfoMap[BPVal] = {BP, Builder.CreateAlloca(Ty: PtrTy)};
10656 Builder.restoreIP(IP: CodeGenIP);
10657 if (DeviceAddrCB)
10658 DeviceAddrCB(I, Info.DevicePtrInfoMap[BPVal].second);
10659 } else if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Address) {
10660 Info.DevicePtrInfoMap[BPVal] = {BP, BP};
10661 if (DeviceAddrCB)
10662 DeviceAddrCB(I, BP);
10663 }
10664 }
10665
10666 Value *PVal = CombinedInfo.Pointers[I];
10667 Value *P = Builder.CreateConstInBoundsGEP2_32(
10668 Ty: ArrayType::get(ElementType: PtrTy, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.PointersArray, Idx0: 0,
10669 Idx1: I);
10670 // TODO: Check alignment correct.
10671 Builder.CreateAlignedStore(Val: PVal, Ptr: P,
10672 Align: M.getDataLayout().getPrefTypeAlign(Ty: PtrTy));
10673
10674 if (RuntimeSizes.test(Idx: I)) {
10675 Value *S = Builder.CreateConstInBoundsGEP2_32(
10676 Ty: ArrayType::get(ElementType: Int64Ty, NumElements: Info.NumberOfPtrs), Ptr: Info.RTArgs.SizesArray,
10677 /*Idx0=*/0,
10678 /*Idx1=*/I);
10679 Builder.CreateAlignedStore(Val: Builder.CreateIntCast(V: CombinedInfo.Sizes[I],
10680 DestTy: Int64Ty,
10681 /*isSigned=*/true),
10682 Ptr: S, Align: M.getDataLayout().getPrefTypeAlign(Ty: PtrTy));
10683 }
10684 // Fill up the mapper array.
10685 unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(AS: 0);
10686 Value *MFunc = ConstantPointerNull::get(T: PtrTy);
10687
10688 auto CustomMFunc = CustomMapperCB(I);
10689 if (!CustomMFunc)
10690 return CustomMFunc.takeError();
10691 if (*CustomMFunc)
10692 MFunc = Builder.CreatePointerCast(V: *CustomMFunc, DestTy: PtrTy);
10693
10694 Value *MAddr = Builder.CreateInBoundsGEP(
10695 Ty: PointerArrayType, Ptr: MappersArray,
10696 IdxList: {Builder.getIntN(N: IndexSize, C: 0), Builder.getIntN(N: IndexSize, C: I)});
10697 Builder.CreateAlignedStore(
10698 Val: MFunc, Ptr: MAddr, Align: M.getDataLayout().getPrefTypeAlign(Ty: MAddr->getType()));
10699 }
10700
10701 if (!IsNonContiguous || CombinedInfo.NonContigInfo.Offsets.empty() ||
10702 Info.NumberOfPtrs == 0)
10703 return Error::success();
10704 emitNonContiguousDescriptor(AllocaIP, CodeGenIP, CombinedInfo, Info);
10705 return Error::success();
10706}
10707
10708void OpenMPIRBuilder::emitBranch(BasicBlock *Target) {
10709 BasicBlock *CurBB = Builder.GetInsertBlock();
10710
10711 if (!CurBB || CurBB->hasTerminator()) {
10712 // If there is no insert point or the previous block is already
10713 // terminated, don't touch it.
10714 } else {
10715 // Otherwise, create a fall-through branch.
10716 Builder.CreateBr(Dest: Target);
10717 }
10718
10719 Builder.ClearInsertionPoint();
10720}
10721
10722void OpenMPIRBuilder::emitBlock(BasicBlock *BB, Function *CurFn,
10723 bool IsFinished) {
10724 BasicBlock *CurBB = Builder.GetInsertBlock();
10725
10726 // Fall out of the current block (if necessary).
10727 emitBranch(Target: BB);
10728
10729 if (IsFinished && BB->use_empty()) {
10730 BB->eraseFromParent();
10731 return;
10732 }
10733
10734 // Place the block after the current block, if possible, or else at
10735 // the end of the function.
10736 if (CurBB && CurBB->getParent())
10737 CurFn->insert(Position: std::next(x: CurBB->getIterator()), BB);
10738 else
10739 CurFn->insert(Position: CurFn->end(), BB);
10740 Builder.SetInsertPoint(BB);
10741}
10742
10743Error OpenMPIRBuilder::emitIfClause(Value *Cond, BodyGenCallbackTy ThenGen,
10744 BodyGenCallbackTy ElseGen,
10745 InsertPointTy AllocaIP,
10746 ArrayRef<BasicBlock *> DeallocBlocks) {
10747 // If the condition constant folds and can be elided, try to avoid emitting
10748 // the condition and the dead arm of the if/else.
10749 if (auto *CI = dyn_cast<ConstantInt>(Val: Cond)) {
10750 auto CondConstant = CI->getSExtValue();
10751 if (CondConstant)
10752 return ThenGen(AllocaIP, Builder.saveIP(), DeallocBlocks);
10753
10754 return ElseGen(AllocaIP, Builder.saveIP(), DeallocBlocks);
10755 }
10756
10757 Function *CurFn = Builder.GetInsertBlock()->getParent();
10758
10759 // Otherwise, the condition did not fold, or we couldn't elide it. Just
10760 // emit the conditional branch.
10761 BasicBlock *ThenBlock = BasicBlock::Create(Context&: M.getContext(), Name: "omp_if.then");
10762 BasicBlock *ElseBlock = BasicBlock::Create(Context&: M.getContext(), Name: "omp_if.else");
10763 BasicBlock *ContBlock = BasicBlock::Create(Context&: M.getContext(), Name: "omp_if.end");
10764 Builder.CreateCondBr(Cond, True: ThenBlock, False: ElseBlock);
10765 // Emit the 'then' code.
10766 emitBlock(BB: ThenBlock, CurFn);
10767 if (Error Err = ThenGen(AllocaIP, Builder.saveIP(), DeallocBlocks))
10768 return Err;
10769 emitBranch(Target: ContBlock);
10770 // Emit the 'else' code if present.
10771 // There is no need to emit line number for unconditional branch.
10772 emitBlock(BB: ElseBlock, CurFn);
10773 if (Error Err = ElseGen(AllocaIP, Builder.saveIP(), DeallocBlocks))
10774 return Err;
10775 // There is no need to emit line number for unconditional branch.
10776 emitBranch(Target: ContBlock);
10777 // Emit the continuation block for code after the if.
10778 emitBlock(BB: ContBlock, CurFn, /*IsFinished=*/true);
10779 return Error::success();
10780}
10781
10782bool OpenMPIRBuilder::checkAndEmitFlushAfterAtomic(
10783 const LocationDescription &Loc, llvm::AtomicOrdering AO, AtomicKind AK) {
10784 assert(!(AO == AtomicOrdering::NotAtomic ||
10785 AO == llvm::AtomicOrdering::Unordered) &&
10786 "Unexpected Atomic Ordering.");
10787
10788 bool Flush = false;
10789 llvm::AtomicOrdering FlushAO = AtomicOrdering::Monotonic;
10790
10791 switch (AK) {
10792 case Read:
10793 if (AO == AtomicOrdering::Acquire || AO == AtomicOrdering::AcquireRelease ||
10794 AO == AtomicOrdering::SequentiallyConsistent) {
10795 FlushAO = AtomicOrdering::Acquire;
10796 Flush = true;
10797 }
10798 break;
10799 case Write:
10800 case Compare:
10801 case Update:
10802 if (AO == AtomicOrdering::Release || AO == AtomicOrdering::AcquireRelease ||
10803 AO == AtomicOrdering::SequentiallyConsistent) {
10804 FlushAO = AtomicOrdering::Release;
10805 Flush = true;
10806 }
10807 break;
10808 case Capture:
10809 switch (AO) {
10810 case AtomicOrdering::Acquire:
10811 FlushAO = AtomicOrdering::Acquire;
10812 Flush = true;
10813 break;
10814 case AtomicOrdering::Release:
10815 FlushAO = AtomicOrdering::Release;
10816 Flush = true;
10817 break;
10818 case AtomicOrdering::AcquireRelease:
10819 case AtomicOrdering::SequentiallyConsistent:
10820 FlushAO = AtomicOrdering::AcquireRelease;
10821 Flush = true;
10822 break;
10823 default:
10824 // do nothing - leave silently.
10825 break;
10826 }
10827 }
10828
10829 if (Flush) {
10830 // Currently Flush RT call still doesn't take memory_ordering, so for when
10831 // that happens, this tries to do the resolution of which atomic ordering
10832 // to use with but issue the flush call
10833 // TODO: pass `FlushAO` after memory ordering support is added
10834 (void)FlushAO;
10835 emitFlush(Loc);
10836 }
10837
10838 // for AO == AtomicOrdering::Monotonic and all other case combinations
10839 // do nothing
10840 return Flush;
10841}
10842
10843OpenMPIRBuilder::InsertPointTy
10844OpenMPIRBuilder::createAtomicRead(const LocationDescription &Loc,
10845 AtomicOpValue &X, AtomicOpValue &V,
10846 AtomicOrdering AO, InsertPointTy AllocaIP) {
10847 if (!updateToLocation(Loc))
10848 return Loc.IP;
10849
10850 assert(X.Var->getType()->isPointerTy() &&
10851 "OMP Atomic expects a pointer to target memory");
10852 Type *XElemTy = X.ElemTy;
10853 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
10854 XElemTy->isPointerTy() || XElemTy->isStructTy()) &&
10855 "OMP atomic read expected a scalar type");
10856
10857 Value *XRead = nullptr;
10858
10859 if (XElemTy->isIntegerTy()) {
10860 LoadInst *XLD =
10861 Builder.CreateLoad(Ty: XElemTy, Ptr: X.Var, isVolatile: X.IsVolatile, Name: "omp.atomic.read");
10862 XLD->setAtomic(Ordering: AO);
10863 XRead = cast<Value>(Val: XLD);
10864 } else if (XElemTy->isStructTy()) {
10865 // FIXME: Add checks to ensure __atomic_load is emitted iff the
10866 // target does not support `atomicrmw` of the size of the struct
10867 LoadInst *OldVal = Builder.CreateLoad(Ty: XElemTy, Ptr: X.Var, Name: "omp.atomic.read");
10868 OldVal->setAtomic(Ordering: AO);
10869 const DataLayout &DL = OldVal->getModule()->getDataLayout();
10870 unsigned LoadSize = DL.getTypeStoreSize(Ty: XElemTy);
10871 OpenMPIRBuilder::AtomicInfo atomicInfo(
10872 &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
10873 OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X.Var);
10874 auto AtomicLoadRes = atomicInfo.EmitAtomicLoadLibcall(AO);
10875 XRead = AtomicLoadRes.first;
10876 OldVal->eraseFromParent();
10877 } else {
10878 // We need to perform atomic op as integer
10879 IntegerType *IntCastTy =
10880 IntegerType::get(C&: M.getContext(), NumBits: XElemTy->getScalarSizeInBits());
10881 LoadInst *XLoad =
10882 Builder.CreateLoad(Ty: IntCastTy, Ptr: X.Var, isVolatile: X.IsVolatile, Name: "omp.atomic.load");
10883 XLoad->setAtomic(Ordering: AO);
10884 if (XElemTy->isFloatingPointTy()) {
10885 XRead = Builder.CreateBitCast(V: XLoad, DestTy: XElemTy, Name: "atomic.flt.cast");
10886 } else {
10887 XRead = Builder.CreateIntToPtr(V: XLoad, DestTy: XElemTy, Name: "atomic.ptr.cast");
10888 }
10889 }
10890 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Read);
10891 Builder.CreateStore(Val: XRead, Ptr: V.Var, isVolatile: V.IsVolatile);
10892 return Builder.saveIP();
10893}
10894
10895OpenMPIRBuilder::InsertPointTy
10896OpenMPIRBuilder::createAtomicWrite(const LocationDescription &Loc,
10897 AtomicOpValue &X, Value *Expr,
10898 AtomicOrdering AO, InsertPointTy AllocaIP) {
10899 if (!updateToLocation(Loc))
10900 return Loc.IP;
10901
10902 assert(X.Var->getType()->isPointerTy() &&
10903 "OMP Atomic expects a pointer to target memory");
10904 Type *XElemTy = X.ElemTy;
10905 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
10906 XElemTy->isPointerTy() || XElemTy->isStructTy()) &&
10907 "OMP atomic write expected a scalar type");
10908
10909 if (XElemTy->isIntegerTy()) {
10910 StoreInst *XSt = Builder.CreateStore(Val: Expr, Ptr: X.Var, isVolatile: X.IsVolatile);
10911 XSt->setAtomic(Ordering: AO);
10912 } else if (XElemTy->isStructTy()) {
10913 LoadInst *OldVal = Builder.CreateLoad(Ty: XElemTy, Ptr: X.Var, Name: "omp.atomic.read");
10914 const DataLayout &DL = OldVal->getModule()->getDataLayout();
10915 unsigned LoadSize = DL.getTypeStoreSize(Ty: XElemTy);
10916 OpenMPIRBuilder::AtomicInfo atomicInfo(
10917 &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
10918 OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X.Var);
10919 atomicInfo.EmitAtomicStoreLibcall(AO, Source: Expr);
10920 OldVal->eraseFromParent();
10921 } else {
10922 // We need to bitcast and perform atomic op as integers
10923 IntegerType *IntCastTy =
10924 IntegerType::get(C&: M.getContext(), NumBits: XElemTy->getScalarSizeInBits());
10925 Value *ExprCast =
10926 Builder.CreateBitCast(V: Expr, DestTy: IntCastTy, Name: "atomic.src.int.cast");
10927 StoreInst *XSt = Builder.CreateStore(Val: ExprCast, Ptr: X.Var, isVolatile: X.IsVolatile);
10928 XSt->setAtomic(Ordering: AO);
10929 }
10930
10931 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Write);
10932 return Builder.saveIP();
10933}
10934
10935OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createAtomicUpdate(
10936 const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
10937 Value *Expr, AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
10938 AtomicUpdateCallbackTy &UpdateOp, bool IsXBinopExpr,
10939 bool IsIgnoreDenormalMode, bool IsFineGrainedMemory, bool IsRemoteMemory) {
10940 assert(!isConflictIP(Loc.IP, AllocaIP) && "IPs must not be ambiguous");
10941 if (!updateToLocation(Loc))
10942 return Loc.IP;
10943
10944 LLVM_DEBUG({
10945 Type *XTy = X.Var->getType();
10946 assert(XTy->isPointerTy() &&
10947 "OMP Atomic expects a pointer to target memory");
10948 Type *XElemTy = X.ElemTy;
10949 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
10950 XElemTy->isPointerTy() || XElemTy->isStructTy()) &&
10951 "OMP atomic update expected a scalar or struct type");
10952 assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
10953 (RMWOp != AtomicRMWInst::UMax) && (RMWOp != AtomicRMWInst::UMin) &&
10954 "OpenMP atomic does not support LT or GT operations");
10955 });
10956
10957 Expected<std::pair<Value *, Value *>> AtomicResult = emitAtomicUpdate(
10958 AllocaIP, X: X.Var, XElemTy: X.ElemTy, Expr, AO, RMWOp, UpdateOp, VolatileX: X.IsVolatile,
10959 IsXBinopExpr, IsIgnoreDenormalMode, IsFineGrainedMemory, IsRemoteMemory);
10960 if (!AtomicResult)
10961 return AtomicResult.takeError();
10962 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Update);
10963 return Builder.saveIP();
10964}
10965
10966// FIXME: Duplicating AtomicExpand
10967Value *OpenMPIRBuilder::emitRMWOpAsInstruction(Value *Src1, Value *Src2,
10968 AtomicRMWInst::BinOp RMWOp) {
10969 switch (RMWOp) {
10970 case AtomicRMWInst::Add:
10971 return Builder.CreateAdd(LHS: Src1, RHS: Src2);
10972 case AtomicRMWInst::Sub:
10973 return Builder.CreateSub(LHS: Src1, RHS: Src2);
10974 case AtomicRMWInst::And:
10975 return Builder.CreateAnd(LHS: Src1, RHS: Src2);
10976 case AtomicRMWInst::Nand:
10977 return Builder.CreateNeg(V: Builder.CreateAnd(LHS: Src1, RHS: Src2));
10978 case AtomicRMWInst::Or:
10979 return Builder.CreateOr(LHS: Src1, RHS: Src2);
10980 case AtomicRMWInst::Xor:
10981 return Builder.CreateXor(LHS: Src1, RHS: Src2);
10982 case AtomicRMWInst::Xchg:
10983 case AtomicRMWInst::FAdd:
10984 case AtomicRMWInst::FSub:
10985 case AtomicRMWInst::BAD_BINOP:
10986 case AtomicRMWInst::Max:
10987 case AtomicRMWInst::Min:
10988 case AtomicRMWInst::UMax:
10989 case AtomicRMWInst::UMin:
10990 case AtomicRMWInst::FMax:
10991 case AtomicRMWInst::FMin:
10992 case AtomicRMWInst::FMaximum:
10993 case AtomicRMWInst::FMinimum:
10994 case AtomicRMWInst::FMaximumNum:
10995 case AtomicRMWInst::FMinimumNum:
10996 case AtomicRMWInst::UIncWrap:
10997 case AtomicRMWInst::UDecWrap:
10998 case AtomicRMWInst::USubCond:
10999 case AtomicRMWInst::USubSat:
11000 llvm_unreachable("Unsupported atomic update operation");
11001 }
11002 llvm_unreachable("Unsupported atomic update operation");
11003}
11004
11005static AtomicOrdering TransformReleaseAcquireRelease(AtomicOrdering AO) {
11006 // Loads cannot use Release or AcquireRelease ordering. This load is
11007 // just the initial value for the cmpxchg loop; the cmpxchg itself
11008 // retains the original ordering.
11009 AtomicOrdering LoadAO = AO;
11010
11011 if (AO == AtomicOrdering::Release) {
11012 LoadAO = AtomicOrdering::Monotonic;
11013 } else if (AO == AtomicOrdering::AcquireRelease) {
11014 LoadAO = AtomicOrdering::Acquire;
11015 }
11016
11017 return LoadAO;
11018}
11019
11020Expected<std::pair<Value *, Value *>> OpenMPIRBuilder::emitAtomicUpdate(
11021 InsertPointTy AllocaIP, Value *X, Type *XElemTy, Value *Expr,
11022 AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
11023 AtomicUpdateCallbackTy &UpdateOp, bool VolatileX, bool IsXBinopExpr,
11024 bool IsIgnoreDenormalMode, bool IsFineGrainedMemory, bool IsRemoteMemory) {
11025 // TODO: handle the case where XElemTy is not byte-sized or not a power of 2.
11026 bool emitRMWOp = false;
11027 switch (RMWOp) {
11028 case AtomicRMWInst::Add:
11029 case AtomicRMWInst::And:
11030 case AtomicRMWInst::Nand:
11031 case AtomicRMWInst::Or:
11032 case AtomicRMWInst::Xor:
11033 case AtomicRMWInst::Xchg:
11034 emitRMWOp = XElemTy;
11035 break;
11036 case AtomicRMWInst::Sub:
11037 emitRMWOp = (IsXBinopExpr && XElemTy);
11038 break;
11039 default:
11040 emitRMWOp = false;
11041 }
11042 emitRMWOp &= XElemTy->isIntegerTy();
11043
11044 std::pair<Value *, Value *> Res;
11045 if (emitRMWOp) {
11046 AtomicRMWInst *RMWInst =
11047 Builder.CreateAtomicRMW(Op: RMWOp, Ptr: X, Val: Expr, Align: llvm::MaybeAlign(), Ordering: AO);
11048 if (T.isAMDGPU()) {
11049 if (IsIgnoreDenormalMode)
11050 RMWInst->setMetadata(Kind: "amdgpu.ignore.denormal.mode",
11051 Node: llvm::MDNode::get(Context&: Builder.getContext(), MDs: {}));
11052 if (!IsFineGrainedMemory)
11053 RMWInst->setMetadata(Kind: "amdgpu.no.fine.grained.memory",
11054 Node: llvm::MDNode::get(Context&: Builder.getContext(), MDs: {}));
11055 if (!IsRemoteMemory)
11056 RMWInst->setMetadata(Kind: "amdgpu.no.remote.memory",
11057 Node: llvm::MDNode::get(Context&: Builder.getContext(), MDs: {}));
11058 }
11059 Res.first = RMWInst;
11060 // not needed except in case of postfix captures. Generate anyway for
11061 // consistency with the else part. Will be removed with any DCE pass.
11062 // AtomicRMWInst::Xchg does not have a coressponding instruction.
11063 if (RMWOp == AtomicRMWInst::Xchg)
11064 Res.second = Res.first;
11065 else
11066 Res.second = emitRMWOpAsInstruction(Src1: Res.first, Src2: Expr, RMWOp);
11067 } else if (XElemTy->isStructTy()) {
11068 LoadInst *OldVal =
11069 Builder.CreateLoad(Ty: XElemTy, Ptr: X, Name: X->getName() + ".atomic.load");
11070 AtomicOrdering LoadAO = TransformReleaseAcquireRelease(AO);
11071 OldVal->setAtomic(Ordering: LoadAO);
11072 const DataLayout &LoadDL = OldVal->getModule()->getDataLayout();
11073 unsigned LoadSize = LoadDL.getTypeStoreSize(Ty: XElemTy);
11074
11075 OpenMPIRBuilder::AtomicInfo atomicInfo(
11076 &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
11077 OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X);
11078 auto AtomicLoadRes = atomicInfo.EmitAtomicLoadLibcall(AO);
11079 BasicBlock *CurBB = Builder.GetInsertBlock();
11080 Instruction *CurBBTI = CurBB->getTerminatorOrNull();
11081 CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
11082 BasicBlock *ExitBB =
11083 CurBB->splitBasicBlock(I: CurBBTI, BBName: X->getName() + ".atomic.exit");
11084 BasicBlock *ContBB = CurBB->splitBasicBlock(I: CurBB->getTerminator(),
11085 BBName: X->getName() + ".atomic.cont");
11086 ContBB->getTerminator()->eraseFromParent();
11087 Builder.restoreIP(IP: AllocaIP);
11088 AllocaInst *NewAtomicAddr = Builder.CreateAlloca(Ty: XElemTy);
11089 NewAtomicAddr->setName(X->getName() + "x.new.val");
11090 Builder.SetInsertPoint(ContBB);
11091 llvm::PHINode *PHI = Builder.CreatePHI(Ty: OldVal->getType(), NumReservedValues: 2);
11092 PHI->addIncoming(V: AtomicLoadRes.first, BB: CurBB);
11093 Value *OldExprVal = PHI;
11094 Expected<Value *> CBResult = UpdateOp(OldExprVal, Builder);
11095 if (!CBResult)
11096 return CBResult.takeError();
11097 Value *Upd = *CBResult;
11098 Builder.CreateStore(Val: Upd, Ptr: NewAtomicAddr);
11099 AtomicOrdering Failure =
11100 llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(SuccessOrdering: AO);
11101 auto Result = atomicInfo.EmitAtomicCompareExchangeLibcall(
11102 ExpectedVal: AtomicLoadRes.second, DesiredVal: NewAtomicAddr, Success: AO, Failure);
11103 LoadInst *PHILoad = Builder.CreateLoad(Ty: XElemTy, Ptr: Result.first);
11104 PHI->addIncoming(V: PHILoad, BB: Builder.GetInsertBlock());
11105 Builder.CreateCondBr(Cond: Result.second, True: ExitBB, False: ContBB);
11106 OldVal->eraseFromParent();
11107 Res.first = OldExprVal;
11108 Res.second = Upd;
11109
11110 if (UnreachableInst *ExitTI =
11111 dyn_cast<UnreachableInst>(Val: ExitBB->getTerminator())) {
11112 CurBBTI->eraseFromParent();
11113 Builder.SetInsertPoint(ExitBB);
11114 } else {
11115 Builder.SetInsertPoint(ExitTI);
11116 }
11117 } else {
11118 IntegerType *IntCastTy =
11119 IntegerType::get(C&: M.getContext(), NumBits: XElemTy->getScalarSizeInBits());
11120 LoadInst *OldVal =
11121 Builder.CreateLoad(Ty: IntCastTy, Ptr: X, Name: X->getName() + ".atomic.load");
11122 AtomicOrdering LoadAO = TransformReleaseAcquireRelease(AO);
11123 OldVal->setAtomic(Ordering: LoadAO);
11124 // CurBB
11125 // | /---\
11126 // ContBB |
11127 // | \---/
11128 // ExitBB
11129 BasicBlock *CurBB = Builder.GetInsertBlock();
11130 Instruction *CurBBTI = CurBB->getTerminatorOrNull();
11131 CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
11132 BasicBlock *ExitBB =
11133 CurBB->splitBasicBlock(I: CurBBTI, BBName: X->getName() + ".atomic.exit");
11134 BasicBlock *ContBB = CurBB->splitBasicBlock(I: CurBB->getTerminator(),
11135 BBName: X->getName() + ".atomic.cont");
11136 ContBB->getTerminator()->eraseFromParent();
11137 Builder.restoreIP(IP: AllocaIP);
11138 AllocaInst *NewAtomicAddr = Builder.CreateAlloca(Ty: XElemTy);
11139 NewAtomicAddr->setName(X->getName() + "x.new.val");
11140 Builder.SetInsertPoint(ContBB);
11141 llvm::PHINode *PHI = Builder.CreatePHI(Ty: OldVal->getType(), NumReservedValues: 2);
11142 PHI->addIncoming(V: OldVal, BB: CurBB);
11143 bool IsIntTy = XElemTy->isIntegerTy();
11144 Value *OldExprVal = PHI;
11145 if (!IsIntTy) {
11146 if (XElemTy->isFloatingPointTy()) {
11147 OldExprVal = Builder.CreateBitCast(V: PHI, DestTy: XElemTy,
11148 Name: X->getName() + ".atomic.fltCast");
11149 } else {
11150 OldExprVal = Builder.CreateIntToPtr(V: PHI, DestTy: XElemTy,
11151 Name: X->getName() + ".atomic.ptrCast");
11152 }
11153 }
11154
11155 Expected<Value *> CBResult = UpdateOp(OldExprVal, Builder);
11156 if (!CBResult)
11157 return CBResult.takeError();
11158 Value *Upd = *CBResult;
11159 Builder.CreateStore(Val: Upd, Ptr: NewAtomicAddr);
11160 LoadInst *DesiredVal = Builder.CreateLoad(Ty: IntCastTy, Ptr: NewAtomicAddr);
11161 AtomicOrdering Failure =
11162 llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(SuccessOrdering: AO);
11163 AtomicCmpXchgInst *Result = Builder.CreateAtomicCmpXchg(
11164 Ptr: X, Cmp: PHI, New: DesiredVal, Align: llvm::MaybeAlign(), SuccessOrdering: AO, FailureOrdering: Failure);
11165 Result->setVolatile(VolatileX);
11166 Value *PreviousVal = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/0);
11167 Value *SuccessFailureVal = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/1);
11168 PHI->addIncoming(V: PreviousVal, BB: Builder.GetInsertBlock());
11169 Builder.CreateCondBr(Cond: SuccessFailureVal, True: ExitBB, False: ContBB);
11170
11171 Res.first = OldExprVal;
11172 Res.second = Upd;
11173
11174 // set Insertion point in exit block
11175 if (UnreachableInst *ExitTI =
11176 dyn_cast<UnreachableInst>(Val: ExitBB->getTerminator())) {
11177 CurBBTI->eraseFromParent();
11178 Builder.SetInsertPoint(ExitBB);
11179 } else {
11180 Builder.SetInsertPoint(ExitTI);
11181 }
11182 }
11183
11184 return Res;
11185}
11186
11187OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createAtomicCapture(
11188 const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
11189 AtomicOpValue &V, Value *Expr, AtomicOrdering AO,
11190 AtomicRMWInst::BinOp RMWOp, AtomicUpdateCallbackTy &UpdateOp,
11191 bool UpdateExpr, bool IsPostfixUpdate, bool IsXBinopExpr,
11192 bool IsIgnoreDenormalMode, bool IsFineGrainedMemory, bool IsRemoteMemory) {
11193 if (!updateToLocation(Loc))
11194 return Loc.IP;
11195
11196 LLVM_DEBUG({
11197 Type *XTy = X.Var->getType();
11198 assert(XTy->isPointerTy() &&
11199 "OMP Atomic expects a pointer to target memory");
11200 Type *XElemTy = X.ElemTy;
11201 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
11202 XElemTy->isPointerTy() || XElemTy->isStructTy()) &&
11203 "OMP atomic capture expected a scalar or struct type");
11204 assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
11205 "OpenMP atomic does not support LT or GT operations");
11206 });
11207
11208 // If UpdateExpr is 'x' updated with some `expr` not based on 'x',
11209 // 'x' is simply atomically rewritten with 'expr'.
11210 AtomicRMWInst::BinOp AtomicOp = (UpdateExpr ? RMWOp : AtomicRMWInst::Xchg);
11211 Expected<std::pair<Value *, Value *>> AtomicResult = emitAtomicUpdate(
11212 AllocaIP, X: X.Var, XElemTy: X.ElemTy, Expr, AO, RMWOp: AtomicOp, UpdateOp, VolatileX: X.IsVolatile,
11213 IsXBinopExpr, IsIgnoreDenormalMode, IsFineGrainedMemory, IsRemoteMemory);
11214 if (!AtomicResult)
11215 return AtomicResult.takeError();
11216 Value *CapturedVal =
11217 (IsPostfixUpdate ? AtomicResult->first : AtomicResult->second);
11218 Builder.CreateStore(Val: CapturedVal, Ptr: V.Var, isVolatile: V.IsVolatile);
11219
11220 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Capture);
11221 return Builder.saveIP();
11222}
11223
11224OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
11225 const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
11226 AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
11227 omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
11228 bool IsFailOnly, bool IsWeak) {
11229
11230 AtomicOrdering Failure = AtomicCmpXchgInst::getStrongestFailureOrdering(SuccessOrdering: AO);
11231 return createAtomicCompare(Loc, X, V, R, E, D, AO, Op, IsXBinopExpr,
11232 IsPostfixUpdate, IsFailOnly, Failure, IsWeak);
11233}
11234
11235OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
11236 const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
11237 AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
11238 omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
11239 bool IsFailOnly, AtomicOrdering Failure, bool IsWeak) {
11240
11241 if (!updateToLocation(Loc))
11242 return Loc.IP;
11243
11244 assert(X.Var->getType()->isPointerTy() &&
11245 "OMP atomic expects a pointer to target memory");
11246 // compare capture
11247 if (V.Var) {
11248 assert(V.Var->getType()->isPointerTy() && "v.var must be of pointer type");
11249 assert(V.ElemTy == X.ElemTy && "x and v must be of same type");
11250 }
11251
11252 bool IsInteger = E->getType()->isIntegerTy();
11253
11254 if (Op == OMPAtomicCompareOp::EQ) {
11255 // OldValue and SuccessOrFail are set below and used in the shared V.Var /
11256 // R.Var handling.
11257 Value *OldValue = nullptr;
11258 Value *SuccessOrFail = nullptr;
11259
11260 if (!IsInteger && HandleFPNegZero) {
11261 // IEEE 754 special cases for cmpxchg (which is bitwise):
11262 // 1. -0.0 == +0.0 but they have different bit patterns.
11263 // 2. NaN != NaN but identical NaN bit patterns would match.
11264 //
11265 // CurBB:
11266 // %e_int = bitcast E to intN
11267 // %d_int = bitcast D to intN
11268 // %x_curr = load atomic intN, X
11269 // %x_fp = bitcast %x_curr to FP
11270 // %e_is_nan = fcmp uno E, E
11271 // %x_is_nan = fcmp uno %x_fp, %x_fp
11272 // %either_nan = or %e_is_nan, %x_is_nan
11273 // br %either_nan, NaNBB, NotNaNBB
11274 // NaNBB: ; NaN == anything is always false
11275 // br ExitBB
11276 // NotNaNBB:
11277 // %x_is_zero = fcmp oeq %x_fp, 0.0
11278 // %e_is_zero = fcmp oeq E, 0.0
11279 // %both_zero = and %x_is_zero, %e_is_zero
11280 // br %both_zero, ZeroBB, NormalBB
11281 // ZeroBB: ; both ±0.0 → x = d
11282 // cmpxchg X, %x_curr, %d_int
11283 // br ExitBB
11284 // NormalBB: ; original path
11285 // cmpxchg X, %e_int, %d_int
11286 // br ExitBB
11287 // ExitBB:
11288 // phi merge
11289 IntegerType *IntCastTy =
11290 IntegerType::get(C&: M.getContext(), NumBits: X.ElemTy->getScalarSizeInBits());
11291 Value *EBCast = Builder.CreateBitCast(V: E, DestTy: IntCastTy);
11292 Value *DBCast = Builder.CreateBitCast(V: D, DestTy: IntCastTy);
11293
11294 // Load X atomically.
11295 LoadInst *XCurr = Builder.CreateLoad(Ty: IntCastTy, Ptr: X.Var,
11296 Name: X.Var->getName() + ".atomic.load");
11297 XCurr->setAtomic(Ordering: AtomicOrdering::Monotonic);
11298 Value *XFP = Builder.CreateBitCast(V: XCurr, DestTy: X.ElemTy);
11299
11300 // IEEE 754: NaN != NaN, but cmpxchg would succeed if E and X have
11301 // the same NaN bit pattern. Skip cmpxchg when either is NaN.
11302 Value *EIsNaN = Builder.CreateFCmpUNO(LHS: E, RHS: E, Name: "atomic.e.isnan");
11303 Value *XIsNaN = Builder.CreateFCmpUNO(LHS: XFP, RHS: XFP, Name: "atomic.x.isnan");
11304 Value *EitherNaN = Builder.CreateOr(LHS: EIsNaN, RHS: XIsNaN, Name: "atomic.either.nan");
11305
11306 BasicBlock *CurBB = Builder.GetInsertBlock();
11307 Function *F = CurBB->getParent();
11308 Instruction *CurBBTI = CurBB->getTerminatorOrNull();
11309 CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
11310 BasicBlock *ExitBB =
11311 CurBB->splitBasicBlock(I: CurBBTI, BBName: X.Var->getName() + ".atomic.exit");
11312 BasicBlock *NaNBB = BasicBlock::Create(
11313 Context&: M.getContext(), Name: X.Var->getName() + ".atomic.nan", Parent: F, InsertBefore: ExitBB);
11314 BasicBlock *NotNaNBB = BasicBlock::Create(
11315 Context&: M.getContext(), Name: X.Var->getName() + ".atomic.notnan", Parent: F, InsertBefore: ExitBB);
11316 BasicBlock *ZeroBB = BasicBlock::Create(
11317 Context&: M.getContext(), Name: X.Var->getName() + ".atomic.zero", Parent: F, InsertBefore: ExitBB);
11318 BasicBlock *NormalBB = BasicBlock::Create(
11319 Context&: M.getContext(), Name: X.Var->getName() + ".atomic.normal", Parent: F, InsertBefore: ExitBB);
11320
11321 // If either E or X is NaN → NaNBB (always fails), else check for ±0.0.
11322 CurBB->getTerminator()->eraseFromParent();
11323 Builder.SetInsertPoint(CurBB);
11324 Builder.CreateCondBr(Cond: EitherNaN, True: NaNBB, False: NotNaNBB);
11325
11326 // NaNBB: NaN == anything is always false; skip cmpxchg.
11327 Builder.SetInsertPoint(NaNBB);
11328 Builder.CreateBr(Dest: ExitBB);
11329
11330 // NotNaNBB: check both X and E for ±0.0.
11331 Builder.SetInsertPoint(NotNaNBB);
11332 Value *XIsZero =
11333 Builder.CreateFCmpOEQ(LHS: XFP, RHS: ConstantFP::getZero(Ty: X.ElemTy),
11334 Name: X.Var->getName() + ".atomic.xiszero");
11335 Value *EIsZero = Builder.CreateFCmpOEQ(LHS: E, RHS: ConstantFP::getZero(Ty: X.ElemTy),
11336 Name: "atomic.e.iszero");
11337 Value *BothZero = Builder.CreateAnd(LHS: XIsZero, RHS: EIsZero, Name: "atomic.both.zero");
11338 Builder.CreateCondBr(Cond: BothZero, True: ZeroBB, False: NormalBB);
11339
11340 // ZeroBB: cmpxchg with X's loaded bit-pattern.
11341 Builder.SetInsertPoint(ZeroBB);
11342 AtomicCmpXchgInst *ResZero = Builder.CreateAtomicCmpXchg(
11343 Ptr: X.Var, Cmp: XCurr, New: DBCast, Align: MaybeAlign(), SuccessOrdering: AO, FailureOrdering: Failure);
11344 ResZero->setWeak(IsWeak);
11345 Value *OldZero = Builder.CreateExtractValue(Agg: ResZero, /*Idxs=*/0);
11346 Value *OkZero = Builder.CreateExtractValue(Agg: ResZero, /*Idxs=*/1);
11347 Builder.CreateBr(Dest: ExitBB);
11348
11349 // NormalBB: original bitwise cmpxchg.
11350 Builder.SetInsertPoint(NormalBB);
11351 AtomicCmpXchgInst *ResNormal = Builder.CreateAtomicCmpXchg(
11352 Ptr: X.Var, Cmp: EBCast, New: DBCast, Align: MaybeAlign(), SuccessOrdering: AO, FailureOrdering: Failure);
11353 ResNormal->setWeak(IsWeak);
11354 Value *OldNormal = Builder.CreateExtractValue(Agg: ResNormal, /*Idxs=*/0);
11355 Value *OkNormal = Builder.CreateExtractValue(Agg: ResNormal, /*Idxs=*/1);
11356 Builder.CreateBr(Dest: ExitBB);
11357
11358 // ExitBB: merge results from NaN, Zero, and Normal paths.
11359 Builder.SetInsertPoint(TheBB: ExitBB, IP: ExitBB->begin());
11360 PHINode *OldIntPHI =
11361 Builder.CreatePHI(Ty: IntCastTy, NumReservedValues: 3, Name: X.Var->getName() + ".atomic.old");
11362 OldIntPHI->addIncoming(V: XCurr, BB: NaNBB);
11363 OldIntPHI->addIncoming(V: OldZero, BB: ZeroBB);
11364 OldIntPHI->addIncoming(V: OldNormal, BB: NormalBB);
11365 PHINode *SuccessPHI = Builder.CreatePHI(Ty: Builder.getInt1Ty(), NumReservedValues: 3,
11366 Name: X.Var->getName() + ".atomic.ok");
11367 SuccessPHI->addIncoming(V: Builder.getFalse(), BB: NaNBB);
11368 SuccessPHI->addIncoming(V: OkZero, BB: ZeroBB);
11369 SuccessPHI->addIncoming(V: OkNormal, BB: NormalBB);
11370
11371 if (isa<UnreachableInst>(Val: ExitBB->getTerminator())) {
11372 CurBBTI->eraseFromParent();
11373 Builder.SetInsertPoint(ExitBB);
11374 } else {
11375 Builder.SetInsertPoint(&*ExitBB->getFirstNonPHIIt());
11376 }
11377
11378 OldValue = Builder.CreateBitCast(V: OldIntPHI, DestTy: X.ElemTy,
11379 Name: X.Var->getName() + ".atomic.old.fp");
11380 SuccessOrFail = SuccessPHI;
11381 } else {
11382 AtomicCmpXchgInst *Result = nullptr;
11383 if (!IsInteger) {
11384 IntegerType *IntCastTy =
11385 IntegerType::get(C&: M.getContext(), NumBits: X.ElemTy->getScalarSizeInBits());
11386 Value *EBCast = Builder.CreateBitCast(V: E, DestTy: IntCastTy);
11387 Value *DBCast = Builder.CreateBitCast(V: D, DestTy: IntCastTy);
11388 Result = Builder.CreateAtomicCmpXchg(Ptr: X.Var, Cmp: EBCast, New: DBCast,
11389 Align: MaybeAlign(), SuccessOrdering: AO, FailureOrdering: Failure);
11390 } else {
11391 Result =
11392 Builder.CreateAtomicCmpXchg(Ptr: X.Var, Cmp: E, New: D, Align: MaybeAlign(), SuccessOrdering: AO, FailureOrdering: Failure);
11393 }
11394 Result->setWeak(IsWeak);
11395
11396 if (V.Var) {
11397 OldValue = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/0);
11398 if (!IsInteger)
11399 OldValue = Builder.CreateBitCast(V: OldValue, DestTy: X.ElemTy);
11400 assert(OldValue->getType() == V.ElemTy &&
11401 "OldValue and V must be of same type");
11402 if (IsPostfixUpdate) {
11403 Builder.CreateStore(Val: OldValue, Ptr: V.Var, isVolatile: V.IsVolatile);
11404 } else {
11405 SuccessOrFail = Builder.CreateExtractValue(Agg: Result, /*Idxs=*/1);
11406 if (IsFailOnly) {
11407 BasicBlock *CurBB = Builder.GetInsertBlock();
11408 Instruction *CurBBTI = CurBB->getTerminatorOrNull();
11409 CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
11410 BasicBlock *ExitBB = CurBB->splitBasicBlock(
11411 I: CurBBTI, BBName: X.Var->getName() + ".atomic.exit");
11412 BasicBlock *ContBB = CurBB->splitBasicBlock(
11413 I: CurBB->getTerminator(), BBName: X.Var->getName() + ".atomic.cont");
11414 ContBB->getTerminator()->eraseFromParent();
11415 CurBB->getTerminator()->eraseFromParent();
11416
11417 Builder.CreateCondBr(Cond: SuccessOrFail, True: ExitBB, False: ContBB);
11418
11419 Builder.SetInsertPoint(ContBB);
11420 Builder.CreateStore(Val: OldValue, Ptr: V.Var);
11421 Builder.CreateBr(Dest: ExitBB);
11422
11423 if (UnreachableInst *ExitTI =
11424 dyn_cast<UnreachableInst>(Val: ExitBB->getTerminator())) {
11425 CurBBTI->eraseFromParent();
11426 Builder.SetInsertPoint(ExitBB);
11427 } else {
11428 Builder.SetInsertPoint(ExitTI);
11429 }
11430 } else {
11431 Value *CapturedValue =
11432 Builder.CreateSelect(C: SuccessOrFail, True: E, False: OldValue);
11433 Builder.CreateStore(Val: CapturedValue, Ptr: V.Var, isVolatile: V.IsVolatile);
11434 }
11435 }
11436 }
11437 // The comparison result has to be stored.
11438 if (R.Var) {
11439 assert(R.Var->getType()->isPointerTy() &&
11440 "r.var must be of pointer type");
11441 assert(R.ElemTy->isIntegerTy() && "r must be of integral type");
11442
11443 Value *SuccessFailureVal =
11444 Builder.CreateExtractValue(Agg: Result, /*Idxs=*/1);
11445 Value *ResultCast =
11446 R.IsSigned ? Builder.CreateSExt(V: SuccessFailureVal, DestTy: R.ElemTy)
11447 : Builder.CreateZExt(V: SuccessFailureVal, DestTy: R.ElemTy);
11448 Builder.CreateStore(Val: ResultCast, Ptr: R.Var, isVolatile: R.IsVolatile);
11449 }
11450 }
11451
11452 // For the HandleFPNegZero path, handle V.Var and R.Var using the
11453 // pre-computed OldValue and SuccessOrFail.
11454 if (HandleFPNegZero && !IsInteger) {
11455 if (V.Var) {
11456 assert(OldValue->getType() == V.ElemTy &&
11457 "OldValue and V must be of same type");
11458 if (IsPostfixUpdate) {
11459 Builder.CreateStore(Val: OldValue, Ptr: V.Var, isVolatile: V.IsVolatile);
11460 } else {
11461 if (IsFailOnly) {
11462 BasicBlock *CurBB = Builder.GetInsertBlock();
11463 Instruction *CurBBTI = CurBB->getTerminatorOrNull();
11464 CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
11465 BasicBlock *ExitBB = CurBB->splitBasicBlock(
11466 I: CurBBTI, BBName: X.Var->getName() + ".atomic.exit");
11467 BasicBlock *ContBB = CurBB->splitBasicBlock(
11468 I: CurBB->getTerminator(), BBName: X.Var->getName() + ".atomic.cont");
11469 ContBB->getTerminator()->eraseFromParent();
11470 CurBB->getTerminator()->eraseFromParent();
11471
11472 Builder.CreateCondBr(Cond: SuccessOrFail, True: ExitBB, False: ContBB);
11473
11474 Builder.SetInsertPoint(ContBB);
11475 Builder.CreateStore(Val: OldValue, Ptr: V.Var);
11476 Builder.CreateBr(Dest: ExitBB);
11477
11478 if (UnreachableInst *ExitTI =
11479 dyn_cast<UnreachableInst>(Val: ExitBB->getTerminator())) {
11480 CurBBTI->eraseFromParent();
11481 Builder.SetInsertPoint(ExitBB);
11482 } else {
11483 Builder.SetInsertPoint(ExitTI);
11484 }
11485 } else {
11486 Value *CapturedValue =
11487 Builder.CreateSelect(C: SuccessOrFail, True: E, False: OldValue);
11488 Builder.CreateStore(Val: CapturedValue, Ptr: V.Var, isVolatile: V.IsVolatile);
11489 }
11490 }
11491 }
11492 // The comparison result has to be stored.
11493 if (R.Var) {
11494 assert(R.Var->getType()->isPointerTy() &&
11495 "r.var must be of pointer type");
11496 assert(R.ElemTy->isIntegerTy() && "r must be of integral type");
11497
11498 Value *ResultCast = R.IsSigned
11499 ? Builder.CreateSExt(V: SuccessOrFail, DestTy: R.ElemTy)
11500 : Builder.CreateZExt(V: SuccessOrFail, DestTy: R.ElemTy);
11501 Builder.CreateStore(Val: ResultCast, Ptr: R.Var, isVolatile: R.IsVolatile);
11502 }
11503 }
11504 } else {
11505 assert((Op == OMPAtomicCompareOp::MAX || Op == OMPAtomicCompareOp::MIN) &&
11506 "Op should be either max or min at this point");
11507 assert(!IsFailOnly && "IsFailOnly is only valid when the comparison is ==");
11508
11509 // Reverse the ordop as the OpenMP forms are different from LLVM forms.
11510 // Let's take max as example.
11511 // OpenMP form:
11512 // x = x > expr ? expr : x;
11513 // LLVM form:
11514 // *ptr = *ptr > val ? *ptr : val;
11515 // We need to transform to LLVM form.
11516 // x = x <= expr ? x : expr;
11517 AtomicRMWInst::BinOp NewOp;
11518 if (IsXBinopExpr) {
11519 if (IsInteger) {
11520 if (X.IsSigned)
11521 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Min
11522 : AtomicRMWInst::Max;
11523 else
11524 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMin
11525 : AtomicRMWInst::UMax;
11526 } else {
11527 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMin
11528 : AtomicRMWInst::FMax;
11529 }
11530 } else {
11531 if (IsInteger) {
11532 if (X.IsSigned)
11533 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Max
11534 : AtomicRMWInst::Min;
11535 else
11536 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMax
11537 : AtomicRMWInst::UMin;
11538 } else {
11539 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMax
11540 : AtomicRMWInst::FMin;
11541 }
11542 }
11543
11544 AtomicRMWInst *OldValue =
11545 Builder.CreateAtomicRMW(Op: NewOp, Ptr: X.Var, Val: E, Align: MaybeAlign(), Ordering: AO);
11546 if (V.Var) {
11547 Value *CapturedValue = nullptr;
11548 if (IsPostfixUpdate) {
11549 CapturedValue = OldValue;
11550 } else {
11551 CmpInst::Predicate Pred;
11552 switch (NewOp) {
11553 case AtomicRMWInst::Max:
11554 Pred = CmpInst::ICMP_SGT;
11555 break;
11556 case AtomicRMWInst::UMax:
11557 Pred = CmpInst::ICMP_UGT;
11558 break;
11559 case AtomicRMWInst::FMax:
11560 Pred = CmpInst::FCMP_OGT;
11561 break;
11562 case AtomicRMWInst::Min:
11563 Pred = CmpInst::ICMP_SLT;
11564 break;
11565 case AtomicRMWInst::UMin:
11566 Pred = CmpInst::ICMP_ULT;
11567 break;
11568 case AtomicRMWInst::FMin:
11569 Pred = CmpInst::FCMP_OLT;
11570 break;
11571 default:
11572 llvm_unreachable("unexpected comparison op");
11573 }
11574 Value *NonAtomicCmp = Builder.CreateCmp(Pred, LHS: OldValue, RHS: E);
11575 CapturedValue = Builder.CreateSelect(C: NonAtomicCmp, True: E, False: OldValue);
11576 }
11577 Builder.CreateStore(Val: CapturedValue, Ptr: V.Var, isVolatile: V.IsVolatile);
11578 }
11579 }
11580
11581 checkAndEmitFlushAfterAtomic(Loc, AO, AK: AtomicKind::Compare);
11582
11583 return Builder.saveIP();
11584}
11585
11586OpenMPIRBuilder::InsertPointOrErrorTy
11587OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
11588 BodyGenCallbackTy BodyGenCB, Value *NumTeamsLower,
11589 Value *NumTeamsUpper, Value *ThreadLimit,
11590 Value *IfExpr) {
11591 if (!updateToLocation(Loc))
11592 return InsertPointTy();
11593
11594 uint32_t SrcLocStrSize;
11595 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
11596 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
11597 Function *CurrentFunction = Builder.GetInsertBlock()->getParent();
11598
11599 // Outer allocation basicblock is the entry block of the current function.
11600 BasicBlock &OuterAllocaBB = CurrentFunction->getEntryBlock();
11601 if (&OuterAllocaBB == Builder.GetInsertBlock()) {
11602 BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, Name: "teams.entry");
11603 Builder.SetInsertPoint(TheBB: BodyBB, IP: BodyBB->begin());
11604 }
11605
11606 // The current basic block is split into four basic blocks. After outlining,
11607 // they will be mapped as follows:
11608 // ```
11609 // def current_fn() {
11610 // current_basic_block:
11611 // br label %teams.exit
11612 // teams.exit:
11613 // ; instructions after teams
11614 // }
11615 //
11616 // def outlined_fn() {
11617 // teams.alloca:
11618 // br label %teams.body
11619 // teams.body:
11620 // ; instructions within teams body
11621 // }
11622 // ```
11623 BasicBlock *ExitBB = splitBB(Builder, /*CreateBranch=*/true, Name: "teams.exit");
11624 BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, Name: "teams.body");
11625 BasicBlock *AllocaBB =
11626 splitBB(Builder, /*CreateBranch=*/true, Name: "teams.alloca");
11627
11628 bool SubClausesPresent =
11629 (NumTeamsLower || NumTeamsUpper || ThreadLimit || IfExpr);
11630 // Push num_teams
11631 if (!Config.isTargetDevice() && SubClausesPresent) {
11632 assert((NumTeamsLower == nullptr || NumTeamsUpper != nullptr) &&
11633 "if lowerbound is non-null, then upperbound must also be non-null "
11634 "for bounds on num_teams");
11635
11636 if (NumTeamsUpper == nullptr)
11637 NumTeamsUpper = Builder.getInt32(C: 0);
11638
11639 if (NumTeamsLower == nullptr)
11640 NumTeamsLower = NumTeamsUpper;
11641
11642 if (IfExpr) {
11643 assert(IfExpr->getType()->isIntegerTy() &&
11644 "argument to if clause must be an integer value");
11645
11646 // upper = ifexpr ? upper : 1
11647 if (IfExpr->getType() != Int1)
11648 IfExpr = Builder.CreateICmpNE(LHS: IfExpr,
11649 RHS: ConstantInt::get(Ty: IfExpr->getType(), V: 0));
11650 NumTeamsUpper = Builder.CreateSelect(
11651 C: IfExpr, True: NumTeamsUpper, False: Builder.getInt32(C: 1), Name: "numTeamsUpper");
11652
11653 // lower = ifexpr ? lower : 1
11654 NumTeamsLower = Builder.CreateSelect(
11655 C: IfExpr, True: NumTeamsLower, False: Builder.getInt32(C: 1), Name: "numTeamsLower");
11656 }
11657
11658 if (ThreadLimit == nullptr)
11659 ThreadLimit = Builder.getInt32(C: 0);
11660
11661 // The __kmpc_push_num_teams_51 function expects int32 as the arguments. So,
11662 // truncate or sign extend the passed values to match the int32 parameters.
11663 Value *NumTeamsLowerInt32 =
11664 Builder.CreateSExtOrTrunc(V: NumTeamsLower, DestTy: Builder.getInt32Ty());
11665 Value *NumTeamsUpperInt32 =
11666 Builder.CreateSExtOrTrunc(V: NumTeamsUpper, DestTy: Builder.getInt32Ty());
11667 Value *ThreadLimitInt32 =
11668 Builder.CreateSExtOrTrunc(V: ThreadLimit, DestTy: Builder.getInt32Ty());
11669
11670 Value *ThreadNum = getOrCreateThreadID(Ident);
11671
11672 createRuntimeFunctionCall(
11673 Callee: getOrCreateRuntimeFunctionPtr(FnID: OMPRTL___kmpc_push_num_teams_51),
11674 Args: {Ident, ThreadNum, NumTeamsLowerInt32, NumTeamsUpperInt32,
11675 ThreadLimitInt32});
11676 }
11677 // Generate the body of teams.
11678 InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
11679 InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
11680 if (Error Err = BodyGenCB(AllocaIP, CodeGenIP, ExitBB))
11681 return Err;
11682
11683 auto OI = std::make_unique<OutlineInfo>();
11684 OI->EntryBB = AllocaBB;
11685 OI->ExitBB = ExitBB;
11686 OI->OuterAllocBB = &OuterAllocaBB;
11687
11688 // Insert fake values for global tid and bound tid.
11689 SmallVector<Instruction *, 8> ToBeDeleted;
11690 InsertPointTy OuterAllocaIP(&OuterAllocaBB, OuterAllocaBB.begin());
11691 OI->ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
11692 Builder, OuterAllocaIP, ToBeDeleted, InnerAllocaIP: AllocaIP, Name: "gid", AsPtr: true));
11693 OI->ExcludeArgsFromAggregate.push_back(Elt: createFakeIntVal(
11694 Builder, OuterAllocaIP, ToBeDeleted, InnerAllocaIP: AllocaIP, Name: "tid", AsPtr: true));
11695
11696 auto HostPostOutlineCB = [this, Ident,
11697 ToBeDeleted](Function &OutlinedFn) mutable {
11698 // The stale call instruction will be replaced with a new call instruction
11699 // for runtime call with the outlined function.
11700
11701 assert(OutlinedFn.hasOneUse() &&
11702 "there must be a single user for the outlined function");
11703 CallInst *StaleCI = cast<CallInst>(Val: OutlinedFn.user_back());
11704 ToBeDeleted.push_back(Elt: StaleCI);
11705
11706 assert((OutlinedFn.arg_size() == 2 || OutlinedFn.arg_size() == 3) &&
11707 "Outlined function must have two or three arguments only");
11708
11709 bool HasShared = OutlinedFn.arg_size() == 3;
11710
11711 OutlinedFn.getArg(i: 0)->setName("global.tid.ptr");
11712 OutlinedFn.getArg(i: 1)->setName("bound.tid.ptr");
11713 if (HasShared)
11714 OutlinedFn.getArg(i: 2)->setName("data");
11715
11716 // Call to the runtime function for teams in the current function.
11717 assert(StaleCI && "Error while outlining - no CallInst user found for the "
11718 "outlined function.");
11719 Builder.SetInsertPoint(StaleCI);
11720 SmallVector<Value *> Args = {
11721 Ident, Builder.getInt32(C: StaleCI->arg_size() - 2), &OutlinedFn};
11722 if (HasShared)
11723 Args.push_back(Elt: StaleCI->getArgOperand(i: 2));
11724 createRuntimeFunctionCall(
11725 Callee: getOrCreateRuntimeFunctionPtr(
11726 FnID: omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
11727 Args);
11728
11729 for (Instruction *I : llvm::reverse(C&: ToBeDeleted))
11730 I->eraseFromParent();
11731 };
11732
11733 if (!Config.isTargetDevice())
11734 OI->PostOutlineCB = HostPostOutlineCB;
11735
11736 addOutlineInfo(OI: std::move(OI));
11737
11738 Builder.SetInsertPoint(ExitBB);
11739
11740 return Builder.saveIP();
11741}
11742
11743OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createDistribute(
11744 const LocationDescription &Loc, InsertPointTy OuterAllocIP,
11745 ArrayRef<BasicBlock *> OuterDeallocBlocks, BodyGenCallbackTy BodyGenCB) {
11746 if (!updateToLocation(Loc))
11747 return InsertPointTy();
11748
11749 BasicBlock *OuterAllocaBB = OuterAllocIP.getBlock();
11750
11751 if (OuterAllocaBB == Builder.GetInsertBlock()) {
11752 BasicBlock *BodyBB =
11753 splitBB(Builder, /*CreateBranch=*/true, Name: "distribute.entry");
11754 Builder.SetInsertPoint(TheBB: BodyBB, IP: BodyBB->begin());
11755 }
11756 BasicBlock *ExitBB =
11757 splitBB(Builder, /*CreateBranch=*/true, Name: "distribute.exit");
11758 BasicBlock *BodyBB =
11759 splitBB(Builder, /*CreateBranch=*/true, Name: "distribute.body");
11760 BasicBlock *AllocaBB =
11761 splitBB(Builder, /*CreateBranch=*/true, Name: "distribute.alloca");
11762
11763 // Generate the body of distribute clause
11764 InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
11765 InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
11766 if (Error Err = BodyGenCB(AllocaIP, CodeGenIP, ExitBB))
11767 return Err;
11768
11769 // When using target we use different runtime functions which require a
11770 // callback.
11771 if (Config.isTargetDevice()) {
11772 auto OI = std::make_unique<OutlineInfo>();
11773 OI->OuterAllocBB = OuterAllocIP.getBlock();
11774 OI->EntryBB = AllocaBB;
11775 OI->ExitBB = ExitBB;
11776 OI->OuterDeallocBBs.reserve(N: OuterDeallocBlocks.size());
11777 copy(Range&: OuterDeallocBlocks, Out: OI->OuterDeallocBBs.end());
11778
11779 addOutlineInfo(OI: std::move(OI));
11780 }
11781 Builder.SetInsertPoint(ExitBB);
11782
11783 return Builder.saveIP();
11784}
11785
11786GlobalVariable *
11787OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl<llvm::Constant *> &Names,
11788 std::string VarName) {
11789 llvm::Constant *MapNamesArrayInit = llvm::ConstantArray::get(
11790 T: llvm::ArrayType::get(ElementType: llvm::PointerType::getUnqual(C&: M.getContext()),
11791 NumElements: Names.size()),
11792 V: Names);
11793 auto *MapNamesArrayGlobal = new llvm::GlobalVariable(
11794 M, MapNamesArrayInit->getType(),
11795 /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MapNamesArrayInit,
11796 VarName);
11797 return MapNamesArrayGlobal;
11798}
11799
11800// Create all simple and struct types exposed by the runtime and remember
11801// the llvm::PointerTypes of them for easy access later.
11802void OpenMPIRBuilder::initializeTypes(Module &M) {
11803 LLVMContext &Ctx = M.getContext();
11804 StructType *T;
11805 unsigned DefaultTargetAS = Config.getDefaultTargetAS();
11806 unsigned ProgramAS = M.getDataLayout().getProgramAddressSpace();
11807#define OMP_TYPE(VarName, InitValue) VarName = InitValue;
11808#define OMP_ARRAY_TYPE(VarName, ElemTy, ArraySize) \
11809 VarName##Ty = ArrayType::get(ElemTy, ArraySize); \
11810 VarName##PtrTy = PointerType::get(Ctx, DefaultTargetAS);
11811#define OMP_FUNCTION_TYPE(VarName, IsVarArg, ReturnType, ...) \
11812 VarName = FunctionType::get(ReturnType, {__VA_ARGS__}, IsVarArg); \
11813 VarName##Ptr = PointerType::get(Ctx, ProgramAS);
11814#define OMP_STRUCT_TYPE(VarName, StructName, Packed, ...) \
11815 T = StructType::getTypeByName(Ctx, StructName); \
11816 if (!T) \
11817 T = StructType::create(Ctx, {__VA_ARGS__}, StructName, Packed); \
11818 VarName = T; \
11819 VarName##Ptr = PointerType::get(Ctx, DefaultTargetAS);
11820#include "llvm/Frontend/OpenMP/OMPKinds.def"
11821}
11822
11823void OpenMPIRBuilder::OutlineInfo::collectBlocks(
11824 SmallPtrSetImpl<BasicBlock *> &BlockSet,
11825 SmallVectorImpl<BasicBlock *> &BlockVector) {
11826 SmallVector<BasicBlock *, 32> Worklist;
11827 BlockSet.insert(Ptr: EntryBB);
11828 BlockSet.insert(Ptr: ExitBB);
11829
11830 Worklist.push_back(Elt: EntryBB);
11831 while (!Worklist.empty()) {
11832 BasicBlock *BB = Worklist.pop_back_val();
11833 BlockVector.push_back(Elt: BB);
11834 for (BasicBlock *SuccBB : successors(BB))
11835 if (BlockSet.insert(Ptr: SuccBB).second)
11836 Worklist.push_back(Elt: SuccBB);
11837 }
11838}
11839
11840std::unique_ptr<CodeExtractor>
11841OpenMPIRBuilder::OutlineInfo::createCodeExtractor(ArrayRef<BasicBlock *> Blocks,
11842 bool ArgsInZeroAddressSpace,
11843 Twine Suffix) {
11844 return std::make_unique<CodeExtractor>(
11845 args&: Blocks, /* DominatorTree */ args: nullptr,
11846 /* AggregateArgs */ args: true,
11847 /* BlockFrequencyInfo */ args: nullptr,
11848 /* BranchProbabilityInfo */ args: nullptr,
11849 /* AssumptionCache */ args: nullptr,
11850 /* AllowVarArgs */ args: true,
11851 /* AllowAlloca */ args: true,
11852 /* AllocationBlock*/ args&: OuterAllocBB,
11853 /* DeallocationBlocks */ args: ArrayRef<BasicBlock *>(),
11854 /* Suffix */ args: Suffix.str(), args&: ArgsInZeroAddressSpace);
11855}
11856
11857std::unique_ptr<CodeExtractor> DeviceSharedMemOutlineInfo::createCodeExtractor(
11858 ArrayRef<BasicBlock *> Blocks, bool ArgsInZeroAddressSpace, Twine Suffix) {
11859 return std::make_unique<DeviceSharedMemCodeExtractor>(
11860 args&: OMPBuilder, args&: Blocks, /* DominatorTree */ args: nullptr,
11861 /* AggregateArgs */ args: true,
11862 /* BlockFrequencyInfo */ args: nullptr,
11863 /* BranchProbabilityInfo */ args: nullptr,
11864 /* AssumptionCache */ args: nullptr,
11865 /* AllowVarArgs */ args: true,
11866 /* AllowAlloca */ args: true,
11867 /* AllocationBlock*/ args&: OuterAllocBB,
11868 /* DeallocationBlocks */ args: OuterDeallocBBs.empty()
11869 ? SmallVector<BasicBlock *>{ExitBB}
11870 : OuterDeallocBBs,
11871 /* Suffix */ args: Suffix.str(), args&: ArgsInZeroAddressSpace);
11872}
11873
11874void OpenMPIRBuilder::createOffloadEntry(Constant *ID, Constant *Addr,
11875 uint64_t Size, int32_t Flags,
11876 GlobalValue::LinkageTypes,
11877 StringRef Name) {
11878 if (!Config.isGPU()) {
11879 llvm::offloading::emitOffloadingEntry(
11880 M, Kind: object::OffloadKind::OFK_OpenMP, Addr: ID,
11881 Name: Name.empty() ? Addr->getName() : Name, Size, Flags, /*Data=*/0);
11882 return;
11883 }
11884 // TODO: Add support for global variables on the device after declare target
11885 // support.
11886 Function *Fn = dyn_cast<Function>(Val: Addr);
11887 if (!Fn)
11888 return;
11889
11890 // Add a function attribute for the kernel.
11891 Fn->addFnAttr(Kind: "kernel");
11892 if (T.isAMDGCN())
11893 Fn->addFnAttr(Kind: "uniform-work-group-size");
11894 Fn->addFnAttr(Kind: Attribute::MustProgress);
11895}
11896
11897// We only generate metadata for function that contain target regions.
11898void OpenMPIRBuilder::createOffloadEntriesAndInfoMetadata(
11899 EmitMetadataErrorReportFunctionTy &ErrorFn) {
11900
11901 // If there are no entries, we don't need to do anything.
11902 if (OffloadInfoManager.empty())
11903 return;
11904
11905 LLVMContext &C = M.getContext();
11906 SmallVector<std::pair<const OffloadEntriesInfoManager::OffloadEntryInfo *,
11907 TargetRegionEntryInfo>,
11908 16>
11909 OrderedEntries(OffloadInfoManager.size());
11910
11911 // Auxiliary methods to create metadata values and strings.
11912 auto &&GetMDInt = [this](unsigned V) {
11913 return ConstantAsMetadata::get(C: ConstantInt::get(Ty: Builder.getInt32Ty(), V));
11914 };
11915
11916 auto &&GetMDString = [&C](StringRef V) { return MDString::get(Context&: C, Str: V); };
11917
11918 // Create the offloading info metadata node.
11919 NamedMDNode *MD = M.getOrInsertNamedMetadata(Name: "omp_offload.info");
11920 auto &&TargetRegionMetadataEmitter =
11921 [&C, MD, &OrderedEntries, &GetMDInt, &GetMDString](
11922 const TargetRegionEntryInfo &EntryInfo,
11923 const OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion &E) {
11924 // Generate metadata for target regions. Each entry of this metadata
11925 // contains:
11926 // - Entry 0 -> Kind of this type of metadata (0).
11927 // - Entry 1 -> Device ID of the file where the entry was identified.
11928 // - Entry 2 -> File ID of the file where the entry was identified.
11929 // - Entry 3 -> Mangled name of the function where the entry was
11930 // identified.
11931 // - Entry 4 -> Line in the file where the entry was identified.
11932 // - Entry 5 -> Count of regions at this DeviceID/FilesID/Line.
11933 // - Entry 6 -> Order the entry was created.
11934 // The first element of the metadata node is the kind.
11935 Metadata *Ops[] = {
11936 GetMDInt(E.getKind()), GetMDInt(EntryInfo.DeviceID),
11937 GetMDInt(EntryInfo.FileID), GetMDString(EntryInfo.ParentName),
11938 GetMDInt(EntryInfo.Line), GetMDInt(EntryInfo.Count),
11939 GetMDInt(E.getOrder())};
11940
11941 // Save this entry in the right position of the ordered entries array.
11942 OrderedEntries[E.getOrder()] = std::make_pair(x: &E, y: EntryInfo);
11943
11944 // Add metadata to the named metadata node.
11945 MD->addOperand(M: MDNode::get(Context&: C, MDs: Ops));
11946 };
11947
11948 OffloadInfoManager.actOnTargetRegionEntriesInfo(Action: TargetRegionMetadataEmitter);
11949
11950 // Create function that emits metadata for each device global variable entry;
11951 auto &&DeviceGlobalVarMetadataEmitter =
11952 [&C, &OrderedEntries, &GetMDInt, &GetMDString, MD](
11953 StringRef MangledName,
11954 const OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar &E) {
11955 // Generate metadata for global variables. Each entry of this metadata
11956 // contains:
11957 // - Entry 0 -> Kind of this type of metadata (1).
11958 // - Entry 1 -> Mangled name of the variable.
11959 // - Entry 2 -> Declare target kind.
11960 // - Entry 3 -> Order the entry was created.
11961 // The first element of the metadata node is the kind.
11962 Metadata *Ops[] = {GetMDInt(E.getKind()), GetMDString(MangledName),
11963 GetMDInt(E.getFlags()), GetMDInt(E.getOrder())};
11964
11965 // Save this entry in the right position of the ordered entries array.
11966 TargetRegionEntryInfo varInfo(MangledName, 0, 0, 0);
11967 OrderedEntries[E.getOrder()] = std::make_pair(x: &E, y&: varInfo);
11968
11969 // Add metadata to the named metadata node.
11970 MD->addOperand(M: MDNode::get(Context&: C, MDs: Ops));
11971 };
11972
11973 OffloadInfoManager.actOnDeviceGlobalVarEntriesInfo(
11974 Action: DeviceGlobalVarMetadataEmitter);
11975
11976 for (const auto &E : OrderedEntries) {
11977 assert(E.first && "All ordered entries must exist!");
11978 if (const auto *CE =
11979 dyn_cast<OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion>(
11980 Val: E.first)) {
11981 if (!CE->getID() || !CE->getAddress()) {
11982 // Do not blame the entry if the parent funtion is not emitted.
11983 TargetRegionEntryInfo EntryInfo = E.second;
11984 StringRef FnName = EntryInfo.ParentName;
11985 if (!M.getNamedValue(Name: FnName))
11986 continue;
11987 ErrorFn(EMIT_MD_TARGET_REGION_ERROR, EntryInfo);
11988 continue;
11989 }
11990 createOffloadEntry(ID: CE->getID(), Addr: CE->getAddress(),
11991 /*Size=*/0, Flags: CE->getFlags(),
11992 GlobalValue::WeakAnyLinkage);
11993 } else if (const auto *CE = dyn_cast<
11994 OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar>(
11995 Val: E.first)) {
11996 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags =
11997 static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
11998 CE->getFlags());
11999 switch (Flags) {
12000 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter:
12001 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo:
12002 if (Config.isTargetDevice() && Config.hasRequiresUnifiedSharedMemory())
12003 continue;
12004 if (!CE->getAddress()) {
12005 ErrorFn(EMIT_MD_DECLARE_TARGET_ERROR, E.second);
12006 continue;
12007 }
12008 // The vaiable has no definition - no need to add the entry.
12009 if (CE->getVarSize() == 0)
12010 continue;
12011 break;
12012 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink:
12013 assert(((Config.isTargetDevice() && !CE->getAddress()) ||
12014 (!Config.isTargetDevice() && CE->getAddress())) &&
12015 "Declaret target link address is set.");
12016 if (Config.isTargetDevice())
12017 continue;
12018 if (!CE->getAddress()) {
12019 ErrorFn(EMIT_MD_GLOBAL_VAR_LINK_ERROR, TargetRegionEntryInfo());
12020 continue;
12021 }
12022 break;
12023 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect:
12024 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirectVTable:
12025 if (!CE->getAddress()) {
12026 ErrorFn(EMIT_MD_GLOBAL_VAR_INDIRECT_ERROR, E.second);
12027 continue;
12028 }
12029 break;
12030 default:
12031 break;
12032 }
12033
12034 // Hidden or internal symbols on the device are not externally visible.
12035 // We should not attempt to register them by creating an offloading
12036 // entry. Indirect variables are handled separately on the device.
12037 if (auto *GV = dyn_cast<GlobalValue>(Val: CE->getAddress()))
12038 if ((GV->hasLocalLinkage() || GV->hasHiddenVisibility()) &&
12039 (Flags !=
12040 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect &&
12041 Flags != OffloadEntriesInfoManager::
12042 OMPTargetGlobalVarEntryIndirectVTable))
12043 continue;
12044
12045 // Indirect globals need to use a special name that doesn't match the name
12046 // of the associated host global.
12047 if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect ||
12048 Flags ==
12049 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirectVTable)
12050 createOffloadEntry(ID: CE->getAddress(), Addr: CE->getAddress(), Size: CE->getVarSize(),
12051 Flags, CE->getLinkage(), Name: CE->getVarName());
12052 else
12053 createOffloadEntry(ID: CE->getAddress(), Addr: CE->getAddress(), Size: CE->getVarSize(),
12054 Flags, CE->getLinkage());
12055
12056 } else {
12057 llvm_unreachable("Unsupported entry kind.");
12058 }
12059 }
12060
12061 // Emit requires directive globals to a special entry so the runtime can
12062 // register them when the device image is loaded.
12063 // TODO: This reduces the offloading entries to a 32-bit integer. Offloading
12064 // entries should be redesigned to better suit this use-case.
12065 if (Config.hasRequiresFlags() && !Config.isTargetDevice())
12066 offloading::emitOffloadingEntry(
12067 M, Kind: object::OffloadKind::OFK_OpenMP,
12068 Addr: Constant::getNullValue(Ty: PointerType::getUnqual(C&: M.getContext())),
12069 Name: ".requires", /*Size=*/0,
12070 Flags: OffloadEntriesInfoManager::OMPTargetGlobalRegisterRequires,
12071 Data: Config.getRequiresFlags());
12072}
12073
12074void TargetRegionEntryInfo::getTargetRegionEntryFnName(
12075 SmallVectorImpl<char> &Name, StringRef ParentName, unsigned DeviceID,
12076 unsigned FileID, unsigned Line, unsigned Count) {
12077 raw_svector_ostream OS(Name);
12078 OS << KernelNamePrefix << llvm::format(Fmt: "%x", Vals: DeviceID)
12079 << llvm::format(Fmt: "_%x_", Vals: FileID) << ParentName << "_l" << Line;
12080 if (Count)
12081 OS << "_" << Count;
12082}
12083
12084void OffloadEntriesInfoManager::getTargetRegionEntryFnName(
12085 SmallVectorImpl<char> &Name, const TargetRegionEntryInfo &EntryInfo) {
12086 unsigned NewCount = getTargetRegionEntryInfoCount(EntryInfo);
12087 TargetRegionEntryInfo::getTargetRegionEntryFnName(
12088 Name, ParentName: EntryInfo.ParentName, DeviceID: EntryInfo.DeviceID, FileID: EntryInfo.FileID,
12089 Line: EntryInfo.Line, Count: NewCount);
12090}
12091
12092TargetRegionEntryInfo
12093OpenMPIRBuilder::getTargetEntryUniqueInfo(FileIdentifierInfoCallbackTy CallBack,
12094 vfs::FileSystem &VFS,
12095 StringRef ParentName) {
12096 sys::fs::UniqueID ID(0xdeadf17e, 0);
12097 auto FileIDInfo = CallBack();
12098 uint64_t FileID = 0;
12099 if (ErrorOr<vfs::Status> Status = VFS.status(Path: std::get<0>(t&: FileIDInfo))) {
12100 ID = Status->getUniqueID();
12101 FileID = Status->getUniqueID().getFile();
12102 } else {
12103 // If the inode ID could not be determined, create a hash value
12104 // the current file name and use that as an ID.
12105 FileID = hash_value(arg: std::get<0>(t&: FileIDInfo));
12106 }
12107
12108 return TargetRegionEntryInfo(ParentName, ID.getDevice(), FileID,
12109 std::get<1>(t&: FileIDInfo));
12110}
12111
12112unsigned OpenMPIRBuilder::getFlagMemberOffset() {
12113 unsigned Offset = 0;
12114 for (uint64_t Remain =
12115 static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
12116 omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF);
12117 !(Remain & 1); Remain = Remain >> 1)
12118 Offset++;
12119 return Offset;
12120}
12121
12122omp::OpenMPOffloadMappingFlags
12123OpenMPIRBuilder::getMemberOfFlag(unsigned Position) {
12124 // Rotate by getFlagMemberOffset() bits.
12125 return static_cast<omp::OpenMPOffloadMappingFlags>(((uint64_t)Position + 1)
12126 << getFlagMemberOffset());
12127}
12128
12129void OpenMPIRBuilder::setCorrectMemberOfFlag(
12130 omp::OpenMPOffloadMappingFlags &Flags,
12131 omp::OpenMPOffloadMappingFlags MemberOfFlag) {
12132 // If the entry is PTR_AND_OBJ but has not been marked with the special
12133 // placeholder value 0xFFFF in the MEMBER_OF field, then it should not be
12134 // marked as MEMBER_OF.
12135 if (static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
12136 Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ) &&
12137 static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
12138 (Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF) !=
12139 omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF))
12140 return;
12141
12142 // Entries with ATTACH are not members-of anything. They are handled
12143 // separately by the runtime after other maps have been handled.
12144 if (static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
12145 Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH))
12146 return;
12147
12148 // Reset the placeholder value to prepare the flag for the assignment of the
12149 // proper MEMBER_OF value.
12150 Flags &= ~omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
12151 Flags |= MemberOfFlag;
12152}
12153
12154Constant *OpenMPIRBuilder::getAddrOfDeclareTargetVar(
12155 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
12156 OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
12157 bool IsDeclaration, bool IsExternallyVisible,
12158 TargetRegionEntryInfo EntryInfo, StringRef MangledName,
12159 std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
12160 std::vector<Triple> TargetTriple, Type *LlvmPtrTy,
12161 std::function<Constant *()> GlobalInitializer,
12162 std::function<GlobalValue::LinkageTypes()> VariableLinkage) {
12163 // TODO: convert this to utilise the IRBuilder Config rather than
12164 // a passed down argument.
12165 if (OpenMPSIMD)
12166 return nullptr;
12167
12168 if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink ||
12169 ((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
12170 CaptureClause ==
12171 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
12172 Config.hasRequiresUnifiedSharedMemory())) {
12173 SmallString<64> PtrName;
12174 {
12175 raw_svector_ostream OS(PtrName);
12176 OS << MangledName;
12177 if (!IsExternallyVisible)
12178 OS << format(Fmt: "_%x", Vals: EntryInfo.FileID);
12179 OS << "_decl_tgt_ref_ptr";
12180 }
12181
12182 Value *Ptr = M.getNamedValue(Name: PtrName);
12183
12184 if (!Ptr) {
12185 GlobalValue *GlobalValue = M.getNamedValue(Name: MangledName);
12186 Ptr = getOrCreateInternalVariable(Ty: LlvmPtrTy, Name: PtrName);
12187
12188 auto *GV = cast<GlobalVariable>(Val: Ptr);
12189 GV->setLinkage(GlobalValue::WeakAnyLinkage);
12190
12191 if (!Config.isTargetDevice()) {
12192 if (GlobalInitializer)
12193 GV->setInitializer(GlobalInitializer());
12194 else
12195 GV->setInitializer(GlobalValue);
12196 }
12197
12198 registerTargetGlobalVariable(
12199 CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
12200 EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
12201 GlobalInitializer, VariableLinkage, LlvmPtrTy, Addr: cast<Constant>(Val: Ptr));
12202 }
12203
12204 return cast<Constant>(Val: Ptr);
12205 }
12206
12207 return nullptr;
12208}
12209
12210void OpenMPIRBuilder::registerTargetGlobalVariable(
12211 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
12212 OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
12213 bool IsDeclaration, bool IsExternallyVisible,
12214 TargetRegionEntryInfo EntryInfo, StringRef MangledName,
12215 std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
12216 std::vector<Triple> TargetTriple,
12217 std::function<Constant *()> GlobalInitializer,
12218 std::function<GlobalValue::LinkageTypes()> VariableLinkage, Type *LlvmPtrTy,
12219 Constant *Addr) {
12220 if (DeviceClause != OffloadEntriesInfoManager::OMPTargetDeviceClauseAny ||
12221 (TargetTriple.empty() && !Config.isTargetDevice()))
12222 return;
12223
12224 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags;
12225 StringRef VarName;
12226 int64_t VarSize;
12227 GlobalValue::LinkageTypes Linkage;
12228
12229 if ((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
12230 CaptureClause ==
12231 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
12232 !Config.hasRequiresUnifiedSharedMemory()) {
12233 Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
12234 VarName = MangledName;
12235 GlobalValue *LlvmVal = M.getNamedValue(Name: VarName);
12236
12237 if (!IsDeclaration)
12238 VarSize = divideCeil(
12239 Numerator: M.getDataLayout().getTypeSizeInBits(Ty: LlvmVal->getValueType()), Denominator: 8);
12240 else
12241 VarSize = 0;
12242 Linkage = (VariableLinkage) ? VariableLinkage() : LlvmVal->getLinkage();
12243
12244 // This is a workaround carried over from Clang which prevents undesired
12245 // optimisation of internal variables.
12246 if (Config.isTargetDevice() &&
12247 (!IsExternallyVisible || Linkage == GlobalValue::LinkOnceODRLinkage)) {
12248 // Do not create a "ref-variable" if the original is not also available
12249 // on the host.
12250 if (!OffloadInfoManager.hasDeviceGlobalVarEntryInfo(VarName))
12251 return;
12252
12253 std::string RefName = createPlatformSpecificName(Parts: {VarName, "ref"});
12254
12255 if (!M.getNamedValue(Name: RefName)) {
12256 Constant *AddrRef =
12257 getOrCreateInternalVariable(Ty: Addr->getType(), Name: RefName);
12258 auto *GvAddrRef = cast<GlobalVariable>(Val: AddrRef);
12259 GvAddrRef->setConstant(true);
12260 GvAddrRef->setLinkage(GlobalValue::InternalLinkage);
12261 GvAddrRef->setInitializer(Addr);
12262 GeneratedRefs.push_back(x: GvAddrRef);
12263 }
12264 }
12265 } else {
12266 if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink)
12267 Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
12268 else
12269 Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
12270
12271 if (Config.isTargetDevice()) {
12272 VarName = (Addr) ? Addr->getName() : "";
12273 Addr = nullptr;
12274 } else {
12275 Addr = getAddrOfDeclareTargetVar(
12276 CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
12277 EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
12278 LlvmPtrTy, GlobalInitializer, VariableLinkage);
12279 VarName = (Addr) ? Addr->getName() : "";
12280 }
12281 VarSize = M.getDataLayout().getPointerSize();
12282 Linkage = GlobalValue::WeakAnyLinkage;
12283 }
12284
12285 OffloadInfoManager.registerDeviceGlobalVarEntryInfo(VarName, Addr, VarSize,
12286 Flags, Linkage);
12287}
12288
12289/// Loads all the offload entries information from the host IR
12290/// metadata.
12291void OpenMPIRBuilder::loadOffloadInfoMetadata(Module &M) {
12292 // If we are in target mode, load the metadata from the host IR. This code has
12293 // to match the metadata creation in createOffloadEntriesAndInfoMetadata().
12294
12295 NamedMDNode *MD = M.getNamedMetadata(Name: ompOffloadInfoName);
12296 if (!MD)
12297 return;
12298
12299 for (MDNode *MN : MD->operands()) {
12300 auto &&GetMDInt = [MN](unsigned Idx) {
12301 auto *V = cast<ConstantAsMetadata>(Val: MN->getOperand(I: Idx));
12302 return cast<ConstantInt>(Val: V->getValue())->getZExtValue();
12303 };
12304
12305 auto &&GetMDString = [MN](unsigned Idx) {
12306 auto *V = cast<MDString>(Val: MN->getOperand(I: Idx));
12307 return V->getString();
12308 };
12309
12310 switch (GetMDInt(0)) {
12311 default:
12312 llvm_unreachable("Unexpected metadata!");
12313 break;
12314 case OffloadEntriesInfoManager::OffloadEntryInfo::
12315 OffloadingEntryInfoTargetRegion: {
12316 TargetRegionEntryInfo EntryInfo(/*ParentName=*/GetMDString(3),
12317 /*DeviceID=*/GetMDInt(1),
12318 /*FileID=*/GetMDInt(2),
12319 /*Line=*/GetMDInt(4),
12320 /*Count=*/GetMDInt(5));
12321 OffloadInfoManager.initializeTargetRegionEntryInfo(EntryInfo,
12322 /*Order=*/GetMDInt(6));
12323 break;
12324 }
12325 case OffloadEntriesInfoManager::OffloadEntryInfo::
12326 OffloadingEntryInfoDeviceGlobalVar:
12327 OffloadInfoManager.initializeDeviceGlobalVarEntryInfo(
12328 /*MangledName=*/Name: GetMDString(1),
12329 Flags: static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
12330 /*Flags=*/GetMDInt(2)),
12331 /*Order=*/GetMDInt(3));
12332 break;
12333 }
12334 }
12335}
12336
12337void OpenMPIRBuilder::loadOffloadInfoMetadata(vfs::FileSystem &VFS,
12338 StringRef HostFilePath) {
12339 if (HostFilePath.empty())
12340 return;
12341
12342 auto Buf = VFS.getBufferForFile(Name: HostFilePath);
12343 if (std::error_code Err = Buf.getError()) {
12344 report_fatal_error(reason: ("error opening host file from host file path inside of "
12345 "OpenMPIRBuilder: " +
12346 Err.message())
12347 .c_str());
12348 }
12349
12350 LLVMContext Ctx;
12351 auto M = expectedToErrorOrAndEmitErrors(
12352 Ctx, Val: parseBitcodeFile(Buffer: Buf.get()->getMemBufferRef(), Context&: Ctx));
12353 if (std::error_code Err = M.getError()) {
12354 report_fatal_error(
12355 reason: ("error parsing host file inside of OpenMPIRBuilder: " + Err.message())
12356 .c_str());
12357 }
12358
12359 loadOffloadInfoMetadata(M&: *M.get());
12360}
12361
12362OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createIteratorLoop(
12363 LocationDescription Loc, llvm::Value *TripCount, IteratorBodyGenTy BodyGen,
12364 llvm::StringRef Name) {
12365 Builder.restoreIP(IP: Loc.IP);
12366
12367 BasicBlock *CurBB = Builder.GetInsertBlock();
12368 assert(CurBB &&
12369 "expected a valid insertion block for creating an iterator loop");
12370 Function *F = CurBB->getParent();
12371
12372 InsertPointTy SplitIP = Builder.saveIP();
12373 if (SplitIP.getPoint() == CurBB->end())
12374 if (Instruction *Terminator = CurBB->getTerminatorOrNull())
12375 SplitIP = InsertPointTy(CurBB, Terminator->getIterator());
12376
12377 BasicBlock *ContBB =
12378 splitBB(IP: SplitIP, /*CreateBranch=*/false,
12379 DL: Builder.getCurrentDebugLocation(), Name: "omp.it.cont");
12380
12381 CanonicalLoopInfo *CLI =
12382 createLoopSkeleton(DL: Builder.getCurrentDebugLocation(), TripCount, F,
12383 /*PreInsertBefore=*/ContBB,
12384 /*PostInsertBefore=*/ContBB, Name);
12385
12386 // Enter loop from original block.
12387 redirectTo(Source: CurBB, Target: CLI->getPreheader(), DL: Builder.getCurrentDebugLocation());
12388
12389 // Remove the unconditional branch inserted by createLoopSkeleton in the body
12390 if (Instruction *T = CLI->getBody()->getTerminatorOrNull())
12391 T->eraseFromParent();
12392
12393 InsertPointTy BodyIP = CLI->getBodyIP();
12394 if (llvm::Error Err = BodyGen(BodyIP, CLI->getIndVar()))
12395 return Err;
12396
12397 // Body must either fallthrough to the latch or branch directly to it.
12398 if (Instruction *BodyTerminator = CLI->getBody()->getTerminatorOrNull()) {
12399 auto *BodyBr = dyn_cast<UncondBrInst>(Val: BodyTerminator);
12400 if (!BodyBr || BodyBr->getSuccessor() != CLI->getLatch()) {
12401 return make_error<StringError>(
12402 Args: "iterator bodygen must terminate the canonical body with an "
12403 "unconditional branch to the loop latch",
12404 Args: inconvertibleErrorCode());
12405 }
12406 } else {
12407 // Ensure we end the loop body by jumping to the latch.
12408 Builder.SetInsertPoint(CLI->getBody());
12409 Builder.CreateBr(Dest: CLI->getLatch());
12410 }
12411
12412 // Link After -> ContBB
12413 Builder.SetInsertPoint(TheBB: CLI->getAfter(), IP: CLI->getAfter()->begin());
12414 if (!CLI->getAfter()->hasTerminator())
12415 Builder.CreateBr(Dest: ContBB);
12416
12417 return InsertPointTy{ContBB, ContBB->begin()};
12418}
12419
12420/// Mangle the parameter part of the vector function name according to
12421/// their OpenMP classification. The mangling function is defined in
12422/// section 4.5 of the AAVFABI(2021Q1).
12423static std::string mangleVectorParameters(
12424 ArrayRef<llvm::OpenMPIRBuilder::DeclareSimdAttrTy> ParamAttrs) {
12425 SmallString<256> Buffer;
12426 llvm::raw_svector_ostream Out(Buffer);
12427 for (const auto &ParamAttr : ParamAttrs) {
12428 switch (ParamAttr.Kind) {
12429 case llvm::OpenMPIRBuilder::DeclareSimdKindTy::Linear:
12430 Out << 'l';
12431 break;
12432 case llvm::OpenMPIRBuilder::DeclareSimdKindTy::LinearRef:
12433 Out << 'R';
12434 break;
12435 case llvm::OpenMPIRBuilder::DeclareSimdKindTy::LinearUVal:
12436 Out << 'U';
12437 break;
12438 case llvm::OpenMPIRBuilder::DeclareSimdKindTy::LinearVal:
12439 Out << 'L';
12440 break;
12441 case llvm::OpenMPIRBuilder::DeclareSimdKindTy::Uniform:
12442 Out << 'u';
12443 break;
12444 case llvm::OpenMPIRBuilder::DeclareSimdKindTy::Vector:
12445 Out << 'v';
12446 break;
12447 }
12448 if (ParamAttr.HasVarStride)
12449 Out << "s" << ParamAttr.StrideOrArg;
12450 else if (ParamAttr.Kind ==
12451 llvm::OpenMPIRBuilder::DeclareSimdKindTy::Linear ||
12452 ParamAttr.Kind ==
12453 llvm::OpenMPIRBuilder::DeclareSimdKindTy::LinearRef ||
12454 ParamAttr.Kind ==
12455 llvm::OpenMPIRBuilder::DeclareSimdKindTy::LinearUVal ||
12456 ParamAttr.Kind ==
12457 llvm::OpenMPIRBuilder::DeclareSimdKindTy::LinearVal) {
12458 // Don't print the step value if it is not present or if it is
12459 // equal to 1.
12460 if (ParamAttr.StrideOrArg < 0)
12461 Out << 'n' << -ParamAttr.StrideOrArg;
12462 else if (ParamAttr.StrideOrArg != 1)
12463 Out << ParamAttr.StrideOrArg;
12464 }
12465
12466 if (!!ParamAttr.Alignment)
12467 Out << 'a' << ParamAttr.Alignment;
12468 }
12469
12470 return std::string(Out.str());
12471}
12472
12473void OpenMPIRBuilder::emitX86DeclareSimdFunction(
12474 llvm::Function *Fn, unsigned NumElts, const llvm::APSInt &VLENVal,
12475 llvm::ArrayRef<DeclareSimdAttrTy> ParamAttrs, DeclareSimdBranch Branch) {
12476 struct ISADataTy {
12477 char ISA;
12478 unsigned VecRegSize;
12479 };
12480 ISADataTy ISAData[] = {
12481 {.ISA: 'b', .VecRegSize: 128}, // SSE
12482 {.ISA: 'c', .VecRegSize: 256}, // AVX
12483 {.ISA: 'd', .VecRegSize: 256}, // AVX2
12484 {.ISA: 'e', .VecRegSize: 512}, // AVX512
12485 };
12486 llvm::SmallVector<char, 2> Masked;
12487 switch (Branch) {
12488 case DeclareSimdBranch::Undefined:
12489 Masked.push_back(Elt: 'N');
12490 Masked.push_back(Elt: 'M');
12491 break;
12492 case DeclareSimdBranch::Notinbranch:
12493 Masked.push_back(Elt: 'N');
12494 break;
12495 case DeclareSimdBranch::Inbranch:
12496 Masked.push_back(Elt: 'M');
12497 break;
12498 }
12499 for (char Mask : Masked) {
12500 for (const ISADataTy &Data : ISAData) {
12501 llvm::SmallString<256> Buffer;
12502 llvm::raw_svector_ostream Out(Buffer);
12503 Out << "_ZGV" << Data.ISA << Mask;
12504 if (!VLENVal) {
12505 assert(NumElts && "Non-zero simdlen/cdtsize expected");
12506 Out << llvm::APSInt::getUnsigned(X: Data.VecRegSize / NumElts);
12507 } else {
12508 Out << VLENVal;
12509 }
12510 Out << mangleVectorParameters(ParamAttrs);
12511 Out << '_' << Fn->getName();
12512 Fn->addFnAttr(Kind: Out.str());
12513 }
12514 }
12515}
12516
12517// Function used to add the attribute. The parameter `VLEN` is templated to
12518// allow the use of `x` when targeting scalable functions for SVE.
12519template <typename T>
12520static void addAArch64VectorName(T VLEN, StringRef LMask, StringRef Prefix,
12521 char ISA, StringRef ParSeq,
12522 StringRef MangledName, bool OutputBecomesInput,
12523 llvm::Function *Fn) {
12524 SmallString<256> Buffer;
12525 llvm::raw_svector_ostream Out(Buffer);
12526 Out << Prefix << ISA << LMask << VLEN;
12527 if (OutputBecomesInput)
12528 Out << 'v';
12529 Out << ParSeq << '_' << MangledName;
12530 Fn->addFnAttr(Kind: Out.str());
12531}
12532
12533// Helper function to generate the Advanced SIMD names depending on the value
12534// of the NDS when simdlen is not present.
12535static void addAArch64AdvSIMDNDSNames(unsigned NDS, StringRef Mask,
12536 StringRef Prefix, char ISA,
12537 StringRef ParSeq, StringRef MangledName,
12538 bool OutputBecomesInput,
12539 llvm::Function *Fn) {
12540 switch (NDS) {
12541 case 8:
12542 addAArch64VectorName(VLEN: 8, LMask: Mask, Prefix, ISA, ParSeq, MangledName,
12543 OutputBecomesInput, Fn);
12544 addAArch64VectorName(VLEN: 16, LMask: Mask, Prefix, ISA, ParSeq, MangledName,
12545 OutputBecomesInput, Fn);
12546 break;
12547 case 16:
12548 addAArch64VectorName(VLEN: 4, LMask: Mask, Prefix, ISA, ParSeq, MangledName,
12549 OutputBecomesInput, Fn);
12550 addAArch64VectorName(VLEN: 8, LMask: Mask, Prefix, ISA, ParSeq, MangledName,
12551 OutputBecomesInput, Fn);
12552 break;
12553 case 32:
12554 addAArch64VectorName(VLEN: 2, LMask: Mask, Prefix, ISA, ParSeq, MangledName,
12555 OutputBecomesInput, Fn);
12556 addAArch64VectorName(VLEN: 4, LMask: Mask, Prefix, ISA, ParSeq, MangledName,
12557 OutputBecomesInput, Fn);
12558 break;
12559 case 64:
12560 case 128:
12561 addAArch64VectorName(VLEN: 2, LMask: Mask, Prefix, ISA, ParSeq, MangledName,
12562 OutputBecomesInput, Fn);
12563 break;
12564 default:
12565 llvm_unreachable("Scalar type is too wide.");
12566 }
12567}
12568
12569/// Emit vector function attributes for AArch64, as defined in the AAVFABI.
12570void OpenMPIRBuilder::emitAArch64DeclareSimdFunction(
12571 llvm::Function *Fn, unsigned UserVLEN,
12572 llvm::ArrayRef<DeclareSimdAttrTy> ParamAttrs, DeclareSimdBranch Branch,
12573 char ISA, unsigned NarrowestDataSize, bool OutputBecomesInput) {
12574 assert((ISA == 'n' || ISA == 's') && "Expected ISA either 's' or 'n'.");
12575
12576 // Sort out parameter sequence.
12577 const std::string ParSeq = mangleVectorParameters(ParamAttrs);
12578 StringRef Prefix = "_ZGV";
12579 StringRef MangledName = Fn->getName();
12580
12581 // Generate simdlen from user input (if any).
12582 if (UserVLEN) {
12583 if (ISA == 's') {
12584 // SVE generates only a masked function.
12585 addAArch64VectorName(VLEN: UserVLEN, LMask: "M", Prefix, ISA, ParSeq, MangledName,
12586 OutputBecomesInput, Fn);
12587 return;
12588 }
12589
12590 switch (Branch) {
12591 case DeclareSimdBranch::Undefined:
12592 addAArch64VectorName(VLEN: UserVLEN, LMask: "N", Prefix, ISA, ParSeq, MangledName,
12593 OutputBecomesInput, Fn);
12594 addAArch64VectorName(VLEN: UserVLEN, LMask: "M", Prefix, ISA, ParSeq, MangledName,
12595 OutputBecomesInput, Fn);
12596 break;
12597 case DeclareSimdBranch::Inbranch:
12598 addAArch64VectorName(VLEN: UserVLEN, LMask: "M", Prefix, ISA, ParSeq, MangledName,
12599 OutputBecomesInput, Fn);
12600 break;
12601 case DeclareSimdBranch::Notinbranch:
12602 addAArch64VectorName(VLEN: UserVLEN, LMask: "N", Prefix, ISA, ParSeq, MangledName,
12603 OutputBecomesInput, Fn);
12604 break;
12605 }
12606 return;
12607 }
12608
12609 if (ISA == 's') {
12610 // SVE, section 3.4.1, item 1.
12611 addAArch64VectorName(VLEN: "x", LMask: "M", Prefix, ISA, ParSeq, MangledName,
12612 OutputBecomesInput, Fn);
12613 return;
12614 }
12615
12616 switch (Branch) {
12617 case DeclareSimdBranch::Undefined:
12618 addAArch64AdvSIMDNDSNames(NDS: NarrowestDataSize, Mask: "N", Prefix, ISA, ParSeq,
12619 MangledName, OutputBecomesInput, Fn);
12620 addAArch64AdvSIMDNDSNames(NDS: NarrowestDataSize, Mask: "M", Prefix, ISA, ParSeq,
12621 MangledName, OutputBecomesInput, Fn);
12622 break;
12623 case DeclareSimdBranch::Inbranch:
12624 addAArch64AdvSIMDNDSNames(NDS: NarrowestDataSize, Mask: "M", Prefix, ISA, ParSeq,
12625 MangledName, OutputBecomesInput, Fn);
12626 break;
12627 case DeclareSimdBranch::Notinbranch:
12628 addAArch64AdvSIMDNDSNames(NDS: NarrowestDataSize, Mask: "N", Prefix, ISA, ParSeq,
12629 MangledName, OutputBecomesInput, Fn);
12630 break;
12631 }
12632}
12633
12634//===----------------------------------------------------------------------===//
12635// OffloadEntriesInfoManager
12636//===----------------------------------------------------------------------===//
12637
12638bool OffloadEntriesInfoManager::empty() const {
12639 return OffloadEntriesTargetRegion.empty() &&
12640 OffloadEntriesDeviceGlobalVar.empty();
12641}
12642
12643unsigned OffloadEntriesInfoManager::getTargetRegionEntryInfoCount(
12644 const TargetRegionEntryInfo &EntryInfo) const {
12645 auto It = OffloadEntriesTargetRegionCount.find(
12646 x: getTargetRegionEntryCountKey(EntryInfo));
12647 if (It == OffloadEntriesTargetRegionCount.end())
12648 return 0;
12649 return It->second;
12650}
12651
12652void OffloadEntriesInfoManager::incrementTargetRegionEntryInfoCount(
12653 const TargetRegionEntryInfo &EntryInfo) {
12654 OffloadEntriesTargetRegionCount[getTargetRegionEntryCountKey(EntryInfo)] =
12655 EntryInfo.Count + 1;
12656}
12657
12658/// Initialize target region entry.
12659void OffloadEntriesInfoManager::initializeTargetRegionEntryInfo(
12660 const TargetRegionEntryInfo &EntryInfo, unsigned Order) {
12661 OffloadEntriesTargetRegion[EntryInfo] =
12662 OffloadEntryInfoTargetRegion(Order, /*Addr=*/nullptr, /*ID=*/nullptr,
12663 OMPTargetRegionEntryTargetRegion);
12664 ++OffloadingEntriesNum;
12665}
12666
12667void OffloadEntriesInfoManager::registerTargetRegionEntryInfo(
12668 TargetRegionEntryInfo EntryInfo, Constant *Addr, Constant *ID,
12669 OMPTargetRegionEntryKind Flags) {
12670 assert(EntryInfo.Count == 0 && "expected default EntryInfo");
12671
12672 // Update the EntryInfo with the next available count for this location.
12673 EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
12674
12675 // If we are emitting code for a target, the entry is already initialized,
12676 // only has to be registered.
12677 if (OMPBuilder->Config.isTargetDevice()) {
12678 // This could happen if the device compilation is invoked standalone.
12679 if (!hasTargetRegionEntryInfo(EntryInfo)) {
12680 return;
12681 }
12682 auto &Entry = OffloadEntriesTargetRegion[EntryInfo];
12683 Entry.setAddress(Addr);
12684 Entry.setID(ID);
12685 Entry.setFlags(Flags);
12686 } else {
12687 if (Flags == OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion &&
12688 hasTargetRegionEntryInfo(EntryInfo, /*IgnoreAddressId*/ true))
12689 return;
12690 assert(!hasTargetRegionEntryInfo(EntryInfo) &&
12691 "Target region entry already registered!");
12692 OffloadEntryInfoTargetRegion Entry(OffloadingEntriesNum, Addr, ID, Flags);
12693 OffloadEntriesTargetRegion[EntryInfo] = Entry;
12694 ++OffloadingEntriesNum;
12695 }
12696 incrementTargetRegionEntryInfoCount(EntryInfo);
12697}
12698
12699bool OffloadEntriesInfoManager::hasTargetRegionEntryInfo(
12700 TargetRegionEntryInfo EntryInfo, bool IgnoreAddressId) const {
12701
12702 // Update the EntryInfo with the next available count for this location.
12703 EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
12704
12705 auto It = OffloadEntriesTargetRegion.find(x: EntryInfo);
12706 if (It == OffloadEntriesTargetRegion.end()) {
12707 return false;
12708 }
12709 // Fail if this entry is already registered.
12710 if (!IgnoreAddressId && (It->second.getAddress() || It->second.getID()))
12711 return false;
12712 return true;
12713}
12714
12715void OffloadEntriesInfoManager::actOnTargetRegionEntriesInfo(
12716 const OffloadTargetRegionEntryInfoActTy &Action) {
12717 // Scan all target region entries and perform the provided action.
12718 for (const auto &It : OffloadEntriesTargetRegion) {
12719 Action(It.first, It.second);
12720 }
12721}
12722
12723void OffloadEntriesInfoManager::initializeDeviceGlobalVarEntryInfo(
12724 StringRef Name, OMPTargetGlobalVarEntryKind Flags, unsigned Order) {
12725 OffloadEntriesDeviceGlobalVar.try_emplace(Key: Name, Args&: Order, Args&: Flags);
12726 ++OffloadingEntriesNum;
12727}
12728
12729void OffloadEntriesInfoManager::registerDeviceGlobalVarEntryInfo(
12730 StringRef VarName, Constant *Addr, int64_t VarSize,
12731 OMPTargetGlobalVarEntryKind Flags, GlobalValue::LinkageTypes Linkage) {
12732 if (OMPBuilder->Config.isTargetDevice()) {
12733 // This could happen if the device compilation is invoked standalone.
12734 if (!hasDeviceGlobalVarEntryInfo(VarName))
12735 return;
12736 auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
12737 if (Entry.getAddress() && hasDeviceGlobalVarEntryInfo(VarName)) {
12738 if (Entry.getVarSize() == 0) {
12739 Entry.setVarSize(VarSize);
12740 Entry.setLinkage(Linkage);
12741 }
12742 return;
12743 }
12744 Entry.setVarSize(VarSize);
12745 Entry.setLinkage(Linkage);
12746 Entry.setAddress(Addr);
12747 } else {
12748 if (hasDeviceGlobalVarEntryInfo(VarName)) {
12749 auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
12750 assert(Entry.isValid() && Entry.getFlags() == Flags &&
12751 "Entry not initialized!");
12752 if (Entry.getVarSize() == 0) {
12753 Entry.setVarSize(VarSize);
12754 Entry.setLinkage(Linkage);
12755 }
12756 return;
12757 }
12758 if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect ||
12759 Flags ==
12760 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirectVTable)
12761 OffloadEntriesDeviceGlobalVar.try_emplace(Key: VarName, Args&: OffloadingEntriesNum,
12762 Args&: Addr, Args&: VarSize, Args&: Flags, Args&: Linkage,
12763 Args: VarName.str());
12764 else
12765 OffloadEntriesDeviceGlobalVar.try_emplace(
12766 Key: VarName, Args&: OffloadingEntriesNum, Args&: Addr, Args&: VarSize, Args&: Flags, Args&: Linkage, Args: "");
12767 ++OffloadingEntriesNum;
12768 }
12769}
12770
12771void OffloadEntriesInfoManager::actOnDeviceGlobalVarEntriesInfo(
12772 const OffloadDeviceGlobalVarEntryInfoActTy &Action) {
12773 // Scan all target region entries and perform the provided action.
12774 for (const auto &E : OffloadEntriesDeviceGlobalVar)
12775 Action(E.getKey(), E.getValue());
12776}
12777
12778//===----------------------------------------------------------------------===//
12779// CanonicalLoopInfo
12780//===----------------------------------------------------------------------===//
12781
12782void CanonicalLoopInfo::collectControlBlocks(
12783 SmallVectorImpl<BasicBlock *> &BBs) {
12784 // We only count those BBs as control block for which we do not need to
12785 // reverse the CFG, i.e. not the loop body which can contain arbitrary control
12786 // flow. For consistency, this also means we do not add the Body block, which
12787 // is just the entry to the body code.
12788 BBs.reserve(N: BBs.size() + 6);
12789 BBs.append(IL: {getPreheader(), Header, Cond, Latch, Exit, getAfter()});
12790}
12791
12792BasicBlock *CanonicalLoopInfo::getPreheader() const {
12793 assert(isValid() && "Requires a valid canonical loop");
12794 for (BasicBlock *Pred : predecessors(BB: Header)) {
12795 if (Pred != Latch)
12796 return Pred;
12797 }
12798 llvm_unreachable("Missing preheader");
12799}
12800
12801void CanonicalLoopInfo::setTripCount(Value *TripCount) {
12802 assert(isValid() && "Requires a valid canonical loop");
12803
12804 Instruction *CmpI = &getCond()->front();
12805 assert(isa<CmpInst>(CmpI) && "First inst must compare IV with TripCount");
12806 CmpI->setOperand(i: 1, Val: TripCount);
12807
12808#ifndef NDEBUG
12809 assertOK();
12810#endif
12811}
12812
12813void CanonicalLoopInfo::mapIndVar(
12814 llvm::function_ref<Value *(Instruction *)> Updater) {
12815 assert(isValid() && "Requires a valid canonical loop");
12816
12817 Instruction *OldIV = getIndVar();
12818
12819 // Record all uses excluding those introduced by the updater. Uses by the
12820 // CanonicalLoopInfo itself to keep track of the number of iterations are
12821 // excluded.
12822 SmallVector<Use *> ReplacableUses;
12823 for (Use &U : OldIV->uses()) {
12824 auto *User = dyn_cast<Instruction>(Val: U.getUser());
12825 if (!User)
12826 continue;
12827 if (User->getParent() == getCond())
12828 continue;
12829 if (User->getParent() == getLatch())
12830 continue;
12831 ReplacableUses.push_back(Elt: &U);
12832 }
12833
12834 // Run the updater that may introduce new uses
12835 Value *NewIV = Updater(OldIV);
12836
12837 // Replace the old uses with the value returned by the updater.
12838 for (Use *U : ReplacableUses)
12839 U->set(NewIV);
12840
12841#ifndef NDEBUG
12842 assertOK();
12843#endif
12844}
12845
12846void CanonicalLoopInfo::assertOK() const {
12847#ifndef NDEBUG
12848 // No constraints if this object currently does not describe a loop.
12849 if (!isValid())
12850 return;
12851
12852 BasicBlock *Preheader = getPreheader();
12853 BasicBlock *Body = getBody();
12854 BasicBlock *After = getAfter();
12855
12856 // Verify standard control-flow we use for OpenMP loops.
12857 assert(Preheader);
12858 assert(isa<UncondBrInst>(Preheader->getTerminator()) &&
12859 "Preheader must terminate with unconditional branch");
12860 assert(Preheader->getSingleSuccessor() == Header &&
12861 "Preheader must jump to header");
12862
12863 assert(Header);
12864 assert(isa<UncondBrInst>(Header->getTerminator()) &&
12865 "Header must terminate with unconditional branch");
12866 assert(Header->getSingleSuccessor() == Cond &&
12867 "Header must jump to exiting block");
12868
12869 assert(Cond);
12870 assert(Cond->getSinglePredecessor() == Header &&
12871 "Exiting block only reachable from header");
12872
12873 assert(isa<CondBrInst>(Cond->getTerminator()) &&
12874 "Exiting block must terminate with conditional branch");
12875 assert(cast<CondBrInst>(Cond->getTerminator())->getSuccessor(0) == Body &&
12876 "Exiting block's first successor jump to the body");
12877 assert(cast<CondBrInst>(Cond->getTerminator())->getSuccessor(1) == Exit &&
12878 "Exiting block's second successor must exit the loop");
12879
12880 assert(Body);
12881 assert(Body->getSinglePredecessor() == Cond &&
12882 "Body only reachable from exiting block");
12883 assert(!isa<PHINode>(Body->front()));
12884
12885 assert(Latch);
12886 assert(isa<UncondBrInst>(Latch->getTerminator()) &&
12887 "Latch must terminate with unconditional branch");
12888 assert(Latch->getSingleSuccessor() == Header && "Latch must jump to header");
12889 // TODO: To support simple redirecting of the end of the body code that has
12890 // multiple; introduce another auxiliary basic block like preheader and after.
12891 assert(Latch->getSinglePredecessor() != nullptr);
12892 assert(!isa<PHINode>(Latch->front()));
12893
12894 assert(Exit);
12895 assert(isa<UncondBrInst>(Exit->getTerminator()) &&
12896 "Exit block must terminate with unconditional branch");
12897 assert(Exit->getSingleSuccessor() == After &&
12898 "Exit block must jump to after block");
12899
12900 assert(After);
12901 assert(After->getSinglePredecessor() == Exit &&
12902 "After block only reachable from exit block");
12903 assert(After->empty() || !isa<PHINode>(After->front()));
12904
12905 Instruction *IndVar = getIndVar();
12906 assert(IndVar && "Canonical induction variable not found?");
12907 assert(isa<IntegerType>(IndVar->getType()) &&
12908 "Induction variable must be an integer");
12909 assert(cast<PHINode>(IndVar)->getParent() == Header &&
12910 "Induction variable must be a PHI in the loop header");
12911 assert(cast<PHINode>(IndVar)->getIncomingBlock(0) == Preheader);
12912 assert(
12913 cast<ConstantInt>(cast<PHINode>(IndVar)->getIncomingValue(0))->isZero());
12914 assert(cast<PHINode>(IndVar)->getIncomingBlock(1) == Latch);
12915
12916 auto *NextIndVar = cast<PHINode>(IndVar)->getIncomingValue(1);
12917 assert(cast<Instruction>(NextIndVar)->getParent() == Latch);
12918 assert(cast<BinaryOperator>(NextIndVar)->getOpcode() == BinaryOperator::Add);
12919 assert(cast<BinaryOperator>(NextIndVar)->getOperand(0) == IndVar);
12920 assert(cast<ConstantInt>(cast<BinaryOperator>(NextIndVar)->getOperand(1))
12921 ->isOne());
12922
12923 Value *TripCount = getTripCount();
12924 assert(TripCount && "Loop trip count not found?");
12925 assert(IndVar->getType() == TripCount->getType() &&
12926 "Trip count and induction variable must have the same type");
12927
12928 auto *CmpI = cast<CmpInst>(&Cond->front());
12929 assert(CmpI->getPredicate() == CmpInst::ICMP_ULT &&
12930 "Exit condition must be a signed less-than comparison");
12931 assert(CmpI->getOperand(0) == IndVar &&
12932 "Exit condition must compare the induction variable");
12933 assert(CmpI->getOperand(1) == TripCount &&
12934 "Exit condition must compare with the trip count");
12935#endif
12936}
12937
12938void CanonicalLoopInfo::invalidate() {
12939 Header = nullptr;
12940 Cond = nullptr;
12941 Latch = nullptr;
12942 Exit = nullptr;
12943}
12944