summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
diff options
context:
space:
mode:
authorArthur Eubanks <aeubanks@google.com>2024-06-27 16:32:27 -0700
committershawbyoung <shawbyoung@gmail.com>2024-06-27 16:32:27 -0700
commitf5c7df12cacdb84552b36a7ac598a8db41acc680 (patch)
tree3b33e941b9bfb88c40c64fd18ee32a633423cbed /mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
parent608880c3a7a59c86db82728067e553a8d4665a45 (diff)
parent804415825b97e974c96a92580bcbeaf4c7ff0a04 (diff)
[𝘀𝗽𝗿] changes introduced through rebaseusers/shawbyoung/spr/main.boltnfc-refactoring-callgraph
Created using spr 1.3.4 [skip ci]
Diffstat (limited to 'mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp')
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp457
1 files changed, 344 insertions, 113 deletions
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9b3121774ab3..4eb334f8bbbf 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -369,12 +369,11 @@ void transform::BufferizeToAllocationOp::getEffects(
if (getBufferizeDestinationOnly()) {
// The destination is replaced with a newly allocated buffer, but the op
// itself remains in place.
- onlyReadsHandle(getTarget(), effects);
+ onlyReadsHandle(getTargetMutable(), effects);
} else {
- consumesHandle(getTarget(), effects);
+ consumesHandle(getTargetMutable(), effects);
}
- producesHandle(getAllocatedBuffer(), effects);
- producesHandle(getNewOps(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
modifiesPayload(effects);
}
@@ -463,7 +462,7 @@ DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- onlyReadsHandle(getTarget(), effects);
+ onlyReadsHandle(getTargetMutable(), effects);
modifiesPayload(effects);
}
@@ -1040,9 +1039,9 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
void transform::FuseIntoContainingOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- consumesHandle(getProducerOp(), effects);
- onlyReadsHandle(getContainingOp(), effects);
- producesHandle(getResults(), effects);
+ consumesHandle(getProducerOpMutable(), effects);
+ onlyReadsHandle(getContainingOpMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
modifiesPayload(effects);
}
@@ -1391,8 +1390,8 @@ DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
void transform::MultiTileSizesOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- onlyReadsHandle(getTarget(), effects);
- producesHandle(getResults(), effects);
+ onlyReadsHandle(getTargetMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
if (isa<TransformParamTypeInterface>(getLowSize().getType()))
onlyReadsPayload(effects);
else
@@ -1478,9 +1477,9 @@ transform::PackOp::apply(transform::TransformRewriter &rewriter,
void transform::PackOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- transform::consumesHandle(getTarget(), effects);
- transform::onlyReadsHandle(getPackedSizes(), effects);
- transform::producesHandle(getPackedOp(), effects);
+ transform::consumesHandle(getTargetMutable(), effects);
+ transform::onlyReadsHandle(getPackedSizesMutable(), effects);
+ transform::producesHandle(getOperation()->getOpResults(), effects);
transform::modifiesPayload(effects);
}
@@ -1549,9 +1548,9 @@ SmallVector<OpFoldResult> PackGreedilyOp::getMixedMatmulPackedSizes() {
void transform::PackGreedilyOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- transform::consumesHandle(getTarget(), effects);
- transform::onlyReadsHandle(getMatmulPackedSizes(), effects);
- transform::producesHandle(getPackedOp(), effects);
+ transform::consumesHandle(getTargetMutable(), effects);
+ transform::onlyReadsHandle(getMatmulPackedSizesMutable(), effects);
+ transform::producesHandle(getOperation()->getOpResults(), effects);
transform::modifiesPayload(effects);
}
@@ -1761,11 +1760,9 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
void PadOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- consumesHandle(getTarget(), effects);
- onlyReadsHandle(getPadToMultipleOf(), effects);
- producesHandle(getPadded(), effects);
- producesHandle(getPad(), effects);
- producesHandle(getCopy(), effects);
+ consumesHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getPadToMultipleOfMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
modifiesPayload(effects);
}
@@ -1992,9 +1989,9 @@ LogicalResult transform::HoistPadBuildPackingLoopNestOp::verify() {
void transform::HoistPadBuildPackingLoopNestOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- transform::onlyReadsHandle(getTarget(), effects);
- transform::onlyReadsHandle(getLoop(), effects);
- transform::producesHandle(getPackingLoop(), effects);
+ transform::onlyReadsHandle(getTargetMutable(), effects);
+ transform::onlyReadsHandle(getLoopMutable(), effects);
+ transform::producesHandle(getOperation()->getOpResults(), effects);
transform::modifiesPayload(effects);
}
@@ -2135,8 +2132,8 @@ transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
void transform::ReplaceOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- consumesHandle(getTarget(), effects);
- producesHandle(getReplacement(), effects);
+ consumesHandle(getTargetMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
modifiesPayload(effects);
}
@@ -2269,13 +2266,26 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
// Collect the dynamic split points if provided.
SmallVector<Operation *> payload =
llvm::to_vector(state.getPayloadOps(getTarget()));
- SmallVector<OpFoldResult> splitPoints;
- splitPoints.reserve(payload.size());
- if (getDynamicSplitPoint()) {
+
+ bool isMultiwaySplit = getMultiway();
+
+ if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
+ return mlir::emitSilenceableFailure(getLoc())
+ << "requires exactly one target when "
+ "multiway split is enabled (got "
+ << llvm::range_size(payload) << ")";
+ }
+
+ SmallVector<OpFoldResult> chunkSizes;
+
+ if (!isMultiwaySplit)
+ chunkSizes.reserve(payload.size());
+
+ if (getDynamicChunkSizes()) {
auto diag = DiagnosedSilenceableFailure::success();
- if (isa<TransformHandleTypeInterface>(getDynamicSplitPoint().getType())) {
- splitPoints = llvm::to_vector(llvm::map_range(
- state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
+ if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().getType())) {
+ chunkSizes = llvm::to_vector(llvm::map_range(
+ state.getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
if (op->getNumResults() != 1 ||
!op->getResult(0).getType().isIndex()) {
diag = emitSilenceableError()
@@ -2286,103 +2296,174 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
return OpFoldResult(op->getResult(0));
}));
} else {
- splitPoints = llvm::to_vector(
- llvm::map_range(state.getParams(getDynamicSplitPoint()),
+ chunkSizes = llvm::to_vector(
+ llvm::map_range(state.getParams(getDynamicChunkSizes()),
[](Attribute attr) { return OpFoldResult(attr); }));
}
if (diag.isSilenceableFailure())
return diag;
- if (splitPoints.size() != payload.size()) {
+ // For multiway split, a single payload is expected to have multiple
+ // split points.
+ if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
return emitDefiniteFailure()
<< "expected the dynamic split point handle to point to as "
"many operations ("
- << splitPoints.size() << ") as the target handle ("
+ << chunkSizes.size() << ") as the target handle ("
<< payload.size() << ")";
}
} else {
- splitPoints.resize(payload.size(),
- rewriter.getIndexAttr(getStaticSplitPoint()));
+ chunkSizes.resize(payload.size(),
+ rewriter.getIndexAttr(getStaticChunkSizes()));
}
- // Split each target operation.
- SmallVector<Operation *> first, second;
- Operation *noSecondPart = nullptr;
- for (const auto &pair : llvm::zip(payload, splitPoints)) {
- Operation *target = std::get<0>(pair);
- auto linalgOp = dyn_cast<LinalgOp>(target);
+ auto checkStructuredOpAndDimensions =
+ [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
if (!linalgOp) {
auto diag = emitSilenceableError() << "only applies to structured ops";
- diag.attachNote(target->getLoc()) << "target op";
+ diag.attachNote(loc) << "target op";
return diag;
}
if (getDimension() >= linalgOp.getNumLoops()) {
auto diag = emitSilenceableError() << "dimension " << getDimension()
<< " does not exist in target op";
- diag.attachNote(target->getLoc()) << "target op";
+ diag.attachNote(loc) << "target op";
return diag;
}
+ return DiagnosedSilenceableFailure::success();
+ };
- rewriter.setInsertionPoint(linalgOp);
- std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
- rewriter, cast<TilingInterface>(linalgOp.getOperation()),
- getDimension(), std::get<1>(pair));
-
- // Propagate errors.
- if (!first.back() && !second.back()) {
+ auto checkFailureInSplitting =
+ [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
+ if (hasFailed) {
auto diag = emitDefiniteFailure() << "internal failure in splitting";
- diag.attachNote(target->getLoc()) << "target op";
+ diag.attachNote(loc) << "target op";
return diag;
}
+ return DiagnosedSilenceableFailure::success();
+ };
+
+ if (isMultiwaySplit) {
+
+ // Split a single target operation at multiple points.
+ SmallVector<Operation *> opList;
+ TilingInterface head, tail;
+ Operation *target = payload.front();
- // Do not add null second parts.
- if (!second.back()) {
- noSecondPart = target;
- second.pop_back();
+ LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
+
+ // Check that the target is a valid LinalgOp with correct dimensions.
+ DiagnosedSilenceableFailure diag =
+ checkStructuredOpAndDimensions(linalgOp, target->getLoc());
+ if (diag.isSilenceableFailure())
+ return diag;
+
+ for (auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
+
+ if (idx > 0)
+ target = tail.getOperation();
+
+ if (!target)
+ break;
+
+ linalgOp = cast<LinalgOp>(target);
+ Location loc = target->getLoc();
+
+ rewriter.setInsertionPoint(linalgOp);
+ std::tie(head, tail) = linalg::splitOp(
+ rewriter, cast<TilingInterface>(linalgOp.getOperation()),
+ getDimension(), chunkSize);
+
+ // Propagate errors.
+ DiagnosedSilenceableFailure diag =
+ checkFailureInSplitting(!head && !tail, loc);
+ if (diag.isDefiniteFailure())
+ return diag;
+
+ opList.push_back(head.getOperation());
}
- }
- if (second.size() != first.size() && !second.empty()) {
- auto diag = emitSilenceableError()
- << "splitting does not produce the second part for a subset "
- "of targets";
- diag.attachNote() << "expected splitting to produce the second part of all "
- "or none of the targets";
- diag.attachNote(noSecondPart->getLoc())
- << "first target with no second part";
- return diag;
- }
+ // Append any leftover parts to the end of the result list.
+ if (tail)
+ opList.push_back(tail.getOperation());
+ results.set(cast<OpResult>(getFirst()), opList);
+ results.set(cast<OpResult>(getSecond()), {});
- results.set(cast<OpResult>(getFirst()), first);
- results.set(cast<OpResult>(getSecond()), second);
+ } else {
+ // Split each target operation.
+ SmallVector<Operation *> first, second;
+ Operation *noSecondPart = nullptr;
+ for (const auto &pair : llvm::zip(payload, chunkSizes)) {
+ Operation *target = std::get<0>(pair);
+ Location loc = target->getLoc();
+ LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
+ DiagnosedSilenceableFailure diag =
+ checkStructuredOpAndDimensions(linalgOp, target->getLoc());
+
+ if (diag.isSilenceableFailure())
+ return diag;
+
+ rewriter.setInsertionPoint(linalgOp);
+ std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
+ rewriter, cast<TilingInterface>(linalgOp.getOperation()),
+ getDimension(), std::get<1>(pair));
+
+ // Propagate errors.
+ DiagnosedSilenceableFailure diagSplit =
+ checkFailureInSplitting(!first.back() && !second.back(), loc);
+ if (diagSplit.isDefiniteFailure())
+ return diag;
+
+ // Do not add null second parts.
+ if (!second.back()) {
+ noSecondPart = target;
+ second.pop_back();
+ }
+ }
+
+ if (second.size() != first.size() && !second.empty()) {
+ auto diag = emitSilenceableError()
+ << "splitting does not produce the second part for a subset "
+ "of targets";
+ diag.attachNote()
+ << "expected splitting to produce the second part of all "
+ "or none of the targets";
+ diag.attachNote(noSecondPart->getLoc())
+ << "first target with no second part";
+ return diag;
+ }
+
+ results.set(cast<OpResult>(getFirst()), first);
+ results.set(cast<OpResult>(getSecond()), second);
+ }
return DiagnosedSilenceableFailure::success();
}
void SplitOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- consumesHandle(getTarget(), effects);
- if (getDynamicSplitPoint())
- onlyReadsHandle(getDynamicSplitPoint(), effects);
- producesHandle(getResults(), effects);
+ consumesHandle(getTargetMutable(), effects);
+ if (getDynamicChunkSizes())
+ onlyReadsHandle(getDynamicChunkSizesMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
modifiesPayload(effects);
}
ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
- OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
- IntegerAttr staticSplitPoint;
+ OpAsmParser::UnresolvedOperand target, dynamicChunkSizes;
+ IntegerAttr staticChunkSizes;
if (parser.parseOperand(target) || parser.parseKeyword("after"))
return failure();
OptionalParseResult dynamicPointParseResult =
- parser.parseOptionalOperand(dynamicSplitPoint);
+ parser.parseOptionalOperand(dynamicChunkSizes);
if (!dynamicPointParseResult.has_value()) {
- int64_t staticSplitPointValue;
- if (failed(parser.parseInteger(staticSplitPointValue)))
+ int64_t staticChunkSizesValue;
+ if (failed(parser.parseInteger(staticChunkSizesValue)))
return failure();
- staticSplitPoint =
- parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
+ staticChunkSizes =
+ parser.getBuilder().getI64IntegerAttr(staticChunkSizesValue);
}
Type targetType;
@@ -2392,43 +2473,43 @@ ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();
}
if (dynamicPointParseResult.has_value()) {
- Type splitPointType;
+ Type ChunkSizesType;
if (failed(*dynamicPointParseResult) || parser.parseComma() ||
- parser.parseType(splitPointType) ||
- parser.resolveOperand(dynamicSplitPoint, splitPointType,
+ parser.parseType(ChunkSizesType) ||
+ parser.resolveOperand(dynamicChunkSizes, ChunkSizesType,
result.operands)) {
return failure();
}
- staticSplitPoint =
+ staticChunkSizes =
parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
}
result.addAttribute(
- SplitOp::getStaticSplitPointAttrName(result.name).getValue(),
- staticSplitPoint);
+ SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
+ staticChunkSizes);
result.addTypes({targetType, targetType});
return success();
}
void SplitOp::print(OpAsmPrinter &printer) {
printer << " " << getTarget() << " after ";
- int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint());
- if (staticSplitSize != ShapedType::kDynamic)
- printer << staticSplitSize;
+ int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes());
+ if (staticChunkSize != ShapedType::kDynamic)
+ printer << staticChunkSize;
else
- printer << getDynamicSplitPoint();
+ printer << getDynamicChunkSizes();
printer << " ";
printer.printOptionalAttrDict(getOperation()->getAttrs(),
- {getStaticSplitPointAttrName()});
+ {getStaticChunkSizesAttrName()});
printer << " : " << getTarget().getType();
- if (staticSplitSize == ShapedType::kDynamic)
- printer << ", " << getDynamicSplitPoint().getType();
+ if (staticChunkSize == ShapedType::kDynamic)
+ printer << ", " << getDynamicChunkSizes().getType();
}
LogicalResult SplitOp::verify() {
- if ((static_cast<int64_t>(getStaticSplitPoint()) != ShapedType::kDynamic) ^
- (getDynamicSplitPoint() == nullptr)) {
+ if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
+ (getDynamicChunkSizes() == nullptr)) {
return emitOpError() << "expects either a dynamic or a static split "
"point to be provided";
}
@@ -2525,8 +2606,10 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
return emitDefaultSilenceableFailure(target);
for (Value initValue : result->initialValues)
results.push_back(initValue.getDefiningOp());
- results.push_back(result->parallelTiledOp);
- results.push_back(result->mergeOp);
+ for (auto parallelTiledOp : result->parallelTiledOps)
+ results.push_back(parallelTiledOp);
+ for (auto mergeOp : result->mergeOps)
+ results.push_back(mergeOp);
results.push_back(result->loops.front());
return DiagnosedSilenceableFailure::success();
}
@@ -2577,13 +2660,162 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
}
for (Value initValue : result->initialValues)
results.push_back(initValue.getDefiningOp());
- results.push_back(result->parallelTiledOp);
- results.push_back(result->mergeOp);
+ for (auto parallelTiledOp : result->parallelTiledOps)
+ results.push_back(parallelTiledOp);
+ for (auto mergeOp : result->mergeOps)
+ results.push_back(mergeOp);
results.push_back(result->loops);
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
+// ContinuousTileSizesOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
+ TransformResults &transformResults,
+ TransformState &state) {
+
+ SmallVector<Operation *> targetOps =
+ llvm::to_vector(state.getPayloadOps(getTarget()));
+
+ if (!llvm::hasSingleElement(targetOps)) {
+ return mlir::emitSilenceableFailure(getLoc())
+ << "requires exactly one target (got " << llvm::range_size(targetOps)
+ << ")";
+ }
+
+ Operation *target = *targetOps.begin();
+ auto linalgOp = dyn_cast<LinalgOp>(target);
+ auto tileableOp = dyn_cast<TilingInterface>(target);
+
+ if (!linalgOp)
+ return emitDefiniteFailure() << "expected Linalg Op";
+
+ OpBuilder builder(linalgOp.getContext());
+
+ if (isa<TransformParamTypeInterface>(getChunkSizes().getType())) {
+ if (linalgOp.hasDynamicShape()) {
+ auto diag = emitSilenceableError()
+ << "cannot compute parametric tile sizes for dynamically "
+ "shaped payload op";
+ diag.attachNote(linalgOp->getLoc()) << "payload op";
+ return diag;
+ }
+
+ FailureOr<StaticContinuousTileSizeSpecification> spec =
+ computeStaticContinuousTileSizes(linalgOp, getDimension(),
+ getTargetSize());
+ if (failed(spec)) {
+ return emitSilenceableError()
+ << "failed to compute multi-size tiling sizes";
+ }
+
+ SmallVector<int64_t> chunkSizes;
+
+ for (auto &&[tileSize, tripCount] :
+ llvm::zip_equal(spec->tileSizes, spec->tripCounts))
+ chunkSizes.push_back(tileSize * tripCount);
+
+ auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
+ return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
+ return builder.getI64IntegerAttr(value);
+ });
+ };
+ transformResults.setParams(cast<OpResult>(getTileSizes()),
+ getI64AttrsFromI64(spec->tileSizes));
+ transformResults.setParams(cast<OpResult>(getChunkSizes()),
+ getI64AttrsFromI64(chunkSizes));
+
+ return DiagnosedSilenceableFailure::success();
+ }
+
+ builder.setInsertionPoint(linalgOp);
+
+ OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
+ unsigned dimension = getDimension();
+
+ FailureOr<ContinuousTileSizeSpecification> spec = computeContinuousTileSizes(
+ builder, tileableOp, dimension, targetSize, true);
+ if (failed(spec)) {
+ return emitSilenceableError() << "could not generate tile size computation";
+ }
+
+ AffineExpr s0 = builder.getAffineSymbolExpr(0);
+ AffineExpr s1 = builder.getAffineSymbolExpr(1);
+ auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
+ return affine::makeComposedAffineApply(builder, linalgOp->getLoc(), expr,
+ ofrs);
+ };
+
+ SmallVector<Value> chunkSizes;
+ Value splitPoint;
+ for (auto &&[tileSize, tripCount] :
+ llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
+ splitPoint = apply(s0 * s1, {tileSize, tripCount});
+ chunkSizes.push_back(splitPoint);
+ }
+
+ auto getDefiningOps = [&](ArrayRef<Value> values) {
+ return llvm::map_to_vector(values, [&](Value value) -> Operation * {
+ return value.getDefiningOp();
+ });
+ };
+
+ transformResults.set(cast<OpResult>(getTileSizes()),
+ getDefiningOps(spec->tileSizes));
+ transformResults.set(cast<OpResult>(getChunkSizes()),
+ getDefiningOps(chunkSizes));
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::ContinuousTileSizesOp::verify() {
+
+ if (getTileSizes().getType() != getChunkSizes().getType()) {
+ return emitOpError() << "expects all results type to be the same";
+ }
+
+ return success();
+}
+
+void transform::ContinuousTileSizesOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ if (isa<TransformParamTypeInterface>(getTileSizes().getType()))
+ onlyReadsPayload(effects);
+ else
+ modifiesPayload(effects);
+ onlyReadsHandle(getTargetMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+}
+
+static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op,
+ Type targetType, Type tile_sizes,
+ Type) {
+ printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes});
+}
+
+static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
+ Type &targetType,
+ Type &tileSizesType,
+ Type &chunkSizesType) {
+ FunctionType funcType;
+ llvm::SMLoc typeLoc = parser.getCurrentLocation();
+ if (failed(parser.parseType<FunctionType>(funcType)))
+ return failure();
+
+ if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
+ parser.emitError(typeLoc) << "expects a trailing functional type with one "
+ "argument and one result";
+ }
+ targetType = funcType.getInput(0);
+ tileSizesType = chunkSizesType = funcType.getResult(0);
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// TileUsingForOp
//===----------------------------------------------------------------------===//
@@ -2827,10 +3059,9 @@ SmallVector<OpFoldResult> transform::TileUsingForOp::getMixedSizes() {
void transform::TileUsingForOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- consumesHandle(getTarget(), effects);
- onlyReadsHandle(getDynamicSizes(), effects);
- producesHandle(getTiledLinalgOp(), effects);
- producesHandle(getLoops(), effects);
+ consumesHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getDynamicSizesMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
modifiesPayload(effects);
}
@@ -2995,12 +3226,12 @@ DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
void transform::TileUsingForallOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- consumesHandle(getTarget(), effects);
- onlyReadsHandle(getTileSizes(), effects);
- onlyReadsHandle(getNumThreads(), effects);
- onlyReadsHandle(getPackedNumThreads(), effects);
- onlyReadsHandle(getPackedTileSizes(), effects);
- producesHandle(getResults(), effects);
+ consumesHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getTileSizesMutable(), effects);
+ onlyReadsHandle(getNumThreadsMutable(), effects);
+ onlyReadsHandle(getPackedNumThreadsMutable(), effects);
+ onlyReadsHandle(getPackedTileSizesMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
modifiesPayload(effects);
}
@@ -3178,8 +3409,8 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
void transform::VectorizeOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- consumesHandle(getTarget(), effects);
- onlyReadsHandle(getVectorSizes(), effects);
+ consumesHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getVectorSizesMutable(), effects);
modifiesPayload(effects);
}