summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp')
-rw-r--r--mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp123
1 files changed, 107 insertions, 16 deletions
diff --git a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
index 69017efb9a0e..f5e30a278f06 100644
--- a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "llvm/ADT/STLForwardCompat.h"
#include "llvm/Support/ErrorHandling.h"
@@ -54,6 +55,35 @@ static Value valueByDim(KernelDim3 dims, Dimension dim) {
static uint64_t zext(uint32_t arg) { return static_cast<uint64_t>(arg); }
+static std::optional<uint64_t>
+getKnownLaunchAttr(GPUFuncOp func, LaunchDims dims, Dimension dim) {
+ DenseI32ArrayAttr bounds;
+ switch (dims) {
+ case LaunchDims::Block:
+ bounds = func.getKnownBlockSizeAttr();
+ break;
+ case LaunchDims::Grid:
+ bounds = func.getKnownGridSizeAttr();
+ break;
+ }
+ if (!bounds)
+ return std::nullopt;
+ if (bounds.size() < static_cast<uint32_t>(dim))
+ return std::nullopt;
+ return zext(bounds[static_cast<uint32_t>(dim)]);
+}
+
+static std::optional<uint64_t> getKnownLaunchAttr(FunctionOpInterface func,
+ StringRef attrName,
+ Dimension dim) {
+ auto bounds = func.getOperation()->getAttrOfType<DenseI32ArrayAttr>(attrName);
+ if (!bounds)
+ return std::nullopt;
+ if (bounds.size() < static_cast<uint32_t>(dim))
+ return std::nullopt;
+ return zext(bounds[static_cast<uint32_t>(dim)]);
+}
+
template <typename Op>
static std::optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) {
Dimension dim = op.getDimension();
@@ -73,25 +103,57 @@ static std::optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) {
return value.getZExtValue();
}
- if (auto func = op->template getParentOfType<GPUFuncOp>()) {
+ if (auto gpuFunc = op->template getParentOfType<GPUFuncOp>()) {
+ auto inherentAttr = getKnownLaunchAttr(gpuFunc, type, dim);
+ if (inherentAttr)
+ return inherentAttr;
+ }
+ if (auto func = op->template getParentOfType<FunctionOpInterface>()) {
+ StringRef attrName;
switch (type) {
case LaunchDims::Block:
- return llvm::transformOptional(func.getKnownBlockSize(dim), zext);
+ attrName = GPUDialect::KnownBlockSizeAttrHelper::getNameStr();
+ break;
case LaunchDims::Grid:
- return llvm::transformOptional(func.getKnownGridSize(dim), zext);
+ attrName = GPUDialect::KnownGridSizeAttrHelper::getNameStr();
+ break;
}
+ auto discardableAttr = getKnownLaunchAttr(func, attrName, dim);
+ if (discardableAttr)
+ return discardableAttr;
}
return std::nullopt;
}
void ClusterDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), getIndexRange(1, kMaxClusterDim));
+ uint64_t max = kMaxDim;
+ if (auto specified = getUpperBound())
+ max = specified->getZExtValue();
+ setResultRange(getResult(), getIndexRange(1, max));
+}
+
+void ClusterDimBlocksOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+ SetIntRangeFn setResultRange) {
+ uint64_t max = kMaxClusterDim;
+ if (auto specified = getUpperBound())
+ max = specified->getZExtValue();
+ setResultRange(getResult(), getIndexRange(1, max));
}
void ClusterIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) {
+ uint64_t max = kMaxDim;
+ if (auto specified = getUpperBound())
+ max = specified->getZExtValue();
+ setResultRange(getResult(), getIndexRange(0, max - 1ULL));
+}
+
+void ClusterBlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+ SetIntRangeFn setResultRange) {
uint64_t max = kMaxClusterDim;
+ if (auto specified = getUpperBound())
+ max = specified->getZExtValue();
setResultRange(getResult(), getIndexRange(0, max - 1ULL));
}
@@ -100,14 +162,21 @@ void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
std::optional<uint64_t> knownVal =
getKnownLaunchDim(*this, LaunchDims::Block);
if (knownVal)
- setResultRange(getResult(), getIndexRange(*knownVal, *knownVal));
- else
- setResultRange(getResult(), getIndexRange(1, kMaxDim));
+ return setResultRange(getResult(), getIndexRange(*knownVal, *knownVal));
+ ;
+ uint64_t max = kMaxDim;
+ if (auto specified = getUpperBound())
+ max = specified->getZExtValue();
+ setResultRange(getResult(), getIndexRange(1, max));
}
void BlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) {
- uint64_t max = getKnownLaunchDim(*this, LaunchDims::Grid).value_or(kMaxDim);
+ uint64_t max = kMaxDim;
+ if (auto fromContext = getKnownLaunchDim(*this, LaunchDims::Grid))
+ max = fromContext.value();
+ if (auto specified = getUpperBound())
+ max = specified->getZExtValue();
setResultRange(getResult(), getIndexRange(0, max - 1ULL));
}
@@ -115,29 +184,45 @@ void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) {
std::optional<uint64_t> knownVal = getKnownLaunchDim(*this, LaunchDims::Grid);
if (knownVal)
- setResultRange(getResult(), getIndexRange(*knownVal, *knownVal));
- else
- setResultRange(getResult(), getIndexRange(1, kMaxDim));
+ return setResultRange(getResult(), getIndexRange(*knownVal, *knownVal));
+ uint64_t max = kMaxDim;
+ if (auto specified = getUpperBound())
+ max = specified->getZExtValue();
+ setResultRange(getResult(), getIndexRange(1, max));
}
void ThreadIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) {
- uint64_t max = getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
+ uint64_t max = kMaxDim;
+ if (auto fromContext = getKnownLaunchDim(*this, LaunchDims::Block))
+ max = fromContext.value();
+ if (auto specified = getUpperBound())
+ max = specified->getZExtValue();
setResultRange(getResult(), getIndexRange(0, max - 1ULL));
}
void LaneIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), getIndexRange(0, kMaxSubgroupSize - 1ULL));
+ uint64_t max = kMaxSubgroupSize;
+ if (auto specified = getUpperBound())
+ max = specified->getZExtValue();
+ setResultRange(getResult(), getIndexRange(0, max - 1ULL));
}
void SubgroupIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), getIndexRange(0, kMaxDim - 1ULL));
+ uint64_t max = kMaxDim;
+ if (auto specified = getUpperBound())
+ max = specified->getZExtValue();
+ setResultRange(getResult(), getIndexRange(0, max - 1ULL));
}
void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) {
+ if (auto specified = getUpperBound())
+ return setResultRange(getResult(),
+ getIndexRange(0, specified->getZExtValue() - 1ULL));
+
uint64_t blockDimMax =
getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
uint64_t gridDimMax =
@@ -148,12 +233,18 @@ void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
void NumSubgroupsOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), getIndexRange(1, kMaxDim));
+ uint64_t max = kMaxDim;
+ if (auto specified = getUpperBound())
+ max = specified->getZExtValue();
+ setResultRange(getResult(), getIndexRange(1, max));
}
void SubgroupSizeOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), getIndexRange(1, kMaxSubgroupSize));
+ uint64_t max = kMaxSubgroupSize;
+ if (auto specified = getUpperBound())
+ max = specified->getZExtValue();
+ setResultRange(getResult(), getIndexRange(1, max));
}
void LaunchOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,