summaryrefslogtreecommitdiff
path: root/mlir/lib/CAPI/Transforms/Rewrite.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/CAPI/Transforms/Rewrite.cpp')
-rw-r--r--mlir/lib/CAPI/Transforms/Rewrite.cpp104
1 files changed, 104 insertions, 0 deletions
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 6f85357a14a1..8ee6308cadf8 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -13,6 +13,8 @@
#include "mlir/CAPI/Rewrite.h"
#include "mlir/CAPI/Support.h"
#include "mlir/CAPI/Wrap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/PDLPatternMatch.h.inc"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -302,6 +304,19 @@ mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
}
//===----------------------------------------------------------------------===//
+/// PatternRewriter API
+//===----------------------------------------------------------------------===//
+
+inline mlir::PatternRewriter *unwrap(MlirPatternRewriter rewriter) {
+ assert(rewriter.ptr && "unexpected null rewriter");
+ return static_cast<mlir::PatternRewriter *>(rewriter.ptr);
+}
+
+inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) {
+ return {rewriter};
+}
+
+//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
@@ -331,4 +346,93 @@ mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) {
op.ptr = nullptr;
return wrap(m);
}
+
+inline const mlir::PDLValue *unwrap(MlirPDLValue value) {
+ assert(value.ptr && "unexpected null PDL value");
+ return static_cast<const mlir::PDLValue *>(value.ptr);
+}
+
+inline MlirPDLValue wrap(const mlir::PDLValue *value) { return {value}; }
+
+inline mlir::PDLResultList *unwrap(MlirPDLResultList results) {
+ assert(results.ptr && "unexpected null PDL results");
+ return static_cast<mlir::PDLResultList *>(results.ptr);
+}
+
+inline MlirPDLResultList wrap(mlir::PDLResultList *results) {
+ return {results};
+}
+
+MlirValue mlirPDLValueAsValue(MlirPDLValue value) {
+ return wrap(unwrap(value)->dyn_cast<mlir::Value>());
+}
+
+MlirType mlirPDLValueAsType(MlirPDLValue value) {
+ return wrap(unwrap(value)->dyn_cast<mlir::Type>());
+}
+
+MlirOperation mlirPDLValueAsOperation(MlirPDLValue value) {
+ return wrap(unwrap(value)->dyn_cast<mlir::Operation *>());
+}
+
+MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value) {
+ return wrap(unwrap(value)->dyn_cast<mlir::Attribute>());
+}
+
+void mlirPDLResultListPushBackValue(MlirPDLResultList results,
+ MlirValue value) {
+ unwrap(results)->push_back(unwrap(value));
+}
+
+void mlirPDLResultListPushBackType(MlirPDLResultList results, MlirType value) {
+ unwrap(results)->push_back(unwrap(value));
+}
+
+void mlirPDLResultListPushBackOperation(MlirPDLResultList results,
+ MlirOperation value) {
+ unwrap(results)->push_back(unwrap(value));
+}
+
+void mlirPDLResultListPushBackAttribute(MlirPDLResultList results,
+ MlirAttribute value) {
+ unwrap(results)->push_back(unwrap(value));
+}
+
+inline std::vector<MlirPDLValue> wrap(ArrayRef<PDLValue> values) {
+ std::vector<MlirPDLValue> mlirValues;
+ mlirValues.reserve(values.size());
+ for (auto &value : values) {
+ mlirValues.push_back(wrap(&value));
+ }
+ return mlirValues;
+}
+
+void mlirPDLPatternModuleRegisterRewriteFunction(
+ MlirPDLPatternModule pdlModule, MlirStringRef name,
+ MlirPDLRewriteFunction rewriteFn, void *userData) {
+ unwrap(pdlModule)->registerRewriteFunction(
+ unwrap(name),
+ [userData, rewriteFn](PatternRewriter &rewriter, PDLResultList &results,
+ ArrayRef<PDLValue> values) -> LogicalResult {
+ std::vector<MlirPDLValue> mlirValues = wrap(values);
+ return unwrap(rewriteFn(wrap(&rewriter), wrap(&results),
+ mlirValues.size(), mlirValues.data(),
+ userData));
+ });
+}
+
+void mlirPDLPatternModuleRegisterConstraintFunction(
+ MlirPDLPatternModule pdlModule, MlirStringRef name,
+ MlirPDLConstraintFunction constraintFn, void *userData) {
+ unwrap(pdlModule)->registerConstraintFunction(
+ unwrap(name),
+ [userData, constraintFn](PatternRewriter &rewriter,
+ PDLResultList &results,
+ ArrayRef<PDLValue> values) -> LogicalResult {
+ std::vector<MlirPDLValue> mlirValues = wrap(values);
+ return unwrap(constraintFn(wrap(&rewriter), wrap(&results),
+ mlirValues.size(), mlirValues.data(),
+ userData));
+ });
+}
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH