diff options
| author | Mingming Liu <mingmingl@google.com> | 2025-09-10 15:25:31 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-09-10 15:25:31 -0700 |
| commit | 1417dafa1db9cb1b2b09438aa9f53ea5ab6e36e2 (patch) | |
| tree | 57f4b1f313c8cf74eed8819870f39c36ea263c68 /llvm/lib/Target/NVPTX | |
| parent | 898b813bc8a6d0276bf0f4769f5f2f64b34e632d (diff) | |
| parent | b8cefcb601ddaa18482555c4ff363c01a270c2fe (diff) | |
Merge branch 'main' into users/mingmingl-llvm/samplefdo-profile-formatusers/mingmingl-llvm/samplefdo-profile-format
Diffstat (limited to 'llvm/lib/Target/NVPTX')
| -rw-r--r-- | llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp | 3 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTX.td | 10 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 13 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 58 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h | 1 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 503 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelLowering.h | 15 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp | 6 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXInstrInfo.h | 3 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 20 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 45 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp | 2 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXSubtarget.h | 8 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp | 6 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXUtilities.cpp | 32 |
15 files changed, 423 insertions, 302 deletions
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp index ee1ca4538554..f9bdc0993533 100644 --- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp +++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp @@ -290,7 +290,8 @@ void NVPTXInstPrinter::printAtomicCode(const MCInst *MI, int OpNum, O << ".acq_rel"; return; case NVPTX::Ordering::SequentiallyConsistent: - O << ".seq_cst"; + report_fatal_error( + "NVPTX AtomicCode Printer does not support \"seq_cst\" ordering."); return; case NVPTX::Ordering::Volatile: O << ".volatile"; diff --git a/llvm/lib/Target/NVPTX/NVPTX.td b/llvm/lib/Target/NVPTX/NVPTX.td index 8a445f82e700..31c117a8c0fe 100644 --- a/llvm/lib/Target/NVPTX/NVPTX.td +++ b/llvm/lib/Target/NVPTX/NVPTX.td @@ -80,9 +80,9 @@ class FeaturePTX<int version>: // + Compare within the family by comparing FullSMVersion, given both belongs to // the same family. // + Detect 'a' variants by checking FullSMVersion & 1. -foreach sm = [20, 21, 30, 32, 35, 37, 50, 52, 53, - 60, 61, 62, 70, 72, 75, 80, 86, 87, - 89, 90, 100, 101, 103, 120, 121] in { +foreach sm = [20, 21, 30, 32, 35, 37, 50, 52, 53, 60, + 61, 62, 70, 72, 75, 80, 86, 87, 88, 89, + 90, 100, 101, 103, 110, 120, 121] in { // Base SM version (e.g. FullSMVersion for sm_100 is 1000) def SM#sm : FeatureSM<""#sm, !mul(sm, 10)>; @@ -127,6 +127,7 @@ def : Proc<"sm_75", [SM75, PTX63]>; def : Proc<"sm_80", [SM80, PTX70]>; def : Proc<"sm_86", [SM86, PTX71]>; def : Proc<"sm_87", [SM87, PTX74]>; +def : Proc<"sm_88", [SM88, PTX90]>; def : Proc<"sm_89", [SM89, PTX78]>; def : Proc<"sm_90", [SM90, PTX78]>; def : Proc<"sm_90a", [SM90a, PTX80]>; @@ -139,6 +140,9 @@ def : Proc<"sm_101f", [SM101f, PTX88]>; def : Proc<"sm_103", [SM103, PTX88]>; def : Proc<"sm_103a", [SM103a, PTX88]>; def : Proc<"sm_103f", [SM103f, PTX88]>; +def : Proc<"sm_110", [SM110, PTX90]>; +def : Proc<"sm_110a", [SM110a, PTX90]>; +def : Proc<"sm_110f", [SM110f, PTX90]>; def : Proc<"sm_120", [SM120, PTX87]>; def : Proc<"sm_120a", [SM120a, PTX87]>; def : Proc<"sm_120f", [SM120f, PTX88]>; diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp index 7391c2d488b5..14ca867023e2 100644 --- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp @@ -432,7 +432,7 @@ void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F, // .maxclusterrank directive requires SM_90 or higher, make sure that we // filter it out for lower SM versions, as it causes a hard ptxas crash. const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); - const auto *STI = static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl()); + const NVPTXSubtarget *STI = &NTM.getSubtarget<NVPTXSubtarget>(F); if (STI->getSmVersion() >= 90) { const auto ClusterDim = getClusterDim(F); @@ -669,7 +669,7 @@ void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) { // rest of NVPTX isn't friendly to change subtargets per function and // so the default TargetMachine will have all of the options. const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); - const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl()); + const NVPTXSubtarget *STI = NTM.getSubtargetImpl(); SmallString<128> Str1; raw_svector_ostream OS1(Str1); @@ -680,8 +680,7 @@ void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) { bool NVPTXAsmPrinter::doInitialization(Module &M) { const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); - const NVPTXSubtarget &STI = - *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl()); + const NVPTXSubtarget &STI = *NTM.getSubtargetImpl(); if (M.alias_size() && (STI.getPTXVersion() < 63 || STI.getSmVersion() < 30)) report_fatal_error(".alias requires PTX version >= 6.3 and sm_30"); @@ -716,8 +715,7 @@ void NVPTXAsmPrinter::emitGlobals(const Module &M) { assert(GVVisiting.size() == 0 && "Did not fully process a global variable"); const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); - const NVPTXSubtarget &STI = - *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl()); + const NVPTXSubtarget &STI = *NTM.getSubtargetImpl(); // Print out module-level global variables in proper order for (const GlobalVariable *GV : Globals) @@ -1178,8 +1176,7 @@ void NVPTXAsmPrinter::emitDemotedVars(const Function *F, raw_ostream &O) { ArrayRef<const GlobalVariable *> GVars = It->second; const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); - const NVPTXSubtarget &STI = - *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl()); + const NVPTXSubtarget &STI = *NTM.getSubtargetImpl(); for (const GlobalVariable *GV : GVars) { O << "\t// demoted variable\n\t"; diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp index 3300ed9a5a81..c70f48af33cf 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -170,6 +170,10 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) { } break; } + case NVPTXISD::ATOMIC_CMP_SWAP_B128: + case NVPTXISD::ATOMIC_SWAP_B128: + selectAtomicSwap128(N); + return; case ISD::FADD: case ISD::FMUL: case ISD::FSUB: @@ -1097,11 +1101,6 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { if (PlainLoad && PlainLoad->isIndexed()) return false; - const EVT LoadedEVT = LD->getMemoryVT(); - if (!LoadedEVT.isSimple()) - return false; - const MVT LoadedVT = LoadedEVT.getSimpleVT(); - // Address Space Setting const auto CodeAddrSpace = getAddrSpace(LD); if (canLowerToLDG(*LD, *Subtarget, CodeAddrSpace)) @@ -1111,7 +1110,7 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { SDValue Chain = N->getOperand(0); const auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, LD); - const unsigned FromTypeWidth = LoadedVT.getSizeInBits(); + const unsigned FromTypeWidth = LD->getMemoryVT().getSizeInBits(); // Vector Setting const unsigned FromType = @@ -1165,9 +1164,6 @@ static unsigned getStoreVectorNumElts(SDNode *N) { bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) { MemSDNode *LD = cast<MemSDNode>(N); - const EVT MemEVT = LD->getMemoryVT(); - if (!MemEVT.isSimple()) - return false; // Address Space Setting const auto CodeAddrSpace = getAddrSpace(LD); @@ -1237,10 +1233,6 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) { } bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) { - const EVT LoadedEVT = LD->getMemoryVT(); - if (!LoadedEVT.isSimple()) - return false; - SDLoc DL(LD); unsigned ExtensionType; @@ -1357,10 +1349,6 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { if (PlainStore && PlainStore->isIndexed()) return false; - const EVT StoreVT = ST->getMemoryVT(); - if (!StoreVT.isSimple()) - return false; - // Address Space Setting const auto CodeAddrSpace = getAddrSpace(ST); @@ -1369,7 +1357,7 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { const auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, ST); // Vector Setting - const unsigned ToTypeWidth = StoreVT.getSimpleVT().getSizeInBits(); + const unsigned ToTypeWidth = ST->getMemoryVT().getSizeInBits(); // Create the machine instruction DAG SDValue Value = PlainStore ? PlainStore->getValue() : AtomicStore->getVal(); @@ -1406,8 +1394,7 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) { MemSDNode *ST = cast<MemSDNode>(N); - const EVT StoreVT = ST->getMemoryVT(); - assert(StoreVT.isSimple() && "Store value is not simple"); + const unsigned TotalWidth = ST->getMemoryVT().getSizeInBits(); // Address Space Setting const auto CodeAddrSpace = getAddrSpace(ST); @@ -1420,10 +1407,6 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) { SDValue Chain = ST->getChain(); const auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, ST); - // Type Setting: toType + toTypeWidth - // - for integer type, always use 'u' - const unsigned TotalWidth = StoreVT.getSimpleVT().getSizeInBits(); - const unsigned NumElts = getStoreVectorNumElts(ST); SmallVector<SDValue, 16> Ops; @@ -2337,3 +2320,30 @@ bool NVPTXDAGToDAGISel::tryIntrinsicVoid(SDNode *N) { } } } + +void NVPTXDAGToDAGISel::selectAtomicSwap128(SDNode *N) { + MemSDNode *AN = cast<MemSDNode>(N); + SDLoc dl(N); + + const SDValue Chain = N->getOperand(0); + const auto [Base, Offset] = selectADDR(N->getOperand(1), CurDAG); + SmallVector<SDValue, 5> Ops{Base, Offset}; + Ops.append(N->op_begin() + 2, N->op_end()); + Ops.append({ + getI32Imm(getMemOrder(AN), dl), + getI32Imm(getAtomicScope(AN), dl), + getI32Imm(getAddrSpace(AN), dl), + Chain, + }); + + assert(N->getOpcode() == NVPTXISD::ATOMIC_CMP_SWAP_B128 || + N->getOpcode() == NVPTXISD::ATOMIC_SWAP_B128); + unsigned Opcode = N->getOpcode() == NVPTXISD::ATOMIC_SWAP_B128 + ? NVPTX::ATOM_EXCH_B128 + : NVPTX::ATOM_CAS_B128; + + auto *ATOM = CurDAG->getMachineNode(Opcode, dl, N->getVTList(), Ops); + CurDAG->setNodeMemRefs(ATOM, AN->getMemOperand()); + + ReplaceNode(N, ATOM); +} diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h index e2ad55bc1796..8dcd5362c451 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h @@ -90,6 +90,7 @@ private: bool IsIm2Col = false); void SelectTcgen05Ld(SDNode *N, bool hasOffset = false); void SelectTcgen05St(SDNode *N, bool hasOffset = false); + void selectAtomicSwap128(SDNode *N); inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) { return CurDAG->getTargetConstant(Imm, DL, MVT::i32); diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index bb4bb1195f78..d3fb657851fe 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -198,6 +198,12 @@ static bool IsPTXVectorType(MVT VT) { static std::optional<std::pair<unsigned int, MVT>> getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI, unsigned AddressSpace) { + const bool CanLowerTo256Bit = STI.has256BitVectorLoadStore(AddressSpace); + + if (CanLowerTo256Bit && VectorEVT.isScalarInteger() && + VectorEVT.getSizeInBits() == 256) + return {{4, MVT::i64}}; + if (!VectorEVT.isSimple()) return std::nullopt; const MVT VectorVT = VectorEVT.getSimpleVT(); @@ -214,8 +220,6 @@ getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI, // The size of the PTX virtual register that holds a packed type. unsigned PackRegSize; - bool CanLowerTo256Bit = STI.has256BitVectorLoadStore(AddressSpace); - // We only handle "native" vector sizes for now, e.g. <4 x double> is not // legal. We can (and should) split that into 2 stores of <2 x double> here // but I'm leaving that as a TODO for now. @@ -539,6 +543,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, case ISD::FMINNUM_IEEE: case ISD::FMAXIMUM: case ISD::FMINIMUM: + case ISD::FMAXIMUMNUM: + case ISD::FMINIMUMNUM: IsOpSupported &= STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70; break; case ISD::FEXP2: @@ -702,57 +708,66 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, // intrinsics. setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom); - // Turn FP extload into load/fpextend - setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2f16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2bf16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2bf16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f32, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4f16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4bf16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4bf16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f32, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8f16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8f16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8bf16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8bf16, Expand); - // Turn FP truncstore into trunc + store. - // FIXME: vector types should also be expanded - setTruncStoreAction(MVT::f32, MVT::f16, Expand); - setTruncStoreAction(MVT::f64, MVT::f16, Expand); - setTruncStoreAction(MVT::f32, MVT::bf16, Expand); - setTruncStoreAction(MVT::f64, MVT::bf16, Expand); - setTruncStoreAction(MVT::f64, MVT::f32, Expand); - setTruncStoreAction(MVT::v2f32, MVT::v2f16, Expand); - setTruncStoreAction(MVT::v2f32, MVT::v2bf16, Expand); + // FP extload/truncstore is not legal in PTX. We need to expand all these. + for (auto FloatVTs : + {MVT::fp_valuetypes(), MVT::fp_fixedlen_vector_valuetypes()}) { + for (MVT ValVT : FloatVTs) { + for (MVT MemVT : FloatVTs) { + setLoadExtAction(ISD::EXTLOAD, ValVT, MemVT, Expand); + setTruncStoreAction(ValVT, MemVT, Expand); + } + } + } - // PTX does not support load / store predicate registers - setOperationAction(ISD::LOAD, MVT::i1, Custom); - setOperationAction(ISD::STORE, MVT::i1, Custom); + // To improve CodeGen we'll legalize any-extend loads to zext loads. This is + // how they'll be lowered in ISel anyway, and by doing this a little earlier + // we allow for more DAG combine opportunities. + for (auto IntVTs : + {MVT::integer_valuetypes(), MVT::integer_fixedlen_vector_valuetypes()}) + for (MVT ValVT : IntVTs) + for (MVT MemVT : IntVTs) + if (isTypeLegal(ValVT)) + setLoadExtAction(ISD::EXTLOAD, ValVT, MemVT, Custom); + // PTX does not support load / store predicate registers + setOperationAction({ISD::LOAD, ISD::STORE}, MVT::i1, Custom); for (MVT VT : MVT::integer_valuetypes()) { - setLoadExtAction(ISD::SEXTLOAD, VT, MVT::i1, Promote); - setLoadExtAction(ISD::ZEXTLOAD, VT, MVT::i1, Promote); - setLoadExtAction(ISD::EXTLOAD, VT, MVT::i1, Promote); + setLoadExtAction({ISD::SEXTLOAD, ISD::ZEXTLOAD, ISD::EXTLOAD}, VT, MVT::i1, + Promote); setTruncStoreAction(VT, MVT::i1, Expand); } + // Disable generations of extload/truncstore for v2i16/v2i8. The generic + // expansion for these nodes when they are unaligned is incorrect if the + // type is a vector. + // + // TODO: Fix the generic expansion for these nodes found in + // TargetLowering::expandUnalignedLoad/Store. + setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::v2i16, + MVT::v2i8, Expand); + setTruncStoreAction(MVT::v2i16, MVT::v2i8, Expand); + + // Register custom handling for illegal type loads/stores. We'll try to custom + // lower almost all illegal types and logic in the lowering will discard cases + // we can't handle. + setOperationAction({ISD::LOAD, ISD::STORE}, {MVT::i128, MVT::f128}, Custom); + for (MVT VT : MVT::fixedlen_vector_valuetypes()) + if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256) + setOperationAction({ISD::STORE, ISD::LOAD}, VT, Custom); + + // Custom legalization for LDU intrinsics. + // TODO: The logic to lower these is not very robust and we should rewrite it. + // Perhaps LDU should not be represented as an intrinsic at all. + setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom); + for (MVT VT : MVT::fixedlen_vector_valuetypes()) + if (IsPTXVectorType(VT)) + setOperationAction(ISD::INTRINSIC_W_CHAIN, VT, Custom); + setCondCodeAction({ISD::SETNE, ISD::SETEQ, ISD::SETUGE, ISD::SETULE, ISD::SETUGT, ISD::SETULT, ISD::SETGT, ISD::SETLT, ISD::SETGE, ISD::SETLE}, MVT::i1, Expand); - // expand extload of vector of integers. - setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::v2i16, - MVT::v2i8, Expand); - setTruncStoreAction(MVT::v2i16, MVT::v2i8, Expand); - // This is legal in NVPTX setOperationAction(ISD::ConstantFP, MVT::f64, Legal); setOperationAction(ISD::ConstantFP, MVT::f32, Legal); @@ -767,24 +782,12 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, // DEBUGTRAP can be lowered to PTX brkpt setOperationAction(ISD::DEBUGTRAP, MVT::Other, Legal); - // Register custom handling for vector loads/stores - for (MVT VT : MVT::fixedlen_vector_valuetypes()) - if (IsPTXVectorType(VT)) - setOperationAction({ISD::LOAD, ISD::STORE, ISD::INTRINSIC_W_CHAIN}, VT, - Custom); - - setOperationAction({ISD::LOAD, ISD::STORE, ISD::INTRINSIC_W_CHAIN}, - {MVT::i128, MVT::f128}, Custom); - // Support varargs. setOperationAction(ISD::VASTART, MVT::Other, Custom); setOperationAction(ISD::VAARG, MVT::Other, Custom); setOperationAction(ISD::VACOPY, MVT::Other, Expand); setOperationAction(ISD::VAEND, MVT::Other, Expand); - // Custom handling for i8 intrinsics - setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom); - setOperationAction({ISD::ABS, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX}, {MVT::i16, MVT::i32, MVT::i64}, Legal); @@ -988,7 +991,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, if (getOperationAction(ISD::FABS, MVT::bf16) == Promote) AddPromotedToType(ISD::FABS, MVT::bf16, MVT::f32); - for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) { + for (const auto &Op : + {ISD::FMINNUM, ISD::FMAXNUM, ISD::FMINIMUMNUM, ISD::FMAXIMUMNUM}) { setOperationAction(Op, MVT::f32, Legal); setOperationAction(Op, MVT::f64, Legal); setFP16OperationAction(Op, MVT::f16, Legal, Promote); @@ -1039,7 +1043,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setOperationAction(ISD::ADDRSPACECAST, {MVT::i32, MVT::i64}, Custom); setOperationAction(ISD::ATOMIC_LOAD_SUB, {MVT::i32, MVT::i64}, Expand); - // No FPOW or FREM in PTX. + + // atom.b128 is legal in PTX but since we don't represent i128 as a legal + // type, we need to custom lower it. + setOperationAction({ISD::ATOMIC_CMP_SWAP, ISD::ATOMIC_SWAP}, MVT::i128, + Custom); // Now deduce the information based on the above mentioned // actions @@ -1047,7 +1055,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, // PTX support for 16-bit CAS is emulated. Only use 32+ setMinCmpXchgSizeInBits(STI.getMinCmpXchgSizeInBits()); - setMaxAtomicSizeInBitsSupported(64); + setMaxAtomicSizeInBitsSupported(STI.hasAtomSwap128() ? 128 : 64); setMaxDivRemBitWidthSupported(64); // Custom lowering for tcgen05.ld vector operands @@ -1080,6 +1088,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const { case NVPTXISD::FIRST_NUMBER: break; + MAKE_CASE(NVPTXISD::ATOMIC_CMP_SWAP_B128) + MAKE_CASE(NVPTXISD::ATOMIC_SWAP_B128) MAKE_CASE(NVPTXISD::RET_GLUE) MAKE_CASE(NVPTXISD::DeclareArrayParam) MAKE_CASE(NVPTXISD::DeclareScalarParam) @@ -3088,29 +3098,112 @@ SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const { MachinePointerInfo(SV)); } -static void replaceLoadVector(SDNode *N, SelectionDAG &DAG, - SmallVectorImpl<SDValue> &Results, - const NVPTXSubtarget &STI); +/// replaceLoadVector - Convert vector loads into multi-output scalar loads. +static std::optional<std::pair<SDValue, SDValue>> +replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) { + LoadSDNode *LD = cast<LoadSDNode>(N); + const EVT ResVT = LD->getValueType(0); + const EVT MemVT = LD->getMemoryVT(); -SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const { - if (Op.getValueType() == MVT::i1) - return LowerLOADi1(Op, DAG); + // If we're doing sign/zero extension as part of the load, avoid lowering to + // a LoadV node. TODO: consider relaxing this restriction. + if (ResVT != MemVT) + return std::nullopt; - EVT VT = Op.getValueType(); + const auto NumEltsAndEltVT = + getVectorLoweringShape(ResVT, STI, LD->getAddressSpace()); + if (!NumEltsAndEltVT) + return std::nullopt; + const auto [NumElts, EltVT] = NumEltsAndEltVT.value(); + + Align Alignment = LD->getAlign(); + const auto &TD = DAG.getDataLayout(); + Align PrefAlign = TD.getPrefTypeAlign(MemVT.getTypeForEVT(*DAG.getContext())); + if (Alignment < PrefAlign) { + // This load is not sufficiently aligned, so bail out and let this vector + // load be scalarized. Note that we may still be able to emit smaller + // vector loads. For example, if we are loading a <4 x float> with an + // alignment of 8, this check will fail but the legalizer will try again + // with 2 x <2 x float>, which will succeed with an alignment of 8. + return std::nullopt; + } + + // Since LoadV2 is a target node, we cannot rely on DAG type legalization. + // Therefore, we must ensure the type is legal. For i1 and i8, we set the + // loaded type to i16 and propagate the "real" type as the memory type. + const MVT LoadEltVT = (EltVT.getSizeInBits() < 16) ? MVT::i16 : EltVT; + + unsigned Opcode; + switch (NumElts) { + default: + return std::nullopt; + case 2: + Opcode = NVPTXISD::LoadV2; + break; + case 4: + Opcode = NVPTXISD::LoadV4; + break; + case 8: + Opcode = NVPTXISD::LoadV8; + break; + } + auto ListVTs = SmallVector<EVT, 9>(NumElts, LoadEltVT); + ListVTs.push_back(MVT::Other); + SDVTList LdResVTs = DAG.getVTList(ListVTs); - if (NVPTX::isPackedVectorTy(VT)) { - // v2f32/v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to - // handle unaligned loads and have to handle it here. - LoadSDNode *Load = cast<LoadSDNode>(Op); - EVT MemVT = Load->getMemoryVT(); - if (!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(), - MemVT, *Load->getMemOperand())) { - SDValue Ops[2]; - std::tie(Ops[0], Ops[1]) = expandUnalignedLoad(Load, DAG); - return DAG.getMergeValues(Ops, SDLoc(Op)); + SDLoc DL(LD); + + // Copy regular operands + SmallVector<SDValue, 8> OtherOps(LD->ops()); + + // The select routine does not have access to the LoadSDNode instance, so + // pass along the extension information + OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL)); + + SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps, MemVT, + LD->getMemOperand()); + + SmallVector<SDValue> ScalarRes; + if (EltVT.isVector()) { + assert(EVT(EltVT.getVectorElementType()) == ResVT.getVectorElementType()); + assert(NumElts * EltVT.getVectorNumElements() == + ResVT.getVectorNumElements()); + // Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back + // into individual elements. + for (const unsigned I : llvm::seq(NumElts)) { + SDValue SubVector = NewLD.getValue(I); + DAG.ExtractVectorElements(SubVector, ScalarRes); + } + } else { + for (const unsigned I : llvm::seq(NumElts)) { + SDValue Res = NewLD.getValue(I); + if (LoadEltVT != EltVT) + Res = DAG.getNode(ISD::TRUNCATE, DL, EltVT, Res); + ScalarRes.push_back(Res); } } + SDValue LoadChain = NewLD.getValue(NumElts); + + const MVT BuildVecVT = + MVT::getVectorVT(EltVT.getScalarType(), ScalarRes.size()); + SDValue BuildVec = DAG.getBuildVector(BuildVecVT, DL, ScalarRes); + SDValue LoadValue = DAG.getBitcast(ResVT, BuildVec); + + return {{LoadValue, LoadChain}}; +} + +static void replaceLoadVector(SDNode *N, SelectionDAG &DAG, + SmallVectorImpl<SDValue> &Results, + const NVPTXSubtarget &STI) { + if (auto Res = replaceLoadVector(N, DAG, STI)) + Results.append({Res->first, Res->second}); +} + +static SDValue lowerLoadVector(SDNode *N, SelectionDAG &DAG, + const NVPTXSubtarget &STI) { + if (auto Res = replaceLoadVector(N, DAG, STI)) + return DAG.getMergeValues({Res->first, Res->second}, SDLoc(N)); return SDValue(); } @@ -3118,13 +3211,10 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const { // => // v1 = ld i8* addr (-> i16) // v = trunc i16 to i1 -SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const { - SDNode *Node = Op.getNode(); - LoadSDNode *LD = cast<LoadSDNode>(Node); - SDLoc dl(Node); +static SDValue lowerLOADi1(LoadSDNode *LD, SelectionDAG &DAG) { + SDLoc dl(LD); assert(LD->getExtensionType() == ISD::NON_EXTLOAD); - assert(Node->getValueType(0) == MVT::i1 && - "Custom lowering for i1 load only"); + assert(LD->getValueType(0) == MVT::i1 && "Custom lowering for i1 load only"); SDValue newLD = DAG.getExtLoad(ISD::ZEXTLOAD, dl, MVT::i16, LD->getChain(), LD->getBasePtr(), LD->getPointerInfo(), MVT::i8, LD->getAlign(), @@ -3133,35 +3223,31 @@ SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const { // The legalizer (the caller) is expecting two values from the legalized // load, so we build a MergeValues node for it. See ExpandUnalignedLoad() // in LegalizeDAG.cpp which also uses MergeValues. - SDValue Ops[] = { result, LD->getChain() }; - return DAG.getMergeValues(Ops, dl); + return DAG.getMergeValues({result, LD->getChain()}, dl); } -SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const { - StoreSDNode *Store = cast<StoreSDNode>(Op); - EVT VT = Store->getMemoryVT(); - - if (VT == MVT::i1) - return LowerSTOREi1(Op, DAG); +SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const { + LoadSDNode *LD = cast<LoadSDNode>(Op); - // v2f32/v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to - // handle unaligned stores and have to handle it here. - if (NVPTX::isPackedVectorTy(VT) && - !allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(), - VT, *Store->getMemOperand())) - return expandUnalignedStore(Store, DAG); + if (Op.getValueType() == MVT::i1) + return lowerLOADi1(LD, DAG); - // v2f16/v2bf16/v2i16 don't need special handling. - if (NVPTX::isPackedVectorTy(VT) && VT.is32BitVector()) - return SDValue(); + // To improve CodeGen we'll legalize any-extend loads to zext loads. This is + // how they'll be lowered in ISel anyway, and by doing this a little earlier + // we allow for more DAG combine opportunities. + if (LD->getExtensionType() == ISD::EXTLOAD) { + assert(LD->getValueType(0).isInteger() && LD->getMemoryVT().isInteger() && + "Unexpected fpext-load"); + return DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(Op), Op.getValueType(), + LD->getChain(), LD->getBasePtr(), LD->getMemoryVT(), + LD->getMemOperand()); + } - // Lower store of any other vector type, including v2f32 as we want to break - // it apart since this is not a widely-supported type. - return LowerSTOREVector(Op, DAG); + llvm_unreachable("Unexpected custom lowering for load"); } -SDValue -NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const { +static SDValue lowerSTOREVector(SDValue Op, SelectionDAG &DAG, + const NVPTXSubtarget &STI) { MemSDNode *N = cast<MemSDNode>(Op.getNode()); SDValue Val = N->getOperand(1); SDLoc DL(N); @@ -3253,6 +3339,18 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const { return NewSt; } +SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const { + StoreSDNode *Store = cast<StoreSDNode>(Op); + EVT VT = Store->getMemoryVT(); + + if (VT == MVT::i1) + return LowerSTOREi1(Op, DAG); + + // Lower store of any other vector type, including v2f32 as we want to break + // it apart since this is not a widely-supported type. + return lowerSTOREVector(Op, DAG, STI); +} + // st i1 v, addr // => // v1 = zxt v to i16 @@ -4010,14 +4108,8 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic( case Intrinsic::nvvm_ldu_global_i: case Intrinsic::nvvm_ldu_global_f: case Intrinsic::nvvm_ldu_global_p: { - auto &DL = I.getDataLayout(); Info.opc = ISD::INTRINSIC_W_CHAIN; - if (Intrinsic == Intrinsic::nvvm_ldu_global_i) - Info.memVT = getValueType(DL, I.getType()); - else if(Intrinsic == Intrinsic::nvvm_ldu_global_p) - Info.memVT = getPointerTy(DL); - else - Info.memVT = getValueType(DL, I.getType()); + Info.memVT = getValueType(I.getDataLayout(), I.getType()); Info.ptrVal = I.getArgOperand(0); Info.offset = 0; Info.flags = MachineMemOperand::MOLoad; @@ -5152,11 +5244,34 @@ static SDValue combinePackingMovIntoStore(SDNode *N, ST->getMemoryVT(), ST->getMemOperand()); } -static SDValue PerformStoreCombine(SDNode *N, - TargetLowering::DAGCombinerInfo &DCI) { +static SDValue combineSTORE(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, + const NVPTXSubtarget &STI) { + + if (DCI.isBeforeLegalize() && N->getOpcode() == ISD::STORE) { + // Here is our chance to custom lower a store with a non-simple type. + // Unfortunately, we can't do this in the legalizer because there is no + // way to setOperationAction for an non-simple type. + StoreSDNode *ST = cast<StoreSDNode>(N); + if (!ST->getValue().getValueType().isSimple()) + return lowerSTOREVector(SDValue(ST, 0), DCI.DAG, STI); + } + return combinePackingMovIntoStore(N, DCI, 1, 2); } +static SDValue combineLOAD(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, + const NVPTXSubtarget &STI) { + if (DCI.isBeforeLegalize() && N->getOpcode() == ISD::LOAD) { + // Here is our chance to custom lower a load with a non-simple type. + // Unfortunately, we can't do this in the legalizer because there is no + // way to setOperationAction for an non-simple type. + if (!N->getValueType(0).isSimple()) + return lowerLoadVector(N, DCI.DAG, STI); + } + + return combineUnpackingMovIntoLoad(N, DCI); +} + /// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD. /// static SDValue PerformADDCombine(SDNode *N, @@ -5884,7 +5999,7 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, case ISD::LOAD: case NVPTXISD::LoadV2: case NVPTXISD::LoadV4: - return combineUnpackingMovIntoLoad(N, DCI); + return combineLOAD(N, DCI, STI); case ISD::MUL: return PerformMULCombine(N, DCI, OptLevel); case NVPTXISD::PRMT: @@ -5901,7 +6016,7 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, case ISD::STORE: case NVPTXISD::StoreV2: case NVPTXISD::StoreV4: - return PerformStoreCombine(N, DCI); + return combineSTORE(N, DCI, STI); case ISD::VSELECT: return PerformVSELECTCombine(N, DCI); } @@ -5930,103 +6045,6 @@ static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG, DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2i8, {Vec0, Vec1})); } -/// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads. -static void replaceLoadVector(SDNode *N, SelectionDAG &DAG, - SmallVectorImpl<SDValue> &Results, - const NVPTXSubtarget &STI) { - LoadSDNode *LD = cast<LoadSDNode>(N); - const EVT ResVT = LD->getValueType(0); - const EVT MemVT = LD->getMemoryVT(); - - // If we're doing sign/zero extension as part of the load, avoid lowering to - // a LoadV node. TODO: consider relaxing this restriction. - if (ResVT != MemVT) - return; - - const auto NumEltsAndEltVT = - getVectorLoweringShape(ResVT, STI, LD->getAddressSpace()); - if (!NumEltsAndEltVT) - return; - const auto [NumElts, EltVT] = NumEltsAndEltVT.value(); - - Align Alignment = LD->getAlign(); - const auto &TD = DAG.getDataLayout(); - Align PrefAlign = TD.getPrefTypeAlign(MemVT.getTypeForEVT(*DAG.getContext())); - if (Alignment < PrefAlign) { - // This load is not sufficiently aligned, so bail out and let this vector - // load be scalarized. Note that we may still be able to emit smaller - // vector loads. For example, if we are loading a <4 x float> with an - // alignment of 8, this check will fail but the legalizer will try again - // with 2 x <2 x float>, which will succeed with an alignment of 8. - return; - } - - // Since LoadV2 is a target node, we cannot rely on DAG type legalization. - // Therefore, we must ensure the type is legal. For i1 and i8, we set the - // loaded type to i16 and propagate the "real" type as the memory type. - const MVT LoadEltVT = (EltVT.getSizeInBits() < 16) ? MVT::i16 : EltVT; - - unsigned Opcode; - switch (NumElts) { - default: - return; - case 2: - Opcode = NVPTXISD::LoadV2; - break; - case 4: - Opcode = NVPTXISD::LoadV4; - break; - case 8: - Opcode = NVPTXISD::LoadV8; - break; - } - auto ListVTs = SmallVector<EVT, 9>(NumElts, LoadEltVT); - ListVTs.push_back(MVT::Other); - SDVTList LdResVTs = DAG.getVTList(ListVTs); - - SDLoc DL(LD); - - // Copy regular operands - SmallVector<SDValue, 8> OtherOps(LD->ops()); - - // The select routine does not have access to the LoadSDNode instance, so - // pass along the extension information - OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL)); - - SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps, - LD->getMemoryVT(), - LD->getMemOperand()); - - SmallVector<SDValue> ScalarRes; - if (EltVT.isVector()) { - assert(EVT(EltVT.getVectorElementType()) == ResVT.getVectorElementType()); - assert(NumElts * EltVT.getVectorNumElements() == - ResVT.getVectorNumElements()); - // Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back - // into individual elements. - for (const unsigned I : llvm::seq(NumElts)) { - SDValue SubVector = NewLD.getValue(I); - DAG.ExtractVectorElements(SubVector, ScalarRes); - } - } else { - for (const unsigned I : llvm::seq(NumElts)) { - SDValue Res = NewLD.getValue(I); - if (LoadEltVT != EltVT) - Res = DAG.getNode(ISD::TRUNCATE, DL, EltVT, Res); - ScalarRes.push_back(Res); - } - } - - SDValue LoadChain = NewLD.getValue(NumElts); - - const MVT BuildVecVT = - MVT::getVectorVT(EltVT.getScalarType(), ScalarRes.size()); - SDValue BuildVec = DAG.getBuildVector(BuildVecVT, DL, ScalarRes); - SDValue LoadValue = DAG.getBitcast(ResVT, BuildVec); - - Results.append({LoadValue, LoadChain}); -} - // Lower vector return type of tcgen05.ld intrinsics static void ReplaceTcgen05Ld(SDNode *N, SelectionDAG &DAG, SmallVectorImpl<SDValue> &Results, @@ -6262,6 +6280,49 @@ static void replaceProxyReg(SDNode *N, SelectionDAG &DAG, Results.push_back(Res); } +static void replaceAtomicSwap128(SDNode *N, SelectionDAG &DAG, + const NVPTXSubtarget &STI, + SmallVectorImpl<SDValue> &Results) { + assert(N->getValueType(0) == MVT::i128 && + "Custom lowering for atomic128 only supports i128"); + + AtomicSDNode *AN = cast<AtomicSDNode>(N); + SDLoc dl(N); + + if (!STI.hasAtomSwap128()) { + DAG.getContext()->diagnose(DiagnosticInfoUnsupported( + DAG.getMachineFunction().getFunction(), + "Support for b128 atomics introduced in PTX ISA version 8.3 and " + "requires target sm_90.", + dl.getDebugLoc())); + + Results.push_back(DAG.getUNDEF(MVT::i128)); + Results.push_back(AN->getOperand(0)); // Chain + return; + } + + SmallVector<SDValue, 6> Ops; + Ops.push_back(AN->getOperand(0)); // Chain + Ops.push_back(AN->getOperand(1)); // Ptr + for (const auto &Op : AN->ops().drop_front(2)) { + // Low part + Ops.push_back(DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i64, Op, + DAG.getIntPtrConstant(0, dl))); + // High part + Ops.push_back(DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i64, Op, + DAG.getIntPtrConstant(1, dl))); + } + unsigned Opcode = N->getOpcode() == ISD::ATOMIC_SWAP + ? NVPTXISD::ATOMIC_SWAP_B128 + : NVPTXISD::ATOMIC_CMP_SWAP_B128; + SDVTList Tys = DAG.getVTList(MVT::i64, MVT::i64, MVT::Other); + SDValue Result = DAG.getMemIntrinsicNode(Opcode, dl, Tys, Ops, MVT::i128, + AN->getMemOperand()); + Results.push_back(DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i128, + {Result.getValue(0), Result.getValue(1)})); + Results.push_back(Result.getValue(2)); +} + void NVPTXTargetLowering::ReplaceNodeResults( SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const { switch (N->getOpcode()) { @@ -6282,6 +6343,10 @@ void NVPTXTargetLowering::ReplaceNodeResults( case NVPTXISD::ProxyReg: replaceProxyReg(N, DAG, *this, Results); return; + case ISD::ATOMIC_CMP_SWAP: + case ISD::ATOMIC_SWAP: + replaceAtomicSwap128(N, DAG, STI, Results); + return; } } @@ -6306,16 +6371,19 @@ NVPTXTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const { } assert(Ty->isIntegerTy() && "Ty should be integer at this point"); - auto ITy = cast<llvm::IntegerType>(Ty); + const unsigned BitWidth = cast<IntegerType>(Ty)->getBitWidth(); switch (AI->getOperation()) { default: return AtomicExpansionKind::CmpXChg; + case AtomicRMWInst::BinOp::Xchg: + if (BitWidth == 128) + return AtomicExpansionKind::None; + LLVM_FALLTHROUGH; case AtomicRMWInst::BinOp::And: case AtomicRMWInst::BinOp::Or: case AtomicRMWInst::BinOp::Xor: - case AtomicRMWInst::BinOp::Xchg: - switch (ITy->getBitWidth()) { + switch (BitWidth) { case 8: case 16: return AtomicExpansionKind::CmpXChg; @@ -6325,6 +6393,8 @@ NVPTXTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const { if (STI.hasAtomBitwise64()) return AtomicExpansionKind::None; return AtomicExpansionKind::CmpXChg; + case 128: + return AtomicExpansionKind::CmpXChg; default: llvm_unreachable("unsupported width encountered"); } @@ -6334,7 +6404,7 @@ NVPTXTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const { case AtomicRMWInst::BinOp::Min: case AtomicRMWInst::BinOp::UMax: case AtomicRMWInst::BinOp::UMin: - switch (ITy->getBitWidth()) { + switch (BitWidth) { case 8: case 16: return AtomicExpansionKind::CmpXChg; @@ -6344,17 +6414,20 @@ NVPTXTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const { if (STI.hasAtomMinMax64()) return AtomicExpansionKind::None; return AtomicExpansionKind::CmpXChg; + case 128: + return AtomicExpansionKind::CmpXChg; default: llvm_unreachable("unsupported width encountered"); } case AtomicRMWInst::BinOp::UIncWrap: case AtomicRMWInst::BinOp::UDecWrap: - switch (ITy->getBitWidth()) { + switch (BitWidth) { case 32: return AtomicExpansionKind::None; case 8: case 16: case 64: + case 128: return AtomicExpansionKind::CmpXChg; default: llvm_unreachable("unsupported width encountered"); diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h index 27f099e22097..03b3edc902e5 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h @@ -81,7 +81,17 @@ enum NodeType : unsigned { CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Z, FIRST_MEMORY_OPCODE, - LoadV2 = FIRST_MEMORY_OPCODE, + + /// These nodes are used to lower atomic instructions with i128 type. They are + /// similar to the generic nodes, but the input and output values are split + /// into two 64-bit values. + /// ValLo, ValHi, OUTCHAIN = ATOMIC_CMP_SWAP_B128(INCHAIN, ptr, cmpLo, cmpHi, + /// swapLo, swapHi) + /// ValLo, ValHi, OUTCHAIN = ATOMIC_SWAP_B128(INCHAIN, ptr, amtLo, amtHi) + ATOMIC_CMP_SWAP_B128 = FIRST_MEMORY_OPCODE, + ATOMIC_SWAP_B128, + + LoadV2, LoadV4, LoadV8, LDUV2, // LDU.v2 @@ -309,11 +319,8 @@ private: SDValue LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const; SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const; - SDValue LowerLOADi1(SDValue Op, SelectionDAG &DAG) const; - SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const; SDValue LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const; - SDValue LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const; SDValue LowerShiftRightParts(SDValue Op, SelectionDAG &DAG) const; SDValue LowerShiftLeftParts(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp index 34fe467c9456..6840c7ae8faf 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp @@ -12,6 +12,7 @@ #include "NVPTXInstrInfo.h" #include "NVPTX.h" +#include "NVPTXSubtarget.h" #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineRegisterInfo.h" @@ -24,7 +25,8 @@ using namespace llvm; // Pin the vtable to this file. void NVPTXInstrInfo::anchor() {} -NVPTXInstrInfo::NVPTXInstrInfo() : RegInfo() {} +NVPTXInstrInfo::NVPTXInstrInfo(const NVPTXSubtarget &STI) + : NVPTXGenInstrInfo(STI), RegInfo() {} void NVPTXInstrInfo::copyPhysReg(MachineBasicBlock &MBB, MachineBasicBlock::iterator I, @@ -190,4 +192,4 @@ unsigned NVPTXInstrInfo::insertBranch(MachineBasicBlock &MBB, BuildMI(&MBB, DL, get(NVPTX::CBranch)).add(Cond[0]).addMBB(TBB); BuildMI(&MBB, DL, get(NVPTX::GOTO)).addMBB(FBB); return 2; -}
\ No newline at end of file +} diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.h b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.h index 4e9dc9d3b468..23889531431e 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.h +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.h @@ -21,12 +21,13 @@ #include "NVPTXGenInstrInfo.inc" namespace llvm { +class NVPTXSubtarget; class NVPTXInstrInfo : public NVPTXGenInstrInfo { const NVPTXRegisterInfo RegInfo; virtual void anchor(); public: - explicit NVPTXInstrInfo(); + explicit NVPTXInstrInfo(const NVPTXSubtarget &STI); const NVPTXRegisterInfo &getRegisterInfo() const { return RegInfo; } diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index 7b135098bd4c..4e38e026e6bd 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -104,6 +104,7 @@ def hasAtomAddF64 : Predicate<"Subtarget->hasAtomAddF64()">; def hasAtomScope : Predicate<"Subtarget->hasAtomScope()">; def hasAtomBitwise64 : Predicate<"Subtarget->hasAtomBitwise64()">; def hasAtomMinMax64 : Predicate<"Subtarget->hasAtomMinMax64()">; +def hasAtomSwap128 : Predicate<"Subtarget->hasAtomSwap128()">; def hasClusters : Predicate<"Subtarget->hasClusters()">; def hasPTXASUnreachableBug : Predicate<"Subtarget->hasPTXASUnreachableBug()">; def noPTXASUnreachableBug : Predicate<"!Subtarget->hasPTXASUnreachableBug()">; @@ -294,7 +295,7 @@ multiclass ADD_SUB_INT_CARRY<string op_str, SDNode op_node, bit commutative> { // // Also defines ftz (flush subnormal inputs and results to sign-preserving // zero) variants for fp32 functions. -multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDNode OpNode> { +multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDPatternOperator OpNode> { defvar nan_str = !if(NaN, ".NaN", ""); if !not(NaN) then { def _f64_rr : @@ -898,10 +899,8 @@ let Predicates = [hasOptEnabled] in { defm MAD_LO_S32 : MADInst<"lo.s32", mul, I32RT, I32RT>; defm MAD_LO_S64 : MADInst<"lo.s64", mul, I64RT, I64RT>; - defm MAD_WIDE_U16 : MADInst<"wide.u16", umul_wide, I32RT, I16RT>; - defm MAD_WIDE_S16 : MADInst<"wide.s16", smul_wide, I32RT, I16RT>; - defm MAD_WIDE_U32 : MADInst<"wide.u32", umul_wide, I64RT, I32RT>; - defm MAD_WIDE_S32 : MADInst<"wide.s32", smul_wide, I64RT, I32RT>; + // Generating mad.wide causes a regression: + // https://github.com/llvm/llvm-project/pull/150477#issuecomment-3191367837 } //----------------------------------- @@ -912,8 +911,15 @@ defm FADD : F3_fma_component<"add", fadd>; defm FSUB : F3_fma_component<"sub", fsub>; defm FMUL : F3_fma_component<"mul", fmul>; -defm MIN : FMINIMUMMAXIMUM<"min", /* NaN */ false, fminnum>; -defm MAX : FMINIMUMMAXIMUM<"max", /* NaN */ false, fmaxnum>; +def fminnum_or_fminimumnum : PatFrags<(ops node:$a, node:$b), + [(fminnum node:$a, node:$b), + (fminimumnum node:$a, node:$b)]>; +def fmaxnum_or_fmaximumnum : PatFrags<(ops node:$a, node:$b), + [(fmaxnum node:$a, node:$b), + (fmaximumnum node:$a, node:$b)]>; + +defm MIN : FMINIMUMMAXIMUM<"min", /* NaN */ false, fminnum_or_fminimumnum>; +defm MAX : FMINIMUMMAXIMUM<"max", /* NaN */ false, fmaxnum_or_fmaximumnum>; defm MIN_NAN : FMINIMUMMAXIMUM<"min", /* NaN */ true, fminimum>; defm MAX_NAN : FMINIMUMMAXIMUM<"max", /* NaN */ true, fmaximum>; diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 4ab30a5b5f5e..c544911bdf1e 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -1990,19 +1990,23 @@ multiclass F_ATOMIC_3<RegTyInfo t, string op_str, SDPatternOperator op, SDNode a let mayLoad = 1, mayStore = 1, hasSideEffects = 1 in { def _rr : BasicFlagsNVPTXInst<(outs t.RC:$dst), - (ins ADDR:$addr, t.RC:$b, t.RC:$c), (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp), + (ins ADDR:$addr, t.RC:$b, t.RC:$c), + (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp), asm_str>; def _ir : BasicFlagsNVPTXInst<(outs t.RC:$dst), - (ins ADDR:$addr, t.Imm:$b, t.RC:$c), (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp), + (ins ADDR:$addr, t.Imm:$b, t.RC:$c), + (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp), asm_str>; def _ri : BasicFlagsNVPTXInst<(outs t.RC:$dst), - (ins ADDR:$addr, t.RC:$b, t.Imm:$c), (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp), + (ins ADDR:$addr, t.RC:$b, t.Imm:$c), + (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp), asm_str>; def _ii : BasicFlagsNVPTXInst<(outs t.RC:$dst), - (ins ADDR:$addr, t.Imm:$b, t.Imm:$c), (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp), + (ins ADDR:$addr, t.Imm:$b, t.Imm:$c), + (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp), asm_str>; } @@ -2200,6 +2204,37 @@ defm INT_PTX_SATOM_MIN : ATOM2_minmax_impl<"min">; defm INT_PTX_SATOM_OR : ATOM2_bitwise_impl<"or">; defm INT_PTX_SATOM_XOR : ATOM2_bitwise_impl<"xor">; +// atom.*.b128 + +let mayLoad = true, mayStore = true, hasSideEffects = true, + Predicates = [hasAtomSwap128] in { + def ATOM_CAS_B128 : + NVPTXInst< + (outs B64:$dst0, B64:$dst1), + (ins ADDR:$addr, B64:$cmp0, B64:$cmp1, B64:$swap0, B64:$swap1, + AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp), + "{{\n\t" + ".reg .b128 cmp, swap, dst;\n\t" + "mov.b128 cmp, {$cmp0, $cmp1};\n\t" + "mov.b128 swap, {$swap0, $swap1};\n\t" + "atom${sem:sem}${scope:scope}${addsp:addsp}.cas.b128 dst, [$addr], cmp, swap;\n\t" + "mov.b128 {$dst0, $dst1}, dst;\n\t" + "}}">; + + def ATOM_EXCH_B128 : + NVPTXInst< + (outs B64:$dst0, B64:$dst1), + (ins ADDR:$addr, B64:$amt0, B64:$amt1, + AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp), + "{{\n\t" + ".reg .b128 amt, dst;\n\t" + "mov.b128 amt, {$amt0, $amt1};\n\t" + "atom${sem:sem}${scope:scope}${addsp:addsp}.exch.b128 dst, [$addr], amt;\n\t" + "mov.b128 {$dst0, $dst1}, dst;\n\t" + "}}">; +} + + //----------------------------------- // Support for ldu on sm_20 or later //----------------------------------- @@ -4358,10 +4393,12 @@ let hasSideEffects = 1 in { def SREG_CLOCK : PTX_READ_SREG_R32<"clock", int_nvvm_read_ptx_sreg_clock>; def SREG_CLOCK64 : PTX_READ_SREG_R64<"clock64", int_nvvm_read_ptx_sreg_clock64>; def SREG_GLOBALTIMER : PTX_READ_SREG_R64<"globaltimer", int_nvvm_read_ptx_sreg_globaltimer>; + def SREG_GLOBALTIMER_LO : PTX_READ_SREG_R32<"globaltimer_lo", int_nvvm_read_ptx_sreg_globaltimer_lo>; } def: Pat <(i64 (readcyclecounter)), (SREG_CLOCK64)>; def: Pat <(i64 (readsteadycounter)), (SREG_GLOBALTIMER)>; +def: Pat <(i32 (readsteadycounter)), (SREG_GLOBALTIMER_LO)>; def INT_PTX_SREG_PM0 : PTX_READ_SREG_R32<"pm0", int_nvvm_read_ptx_sreg_pm0>; def INT_PTX_SREG_PM1 : PTX_READ_SREG_R32<"pm1", int_nvvm_read_ptx_sreg_pm1>; diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp b/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp index a84ceaba991c..c5489670bd24 100644 --- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp @@ -62,7 +62,7 @@ NVPTXSubtarget::NVPTXSubtarget(const Triple &TT, const std::string &CPU, const NVPTXTargetMachine &TM) : NVPTXGenSubtargetInfo(TT, CPU, /*TuneCPU*/ CPU, FS), PTXVersion(0), FullSmVersion(200), SmVersion(getSmVersion()), - TLInfo(TM, initializeSubtargetDependencies(CPU, FS)) { + InstrInfo(initializeSubtargetDependencies(CPU, FS)), TLInfo(TM, *this) { TSInfo = std::make_unique<NVPTXSelectionDAGInfo>(); } diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h index acf025b70ce3..0a77a633cb25 100644 --- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h +++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h @@ -82,6 +82,7 @@ public: bool hasAtomBitwise64() const { return SmVersion >= 32; } bool hasAtomMinMax64() const { return SmVersion >= 32; } bool hasAtomCas16() const { return SmVersion >= 70 && PTXVersion >= 63; } + bool hasAtomSwap128() const { return SmVersion >= 90 && PTXVersion >= 83; } bool hasClusters() const { return SmVersion >= 90 && PTXVersion >= 78; } bool hasLDG() const { return SmVersion >= 32; } bool hasHWROT32() const { return SmVersion >= 32; } @@ -105,6 +106,7 @@ public: // Tcgen05 instructions in Blackwell family bool hasTcgen05Instructions() const { bool HasTcgen05 = false; + unsigned MinPTXVersion = 86; switch (FullSmVersion) { default: break; @@ -112,9 +114,13 @@ public: case 1013: // sm_101a HasTcgen05 = true; break; + case 1033: // sm_103a + HasTcgen05 = true; + MinPTXVersion = 88; + break; } - return HasTcgen05 && PTXVersion >= 86; + return HasTcgen05 && PTXVersion >= MinPTXVersion; } // f32x2 instructions in Blackwell family bool hasF32x2Instructions() const; diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp index 0603994606d7..833f014a4c87 100644 --- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp @@ -126,12 +126,12 @@ static std::string computeDataLayout(bool is64Bit, bool UseShortPointers) { // (addrspace:3). if (!is64Bit) Ret += "-p:32:32-p6:32:32-p7:32:32"; - else if (UseShortPointers) { + else if (UseShortPointers) Ret += "-p3:32:32-p4:32:32-p5:32:32-p6:32:32-p7:32:32"; - } else + else Ret += "-p6:32:32"; - Ret += "-i64:64-i128:128-v16:16-v32:32-n16:32:64"; + Ret += "-i64:64-i128:128-i256:256-v16:16-v32:32-n16:32:64"; return Ret; } diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp index 274b04fdd30b..8e97b422218f 100644 --- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp @@ -55,15 +55,6 @@ void clearAnnotationCache(const Module *Mod) { AC.Cache.erase(Mod); } -static void readIntVecFromMDNode(const MDNode *MetadataNode, - std::vector<unsigned> &Vec) { - for (unsigned i = 0, e = MetadataNode->getNumOperands(); i != e; ++i) { - ConstantInt *Val = - mdconst::extract<ConstantInt>(MetadataNode->getOperand(i)); - Vec.push_back(Val->getZExtValue()); - } -} - static void cacheAnnotationFromMD(const MDNode *MetadataNode, key_val_pair_t &retval) { auto &AC = getAnnotationCache(); @@ -83,19 +74,8 @@ static void cacheAnnotationFromMD(const MDNode *MetadataNode, if (ConstantInt *Val = mdconst::dyn_extract<ConstantInt>( MetadataNode->getOperand(i + 1))) { retval[Key].push_back(Val->getZExtValue()); - } else if (MDNode *VecMd = - dyn_cast<MDNode>(MetadataNode->getOperand(i + 1))) { - // note: only "grid_constant" annotations support vector MDNodes. - // assert: there can only exist one unique key value pair of - // the form (string key, MDNode node). Operands of such a node - // shall always be unsigned ints. - auto [It, Inserted] = retval.try_emplace(Key); - if (Inserted) { - readIntVecFromMDNode(VecMd, It->second); - continue; - } } else { - llvm_unreachable("Value operand not a constant int or an mdnode"); + llvm_unreachable("Value operand not a constant int"); } } } @@ -179,16 +159,13 @@ static bool globalHasNVVMAnnotation(const Value &V, const std::string &Prop) { } static bool argHasNVVMAnnotation(const Value &Val, - const std::string &Annotation, - const bool StartArgIndexAtOne = false) { + const std::string &Annotation) { if (const Argument *Arg = dyn_cast<Argument>(&Val)) { const Function *Func = Arg->getParent(); std::vector<unsigned> Annot; if (findAllNVVMAnnotation(Func, Annotation, Annot)) { - const unsigned BaseOffset = StartArgIndexAtOne ? 1 : 0; - if (is_contained(Annot, BaseOffset + Arg->getArgNo())) { + if (is_contained(Annot, Arg->getArgNo())) return true; - } } } return false; @@ -250,8 +227,7 @@ bool isParamGridConstant(const Argument &Arg) { } // "grid_constant" counts argument indices starting from 1 - if (argHasNVVMAnnotation(Arg, "grid_constant", - /*StartArgIndexAtOne*/ true)) + if (Arg.hasAttribute("nvvm.grid_constant")) return true; return false; |
