diff --git a/forte/common/constants.py b/forte/common/constants.py index db28078ff..1032f8310 100644 --- a/forte/common/constants.py +++ b/forte/common/constants.py @@ -36,7 +36,15 @@ # Name of the key to access the attribute dict of an entry type from # ``_type_attributes`` of ``DataStore``. -TYPE_ATTR_KEY = "attributes" +ATTR_INFO_KEY = "attributes" + +# Name of the key to access the type of an attribute from +# ``_type_attributes`` of ``DataStore``. +ATTR_TYPE_KEY = "type" + +# Name of the key to access the index of an attribute from +# ``_type_attributes`` of ``DataStore``. +ATTR_INDEX_KEY = "index" # Name of the key to access a set of parent names of an entry type from # ``_type_attributes`` of ``DataStore``. diff --git a/forte/data/base_pack.py b/forte/data/base_pack.py index a519ab7fd..182231d42 100644 --- a/forte/data/base_pack.py +++ b/forte/data/base_pack.py @@ -32,7 +32,8 @@ Iterable, ) from functools import partial -from typing_inspect import get_origin +from inspect import isclass +from typing_inspect import is_forward_ref from packaging.version import Version import jsonpickle @@ -47,6 +48,7 @@ LinkType, FList, FDict, + ENTRY_TYPE_DATA_STRUCTURES, ) from forte.version import ( PACK_VERSION, @@ -455,32 +457,71 @@ def on_entry_creation( # Use the auto-inferred control component. c = self.__control_component - def entry_getter(cls: Entry, attr_name: str, field_type): + def entry_getter(cls: Entry, attr_name: str): """A getter function for dataclass fields of entry object. - When the field contains ``tid``s, we will convert them to entry - object on the fly. + Depending on the value stored in the data store and the type + of the attribute, the method decides how to process the value. + + - Attributes repersented as ``FList`` and ``FDict`` objects are stored + as list and dictionary respectively in the dtaa store entry. These + values are converted to ``FList`` and ``FDict`` objects on the fly. + - When the field contains ``tid``s, we will convert them to entry + object on the fly. This is done by checking the type + information of the attribute in the entry object. If the + attribute is of type ``Entry`` or a ``ForwardRef``, we can + assume that that value stored in the data store entry represents + the entry's ``tid``. + - When values are stored as a tuple, we assume the value represents + a `subentry` stored in a `MultiPack`. + - In all other cases, the values are returned in the forms that they + are stored in the data store entry. Args: cls: An ``Entry`` class object. attr_name: The name of the attribute. - field_type: The type of the attribute. + + Returns: + The value of the required attribute in the form specified + by the corresponding ``Entry`` class object. """ + data_store_ref = ( cls.pack._data_store # pylint: disable=protected-access ) attr_val = data_store_ref.get_attribute( tid=cls.tid, attr_name=attr_name ) - if field_type in (FList, FDict): + attr_type = data_store_ref.get_attr_type( + cls.entry_type(), attr_name + ) + + if attr_type[0] in ENTRY_TYPE_DATA_STRUCTURES: # Generate FList/FDict object on the fly - return field_type(parent_entry=cls, data=attr_val) + return attr_type[0](parent_entry=cls, data=attr_val) try: - # TODO: Find a better solution to determine if a field is Entry - # Will be addressed by https://github.com/asyml/forte/issues/835 - # Convert tid to entry object on the fly - if isinstance(attr_val, int): - # Single pack entry + # Check dataclass attribute value type + # If the attribute was an Entry object, only its tid + # is stored in the DataStore and hence its needs to be converted. + + # Entry objects are stored in data stores by their tid (which is + # of type int). Thus, if we enounter an int value, we check the + # type information which is stored as a tuple. if any entry in this + # tuple is a subclass of Entry or is a ForwardRef to another entry, + # we can infer that this int value represents the tid of an Entry + # object and thus must be converted to an object using get_entry + # before returning. + if ( + isinstance(attr_val, int) + and attr_type[1] + and any( + issubclass(entry, Entry) + if isclass(entry) + else is_forward_ref(entry) + for entry in list(attr_type[1]) + ) + ): return cls.pack.get_entry(tid=attr_val) + # The condition below is to check whether the attribute's value # is a pair of integers - `(pack_id, tid)`. If so we may have # encountered a `tid` that can only be resolved by @@ -497,7 +538,7 @@ def entry_getter(cls: Entry, attr_name: str, field_type): pass return attr_val - def entry_setter(cls: Entry, value: Any, attr_name: str, field_type): + def entry_setter(cls: Entry, value: Any, attr_name: str): """A setter function for dataclass fields of entry object. When the value contains entry objects, we will convert them into ``tid``s before storing to ``DataStore``. @@ -506,16 +547,19 @@ def entry_setter(cls: Entry, value: Any, attr_name: str, field_type): cls: An ``Entry`` class object. value: The value to be assigned to the attribute. attr_name: The name of the attribute. - field_type: The type of the attribute. """ attr_value: Any data_store_ref = ( cls.pack._data_store # pylint: disable=protected-access ) + + attr_type = data_store_ref.get_attr_type( + cls.entry_type(), attr_name + ) # Assumption: Users will not assign value to a FList/FDict field. # Only internal methods can set the FList/FDict field, and value's # type has to be Iterator[Entry]/Dict[Any, Entry]. - if field_type is FList: + if attr_type[0] is FList: try: attr_value = [entry.tid for entry in value] except AttributeError as e: @@ -523,7 +567,7 @@ def entry_setter(cls: Entry, value: Any, attr_name: str, field_type): "You are trying to assign value to a `FList` field, " "which can only accept an iterator of `Entry` objects." ) from e - elif field_type is FDict: + elif attr_type[0] is FDict: try: attr_value = { key: entry.tid for key, entry in value.items() @@ -554,10 +598,9 @@ def entry_setter(cls: Entry, value: Any, attr_name: str, field_type): self._save_entry_to_data_store(entry=entry) # Register property functions for all dataclass fields. - for name, field in entry.__dataclass_fields__.items(): + for name in entry.__dataclass_fields__: # Convert the typing annotation to the original class. # This will be used to determine if a field is FList/FDict. - field_type = get_origin(field.type) setattr( type(entry), name, @@ -566,12 +609,8 @@ def entry_setter(cls: Entry, value: Any, attr_name: str, field_type): property( # We need to bound the attribute name and field type here # for the getter and setter of each field. - fget=partial( - entry_getter, attr_name=name, field_type=field_type - ), - fset=partial( - entry_setter, attr_name=name, field_type=field_type - ), + fget=partial(entry_getter, attr_name=name), + fset=partial(entry_setter, attr_name=name), ), ) diff --git a/forte/data/data_store.py b/forte/data/data_store.py index 2a1f56aaf..1f959c774 100644 --- a/forte/data/data_store.py +++ b/forte/data/data_store.py @@ -11,14 +11,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy import json -from typing import Dict, List, Iterator, Tuple, Optional, Any, Type +import sys +from typing import Dict, List, Iterator, Set, Tuple, Optional, Any, Type import uuid import logging from heapq import heappush, heappop from sortedcontainers import SortedList -from typing_inspect import get_origin +from typing_inspect import get_origin, get_args, is_generic_type from forte.utils import get_class from forte.utils.utils import get_full_module_name @@ -37,7 +39,12 @@ MultiPackGroup, MultiPackLink, ) -from forte.data.ontology.core import Entry, FList, FDict +from forte.data.ontology.core import ( + Entry, + FList, + FDict, + ENTRY_TYPE_DATA_STRUCTURES, +) from forte.common import constants @@ -157,12 +164,22 @@ def __init__( """ The ``_type_attributes`` is a private dictionary that provides - ``type_name``, their parent entry, and the order of corresponding attributes. - The keys are fully qualified names of every type; The value is a - dictionary with two keys. Key ``attribute`` provides an inner dictionary - with all valid attributes for this type and the indices of attributes - among these lists. Key ``parent_class`` is a string representing the - ancestors of this type. + ``type_name`` as the key, and the metadata of the entry represented by + ``type_name``. This metadata includes the order and type information of + attributes stored in the data store entry; The value is a dictionary with + two keys. + + 1) Key ``attribute`` has its value as a dictionary + with all valid attributes for this entry type as keys and their metadata. + as values. The metadata is represented as another inner dictionary + that has two keys: + a) the first key is ``index`` whose value determines the position + of where the attribute is stored in the data store entry. + b) The second key is type, which is a tuple of two elements that provides the + type information of a given attribute. + i) The first element is the `unsubscripted` version of the attribute's type + ii) the second element is the type arguments for the same. + 2) Key ``parent_class`` is a string representing the ancestors of this type. This structure is supposed to be built dynamically. When a user adds new entries, `DataStore` will check unknown types and add them to @@ -172,23 +189,33 @@ def __init__( .. code-block:: python - # DataStore._type_attributes is: - # { - # "ft.onto.base_ontology.Token": { - # "attributes": {"pos": 4, "ud_xpos": 5, - # "lemma": 6, "chunk": 7, "ner": 8, "sense": 9, - # "is_root": 10, "ud_features": 11, "ud_misc": 12}, - # "parent_class": set("forte.data.ontology.top.Annotation"), }, - # "ft.onto.base_ontology.Document": { - # "attributes": {"document_class": 4, - # "sentiment": 5, "classifications": 6}, - # "parent_class": set("forte.data.ontology.top.Annotation"), }, - # "ft.onto.base_ontology.Sentence": { - # "attributes": {"speaker": 4, - # "part_id": 5, "sentiment": 6, - # "classification": 7, "classifications": 8}, - # "parent_class": set(), } - # } + DataStore._type_attributes is: + { + "ft.onto.base_ontology.Document": { + "attributes": { + "document_class": {"index": 4, "type": (list, (str,))}, + "sentiment": {"index": 5, "type": (dict, (str, float))}, + "classifications": { + "index": 6, + "type":(FDict,(str, Classification)) + } + }, + "parent_class": set(), + }, + "ft.onto.base_ontology.Sentence": { + "attributes": { + "speaker": {"index": 4, "type": (Union, (str, type(None)))}, + "part_id": {"index": 5, "type": (Union, (int, type(None)))}, + "sentiment": {"index": 6, "type": (dict, (str, float))}, + "classification": {"index": 7, "type": (dict, (str, float))}, + "classifications": { + "index": 8, + "type": (FDict,(str, Classification)) + }, + }, + "parent_class": set(), + }, + } """ self._init_top_to_core_entries() if self._onto_file_path: @@ -258,18 +285,28 @@ def __getstate__(self): """ state = super().__getstate__() state["_DataStore__elements"] = {} + + # Make a copy of the updated type_attributes + type_attributes = deepcopy(DataStore._type_attributes) + state["fields"] = DataStore._type_attributes + for k in self.__elements: # build the full `_type_attributes` self._get_type_info(k) state["_DataStore__elements"][k] = list(self.__elements[k]) + state.pop("_DataStore__tid_ref_dict") state.pop("_DataStore__tid_idx_dict") state.pop("_DataStore__deletion_count") state["entries"] = state.pop("_DataStore__elements") - state["fields"] = self._type_attributes for _, v in state["fields"].items(): if constants.PARENT_CLASS_KEY in v: v.pop(constants.PARENT_CLASS_KEY) + + if constants.ATTR_INFO_KEY in v: + for _, info in v[constants.ATTR_INFO_KEY].items(): + info.pop(constants.ATTR_TYPE_KEY) + DataStore._type_attributes = type_attributes return state def __setstate__(self, state): @@ -286,6 +323,27 @@ def __setstate__(self, state): self._DataStore__tid_idx_dict = {} self._DataStore__deletion_count = {} + # Update `_type_attributes` to store the types of each + # entry attribute as well. + for tn in self._type_attributes: + entry_type = self.fetch_entry_type_data(tn) + for attr, type_val in entry_type.items(): + try: + info_dict = self._type_attributes[tn][ + constants.ATTR_INFO_KEY + ][attr] + + # If in case there is an attribute of entry + # referenced by tn which is defined in the + # _type_attributes dict of DataStore but not + # in the serialized data of _type_attributes, + # we dont need to add type information for + # that attribute. + except KeyError: + continue + if constants.ATTR_TYPE_KEY not in info_dict: + info_dict[constants.ATTR_TYPE_KEY] = type_val + reset_index = {} for k in self.__elements: if self._is_annotation(k): @@ -394,22 +452,32 @@ def check_fields(store): # If a field only occurs in the serialized object but not in # the current class, it will not be detected. # Instead, it will be dropped later. - diff = set(v[constants.TYPE_ATTR_KEY].items()) - set( - store._type_attributes[t][constants.TYPE_ATTR_KEY].items() + + # This lambda function is used to get a temporary + # representation of type_attributes with only the + # name and index + get_temp_rep = lambda entry: set( + (attr, val[constants.ATTR_INDEX_KEY]) + for attr, val in entry[constants.ATTR_INFO_KEY].items() ) + + temp_cls_rep = get_temp_rep(v) + temp_obj_rep = get_temp_rep(store._type_attributes[t]) + + diff = temp_cls_rep - temp_obj_rep for f in diff: # if fields appear in both the current class and the # serialized objects but have different orders, switch # fields to match the order of the current class. if ( f[0] - in store._type_attributes[t][constants.TYPE_ATTR_KEY] + in store._type_attributes[t][constants.ATTR_INFO_KEY] ): # record indices of the same field in the class and # objects. Save different indices to a dictionary. change_map[f[1]] = store._type_attributes[t][ - constants.TYPE_ATTR_KEY - ][f[0]] + constants.ATTR_INFO_KEY + ][f[0]][constants.ATTR_INDEX_KEY] # record indices of fields that only appear in the # current class. We want to fill them with None. else: @@ -432,7 +500,13 @@ def check_fields(store): # throw fields that are redundant/only appear in # the serialized object for i in range( - max(v[constants.TYPE_ATTR_KEY].values()) + 1 + max( + info[constants.ATTR_INDEX_KEY] + for info in v[ + constants.ATTR_INFO_KEY + ].values() + ) + + 1 ) ] if len(contradict_loc) > 0: @@ -491,7 +565,8 @@ def _get_type_info(self, type_name: str) -> Dict[str, Any]: ``DataStore._type_attributes``. If the ``type_name`` does not currently exists and dynamic import is enabled, this function will add a new key-value pair into ``DataStore._type_attributes``. The value consists - of a full attribute-to-index dictionary and an empty parent set. + a dictionary which stores the name and the type information of every + attribute of the entry and an empty parent set. This function returns a dictionary containing an attribute dict and a set of parent entries of the given type. For example: @@ -522,7 +597,7 @@ def _get_type_info(self, type_name: str) -> Dict[str, Any]: # check if type is in dictionary if ( type_name in DataStore._type_attributes - and constants.TYPE_ATTR_KEY in DataStore._type_attributes[type_name] + and constants.ATTR_INFO_KEY in DataStore._type_attributes[type_name] ): return DataStore._type_attributes[type_name] if not self._dynamically_add_type: @@ -537,18 +612,24 @@ def _get_type_info(self, type_name: str) -> Dict[str, Any]: attr_dict = {} attr_idx = constants.ENTRY_TYPE_INDEX + 1 + type_dict = self.fetch_entry_type_data(type_name) + for attr_name in attributes: - attr_dict[attr_name] = attr_idx + attr_dict[attr_name] = { + constants.ATTR_TYPE_KEY: type_dict[attr_name], + constants.ATTR_INDEX_KEY: attr_idx, + } attr_idx += 1 new_entry_info = { - constants.TYPE_ATTR_KEY: attr_dict, + constants.ATTR_INFO_KEY: attr_dict, constants.PARENT_CLASS_KEY: set(), } DataStore._type_attributes[type_name] = new_entry_info + return new_entry_info - def _get_type_attribute_dict(self, type_name: str) -> Dict[str, int]: + def _get_type_attribute_dict(self, type_name: str) -> Dict[str, Dict]: """Get the attribute dict of an entry type. The attribute dict maps attribute names to a list of consecutive integers as indices. For example: .. code-block:: python @@ -566,7 +647,7 @@ def _get_type_attribute_dict(self, type_name: str) -> Dict[str, int]: Returns: attr_dict (dict): The attribute-to-index dictionary of an entry. """ - return self._get_type_info(type_name)[constants.TYPE_ATTR_KEY] + return self._get_type_info(type_name)[constants.ATTR_INFO_KEY] def _get_type_parent(self, type_name: str) -> str: """Get a set of parent names of an entry type. The set is a subset of all @@ -594,9 +675,11 @@ def _default_attributes_for_type(self, type_name: str) -> List: attr_dict: Dict = self._get_type_attribute_dict(type_name) attr_fields: Dict = self._get_entry_attributes_by_class(type_name) attr_list: List = [None] * len(attr_dict) - for attr_name, attr_id in attr_dict.items(): + for attr_name, attr_info in attr_dict.items(): # TODO: We should keep a record of the attribute class instead of # inspecting the class on the fly. + attr_id = attr_info[constants.ATTR_INDEX_KEY] + attr_class = get_origin(attr_fields[attr_name].type) if attr_class in (FList, list, List): attr_list[attr_id - constants.ATTR_BEGIN_INDEX] = [] @@ -604,6 +687,106 @@ def _default_attributes_for_type(self, type_name: str) -> List: attr_list[attr_id - constants.ATTR_BEGIN_INDEX] = {} return attr_list + def fetch_entry_type_data( + self, type_name: str, attributes: Optional[Set[Tuple[str, str]]] = None + ) -> Dict[str, Tuple]: + r"""This function takes a fully qualified ``type_name`` class name + and a set of tuples representing an attribute and its required type + (only in the case where the ``type_name`` class name represents an + entry being added from a user defined ontology) and creates a + dictionary where the key is attribute of the entry and value is + the type information of that attribute. + + There are two cases in which a fully qualified ``type_name`` class + name can be handled: + + 1) If the class being added is of an existing entry: This means + that there is information present about this entry through + its `dataclass` attributes and their respective types. Thus, + we use the `_get_entry_attributes_by_class` method to fetch + this information. + 2) If the class being added is of a user defined entry: In this + case, we fetch the information about the entry's attributes + and their types from the ``attributes`` argument. + + Args: + type_name: A fully qualified name of an entry class. + attributes: This argument is used when parsing ontology + files. The entries in the set are a tuples of two + elements. + + .. code-block:: python + + attributes = { + ('passage_id', 'str'), + ('author', 'str') + } + Returns: A dictionary representing attributes as key and type + information as value. For each attribute, the type information is + represented by a tuple of two elements. The first element is the + `unsubscripted` version of the attribute's type and the second + element is the type arguments for the same. The `type_dict` is used + to populate the type information for attributes of an entry + specified by ``type_name`` in `_type_attributes`. For example, + + .. code-block:: python + + type_dict = { + "document_class": (list, (str,)), + "sentiment": (dict, (str, float)), + "classifications": (FDict, (str, Classification)) + } + """ + type_dict = {} + attr_class: Any + attr_args: Tuple + + if attributes: + for attr, type_val in attributes: + # the type_dict only stores the type of each + # attribute class. When attributes and their + # types are defined in ontology files, these + # values are stored in attr_args. attr_class + # is empty in this case and has a value of + # None. But to maintain the consistency of + # type_dict, we only store the type of every + # value, even None. + attr_class = type(None) + attr_args = tuple([get_class(type_val)]) + type_dict[attr] = tuple([attr_class, attr_args]) + + else: + attr_fields: Dict = self._get_entry_attributes_by_class(type_name) + for attr_name, attr_info in attr_fields.items(): + + attr_class = get_origin(attr_info.type) + attr_args = get_args(attr_info.type) + + # Prior to Python 3.7, fetching generic type + # aliases resulted in actual type objects whereas from + # Python 3.7, they were converted to their primitive + # form. For example, typing.List and typing.Dict + # is converted to primitive forms of list and + # dict. We handle them separately here + if ( + is_generic_type(attr_info.type) + and hasattr(attr_info.type, "__extra__") + and sys.version_info[:3] < (3, 7, 0) + and attr_class not in ENTRY_TYPE_DATA_STRUCTURES + ): + # if python version is < 3.7, thr primitive form + # of generic types are stored in the __extra__ + # attribute. This attribute is not present in + # generic types from 3.7. + try: + attr_class = attr_info.type.__extra__ + except AttributeError: + pass + + type_dict[attr_name] = tuple([attr_class, attr_args]) + + return type_dict + def _is_subclass( self, type_name: str, cls, no_dynamic_subclass: bool = False ) -> bool: @@ -690,6 +873,34 @@ def _is_annotation(self, type_name: str) -> bool: for entry_class in (Annotation, AudioAnnotation) ) + def get_attr_type( + self, type_name: str, attr_name: str + ) -> Tuple[Any, Tuple]: + """ + Retrieve the type information of a given attribute ``attr_name`` + in an entry of type ``type_name`` + + Args: + type_name (str): The type name of the entry whose attribute entry + type needs to be fetched + attr_name (str): The name of the attribute in the entry whose type + information needs to be fetched. + + Returns: + The type information of the required attribute. This information is + stored in the ``_type_attributes`` dictionary of the Data Store. + """ + try: + return DataStore._type_attributes[type_name][ + constants.ATTR_INFO_KEY + ][attr_name][constants.ATTR_TYPE_KEY] + except KeyError as e: + raise KeyError( + f"Attribute {attr_name} does not have type " + f"information provided or attribute {attr_name}" + f"is not a valid attribute of entry {type_name}" + ) from e + def all_entries(self, entry_type_name: str) -> Iterator[List]: """ Retrieve all entry data of entry type ``entry_type_name`` and @@ -946,7 +1157,9 @@ def set_attribute(self, tid: int, attr_name: str, attr_value: Any): entry, entry_type = self.get_entry(tid) try: - attr_id = self._get_type_attribute_dict(entry_type)[attr_name] + attr_id = self._get_type_attribute_dict(entry_type)[attr_name][ + constants.ATTR_INDEX_KEY + ] except KeyError as e: raise KeyError(f"{entry_type} has no {attr_name} attribute.") from e @@ -984,7 +1197,9 @@ def get_attribute(self, tid: int, attr_name: str) -> Any: entry, entry_type = self.get_entry(tid) try: - attr_id = self._get_type_attribute_dict(entry_type)[attr_name] + attr_id = self._get_type_attribute_dict(entry_type)[attr_name][ + constants.ATTR_INDEX_KEY + ] except KeyError as e: raise KeyError(f"{entry_type} has no {attr_name} attribute.") from e @@ -1630,6 +1845,12 @@ def _parse_onto_file(self): children = entry_tree.root.children while len(children) > 0: + # entry_node represents a node in the ontology tree + # generated by parsing an existing ontology file. + # The entry_node the information of the entry + # represented by this node. It also stores the name + # and the type information of the attributes of the + # entry represented by this node. entry_node = children.pop(0) children.extend(entry_node.children) @@ -1639,17 +1860,23 @@ def _parse_onto_file(self): attr_dict = {} idx = constants.ATTR_BEGIN_INDEX + type_dict = self.fetch_entry_type_data( + entry_name, entry_node.attributes + ) + # sort the attribute dictionary for d in sorted(entry_node.attributes): - name = d - attr_dict[name] = idx + name = d[0] + attr_dict[name] = { + constants.ATTR_INDEX_KEY: idx, + constants.ATTR_TYPE_KEY: type_dict[name], + } idx += 1 entry_dict = {} entry_dict[constants.PARENT_CLASS_KEY] = set() entry_dict[constants.PARENT_CLASS_KEY].add(entry_node.parent.name) - entry_dict[constants.TYPE_ATTR_KEY] = attr_dict - + entry_dict[constants.ATTR_INFO_KEY] = attr_dict DataStore._type_attributes[entry_name] = entry_dict def _init_top_to_core_entries(self): diff --git a/forte/data/ontology/code_generation_objects.py b/forte/data/ontology/code_generation_objects.py index 2098fbd0a..f8315ab02 100644 --- a/forte/data/ontology/code_generation_objects.py +++ b/forte/data/ontology/code_generation_objects.py @@ -17,7 +17,7 @@ import warnings from abc import ABC from pathlib import Path -from typing import Optional, Any, List, Dict, Set, Tuple +from typing import Optional, Any, List, Dict, Set, Tuple, cast from numpy import ndarray from forte.data.ontology.code_generation_exceptions import ( @@ -797,7 +797,7 @@ def __init__(self, name: str): self.children: List[EntryTreeNode] = [] self.parent: Optional[EntryTreeNode] = None self.name: str = name - self.attributes: Set[str] = set() + self.attributes: Set[Tuple[str, str]] = set() def __repr__(self): r"""for printing purpose.""" @@ -817,7 +817,7 @@ def add_node( self, curr_entry_name: str, parent_entry_name: str, - curr_entry_attr: Set[str], + curr_entry_attr: Set[Tuple[str, str]], ): r"""Add a tree node with `curr_entry_name` as a child to `parent_entry_name` in the tree, the attributes `curr_entry_attr` @@ -856,7 +856,9 @@ def collect_parents(self, node_dict: Dict[str, Set[str]]): Args: node_dict: the nodes dictionary of nodes to collect parent nodes - for. + for. The entry represented by nodes in this dictionary do not store + type information of its attributes. This dictionary does not store + the type information of the nodes. """ input_node_dict = node_dict.copy() @@ -864,9 +866,9 @@ def collect_parents(self, node_dict: Dict[str, Set[str]]): found_node = search(self.root, search_node_name=node_name) if found_node is not None: while found_node.parent.name != "root": - node_dict[ - found_node.parent.name - ] = found_node.parent.attributes + node_dict[found_node.parent.name] = set( + val[0] for val in found_node.parent.attributes + ) found_node = found_node.parent def todict(self) -> Dict[str, Any]: @@ -906,12 +908,18 @@ def fromdict( if parent_entry_name is None: self.root = EntryTreeNode(name=tree_dict["name"]) - self.root.attributes = set(tree_dict["attributes"]) + self.root.attributes = set( + cast(Tuple[str, str], tuple(attr)) + for attr in tree_dict["attributes"] + ) else: self.add_node( curr_entry_name=tree_dict["name"], parent_entry_name=parent_entry_name, - curr_entry_attr=set(tree_dict["attributes"]), + curr_entry_attr=set( + cast(Tuple[str, str], tuple(attr)) + for attr in tree_dict["attributes"] + ), ) for child in tree_dict["children"]: self.fromdict(child, tree_dict["name"]) diff --git a/forte/data/ontology/core.py b/forte/data/ontology/core.py index 63802bdfb..22ecbba96 100644 --- a/forte/data/ontology/core.py +++ b/forte/data/ontology/core.py @@ -864,3 +864,5 @@ def __hash__(self): GroupType = TypeVar("GroupType", bound=BaseGroup) LinkType = TypeVar("LinkType", bound=BaseLink) + +ENTRY_TYPE_DATA_STRUCTURES = (FDict, FList) diff --git a/forte/data/ontology/ontology_code_generator.py b/forte/data/ontology/ontology_code_generator.py index b43125cc9..80d0db784 100644 --- a/forte/data/ontology/ontology_code_generator.py +++ b/forte/data/ontology/ontology_code_generator.py @@ -270,6 +270,7 @@ def __init__( # and their attributes (if any) in order to validate the attribute # types. self.allowed_types_tree: Dict[str, Set] = {} + for type_str in ALL_INBUILT_TYPES: self.allowed_types_tree[type_str] = set() @@ -811,7 +812,8 @@ def parse_schema( module_writer.add_entry(en, entry_item) # Adding entry attributes to the allowed types for validation. - for property_name in properties: + for property in properties: + property_name = property[0] # Check if the name is allowed. if not property_name.isidentifier(): raise InvalidIdentifierException( @@ -819,14 +821,16 @@ def parse_schema( f"python identifier." ) - if property_name in self.allowed_types_tree[en.class_name]: + if property_name in set( + val[0] for val in self.allowed_types_tree[en.class_name] + ): warnings.warn( f"Attribute type for the entry {en.class_name} " f"and the attribute {property_name} already present in " f"the ontology, will be overridden", DuplicatedAttributesWarning, ) - self.allowed_types_tree[en.class_name].add(property_name) + self.allowed_types_tree[en.class_name].add(property) # populate the entry tree based on information if merged_entry_tree is not None: curr_entry_name = en.class_name @@ -967,15 +971,16 @@ def construct_init(self, entry_name: EntryName, base_entry: str): def parse_entry( self, entry_name: EntryName, schema: Dict - ) -> Tuple[EntryDefinition, List[str]]: + ) -> Tuple[EntryDefinition, List[Tuple[str, str]]]: """ Args: entry_name: Object holds various name form of the entry. schema: Dictionary containing specifications for an entry. Returns: extracted entry information: entry package string, entry - filename, entry class entry_name, generated entry code and entry - attribute names. + filename, entry class entry_name, generated entry code and a list + of tuples where each element in the list represents the an attribute + in the entry and its corresponding type. """ this_manager = self.import_managers.get(entry_name.module_name) @@ -1032,16 +1037,21 @@ def parse_entry( property_items, property_names = [], [] for prop_schema in properties: # TODO: add test - prop_name = prop_schema["name"] - if prop_name in RESERVED_ATTRIBUTE_NAMES: + + # the prop attributes will store the properties of each attribute + # of the the entry defined by the ontology. The properties are + # the name of the attribute and its data type. + prop = (prop_schema["name"], prop_schema["type"]) + + if prop_schema["name"] in RESERVED_ATTRIBUTE_NAMES: raise InvalidIdentifierException( - f"The attribute name {prop_name} is reserved and cannot be " + f"The attribute name {prop_schema['name']} is reserved and cannot be " f"used, please consider changed the name. The list of " f"reserved name strings are " f"{RESERVED_ATTRIBUTE_NAMES}" ) - property_names.append(prop_schema["name"]) + property_names.append(prop) property_items.append(self.parse_property(entry_name, prop_schema)) # For special classes that requires a constraint. diff --git a/forte/data/ontology/top.py b/forte/data/ontology/top.py index f2e6f8de7..b19224873 100644 --- a/forte/data/ontology/top.py +++ b/forte/data/ontology/top.py @@ -1250,6 +1250,12 @@ def __getstate__(self): # Entry store is being integrated into DataStore state = self.__dict__.copy() state["_modality"] = self._modality.name + + if isinstance(state["_cache"], np.ndarray): + state["_cache"] = list(self._cache.tolist()) + if isinstance(state["_embedding"], np.ndarray): + state["_embedding"] = list(self._embedding.tolist()) + return state def __setstate__(self, state): @@ -1261,6 +1267,20 @@ def __setstate__(self, state): self.__dict__.update(state) self._modality = getattr(Modality, state["_modality"]) + # During de-serialization, convert the list back to numpy array. + if "_embedding" in state: + state["_embedding"] = np.array(state["_embedding"]) + else: + state["_embedding"] = np.empty(0) + + # Here we assume that if the payload is not text (in which case + # cache is stored a string), cache will always be stored as a + # numpy array (which is converted to a list during serialization). + # This check can be made more comprehensive when new types of + # payloads are introduced. + if "_cache" in state and isinstance(state["_cache"], list): + state["_cache"] = np.array(state["_cache"]) + SinglePackEntries = ( Link, diff --git a/tests/forte/data/data_store_serialization_test.py b/tests/forte/data/data_store_serialization_test.py index f5048f5c2..5fb33ce72 100644 --- a/tests/forte/data/data_store_serialization_test.py +++ b/tests/forte/data/data_store_serialization_test.py @@ -16,11 +16,14 @@ """ import logging +from typing import Union import unittest import tempfile import os from sortedcontainers import SortedList from forte.data.data_store import DataStore +from forte.data.ontology.core import FDict +from ft.onto.base_ontology import Classification logging.basicConfig(level=logging.DEBUG) @@ -36,19 +39,28 @@ def setUp(self) -> None: DataStore._type_attributes = { "ft.onto.base_ontology.Document": { "attributes": { - "sentiment": 4, - "classifications": 5, + "sentiment": {"index": 4, "type": (dict, (str, float))}, + "classifications": { + "index": 5, + "type": (FDict, (str, Classification)), + }, }, "parent_entry": "forte.data.ontology.top.Annotation", }, "ft.onto.base_ontology.Sentence": { "attributes": { - "sentiment": 4, - "speaker": 5, - "part_id": 6, - "classification_test": 7, - "classifications": 8, - "temp": 9, + "sentiment": {"index": 4, "type": (dict, (str, float))}, + "speaker": {"index": 5, "type": (Union, (str, type(None)))}, + "part_id": {"index": 6, "type": (Union, (int, type(None)))}, + "classification_test": { + "index": 7, + "type": (dict, (str, float)), + }, + "classifications": { + "index": 8, + "type": (FDict, (str, Classification)), + }, + "temp": {"index": 9, "type": (Union, (str, type(None)))}, }, "parent_entry": "forte.data.ontology.top.Annotation", }, @@ -101,7 +113,7 @@ def setUp(self) -> None: ], [ 40, - 55, + 55, 7890, "ft.onto.base_ontology.Document", "Very Positive", @@ -236,19 +248,34 @@ def test_save_attribute_pickle(self): DataStore._type_attributes = { "ft.onto.base_ontology.Document": { "attributes": { - "document_class": 4, - "sentiment": 5, - "classifications": 6, + "document_class": {"index": 4, "type": (list, (str,))}, + "sentiment": {"index": 5, "type": (dict, (str, float))}, + "classifications": { + "index": 6, + "type": (FDict, (str, Classification)), + }, }, "parent_entry": "forte.data.ontology.top.Annotation", }, "ft.onto.base_ontology.Sentence": { "attributes": { - "speaker": 4, - "part_id": 5, - "sentiment": 6, - "classification": 7, - "classifications": 8, + "speaker": { + "index": 4, + "type": (Union, (str, type(None))), + }, + "part_id": { + "index": 5, + "type": (Union, (int, type(None))), + }, + "sentiment": {"index": 6, "type": (dict, (str, float))}, + "classification": { + "index": 7, + "type": (dict, (str, float)), + }, + "classifications": { + "index": 8, + "type": (FDict, (str, Classification)), + }, }, "parent_entry": "forte.data.ontology.top.Annotation", }, @@ -423,14 +450,14 @@ def test_save_attribute_pickle(self): ][3], }, ) - + self.assertEqual( temp._DataStore__tid_idx_dict, { 10123: ["forte.data.ontology.top.Group", 0], 23456: ["forte.data.ontology.top.Group", 1], 88888: ["forte.data.ontology.top.Link", 0], - } + }, ) temp = DataStore.deserialize( @@ -491,7 +518,7 @@ def test_save_attribute_pickle(self): 9, 9999, "ft.onto.base_ontology.Sentence", - "Positive", + "Positive", "teacher", 1, None, @@ -509,6 +536,7 @@ def test_save_attribute_pickle(self): "Class C", "Class D", "abc", + ], [ 60, @@ -534,7 +562,7 @@ def test_save_attribute_pickle(self): "class2", "good", ], - ] + ], ), "forte.data.ontology.top.Group": [ [ @@ -557,7 +585,7 @@ def test_save_attribute_pickle(self): 88888, "forte.data.ontology.top.Link", ], - ], + ], }, ) self.assertEqual( @@ -599,7 +627,7 @@ def test_save_attribute_pickle(self): 10123: ["forte.data.ontology.top.Group", 0], 23456: ["forte.data.ontology.top.Group", 1], 88888: ["forte.data.ontology.top.Link", 0], - } + }, ) # test check_attribute with accept_unknown_attribute = False @@ -626,19 +654,34 @@ def test_fast_pickle(self): DataStore._type_attributes = { "ft.onto.base_ontology.Document": { "attributes": { - "document_class": 4, - "sentiment": 5, - "classifications": 6, + "document_class": {"index": 4, "type": (list, (str,))}, + "sentiment": {"index": 5, "type": (dict, (str, float))}, + "classifications": { + "index": 6, + "type": (FDict, (str, Classification)), + }, }, "parent_entry": "forte.data.ontology.top.Annotation", }, "ft.onto.base_ontology.Sentence": { "attributes": { - "speaker": 4, - "part_id": 5, - "sentiment": 6, - "classification": 7, - "classifications": 8, + "speaker": { + "index": 4, + "type": (Union, (str, type(None))), + }, + "part_id": { + "index": 5, + "type": (Union, (int, type(None))), + }, + "sentiment": {"index": 6, "type": (dict, (str, float))}, + "classification": { + "index": 7, + "type": (dict, (str, float)), + }, + "classifications": { + "index": 8, + "type": (FDict, (str, Classification)), + }, }, "parent_entry": "forte.data.ontology.top.Annotation", }, @@ -711,19 +754,34 @@ def test_delete_serialize(self): DataStore._type_attributes = { "ft.onto.base_ontology.Document": { "attributes": { - "document_class": 4, - "sentiment": 5, - "classifications": 6, + "document_class": {"index": 4, "type": (list, (str,))}, + "sentiment": {"index": 5, "type": (dict, (str, float))}, + "classifications": { + "index": 6, + "type": (FDict, (str, Classification)), + }, }, "parent_entry": "forte.data.ontology.top.Annotation", }, "ft.onto.base_ontology.Sentence": { "attributes": { - "speaker": 4, - "part_id": 5, - "sentiment": 6, - "classification": 7, - "classifications": 8, + "speaker": { + "index": 4, + "type": (Union, (str, type(None))), + }, + "part_id": { + "index": 5, + "type": (Union, (int, type(None))), + }, + "sentiment": {"index": 6, "type": (dict, (str, float))}, + "classification": { + "index": 7, + "type": (dict, (str, float)), + }, + "classifications": { + "index": 8, + "type": (FDict, (str, Classification)), + }, }, "parent_entry": "forte.data.ontology.top.Annotation", }, @@ -828,6 +886,7 @@ def test_delete_serialize(self): "Positive", None, "class2", + ], ], ), @@ -878,13 +937,13 @@ def test_delete_serialize(self): ][3], }, ) - + self.assertEqual( temp._DataStore__tid_idx_dict, { 23456: ["forte.data.ontology.top.Group", 0], 88888: ["forte.data.ontology.top.Link", 0], - } + }, ) diff --git a/tests/forte/data/data_store_test.py b/tests/forte/data/data_store_test.py index 7e155f19a..87cde3f34 100644 --- a/tests/forte/data/data_store_test.py +++ b/tests/forte/data/data_store_test.py @@ -20,7 +20,7 @@ import unittest import copy from sortedcontainers import SortedList -from typing import Optional, Dict +from typing import List, Optional, Dict, Union from dataclasses import dataclass from forte.data.data_store import DataStore from forte.data.ontology.top import ( @@ -35,6 +35,8 @@ ) from forte.data.data_pack import DataPack from forte.common import constants +from forte.data.ontology.core import FDict +from ft.onto.base_ontology import Classification logging.basicConfig(level=logging.DEBUG) @@ -126,19 +128,34 @@ def setUp(self) -> None: self.reference_type_attributes = { "ft.onto.base_ontology.Document": { "attributes": { - "document_class": 4, - "sentiment": 5, - "classifications": 6, + "document_class": {"index": 4, "type": (list, (str,))}, + "sentiment": {"index": 5, "type": (dict, (str, float))}, + "classifications": { + "index": 6, + "type": (FDict, (str, Classification)), + }, }, "parent_class": set(), }, "ft.onto.base_ontology.Sentence": { "attributes": { - "speaker": 4, - "part_id": 5, - "sentiment": 6, - "classification": 7, - "classifications": 8, + "speaker": { + "index": 4, + "type": (Union, (str, type(None))), + }, + "part_id": { + "index": 5, + "type": (Union, (int, type(None))), + }, + "sentiment": {"index": 6, "type": (dict, (str, float))}, + "classification": { + "index": 7, + "type": (dict, (str, float)), + }, + "classifications": { + "index": 8, + "type": (FDict, (str, Classification)), + }, }, "parent_class": set(), }, @@ -168,20 +185,32 @@ def setUp(self) -> None: DataStore._type_attributes["ft.onto.base_ontology.Document"] = { "attributes": { - "document_class": 4, - "sentiment": 5, - "classifications": 6, + "document_class": {"index": 4, "type": (list, (str,))}, + "sentiment": {"index": 5, "type": (dict, (str, float))}, + "classifications": { + "index": 6, + "type": (FDict, (str, Classification)), + }, }, "parent_class": set(), } DataStore._type_attributes["ft.onto.base_ontology.Sentence"] = { "attributes": { - "speaker": 4, - "part_id": 5, - "sentiment": 6, - "classification": 7, - "classifications": 8, + "speaker": { + "index": 4, + "type": (Union, (str, type(None))), + }, + "part_id": { + "index": 5, + "type": (Union, (int, type(None))), + }, + "sentiment": {"index": 6, "type": (dict, (str, float))}, + "classification": {"index": 7, "type": (dict, (str, float))}, + "classifications": { + "index": 8, + "type": (FDict, (str, Classification)), + }, }, "parent_class": set(), } @@ -306,6 +335,7 @@ def test_get_type_info(self): ) empty_data_store._get_type_info("ft.onto.base_ontology.Sentence") self.assertEqual(len(empty_data_store._DataStore__elements), 0) + self.assertEqual( DataStore._type_attributes["ft.onto.base_ontology.Sentence"], self.reference_type_attributes["ft.onto.base_ontology.Sentence"], @@ -1269,21 +1299,21 @@ def test_check_onto_file(self): expected_type_attributes = { "ft.onto.test.Description": { "attributes": { - "author": 4, - "passage_id": 5, + "author": {"index": 4, "type": (type(None), (str,))}, + "passage_id": {"index": 5, "type": (type(None), (str,))}, }, "parent_class": {"forte.data.ontology.top.Annotation"}, }, "ft.onto.test.EntityMention": { "attributes": { - "ner_type": 4, + "ner_type": {"index": 4, "type": (type(None), (str,))}, }, "parent_class": {"forte.data.ontology.top.Annotation"}, }, "ft.onto.test.MedicalEntityMention": { "attributes": { - "umls_entities": 4, - "umls_link": 5, + "umls_entities": {"index": 4, "type": (type(None), (int,))}, + "umls_link": {"index": 5, "type": (type(None), (str,))}, }, "parent_class": {"ft.onto.test.EntityMention"}, },