summaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp')
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp33
1 files changed, 23 insertions, 10 deletions
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 931d4d42f56e..a09d2037e97b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -2908,16 +2908,17 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
// Clamp the range if using multiply-accumulate-reduction is profitable.
auto IsMulAccValidAndClampRange =
- [&](bool isZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
- VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt) -> bool {
+ [&](bool IsZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
+ VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt,
+ bool Negated = false) -> bool {
return LoopVectorizationPlanner::getDecisionAndClampRange(
[&](ElementCount VF) {
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
Type *SrcTy =
Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy;
auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
- InstructionCost MulAccCost =
- Ctx.TTI.getMulAccReductionCost(isZExt, RedTy, SrcVecTy, CostKind);
+ InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
+ IsZExt, RedTy, SrcVecTy, Negated, CostKind);
InstructionCost MulCost = Mul->computeCost(VF, Ctx);
InstructionCost RedCost = Red->computeCost(VF, Ctx);
InstructionCost ExtCost = 0;
@@ -2935,14 +2936,22 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
};
VPValue *VecOp = Red->getVecOp();
+ VPValue *Mul = nullptr;
+ VPValue *Sub = nullptr;
VPValue *A, *B;
+ // Sub reductions will have a sub between the add reduction and vec op.
+ if (match(VecOp,
+ m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(Mul))))
+ Sub = VecOp;
+ else
+ Mul = VecOp;
// Try to match reduce.add(mul(...)).
- if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
+ if (match(Mul, m_Mul(m_VPValue(A), m_VPValue(B)))) {
auto *RecipeA =
dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
auto *RecipeB =
dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
- auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
+ auto *MulR = cast<VPWidenRecipe>(Mul->getDefiningRecipe());
// Match reduce.add(mul(ext, ext)).
if (RecipeA && RecipeB &&
@@ -2951,12 +2960,16 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
Instruction::CastOps::ZExt,
- Mul, RecipeA, RecipeB, nullptr)) {
- return new VPExpressionRecipe(RecipeA, RecipeB, Mul, Red);
+ MulR, RecipeA, RecipeB, nullptr, Sub)) {
+ if (Sub)
+ return new VPExpressionRecipe(
+ RecipeA, RecipeB, MulR,
+ cast<VPWidenRecipe>(Sub->getDefiningRecipe()), Red);
+ return new VPExpressionRecipe(RecipeA, RecipeB, MulR, Red);
}
// Match reduce.add(mul).
- if (IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr))
- return new VPExpressionRecipe(Mul, Red);
+ if (IsMulAccValidAndClampRange(true, MulR, nullptr, nullptr, nullptr, Sub))
+ return new VPExpressionRecipe(MulR, Red);
}
// Match reduce.add(ext(mul(ext(A), ext(B)))).
// All extend recipes must have same opcode or A == B