diff options
Diffstat (limited to 'llvm/lib/Target/WebAssembly')
7 files changed, 171 insertions, 49 deletions
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISD.def b/llvm/lib/Target/WebAssembly/WebAssemblyISD.def index 378ef2c8f250..1eae3586d16b 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISD.def +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISD.def @@ -27,6 +27,7 @@ HANDLE_NODETYPE(WrapperREL) HANDLE_NODETYPE(BR_IF) HANDLE_NODETYPE(BR_TABLE) HANDLE_NODETYPE(DOT) +HANDLE_NODETYPE(EXT_ADD_PAIRWISE_U) HANDLE_NODETYPE(SHUFFLE) HANDLE_NODETYPE(SWIZZLE) HANDLE_NODETYPE(VEC_SHL) diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp index c6c2d0cfccb6..fe100dab427e 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -2183,13 +2183,10 @@ SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) { SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS); SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS); - SDValue LowLow = DAG.getNode(LowOpc, DL, MVT::v4i32, MulLow); - SDValue LowHigh = DAG.getNode(LowOpc, DL, MVT::v4i32, MulHigh); - SDValue HighLow = DAG.getNode(HighOpc, DL, MVT::v4i32, MulLow); - SDValue HighHigh = DAG.getNode(HighOpc, DL, MVT::v4i32, MulHigh); - - SDValue AddLow = DAG.getNode(ISD::ADD, DL, MVT::v4i32, LowLow, HighLow); - SDValue AddHigh = DAG.getNode(ISD::ADD, DL, MVT::v4i32, LowHigh, HighHigh); + SDValue AddLow = + DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL, MVT::v4i32, MulLow); + SDValue AddHigh = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL, + MVT::v4i32, MulHigh); SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh); return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add); } @@ -3588,34 +3585,53 @@ static SDValue performMulCombine(SDNode *N, if (auto Res = TryWideExtMulCombine(N, DCI.DAG)) return Res; - // We don't natively support v16i8 mul, but we do support v8i16 so split the - // inputs and extend them to v8i16. Only do this before legalization in case - // a narrow vector is widened and may be simplified later. - if (!DCI.isBeforeLegalize() || VT != MVT::v16i8) + // We don't natively support v16i8 or v8i8 mul, but we do support v8i16. So, + // extend them to v8i16. Only do this before legalization in case a narrow + // vector is widened and may be simplified later. + if (!DCI.isBeforeLegalize() || (VT != MVT::v8i8 && VT != MVT::v16i8)) return SDValue(); SDLoc DL(N); SelectionDAG &DAG = DCI.DAG; SDValue LHS = N->getOperand(0); SDValue RHS = N->getOperand(1); - SDValue LowLHS = - DAG.getNode(WebAssemblyISD::EXTEND_LOW_U, DL, MVT::v8i16, LHS); - SDValue HighLHS = - DAG.getNode(WebAssemblyISD::EXTEND_HIGH_U, DL, MVT::v8i16, LHS); - SDValue LowRHS = - DAG.getNode(WebAssemblyISD::EXTEND_LOW_U, DL, MVT::v8i16, RHS); - SDValue HighRHS = - DAG.getNode(WebAssemblyISD::EXTEND_HIGH_U, DL, MVT::v8i16, RHS); - - SDValue MulLow = - DAG.getBitcast(VT, DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS)); - SDValue MulHigh = DAG.getBitcast( - VT, DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS)); - - // Take the low byte of each lane. - return DAG.getVectorShuffle( - VT, DL, MulLow, MulHigh, - {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}); + EVT MulVT = MVT::v8i16; + + if (VT == MVT::v8i8) { + SDValue PromotedLHS = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v16i8, LHS, + DAG.getUNDEF(MVT::v8i8)); + SDValue PromotedRHS = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v16i8, RHS, + DAG.getUNDEF(MVT::v8i8)); + SDValue LowLHS = + DAG.getNode(WebAssemblyISD::EXTEND_LOW_U, DL, MulVT, PromotedLHS); + SDValue LowRHS = + DAG.getNode(WebAssemblyISD::EXTEND_LOW_U, DL, MulVT, PromotedRHS); + SDValue MulLow = DAG.getBitcast( + MVT::v16i8, DAG.getNode(ISD::MUL, DL, MulVT, LowLHS, LowRHS)); + // Take the low byte of each lane. + SDValue Shuffle = DAG.getVectorShuffle( + MVT::v16i8, DL, MulLow, DAG.getUNDEF(MVT::v16i8), + {0, 2, 4, 6, 8, 10, 12, 14, -1, -1, -1, -1, -1, -1, -1, -1}); + return extractSubVector(Shuffle, 0, DAG, DL, 64); + } else { + assert(VT == MVT::v16i8 && "Expected v16i8"); + SDValue LowLHS = DAG.getNode(WebAssemblyISD::EXTEND_LOW_U, DL, MulVT, LHS); + SDValue LowRHS = DAG.getNode(WebAssemblyISD::EXTEND_LOW_U, DL, MulVT, RHS); + SDValue HighLHS = + DAG.getNode(WebAssemblyISD::EXTEND_HIGH_U, DL, MulVT, LHS); + SDValue HighRHS = + DAG.getNode(WebAssemblyISD::EXTEND_HIGH_U, DL, MulVT, RHS); + + SDValue MulLow = + DAG.getBitcast(VT, DAG.getNode(ISD::MUL, DL, MulVT, LowLHS, LowRHS)); + SDValue MulHigh = + DAG.getBitcast(VT, DAG.getNode(ISD::MUL, DL, MulVT, HighLHS, HighRHS)); + + // Take the low byte of each lane. + return DAG.getVectorShuffle( + VT, DL, MulLow, MulHigh, + {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}); + } } SDValue diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.cpp index a934853ff9f4..feac04a17068 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.cpp @@ -34,7 +34,7 @@ using namespace llvm; #include "WebAssemblyGenInstrInfo.inc" WebAssemblyInstrInfo::WebAssemblyInstrInfo(const WebAssemblySubtarget &STI) - : WebAssemblyGenInstrInfo(WebAssembly::ADJCALLSTACKDOWN, + : WebAssemblyGenInstrInfo(STI, WebAssembly::ADJCALLSTACKDOWN, WebAssembly::ADJCALLSTACKUP, WebAssembly::CATCHRET), RI(STI.getTargetTriple()) {} diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td index f06f8d5174e3..3c26b453c448 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td @@ -1453,15 +1453,22 @@ if !ne(t1, t2) then def : Pat<(t1.vt (bitconvert (t2.vt V128:$v))), (t1.vt V128:$v)>; // Extended pairwise addition +def extadd_pairwise_u : SDNode<"WebAssemblyISD::EXT_ADD_PAIRWISE_U", extend_t>; + defm "" : SIMDConvert<I16x8, I8x16, int_wasm_extadd_pairwise_signed, "extadd_pairwise_i8x16_s", 0x7c>; -defm "" : SIMDConvert<I16x8, I8x16, int_wasm_extadd_pairwise_unsigned, +defm "" : SIMDConvert<I16x8, I8x16, extadd_pairwise_u, "extadd_pairwise_i8x16_u", 0x7d>; defm "" : SIMDConvert<I32x4, I16x8, int_wasm_extadd_pairwise_signed, "extadd_pairwise_i16x8_s", 0x7e>; -defm "" : SIMDConvert<I32x4, I16x8, int_wasm_extadd_pairwise_unsigned, +defm "" : SIMDConvert<I32x4, I16x8, extadd_pairwise_u, "extadd_pairwise_i16x8_u", 0x7f>; +def : Pat<(v4i32 (int_wasm_extadd_pairwise_unsigned (v8i16 V128:$in))), + (extadd_pairwise_u_I32x4 V128:$in)>; +def : Pat<(v8i16 (int_wasm_extadd_pairwise_unsigned (v16i8 V128:$in))), + (extadd_pairwise_u_I16x8 V128:$in)>; + // f64x2 <-> f32x4 conversions def demote_t : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>; def demote_zero : SDNode<"WebAssemblyISD::DEMOTE_ZERO", demote_t>; diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyRegStackify.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyRegStackify.cpp index bc91c6424b63..08ca20b5eef6 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyRegStackify.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyRegStackify.cpp @@ -247,7 +247,8 @@ static void query(const MachineInstr &MI, bool &Read, bool &Write, // Check for writes to __stack_pointer global. if ((MI.getOpcode() == WebAssembly::GLOBAL_SET_I32 || MI.getOpcode() == WebAssembly::GLOBAL_SET_I64) && - strcmp(MI.getOperand(0).getSymbolName(), "__stack_pointer") == 0) + MI.getOperand(0).isSymbol() && + !strcmp(MI.getOperand(0).getSymbolName(), "__stack_pointer")) StackPointer = true; // Analyze calls. diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp index 08fb7586d215..0eefd3e2b350 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp @@ -166,12 +166,6 @@ InstructionCost WebAssemblyTTIImpl::getMemoryOpCost( CostKind); } - int ISD = TLI->InstructionOpcodeToISD(Opcode); - if (ISD != ISD::LOAD) { - return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace, - CostKind); - } - EVT VT = TLI->getValueType(DL, Ty, true); // Type legalization can't handle structs if (VT == MVT::Other) @@ -182,22 +176,121 @@ InstructionCost WebAssemblyTTIImpl::getMemoryOpCost( if (!LT.first.isValid()) return InstructionCost::getInvalid(); - // 128-bit loads are a single instruction. 32-bit and 64-bit vector loads can - // be lowered to load32_zero and load64_zero respectively. Assume SIMD loads - // are twice as expensive as scalar. + int ISD = TLI->InstructionOpcodeToISD(Opcode); unsigned width = VT.getSizeInBits(); - switch (width) { - default: - break; - case 32: - case 64: - case 128: - return 2; + if (ISD == ISD::LOAD) { + // 128-bit loads are a single instruction. 32-bit and 64-bit vector loads + // can be lowered to load32_zero and load64_zero respectively. Assume SIMD + // loads are twice as expensive as scalar. + switch (width) { + default: + break; + case 32: + case 64: + case 128: + return 2; + } + } else if (ISD == ISD::STORE) { + // For stores, we can use store lane operations. + switch (width) { + default: + break; + case 8: + case 16: + case 32: + case 64: + case 128: + return 2; + } } return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace, CostKind); } +InstructionCost WebAssemblyTTIImpl::getInterleavedMemoryOpCost( + unsigned Opcode, Type *Ty, unsigned Factor, ArrayRef<unsigned> Indices, + Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind, + bool UseMaskForCond, bool UseMaskForGaps) const { + assert(Factor >= 2 && "Invalid interleave factor"); + + auto *VecTy = cast<VectorType>(Ty); + if (!ST->hasSIMD128() || !isa<FixedVectorType>(VecTy)) { + return InstructionCost::getInvalid(); + } + + if (UseMaskForCond || UseMaskForGaps) + return BaseT::getInterleavedMemoryOpCost(Opcode, Ty, Factor, Indices, + Alignment, AddressSpace, CostKind, + UseMaskForCond, UseMaskForGaps); + + constexpr unsigned MaxInterleaveFactor = 4; + if (Factor <= MaxInterleaveFactor) { + unsigned MinElts = VecTy->getElementCount().getKnownMinValue(); + // Ensure the number of vector elements is greater than 1. + if (MinElts < 2 || MinElts % Factor != 0) + return InstructionCost::getInvalid(); + + unsigned ElSize = DL.getTypeSizeInBits(VecTy->getElementType()); + // Ensure the element type is legal. + if (ElSize != 8 && ElSize != 16 && ElSize != 32 && ElSize != 64) + return InstructionCost::getInvalid(); + + auto *SubVecTy = + VectorType::get(VecTy->getElementType(), + VecTy->getElementCount().divideCoefficientBy(Factor)); + InstructionCost MemCost = + getMemoryOpCost(Opcode, SubVecTy, Alignment, AddressSpace, CostKind); + + unsigned VecSize = DL.getTypeSizeInBits(SubVecTy); + unsigned MaxVecSize = 128; + unsigned NumAccesses = + std::max<unsigned>(1, (MinElts * ElSize + MaxVecSize - 1) / VecSize); + + // A stride of two is commonly supported via dedicated instructions, so it + // should be relatively cheap for all element sizes. A stride of four is + // more expensive as it will likely require more shuffles. Using two + // simd128 inputs is considered more expensive and we mainly account for + // shuffling two inputs (32 bytes), but we do model 4 x v4i32 to enable + // arithmetic kernels. + static const CostTblEntry ShuffleCostTbl[] = { + // One reg. + {2, MVT::v2i8, 1}, // interleave 2 x 2i8 into 4i8 + {2, MVT::v4i8, 1}, // interleave 2 x 4i8 into 8i8 + {2, MVT::v8i8, 1}, // interleave 2 x 8i8 into 16i8 + {2, MVT::v2i16, 1}, // interleave 2 x 2i16 into 4i16 + {2, MVT::v4i16, 1}, // interleave 2 x 4i16 into 8i16 + {2, MVT::v2i32, 1}, // interleave 2 x 2i32 into 4i32 + + // Two regs. + {2, MVT::v16i8, 2}, // interleave 2 x 16i8 into 32i8 + {2, MVT::v8i16, 2}, // interleave 2 x 8i16 into 16i16 + {2, MVT::v4i32, 2}, // interleave 2 x 4i32 into 8i32 + + // One reg. + {4, MVT::v2i8, 4}, // interleave 4 x 2i8 into 8i8 + {4, MVT::v4i8, 4}, // interleave 4 x 4i8 into 16i8 + {4, MVT::v2i16, 4}, // interleave 4 x 2i16 into 8i16 + + // Two regs. + {4, MVT::v8i8, 16}, // interleave 4 x 8i8 into 32i8 + {4, MVT::v4i16, 8}, // interleave 4 x 4i16 into 16i16 + {4, MVT::v2i32, 4}, // interleave 4 x 2i32 into 8i32 + + // Four regs. + {4, MVT::v4i32, 16}, // interleave 4 x 4i32 into 16i32 + }; + + EVT ETy = TLI->getValueType(DL, SubVecTy); + if (const auto *Entry = + CostTableLookup(ShuffleCostTbl, Factor, ETy.getSimpleVT())) + return Entry->Cost + (NumAccesses * MemCost); + } + + return BaseT::getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices, + Alignment, AddressSpace, CostKind, + UseMaskForCond, UseMaskForGaps); +} + InstructionCost WebAssemblyTTIImpl::getVectorInstrCost( unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index, const Value *Op0, const Value *Op1) const { diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h index c915eeb07d4f..2573066cd5d6 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h +++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h @@ -82,6 +82,10 @@ public: TTI::TargetCostKind CostKind, TTI::OperandValueInfo OpInfo = {TTI::OK_AnyValue, TTI::OP_None}, const Instruction *I = nullptr) const override; + InstructionCost getInterleavedMemoryOpCost( + unsigned Opcode, Type *Ty, unsigned Factor, ArrayRef<unsigned> Indices, + Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind, + bool UseMaskForCond, bool UseMaskForGaps) const override; using BaseT::getVectorInstrCost; InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, |
