From 4cb739489d8b562fda4d288b18cc6736c3a0a37e Mon Sep 17 00:00:00 2001 From: Lucas Gautheron Date: Mon, 7 Jun 2021 13:23:22 +0200 Subject: [PATCH] importation code factorization --- ChildProject/annotations.py | 26 ++++---------------------- ChildProject/converters.py | 10 +++++----- 2 files changed, 9 insertions(+), 27 deletions(-) diff --git a/ChildProject/annotations.py b/ChildProject/annotations.py index 3872b65b6..2486b2762 100644 --- a/ChildProject/annotations.py +++ b/ChildProject/annotations.py @@ -31,7 +31,7 @@ class AnnotationManager: IndexColumn(name = 'range_onset', description = 'covered range start time in milliseconds, measured since `time_seek`', regex = r"([0-9]+)", required = True), IndexColumn(name = 'range_offset', description = 'covered range end time in milliseconds, measured since `time_seek`', regex = r"([0-9]+)", required = True), IndexColumn(name = 'raw_filename', description = 'annotation input filename location, relative to `annotations//raw`', filename = True, required = True), - IndexColumn(name = 'format', description = 'input annotation format', choices = ['TextGrid', 'eaf', 'vtc_rttm', 'vcm_rttm', 'alice', 'its', 'cha', 'NA'], required = False), + IndexColumn(name = 'format', description = 'input annotation format', choices = converters.keys() + ['NA'], required = False), IndexColumn(name = 'filter', description = 'source file to filter in (for rttm and alice only)', required = False), IndexColumn(name = 'annotation_filename', description = 'output formatted annotation location, relative to `annotations//converted (automatic column, don\'t specify)', filename = True, required = False, generated = True), IndexColumn(name = 'imported_at', description = 'importation date (automatic column, don\'t specify)', datetime = "%Y-%m-%d %H:%M:%S", required = False, generated = True), @@ -195,27 +195,9 @@ def _import_annotation(self, import_function: Callable[[str], pd.DataFrame], ann try: if callable(import_function): df = import_function(path) - elif annotation_format == 'TextGrid': - from .converters import TextGridConverter - df = TextGridConverter.convert(path) - elif annotation_format == 'eaf': - from .converters import EafConverter - df = EafConverter.convert(path) - elif annotation_format == 'vtc_rttm': - from .converters import VtcConverter - df = VtcConverter.convert(path, source_file = filter) - elif annotation_format == 'vcm_rttm': - from .converters import VcmConverter - df = VcmConverter.convert(path, source_file = filter) - elif annotation_format == 'its': - from .converters import ItsConverter - df = ItsConverter.convert(path, recording_num = filter) - elif annotation_format == 'alice': - from .converters import AliceConverter - df = AliceConverter.convert(path, source_file = filter) - elif annotation_format == 'cha': - from .converters import ChatConverter - df = ChatConverter.convert(path) + elif annotation_format in converters: + converter = converters[annotation_format] + df = converter.convert(path, filter) else: raise ValueError("file format '{}' unknown for '{}'".format(annotation_format, path)) except: diff --git a/ChildProject/converters.py b/ChildProject/converters.py index 2830f9737..c6ad30991 100644 --- a/ChildProject/converters.py +++ b/ChildProject/converters.py @@ -128,7 +128,7 @@ def convert(filename: str, source_file: str = '') -> pd.DataFrame: class AliceConverter(AnnotationConverter): FORMAT = 'alice' - + @staticmethod def convert(filename: str, source_file: str = '') -> pd.DataFrame: df = pd.read_csv( @@ -287,10 +287,10 @@ def extract_from_regex(pattern, subject): return df class TextGridConverter(AnnotationConverter): - FORMAT = 'textgrid' + FORMAT = 'TextGrid' @staticmethod - def convert(filename: str) -> pd.DataFrame: + def convert(filename: str, filter = None) -> pd.DataFrame: import pympi textgrid = pympi.Praat.TextGrid(filename) @@ -330,7 +330,7 @@ class EafConverter(AnnotationConverter): FORMAT = 'eaf' @staticmethod - def convert(filename: str) -> pd.DataFrame: + def convert(filename: str, filter = None) -> pd.DataFrame: import pympi eaf = pympi.Elan.Eaf(filename) @@ -463,7 +463,7 @@ def role_to_addressee(role): return ChatConverter.ADDRESSEE_TABLE[ChatConverter.SPEAKER_ROLE_TO_TYPE[role]] @staticmethod - def convert(filename: str) -> pd.DataFrame: + def convert(filename: str, filter = None) -> pd.DataFrame: import pylangacq