diff options
Diffstat (limited to 'llvm/lib/Transforms/Vectorize')
| -rw-r--r-- | llvm/lib/Transforms/Vectorize/LoopVectorize.cpp | 6 | ||||
| -rw-r--r-- | llvm/lib/Transforms/Vectorize/VPlan.h | 11 | ||||
| -rw-r--r-- | llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp | 35 | ||||
| -rw-r--r-- | llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp | 33 | ||||
| -rw-r--r-- | llvm/lib/Transforms/Vectorize/VectorCombine.cpp | 4 |
5 files changed, 71 insertions, 18 deletions
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 1cfbcf133662..0adff8d957e9 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -5538,7 +5538,7 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I, TTI::CastContextHint::None, CostKind, RedOp); InstructionCost RedCost = TTI.getMulAccReductionCost( - IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind); + IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, false, CostKind); if (RedCost.isValid() && RedCost < ExtCost * 2 + MulCost + Ext2Cost + BaseCost) @@ -5583,7 +5583,7 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I, TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind); InstructionCost RedCost = TTI.getMulAccReductionCost( - IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind); + IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, false, CostKind); InstructionCost ExtraExtCost = 0; if (Op0Ty != LargestOpTy || Op1Ty != LargestOpTy) { Instruction *ExtraExtOp = (Op0Ty != LargestOpTy) ? Op0 : Op1; @@ -5602,7 +5602,7 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I, TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind); InstructionCost RedCost = TTI.getMulAccReductionCost( - true, RdxDesc.getRecurrenceType(), VectorTy, CostKind); + true, RdxDesc.getRecurrenceType(), VectorTy, false, CostKind); if (RedCost.isValid() && RedCost < MulCost + BaseCost) return I == RetI ? RedCost : 0; diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index d460573f5bec..1bc926db301d 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -2757,6 +2757,12 @@ class VPExpressionRecipe : public VPSingleDefRecipe { /// vector operands, performing a reduction.add on the result, and adding /// the scalar result to a chain. MulAccReduction, + /// Represent an inloop multiply-accumulate reduction, multiplying the + /// extended vector operands, negating the multiplication, performing a + /// reduction.add + /// on the result, and adding + /// the scalar result to a chain. + ExtNegatedMulAccReduction, }; /// Type of the expression. @@ -2780,6 +2786,11 @@ public: VPWidenRecipe *Mul, VPReductionRecipe *Red) : VPExpressionRecipe(ExpressionTypes::ExtMulAccReduction, {Ext0, Ext1, Mul, Red}) {} + VPExpressionRecipe(VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1, + VPWidenRecipe *Mul, VPWidenRecipe *Sub, + VPReductionRecipe *Red) + : VPExpressionRecipe(ExpressionTypes::ExtNegatedMulAccReduction, + {Ext0, Ext1, Mul, Sub, Red}) {} ~VPExpressionRecipe() override { for (auto *R : reverse(ExpressionRecipes)) diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 318e8171e098..c20b1920c379 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -2672,13 +2672,17 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF, RedTy, SrcVecTy, std::nullopt, Ctx.CostKind); } case ExpressionTypes::MulAccReduction: - return Ctx.TTI.getMulAccReductionCost(false, RedTy, SrcVecTy, Ctx.CostKind); + return Ctx.TTI.getMulAccReductionCost(false, RedTy, SrcVecTy, false, + Ctx.CostKind); - case ExpressionTypes::ExtMulAccReduction: + case ExpressionTypes::ExtNegatedMulAccReduction: + case ExpressionTypes::ExtMulAccReduction: { + bool Negated = ExpressionType == ExpressionTypes::ExtNegatedMulAccReduction; return Ctx.TTI.getMulAccReductionCost( cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() == Instruction::ZExt, - RedTy, SrcVecTy, Ctx.CostKind); + RedTy, SrcVecTy, Negated, Ctx.CostKind); + } } llvm_unreachable("Unknown VPExpressionRecipe::ExpressionTypes enum"); } @@ -2725,6 +2729,31 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent, O << ")"; break; } + case ExpressionTypes::ExtNegatedMulAccReduction: { + getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker); + O << " + "; + O << "reduce." + << Instruction::getOpcodeName( + RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind())) + << " (sub (0, mul"; + auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]); + Mul->printFlags(O); + O << "("; + getOperand(0)->printAsOperand(O, SlotTracker); + auto *Ext0 = cast<VPWidenCastRecipe>(ExpressionRecipes[0]); + O << " " << Instruction::getOpcodeName(Ext0->getOpcode()) << " to " + << *Ext0->getResultType() << "), ("; + getOperand(1)->printAsOperand(O, SlotTracker); + auto *Ext1 = cast<VPWidenCastRecipe>(ExpressionRecipes[1]); + O << " " << Instruction::getOpcodeName(Ext1->getOpcode()) << " to " + << *Ext1->getResultType() << ")"; + if (Red->isConditional()) { + O << ", "; + Red->getCondOp()->printAsOperand(O, SlotTracker); + } + O << "))"; + break; + } case ExpressionTypes::MulAccReduction: case ExpressionTypes::ExtMulAccReduction: { getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker); diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index 931d4d42f56e..a09d2037e97b 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -2908,16 +2908,17 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, // Clamp the range if using multiply-accumulate-reduction is profitable. auto IsMulAccValidAndClampRange = - [&](bool isZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0, - VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt) -> bool { + [&](bool IsZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0, + VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt, + bool Negated = false) -> bool { return LoopVectorizationPlanner::getDecisionAndClampRange( [&](ElementCount VF) { TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; Type *SrcTy = Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy; auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF)); - InstructionCost MulAccCost = - Ctx.TTI.getMulAccReductionCost(isZExt, RedTy, SrcVecTy, CostKind); + InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost( + IsZExt, RedTy, SrcVecTy, Negated, CostKind); InstructionCost MulCost = Mul->computeCost(VF, Ctx); InstructionCost RedCost = Red->computeCost(VF, Ctx); InstructionCost ExtCost = 0; @@ -2935,14 +2936,22 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, }; VPValue *VecOp = Red->getVecOp(); + VPValue *Mul = nullptr; + VPValue *Sub = nullptr; VPValue *A, *B; + // Sub reductions will have a sub between the add reduction and vec op. + if (match(VecOp, + m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(Mul)))) + Sub = VecOp; + else + Mul = VecOp; // Try to match reduce.add(mul(...)). - if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) { + if (match(Mul, m_Mul(m_VPValue(A), m_VPValue(B)))) { auto *RecipeA = dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe()); auto *RecipeB = dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe()); - auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe()); + auto *MulR = cast<VPWidenRecipe>(Mul->getDefiningRecipe()); // Match reduce.add(mul(ext, ext)). if (RecipeA && RecipeB && @@ -2951,12 +2960,16 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, match(RecipeB, m_ZExtOrSExt(m_VPValue())) && IsMulAccValidAndClampRange(RecipeA->getOpcode() == Instruction::CastOps::ZExt, - Mul, RecipeA, RecipeB, nullptr)) { - return new VPExpressionRecipe(RecipeA, RecipeB, Mul, Red); + MulR, RecipeA, RecipeB, nullptr, Sub)) { + if (Sub) + return new VPExpressionRecipe( + RecipeA, RecipeB, MulR, + cast<VPWidenRecipe>(Sub->getDefiningRecipe()), Red); + return new VPExpressionRecipe(RecipeA, RecipeB, MulR, Red); } // Match reduce.add(mul). - if (IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr)) - return new VPExpressionRecipe(Mul, Red); + if (IsMulAccValidAndClampRange(true, MulR, nullptr, nullptr, nullptr, Sub)) + return new VPExpressionRecipe(MulR, Red); } // Match reduce.add(ext(mul(ext(A), ext(B)))). // All extend recipes must have same opcode or A == B diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index b2fced47b952..7953aec48c8b 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -1401,8 +1401,8 @@ static void analyzeCostOfVecReduction(const IntrinsicInst &II, TTI::CastContextHint::None, CostKind, RedOp); CostBeforeReduction = ExtCost * 2 + MulCost + Ext2Cost; - CostAfterReduction = - TTI.getMulAccReductionCost(IsUnsigned, II.getType(), ExtType, CostKind); + CostAfterReduction = TTI.getMulAccReductionCost(IsUnsigned, II.getType(), + ExtType, false, CostKind); return; } CostAfterReduction = TTI.getArithmeticReductionCost(ReductionOpc, VecRedTy, |
