summaryrefslogtreecommitdiff
path: root/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVISelLowering.cpp')
-rw-r--r--llvm/lib/Target/RISCV/RISCVISelLowering.cpp111
1 files changed, 69 insertions, 42 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index e670567bd184..0c54101a1156 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -3724,14 +3724,14 @@ static SDValue lowerBuildVectorViaVID(SDValue Op, SelectionDAG &DAG,
SplatStepVal = Log2_64(std::abs(StepNumerator));
}
- // Only emit VIDs with suitably-small steps/addends. We use imm5 is a
- // threshold since it's the immediate value many RVV instructions accept.
- // There is no vmul.vi instruction so ensure multiply constant can fit in
- // a single addi instruction.
+ // Only emit VIDs with suitably-small steps. We use imm5 as a threshold
+ // since it's the immediate value many RVV instructions accept. There is
+ // no vmul.vi instruction so ensure multiply constant can fit in a
+ // single addi instruction. For the addend, we allow up to 32 bits..
if (((StepOpcode == ISD::MUL && isInt<12>(SplatStepVal)) ||
(StepOpcode == ISD::SHL && isUInt<5>(SplatStepVal))) &&
isPowerOf2_32(StepDenominator) &&
- (SplatStepVal >= 0 || StepDenominator == 1) && isInt<5>(Addend)) {
+ (SplatStepVal >= 0 || StepDenominator == 1) && isInt<32>(Addend)) {
MVT VIDVT =
VT.isFloatingPoint() ? VT.changeVectorElementTypeToInteger() : VT;
MVT VIDContainerVT =
@@ -16190,10 +16190,6 @@ combineVectorSizedSetCCEquality(EVT VT, SDValue X, SDValue Y, ISD::CondCode CC,
return SDValue();
unsigned OpSize = OpVT.getSizeInBits();
- // TODO: Support non-power-of-2 types.
- if (!isPowerOf2_32(OpSize))
- return SDValue();
-
// The size should be larger than XLen and smaller than the maximum vector
// size.
if (OpSize <= Subtarget.getXLen() ||
@@ -16214,14 +16210,25 @@ combineVectorSizedSetCCEquality(EVT VT, SDValue X, SDValue Y, ISD::CondCode CC,
Attribute::NoImplicitFloat))
return SDValue();
+ // Bail out for non-byte-sized types.
+ if (!OpVT.isByteSized())
+ return SDValue();
+
unsigned VecSize = OpSize / 8;
- EVT VecVT = MVT::getVectorVT(MVT::i8, VecSize);
- EVT CmpVT = MVT::getVectorVT(MVT::i1, VecSize);
+ EVT VecVT = EVT::getVectorVT(*DAG.getContext(), MVT::i8, VecSize);
+ EVT CmpVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, VecSize);
SDValue VecX = DAG.getBitcast(VecVT, X);
SDValue VecY = DAG.getBitcast(VecVT, Y);
- SDValue Cmp = DAG.getSetCC(DL, CmpVT, VecX, VecY, ISD::SETNE);
- return DAG.getSetCC(DL, VT, DAG.getNode(ISD::VECREDUCE_OR, DL, XLenVT, Cmp),
+ SDValue Mask = DAG.getAllOnesConstant(DL, CmpVT);
+ SDValue VL = DAG.getConstant(VecSize, DL, XLenVT);
+
+ SDValue Cmp = DAG.getNode(ISD::VP_SETCC, DL, CmpVT, VecX, VecY,
+ DAG.getCondCode(ISD::SETNE), Mask, VL);
+ return DAG.getSetCC(DL, VT,
+ DAG.getNode(ISD::VP_REDUCE_OR, DL, XLenVT,
+ DAG.getConstant(0, DL, XLenVT), Cmp, Mask,
+ VL),
DAG.getConstant(0, DL, XLenVT), CC);
}
@@ -16309,7 +16316,12 @@ namespace {
// apply a combine.
struct CombineResult;
-enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 };
+enum ExtKind : uint8_t {
+ ZExt = 1 << 0,
+ SExt = 1 << 1,
+ FPExt = 1 << 2,
+ BF16Ext = 1 << 3
+};
/// Helper class for folding sign/zero extensions.
/// In particular, this class is used for the following combines:
/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
@@ -16344,8 +16356,10 @@ struct NodeExtensionHelper {
/// instance, a splat constant (e.g., 3), would support being both sign and
/// zero extended.
bool SupportsSExt;
- /// Records if this operand is like being floating-Point extended.
+ /// Records if this operand is like being floating point extended.
bool SupportsFPExt;
+ /// Records if this operand is extended from bf16.
+ bool SupportsBF16Ext;
/// This boolean captures whether we care if this operand would still be
/// around after the folding happens.
bool EnforceOneUse;
@@ -16381,6 +16395,7 @@ struct NodeExtensionHelper {
case ExtKind::ZExt:
return RISCVISD::VZEXT_VL;
case ExtKind::FPExt:
+ case ExtKind::BF16Ext:
return RISCVISD::FP_EXTEND_VL;
}
llvm_unreachable("Unknown ExtKind enum");
@@ -16402,13 +16417,6 @@ struct NodeExtensionHelper {
if (Source.getValueType() == NarrowVT)
return Source;
- // vfmadd_vl -> vfwmadd_vl can take bf16 operands
- if (Source.getValueType().getVectorElementType() == MVT::bf16) {
- assert(Root->getSimpleValueType(0).getVectorElementType() == MVT::f32 &&
- Root->getOpcode() == RISCVISD::VFMADD_VL);
- return Source;
- }
-
unsigned ExtOpc = getExtOpc(*SupportsExt);
// If we need an extension, we should be changing the type.
@@ -16451,7 +16459,8 @@ struct NodeExtensionHelper {
// Determine the narrow size.
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
- MVT EltVT = SupportsExt == ExtKind::FPExt
+ MVT EltVT = SupportsExt == ExtKind::BF16Ext ? MVT::bf16
+ : SupportsExt == ExtKind::FPExt
? MVT::getFloatingPointVT(NarrowSize)
: MVT::getIntegerVT(NarrowSize);
@@ -16628,17 +16637,13 @@ struct NodeExtensionHelper {
EnforceOneUse = false;
}
- bool isSupportedFPExtend(SDNode *Root, MVT NarrowEltVT,
- const RISCVSubtarget &Subtarget) {
- // Any f16 extension will need zvfh
- if (NarrowEltVT == MVT::f16 && !Subtarget.hasVInstructionsF16())
- return false;
- // The only bf16 extension we can do is vfmadd_vl -> vfwmadd_vl with
- // zvfbfwma
- if (NarrowEltVT == MVT::bf16 && (!Subtarget.hasStdExtZvfbfwma() ||
- Root->getOpcode() != RISCVISD::VFMADD_VL))
- return false;
- return true;
+ bool isSupportedFPExtend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) {
+ return (NarrowEltVT == MVT::f32 ||
+ (NarrowEltVT == MVT::f16 && Subtarget.hasVInstructionsF16()));
+ }
+
+ bool isSupportedBF16Extend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) {
+ return NarrowEltVT == MVT::bf16 && Subtarget.hasStdExtZvfbfwma();
}
/// Helper method to set the various fields of this struct based on the
@@ -16648,6 +16653,7 @@ struct NodeExtensionHelper {
SupportsZExt = false;
SupportsSExt = false;
SupportsFPExt = false;
+ SupportsBF16Ext = false;
EnforceOneUse = true;
unsigned Opc = OrigOperand.getOpcode();
// For the nodes we handle below, we end up using their inputs directly: see
@@ -16679,9 +16685,11 @@ struct NodeExtensionHelper {
case RISCVISD::FP_EXTEND_VL: {
MVT NarrowEltVT =
OrigOperand.getOperand(0).getSimpleValueType().getVectorElementType();
- if (!isSupportedFPExtend(Root, NarrowEltVT, Subtarget))
- break;
- SupportsFPExt = true;
+ if (isSupportedFPExtend(NarrowEltVT, Subtarget))
+ SupportsFPExt = true;
+ if (isSupportedBF16Extend(NarrowEltVT, Subtarget))
+ SupportsBF16Ext = true;
+
break;
}
case ISD::SPLAT_VECTOR:
@@ -16698,16 +16706,16 @@ struct NodeExtensionHelper {
if (Op.getOpcode() != ISD::FP_EXTEND)
break;
- if (!isSupportedFPExtend(Root, Op.getOperand(0).getSimpleValueType(),
- Subtarget))
- break;
-
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
unsigned ScalarBits = Op.getOperand(0).getValueSizeInBits();
if (NarrowSize != ScalarBits)
break;
- SupportsFPExt = true;
+ if (isSupportedFPExtend(Op.getOperand(0).getSimpleValueType(), Subtarget))
+ SupportsFPExt = true;
+ if (isSupportedBF16Extend(Op.getOperand(0).getSimpleValueType(),
+ Subtarget))
+ SupportsBF16Ext = true;
break;
}
default:
@@ -16940,6 +16948,11 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
/*RHSExt=*/{ExtKind::FPExt});
+ if ((AllowExtMask & ExtKind::BF16Ext) && LHS.SupportsBF16Ext &&
+ RHS.SupportsBF16Ext)
+ return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
+ Root, LHS, /*LHSExt=*/{ExtKind::BF16Ext}, RHS,
+ /*RHSExt=*/{ExtKind::BF16Ext});
return std::nullopt;
}
@@ -17022,6 +17035,18 @@ canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
Subtarget);
}
+/// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
+///
+/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
+/// can be used to apply the pattern.
+static std::optional<CombineResult>
+canFoldToVWWithBF16EXT(SDNode *Root, const NodeExtensionHelper &LHS,
+ const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
+ Subtarget);
+}
+
/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
///
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
@@ -17061,6 +17086,8 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
case RISCVISD::VFNMADD_VL:
case RISCVISD::VFNMSUB_VL:
Strategies.push_back(canFoldToVWWithSameExtension);
+ if (Root->getOpcode() == RISCVISD::VFMADD_VL)
+ Strategies.push_back(canFoldToVWWithBF16EXT);
break;
case ISD::MUL:
case RISCVISD::MUL_VL: