diff options
Diffstat (limited to 'mlir/tools/mlir-tblgen/OpFormatGen.cpp')
| -rw-r--r-- | mlir/tools/mlir-tblgen/OpFormatGen.cpp | 28 |
1 files changed, 27 insertions, 1 deletions
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index 0a9d14d6603a..11edf2523f1a 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -1365,7 +1365,7 @@ if (attr && ::mlir::failed(setFromAttr(prop.{1}, attr, emitError))) auto &propStorage = prop.{0}; auto {0}AttrName = StringAttr::get(ctx, "{0}"); auto attr = dict.get({0}AttrName); -usedKeys.insert(StringAttr::get(ctx, "{1}")); +usedKeys.insert({0}AttrName); if (attr || /*isRequired=*/{1}) {{ if (!attr) {{ emitError() << "expected key entry for {0} in DictionaryAttr to set " @@ -2787,6 +2787,11 @@ private: void handleTypesMatchConstraint( StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def); + /// Check for inferable type resolution based on + /// `ShapedTypeMatchesElementCountAndTypes` constraint. + void handleShapedTypeMatchesElementCountAndTypesConstraint( + StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def); + /// Returns an argument or attribute with the given name that has been seen /// within the format. ConstArgument findSeenArg(StringRef name); @@ -2850,6 +2855,9 @@ LogicalResult OpFormatParser::verify(SMLoc loc, handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true); } else if (def.isSubClassOf("TypesMatchWith")) { handleTypesMatchConstraint(variableTyResolver, def); + } else if (def.isSubClassOf("ShapedTypeMatchesElementCountAndTypes")) { + handleShapedTypeMatchesElementCountAndTypesConstraint(variableTyResolver, + def); } else if (!op.allResultTypesKnown()) { // This doesn't check the name directly to handle // DeclareOpInterfaceMethods<InferTypeOpInterface> @@ -3289,6 +3297,24 @@ void OpFormatParser::handleTypesMatchConstraint( variableTyResolver[rhsName] = {arg, transformer}; } +void OpFormatParser::handleShapedTypeMatchesElementCountAndTypesConstraint( + StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def) { + StringRef shapedArg = def.getValueAsString("shaped"); + StringRef elementsArg = def.getValueAsString("elements"); + + // Check if the 'shaped' argument is seen, then we can infer the 'elements' + // types. + if (ConstArgument arg = findSeenArg(shapedArg)) { + variableTyResolver[elementsArg] = { + arg, "::llvm::SmallVector<::mlir::Type>(::llvm::cast<::mlir::" + "ShapedType>($_self).getNumElements(), " + "::llvm::cast<::mlir::ShapedType>($_self).getElementType())"}; + } + + // Type inference in the opposite direction is not possible as the actual + // shaped type can't be inferred from the variadic elements. +} + ConstArgument OpFormatParser::findSeenArg(StringRef name) { if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name)) return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr; |
