summaryrefslogtreecommitdiff
path: root/mlir/tools/mlir-tblgen/OpFormatGen.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/tools/mlir-tblgen/OpFormatGen.cpp')
-rw-r--r--mlir/tools/mlir-tblgen/OpFormatGen.cpp28
1 files changed, 27 insertions, 1 deletions
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 0a9d14d6603a..11edf2523f1a 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -1365,7 +1365,7 @@ if (attr && ::mlir::failed(setFromAttr(prop.{1}, attr, emitError)))
auto &propStorage = prop.{0};
auto {0}AttrName = StringAttr::get(ctx, "{0}");
auto attr = dict.get({0}AttrName);
-usedKeys.insert(StringAttr::get(ctx, "{1}"));
+usedKeys.insert({0}AttrName);
if (attr || /*isRequired=*/{1}) {{
if (!attr) {{
emitError() << "expected key entry for {0} in DictionaryAttr to set "
@@ -2787,6 +2787,11 @@ private:
void handleTypesMatchConstraint(
StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def);
+ /// Check for inferable type resolution based on
+ /// `ShapedTypeMatchesElementCountAndTypes` constraint.
+ void handleShapedTypeMatchesElementCountAndTypesConstraint(
+ StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def);
+
/// Returns an argument or attribute with the given name that has been seen
/// within the format.
ConstArgument findSeenArg(StringRef name);
@@ -2850,6 +2855,9 @@ LogicalResult OpFormatParser::verify(SMLoc loc,
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
} else if (def.isSubClassOf("TypesMatchWith")) {
handleTypesMatchConstraint(variableTyResolver, def);
+ } else if (def.isSubClassOf("ShapedTypeMatchesElementCountAndTypes")) {
+ handleShapedTypeMatchesElementCountAndTypesConstraint(variableTyResolver,
+ def);
} else if (!op.allResultTypesKnown()) {
// This doesn't check the name directly to handle
// DeclareOpInterfaceMethods<InferTypeOpInterface>
@@ -3289,6 +3297,24 @@ void OpFormatParser::handleTypesMatchConstraint(
variableTyResolver[rhsName] = {arg, transformer};
}
+void OpFormatParser::handleShapedTypeMatchesElementCountAndTypesConstraint(
+ StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def) {
+ StringRef shapedArg = def.getValueAsString("shaped");
+ StringRef elementsArg = def.getValueAsString("elements");
+
+ // Check if the 'shaped' argument is seen, then we can infer the 'elements'
+ // types.
+ if (ConstArgument arg = findSeenArg(shapedArg)) {
+ variableTyResolver[elementsArg] = {
+ arg, "::llvm::SmallVector<::mlir::Type>(::llvm::cast<::mlir::"
+ "ShapedType>($_self).getNumElements(), "
+ "::llvm::cast<::mlir::ShapedType>($_self).getElementType())"};
+ }
+
+ // Type inference in the opposite direction is not possible as the actual
+ // shaped type can't be inferred from the variadic elements.
+}
+
ConstArgument OpFormatParser::findSeenArg(StringRef name) {
if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name))
return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;