Skip to content

Commit

Permalink
added static methods to infer pack type (#553)
Browse files Browse the repository at this point in the history
* added static methods to infer pack type

* added unit tests for datapack type inference.

Co-authored-by: Zhanyuan Zhang <[email protected]>
  • Loading branch information
zhanyuanucb and Zhanyuan Zhang authored Nov 2, 2021
1 parent 2af5029 commit 169718e
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 13 deletions.
18 changes: 9 additions & 9 deletions forte/data/base_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def default_configs(cls):
"""
return {"zip_pack": False, "serialize_method": "jsonpickle"}

@property
def pack_type(self):
@staticmethod
def pack_type():
raise NotImplementedError

@abstractmethod
Expand Down Expand Up @@ -210,7 +210,7 @@ def _lazy_iter(self, *args, **kwargs):
if self._cache_directory is not None:
self.cache_data(collection, pack, not_first)

if not isinstance(pack, self.pack_type):
if not isinstance(pack, self.pack_type()):
raise ValueError(
f"No Pack object read from the given "
f"collection {collection}, returned {type(pack)}."
Expand Down Expand Up @@ -355,10 +355,10 @@ def read_from_cache(
with open(cache_filename, "r", encoding="utf-8") as cache_file:
for line in cache_file:
pack = DataPack.from_string(line.strip())
if not isinstance(pack, self.pack_type):
if not isinstance(pack, self.pack_type()):
raise TypeError(
f"Pack deserialized from {cache_filename} "
f"is {type(pack)}, but expect {self.pack_type}"
f"is {type(pack)}, but expect {self.pack_type()}"
)
yield pack

Expand All @@ -380,8 +380,8 @@ def set_text(self, pack: DataPack, text: str):
class PackReader(BaseReader[DataPack], ABC):
r"""A Pack Reader reads data into :class:`DataPack`."""

@property
def pack_type(self):
@staticmethod
def pack_type():
return DataPack


Expand All @@ -390,6 +390,6 @@ class MultiPackReader(BaseReader[MultiPack], ABC):
data readers which return :class:`MultiPack`.
"""

@property
def pack_type(self):
@staticmethod
def pack_type():
return MultiPack
16 changes: 16 additions & 0 deletions forte/data/caster.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ class Caster(
def cast(self, pack: InputPackType) -> OutputPackType:
raise NotImplementedError

@staticmethod
def input_pack_type():
raise NotImplementedError

@staticmethod
def output_pack_type():
raise NotImplementedError


class MultiPackBoxer(Caster[DataPack, MultiPack]):
"""
Expand All @@ -62,3 +70,11 @@ def cast(self, pack: DataPack) -> MultiPack:
@classmethod
def default_configs(cls):
return {"pack_name": "default"}

@staticmethod
def input_pack_type():
return DataPack

@staticmethod
def output_pack_type():
return MultiPack
8 changes: 4 additions & 4 deletions forte/data/readers/misc_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,12 @@ def _parse_pack(self, pack: PackType) -> Iterator[PackType]:


class RawPackReader(BaseRawPackReader):
@property
def pack_type(self):
@staticmethod
def pack_type():
return DataPack


class RawMultiPackReader(BaseRawPackReader):
@property
def pack_type(self):
@staticmethod
def pack_type():
return MultiPack
39 changes: 39 additions & 0 deletions tests/forte/data/datapack_type_infer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import unittest
from ddt import data, ddt

from forte.data.caster import MultiPackBoxer
from forte.data.data_pack import DataPack
from forte.data.multi_pack import MultiPack
from forte.data.readers.misc_readers import RawPackReader, RawMultiPackReader
from forte.data.readers.multipack_sentence_reader import MultiPackSentenceReader
from forte.data.readers.multipack_terminal_reader import MultiPackTerminalReader
from forte.data.readers.plaintext_reader import PlainTextReader


@ddt
class DataPackTypeInferTest(unittest.TestCase):

@data(
PlainTextReader,
RawPackReader,
)
def test_datapack_reader(self, component):
reader = component()
self.assertTrue(reader.pack_type() is DataPack)

@data(
MultiPackSentenceReader,
MultiPackTerminalReader,
RawMultiPackReader,
)
def test_multipack_reader(self, component):
reader = component()
self.assertTrue(reader.pack_type() is MultiPack)

@data(
MultiPackBoxer,
)
def test_multipack_boxer(self, component):
caster = component()
self.assertTrue(caster.input_pack_type() is DataPack)
self.assertTrue(caster.output_pack_type() is MultiPack)

0 comments on commit 169718e

Please sign in to comment.