From 169718e40901ddf5dec7f591c0e576d507d15794 Mon Sep 17 00:00:00 2001 From: Zhanyuan Zhang <32000378+zhanyuanucb@users.noreply.github.com> Date: Tue, 2 Nov 2021 08:55:28 -0700 Subject: [PATCH] added static methods to infer pack type (#553) * added static methods to infer pack type * added unit tests for datapack type inference. Co-authored-by: Zhanyuan Zhang --- forte/data/base_reader.py | 18 ++++----- forte/data/caster.py | 16 ++++++++ forte/data/readers/misc_readers.py | 8 ++-- tests/forte/data/datapack_type_infer_test.py | 39 ++++++++++++++++++++ 4 files changed, 68 insertions(+), 13 deletions(-) create mode 100644 tests/forte/data/datapack_type_infer_test.py diff --git a/forte/data/base_reader.py b/forte/data/base_reader.py index 9605137e1..426cdda45 100644 --- a/forte/data/base_reader.py +++ b/forte/data/base_reader.py @@ -106,8 +106,8 @@ def default_configs(cls): """ return {"zip_pack": False, "serialize_method": "jsonpickle"} - @property - def pack_type(self): + @staticmethod + def pack_type(): raise NotImplementedError @abstractmethod @@ -210,7 +210,7 @@ def _lazy_iter(self, *args, **kwargs): if self._cache_directory is not None: self.cache_data(collection, pack, not_first) - if not isinstance(pack, self.pack_type): + if not isinstance(pack, self.pack_type()): raise ValueError( f"No Pack object read from the given " f"collection {collection}, returned {type(pack)}." @@ -355,10 +355,10 @@ def read_from_cache( with open(cache_filename, "r", encoding="utf-8") as cache_file: for line in cache_file: pack = DataPack.from_string(line.strip()) - if not isinstance(pack, self.pack_type): + if not isinstance(pack, self.pack_type()): raise TypeError( f"Pack deserialized from {cache_filename} " - f"is {type(pack)}, but expect {self.pack_type}" + f"is {type(pack)}, but expect {self.pack_type()}" ) yield pack @@ -380,8 +380,8 @@ def set_text(self, pack: DataPack, text: str): class PackReader(BaseReader[DataPack], ABC): r"""A Pack Reader reads data into :class:`DataPack`.""" - @property - def pack_type(self): + @staticmethod + def pack_type(): return DataPack @@ -390,6 +390,6 @@ class MultiPackReader(BaseReader[MultiPack], ABC): data readers which return :class:`MultiPack`. """ - @property - def pack_type(self): + @staticmethod + def pack_type(): return MultiPack diff --git a/forte/data/caster.py b/forte/data/caster.py index d798205f7..fb47ec9ca 100644 --- a/forte/data/caster.py +++ b/forte/data/caster.py @@ -37,6 +37,14 @@ class Caster( def cast(self, pack: InputPackType) -> OutputPackType: raise NotImplementedError + @staticmethod + def input_pack_type(): + raise NotImplementedError + + @staticmethod + def output_pack_type(): + raise NotImplementedError + class MultiPackBoxer(Caster[DataPack, MultiPack]): """ @@ -62,3 +70,11 @@ def cast(self, pack: DataPack) -> MultiPack: @classmethod def default_configs(cls): return {"pack_name": "default"} + + @staticmethod + def input_pack_type(): + return DataPack + + @staticmethod + def output_pack_type(): + return MultiPack diff --git a/forte/data/readers/misc_readers.py b/forte/data/readers/misc_readers.py index 4d2c8a0e7..ea0155b62 100644 --- a/forte/data/readers/misc_readers.py +++ b/forte/data/readers/misc_readers.py @@ -98,12 +98,12 @@ def _parse_pack(self, pack: PackType) -> Iterator[PackType]: class RawPackReader(BaseRawPackReader): - @property - def pack_type(self): + @staticmethod + def pack_type(): return DataPack class RawMultiPackReader(BaseRawPackReader): - @property - def pack_type(self): + @staticmethod + def pack_type(): return MultiPack diff --git a/tests/forte/data/datapack_type_infer_test.py b/tests/forte/data/datapack_type_infer_test.py new file mode 100644 index 000000000..6598fc684 --- /dev/null +++ b/tests/forte/data/datapack_type_infer_test.py @@ -0,0 +1,39 @@ +import unittest +from ddt import data, ddt + +from forte.data.caster import MultiPackBoxer +from forte.data.data_pack import DataPack +from forte.data.multi_pack import MultiPack +from forte.data.readers.misc_readers import RawPackReader, RawMultiPackReader +from forte.data.readers.multipack_sentence_reader import MultiPackSentenceReader +from forte.data.readers.multipack_terminal_reader import MultiPackTerminalReader +from forte.data.readers.plaintext_reader import PlainTextReader + + +@ddt +class DataPackTypeInferTest(unittest.TestCase): + + @data( + PlainTextReader, + RawPackReader, + ) + def test_datapack_reader(self, component): + reader = component() + self.assertTrue(reader.pack_type() is DataPack) + + @data( + MultiPackSentenceReader, + MultiPackTerminalReader, + RawMultiPackReader, + ) + def test_multipack_reader(self, component): + reader = component() + self.assertTrue(reader.pack_type() is MultiPack) + + @data( + MultiPackBoxer, + ) + def test_multipack_boxer(self, component): + caster = component() + self.assertTrue(caster.input_pack_type() is DataPack) + self.assertTrue(caster.output_pack_type() is MultiPack)