summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp')
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp448
1 files changed, 269 insertions, 179 deletions
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index b3a780abd3f1..6d45a51ab026 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -387,6 +387,8 @@ private:
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
+ bool hasAnchorLayout(xegpu::DistributeLayoutAttr anchorLayout);
+
public:
LayoutInfoPropagation(DataFlowSolver &solver,
SymbolTableCollection &symbolTable,
@@ -475,49 +477,72 @@ LogicalResult LayoutInfoPropagation::visitOperation(
return success();
}
+bool LayoutInfoPropagation::hasAnchorLayout(
+ xegpu::DistributeLayoutAttr anchorLayout) {
+ if (anchorLayout == nullptr) {
+ return false;
+ }
+ if (layoutKind == LayoutKind::InstData) {
+ return !(anchorLayout.getEffectiveInstDataAsInt().empty());
+ } else if (layoutKind == LayoutKind::Lane) {
+ return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
+ anchorLayout.getEffectiveLaneDataAsInt().empty());
+ }
+ return false;
+}
+
void LayoutInfoPropagation::visitPrefetchNdOp(
xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- // Here we assign the default layout to the tensor descriptor operand of
- // prefetch.
- auto tdescTy = prefetch.getTensorDescType();
-
- auto uArch = getUArch(getChipStr(prefetch).value_or(""));
- const auto *uArchInstruction =
- dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
- uArch->getInstruction(
- xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
-
- auto blockWHC =
- uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
- if (!blockWHC)
- prefetch.emitWarning("No known block params found for the element type.");
- auto [bWidth, bHeight, bCount] = blockWHC.value();
- SmallVector<int> instData;
- int instWidth = xegpu::getLargestDivisor(
- static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth,
- bCount);
- if (instWidth == -1)
- prefetch.emitWarning(
- "No suitable instruction multiple found for the given shape.");
- if (tdescTy.getRank() == 1)
- instData = {instWidth};
- else {
- int instHeight = xegpu::getLargestDivisor(
- static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
- if (instHeight == -1)
+
+ LayoutInfo prefetchLayout;
+ xegpu::DistributeLayoutAttr anchorLayout = prefetch.getAnchorLayoutAttr();
+ if (hasAnchorLayout(anchorLayout)) {
+ prefetchLayout = LayoutInfo(anchorLayout);
+ } else {
+ // Here we assign the default layout to the tensor descriptor operand of
+ // prefetch.
+ auto tdescTy = prefetch.getTensorDescType();
+
+ auto uArch = getUArch(getChipStr(prefetch).value_or(""));
+ const auto *uArchInstruction =
+ dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
+ uArch->getInstruction(
+ xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
+
+ auto blockWHC =
+ uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
+ if (!blockWHC)
+ prefetch.emitWarning("No known block params found for the element type.");
+ auto [bWidth, bHeight, bCount] = blockWHC.value();
+ SmallVector<int> instData;
+ int instWidth = xegpu::getLargestDivisor(
+ static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth,
+ bCount);
+ if (instWidth == -1)
prefetch.emitWarning(
"No suitable instruction multiple found for the given shape.");
- instData = {instHeight, instWidth};
- }
- LayoutInfo prefetchLayout;
- if (layoutKind == LayoutKind::InstData)
- prefetchLayout =
- LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
- else
- prefetchLayout = getDefaultSIMTLayoutInfo(
- tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
+ if (tdescTy.getRank() == 1)
+ instData = {instWidth};
+ else {
+ int instHeight = xegpu::getLargestDivisor(
+ static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
+ if (instHeight == -1)
+ prefetch.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ instData = {instHeight, instWidth};
+ }
+
+ if (layoutKind == LayoutKind::InstData)
+ prefetchLayout =
+ LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
+ else
+ prefetchLayout = getDefaultSIMTLayoutInfo(
+ tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
+ prefetch.setAnchorLayoutAttr(
+ dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get()));
+ }
// Propagate the layout to the source tensor descriptor.
propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
}
@@ -617,69 +642,96 @@ void LayoutInfoPropagation::visitUpdateNdOffsetOp(
void LayoutInfoPropagation::visitDpasOp(
xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- VectorType aTy = dpas.getLhsType();
- VectorType bTy = dpas.getRhsType();
-
- auto uArch = getUArch(getChipStr(dpas).value_or(""));
- const int subgroupSize = uArch->getSubgroupSize();
- const auto *uArchInstruction =
- dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
- xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
-
- const unsigned dataALen = aTy.getShape().front();
- auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
- const int maxALen =
- xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
- if (maxALen == -1)
- dpas.emitWarning(
- "No suitable instruction multiple found for the given shape.");
-
- const unsigned dataBLen = bTy.getShape().back();
- auto supportedBLen = uArchInstruction->getSupportedK(bTy.getElementType());
- const int maxBLen =
- xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
- if (maxBLen == -1)
- dpas.emitWarning(
- "No suitable instruction multiple found for the given shape.");
- SmallVector<int> instDataA = {maxALen, subgroupSize};
- SmallVector<int> instDataB = {subgroupSize, maxBLen};
LayoutInfo dpasALayout;
LayoutInfo dpasBLayout;
LayoutInfo dpasCLayout;
- if (layoutKind == LayoutKind::InstData) {
- dpasALayout =
- LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
- dpasBLayout =
- LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
+ xegpu::DistributeLayoutAttr anchorLayoutC = dpas.getAnchorLayoutCdAttr();
+ if (hasAnchorLayout(anchorLayoutC)) {
+ xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getAnchorLayoutAAttr();
+ xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getAnchorLayoutBAttr();
+ assert(hasAnchorLayout(anchorLayoutA) &&
+ "Expected anchor layout for DPAS A operand.");
+ assert(hasAnchorLayout(anchorLayoutB) &&
+ "Expected anchor layout for DPAS B operand.");
+ dpasALayout = LayoutInfo(anchorLayoutA);
+ dpasBLayout = LayoutInfo(anchorLayoutB);
+ dpasCLayout = LayoutInfo(anchorLayoutC);
+
} else {
- dpasALayout = getSIMTLayoutInfoForDPASOperand(
- aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
- dpasBLayout = getSIMTLayoutInfoForDPASOperand(
- bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
- }
- propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
- propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
- if (operands.size() > 2) {
- VectorType cTy = dpas.getAccType();
- const unsigned dataCLen = bTy.getShape().back();
- auto supportedCLen = uArchInstruction->getSupportedN(bTy.getElementType());
- const int maxCLen =
- xegpu::getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
- if (maxCLen == -1)
+ VectorType aTy = dpas.getLhsType();
+ VectorType bTy = dpas.getRhsType();
+
+ auto uArch = getUArch(getChipStr(dpas).value_or(""));
+ const int subgroupSize = uArch->getSubgroupSize();
+ const auto *uArchInstruction =
+ dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
+ xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
+
+ const unsigned dataALen = aTy.getShape().front();
+ auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
+ const int maxALen =
+ xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
+ if (maxALen == -1)
dpas.emitWarning(
"No suitable instruction multiple found for the given shape.");
- SmallVector<int> instDataC = {maxALen, maxCLen};
- if (layoutKind == LayoutKind::InstData)
- dpasCLayout =
- LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC));
- else
- dpasCLayout = getSIMTLayoutInfoForDPASOperand(
- cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
+ const unsigned dataBLen = bTy.getShape().back();
+ auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType());
+
+ const int maxBLen =
+ xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
+
+ if (maxBLen == -1)
+ dpas.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ SmallVector<int> instDataA = {maxALen, subgroupSize};
+ SmallVector<int> instDataB = {subgroupSize, maxBLen};
+
+ if (layoutKind == LayoutKind::InstData) {
+ dpasALayout =
+ LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
+ dpasBLayout =
+ LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
+ } else {
+ dpasALayout = getSIMTLayoutInfoForDPASOperand(
+ aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
+ dpasBLayout = getSIMTLayoutInfoForDPASOperand(
+ bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
+ }
+ if (operands.size() > 2) {
+ VectorType cTy = dpas.getAccType();
+ if (layoutKind == LayoutKind::InstData) {
+ const unsigned dataCLen = bTy.getShape().back();
+ auto supportedCLen =
+ uArchInstruction->getSupportedN(bTy.getElementType());
+ const int maxCLen = xegpu::getLargestDivisor(
+ dataCLen, ArrayRef<unsigned>(supportedCLen));
+ if (maxCLen == -1)
+ dpas.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ SmallVector<int> instDataC = {maxALen, maxCLen};
+ dpasCLayout =
+ LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC));
+ } else
+ dpasCLayout = getSIMTLayoutInfoForDPASOperand(
+ cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
+
+ dpas.setAnchorLayoutCdAttr(
+ dyn_cast<xegpu::DistributeLayoutAttr>(dpasCLayout.get()));
+ }
+ dpas.setAnchorLayoutAAttr(
+ dyn_cast<xegpu::DistributeLayoutAttr>(dpasALayout.get()));
+ dpas.setAnchorLayoutBAttr(
+ dyn_cast<xegpu::DistributeLayoutAttr>(dpasBLayout.get()));
+ }
+
+ propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
+ propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
+ if (operands.size() > 2) {
propagateIfChanged(operands[2], operands[2]->meet(dpasCLayout));
}
}
@@ -689,43 +741,51 @@ void LayoutInfoPropagation::visitStoreNdOp(
xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- auto uArch = getUArch(getChipStr(store).value_or(""));
- const auto *uArchInstruction =
- dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
- uArch->getInstruction(
- xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
- VectorType dataTy = store.getValueType();
- auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
- store.getValueType().getElementType());
- if (!blockWHC)
- store.emitWarning("No known block params found for the element type.");
- auto [bWidth, bHeight, bCount] = blockWHC.value();
- SmallVector<int> instData;
- int instWidth = xegpu::getLargestDivisor(
- static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth,
- bCount);
- if (instWidth == -1)
- store.emitWarning(
- "No suitable instruction multiple found for the given shape.");
- if (dataTy.getRank() == 1)
- instData = {instWidth};
- else {
- int instHeight = xegpu::getLargestDivisor(
- static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
- if (instHeight == -1)
+ LayoutInfo storeLayout;
+ xegpu::DistributeLayoutAttr anchorLayout = store.getAnchorLayoutAttr();
+ if (hasAnchorLayout(anchorLayout)) {
+ storeLayout = LayoutInfo(anchorLayout);
+ } else {
+ auto uArch = getUArch(getChipStr(store).value_or(""));
+ const auto *uArchInstruction =
+ dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
+ uArch->getInstruction(
+ xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
+ VectorType dataTy = store.getValueType();
+ auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
+ store.getValueType().getElementType());
+ if (!blockWHC)
+ store.emitWarning("No known block params found for the element type.");
+ auto [bWidth, bHeight, bCount] = blockWHC.value();
+ SmallVector<int> instData;
+ int instWidth = xegpu::getLargestDivisor(
+ static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth,
+ bCount);
+ if (instWidth == -1)
store.emitWarning(
"No suitable instruction multiple found for the given shape.");
- instData = {instHeight, instWidth};
- }
+ if (dataTy.getRank() == 1)
+ instData = {instWidth};
+ else {
+ int instHeight = xegpu::getLargestDivisor(
+ static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
+ if (instHeight == -1)
+ store.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ instData = {instHeight, instWidth};
+ }
- LayoutInfo storeLayout;
- if (layoutKind == LayoutKind::InstData)
- storeLayout =
- LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
- else
- storeLayout =
- getDefaultSIMTLayoutInfo(store.getValueType(), uArch,
- uArchInstruction->getPackedFormatBitSize());
+ if (layoutKind == LayoutKind::InstData)
+ storeLayout =
+ LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
+ else
+ storeLayout =
+ getDefaultSIMTLayoutInfo(store.getValueType(), uArch,
+ uArchInstruction->getPackedFormatBitSize());
+ store.setAnchorLayoutAttr(
+ dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
+ }
+ // Propagate the layout to the value operand.
// Both operands should have the same layout
for (LayoutInfoLattice *operand : operands)
propagateIfChanged(operand, operand->meet(storeLayout));
@@ -736,21 +796,31 @@ void LayoutInfoPropagation::visitStoreNdOp(
void LayoutInfoPropagation::visitLoadNdOp(
xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- LayoutInfo valueLayout = results[0]->getValue();
- // Need the layout of the value to propagate to the tensor descriptor.
- if (!valueLayout.isAssigned())
- return;
- LayoutInfo tensorDescLayout = valueLayout;
- // LoadNdOp has the transpose effect. However, at the stage of this analysis
- // this effect is not expected and should be abstracted away. Emit a
- // warning.
- if (auto transpose = load.getTranspose()) {
- load.emitWarning("Transpose effect is not expected for LoadNdOp at "
- "LayoutInfoPropagation stage.");
- tensorDescLayout = valueLayout.transpose(transpose.value());
+
+ LayoutInfo loadLayout;
+ xegpu::DistributeLayoutAttr anchorLayout = load.getAnchorLayoutAttr();
+ if (hasAnchorLayout(anchorLayout)) {
+ loadLayout = LayoutInfo(anchorLayout);
+ } else {
+
+ LayoutInfo valueLayout = results[0]->getValue();
+ // Need the layout of the value to propagate to the tensor descriptor.
+ if (!valueLayout.isAssigned())
+ return;
+ loadLayout = valueLayout;
+ // LoadNdOp has the transpose effect. However, at the stage of this analysis
+ // this effect is not expected and should be abstracted away. Emit a
+ // warning.
+ if (auto transpose = load.getTranspose()) {
+ load.emitWarning("Transpose effect is not expected for LoadNdOp at "
+ "LayoutInfoPropagation stage.");
+ loadLayout = valueLayout.transpose(transpose.value());
+ }
+ load.setAnchorLayoutAttr(
+ dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
}
// Propagate the new layout to the tensor descriptor operand.
- propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
+ propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
}
/// For vector::TransposeOp, the layout of the result is transposed and
@@ -840,37 +910,49 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
void LayoutInfoPropagation::visitLoadGatherOp(
xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- // The layout is strictly determined by the payload type.
- auto payloadTy = dyn_cast<VectorType>(load.getValueType());
- if (!payloadTy) {
- load.emitWarning("Not propagating, non-vector payload supplied.");
- return;
- }
- auto uArch = getUArch(getChipStr(load).value_or(""));
- const int subgroupSize = uArch->getSubgroupSize();
- SmallVector<int> instData{subgroupSize};
- if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1)
- instData.push_back(chunkSize);
- else if (auto srcTdescTy =
- dyn_cast<xegpu::TensorDescType>(load.getSourceType())) {
- if (srcTdescTy.getChunkSizeAsInt() > 1)
+
+ LayoutInfo loadLayout;
+ LayoutInfo maskLayout;
+ xegpu::DistributeLayoutAttr anchorLayout = load.getAnchorLayoutAttr();
+ if (hasAnchorLayout(anchorLayout)) {
+ loadLayout = LayoutInfo(anchorLayout);
+ maskLayout = loadLayout;
+ } else {
+
+ // The layout is strictly determined by the payload type.
+ auto payloadTy = dyn_cast<VectorType>(load.getValueType());
+ if (!payloadTy) {
+ load.emitWarning("Not propagating, non-vector payload supplied.");
+ return;
+ }
+ auto uArch = getUArch(getChipStr(load).value_or(""));
+ const int subgroupSize = uArch->getSubgroupSize();
+ SmallVector<int> instData{subgroupSize};
+ if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1)
instData.push_back(chunkSize);
- }
- LayoutInfo layout;
- if (layoutKind == LayoutKind::InstData)
- layout = LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
- else
- layout = getDefaultSIMTLayoutInfo(payloadTy, uArch,
- uArch->getGeneralPackedFormatBitSize(),
- /*scattered*/ true);
-
- // Mask operand should have 1D default layout.
- LayoutInfo maskLayout =
- getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
+ else if (auto srcTdescTy =
+ dyn_cast<xegpu::TensorDescType>(load.getSourceType())) {
+ if (srcTdescTy.getChunkSizeAsInt() > 1)
+ instData.push_back(chunkSize);
+ }
+
+ if (layoutKind == LayoutKind::InstData)
+ loadLayout =
+ LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
+ else
+ loadLayout = getDefaultSIMTLayoutInfo(
+ payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
+ /*scattered*/ true);
+
+ // Mask operand should have 1D default layout.
+ maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
+ load.setAnchorLayoutAttr(
+ dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
+ }
// Propagate the new layout to the tensor descriptor operand.
if (isa<xegpu::TensorDescType>(load.getSourceType()))
- propagateIfChanged(operands[0], operands[0]->meet(layout));
+ propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
// Propagate the new layout to the mask and optional offset operand.
propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
if (load.getOffsets())
@@ -898,21 +980,26 @@ void LayoutInfoPropagation::visitCreateDescOp(
void LayoutInfoPropagation::visitStoreScatterOp(
xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- // Currently, for 2D StoreScatterOp we expect that the height dimension of
- // the tensor descriptor is equal to the subgroup size. This is ensured by
- // the op verifier.
- auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
- if (!payloadTy) {
- storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
- return;
- }
- LayoutInfo payloadLayout;
- auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
- const int subgroupSize = uArch->getSubgroupSize();
- if (auto layout = storeScatter.getLayoutAttr()) {
- payloadLayout = LayoutInfo(layout);
+ LayoutInfo payloadLayout;
+ LayoutInfo maskLayout;
+ xegpu::DistributeLayoutAttr anchorLayout = storeScatter.getAnchorLayoutAttr();
+ if (hasAnchorLayout(anchorLayout)) {
+ payloadLayout = LayoutInfo(anchorLayout);
+ maskLayout = payloadLayout;
} else {
+ // Currently, for 2D StoreScatterOp we expect that the height dimension of
+ // the tensor descriptor is equal to the subgroup size. This is ensured by
+ // the op verifier.
+ auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
+ if (!payloadTy) {
+ storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
+ return;
+ }
+
+ auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
+ const int subgroupSize = uArch->getSubgroupSize();
+
if (layoutKind == LayoutKind::InstData) {
SmallVector<int> instData{subgroupSize};
if (auto chunkSize = storeScatter.getChunkSize().value_or(0);
@@ -936,10 +1023,13 @@ void LayoutInfoPropagation::visitStoreScatterOp(
payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
/*scattered=*/true);
}
- }
- LayoutInfo maskLayout =
- getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
+ maskLayout =
+ getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
+
+ storeScatter.setAnchorLayoutAttr(
+ dyn_cast<xegpu::DistributeLayoutAttr>(payloadLayout.get()));
+ }
// Propagate the payload operand layout
propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
// Propagate the destination (if tdesc) operand layout