1//===- DemandedBits.cpp - Determine demanded bits -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements a demanded bits analysis. A demanded bit is one that
10// contributes to a result; bits that are not demanded can be either zero or
11// one without affecting control or data flow. For example in this sequence:
12//
13// %1 = add i32 %x, %y
14// %2 = trunc i32 %1 to i16
15//
16// Only the lowest 16 bits of %1 are demanded; the rest are removed by the
17// trunc.
18//
19//===----------------------------------------------------------------------===//
20
21#include "llvm/Analysis/DemandedBits.h"
22#include "llvm/ADT/APInt.h"
23#include "llvm/ADT/SetVector.h"
24#include "llvm/Analysis/AssumptionCache.h"
25#include "llvm/Analysis/ValueTracking.h"
26#include "llvm/IR/DataLayout.h"
27#include "llvm/IR/Dominators.h"
28#include "llvm/IR/InstIterator.h"
29#include "llvm/IR/Instruction.h"
30#include "llvm/IR/IntrinsicInst.h"
31#include "llvm/IR/Operator.h"
32#include "llvm/IR/PassManager.h"
33#include "llvm/IR/PatternMatch.h"
34#include "llvm/IR/Type.h"
35#include "llvm/IR/Use.h"
36#include "llvm/Support/Casting.h"
37#include "llvm/Support/Debug.h"
38#include "llvm/Support/KnownBits.h"
39#include "llvm/Support/raw_ostream.h"
40#include <algorithm>
41#include <cstdint>
42
43using namespace llvm;
44using namespace llvm::PatternMatch;
45
46#define DEBUG_TYPE "demanded-bits"
47
48static bool isAlwaysLive(Instruction *I) {
49 return I->isTerminator() || I->isEHPad() || I->mayHaveSideEffects();
50}
51
52void DemandedBits::determineLiveOperandBits(
53 const Instruction *UserI, const Value *Val, unsigned OperandNo,
54 const APInt &AOut, APInt &AB, KnownBits &Known, KnownBits &Known2,
55 bool &KnownBitsComputed) {
56 unsigned BitWidth = AB.getBitWidth();
57
58 // We're called once per operand, but for some instructions, we need to
59 // compute known bits of both operands in order to determine the live bits of
60 // either (when both operands are instructions themselves). We don't,
61 // however, want to do this twice, so we cache the result in APInts that live
62 // in the caller. For the two-relevant-operands case, both operand values are
63 // provided here.
64 auto ComputeKnownBits =
65 [&](unsigned BitWidth, const Value *V1, const Value *V2) {
66 if (KnownBitsComputed)
67 return;
68 KnownBitsComputed = true;
69
70 const DataLayout &DL = UserI->getDataLayout();
71 Known = KnownBits(BitWidth);
72 computeKnownBits(V: V1, Known, DL, AC: &AC, CxtI: UserI, DT: &DT);
73
74 if (V2) {
75 Known2 = KnownBits(BitWidth);
76 computeKnownBits(V: V2, Known&: Known2, DL, AC: &AC, CxtI: UserI, DT: &DT);
77 }
78 };
79
80 switch (UserI->getOpcode()) {
81 default: break;
82 case Instruction::Call:
83 case Instruction::Invoke:
84 if (const auto *II = dyn_cast<IntrinsicInst>(Val: UserI)) {
85 switch (II->getIntrinsicID()) {
86 default: break;
87 case Intrinsic::bswap:
88 // The alive bits of the input are the swapped alive bits of
89 // the output.
90 AB = AOut.byteSwap();
91 break;
92 case Intrinsic::bitreverse:
93 // The alive bits of the input are the reversed alive bits of
94 // the output.
95 AB = AOut.reverseBits();
96 break;
97 case Intrinsic::ctlz:
98 if (OperandNo == 0) {
99 // We need some output bits, so we need all bits of the
100 // input to the left of, and including, the leftmost bit
101 // known to be one.
102 ComputeKnownBits(BitWidth, Val, nullptr);
103 AB = APInt::getHighBitsSet(numBits: BitWidth,
104 hiBitsSet: std::min(a: BitWidth, b: Known.countMaxLeadingZeros()+1));
105 }
106 break;
107 case Intrinsic::cttz:
108 if (OperandNo == 0) {
109 // We need some output bits, so we need all bits of the
110 // input to the right of, and including, the rightmost bit
111 // known to be one.
112 ComputeKnownBits(BitWidth, Val, nullptr);
113 AB = APInt::getLowBitsSet(numBits: BitWidth,
114 loBitsSet: std::min(a: BitWidth, b: Known.countMaxTrailingZeros()+1));
115 }
116 break;
117 case Intrinsic::fshl:
118 case Intrinsic::fshr: {
119 const APInt *SA;
120 if (OperandNo == 2) {
121 // Shift amount is modulo the bitwidth. For powers of two we have
122 // SA % BW == SA & (BW - 1).
123 if (isPowerOf2_32(Value: BitWidth))
124 AB = BitWidth - 1;
125 } else if (match(V: II->getOperand(i_nocapture: 2), P: m_APInt(Res&: SA))) {
126 // Normalize to funnel shift left. APInt shifts of BitWidth are well-
127 // defined, so no need to special-case zero shifts here.
128 uint64_t ShiftAmt = SA->urem(RHS: BitWidth);
129 if (II->getIntrinsicID() == Intrinsic::fshr)
130 ShiftAmt = BitWidth - ShiftAmt;
131
132 if (OperandNo == 0)
133 AB = AOut.lshr(shiftAmt: ShiftAmt);
134 else if (OperandNo == 1)
135 AB = AOut.shl(shiftAmt: BitWidth - ShiftAmt);
136 }
137 break;
138 }
139 case Intrinsic::umax:
140 case Intrinsic::umin:
141 case Intrinsic::smax:
142 case Intrinsic::smin:
143 // If low bits of result are not demanded, they are also not demanded
144 // for the min/max operands.
145 AB = APInt::getBitsSetFrom(numBits: BitWidth, loBit: AOut.countr_zero());
146 break;
147 }
148 }
149 break;
150 case Instruction::Add:
151 if (AOut.isMask()) {
152 AB = AOut;
153 } else {
154 ComputeKnownBits(BitWidth, UserI->getOperand(i: 0), UserI->getOperand(i: 1));
155 AB = determineLiveOperandBitsAdd(OperandNo, AOut, LHS: Known, RHS: Known2);
156 }
157 break;
158 case Instruction::Sub:
159 if (AOut.isMask()) {
160 AB = AOut;
161 } else {
162 ComputeKnownBits(BitWidth, UserI->getOperand(i: 0), UserI->getOperand(i: 1));
163 AB = determineLiveOperandBitsSub(OperandNo, AOut, LHS: Known, RHS: Known2);
164 }
165 break;
166 case Instruction::Mul:
167 // Find the highest live output bit. We don't need any more input
168 // bits than that (adds, and thus subtracts, ripple only to the
169 // left).
170 AB = APInt::getLowBitsSet(numBits: BitWidth, loBitsSet: AOut.getActiveBits());
171 break;
172 case Instruction::Shl:
173 if (OperandNo == 0) {
174 const APInt *ShiftAmtC;
175 if (match(V: UserI->getOperand(i: 1), P: m_APInt(Res&: ShiftAmtC))) {
176 uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(Limit: BitWidth - 1);
177 AB = AOut.lshr(shiftAmt: ShiftAmt);
178
179 // If the shift is nuw/nsw, then the high bits are not dead
180 // (because we've promised that they *must* be zero).
181 const auto *S = cast<ShlOperator>(Val: UserI);
182 if (S->hasNoSignedWrap())
183 AB |= APInt::getHighBitsSet(numBits: BitWidth, hiBitsSet: ShiftAmt+1);
184 else if (S->hasNoUnsignedWrap())
185 AB |= APInt::getHighBitsSet(numBits: BitWidth, hiBitsSet: ShiftAmt);
186 }
187 }
188 break;
189 case Instruction::LShr:
190 if (OperandNo == 0) {
191 const APInt *ShiftAmtC;
192 if (match(V: UserI->getOperand(i: 1), P: m_APInt(Res&: ShiftAmtC))) {
193 uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(Limit: BitWidth - 1);
194 AB = AOut.shl(shiftAmt: ShiftAmt);
195
196 // If the shift is exact, then the low bits are not dead
197 // (they must be zero).
198 if (cast<LShrOperator>(Val: UserI)->isExact())
199 AB |= APInt::getLowBitsSet(numBits: BitWidth, loBitsSet: ShiftAmt);
200 }
201 }
202 break;
203 case Instruction::AShr:
204 if (OperandNo == 0) {
205 const APInt *ShiftAmtC;
206 if (match(V: UserI->getOperand(i: 1), P: m_APInt(Res&: ShiftAmtC))) {
207 uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(Limit: BitWidth - 1);
208 AB = AOut.shl(shiftAmt: ShiftAmt);
209 // Because the high input bit is replicated into the
210 // high-order bits of the result, if we need any of those
211 // bits, then we must keep the highest input bit.
212 if ((AOut & APInt::getHighBitsSet(numBits: BitWidth, hiBitsSet: ShiftAmt))
213 .getBoolValue())
214 AB.setSignBit();
215
216 // If the shift is exact, then the low bits are not dead
217 // (they must be zero).
218 if (cast<AShrOperator>(Val: UserI)->isExact())
219 AB |= APInt::getLowBitsSet(numBits: BitWidth, loBitsSet: ShiftAmt);
220 }
221 }
222 break;
223 case Instruction::And:
224 AB = AOut;
225
226 // For bits that are known zero, the corresponding bits in the
227 // other operand are dead (unless they're both zero, in which
228 // case they can't both be dead, so just mark the LHS bits as
229 // dead).
230 ComputeKnownBits(BitWidth, UserI->getOperand(i: 0), UserI->getOperand(i: 1));
231 if (OperandNo == 0)
232 AB &= ~Known2.Zero;
233 else
234 AB &= ~(Known.Zero & ~Known2.Zero);
235 break;
236 case Instruction::Or:
237 AB = AOut;
238
239 // For bits that are known one, the corresponding bits in the
240 // other operand are dead (unless they're both one, in which
241 // case they can't both be dead, so just mark the LHS bits as
242 // dead).
243 ComputeKnownBits(BitWidth, UserI->getOperand(i: 0), UserI->getOperand(i: 1));
244 if (OperandNo == 0)
245 AB &= ~Known2.One;
246 else
247 AB &= ~(Known.One & ~Known2.One);
248 break;
249 case Instruction::Xor:
250 case Instruction::PHI:
251 AB = AOut;
252 break;
253 case Instruction::Trunc:
254 AB = AOut.zext(width: BitWidth);
255 break;
256 case Instruction::ZExt:
257 AB = AOut.trunc(width: BitWidth);
258 break;
259 case Instruction::SExt:
260 AB = AOut.trunc(width: BitWidth);
261 // Because the high input bit is replicated into the
262 // high-order bits of the result, if we need any of those
263 // bits, then we must keep the highest input bit.
264 if ((AOut & APInt::getHighBitsSet(numBits: AOut.getBitWidth(),
265 hiBitsSet: AOut.getBitWidth() - BitWidth))
266 .getBoolValue())
267 AB.setSignBit();
268 break;
269 case Instruction::Select:
270 if (OperandNo != 0)
271 AB = AOut;
272 break;
273 case Instruction::ExtractElement:
274 if (OperandNo == 0)
275 AB = AOut;
276 break;
277 case Instruction::InsertElement:
278 case Instruction::ShuffleVector:
279 if (OperandNo == 0 || OperandNo == 1)
280 AB = AOut;
281 break;
282 }
283}
284
285void DemandedBits::performAnalysis() {
286 if (Analyzed)
287 // Analysis already completed for this function.
288 return;
289 Analyzed = true;
290
291 Visited.clear();
292 AliveBits.clear();
293 DeadUses.clear();
294
295 SmallSetVector<Instruction*, 16> Worklist;
296
297 // Collect the set of "root" instructions that are known live.
298 for (Instruction &I : instructions(F)) {
299 if (!isAlwaysLive(I: &I))
300 continue;
301
302 LLVM_DEBUG(dbgs() << "DemandedBits: Root: " << I << "\n");
303 // For integer-valued instructions, set up an initial empty set of alive
304 // bits and add the instruction to the work list. For other instructions
305 // add their operands to the work list (for integer values operands, mark
306 // all bits as live).
307 Type *T = I.getType();
308 if (T->isIntOrIntVectorTy()) {
309 if (AliveBits.try_emplace(Key: &I, Args: T->getScalarSizeInBits(), Args: 0).second)
310 Worklist.insert(X: &I);
311
312 continue;
313 }
314
315 // Non-integer-typed instructions...
316 for (Use &OI : I.operands()) {
317 if (auto *J = dyn_cast<Instruction>(Val&: OI)) {
318 Type *T = J->getType();
319 if (T->isIntOrIntVectorTy())
320 AliveBits[J] = APInt::getAllOnes(numBits: T->getScalarSizeInBits());
321 else
322 Visited.insert(Ptr: J);
323 Worklist.insert(X: J);
324 }
325 }
326 // To save memory, we don't add I to the Visited set here. Instead, we
327 // check isAlwaysLive on every instruction when searching for dead
328 // instructions later (we need to check isAlwaysLive for the
329 // integer-typed instructions anyway).
330 }
331
332 // Propagate liveness backwards to operands.
333 while (!Worklist.empty()) {
334 Instruction *UserI = Worklist.pop_back_val();
335
336 LLVM_DEBUG(dbgs() << "DemandedBits: Visiting: " << *UserI);
337 APInt AOut;
338 bool InputIsKnownDead = false;
339 if (UserI->getType()->isIntOrIntVectorTy()) {
340 AOut = AliveBits[UserI];
341 LLVM_DEBUG(dbgs() << " Alive Out: 0x"
342 << Twine::utohexstr(AOut.getLimitedValue()));
343
344 // If all bits of the output are dead, then all bits of the input
345 // are also dead.
346 InputIsKnownDead = !AOut && !isAlwaysLive(I: UserI);
347 }
348 LLVM_DEBUG(dbgs() << "\n");
349
350 KnownBits Known, Known2;
351 bool KnownBitsComputed = false;
352 // Compute the set of alive bits for each operand. These are anded into the
353 // existing set, if any, and if that changes the set of alive bits, the
354 // operand is added to the work-list.
355 for (Use &OI : UserI->operands()) {
356 // We also want to detect dead uses of arguments, but will only store
357 // demanded bits for instructions.
358 auto *I = dyn_cast<Instruction>(Val&: OI);
359 if (!I && !isa<Argument>(Val: OI))
360 continue;
361
362 Type *T = OI->getType();
363 if (T->isIntOrIntVectorTy()) {
364 unsigned BitWidth = T->getScalarSizeInBits();
365 APInt AB = APInt::getAllOnes(numBits: BitWidth);
366 if (InputIsKnownDead) {
367 AB = APInt(BitWidth, 0);
368 } else {
369 // Bits of each operand that are used to compute alive bits of the
370 // output are alive, all others are dead.
371 determineLiveOperandBits(UserI, Val: OI, OperandNo: OI.getOperandNo(), AOut, AB,
372 Known, Known2, KnownBitsComputed);
373
374 // Keep track of uses which have no demanded bits.
375 if (AB.isZero())
376 DeadUses.insert(Ptr: &OI);
377 else
378 DeadUses.erase(Ptr: &OI);
379 }
380
381 if (I) {
382 // If we've added to the set of alive bits (or the operand has not
383 // been previously visited), then re-queue the operand to be visited
384 // again.
385 auto Res = AliveBits.try_emplace(Key: I);
386 if (Res.second || (AB |= Res.first->second) != Res.first->second) {
387 Res.first->second = std::move(AB);
388 Worklist.insert(X: I);
389 }
390 }
391 } else if (I && Visited.insert(Ptr: I).second) {
392 Worklist.insert(X: I);
393 }
394 }
395 }
396}
397
398APInt DemandedBits::getDemandedBits(Instruction *I) {
399 performAnalysis();
400
401 auto Found = AliveBits.find(Val: I);
402 if (Found != AliveBits.end())
403 return Found->second;
404
405 const DataLayout &DL = I->getDataLayout();
406 return APInt::getAllOnes(numBits: DL.getTypeSizeInBits(Ty: I->getType()->getScalarType()));
407}
408
409APInt DemandedBits::getDemandedBits(Use *U) {
410 Type *T = (*U)->getType();
411 auto *UserI = cast<Instruction>(Val: U->getUser());
412 const DataLayout &DL = UserI->getDataLayout();
413 unsigned BitWidth = DL.getTypeSizeInBits(Ty: T->getScalarType());
414
415 // We only track integer uses, everything else produces a mask with all bits
416 // set
417 if (!T->isIntOrIntVectorTy())
418 return APInt::getAllOnes(numBits: BitWidth);
419
420 if (isUseDead(U))
421 return APInt(BitWidth, 0);
422
423 performAnalysis();
424
425 APInt AOut = getDemandedBits(I: UserI);
426 APInt AB = APInt::getAllOnes(numBits: BitWidth);
427 KnownBits Known, Known2;
428 bool KnownBitsComputed = false;
429
430 determineLiveOperandBits(UserI, Val: *U, OperandNo: U->getOperandNo(), AOut, AB, Known,
431 Known2, KnownBitsComputed);
432
433 return AB;
434}
435
436bool DemandedBits::isInstructionDead(Instruction *I) {
437 performAnalysis();
438
439 return !Visited.count(Ptr: I) && !AliveBits.contains(Val: I) && !isAlwaysLive(I);
440}
441
442bool DemandedBits::isUseDead(Use *U) {
443 // We only track integer uses, everything else is assumed live.
444 if (!(*U)->getType()->isIntOrIntVectorTy())
445 return false;
446
447 // Uses by always-live instructions are never dead.
448 auto *UserI = cast<Instruction>(Val: U->getUser());
449 if (isAlwaysLive(I: UserI))
450 return false;
451
452 performAnalysis();
453 if (DeadUses.count(Ptr: U))
454 return true;
455
456 // If no output bits are demanded, no input bits are demanded and the use
457 // is dead. These uses might not be explicitly present in the DeadUses map.
458 if (UserI->getType()->isIntOrIntVectorTy()) {
459 auto Found = AliveBits.find(Val: UserI);
460 if (Found != AliveBits.end() && Found->second.isZero())
461 return true;
462 }
463
464 return false;
465}
466
467void DemandedBits::print(raw_ostream &OS) {
468 auto PrintDB = [&](const Instruction *I, const APInt &A, Value *V = nullptr) {
469 OS << "DemandedBits: 0x" << Twine::utohexstr(Val: A.getLimitedValue())
470 << " for ";
471 if (V) {
472 V->printAsOperand(O&: OS, PrintType: false);
473 OS << " in ";
474 }
475 OS << *I << '\n';
476 };
477
478 OS << "Printing analysis 'Demanded Bits Analysis' for function '" << F.getName() << "':\n";
479 performAnalysis();
480 for (auto &KV : AliveBits) {
481 Instruction *I = KV.first;
482 PrintDB(I, KV.second);
483
484 for (Use &OI : I->operands()) {
485 PrintDB(I, getDemandedBits(U: &OI), OI);
486 }
487 }
488}
489
490static APInt determineLiveOperandBitsAddCarry(unsigned OperandNo,
491 const APInt &AOut,
492 const KnownBits &LHS,
493 const KnownBits &RHS,
494 bool CarryZero, bool CarryOne) {
495 assert(!(CarryZero && CarryOne) &&
496 "Carry can't be zero and one at the same time");
497
498 // The following check should be done by the caller, as it also indicates
499 // that LHS and RHS don't need to be computed.
500 //
501 // if (AOut.isMask())
502 // return AOut;
503
504 // Boundary bits' carry out is unaffected by their carry in.
505 APInt Bound = (LHS.Zero & RHS.Zero) | (LHS.One & RHS.One);
506
507 // First, the alive carry bits are determined from the alive output bits:
508 // Let demand ripple to the right but only up to any set bit in Bound.
509 // AOut = -1----
510 // Bound = ----1-
511 // ACarry&~AOut = --111-
512 APInt RBound = Bound.reverseBits();
513 APInt RAOut = AOut.reverseBits();
514 APInt RProp = RAOut + (RAOut | ~RBound);
515 APInt RACarry = RProp ^ ~RBound;
516 APInt ACarry = RACarry.reverseBits();
517
518 // Then, the alive input bits are determined from the alive carry bits:
519 APInt NeededToMaintainCarryZero;
520 APInt NeededToMaintainCarryOne;
521 if (OperandNo == 0) {
522 NeededToMaintainCarryZero = LHS.Zero | ~RHS.Zero;
523 NeededToMaintainCarryOne = LHS.One | ~RHS.One;
524 } else {
525 NeededToMaintainCarryZero = RHS.Zero | ~LHS.Zero;
526 NeededToMaintainCarryOne = RHS.One | ~LHS.One;
527 }
528
529 // As in computeForAddCarry
530 APInt PossibleSumZero = ~LHS.Zero + ~RHS.Zero + !CarryZero;
531 APInt PossibleSumOne = LHS.One + RHS.One + CarryOne;
532
533 // The below is simplified from
534 //
535 // APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero);
536 // APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One;
537 // APInt CarryUnknown = ~(CarryKnownZero | CarryKnownOne);
538 //
539 // APInt NeededToMaintainCarry =
540 // (CarryKnownZero & NeededToMaintainCarryZero) |
541 // (CarryKnownOne & NeededToMaintainCarryOne) |
542 // CarryUnknown;
543
544 APInt NeededToMaintainCarry = (~PossibleSumZero | NeededToMaintainCarryZero) &
545 (PossibleSumOne | NeededToMaintainCarryOne);
546
547 APInt AB = AOut | (ACarry & NeededToMaintainCarry);
548 return AB;
549}
550
551APInt DemandedBits::determineLiveOperandBitsAdd(unsigned OperandNo,
552 const APInt &AOut,
553 const KnownBits &LHS,
554 const KnownBits &RHS) {
555 return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, RHS, CarryZero: true,
556 CarryOne: false);
557}
558
559APInt DemandedBits::determineLiveOperandBitsSub(unsigned OperandNo,
560 const APInt &AOut,
561 const KnownBits &LHS,
562 const KnownBits &RHS) {
563 KnownBits NRHS;
564 NRHS.Zero = RHS.One;
565 NRHS.One = RHS.Zero;
566 return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, RHS: NRHS, CarryZero: false,
567 CarryOne: true);
568}
569
570AnalysisKey DemandedBitsAnalysis::Key;
571
572DemandedBits DemandedBitsAnalysis::run(Function &F,
573 FunctionAnalysisManager &AM) {
574 auto &AC = AM.getResult<AssumptionAnalysis>(IR&: F);
575 auto &DT = AM.getResult<DominatorTreeAnalysis>(IR&: F);
576 return DemandedBits(F, AC, DT);
577}
578
579PreservedAnalyses DemandedBitsPrinterPass::run(Function &F,
580 FunctionAnalysisManager &AM) {
581 AM.getResult<DemandedBitsAnalysis>(IR&: F).print(OS);
582 return PreservedAnalyses::all();
583}
584