diff options
Diffstat (limited to 'llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp')
| -rw-r--r-- | llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp | 33 |
1 files changed, 23 insertions, 10 deletions
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index 931d4d42f56e..a09d2037e97b 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -2908,16 +2908,17 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, // Clamp the range if using multiply-accumulate-reduction is profitable. auto IsMulAccValidAndClampRange = - [&](bool isZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0, - VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt) -> bool { + [&](bool IsZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0, + VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt, + bool Negated = false) -> bool { return LoopVectorizationPlanner::getDecisionAndClampRange( [&](ElementCount VF) { TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; Type *SrcTy = Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy; auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF)); - InstructionCost MulAccCost = - Ctx.TTI.getMulAccReductionCost(isZExt, RedTy, SrcVecTy, CostKind); + InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost( + IsZExt, RedTy, SrcVecTy, Negated, CostKind); InstructionCost MulCost = Mul->computeCost(VF, Ctx); InstructionCost RedCost = Red->computeCost(VF, Ctx); InstructionCost ExtCost = 0; @@ -2935,14 +2936,22 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, }; VPValue *VecOp = Red->getVecOp(); + VPValue *Mul = nullptr; + VPValue *Sub = nullptr; VPValue *A, *B; + // Sub reductions will have a sub between the add reduction and vec op. + if (match(VecOp, + m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(Mul)))) + Sub = VecOp; + else + Mul = VecOp; // Try to match reduce.add(mul(...)). - if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) { + if (match(Mul, m_Mul(m_VPValue(A), m_VPValue(B)))) { auto *RecipeA = dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe()); auto *RecipeB = dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe()); - auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe()); + auto *MulR = cast<VPWidenRecipe>(Mul->getDefiningRecipe()); // Match reduce.add(mul(ext, ext)). if (RecipeA && RecipeB && @@ -2951,12 +2960,16 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, match(RecipeB, m_ZExtOrSExt(m_VPValue())) && IsMulAccValidAndClampRange(RecipeA->getOpcode() == Instruction::CastOps::ZExt, - Mul, RecipeA, RecipeB, nullptr)) { - return new VPExpressionRecipe(RecipeA, RecipeB, Mul, Red); + MulR, RecipeA, RecipeB, nullptr, Sub)) { + if (Sub) + return new VPExpressionRecipe( + RecipeA, RecipeB, MulR, + cast<VPWidenRecipe>(Sub->getDefiningRecipe()), Red); + return new VPExpressionRecipe(RecipeA, RecipeB, MulR, Red); } // Match reduce.add(mul). - if (IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr)) - return new VPExpressionRecipe(Mul, Red); + if (IsMulAccValidAndClampRange(true, MulR, nullptr, nullptr, nullptr, Sub)) + return new VPExpressionRecipe(MulR, Red); } // Match reduce.add(ext(mul(ext(A), ext(B)))). // All extend recipes must have same opcode or A == B |
