//===- TestTilingInterfaceTransformOps.cpp - Test `TilingInterface` ------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines transform dialect operations used for testing // TilingInterface // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" #include "mlir/Dialect/Transform/IR/TransformAttrs.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/TilingInterface.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "test-tiling-interface" #define GET_OP_CLASSES #include "TestTilingInterfaceTransformOps.h.inc" using namespace mlir; using namespace mlir::transform; //===----------------------------------------------------------------------===// // TestFuseAndYieldOp //===----------------------------------------------------------------------===// static llvm::SmallDenseSet collectTiledAndFusedOps(Operation *op) { SmallVector worklist; llvm::SmallDenseSet producers; worklist.push_back(op); producers.insert(op); while (!worklist.empty()) { Operation *current = worklist.pop_back_val(); for (OpOperand &operand : current->getOpOperands()) { Operation *producer = operand.get().getDefiningOp(); if (!producer || !isa(producer) || producers.contains(producer)) continue; worklist.push_back(producer); producers.insert(producer); } } return producers; } /// Apply a tile and fuse transformation to all payload ops and store both the /// tiled operation as well as the created tile loops. template static LogicalResult applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, scf::SCFTilingOptions tilingOptions, TransformResults &transformResults) { SmallVector tiledOps; SmallVector> loopOps(numLoops); for (Operation *target : payloadOps) { auto tilingInterfaceOp = dyn_cast(target); if (!tilingInterfaceOp) return transformOp->emitError("only TilingInterface ops are supported"); DominanceInfo dominanceInfo(tilingInterfaceOp); llvm::SmallDenseSet tiledAndFusedOps = collectTiledAndFusedOps(tilingInterfaceOp); llvm::DenseSet yieldReplacementsFor; for (auto op : tiledAndFusedOps) { if (llvm::any_of(op->getUsers(), [&](Operation *user) { return dominanceInfo.properlyDominates(tilingInterfaceOp, user); })) { yieldReplacementsFor.insert(op); } } scf::SCFTileAndFuseOptions tileAndFuseOptions; tileAndFuseOptions.setTilingOptions(tilingOptions); scf::SCFTileAndFuseOptions::ControlFnTy controlFn = [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer, bool isDestinationOperand) -> std::optional { Operation *owner = originalProducer.getOwner(); bool yieldProducerReplacement = yieldReplacementsFor.contains(owner); return scf::SCFTileAndFuseOptions::ControlFnResult{ yieldProducerReplacement}; }; tileAndFuseOptions.setFusionControlFn(controlFn); rewriter.setInsertionPoint(target); FailureOr tiledResults = scf::tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp, tileAndFuseOptions); if (failed(tiledResults)) return failure(); // Perform the replacement of tiled and fused values. SmallVector opsToReplace{target}; llvm::append_range(opsToReplace, tiledResults->fusedProducers); for (Operation *toReplace : opsToReplace) { for (OpResult res : toReplace->getResults()) if (auto replacement = tiledResults->replacements.lookup(res)) { Operation *replacementOp = replacement.getDefiningOp(); rewriter.replaceUsesWithIf(res, replacement, [&](OpOperand &use) { Operation *user = use.getOwner(); return dominanceInfo.properlyDominates(replacementOp, user) && user->getParentOp() == replacementOp->getParentOp(); }); } if (toReplace->use_empty()) { rewriter.eraseOp(toReplace); } } // Report back the relevant handles to the transform op. tiledOps.push_back(tiledResults->tiledAndFusedOps.front()); assert(tiledResults->loops.size() == numLoops && "Mismatched number of loops, tile and fuse transform should have " "failed"); for (unsigned int i = 0; i < numLoops; ++i) loopOps[i].push_back(tiledResults->loops[i]); } transformResults.set(transformOp->getOpResult(0), tiledOps); for (unsigned int i = 0; i < numLoops; ++i) transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); return success(); } DiagnosedSilenceableFailure transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter, TransformResults &transformResults, TransformState &state) { SmallVector tileSizes = extractFromIntegerArrayAttr(getTileSizes()); SmallVector tileInterchange = extractFromIntegerArrayAttr(getTileInterchange()); SmallVector tileSizesOfr = getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); scf::SCFTilingOptions tilingOptions; tilingOptions.setTileSizes(tileSizesOfr).setInterchange(tileInterchange); if (getUseForall()) { tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); } LogicalResult result = applyTileAndFuseToAll( rewriter, getOperation(), state.getPayloadOps(getTarget()), tileSizes.size() - llvm::count(tileSizes, 0), tilingOptions, transformResults); return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() : DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // TestFuseConsumerOp //===----------------------------------------------------------------------===// /// Fuse the consumer and store both the original consumer operation as well as /// the fused consumer operation. static LogicalResult applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp, Operation *consumer, MutableArrayRef loops, TransformResults &transformResults) { SmallVector fusedConsumerOps; rewriter.setInsertionPoint(consumer); FailureOr fuseConsumerResults = scf::tileAndFuseConsumer(rewriter, consumer, loops); if (failed(fuseConsumerResults)) return consumer->emitOpError("failed to fuse consumer of slice"); // Report back the relevant handles to the transform op. for (OpOperand *tiledAndFusedConsumerOperand : fuseConsumerResults->tiledAndFusedConsumerOperands) { fusedConsumerOps.push_back(tiledAndFusedConsumerOperand->getOwner()); } transformResults.set(transformOp->getOpResult(0), fusedConsumerOps); for (auto [index, loop] : llvm::enumerate(loops)) { transformResults.set(transformOp->getOpResult(index + 1), {loop}); } return success(); } DiagnosedSilenceableFailure transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter, TransformResults &transformResults, TransformState &state) { Operation *consumer = *state.getPayloadOps(getConsumer()).begin(); SmallVector loops; // Since the matcher works inside-out, we need to iterate the loops in // reverse. for (auto loop : llvm::reverse(getLoops())) { auto loopLikeOp = dyn_cast(*state.getPayloadOps(loop).begin()); if (!loopLikeOp) { return DiagnosedSilenceableFailure::definiteFailure(); } loops.push_back(loopLikeOp); } LogicalResult result = applyFuseConsumer(rewriter, getOperation(), consumer, loops, transformResults); return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() : DiagnosedSilenceableFailure::success(); } void transform::TestFuseConsumerOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getConsumerMutable(), effects); consumesHandle(getLoopsMutable(), effects); producesHandle(getOperation()->getOpResults(), effects); modifiesPayload(effects); } //===----------------------------------------------------------------------===// // TestFuseConsumerUsingSliceOp //===----------------------------------------------------------------------===// /// Apply fusing of consumer transformation to all payload ops and store both /// the original consumer operation as well as the fused consumer operation. static LogicalResult applyFuseConsumerUsingSlices( RewriterBase &rewriter, Operation *transformOp, ArrayRef slices, MutableArrayRef loops, uint32_t numConsumerToFuse, TransformResults &transformResults) { SmallVector originalConsumerOps; SmallVector fusedConsumerOps; rewriter.setInsertionPoint(slices.front()); while (numConsumerToFuse--) { FailureOr fuseConsumerResults = scf::tileAndFuseConsumerOfSlices(rewriter, slices, loops); if (failed(fuseConsumerResults)) return slices.front()->emitOpError("failed to fuse consumer of slice"); // Report back the relevant handles to the transform op. for (OpOperand *origConsumerOperand : fuseConsumerResults->origConsumerOperands) { originalConsumerOps.push_back(origConsumerOperand->getOwner()); } for (OpOperand *tiledAndFusedConsumerOperand : fuseConsumerResults->tiledAndFusedConsumerOperands) { fusedConsumerOps.push_back(tiledAndFusedConsumerOperand->getOwner()); } } transformResults.set(transformOp->getOpResult(0), originalConsumerOps); transformResults.set(transformOp->getOpResult(1), fusedConsumerOps); return success(); } DiagnosedSilenceableFailure transform::TestFuseConsumerUsingSliceOp::apply( TransformRewriter &rewriter, TransformResults &transformResults, TransformState &state) { SmallVector slices; for (auto op : getTargets()) { auto sliceOp = *state.getPayloadOps(op).begin(); slices.push_back(sliceOp); } SmallVector loops; for (auto op : llvm::reverse(getLoops())) { auto loopLikeOp = dyn_cast(*state.getPayloadOps(op).begin()); if (!loopLikeOp) { return DiagnosedSilenceableFailure::definiteFailure(); } loops.push_back(loopLikeOp); } LogicalResult result = applyFuseConsumerUsingSlices(rewriter, getOperation(), slices, loops, getNumConsumerToFuse(), transformResults); return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() : DiagnosedSilenceableFailure::success(); } void transform::TestFuseConsumerUsingSliceOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getTargetsMutable(), effects); consumesHandle(getLoopsMutable(), effects); producesHandle(getOperation()->getOpResults(), effects); modifiesPayload(effects); } //===----------------------------------------------------------------------===// // TestTileUsingForallOp //===----------------------------------------------------------------------===// /// Apply a tiling transformation to all payload ops and store both the /// tiled operation as well as the created tile loops. template static LogicalResult applyTileToAll(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, ArrayRef tileSizes, ArrayRef interchange, std::optional mapping, TransformResults &transformResults) { SmallVector tiledOps; SmallVector loopOps; for (Operation *target : payloadOps) { auto tilingInterfaceOp = dyn_cast(target); if (!tilingInterfaceOp) return transformOp->emitError("only TilingInterface ops are supported"); scf::SCFTilingOptions tilingOptions; tilingOptions.setTileSizes(tileSizes).setInterchange(interchange); if (mapping) { tilingOptions.setMapping(mapping.value().getValue()); } tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); rewriter.setInsertionPoint(target); FailureOr tiledResults = scf::tileUsingSCF(rewriter, tilingInterfaceOp, tilingOptions); if (failed(tiledResults)) return failure(); // Perform the replacement of tiled and fused values. rewriter.replaceOp(tilingInterfaceOp, tiledResults->replacements); // Report back the relevant handles to the transform op. tiledOps.push_back(tiledResults->tiledOps.front()); for (Operation *loop : tiledResults->loops) loopOps.push_back(loop); } transformResults.set(transformOp->getOpResult(0), tiledOps); for (auto [index, loop] : llvm::enumerate(loopOps)) transformResults.set(transformOp->getOpResult(index + 1), {loop}); return success(); } DiagnosedSilenceableFailure transform::TestTileUsingForallOp::apply(TransformRewriter &rewriter, TransformResults &transformResults, TransformState &state) { SmallVector tileSizes = extractFromIntegerArrayAttr(getTileSizes()); SmallVector interchange = extractFromIntegerArrayAttr(getInterchange()); SmallVector tileSizesOfr = getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); LogicalResult result = applyTileToAll(rewriter, getOperation(), state.getPayloadOps(getTarget()), tileSizesOfr, interchange, getMapping(), transformResults); return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() : DiagnosedSilenceableFailure::success(); } void transform::TestTileUsingForallOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getTargetMutable(), effects); producesHandle(getOperation()->getOpResults(), effects); modifiesPayload(effects); } //===----------------------------------------------------------------------===// // TestFuseUsingForallOp //===----------------------------------------------------------------------===// /// Apply a tiling transformation to all payload ops and store both the /// tiled operation as well as the created tile loops. template static LogicalResult applyTilingToAll( RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, TransformResults &transformResults, function_ref(TilingInterface)> applyFn) { SmallVector tiledLinalgOps; SmallVector> loopOps(1); for (Operation *target : payloadOps) { auto tilingInterfaceOp = dyn_cast(target); if (!tilingInterfaceOp) return transformOp->emitError("only TilingInterface ops are supported"); rewriter.setInsertionPoint(target); FailureOr tiledResults = applyFn(tilingInterfaceOp); if (failed(tiledResults)) return failure(); // Perform the replacement of tiled and fused values. SmallVector opsToReplace{target}; llvm::append_range(opsToReplace, tiledResults->fusedProducers); for (Operation *toReplace : opsToReplace) { for (OpResult res : toReplace->getResults()) if (auto replacement = tiledResults->replacements.lookup(res)) rewriter.replaceAllUsesWith(res, replacement); if (toReplace->use_empty()) rewriter.eraseOp(toReplace); } // Report back the relevant handles to the transform op. tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front()); assert(tiledResults->loops.size() == 1 && cast(tiledResults->loops[0]).getRank() == numLoops && "Mismatched number of loops, tile and fuse transform should have " "failed"); loopOps[0] = {tiledResults->loops[0]}; } transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); if (!loopOps.empty()) transformResults.set(transformOp->getOpResult(1), loopOps[0]); return success(); } DiagnosedSilenceableFailure transform::TestFuseUsingForallOp::apply(TransformRewriter &rewriter, TransformResults &transformResults, TransformState &state) { SmallVector tileSizes = extractFromIntegerArrayAttr(getTileSizes()); SmallVector tileInterchange = extractFromIntegerArrayAttr(getInterchange()); scf::SCFTilingOptions tilingOptions; tilingOptions.interchangeVector = tileInterchange; SmallVector tileSizesOfr = getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); tilingOptions = tilingOptions.setTileSizes(tileSizesOfr); tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); scf::SCFTileAndFuseOptions tileAndFuseOptions; tileAndFuseOptions.tilingOptions = tilingOptions; LogicalResult result = applyTilingToAll( rewriter, getOperation(), state.getPayloadOps(getRootOp()), tileSizes.size() - llvm::count(tileSizes, 0), transformResults, [&](TilingInterface tilingInterfaceOp) -> FailureOr { return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp, tileAndFuseOptions); }); return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() : DiagnosedSilenceableFailure::success(); } void transform::TestFuseUsingForallOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getRootOpMutable(), effects); producesHandle(getOperation()->getOpResults(), effects); modifiesPayload(effects); } //===----------------------------------------------------------------------===// // TestTileAndFuseOuterParallelPartialReduction //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::TestTileAndFuseOuterParallelPartialReductionOp::apply( TransformRewriter &rewriter, TransformResults &transformResults, TransformState &state) { auto target = dyn_cast(*state.getPayloadOps(getRootOp()).begin()); if (!target) { emitOpError("expected root operation to implement `TilingInterface`"); return DiagnosedSilenceableFailure::definiteFailure(); } SmallVector reductionDims = extractFromIntegerArrayAttr(getReductionDims()); if (reductionDims.empty()) { for (auto [index, iterator] : llvm::enumerate(target.getLoopIteratorTypes())) if (iterator == utils::IteratorType::reduction) reductionDims.push_back(index); } if (reductionDims.empty()) { emitOpError( "no reduction dimension specified or found in the target operation"); return DiagnosedSilenceableFailure::definiteFailure(); } SmallVector reductionTileSizes = extractFromIntegerArrayAttr(getTileSizes()); if (reductionTileSizes.size() != reductionDims.size()) { emitOpError( "missing tile sizes for reduction dimensions that are to be tiled"); return DiagnosedSilenceableFailure::definiteFailure(); } // Adjust tile sizes so that it corresponds to the reduction iterator types. SmallVector tileSizes; int reductionTileSizeNum = 0; OpFoldResult zero = rewriter.getIndexAttr(0); for (auto iterator : target.getLoopIteratorTypes()) { if (iterator == utils::IteratorType::parallel) { tileSizes.push_back(zero); continue; } tileSizes.push_back( rewriter.getIndexAttr(reductionTileSizes[reductionTileSizeNum++])); } scf::SCFTilingOptions tilingOptions; tilingOptions.setTileSizes(tileSizes) .setLoopType(scf::SCFTilingOptions::LoopType::ForallOp) .setReductionTilingStrategy( ReductionTilingStrategy::PartialReductionOuterParallel) .setReductionDims(reductionDims); if (auto mapping = getMapping()) { tilingOptions.setMapping(getMapping().value()); } LogicalResult result = applyTileAndFuseToAll( rewriter, getOperation(), state.getPayloadOps(getRootOp()), /*numLoops =*/1, tilingOptions, transformResults); return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() : DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // TestTileAndFuseOuterParallelPartialReduction //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::TestTileUsingCustomLoopOp::apply( TransformRewriter &transformRewriter, TransformResults &transformResults, TransformState &state) { auto target = dyn_cast(*state.getPayloadOps(getRootOp()).begin()); if (!target) { emitOpError("expected root operation to implement `TilingInterface`"); return DiagnosedSilenceableFailure::definiteFailure(); } OpFoldResult oneOfr = transformRewriter.getIndexAttr(1); scf::SCFTilingOptions::GenerateLoopHeaderFn loopHeaderFn = [&](RewriterBase &rewriter, Location loc, ArrayRef loopRanges, ArrayRef givenTileSizes, ValueRange outerDestinationTensors) -> FailureOr { // Check that the strides are all 1 (to make it easier in the test). if (llvm::any_of(loopRanges, [](Range r) { return !isConstantIntValue(r.stride, 1); })) { return emitOpError("unable to handle loop ranges with strides != 1"); } // Check number of tile sizes is equal to loop dimensions. if (loopRanges.size() != givenTileSizes.size()) { return emitOpError("expected number of tile sizes to be same as the " "number of loops in the operation"); } // For testing disallow any of the tile sizes being 0. if (llvm::any_of(givenTileSizes, isZeroInteger)) { return emitOpError("unhandled case of zero tile size"); } // For testing, only handle tensor tiling. if (outerDestinationTensors.empty()) { return emitOpError("expected destination tensors"); } // Compute the number of iterations for each of the loops. AffineExpr s0, s1, s2; bindSymbols(rewriter.getContext(), s0, s1, s2); AffineExpr numItersExpr = (s1 - s0).ceilDiv(s2); // (ub - lb) / tileSize SmallVector allNumIters; allNumIters.reserve(loopRanges.size()); for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, givenTileSizes)) { OpFoldResult numIters = affine::makeComposedFoldedAffineApply( rewriter, loc, numItersExpr, {loopRange.offset, loopRange.size, tileSize}); allNumIters.push_back(numIters); } if (allNumIters.empty()) { return emitOpError("invalid empty tile sizes and loop ranges"); } AffineExpr mulExpr = s0 * s1; OpFoldResult cumulative = oneOfr; for (auto numIters : allNumIters) { cumulative = affine::makeComposedFoldedAffineApply( rewriter, loc, mulExpr, {cumulative, numIters}); } Value zeroVal = arith::ConstantIndexOp::create(rewriter, loc, 0); Value oneVal = arith::ConstantIndexOp::create(rewriter, loc, 1); Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, cumulative); SmallVector offsets; SmallVector sizes; SmallVector innerDestinationTensors; offsets.reserve(loopRanges.size()); sizes.reserve(loopRanges.size()); AffineExpr d0; bindDims(rewriter.getContext(), d0); AffineExpr offsetExpr = s0 + d0 * s1; // lb + iv * tileSize AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, rewriter.getContext()); // min(ub - offset, tileSize) auto forOp = scf::ForOp::create( rewriter, loc, zeroVal, ub, oneVal, outerDestinationTensors, [&](OpBuilder &b, Location bodyLoc, Value linearizedIv, ValueRange destinations) { auto delinearizeOp = affine::AffineDelinearizeIndexOp::create( b, bodyLoc, linearizedIv, allNumIters); for (auto [normalizedIv, range, tileSize] : llvm::zip_equal( delinearizeOp.getResults(), loopRanges, givenTileSizes)) { OpFoldResult normalizedIvOfr = getAsOpFoldResult(normalizedIv); OpFoldResult offset = affine::makeComposedFoldedAffineApply( b, bodyLoc, offsetExpr, {normalizedIvOfr, range.offset, tileSize}); offsets.push_back(offset); OpFoldResult size = affine::makeComposedFoldedAffineMin( b, bodyLoc, minMap, {offset, range.size, tileSize}); sizes.push_back(size); } innerDestinationTensors = llvm::to_vector(destinations); }); rewriter.setInsertionPointToEnd(forOp.getBody()); return scf::SCFTilingOptions::CustomLoopHeaderInfo{ {cast(forOp.getOperation())}, offsets, sizes, innerDestinationTensors}; }; scf::SCFTilingOptions::GenerateLoopTerminatorFn terminatorFn = [&](RewriterBase &rewriter, Location loc, ArrayRef loops, ValueRange tiledResults, ArrayRef> resultOffsets, ArrayRef> resultSizes, ValueRange destinationTensors) -> LogicalResult { SmallVector yieldValues; yieldValues.reserve(destinationTensors.size()); for (auto [tiledResult, offsets, sizes, destination] : llvm::zip_equal( tiledResults, resultOffsets, resultSizes, destinationTensors)) { SmallVector strides(offsets.size(), oneOfr); Value insertedVal = tensor::InsertSliceOp::create( rewriter, loc, tiledResult, destination, offsets, sizes, strides); yieldValues.push_back(insertedVal); } scf::YieldOp::create(rewriter, loc, yieldValues); return success(); }; scf::SCFTilingOptions tilingOptions; SmallVector staticTileSizes = extractFromIntegerArrayAttr(getTileSizes()); SmallVector tileSizes = getAsIndexOpFoldResult(transformRewriter.getContext(), staticTileSizes); tilingOptions.setTileSizes(tileSizes) .setLoopType(scf::SCFTilingOptions::LoopType::CustomOp) .setCustomLoopGenerationFns(loopHeaderFn, terminatorFn); OpBuilder::InsertionGuard g(transformRewriter); transformRewriter.setInsertionPoint(target); FailureOr tiledResults = scf::tileUsingSCF(transformRewriter, target, tilingOptions); if (failed(tiledResults)) { return DiagnosedSilenceableFailure::definiteFailure(); } transformRewriter.replaceOp(target, tiledResults->replacements); transformResults.set(getOperation()->getResult(0), tiledResults->tiledOps); transformResults.set(getOperation()->getResult(1), tiledResults->loops); return DiagnosedSilenceableFailure::success(); } #define GET_OP_CLASSES #include "TestTilingInterfaceTransformOps.cpp.inc" namespace { class TestTilingInterfaceDialectExtension : public transform::TransformDialectExtension< TestTilingInterfaceDialectExtension> { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestTilingInterfaceDialectExtension) using Base::Base; void init() { declareDependentDialect(); declareDependentDialect(); declareDependentDialect(); declareDependentDialect(); registerTransformOps< #define GET_OP_LIST #include "TestTilingInterfaceTransformOps.cpp.inc" >(); } }; } // namespace namespace test { void registerTestTilingInterfaceTransformDialectExtension( DialectRegistry ®istry) { registry.addExtensions(); } } // namespace test