Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
improve generation dataset readers (#122)
Browse files Browse the repository at this point in the history
* improve generation dataset readers

* clean up

* format
  • Loading branch information
epwalsh authored Sep 11, 2020
1 parent 48bceca commit 09fe0b0
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 14 deletions.
16 changes: 12 additions & 4 deletions allennlp_models/generation/dataset_readers/cnn_dm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def __init__(
target_max_tokens: Optional[int] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
super().__init__(
manual_distributed_sharding=True, manual_multi_process_sharding=True, **kwargs
)
self._source_tokenizer = source_tokenizer or SpacyTokenizer()
self._target_tokenizer = target_tokenizer or self._source_tokenizer
self._source_token_indexers = source_token_indexers or {"tokens": SingleIdTokenIndexer()}
Expand Down Expand Up @@ -127,7 +129,7 @@ def _read(self, file_path: str):
dm_stories = {Path(s).stem for s in glob.glob(os.path.join(dm_stories_path, "*.story"))}

with open(url_file_path, "r") as url_file:
for url in url_file:
for url in self.shard_iterable(url_file):
url = url.strip()

url_hash = self._hashhex(url.encode("utf-8"))
Expand Down Expand Up @@ -157,15 +159,21 @@ def text_to_instance(
if self._source_max_tokens is not None and len(tokenized_source) > self._source_max_tokens:
tokenized_source = tokenized_source[: self._source_max_tokens]

source_field = TextField(tokenized_source, self._source_token_indexers)
source_field = TextField(tokenized_source)
if target_sequence is not None:
tokenized_target = self._target_tokenizer.tokenize(target_sequence)
if (
self._target_max_tokens is not None
and len(tokenized_target) > self._target_max_tokens
):
tokenized_target = tokenized_target[: self._target_max_tokens]
target_field = TextField(tokenized_target, self._target_token_indexers)
target_field = TextField(tokenized_target)
return Instance({"source_tokens": source_field, "target_tokens": target_field})
else:
return Instance({"source_tokens": source_field})

@overrides
def apply_token_indexers(self, instance: Instance) -> None:
instance.fields["source_tokens"]._token_indexers = self._source_token_indexers # type: ignore
if "target_tokens" in instance.fields:
instance.fields["target_tokens"]._token_indexers = self._target_token_indexers # type: ignore
16 changes: 12 additions & 4 deletions allennlp_models/generation/dataset_readers/copynet_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def __init__(
source_token_indexers: Dict[str, TokenIndexer] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
super().__init__(
manual_distributed_sharding=True, manual_multi_process_sharding=True, **kwargs
)
self._target_namespace = target_namespace
self._source_tokenizer = source_tokenizer or SpacyTokenizer()
self._target_tokenizer = target_tokenizer or self._source_tokenizer
Expand All @@ -120,7 +122,7 @@ def __init__(
def _read(self, file_path):
with open(cached_path(file_path), "r") as data_file:
logger.info("Reading instances from lines in file at: %s", file_path)
for line_num, line in enumerate(data_file):
for line_num, line in self.shard_iterable(enumerate(data_file)):
line = line.strip("\n")
if not line:
continue
Expand Down Expand Up @@ -164,7 +166,7 @@ def text_to_instance(
if not tokenized_source:
# If the tokenized source is empty, it will cause issues downstream.
raise ValueError(f"source tokenizer produced no tokens from source '{source_string}'")
source_field = TextField(tokenized_source, self._source_token_indexers)
source_field = TextField(tokenized_source)

# For each token in the source sentence, we keep track of the matching token
# in the target sentence (which will be the OOV symbol if there is no match).
Expand All @@ -177,7 +179,7 @@ def text_to_instance(
tokenized_target = self._target_tokenizer.tokenize(target_string)
tokenized_target.insert(0, Token(START_SYMBOL))
tokenized_target.append(Token(END_SYMBOL))
target_field = TextField(tokenized_target, self._target_token_indexers)
target_field = TextField(tokenized_target)

fields_dict["target_tokens"] = target_field
meta_fields["target_tokens"] = [y.text for y in tokenized_target[1:-1]]
Expand All @@ -193,3 +195,9 @@ def text_to_instance(
fields_dict["metadata"] = MetadataField(meta_fields)

return Instance(fields_dict)

@overrides
def apply_token_indexers(self, instance: Instance) -> None:
instance.fields["source_tokens"]._token_indexers = self._source_token_indexers # type: ignore
if "target_tokens" in instance.fields:
instance.fields["target_tokens"]._token_indexers = self._target_token_indexers # type: ignore
18 changes: 13 additions & 5 deletions allennlp_models/generation/dataset_readers/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def __init__(
quoting: int = csv.QUOTE_MINIMAL,
**kwargs,
) -> None:
super().__init__(**kwargs)
super().__init__(
manual_distributed_sharding=True, manual_multi_process_sharding=True, **kwargs
)
self._source_tokenizer = source_tokenizer or SpacyTokenizer()
self._target_tokenizer = target_tokenizer or self._source_tokenizer
self._source_token_indexers = source_token_indexers or {"tokens": SingleIdTokenIndexer()}
Expand All @@ -99,8 +101,8 @@ def _read(self, file_path: str):
self._target_max_exceeded = 0
with open(cached_path(file_path), "r") as data_file:
logger.info("Reading instances from lines in file at: %s", file_path)
for line_num, row in enumerate(
csv.reader(data_file, delimiter=self._delimiter, quoting=self.quoting)
for line_num, row in self.shard_iterable(
enumerate(csv.reader(data_file, delimiter=self._delimiter, quoting=self.quoting))
):
if len(row) != 2:
raise ConfigurationError(
Expand Down Expand Up @@ -135,7 +137,7 @@ def text_to_instance(
tokenized_source.insert(0, copy.deepcopy(self._start_token))
if self._source_add_end_token:
tokenized_source.append(copy.deepcopy(self._end_token))
source_field = TextField(tokenized_source, self._source_token_indexers)
source_field = TextField(tokenized_source)
if target_string is not None:
tokenized_target = self._target_tokenizer.tokenize(target_string)
if self._target_max_tokens and len(tokenized_target) > self._target_max_tokens:
Expand All @@ -145,7 +147,13 @@ def text_to_instance(
tokenized_target.insert(0, copy.deepcopy(self._start_token))
if self._target_add_end_token:
tokenized_target.append(copy.deepcopy(self._end_token))
target_field = TextField(tokenized_target, self._target_token_indexers)
target_field = TextField(tokenized_target)
return Instance({"source_tokens": source_field, "target_tokens": target_field})
else:
return Instance({"source_tokens": source_field})

@overrides
def apply_token_indexers(self, instance: Instance) -> None:
instance.fields["source_tokens"]._token_indexers = self._source_token_indexers # type: ignore
if "target_tokens" in instance.fields:
instance.fields["target_tokens"]._token_indexers = self._target_token_indexers # type: ignore
5 changes: 4 additions & 1 deletion allennlp_models/rc/dataset_readers/qangaroo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ class QangarooReader(DatasetReader):
"""

def __init__(
self, tokenizer: Tokenizer = None, token_indexers: Dict[str, TokenIndexer] = None, **kwargs,
self,
tokenizer: Tokenizer = None,
token_indexers: Dict[str, TokenIndexer] = None,
**kwargs,
) -> None:

super().__init__(**kwargs)
Expand Down

0 comments on commit 09fe0b0

Please sign in to comment.