1//===- NVVMIntrRange.cpp - Set range attributes for NVVM intrinsics -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass adds appropriate range attributes for calls to NVVM
10// intrinsics that return a limited range of values.
11//
12//===----------------------------------------------------------------------===//
13
14#include "NVPTX.h"
15#include "NVPTXUtilities.h"
16#include "llvm/IR/InstIterator.h"
17#include "llvm/IR/Instructions.h"
18#include "llvm/IR/IntrinsicInst.h"
19#include "llvm/IR/Intrinsics.h"
20#include "llvm/IR/IntrinsicsNVPTX.h"
21#include "llvm/IR/PassManager.h"
22#include <cstdint>
23
24using namespace llvm;
25
26#define DEBUG_TYPE "nvvm-intr-range"
27
28namespace {
29class NVVMIntrRange : public FunctionPass {
30public:
31 static char ID;
32 NVVMIntrRange() : FunctionPass(ID) {}
33
34 bool runOnFunction(Function &) override;
35};
36} // namespace
37
38FunctionPass *llvm::createNVVMIntrRangePass() { return new NVVMIntrRange(); }
39
40char NVVMIntrRange::ID = 0;
41INITIALIZE_PASS(NVVMIntrRange, "nvvm-intr-range",
42 "Add !range metadata to NVVM intrinsics.", false, false)
43
44// Adds the passed-in [Low,High) range information as metadata to the
45// passed-in call instruction.
46static bool addRangeAttr(uint64_t Low, uint64_t High, IntrinsicInst *II) {
47 if (II->getMetadata(KindID: LLVMContext::MD_range))
48 return false;
49
50 const uint64_t BitWidth = II->getType()->getIntegerBitWidth();
51 ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High));
52
53 if (auto CurrentRange = II->getRange())
54 Range = Range.intersectWith(CR: CurrentRange.value());
55
56 II->addRangeRetAttr(CR: Range);
57 return true;
58}
59
60static bool runNVVMIntrRange(Function &F) {
61 struct Vector3 {
62 unsigned X, Y, Z;
63 };
64
65 // All these annotations are only valid for kernel functions.
66 if (!isKernelFunction(F))
67 return false;
68
69 const auto OverallReqNTID = getOverallReqNTID(F);
70 const auto OverallMaxNTID = getOverallMaxNTID(F);
71 const auto OverallClusterRank = getOverallClusterRank(F);
72
73 // If this function lacks any range information, do nothing.
74 if (!(OverallReqNTID || OverallMaxNTID || OverallClusterRank))
75 return false;
76
77 const unsigned FunctionNTID = OverallReqNTID.value_or(
78 u: OverallMaxNTID.value_or(u: std::numeric_limits<unsigned>::max()));
79
80 const unsigned FunctionClusterRank =
81 OverallClusterRank.value_or(u: std::numeric_limits<unsigned>::max());
82
83 const Vector3 MaxBlockSize{.X: std::min(a: 1024u, b: FunctionNTID),
84 .Y: std::min(a: 1024u, b: FunctionNTID),
85 .Z: std::min(a: 64u, b: FunctionNTID)};
86
87 // We conservatively use the maximum grid size as an upper bound for the
88 // cluster rank.
89 const Vector3 MaxClusterRank{.X: std::min(a: 0x7fffffffu, b: FunctionClusterRank),
90 .Y: std::min(a: 0xffffu, b: FunctionClusterRank),
91 .Z: std::min(a: 0xffffu, b: FunctionClusterRank)};
92
93 const auto ProccessIntrinsic = [&](IntrinsicInst *II) -> bool {
94 switch (II->getIntrinsicID()) {
95 // Index within block
96 case Intrinsic::nvvm_read_ptx_sreg_tid_x:
97 return addRangeAttr(Low: 0, High: MaxBlockSize.X, II);
98 case Intrinsic::nvvm_read_ptx_sreg_tid_y:
99 return addRangeAttr(Low: 0, High: MaxBlockSize.Y, II);
100 case Intrinsic::nvvm_read_ptx_sreg_tid_z:
101 return addRangeAttr(Low: 0, High: MaxBlockSize.Z, II);
102
103 // Block size
104 case Intrinsic::nvvm_read_ptx_sreg_ntid_x:
105 return addRangeAttr(Low: 1, High: MaxBlockSize.X + 1, II);
106 case Intrinsic::nvvm_read_ptx_sreg_ntid_y:
107 return addRangeAttr(Low: 1, High: MaxBlockSize.Y + 1, II);
108 case Intrinsic::nvvm_read_ptx_sreg_ntid_z:
109 return addRangeAttr(Low: 1, High: MaxBlockSize.Z + 1, II);
110
111 // Cluster size
112 case Intrinsic::nvvm_read_ptx_sreg_cluster_ctaid_x:
113 return addRangeAttr(Low: 0, High: MaxClusterRank.X, II);
114 case Intrinsic::nvvm_read_ptx_sreg_cluster_ctaid_y:
115 return addRangeAttr(Low: 0, High: MaxClusterRank.Y, II);
116 case Intrinsic::nvvm_read_ptx_sreg_cluster_ctaid_z:
117 return addRangeAttr(Low: 0, High: MaxClusterRank.Z, II);
118 case Intrinsic::nvvm_read_ptx_sreg_cluster_nctaid_x:
119 return addRangeAttr(Low: 1, High: MaxClusterRank.X + 1, II);
120 case Intrinsic::nvvm_read_ptx_sreg_cluster_nctaid_y:
121 return addRangeAttr(Low: 1, High: MaxClusterRank.Y + 1, II);
122 case Intrinsic::nvvm_read_ptx_sreg_cluster_nctaid_z:
123 return addRangeAttr(Low: 1, High: MaxClusterRank.Z + 1, II);
124
125 case Intrinsic::nvvm_read_ptx_sreg_cluster_ctarank:
126 if (OverallClusterRank)
127 return addRangeAttr(Low: 0, High: FunctionClusterRank, II);
128 break;
129 case Intrinsic::nvvm_read_ptx_sreg_cluster_nctarank:
130 if (OverallClusterRank)
131 return addRangeAttr(Low: 1, High: FunctionClusterRank + 1, II);
132 break;
133 default:
134 return false;
135 }
136 return false;
137 };
138
139 // Go through the calls in this function.
140 bool Changed = false;
141 for (Instruction &I : instructions(F))
142 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: &I))
143 Changed |= ProccessIntrinsic(II);
144
145 return Changed;
146}
147
148bool NVVMIntrRange::runOnFunction(Function &F) { return runNVVMIntrRange(F); }
149
150PreservedAnalyses NVVMIntrRangePass::run(Function &F,
151 FunctionAnalysisManager &AM) {
152 return runNVVMIntrRange(F) ? PreservedAnalyses::none()
153 : PreservedAnalyses::all();
154}
155