//===- TensorTransformOps.cpp - Implementation of tensor transform ops ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/IR/Builders.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; using namespace tensor; //===----------------------------------------------------------------------===// // FindPayloadReplacementOpInterface implementations //===----------------------------------------------------------------------===// namespace { struct ExtractSliceOpReplacementInterface : public transform::FindPayloadReplacementOpInterface::ExternalModel< ExtractSliceOpReplacementInterface, tensor::ExtractSliceOp> { SmallVector getNextOperands(Operation *op) const { auto extractSliceOp = cast(op); if (!isCastLikeExtractSliceOp(extractSliceOp)) return {}; return {extractSliceOp.getSource()}; } }; struct InsertSliceOpReplacementInterface : public transform::FindPayloadReplacementOpInterface::ExternalModel< InsertSliceOpReplacementInterface, tensor::InsertSliceOp> { SmallVector getNextOperands(Operation *op) const { auto insertSliceOp = cast(op); if (!isCastLikeInsertSliceOp(insertSliceOp)) return {}; return {insertSliceOp.getSource()}; } }; struct ReshapeOpReplacementInterface : public transform::FindPayloadReplacementOpInterface::ExternalModel< ReshapeOpReplacementInterface, tensor::ReshapeOp> { SmallVector getNextOperands(Operation *op) const { auto reshapeOp = cast(op); return {reshapeOp.getSource()}; } }; template struct ReassociativeReshapeOpReplacementInterface : public transform::FindPayloadReplacementOpInterface::ExternalModel< ReassociativeReshapeOpReplacementInterface, ConcreteOp> { SmallVector getNextOperands(Operation *op) const { auto reshapeOp = cast(op); return {reshapeOp.getSrc()}; } }; } // namespace void tensor::registerFindPayloadReplacementOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { CollapseShapeOp::attachInterface< ReassociativeReshapeOpReplacementInterface>(*ctx); ExpandShapeOp::attachInterface< ReassociativeReshapeOpReplacementInterface>(*ctx); ExtractSliceOp::attachInterface(*ctx); InsertSliceOp::attachInterface(*ctx); ReshapeOp::attachInterface(*ctx); }); } //===----------------------------------------------------------------------===// // Apply...PatternsOp //===----------------------------------------------------------------------===// void transform::ApplyDecomposeTensorConcatPatternsOp::populatePatterns( RewritePatternSet &patterns) { tensor::populateDecomposeTensorConcatPatterns(patterns); } void transform::ApplyDropRedundantInsertSliceRankExpansionPatternsOp:: populatePatterns(RewritePatternSet &patterns) { tensor::populateDropRedundantInsertSliceRankExpansionPatterns(patterns); } void transform::ApplyFoldTensorEmptyPatternsOp::populatePatterns( RewritePatternSet &patterns) { tensor::populateFoldTensorEmptyPatterns(patterns, getFoldSingleUseOnly()); } void transform::ApplyFoldTensorSubsetOpsPatternsOp::populatePatterns( RewritePatternSet &patterns) { tensor::populateFoldTensorSubsetOpPatterns(patterns); } void transform::ApplyFoldTensorSubsetOpsIntoVectorTransfersPatternsOp:: populatePatterns(RewritePatternSet &patterns) { tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(patterns); } void transform::ApplyMergeConsecutiveInsertExtractSlicePatternsOp:: populatePatterns(RewritePatternSet &patterns) { tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); } void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns( RewritePatternSet &patterns) { tensor::populateReassociativeReshapeFoldingPatterns(patterns); } void transform::ApplyBubbleUpExtractSlicePatternsOp::populatePatterns( RewritePatternSet &patterns) { tensor::populateBubbleUpExtractSliceOpPatterns(patterns); } void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns( RewritePatternSet &patterns) { ControlFoldFn defaultControlFn = [](OpOperand *fusedOperand) { Operation *producer = fusedOperand->get().getDefiningOp(); return producer && producer->hasOneUse(); }; ControlFoldFn aggressiveControlFn = [](OpOperand *fusedOperand) { return true; }; // Add folding with reshape by expansion patterns. if (getAggressive()) tensor::populateRewriteAsConstantPatterns(patterns, aggressiveControlFn); else tensor::populateRewriteAsConstantPatterns(patterns, defaultControlFn); } //===----------------------------------------------------------------------===// // TypeConversionCastTensorShapeOp //===----------------------------------------------------------------------===// void transform::TypeConversionCastShapeDynamicDimsOp:: populateTypeMaterializations(TypeConverter &converter) { bool ignoreDynamicInfo = getIgnoreDynamicInfo(); converter.addSourceMaterialization([ignoreDynamicInfo]( OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) -> Value { if (inputs.size() != 1) { return Value(); } Value input = inputs[0]; if (!ignoreDynamicInfo && !tensor::preservesStaticInformation(resultType, input.getType())) { return Value(); } if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) { return Value(); } return tensor::CastOp::create(builder, loc, resultType, input).getResult(); }); converter.addTargetMaterialization([](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) -> Value { if (inputs.size() != 1) { return Value(); } Value input = inputs[0]; if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) { return Value(); } return tensor::CastOp::create(builder, loc, resultType, input).getResult(); }); } //===----------------------------------------------------------------------===// // MakeLoopIndependentOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MakeLoopIndependentOp::applyToOne( transform::TransformRewriter &rewriter, Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { // Gather IVs. SmallVector ivs; Operation *nextOp = target; for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) { nextOp = nextOp->getParentOfType(); if (!nextOp) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "could not find " << i << "-th enclosing loop"; diag.attachNote(target->getLoc()) << "target op"; return diag; } ivs.push_back(cast(nextOp).getInductionVar()); } // Rewrite IR. FailureOr replacement = failure(); if (auto padOp = dyn_cast(target)) { replacement = tensor::buildIndependentOp(rewriter, padOp, ivs); } else if (auto emptyOp = dyn_cast(target)) { replacement = tensor::buildIndependentOp(rewriter, emptyOp, ivs); } else { DiagnosedSilenceableFailure diag = emitSilenceableError() << "unsupported target op"; diag.attachNote(target->getLoc()) << "target op"; return diag; } if (failed(replacement)) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "could not make target op loop-independent"; diag.attachNote(target->getLoc()) << "target op"; return diag; } rewriter.replaceOp(target, *replacement); results.push_back(replacement->getDefiningOp()); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// namespace { class TensorTransformDialectExtension : public transform::TransformDialectExtension< TensorTransformDialectExtension> { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TensorTransformDialectExtension) using Base::Base; void init() { declareGeneratedDialect(); declareGeneratedDialect(); registerTransformOps< #define GET_OP_LIST #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc" >(); } }; } // namespace #define GET_OP_CLASSES #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc" void mlir::tensor::registerTransformDialectExtension( DialectRegistry ®istry) { registry.addExtensions(); }