diff options
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVISelLowering.cpp')
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 111 |
1 files changed, 69 insertions, 42 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index e670567bd184..0c54101a1156 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -3724,14 +3724,14 @@ static SDValue lowerBuildVectorViaVID(SDValue Op, SelectionDAG &DAG, SplatStepVal = Log2_64(std::abs(StepNumerator)); } - // Only emit VIDs with suitably-small steps/addends. We use imm5 is a - // threshold since it's the immediate value many RVV instructions accept. - // There is no vmul.vi instruction so ensure multiply constant can fit in - // a single addi instruction. + // Only emit VIDs with suitably-small steps. We use imm5 as a threshold + // since it's the immediate value many RVV instructions accept. There is + // no vmul.vi instruction so ensure multiply constant can fit in a + // single addi instruction. For the addend, we allow up to 32 bits.. if (((StepOpcode == ISD::MUL && isInt<12>(SplatStepVal)) || (StepOpcode == ISD::SHL && isUInt<5>(SplatStepVal))) && isPowerOf2_32(StepDenominator) && - (SplatStepVal >= 0 || StepDenominator == 1) && isInt<5>(Addend)) { + (SplatStepVal >= 0 || StepDenominator == 1) && isInt<32>(Addend)) { MVT VIDVT = VT.isFloatingPoint() ? VT.changeVectorElementTypeToInteger() : VT; MVT VIDContainerVT = @@ -16190,10 +16190,6 @@ combineVectorSizedSetCCEquality(EVT VT, SDValue X, SDValue Y, ISD::CondCode CC, return SDValue(); unsigned OpSize = OpVT.getSizeInBits(); - // TODO: Support non-power-of-2 types. - if (!isPowerOf2_32(OpSize)) - return SDValue(); - // The size should be larger than XLen and smaller than the maximum vector // size. if (OpSize <= Subtarget.getXLen() || @@ -16214,14 +16210,25 @@ combineVectorSizedSetCCEquality(EVT VT, SDValue X, SDValue Y, ISD::CondCode CC, Attribute::NoImplicitFloat)) return SDValue(); + // Bail out for non-byte-sized types. + if (!OpVT.isByteSized()) + return SDValue(); + unsigned VecSize = OpSize / 8; - EVT VecVT = MVT::getVectorVT(MVT::i8, VecSize); - EVT CmpVT = MVT::getVectorVT(MVT::i1, VecSize); + EVT VecVT = EVT::getVectorVT(*DAG.getContext(), MVT::i8, VecSize); + EVT CmpVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, VecSize); SDValue VecX = DAG.getBitcast(VecVT, X); SDValue VecY = DAG.getBitcast(VecVT, Y); - SDValue Cmp = DAG.getSetCC(DL, CmpVT, VecX, VecY, ISD::SETNE); - return DAG.getSetCC(DL, VT, DAG.getNode(ISD::VECREDUCE_OR, DL, XLenVT, Cmp), + SDValue Mask = DAG.getAllOnesConstant(DL, CmpVT); + SDValue VL = DAG.getConstant(VecSize, DL, XLenVT); + + SDValue Cmp = DAG.getNode(ISD::VP_SETCC, DL, CmpVT, VecX, VecY, + DAG.getCondCode(ISD::SETNE), Mask, VL); + return DAG.getSetCC(DL, VT, + DAG.getNode(ISD::VP_REDUCE_OR, DL, XLenVT, + DAG.getConstant(0, DL, XLenVT), Cmp, Mask, + VL), DAG.getConstant(0, DL, XLenVT), CC); } @@ -16309,7 +16316,12 @@ namespace { // apply a combine. struct CombineResult; -enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 }; +enum ExtKind : uint8_t { + ZExt = 1 << 0, + SExt = 1 << 1, + FPExt = 1 << 2, + BF16Ext = 1 << 3 +}; /// Helper class for folding sign/zero extensions. /// In particular, this class is used for the following combines: /// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w @@ -16344,8 +16356,10 @@ struct NodeExtensionHelper { /// instance, a splat constant (e.g., 3), would support being both sign and /// zero extended. bool SupportsSExt; - /// Records if this operand is like being floating-Point extended. + /// Records if this operand is like being floating point extended. bool SupportsFPExt; + /// Records if this operand is extended from bf16. + bool SupportsBF16Ext; /// This boolean captures whether we care if this operand would still be /// around after the folding happens. bool EnforceOneUse; @@ -16381,6 +16395,7 @@ struct NodeExtensionHelper { case ExtKind::ZExt: return RISCVISD::VZEXT_VL; case ExtKind::FPExt: + case ExtKind::BF16Ext: return RISCVISD::FP_EXTEND_VL; } llvm_unreachable("Unknown ExtKind enum"); @@ -16402,13 +16417,6 @@ struct NodeExtensionHelper { if (Source.getValueType() == NarrowVT) return Source; - // vfmadd_vl -> vfwmadd_vl can take bf16 operands - if (Source.getValueType().getVectorElementType() == MVT::bf16) { - assert(Root->getSimpleValueType(0).getVectorElementType() == MVT::f32 && - Root->getOpcode() == RISCVISD::VFMADD_VL); - return Source; - } - unsigned ExtOpc = getExtOpc(*SupportsExt); // If we need an extension, we should be changing the type. @@ -16451,7 +16459,8 @@ struct NodeExtensionHelper { // Determine the narrow size. unsigned NarrowSize = VT.getScalarSizeInBits() / 2; - MVT EltVT = SupportsExt == ExtKind::FPExt + MVT EltVT = SupportsExt == ExtKind::BF16Ext ? MVT::bf16 + : SupportsExt == ExtKind::FPExt ? MVT::getFloatingPointVT(NarrowSize) : MVT::getIntegerVT(NarrowSize); @@ -16628,17 +16637,13 @@ struct NodeExtensionHelper { EnforceOneUse = false; } - bool isSupportedFPExtend(SDNode *Root, MVT NarrowEltVT, - const RISCVSubtarget &Subtarget) { - // Any f16 extension will need zvfh - if (NarrowEltVT == MVT::f16 && !Subtarget.hasVInstructionsF16()) - return false; - // The only bf16 extension we can do is vfmadd_vl -> vfwmadd_vl with - // zvfbfwma - if (NarrowEltVT == MVT::bf16 && (!Subtarget.hasStdExtZvfbfwma() || - Root->getOpcode() != RISCVISD::VFMADD_VL)) - return false; - return true; + bool isSupportedFPExtend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) { + return (NarrowEltVT == MVT::f32 || + (NarrowEltVT == MVT::f16 && Subtarget.hasVInstructionsF16())); + } + + bool isSupportedBF16Extend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) { + return NarrowEltVT == MVT::bf16 && Subtarget.hasStdExtZvfbfwma(); } /// Helper method to set the various fields of this struct based on the @@ -16648,6 +16653,7 @@ struct NodeExtensionHelper { SupportsZExt = false; SupportsSExt = false; SupportsFPExt = false; + SupportsBF16Ext = false; EnforceOneUse = true; unsigned Opc = OrigOperand.getOpcode(); // For the nodes we handle below, we end up using their inputs directly: see @@ -16679,9 +16685,11 @@ struct NodeExtensionHelper { case RISCVISD::FP_EXTEND_VL: { MVT NarrowEltVT = OrigOperand.getOperand(0).getSimpleValueType().getVectorElementType(); - if (!isSupportedFPExtend(Root, NarrowEltVT, Subtarget)) - break; - SupportsFPExt = true; + if (isSupportedFPExtend(NarrowEltVT, Subtarget)) + SupportsFPExt = true; + if (isSupportedBF16Extend(NarrowEltVT, Subtarget)) + SupportsBF16Ext = true; + break; } case ISD::SPLAT_VECTOR: @@ -16698,16 +16706,16 @@ struct NodeExtensionHelper { if (Op.getOpcode() != ISD::FP_EXTEND) break; - if (!isSupportedFPExtend(Root, Op.getOperand(0).getSimpleValueType(), - Subtarget)) - break; - unsigned NarrowSize = VT.getScalarSizeInBits() / 2; unsigned ScalarBits = Op.getOperand(0).getValueSizeInBits(); if (NarrowSize != ScalarBits) break; - SupportsFPExt = true; + if (isSupportedFPExtend(Op.getOperand(0).getSimpleValueType(), Subtarget)) + SupportsFPExt = true; + if (isSupportedBF16Extend(Op.getOperand(0).getSimpleValueType(), + Subtarget)) + SupportsBF16Ext = true; break; } default: @@ -16940,6 +16948,11 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS, return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()), Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS, /*RHSExt=*/{ExtKind::FPExt}); + if ((AllowExtMask & ExtKind::BF16Ext) && LHS.SupportsBF16Ext && + RHS.SupportsBF16Ext) + return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()), + Root, LHS, /*LHSExt=*/{ExtKind::BF16Ext}, RHS, + /*RHSExt=*/{ExtKind::BF16Ext}); return std::nullopt; } @@ -17022,6 +17035,18 @@ canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS, Subtarget); } +/// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS)) +/// +/// \returns std::nullopt if the pattern doesn't match or a CombineResult that +/// can be used to apply the pattern. +static std::optional<CombineResult> +canFoldToVWWithBF16EXT(SDNode *Root, const NodeExtensionHelper &LHS, + const NodeExtensionHelper &RHS, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG, + Subtarget); +} + /// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS)) /// /// \returns std::nullopt if the pattern doesn't match or a CombineResult that @@ -17061,6 +17086,8 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) { case RISCVISD::VFNMADD_VL: case RISCVISD::VFNMSUB_VL: Strategies.push_back(canFoldToVWWithSameExtension); + if (Root->getOpcode() == RISCVISD::VFMADD_VL) + Strategies.push_back(canFoldToVWWithBF16EXT); break; case ISD::MUL: case RISCVISD::MUL_VL: |
