1//===-- NVPTXSetByValParamAlign.cpp - Set byval param alignment -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Set explicit alignment on byval parameter attributes in the NVPTX backend.
10// Without this, the alignment is left unspecified and IR-level analyses (e.g.,
11// computeKnownBits via Value::getPointerAlignment) conservatively assume
12// Align(1), since the actual alignment is a target-specific codegen detail not
13// visible at the IR level.
14//
15// The alignment is chosen as follows:
16// - Externally-visible functions: ABI type alignment (capped at 128).
17// - Internal/private functions: max(16, ABI align) to enable 128-bit
18// vectorized param loads. The compiler can _increase_ alignment beyond ABI
19// in this case because it has control over all of the call sites and byval
20// parameters are copies allocated by the caller in .param space.
21//
22// After updating the attribute, the pass propagates the improved alignment to
23// all loads from the byval pointer that use a known constant offset.
24//
25// TODO: Consider removing the load propagation in favor of infer-alignment,
26// which should be able to pick up the improved alignment from the attribute.
27//
28//===----------------------------------------------------------------------===//
29
30#include "NVPTX.h"
31#include "NVPTXUtilities.h"
32#include "llvm/IR/Function.h"
33#include "llvm/IR/Instructions.h"
34#include "llvm/InitializePasses.h"
35#include "llvm/Pass.h"
36#include "llvm/Support/Debug.h"
37#include <queue>
38
39#define DEBUG_TYPE "nvptx-set-byval-param-align"
40
41using namespace llvm;
42
43namespace {
44class NVPTXSetByValParamAlignLegacyPass : public FunctionPass {
45 bool runOnFunction(Function &F) override;
46
47public:
48 static char ID;
49 NVPTXSetByValParamAlignLegacyPass() : FunctionPass(ID) {}
50 StringRef getPassName() const override {
51 return "Set alignment of byval parameters (NVPTX)";
52 }
53};
54} // namespace
55
56char NVPTXSetByValParamAlignLegacyPass::ID = 0;
57
58INITIALIZE_PASS(NVPTXSetByValParamAlignLegacyPass,
59 "nvptx-set-byval-param-align",
60 "Set alignment of byval parameters (NVPTX)", false, false)
61
62static Align setByValParamAlign(Argument *Arg) {
63 Function *F = Arg->getParent();
64 Type *ByValType = Arg->getParamByValType();
65 const DataLayout &DL = F->getDataLayout();
66
67 const Align OptimizedAlign = getFunctionParamOptimizedAlign(F, ArgTy: ByValType, DL);
68 const Align CurrentAlign = Arg->getParamAlign().valueOrOne();
69
70 if (CurrentAlign >= OptimizedAlign)
71 return CurrentAlign;
72
73 LLVM_DEBUG(dbgs() << "Try to use alignment " << OptimizedAlign.value()
74 << " instead of " << CurrentAlign.value() << " for " << *Arg
75 << '\n');
76
77 Arg->removeAttr(Kind: Attribute::Alignment);
78 Arg->addAttr(Attr: Attribute::getWithAlignment(Context&: F->getContext(), Alignment: OptimizedAlign));
79
80 return OptimizedAlign;
81}
82
83// Adjust alignment of arguments passed byval in .param address space. We can
84// increase alignment of such arguments in a way that ensures that we can
85// effectively vectorize their loads. We should also traverse all loads from
86// byval pointer and adjust their alignment, if those were using known offset.
87// Such alignment changes must be conformed with parameter store and load in
88// NVPTXTargetLowering::LowerCall.
89static void propagateAlignmentToLoads(Value *Val, Align NewAlign,
90 const DataLayout &DL) {
91 struct Load {
92 LoadInst *Inst;
93 uint64_t Offset;
94 };
95
96 struct LoadContext {
97 Value *InitialVal;
98 uint64_t Offset;
99 };
100
101 SmallVector<Load> Loads;
102 std::queue<LoadContext> Worklist;
103 Worklist.push(x: {.InitialVal: Val, .Offset: 0});
104
105 while (!Worklist.empty()) {
106 LoadContext Ctx = Worklist.front();
107 Worklist.pop();
108
109 for (User *CurUser : Ctx.InitialVal->users()) {
110 if (auto *I = dyn_cast<LoadInst>(Val: CurUser))
111 Loads.push_back(Elt: {.Inst: I, .Offset: Ctx.Offset});
112 else if (isa<BitCastInst>(Val: CurUser) || isa<AddrSpaceCastInst>(Val: CurUser))
113 Worklist.push(x: {.InitialVal: cast<Instruction>(Val: CurUser), .Offset: Ctx.Offset});
114 else if (auto *I = dyn_cast<GetElementPtrInst>(Val: CurUser)) {
115 APInt OffsetAccumulated =
116 APInt::getZero(numBits: DL.getIndexTypeSizeInBits(Ty: I->getType()));
117
118 if (!I->accumulateConstantOffset(DL, Offset&: OffsetAccumulated))
119 continue;
120
121 uint64_t OffsetLimit = -1;
122 uint64_t Offset = OffsetAccumulated.getLimitedValue(Limit: OffsetLimit);
123 assert(Offset != OffsetLimit && "Expect Offset less than UINT64_MAX");
124
125 Worklist.push(x: {.InitialVal: I, .Offset: Ctx.Offset + Offset});
126 }
127 }
128 }
129
130 for (Load &CurLoad : Loads) {
131 Align NewLoadAlign = commonAlignment(A: NewAlign, Offset: CurLoad.Offset);
132 Align CurLoadAlign = CurLoad.Inst->getAlign();
133 CurLoad.Inst->setAlignment(std::max(a: NewLoadAlign, b: CurLoadAlign));
134 }
135}
136
137static bool setByValParamAlignment(Function &F) {
138 const DataLayout &DL = F.getDataLayout();
139 bool Changed = false;
140 for (Argument &Arg : F.args()) {
141 if (!Arg.hasByValAttr())
142 continue;
143 const Align NewArgAlign = setByValParamAlign(&Arg);
144 propagateAlignmentToLoads(Val: &Arg, NewAlign: NewArgAlign, DL);
145 Changed = true;
146 }
147 return Changed;
148}
149
150bool NVPTXSetByValParamAlignLegacyPass::runOnFunction(Function &F) {
151 return setByValParamAlignment(F);
152}
153
154FunctionPass *llvm::createNVPTXSetByValParamAlignPass() {
155 return new NVPTXSetByValParamAlignLegacyPass();
156}
157
158PreservedAnalyses
159NVPTXSetByValParamAlignPass::run(Function &F, FunctionAnalysisManager &AM) {
160 return setByValParamAlignment(F) ? PreservedAnalyses::none()
161 : PreservedAnalyses::all();
162}
163