From b890a7f1292c6445b894486ffcc2d13a689d7af7 Mon Sep 17 00:00:00 2001 From: Suqi Sun Date: Tue, 10 Aug 2021 00:16:42 -0400 Subject: [PATCH 01/19] Add pipeline states to ir --- forte/pipeline.py | 41 +++++++++++ forte/processors/base/base_processor.py | 5 -- .../forte/processors/remote_processor_test.py | 69 +++++++++++++++---- 3 files changed, 98 insertions(+), 17 deletions(-) diff --git a/forte/pipeline.py b/forte/pipeline.py index 288012abc..622d2b13e 100644 --- a/forte/pipeline.py +++ b/forte/pipeline.py @@ -34,6 +34,7 @@ Set, ) +import pickle import yaml import uvicorn from fastapi import FastAPI @@ -257,6 +258,19 @@ def init_from_config(self, configs: List): is_first: bool = True for component_config in configs: + + if component_config["type"] == "PIPELINE_STATES": + state_configs: Dict[str, Dict] = component_config["configs"] + for attr, val in state_configs["attribute"].items(): + setattr(self, attr, val) + self.resource.update( + **{ + field: pickle.loads(val.encode("latin1")) + for field, val in state_configs["resource"].items() + } + ) + continue + component = create_class_with_kwargs( class_name=component_config["type"], class_args=component_config.get("kwargs", {}), @@ -299,6 +313,33 @@ def _dump_to_config(self): "configs": config.todict(), } ) + + # Serialize current states of pipeline + configs.append( + { + "type": "PIPELINE_STATES", + "configs": { + "attribute": { + attr: getattr(self, attr) + for attr in ( + "_initialized", + "_enable_profiling", + "_check_type_consistency", + "_do_init_type_check", + ) + if hasattr(self, attr) + }, + "resource": { + field: pickle.dumps(self.resource.get(field)).decode( + "latin1" + ) + for field in ("onto_specs_dict", "merged_entry_tree") + if self.resource.contains(field) + }, + }, + } + ) + return configs def save(self, path: str): diff --git a/forte/processors/base/base_processor.py b/forte/processors/base/base_processor.py index 691c0a5a0..87c672abf 100644 --- a/forte/processors/base/base_processor.py +++ b/forte/processors/base/base_processor.py @@ -130,11 +130,6 @@ def default_configs(cls) -> Dict[str, Any]: config = super().default_configs() config.update( { - "selector": { - "type": "forte.data.selector.DummySelector", - "args": None, - "kwargs": {}, - }, "overwrite": False, } ) diff --git a/tests/forte/processors/remote_processor_test.py b/tests/forte/processors/remote_processor_test.py index 4d676556f..0773563d8 100644 --- a/tests/forte/processors/remote_processor_test.py +++ b/tests/forte/processors/remote_processor_test.py @@ -16,20 +16,18 @@ """ import os -import sys -import json import unittest from ddt import ddt, data -from typing import Any, Dict, Iterator, Optional, Type, Set, List -from forte.common import ProcessorConfigError +from typing import Dict, Set from forte.data.data_pack import DataPack -from forte.pipeline import Pipeline, serve +from forte.pipeline import Pipeline from forte.processors.base import PackProcessor from forte.processors.nlp import ElizaProcessor from forte.processors.misc import RemoteProcessor from forte.data.readers import RawDataDeserializeReader, StringReader from forte.data.common_entry_utils import create_utterance, get_last_utterance +from forte.data.ontology.code_generation_objects import EntryTreeNode from ft.onto.base_ontology import Utterance @@ -89,6 +87,13 @@ class TestRemoteProcessor(unittest.TestCase): and all the testcases below are refactored from `./eliza_test.py`. """ + def setUp(self) -> None: + dir_path: str = os.path.dirname(os.path.abspath(__file__)) + self._pl_config_path: str = os.path.join(dir_path, "eliza_pl_ir.yaml") + self._onto_path: str = os.path.join( + dir_path, "../data/ontology/test_specs/base_ontology.json" + ) + @data( [ "I would like to have a chat bot.", @@ -101,23 +106,46 @@ def test_ir(self, input_output_pair): Verify the intermediate representation of pipeline. """ i_str, o_str = input_output_pair - pl_config_path: str = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "eliza_pl_ir.yaml" - ) # Build eliza pipeline - eliza_pl: Pipeline[DataPack] = Pipeline[DataPack]() + eliza_pl: Pipeline[DataPack] = Pipeline[DataPack]( + ontology_file=self._onto_path, + enforce_consistency=True, + do_init_type_check=True + ) eliza_pl.set_reader(StringReader()) eliza_pl.add(UserSimulator(), config={"user_input": i_str}) eliza_pl.add(ElizaProcessor()) - eliza_pl.save(pl_config_path) + eliza_pl.set_profiling() + eliza_pl.initialize() + eliza_pl.save(self._pl_config_path) # Build test pipeline test_pl: Pipeline[DataPack] = Pipeline[DataPack]() - test_pl.init_from_config_path(pl_config_path) - test_pl.initialize() + test_pl.init_from_config_path(self._pl_config_path) + + # Verify pipeline states + self.assertListEqual(*map( + lambda pl: [ + getattr(pl, attr) for attr in ( + "_initialized", + "_enable_profiling", + "_check_type_consistency", + "_do_init_type_check" + ) if hasattr(pl, attr) + ], (eliza_pl, test_pl) + )) + self.assertDictEqual( + eliza_pl.resource.get("onto_specs_dict"), + test_pl.resource.get("onto_specs_dict") + ) + self._assertEntryTreeEqual( + eliza_pl.resource.get("merged_entry_tree").root, + test_pl.resource.get("merged_entry_tree").root + ) # Verify output + test_pl.initialize() res: DataPack = test_pl.process("") utterance = get_last_utterance(res, "ai") self.assertEqual(len([_ for _ in res.get(Utterance)]), 2) @@ -180,6 +208,23 @@ def test_remote_processor(self, input_output_pair): self.assertEqual(len([_ for _ in res.get(Utterance)]), 2) self.assertEqual(utterance.text, o_str) + def _assertEntryTreeEqual(self, root1: EntryTreeNode, root2: EntryTreeNode): + """ + Test if two `EntryTreeNode` objects are recursively equivalent + """ + self.assertEqual(root1.name, root2.name) + self.assertSetEqual(root1.attributes, root2.attributes) + self.assertEqual(len(root1.children), len(root2.children)) + for i in range(len(root1.children)): + self._assertEntryTreeEqual(root1.children[i], root2.children[i]) + + def tearDown(self) -> None: + """ + Remove the IR file of pipeline if necessary. + """ + if os.path.exists(self._pl_config_path): + os.remove(self._pl_config_path) + if __name__ == "__main__": unittest.main() From 32e2afc1be1fad03dc00b60a76ec8cdc72022c60 Mon Sep 17 00:00:00 2001 From: Suqi Sun Date: Tue, 10 Aug 2021 10:44:19 -0400 Subject: [PATCH 02/19] Add doc for ir --- forte/pipeline.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/forte/pipeline.py b/forte/pipeline.py index 622d2b13e..8c20ec524 100644 --- a/forte/pipeline.py +++ b/forte/pipeline.py @@ -253,12 +253,20 @@ def init_from_config(self, configs: List): given configurations. Args: - configs: The configs used to initialize the pipeline. + configs: The configs used to initialize the pipeline. It should be + a list of dictionary that contains `"type"` and `"configs"`. + `"type"` indicates the class of pipeline components and + `"configs"` stores the corresponding component's configs. One + exception is that when `"type"` is set to `"PIPELINE_STATES"`, + `"configs"` will be used to update the pipeline states + based on the fields specified in `configs.attribute` and + `configs.resource`. """ is_first: bool = True for component_config in configs: + # Set pipeline states and resources if component_config["type"] == "PIPELINE_STATES": state_configs: Dict[str, Dict] = component_config["configs"] for attr, val in state_configs["attribute"].items(): @@ -293,7 +301,7 @@ def _dump_to_config(self): a pipeline. Returns: - dict: A dictionary storing IR. + list: A list of dictionary storing IR. """ configs: List[Dict] = [] configs.append( From 94eb8afe2ab10227a61587cfaf72e62e71f8fc62 Mon Sep 17 00:00:00 2001 From: Suqi Sun Date: Tue, 10 Aug 2021 20:49:29 -0400 Subject: [PATCH 03/19] Restructure IR --- .../data/ontology/code_generation_objects.py | 48 +++++++ forte/pipeline.py | 128 +++++++++++------- ...ssor_test.py => advanced_pipeline_test.py} | 93 ++++++++++--- 3 files changed, 199 insertions(+), 70 deletions(-) rename tests/forte/{processors/remote_processor_test.py => advanced_pipeline_test.py} (71%) diff --git a/forte/data/ontology/code_generation_objects.py b/forte/data/ontology/code_generation_objects.py index 130a14757..80986b50b 100644 --- a/forte/data/ontology/code_generation_objects.py +++ b/forte/data/ontology/code_generation_objects.py @@ -785,6 +785,54 @@ def collect_parents(self, node_dict: Dict[str, Set[str]]): ] = found_node.parent.attributes found_node = found_node.parent + def todict(self) -> Dict[str, Any]: + r"""Dump the EntryTree structure to a dictionary. + + Returns: + dict: A dictionary storing the EntryTree. + """ + + def node_to_dict(node: EntryTreeNode): + return ( + None + if not node + else { + "name": node.name, + "attributes": list(node.attributes), + "children": [ + node_to_dict(child) for child in node.children + ], + } + ) + + return node_to_dict(self.root) + + def fromdict( + self, tree_dict: Dict[str, Any], parent_entry_name: Optional[str] = None + ) -> Optional["EntryTree"]: + r"""Load the EntryTree structure from a dictionary. + + Args: + tree_dict: A dictionary storing the EntryTree. + parent_entry_name: The type name of the parent of the node to be + built. Default value is None. + """ + if not tree_dict: + return None + + if parent_entry_name is None: + self.root = EntryTreeNode(name=tree_dict["name"]) + self.root.attributes = set(tree_dict["attributes"]) + else: + self.add_node( + curr_entry_name=tree_dict["name"], + parent_entry_name=parent_entry_name, + curr_entry_attr=set(tree_dict["attributes"]), + ) + for child in tree_dict["children"]: + self.fromdict(child, tree_dict["name"]) + return self + def search(node: EntryTreeNode, search_node_name: str): if node.name == search_node_name: diff --git a/forte/pipeline.py b/forte/pipeline.py index 8c20ec524..b75ca6c4f 100644 --- a/forte/pipeline.py +++ b/forte/pipeline.py @@ -34,7 +34,6 @@ Set, ) -import pickle import yaml import uvicorn from fastapi import FastAPI @@ -220,6 +219,9 @@ def __init__( # Indicate whether do type checking during pipeline initialization self._do_init_type_check: bool = do_init_type_check + # The version of intermediate representation format + self.FORTE_IR_VERSION: str = "0.0.1" + def enforce_consistency(self, enforce: bool = True): r"""This function determines whether the pipeline will check the content expectations specified in each pipeline component. This @@ -248,37 +250,32 @@ def init_from_config_path(self, config_path): configs = yaml.safe_load(open(config_path)) self.init_from_config(configs) - def init_from_config(self, configs: List): + def init_from_config(self, configs: Dict[str, Any]): r"""Initialized the pipeline (ontology and processors) from the given configurations. Args: configs: The configs used to initialize the pipeline. It should be - a list of dictionary that contains `"type"` and `"configs"`. - `"type"` indicates the class of pipeline components and - `"configs"` stores the corresponding component's configs. One - exception is that when `"type"` is set to `"PIPELINE_STATES"`, - `"configs"` will be used to update the pipeline states - based on the fields specified in `configs.attribute` and - `configs.resource`. + a dictionary that contains `forte_ir_version`, `components` + and `states`. `forte_ir_version` is a string used to validate + input format. `components` is a list of dictionary that + contains `type` (the class of pipeline components), + `configs` (the corresponding component's configs) and + `selector`. `states` will be used to update the pipeline states + based on the fields specified in `states.attribute` and + `states.resource`. """ + # Validate IR version + if configs.get("forte_ir_version") != self.FORTE_IR_VERSION: + raise ProcessorConfigError( + f"forte_ir_version={configs.get('forte_ir_version')} not " + "supported. Please make sure the format of input IR complies " + f"with forte_ir_version={self.FORTE_IR_VERSION}." + ) + # Add components from IR is_first: bool = True - for component_config in configs: - - # Set pipeline states and resources - if component_config["type"] == "PIPELINE_STATES": - state_configs: Dict[str, Dict] = component_config["configs"] - for attr, val in state_configs["attribute"].items(): - setattr(self, attr, val) - self.resource.update( - **{ - field: pickle.loads(val.encode("latin1")) - for field, val in state_configs["resource"].items() - } - ) - continue - + for component_config in configs["components"]: component = create_class_with_kwargs( class_name=component_config["type"], class_args=component_config.get("kwargs", {}), @@ -293,7 +290,26 @@ def init_from_config(self, configs: List): is_first = False else: # Can be processor, caster, or evaluator - self.add(component, component_config.get("configs", {})) + selector = create_class_with_kwargs( + class_name=component_config["selector"]["type"], + class_args=component_config["selector"].get("kwargs", {}), + ) + self.add( + component=component, + config=component_config.get("configs", {}), + selector=selector, + ) + + # Set pipeline states and resources + states_config: Dict[str, Dict] = configs["states"] + for attr, val in states_config["attribute"].items(): + setattr(self, attr, val) + self.resource.update( + onto_specs_dict=states_config["resource"]["onto_specs_dict"], + merged_entry_tree=EntryTree().fromdict( + states_config["resource"]["merged_entry_tree"] + ), + ) def _dump_to_config(self): r"""Serialize the pipeline to an IR(intermediate representation). @@ -301,10 +317,16 @@ def _dump_to_config(self): a pipeline. Returns: - list: A list of dictionary storing IR. + dict: A dictionary storing IR. """ - configs: List[Dict] = [] - configs.append( + configs: Dict = { + "forte_ir_version": self.FORTE_IR_VERSION, + "components": list(), + "states": dict(), + } + + # Serialize pipeline components + configs["components"].append( { "type": ".".join( [self._reader.__module__, type(self._reader).__name__] @@ -312,38 +334,44 @@ def _dump_to_config(self): "configs": self._reader_config.todict(), } ) - for component, config in zip(self.components, self.component_configs): - configs.append( + for component, config, selector in zip( + self.components, self.component_configs, self._selectors + ): + configs["components"].append( { "type": ".".join( [component.__module__, type(component).__name__] ), "configs": config.todict(), + "selector": { + "type": ".".join( + [selector.__module__, type(selector).__name__] + ), + # TODO: This presumes that class attributes' names are + # the same as the paramaters' names passed to + # selector's constructor, which may not be always true. + "kwargs": selector.__dict__ or None, + }, } ) # Serialize current states of pipeline - configs.append( + configs["states"].update( { - "type": "PIPELINE_STATES", - "configs": { - "attribute": { - attr: getattr(self, attr) - for attr in ( - "_initialized", - "_enable_profiling", - "_check_type_consistency", - "_do_init_type_check", - ) - if hasattr(self, attr) - }, - "resource": { - field: pickle.dumps(self.resource.get(field)).decode( - "latin1" - ) - for field in ("onto_specs_dict", "merged_entry_tree") - if self.resource.contains(field) - }, + "attribute": { + attr: getattr(self, attr) + for attr in ( + "_initialized", + "_enable_profiling", + "_check_type_consistency", + "_do_init_type_check", + ) + if hasattr(self, attr) + }, + "resource": { + "onto_specs_dict": self.resource.get("onto_specs_dict"), + "merged_entry_tree": self.resource.get("merged_entry_tree") + and self.resource.get("merged_entry_tree").todict(), }, } ) diff --git a/tests/forte/processors/remote_processor_test.py b/tests/forte/advanced_pipeline_test.py similarity index 71% rename from tests/forte/processors/remote_processor_test.py rename to tests/forte/advanced_pipeline_test.py index 0773563d8..0adf7c350 100644 --- a/tests/forte/processors/remote_processor_test.py +++ b/tests/forte/advanced_pipeline_test.py @@ -19,9 +19,12 @@ import unittest from ddt import ddt, data -from typing import Dict, Set +from typing import Dict, Set, Any, Iterator from forte.data.data_pack import DataPack +from forte.data.multi_pack import MultiPack +from forte.data.selector import RegexNameMatchSelector from forte.pipeline import Pipeline +from forte.data.base_reader import MultiPackReader from forte.processors.base import PackProcessor from forte.processors.nlp import ElizaProcessor from forte.processors.misc import RemoteProcessor @@ -29,6 +32,7 @@ from forte.data.common_entry_utils import create_utterance, get_last_utterance from forte.data.ontology.code_generation_objects import EntryTreeNode from ft.onto.base_ontology import Utterance +from forte.data.ontology.top import Generics TEST_RECORDS_1 = { @@ -56,6 +60,22 @@ def default_configs(cls): return config +class DummyMultiPackReader(MultiPackReader): + def _collect(self, *args: Any, **kwargs: Any) -> Iterator[Any]: + yield 0 + + def _parse_pack(self, collection: Any) -> Iterator[MultiPack]: + multi_pack: MultiPack = MultiPack() + data_pack1 = multi_pack.add_pack(ref_name="pack1") + data_pack2 = multi_pack.add_pack(ref_name="pack2") + data_pack3 = multi_pack.add_pack(ref_name="pack_three") + + data_pack1.pack_name = "1" + data_pack2.pack_name = "2" + data_pack3.pack_name = "Three" + yield multi_pack + + class DummyProcessor(PackProcessor): """ A dummpy Processor to check the expected/output records from the remote @@ -71,7 +91,11 @@ def __init__( self._output_records: Dict[str, Set[str]] = output_records def _process(self, input_pack: DataPack): - pass + entries = list(input_pack.get_entries_of(Generics)) + if len(entries) == 0: + Generics(pack=input_pack) + else: + entry = entries[0] def expected_types_and_attributes(self): return self._expected_records @@ -81,17 +105,18 @@ def record(self, record_meta: Dict[str, Set[str]]): @ddt -class TestRemoteProcessor(unittest.TestCase): +class AdvancedPipelineTest(unittest.TestCase): """ - Test RemoteProcessor. Here we use eliza pipeline as an example, - and all the testcases below are refactored from `./eliza_test.py`. + Test intermediate representation and RemoteProcessor. Here we use eliza + pipeline as an example, and all the testcases below are refactored from + `eliza_test.py`. """ def setUp(self) -> None: dir_path: str = os.path.dirname(os.path.abspath(__file__)) self._pl_config_path: str = os.path.join(dir_path, "eliza_pl_ir.yaml") self._onto_path: str = os.path.join( - dir_path, "../data/ontology/test_specs/base_ontology.json" + dir_path, "data/ontology/test_specs/base_ontology.json" ) @data( @@ -101,7 +126,7 @@ def setUp(self) -> None: ], ["bye", "Goodbye. Thank you for talking to me."], ) - def test_ir(self, input_output_pair): + def test_ir_basic(self, input_output_pair): """ Verify the intermediate representation of pipeline. """ @@ -111,7 +136,7 @@ def test_ir(self, input_output_pair): eliza_pl: Pipeline[DataPack] = Pipeline[DataPack]( ontology_file=self._onto_path, enforce_consistency=True, - do_init_type_check=True + do_init_type_check=True, ) eliza_pl.set_reader(StringReader()) eliza_pl.add(UserSimulator(), config={"user_input": i_str}) @@ -125,23 +150,28 @@ def test_ir(self, input_output_pair): test_pl.init_from_config_path(self._pl_config_path) # Verify pipeline states - self.assertListEqual(*map( - lambda pl: [ - getattr(pl, attr) for attr in ( - "_initialized", - "_enable_profiling", - "_check_type_consistency", - "_do_init_type_check" - ) if hasattr(pl, attr) - ], (eliza_pl, test_pl) - )) + self.assertListEqual( + *map( + lambda pl: [ + getattr(pl, attr) + for attr in ( + "_initialized", + "_enable_profiling", + "_check_type_consistency", + "_do_init_type_check", + ) + if hasattr(pl, attr) + ], + (eliza_pl, test_pl), + ) + ) self.assertDictEqual( eliza_pl.resource.get("onto_specs_dict"), - test_pl.resource.get("onto_specs_dict") + test_pl.resource.get("onto_specs_dict"), ) self._assertEntryTreeEqual( eliza_pl.resource.get("merged_entry_tree").root, - test_pl.resource.get("merged_entry_tree").root + test_pl.resource.get("merged_entry_tree").root, ) # Verify output @@ -151,6 +181,29 @@ def test_ir(self, input_output_pair): self.assertEqual(len([_ for _ in res.get(Utterance)]), 2) self.assertEqual(utterance.text, o_str) + def test_ir_selector(self): + """ + Test the intermediate representation of selector. + """ + # Build original pipeline with RegexNameMatchSelector + pl: Pipeline = Pipeline[MultiPack]() + pl.set_reader(DummyMultiPackReader()) + pl.add( + DummyProcessor(), + selector=RegexNameMatchSelector(select_name="^.*\\d$"), + ) + pl.save(self._pl_config_path) + + # Verify the selector from IR + test_pl: Pipeline = Pipeline[MultiPack]() + test_pl.init_from_config_path(self._pl_config_path) + test_pl.initialize() + for multi_pack in test_pl.process_dataset(): + for _, pack in multi_pack.iter_packs(): + self.assertEqual( + pack.num_generics_entries, int(pack.pack_name in ("1", "2")) + ) + @data( [ "I would like to have a chat bot.", From 0bb40866685244766ad970ae9caa2b240e7ca3fc Mon Sep 17 00:00:00 2001 From: Suqi Sun Date: Wed, 11 Aug 2021 13:41:21 -0400 Subject: [PATCH 04/19] Update IR parse --- forte/pipeline.py | 64 ++++++++++++++++++++++++++++------------------- forte/version.py | 1 + 2 files changed, 39 insertions(+), 26 deletions(-) diff --git a/forte/pipeline.py b/forte/pipeline.py index b75ca6c4f..5fec9a9ce 100644 --- a/forte/pipeline.py +++ b/forte/pipeline.py @@ -60,6 +60,7 @@ from forte.processors.base.batch_processor import BaseBatchProcessor from forte.utils import create_class_with_kwargs from forte.utils.utils_processor import record_types_and_attributes_check +from forte.version import FORTE_IR_VERSION if sys.version_info < (3, 7): import importlib_resources as resources @@ -219,9 +220,6 @@ def __init__( # Indicate whether do type checking during pipeline initialization self._do_init_type_check: bool = do_init_type_check - # The version of intermediate representation format - self.FORTE_IR_VERSION: str = "0.0.1" - def enforce_consistency(self, enforce: bool = True): r"""This function determines whether the pipeline will check the content expectations specified in each pipeline component. This @@ -266,16 +264,16 @@ def init_from_config(self, configs: Dict[str, Any]): `states.resource`. """ # Validate IR version - if configs.get("forte_ir_version") != self.FORTE_IR_VERSION: + if configs.get("forte_ir_version") != FORTE_IR_VERSION: raise ProcessorConfigError( f"forte_ir_version={configs.get('forte_ir_version')} not " "supported. Please make sure the format of input IR complies " - f"with forte_ir_version={self.FORTE_IR_VERSION}." + f"with forte_ir_version={FORTE_IR_VERSION}." ) # Add components from IR is_first: bool = True - for component_config in configs["components"]: + for component_config in configs.get("components", []): component = create_class_with_kwargs( class_name=component_config["type"], class_args=component_config.get("kwargs", {}), @@ -290,26 +288,32 @@ def init_from_config(self, configs: Dict[str, Any]): is_first = False else: # Can be processor, caster, or evaluator - selector = create_class_with_kwargs( - class_name=component_config["selector"]["type"], - class_args=component_config["selector"].get("kwargs", {}), - ) + selector_config = component_config.get("selector") self.add( component=component, config=component_config.get("configs", {}), - selector=selector, + selector=selector_config + and create_class_with_kwargs( + class_name=selector_config["type"], + class_args=selector_config.get("kwargs", {}), + ), ) # Set pipeline states and resources - states_config: Dict[str, Dict] = configs["states"] - for attr, val in states_config["attribute"].items(): + states_config: Dict[str, Dict] = configs.get("states", {}) + for attr, val in states_config.get("attribute", {}).items(): setattr(self, attr, val) - self.resource.update( - onto_specs_dict=states_config["resource"]["onto_specs_dict"], - merged_entry_tree=EntryTree().fromdict( - states_config["resource"]["merged_entry_tree"] - ), - ) + resource_config: Dict[str, Dict] = states_config.get("resource", {}) + if "onto_specs_dict" in resource_config: + self.resource.update( + onto_specs_dict=resource_config["onto_specs_dict"] + ) + if "merged_entry_tree" in resource_config: + self.resource.update( + merged_entry_tree=EntryTree().fromdict( + resource_config["merged_entry_tree"] + ), + ) def _dump_to_config(self): r"""Serialize the pipeline to an IR(intermediate representation). @@ -320,7 +324,7 @@ def _dump_to_config(self): dict: A dictionary storing IR. """ configs: Dict = { - "forte_ir_version": self.FORTE_IR_VERSION, + "forte_ir_version": FORTE_IR_VERSION, "components": list(), "states": dict(), } @@ -350,7 +354,7 @@ def _dump_to_config(self): # TODO: This presumes that class attributes' names are # the same as the paramaters' names passed to # selector's constructor, which may not be always true. - "kwargs": selector.__dict__ or None, + "kwargs": selector.__dict__ or {}, }, } ) @@ -368,13 +372,21 @@ def _dump_to_config(self): ) if hasattr(self, attr) }, - "resource": { - "onto_specs_dict": self.resource.get("onto_specs_dict"), - "merged_entry_tree": self.resource.get("merged_entry_tree") - and self.resource.get("merged_entry_tree").todict(), - }, + "resource": dict(), } ) + if self.resource.contains("onto_specs_dict"): + configs["states"]["resource"].update( + {"onto_specs_dict": self.resource.get("onto_specs_dict")} + ) + if self.resource.contains("merged_entry_tree"): + configs["states"]["resource"].update( + { + "merged_entry_tree": self.resource.get( + "merged_entry_tree" + ).todict() + } + ) return configs diff --git a/forte/version.py b/forte/version.py index bb8be5c60..c5d45fbb0 100644 --- a/forte/version.py +++ b/forte/version.py @@ -18,3 +18,4 @@ VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR) VERSION = "{0}.{1}.{2}".format(_MAJOR, _MINOR, _REVISION) +FORTE_IR_VERSION = "0.0.1" From 66e3042b2e88666bcb26bc8a862591715c35bdf3 Mon Sep 17 00:00:00 2001 From: Suqi Sun Date: Sat, 14 Aug 2021 17:53:47 -0400 Subject: [PATCH 05/19] Update selector serialization --- forte/data/selector.py | 6 +++--- forte/pipeline.py | 42 ++++++++++++++++++++++++++++++------------ 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/forte/data/selector.py b/forte/data/selector.py index bd12a569e..5a7a36e84 100644 --- a/forte/data/selector.py +++ b/forte/data/selector.py @@ -39,7 +39,7 @@ class Selector(Generic[InputPackType, OutputPackType]): def __init__(self, **kwargs): - pass + self._stored_kwargs = kwargs def select(self, pack: InputPackType) -> Iterator[OutputPackType]: raise NotImplementedError @@ -69,7 +69,7 @@ class NameMatchSelector(SinglePackSelector): """ def __init__(self, select_name: str): - super().__init__() + super().__init__(select_name=select_name) assert select_name is not None self.select_name: str = select_name @@ -90,7 +90,7 @@ class RegexNameMatchSelector(SinglePackSelector): r"""Select a :class:`DataPack` from a :class:`MultiPack` using a regex.""" def __init__(self, select_name: str): - super().__init__() + super().__init__(select_name=select_name) assert select_name is not None self.select_name: str = select_name diff --git a/forte/pipeline.py b/forte/pipeline.py index 5fec9a9ce..baf67ba4f 100644 --- a/forte/pipeline.py +++ b/forte/pipeline.py @@ -323,6 +323,22 @@ def _dump_to_config(self): Returns: dict: A dictionary storing IR. """ + + def get_type(instance) -> str: + r"""Get full module name of an instance""" + return instance.__module__ + "." + type(instance).__name__ + + def test_jsonable(test_dict: Dict, type_name: str = ""): + r"""Check if a dictionary is JSON serializable""" + try: + json.dumps(test_dict) + return test_dict + except (TypeError, OverflowError) as e: + raise ProcessorConfigError( + f"{type_name} is not JSON serializable. Please double " + "check the configuration or arguments" + ) from e + configs: Dict = { "forte_ir_version": FORTE_IR_VERSION, "components": list(), @@ -332,10 +348,11 @@ def _dump_to_config(self): # Serialize pipeline components configs["components"].append( { - "type": ".".join( - [self._reader.__module__, type(self._reader).__name__] + "type": get_type(self._reader), + "configs": test_jsonable( + test_dict=self._reader_config.todict(), + type_name=f"Configuration of {get_type(self._reader)}", ), - "configs": self._reader_config.todict(), } ) for component, config, selector in zip( @@ -343,18 +360,19 @@ def _dump_to_config(self): ): configs["components"].append( { - "type": ".".join( - [component.__module__, type(component).__name__] + "type": get_type(component), + "configs": test_jsonable( + test_dict=config.todict(), + type_name=f"Configuration of {get_type(component)}", ), - "configs": config.todict(), "selector": { - "type": ".".join( - [selector.__module__, type(selector).__name__] + "type": get_type(selector), + "kwargs": test_jsonable( + # pylint: disable=protected-access + test_dict=selector._stored_kwargs, + # pylint: enable=protected-access + type_name=f"kwargs of {get_type(selector)}", ), - # TODO: This presumes that class attributes' names are - # the same as the paramaters' names passed to - # selector's constructor, which may not be always true. - "kwargs": selector.__dict__ or {}, }, } ) From e856c823d2bf9b372788650b5e5c7845f197e8d3 Mon Sep 17 00:00:00 2001 From: Suqi Sun Date: Sun, 15 Aug 2021 16:48:39 -0400 Subject: [PATCH 06/19] Add informative error message --- forte/pipeline.py | 63 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 47 insertions(+), 16 deletions(-) diff --git a/forte/pipeline.py b/forte/pipeline.py index baf67ba4f..5d4a1563b 100644 --- a/forte/pipeline.py +++ b/forte/pipeline.py @@ -58,7 +58,7 @@ from forte.process_manager import ProcessManager, ProcessJobStatus from forte.processors.base import BaseProcessor from forte.processors.base.batch_processor import BaseBatchProcessor -from forte.utils import create_class_with_kwargs +from forte.utils import create_class_with_kwargs, get_full_module_name from forte.utils.utils_processor import record_types_and_attributes_check from forte.version import FORTE_IR_VERSION @@ -324,20 +324,51 @@ def _dump_to_config(self): dict: A dictionary storing IR. """ - def get_type(instance) -> str: - r"""Get full module name of an instance""" - return instance.__module__ + "." + type(instance).__name__ - - def test_jsonable(test_dict: Dict, type_name: str = ""): + def test_jsonable(test_dict: Dict, err_msg: str): r"""Check if a dictionary is JSON serializable""" try: json.dumps(test_dict) return test_dict except (TypeError, OverflowError) as e: - raise ProcessorConfigError( - f"{type_name} is not JSON serializable. Please double " - "check the configuration or arguments" - ) from e + raise ProcessorConfigError(err_msg) from e + + get_err_msg: Dict = { + "reader": lambda reader: ( + "The reader of the pipeline cannot be JSON serialized. This is" + " likely due to some parameters in the configuration of the " + f"reader {get_full_module_name(reader)} cannot be serialized " + "in JSON. To resolve this issue, you can consider implementing" + " a JSON serialization for that parameter type or changing the" + " parameters of this reader. Note that in order for the reader" + " to be serialized in JSON, all the variables defined in both " + "the default_configs and the configuration passed in during " + "pipeline.set_reader() need to be JSON-serializable. You can " + "find in the stack trace the type of the un-serializable " + "parameter." + ), + "component": lambda component: ( + "One component of the pipeline cannot be JSON serialized. This" + " is likely due to some parameters in the configuration of the" + f" component {get_full_module_name(component)} cannot be " + "serialized in JSON. To resolve this issue, you can consider " + "implementing a JSON serialization for that parameter type or " + "changing the parameters of the component. Note that in order " + "for the component to be serialized in JSON, all the variables" + " defined in both the default_configs and the configuration " + "passed in during pipeline.add() need to be JSON-serializable." + " You can find in the stack trace the type of the " + "un-serializable parameter." + ), + "selector": lambda selector: ( + "A selector cannot be JSON serialized. This is likely due to " + "some __init__ parameters for class " + f"{get_full_module_name(selector)} cannot be serialized in " + "JSON. To resolve this issue, you can consider implementing a " + "JSON serialization for that parameter type or changing the " + "signature of the __init__ function. You can find in the stack" + " trace the type of the un-serializable parameter." + ), + } configs: Dict = { "forte_ir_version": FORTE_IR_VERSION, @@ -348,10 +379,10 @@ def test_jsonable(test_dict: Dict, type_name: str = ""): # Serialize pipeline components configs["components"].append( { - "type": get_type(self._reader), + "type": get_full_module_name(self._reader), "configs": test_jsonable( test_dict=self._reader_config.todict(), - type_name=f"Configuration of {get_type(self._reader)}", + err_msg=get_err_msg["reader"](self._reader), ), } ) @@ -360,18 +391,18 @@ def test_jsonable(test_dict: Dict, type_name: str = ""): ): configs["components"].append( { - "type": get_type(component), + "type": get_full_module_name(component), "configs": test_jsonable( test_dict=config.todict(), - type_name=f"Configuration of {get_type(component)}", + err_msg=get_err_msg["component"](component), ), "selector": { - "type": get_type(selector), + "type": get_full_module_name(selector), "kwargs": test_jsonable( # pylint: disable=protected-access test_dict=selector._stored_kwargs, # pylint: enable=protected-access - type_name=f"kwargs of {get_type(selector)}", + err_msg=get_err_msg["selector"](selector), ), }, } From e9902bc19c9a6068a554f6d807353c102fec7301 Mon Sep 17 00:00:00 2001 From: Zhanyuan Zhang Date: Tue, 7 Sep 2021 15:36:54 -0400 Subject: [PATCH 07/19] init implement of Selector serialization/deserialization --- forte/common/exception.py | 6 +++ forte/data/selector.py | 108 ++++++++++++++++++++++++++++++++++---- forte/pipeline.py | 15 +++--- 3 files changed, 112 insertions(+), 17 deletions(-) diff --git a/forte/common/exception.py b/forte/common/exception.py index 4d70d2e10..b5011a896 100644 --- a/forte/common/exception.py +++ b/forte/common/exception.py @@ -20,6 +20,7 @@ "IncompleteEntryError", "EntryNotFoundError", "ProcessorConfigError", + "SelectorConfigError", "PackDataException", "ProcessFlowException", "ProcessExecutionException", @@ -60,6 +61,11 @@ class ResourceError(ValueError): """ pass +class SelectorConfigError(ValueError): + r"""Raise this error when the there is a problem with the processor + configuration. + """ + pass class PackDataException(Exception): r"""Raise this error when the data in pack is wrong.""" diff --git a/forte/data/selector.py b/forte/data/selector.py index 5a7a36e84..9284c162c 100644 --- a/forte/data/selector.py +++ b/forte/data/selector.py @@ -15,13 +15,17 @@ This defines some selector interface used as glue to combine DataPack/multiPack processors and Pipeline. """ -from typing import Generic, Iterator, TypeVar +from typing import Generic, Iterator, TypeVar, Optional, Union, Dict, Any import re +import yaml +from forte.common.configuration import Config +from forte.common import SelectorConfigError from forte.data.base_pack import BasePack from forte.data.data_pack import DataPack from forte.data.multi_pack import MultiPack +from forte.utils import get_full_module_name InputPackType = TypeVar("InputPackType", bound=BasePack) OutputPackType = TypeVar("OutputPackType", bound=BasePack) @@ -38,12 +42,68 @@ class Selector(Generic[InputPackType, OutputPackType]): - def __init__(self, **kwargs): - self._stored_kwargs = kwargs + def __init__(self, + configs: Optional[Union[Config, Dict[str, Any]]] = None): + self.configs = self.make_configs(configs) def select(self, pack: InputPackType) -> Iterator[OutputPackType]: raise NotImplementedError + @classmethod + def make_configs( + cls, configs: Optional[Union[Config, Dict[str, Any]]] + ) -> Config: + """ + Create the component configuration for this class, by merging the + provided config with the ``default_configs()``. + + The following config conventions are expected: + - The top level key can be a special `config_path`. + - `config_path` should be point to a file system path, which will + be a YAML file containing configurations. + - Other key values in the configs will be considered as parameters. + + Args: + configs: The input config to be merged with the default config. + + Returns: + The merged configuration. + """ + merged_configs: Dict = {} + + if configs is not None: + if isinstance(configs, Config): + configs = configs.todict() + + if "config_path" in configs and not configs["config_path"] is None: + filebased_configs = yaml.safe_load( + open(configs.pop("config_path")) + ) + else: + filebased_configs = {} + + merged_configs.update(filebased_configs) + + merged_configs.update(configs) + + try: + final_configs = Config(merged_configs, cls.default_configs()) + except ValueError as e: + raise SelectorConfigError( + f"Configuration error for the selector " + f"{get_full_module_name(cls)}." + ) from e + + return final_configs + + @classmethod + def default_configs(cls): + r"""Returns a `dict` of configurations of the component with default + values. Used to replace the missing values of input `configs` + during selector construction. + """ + return {} + class DummySelector(Selector[InputPackType, InputPackType]): r"""Do nothing, return the data pack itself, which can be either @@ -59,6 +119,10 @@ class SinglePackSelector(Selector[MultiPack, DataPack]): This is the base class that select a DataPack from MultiPack. """ + def __init__(self, + configs: Optional[Union[Config, Dict[str, Any]]] = None): + super.__init__(configs) + def select(self, pack: MultiPack) -> Iterator[DataPack]: raise NotImplementedError @@ -68,10 +132,11 @@ class NameMatchSelector(SinglePackSelector): name. """ - def __init__(self, select_name: str): - super().__init__(select_name=select_name) - assert select_name is not None - self.select_name: str = select_name + def __init__(self, + configs: Optional[Union[Config, Dict[str, Any]]] = None): + super().__init__(configs) + self.select_name = self.configs["select_name"] + assert self.select_name is not None def select(self, m_pack: MultiPack) -> Iterator[DataPack]: matches = 0 @@ -85,14 +150,25 @@ def select(self, m_pack: MultiPack) -> Iterator[DataPack]: f"Pack name {self.select_name}" f" not in the MultiPack" ) + @classmethod + def default_configs(cls): + config = super().default_configs() + config.update( + { + "select_name": None + } + ) + return config + class RegexNameMatchSelector(SinglePackSelector): r"""Select a :class:`DataPack` from a :class:`MultiPack` using a regex.""" - def __init__(self, select_name: str): - super().__init__(select_name=select_name) - assert select_name is not None - self.select_name: str = select_name + def __init__(self, + configs: Optional[Union[Config, Dict[str, Any]]] = None): + super().__init__(configs) + self.select_name = self.configs["select_name"] + assert self.select_name is not None def select(self, m_pack: MultiPack) -> Iterator[DataPack]: if len(m_pack.packs) == 0: @@ -102,6 +178,16 @@ def select(self, m_pack: MultiPack) -> Iterator[DataPack]: if re.match(self.select_name, name): yield pack + @classmethod + def default_configs(cls): + config = super().default_configs() + config.update( + { + "select_name": None + } + ) + return config + class FirstPackSelector(SinglePackSelector): r"""Select the first entry from :class:`MultiPack` and yield it.""" diff --git a/forte/pipeline.py b/forte/pipeline.py index 5d4a1563b..20a64cf11 100644 --- a/forte/pipeline.py +++ b/forte/pipeline.py @@ -398,12 +398,15 @@ def test_jsonable(test_dict: Dict, err_msg: str): ), "selector": { "type": get_full_module_name(selector), - "kwargs": test_jsonable( - # pylint: disable=protected-access - test_dict=selector._stored_kwargs, - # pylint: enable=protected-access - err_msg=get_err_msg["selector"](selector), - ), + "kwargs": { + "configs": + test_jsonable( + # pylint: disable=protected-access + test_dict=selector.configs.todict(), + # pylint: enable=protected-access + err_msg=get_err_msg["selector"](selector), + ) + } }, } ) From 556900957f0a60025a40bed31cd825fa58ca00bd Mon Sep 17 00:00:00 2001 From: Zhanyuan Zhang Date: Tue, 7 Sep 2021 22:16:54 -0400 Subject: [PATCH 08/19] passed given unit tests --- forte/data/selector.py | 6 +++--- tests/forte/advanced_pipeline_test.py | 6 +++++- tests/forte/data/selector_test.py | 12 ++++++++++-- tests/forte/pipeline_test.py | 20 +++++++++++++++++--- 4 files changed, 35 insertions(+), 9 deletions(-) diff --git a/forte/data/selector.py b/forte/data/selector.py index 9284c162c..abaa62ac2 100644 --- a/forte/data/selector.py +++ b/forte/data/selector.py @@ -121,7 +121,7 @@ class SinglePackSelector(Selector[MultiPack, DataPack]): def __init__(self, configs: Optional[Union[Config, Dict[str, Any]]] = None): - super.__init__(configs) + super().__init__(configs=configs) def select(self, pack: MultiPack) -> Iterator[DataPack]: raise NotImplementedError @@ -134,7 +134,7 @@ class NameMatchSelector(SinglePackSelector): def __init__(self, configs: Optional[Union[Config, Dict[str, Any]]] = None): - super().__init__(configs) + super().__init__(configs=configs) self.select_name = self.configs["select_name"] assert self.select_name is not None @@ -166,7 +166,7 @@ class RegexNameMatchSelector(SinglePackSelector): def __init__(self, configs: Optional[Union[Config, Dict[str, Any]]] = None): - super().__init__(configs) + super().__init__(configs=configs) self.select_name = self.configs["select_name"] assert self.select_name is not None diff --git a/tests/forte/advanced_pipeline_test.py b/tests/forte/advanced_pipeline_test.py index 0adf7c350..57d41a34e 100644 --- a/tests/forte/advanced_pipeline_test.py +++ b/tests/forte/advanced_pipeline_test.py @@ -190,7 +190,11 @@ def test_ir_selector(self): pl.set_reader(DummyMultiPackReader()) pl.add( DummyProcessor(), - selector=RegexNameMatchSelector(select_name="^.*\\d$"), + selector=RegexNameMatchSelector( + configs={ + "select_name": "^.*\\d$" + } + ) ) pl.save(self._pl_config_path) diff --git a/tests/forte/data/selector_test.py b/tests/forte/data/selector_test.py index dcaa74934..8ecdf69dc 100644 --- a/tests/forte/data/selector_test.py +++ b/tests/forte/data/selector_test.py @@ -39,14 +39,22 @@ def setUp(self) -> None: data_pack3.pack_name = "Three" def test_name_match_selector(self) -> None: - selector = NameMatchSelector(select_name="pack1") + selector = NameMatchSelector( + configs={ + "select_name": "pack1" + } + ) packs = selector.select(self.multi_pack) doc_ids = ["1"] for doc_id, pack in zip(doc_ids, packs): self.assertEqual(doc_id, pack.pack_name) def test_regex_name_match_selector(self) -> None: - selector = RegexNameMatchSelector(select_name="^.*\\d$") + selector = RegexNameMatchSelector( + configs={ + "select_name": "^.*\\d$" + } + ) packs = selector.select(self.multi_pack) doc_ids = ["1", "2"] for doc_id, pack in zip(doc_ids, packs): diff --git a/tests/forte/pipeline_test.py b/tests/forte/pipeline_test.py index a14b81080..893a04c24 100644 --- a/tests/forte/pipeline_test.py +++ b/tests/forte/pipeline_test.py @@ -634,7 +634,11 @@ def test_process_multi_next(self): nlp.add( DummyRelationExtractor(), config={"batcher": {"batch_size": 5}}, - selector=NameMatchSelector(select_name=pack_name), + selector=NameMatchSelector( + configs={ + "select_name": pack_name + } + ) ) nlp.initialize() @@ -1154,7 +1158,11 @@ def test_reuse_processor(self): nlp.add( dummy, config={"test": "dummy1"}, - selector=NameMatchSelector("default"), + selector=NameMatchSelector( + configs={ + "select_name": "default" + } + ) ) # This will not add the component successfully because the processor is @@ -1163,7 +1171,13 @@ def test_reuse_processor(self): nlp.add(dummy, config={"test": "dummy2"}) # This will add the component, with a different selector - nlp.add(dummy, selector=NameMatchSelector("copy")) + nlp.add(dummy, + selector=NameMatchSelector( + configs={ + "select_name": "copy" + } + ) + ) nlp.initialize() # Check that the two processors have the same name. From 90ecbe0a498447c8f20b16225c87bd4ce371c7ca Mon Sep 17 00:00:00 2001 From: Zhanyuan Zhang Date: Thu, 9 Sep 2021 03:53:42 -0400 Subject: [PATCH 09/19] Selector inherits Configurable and assures backward compatability --- forte/common/exception.py | 6 -- forte/data/selector.py | 98 +++++++++------------------ tests/forte/advanced_pipeline_test.py | 7 +- tests/forte/data/selector_test.py | 26 +++++++ 4 files changed, 64 insertions(+), 73 deletions(-) diff --git a/forte/common/exception.py b/forte/common/exception.py index 2ecc94d20..fce6c88b8 100644 --- a/forte/common/exception.py +++ b/forte/common/exception.py @@ -20,7 +20,6 @@ "IncompleteEntryError", "EntryNotFoundError", "ProcessorConfigError", - "SelectorConfigError", "PackDataException", "ProcessFlowException", "ProcessExecutionException", @@ -62,11 +61,6 @@ class ResourceError(ValueError): """ pass -class SelectorConfigError(ValueError): - r"""Raise this error when the there is a problem with the processor - configuration. - """ - pass class PackDataException(Exception): r"""Raise this error when the data in pack is wrong.""" diff --git a/forte/data/selector.py b/forte/data/selector.py index abaa62ac2..003328fcc 100644 --- a/forte/data/selector.py +++ b/forte/data/selector.py @@ -18,14 +18,12 @@ from typing import Generic, Iterator, TypeVar, Optional, Union, Dict, Any import re -import yaml from forte.common.configuration import Config -from forte.common import SelectorConfigError +from forte.common.configurable import Configurable from forte.data.base_pack import BasePack from forte.data.data_pack import DataPack from forte.data.multi_pack import MultiPack -from forte.utils import get_full_module_name InputPackType = TypeVar("InputPackType", bound=BasePack) OutputPackType = TypeVar("OutputPackType", bound=BasePack) @@ -41,69 +39,14 @@ ] -class Selector(Generic[InputPackType, OutputPackType]): - def __init__(self, +class Selector(Generic[InputPackType, OutputPackType], Configurable): + def __init__(self, configs: Optional[Union[Config, Dict[str, Any]]] = None): self.configs = self.make_configs(configs) def select(self, pack: InputPackType) -> Iterator[OutputPackType]: raise NotImplementedError - @classmethod - def make_configs( - cls, configs: Optional[Union[Config, Dict[str, Any]]] - ) -> Config: - """ - Create the component configuration for this class, by merging the - provided config with the ``default_configs()``. - - The following config conventions are expected: - - The top level key can be a special `config_path`. - - `config_path` should be point to a file system path, which will - be a YAML file containing configurations. - - Other key values in the configs will be considered as parameters. - - Args: - configs: The input config to be merged with the default config. - - Returns: - The merged configuration. - """ - merged_configs: Dict = {} - - if configs is not None: - if isinstance(configs, Config): - configs = configs.todict() - - if "config_path" in configs and not configs["config_path"] is None: - filebased_configs = yaml.safe_load( - open(configs.pop("config_path")) - ) - else: - filebased_configs = {} - - merged_configs.update(filebased_configs) - - merged_configs.update(configs) - - try: - final_configs = Config(merged_configs, cls.default_configs()) - except ValueError as e: - raise SelectorConfigError( - f"Configuration error for the selector " - f"{get_full_module_name(cls)}." - ) from e - - return final_configs - - @classmethod - def default_configs(cls): - r"""Returns a `dict` of configurations of the component with default - values. Used to replace the missing values of input `configs` - during selector construction. - """ - return {} - class DummySelector(Selector[InputPackType, InputPackType]): r"""Do nothing, return the data pack itself, which can be either @@ -119,7 +62,7 @@ class SinglePackSelector(Selector[MultiPack, DataPack]): This is the base class that select a DataPack from MultiPack. """ - def __init__(self, + def __init__(self, configs: Optional[Union[Config, Dict[str, Any]]] = None): super().__init__(configs=configs) @@ -130,10 +73,27 @@ def select(self, pack: MultiPack) -> Iterator[DataPack]: class NameMatchSelector(SinglePackSelector): r"""Select a :class:`DataPack` from a :class:`MultiPack` with specified name. + Previous: + - selector = NameMatchSelector(select_name="foo") + - selector = NameMatchSelector("foo") + New: + - selector = NameMatchSelector( + configs={ + "select_name": "foo" + } + ) """ - def __init__(self, - configs: Optional[Union[Config, Dict[str, Any]]] = None): + def __init__(self, *args, **kwargs): + assert (len(args) == 0) ^ (len(kwargs) == 0) + if args: + configs = {"select_name": args[0]} + else: + assert ("configs" in kwargs) or ("select_name" in kwargs) + if "select_name" in kwargs: + configs = {"select_name": kwargs["select_name"]} + else: + configs = kwargs["configs"] super().__init__(configs=configs) self.select_name = self.configs["select_name"] assert self.select_name is not None @@ -164,8 +124,16 @@ def default_configs(cls): class RegexNameMatchSelector(SinglePackSelector): r"""Select a :class:`DataPack` from a :class:`MultiPack` using a regex.""" - def __init__(self, - configs: Optional[Union[Config, Dict[str, Any]]] = None): + def __init__(self, *args, **kwargs): + assert (len(args) == 0) ^ (len(kwargs) == 0) + if args: + configs = {"select_name": args[0]} + else: + assert ("configs" in kwargs) or ("select_name" in kwargs) + if "select_name" in kwargs: + configs = {"select_name": kwargs["select_name"]} + else: + configs = kwargs["configs"] super().__init__(configs=configs) self.select_name = self.configs["select_name"] assert self.select_name is not None diff --git a/tests/forte/advanced_pipeline_test.py b/tests/forte/advanced_pipeline_test.py index bbcc4a59d..57d41a34e 100644 --- a/tests/forte/advanced_pipeline_test.py +++ b/tests/forte/advanced_pipeline_test.py @@ -34,6 +34,7 @@ from ft.onto.base_ontology import Utterance from forte.data.ontology.top import Generics + TEST_RECORDS_1 = { "Token": {"1", "2"}, "Document": {"2"}, @@ -54,7 +55,9 @@ def _process(self, input_pack: DataPack): @classmethod def default_configs(cls): - return {"user_input": ""} + config = super().default_configs() + config["user_input"] = "" + return config class DummyMultiPackReader(MultiPackReader): @@ -75,7 +78,7 @@ def _parse_pack(self, collection: Any) -> Iterator[MultiPack]: class DummyProcessor(PackProcessor): """ - A dummy Processor to check the expected/output records from the remote + A dummpy Processor to check the expected/output records from the remote pipeline. """ diff --git a/tests/forte/data/selector_test.py b/tests/forte/data/selector_test.py index 8ecdf69dc..3c9cdbff4 100644 --- a/tests/forte/data/selector_test.py +++ b/tests/forte/data/selector_test.py @@ -49,6 +49,19 @@ def test_name_match_selector(self) -> None: for doc_id, pack in zip(doc_ids, packs): self.assertEqual(doc_id, pack.pack_name) + def test_name_match_selector_backward_compatability(self) -> None: + selector = NameMatchSelector(select_name="pack1") + packs = selector.select(self.multi_pack) + doc_ids = ["1"] + for doc_id, pack in zip(doc_ids, packs): + self.assertEqual(doc_id, pack.pack_name) + + selector = NameMatchSelector("pack1") + packs = selector.select(self.multi_pack) + doc_ids = ["1"] + for doc_id, pack in zip(doc_ids, packs): + self.assertEqual(doc_id, pack.pack_name) + def test_regex_name_match_selector(self) -> None: selector = RegexNameMatchSelector( configs={ @@ -60,6 +73,19 @@ def test_regex_name_match_selector(self) -> None: for doc_id, pack in zip(doc_ids, packs): self.assertEqual(doc_id, pack.pack_name) + def test_regex_name_match_selector_backward_compatability(self) -> None: + selector = RegexNameMatchSelector(select_name="^.*\\d$") + packs = selector.select(self.multi_pack) + doc_ids = ["1", "2"] + for doc_id, pack in zip(doc_ids, packs): + self.assertEqual(doc_id, pack.pack_name) + + selector = RegexNameMatchSelector("^.*\\d$") + packs = selector.select(self.multi_pack) + doc_ids = ["1", "2"] + for doc_id, pack in zip(doc_ids, packs): + self.assertEqual(doc_id, pack.pack_name) + def test_first_pack_selector(self) -> None: selector = FirstPackSelector() packs = list(selector.select(self.multi_pack)) From 6f8a7eb12cef0a3f26f57ba6dc90c17e839bca30 Mon Sep 17 00:00:00 2001 From: Suqi Sun Date: Thu, 9 Sep 2021 14:43:58 -0400 Subject: [PATCH 10/19] Fix pylint err --- forte/pipeline.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/forte/pipeline.py b/forte/pipeline.py index 778666314..223fab979 100644 --- a/forte/pipeline.py +++ b/forte/pipeline.py @@ -373,8 +373,8 @@ def test_jsonable(test_dict: Dict, err_msg: str): configs: Dict = { "forte_ir_version": FORTE_IR_VERSION, - "components": list(), - "states": dict(), + "components": [], + "states": {}, } # Serialize pipeline components @@ -422,7 +422,7 @@ def test_jsonable(test_dict: Dict, err_msg: str): ) if hasattr(self, attr) }, - "resource": dict(), + "resource": {}, } ) if self.resource.contains("onto_specs_dict"): From 7f4df756f8c38a8e4b65b1f94b25c2e914beb3cc Mon Sep 17 00:00:00 2001 From: Zhanyuan Zhang Date: Thu, 9 Sep 2021 16:40:12 -0400 Subject: [PATCH 11/19] resolved PR comments on Selector 1.0 --- forte/data/selector.py | 46 +++++++++++++-------------- tests/forte/advanced_pipeline_test.py | 2 +- 2 files changed, 23 insertions(+), 25 deletions(-) diff --git a/forte/data/selector.py b/forte/data/selector.py index 003328fcc..2e3744711 100644 --- a/forte/data/selector.py +++ b/forte/data/selector.py @@ -62,10 +62,6 @@ class SinglePackSelector(Selector[MultiPack, DataPack]): This is the base class that select a DataPack from MultiPack. """ - def __init__(self, - configs: Optional[Union[Config, Dict[str, Any]]] = None): - super().__init__(configs=configs) - def select(self, pack: MultiPack) -> Iterator[DataPack]: raise NotImplementedError @@ -73,11 +69,13 @@ def select(self, pack: MultiPack) -> Iterator[DataPack]: class NameMatchSelector(SinglePackSelector): r"""Select a :class:`DataPack` from a :class:`MultiPack` with specified name. - Previous: - - selector = NameMatchSelector(select_name="foo") - - selector = NameMatchSelector("foo") - New: - - selector = NameMatchSelector( + + This implementation takes special care for backward compatability: + Deprecated: + selector = NameMatchSelector(select_name="foo") + selector = NameMatchSelector("foo") + Now: + selector = NameMatchSelector( configs={ "select_name": "foo" } @@ -112,17 +110,23 @@ def select(self, m_pack: MultiPack) -> Iterator[DataPack]: @classmethod def default_configs(cls): - config = super().default_configs() - config.update( - { - "select_name": None - } - ) - return config + return {"select_name": None} class RegexNameMatchSelector(SinglePackSelector): - r"""Select a :class:`DataPack` from a :class:`MultiPack` using a regex.""" + r"""Select a :class:`DataPack` from a :class:`MultiPack` using a regex. + + This implementation takes special care for backward compatability: + Deprecated: + selector = RegexNameMatchSelector(select_name="^.*\\d$") + selector = RegexNameMatchSelector("^.*\\d$") + Now: + selector = RegexNameMatchSelector( + configs={ + "select_name": "^.*\\d$" + } + ) + """ def __init__(self, *args, **kwargs): assert (len(args) == 0) ^ (len(kwargs) == 0) @@ -148,13 +152,7 @@ def select(self, m_pack: MultiPack) -> Iterator[DataPack]: @classmethod def default_configs(cls): - config = super().default_configs() - config.update( - { - "select_name": None - } - ) - return config + return {"select_name": None} class FirstPackSelector(SinglePackSelector): diff --git a/tests/forte/advanced_pipeline_test.py b/tests/forte/advanced_pipeline_test.py index 57d41a34e..588aba259 100644 --- a/tests/forte/advanced_pipeline_test.py +++ b/tests/forte/advanced_pipeline_test.py @@ -78,7 +78,7 @@ def _parse_pack(self, collection: Any) -> Iterator[MultiPack]: class DummyProcessor(PackProcessor): """ - A dummpy Processor to check the expected/output records from the remote + A dummy Processor to check the expected/output records from the remote pipeline. """ From e726683d58b27c8dc2ea94bbceae7c2f6bb6d305 Mon Sep 17 00:00:00 2001 From: Suqi Sun Date: Thu, 9 Sep 2021 19:02:17 -0400 Subject: [PATCH 12/19] Fix lint issue --- forte/data/selector.py | 3 +-- forte/pipeline.py | 13 ++++++------- tests/forte/pipeline_test.py | 2 +- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/forte/data/selector.py b/forte/data/selector.py index 2e3744711..83d63476d 100644 --- a/forte/data/selector.py +++ b/forte/data/selector.py @@ -40,8 +40,7 @@ class Selector(Generic[InputPackType, OutputPackType], Configurable): - def __init__(self, - configs: Optional[Union[Config, Dict[str, Any]]] = None): + def __init__(self, configs: Optional[Union[Config, Dict[str, Any]]] = None): self.configs = self.make_configs(configs) def select(self, pack: InputPackType) -> Iterator[OutputPackType]: diff --git a/forte/pipeline.py b/forte/pipeline.py index aa2489f63..1f6cfdfcf 100644 --- a/forte/pipeline.py +++ b/forte/pipeline.py @@ -400,14 +400,13 @@ def test_jsonable(test_dict: Dict, err_msg: str): "selector": { "type": get_full_module_name(selector), "kwargs": { - "configs": - test_jsonable( - # pylint: disable=protected-access - test_dict=selector.configs.todict(), - # pylint: enable=protected-access - err_msg=get_err_msg["selector"](selector), + "configs": test_jsonable( + # pylint: disable=protected-access + test_dict=selector.configs.todict(), + # pylint: enable=protected-access + err_msg=get_err_msg["selector"](selector), ) - } + }, }, } ) diff --git a/tests/forte/pipeline_test.py b/tests/forte/pipeline_test.py index 014b03361..094a13653 100644 --- a/tests/forte/pipeline_test.py +++ b/tests/forte/pipeline_test.py @@ -1321,7 +1321,7 @@ def test_reuse_processor(self): configs={ "select_name": "copy" } - ) + ) ) nlp.initialize() From 300e1b783da043bcf230715f2ad806be42c4159d Mon Sep 17 00:00:00 2001 From: Zhanyuan Zhang Date: Tue, 14 Sep 2021 17:40:38 -0400 Subject: [PATCH 13/19] pipeline calls selectors' method --- forte/data/selector.py | 10 ++++++++++ forte/pipeline.py | 14 +++++++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/forte/data/selector.py b/forte/data/selector.py index 83d63476d..d6da58d01 100644 --- a/forte/data/selector.py +++ b/forte/data/selector.py @@ -42,10 +42,20 @@ class Selector(Generic[InputPackType, OutputPackType], Configurable): def __init__(self, configs: Optional[Union[Config, Dict[str, Any]]] = None): self.configs = self.make_configs(configs) + # The flag indicating whether the selector is initialized. + self.__is_initialized: bool = False def select(self, pack: InputPackType) -> Iterator[OutputPackType]: raise NotImplementedError + def initialize(self): + # Reset selector states + self.__is_initialized = True + + @property + def is_initialized(self) -> bool: + return self.__is_initialized + class DummySelector(Selector[InputPackType, InputPackType]): r"""Do nothing, return the data pack itself, which can be either diff --git a/forte/pipeline.py b/forte/pipeline.py index 1f6cfdfcf..ef6aad183 100644 --- a/forte/pipeline.py +++ b/forte/pipeline.py @@ -613,8 +613,9 @@ def initialize(self) -> "Pipeline": else: self.reader.enforce_consistency(enforce=False) - # Handle other components. + # Handle other components and their selectors. self.initialize_components() + self.initialize_selectors() self._initialized = True # Create profiler @@ -673,6 +674,17 @@ def initialize_components(self): component.enforce_consistency(enforce=self._check_type_consistency) + def initialize_selectors(self): + """ + This function will reset the states of selectors + """ + for selector in self._selectors: + try: + selector.initialize() + except ValueError as e: + logging.error("Exception occur when initializing selectors") + raise e + def set_reader( self, reader: BaseReader, From 95b602cef180ae1a03a8138a73bd87debce1ab3f Mon Sep 17 00:00:00 2001 From: Zhanyuan Zhang Date: Tue, 14 Sep 2021 18:21:42 -0400 Subject: [PATCH 14/19] removed is_initialized property from Selector --- forte/data/selector.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/forte/data/selector.py b/forte/data/selector.py index d6da58d01..7ab3c3abb 100644 --- a/forte/data/selector.py +++ b/forte/data/selector.py @@ -42,19 +42,13 @@ class Selector(Generic[InputPackType, OutputPackType], Configurable): def __init__(self, configs: Optional[Union[Config, Dict[str, Any]]] = None): self.configs = self.make_configs(configs) - # The flag indicating whether the selector is initialized. - self.__is_initialized: bool = False def select(self, pack: InputPackType) -> Iterator[OutputPackType]: raise NotImplementedError def initialize(self): # Reset selector states - self.__is_initialized = True - - @property - def is_initialized(self) -> bool: - return self.__is_initialized + pass class DummySelector(Selector[InputPackType, InputPackType]): From ded754c482bb897a0fb4b918d0680df4e5441415 Mon Sep 17 00:00:00 2001 From: Zhanyuan Zhang Date: Wed, 15 Sep 2021 19:10:33 -0400 Subject: [PATCH 15/19] implemented Selector initialize method --- forte/data/selector.py | 75 +++++++++++++++------------ forte/pipeline.py | 29 ++++++----- tests/forte/advanced_pipeline_test.py | 9 ++-- tests/forte/data/selector_test.py | 10 ++-- tests/forte/pipeline_test.py | 30 +++++------ 5 files changed, 83 insertions(+), 70 deletions(-) diff --git a/forte/data/selector.py b/forte/data/selector.py index 7ab3c3abb..98af60852 100644 --- a/forte/data/selector.py +++ b/forte/data/selector.py @@ -40,15 +40,16 @@ class Selector(Generic[InputPackType, OutputPackType], Configurable): - def __init__(self, configs: Optional[Union[Config, Dict[str, Any]]] = None): - self.configs = self.make_configs(configs) + + def __init__(self): + self.configs: Config = Config({}, {}) def select(self, pack: InputPackType) -> Iterator[OutputPackType]: raise NotImplementedError - def initialize(self): - # Reset selector states - pass + def initialize(self, + configs: Optional[Union[Config, Dict[str, Any]]] = None): + self.configs = self.make_configs(configs) class DummySelector(Selector[InputPackType, InputPackType]): @@ -78,26 +79,17 @@ class NameMatchSelector(SinglePackSelector): selector = NameMatchSelector(select_name="foo") selector = NameMatchSelector("foo") Now: - selector = NameMatchSelector( + selector = NameMatchSelector() + selector.initialize( configs={ "select_name": "foo" } ) """ - def __init__(self, *args, **kwargs): - assert (len(args) == 0) ^ (len(kwargs) == 0) - if args: - configs = {"select_name": args[0]} - else: - assert ("configs" in kwargs) or ("select_name" in kwargs) - if "select_name" in kwargs: - configs = {"select_name": kwargs["select_name"]} - else: - configs = kwargs["configs"] - super().__init__(configs=configs) - self.select_name = self.configs["select_name"] - assert self.select_name is not None + def __init__(self, select_name: str = None): + super().__init__() + self.select_name = select_name def select(self, m_pack: MultiPack) -> Iterator[DataPack]: matches = 0 @@ -111,6 +103,19 @@ def select(self, m_pack: MultiPack) -> Iterator[DataPack]: f"Pack name {self.select_name}" f" not in the MultiPack" ) + def initialize(self, + configs: Optional[Union[Config, Dict[str, Any]]] = None): + if self.select_name is not None: + super().initialize( + {"select_name": self.select_name} + ) + else: + super().initialize(configs) + + if self.configs["select_name"] is None: + raise ValueError("select_name shouldn't be None.") + self.select_name = self.configs["select_name"] + @classmethod def default_configs(cls): return {"select_name": None} @@ -124,26 +129,17 @@ class RegexNameMatchSelector(SinglePackSelector): selector = RegexNameMatchSelector(select_name="^.*\\d$") selector = RegexNameMatchSelector("^.*\\d$") Now: - selector = RegexNameMatchSelector( + selector = RegexNameMatchSelector() + selector.initialize( configs={ "select_name": "^.*\\d$" } ) """ - def __init__(self, *args, **kwargs): - assert (len(args) == 0) ^ (len(kwargs) == 0) - if args: - configs = {"select_name": args[0]} - else: - assert ("configs" in kwargs) or ("select_name" in kwargs) - if "select_name" in kwargs: - configs = {"select_name": kwargs["select_name"]} - else: - configs = kwargs["configs"] - super().__init__(configs=configs) - self.select_name = self.configs["select_name"] - assert self.select_name is not None + def __init__(self, select_name: str = None): + super().__init__() + self.select_name = select_name def select(self, m_pack: MultiPack) -> Iterator[DataPack]: if len(m_pack.packs) == 0: @@ -153,6 +149,19 @@ def select(self, m_pack: MultiPack) -> Iterator[DataPack]: if re.match(self.select_name, name): yield pack + def initialize(self, + configs: Optional[Union[Config, Dict[str, Any]]] = None): + if self.select_name is not None: + super().initialize( + {"select_name": self.select_name} + ) + else: + super().initialize(configs) + + if self.configs["select_name"] is None: + raise ValueError("select_name shouldn't be None.") + self.select_name = self.configs["select_name"] + @classmethod def default_configs(cls): return {"select_name": None} diff --git a/forte/pipeline.py b/forte/pipeline.py index ef6aad183..f9d5ed894 100644 --- a/forte/pipeline.py +++ b/forte/pipeline.py @@ -171,6 +171,7 @@ def __init__( self._components: List[PipelineComponent] = [] self._selectors: List[Selector] = [] self._configs: List[Optional[Config]] = [] + self._selectors_configs: List[Optional[Config]] = [] # Maintain a set of the pipeline components to fast check whether # the component is already there. @@ -210,6 +211,7 @@ def __init__( # Create one copy of the dummy selector to reduce class creation. self.__default_selector: Selector = DummySelector() + self.__default_selector_config: Config = Config({}, {}) # needed for time profiling of pipeline self._enable_profiling: bool = False @@ -298,6 +300,7 @@ def init_from_config(self, configs: Dict[str, Any]): class_name=selector_config["type"], class_args=selector_config.get("kwargs", {}), ), + selector_config=selector_config.get("configs") ) # Set pipeline states and resources @@ -387,8 +390,9 @@ def test_jsonable(test_dict: Dict, err_msg: str): ), } ) - for component, config, selector in zip( - self.components, self.component_configs, self._selectors + for component, config, selector, selector_config in zip( + self.components, self.component_configs, + self._selectors, self._selectors_configs ): configs["components"].append( { @@ -399,14 +403,12 @@ def test_jsonable(test_dict: Dict, err_msg: str): ), "selector": { "type": get_full_module_name(selector), - "kwargs": { - "configs": test_jsonable( - # pylint: disable=protected-access - test_dict=selector.configs.todict(), - # pylint: enable=protected-access - err_msg=get_err_msg["selector"](selector), - ) - }, + "configs": test_jsonable( + # pylint: disable=protected-access + test_dict=selector_config.todict(), + # pylint: enable=protected-access + err_msg=get_err_msg["selector"](selector), + ), }, } ) @@ -678,9 +680,9 @@ def initialize_selectors(self): """ This function will reset the states of selectors """ - for selector in self._selectors: + for selector, config in zip(self._selectors, self._selectors_configs): try: - selector.initialize() + selector.initialize(config) except ValueError as e: logging.error("Exception occur when initializing selectors") raise e @@ -744,6 +746,7 @@ def add( component: PipelineComponent, config: Optional[Union[Config, Dict[str, Any]]] = None, selector: Optional[Selector] = None, + selector_config: Optional[Union[Config, Dict[str, Any]]] = None, ) -> "Pipeline": """ Adds a pipeline component to the pipeline. The pipeline components @@ -813,8 +816,10 @@ def add( if selector is None: self._selectors.append(self.__default_selector) + self._selectors_configs.append(self.__default_selector_config) else: self._selectors.append(selector) + self._selectors_configs.append(selector.make_configs(selector_config)) return self diff --git a/tests/forte/advanced_pipeline_test.py b/tests/forte/advanced_pipeline_test.py index 588aba259..be08e5c40 100644 --- a/tests/forte/advanced_pipeline_test.py +++ b/tests/forte/advanced_pipeline_test.py @@ -190,11 +190,10 @@ def test_ir_selector(self): pl.set_reader(DummyMultiPackReader()) pl.add( DummyProcessor(), - selector=RegexNameMatchSelector( - configs={ - "select_name": "^.*\\d$" - } - ) + selector=RegexNameMatchSelector(), + selector_config={ + "select_name": "^.*\\d$" + }, ) pl.save(self._pl_config_path) diff --git a/tests/forte/data/selector_test.py b/tests/forte/data/selector_test.py index 3c9cdbff4..178be0492 100644 --- a/tests/forte/data/selector_test.py +++ b/tests/forte/data/selector_test.py @@ -39,10 +39,11 @@ def setUp(self) -> None: data_pack3.pack_name = "Three" def test_name_match_selector(self) -> None: - selector = NameMatchSelector( + selector = NameMatchSelector() + selector.initialize( configs={ "select_name": "pack1" - } + }, ) packs = selector.select(self.multi_pack) doc_ids = ["1"] @@ -63,10 +64,11 @@ def test_name_match_selector_backward_compatability(self) -> None: self.assertEqual(doc_id, pack.pack_name) def test_regex_name_match_selector(self) -> None: - selector = RegexNameMatchSelector( + selector = RegexNameMatchSelector() + selector.initialize( configs={ "select_name": "^.*\\d$" - } + }, ) packs = selector.select(self.multi_pack) doc_ids = ["1", "2"] diff --git a/tests/forte/pipeline_test.py b/tests/forte/pipeline_test.py index 094a13653..ab33601d4 100644 --- a/tests/forte/pipeline_test.py +++ b/tests/forte/pipeline_test.py @@ -699,11 +699,10 @@ def test_process_multi_next(self): nlp.add( DummyRelationExtractor(), config={"batcher": {"batch_size": 5}}, - selector=NameMatchSelector( - configs={ - "select_name": pack_name - } - ) + selector=NameMatchSelector(), + selector_config={ + "select_name": pack_name + }, ) nlp.initialize() @@ -1303,11 +1302,10 @@ def test_reuse_processor(self): nlp.add( dummy, config={"test": "dummy1"}, - selector=NameMatchSelector( - configs={ - "select_name": "default" - } - ) + selector=NameMatchSelector(), + selector_config={ + "select_name": "default" + }, ) # This will not add the component successfully because the processor is @@ -1316,12 +1314,12 @@ def test_reuse_processor(self): nlp.add(dummy, config={"test": "dummy2"}) # This will add the component, with a different selector - nlp.add(dummy, - selector=NameMatchSelector( - configs={ - "select_name": "copy" - } - ) + nlp.add( + dummy, + selector=NameMatchSelector(), + selector_config={ + "select_name": "copy" + }, ) nlp.initialize() From 88798f2c092691fd9a601e9e24f9b109d7e5e851 Mon Sep 17 00:00:00 2001 From: Suqi Sun Date: Thu, 16 Sep 2021 18:23:24 -0400 Subject: [PATCH 16/19] Fix black issue --- forte/data/selector.py | 24 +++++++++++------------- forte/pipeline.py | 12 ++++++++---- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/forte/data/selector.py b/forte/data/selector.py index 98af60852..e59c47770 100644 --- a/forte/data/selector.py +++ b/forte/data/selector.py @@ -40,15 +40,15 @@ class Selector(Generic[InputPackType, OutputPackType], Configurable): - def __init__(self): self.configs: Config = Config({}, {}) def select(self, pack: InputPackType) -> Iterator[OutputPackType]: raise NotImplementedError - def initialize(self, - configs: Optional[Union[Config, Dict[str, Any]]] = None): + def initialize( + self, configs: Optional[Union[Config, Dict[str, Any]]] = None + ): self.configs = self.make_configs(configs) @@ -103,12 +103,11 @@ def select(self, m_pack: MultiPack) -> Iterator[DataPack]: f"Pack name {self.select_name}" f" not in the MultiPack" ) - def initialize(self, - configs: Optional[Union[Config, Dict[str, Any]]] = None): + def initialize( + self, configs: Optional[Union[Config, Dict[str, Any]]] = None + ): if self.select_name is not None: - super().initialize( - {"select_name": self.select_name} - ) + super().initialize({"select_name": self.select_name}) else: super().initialize(configs) @@ -149,12 +148,11 @@ def select(self, m_pack: MultiPack) -> Iterator[DataPack]: if re.match(self.select_name, name): yield pack - def initialize(self, - configs: Optional[Union[Config, Dict[str, Any]]] = None): + def initialize( + self, configs: Optional[Union[Config, Dict[str, Any]]] = None + ): if self.select_name is not None: - super().initialize( - {"select_name": self.select_name} - ) + super().initialize({"select_name": self.select_name}) else: super().initialize(configs) diff --git a/forte/pipeline.py b/forte/pipeline.py index f9d5ed894..8e005ecd5 100644 --- a/forte/pipeline.py +++ b/forte/pipeline.py @@ -300,7 +300,7 @@ def init_from_config(self, configs: Dict[str, Any]): class_name=selector_config["type"], class_args=selector_config.get("kwargs", {}), ), - selector_config=selector_config.get("configs") + selector_config=selector_config.get("configs"), ) # Set pipeline states and resources @@ -391,8 +391,10 @@ def test_jsonable(test_dict: Dict, err_msg: str): } ) for component, config, selector, selector_config in zip( - self.components, self.component_configs, - self._selectors, self._selectors_configs + self.components, + self.component_configs, + self._selectors, + self._selectors_configs, ): configs["components"].append( { @@ -819,7 +821,9 @@ def add( self._selectors_configs.append(self.__default_selector_config) else: self._selectors.append(selector) - self._selectors_configs.append(selector.make_configs(selector_config)) + self._selectors_configs.append( + selector.make_configs(selector_config) + ) return self From 7268cd449bd0f0b48086cba703c0475f01558e40 Mon Sep 17 00:00:00 2001 From: Suqi Sun Date: Thu, 16 Sep 2021 18:31:42 -0400 Subject: [PATCH 17/19] Ignore too-many-public-methods pylint error --- forte/pipeline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/forte/pipeline.py b/forte/pipeline.py index 8e005ecd5..4ab000bff 100644 --- a/forte/pipeline.py +++ b/forte/pipeline.py @@ -113,6 +113,7 @@ def __next__(self) -> ProcessJob: class Pipeline(Generic[PackType]): + # pylint: disable=too-many-public-methods r"""This controls the main inference flow of the system. A pipeline is consisted of a set of Components (readers and processors). The data flows in the pipeline as data packs, and each component will use or add From c4b0ac01c945238f223acf9e63319bf106afd43d Mon Sep 17 00:00:00 2001 From: Suqi Sun Date: Thu, 16 Sep 2021 19:02:02 -0400 Subject: [PATCH 18/19] Fix mypy error --- forte/data/selector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/forte/data/selector.py b/forte/data/selector.py index e59c47770..40354e2fb 100644 --- a/forte/data/selector.py +++ b/forte/data/selector.py @@ -87,7 +87,7 @@ class NameMatchSelector(SinglePackSelector): ) """ - def __init__(self, select_name: str = None): + def __init__(self, select_name: Optional[str] = None): super().__init__() self.select_name = select_name @@ -136,7 +136,7 @@ class RegexNameMatchSelector(SinglePackSelector): ) """ - def __init__(self, select_name: str = None): + def __init__(self, select_name: Optional[str] = None): super().__init__() self.select_name = select_name @@ -145,7 +145,7 @@ def select(self, m_pack: MultiPack) -> Iterator[DataPack]: raise ValueError("Multi-pack is empty") else: for name, pack in m_pack.iter_packs(): - if re.match(self.select_name, name): + if re.match(self.select_name, name): # type: ignore yield pack def initialize( From b3de9990fd9ca3024d12aed72e37d4b153f5468e Mon Sep 17 00:00:00 2001 From: Suqi Sun Date: Thu, 16 Sep 2021 19:12:14 -0400 Subject: [PATCH 19/19] Fix black issue --- forte/data/selector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/forte/data/selector.py b/forte/data/selector.py index 40354e2fb..4e76005a1 100644 --- a/forte/data/selector.py +++ b/forte/data/selector.py @@ -145,7 +145,7 @@ def select(self, m_pack: MultiPack) -> Iterator[DataPack]: raise ValueError("Multi-pack is empty") else: for name, pack in m_pack.iter_packs(): - if re.match(self.select_name, name): # type: ignore + if re.match(self.select_name, name): # type: ignore yield pack def initialize(