summaryrefslogtreecommitdiff
path: root/llvm/lib/Target/NVPTX
diff options
context:
space:
mode:
authorMingming Liu <mingmingl@google.com>2025-09-10 15:25:31 -0700
committerGitHub <noreply@github.com>2025-09-10 15:25:31 -0700
commit1417dafa1db9cb1b2b09438aa9f53ea5ab6e36e2 (patch)
tree57f4b1f313c8cf74eed8819870f39c36ea263c68 /llvm/lib/Target/NVPTX
parent898b813bc8a6d0276bf0f4769f5f2f64b34e632d (diff)
parentb8cefcb601ddaa18482555c4ff363c01a270c2fe (diff)
Merge branch 'main' into users/mingmingl-llvm/samplefdo-profile-formatusers/mingmingl-llvm/samplefdo-profile-format
Diffstat (limited to 'llvm/lib/Target/NVPTX')
-rw-r--r--llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp3
-rw-r--r--llvm/lib/Target/NVPTX/NVPTX.td10
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp13
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp58
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h1
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp503
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelLowering.h15
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp6
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXInstrInfo.h3
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXInstrInfo.td20
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXIntrinsics.td45
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp2
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXSubtarget.h8
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp6
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXUtilities.cpp32
15 files changed, 423 insertions, 302 deletions
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index ee1ca4538554..f9bdc0993533 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -290,7 +290,8 @@ void NVPTXInstPrinter::printAtomicCode(const MCInst *MI, int OpNum,
O << ".acq_rel";
return;
case NVPTX::Ordering::SequentiallyConsistent:
- O << ".seq_cst";
+ report_fatal_error(
+ "NVPTX AtomicCode Printer does not support \"seq_cst\" ordering.");
return;
case NVPTX::Ordering::Volatile:
O << ".volatile";
diff --git a/llvm/lib/Target/NVPTX/NVPTX.td b/llvm/lib/Target/NVPTX/NVPTX.td
index 8a445f82e700..31c117a8c0fe 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.td
+++ b/llvm/lib/Target/NVPTX/NVPTX.td
@@ -80,9 +80,9 @@ class FeaturePTX<int version>:
// + Compare within the family by comparing FullSMVersion, given both belongs to
// the same family.
// + Detect 'a' variants by checking FullSMVersion & 1.
-foreach sm = [20, 21, 30, 32, 35, 37, 50, 52, 53,
- 60, 61, 62, 70, 72, 75, 80, 86, 87,
- 89, 90, 100, 101, 103, 120, 121] in {
+foreach sm = [20, 21, 30, 32, 35, 37, 50, 52, 53, 60,
+ 61, 62, 70, 72, 75, 80, 86, 87, 88, 89,
+ 90, 100, 101, 103, 110, 120, 121] in {
// Base SM version (e.g. FullSMVersion for sm_100 is 1000)
def SM#sm : FeatureSM<""#sm, !mul(sm, 10)>;
@@ -127,6 +127,7 @@ def : Proc<"sm_75", [SM75, PTX63]>;
def : Proc<"sm_80", [SM80, PTX70]>;
def : Proc<"sm_86", [SM86, PTX71]>;
def : Proc<"sm_87", [SM87, PTX74]>;
+def : Proc<"sm_88", [SM88, PTX90]>;
def : Proc<"sm_89", [SM89, PTX78]>;
def : Proc<"sm_90", [SM90, PTX78]>;
def : Proc<"sm_90a", [SM90a, PTX80]>;
@@ -139,6 +140,9 @@ def : Proc<"sm_101f", [SM101f, PTX88]>;
def : Proc<"sm_103", [SM103, PTX88]>;
def : Proc<"sm_103a", [SM103a, PTX88]>;
def : Proc<"sm_103f", [SM103f, PTX88]>;
+def : Proc<"sm_110", [SM110, PTX90]>;
+def : Proc<"sm_110a", [SM110a, PTX90]>;
+def : Proc<"sm_110f", [SM110f, PTX90]>;
def : Proc<"sm_120", [SM120, PTX87]>;
def : Proc<"sm_120a", [SM120a, PTX87]>;
def : Proc<"sm_120f", [SM120f, PTX88]>;
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 7391c2d488b5..14ca867023e2 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -432,7 +432,7 @@ void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
// .maxclusterrank directive requires SM_90 or higher, make sure that we
// filter it out for lower SM versions, as it causes a hard ptxas crash.
const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
- const auto *STI = static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
+ const NVPTXSubtarget *STI = &NTM.getSubtarget<NVPTXSubtarget>(F);
if (STI->getSmVersion() >= 90) {
const auto ClusterDim = getClusterDim(F);
@@ -669,7 +669,7 @@ void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) {
// rest of NVPTX isn't friendly to change subtargets per function and
// so the default TargetMachine will have all of the options.
const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
- const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl());
+ const NVPTXSubtarget *STI = NTM.getSubtargetImpl();
SmallString<128> Str1;
raw_svector_ostream OS1(Str1);
@@ -680,8 +680,7 @@ void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) {
bool NVPTXAsmPrinter::doInitialization(Module &M) {
const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
- const NVPTXSubtarget &STI =
- *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
+ const NVPTXSubtarget &STI = *NTM.getSubtargetImpl();
if (M.alias_size() && (STI.getPTXVersion() < 63 || STI.getSmVersion() < 30))
report_fatal_error(".alias requires PTX version >= 6.3 and sm_30");
@@ -716,8 +715,7 @@ void NVPTXAsmPrinter::emitGlobals(const Module &M) {
assert(GVVisiting.size() == 0 && "Did not fully process a global variable");
const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
- const NVPTXSubtarget &STI =
- *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
+ const NVPTXSubtarget &STI = *NTM.getSubtargetImpl();
// Print out module-level global variables in proper order
for (const GlobalVariable *GV : Globals)
@@ -1178,8 +1176,7 @@ void NVPTXAsmPrinter::emitDemotedVars(const Function *F, raw_ostream &O) {
ArrayRef<const GlobalVariable *> GVars = It->second;
const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
- const NVPTXSubtarget &STI =
- *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
+ const NVPTXSubtarget &STI = *NTM.getSubtargetImpl();
for (const GlobalVariable *GV : GVars) {
O << "\t// demoted variable\n\t";
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 3300ed9a5a81..c70f48af33cf 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -170,6 +170,10 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
}
break;
}
+ case NVPTXISD::ATOMIC_CMP_SWAP_B128:
+ case NVPTXISD::ATOMIC_SWAP_B128:
+ selectAtomicSwap128(N);
+ return;
case ISD::FADD:
case ISD::FMUL:
case ISD::FSUB:
@@ -1097,11 +1101,6 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
if (PlainLoad && PlainLoad->isIndexed())
return false;
- const EVT LoadedEVT = LD->getMemoryVT();
- if (!LoadedEVT.isSimple())
- return false;
- const MVT LoadedVT = LoadedEVT.getSimpleVT();
-
// Address Space Setting
const auto CodeAddrSpace = getAddrSpace(LD);
if (canLowerToLDG(*LD, *Subtarget, CodeAddrSpace))
@@ -1111,7 +1110,7 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
SDValue Chain = N->getOperand(0);
const auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, LD);
- const unsigned FromTypeWidth = LoadedVT.getSizeInBits();
+ const unsigned FromTypeWidth = LD->getMemoryVT().getSizeInBits();
// Vector Setting
const unsigned FromType =
@@ -1165,9 +1164,6 @@ static unsigned getStoreVectorNumElts(SDNode *N) {
bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
MemSDNode *LD = cast<MemSDNode>(N);
- const EVT MemEVT = LD->getMemoryVT();
- if (!MemEVT.isSimple())
- return false;
// Address Space Setting
const auto CodeAddrSpace = getAddrSpace(LD);
@@ -1237,10 +1233,6 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
}
bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
- const EVT LoadedEVT = LD->getMemoryVT();
- if (!LoadedEVT.isSimple())
- return false;
-
SDLoc DL(LD);
unsigned ExtensionType;
@@ -1357,10 +1349,6 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
if (PlainStore && PlainStore->isIndexed())
return false;
- const EVT StoreVT = ST->getMemoryVT();
- if (!StoreVT.isSimple())
- return false;
-
// Address Space Setting
const auto CodeAddrSpace = getAddrSpace(ST);
@@ -1369,7 +1357,7 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
const auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, ST);
// Vector Setting
- const unsigned ToTypeWidth = StoreVT.getSimpleVT().getSizeInBits();
+ const unsigned ToTypeWidth = ST->getMemoryVT().getSizeInBits();
// Create the machine instruction DAG
SDValue Value = PlainStore ? PlainStore->getValue() : AtomicStore->getVal();
@@ -1406,8 +1394,7 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
MemSDNode *ST = cast<MemSDNode>(N);
- const EVT StoreVT = ST->getMemoryVT();
- assert(StoreVT.isSimple() && "Store value is not simple");
+ const unsigned TotalWidth = ST->getMemoryVT().getSizeInBits();
// Address Space Setting
const auto CodeAddrSpace = getAddrSpace(ST);
@@ -1420,10 +1407,6 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
SDValue Chain = ST->getChain();
const auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, ST);
- // Type Setting: toType + toTypeWidth
- // - for integer type, always use 'u'
- const unsigned TotalWidth = StoreVT.getSimpleVT().getSizeInBits();
-
const unsigned NumElts = getStoreVectorNumElts(ST);
SmallVector<SDValue, 16> Ops;
@@ -2337,3 +2320,30 @@ bool NVPTXDAGToDAGISel::tryIntrinsicVoid(SDNode *N) {
}
}
}
+
+void NVPTXDAGToDAGISel::selectAtomicSwap128(SDNode *N) {
+ MemSDNode *AN = cast<MemSDNode>(N);
+ SDLoc dl(N);
+
+ const SDValue Chain = N->getOperand(0);
+ const auto [Base, Offset] = selectADDR(N->getOperand(1), CurDAG);
+ SmallVector<SDValue, 5> Ops{Base, Offset};
+ Ops.append(N->op_begin() + 2, N->op_end());
+ Ops.append({
+ getI32Imm(getMemOrder(AN), dl),
+ getI32Imm(getAtomicScope(AN), dl),
+ getI32Imm(getAddrSpace(AN), dl),
+ Chain,
+ });
+
+ assert(N->getOpcode() == NVPTXISD::ATOMIC_CMP_SWAP_B128 ||
+ N->getOpcode() == NVPTXISD::ATOMIC_SWAP_B128);
+ unsigned Opcode = N->getOpcode() == NVPTXISD::ATOMIC_SWAP_B128
+ ? NVPTX::ATOM_EXCH_B128
+ : NVPTX::ATOM_CAS_B128;
+
+ auto *ATOM = CurDAG->getMachineNode(Opcode, dl, N->getVTList(), Ops);
+ CurDAG->setNodeMemRefs(ATOM, AN->getMemOperand());
+
+ ReplaceNode(N, ATOM);
+}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index e2ad55bc1796..8dcd5362c451 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -90,6 +90,7 @@ private:
bool IsIm2Col = false);
void SelectTcgen05Ld(SDNode *N, bool hasOffset = false);
void SelectTcgen05St(SDNode *N, bool hasOffset = false);
+ void selectAtomicSwap128(SDNode *N);
inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) {
return CurDAG->getTargetConstant(Imm, DL, MVT::i32);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index bb4bb1195f78..d3fb657851fe 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -198,6 +198,12 @@ static bool IsPTXVectorType(MVT VT) {
static std::optional<std::pair<unsigned int, MVT>>
getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
unsigned AddressSpace) {
+ const bool CanLowerTo256Bit = STI.has256BitVectorLoadStore(AddressSpace);
+
+ if (CanLowerTo256Bit && VectorEVT.isScalarInteger() &&
+ VectorEVT.getSizeInBits() == 256)
+ return {{4, MVT::i64}};
+
if (!VectorEVT.isSimple())
return std::nullopt;
const MVT VectorVT = VectorEVT.getSimpleVT();
@@ -214,8 +220,6 @@ getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
// The size of the PTX virtual register that holds a packed type.
unsigned PackRegSize;
- bool CanLowerTo256Bit = STI.has256BitVectorLoadStore(AddressSpace);
-
// We only handle "native" vector sizes for now, e.g. <4 x double> is not
// legal. We can (and should) split that into 2 stores of <2 x double> here
// but I'm leaving that as a TODO for now.
@@ -539,6 +543,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
case ISD::FMINNUM_IEEE:
case ISD::FMAXIMUM:
case ISD::FMINIMUM:
+ case ISD::FMAXIMUMNUM:
+ case ISD::FMINIMUMNUM:
IsOpSupported &= STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
break;
case ISD::FEXP2:
@@ -702,57 +708,66 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// intrinsics.
setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
- // Turn FP extload into load/fpextend
- setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
- setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand);
- setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand);
- setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand);
- setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
- setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2f16, Expand);
- setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f16, Expand);
- setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2bf16, Expand);
- setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2bf16, Expand);
- setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f32, Expand);
- setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4f16, Expand);
- setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f16, Expand);
- setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4bf16, Expand);
- setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4bf16, Expand);
- setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f32, Expand);
- setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8f16, Expand);
- setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8f16, Expand);
- setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8bf16, Expand);
- setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8bf16, Expand);
- // Turn FP truncstore into trunc + store.
- // FIXME: vector types should also be expanded
- setTruncStoreAction(MVT::f32, MVT::f16, Expand);
- setTruncStoreAction(MVT::f64, MVT::f16, Expand);
- setTruncStoreAction(MVT::f32, MVT::bf16, Expand);
- setTruncStoreAction(MVT::f64, MVT::bf16, Expand);
- setTruncStoreAction(MVT::f64, MVT::f32, Expand);
- setTruncStoreAction(MVT::v2f32, MVT::v2f16, Expand);
- setTruncStoreAction(MVT::v2f32, MVT::v2bf16, Expand);
+ // FP extload/truncstore is not legal in PTX. We need to expand all these.
+ for (auto FloatVTs :
+ {MVT::fp_valuetypes(), MVT::fp_fixedlen_vector_valuetypes()}) {
+ for (MVT ValVT : FloatVTs) {
+ for (MVT MemVT : FloatVTs) {
+ setLoadExtAction(ISD::EXTLOAD, ValVT, MemVT, Expand);
+ setTruncStoreAction(ValVT, MemVT, Expand);
+ }
+ }
+ }
- // PTX does not support load / store predicate registers
- setOperationAction(ISD::LOAD, MVT::i1, Custom);
- setOperationAction(ISD::STORE, MVT::i1, Custom);
+ // To improve CodeGen we'll legalize any-extend loads to zext loads. This is
+ // how they'll be lowered in ISel anyway, and by doing this a little earlier
+ // we allow for more DAG combine opportunities.
+ for (auto IntVTs :
+ {MVT::integer_valuetypes(), MVT::integer_fixedlen_vector_valuetypes()})
+ for (MVT ValVT : IntVTs)
+ for (MVT MemVT : IntVTs)
+ if (isTypeLegal(ValVT))
+ setLoadExtAction(ISD::EXTLOAD, ValVT, MemVT, Custom);
+ // PTX does not support load / store predicate registers
+ setOperationAction({ISD::LOAD, ISD::STORE}, MVT::i1, Custom);
for (MVT VT : MVT::integer_valuetypes()) {
- setLoadExtAction(ISD::SEXTLOAD, VT, MVT::i1, Promote);
- setLoadExtAction(ISD::ZEXTLOAD, VT, MVT::i1, Promote);
- setLoadExtAction(ISD::EXTLOAD, VT, MVT::i1, Promote);
+ setLoadExtAction({ISD::SEXTLOAD, ISD::ZEXTLOAD, ISD::EXTLOAD}, VT, MVT::i1,
+ Promote);
setTruncStoreAction(VT, MVT::i1, Expand);
}
+ // Disable generations of extload/truncstore for v2i16/v2i8. The generic
+ // expansion for these nodes when they are unaligned is incorrect if the
+ // type is a vector.
+ //
+ // TODO: Fix the generic expansion for these nodes found in
+ // TargetLowering::expandUnalignedLoad/Store.
+ setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::v2i16,
+ MVT::v2i8, Expand);
+ setTruncStoreAction(MVT::v2i16, MVT::v2i8, Expand);
+
+ // Register custom handling for illegal type loads/stores. We'll try to custom
+ // lower almost all illegal types and logic in the lowering will discard cases
+ // we can't handle.
+ setOperationAction({ISD::LOAD, ISD::STORE}, {MVT::i128, MVT::f128}, Custom);
+ for (MVT VT : MVT::fixedlen_vector_valuetypes())
+ if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256)
+ setOperationAction({ISD::STORE, ISD::LOAD}, VT, Custom);
+
+ // Custom legalization for LDU intrinsics.
+ // TODO: The logic to lower these is not very robust and we should rewrite it.
+ // Perhaps LDU should not be represented as an intrinsic at all.
+ setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom);
+ for (MVT VT : MVT::fixedlen_vector_valuetypes())
+ if (IsPTXVectorType(VT))
+ setOperationAction(ISD::INTRINSIC_W_CHAIN, VT, Custom);
+
setCondCodeAction({ISD::SETNE, ISD::SETEQ, ISD::SETUGE, ISD::SETULE,
ISD::SETUGT, ISD::SETULT, ISD::SETGT, ISD::SETLT,
ISD::SETGE, ISD::SETLE},
MVT::i1, Expand);
- // expand extload of vector of integers.
- setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::v2i16,
- MVT::v2i8, Expand);
- setTruncStoreAction(MVT::v2i16, MVT::v2i8, Expand);
-
// This is legal in NVPTX
setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
@@ -767,24 +782,12 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// DEBUGTRAP can be lowered to PTX brkpt
setOperationAction(ISD::DEBUGTRAP, MVT::Other, Legal);
- // Register custom handling for vector loads/stores
- for (MVT VT : MVT::fixedlen_vector_valuetypes())
- if (IsPTXVectorType(VT))
- setOperationAction({ISD::LOAD, ISD::STORE, ISD::INTRINSIC_W_CHAIN}, VT,
- Custom);
-
- setOperationAction({ISD::LOAD, ISD::STORE, ISD::INTRINSIC_W_CHAIN},
- {MVT::i128, MVT::f128}, Custom);
-
// Support varargs.
setOperationAction(ISD::VASTART, MVT::Other, Custom);
setOperationAction(ISD::VAARG, MVT::Other, Custom);
setOperationAction(ISD::VACOPY, MVT::Other, Expand);
setOperationAction(ISD::VAEND, MVT::Other, Expand);
- // Custom handling for i8 intrinsics
- setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom);
-
setOperationAction({ISD::ABS, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX},
{MVT::i16, MVT::i32, MVT::i64}, Legal);
@@ -988,7 +991,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
if (getOperationAction(ISD::FABS, MVT::bf16) == Promote)
AddPromotedToType(ISD::FABS, MVT::bf16, MVT::f32);
- for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) {
+ for (const auto &Op :
+ {ISD::FMINNUM, ISD::FMAXNUM, ISD::FMINIMUMNUM, ISD::FMAXIMUMNUM}) {
setOperationAction(Op, MVT::f32, Legal);
setOperationAction(Op, MVT::f64, Legal);
setFP16OperationAction(Op, MVT::f16, Legal, Promote);
@@ -1039,7 +1043,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::ADDRSPACECAST, {MVT::i32, MVT::i64}, Custom);
setOperationAction(ISD::ATOMIC_LOAD_SUB, {MVT::i32, MVT::i64}, Expand);
- // No FPOW or FREM in PTX.
+
+ // atom.b128 is legal in PTX but since we don't represent i128 as a legal
+ // type, we need to custom lower it.
+ setOperationAction({ISD::ATOMIC_CMP_SWAP, ISD::ATOMIC_SWAP}, MVT::i128,
+ Custom);
// Now deduce the information based on the above mentioned
// actions
@@ -1047,7 +1055,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// PTX support for 16-bit CAS is emulated. Only use 32+
setMinCmpXchgSizeInBits(STI.getMinCmpXchgSizeInBits());
- setMaxAtomicSizeInBitsSupported(64);
+ setMaxAtomicSizeInBitsSupported(STI.hasAtomSwap128() ? 128 : 64);
setMaxDivRemBitWidthSupported(64);
// Custom lowering for tcgen05.ld vector operands
@@ -1080,6 +1088,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
case NVPTXISD::FIRST_NUMBER:
break;
+ MAKE_CASE(NVPTXISD::ATOMIC_CMP_SWAP_B128)
+ MAKE_CASE(NVPTXISD::ATOMIC_SWAP_B128)
MAKE_CASE(NVPTXISD::RET_GLUE)
MAKE_CASE(NVPTXISD::DeclareArrayParam)
MAKE_CASE(NVPTXISD::DeclareScalarParam)
@@ -3088,29 +3098,112 @@ SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
MachinePointerInfo(SV));
}
-static void replaceLoadVector(SDNode *N, SelectionDAG &DAG,
- SmallVectorImpl<SDValue> &Results,
- const NVPTXSubtarget &STI);
+/// replaceLoadVector - Convert vector loads into multi-output scalar loads.
+static std::optional<std::pair<SDValue, SDValue>>
+replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
+ LoadSDNode *LD = cast<LoadSDNode>(N);
+ const EVT ResVT = LD->getValueType(0);
+ const EVT MemVT = LD->getMemoryVT();
-SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
- if (Op.getValueType() == MVT::i1)
- return LowerLOADi1(Op, DAG);
+ // If we're doing sign/zero extension as part of the load, avoid lowering to
+ // a LoadV node. TODO: consider relaxing this restriction.
+ if (ResVT != MemVT)
+ return std::nullopt;
- EVT VT = Op.getValueType();
+ const auto NumEltsAndEltVT =
+ getVectorLoweringShape(ResVT, STI, LD->getAddressSpace());
+ if (!NumEltsAndEltVT)
+ return std::nullopt;
+ const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
+
+ Align Alignment = LD->getAlign();
+ const auto &TD = DAG.getDataLayout();
+ Align PrefAlign = TD.getPrefTypeAlign(MemVT.getTypeForEVT(*DAG.getContext()));
+ if (Alignment < PrefAlign) {
+ // This load is not sufficiently aligned, so bail out and let this vector
+ // load be scalarized. Note that we may still be able to emit smaller
+ // vector loads. For example, if we are loading a <4 x float> with an
+ // alignment of 8, this check will fail but the legalizer will try again
+ // with 2 x <2 x float>, which will succeed with an alignment of 8.
+ return std::nullopt;
+ }
+
+ // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
+ // Therefore, we must ensure the type is legal. For i1 and i8, we set the
+ // loaded type to i16 and propagate the "real" type as the memory type.
+ const MVT LoadEltVT = (EltVT.getSizeInBits() < 16) ? MVT::i16 : EltVT;
+
+ unsigned Opcode;
+ switch (NumElts) {
+ default:
+ return std::nullopt;
+ case 2:
+ Opcode = NVPTXISD::LoadV2;
+ break;
+ case 4:
+ Opcode = NVPTXISD::LoadV4;
+ break;
+ case 8:
+ Opcode = NVPTXISD::LoadV8;
+ break;
+ }
+ auto ListVTs = SmallVector<EVT, 9>(NumElts, LoadEltVT);
+ ListVTs.push_back(MVT::Other);
+ SDVTList LdResVTs = DAG.getVTList(ListVTs);
- if (NVPTX::isPackedVectorTy(VT)) {
- // v2f32/v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to
- // handle unaligned loads and have to handle it here.
- LoadSDNode *Load = cast<LoadSDNode>(Op);
- EVT MemVT = Load->getMemoryVT();
- if (!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
- MemVT, *Load->getMemOperand())) {
- SDValue Ops[2];
- std::tie(Ops[0], Ops[1]) = expandUnalignedLoad(Load, DAG);
- return DAG.getMergeValues(Ops, SDLoc(Op));
+ SDLoc DL(LD);
+
+ // Copy regular operands
+ SmallVector<SDValue, 8> OtherOps(LD->ops());
+
+ // The select routine does not have access to the LoadSDNode instance, so
+ // pass along the extension information
+ OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL));
+
+ SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps, MemVT,
+ LD->getMemOperand());
+
+ SmallVector<SDValue> ScalarRes;
+ if (EltVT.isVector()) {
+ assert(EVT(EltVT.getVectorElementType()) == ResVT.getVectorElementType());
+ assert(NumElts * EltVT.getVectorNumElements() ==
+ ResVT.getVectorNumElements());
+ // Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back
+ // into individual elements.
+ for (const unsigned I : llvm::seq(NumElts)) {
+ SDValue SubVector = NewLD.getValue(I);
+ DAG.ExtractVectorElements(SubVector, ScalarRes);
+ }
+ } else {
+ for (const unsigned I : llvm::seq(NumElts)) {
+ SDValue Res = NewLD.getValue(I);
+ if (LoadEltVT != EltVT)
+ Res = DAG.getNode(ISD::TRUNCATE, DL, EltVT, Res);
+ ScalarRes.push_back(Res);
}
}
+ SDValue LoadChain = NewLD.getValue(NumElts);
+
+ const MVT BuildVecVT =
+ MVT::getVectorVT(EltVT.getScalarType(), ScalarRes.size());
+ SDValue BuildVec = DAG.getBuildVector(BuildVecVT, DL, ScalarRes);
+ SDValue LoadValue = DAG.getBitcast(ResVT, BuildVec);
+
+ return {{LoadValue, LoadChain}};
+}
+
+static void replaceLoadVector(SDNode *N, SelectionDAG &DAG,
+ SmallVectorImpl<SDValue> &Results,
+ const NVPTXSubtarget &STI) {
+ if (auto Res = replaceLoadVector(N, DAG, STI))
+ Results.append({Res->first, Res->second});
+}
+
+static SDValue lowerLoadVector(SDNode *N, SelectionDAG &DAG,
+ const NVPTXSubtarget &STI) {
+ if (auto Res = replaceLoadVector(N, DAG, STI))
+ return DAG.getMergeValues({Res->first, Res->second}, SDLoc(N));
return SDValue();
}
@@ -3118,13 +3211,10 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
// =>
// v1 = ld i8* addr (-> i16)
// v = trunc i16 to i1
-SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
- SDNode *Node = Op.getNode();
- LoadSDNode *LD = cast<LoadSDNode>(Node);
- SDLoc dl(Node);
+static SDValue lowerLOADi1(LoadSDNode *LD, SelectionDAG &DAG) {
+ SDLoc dl(LD);
assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
- assert(Node->getValueType(0) == MVT::i1 &&
- "Custom lowering for i1 load only");
+ assert(LD->getValueType(0) == MVT::i1 && "Custom lowering for i1 load only");
SDValue newLD = DAG.getExtLoad(ISD::ZEXTLOAD, dl, MVT::i16, LD->getChain(),
LD->getBasePtr(), LD->getPointerInfo(),
MVT::i8, LD->getAlign(),
@@ -3133,35 +3223,31 @@ SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
// The legalizer (the caller) is expecting two values from the legalized
// load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
// in LegalizeDAG.cpp which also uses MergeValues.
- SDValue Ops[] = { result, LD->getChain() };
- return DAG.getMergeValues(Ops, dl);
+ return DAG.getMergeValues({result, LD->getChain()}, dl);
}
-SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
- StoreSDNode *Store = cast<StoreSDNode>(Op);
- EVT VT = Store->getMemoryVT();
-
- if (VT == MVT::i1)
- return LowerSTOREi1(Op, DAG);
+SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
+ LoadSDNode *LD = cast<LoadSDNode>(Op);
- // v2f32/v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to
- // handle unaligned stores and have to handle it here.
- if (NVPTX::isPackedVectorTy(VT) &&
- !allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
- VT, *Store->getMemOperand()))
- return expandUnalignedStore(Store, DAG);
+ if (Op.getValueType() == MVT::i1)
+ return lowerLOADi1(LD, DAG);
- // v2f16/v2bf16/v2i16 don't need special handling.
- if (NVPTX::isPackedVectorTy(VT) && VT.is32BitVector())
- return SDValue();
+ // To improve CodeGen we'll legalize any-extend loads to zext loads. This is
+ // how they'll be lowered in ISel anyway, and by doing this a little earlier
+ // we allow for more DAG combine opportunities.
+ if (LD->getExtensionType() == ISD::EXTLOAD) {
+ assert(LD->getValueType(0).isInteger() && LD->getMemoryVT().isInteger() &&
+ "Unexpected fpext-load");
+ return DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(Op), Op.getValueType(),
+ LD->getChain(), LD->getBasePtr(), LD->getMemoryVT(),
+ LD->getMemOperand());
+ }
- // Lower store of any other vector type, including v2f32 as we want to break
- // it apart since this is not a widely-supported type.
- return LowerSTOREVector(Op, DAG);
+ llvm_unreachable("Unexpected custom lowering for load");
}
-SDValue
-NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
+static SDValue lowerSTOREVector(SDValue Op, SelectionDAG &DAG,
+ const NVPTXSubtarget &STI) {
MemSDNode *N = cast<MemSDNode>(Op.getNode());
SDValue Val = N->getOperand(1);
SDLoc DL(N);
@@ -3253,6 +3339,18 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
return NewSt;
}
+SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
+ StoreSDNode *Store = cast<StoreSDNode>(Op);
+ EVT VT = Store->getMemoryVT();
+
+ if (VT == MVT::i1)
+ return LowerSTOREi1(Op, DAG);
+
+ // Lower store of any other vector type, including v2f32 as we want to break
+ // it apart since this is not a widely-supported type.
+ return lowerSTOREVector(Op, DAG, STI);
+}
+
// st i1 v, addr
// =>
// v1 = zxt v to i16
@@ -4010,14 +4108,8 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_ldu_global_i:
case Intrinsic::nvvm_ldu_global_f:
case Intrinsic::nvvm_ldu_global_p: {
- auto &DL = I.getDataLayout();
Info.opc = ISD::INTRINSIC_W_CHAIN;
- if (Intrinsic == Intrinsic::nvvm_ldu_global_i)
- Info.memVT = getValueType(DL, I.getType());
- else if(Intrinsic == Intrinsic::nvvm_ldu_global_p)
- Info.memVT = getPointerTy(DL);
- else
- Info.memVT = getValueType(DL, I.getType());
+ Info.memVT = getValueType(I.getDataLayout(), I.getType());
Info.ptrVal = I.getArgOperand(0);
Info.offset = 0;
Info.flags = MachineMemOperand::MOLoad;
@@ -5152,11 +5244,34 @@ static SDValue combinePackingMovIntoStore(SDNode *N,
ST->getMemoryVT(), ST->getMemOperand());
}
-static SDValue PerformStoreCombine(SDNode *N,
- TargetLowering::DAGCombinerInfo &DCI) {
+static SDValue combineSTORE(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+ const NVPTXSubtarget &STI) {
+
+ if (DCI.isBeforeLegalize() && N->getOpcode() == ISD::STORE) {
+ // Here is our chance to custom lower a store with a non-simple type.
+ // Unfortunately, we can't do this in the legalizer because there is no
+ // way to setOperationAction for an non-simple type.
+ StoreSDNode *ST = cast<StoreSDNode>(N);
+ if (!ST->getValue().getValueType().isSimple())
+ return lowerSTOREVector(SDValue(ST, 0), DCI.DAG, STI);
+ }
+
return combinePackingMovIntoStore(N, DCI, 1, 2);
}
+static SDValue combineLOAD(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+ const NVPTXSubtarget &STI) {
+ if (DCI.isBeforeLegalize() && N->getOpcode() == ISD::LOAD) {
+ // Here is our chance to custom lower a load with a non-simple type.
+ // Unfortunately, we can't do this in the legalizer because there is no
+ // way to setOperationAction for an non-simple type.
+ if (!N->getValueType(0).isSimple())
+ return lowerLoadVector(N, DCI.DAG, STI);
+ }
+
+ return combineUnpackingMovIntoLoad(N, DCI);
+}
+
/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
///
static SDValue PerformADDCombine(SDNode *N,
@@ -5884,7 +5999,7 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
case ISD::LOAD:
case NVPTXISD::LoadV2:
case NVPTXISD::LoadV4:
- return combineUnpackingMovIntoLoad(N, DCI);
+ return combineLOAD(N, DCI, STI);
case ISD::MUL:
return PerformMULCombine(N, DCI, OptLevel);
case NVPTXISD::PRMT:
@@ -5901,7 +6016,7 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
case ISD::STORE:
case NVPTXISD::StoreV2:
case NVPTXISD::StoreV4:
- return PerformStoreCombine(N, DCI);
+ return combineSTORE(N, DCI, STI);
case ISD::VSELECT:
return PerformVSELECTCombine(N, DCI);
}
@@ -5930,103 +6045,6 @@ static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2i8, {Vec0, Vec1}));
}
-/// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
-static void replaceLoadVector(SDNode *N, SelectionDAG &DAG,
- SmallVectorImpl<SDValue> &Results,
- const NVPTXSubtarget &STI) {
- LoadSDNode *LD = cast<LoadSDNode>(N);
- const EVT ResVT = LD->getValueType(0);
- const EVT MemVT = LD->getMemoryVT();
-
- // If we're doing sign/zero extension as part of the load, avoid lowering to
- // a LoadV node. TODO: consider relaxing this restriction.
- if (ResVT != MemVT)
- return;
-
- const auto NumEltsAndEltVT =
- getVectorLoweringShape(ResVT, STI, LD->getAddressSpace());
- if (!NumEltsAndEltVT)
- return;
- const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
-
- Align Alignment = LD->getAlign();
- const auto &TD = DAG.getDataLayout();
- Align PrefAlign = TD.getPrefTypeAlign(MemVT.getTypeForEVT(*DAG.getContext()));
- if (Alignment < PrefAlign) {
- // This load is not sufficiently aligned, so bail out and let this vector
- // load be scalarized. Note that we may still be able to emit smaller
- // vector loads. For example, if we are loading a <4 x float> with an
- // alignment of 8, this check will fail but the legalizer will try again
- // with 2 x <2 x float>, which will succeed with an alignment of 8.
- return;
- }
-
- // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
- // Therefore, we must ensure the type is legal. For i1 and i8, we set the
- // loaded type to i16 and propagate the "real" type as the memory type.
- const MVT LoadEltVT = (EltVT.getSizeInBits() < 16) ? MVT::i16 : EltVT;
-
- unsigned Opcode;
- switch (NumElts) {
- default:
- return;
- case 2:
- Opcode = NVPTXISD::LoadV2;
- break;
- case 4:
- Opcode = NVPTXISD::LoadV4;
- break;
- case 8:
- Opcode = NVPTXISD::LoadV8;
- break;
- }
- auto ListVTs = SmallVector<EVT, 9>(NumElts, LoadEltVT);
- ListVTs.push_back(MVT::Other);
- SDVTList LdResVTs = DAG.getVTList(ListVTs);
-
- SDLoc DL(LD);
-
- // Copy regular operands
- SmallVector<SDValue, 8> OtherOps(LD->ops());
-
- // The select routine does not have access to the LoadSDNode instance, so
- // pass along the extension information
- OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL));
-
- SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps,
- LD->getMemoryVT(),
- LD->getMemOperand());
-
- SmallVector<SDValue> ScalarRes;
- if (EltVT.isVector()) {
- assert(EVT(EltVT.getVectorElementType()) == ResVT.getVectorElementType());
- assert(NumElts * EltVT.getVectorNumElements() ==
- ResVT.getVectorNumElements());
- // Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back
- // into individual elements.
- for (const unsigned I : llvm::seq(NumElts)) {
- SDValue SubVector = NewLD.getValue(I);
- DAG.ExtractVectorElements(SubVector, ScalarRes);
- }
- } else {
- for (const unsigned I : llvm::seq(NumElts)) {
- SDValue Res = NewLD.getValue(I);
- if (LoadEltVT != EltVT)
- Res = DAG.getNode(ISD::TRUNCATE, DL, EltVT, Res);
- ScalarRes.push_back(Res);
- }
- }
-
- SDValue LoadChain = NewLD.getValue(NumElts);
-
- const MVT BuildVecVT =
- MVT::getVectorVT(EltVT.getScalarType(), ScalarRes.size());
- SDValue BuildVec = DAG.getBuildVector(BuildVecVT, DL, ScalarRes);
- SDValue LoadValue = DAG.getBitcast(ResVT, BuildVec);
-
- Results.append({LoadValue, LoadChain});
-}
-
// Lower vector return type of tcgen05.ld intrinsics
static void ReplaceTcgen05Ld(SDNode *N, SelectionDAG &DAG,
SmallVectorImpl<SDValue> &Results,
@@ -6262,6 +6280,49 @@ static void replaceProxyReg(SDNode *N, SelectionDAG &DAG,
Results.push_back(Res);
}
+static void replaceAtomicSwap128(SDNode *N, SelectionDAG &DAG,
+ const NVPTXSubtarget &STI,
+ SmallVectorImpl<SDValue> &Results) {
+ assert(N->getValueType(0) == MVT::i128 &&
+ "Custom lowering for atomic128 only supports i128");
+
+ AtomicSDNode *AN = cast<AtomicSDNode>(N);
+ SDLoc dl(N);
+
+ if (!STI.hasAtomSwap128()) {
+ DAG.getContext()->diagnose(DiagnosticInfoUnsupported(
+ DAG.getMachineFunction().getFunction(),
+ "Support for b128 atomics introduced in PTX ISA version 8.3 and "
+ "requires target sm_90.",
+ dl.getDebugLoc()));
+
+ Results.push_back(DAG.getUNDEF(MVT::i128));
+ Results.push_back(AN->getOperand(0)); // Chain
+ return;
+ }
+
+ SmallVector<SDValue, 6> Ops;
+ Ops.push_back(AN->getOperand(0)); // Chain
+ Ops.push_back(AN->getOperand(1)); // Ptr
+ for (const auto &Op : AN->ops().drop_front(2)) {
+ // Low part
+ Ops.push_back(DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i64, Op,
+ DAG.getIntPtrConstant(0, dl)));
+ // High part
+ Ops.push_back(DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i64, Op,
+ DAG.getIntPtrConstant(1, dl)));
+ }
+ unsigned Opcode = N->getOpcode() == ISD::ATOMIC_SWAP
+ ? NVPTXISD::ATOMIC_SWAP_B128
+ : NVPTXISD::ATOMIC_CMP_SWAP_B128;
+ SDVTList Tys = DAG.getVTList(MVT::i64, MVT::i64, MVT::Other);
+ SDValue Result = DAG.getMemIntrinsicNode(Opcode, dl, Tys, Ops, MVT::i128,
+ AN->getMemOperand());
+ Results.push_back(DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i128,
+ {Result.getValue(0), Result.getValue(1)}));
+ Results.push_back(Result.getValue(2));
+}
+
void NVPTXTargetLowering::ReplaceNodeResults(
SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
switch (N->getOpcode()) {
@@ -6282,6 +6343,10 @@ void NVPTXTargetLowering::ReplaceNodeResults(
case NVPTXISD::ProxyReg:
replaceProxyReg(N, DAG, *this, Results);
return;
+ case ISD::ATOMIC_CMP_SWAP:
+ case ISD::ATOMIC_SWAP:
+ replaceAtomicSwap128(N, DAG, STI, Results);
+ return;
}
}
@@ -6306,16 +6371,19 @@ NVPTXTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
}
assert(Ty->isIntegerTy() && "Ty should be integer at this point");
- auto ITy = cast<llvm::IntegerType>(Ty);
+ const unsigned BitWidth = cast<IntegerType>(Ty)->getBitWidth();
switch (AI->getOperation()) {
default:
return AtomicExpansionKind::CmpXChg;
+ case AtomicRMWInst::BinOp::Xchg:
+ if (BitWidth == 128)
+ return AtomicExpansionKind::None;
+ LLVM_FALLTHROUGH;
case AtomicRMWInst::BinOp::And:
case AtomicRMWInst::BinOp::Or:
case AtomicRMWInst::BinOp::Xor:
- case AtomicRMWInst::BinOp::Xchg:
- switch (ITy->getBitWidth()) {
+ switch (BitWidth) {
case 8:
case 16:
return AtomicExpansionKind::CmpXChg;
@@ -6325,6 +6393,8 @@ NVPTXTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
if (STI.hasAtomBitwise64())
return AtomicExpansionKind::None;
return AtomicExpansionKind::CmpXChg;
+ case 128:
+ return AtomicExpansionKind::CmpXChg;
default:
llvm_unreachable("unsupported width encountered");
}
@@ -6334,7 +6404,7 @@ NVPTXTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
case AtomicRMWInst::BinOp::Min:
case AtomicRMWInst::BinOp::UMax:
case AtomicRMWInst::BinOp::UMin:
- switch (ITy->getBitWidth()) {
+ switch (BitWidth) {
case 8:
case 16:
return AtomicExpansionKind::CmpXChg;
@@ -6344,17 +6414,20 @@ NVPTXTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
if (STI.hasAtomMinMax64())
return AtomicExpansionKind::None;
return AtomicExpansionKind::CmpXChg;
+ case 128:
+ return AtomicExpansionKind::CmpXChg;
default:
llvm_unreachable("unsupported width encountered");
}
case AtomicRMWInst::BinOp::UIncWrap:
case AtomicRMWInst::BinOp::UDecWrap:
- switch (ITy->getBitWidth()) {
+ switch (BitWidth) {
case 32:
return AtomicExpansionKind::None;
case 8:
case 16:
case 64:
+ case 128:
return AtomicExpansionKind::CmpXChg;
default:
llvm_unreachable("unsupported width encountered");
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 27f099e22097..03b3edc902e5 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -81,7 +81,17 @@ enum NodeType : unsigned {
CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Z,
FIRST_MEMORY_OPCODE,
- LoadV2 = FIRST_MEMORY_OPCODE,
+
+ /// These nodes are used to lower atomic instructions with i128 type. They are
+ /// similar to the generic nodes, but the input and output values are split
+ /// into two 64-bit values.
+ /// ValLo, ValHi, OUTCHAIN = ATOMIC_CMP_SWAP_B128(INCHAIN, ptr, cmpLo, cmpHi,
+ /// swapLo, swapHi)
+ /// ValLo, ValHi, OUTCHAIN = ATOMIC_SWAP_B128(INCHAIN, ptr, amtLo, amtHi)
+ ATOMIC_CMP_SWAP_B128 = FIRST_MEMORY_OPCODE,
+ ATOMIC_SWAP_B128,
+
+ LoadV2,
LoadV4,
LoadV8,
LDUV2, // LDU.v2
@@ -309,11 +319,8 @@ private:
SDValue LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const;
- SDValue LowerLOADi1(SDValue Op, SelectionDAG &DAG) const;
-
SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const;
- SDValue LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerShiftRightParts(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerShiftLeftParts(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
index 34fe467c9456..6840c7ae8faf 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
@@ -12,6 +12,7 @@
#include "NVPTXInstrInfo.h"
#include "NVPTX.h"
+#include "NVPTXSubtarget.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
@@ -24,7 +25,8 @@ using namespace llvm;
// Pin the vtable to this file.
void NVPTXInstrInfo::anchor() {}
-NVPTXInstrInfo::NVPTXInstrInfo() : RegInfo() {}
+NVPTXInstrInfo::NVPTXInstrInfo(const NVPTXSubtarget &STI)
+ : NVPTXGenInstrInfo(STI), RegInfo() {}
void NVPTXInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
MachineBasicBlock::iterator I,
@@ -190,4 +192,4 @@ unsigned NVPTXInstrInfo::insertBranch(MachineBasicBlock &MBB,
BuildMI(&MBB, DL, get(NVPTX::CBranch)).add(Cond[0]).addMBB(TBB);
BuildMI(&MBB, DL, get(NVPTX::GOTO)).addMBB(FBB);
return 2;
-} \ No newline at end of file
+}
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.h b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.h
index 4e9dc9d3b468..23889531431e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.h
@@ -21,12 +21,13 @@
#include "NVPTXGenInstrInfo.inc"
namespace llvm {
+class NVPTXSubtarget;
class NVPTXInstrInfo : public NVPTXGenInstrInfo {
const NVPTXRegisterInfo RegInfo;
virtual void anchor();
public:
- explicit NVPTXInstrInfo();
+ explicit NVPTXInstrInfo(const NVPTXSubtarget &STI);
const NVPTXRegisterInfo &getRegisterInfo() const { return RegInfo; }
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 7b135098bd4c..4e38e026e6bd 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -104,6 +104,7 @@ def hasAtomAddF64 : Predicate<"Subtarget->hasAtomAddF64()">;
def hasAtomScope : Predicate<"Subtarget->hasAtomScope()">;
def hasAtomBitwise64 : Predicate<"Subtarget->hasAtomBitwise64()">;
def hasAtomMinMax64 : Predicate<"Subtarget->hasAtomMinMax64()">;
+def hasAtomSwap128 : Predicate<"Subtarget->hasAtomSwap128()">;
def hasClusters : Predicate<"Subtarget->hasClusters()">;
def hasPTXASUnreachableBug : Predicate<"Subtarget->hasPTXASUnreachableBug()">;
def noPTXASUnreachableBug : Predicate<"!Subtarget->hasPTXASUnreachableBug()">;
@@ -294,7 +295,7 @@ multiclass ADD_SUB_INT_CARRY<string op_str, SDNode op_node, bit commutative> {
//
// Also defines ftz (flush subnormal inputs and results to sign-preserving
// zero) variants for fp32 functions.
-multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDNode OpNode> {
+multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDPatternOperator OpNode> {
defvar nan_str = !if(NaN, ".NaN", "");
if !not(NaN) then {
def _f64_rr :
@@ -898,10 +899,8 @@ let Predicates = [hasOptEnabled] in {
defm MAD_LO_S32 : MADInst<"lo.s32", mul, I32RT, I32RT>;
defm MAD_LO_S64 : MADInst<"lo.s64", mul, I64RT, I64RT>;
- defm MAD_WIDE_U16 : MADInst<"wide.u16", umul_wide, I32RT, I16RT>;
- defm MAD_WIDE_S16 : MADInst<"wide.s16", smul_wide, I32RT, I16RT>;
- defm MAD_WIDE_U32 : MADInst<"wide.u32", umul_wide, I64RT, I32RT>;
- defm MAD_WIDE_S32 : MADInst<"wide.s32", smul_wide, I64RT, I32RT>;
+ // Generating mad.wide causes a regression:
+ // https://github.com/llvm/llvm-project/pull/150477#issuecomment-3191367837
}
//-----------------------------------
@@ -912,8 +911,15 @@ defm FADD : F3_fma_component<"add", fadd>;
defm FSUB : F3_fma_component<"sub", fsub>;
defm FMUL : F3_fma_component<"mul", fmul>;
-defm MIN : FMINIMUMMAXIMUM<"min", /* NaN */ false, fminnum>;
-defm MAX : FMINIMUMMAXIMUM<"max", /* NaN */ false, fmaxnum>;
+def fminnum_or_fminimumnum : PatFrags<(ops node:$a, node:$b),
+ [(fminnum node:$a, node:$b),
+ (fminimumnum node:$a, node:$b)]>;
+def fmaxnum_or_fmaximumnum : PatFrags<(ops node:$a, node:$b),
+ [(fmaxnum node:$a, node:$b),
+ (fmaximumnum node:$a, node:$b)]>;
+
+defm MIN : FMINIMUMMAXIMUM<"min", /* NaN */ false, fminnum_or_fminimumnum>;
+defm MAX : FMINIMUMMAXIMUM<"max", /* NaN */ false, fmaxnum_or_fmaximumnum>;
defm MIN_NAN : FMINIMUMMAXIMUM<"min", /* NaN */ true, fminimum>;
defm MAX_NAN : FMINIMUMMAXIMUM<"max", /* NaN */ true, fmaximum>;
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 4ab30a5b5f5e..c544911bdf1e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1990,19 +1990,23 @@ multiclass F_ATOMIC_3<RegTyInfo t, string op_str, SDPatternOperator op, SDNode a
let mayLoad = 1, mayStore = 1, hasSideEffects = 1 in {
def _rr : BasicFlagsNVPTXInst<(outs t.RC:$dst),
- (ins ADDR:$addr, t.RC:$b, t.RC:$c), (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
+ (ins ADDR:$addr, t.RC:$b, t.RC:$c),
+ (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
asm_str>;
def _ir : BasicFlagsNVPTXInst<(outs t.RC:$dst),
- (ins ADDR:$addr, t.Imm:$b, t.RC:$c), (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
+ (ins ADDR:$addr, t.Imm:$b, t.RC:$c),
+ (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
asm_str>;
def _ri : BasicFlagsNVPTXInst<(outs t.RC:$dst),
- (ins ADDR:$addr, t.RC:$b, t.Imm:$c), (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
+ (ins ADDR:$addr, t.RC:$b, t.Imm:$c),
+ (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
asm_str>;
def _ii : BasicFlagsNVPTXInst<(outs t.RC:$dst),
- (ins ADDR:$addr, t.Imm:$b, t.Imm:$c), (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
+ (ins ADDR:$addr, t.Imm:$b, t.Imm:$c),
+ (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
asm_str>;
}
@@ -2200,6 +2204,37 @@ defm INT_PTX_SATOM_MIN : ATOM2_minmax_impl<"min">;
defm INT_PTX_SATOM_OR : ATOM2_bitwise_impl<"or">;
defm INT_PTX_SATOM_XOR : ATOM2_bitwise_impl<"xor">;
+// atom.*.b128
+
+let mayLoad = true, mayStore = true, hasSideEffects = true,
+ Predicates = [hasAtomSwap128] in {
+ def ATOM_CAS_B128 :
+ NVPTXInst<
+ (outs B64:$dst0, B64:$dst1),
+ (ins ADDR:$addr, B64:$cmp0, B64:$cmp1, B64:$swap0, B64:$swap1,
+ AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
+ "{{\n\t"
+ ".reg .b128 cmp, swap, dst;\n\t"
+ "mov.b128 cmp, {$cmp0, $cmp1};\n\t"
+ "mov.b128 swap, {$swap0, $swap1};\n\t"
+ "atom${sem:sem}${scope:scope}${addsp:addsp}.cas.b128 dst, [$addr], cmp, swap;\n\t"
+ "mov.b128 {$dst0, $dst1}, dst;\n\t"
+ "}}">;
+
+ def ATOM_EXCH_B128 :
+ NVPTXInst<
+ (outs B64:$dst0, B64:$dst1),
+ (ins ADDR:$addr, B64:$amt0, B64:$amt1,
+ AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
+ "{{\n\t"
+ ".reg .b128 amt, dst;\n\t"
+ "mov.b128 amt, {$amt0, $amt1};\n\t"
+ "atom${sem:sem}${scope:scope}${addsp:addsp}.exch.b128 dst, [$addr], amt;\n\t"
+ "mov.b128 {$dst0, $dst1}, dst;\n\t"
+ "}}">;
+}
+
+
//-----------------------------------
// Support for ldu on sm_20 or later
//-----------------------------------
@@ -4358,10 +4393,12 @@ let hasSideEffects = 1 in {
def SREG_CLOCK : PTX_READ_SREG_R32<"clock", int_nvvm_read_ptx_sreg_clock>;
def SREG_CLOCK64 : PTX_READ_SREG_R64<"clock64", int_nvvm_read_ptx_sreg_clock64>;
def SREG_GLOBALTIMER : PTX_READ_SREG_R64<"globaltimer", int_nvvm_read_ptx_sreg_globaltimer>;
+ def SREG_GLOBALTIMER_LO : PTX_READ_SREG_R32<"globaltimer_lo", int_nvvm_read_ptx_sreg_globaltimer_lo>;
}
def: Pat <(i64 (readcyclecounter)), (SREG_CLOCK64)>;
def: Pat <(i64 (readsteadycounter)), (SREG_GLOBALTIMER)>;
+def: Pat <(i32 (readsteadycounter)), (SREG_GLOBALTIMER_LO)>;
def INT_PTX_SREG_PM0 : PTX_READ_SREG_R32<"pm0", int_nvvm_read_ptx_sreg_pm0>;
def INT_PTX_SREG_PM1 : PTX_READ_SREG_R32<"pm1", int_nvvm_read_ptx_sreg_pm1>;
diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp b/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp
index a84ceaba991c..c5489670bd24 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp
@@ -62,7 +62,7 @@ NVPTXSubtarget::NVPTXSubtarget(const Triple &TT, const std::string &CPU,
const NVPTXTargetMachine &TM)
: NVPTXGenSubtargetInfo(TT, CPU, /*TuneCPU*/ CPU, FS), PTXVersion(0),
FullSmVersion(200), SmVersion(getSmVersion()),
- TLInfo(TM, initializeSubtargetDependencies(CPU, FS)) {
+ InstrInfo(initializeSubtargetDependencies(CPU, FS)), TLInfo(TM, *this) {
TSInfo = std::make_unique<NVPTXSelectionDAGInfo>();
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
index acf025b70ce3..0a77a633cb25 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
+++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
@@ -82,6 +82,7 @@ public:
bool hasAtomBitwise64() const { return SmVersion >= 32; }
bool hasAtomMinMax64() const { return SmVersion >= 32; }
bool hasAtomCas16() const { return SmVersion >= 70 && PTXVersion >= 63; }
+ bool hasAtomSwap128() const { return SmVersion >= 90 && PTXVersion >= 83; }
bool hasClusters() const { return SmVersion >= 90 && PTXVersion >= 78; }
bool hasLDG() const { return SmVersion >= 32; }
bool hasHWROT32() const { return SmVersion >= 32; }
@@ -105,6 +106,7 @@ public:
// Tcgen05 instructions in Blackwell family
bool hasTcgen05Instructions() const {
bool HasTcgen05 = false;
+ unsigned MinPTXVersion = 86;
switch (FullSmVersion) {
default:
break;
@@ -112,9 +114,13 @@ public:
case 1013: // sm_101a
HasTcgen05 = true;
break;
+ case 1033: // sm_103a
+ HasTcgen05 = true;
+ MinPTXVersion = 88;
+ break;
}
- return HasTcgen05 && PTXVersion >= 86;
+ return HasTcgen05 && PTXVersion >= MinPTXVersion;
}
// f32x2 instructions in Blackwell family
bool hasF32x2Instructions() const;
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
index 0603994606d7..833f014a4c87 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
@@ -126,12 +126,12 @@ static std::string computeDataLayout(bool is64Bit, bool UseShortPointers) {
// (addrspace:3).
if (!is64Bit)
Ret += "-p:32:32-p6:32:32-p7:32:32";
- else if (UseShortPointers) {
+ else if (UseShortPointers)
Ret += "-p3:32:32-p4:32:32-p5:32:32-p6:32:32-p7:32:32";
- } else
+ else
Ret += "-p6:32:32";
- Ret += "-i64:64-i128:128-v16:16-v32:32-n16:32:64";
+ Ret += "-i64:64-i128:128-i256:256-v16:16-v32:32-n16:32:64";
return Ret;
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 274b04fdd30b..8e97b422218f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -55,15 +55,6 @@ void clearAnnotationCache(const Module *Mod) {
AC.Cache.erase(Mod);
}
-static void readIntVecFromMDNode(const MDNode *MetadataNode,
- std::vector<unsigned> &Vec) {
- for (unsigned i = 0, e = MetadataNode->getNumOperands(); i != e; ++i) {
- ConstantInt *Val =
- mdconst::extract<ConstantInt>(MetadataNode->getOperand(i));
- Vec.push_back(Val->getZExtValue());
- }
-}
-
static void cacheAnnotationFromMD(const MDNode *MetadataNode,
key_val_pair_t &retval) {
auto &AC = getAnnotationCache();
@@ -83,19 +74,8 @@ static void cacheAnnotationFromMD(const MDNode *MetadataNode,
if (ConstantInt *Val = mdconst::dyn_extract<ConstantInt>(
MetadataNode->getOperand(i + 1))) {
retval[Key].push_back(Val->getZExtValue());
- } else if (MDNode *VecMd =
- dyn_cast<MDNode>(MetadataNode->getOperand(i + 1))) {
- // note: only "grid_constant" annotations support vector MDNodes.
- // assert: there can only exist one unique key value pair of
- // the form (string key, MDNode node). Operands of such a node
- // shall always be unsigned ints.
- auto [It, Inserted] = retval.try_emplace(Key);
- if (Inserted) {
- readIntVecFromMDNode(VecMd, It->second);
- continue;
- }
} else {
- llvm_unreachable("Value operand not a constant int or an mdnode");
+ llvm_unreachable("Value operand not a constant int");
}
}
}
@@ -179,16 +159,13 @@ static bool globalHasNVVMAnnotation(const Value &V, const std::string &Prop) {
}
static bool argHasNVVMAnnotation(const Value &Val,
- const std::string &Annotation,
- const bool StartArgIndexAtOne = false) {
+ const std::string &Annotation) {
if (const Argument *Arg = dyn_cast<Argument>(&Val)) {
const Function *Func = Arg->getParent();
std::vector<unsigned> Annot;
if (findAllNVVMAnnotation(Func, Annotation, Annot)) {
- const unsigned BaseOffset = StartArgIndexAtOne ? 1 : 0;
- if (is_contained(Annot, BaseOffset + Arg->getArgNo())) {
+ if (is_contained(Annot, Arg->getArgNo()))
return true;
- }
}
}
return false;
@@ -250,8 +227,7 @@ bool isParamGridConstant(const Argument &Arg) {
}
// "grid_constant" counts argument indices starting from 1
- if (argHasNVVMAnnotation(Arg, "grid_constant",
- /*StartArgIndexAtOne*/ true))
+ if (Arg.hasAttribute("nvvm.grid_constant"))
return true;
return false;