Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type Key Removal in Configs #554

Merged
merged 2 commits into from
Nov 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 17 additions & 34 deletions forte/processors/data_augment/algorithms/eda_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,16 +278,12 @@ def default_configs(cls):
"other_entry_policy": {
# to use Texar hyperparams 'kwargs' must
# accompany with 'type'
"type": "",
"kwargs": {
"ft.onto.base_ontology.Document": "auto_align",
"ft.onto.base_ontology.Sentence": "auto_align",
},
"ft.onto.base_ontology.Document": "auto_align",
"ft.onto.base_ontology.Sentence": "auto_align",
},
"alpha": 0.1,
"augment_pack_names": {
"type": "",
"kwargs": {"input_src": "augmented_input_src"},
"input_src": "augmented_input_src",
},
}

Expand All @@ -308,9 +304,7 @@ def initialize(self, resources: Resources, configs: Config):
def _augment(self, input_pack: MultiPack, aug_pack_names: List[str]):
replacement_op = create_class_with_kwargs(
self.configs["data_aug_op"],
class_args={
"configs": self.configs["data_aug_op_config"]["kwargs"]
},
class_args={"configs": self.configs["data_aug_op_config"]},
)
augment_entry = get_class(self.configs["augment_entry"])

Expand Down Expand Up @@ -352,29 +346,22 @@ def default_configs(cls):
{
"augment_entry": "ft.onto.base_ontology.Token",
"other_entry_policy": {
"type": "",
"kwargs": {
"ft.onto.base_ontology.Document": "auto_align",
"ft.onto.base_ontology.Sentence": "auto_align",
},
"ft.onto.base_ontology.Document": "auto_align",
"ft.onto.base_ontology.Sentence": "auto_align",
},
"data_aug_op": "forte.processors.data_augment.algorithms."
"dictionary_replacement_op.DictionaryReplacementOp",
"data_aug_op_config": {
"type": "",
"kwargs": {
"dictionary_class": (
"forte.processors.data_augment."
"algorithms.dictionary.WordnetDictionary"
),
"prob": 1.0,
"lang": "eng",
},
"dictionary_class": (
"forte.processors.data_augment."
"algorithms.dictionary.WordnetDictionary"
),
"prob": 1.0,
"lang": "eng",
},
"alpha": 0.1,
"augment_pack_names": {
"type": "",
"kwargs": {"input_src": "augmented_input_src"},
"input_src": "augmented_input_src",
},
"stopwords": english_stopwords,
}
Expand Down Expand Up @@ -411,17 +398,13 @@ def default_configs(cls):
{
"augment_entry": "ft.onto.base_ontology.Token",
"other_entry_policy": {
"type": "",
"kwargs": {
"ft.onto.base_ontology.Document": "auto_align",
"ft.onto.base_ontology.Sentence": "auto_align",
},
"ft.onto.base_ontology.Document": "auto_align",
"ft.onto.base_ontology.Sentence": "auto_align",
},
"data_aug_op_config": {"type": "", "kwargs": {}},
"data_aug_op_config": {},
"alpha": 0.1,
"augment_pack_names": {
"type": "",
"kwargs": {"input_src": "augmented_input_src"},
"input_src": "augmented_input_src",
},
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,12 @@ def default_configs(cls):
{
"augment_entry": "ft.onto.base_ontology.Token",
"other_entry_policy": {
"type": "",
"kwargs": {
"ft.onto.base_ontology.Document": "auto_align",
"ft.onto.base_ontology.Sentence": "auto_align",
},
"ft.onto.base_ontology.Document": "auto_align",
"ft.onto.base_ontology.Sentence": "auto_align",
},
"alpha": 0.1,
"augment_pack_names": {
"type": "",
"kwargs": {"input_src": "augmented_input_src"},
"input_src": "augmented_input_src",
},
}
)
Expand Down
27 changes: 15 additions & 12 deletions forte/processors/data_augment/base_data_augment_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def __init__(self):

def initialize(self, resources: Resources, configs: Config):
super().initialize(resources, configs)
self._other_entry_policy = self.configs["other_entry_policy"]["kwargs"]
self._other_entry_policy = self.configs["other_entry_policy"]

def _overlap_with_existing(self, pid: int, begin: int, end: int) -> bool:
r"""
Expand Down Expand Up @@ -743,9 +743,7 @@ def _augment(self, input_pack: MultiPack, aug_pack_names: List[str]):
"""
replacement_op = create_class_with_kwargs(
self.configs["data_aug_op"],
class_args={
"configs": self.configs["data_aug_op_config"]["kwargs"]
},
class_args={"configs": self.configs["data_aug_op_config"]},
)
augment_entry = get_class(self.configs["augment_entry"])

Expand All @@ -760,20 +758,20 @@ def _process(self, input_pack: MultiPack):
aug_pack_names: List[str] = []

