summaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
diff options
context:
space:
mode:
authorSamuel Tebbs <samuel.tebbs@arm.com>2025-06-30 14:29:54 +0100
committerSamuel Tebbs <samuel.tebbs@arm.com>2025-07-06 22:21:12 +0100
commit1a5f4e42e4f9d1eae0222302dcabdf08492f67c3 (patch)
tree8963470cd17d362b5cb546c2e7492b40f0cab1d3 /llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
parent39f3dab6bb392a79e1888fc1c9aa82d9259c5576 (diff)
[LV] Bundle sub reductions into VPExpressionRecipeusers/SamTebbs33/expression-recipe-sub
This PR bundles sub reductions into the VPExpressionRecipe class and adjusts the cost functions to take the negation into account.
Diffstat (limited to 'llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp')
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp35
1 files changed, 32 insertions, 3 deletions
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 318e8171e098..c20b1920c379 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2672,13 +2672,17 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
RedTy, SrcVecTy, std::nullopt, Ctx.CostKind);
}
case ExpressionTypes::MulAccReduction:
- return Ctx.TTI.getMulAccReductionCost(false, RedTy, SrcVecTy, Ctx.CostKind);
+ return Ctx.TTI.getMulAccReductionCost(false, RedTy, SrcVecTy, false,
+ Ctx.CostKind);
- case ExpressionTypes::ExtMulAccReduction:
+ case ExpressionTypes::ExtNegatedMulAccReduction:
+ case ExpressionTypes::ExtMulAccReduction: {
+ bool Negated = ExpressionType == ExpressionTypes::ExtNegatedMulAccReduction;
return Ctx.TTI.getMulAccReductionCost(
cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() ==
Instruction::ZExt,
- RedTy, SrcVecTy, Ctx.CostKind);
+ RedTy, SrcVecTy, Negated, Ctx.CostKind);
+ }
}
llvm_unreachable("Unknown VPExpressionRecipe::ExpressionTypes enum");
}
@@ -2725,6 +2729,31 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
O << ")";
break;
}
+ case ExpressionTypes::ExtNegatedMulAccReduction: {
+ getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);
+ O << " + ";
+ O << "reduce."
+ << Instruction::getOpcodeName(
+ RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()))
+ << " (sub (0, mul";
+ auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]);
+ Mul->printFlags(O);
+ O << "(";
+ getOperand(0)->printAsOperand(O, SlotTracker);
+ auto *Ext0 = cast<VPWidenCastRecipe>(ExpressionRecipes[0]);
+ O << " " << Instruction::getOpcodeName(Ext0->getOpcode()) << " to "
+ << *Ext0->getResultType() << "), (";
+ getOperand(1)->printAsOperand(O, SlotTracker);
+ auto *Ext1 = cast<VPWidenCastRecipe>(ExpressionRecipes[1]);
+ O << " " << Instruction::getOpcodeName(Ext1->getOpcode()) << " to "
+ << *Ext1->getResultType() << ")";
+ if (Red->isConditional()) {
+ O << ", ";
+ Red->getCondOp()->printAsOperand(O, SlotTracker);
+ }
+ O << "))";
+ break;
+ }
case ExpressionTypes::MulAccReduction:
case ExpressionTypes::ExtMulAccReduction: {
getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);