1 | //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===// |
2 | // intrinsics |
3 | // |
4 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
5 | // See https://llvm.org/LICENSE.txt for license information. |
6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
7 | // |
8 | //===----------------------------------------------------------------------===// |
9 | // |
10 | // This pass replaces masked memory intrinsics - when unsupported by the target |
11 | // - with a chain of basic blocks, that deal with the elements one-by-one if the |
12 | // appropriate mask bit is set. |
13 | // |
14 | //===----------------------------------------------------------------------===// |
15 | |
16 | #include "llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h" |
17 | #include "llvm/ADT/Twine.h" |
18 | #include "llvm/Analysis/DomTreeUpdater.h" |
19 | #include "llvm/Analysis/TargetTransformInfo.h" |
20 | #include "llvm/Analysis/VectorUtils.h" |
21 | #include "llvm/IR/BasicBlock.h" |
22 | #include "llvm/IR/Constant.h" |
23 | #include "llvm/IR/Constants.h" |
24 | #include "llvm/IR/DerivedTypes.h" |
25 | #include "llvm/IR/Dominators.h" |
26 | #include "llvm/IR/Function.h" |
27 | #include "llvm/IR/IRBuilder.h" |
28 | #include "llvm/IR/Instruction.h" |
29 | #include "llvm/IR/Instructions.h" |
30 | #include "llvm/IR/IntrinsicInst.h" |
31 | #include "llvm/IR/Type.h" |
32 | #include "llvm/IR/Value.h" |
33 | #include "llvm/InitializePasses.h" |
34 | #include "llvm/Pass.h" |
35 | #include "llvm/Support/Casting.h" |
36 | #include "llvm/Transforms/Scalar.h" |
37 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
38 | #include <cassert> |
39 | #include <optional> |
40 | |
41 | using namespace llvm; |
42 | |
43 | #define DEBUG_TYPE "scalarize-masked-mem-intrin" |
44 | |
45 | namespace { |
46 | |
47 | class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass { |
48 | public: |
49 | static char ID; // Pass identification, replacement for typeid |
50 | |
51 | explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) { |
52 | initializeScalarizeMaskedMemIntrinLegacyPassPass( |
53 | *PassRegistry::getPassRegistry()); |
54 | } |
55 | |
56 | bool runOnFunction(Function &F) override; |
57 | |
58 | StringRef getPassName() const override { |
59 | return "Scalarize Masked Memory Intrinsics" ; |
60 | } |
61 | |
62 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
63 | AU.addRequired<TargetTransformInfoWrapperPass>(); |
64 | AU.addPreserved<DominatorTreeWrapperPass>(); |
65 | } |
66 | }; |
67 | |
68 | } // end anonymous namespace |
69 | |
70 | static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, |
71 | const TargetTransformInfo &TTI, const DataLayout &DL, |
72 | bool HasBranchDivergence, DomTreeUpdater *DTU); |
73 | static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, |
74 | const TargetTransformInfo &TTI, |
75 | const DataLayout &DL, bool HasBranchDivergence, |
76 | DomTreeUpdater *DTU); |
77 | |
78 | char ScalarizeMaskedMemIntrinLegacyPass::ID = 0; |
79 | |
80 | INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE, |
81 | "Scalarize unsupported masked memory intrinsics" , false, |
82 | false) |
83 | INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) |
84 | INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) |
85 | INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE, |
86 | "Scalarize unsupported masked memory intrinsics" , false, |
87 | false) |
88 | |
89 | FunctionPass *llvm::createScalarizeMaskedMemIntrinLegacyPass() { |
90 | return new ScalarizeMaskedMemIntrinLegacyPass(); |
91 | } |
92 | |
93 | static bool isConstantIntVector(Value *Mask) { |
94 | Constant *C = dyn_cast<Constant>(Val: Mask); |
95 | if (!C) |
96 | return false; |
97 | |
98 | unsigned NumElts = cast<FixedVectorType>(Val: Mask->getType())->getNumElements(); |
99 | for (unsigned i = 0; i != NumElts; ++i) { |
100 | Constant *CElt = C->getAggregateElement(Elt: i); |
101 | if (!CElt || !isa<ConstantInt>(Val: CElt)) |
102 | return false; |
103 | } |
104 | |
105 | return true; |
106 | } |
107 | |
108 | static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth, |
109 | unsigned Idx) { |
110 | return DL.isBigEndian() ? VectorWidth - 1 - Idx : Idx; |
111 | } |
112 | |
113 | // Translate a masked load intrinsic like |
114 | // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align, |
115 | // <16 x i1> %mask, <16 x i32> %passthru) |
116 | // to a chain of basic blocks, with loading element one-by-one if |
117 | // the appropriate mask bit is set |
118 | // |
119 | // %1 = bitcast i8* %addr to i32* |
120 | // %2 = extractelement <16 x i1> %mask, i32 0 |
121 | // br i1 %2, label %cond.load, label %else |
122 | // |
123 | // cond.load: ; preds = %0 |
124 | // %3 = getelementptr i32* %1, i32 0 |
125 | // %4 = load i32* %3 |
126 | // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0 |
127 | // br label %else |
128 | // |
129 | // else: ; preds = %0, %cond.load |
130 | // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ poison, %0 ] |
131 | // %6 = extractelement <16 x i1> %mask, i32 1 |
132 | // br i1 %6, label %cond.load1, label %else2 |
133 | // |
134 | // cond.load1: ; preds = %else |
135 | // %7 = getelementptr i32* %1, i32 1 |
136 | // %8 = load i32* %7 |
137 | // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1 |
138 | // br label %else2 |
139 | // |
140 | // else2: ; preds = %else, %cond.load1 |
141 | // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ] |
142 | // %10 = extractelement <16 x i1> %mask, i32 2 |
143 | // br i1 %10, label %cond.load4, label %else5 |
144 | // |
145 | static void scalarizeMaskedLoad(const DataLayout &DL, bool HasBranchDivergence, |
146 | CallInst *CI, DomTreeUpdater *DTU, |
147 | bool &ModifiedDT) { |
148 | Value *Ptr = CI->getArgOperand(i: 0); |
149 | Value *Alignment = CI->getArgOperand(i: 1); |
150 | Value *Mask = CI->getArgOperand(i: 2); |
151 | Value *Src0 = CI->getArgOperand(i: 3); |
152 | |
153 | const Align AlignVal = cast<ConstantInt>(Val: Alignment)->getAlignValue(); |
154 | VectorType *VecType = cast<FixedVectorType>(Val: CI->getType()); |
155 | |
156 | Type *EltTy = VecType->getElementType(); |
157 | |
158 | IRBuilder<> Builder(CI->getContext()); |
159 | Instruction *InsertPt = CI; |
160 | BasicBlock *IfBlock = CI->getParent(); |
161 | |
162 | Builder.SetInsertPoint(InsertPt); |
163 | Builder.SetCurrentDebugLocation(CI->getDebugLoc()); |
164 | |
165 | // Short-cut if the mask is all-true. |
166 | if (isa<Constant>(Val: Mask) && cast<Constant>(Val: Mask)->isAllOnesValue()) { |
167 | LoadInst *NewI = Builder.CreateAlignedLoad(Ty: VecType, Ptr, Align: AlignVal); |
168 | NewI->copyMetadata(SrcInst: *CI); |
169 | NewI->takeName(V: CI); |
170 | CI->replaceAllUsesWith(V: NewI); |
171 | CI->eraseFromParent(); |
172 | return; |
173 | } |
174 | |
175 | // Adjust alignment for the scalar instruction. |
176 | const Align AdjustedAlignVal = |
177 | commonAlignment(A: AlignVal, Offset: EltTy->getPrimitiveSizeInBits() / 8); |
178 | unsigned VectorWidth = cast<FixedVectorType>(Val: VecType)->getNumElements(); |
179 | |
180 | // The result vector |
181 | Value *VResult = Src0; |
182 | |
183 | if (isConstantIntVector(Mask)) { |
184 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
185 | if (cast<Constant>(Val: Mask)->getAggregateElement(Elt: Idx)->isNullValue()) |
186 | continue; |
187 | Value *Gep = Builder.CreateConstInBoundsGEP1_32(Ty: EltTy, Ptr, Idx0: Idx); |
188 | LoadInst *Load = Builder.CreateAlignedLoad(Ty: EltTy, Ptr: Gep, Align: AdjustedAlignVal); |
189 | VResult = Builder.CreateInsertElement(Vec: VResult, NewElt: Load, Idx); |
190 | } |
191 | CI->replaceAllUsesWith(V: VResult); |
192 | CI->eraseFromParent(); |
193 | return; |
194 | } |
195 | |
196 | // Optimize the case where the "masked load" is a predicated load - that is, |
197 | // where the mask is the splat of a non-constant scalar boolean. In that case, |
198 | // use that splated value as the guard on a conditional vector load. |
199 | if (isSplatValue(V: Mask, /*Index=*/0)) { |
200 | Value *Predicate = Builder.CreateExtractElement(Vec: Mask, Idx: uint64_t(0ull), |
201 | Name: Mask->getName() + ".first" ); |
202 | Instruction *ThenTerm = |
203 | SplitBlockAndInsertIfThen(Cond: Predicate, SplitBefore: InsertPt, /*Unreachable=*/false, |
204 | /*BranchWeights=*/nullptr, DTU); |
205 | |
206 | BasicBlock *CondBlock = ThenTerm->getParent(); |
207 | CondBlock->setName("cond.load" ); |
208 | Builder.SetInsertPoint(CondBlock->getTerminator()); |
209 | LoadInst *Load = Builder.CreateAlignedLoad(Ty: VecType, Ptr, Align: AlignVal, |
210 | Name: CI->getName() + ".cond.load" ); |
211 | Load->copyMetadata(SrcInst: *CI); |
212 | |
213 | BasicBlock *PostLoad = ThenTerm->getSuccessor(Idx: 0); |
214 | Builder.SetInsertPoint(TheBB: PostLoad, IP: PostLoad->begin()); |
215 | PHINode *Phi = Builder.CreatePHI(Ty: VecType, /*NumReservedValues=*/2); |
216 | Phi->addIncoming(V: Load, BB: CondBlock); |
217 | Phi->addIncoming(V: Src0, BB: IfBlock); |
218 | Phi->takeName(V: CI); |
219 | |
220 | CI->replaceAllUsesWith(V: Phi); |
221 | CI->eraseFromParent(); |
222 | ModifiedDT = true; |
223 | return; |
224 | } |
225 | // If the mask is not v1i1, use scalar bit test operations. This generates |
226 | // better results on X86 at least. However, don't do this on GPUs and other |
227 | // machines with divergence, as there each i1 needs a vector register. |
228 | Value *SclrMask = nullptr; |
229 | if (VectorWidth != 1 && !HasBranchDivergence) { |
230 | Type *SclrMaskTy = Builder.getIntNTy(N: VectorWidth); |
231 | SclrMask = Builder.CreateBitCast(V: Mask, DestTy: SclrMaskTy, Name: "scalar_mask" ); |
232 | } |
233 | |
234 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
235 | // Fill the "else" block, created in the previous iteration |
236 | // |
237 | // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, |
238 | // %else ] %mask_1 = and i16 %scalar_mask, i32 1 << Idx %cond = icmp ne i16 |
239 | // %mask_1, 0 br i1 %mask_1, label %cond.load, label %else |
240 | // |
241 | // On GPUs, use |
242 | // %cond = extrectelement %mask, Idx |
243 | // instead |
244 | Value *Predicate; |
245 | if (SclrMask != nullptr) { |
246 | Value *Mask = Builder.getInt(AI: APInt::getOneBitSet( |
247 | numBits: VectorWidth, BitNo: adjustForEndian(DL, VectorWidth, Idx))); |
248 | Predicate = Builder.CreateICmpNE(LHS: Builder.CreateAnd(LHS: SclrMask, RHS: Mask), |
249 | RHS: Builder.getIntN(N: VectorWidth, C: 0)); |
250 | } else { |
251 | Predicate = Builder.CreateExtractElement(Vec: Mask, Idx); |
252 | } |
253 | |
254 | // Create "cond" block |
255 | // |
256 | // %EltAddr = getelementptr i32* %1, i32 0 |
257 | // %Elt = load i32* %EltAddr |
258 | // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx |
259 | // |
260 | Instruction *ThenTerm = |
261 | SplitBlockAndInsertIfThen(Cond: Predicate, SplitBefore: InsertPt, /*Unreachable=*/false, |
262 | /*BranchWeights=*/nullptr, DTU); |
263 | |
264 | BasicBlock *CondBlock = ThenTerm->getParent(); |
265 | CondBlock->setName("cond.load" ); |
266 | |
267 | Builder.SetInsertPoint(CondBlock->getTerminator()); |
268 | Value *Gep = Builder.CreateConstInBoundsGEP1_32(Ty: EltTy, Ptr, Idx0: Idx); |
269 | LoadInst *Load = Builder.CreateAlignedLoad(Ty: EltTy, Ptr: Gep, Align: AdjustedAlignVal); |
270 | Value *NewVResult = Builder.CreateInsertElement(Vec: VResult, NewElt: Load, Idx); |
271 | |
272 | // Create "else" block, fill it in the next iteration |
273 | BasicBlock *NewIfBlock = ThenTerm->getSuccessor(Idx: 0); |
274 | NewIfBlock->setName("else" ); |
275 | BasicBlock *PrevIfBlock = IfBlock; |
276 | IfBlock = NewIfBlock; |
277 | |
278 | // Create the phi to join the new and previous value. |
279 | Builder.SetInsertPoint(TheBB: NewIfBlock, IP: NewIfBlock->begin()); |
280 | PHINode *Phi = Builder.CreatePHI(Ty: VecType, NumReservedValues: 2, Name: "res.phi.else" ); |
281 | Phi->addIncoming(V: NewVResult, BB: CondBlock); |
282 | Phi->addIncoming(V: VResult, BB: PrevIfBlock); |
283 | VResult = Phi; |
284 | } |
285 | |
286 | CI->replaceAllUsesWith(V: VResult); |
287 | CI->eraseFromParent(); |
288 | |
289 | ModifiedDT = true; |
290 | } |
291 | |
292 | // Translate a masked store intrinsic, like |
293 | // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align, |
294 | // <16 x i1> %mask) |
295 | // to a chain of basic blocks, that stores element one-by-one if |
296 | // the appropriate mask bit is set |
297 | // |
298 | // %1 = bitcast i8* %addr to i32* |
299 | // %2 = extractelement <16 x i1> %mask, i32 0 |
300 | // br i1 %2, label %cond.store, label %else |
301 | // |
302 | // cond.store: ; preds = %0 |
303 | // %3 = extractelement <16 x i32> %val, i32 0 |
304 | // %4 = getelementptr i32* %1, i32 0 |
305 | // store i32 %3, i32* %4 |
306 | // br label %else |
307 | // |
308 | // else: ; preds = %0, %cond.store |
309 | // %5 = extractelement <16 x i1> %mask, i32 1 |
310 | // br i1 %5, label %cond.store1, label %else2 |
311 | // |
312 | // cond.store1: ; preds = %else |
313 | // %6 = extractelement <16 x i32> %val, i32 1 |
314 | // %7 = getelementptr i32* %1, i32 1 |
315 | // store i32 %6, i32* %7 |
316 | // br label %else2 |
317 | // . . . |
318 | static void scalarizeMaskedStore(const DataLayout &DL, bool HasBranchDivergence, |
319 | CallInst *CI, DomTreeUpdater *DTU, |
320 | bool &ModifiedDT) { |
321 | Value *Src = CI->getArgOperand(i: 0); |
322 | Value *Ptr = CI->getArgOperand(i: 1); |
323 | Value *Alignment = CI->getArgOperand(i: 2); |
324 | Value *Mask = CI->getArgOperand(i: 3); |
325 | |
326 | const Align AlignVal = cast<ConstantInt>(Val: Alignment)->getAlignValue(); |
327 | auto *VecType = cast<VectorType>(Val: Src->getType()); |
328 | |
329 | Type *EltTy = VecType->getElementType(); |
330 | |
331 | IRBuilder<> Builder(CI->getContext()); |
332 | Instruction *InsertPt = CI; |
333 | Builder.SetInsertPoint(InsertPt); |
334 | Builder.SetCurrentDebugLocation(CI->getDebugLoc()); |
335 | |
336 | // Short-cut if the mask is all-true. |
337 | if (isa<Constant>(Val: Mask) && cast<Constant>(Val: Mask)->isAllOnesValue()) { |
338 | StoreInst *Store = Builder.CreateAlignedStore(Val: Src, Ptr, Align: AlignVal); |
339 | Store->takeName(V: CI); |
340 | Store->copyMetadata(SrcInst: *CI); |
341 | CI->eraseFromParent(); |
342 | return; |
343 | } |
344 | |
345 | // Adjust alignment for the scalar instruction. |
346 | const Align AdjustedAlignVal = |
347 | commonAlignment(A: AlignVal, Offset: EltTy->getPrimitiveSizeInBits() / 8); |
348 | unsigned VectorWidth = cast<FixedVectorType>(Val: VecType)->getNumElements(); |
349 | |
350 | if (isConstantIntVector(Mask)) { |
351 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
352 | if (cast<Constant>(Val: Mask)->getAggregateElement(Elt: Idx)->isNullValue()) |
353 | continue; |
354 | Value *OneElt = Builder.CreateExtractElement(Vec: Src, Idx); |
355 | Value *Gep = Builder.CreateConstInBoundsGEP1_32(Ty: EltTy, Ptr, Idx0: Idx); |
356 | Builder.CreateAlignedStore(Val: OneElt, Ptr: Gep, Align: AdjustedAlignVal); |
357 | } |
358 | CI->eraseFromParent(); |
359 | return; |
360 | } |
361 | |
362 | // Optimize the case where the "masked store" is a predicated store - that is, |
363 | // when the mask is the splat of a non-constant scalar boolean. In that case, |
364 | // optimize to a conditional store. |
365 | if (isSplatValue(V: Mask, /*Index=*/0)) { |
366 | Value *Predicate = Builder.CreateExtractElement(Vec: Mask, Idx: uint64_t(0ull), |
367 | Name: Mask->getName() + ".first" ); |
368 | Instruction *ThenTerm = |
369 | SplitBlockAndInsertIfThen(Cond: Predicate, SplitBefore: InsertPt, /*Unreachable=*/false, |
370 | /*BranchWeights=*/nullptr, DTU); |
371 | BasicBlock *CondBlock = ThenTerm->getParent(); |
372 | CondBlock->setName("cond.store" ); |
373 | Builder.SetInsertPoint(CondBlock->getTerminator()); |
374 | |
375 | StoreInst *Store = Builder.CreateAlignedStore(Val: Src, Ptr, Align: AlignVal); |
376 | Store->takeName(V: CI); |
377 | Store->copyMetadata(SrcInst: *CI); |
378 | |
379 | CI->eraseFromParent(); |
380 | ModifiedDT = true; |
381 | return; |
382 | } |
383 | |
384 | // If the mask is not v1i1, use scalar bit test operations. This generates |
385 | // better results on X86 at least. However, don't do this on GPUs or other |
386 | // machines with branch divergence, as there each i1 takes up a register. |
387 | Value *SclrMask = nullptr; |
388 | if (VectorWidth != 1 && !HasBranchDivergence) { |
389 | Type *SclrMaskTy = Builder.getIntNTy(N: VectorWidth); |
390 | SclrMask = Builder.CreateBitCast(V: Mask, DestTy: SclrMaskTy, Name: "scalar_mask" ); |
391 | } |
392 | |
393 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
394 | // Fill the "else" block, created in the previous iteration |
395 | // |
396 | // %mask_1 = and i16 %scalar_mask, i32 1 << Idx |
397 | // %cond = icmp ne i16 %mask_1, 0 |
398 | // br i1 %mask_1, label %cond.store, label %else |
399 | // |
400 | // On GPUs, use |
401 | // %cond = extrectelement %mask, Idx |
402 | // instead |
403 | Value *Predicate; |
404 | if (SclrMask != nullptr) { |
405 | Value *Mask = Builder.getInt(AI: APInt::getOneBitSet( |
406 | numBits: VectorWidth, BitNo: adjustForEndian(DL, VectorWidth, Idx))); |
407 | Predicate = Builder.CreateICmpNE(LHS: Builder.CreateAnd(LHS: SclrMask, RHS: Mask), |
408 | RHS: Builder.getIntN(N: VectorWidth, C: 0)); |
409 | } else { |
410 | Predicate = Builder.CreateExtractElement(Vec: Mask, Idx); |
411 | } |
412 | |
413 | // Create "cond" block |
414 | // |
415 | // %OneElt = extractelement <16 x i32> %Src, i32 Idx |
416 | // %EltAddr = getelementptr i32* %1, i32 0 |
417 | // %store i32 %OneElt, i32* %EltAddr |
418 | // |
419 | Instruction *ThenTerm = |
420 | SplitBlockAndInsertIfThen(Cond: Predicate, SplitBefore: InsertPt, /*Unreachable=*/false, |
421 | /*BranchWeights=*/nullptr, DTU); |
422 | |
423 | BasicBlock *CondBlock = ThenTerm->getParent(); |
424 | CondBlock->setName("cond.store" ); |
425 | |
426 | Builder.SetInsertPoint(CondBlock->getTerminator()); |
427 | Value *OneElt = Builder.CreateExtractElement(Vec: Src, Idx); |
428 | Value *Gep = Builder.CreateConstInBoundsGEP1_32(Ty: EltTy, Ptr, Idx0: Idx); |
429 | Builder.CreateAlignedStore(Val: OneElt, Ptr: Gep, Align: AdjustedAlignVal); |
430 | |
431 | // Create "else" block, fill it in the next iteration |
432 | BasicBlock *NewIfBlock = ThenTerm->getSuccessor(Idx: 0); |
433 | NewIfBlock->setName("else" ); |
434 | |
435 | Builder.SetInsertPoint(TheBB: NewIfBlock, IP: NewIfBlock->begin()); |
436 | } |
437 | CI->eraseFromParent(); |
438 | |
439 | ModifiedDT = true; |
440 | } |
441 | |
442 | // Translate a masked gather intrinsic like |
443 | // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4, |
444 | // <16 x i1> %Mask, <16 x i32> %Src) |
445 | // to a chain of basic blocks, with loading element one-by-one if |
446 | // the appropriate mask bit is set |
447 | // |
448 | // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind |
449 | // %Mask0 = extractelement <16 x i1> %Mask, i32 0 |
450 | // br i1 %Mask0, label %cond.load, label %else |
451 | // |
452 | // cond.load: |
453 | // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 |
454 | // %Load0 = load i32, i32* %Ptr0, align 4 |
455 | // %Res0 = insertelement <16 x i32> poison, i32 %Load0, i32 0 |
456 | // br label %else |
457 | // |
458 | // else: |
459 | // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [poison, %0] |
460 | // %Mask1 = extractelement <16 x i1> %Mask, i32 1 |
461 | // br i1 %Mask1, label %cond.load1, label %else2 |
462 | // |
463 | // cond.load1: |
464 | // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 |
465 | // %Load1 = load i32, i32* %Ptr1, align 4 |
466 | // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1 |
467 | // br label %else2 |
468 | // . . . |
469 | // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src |
470 | // ret <16 x i32> %Result |
471 | static void scalarizeMaskedGather(const DataLayout &DL, |
472 | bool HasBranchDivergence, CallInst *CI, |
473 | DomTreeUpdater *DTU, bool &ModifiedDT) { |
474 | Value *Ptrs = CI->getArgOperand(i: 0); |
475 | Value *Alignment = CI->getArgOperand(i: 1); |
476 | Value *Mask = CI->getArgOperand(i: 2); |
477 | Value *Src0 = CI->getArgOperand(i: 3); |
478 | |
479 | auto *VecType = cast<FixedVectorType>(Val: CI->getType()); |
480 | Type *EltTy = VecType->getElementType(); |
481 | |
482 | IRBuilder<> Builder(CI->getContext()); |
483 | Instruction *InsertPt = CI; |
484 | BasicBlock *IfBlock = CI->getParent(); |
485 | Builder.SetInsertPoint(InsertPt); |
486 | MaybeAlign AlignVal = cast<ConstantInt>(Val: Alignment)->getMaybeAlignValue(); |
487 | |
488 | Builder.SetCurrentDebugLocation(CI->getDebugLoc()); |
489 | |
490 | // The result vector |
491 | Value *VResult = Src0; |
492 | unsigned VectorWidth = VecType->getNumElements(); |
493 | |
494 | // Shorten the way if the mask is a vector of constants. |
495 | if (isConstantIntVector(Mask)) { |
496 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
497 | if (cast<Constant>(Val: Mask)->getAggregateElement(Elt: Idx)->isNullValue()) |
498 | continue; |
499 | Value *Ptr = Builder.CreateExtractElement(Vec: Ptrs, Idx, Name: "Ptr" + Twine(Idx)); |
500 | LoadInst *Load = |
501 | Builder.CreateAlignedLoad(Ty: EltTy, Ptr, Align: AlignVal, Name: "Load" + Twine(Idx)); |
502 | VResult = |
503 | Builder.CreateInsertElement(Vec: VResult, NewElt: Load, Idx, Name: "Res" + Twine(Idx)); |
504 | } |
505 | CI->replaceAllUsesWith(V: VResult); |
506 | CI->eraseFromParent(); |
507 | return; |
508 | } |
509 | |
510 | // If the mask is not v1i1, use scalar bit test operations. This generates |
511 | // better results on X86 at least. However, don't do this on GPUs or other |
512 | // machines with branch divergence, as there, each i1 takes up a register. |
513 | Value *SclrMask = nullptr; |
514 | if (VectorWidth != 1 && !HasBranchDivergence) { |
515 | Type *SclrMaskTy = Builder.getIntNTy(N: VectorWidth); |
516 | SclrMask = Builder.CreateBitCast(V: Mask, DestTy: SclrMaskTy, Name: "scalar_mask" ); |
517 | } |
518 | |
519 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
520 | // Fill the "else" block, created in the previous iteration |
521 | // |
522 | // %Mask1 = and i16 %scalar_mask, i32 1 << Idx |
523 | // %cond = icmp ne i16 %mask_1, 0 |
524 | // br i1 %Mask1, label %cond.load, label %else |
525 | // |
526 | // On GPUs, use |
527 | // %cond = extrectelement %mask, Idx |
528 | // instead |
529 | |
530 | Value *Predicate; |
531 | if (SclrMask != nullptr) { |
532 | Value *Mask = Builder.getInt(AI: APInt::getOneBitSet( |
533 | numBits: VectorWidth, BitNo: adjustForEndian(DL, VectorWidth, Idx))); |
534 | Predicate = Builder.CreateICmpNE(LHS: Builder.CreateAnd(LHS: SclrMask, RHS: Mask), |
535 | RHS: Builder.getIntN(N: VectorWidth, C: 0)); |
536 | } else { |
537 | Predicate = Builder.CreateExtractElement(Vec: Mask, Idx, Name: "Mask" + Twine(Idx)); |
538 | } |
539 | |
540 | // Create "cond" block |
541 | // |
542 | // %EltAddr = getelementptr i32* %1, i32 0 |
543 | // %Elt = load i32* %EltAddr |
544 | // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx |
545 | // |
546 | Instruction *ThenTerm = |
547 | SplitBlockAndInsertIfThen(Cond: Predicate, SplitBefore: InsertPt, /*Unreachable=*/false, |
548 | /*BranchWeights=*/nullptr, DTU); |
549 | |
550 | BasicBlock *CondBlock = ThenTerm->getParent(); |
551 | CondBlock->setName("cond.load" ); |
552 | |
553 | Builder.SetInsertPoint(CondBlock->getTerminator()); |
554 | Value *Ptr = Builder.CreateExtractElement(Vec: Ptrs, Idx, Name: "Ptr" + Twine(Idx)); |
555 | LoadInst *Load = |
556 | Builder.CreateAlignedLoad(Ty: EltTy, Ptr, Align: AlignVal, Name: "Load" + Twine(Idx)); |
557 | Value *NewVResult = |
558 | Builder.CreateInsertElement(Vec: VResult, NewElt: Load, Idx, Name: "Res" + Twine(Idx)); |
559 | |
560 | // Create "else" block, fill it in the next iteration |
561 | BasicBlock *NewIfBlock = ThenTerm->getSuccessor(Idx: 0); |
562 | NewIfBlock->setName("else" ); |
563 | BasicBlock *PrevIfBlock = IfBlock; |
564 | IfBlock = NewIfBlock; |
565 | |
566 | // Create the phi to join the new and previous value. |
567 | Builder.SetInsertPoint(TheBB: NewIfBlock, IP: NewIfBlock->begin()); |
568 | PHINode *Phi = Builder.CreatePHI(Ty: VecType, NumReservedValues: 2, Name: "res.phi.else" ); |
569 | Phi->addIncoming(V: NewVResult, BB: CondBlock); |
570 | Phi->addIncoming(V: VResult, BB: PrevIfBlock); |
571 | VResult = Phi; |
572 | } |
573 | |
574 | CI->replaceAllUsesWith(V: VResult); |
575 | CI->eraseFromParent(); |
576 | |
577 | ModifiedDT = true; |
578 | } |
579 | |
580 | // Translate a masked scatter intrinsic, like |
581 | // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4, |
582 | // <16 x i1> %Mask) |
583 | // to a chain of basic blocks, that stores element one-by-one if |
584 | // the appropriate mask bit is set. |
585 | // |
586 | // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind |
587 | // %Mask0 = extractelement <16 x i1> %Mask, i32 0 |
588 | // br i1 %Mask0, label %cond.store, label %else |
589 | // |
590 | // cond.store: |
591 | // %Elt0 = extractelement <16 x i32> %Src, i32 0 |
592 | // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 |
593 | // store i32 %Elt0, i32* %Ptr0, align 4 |
594 | // br label %else |
595 | // |
596 | // else: |
597 | // %Mask1 = extractelement <16 x i1> %Mask, i32 1 |
598 | // br i1 %Mask1, label %cond.store1, label %else2 |
599 | // |
600 | // cond.store1: |
601 | // %Elt1 = extractelement <16 x i32> %Src, i32 1 |
602 | // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 |
603 | // store i32 %Elt1, i32* %Ptr1, align 4 |
604 | // br label %else2 |
605 | // . . . |
606 | static void scalarizeMaskedScatter(const DataLayout &DL, |
607 | bool HasBranchDivergence, CallInst *CI, |
608 | DomTreeUpdater *DTU, bool &ModifiedDT) { |
609 | Value *Src = CI->getArgOperand(i: 0); |
610 | Value *Ptrs = CI->getArgOperand(i: 1); |
611 | Value *Alignment = CI->getArgOperand(i: 2); |
612 | Value *Mask = CI->getArgOperand(i: 3); |
613 | |
614 | auto *SrcFVTy = cast<FixedVectorType>(Val: Src->getType()); |
615 | |
616 | assert( |
617 | isa<VectorType>(Ptrs->getType()) && |
618 | isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) && |
619 | "Vector of pointers is expected in masked scatter intrinsic" ); |
620 | |
621 | IRBuilder<> Builder(CI->getContext()); |
622 | Instruction *InsertPt = CI; |
623 | Builder.SetInsertPoint(InsertPt); |
624 | Builder.SetCurrentDebugLocation(CI->getDebugLoc()); |
625 | |
626 | MaybeAlign AlignVal = cast<ConstantInt>(Val: Alignment)->getMaybeAlignValue(); |
627 | unsigned VectorWidth = SrcFVTy->getNumElements(); |
628 | |
629 | // Shorten the way if the mask is a vector of constants. |
630 | if (isConstantIntVector(Mask)) { |
631 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
632 | if (cast<Constant>(Val: Mask)->getAggregateElement(Elt: Idx)->isNullValue()) |
633 | continue; |
634 | Value *OneElt = |
635 | Builder.CreateExtractElement(Vec: Src, Idx, Name: "Elt" + Twine(Idx)); |
636 | Value *Ptr = Builder.CreateExtractElement(Vec: Ptrs, Idx, Name: "Ptr" + Twine(Idx)); |
637 | Builder.CreateAlignedStore(Val: OneElt, Ptr, Align: AlignVal); |
638 | } |
639 | CI->eraseFromParent(); |
640 | return; |
641 | } |
642 | |
643 | // If the mask is not v1i1, use scalar bit test operations. This generates |
644 | // better results on X86 at least. |
645 | Value *SclrMask = nullptr; |
646 | if (VectorWidth != 1 && !HasBranchDivergence) { |
647 | Type *SclrMaskTy = Builder.getIntNTy(N: VectorWidth); |
648 | SclrMask = Builder.CreateBitCast(V: Mask, DestTy: SclrMaskTy, Name: "scalar_mask" ); |
649 | } |
650 | |
651 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
652 | // Fill the "else" block, created in the previous iteration |
653 | // |
654 | // %Mask1 = and i16 %scalar_mask, i32 1 << Idx |
655 | // %cond = icmp ne i16 %mask_1, 0 |
656 | // br i1 %Mask1, label %cond.store, label %else |
657 | // |
658 | // On GPUs, use |
659 | // %cond = extrectelement %mask, Idx |
660 | // instead |
661 | Value *Predicate; |
662 | if (SclrMask != nullptr) { |
663 | Value *Mask = Builder.getInt(AI: APInt::getOneBitSet( |
664 | numBits: VectorWidth, BitNo: adjustForEndian(DL, VectorWidth, Idx))); |
665 | Predicate = Builder.CreateICmpNE(LHS: Builder.CreateAnd(LHS: SclrMask, RHS: Mask), |
666 | RHS: Builder.getIntN(N: VectorWidth, C: 0)); |
667 | } else { |
668 | Predicate = Builder.CreateExtractElement(Vec: Mask, Idx, Name: "Mask" + Twine(Idx)); |
669 | } |
670 | |
671 | // Create "cond" block |
672 | // |
673 | // %Elt1 = extractelement <16 x i32> %Src, i32 1 |
674 | // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 |
675 | // %store i32 %Elt1, i32* %Ptr1 |
676 | // |
677 | Instruction *ThenTerm = |
678 | SplitBlockAndInsertIfThen(Cond: Predicate, SplitBefore: InsertPt, /*Unreachable=*/false, |
679 | /*BranchWeights=*/nullptr, DTU); |
680 | |
681 | BasicBlock *CondBlock = ThenTerm->getParent(); |
682 | CondBlock->setName("cond.store" ); |
683 | |
684 | Builder.SetInsertPoint(CondBlock->getTerminator()); |
685 | Value *OneElt = Builder.CreateExtractElement(Vec: Src, Idx, Name: "Elt" + Twine(Idx)); |
686 | Value *Ptr = Builder.CreateExtractElement(Vec: Ptrs, Idx, Name: "Ptr" + Twine(Idx)); |
687 | Builder.CreateAlignedStore(Val: OneElt, Ptr, Align: AlignVal); |
688 | |
689 | // Create "else" block, fill it in the next iteration |
690 | BasicBlock *NewIfBlock = ThenTerm->getSuccessor(Idx: 0); |
691 | NewIfBlock->setName("else" ); |
692 | |
693 | Builder.SetInsertPoint(TheBB: NewIfBlock, IP: NewIfBlock->begin()); |
694 | } |
695 | CI->eraseFromParent(); |
696 | |
697 | ModifiedDT = true; |
698 | } |
699 | |
700 | static void scalarizeMaskedExpandLoad(const DataLayout &DL, |
701 | bool HasBranchDivergence, CallInst *CI, |
702 | DomTreeUpdater *DTU, bool &ModifiedDT) { |
703 | Value *Ptr = CI->getArgOperand(i: 0); |
704 | Value *Mask = CI->getArgOperand(i: 1); |
705 | Value *PassThru = CI->getArgOperand(i: 2); |
706 | Align Alignment = CI->getParamAlign(ArgNo: 0).valueOrOne(); |
707 | |
708 | auto *VecType = cast<FixedVectorType>(Val: CI->getType()); |
709 | |
710 | Type *EltTy = VecType->getElementType(); |
711 | |
712 | IRBuilder<> Builder(CI->getContext()); |
713 | Instruction *InsertPt = CI; |
714 | BasicBlock *IfBlock = CI->getParent(); |
715 | |
716 | Builder.SetInsertPoint(InsertPt); |
717 | Builder.SetCurrentDebugLocation(CI->getDebugLoc()); |
718 | |
719 | unsigned VectorWidth = VecType->getNumElements(); |
720 | |
721 | // The result vector |
722 | Value *VResult = PassThru; |
723 | |
724 | // Adjust alignment for the scalar instruction. |
725 | const Align AdjustedAlignment = |
726 | commonAlignment(A: Alignment, Offset: EltTy->getPrimitiveSizeInBits() / 8); |
727 | |
728 | // Shorten the way if the mask is a vector of constants. |
729 | // Create a build_vector pattern, with loads/poisons as necessary and then |
730 | // shuffle blend with the pass through value. |
731 | if (isConstantIntVector(Mask)) { |
732 | unsigned MemIndex = 0; |
733 | VResult = PoisonValue::get(T: VecType); |
734 | SmallVector<int, 16> ShuffleMask(VectorWidth, PoisonMaskElem); |
735 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
736 | Value *InsertElt; |
737 | if (cast<Constant>(Val: Mask)->getAggregateElement(Elt: Idx)->isNullValue()) { |
738 | InsertElt = PoisonValue::get(T: EltTy); |
739 | ShuffleMask[Idx] = Idx + VectorWidth; |
740 | } else { |
741 | Value *NewPtr = |
742 | Builder.CreateConstInBoundsGEP1_32(Ty: EltTy, Ptr, Idx0: MemIndex); |
743 | InsertElt = Builder.CreateAlignedLoad(Ty: EltTy, Ptr: NewPtr, Align: AdjustedAlignment, |
744 | Name: "Load" + Twine(Idx)); |
745 | ShuffleMask[Idx] = Idx; |
746 | ++MemIndex; |
747 | } |
748 | VResult = Builder.CreateInsertElement(Vec: VResult, NewElt: InsertElt, Idx, |
749 | Name: "Res" + Twine(Idx)); |
750 | } |
751 | VResult = Builder.CreateShuffleVector(V1: VResult, V2: PassThru, Mask: ShuffleMask); |
752 | CI->replaceAllUsesWith(V: VResult); |
753 | CI->eraseFromParent(); |
754 | return; |
755 | } |
756 | |
757 | // If the mask is not v1i1, use scalar bit test operations. This generates |
758 | // better results on X86 at least. However, don't do this on GPUs or other |
759 | // machines with branch divergence, as there, each i1 takes up a register. |
760 | Value *SclrMask = nullptr; |
761 | if (VectorWidth != 1 && !HasBranchDivergence) { |
762 | Type *SclrMaskTy = Builder.getIntNTy(N: VectorWidth); |
763 | SclrMask = Builder.CreateBitCast(V: Mask, DestTy: SclrMaskTy, Name: "scalar_mask" ); |
764 | } |
765 | |
766 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
767 | // Fill the "else" block, created in the previous iteration |
768 | // |
769 | // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, |
770 | // %else ] %mask_1 = extractelement <16 x i1> %mask, i32 Idx br i1 %mask_1, |
771 | // label %cond.load, label %else |
772 | // |
773 | // On GPUs, use |
774 | // %cond = extrectelement %mask, Idx |
775 | // instead |
776 | |
777 | Value *Predicate; |
778 | if (SclrMask != nullptr) { |
779 | Value *Mask = Builder.getInt(AI: APInt::getOneBitSet( |
780 | numBits: VectorWidth, BitNo: adjustForEndian(DL, VectorWidth, Idx))); |
781 | Predicate = Builder.CreateICmpNE(LHS: Builder.CreateAnd(LHS: SclrMask, RHS: Mask), |
782 | RHS: Builder.getIntN(N: VectorWidth, C: 0)); |
783 | } else { |
784 | Predicate = Builder.CreateExtractElement(Vec: Mask, Idx, Name: "Mask" + Twine(Idx)); |
785 | } |
786 | |
787 | // Create "cond" block |
788 | // |
789 | // %EltAddr = getelementptr i32* %1, i32 0 |
790 | // %Elt = load i32* %EltAddr |
791 | // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx |
792 | // |
793 | Instruction *ThenTerm = |
794 | SplitBlockAndInsertIfThen(Cond: Predicate, SplitBefore: InsertPt, /*Unreachable=*/false, |
795 | /*BranchWeights=*/nullptr, DTU); |
796 | |
797 | BasicBlock *CondBlock = ThenTerm->getParent(); |
798 | CondBlock->setName("cond.load" ); |
799 | |
800 | Builder.SetInsertPoint(CondBlock->getTerminator()); |
801 | LoadInst *Load = Builder.CreateAlignedLoad(Ty: EltTy, Ptr, Align: AdjustedAlignment); |
802 | Value *NewVResult = Builder.CreateInsertElement(Vec: VResult, NewElt: Load, Idx); |
803 | |
804 | // Move the pointer if there are more blocks to come. |
805 | Value *NewPtr; |
806 | if ((Idx + 1) != VectorWidth) |
807 | NewPtr = Builder.CreateConstInBoundsGEP1_32(Ty: EltTy, Ptr, Idx0: 1); |
808 | |
809 | // Create "else" block, fill it in the next iteration |
810 | BasicBlock *NewIfBlock = ThenTerm->getSuccessor(Idx: 0); |
811 | NewIfBlock->setName("else" ); |
812 | BasicBlock *PrevIfBlock = IfBlock; |
813 | IfBlock = NewIfBlock; |
814 | |
815 | // Create the phi to join the new and previous value. |
816 | Builder.SetInsertPoint(TheBB: NewIfBlock, IP: NewIfBlock->begin()); |
817 | PHINode *ResultPhi = Builder.CreatePHI(Ty: VecType, NumReservedValues: 2, Name: "res.phi.else" ); |
818 | ResultPhi->addIncoming(V: NewVResult, BB: CondBlock); |
819 | ResultPhi->addIncoming(V: VResult, BB: PrevIfBlock); |
820 | VResult = ResultPhi; |
821 | |
822 | // Add a PHI for the pointer if this isn't the last iteration. |
823 | if ((Idx + 1) != VectorWidth) { |
824 | PHINode *PtrPhi = Builder.CreatePHI(Ty: Ptr->getType(), NumReservedValues: 2, Name: "ptr.phi.else" ); |
825 | PtrPhi->addIncoming(V: NewPtr, BB: CondBlock); |
826 | PtrPhi->addIncoming(V: Ptr, BB: PrevIfBlock); |
827 | Ptr = PtrPhi; |
828 | } |
829 | } |
830 | |
831 | CI->replaceAllUsesWith(V: VResult); |
832 | CI->eraseFromParent(); |
833 | |
834 | ModifiedDT = true; |
835 | } |
836 | |
837 | static void scalarizeMaskedCompressStore(const DataLayout &DL, |
838 | bool HasBranchDivergence, CallInst *CI, |
839 | DomTreeUpdater *DTU, |
840 | bool &ModifiedDT) { |
841 | Value *Src = CI->getArgOperand(i: 0); |
842 | Value *Ptr = CI->getArgOperand(i: 1); |
843 | Value *Mask = CI->getArgOperand(i: 2); |
844 | Align Alignment = CI->getParamAlign(ArgNo: 1).valueOrOne(); |
845 | |
846 | auto *VecType = cast<FixedVectorType>(Val: Src->getType()); |
847 | |
848 | IRBuilder<> Builder(CI->getContext()); |
849 | Instruction *InsertPt = CI; |
850 | BasicBlock *IfBlock = CI->getParent(); |
851 | |
852 | Builder.SetInsertPoint(InsertPt); |
853 | Builder.SetCurrentDebugLocation(CI->getDebugLoc()); |
854 | |
855 | Type *EltTy = VecType->getElementType(); |
856 | |
857 | // Adjust alignment for the scalar instruction. |
858 | const Align AdjustedAlignment = |
859 | commonAlignment(A: Alignment, Offset: EltTy->getPrimitiveSizeInBits() / 8); |
860 | |
861 | unsigned VectorWidth = VecType->getNumElements(); |
862 | |
863 | // Shorten the way if the mask is a vector of constants. |
864 | if (isConstantIntVector(Mask)) { |
865 | unsigned MemIndex = 0; |
866 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
867 | if (cast<Constant>(Val: Mask)->getAggregateElement(Elt: Idx)->isNullValue()) |
868 | continue; |
869 | Value *OneElt = |
870 | Builder.CreateExtractElement(Vec: Src, Idx, Name: "Elt" + Twine(Idx)); |
871 | Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(Ty: EltTy, Ptr, Idx0: MemIndex); |
872 | Builder.CreateAlignedStore(Val: OneElt, Ptr: NewPtr, Align: AdjustedAlignment); |
873 | ++MemIndex; |
874 | } |
875 | CI->eraseFromParent(); |
876 | return; |
877 | } |
878 | |
879 | // If the mask is not v1i1, use scalar bit test operations. This generates |
880 | // better results on X86 at least. However, don't do this on GPUs or other |
881 | // machines with branch divergence, as there, each i1 takes up a register. |
882 | Value *SclrMask = nullptr; |
883 | if (VectorWidth != 1 && !HasBranchDivergence) { |
884 | Type *SclrMaskTy = Builder.getIntNTy(N: VectorWidth); |
885 | SclrMask = Builder.CreateBitCast(V: Mask, DestTy: SclrMaskTy, Name: "scalar_mask" ); |
886 | } |
887 | |
888 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
889 | // Fill the "else" block, created in the previous iteration |
890 | // |
891 | // %mask_1 = extractelement <16 x i1> %mask, i32 Idx |
892 | // br i1 %mask_1, label %cond.store, label %else |
893 | // |
894 | // On GPUs, use |
895 | // %cond = extrectelement %mask, Idx |
896 | // instead |
897 | Value *Predicate; |
898 | if (SclrMask != nullptr) { |
899 | Value *Mask = Builder.getInt(AI: APInt::getOneBitSet( |
900 | numBits: VectorWidth, BitNo: adjustForEndian(DL, VectorWidth, Idx))); |
901 | Predicate = Builder.CreateICmpNE(LHS: Builder.CreateAnd(LHS: SclrMask, RHS: Mask), |
902 | RHS: Builder.getIntN(N: VectorWidth, C: 0)); |
903 | } else { |
904 | Predicate = Builder.CreateExtractElement(Vec: Mask, Idx, Name: "Mask" + Twine(Idx)); |
905 | } |
906 | |
907 | // Create "cond" block |
908 | // |
909 | // %OneElt = extractelement <16 x i32> %Src, i32 Idx |
910 | // %EltAddr = getelementptr i32* %1, i32 0 |
911 | // %store i32 %OneElt, i32* %EltAddr |
912 | // |
913 | Instruction *ThenTerm = |
914 | SplitBlockAndInsertIfThen(Cond: Predicate, SplitBefore: InsertPt, /*Unreachable=*/false, |
915 | /*BranchWeights=*/nullptr, DTU); |
916 | |
917 | BasicBlock *CondBlock = ThenTerm->getParent(); |
918 | CondBlock->setName("cond.store" ); |
919 | |
920 | Builder.SetInsertPoint(CondBlock->getTerminator()); |
921 | Value *OneElt = Builder.CreateExtractElement(Vec: Src, Idx); |
922 | Builder.CreateAlignedStore(Val: OneElt, Ptr, Align: AdjustedAlignment); |
923 | |
924 | // Move the pointer if there are more blocks to come. |
925 | Value *NewPtr; |
926 | if ((Idx + 1) != VectorWidth) |
927 | NewPtr = Builder.CreateConstInBoundsGEP1_32(Ty: EltTy, Ptr, Idx0: 1); |
928 | |
929 | // Create "else" block, fill it in the next iteration |
930 | BasicBlock *NewIfBlock = ThenTerm->getSuccessor(Idx: 0); |
931 | NewIfBlock->setName("else" ); |
932 | BasicBlock *PrevIfBlock = IfBlock; |
933 | IfBlock = NewIfBlock; |
934 | |
935 | Builder.SetInsertPoint(TheBB: NewIfBlock, IP: NewIfBlock->begin()); |
936 | |
937 | // Add a PHI for the pointer if this isn't the last iteration. |
938 | if ((Idx + 1) != VectorWidth) { |
939 | PHINode *PtrPhi = Builder.CreatePHI(Ty: Ptr->getType(), NumReservedValues: 2, Name: "ptr.phi.else" ); |
940 | PtrPhi->addIncoming(V: NewPtr, BB: CondBlock); |
941 | PtrPhi->addIncoming(V: Ptr, BB: PrevIfBlock); |
942 | Ptr = PtrPhi; |
943 | } |
944 | } |
945 | CI->eraseFromParent(); |
946 | |
947 | ModifiedDT = true; |
948 | } |
949 | |
950 | static void scalarizeMaskedVectorHistogram(const DataLayout &DL, CallInst *CI, |
951 | DomTreeUpdater *DTU, |
952 | bool &ModifiedDT) { |
953 | // If we extend histogram to return a result someday (like the updated vector) |
954 | // then we'll need to support it here. |
955 | assert(CI->getType()->isVoidTy() && "Histogram with non-void return." ); |
956 | Value *Ptrs = CI->getArgOperand(i: 0); |
957 | Value *Inc = CI->getArgOperand(i: 1); |
958 | Value *Mask = CI->getArgOperand(i: 2); |
959 | |
960 | auto *AddrType = cast<FixedVectorType>(Val: Ptrs->getType()); |
961 | Type *EltTy = Inc->getType(); |
962 | |
963 | IRBuilder<> Builder(CI->getContext()); |
964 | Instruction *InsertPt = CI; |
965 | Builder.SetInsertPoint(InsertPt); |
966 | |
967 | Builder.SetCurrentDebugLocation(CI->getDebugLoc()); |
968 | |
969 | // FIXME: Do we need to add an alignment parameter to the intrinsic? |
970 | unsigned VectorWidth = AddrType->getNumElements(); |
971 | auto CreateHistogramUpdateValue = [&](IntrinsicInst *CI, Value *Load, |
972 | Value *Inc) -> Value * { |
973 | Value *UpdateOp; |
974 | switch (CI->getIntrinsicID()) { |
975 | case Intrinsic::experimental_vector_histogram_add: |
976 | UpdateOp = Builder.CreateAdd(LHS: Load, RHS: Inc); |
977 | break; |
978 | case Intrinsic::experimental_vector_histogram_uadd_sat: |
979 | UpdateOp = |
980 | Builder.CreateIntrinsic(ID: Intrinsic::uadd_sat, Types: {EltTy}, Args: {Load, Inc}); |
981 | break; |
982 | case Intrinsic::experimental_vector_histogram_umin: |
983 | UpdateOp = Builder.CreateIntrinsic(ID: Intrinsic::umin, Types: {EltTy}, Args: {Load, Inc}); |
984 | break; |
985 | case Intrinsic::experimental_vector_histogram_umax: |
986 | UpdateOp = Builder.CreateIntrinsic(ID: Intrinsic::umax, Types: {EltTy}, Args: {Load, Inc}); |
987 | break; |
988 | |
989 | default: |
990 | llvm_unreachable("Unexpected histogram intrinsic" ); |
991 | } |
992 | return UpdateOp; |
993 | }; |
994 | |
995 | // Shorten the way if the mask is a vector of constants. |
996 | if (isConstantIntVector(Mask)) { |
997 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
998 | if (cast<Constant>(Val: Mask)->getAggregateElement(Elt: Idx)->isNullValue()) |
999 | continue; |
1000 | Value *Ptr = Builder.CreateExtractElement(Vec: Ptrs, Idx, Name: "Ptr" + Twine(Idx)); |
1001 | LoadInst *Load = Builder.CreateLoad(Ty: EltTy, Ptr, Name: "Load" + Twine(Idx)); |
1002 | Value *Update = |
1003 | CreateHistogramUpdateValue(cast<IntrinsicInst>(Val: CI), Load, Inc); |
1004 | Builder.CreateStore(Val: Update, Ptr); |
1005 | } |
1006 | CI->eraseFromParent(); |
1007 | return; |
1008 | } |
1009 | |
1010 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
1011 | Value *Predicate = |
1012 | Builder.CreateExtractElement(Vec: Mask, Idx, Name: "Mask" + Twine(Idx)); |
1013 | |
1014 | Instruction *ThenTerm = |
1015 | SplitBlockAndInsertIfThen(Cond: Predicate, SplitBefore: InsertPt, /*Unreachable=*/false, |
1016 | /*BranchWeights=*/nullptr, DTU); |
1017 | |
1018 | BasicBlock *CondBlock = ThenTerm->getParent(); |
1019 | CondBlock->setName("cond.histogram.update" ); |
1020 | |
1021 | Builder.SetInsertPoint(CondBlock->getTerminator()); |
1022 | Value *Ptr = Builder.CreateExtractElement(Vec: Ptrs, Idx, Name: "Ptr" + Twine(Idx)); |
1023 | LoadInst *Load = Builder.CreateLoad(Ty: EltTy, Ptr, Name: "Load" + Twine(Idx)); |
1024 | Value *UpdateOp = |
1025 | CreateHistogramUpdateValue(cast<IntrinsicInst>(Val: CI), Load, Inc); |
1026 | Builder.CreateStore(Val: UpdateOp, Ptr); |
1027 | |
1028 | // Create "else" block, fill it in the next iteration |
1029 | BasicBlock *NewIfBlock = ThenTerm->getSuccessor(Idx: 0); |
1030 | NewIfBlock->setName("else" ); |
1031 | Builder.SetInsertPoint(TheBB: NewIfBlock, IP: NewIfBlock->begin()); |
1032 | } |
1033 | |
1034 | CI->eraseFromParent(); |
1035 | ModifiedDT = true; |
1036 | } |
1037 | |
1038 | static bool runImpl(Function &F, const TargetTransformInfo &TTI, |
1039 | DominatorTree *DT) { |
1040 | std::optional<DomTreeUpdater> DTU; |
1041 | if (DT) |
1042 | DTU.emplace(args&: DT, args: DomTreeUpdater::UpdateStrategy::Lazy); |
1043 | |
1044 | bool EverMadeChange = false; |
1045 | bool MadeChange = true; |
1046 | auto &DL = F.getDataLayout(); |
1047 | bool HasBranchDivergence = TTI.hasBranchDivergence(F: &F); |
1048 | while (MadeChange) { |
1049 | MadeChange = false; |
1050 | for (BasicBlock &BB : llvm::make_early_inc_range(Range&: F)) { |
1051 | bool ModifiedDTOnIteration = false; |
1052 | MadeChange |= optimizeBlock(BB, ModifiedDT&: ModifiedDTOnIteration, TTI, DL, |
1053 | HasBranchDivergence, DTU: DTU ? &*DTU : nullptr); |
1054 | |
1055 | // Restart BB iteration if the dominator tree of the Function was changed |
1056 | if (ModifiedDTOnIteration) |
1057 | break; |
1058 | } |
1059 | |
1060 | EverMadeChange |= MadeChange; |
1061 | } |
1062 | return EverMadeChange; |
1063 | } |
1064 | |
1065 | bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) { |
1066 | auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); |
1067 | DominatorTree *DT = nullptr; |
1068 | if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>()) |
1069 | DT = &DTWP->getDomTree(); |
1070 | return runImpl(F, TTI, DT); |
1071 | } |
1072 | |
1073 | PreservedAnalyses |
1074 | ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) { |
1075 | auto &TTI = AM.getResult<TargetIRAnalysis>(IR&: F); |
1076 | auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(IR&: F); |
1077 | if (!runImpl(F, TTI, DT)) |
1078 | return PreservedAnalyses::all(); |
1079 | PreservedAnalyses PA; |
1080 | PA.preserve<TargetIRAnalysis>(); |
1081 | PA.preserve<DominatorTreeAnalysis>(); |
1082 | return PA; |
1083 | } |
1084 | |
1085 | static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, |
1086 | const TargetTransformInfo &TTI, const DataLayout &DL, |
1087 | bool HasBranchDivergence, DomTreeUpdater *DTU) { |
1088 | bool MadeChange = false; |
1089 | |
1090 | BasicBlock::iterator CurInstIterator = BB.begin(); |
1091 | while (CurInstIterator != BB.end()) { |
1092 | if (CallInst *CI = dyn_cast<CallInst>(Val: &*CurInstIterator++)) |
1093 | MadeChange |= |
1094 | optimizeCallInst(CI, ModifiedDT, TTI, DL, HasBranchDivergence, DTU); |
1095 | if (ModifiedDT) |
1096 | return true; |
1097 | } |
1098 | |
1099 | return MadeChange; |
1100 | } |
1101 | |
1102 | static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, |
1103 | const TargetTransformInfo &TTI, |
1104 | const DataLayout &DL, bool HasBranchDivergence, |
1105 | DomTreeUpdater *DTU) { |
1106 | IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: CI); |
1107 | if (II) { |
1108 | // The scalarization code below does not work for scalable vectors. |
1109 | if (isa<ScalableVectorType>(Val: II->getType()) || |
1110 | any_of(Range: II->args(), |
1111 | P: [](Value *V) { return isa<ScalableVectorType>(Val: V->getType()); })) |
1112 | return false; |
1113 | switch (II->getIntrinsicID()) { |
1114 | default: |
1115 | break; |
1116 | case Intrinsic::experimental_vector_histogram_add: |
1117 | case Intrinsic::experimental_vector_histogram_uadd_sat: |
1118 | case Intrinsic::experimental_vector_histogram_umin: |
1119 | case Intrinsic::experimental_vector_histogram_umax: |
1120 | if (TTI.isLegalMaskedVectorHistogram(AddrType: CI->getArgOperand(i: 0)->getType(), |
1121 | DataType: CI->getArgOperand(i: 1)->getType())) |
1122 | return false; |
1123 | scalarizeMaskedVectorHistogram(DL, CI, DTU, ModifiedDT); |
1124 | return true; |
1125 | case Intrinsic::masked_load: |
1126 | // Scalarize unsupported vector masked load |
1127 | if (TTI.isLegalMaskedLoad( |
1128 | DataType: CI->getType(), |
1129 | Alignment: cast<ConstantInt>(Val: CI->getArgOperand(i: 1))->getAlignValue(), |
1130 | AddressSpace: cast<PointerType>(Val: CI->getArgOperand(i: 0)->getType()) |
1131 | ->getAddressSpace())) |
1132 | return false; |
1133 | scalarizeMaskedLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT); |
1134 | return true; |
1135 | case Intrinsic::masked_store: |
1136 | if (TTI.isLegalMaskedStore( |
1137 | DataType: CI->getArgOperand(i: 0)->getType(), |
1138 | Alignment: cast<ConstantInt>(Val: CI->getArgOperand(i: 2))->getAlignValue(), |
1139 | AddressSpace: cast<PointerType>(Val: CI->getArgOperand(i: 1)->getType()) |
1140 | ->getAddressSpace())) |
1141 | return false; |
1142 | scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT); |
1143 | return true; |
1144 | case Intrinsic::masked_gather: { |
1145 | MaybeAlign MA = |
1146 | cast<ConstantInt>(Val: CI->getArgOperand(i: 1))->getMaybeAlignValue(); |
1147 | Type *LoadTy = CI->getType(); |
1148 | Align Alignment = DL.getValueOrABITypeAlignment(Alignment: MA, |
1149 | Ty: LoadTy->getScalarType()); |
1150 | if (TTI.isLegalMaskedGather(DataType: LoadTy, Alignment) && |
1151 | !TTI.forceScalarizeMaskedGather(Type: cast<VectorType>(Val: LoadTy), Alignment)) |
1152 | return false; |
1153 | scalarizeMaskedGather(DL, HasBranchDivergence, CI, DTU, ModifiedDT); |
1154 | return true; |
1155 | } |
1156 | case Intrinsic::masked_scatter: { |
1157 | MaybeAlign MA = |
1158 | cast<ConstantInt>(Val: CI->getArgOperand(i: 2))->getMaybeAlignValue(); |
1159 | Type *StoreTy = CI->getArgOperand(i: 0)->getType(); |
1160 | Align Alignment = DL.getValueOrABITypeAlignment(Alignment: MA, |
1161 | Ty: StoreTy->getScalarType()); |
1162 | if (TTI.isLegalMaskedScatter(DataType: StoreTy, Alignment) && |
1163 | !TTI.forceScalarizeMaskedScatter(Type: cast<VectorType>(Val: StoreTy), |
1164 | Alignment)) |
1165 | return false; |
1166 | scalarizeMaskedScatter(DL, HasBranchDivergence, CI, DTU, ModifiedDT); |
1167 | return true; |
1168 | } |
1169 | case Intrinsic::masked_expandload: |
1170 | if (TTI.isLegalMaskedExpandLoad( |
1171 | DataType: CI->getType(), |
1172 | Alignment: CI->getAttributes().getParamAttrs(ArgNo: 0).getAlignment().valueOrOne())) |
1173 | return false; |
1174 | scalarizeMaskedExpandLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT); |
1175 | return true; |
1176 | case Intrinsic::masked_compressstore: |
1177 | if (TTI.isLegalMaskedCompressStore( |
1178 | DataType: CI->getArgOperand(i: 0)->getType(), |
1179 | Alignment: CI->getAttributes().getParamAttrs(ArgNo: 1).getAlignment().valueOrOne())) |
1180 | return false; |
1181 | scalarizeMaskedCompressStore(DL, HasBranchDivergence, CI, DTU, |
1182 | ModifiedDT); |
1183 | return true; |
1184 | } |
1185 | } |
1186 | |
1187 | return false; |
1188 | } |
1189 | |