1//===- TruncInstCombine.cpp -----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// TruncInstCombine - looks for expression graphs post-dominated by TruncInst
10// and for each eligible graph, it will create a reduced bit-width expression,
11// replace the old expression with this new one and remove the old expression.
12// Eligible expression graph is such that:
13// 1. Contains only supported instructions.
14// 2. Supported leaves: ZExtInst, SExtInst, TruncInst and Constant value.
15// 3. Can be evaluated into type with reduced legal bit-width.
16// 4. All instructions in the graph must not have users outside the graph.
17// The only exception is for {ZExt, SExt}Inst with operand type equal to
18// the new reduced type evaluated in (3).
19//
20// The motivation for this optimization is that evaluating and expression using
21// smaller bit-width is preferable, especially for vectorization where we can
22// fit more values in one vectorized instruction. In addition, this optimization
23// may decrease the number of cast instructions, but will not increase it.
24//
25//===----------------------------------------------------------------------===//
26
27#include "AggressiveInstCombineInternal.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ADT/Statistic.h"
30#include "llvm/Analysis/ConstantFolding.h"
31#include "llvm/IR/DataLayout.h"
32#include "llvm/IR/Dominators.h"
33#include "llvm/IR/IRBuilder.h"
34#include "llvm/IR/Instruction.h"
35#include "llvm/Support/KnownBits.h"
36
37using namespace llvm;
38
39#define DEBUG_TYPE "aggressive-instcombine"
40
41STATISTIC(NumExprsReduced, "Number of truncations eliminated by reducing bit "
42 "width of expression graph");
43STATISTIC(NumInstrsReduced,
44 "Number of instructions whose bit width was reduced");
45
46/// Given an instruction and a container, it fills all the relevant operands of
47/// that instruction, with respect to the Trunc expression graph optimizaton.
48static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) {
49 unsigned Opc = I->getOpcode();
50 switch (Opc) {
51 case Instruction::Trunc:
52 case Instruction::ZExt:
53 case Instruction::SExt:
54 // These CastInst are considered leaves of the evaluated expression, thus,
55 // their operands are not relevent.
56 break;
57 case Instruction::Add:
58 case Instruction::Sub:
59 case Instruction::Mul:
60 case Instruction::And:
61 case Instruction::Or:
62 case Instruction::Xor:
63 case Instruction::Shl:
64 case Instruction::LShr:
65 case Instruction::AShr:
66 case Instruction::UDiv:
67 case Instruction::URem:
68 case Instruction::InsertElement:
69 Ops.push_back(Elt: I->getOperand(i: 0));
70 Ops.push_back(Elt: I->getOperand(i: 1));
71 break;
72 case Instruction::ExtractElement:
73 Ops.push_back(Elt: I->getOperand(i: 0));
74 break;
75 case Instruction::Select:
76 Ops.push_back(Elt: I->getOperand(i: 1));
77 Ops.push_back(Elt: I->getOperand(i: 2));
78 break;
79 case Instruction::PHI:
80 llvm::append_range(C&: Ops, R: cast<PHINode>(Val: I)->incoming_values());
81 break;
82 default:
83 llvm_unreachable("Unreachable!");
84 }
85}
86
87bool TruncInstCombine::buildTruncExpressionGraph() {
88 SmallVector<Value *, 8> Worklist;
89 SmallVector<Instruction *, 8> Stack;
90 // Clear old instructions info.
91 InstInfoMap.clear();
92
93 Worklist.push_back(Elt: CurrentTruncInst->getOperand(i_nocapture: 0));
94
95 while (!Worklist.empty()) {
96 Value *Curr = Worklist.back();
97
98 if (isa<Constant>(Val: Curr)) {
99 Worklist.pop_back();
100 continue;
101 }
102
103 auto *I = dyn_cast<Instruction>(Val: Curr);
104 if (!I)
105 return false;
106
107 if (!Stack.empty() && Stack.back() == I) {
108 // Already handled all instruction operands, can remove it from both the
109 // Worklist and the Stack, and add it to the instruction info map.
110 Worklist.pop_back();
111 Stack.pop_back();
112 // Insert I to the Info map.
113 InstInfoMap.try_emplace(Key: I);
114 continue;
115 }
116
117 if (InstInfoMap.count(Key: I)) {
118 Worklist.pop_back();
119 continue;
120 }
121
122 // Add the instruction to the stack before start handling its operands.
123 Stack.push_back(Elt: I);
124
125 unsigned Opc = I->getOpcode();
126 switch (Opc) {
127 case Instruction::Trunc:
128 case Instruction::ZExt:
129 case Instruction::SExt:
130 // trunc(trunc(x)) -> trunc(x)
131 // trunc(ext(x)) -> ext(x) if the source type is smaller than the new dest
132 // trunc(ext(x)) -> trunc(x) if the source type is larger than the new
133 // dest
134 break;
135 case Instruction::Add:
136 case Instruction::Sub:
137 case Instruction::Mul:
138 case Instruction::And:
139 case Instruction::Or:
140 case Instruction::Xor:
141 case Instruction::Shl:
142 case Instruction::LShr:
143 case Instruction::AShr:
144 case Instruction::UDiv:
145 case Instruction::URem:
146 case Instruction::InsertElement:
147 case Instruction::ExtractElement:
148 case Instruction::Select: {
149 SmallVector<Value *, 2> Operands;
150 getRelevantOperands(I, Ops&: Operands);
151 append_range(C&: Worklist, R&: Operands);
152 break;
153 }
154 case Instruction::PHI: {
155 SmallVector<Value *, 2> Operands;
156 getRelevantOperands(I, Ops&: Operands);
157 // Add only operands not in Stack to prevent cycle
158 for (auto *Op : Operands)
159 if (!llvm::is_contained(Range&: Stack, Element: Op))
160 Worklist.push_back(Elt: Op);
161 break;
162 }
163 default:
164 // TODO: Can handle more cases here:
165 // 1. shufflevector
166 // 2. sdiv, srem
167 // ...
168 return false;
169 }
170 }
171 return true;
172}
173
174unsigned TruncInstCombine::getMinBitWidth() {
175 SmallVector<Value *, 8> Worklist;
176 SmallVector<Instruction *, 8> Stack;
177
178 Value *Src = CurrentTruncInst->getOperand(i_nocapture: 0);
179 Type *DstTy = CurrentTruncInst->getType();
180 unsigned TruncBitWidth = DstTy->getScalarSizeInBits();
181 unsigned OrigBitWidth =
182 CurrentTruncInst->getOperand(i_nocapture: 0)->getType()->getScalarSizeInBits();
183
184 if (isa<Constant>(Val: Src))
185 return TruncBitWidth;
186
187 Worklist.push_back(Elt: Src);
188 InstInfoMap[cast<Instruction>(Val: Src)].ValidBitWidth = TruncBitWidth;
189
190 while (!Worklist.empty()) {
191 Value *Curr = Worklist.back();
192
193 if (isa<Constant>(Val: Curr)) {
194 Worklist.pop_back();
195 continue;
196 }
197
198 // Otherwise, it must be an instruction.
199 auto *I = cast<Instruction>(Val: Curr);
200
201 auto &Info = InstInfoMap[I];
202
203 SmallVector<Value *, 2> Operands;
204 getRelevantOperands(I, Ops&: Operands);
205
206 if (!Stack.empty() && Stack.back() == I) {
207 // Already handled all instruction operands, can remove it from both, the
208 // Worklist and the Stack, and update MinBitWidth.
209 Worklist.pop_back();
210 Stack.pop_back();
211 for (auto *Operand : Operands)
212 if (auto *IOp = dyn_cast<Instruction>(Val: Operand))
213 Info.MinBitWidth =
214 std::max(a: Info.MinBitWidth, b: InstInfoMap[IOp].MinBitWidth);
215 continue;
216 }
217
218 // Add the instruction to the stack before start handling its operands.
219 Stack.push_back(Elt: I);
220 unsigned ValidBitWidth = Info.ValidBitWidth;
221
222 // Update minimum bit-width before handling its operands. This is required
223 // when the instruction is part of a loop.
224 Info.MinBitWidth = std::max(a: Info.MinBitWidth, b: Info.ValidBitWidth);
225
226 for (auto *Operand : Operands)
227 if (auto *IOp = dyn_cast<Instruction>(Val: Operand)) {
228 // If we already calculated the minimum bit-width for this valid
229 // bit-width, or for a smaller valid bit-width, then just keep the
230 // answer we already calculated.
231 unsigned IOpBitwidth = InstInfoMap.lookup(Key: IOp).ValidBitWidth;
232 if (IOpBitwidth >= ValidBitWidth)
233 continue;
234 InstInfoMap[IOp].ValidBitWidth = ValidBitWidth;
235 Worklist.push_back(Elt: IOp);
236 }
237 }
238 unsigned MinBitWidth = InstInfoMap.lookup(Key: cast<Instruction>(Val: Src)).MinBitWidth;
239 assert(MinBitWidth >= TruncBitWidth);
240
241 if (MinBitWidth > TruncBitWidth) {
242 // In this case reducing expression with vector type might generate a new
243 // vector type, which is not preferable as it might result in generating
244 // sub-optimal code.
245 if (DstTy->isVectorTy())
246 return OrigBitWidth;
247 // Use the smallest integer type in the range [MinBitWidth, OrigBitWidth).
248 Type *Ty = DL.getSmallestLegalIntType(C&: DstTy->getContext(), Width: MinBitWidth);
249 // Update minimum bit-width with the new destination type bit-width if
250 // succeeded to find such, otherwise, with original bit-width.
251 MinBitWidth = Ty ? Ty->getScalarSizeInBits() : OrigBitWidth;
252 } else { // MinBitWidth == TruncBitWidth
253 // In this case the expression can be evaluated with the trunc instruction
254 // destination type, and trunc instruction can be omitted. However, we
255 // should not perform the evaluation if the original type is a legal scalar
256 // type and the target type is illegal.
257 bool FromLegal = MinBitWidth == 1 || DL.isLegalInteger(Width: OrigBitWidth);
258 bool ToLegal = MinBitWidth == 1 || DL.isLegalInteger(Width: MinBitWidth);
259 if (!DstTy->isVectorTy() && FromLegal && !ToLegal)
260 return OrigBitWidth;
261 }
262 return MinBitWidth;
263}
264
265Type *TruncInstCombine::getBestTruncatedType() {
266 if (!buildTruncExpressionGraph())
267 return nullptr;
268
269 // We don't want to duplicate instructions, which isn't profitable. Thus, we
270 // can't shrink something that has multiple users, unless all users are
271 // post-dominated by the trunc instruction, i.e., were visited during the
272 // expression evaluation.
273 unsigned DesiredBitWidth = 0;
274 for (auto Itr : InstInfoMap) {
275 Instruction *I = Itr.first;
276 if (I->hasOneUse())
277 continue;
278 bool IsExtInst = (isa<ZExtInst>(Val: I) || isa<SExtInst>(Val: I));
279 for (auto *U : I->users())
280 if (auto *UI = dyn_cast<Instruction>(Val: U))
281 if (UI != CurrentTruncInst && !InstInfoMap.count(Key: UI)) {
282 if (!IsExtInst)
283 return nullptr;
284 // If this is an extension from the dest type, we can eliminate it,
285 // even if it has multiple users. Thus, update the DesiredBitWidth and
286 // validate all extension instructions agrees on same DesiredBitWidth.
287 unsigned ExtInstBitWidth =
288 I->getOperand(i: 0)->getType()->getScalarSizeInBits();
289 if (DesiredBitWidth && DesiredBitWidth != ExtInstBitWidth)
290 return nullptr;
291 DesiredBitWidth = ExtInstBitWidth;
292 }
293 }
294
295 unsigned OrigBitWidth =
296 CurrentTruncInst->getOperand(i_nocapture: 0)->getType()->getScalarSizeInBits();
297
298 // Initialize MinBitWidth for shift instructions with the minimum number
299 // that is greater than shift amount (i.e. shift amount + 1).
300 // For `lshr` adjust MinBitWidth so that all potentially truncated
301 // bits of the value-to-be-shifted are zeros.
302 // For `ashr` adjust MinBitWidth so that all potentially truncated
303 // bits of the value-to-be-shifted are sign bits (all zeros or ones)
304 // and even one (first) untruncated bit is sign bit.
305 // Exit early if MinBitWidth is not less than original bitwidth.
306 for (auto &Itr : InstInfoMap) {
307 Instruction *I = Itr.first;
308 if (I->isShift()) {
309 KnownBits KnownRHS = computeKnownBits(V: I->getOperand(i: 1));
310 unsigned MinBitWidth = KnownRHS.getMaxValue()
311 .uadd_sat(RHS: APInt(OrigBitWidth, 1))
312 .getLimitedValue(Limit: OrigBitWidth);
313 if (MinBitWidth == OrigBitWidth)
314 return nullptr;
315 if (I->getOpcode() == Instruction::LShr) {
316 KnownBits KnownLHS = computeKnownBits(V: I->getOperand(i: 0));
317 MinBitWidth =
318 std::max(a: MinBitWidth, b: KnownLHS.getMaxValue().getActiveBits());
319 }
320 if (I->getOpcode() == Instruction::AShr) {
321 unsigned NumSignBits = ComputeNumSignBits(V: I->getOperand(i: 0));
322 MinBitWidth = std::max(a: MinBitWidth, b: OrigBitWidth - NumSignBits + 1);
323 }
324 if (MinBitWidth >= OrigBitWidth)
325 return nullptr;
326 Itr.second.MinBitWidth = MinBitWidth;
327 }
328 if (I->getOpcode() == Instruction::UDiv ||
329 I->getOpcode() == Instruction::URem) {
330 unsigned MinBitWidth = 0;
331 for (const auto &Op : I->operands()) {
332 KnownBits Known = computeKnownBits(V: Op);
333 MinBitWidth =
334 std::max(a: Known.getMaxValue().getActiveBits(), b: MinBitWidth);
335 if (MinBitWidth >= OrigBitWidth)
336 return nullptr;
337 }
338 Itr.second.MinBitWidth = MinBitWidth;
339 }
340 }
341
342 // Calculate minimum allowed bit-width allowed for shrinking the currently
343 // visited truncate's operand.
344 unsigned MinBitWidth = getMinBitWidth();
345
346 // Check that we can shrink to smaller bit-width than original one and that
347 // it is similar to the DesiredBitWidth is such exists.
348 if (MinBitWidth >= OrigBitWidth ||
349 (DesiredBitWidth && DesiredBitWidth != MinBitWidth))
350 return nullptr;
351
352 return IntegerType::get(C&: CurrentTruncInst->getContext(), NumBits: MinBitWidth);
353}
354
355/// Given a reduced scalar type \p Ty and a \p V value, return a reduced type
356/// for \p V, according to its type, if it vector type, return the vector
357/// version of \p Ty, otherwise return \p Ty.
358static Type *getReducedType(Value *V, Type *Ty) {
359 assert(Ty && !Ty->isVectorTy() && "Expect Scalar Type");
360 if (auto *VTy = dyn_cast<VectorType>(Val: V->getType()))
361 return VectorType::get(ElementType: Ty, EC: VTy->getElementCount());
362 return Ty;
363}
364
365Value *TruncInstCombine::getReducedOperand(Value *V, Type *SclTy) {
366 Type *Ty = getReducedType(V, Ty: SclTy);
367 if (auto *C = dyn_cast<Constant>(Val: V)) {
368 C = ConstantExpr::getTrunc(C, Ty);
369 // If we got a constantexpr back, try to simplify it with DL info.
370 return ConstantFoldConstant(C, DL, TLI: &TLI);
371 }
372
373 auto *I = cast<Instruction>(Val: V);
374 Info Entry = InstInfoMap.lookup(Key: I);
375 assert(Entry.NewValue);
376 return Entry.NewValue;
377}
378
379void TruncInstCombine::ReduceExpressionGraph(Type *SclTy) {
380 NumInstrsReduced += InstInfoMap.size();
381 // Pairs of old and new phi-nodes
382 SmallVector<std::pair<PHINode *, PHINode *>, 2> OldNewPHINodes;
383 for (auto &Itr : InstInfoMap) { // Forward
384 Instruction *I = Itr.first;
385 TruncInstCombine::Info &NodeInfo = Itr.second;
386
387 assert(!NodeInfo.NewValue && "Instruction has been evaluated");
388
389 IRBuilder<> Builder(I);
390 Value *Res = nullptr;
391 unsigned Opc = I->getOpcode();
392 switch (Opc) {
393 case Instruction::Trunc:
394 case Instruction::ZExt:
395 case Instruction::SExt: {
396 Type *Ty = getReducedType(V: I, Ty: SclTy);
397 // If the source type of the cast is the type we're trying for then we can
398 // just return the source. There's no need to insert it because it is not
399 // new.
400 if (I->getOperand(i: 0)->getType() == Ty) {
401 assert(!isa<TruncInst>(I) && "Cannot reach here with TruncInst");
402 NodeInfo.NewValue = I->getOperand(i: 0);
403 continue;
404 }
405 // Otherwise, must be the same type of cast, so just reinsert a new one.
406 // This also handles the case of zext(trunc(x)) -> zext(x).
407 Res = Builder.CreateIntCast(V: I->getOperand(i: 0), DestTy: Ty,
408 isSigned: Opc == Instruction::SExt);
409
410 // Update Worklist entries with new value if needed.
411 // There are three possible changes to the Worklist:
412 // 1. Update Old-TruncInst -> New-TruncInst.
413 // 2. Remove Old-TruncInst (if New node is not TruncInst).
414 // 3. Add New-TruncInst (if Old node was not TruncInst).
415 auto *Entry = find(Range&: Worklist, Val: I);
416 if (Entry != Worklist.end()) {
417 if (auto *NewCI = dyn_cast<TruncInst>(Val: Res))
418 *Entry = NewCI;
419 else
420 Worklist.erase(CI: Entry);
421 } else if (auto *NewCI = dyn_cast<TruncInst>(Val: Res))
422 Worklist.push_back(Elt: NewCI);
423 break;
424 }
425 case Instruction::Add:
426 case Instruction::Sub:
427 case Instruction::Mul:
428 case Instruction::And:
429 case Instruction::Or:
430 case Instruction::Xor:
431 case Instruction::Shl:
432 case Instruction::LShr:
433 case Instruction::AShr:
434 case Instruction::UDiv:
435 case Instruction::URem: {
436 Value *LHS = getReducedOperand(V: I->getOperand(i: 0), SclTy);
437 Value *RHS = getReducedOperand(V: I->getOperand(i: 1), SclTy);
438 Res = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)Opc, LHS, RHS);
439 // Preserve `exact` flag since truncation doesn't change exactness
440 if (auto *PEO = dyn_cast<PossiblyExactOperator>(Val: I))
441 if (auto *ResI = dyn_cast<Instruction>(Val: Res))
442 ResI->setIsExact(PEO->isExact());
443 break;
444 }
445 case Instruction::ExtractElement: {
446 Value *Vec = getReducedOperand(V: I->getOperand(i: 0), SclTy);
447 Value *Idx = I->getOperand(i: 1);
448 Res = Builder.CreateExtractElement(Vec, Idx);
449 break;
450 }
451 case Instruction::InsertElement: {
452 Value *Vec = getReducedOperand(V: I->getOperand(i: 0), SclTy);
453 Value *NewElt = getReducedOperand(V: I->getOperand(i: 1), SclTy);
454 Value *Idx = I->getOperand(i: 2);
455 Res = Builder.CreateInsertElement(Vec, NewElt, Idx);
456 break;
457 }
458 case Instruction::Select: {
459 Value *Op0 = I->getOperand(i: 0);
460 Value *LHS = getReducedOperand(V: I->getOperand(i: 1), SclTy);
461 Value *RHS = getReducedOperand(V: I->getOperand(i: 2), SclTy);
462 Res = Builder.CreateSelect(C: Op0, True: LHS, False: RHS);
463 break;
464 }
465 case Instruction::PHI: {
466 Res = Builder.CreatePHI(Ty: getReducedType(V: I, Ty: SclTy), NumReservedValues: I->getNumOperands());
467 OldNewPHINodes.push_back(
468 Elt: std::make_pair(x: cast<PHINode>(Val: I), y: cast<PHINode>(Val: Res)));
469 break;
470 }
471 default:
472 llvm_unreachable("Unhandled instruction");
473 }
474
475 NodeInfo.NewValue = Res;
476 if (auto *ResI = dyn_cast<Instruction>(Val: Res))
477 ResI->takeName(V: I);
478 }
479
480 for (auto &Node : OldNewPHINodes) {
481 PHINode *OldPN = Node.first;
482 PHINode *NewPN = Node.second;
483 for (auto Incoming : zip(t: OldPN->incoming_values(), u: OldPN->blocks()))
484 NewPN->addIncoming(V: getReducedOperand(V: std::get<0>(t&: Incoming), SclTy),
485 BB: std::get<1>(t&: Incoming));
486 }
487
488 Value *Res = getReducedOperand(V: CurrentTruncInst->getOperand(i_nocapture: 0), SclTy);
489 Type *DstTy = CurrentTruncInst->getType();
490 if (Res->getType() != DstTy) {
491 IRBuilder<> Builder(CurrentTruncInst);
492 Res = Builder.CreateIntCast(V: Res, DestTy: DstTy, isSigned: false);
493 if (auto *ResI = dyn_cast<Instruction>(Val: Res))
494 ResI->takeName(V: CurrentTruncInst);
495 }
496 CurrentTruncInst->replaceAllUsesWith(V: Res);
497
498 // Erase old expression graph, which was replaced by the reduced expression
499 // graph.
500 CurrentTruncInst->eraseFromParent();
501 // First, erase old phi-nodes and its uses
502 for (auto &Node : OldNewPHINodes) {
503 PHINode *OldPN = Node.first;
504 OldPN->replaceAllUsesWith(V: PoisonValue::get(T: OldPN->getType()));
505 InstInfoMap.erase(Key: OldPN);
506 OldPN->eraseFromParent();
507 }
508 // Now we have expression graph turned into dag.
509 // We iterate backward, which means we visit the instruction before we
510 // visit any of its operands, this way, when we get to the operand, we already
511 // removed the instructions (from the expression dag) that uses it.
512 for (auto &I : llvm::reverse(C&: InstInfoMap)) {
513 // We still need to check that the instruction has no users before we erase
514 // it, because {SExt, ZExt}Inst Instruction might have other users that was
515 // not reduced, in such case, we need to keep that instruction.
516 if (I.first->use_empty())
517 I.first->eraseFromParent();
518 else
519 assert((isa<SExtInst>(I.first) || isa<ZExtInst>(I.first)) &&
520 "Only {SExt, ZExt}Inst might have unreduced users");
521 }
522}
523
524bool TruncInstCombine::run(Function &F) {
525 bool MadeIRChange = false;
526
527 // Collect all TruncInst in the function into the Worklist for evaluating.
528 for (auto &BB : F) {
529 // Ignore unreachable basic block.
530 if (!DT.isReachableFromEntry(A: &BB))
531 continue;
532 for (auto &I : BB)
533 if (auto *CI = dyn_cast<TruncInst>(Val: &I))
534 Worklist.push_back(Elt: CI);
535 }
536
537 // Process all TruncInst in the Worklist, for each instruction:
538 // 1. Check if it dominates an eligible expression graph to be reduced.
539 // 2. Create a reduced expression graph and replace the old one with it.
540 while (!Worklist.empty()) {
541 CurrentTruncInst = Worklist.pop_back_val();
542
543 if (Type *NewDstSclTy = getBestTruncatedType()) {
544 LLVM_DEBUG(
545 dbgs() << "ICE: TruncInstCombine reducing type of expression graph "
546 "dominated by: "
547 << CurrentTruncInst << '\n');
548 ReduceExpressionGraph(SclTy: NewDstSclTy);
549 ++NumExprsReduced;
550 MadeIRChange = true;
551 }
552 }
553
554 return MadeIRChange;
555}
556