diff options
Diffstat (limited to 'mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp')
| -rw-r--r-- | mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp | 123 |
1 files changed, 107 insertions, 16 deletions
diff --git a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp index 69017efb9a0e..f5e30a278f06 100644 --- a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp +++ b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/IR/Matchers.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/InferIntRangeInterface.h" #include "llvm/ADT/STLForwardCompat.h" #include "llvm/Support/ErrorHandling.h" @@ -54,6 +55,35 @@ static Value valueByDim(KernelDim3 dims, Dimension dim) { static uint64_t zext(uint32_t arg) { return static_cast<uint64_t>(arg); } +static std::optional<uint64_t> +getKnownLaunchAttr(GPUFuncOp func, LaunchDims dims, Dimension dim) { + DenseI32ArrayAttr bounds; + switch (dims) { + case LaunchDims::Block: + bounds = func.getKnownBlockSizeAttr(); + break; + case LaunchDims::Grid: + bounds = func.getKnownGridSizeAttr(); + break; + } + if (!bounds) + return std::nullopt; + if (bounds.size() < static_cast<uint32_t>(dim)) + return std::nullopt; + return zext(bounds[static_cast<uint32_t>(dim)]); +} + +static std::optional<uint64_t> getKnownLaunchAttr(FunctionOpInterface func, + StringRef attrName, + Dimension dim) { + auto bounds = func.getOperation()->getAttrOfType<DenseI32ArrayAttr>(attrName); + if (!bounds) + return std::nullopt; + if (bounds.size() < static_cast<uint32_t>(dim)) + return std::nullopt; + return zext(bounds[static_cast<uint32_t>(dim)]); +} + template <typename Op> static std::optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) { Dimension dim = op.getDimension(); @@ -73,25 +103,57 @@ static std::optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) { return value.getZExtValue(); } - if (auto func = op->template getParentOfType<GPUFuncOp>()) { + if (auto gpuFunc = op->template getParentOfType<GPUFuncOp>()) { + auto inherentAttr = getKnownLaunchAttr(gpuFunc, type, dim); + if (inherentAttr) + return inherentAttr; + } + if (auto func = op->template getParentOfType<FunctionOpInterface>()) { + StringRef attrName; switch (type) { case LaunchDims::Block: - return llvm::transformOptional(func.getKnownBlockSize(dim), zext); + attrName = GPUDialect::KnownBlockSizeAttrHelper::getNameStr(); + break; case LaunchDims::Grid: - return llvm::transformOptional(func.getKnownGridSize(dim), zext); + attrName = GPUDialect::KnownGridSizeAttrHelper::getNameStr(); + break; } + auto discardableAttr = getKnownLaunchAttr(func, attrName, dim); + if (discardableAttr) + return discardableAttr; } return std::nullopt; } void ClusterDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>, SetIntRangeFn setResultRange) { - setResultRange(getResult(), getIndexRange(1, kMaxClusterDim)); + uint64_t max = kMaxDim; + if (auto specified = getUpperBound()) + max = specified->getZExtValue(); + setResultRange(getResult(), getIndexRange(1, max)); +} + +void ClusterDimBlocksOp::inferResultRanges(ArrayRef<ConstantIntRanges>, + SetIntRangeFn setResultRange) { + uint64_t max = kMaxClusterDim; + if (auto specified = getUpperBound()) + max = specified->getZExtValue(); + setResultRange(getResult(), getIndexRange(1, max)); } void ClusterIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>, SetIntRangeFn setResultRange) { + uint64_t max = kMaxDim; + if (auto specified = getUpperBound()) + max = specified->getZExtValue(); + setResultRange(getResult(), getIndexRange(0, max - 1ULL)); +} + +void ClusterBlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>, + SetIntRangeFn setResultRange) { uint64_t max = kMaxClusterDim; + if (auto specified = getUpperBound()) + max = specified->getZExtValue(); setResultRange(getResult(), getIndexRange(0, max - 1ULL)); } @@ -100,14 +162,21 @@ void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>, std::optional<uint64_t> knownVal = getKnownLaunchDim(*this, LaunchDims::Block); if (knownVal) - setResultRange(getResult(), getIndexRange(*knownVal, *knownVal)); - else - setResultRange(getResult(), getIndexRange(1, kMaxDim)); + return setResultRange(getResult(), getIndexRange(*knownVal, *knownVal)); + ; + uint64_t max = kMaxDim; + if (auto specified = getUpperBound()) + max = specified->getZExtValue(); + setResultRange(getResult(), getIndexRange(1, max)); } void BlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>, SetIntRangeFn setResultRange) { - uint64_t max = getKnownLaunchDim(*this, LaunchDims::Grid).value_or(kMaxDim); + uint64_t max = kMaxDim; + if (auto fromContext = getKnownLaunchDim(*this, LaunchDims::Grid)) + max = fromContext.value(); + if (auto specified = getUpperBound()) + max = specified->getZExtValue(); setResultRange(getResult(), getIndexRange(0, max - 1ULL)); } @@ -115,29 +184,45 @@ void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>, SetIntRangeFn setResultRange) { std::optional<uint64_t> knownVal = getKnownLaunchDim(*this, LaunchDims::Grid); if (knownVal) - setResultRange(getResult(), getIndexRange(*knownVal, *knownVal)); - else - setResultRange(getResult(), getIndexRange(1, kMaxDim)); + return setResultRange(getResult(), getIndexRange(*knownVal, *knownVal)); + uint64_t max = kMaxDim; + if (auto specified = getUpperBound()) + max = specified->getZExtValue(); + setResultRange(getResult(), getIndexRange(1, max)); } void ThreadIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>, SetIntRangeFn setResultRange) { - uint64_t max = getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim); + uint64_t max = kMaxDim; + if (auto fromContext = getKnownLaunchDim(*this, LaunchDims::Block)) + max = fromContext.value(); + if (auto specified = getUpperBound()) + max = specified->getZExtValue(); setResultRange(getResult(), getIndexRange(0, max - 1ULL)); } void LaneIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>, SetIntRangeFn setResultRange) { - setResultRange(getResult(), getIndexRange(0, kMaxSubgroupSize - 1ULL)); + uint64_t max = kMaxSubgroupSize; + if (auto specified = getUpperBound()) + max = specified->getZExtValue(); + setResultRange(getResult(), getIndexRange(0, max - 1ULL)); } void SubgroupIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>, SetIntRangeFn setResultRange) { - setResultRange(getResult(), getIndexRange(0, kMaxDim - 1ULL)); + uint64_t max = kMaxDim; + if (auto specified = getUpperBound()) + max = specified->getZExtValue(); + setResultRange(getResult(), getIndexRange(0, max - 1ULL)); } void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>, SetIntRangeFn setResultRange) { + if (auto specified = getUpperBound()) + return setResultRange(getResult(), + getIndexRange(0, specified->getZExtValue() - 1ULL)); + uint64_t blockDimMax = getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim); uint64_t gridDimMax = @@ -148,12 +233,18 @@ void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>, void NumSubgroupsOp::inferResultRanges(ArrayRef<ConstantIntRanges>, SetIntRangeFn setResultRange) { - setResultRange(getResult(), getIndexRange(1, kMaxDim)); + uint64_t max = kMaxDim; + if (auto specified = getUpperBound()) + max = specified->getZExtValue(); + setResultRange(getResult(), getIndexRange(1, max)); } void SubgroupSizeOp::inferResultRanges(ArrayRef<ConstantIntRanges>, SetIntRangeFn setResultRange) { - setResultRange(getResult(), getIndexRange(1, kMaxSubgroupSize)); + uint64_t max = kMaxSubgroupSize; + if (auto specified = getUpperBound()) + max = specified->getZExtValue(); + setResultRange(getResult(), getIndexRange(1, max)); } void LaunchOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
