diff --git a/forte/data/base_pack.py b/forte/data/base_pack.py index ac3743d39..acb059af0 100644 --- a/forte/data/base_pack.py +++ b/forte/data/base_pack.py @@ -39,6 +39,7 @@ from forte.data.container import EntryContainer from forte.data.index import BaseIndex from forte.data.ontology.core import Entry, EntryType, GroupType, LinkType +from forte.version import PACK_VERSION, DEFAULT_PACK_VERSION __all__ = ["BasePack", "BaseMeta", "PackType"] @@ -98,6 +99,7 @@ def __init__(self, pack_name: Optional[str] = None): super().__init__() self.links: List[LinkType] = [] self.groups: List[GroupType] = [] + self.pack_version: str = PACK_VERSION self._meta: BaseMeta = self._init_meta(pack_name) self._index: BaseIndex = BaseIndex() @@ -199,11 +201,18 @@ def _deserialize( with _open(data_source, mode="rb") as f: # type: ignore pack = pickle.load(f) + if not hasattr(pack, "pack_version"): + pack.pack_version = DEFAULT_PACK_VERSION + return pack # type: ignore @classmethod def from_string(cls, data_content: str) -> "BasePack": - return jsonpickle.decode(data_content) + pack = jsonpickle.decode(data_content) + if not hasattr(pack, "pack_version"): + pack.pack_version = DEFAULT_PACK_VERSION + + return pack @abstractmethod def delete_entry(self, entry: EntryType): diff --git a/forte/data/multi_pack.py b/forte/data/multi_pack.py index e1998f7fc..1e5c99046 100644 --- a/forte/data/multi_pack.py +++ b/forte/data/multi_pack.py @@ -14,10 +14,14 @@ import copy import logging + from pathlib import Path from typing import Dict, List, Set, Union, Iterator, Optional, Type, Any, Tuple +import jsonpickle + from sortedcontainers import SortedList +from packaging.version import Version from forte.common import ProcessExecutionException from forte.data.base_pack import BaseMeta, BasePack @@ -34,6 +38,8 @@ ) from forte.data.types import DataRequest from forte.utils import get_class +from forte.version import DEFAULT_PACK_VERSION, PACK_ID_COMPATIBLE_VERSION + logger = logging.getLogger(__name__) @@ -168,7 +174,33 @@ def _validate(self, entry: EntryType) -> bool: # TODO: get_subentry maybe useless def get_subentry(self, pack_idx: int, entry_id: int): - return self.get_pack_at(pack_idx).get_entry(entry_id) + r""" + Get sub_entry from multi pack. This method uses `pack_id` (a unique + identifier assigned to datapack) to get a pack from multi pack, + and then return its sub_entry with entry_id. Noted this is changed from + the way of accessing such pack before the PACK_ID_COMPATIBLE_VERSION, + in which the `pack_idx` was used as list index number to access/reference + a pack within the multi pack (and in this case then get the sub_entry). + + Args: + pack_idx (int): The pack_id for the data_pack in the + multi pack. + entry_id (int): the id for the entry from the pack with pack_id + + Returns: + sub-entry of the pack with id = `pack_idx` + + """ + pack_array_index: int = pack_idx # the old way + # the following check if the pack version is higher than the (backward) + # compatible version in which pack_idx is the pack_id not list index + if Version(self.pack_version) >= Version(PACK_ID_COMPATIBLE_VERSION): + pack_array_index = self.get_pack_index( + pack_idx + ) # the new way: using pack_id instead of array index + + return self._packs[pack_array_index].get_entry(entry_id) + # return self.get_pack_at(pack_idx).get_entry(entry_id) #old version def get_span_text(self, begin: int, end: int): raise ValueError( @@ -469,10 +501,10 @@ def get_pack_index(self, pack_id: int) -> int: """ try: return self._inverse_pack_ref[pack_id] - except KeyError as e: + except KeyError as ke: raise ProcessExecutionException( f"Pack {pack_id} is not in this multi-pack." - ) from e + ) from ke def get_pack(self, name: str) -> DataPack: """ @@ -856,7 +888,32 @@ def deserialize( Returns: An data pack object deserialized from the string. """ - return cls._deserialize(data_path, serialize_method, zip_pack) + # pylint: disable=protected-access + mp: MultiPack = cls._deserialize(data_path, serialize_method, zip_pack) + + # (fix 595) change the dictionary's key after deserialization from str back to int + mp._inverse_pack_ref = { + int(k): v for k, v in mp._inverse_pack_ref.items() + } + + return mp + + @classmethod + def from_string(cls, data_content: str): + # pylint: disable=protected-access + # can not use explict type hint for mp as pylint does not allow type change + # from base_pack to multi_pack which is problematic so use jsonpickle instead + + mp = jsonpickle.decode(data_content) + if not hasattr(mp, "pack_version"): + mp.pack_version = DEFAULT_PACK_VERSION + # (fix 595) change the dictionary's key after deserialization from str back to int + mp._inverse_pack_ref = { # pylint: disable=no-member + int(k): v + for k, v in mp._inverse_pack_ref.items() # pylint: disable=no-member + } + + return mp def _add_entry(self, entry: EntryType) -> EntryType: r"""Force add an :class:`forte.data.ontology.core.Entry` object to the diff --git a/forte/data/ontology/core.py b/forte/data/ontology/core.py index ce1c35d74..9426ba2c7 100644 --- a/forte/data/ontology/core.py +++ b/forte/data/ontology/core.py @@ -16,6 +16,7 @@ representation system. """ import uuid + from abc import abstractmethod, ABC from collections.abc import MutableSequence, MutableMapping from dataclasses import dataclass @@ -35,6 +36,8 @@ import numpy as np +from packaging.version import Version + from forte.data.container import ContainerType, BasePointer __all__ = [ @@ -53,6 +56,7 @@ ] from forte.utils import get_full_module_name +from forte.version import DEFAULT_PACK_VERSION, PACK_ID_COMPATIBLE_VERSION default_entry_fields = [ "_Entry__pack", @@ -277,7 +281,9 @@ def as_pointer(self, from_entry: "Entry"): """ if isinstance(from_entry, MultiEntry): return MpPointer( - from_entry.pack.get_pack_index(self.pack_id), self.tid + # bug fix/enhancement 559: change pack index to pack_id for multi-entry/multi-pack + self.pack_id, + self.tid, # from_entry.pack.get_pack_index(self.pack_id) ) elif isinstance(from_entry, Entry): return Pointer(self.tid) @@ -429,7 +435,20 @@ def _resolve_pointer(self, ptr: BasePointer) -> Entry: if isinstance(ptr, Pointer): return self.pack.get_entry(ptr.tid) elif isinstance(ptr, MpPointer): - return self.pack.packs[ptr.pack_index].get_entry(ptr.tid) + # bugfix/new feature 559: in new version pack_index will be using pack_id internally + pack_array_index = ptr.pack_index # old version + pack_version = "" + try: + pack_version = self.pack.pack_version + except AttributeError: + pack_version = DEFAULT_PACK_VERSION # set to default if lacking version attribute + + if Version(pack_version) >= Version(PACK_ID_COMPATIBLE_VERSION): + pack_array_index = self.pack.get_pack_index( + ptr.pack_index + ) # default: new version + + return self.pack.packs[pack_array_index].get_entry(ptr.tid) else: raise TypeError(f"Unknown pointer type {ptr.__class__}") diff --git a/forte/data/ontology/top.py b/forte/data/ontology/top.py index b4aee603f..276e7ffa4 100644 --- a/forte/data/ontology/top.py +++ b/forte/data/ontology/top.py @@ -431,7 +431,9 @@ def set_parent(self, parent: Entry): f"The parent of {type(self)} should be an " f"instance of {self.ParentType}, but get {type(parent)}" ) - self._parent = self.pack.get_pack_index(parent.pack_id), parent.tid + # fix bug/enhancement #559: using pack_id instead of index + # self._parent = self.pack.get_pack_index(parent.pack_id), parent.tid + self._parent = parent.pack_id, parent.tid def set_child(self, child: Entry): r"""This will set the `child` of the current instance with given Entry. @@ -449,7 +451,9 @@ def set_child(self, child: Entry): f"The child of {type(self)} should be an " f"instance of {self.ChildType}, but get {type(child)}" ) - self._child = self.pack.get_pack_index(child.pack_id), child.tid + # fix bug/enhancement #559: using pack_id instead of index + # self._child = self.pack.get_pack_index(child.pack_id), child.tid + self._child = child.pack_id, child.tid def get_parent(self) -> Entry: r"""Get the parent entry of the link. @@ -499,7 +503,8 @@ def add_member(self, member: Entry): ) self._members.append( - (self.pack.get_pack_index(member.pack_id), member.tid) + # fix bug/enhancement 559: use pack_id instead of index + (member.pack_id, member.tid) # self.pack.get_pack_index(..) ) def get_members(self) -> List[Entry]: diff --git a/forte/version.py b/forte/version.py index 7bea629cf..a06dccb73 100644 --- a/forte/version.py +++ b/forte/version.py @@ -19,3 +19,6 @@ VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR) VERSION = "{0}.{1}.{2}".format(_MAJOR, _MINOR, _REVISION) FORTE_IR_VERSION = "0.0.1" +PACK_VERSION = "0.0.1" +DEFAULT_PACK_VERSION = "0.0.0" +PACK_ID_COMPATIBLE_VERSION = "0.0.1" diff --git a/requirements.txt b/requirements.txt index 0567dc95f..486793788 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,5 @@ typing-inspect>=0.6.0 dataclasses~=0.7; python_version <'3.7' importlib-resources==5.1.4;python_version<'3.7' fastapi==0.65.2 -uvicorn==0.14.0 \ No newline at end of file +uvicorn==0.14.0 +packaging~=21.2 diff --git a/tests/forte/data/entry_data_structures_test.py b/tests/forte/data/entry_data_structures_test.py index 8e5e22fbe..aa68787c4 100644 --- a/tests/forte/data/entry_data_structures_test.py +++ b/tests/forte/data/entry_data_structures_test.py @@ -178,6 +178,54 @@ def test_entry_attribute_mp_pointer(self): self.assertEqual(re_mpe.refer_entry.tid, mpe.refer_entry.tid) self.assertEqual(re_mpe.tid, mpe.tid) + def test_mp_pointer_with_version(self): + old_serialized_mp = """{"py/object": "forte.data.multi_pack.MultiPack", "py/state": {"_creation_records": {}, + "_field_records": {}, "links": [], "groups": [], "_meta": {"py/object": "forte.data.multi_pack.MultiPackMeta", + "py/state": {"pack_name": "doc1", "_pack_id": 181242127422469546094667436428172965279, "record": {}}}, + "_pack_ref": [339609801674405881625808524240847417793, 2921617025007791382061014912332775176], + "_inverse_pack_ref": {"339609801674405881625808524240847417793": 0, "2921617025007791382061014912332775176": 1}, + "_pack_names": ["pack1", "pack2"], "_name_index": {"pack1": 0, "pack2": 1}, "generics": [{"py/object": + "entry_data_structures_test.ExampleMPEntry", "py/state": {"_tid": 47726154965183551280893968259456773646, + "refer_entry": {"py/object": "forte.data.ontology.core.MpPointer", "py/state": + {"_pack_index": 0, "_tid": 75914137358482571607300906707755792037}}}}], + "_MultiPack__default_pack_prefix": "_pack"}}""" + + recovered_mp = MultiPack.from_string(old_serialized_mp) + from forte.version import DEFAULT_PACK_VERSION + self.assertEqual(recovered_mp.pack_version, DEFAULT_PACK_VERSION) + + s_packs: List[str] = ["""{"py/object": "forte.data.data_pack.DataPack", "py/state": {"_creation_records": {}, + "_field_records": {}, "links": [], "groups": [], "_meta": {"py/object": "forte.data.data_pack.Meta", + "py/state": {"pack_name": null, "_pack_id": 339609801674405881625808524240847417793, "record": {}, + "language": "eng", "span_unit": "character", "info": {}}}, "_text": "", "annotations": [], "generics": + [{"py/object": "entry_data_structures_test.ExampleEntry", "py/state": + {"_tid": 75914137358482571607300906707755792037, "secret_number": 1}}], + "_DataPack__replace_back_operations": [], "_DataPack__processed_original_spans": [], + "_DataPack__orig_text_len": 0}}""", """{"py/object": "forte.data.data_pack.DataPack", "py/state": + {"_creation_records": {}, "_field_records": {}, "links": [], "groups": [], "_meta": {"py/object": + "forte.data.data_pack.Meta", "py/state": {"pack_name": null, "_pack_id": 2921617025007791382061014912332775176, + "record": {}, "language": "eng", "span_unit": "character", "info": {}}}, "_text": "", "annotations": [], + "generics": [{"py/object": "entry_data_structures_test.ExampleEntry", "py/state": + {"_tid": 242133944929228462168174254535391188929}}], "_DataPack__replace_back_operations": [], + "_DataPack__processed_original_spans": [], "_DataPack__orig_text_len": 0}}"""] + + recovered_packs = [DataPack.from_string(s) for s in s_packs] + + recovered_mp.relink(recovered_packs) + + re_mpe: ExampleMPEntry = recovered_mp.get_single(ExampleMPEntry) + self.assertIsInstance(re_mpe.refer_entry, ExampleEntry) + + def test_multipack_deserialized_dictionary_recover(self): + serialized_mp = self.pack.to_string(drop_record=True) + recovered_mp = MultiPack.from_string(serialized_mp) + + s_packs = [p.to_string() for p in self.pack.packs] + recovered_packs = [DataPack.from_string(s) for s in s_packs] + pid = recovered_packs[0].pack_id + self.assertEqual(recovered_mp._inverse_pack_ref[pid], 0) + recovered_mp.relink(recovered_packs) + class EntryDataStructure(unittest.TestCase): def setUp(self): @@ -257,10 +305,10 @@ def test_entry_dict(self): def test_entry_key_memories(self): pack = ( Pipeline[MultiPack]() - .set_reader(EmptyReader()) - .add(ChildEntryAnnotator()) - .initialize() - .process(["pack1", "pack2"]) + .set_reader(EmptyReader()) + .add(ChildEntryAnnotator()) + .initialize() + .process(["pack1", "pack2"]) ) DataPack.from_string(pack.to_string(True)) diff --git a/tests/forte/data/multi_pack_test.py b/tests/forte/data/multi_pack_test.py index a1c339da7..14c5eb0fc 100644 --- a/tests/forte/data/multi_pack_test.py +++ b/tests/forte/data/multi_pack_test.py @@ -176,6 +176,45 @@ def test_multipack_entries(self): ], ) + # fix bug 559: additional test for index to pack_id changes + serialized_mp = self.multi_pack.to_string(drop_record=False) + recovered_mp = MultiPack.from_string(serialized_mp) + s_packs = [p.to_string() for p in self.multi_pack.packs] + recovered_packs = [DataPack.from_string(s) for s in s_packs] + + # 1st verify recovered_packs + left_tokens_recovered = [t.text for t in recovered_packs[0].get(Token)] + right_tokens_recovered = [t.text for t in recovered_packs[1].get(Token)] + + self.assertListEqual( + left_tokens_recovered, ["This", "pack", "contains", "some", "sample", "data."] + ) + self.assertListEqual( + right_tokens_recovered, + ["This", "pack", "contains", "some", "other", "sample", "data."], + ) + + recovered_mp.relink(recovered_packs) + + # teh verfiy the links are ok (restored correctly) + linked_tokens_recovered = [] + for link in recovered_mp.all_links: + parent_text = link.get_parent().text + child_text = link.get_child().text + linked_tokens_recovered.append((parent_text, child_text)) + + self.assertListEqual( + linked_tokens_recovered, + [ + ("This", "This"), + ("pack", "pack"), + ("contains", "contains"), + ("some", "some"), + ("sample", "sample"), + ("data.", "data."), + ], + ) + # 3. Test deletion # Delete the second link.