1 | //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===// |
2 | // intrinsics |
3 | // |
4 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
5 | // See https://llvm.org/LICENSE.txt for license information. |
6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
7 | // |
8 | //===----------------------------------------------------------------------===// |
9 | // |
10 | // This pass replaces masked memory intrinsics - when unsupported by the target |
11 | // - with a chain of basic blocks, that deal with the elements one-by-one if the |
12 | // appropriate mask bit is set. |
13 | // |
14 | //===----------------------------------------------------------------------===// |
15 | |
16 | #include "llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h" |
17 | #include "llvm/ADT/Twine.h" |
18 | #include "llvm/Analysis/DomTreeUpdater.h" |
19 | #include "llvm/Analysis/TargetTransformInfo.h" |
20 | #include "llvm/IR/BasicBlock.h" |
21 | #include "llvm/IR/Constant.h" |
22 | #include "llvm/IR/Constants.h" |
23 | #include "llvm/IR/DerivedTypes.h" |
24 | #include "llvm/IR/Dominators.h" |
25 | #include "llvm/IR/Function.h" |
26 | #include "llvm/IR/IRBuilder.h" |
27 | #include "llvm/IR/Instruction.h" |
28 | #include "llvm/IR/Instructions.h" |
29 | #include "llvm/IR/IntrinsicInst.h" |
30 | #include "llvm/IR/Type.h" |
31 | #include "llvm/IR/Value.h" |
32 | #include "llvm/InitializePasses.h" |
33 | #include "llvm/Pass.h" |
34 | #include "llvm/Support/Casting.h" |
35 | #include "llvm/Transforms/Scalar.h" |
36 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
37 | #include <cassert> |
38 | #include <optional> |
39 | |
40 | using namespace llvm; |
41 | |
42 | #define DEBUG_TYPE "scalarize-masked-mem-intrin" |
43 | |
44 | namespace { |
45 | |
46 | class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass { |
47 | public: |
48 | static char ID; // Pass identification, replacement for typeid |
49 | |
50 | explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) { |
51 | initializeScalarizeMaskedMemIntrinLegacyPassPass( |
52 | *PassRegistry::getPassRegistry()); |
53 | } |
54 | |
55 | bool runOnFunction(Function &F) override; |
56 | |
57 | StringRef getPassName() const override { |
58 | return "Scalarize Masked Memory Intrinsics" ; |
59 | } |
60 | |
61 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
62 | AU.addRequired<TargetTransformInfoWrapperPass>(); |
63 | AU.addPreserved<DominatorTreeWrapperPass>(); |
64 | } |
65 | }; |
66 | |
67 | } // end anonymous namespace |
68 | |
69 | static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, |
70 | const TargetTransformInfo &TTI, const DataLayout &DL, |
71 | DomTreeUpdater *DTU); |
72 | static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, |
73 | const TargetTransformInfo &TTI, |
74 | const DataLayout &DL, DomTreeUpdater *DTU); |
75 | |
76 | char ScalarizeMaskedMemIntrinLegacyPass::ID = 0; |
77 | |
78 | INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE, |
79 | "Scalarize unsupported masked memory intrinsics" , false, |
80 | false) |
81 | INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) |
82 | INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) |
83 | INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE, |
84 | "Scalarize unsupported masked memory intrinsics" , false, |
85 | false) |
86 | |
87 | FunctionPass *llvm::createScalarizeMaskedMemIntrinLegacyPass() { |
88 | return new ScalarizeMaskedMemIntrinLegacyPass(); |
89 | } |
90 | |
91 | static bool isConstantIntVector(Value *Mask) { |
92 | Constant *C = dyn_cast<Constant>(Val: Mask); |
93 | if (!C) |
94 | return false; |
95 | |
96 | unsigned NumElts = cast<FixedVectorType>(Val: Mask->getType())->getNumElements(); |
97 | for (unsigned i = 0; i != NumElts; ++i) { |
98 | Constant *CElt = C->getAggregateElement(Elt: i); |
99 | if (!CElt || !isa<ConstantInt>(Val: CElt)) |
100 | return false; |
101 | } |
102 | |
103 | return true; |
104 | } |
105 | |
106 | static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth, |
107 | unsigned Idx) { |
108 | return DL.isBigEndian() ? VectorWidth - 1 - Idx : Idx; |
109 | } |
110 | |
111 | // Translate a masked load intrinsic like |
112 | // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align, |
113 | // <16 x i1> %mask, <16 x i32> %passthru) |
114 | // to a chain of basic blocks, with loading element one-by-one if |
115 | // the appropriate mask bit is set |
116 | // |
117 | // %1 = bitcast i8* %addr to i32* |
118 | // %2 = extractelement <16 x i1> %mask, i32 0 |
119 | // br i1 %2, label %cond.load, label %else |
120 | // |
121 | // cond.load: ; preds = %0 |
122 | // %3 = getelementptr i32* %1, i32 0 |
123 | // %4 = load i32* %3 |
124 | // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0 |
125 | // br label %else |
126 | // |
127 | // else: ; preds = %0, %cond.load |
128 | // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ poison, %0 ] |
129 | // %6 = extractelement <16 x i1> %mask, i32 1 |
130 | // br i1 %6, label %cond.load1, label %else2 |
131 | // |
132 | // cond.load1: ; preds = %else |
133 | // %7 = getelementptr i32* %1, i32 1 |
134 | // %8 = load i32* %7 |
135 | // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1 |
136 | // br label %else2 |
137 | // |
138 | // else2: ; preds = %else, %cond.load1 |
139 | // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ] |
140 | // %10 = extractelement <16 x i1> %mask, i32 2 |
141 | // br i1 %10, label %cond.load4, label %else5 |
142 | // |
143 | static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI, |
144 | DomTreeUpdater *DTU, bool &ModifiedDT) { |
145 | Value *Ptr = CI->getArgOperand(i: 0); |
146 | Value *Alignment = CI->getArgOperand(i: 1); |
147 | Value *Mask = CI->getArgOperand(i: 2); |
148 | Value *Src0 = CI->getArgOperand(i: 3); |
149 | |
150 | const Align AlignVal = cast<ConstantInt>(Val: Alignment)->getAlignValue(); |
151 | VectorType *VecType = cast<FixedVectorType>(Val: CI->getType()); |
152 | |
153 | Type *EltTy = VecType->getElementType(); |
154 | |
155 | IRBuilder<> Builder(CI->getContext()); |
156 | Instruction *InsertPt = CI; |
157 | BasicBlock *IfBlock = CI->getParent(); |
158 | |
159 | Builder.SetInsertPoint(InsertPt); |
160 | Builder.SetCurrentDebugLocation(CI->getDebugLoc()); |
161 | |
162 | // Short-cut if the mask is all-true. |
163 | if (isa<Constant>(Val: Mask) && cast<Constant>(Val: Mask)->isAllOnesValue()) { |
164 | Value *NewI = Builder.CreateAlignedLoad(Ty: VecType, Ptr, Align: AlignVal); |
165 | CI->replaceAllUsesWith(V: NewI); |
166 | CI->eraseFromParent(); |
167 | return; |
168 | } |
169 | |
170 | // Adjust alignment for the scalar instruction. |
171 | const Align AdjustedAlignVal = |
172 | commonAlignment(A: AlignVal, Offset: EltTy->getPrimitiveSizeInBits() / 8); |
173 | unsigned VectorWidth = cast<FixedVectorType>(Val: VecType)->getNumElements(); |
174 | |
175 | // The result vector |
176 | Value *VResult = Src0; |
177 | |
178 | if (isConstantIntVector(Mask)) { |
179 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
180 | if (cast<Constant>(Val: Mask)->getAggregateElement(Elt: Idx)->isNullValue()) |
181 | continue; |
182 | Value *Gep = Builder.CreateConstInBoundsGEP1_32(Ty: EltTy, Ptr, Idx0: Idx); |
183 | LoadInst *Load = Builder.CreateAlignedLoad(Ty: EltTy, Ptr: Gep, Align: AdjustedAlignVal); |
184 | VResult = Builder.CreateInsertElement(Vec: VResult, NewElt: Load, Idx); |
185 | } |
186 | CI->replaceAllUsesWith(V: VResult); |
187 | CI->eraseFromParent(); |
188 | return; |
189 | } |
190 | |
191 | // If the mask is not v1i1, use scalar bit test operations. This generates |
192 | // better results on X86 at least. |
193 | Value *SclrMask; |
194 | if (VectorWidth != 1) { |
195 | Type *SclrMaskTy = Builder.getIntNTy(N: VectorWidth); |
196 | SclrMask = Builder.CreateBitCast(V: Mask, DestTy: SclrMaskTy, Name: "scalar_mask" ); |
197 | } |
198 | |
199 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
200 | // Fill the "else" block, created in the previous iteration |
201 | // |
202 | // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] |
203 | // %mask_1 = and i16 %scalar_mask, i32 1 << Idx |
204 | // %cond = icmp ne i16 %mask_1, 0 |
205 | // br i1 %mask_1, label %cond.load, label %else |
206 | // |
207 | Value *Predicate; |
208 | if (VectorWidth != 1) { |
209 | Value *Mask = Builder.getInt(AI: APInt::getOneBitSet( |
210 | numBits: VectorWidth, BitNo: adjustForEndian(DL, VectorWidth, Idx))); |
211 | Predicate = Builder.CreateICmpNE(LHS: Builder.CreateAnd(LHS: SclrMask, RHS: Mask), |
212 | RHS: Builder.getIntN(N: VectorWidth, C: 0)); |
213 | } else { |
214 | Predicate = Builder.CreateExtractElement(Vec: Mask, Idx); |
215 | } |
216 | |
217 | // Create "cond" block |
218 | // |
219 | // %EltAddr = getelementptr i32* %1, i32 0 |
220 | // %Elt = load i32* %EltAddr |
221 | // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx |
222 | // |
223 | Instruction *ThenTerm = |
224 | SplitBlockAndInsertIfThen(Cond: Predicate, SplitBefore: InsertPt, /*Unreachable=*/false, |
225 | /*BranchWeights=*/nullptr, DTU); |
226 | |
227 | BasicBlock *CondBlock = ThenTerm->getParent(); |
228 | CondBlock->setName("cond.load" ); |
229 | |
230 | Builder.SetInsertPoint(CondBlock->getTerminator()); |
231 | Value *Gep = Builder.CreateConstInBoundsGEP1_32(Ty: EltTy, Ptr, Idx0: Idx); |
232 | LoadInst *Load = Builder.CreateAlignedLoad(Ty: EltTy, Ptr: Gep, Align: AdjustedAlignVal); |
233 | Value *NewVResult = Builder.CreateInsertElement(Vec: VResult, NewElt: Load, Idx); |
234 | |
235 | // Create "else" block, fill it in the next iteration |
236 | BasicBlock *NewIfBlock = ThenTerm->getSuccessor(Idx: 0); |
237 | NewIfBlock->setName("else" ); |
238 | BasicBlock *PrevIfBlock = IfBlock; |
239 | IfBlock = NewIfBlock; |
240 | |
241 | // Create the phi to join the new and previous value. |
242 | Builder.SetInsertPoint(TheBB: NewIfBlock, IP: NewIfBlock->begin()); |
243 | PHINode *Phi = Builder.CreatePHI(Ty: VecType, NumReservedValues: 2, Name: "res.phi.else" ); |
244 | Phi->addIncoming(V: NewVResult, BB: CondBlock); |
245 | Phi->addIncoming(V: VResult, BB: PrevIfBlock); |
246 | VResult = Phi; |
247 | } |
248 | |
249 | CI->replaceAllUsesWith(V: VResult); |
250 | CI->eraseFromParent(); |
251 | |
252 | ModifiedDT = true; |
253 | } |
254 | |
255 | // Translate a masked store intrinsic, like |
256 | // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align, |
257 | // <16 x i1> %mask) |
258 | // to a chain of basic blocks, that stores element one-by-one if |
259 | // the appropriate mask bit is set |
260 | // |
261 | // %1 = bitcast i8* %addr to i32* |
262 | // %2 = extractelement <16 x i1> %mask, i32 0 |
263 | // br i1 %2, label %cond.store, label %else |
264 | // |
265 | // cond.store: ; preds = %0 |
266 | // %3 = extractelement <16 x i32> %val, i32 0 |
267 | // %4 = getelementptr i32* %1, i32 0 |
268 | // store i32 %3, i32* %4 |
269 | // br label %else |
270 | // |
271 | // else: ; preds = %0, %cond.store |
272 | // %5 = extractelement <16 x i1> %mask, i32 1 |
273 | // br i1 %5, label %cond.store1, label %else2 |
274 | // |
275 | // cond.store1: ; preds = %else |
276 | // %6 = extractelement <16 x i32> %val, i32 1 |
277 | // %7 = getelementptr i32* %1, i32 1 |
278 | // store i32 %6, i32* %7 |
279 | // br label %else2 |
280 | // . . . |
281 | static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI, |
282 | DomTreeUpdater *DTU, bool &ModifiedDT) { |
283 | Value *Src = CI->getArgOperand(i: 0); |
284 | Value *Ptr = CI->getArgOperand(i: 1); |
285 | Value *Alignment = CI->getArgOperand(i: 2); |
286 | Value *Mask = CI->getArgOperand(i: 3); |
287 | |
288 | const Align AlignVal = cast<ConstantInt>(Val: Alignment)->getAlignValue(); |
289 | auto *VecType = cast<VectorType>(Val: Src->getType()); |
290 | |
291 | Type *EltTy = VecType->getElementType(); |
292 | |
293 | IRBuilder<> Builder(CI->getContext()); |
294 | Instruction *InsertPt = CI; |
295 | Builder.SetInsertPoint(InsertPt); |
296 | Builder.SetCurrentDebugLocation(CI->getDebugLoc()); |
297 | |
298 | // Short-cut if the mask is all-true. |
299 | if (isa<Constant>(Val: Mask) && cast<Constant>(Val: Mask)->isAllOnesValue()) { |
300 | Builder.CreateAlignedStore(Val: Src, Ptr, Align: AlignVal); |
301 | CI->eraseFromParent(); |
302 | return; |
303 | } |
304 | |
305 | // Adjust alignment for the scalar instruction. |
306 | const Align AdjustedAlignVal = |
307 | commonAlignment(A: AlignVal, Offset: EltTy->getPrimitiveSizeInBits() / 8); |
308 | unsigned VectorWidth = cast<FixedVectorType>(Val: VecType)->getNumElements(); |
309 | |
310 | if (isConstantIntVector(Mask)) { |
311 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
312 | if (cast<Constant>(Val: Mask)->getAggregateElement(Elt: Idx)->isNullValue()) |
313 | continue; |
314 | Value *OneElt = Builder.CreateExtractElement(Vec: Src, Idx); |
315 | Value *Gep = Builder.CreateConstInBoundsGEP1_32(Ty: EltTy, Ptr, Idx0: Idx); |
316 | Builder.CreateAlignedStore(Val: OneElt, Ptr: Gep, Align: AdjustedAlignVal); |
317 | } |
318 | CI->eraseFromParent(); |
319 | return; |
320 | } |
321 | |
322 | // If the mask is not v1i1, use scalar bit test operations. This generates |
323 | // better results on X86 at least. |
324 | Value *SclrMask; |
325 | if (VectorWidth != 1) { |
326 | Type *SclrMaskTy = Builder.getIntNTy(N: VectorWidth); |
327 | SclrMask = Builder.CreateBitCast(V: Mask, DestTy: SclrMaskTy, Name: "scalar_mask" ); |
328 | } |
329 | |
330 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
331 | // Fill the "else" block, created in the previous iteration |
332 | // |
333 | // %mask_1 = and i16 %scalar_mask, i32 1 << Idx |
334 | // %cond = icmp ne i16 %mask_1, 0 |
335 | // br i1 %mask_1, label %cond.store, label %else |
336 | // |
337 | Value *Predicate; |
338 | if (VectorWidth != 1) { |
339 | Value *Mask = Builder.getInt(AI: APInt::getOneBitSet( |
340 | numBits: VectorWidth, BitNo: adjustForEndian(DL, VectorWidth, Idx))); |
341 | Predicate = Builder.CreateICmpNE(LHS: Builder.CreateAnd(LHS: SclrMask, RHS: Mask), |
342 | RHS: Builder.getIntN(N: VectorWidth, C: 0)); |
343 | } else { |
344 | Predicate = Builder.CreateExtractElement(Vec: Mask, Idx); |
345 | } |
346 | |
347 | // Create "cond" block |
348 | // |
349 | // %OneElt = extractelement <16 x i32> %Src, i32 Idx |
350 | // %EltAddr = getelementptr i32* %1, i32 0 |
351 | // %store i32 %OneElt, i32* %EltAddr |
352 | // |
353 | Instruction *ThenTerm = |
354 | SplitBlockAndInsertIfThen(Cond: Predicate, SplitBefore: InsertPt, /*Unreachable=*/false, |
355 | /*BranchWeights=*/nullptr, DTU); |
356 | |
357 | BasicBlock *CondBlock = ThenTerm->getParent(); |
358 | CondBlock->setName("cond.store" ); |
359 | |
360 | Builder.SetInsertPoint(CondBlock->getTerminator()); |
361 | Value *OneElt = Builder.CreateExtractElement(Vec: Src, Idx); |
362 | Value *Gep = Builder.CreateConstInBoundsGEP1_32(Ty: EltTy, Ptr, Idx0: Idx); |
363 | Builder.CreateAlignedStore(Val: OneElt, Ptr: Gep, Align: AdjustedAlignVal); |
364 | |
365 | // Create "else" block, fill it in the next iteration |
366 | BasicBlock *NewIfBlock = ThenTerm->getSuccessor(Idx: 0); |
367 | NewIfBlock->setName("else" ); |
368 | |
369 | Builder.SetInsertPoint(TheBB: NewIfBlock, IP: NewIfBlock->begin()); |
370 | } |
371 | CI->eraseFromParent(); |
372 | |
373 | ModifiedDT = true; |
374 | } |
375 | |
376 | // Translate a masked gather intrinsic like |
377 | // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4, |
378 | // <16 x i1> %Mask, <16 x i32> %Src) |
379 | // to a chain of basic blocks, with loading element one-by-one if |
380 | // the appropriate mask bit is set |
381 | // |
382 | // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind |
383 | // %Mask0 = extractelement <16 x i1> %Mask, i32 0 |
384 | // br i1 %Mask0, label %cond.load, label %else |
385 | // |
386 | // cond.load: |
387 | // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 |
388 | // %Load0 = load i32, i32* %Ptr0, align 4 |
389 | // %Res0 = insertelement <16 x i32> poison, i32 %Load0, i32 0 |
390 | // br label %else |
391 | // |
392 | // else: |
393 | // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [poison, %0] |
394 | // %Mask1 = extractelement <16 x i1> %Mask, i32 1 |
395 | // br i1 %Mask1, label %cond.load1, label %else2 |
396 | // |
397 | // cond.load1: |
398 | // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 |
399 | // %Load1 = load i32, i32* %Ptr1, align 4 |
400 | // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1 |
401 | // br label %else2 |
402 | // . . . |
403 | // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src |
404 | // ret <16 x i32> %Result |
405 | static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI, |
406 | DomTreeUpdater *DTU, bool &ModifiedDT) { |
407 | Value *Ptrs = CI->getArgOperand(i: 0); |
408 | Value *Alignment = CI->getArgOperand(i: 1); |
409 | Value *Mask = CI->getArgOperand(i: 2); |
410 | Value *Src0 = CI->getArgOperand(i: 3); |
411 | |
412 | auto *VecType = cast<FixedVectorType>(Val: CI->getType()); |
413 | Type *EltTy = VecType->getElementType(); |
414 | |
415 | IRBuilder<> Builder(CI->getContext()); |
416 | Instruction *InsertPt = CI; |
417 | BasicBlock *IfBlock = CI->getParent(); |
418 | Builder.SetInsertPoint(InsertPt); |
419 | MaybeAlign AlignVal = cast<ConstantInt>(Val: Alignment)->getMaybeAlignValue(); |
420 | |
421 | Builder.SetCurrentDebugLocation(CI->getDebugLoc()); |
422 | |
423 | // The result vector |
424 | Value *VResult = Src0; |
425 | unsigned VectorWidth = VecType->getNumElements(); |
426 | |
427 | // Shorten the way if the mask is a vector of constants. |
428 | if (isConstantIntVector(Mask)) { |
429 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
430 | if (cast<Constant>(Val: Mask)->getAggregateElement(Elt: Idx)->isNullValue()) |
431 | continue; |
432 | Value *Ptr = Builder.CreateExtractElement(Vec: Ptrs, Idx, Name: "Ptr" + Twine(Idx)); |
433 | LoadInst *Load = |
434 | Builder.CreateAlignedLoad(Ty: EltTy, Ptr, Align: AlignVal, Name: "Load" + Twine(Idx)); |
435 | VResult = |
436 | Builder.CreateInsertElement(Vec: VResult, NewElt: Load, Idx, Name: "Res" + Twine(Idx)); |
437 | } |
438 | CI->replaceAllUsesWith(V: VResult); |
439 | CI->eraseFromParent(); |
440 | return; |
441 | } |
442 | |
443 | // If the mask is not v1i1, use scalar bit test operations. This generates |
444 | // better results on X86 at least. |
445 | Value *SclrMask; |
446 | if (VectorWidth != 1) { |
447 | Type *SclrMaskTy = Builder.getIntNTy(N: VectorWidth); |
448 | SclrMask = Builder.CreateBitCast(V: Mask, DestTy: SclrMaskTy, Name: "scalar_mask" ); |
449 | } |
450 | |
451 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
452 | // Fill the "else" block, created in the previous iteration |
453 | // |
454 | // %Mask1 = and i16 %scalar_mask, i32 1 << Idx |
455 | // %cond = icmp ne i16 %mask_1, 0 |
456 | // br i1 %Mask1, label %cond.load, label %else |
457 | // |
458 | |
459 | Value *Predicate; |
460 | if (VectorWidth != 1) { |
461 | Value *Mask = Builder.getInt(AI: APInt::getOneBitSet( |
462 | numBits: VectorWidth, BitNo: adjustForEndian(DL, VectorWidth, Idx))); |
463 | Predicate = Builder.CreateICmpNE(LHS: Builder.CreateAnd(LHS: SclrMask, RHS: Mask), |
464 | RHS: Builder.getIntN(N: VectorWidth, C: 0)); |
465 | } else { |
466 | Predicate = Builder.CreateExtractElement(Vec: Mask, Idx, Name: "Mask" + Twine(Idx)); |
467 | } |
468 | |
469 | // Create "cond" block |
470 | // |
471 | // %EltAddr = getelementptr i32* %1, i32 0 |
472 | // %Elt = load i32* %EltAddr |
473 | // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx |
474 | // |
475 | Instruction *ThenTerm = |
476 | SplitBlockAndInsertIfThen(Cond: Predicate, SplitBefore: InsertPt, /*Unreachable=*/false, |
477 | /*BranchWeights=*/nullptr, DTU); |
478 | |
479 | BasicBlock *CondBlock = ThenTerm->getParent(); |
480 | CondBlock->setName("cond.load" ); |
481 | |
482 | Builder.SetInsertPoint(CondBlock->getTerminator()); |
483 | Value *Ptr = Builder.CreateExtractElement(Vec: Ptrs, Idx, Name: "Ptr" + Twine(Idx)); |
484 | LoadInst *Load = |
485 | Builder.CreateAlignedLoad(Ty: EltTy, Ptr, Align: AlignVal, Name: "Load" + Twine(Idx)); |
486 | Value *NewVResult = |
487 | Builder.CreateInsertElement(Vec: VResult, NewElt: Load, Idx, Name: "Res" + Twine(Idx)); |
488 | |
489 | // Create "else" block, fill it in the next iteration |
490 | BasicBlock *NewIfBlock = ThenTerm->getSuccessor(Idx: 0); |
491 | NewIfBlock->setName("else" ); |
492 | BasicBlock *PrevIfBlock = IfBlock; |
493 | IfBlock = NewIfBlock; |
494 | |
495 | // Create the phi to join the new and previous value. |
496 | Builder.SetInsertPoint(TheBB: NewIfBlock, IP: NewIfBlock->begin()); |
497 | PHINode *Phi = Builder.CreatePHI(Ty: VecType, NumReservedValues: 2, Name: "res.phi.else" ); |
498 | Phi->addIncoming(V: NewVResult, BB: CondBlock); |
499 | Phi->addIncoming(V: VResult, BB: PrevIfBlock); |
500 | VResult = Phi; |
501 | } |
502 | |
503 | CI->replaceAllUsesWith(V: VResult); |
504 | CI->eraseFromParent(); |
505 | |
506 | ModifiedDT = true; |
507 | } |
508 | |
509 | // Translate a masked scatter intrinsic, like |
510 | // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4, |
511 | // <16 x i1> %Mask) |
512 | // to a chain of basic blocks, that stores element one-by-one if |
513 | // the appropriate mask bit is set. |
514 | // |
515 | // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind |
516 | // %Mask0 = extractelement <16 x i1> %Mask, i32 0 |
517 | // br i1 %Mask0, label %cond.store, label %else |
518 | // |
519 | // cond.store: |
520 | // %Elt0 = extractelement <16 x i32> %Src, i32 0 |
521 | // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 |
522 | // store i32 %Elt0, i32* %Ptr0, align 4 |
523 | // br label %else |
524 | // |
525 | // else: |
526 | // %Mask1 = extractelement <16 x i1> %Mask, i32 1 |
527 | // br i1 %Mask1, label %cond.store1, label %else2 |
528 | // |
529 | // cond.store1: |
530 | // %Elt1 = extractelement <16 x i32> %Src, i32 1 |
531 | // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 |
532 | // store i32 %Elt1, i32* %Ptr1, align 4 |
533 | // br label %else2 |
534 | // . . . |
535 | static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI, |
536 | DomTreeUpdater *DTU, bool &ModifiedDT) { |
537 | Value *Src = CI->getArgOperand(i: 0); |
538 | Value *Ptrs = CI->getArgOperand(i: 1); |
539 | Value *Alignment = CI->getArgOperand(i: 2); |
540 | Value *Mask = CI->getArgOperand(i: 3); |
541 | |
542 | auto *SrcFVTy = cast<FixedVectorType>(Val: Src->getType()); |
543 | |
544 | assert( |
545 | isa<VectorType>(Ptrs->getType()) && |
546 | isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) && |
547 | "Vector of pointers is expected in masked scatter intrinsic" ); |
548 | |
549 | IRBuilder<> Builder(CI->getContext()); |
550 | Instruction *InsertPt = CI; |
551 | Builder.SetInsertPoint(InsertPt); |
552 | Builder.SetCurrentDebugLocation(CI->getDebugLoc()); |
553 | |
554 | MaybeAlign AlignVal = cast<ConstantInt>(Val: Alignment)->getMaybeAlignValue(); |
555 | unsigned VectorWidth = SrcFVTy->getNumElements(); |
556 | |
557 | // Shorten the way if the mask is a vector of constants. |
558 | if (isConstantIntVector(Mask)) { |
559 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
560 | if (cast<Constant>(Val: Mask)->getAggregateElement(Elt: Idx)->isNullValue()) |
561 | continue; |
562 | Value *OneElt = |
563 | Builder.CreateExtractElement(Vec: Src, Idx, Name: "Elt" + Twine(Idx)); |
564 | Value *Ptr = Builder.CreateExtractElement(Vec: Ptrs, Idx, Name: "Ptr" + Twine(Idx)); |
565 | Builder.CreateAlignedStore(Val: OneElt, Ptr, Align: AlignVal); |
566 | } |
567 | CI->eraseFromParent(); |
568 | return; |
569 | } |
570 | |
571 | // If the mask is not v1i1, use scalar bit test operations. This generates |
572 | // better results on X86 at least. |
573 | Value *SclrMask; |
574 | if (VectorWidth != 1) { |
575 | Type *SclrMaskTy = Builder.getIntNTy(N: VectorWidth); |
576 | SclrMask = Builder.CreateBitCast(V: Mask, DestTy: SclrMaskTy, Name: "scalar_mask" ); |
577 | } |
578 | |
579 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
580 | // Fill the "else" block, created in the previous iteration |
581 | // |
582 | // %Mask1 = and i16 %scalar_mask, i32 1 << Idx |
583 | // %cond = icmp ne i16 %mask_1, 0 |
584 | // br i1 %Mask1, label %cond.store, label %else |
585 | // |
586 | Value *Predicate; |
587 | if (VectorWidth != 1) { |
588 | Value *Mask = Builder.getInt(AI: APInt::getOneBitSet( |
589 | numBits: VectorWidth, BitNo: adjustForEndian(DL, VectorWidth, Idx))); |
590 | Predicate = Builder.CreateICmpNE(LHS: Builder.CreateAnd(LHS: SclrMask, RHS: Mask), |
591 | RHS: Builder.getIntN(N: VectorWidth, C: 0)); |
592 | } else { |
593 | Predicate = Builder.CreateExtractElement(Vec: Mask, Idx, Name: "Mask" + Twine(Idx)); |
594 | } |
595 | |
596 | // Create "cond" block |
597 | // |
598 | // %Elt1 = extractelement <16 x i32> %Src, i32 1 |
599 | // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 |
600 | // %store i32 %Elt1, i32* %Ptr1 |
601 | // |
602 | Instruction *ThenTerm = |
603 | SplitBlockAndInsertIfThen(Cond: Predicate, SplitBefore: InsertPt, /*Unreachable=*/false, |
604 | /*BranchWeights=*/nullptr, DTU); |
605 | |
606 | BasicBlock *CondBlock = ThenTerm->getParent(); |
607 | CondBlock->setName("cond.store" ); |
608 | |
609 | Builder.SetInsertPoint(CondBlock->getTerminator()); |
610 | Value *OneElt = Builder.CreateExtractElement(Vec: Src, Idx, Name: "Elt" + Twine(Idx)); |
611 | Value *Ptr = Builder.CreateExtractElement(Vec: Ptrs, Idx, Name: "Ptr" + Twine(Idx)); |
612 | Builder.CreateAlignedStore(Val: OneElt, Ptr, Align: AlignVal); |
613 | |
614 | // Create "else" block, fill it in the next iteration |
615 | BasicBlock *NewIfBlock = ThenTerm->getSuccessor(Idx: 0); |
616 | NewIfBlock->setName("else" ); |
617 | |
618 | Builder.SetInsertPoint(TheBB: NewIfBlock, IP: NewIfBlock->begin()); |
619 | } |
620 | CI->eraseFromParent(); |
621 | |
622 | ModifiedDT = true; |
623 | } |
624 | |
625 | static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI, |
626 | DomTreeUpdater *DTU, bool &ModifiedDT) { |
627 | Value *Ptr = CI->getArgOperand(i: 0); |
628 | Value *Mask = CI->getArgOperand(i: 1); |
629 | Value *PassThru = CI->getArgOperand(i: 2); |
630 | Align Alignment = CI->getParamAlign(ArgNo: 0).valueOrOne(); |
631 | |
632 | auto *VecType = cast<FixedVectorType>(Val: CI->getType()); |
633 | |
634 | Type *EltTy = VecType->getElementType(); |
635 | |
636 | IRBuilder<> Builder(CI->getContext()); |
637 | Instruction *InsertPt = CI; |
638 | BasicBlock *IfBlock = CI->getParent(); |
639 | |
640 | Builder.SetInsertPoint(InsertPt); |
641 | Builder.SetCurrentDebugLocation(CI->getDebugLoc()); |
642 | |
643 | unsigned VectorWidth = VecType->getNumElements(); |
644 | |
645 | // The result vector |
646 | Value *VResult = PassThru; |
647 | |
648 | // Adjust alignment for the scalar instruction. |
649 | const Align AdjustedAlignment = |
650 | commonAlignment(A: Alignment, Offset: EltTy->getPrimitiveSizeInBits() / 8); |
651 | |
652 | // Shorten the way if the mask is a vector of constants. |
653 | // Create a build_vector pattern, with loads/poisons as necessary and then |
654 | // shuffle blend with the pass through value. |
655 | if (isConstantIntVector(Mask)) { |
656 | unsigned MemIndex = 0; |
657 | VResult = PoisonValue::get(T: VecType); |
658 | SmallVector<int, 16> ShuffleMask(VectorWidth, PoisonMaskElem); |
659 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
660 | Value *InsertElt; |
661 | if (cast<Constant>(Val: Mask)->getAggregateElement(Elt: Idx)->isNullValue()) { |
662 | InsertElt = PoisonValue::get(T: EltTy); |
663 | ShuffleMask[Idx] = Idx + VectorWidth; |
664 | } else { |
665 | Value *NewPtr = |
666 | Builder.CreateConstInBoundsGEP1_32(Ty: EltTy, Ptr, Idx0: MemIndex); |
667 | InsertElt = Builder.CreateAlignedLoad(Ty: EltTy, Ptr: NewPtr, Align: AdjustedAlignment, |
668 | Name: "Load" + Twine(Idx)); |
669 | ShuffleMask[Idx] = Idx; |
670 | ++MemIndex; |
671 | } |
672 | VResult = Builder.CreateInsertElement(Vec: VResult, NewElt: InsertElt, Idx, |
673 | Name: "Res" + Twine(Idx)); |
674 | } |
675 | VResult = Builder.CreateShuffleVector(V1: VResult, V2: PassThru, Mask: ShuffleMask); |
676 | CI->replaceAllUsesWith(V: VResult); |
677 | CI->eraseFromParent(); |
678 | return; |
679 | } |
680 | |
681 | // If the mask is not v1i1, use scalar bit test operations. This generates |
682 | // better results on X86 at least. |
683 | Value *SclrMask; |
684 | if (VectorWidth != 1) { |
685 | Type *SclrMaskTy = Builder.getIntNTy(N: VectorWidth); |
686 | SclrMask = Builder.CreateBitCast(V: Mask, DestTy: SclrMaskTy, Name: "scalar_mask" ); |
687 | } |
688 | |
689 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
690 | // Fill the "else" block, created in the previous iteration |
691 | // |
692 | // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] |
693 | // %mask_1 = extractelement <16 x i1> %mask, i32 Idx |
694 | // br i1 %mask_1, label %cond.load, label %else |
695 | // |
696 | |
697 | Value *Predicate; |
698 | if (VectorWidth != 1) { |
699 | Value *Mask = Builder.getInt(AI: APInt::getOneBitSet( |
700 | numBits: VectorWidth, BitNo: adjustForEndian(DL, VectorWidth, Idx))); |
701 | Predicate = Builder.CreateICmpNE(LHS: Builder.CreateAnd(LHS: SclrMask, RHS: Mask), |
702 | RHS: Builder.getIntN(N: VectorWidth, C: 0)); |
703 | } else { |
704 | Predicate = Builder.CreateExtractElement(Vec: Mask, Idx, Name: "Mask" + Twine(Idx)); |
705 | } |
706 | |
707 | // Create "cond" block |
708 | // |
709 | // %EltAddr = getelementptr i32* %1, i32 0 |
710 | // %Elt = load i32* %EltAddr |
711 | // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx |
712 | // |
713 | Instruction *ThenTerm = |
714 | SplitBlockAndInsertIfThen(Cond: Predicate, SplitBefore: InsertPt, /*Unreachable=*/false, |
715 | /*BranchWeights=*/nullptr, DTU); |
716 | |
717 | BasicBlock *CondBlock = ThenTerm->getParent(); |
718 | CondBlock->setName("cond.load" ); |
719 | |
720 | Builder.SetInsertPoint(CondBlock->getTerminator()); |
721 | LoadInst *Load = Builder.CreateAlignedLoad(Ty: EltTy, Ptr, Align: AdjustedAlignment); |
722 | Value *NewVResult = Builder.CreateInsertElement(Vec: VResult, NewElt: Load, Idx); |
723 | |
724 | // Move the pointer if there are more blocks to come. |
725 | Value *NewPtr; |
726 | if ((Idx + 1) != VectorWidth) |
727 | NewPtr = Builder.CreateConstInBoundsGEP1_32(Ty: EltTy, Ptr, Idx0: 1); |
728 | |
729 | // Create "else" block, fill it in the next iteration |
730 | BasicBlock *NewIfBlock = ThenTerm->getSuccessor(Idx: 0); |
731 | NewIfBlock->setName("else" ); |
732 | BasicBlock *PrevIfBlock = IfBlock; |
733 | IfBlock = NewIfBlock; |
734 | |
735 | // Create the phi to join the new and previous value. |
736 | Builder.SetInsertPoint(TheBB: NewIfBlock, IP: NewIfBlock->begin()); |
737 | PHINode *ResultPhi = Builder.CreatePHI(Ty: VecType, NumReservedValues: 2, Name: "res.phi.else" ); |
738 | ResultPhi->addIncoming(V: NewVResult, BB: CondBlock); |
739 | ResultPhi->addIncoming(V: VResult, BB: PrevIfBlock); |
740 | VResult = ResultPhi; |
741 | |
742 | // Add a PHI for the pointer if this isn't the last iteration. |
743 | if ((Idx + 1) != VectorWidth) { |
744 | PHINode *PtrPhi = Builder.CreatePHI(Ty: Ptr->getType(), NumReservedValues: 2, Name: "ptr.phi.else" ); |
745 | PtrPhi->addIncoming(V: NewPtr, BB: CondBlock); |
746 | PtrPhi->addIncoming(V: Ptr, BB: PrevIfBlock); |
747 | Ptr = PtrPhi; |
748 | } |
749 | } |
750 | |
751 | CI->replaceAllUsesWith(V: VResult); |
752 | CI->eraseFromParent(); |
753 | |
754 | ModifiedDT = true; |
755 | } |
756 | |
757 | static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI, |
758 | DomTreeUpdater *DTU, |
759 | bool &ModifiedDT) { |
760 | Value *Src = CI->getArgOperand(i: 0); |
761 | Value *Ptr = CI->getArgOperand(i: 1); |
762 | Value *Mask = CI->getArgOperand(i: 2); |
763 | Align Alignment = CI->getParamAlign(ArgNo: 1).valueOrOne(); |
764 | |
765 | auto *VecType = cast<FixedVectorType>(Val: Src->getType()); |
766 | |
767 | IRBuilder<> Builder(CI->getContext()); |
768 | Instruction *InsertPt = CI; |
769 | BasicBlock *IfBlock = CI->getParent(); |
770 | |
771 | Builder.SetInsertPoint(InsertPt); |
772 | Builder.SetCurrentDebugLocation(CI->getDebugLoc()); |
773 | |
774 | Type *EltTy = VecType->getElementType(); |
775 | |
776 | // Adjust alignment for the scalar instruction. |
777 | const Align AdjustedAlignment = |
778 | commonAlignment(A: Alignment, Offset: EltTy->getPrimitiveSizeInBits() / 8); |
779 | |
780 | unsigned VectorWidth = VecType->getNumElements(); |
781 | |
782 | // Shorten the way if the mask is a vector of constants. |
783 | if (isConstantIntVector(Mask)) { |
784 | unsigned MemIndex = 0; |
785 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
786 | if (cast<Constant>(Val: Mask)->getAggregateElement(Elt: Idx)->isNullValue()) |
787 | continue; |
788 | Value *OneElt = |
789 | Builder.CreateExtractElement(Vec: Src, Idx, Name: "Elt" + Twine(Idx)); |
790 | Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(Ty: EltTy, Ptr, Idx0: MemIndex); |
791 | Builder.CreateAlignedStore(Val: OneElt, Ptr: NewPtr, Align: AdjustedAlignment); |
792 | ++MemIndex; |
793 | } |
794 | CI->eraseFromParent(); |
795 | return; |
796 | } |
797 | |
798 | // If the mask is not v1i1, use scalar bit test operations. This generates |
799 | // better results on X86 at least. |
800 | Value *SclrMask; |
801 | if (VectorWidth != 1) { |
802 | Type *SclrMaskTy = Builder.getIntNTy(N: VectorWidth); |
803 | SclrMask = Builder.CreateBitCast(V: Mask, DestTy: SclrMaskTy, Name: "scalar_mask" ); |
804 | } |
805 | |
806 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
807 | // Fill the "else" block, created in the previous iteration |
808 | // |
809 | // %mask_1 = extractelement <16 x i1> %mask, i32 Idx |
810 | // br i1 %mask_1, label %cond.store, label %else |
811 | // |
812 | Value *Predicate; |
813 | if (VectorWidth != 1) { |
814 | Value *Mask = Builder.getInt(AI: APInt::getOneBitSet( |
815 | numBits: VectorWidth, BitNo: adjustForEndian(DL, VectorWidth, Idx))); |
816 | Predicate = Builder.CreateICmpNE(LHS: Builder.CreateAnd(LHS: SclrMask, RHS: Mask), |
817 | RHS: Builder.getIntN(N: VectorWidth, C: 0)); |
818 | } else { |
819 | Predicate = Builder.CreateExtractElement(Vec: Mask, Idx, Name: "Mask" + Twine(Idx)); |
820 | } |
821 | |
822 | // Create "cond" block |
823 | // |
824 | // %OneElt = extractelement <16 x i32> %Src, i32 Idx |
825 | // %EltAddr = getelementptr i32* %1, i32 0 |
826 | // %store i32 %OneElt, i32* %EltAddr |
827 | // |
828 | Instruction *ThenTerm = |
829 | SplitBlockAndInsertIfThen(Cond: Predicate, SplitBefore: InsertPt, /*Unreachable=*/false, |
830 | /*BranchWeights=*/nullptr, DTU); |
831 | |
832 | BasicBlock *CondBlock = ThenTerm->getParent(); |
833 | CondBlock->setName("cond.store" ); |
834 | |
835 | Builder.SetInsertPoint(CondBlock->getTerminator()); |
836 | Value *OneElt = Builder.CreateExtractElement(Vec: Src, Idx); |
837 | Builder.CreateAlignedStore(Val: OneElt, Ptr, Align: AdjustedAlignment); |
838 | |
839 | // Move the pointer if there are more blocks to come. |
840 | Value *NewPtr; |
841 | if ((Idx + 1) != VectorWidth) |
842 | NewPtr = Builder.CreateConstInBoundsGEP1_32(Ty: EltTy, Ptr, Idx0: 1); |
843 | |
844 | // Create "else" block, fill it in the next iteration |
845 | BasicBlock *NewIfBlock = ThenTerm->getSuccessor(Idx: 0); |
846 | NewIfBlock->setName("else" ); |
847 | BasicBlock *PrevIfBlock = IfBlock; |
848 | IfBlock = NewIfBlock; |
849 | |
850 | Builder.SetInsertPoint(TheBB: NewIfBlock, IP: NewIfBlock->begin()); |
851 | |
852 | // Add a PHI for the pointer if this isn't the last iteration. |
853 | if ((Idx + 1) != VectorWidth) { |
854 | PHINode *PtrPhi = Builder.CreatePHI(Ty: Ptr->getType(), NumReservedValues: 2, Name: "ptr.phi.else" ); |
855 | PtrPhi->addIncoming(V: NewPtr, BB: CondBlock); |
856 | PtrPhi->addIncoming(V: Ptr, BB: PrevIfBlock); |
857 | Ptr = PtrPhi; |
858 | } |
859 | } |
860 | CI->eraseFromParent(); |
861 | |
862 | ModifiedDT = true; |
863 | } |
864 | |
865 | static void scalarizeMaskedVectorHistogram(const DataLayout &DL, CallInst *CI, |
866 | DomTreeUpdater *DTU, |
867 | bool &ModifiedDT) { |
868 | // If we extend histogram to return a result someday (like the updated vector) |
869 | // then we'll need to support it here. |
870 | assert(CI->getType()->isVoidTy() && "Histogram with non-void return." ); |
871 | Value *Ptrs = CI->getArgOperand(i: 0); |
872 | Value *Inc = CI->getArgOperand(i: 1); |
873 | Value *Mask = CI->getArgOperand(i: 2); |
874 | |
875 | auto *AddrType = cast<FixedVectorType>(Val: Ptrs->getType()); |
876 | Type *EltTy = Inc->getType(); |
877 | |
878 | IRBuilder<> Builder(CI->getContext()); |
879 | Instruction *InsertPt = CI; |
880 | Builder.SetInsertPoint(InsertPt); |
881 | |
882 | Builder.SetCurrentDebugLocation(CI->getDebugLoc()); |
883 | |
884 | // FIXME: Do we need to add an alignment parameter to the intrinsic? |
885 | unsigned VectorWidth = AddrType->getNumElements(); |
886 | |
887 | // Shorten the way if the mask is a vector of constants. |
888 | if (isConstantIntVector(Mask)) { |
889 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
890 | if (cast<Constant>(Val: Mask)->getAggregateElement(Elt: Idx)->isNullValue()) |
891 | continue; |
892 | Value *Ptr = Builder.CreateExtractElement(Vec: Ptrs, Idx, Name: "Ptr" + Twine(Idx)); |
893 | LoadInst *Load = Builder.CreateLoad(Ty: EltTy, Ptr, Name: "Load" + Twine(Idx)); |
894 | Value *Add = Builder.CreateAdd(LHS: Load, RHS: Inc); |
895 | Builder.CreateStore(Val: Add, Ptr); |
896 | } |
897 | CI->eraseFromParent(); |
898 | return; |
899 | } |
900 | |
901 | for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { |
902 | Value *Predicate = |
903 | Builder.CreateExtractElement(Vec: Mask, Idx, Name: "Mask" + Twine(Idx)); |
904 | |
905 | Instruction *ThenTerm = |
906 | SplitBlockAndInsertIfThen(Cond: Predicate, SplitBefore: InsertPt, /*Unreachable=*/false, |
907 | /*BranchWeights=*/nullptr, DTU); |
908 | |
909 | BasicBlock *CondBlock = ThenTerm->getParent(); |
910 | CondBlock->setName("cond.histogram.update" ); |
911 | |
912 | Builder.SetInsertPoint(CondBlock->getTerminator()); |
913 | Value *Ptr = Builder.CreateExtractElement(Vec: Ptrs, Idx, Name: "Ptr" + Twine(Idx)); |
914 | LoadInst *Load = Builder.CreateLoad(Ty: EltTy, Ptr, Name: "Load" + Twine(Idx)); |
915 | Value *Add = Builder.CreateAdd(LHS: Load, RHS: Inc); |
916 | Builder.CreateStore(Val: Add, Ptr); |
917 | |
918 | // Create "else" block, fill it in the next iteration |
919 | BasicBlock *NewIfBlock = ThenTerm->getSuccessor(Idx: 0); |
920 | NewIfBlock->setName("else" ); |
921 | Builder.SetInsertPoint(TheBB: NewIfBlock, IP: NewIfBlock->begin()); |
922 | } |
923 | |
924 | CI->eraseFromParent(); |
925 | ModifiedDT = true; |
926 | } |
927 | |
928 | static bool runImpl(Function &F, const TargetTransformInfo &TTI, |
929 | DominatorTree *DT) { |
930 | std::optional<DomTreeUpdater> DTU; |
931 | if (DT) |
932 | DTU.emplace(args&: DT, args: DomTreeUpdater::UpdateStrategy::Lazy); |
933 | |
934 | bool EverMadeChange = false; |
935 | bool MadeChange = true; |
936 | auto &DL = F.getDataLayout(); |
937 | while (MadeChange) { |
938 | MadeChange = false; |
939 | for (BasicBlock &BB : llvm::make_early_inc_range(Range&: F)) { |
940 | bool ModifiedDTOnIteration = false; |
941 | MadeChange |= optimizeBlock(BB, ModifiedDT&: ModifiedDTOnIteration, TTI, DL, |
942 | DTU: DTU ? &*DTU : nullptr); |
943 | |
944 | // Restart BB iteration if the dominator tree of the Function was changed |
945 | if (ModifiedDTOnIteration) |
946 | break; |
947 | } |
948 | |
949 | EverMadeChange |= MadeChange; |
950 | } |
951 | return EverMadeChange; |
952 | } |
953 | |
954 | bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) { |
955 | auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); |
956 | DominatorTree *DT = nullptr; |
957 | if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>()) |
958 | DT = &DTWP->getDomTree(); |
959 | return runImpl(F, TTI, DT); |
960 | } |
961 | |
962 | PreservedAnalyses |
963 | ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) { |
964 | auto &TTI = AM.getResult<TargetIRAnalysis>(IR&: F); |
965 | auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(IR&: F); |
966 | if (!runImpl(F, TTI, DT)) |
967 | return PreservedAnalyses::all(); |
968 | PreservedAnalyses PA; |
969 | PA.preserve<TargetIRAnalysis>(); |
970 | PA.preserve<DominatorTreeAnalysis>(); |
971 | return PA; |
972 | } |
973 | |
974 | static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, |
975 | const TargetTransformInfo &TTI, const DataLayout &DL, |
976 | DomTreeUpdater *DTU) { |
977 | bool MadeChange = false; |
978 | |
979 | BasicBlock::iterator CurInstIterator = BB.begin(); |
980 | while (CurInstIterator != BB.end()) { |
981 | if (CallInst *CI = dyn_cast<CallInst>(Val: &*CurInstIterator++)) |
982 | MadeChange |= optimizeCallInst(CI, ModifiedDT, TTI, DL, DTU); |
983 | if (ModifiedDT) |
984 | return true; |
985 | } |
986 | |
987 | return MadeChange; |
988 | } |
989 | |
990 | static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, |
991 | const TargetTransformInfo &TTI, |
992 | const DataLayout &DL, DomTreeUpdater *DTU) { |
993 | IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: CI); |
994 | if (II) { |
995 | // The scalarization code below does not work for scalable vectors. |
996 | if (isa<ScalableVectorType>(Val: II->getType()) || |
997 | any_of(Range: II->args(), |
998 | P: [](Value *V) { return isa<ScalableVectorType>(Val: V->getType()); })) |
999 | return false; |
1000 | |
1001 | switch (II->getIntrinsicID()) { |
1002 | default: |
1003 | break; |
1004 | case Intrinsic::experimental_vector_histogram_add: |
1005 | if (TTI.isLegalMaskedVectorHistogram(AddrType: CI->getArgOperand(i: 0)->getType(), |
1006 | DataType: CI->getArgOperand(i: 1)->getType())) |
1007 | return false; |
1008 | scalarizeMaskedVectorHistogram(DL, CI, DTU, ModifiedDT); |
1009 | return true; |
1010 | case Intrinsic::masked_load: |
1011 | // Scalarize unsupported vector masked load |
1012 | if (TTI.isLegalMaskedLoad( |
1013 | DataType: CI->getType(), |
1014 | Alignment: cast<ConstantInt>(Val: CI->getArgOperand(i: 1))->getAlignValue())) |
1015 | return false; |
1016 | scalarizeMaskedLoad(DL, CI, DTU, ModifiedDT); |
1017 | return true; |
1018 | case Intrinsic::masked_store: |
1019 | if (TTI.isLegalMaskedStore( |
1020 | DataType: CI->getArgOperand(i: 0)->getType(), |
1021 | Alignment: cast<ConstantInt>(Val: CI->getArgOperand(i: 2))->getAlignValue())) |
1022 | return false; |
1023 | scalarizeMaskedStore(DL, CI, DTU, ModifiedDT); |
1024 | return true; |
1025 | case Intrinsic::masked_gather: { |
1026 | MaybeAlign MA = |
1027 | cast<ConstantInt>(Val: CI->getArgOperand(i: 1))->getMaybeAlignValue(); |
1028 | Type *LoadTy = CI->getType(); |
1029 | Align Alignment = DL.getValueOrABITypeAlignment(Alignment: MA, |
1030 | Ty: LoadTy->getScalarType()); |
1031 | if (TTI.isLegalMaskedGather(DataType: LoadTy, Alignment) && |
1032 | !TTI.forceScalarizeMaskedGather(Type: cast<VectorType>(Val: LoadTy), Alignment)) |
1033 | return false; |
1034 | scalarizeMaskedGather(DL, CI, DTU, ModifiedDT); |
1035 | return true; |
1036 | } |
1037 | case Intrinsic::masked_scatter: { |
1038 | MaybeAlign MA = |
1039 | cast<ConstantInt>(Val: CI->getArgOperand(i: 2))->getMaybeAlignValue(); |
1040 | Type *StoreTy = CI->getArgOperand(i: 0)->getType(); |
1041 | Align Alignment = DL.getValueOrABITypeAlignment(Alignment: MA, |
1042 | Ty: StoreTy->getScalarType()); |
1043 | if (TTI.isLegalMaskedScatter(DataType: StoreTy, Alignment) && |
1044 | !TTI.forceScalarizeMaskedScatter(Type: cast<VectorType>(Val: StoreTy), |
1045 | Alignment)) |
1046 | return false; |
1047 | scalarizeMaskedScatter(DL, CI, DTU, ModifiedDT); |
1048 | return true; |
1049 | } |
1050 | case Intrinsic::masked_expandload: |
1051 | if (TTI.isLegalMaskedExpandLoad( |
1052 | DataType: CI->getType(), |
1053 | Alignment: CI->getAttributes().getParamAttrs(ArgNo: 0).getAlignment().valueOrOne())) |
1054 | return false; |
1055 | scalarizeMaskedExpandLoad(DL, CI, DTU, ModifiedDT); |
1056 | return true; |
1057 | case Intrinsic::masked_compressstore: |
1058 | if (TTI.isLegalMaskedCompressStore( |
1059 | DataType: CI->getArgOperand(i: 0)->getType(), |
1060 | Alignment: CI->getAttributes().getParamAttrs(ArgNo: 1).getAlignment().valueOrOne())) |
1061 | return false; |
1062 | scalarizeMaskedCompressStore(DL, CI, DTU, ModifiedDT); |
1063 | return true; |
1064 | } |
1065 | } |
1066 | |
1067 | return false; |
1068 | } |
1069 | |