summaryrefslogtreecommitdiff
path: root/llvm/lib/Target/WebAssembly
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/WebAssembly')
-rw-r--r--llvm/lib/Target/WebAssembly/WebAssemblyISD.def1
-rw-r--r--llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp74
-rw-r--r--llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.cpp2
-rw-r--r--llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td11
-rw-r--r--llvm/lib/Target/WebAssembly/WebAssemblyRegStackify.cpp3
-rw-r--r--llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp125
-rw-r--r--llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h4
7 files changed, 171 insertions, 49 deletions
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISD.def b/llvm/lib/Target/WebAssembly/WebAssemblyISD.def
index 378ef2c8f250..1eae3586d16b 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISD.def
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISD.def
@@ -27,6 +27,7 @@ HANDLE_NODETYPE(WrapperREL)
HANDLE_NODETYPE(BR_IF)
HANDLE_NODETYPE(BR_TABLE)
HANDLE_NODETYPE(DOT)
+HANDLE_NODETYPE(EXT_ADD_PAIRWISE_U)
HANDLE_NODETYPE(SHUFFLE)
HANDLE_NODETYPE(SWIZZLE)
HANDLE_NODETYPE(VEC_SHL)
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index c6c2d0cfccb6..fe100dab427e 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -2183,13 +2183,10 @@ SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
- SDValue LowLow = DAG.getNode(LowOpc, DL, MVT::v4i32, MulLow);
- SDValue LowHigh = DAG.getNode(LowOpc, DL, MVT::v4i32, MulHigh);
- SDValue HighLow = DAG.getNode(HighOpc, DL, MVT::v4i32, MulLow);
- SDValue HighHigh = DAG.getNode(HighOpc, DL, MVT::v4i32, MulHigh);
-
- SDValue AddLow = DAG.getNode(ISD::ADD, DL, MVT::v4i32, LowLow, HighLow);
- SDValue AddHigh = DAG.getNode(ISD::ADD, DL, MVT::v4i32, LowHigh, HighHigh);
+ SDValue AddLow =
+ DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL, MVT::v4i32, MulLow);
+ SDValue AddHigh = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
+ MVT::v4i32, MulHigh);
SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
}
@@ -3588,34 +3585,53 @@ static SDValue performMulCombine(SDNode *N,
if (auto Res = TryWideExtMulCombine(N, DCI.DAG))
return Res;
- // We don't natively support v16i8 mul, but we do support v8i16 so split the
- // inputs and extend them to v8i16. Only do this before legalization in case
- // a narrow vector is widened and may be simplified later.
- if (!DCI.isBeforeLegalize() || VT != MVT::v16i8)
+ // We don't natively support v16i8 or v8i8 mul, but we do support v8i16. So,
+ // extend them to v8i16. Only do this before legalization in case a narrow
+ // vector is widened and may be simplified later.
+ if (!DCI.isBeforeLegalize() || (VT != MVT::v8i8 && VT != MVT::v16i8))
return SDValue();
SDLoc DL(N);
SelectionDAG &DAG = DCI.DAG;
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);
- SDValue LowLHS =
- DAG.getNode(WebAssemblyISD::EXTEND_LOW_U, DL, MVT::v8i16, LHS);
- SDValue HighLHS =
- DAG.getNode(WebAssemblyISD::EXTEND_HIGH_U, DL, MVT::v8i16, LHS);
- SDValue LowRHS =
- DAG.getNode(WebAssemblyISD::EXTEND_LOW_U, DL, MVT::v8i16, RHS);
- SDValue HighRHS =
- DAG.getNode(WebAssemblyISD::EXTEND_HIGH_U, DL, MVT::v8i16, RHS);
-
- SDValue MulLow =
- DAG.getBitcast(VT, DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS));
- SDValue MulHigh = DAG.getBitcast(
- VT, DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS));
-
- // Take the low byte of each lane.
- return DAG.getVectorShuffle(
- VT, DL, MulLow, MulHigh,
- {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30});
+ EVT MulVT = MVT::v8i16;
+
+ if (VT == MVT::v8i8) {
+ SDValue PromotedLHS = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v16i8, LHS,
+ DAG.getUNDEF(MVT::v8i8));
+ SDValue PromotedRHS = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v16i8, RHS,
+ DAG.getUNDEF(MVT::v8i8));
+ SDValue LowLHS =
+ DAG.getNode(WebAssemblyISD::EXTEND_LOW_U, DL, MulVT, PromotedLHS);
+ SDValue LowRHS =
+ DAG.getNode(WebAssemblyISD::EXTEND_LOW_U, DL, MulVT, PromotedRHS);
+ SDValue MulLow = DAG.getBitcast(
+ MVT::v16i8, DAG.getNode(ISD::MUL, DL, MulVT, LowLHS, LowRHS));
+ // Take the low byte of each lane.
+ SDValue Shuffle = DAG.getVectorShuffle(
+ MVT::v16i8, DL, MulLow, DAG.getUNDEF(MVT::v16i8),
+ {0, 2, 4, 6, 8, 10, 12, 14, -1, -1, -1, -1, -1, -1, -1, -1});
+ return extractSubVector(Shuffle, 0, DAG, DL, 64);
+ } else {
+ assert(VT == MVT::v16i8 && "Expected v16i8");
+ SDValue LowLHS = DAG.getNode(WebAssemblyISD::EXTEND_LOW_U, DL, MulVT, LHS);
+ SDValue LowRHS = DAG.getNode(WebAssemblyISD::EXTEND_LOW_U, DL, MulVT, RHS);
+ SDValue HighLHS =
+ DAG.getNode(WebAssemblyISD::EXTEND_HIGH_U, DL, MulVT, LHS);
+ SDValue HighRHS =
+ DAG.getNode(WebAssemblyISD::EXTEND_HIGH_U, DL, MulVT, RHS);
+
+ SDValue MulLow =
+ DAG.getBitcast(VT, DAG.getNode(ISD::MUL, DL, MulVT, LowLHS, LowRHS));
+ SDValue MulHigh =
+ DAG.getBitcast(VT, DAG.getNode(ISD::MUL, DL, MulVT, HighLHS, HighRHS));
+
+ // Take the low byte of each lane.
+ return DAG.getVectorShuffle(
+ VT, DL, MulLow, MulHigh,
+ {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30});
+ }
}
SDValue
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.cpp
index a934853ff9f4..feac04a17068 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.cpp
@@ -34,7 +34,7 @@ using namespace llvm;
#include "WebAssemblyGenInstrInfo.inc"
WebAssemblyInstrInfo::WebAssemblyInstrInfo(const WebAssemblySubtarget &STI)
- : WebAssemblyGenInstrInfo(WebAssembly::ADJCALLSTACKDOWN,
+ : WebAssemblyGenInstrInfo(STI, WebAssembly::ADJCALLSTACKDOWN,
WebAssembly::ADJCALLSTACKUP,
WebAssembly::CATCHRET),
RI(STI.getTargetTriple()) {}
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
index f06f8d5174e3..3c26b453c448 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
@@ -1453,15 +1453,22 @@ if !ne(t1, t2) then
def : Pat<(t1.vt (bitconvert (t2.vt V128:$v))), (t1.vt V128:$v)>;
// Extended pairwise addition
+def extadd_pairwise_u : SDNode<"WebAssemblyISD::EXT_ADD_PAIRWISE_U", extend_t>;
+
defm "" : SIMDConvert<I16x8, I8x16, int_wasm_extadd_pairwise_signed,
"extadd_pairwise_i8x16_s", 0x7c>;
-defm "" : SIMDConvert<I16x8, I8x16, int_wasm_extadd_pairwise_unsigned,
+defm "" : SIMDConvert<I16x8, I8x16, extadd_pairwise_u,
"extadd_pairwise_i8x16_u", 0x7d>;
defm "" : SIMDConvert<I32x4, I16x8, int_wasm_extadd_pairwise_signed,
"extadd_pairwise_i16x8_s", 0x7e>;
-defm "" : SIMDConvert<I32x4, I16x8, int_wasm_extadd_pairwise_unsigned,
+defm "" : SIMDConvert<I32x4, I16x8, extadd_pairwise_u,
"extadd_pairwise_i16x8_u", 0x7f>;
+def : Pat<(v4i32 (int_wasm_extadd_pairwise_unsigned (v8i16 V128:$in))),
+ (extadd_pairwise_u_I32x4 V128:$in)>;
+def : Pat<(v8i16 (int_wasm_extadd_pairwise_unsigned (v16i8 V128:$in))),
+ (extadd_pairwise_u_I16x8 V128:$in)>;
+
// f64x2 <-> f32x4 conversions
def demote_t : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>;
def demote_zero : SDNode<"WebAssemblyISD::DEMOTE_ZERO", demote_t>;
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyRegStackify.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyRegStackify.cpp
index bc91c6424b63..08ca20b5eef6 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyRegStackify.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyRegStackify.cpp
@@ -247,7 +247,8 @@ static void query(const MachineInstr &MI, bool &Read, bool &Write,
// Check for writes to __stack_pointer global.
if ((MI.getOpcode() == WebAssembly::GLOBAL_SET_I32 ||
MI.getOpcode() == WebAssembly::GLOBAL_SET_I64) &&
- strcmp(MI.getOperand(0).getSymbolName(), "__stack_pointer") == 0)
+ MI.getOperand(0).isSymbol() &&
+ !strcmp(MI.getOperand(0).getSymbolName(), "__stack_pointer"))
StackPointer = true;
// Analyze calls.
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
index 08fb7586d215..0eefd3e2b350 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
@@ -166,12 +166,6 @@ InstructionCost WebAssemblyTTIImpl::getMemoryOpCost(
CostKind);
}
- int ISD = TLI->InstructionOpcodeToISD(Opcode);
- if (ISD != ISD::LOAD) {
- return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace,
- CostKind);
- }
-
EVT VT = TLI->getValueType(DL, Ty, true);
// Type legalization can't handle structs
if (VT == MVT::Other)
@@ -182,22 +176,121 @@ InstructionCost WebAssemblyTTIImpl::getMemoryOpCost(
if (!LT.first.isValid())
return InstructionCost::getInvalid();
- // 128-bit loads are a single instruction. 32-bit and 64-bit vector loads can
- // be lowered to load32_zero and load64_zero respectively. Assume SIMD loads
- // are twice as expensive as scalar.
+ int ISD = TLI->InstructionOpcodeToISD(Opcode);
unsigned width = VT.getSizeInBits();
- switch (width) {
- default:
- break;
- case 32:
- case 64:
- case 128:
- return 2;
+ if (ISD == ISD::LOAD) {
+ // 128-bit loads are a single instruction. 32-bit and 64-bit vector loads
+ // can be lowered to load32_zero and load64_zero respectively. Assume SIMD
+ // loads are twice as expensive as scalar.
+ switch (width) {
+ default:
+ break;
+ case 32:
+ case 64:
+ case 128:
+ return 2;
+ }
+ } else if (ISD == ISD::STORE) {
+ // For stores, we can use store lane operations.
+ switch (width) {
+ default:
+ break;
+ case 8:
+ case 16:
+ case 32:
+ case 64:
+ case 128:
+ return 2;
+ }
}
return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace, CostKind);
}
+InstructionCost WebAssemblyTTIImpl::getInterleavedMemoryOpCost(
+ unsigned Opcode, Type *Ty, unsigned Factor, ArrayRef<unsigned> Indices,
+ Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
+ bool UseMaskForCond, bool UseMaskForGaps) const {
+ assert(Factor >= 2 && "Invalid interleave factor");
+
+ auto *VecTy = cast<VectorType>(Ty);
+ if (!ST->hasSIMD128() || !isa<FixedVectorType>(VecTy)) {
+ return InstructionCost::getInvalid();
+ }
+
+ if (UseMaskForCond || UseMaskForGaps)
+ return BaseT::getInterleavedMemoryOpCost(Opcode, Ty, Factor, Indices,
+ Alignment, AddressSpace, CostKind,
+ UseMaskForCond, UseMaskForGaps);
+
+ constexpr unsigned MaxInterleaveFactor = 4;
+ if (Factor <= MaxInterleaveFactor) {
+ unsigned MinElts = VecTy->getElementCount().getKnownMinValue();
+ // Ensure the number of vector elements is greater than 1.
+ if (MinElts < 2 || MinElts % Factor != 0)
+ return InstructionCost::getInvalid();
+
+ unsigned ElSize = DL.getTypeSizeInBits(VecTy->getElementType());
+ // Ensure the element type is legal.
+ if (ElSize != 8 && ElSize != 16 && ElSize != 32 && ElSize != 64)
+ return InstructionCost::getInvalid();
+
+ auto *SubVecTy =
+ VectorType::get(VecTy->getElementType(),
+ VecTy->getElementCount().divideCoefficientBy(Factor));
+ InstructionCost MemCost =
+ getMemoryOpCost(Opcode, SubVecTy, Alignment, AddressSpace, CostKind);
+
+ unsigned VecSize = DL.getTypeSizeInBits(SubVecTy);
+ unsigned MaxVecSize = 128;
+ unsigned NumAccesses =
+ std::max<unsigned>(1, (MinElts * ElSize + MaxVecSize - 1) / VecSize);
+
+ // A stride of two is commonly supported via dedicated instructions, so it
+ // should be relatively cheap for all element sizes. A stride of four is
+ // more expensive as it will likely require more shuffles. Using two
+ // simd128 inputs is considered more expensive and we mainly account for
+ // shuffling two inputs (32 bytes), but we do model 4 x v4i32 to enable
+ // arithmetic kernels.
+ static const CostTblEntry ShuffleCostTbl[] = {
+ // One reg.
+ {2, MVT::v2i8, 1}, // interleave 2 x 2i8 into 4i8
+ {2, MVT::v4i8, 1}, // interleave 2 x 4i8 into 8i8
+ {2, MVT::v8i8, 1}, // interleave 2 x 8i8 into 16i8
+ {2, MVT::v2i16, 1}, // interleave 2 x 2i16 into 4i16
+ {2, MVT::v4i16, 1}, // interleave 2 x 4i16 into 8i16
+ {2, MVT::v2i32, 1}, // interleave 2 x 2i32 into 4i32
+
+ // Two regs.
+ {2, MVT::v16i8, 2}, // interleave 2 x 16i8 into 32i8
+ {2, MVT::v8i16, 2}, // interleave 2 x 8i16 into 16i16
+ {2, MVT::v4i32, 2}, // interleave 2 x 4i32 into 8i32
+
+ // One reg.
+ {4, MVT::v2i8, 4}, // interleave 4 x 2i8 into 8i8
+ {4, MVT::v4i8, 4}, // interleave 4 x 4i8 into 16i8
+ {4, MVT::v2i16, 4}, // interleave 4 x 2i16 into 8i16
+
+ // Two regs.
+ {4, MVT::v8i8, 16}, // interleave 4 x 8i8 into 32i8
+ {4, MVT::v4i16, 8}, // interleave 4 x 4i16 into 16i16
+ {4, MVT::v2i32, 4}, // interleave 4 x 2i32 into 8i32
+
+ // Four regs.
+ {4, MVT::v4i32, 16}, // interleave 4 x 4i32 into 16i32
+ };
+
+ EVT ETy = TLI->getValueType(DL, SubVecTy);
+ if (const auto *Entry =
+ CostTableLookup(ShuffleCostTbl, Factor, ETy.getSimpleVT()))
+ return Entry->Cost + (NumAccesses * MemCost);
+ }
+
+ return BaseT::getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices,
+ Alignment, AddressSpace, CostKind,
+ UseMaskForCond, UseMaskForGaps);
+}
+
InstructionCost WebAssemblyTTIImpl::getVectorInstrCost(
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
const Value *Op0, const Value *Op1) const {
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
index c915eeb07d4f..2573066cd5d6 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
@@ -82,6 +82,10 @@ public:
TTI::TargetCostKind CostKind,
TTI::OperandValueInfo OpInfo = {TTI::OK_AnyValue, TTI::OP_None},
const Instruction *I = nullptr) const override;
+ InstructionCost getInterleavedMemoryOpCost(
+ unsigned Opcode, Type *Ty, unsigned Factor, ArrayRef<unsigned> Indices,
+ Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
+ bool UseMaskForCond, bool UseMaskForGaps) const override;
using BaseT::getVectorInstrCost;
InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
TTI::TargetCostKind CostKind,