diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e1811c5f73d044..c347f6ed9501d8 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -672,6 +672,7 @@ name for name in dir(dummy_speech_objects) if not name.startswith("_") ] else: + _import_structure["models.audio_spectrogram_transformer"].append("ASTFeatureExtractor") _import_structure["models.mctct"].append("MCTCTFeatureExtractor") _import_structure["models.speech_to_text"].append("Speech2TextFeatureExtractor") @@ -857,7 +858,6 @@ "ASTForSequenceClassification", "ASTModel", "ASTPreTrainedModel", - "ASTFeatureExtractor", ] ) _import_structure["models.albert"].extend( @@ -3760,6 +3760,7 @@ except OptionalDependencyNotAvailable: from .utils.dummy_speech_objects import * else: + from .models.audio_spectrogram_transformer import ASTFeatureExtractor from .models.mctct import MCTCTFeatureExtractor from .models.speech_to_text import Speech2TextFeatureExtractor @@ -3917,7 +3918,6 @@ # PyTorch model imports from .models.audio_spectrogram_transformer import ( AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, - ASTFeatureExtractor, ASTForSequenceClassification, ASTModel, ASTPreTrainedModel, diff --git a/src/transformers/models/audio_spectrogram_transformer/__init__.py b/src/transformers/models/audio_spectrogram_transformer/__init__.py index 255dd1dfe9ed8e..008b43aea90f79 100644 --- a/src/transformers/models/audio_spectrogram_transformer/__init__.py +++ b/src/transformers/models/audio_spectrogram_transformer/__init__.py @@ -17,7 +17,7 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_speech_available, is_torch_available _import_structure = { @@ -33,7 +33,6 @@ except OptionalDependencyNotAvailable: pass else: - _import_structure["feature_extraction_audio_spectrogram_transformer"] = ["ASTFeatureExtractor"] _import_structure["modeling_audio_spectrogram_transformer"] = [ "AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", "ASTForSequenceClassification", @@ -41,6 +40,14 @@ "ASTPreTrainedModel", ] +try: + if not is_speech_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["feature_extraction_audio_spectrogram_transformer"] = ["ASTFeatureExtractor"] + if TYPE_CHECKING: from .configuration_audio_spectrogram_transformer import ( AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -53,7 +60,6 @@ except OptionalDependencyNotAvailable: pass else: - from .feature_extraction_audio_spectrogram_transformer import ASTFeatureExtractor from .modeling_audio_spectrogram_transformer import ( AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, ASTForSequenceClassification, @@ -61,6 +67,14 @@ ASTPreTrainedModel, ) + try: + if not is_speech_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .feature_extraction_audio_spectrogram_transformer import ASTFeatureExtractor + else: import sys diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index ff66ca874df47d..b622056e44566d 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -353,13 +353,6 @@ def load_tf_weights_in_albert(*args, **kwargs): AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None -class ASTFeatureExtractor(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - class ASTForSequenceClassification(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/dummy_speech_objects.py b/src/transformers/utils/dummy_speech_objects.py index ae5589292a4cf9..d1929dd2853b1b 100644 --- a/src/transformers/utils/dummy_speech_objects.py +++ b/src/transformers/utils/dummy_speech_objects.py @@ -3,6 +3,13 @@ from ..utils import DummyObject, requires_backends +class ASTFeatureExtractor(metaclass=DummyObject): + _backends = ["speech"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["speech"]) + + class MCTCTFeatureExtractor(metaclass=DummyObject): _backends = ["speech"]