1//===- NVVMIntrRange.cpp - Set range attributes for NVVM intrinsics -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass adds appropriate range attributes for calls to NVVM
10// intrinsics that return a limited range of values.
11//
12//===----------------------------------------------------------------------===//
13
14#include "NVPTX.h"
15#include "NVVMProperties.h"
16#include "llvm/IR/InstIterator.h"
17#include "llvm/IR/Instructions.h"
18#include "llvm/IR/IntrinsicInst.h"
19#include "llvm/IR/Intrinsics.h"
20#include "llvm/IR/IntrinsicsNVPTX.h"
21#include "llvm/IR/PassManager.h"
22#include <cstdint>
23
24using namespace llvm;
25
26#define DEBUG_TYPE "nvvm-intr-range"
27
28namespace {
29class NVVMIntrRange : public FunctionPass {
30public:
31 static char ID;
32 NVVMIntrRange() : FunctionPass(ID) {}
33
34 bool runOnFunction(Function &) override;
35};
36} // namespace
37
38FunctionPass *llvm::createNVVMIntrRangePass() { return new NVVMIntrRange(); }
39
40char NVVMIntrRange::ID = 0;
41INITIALIZE_PASS(NVVMIntrRange, "nvvm-intr-range",
42 "Add !range metadata to NVVM intrinsics.", false, false)
43
44// Adds the passed-in [Low,High) range information as metadata to the
45// passed-in call instruction.
46static bool addRangeAttr(uint64_t Low, uint64_t High, IntrinsicInst *II) {
47 if (II->getMetadata(KindID: LLVMContext::MD_range))
48 return false;
49
50 const uint64_t BitWidth = II->getType()->getIntegerBitWidth();
51 ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High));
52
53 if (auto CurrentRange = II->getRange())
54 Range = Range.intersectWith(CR: CurrentRange.value());
55
56 II->addRangeRetAttr(CR: Range);
57 return true;
58}
59
60static bool runNVVMIntrRange(Function &F) {
61 struct Vector3 {
62 unsigned X, Y, Z;
63 };
64
65 // All these annotations are only valid for kernel functions.
66 if (!isKernelFunction(F))
67 return false;
68
69 auto ReqNTID = getReqNTID(F);
70 const std::optional<uint64_t> OverallMaxNTID = getOverallMaxNTID(F);
71 auto ClusterDim = getClusterDim(F);
72 const std::optional<unsigned> MaxClusterRank = getMaxClusterRank(F);
73
74 // If this function lacks any range information, do nothing.
75 if (ReqNTID.empty() && !OverallMaxNTID && ClusterDim.empty() &&
76 !MaxClusterRank)
77 return false;
78
79 const uint64_t MaxNTID =
80 OverallMaxNTID.value_or(u: std::numeric_limits<uint64_t>::max());
81
82 // When reqntid is specified, block dimensions are exact compile-time
83 // constants. Otherwise, use maxntid (capped at hardware limits) as upper
84 // bounds.
85 Vector3 MinBlockDim, MaxBlockDim;
86 if (!ReqNTID.empty()) {
87 ReqNTID.resize(N: 3, NV: 1);
88 MinBlockDim = MaxBlockDim = {.X: ReqNTID[0], .Y: ReqNTID[1], .Z: ReqNTID[2]};
89 } else {
90 MinBlockDim = {.X: 1, .Y: 1, .Z: 1};
91 MaxBlockDim = {.X: static_cast<unsigned>(std::min(a: uint64_t{1024}, b: MaxNTID)),
92 .Y: static_cast<unsigned>(std::min(a: uint64_t{1024}, b: MaxNTID)),
93 .Z: static_cast<unsigned>(std::min(a: uint64_t{64}, b: MaxNTID))};
94 }
95
96 const bool HasClusterInfo = !ClusterDim.empty() || MaxClusterRank;
97
98 // When cluster_dim is specified, cluster dimensions are exact compile-time
99 // constants. Otherwise, use maxclusterrank (capped at hardware limits) as
100 // upper bounds.
101 Vector3 MinClusterDim, MaxClusterDim;
102 uint64_t MinClusterSize, MaxClusterSize;
103 if (!ClusterDim.empty()) {
104 ClusterDim.resize(N: 3, NV: 1);
105 MinClusterDim =
106 MaxClusterDim = {.X: ClusterDim[0], .Y: ClusterDim[1], .Z: ClusterDim[2]};
107 MinClusterSize = MaxClusterSize =
108 ClusterDim[0] * ClusterDim[1] * ClusterDim[2];
109 } else {
110 const unsigned MaxNctaPerCluster =
111 MaxClusterRank.value_or(u: std::numeric_limits<unsigned>::max());
112 MinClusterDim = {.X: 1, .Y: 1, .Z: 1};
113 MaxClusterDim = {.X: std::min(a: 0x7fffffffu, b: MaxNctaPerCluster),
114 .Y: std::min(a: 0xffffu, b: MaxNctaPerCluster),
115 .Z: std::min(a: 0xffffu, b: MaxNctaPerCluster)};
116 MinClusterSize = 1;
117 MaxClusterSize = MaxNctaPerCluster;
118 }
119
120 const auto ProcessIntrinsic = [&](IntrinsicInst *II) -> bool {
121 switch (II->getIntrinsicID()) {
122 // Index within block
123 case Intrinsic::nvvm_read_ptx_sreg_tid_x:
124 return addRangeAttr(Low: 0, High: MaxBlockDim.X, II);
125 case Intrinsic::nvvm_read_ptx_sreg_tid_y:
126 return addRangeAttr(Low: 0, High: MaxBlockDim.Y, II);
127 case Intrinsic::nvvm_read_ptx_sreg_tid_z:
128 return addRangeAttr(Low: 0, High: MaxBlockDim.Z, II);
129
130 // Block size: use single-value range when reqntid is specified;
131 // InstCombine will fold these to constants later.
132 case Intrinsic::nvvm_read_ptx_sreg_ntid_x:
133 return addRangeAttr(Low: MinBlockDim.X, High: MaxBlockDim.X + 1, II);
134 case Intrinsic::nvvm_read_ptx_sreg_ntid_y:
135 return addRangeAttr(Low: MinBlockDim.Y, High: MaxBlockDim.Y + 1, II);
136 case Intrinsic::nvvm_read_ptx_sreg_ntid_z:
137 return addRangeAttr(Low: MinBlockDim.Z, High: MaxBlockDim.Z + 1, II);
138
139 // Cluster size: use single-value ranges when cluster_dim is specified;
140 // InstCombine will fold cluster_nctaid.* / cluster_nctarank to constants
141 // later.
142 case Intrinsic::nvvm_read_ptx_sreg_cluster_ctaid_x:
143 return addRangeAttr(Low: 0, High: MaxClusterDim.X, II);
144 case Intrinsic::nvvm_read_ptx_sreg_cluster_ctaid_y:
145 return addRangeAttr(Low: 0, High: MaxClusterDim.Y, II);
146 case Intrinsic::nvvm_read_ptx_sreg_cluster_ctaid_z:
147 return addRangeAttr(Low: 0, High: MaxClusterDim.Z, II);
148 case Intrinsic::nvvm_read_ptx_sreg_cluster_nctaid_x:
149 return addRangeAttr(Low: MinClusterDim.X, High: MaxClusterDim.X + 1, II);
150 case Intrinsic::nvvm_read_ptx_sreg_cluster_nctaid_y:
151 return addRangeAttr(Low: MinClusterDim.Y, High: MaxClusterDim.Y + 1, II);
152 case Intrinsic::nvvm_read_ptx_sreg_cluster_nctaid_z:
153 return addRangeAttr(Low: MinClusterDim.Z, High: MaxClusterDim.Z + 1, II);
154
155 case Intrinsic::nvvm_read_ptx_sreg_cluster_ctarank:
156 return HasClusterInfo && addRangeAttr(Low: 0, High: MaxClusterSize, II);
157 case Intrinsic::nvvm_read_ptx_sreg_cluster_nctarank:
158 return HasClusterInfo &&
159 addRangeAttr(Low: MinClusterSize, High: MaxClusterSize + 1, II);
160 default:
161 return false;
162 }
163 };
164
165 // Go through the calls in this function.
166 bool Changed = false;
167 for (Instruction &I : instructions(F))
168 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: &I))
169 Changed |= ProcessIntrinsic(II);
170
171 return Changed;
172}
173
174bool NVVMIntrRange::runOnFunction(Function &F) { return runNVVMIntrRange(F); }
175
176PreservedAnalyses NVVMIntrRangePass::run(Function &F,
177 FunctionAnalysisManager &AM) {
178 return runNVVMIntrRange(F) ? PreservedAnalyses::none()
179 : PreservedAnalyses::all();
180}
181