# Check if the DataPack exists.
for pack_name in self.configs["augment_pack_names"]["kwargs"].keys():
for pack_name in self.configs["augment_pack_names"].keys():
if pack_name in input_pack.pack_names:
aug_pack_names.append(pack_name)

if len(self.configs["augment_pack_names"]["kwargs"].keys()) == 0:
if len(self.configs["augment_pack_names"].keys()) == 0:
# Augment all the DataPacks if not specified.
aug_pack_names = list(input_pack.pack_names)

self._augment(input_pack, aug_pack_names)
new_packs: List[Tuple[str, DataPack]] = []
for aug_pack_name in aug_pack_names:
new_pack_name: str = self.configs["augment_pack_names"][
"kwargs"
].get(aug_pack_name, "augmented_" + aug_pack_name)
new_pack_name: str = self.configs["augment_pack_names"].get(
aug_pack_name, "augmented_" + aug_pack_name
)
data_pack = input_pack.get_pack(aug_pack_name)
new_pack = self._auto_align_annotations(
data_pack=data_pack,
Expand Down Expand Up @@ -875,9 +873,14 @@ def default_configs(cls):
"""
return {
"augment_entry": "ft.onto.base_ontology.Sentence",
"other_entry_policy": {"type": "", "kwargs": {}},
"other_entry_policy": {},
"type": "data_augmentation_op",
"data_aug_op": "",
"data_aug_op_config": {"type": "", "kwargs": {}},
"augment_pack_names": {"type": "", "kwargs": {}},
"data_aug_op_config": {},
"augment_pack_names": {},
"@no_typecheck": [
"other_entry_policy",
"data_aug_op_config",
"augment_pack_names",
],
}
Original file line number Diff line number Diff line change
Expand Up @@ -144,16 +144,13 @@ def test_pipeline(self, texts, expected_outputs, expected_tokens):
processor_config = {
"augment_entry": "ft.onto.base_ontology.Token",
"other_entry_policy": {
"type": "",
"kwargs": {
"ft.onto.base_ontology.Document": "auto_align",
"ft.onto.base_ontology.Sentence": "auto_align",
},
"ft.onto.base_ontology.Document": "auto_align",
"ft.onto.base_ontology.Sentence": "auto_align",
},
"type": "data_augmentation_op",
"data_aug_op": replacer_op,
"data_aug_op_config": {"type": "", "kwargs": {}},
"augment_pack_names": {"kwargs": {"input": "augmented_input"}},
"data_aug_op_config": {},
"augment_pack_names": {},
}

nlp.set_reader(reader=StringReader())
Expand Down Expand Up @@ -230,12 +227,12 @@ def test_replace_token(
processor_config = {
"augment_entry": "ft.onto.base_ontology.Token",
"other_entry_policy": {
"kwargs": {"ft.onto.base_ontology.Sentence": "auto_align"}
"ft.onto.base_ontology.Sentence": "auto_align"
},
"type": "data_augmentation_op",
"data_aug_op": replacer_op,
"data_aug_op_config": {"kwargs": {}},
"augment_pack_names": {"kwargs": {}},
"data_aug_op_config": {},
"augment_pack_names": {},
}

nlp.initialize()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

import unittest
import random
from forte.data.data_pack import DataPack
from ft.onto.base_ontology import Sentence
from forte.data.data_pack import DataPack
from forte.processors.data_augment.algorithms.back_translation_op import (
BackTranslationOp,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,25 +95,19 @@ def test_pipeline(self, texts, expected_outputs):
processor_config = {
"augment_entry": "ft.onto.base_ontology.Token",
"other_entry_policy": {
"type": "",
"kwargs": {
"ft.onto.base_ontology.Document": "auto_align",
"ft.onto.base_ontology.Sentence": "auto_align",
},
"ft.onto.base_ontology.Document": "auto_align",
"ft.onto.base_ontology.Sentence": "auto_align",
},
"type": "data_augmentation_op",
"data_aug_op": "forte.processors.data_augment.algorithms"
".embedding_similarity_replacement_op."
"EmbeddingSimilarityReplacementOp",
"data_aug_op_config": {
"type": "",
"kwargs": {
"vocab_path": self.abs_vocab_path,
"embed_hparams": self.embed_hparams,
"top_k": 1,
},
"vocab_path": self.abs_vocab_path,
"embed_hparams": self.embed_hparams,
"top_k": 1,
},
"augment_pack_names": {"kwargs": {"input": "augmented_input"}},
"augment_pack_names": {"input": "augmented_input"},
}
nlp.add(
component=ReplacementDataAugmentProcessor(), config=processor_config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def test_word_splitting_processor(
):
entity_config = {
"other_entry_policy": {
"kwargs": {"ft.onto.base_ontology.EntityMention": "auto_align"}
"ft.onto.base_ontology.EntityMention": "auto_align"
}
}
self.nlp.add(
Expand Down