Skip to content

Commit

Permalink
Image size access from ImageAnnotation (#893)
Browse files Browse the repository at this point in the history
  • Loading branch information
hepengfe authored Aug 26, 2022
1 parent 588d6bc commit 6b29ed2
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 34 deletions.
46 changes: 27 additions & 19 deletions docs/notebook_tutorial/ocr.ipynb

Large diffs are not rendered by default.

77 changes: 62 additions & 15 deletions forte/data/ontology/top.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,26 +748,41 @@ def image(self):

@property
def max_x(self):
return self._image_width - 1
return self.image_shape[1] - 1

@property
def max_y(self):
return self._image_height - 1
return self.image_shape[0] - 1

def set_image_shape(self, width, height):
@property
def image_shape(self):
"""
This function is used to set the shape of the image.
Returns the shape of the image.
Args:
width: the width of the image. The unit is pixel.
height: the height of the image. The unit is pixel.
Raises:
ValueError: if the image annotation is not attached to any data
pack.
ValueError: if the image shape is not valid. It must be either
2D or 3D.
Returns:
The shape of the image.
"""
self._image_width = ( # pylint: disable=attribute-defined-outside-init
width
)
self._image_height = ( # pylint: disable=attribute-defined-outside-init
height
)
if self.pack is None:
raise ValueError(
"Cannot get image because image annotation is not "
"attached to any data pack."
)
image_shape = self.pack.get_payload_at(
Modality.Image, self._image_payload_idx
).image_shape
if not 2 <= len(image_shape) <= 3:
raise ValueError(
"Image shape is not valid."
"It should be 2D ([height, width])"
" or 3D ([height, width, channel])."
)
return image_shape

def __eq__(self, other):
if other is None:
Expand Down Expand Up @@ -1045,6 +1060,7 @@ def __init__(

super().__init__(pack)
self._cache: Union[str, np.ndarray] = ""
self._cache_shape: Optional[Sequence[int]] = None
self.replace_back_operations: Sequence[Tuple] = []
self.processed_original_spans: Sequence[Tuple] = []
self.orig_text_len: int = 0
Expand All @@ -1068,15 +1084,43 @@ def cache(self) -> Union[str, np.ndarray]:
def payload_index(self) -> int:
return self.payload_idx

def set_cache(self, data: Union[str, np.ndarray]):
@property
def cache_shape(self) -> Optional[Sequence[int]]:
return self._cache_shape

def set_cache(
self,
data: Union[str, np.ndarray],
cache_shape: Optional[Sequence[int]] = None,
):
"""
Load cache data into the payload.
Load cache data into the payload and set the data shape.
Args:
data: data to be set in the payload. It can be str for text data or
numpy array for audio or image data.
cache_shape: the shape of the data. Its representation varies based
on the modality and its length.
For text data, it is the length of the text.[length, text_embedding_dim]
For audio data, it is the length of the audio. [length, audio_embedding_dim]
For image data, it is the shape of the image. [height, width,
channel]
For example, for image data, if the cache_shape length is 2, it
means the image data is a 2D image [height, width].
"""
self._cache = data
if isinstance(data, np.ndarray):
# if it's a numpy array, we need to set the shape
# even if user input a shape, it will be overwritten
cache_shape = data.shape
elif isinstance(data, str):
cache_shape = (len(data),)
else:
raise ValueError(
f"Unsupported data type {type(data)} for cache. "
f"Currently we only support str and numpy.ndarray."
)
self._cache_shape = cache_shape

def set_payload_index(self, payload_index: int):
"""
Expand All @@ -1087,6 +1131,9 @@ def set_payload_index(self, payload_index: int):
"""
self.payload_idx = payload_index

def image_shape(self):
return self._cache_shape

def __getstate__(self):
r"""
Convert ``modality`` ``Enum`` object to str format for serialization.
Expand Down
24 changes: 24 additions & 0 deletions tests/forte/image_annotation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,33 @@ def setUp(self):
self.line[2, 2] = 1
self.line[3, 3] = 1
self.line[4, 4] = 1
ip = ImagePayload(self.datapack, 0)
ip.set_cache(self.line)
self.img_ann = ImageAnnotation(self.datapack)

def test_datapack_image_operation(self):
datapack = DataPack("image2")
datapack.set_image(self.line, 0)
self.assertTrue(np.array_equal(datapack.image, self.datapack.image))
def fn():
# invalid image index
datapack.set_image(self.line, 2)
self.assertRaises(ProcessExecutionException, fn)

def fn():
# invalid image index
datapack.get_image(1)
self.assertRaises(ProcessExecutionException, fn)

datapack.add_image(self.line)
self.assertTrue(np.array_equal(datapack.get_image(1), self.line))

self.assertEqual(self.img_ann.image_shape, (6, 12))

self.datapack.set_image(self.line, 0)
ImageAnnotation(self.datapack)


def test_datapack_image_operation(self):
datapack = DataPack("image2")

Expand Down

0 comments on commit 6b29ed2

Please sign in to comment.