//===- ExtensibleDialect.cpp - Extensible dialect ---------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/AttributeSupport.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/StorageUniquerSupport.h" #include "mlir/Support/LogicalResult.h" using namespace mlir; //===----------------------------------------------------------------------===// // Dynamic types and attributes shared functions //===----------------------------------------------------------------------===// /// Default parser for dynamic attribute or type parameters. /// Parse in the format '(<>)?' or ''. static LogicalResult typeOrAttrParser(AsmParser &parser, SmallVectorImpl &parsedParams) { // No parameters if (parser.parseOptionalLess() || !parser.parseOptionalGreater()) return success(); Attribute attr; if (parser.parseAttribute(attr)) return failure(); parsedParams.push_back(attr); while (parser.parseOptionalGreater()) { Attribute attr; if (parser.parseComma() || parser.parseAttribute(attr)) return failure(); parsedParams.push_back(attr); } return success(); } /// Default printer for dynamic attribute or type parameters. /// Print in the format '(<>)?' or ''. static void typeOrAttrPrinter(AsmPrinter &printer, ArrayRef params) { if (params.empty()) return; printer << "<"; interleaveComma(params, printer.getStream()); printer << ">"; } //===----------------------------------------------------------------------===// // Dynamic type //===----------------------------------------------------------------------===// std::unique_ptr DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier) { return DynamicTypeDefinition::get(name, dialect, std::move(verifier), typeOrAttrParser, typeOrAttrPrinter); } std::unique_ptr DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier, ParserFn &&parser, PrinterFn &&printer) { return std::unique_ptr( new DynamicTypeDefinition(name, dialect, std::move(verifier), std::move(parser), std::move(printer))); } DynamicTypeDefinition::DynamicTypeDefinition(StringRef nameRef, ExtensibleDialect *dialect, VerifierFn &&verifier, ParserFn &&parser, PrinterFn &&printer) : name(nameRef), dialect(dialect), verifier(std::move(verifier)), parser(std::move(parser)), printer(std::move(printer)), ctx(dialect->getContext()) {} DynamicTypeDefinition::DynamicTypeDefinition(ExtensibleDialect *dialect, StringRef nameRef) : name(nameRef), dialect(dialect), ctx(dialect->getContext()) {} void DynamicTypeDefinition::registerInTypeUniquer() { detail::TypeUniquer::registerType(&getContext(), getTypeID()); } namespace mlir { namespace detail { /// Storage of DynamicType. /// Contains a pointer to the type definition and type parameters. struct DynamicTypeStorage : public TypeStorage { using KeyTy = std::pair>; explicit DynamicTypeStorage(DynamicTypeDefinition *typeDef, ArrayRef params) : typeDef(typeDef), params(params) {} bool operator==(const KeyTy &key) const { return typeDef == key.first && params == key.second; } static llvm::hash_code hashKey(const KeyTy &key) { return llvm::hash_value(key); } static DynamicTypeStorage *construct(TypeStorageAllocator &alloc, const KeyTy &key) { return new (alloc.allocate()) DynamicTypeStorage(key.first, alloc.copyInto(key.second)); } /// Definition of the type. DynamicTypeDefinition *typeDef; /// The type parameters. ArrayRef params; }; } // namespace detail } // namespace mlir DynamicType DynamicType::get(DynamicTypeDefinition *typeDef, ArrayRef params) { auto &ctx = typeDef->getContext(); auto emitError = detail::getDefaultDiagnosticEmitFn(&ctx); assert(succeeded(typeDef->verify(emitError, params))); return detail::TypeUniquer::getWithTypeID( &ctx, typeDef->getTypeID(), typeDef, params); } DynamicType DynamicType::getChecked(function_ref emitError, DynamicTypeDefinition *typeDef, ArrayRef params) { if (failed(typeDef->verify(emitError, params))) return {}; auto &ctx = typeDef->getContext(); return detail::TypeUniquer::getWithTypeID( &ctx, typeDef->getTypeID(), typeDef, params); } DynamicTypeDefinition *DynamicType::getTypeDef() { return getImpl()->typeDef; } ArrayRef DynamicType::getParams() { return getImpl()->params; } bool DynamicType::classof(Type type) { return type.hasTrait(); } ParseResult DynamicType::parse(AsmParser &parser, DynamicTypeDefinition *typeDef, DynamicType &parsedType) { SmallVector params; if (failed(typeDef->parser(parser, params))) return failure(); parsedType = parser.getChecked(typeDef, params); if (!parsedType) return failure(); return success(); } void DynamicType::print(AsmPrinter &printer) { printer << getTypeDef()->getName(); getTypeDef()->printer(printer, getParams()); } //===----------------------------------------------------------------------===// // Dynamic attribute //===----------------------------------------------------------------------===// std::unique_ptr DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier) { return DynamicAttrDefinition::get(name, dialect, std::move(verifier), typeOrAttrParser, typeOrAttrPrinter); } std::unique_ptr DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier, ParserFn &&parser, PrinterFn &&printer) { return std::unique_ptr( new DynamicAttrDefinition(name, dialect, std::move(verifier), std::move(parser), std::move(printer))); } DynamicAttrDefinition::DynamicAttrDefinition(StringRef nameRef, ExtensibleDialect *dialect, VerifierFn &&verifier, ParserFn &&parser, PrinterFn &&printer) : name(nameRef), dialect(dialect), verifier(std::move(verifier)), parser(std::move(parser)), printer(std::move(printer)), ctx(dialect->getContext()) {} DynamicAttrDefinition::DynamicAttrDefinition(ExtensibleDialect *dialect, StringRef nameRef) : name(nameRef), dialect(dialect), ctx(dialect->getContext()) {} void DynamicAttrDefinition::registerInAttrUniquer() { detail::AttributeUniquer::registerAttribute(&getContext(), getTypeID()); } namespace mlir { namespace detail { /// Storage of DynamicAttr. /// Contains a pointer to the attribute definition and attribute parameters. struct DynamicAttrStorage : public AttributeStorage { using KeyTy = std::pair>; explicit DynamicAttrStorage(DynamicAttrDefinition *attrDef, ArrayRef params) : attrDef(attrDef), params(params) {} bool operator==(const KeyTy &key) const { return attrDef == key.first && params == key.second; } static llvm::hash_code hashKey(const KeyTy &key) { return llvm::hash_value(key); } static DynamicAttrStorage *construct(AttributeStorageAllocator &alloc, const KeyTy &key) { return new (alloc.allocate()) DynamicAttrStorage(key.first, alloc.copyInto(key.second)); } /// Definition of the type. DynamicAttrDefinition *attrDef; /// The type parameters. ArrayRef params; }; } // namespace detail } // namespace mlir DynamicAttr DynamicAttr::get(DynamicAttrDefinition *attrDef, ArrayRef params) { auto &ctx = attrDef->getContext(); return detail::AttributeUniquer::getWithTypeID( &ctx, attrDef->getTypeID(), attrDef, params); } DynamicAttr DynamicAttr::getChecked(function_ref emitError, DynamicAttrDefinition *attrDef, ArrayRef params) { if (failed(attrDef->verify(emitError, params))) return {}; return get(attrDef, params); } DynamicAttrDefinition *DynamicAttr::getAttrDef() { return getImpl()->attrDef; } ArrayRef DynamicAttr::getParams() { return getImpl()->params; } bool DynamicAttr::classof(Attribute attr) { return attr.hasTrait(); } ParseResult DynamicAttr::parse(AsmParser &parser, DynamicAttrDefinition *attrDef, DynamicAttr &parsedAttr) { SmallVector params; if (failed(attrDef->parser(parser, params))) return failure(); parsedAttr = parser.getChecked(attrDef, params); if (!parsedAttr) return failure(); return success(); } void DynamicAttr::print(AsmPrinter &printer) { printer << getAttrDef()->getName(); getAttrDef()->printer(printer, getParams()); } //===----------------------------------------------------------------------===// // Dynamic operation //===----------------------------------------------------------------------===// DynamicOpDefinition::DynamicOpDefinition( StringRef name, ExtensibleDialect *dialect, OperationName::VerifyInvariantsFn &&verifyFn, OperationName::VerifyRegionInvariantsFn &&verifyRegionFn, OperationName::ParseAssemblyFn &&parseFn, OperationName::PrintAssemblyFn &&printFn, OperationName::FoldHookFn &&foldHookFn, OperationName::GetCanonicalizationPatternsFn &&getCanonicalizationPatternsFn, OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn) : typeID(dialect->allocateTypeID()), name((dialect->getNamespace() + "." + name).str()), dialect(dialect), verifyFn(std::move(verifyFn)), verifyRegionFn(std::move(verifyRegionFn)), parseFn(std::move(parseFn)), printFn(std::move(printFn)), foldHookFn(std::move(foldHookFn)), getCanonicalizationPatternsFn(std::move(getCanonicalizationPatternsFn)), populateDefaultAttrsFn(std::move(populateDefaultAttrsFn)) {} std::unique_ptr DynamicOpDefinition::get( StringRef name, ExtensibleDialect *dialect, OperationName::VerifyInvariantsFn &&verifyFn, OperationName::VerifyRegionInvariantsFn &&verifyRegionFn) { auto parseFn = [](OpAsmParser &parser, OperationState &result) { return parser.emitError( parser.getCurrentLocation(), "dynamic operation do not define any parser function"); }; auto printFn = [](Operation *op, OpAsmPrinter &printer, StringRef) { printer.printGenericOp(op); }; return DynamicOpDefinition::get(name, dialect, std::move(verifyFn), std::move(verifyRegionFn), std::move(parseFn), std::move(printFn)); } std::unique_ptr DynamicOpDefinition::get( StringRef name, ExtensibleDialect *dialect, OperationName::VerifyInvariantsFn &&verifyFn, OperationName::VerifyRegionInvariantsFn &&verifyRegionFn, OperationName::ParseAssemblyFn &&parseFn, OperationName::PrintAssemblyFn &&printFn) { auto foldHookFn = [](Operation *op, ArrayRef operands, SmallVectorImpl &results) { return failure(); }; auto getCanonicalizationPatternsFn = [](RewritePatternSet &, MLIRContext *) { }; auto populateDefaultAttrsFn = [](const RegisteredOperationName &, NamedAttrList &) {}; return DynamicOpDefinition::get(name, dialect, std::move(verifyFn), std::move(verifyRegionFn), std::move(parseFn), std::move(printFn), std::move(foldHookFn), std::move(getCanonicalizationPatternsFn), std::move(populateDefaultAttrsFn)); } std::unique_ptr DynamicOpDefinition::get( StringRef name, ExtensibleDialect *dialect, OperationName::VerifyInvariantsFn &&verifyFn, OperationName::VerifyInvariantsFn &&verifyRegionFn, OperationName::ParseAssemblyFn &&parseFn, OperationName::PrintAssemblyFn &&printFn, OperationName::FoldHookFn &&foldHookFn, OperationName::GetCanonicalizationPatternsFn &&getCanonicalizationPatternsFn, OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn) { return std::unique_ptr(new DynamicOpDefinition( name, dialect, std::move(verifyFn), std::move(verifyRegionFn), std::move(parseFn), std::move(printFn), std::move(foldHookFn), std::move(getCanonicalizationPatternsFn), std::move(populateDefaultAttrsFn))); } //===----------------------------------------------------------------------===// // Extensible dialect //===----------------------------------------------------------------------===// namespace { /// Interface that can only be implemented by extensible dialects. /// The interface is used to check if a dialect is extensible or not. class IsExtensibleDialect : public DialectInterface::Base { public: IsExtensibleDialect(Dialect *dialect) : Base(dialect) {} MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IsExtensibleDialect) }; } // namespace ExtensibleDialect::ExtensibleDialect(StringRef name, MLIRContext *ctx, TypeID typeID) : Dialect(name, ctx, typeID) { addInterfaces(); } void ExtensibleDialect::registerDynamicType( std::unique_ptr &&type) { DynamicTypeDefinition *typePtr = type.get(); TypeID typeID = type->getTypeID(); StringRef name = type->getName(); ExtensibleDialect *dialect = type->getDialect(); assert(dialect == this && "trying to register a dynamic type in the wrong dialect"); // If a type with the same name is already defined, fail. auto registered = dynTypes.try_emplace(typeID, std::move(type)).second; (void)registered; assert(registered && "type TypeID was not unique"); registered = nameToDynTypes.insert({name, typePtr}).second; (void)registered; assert(registered && "Trying to create a new dynamic type with an existing name"); auto abstractType = AbstractType::get(*dialect, DynamicAttr::getInterfaceMap(), DynamicType::getHasTraitFn(), typeID); /// Add the type to the dialect and the type uniquer. addType(typeID, std::move(abstractType)); typePtr->registerInTypeUniquer(); } void ExtensibleDialect::registerDynamicAttr( std::unique_ptr &&attr) { auto *attrPtr = attr.get(); auto typeID = attr->getTypeID(); auto name = attr->getName(); auto *dialect = attr->getDialect(); assert(dialect == this && "trying to register a dynamic attribute in the wrong dialect"); // If an attribute with the same name is already defined, fail. auto registered = dynAttrs.try_emplace(typeID, std::move(attr)).second; (void)registered; assert(registered && "attribute TypeID was not unique"); registered = nameToDynAttrs.insert({name, attrPtr}).second; (void)registered; assert(registered && "Trying to create a new dynamic attribute with an existing name"); auto abstractAttr = AbstractAttribute::get(*dialect, DynamicAttr::getInterfaceMap(), DynamicAttr::getHasTraitFn(), typeID); /// Add the type to the dialect and the type uniquer. addAttribute(typeID, std::move(abstractAttr)); attrPtr->registerInAttrUniquer(); } void ExtensibleDialect::registerDynamicOp( std::unique_ptr &&op) { assert(op->dialect == this && "trying to register a dynamic op in the wrong dialect"); auto hasTraitFn = [](TypeID traitId) { return false; }; RegisteredOperationName::insert( op->name, *op->dialect, op->typeID, std::move(op->parseFn), std::move(op->printFn), std::move(op->verifyFn), std::move(op->verifyRegionFn), std::move(op->foldHookFn), std::move(op->getCanonicalizationPatternsFn), detail::InterfaceMap::get<>(), std::move(hasTraitFn), {}, std::move(op->populateDefaultAttrsFn)); } bool ExtensibleDialect::classof(const Dialect *dialect) { return const_cast(dialect) ->getRegisteredInterface(); } OptionalParseResult ExtensibleDialect::parseOptionalDynamicType( StringRef typeName, AsmParser &parser, Type &resultType) const { DynamicTypeDefinition *typeDef = lookupTypeDefinition(typeName); if (!typeDef) return llvm::None; DynamicType dynType; if (DynamicType::parse(parser, typeDef, dynType)) return failure(); resultType = dynType; return success(); } LogicalResult ExtensibleDialect::printIfDynamicType(Type type, AsmPrinter &printer) { if (auto dynType = type.dyn_cast()) { dynType.print(printer); return success(); } return failure(); } OptionalParseResult ExtensibleDialect::parseOptionalDynamicAttr( StringRef attrName, AsmParser &parser, Attribute &resultAttr) const { DynamicAttrDefinition *attrDef = lookupAttrDefinition(attrName); if (!attrDef) return llvm::None; DynamicAttr dynAttr; if (DynamicAttr::parse(parser, attrDef, dynAttr)) return failure(); resultAttr = dynAttr; return success(); } LogicalResult ExtensibleDialect::printIfDynamicAttr(Attribute attribute, AsmPrinter &printer) { if (auto dynAttr = attribute.dyn_cast()) { dynAttr.print(printer); return success(); } return failure(); } //===----------------------------------------------------------------------===// // Dynamic dialect //===----------------------------------------------------------------------===// namespace { /// Interface that can only be implemented by extensible dialects. /// The interface is used to check if a dialect is extensible or not. class IsDynamicDialect : public DialectInterface::Base { public: IsDynamicDialect(Dialect *dialect) : Base(dialect) {} MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IsDynamicDialect) }; } // namespace DynamicDialect::DynamicDialect(StringRef name, MLIRContext *ctx) : SelfOwningTypeID(), ExtensibleDialect(name, ctx, SelfOwningTypeID::getTypeID()) { addInterfaces(); } bool DynamicDialect::classof(const Dialect *dialect) { return const_cast(dialect) ->getRegisteredInterface(); } Type DynamicDialect::parseType(DialectAsmParser &parser) const { auto loc = parser.getCurrentLocation(); StringRef typeTag; if (failed(parser.parseKeyword(&typeTag))) return Type(); { Type dynType; auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType); if (parseResult.has_value()) { if (succeeded(parseResult.value())) return dynType; return Type(); } } parser.emitError(loc, "expected dynamic type"); return Type(); } void DynamicDialect::printType(Type type, DialectAsmPrinter &printer) const { auto wasDynamic = printIfDynamicType(type, printer); (void)wasDynamic; assert(succeeded(wasDynamic) && "non-dynamic type defined in dynamic dialect"); } Attribute DynamicDialect::parseAttribute(DialectAsmParser &parser, Type type) const { auto loc = parser.getCurrentLocation(); StringRef typeTag; if (failed(parser.parseKeyword(&typeTag))) return Attribute(); { Attribute dynAttr; auto parseResult = parseOptionalDynamicAttr(typeTag, parser, dynAttr); if (parseResult.has_value()) { if (succeeded(parseResult.value())) return dynAttr; return Attribute(); } } parser.emitError(loc, "expected dynamic attribute"); return Attribute(); } void DynamicDialect::printAttribute(Attribute attr, DialectAsmPrinter &printer) const { auto wasDynamic = printIfDynamicAttr(attr, printer); (void)wasDynamic; assert(succeeded(wasDynamic) && "non-dynamic attribute defined in dynamic dialect"); }