diff options
Diffstat (limited to 'mlir/lib/CAPI/Transforms/Rewrite.cpp')
| -rw-r--r-- | mlir/lib/CAPI/Transforms/Rewrite.cpp | 104 |
1 files changed, 104 insertions, 0 deletions
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index 6f85357a14a1..8ee6308cadf8 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -13,6 +13,8 @@ #include "mlir/CAPI/Rewrite.h" #include "mlir/CAPI/Support.h" #include "mlir/CAPI/Wrap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/PDLPatternMatch.h.inc" #include "mlir/IR/PatternMatch.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -302,6 +304,19 @@ mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op, } //===----------------------------------------------------------------------===// +/// PatternRewriter API +//===----------------------------------------------------------------------===// + +inline mlir::PatternRewriter *unwrap(MlirPatternRewriter rewriter) { + assert(rewriter.ptr && "unexpected null rewriter"); + return static_cast<mlir::PatternRewriter *>(rewriter.ptr); +} + +inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) { + return {rewriter}; +} + +//===----------------------------------------------------------------------===// /// PDLPatternModule API //===----------------------------------------------------------------------===// @@ -331,4 +346,93 @@ mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) { op.ptr = nullptr; return wrap(m); } + +inline const mlir::PDLValue *unwrap(MlirPDLValue value) { + assert(value.ptr && "unexpected null PDL value"); + return static_cast<const mlir::PDLValue *>(value.ptr); +} + +inline MlirPDLValue wrap(const mlir::PDLValue *value) { return {value}; } + +inline mlir::PDLResultList *unwrap(MlirPDLResultList results) { + assert(results.ptr && "unexpected null PDL results"); + return static_cast<mlir::PDLResultList *>(results.ptr); +} + +inline MlirPDLResultList wrap(mlir::PDLResultList *results) { + return {results}; +} + +MlirValue mlirPDLValueAsValue(MlirPDLValue value) { + return wrap(unwrap(value)->dyn_cast<mlir::Value>()); +} + +MlirType mlirPDLValueAsType(MlirPDLValue value) { + return wrap(unwrap(value)->dyn_cast<mlir::Type>()); +} + +MlirOperation mlirPDLValueAsOperation(MlirPDLValue value) { + return wrap(unwrap(value)->dyn_cast<mlir::Operation *>()); +} + +MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value) { + return wrap(unwrap(value)->dyn_cast<mlir::Attribute>()); +} + +void mlirPDLResultListPushBackValue(MlirPDLResultList results, + MlirValue value) { + unwrap(results)->push_back(unwrap(value)); +} + +void mlirPDLResultListPushBackType(MlirPDLResultList results, MlirType value) { + unwrap(results)->push_back(unwrap(value)); +} + +void mlirPDLResultListPushBackOperation(MlirPDLResultList results, + MlirOperation value) { + unwrap(results)->push_back(unwrap(value)); +} + +void mlirPDLResultListPushBackAttribute(MlirPDLResultList results, + MlirAttribute value) { + unwrap(results)->push_back(unwrap(value)); +} + +inline std::vector<MlirPDLValue> wrap(ArrayRef<PDLValue> values) { + std::vector<MlirPDLValue> mlirValues; + mlirValues.reserve(values.size()); + for (auto &value : values) { + mlirValues.push_back(wrap(&value)); + } + return mlirValues; +} + +void mlirPDLPatternModuleRegisterRewriteFunction( + MlirPDLPatternModule pdlModule, MlirStringRef name, + MlirPDLRewriteFunction rewriteFn, void *userData) { + unwrap(pdlModule)->registerRewriteFunction( + unwrap(name), + [userData, rewriteFn](PatternRewriter &rewriter, PDLResultList &results, + ArrayRef<PDLValue> values) -> LogicalResult { + std::vector<MlirPDLValue> mlirValues = wrap(values); + return unwrap(rewriteFn(wrap(&rewriter), wrap(&results), + mlirValues.size(), mlirValues.data(), + userData)); + }); +} + +void mlirPDLPatternModuleRegisterConstraintFunction( + MlirPDLPatternModule pdlModule, MlirStringRef name, + MlirPDLConstraintFunction constraintFn, void *userData) { + unwrap(pdlModule)->registerConstraintFunction( + unwrap(name), + [userData, constraintFn](PatternRewriter &rewriter, + PDLResultList &results, + ArrayRef<PDLValue> values) -> LogicalResult { + std::vector<MlirPDLValue> mlirValues = wrap(values); + return unwrap(constraintFn(wrap(&rewriter), wrap(&results), + mlirValues.size(), mlirValues.data(), + userData)); + }); +} #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH |
