1//===-- AMDGPUCodeGenPrepare.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file
10/// This pass does misc. AMDGPU optimizations on IR *just* before instruction
11/// selection.
12//
13//===----------------------------------------------------------------------===//
14
15#include "AMDGPU.h"
16#include "AMDGPUTargetMachine.h"
17#include "llvm/Analysis/AssumptionCache.h"
18#include "llvm/Analysis/UniformityAnalysis.h"
19#include "llvm/Analysis/ValueTracking.h"
20#include "llvm/CodeGen/TargetPassConfig.h"
21#include "llvm/IR/IRBuilder.h"
22#include "llvm/IR/InstVisitor.h"
23#include "llvm/InitializePasses.h"
24#include "llvm/Support/CommandLine.h"
25#include "llvm/Support/KnownBits.h"
26#include "llvm/Transforms/Utils/Local.h"
27
28#define DEBUG_TYPE "amdgpu-late-codegenprepare"
29
30using namespace llvm;
31
32// Scalar load widening needs running after load-store-vectorizer as that pass
33// doesn't handle overlapping cases. In addition, this pass enhances the
34// widening to handle cases where scalar sub-dword loads are naturally aligned
35// only but not dword aligned.
36static cl::opt<bool>
37 WidenLoads("amdgpu-late-codegenprepare-widen-constant-loads",
38 cl::desc("Widen sub-dword constant address space loads in "
39 "AMDGPULateCodeGenPrepare"),
40 cl::ReallyHidden, cl::init(Val: true));
41
42namespace {
43
44class AMDGPULateCodeGenPrepare
45 : public FunctionPass,
46 public InstVisitor<AMDGPULateCodeGenPrepare, bool> {
47 Module *Mod = nullptr;
48 const DataLayout *DL = nullptr;
49
50 AssumptionCache *AC = nullptr;
51 UniformityInfo *UA = nullptr;
52
53 SmallVector<WeakTrackingVH, 8> DeadInsts;
54
55public:
56 static char ID;
57
58 AMDGPULateCodeGenPrepare() : FunctionPass(ID) {}
59
60 StringRef getPassName() const override {
61 return "AMDGPU IR late optimizations";
62 }
63
64 void getAnalysisUsage(AnalysisUsage &AU) const override {
65 AU.addRequired<TargetPassConfig>();
66 AU.addRequired<AssumptionCacheTracker>();
67 AU.addRequired<UniformityInfoWrapperPass>();
68 AU.setPreservesAll();
69 }
70
71 bool doInitialization(Module &M) override;
72 bool runOnFunction(Function &F) override;
73
74 bool visitInstruction(Instruction &) { return false; }
75
76 // Check if the specified value is at least DWORD aligned.
77 bool isDWORDAligned(const Value *V) const {
78 KnownBits Known = computeKnownBits(V, DL: *DL, Depth: 0, AC);
79 return Known.countMinTrailingZeros() >= 2;
80 }
81
82 bool canWidenScalarExtLoad(LoadInst &LI) const;
83 bool visitLoadInst(LoadInst &LI);
84};
85
86using ValueToValueMap = DenseMap<const Value *, Value *>;
87
88class LiveRegOptimizer {
89private:
90 Module *Mod = nullptr;
91 const DataLayout *DL = nullptr;
92 const GCNSubtarget *ST;
93 /// The scalar type to convert to
94 Type *ConvertToScalar;
95 /// The set of visited Instructions
96 SmallPtrSet<Instruction *, 4> Visited;
97 /// Map of Value -> Converted Value
98 ValueToValueMap ValMap;
99 /// Map of containing conversions from Optimal Type -> Original Type per BB.
100 DenseMap<BasicBlock *, ValueToValueMap> BBUseValMap;
101
102public:
103 /// Calculate the and \p return the type to convert to given a problematic \p
104 /// OriginalType. In some instances, we may widen the type (e.g. v2i8 -> i32).
105 Type *calculateConvertType(Type *OriginalType);
106 /// Convert the virtual register defined by \p V to the compatible vector of
107 /// legal type
108 Value *convertToOptType(Instruction *V, BasicBlock::iterator &InstPt);
109 /// Convert the virtual register defined by \p V back to the original type \p
110 /// ConvertType, stripping away the MSBs in cases where there was an imperfect
111 /// fit (e.g. v2i32 -> v7i8)
112 Value *convertFromOptType(Type *ConvertType, Instruction *V,
113 BasicBlock::iterator &InstPt,
114 BasicBlock *InsertBlock);
115 /// Check for problematic PHI nodes or cross-bb values based on the value
116 /// defined by \p I, and coerce to legal types if necessary. For problematic
117 /// PHI node, we coerce all incoming values in a single invocation.
118 bool optimizeLiveType(Instruction *I,
119 SmallVectorImpl<WeakTrackingVH> &DeadInsts);
120
121 // Whether or not the type should be replaced to avoid inefficient
122 // legalization code
123 bool shouldReplace(Type *ITy) {
124 FixedVectorType *VTy = dyn_cast<FixedVectorType>(Val: ITy);
125 if (!VTy)
126 return false;
127
128 auto TLI = ST->getTargetLowering();
129
130 Type *EltTy = VTy->getElementType();
131 // If the element size is not less than the convert to scalar size, then we
132 // can't do any bit packing
133 if (!EltTy->isIntegerTy() ||
134 EltTy->getScalarSizeInBits() > ConvertToScalar->getScalarSizeInBits())
135 return false;
136
137 // Only coerce illegal types
138 TargetLoweringBase::LegalizeKind LK =
139 TLI->getTypeConversion(Context&: EltTy->getContext(), VT: EVT::getEVT(Ty: EltTy, HandleUnknown: false));
140 return LK.first != TargetLoweringBase::TypeLegal;
141 }
142
143 LiveRegOptimizer(Module *Mod, const GCNSubtarget *ST) : Mod(Mod), ST(ST) {
144 DL = &Mod->getDataLayout();
145 ConvertToScalar = Type::getInt32Ty(C&: Mod->getContext());
146 }
147};
148
149} // end anonymous namespace
150
151bool AMDGPULateCodeGenPrepare::doInitialization(Module &M) {
152 Mod = &M;
153 DL = &Mod->getDataLayout();
154 return false;
155}
156
157bool AMDGPULateCodeGenPrepare::runOnFunction(Function &F) {
158 if (skipFunction(F))
159 return false;
160
161 const TargetPassConfig &TPC = getAnalysis<TargetPassConfig>();
162 const TargetMachine &TM = TPC.getTM<TargetMachine>();
163 const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
164
165 AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
166 UA = &getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo();
167
168 // "Optimize" the virtual regs that cross basic block boundaries. When
169 // building the SelectionDAG, vectors of illegal types that cross basic blocks
170 // will be scalarized and widened, with each scalar living in its
171 // own register. To work around this, this optimization converts the
172 // vectors to equivalent vectors of legal type (which are converted back
173 // before uses in subsequent blocks), to pack the bits into fewer physical
174 // registers (used in CopyToReg/CopyFromReg pairs).
175 LiveRegOptimizer LRO(Mod, &ST);
176
177 bool Changed = false;
178
179 bool HasScalarSubwordLoads = ST.hasScalarSubwordLoads();
180
181 for (auto &BB : reverse(C&: F))
182 for (Instruction &I : make_early_inc_range(Range: reverse(C&: BB))) {
183 Changed |= !HasScalarSubwordLoads && visit(I);
184 Changed |= LRO.optimizeLiveType(I: &I, DeadInsts);
185 }
186
187 RecursivelyDeleteTriviallyDeadInstructionsPermissive(DeadInsts);
188 return Changed;
189}
190
191Type *LiveRegOptimizer::calculateConvertType(Type *OriginalType) {
192 assert(OriginalType->getScalarSizeInBits() <=
193 ConvertToScalar->getScalarSizeInBits());
194
195 FixedVectorType *VTy = cast<FixedVectorType>(Val: OriginalType);
196
197 TypeSize OriginalSize = DL->getTypeSizeInBits(Ty: VTy);
198 TypeSize ConvertScalarSize = DL->getTypeSizeInBits(Ty: ConvertToScalar);
199 unsigned ConvertEltCount =
200 (OriginalSize + ConvertScalarSize - 1) / ConvertScalarSize;
201
202 if (OriginalSize <= ConvertScalarSize)
203 return IntegerType::get(C&: Mod->getContext(), NumBits: ConvertScalarSize);
204
205 return VectorType::get(ElementType: Type::getIntNTy(C&: Mod->getContext(), N: ConvertScalarSize),
206 NumElements: ConvertEltCount, Scalable: false);
207}
208
209Value *LiveRegOptimizer::convertToOptType(Instruction *V,
210 BasicBlock::iterator &InsertPt) {
211 FixedVectorType *VTy = cast<FixedVectorType>(Val: V->getType());
212 Type *NewTy = calculateConvertType(OriginalType: V->getType());
213
214 TypeSize OriginalSize = DL->getTypeSizeInBits(Ty: VTy);
215 TypeSize NewSize = DL->getTypeSizeInBits(Ty: NewTy);
216
217 IRBuilder<> Builder(V->getParent(), InsertPt);
218 // If there is a bitsize match, we can fit the old vector into a new vector of
219 // desired type.
220 if (OriginalSize == NewSize)
221 return Builder.CreateBitCast(V, DestTy: NewTy, Name: V->getName() + ".bc");
222
223 // If there is a bitsize mismatch, we must use a wider vector.
224 assert(NewSize > OriginalSize);
225 uint64_t ExpandedVecElementCount = NewSize / VTy->getScalarSizeInBits();
226
227 SmallVector<int, 8> ShuffleMask;
228 uint64_t OriginalElementCount = VTy->getElementCount().getFixedValue();
229 for (unsigned I = 0; I < OriginalElementCount; I++)
230 ShuffleMask.push_back(Elt: I);
231
232 for (uint64_t I = OriginalElementCount; I < ExpandedVecElementCount; I++)
233 ShuffleMask.push_back(Elt: OriginalElementCount);
234
235 Value *ExpandedVec = Builder.CreateShuffleVector(V, Mask: ShuffleMask);
236 return Builder.CreateBitCast(V: ExpandedVec, DestTy: NewTy, Name: V->getName() + ".bc");
237}
238
239Value *LiveRegOptimizer::convertFromOptType(Type *ConvertType, Instruction *V,
240 BasicBlock::iterator &InsertPt,
241 BasicBlock *InsertBB) {
242 FixedVectorType *NewVTy = cast<FixedVectorType>(Val: ConvertType);
243
244 TypeSize OriginalSize = DL->getTypeSizeInBits(Ty: V->getType());
245 TypeSize NewSize = DL->getTypeSizeInBits(Ty: NewVTy);
246
247 IRBuilder<> Builder(InsertBB, InsertPt);
248 // If there is a bitsize match, we simply convert back to the original type.
249 if (OriginalSize == NewSize)
250 return Builder.CreateBitCast(V, DestTy: NewVTy, Name: V->getName() + ".bc");
251
252 // If there is a bitsize mismatch, then we must have used a wider value to
253 // hold the bits.
254 assert(OriginalSize > NewSize);
255 // For wide scalars, we can just truncate the value.
256 if (!V->getType()->isVectorTy()) {
257 Instruction *Trunc = cast<Instruction>(
258 Val: Builder.CreateTrunc(V, DestTy: IntegerType::get(C&: Mod->getContext(), NumBits: NewSize)));
259 return cast<Instruction>(Val: Builder.CreateBitCast(V: Trunc, DestTy: NewVTy));
260 }
261
262 // For wider vectors, we must strip the MSBs to convert back to the original
263 // type.
264 VectorType *ExpandedVT = VectorType::get(
265 ElementType: Type::getIntNTy(C&: Mod->getContext(), N: NewVTy->getScalarSizeInBits()),
266 NumElements: (OriginalSize / NewVTy->getScalarSizeInBits()), Scalable: false);
267 Instruction *Converted =
268 cast<Instruction>(Val: Builder.CreateBitCast(V, DestTy: ExpandedVT));
269
270 unsigned NarrowElementCount = NewVTy->getElementCount().getFixedValue();
271 SmallVector<int, 8> ShuffleMask(NarrowElementCount);
272 std::iota(first: ShuffleMask.begin(), last: ShuffleMask.end(), value: 0);
273
274 return Builder.CreateShuffleVector(V: Converted, Mask: ShuffleMask);
275}
276
277bool LiveRegOptimizer::optimizeLiveType(
278 Instruction *I, SmallVectorImpl<WeakTrackingVH> &DeadInsts) {
279 SmallVector<Instruction *, 4> Worklist;
280 SmallPtrSet<PHINode *, 4> PhiNodes;
281 SmallPtrSet<Instruction *, 4> Defs;
282 SmallPtrSet<Instruction *, 4> Uses;
283
284 Worklist.push_back(Elt: cast<Instruction>(Val: I));
285 while (!Worklist.empty()) {
286 Instruction *II = Worklist.pop_back_val();
287
288 if (!Visited.insert(Ptr: II).second)
289 continue;
290
291 if (!shouldReplace(ITy: II->getType()))
292 continue;
293
294 if (PHINode *Phi = dyn_cast<PHINode>(Val: II)) {
295 PhiNodes.insert(Ptr: Phi);
296 // Collect all the incoming values of problematic PHI nodes.
297 for (Value *V : Phi->incoming_values()) {
298 // Repeat the collection process for newly found PHI nodes.
299 if (PHINode *OpPhi = dyn_cast<PHINode>(Val: V)) {
300 if (!PhiNodes.count(Ptr: OpPhi) && !Visited.count(Ptr: OpPhi))
301 Worklist.push_back(Elt: OpPhi);
302 continue;
303 }
304
305 Instruction *IncInst = dyn_cast<Instruction>(Val: V);
306 // Other incoming value types (e.g. vector literals) are unhandled
307 if (!IncInst && !isa<ConstantAggregateZero>(Val: V))
308 return false;
309
310 // Collect all other incoming values for coercion.
311 if (IncInst)
312 Defs.insert(Ptr: IncInst);
313 }
314 }
315
316 // Collect all relevant uses.
317 for (User *V : II->users()) {
318 // Repeat the collection process for problematic PHI nodes.
319 if (PHINode *OpPhi = dyn_cast<PHINode>(Val: V)) {
320 if (!PhiNodes.count(Ptr: OpPhi) && !Visited.count(Ptr: OpPhi))
321 Worklist.push_back(Elt: OpPhi);
322 continue;
323 }
324
325 Instruction *UseInst = cast<Instruction>(Val: V);
326 // Collect all uses of PHINodes and any use the crosses BB boundaries.
327 if (UseInst->getParent() != II->getParent() || isa<PHINode>(Val: II)) {
328 Uses.insert(Ptr: UseInst);
329 if (!Defs.count(Ptr: II) && !isa<PHINode>(Val: II)) {
330 Defs.insert(Ptr: II);
331 }
332 }
333 }
334 }
335
336 // Coerce and track the defs.
337 for (Instruction *D : Defs) {
338 if (!ValMap.contains(Val: D)) {
339 BasicBlock::iterator InsertPt = std::next(x: D->getIterator());
340 Value *ConvertVal = convertToOptType(V: D, InsertPt);
341 assert(ConvertVal);
342 ValMap[D] = ConvertVal;
343 }
344 }
345
346 // Construct new-typed PHI nodes.
347 for (PHINode *Phi : PhiNodes) {
348 ValMap[Phi] = PHINode::Create(Ty: calculateConvertType(OriginalType: Phi->getType()),
349 NumReservedValues: Phi->getNumIncomingValues(),
350 NameStr: Phi->getName() + ".tc", InsertBefore: Phi->getIterator());
351 }
352
353 // Connect all the PHI nodes with their new incoming values.
354 for (PHINode *Phi : PhiNodes) {
355 PHINode *NewPhi = cast<PHINode>(Val: ValMap[Phi]);
356 bool MissingIncVal = false;
357 for (int I = 0, E = Phi->getNumIncomingValues(); I < E; I++) {
358 Value *IncVal = Phi->getIncomingValue(i: I);
359 if (isa<ConstantAggregateZero>(Val: IncVal)) {
360 Type *NewType = calculateConvertType(OriginalType: Phi->getType());
361 NewPhi->addIncoming(V: ConstantInt::get(Ty: NewType, V: 0, IsSigned: false),
362 BB: Phi->getIncomingBlock(i: I));
363 } else if (ValMap.contains(Val: IncVal) && ValMap[IncVal])
364 NewPhi->addIncoming(V: ValMap[IncVal], BB: Phi->getIncomingBlock(i: I));
365 else
366 MissingIncVal = true;
367 }
368 if (MissingIncVal) {
369 Value *DeadVal = ValMap[Phi];
370 // The coercion chain of the PHI is broken. Delete the Phi
371 // from the ValMap and any connected / user Phis.
372 SmallVector<Value *, 4> PHIWorklist;
373 SmallPtrSet<Value *, 4> VisitedPhis;
374 PHIWorklist.push_back(Elt: DeadVal);
375 while (!PHIWorklist.empty()) {
376 Value *NextDeadValue = PHIWorklist.pop_back_val();
377 VisitedPhis.insert(Ptr: NextDeadValue);
378 auto OriginalPhi =
379 std::find_if(first: PhiNodes.begin(), last: PhiNodes.end(),
380 pred: [this, &NextDeadValue](PHINode *CandPhi) {
381 return ValMap[CandPhi] == NextDeadValue;
382 });
383 // This PHI may have already been removed from maps when
384 // unwinding a previous Phi
385 if (OriginalPhi != PhiNodes.end())
386 ValMap.erase(Val: *OriginalPhi);
387
388 DeadInsts.emplace_back(Args: cast<Instruction>(Val: NextDeadValue));
389
390 for (User *U : NextDeadValue->users()) {
391 if (!VisitedPhis.contains(Ptr: cast<PHINode>(Val: U)))
392 PHIWorklist.push_back(Elt: U);
393 }
394 }
395 } else {
396 DeadInsts.emplace_back(Args: cast<Instruction>(Val: Phi));
397 }
398 }
399 // Coerce back to the original type and replace the uses.
400 for (Instruction *U : Uses) {
401 // Replace all converted operands for a use.
402 for (auto [OpIdx, Op] : enumerate(First: U->operands())) {
403 if (ValMap.contains(Val: Op) && ValMap[Op]) {
404 Value *NewVal = nullptr;
405 if (BBUseValMap.contains(Val: U->getParent()) &&
406 BBUseValMap[U->getParent()].contains(Val: ValMap[Op]))
407 NewVal = BBUseValMap[U->getParent()][ValMap[Op]];
408 else {
409 BasicBlock::iterator InsertPt = U->getParent()->getFirstNonPHIIt();
410 // We may pick up ops that were previously converted for users in
411 // other blocks. If there is an originally typed definition of the Op
412 // already in this block, simply reuse it.
413 if (isa<Instruction>(Val: Op) && !isa<PHINode>(Val: Op) &&
414 U->getParent() == cast<Instruction>(Val&: Op)->getParent()) {
415 NewVal = Op;
416 } else {
417 NewVal =
418 convertFromOptType(ConvertType: Op->getType(), V: cast<Instruction>(Val: ValMap[Op]),
419 InsertPt, InsertBB: U->getParent());
420 BBUseValMap[U->getParent()][ValMap[Op]] = NewVal;
421 }
422 }
423 assert(NewVal);
424 U->setOperand(i: OpIdx, Val: NewVal);
425 }
426 }
427 }
428
429 return true;
430}
431
432bool AMDGPULateCodeGenPrepare::canWidenScalarExtLoad(LoadInst &LI) const {
433 unsigned AS = LI.getPointerAddressSpace();
434 // Skip non-constant address space.
435 if (AS != AMDGPUAS::CONSTANT_ADDRESS &&
436 AS != AMDGPUAS::CONSTANT_ADDRESS_32BIT)
437 return false;
438 // Skip non-simple loads.
439 if (!LI.isSimple())
440 return false;
441 Type *Ty = LI.getType();
442 // Skip aggregate types.
443 if (Ty->isAggregateType())
444 return false;
445 unsigned TySize = DL->getTypeStoreSize(Ty);
446 // Only handle sub-DWORD loads.
447 if (TySize >= 4)
448 return false;
449 // That load must be at least naturally aligned.
450 if (LI.getAlign() < DL->getABITypeAlign(Ty))
451 return false;
452 // It should be uniform, i.e. a scalar load.
453 return UA->isUniform(I: &LI);
454}
455
456bool AMDGPULateCodeGenPrepare::visitLoadInst(LoadInst &LI) {
457 if (!WidenLoads)
458 return false;
459
460 // Skip if that load is already aligned on DWORD at least as it's handled in
461 // SDAG.
462 if (LI.getAlign() >= 4)
463 return false;
464
465 if (!canWidenScalarExtLoad(LI))
466 return false;
467
468 int64_t Offset = 0;
469 auto *Base =
470 GetPointerBaseWithConstantOffset(Ptr: LI.getPointerOperand(), Offset, DL: *DL);
471 // If that base is not DWORD aligned, it's not safe to perform the following
472 // transforms.
473 if (!isDWORDAligned(V: Base))
474 return false;
475
476 int64_t Adjust = Offset & 0x3;
477 if (Adjust == 0) {
478 // With a zero adjust, the original alignment could be promoted with a
479 // better one.
480 LI.setAlignment(Align(4));
481 return true;
482 }
483
484 IRBuilder<> IRB(&LI);
485 IRB.SetCurrentDebugLocation(LI.getDebugLoc());
486
487 unsigned LdBits = DL->getTypeStoreSizeInBits(Ty: LI.getType());
488 auto IntNTy = Type::getIntNTy(C&: LI.getContext(), N: LdBits);
489
490 auto *NewPtr = IRB.CreateConstGEP1_64(
491 Ty: IRB.getInt8Ty(),
492 Ptr: IRB.CreateAddrSpaceCast(V: Base, DestTy: LI.getPointerOperand()->getType()),
493 Idx0: Offset - Adjust);
494
495 LoadInst *NewLd = IRB.CreateAlignedLoad(Ty: IRB.getInt32Ty(), Ptr: NewPtr, Align: Align(4));
496 NewLd->copyMetadata(SrcInst: LI);
497 NewLd->setMetadata(KindID: LLVMContext::MD_range, Node: nullptr);
498
499 unsigned ShAmt = Adjust * 8;
500 auto *NewVal = IRB.CreateBitCast(
501 V: IRB.CreateTrunc(V: IRB.CreateLShr(LHS: NewLd, RHS: ShAmt), DestTy: IntNTy), DestTy: LI.getType());
502 LI.replaceAllUsesWith(V: NewVal);
503 DeadInsts.emplace_back(Args: &LI);
504
505 return true;
506}
507
508INITIALIZE_PASS_BEGIN(AMDGPULateCodeGenPrepare, DEBUG_TYPE,
509 "AMDGPU IR late optimizations", false, false)
510INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
511INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
512INITIALIZE_PASS_DEPENDENCY(UniformityInfoWrapperPass)
513INITIALIZE_PASS_END(AMDGPULateCodeGenPrepare, DEBUG_TYPE,
514 "AMDGPU IR late optimizations", false, false)
515
516char AMDGPULateCodeGenPrepare::ID = 0;
517
518FunctionPass *llvm::createAMDGPULateCodeGenPreparePass() {
519 return new AMDGPULateCodeGenPrepare();
520}
521