diff options
Diffstat (limited to 'llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp')
| -rw-r--r-- | llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp | 38 |
1 files changed, 22 insertions, 16 deletions
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp index 90bbf2d5d99f..eca5d1d4c5e1 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp @@ -9,6 +9,7 @@ #include "VPlanAnalysis.h" #include "VPlan.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/IR/Instruction.h" using namespace llvm; @@ -26,7 +27,24 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPBlendRecipe *R) { } Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) { - switch (R->getOpcode()) { + // Set the result type from the first operand, check if the types for all + // other operands match and cache them. + auto SetResultTyFromOp = [this, R]() { + Type *ResTy = inferScalarType(R->getOperand(0)); + for (unsigned Op = 1; Op != R->getNumOperands(); ++Op) { + VPValue *OtherV = R->getOperand(Op); + assert(inferScalarType(OtherV) == ResTy && + "different types inferred for different operands"); + CachedTypes[OtherV] = ResTy; + } + return ResTy; + }; + + unsigned Opcode = R->getOpcode(); + if (Instruction::isBinaryOp(Opcode) || Instruction::isUnaryOp(Opcode)) + return SetResultTyFromOp(); + + switch (Opcode) { case Instruction::Select: { Type *ResTy = inferScalarType(R->getOperand(1)); VPValue *OtherV = R->getOperand(2); @@ -35,28 +53,16 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) { CachedTypes[OtherV] = ResTy; return ResTy; } - case Instruction::Or: case Instruction::ICmp: - case VPInstruction::FirstOrderRecurrenceSplice: { - Type *ResTy = inferScalarType(R->getOperand(0)); - VPValue *OtherV = R->getOperand(1); - assert(inferScalarType(OtherV) == ResTy && - "different types inferred for different operands"); - CachedTypes[OtherV] = ResTy; - return ResTy; - } + case VPInstruction::FirstOrderRecurrenceSplice: + case VPInstruction::Not: + return SetResultTyFromOp(); case VPInstruction::ExtractFromEnd: { Type *BaseTy = inferScalarType(R->getOperand(0)); if (auto *VecTy = dyn_cast<VectorType>(BaseTy)) return VecTy->getElementType(); return BaseTy; } - case VPInstruction::Not: { - Type *ResTy = inferScalarType(R->getOperand(0)); - assert(IntegerType::get(Ctx, 1) == ResTy && - "unexpected scalar type inferred for operand"); - return ResTy; - } case VPInstruction::LogicalAnd: return IntegerType::get(Ctx, 1); case VPInstruction::PtrAdd: |
