Skip to content

Commit

Permalink
Rename to ForAudioClassification
Browse files Browse the repository at this point in the history
  • Loading branch information
Niels Rogge authored and Niels Rogge committed Nov 21, 2022
1 parent 4ead058 commit 474ba04
Show file tree
Hide file tree
Showing 10 changed files with 15 additions and 38 deletions.
4 changes: 2 additions & 2 deletions docs/source/en/model_doc/audio-spectrogram-transformer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ The original code can be found [here](https:/YuanGongND/ast).
[[autodoc]] ASTModel
- forward

## ASTForSequenceClassification
## ASTForAudioClassification

[[autodoc]] ASTForSequenceClassification
[[autodoc]] ASTForAudioClassification
- forward
4 changes: 2 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,7 @@
"AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"ASTModel",
"ASTPreTrainedModel",
"ASTForSequenceClassification",
"ASTForAudioClassification",
]
)
_import_structure["models.auto"].extend(
Expand Down Expand Up @@ -3958,7 +3958,7 @@
)
from .models.audio_spectrogram_transformer import (
AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
ASTForSequenceClassification,
ASTForAudioClassification,
ASTModel,
ASTPreTrainedModel,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
else:
_import_structure["modeling_audio_spectrogram_transformer"] = [
"AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"ASTForSequenceClassification",
"ASTForAudioClassification",
"ASTModel",
"ASTPreTrainedModel",
]
Expand All @@ -62,7 +62,7 @@
else:
from .modeling_audio_spectrogram_transformer import (
AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
ASTForSequenceClassification,
ASTForAudioClassification,
ASTModel,
ASTPreTrainedModel,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from datasets import load_dataset

from huggingface_hub import hf_hub_download
from transformers import ASTConfig, ASTFeatureExtractor, ASTForSequenceClassification
from transformers import ASTConfig, ASTFeatureExtractor, ASTForAudioClassification
from transformers.utils import logging


Expand Down Expand Up @@ -193,7 +193,7 @@ def convert_audio_spectrogram_transformer_checkpoint(model_name, pytorch_dump_fo
new_state_dict = convert_state_dict(state_dict, config)

# load 🤗 model
model = ASTForSequenceClassification(config)
model = ASTForAudioClassification(config)
model.eval()

model.load_state_dict(new_state_dict)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def forward(self, hidden_state):
""",
AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING,
)
class ASTForSequenceClassification(ASTPreTrainedModel):
class ASTForAudioClassification(ASTPreTrainedModel):
def __init__(self, config: ASTConfig) -> None:
super().__init__(config)

Expand Down
21 changes: 0 additions & 21 deletions src/transformers/models/audio_spectrogram_transformer/test.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,7 @@
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Audio Classification mapping
("audio-spectrogram-transformer", "ASTForSequenceClassification"),
("audio-spectrogram-transformer", "ASTForAudioClassification"),
("data2vec-audio", "Data2VecAudioForSequenceClassification"),
("hubert", "HubertForSequenceClassification"),
("sew", "SEWForSequenceClassification"),
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/utils/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,7 +1087,7 @@ def docstring_decorator(fn):
expected_loss=expected_loss,
)

if "SequenceClassification" in model_class and modality == "audio":
if ["SequenceClassification" in model_class or "AudioClassification" in model_class] and modality == "audio":
code_sample = sample_docstrings["AudioClassification"]
elif "SequenceClassification" in model_class:
code_sample = sample_docstrings["SequenceClassification"]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def load_tf_weights_in_albert(*args, **kwargs):
AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None


class ASTForSequenceClassification(metaclass=DummyObject):
class ASTForAudioClassification(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import torch
from torch import nn

from transformers import ASTForSequenceClassification, ASTModel
from transformers import ASTForAudioClassification, ASTModel
from transformers.models.audio_spectrogram_transformer.modeling_audio_spectrogram_transformer import (
AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
)
Expand Down Expand Up @@ -148,7 +148,7 @@ class ASTModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
ASTModel,
ASTForSequenceClassification,
ASTForAudioClassification,
)
if is_torch_available()
else ()
Expand Down Expand Up @@ -227,9 +227,7 @@ def default_feature_extractor(self):
def test_inference_audio_classification(self):

feature_extractor = self.default_feature_extractor
model = ASTForSequenceClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593").to(
torch_device
)
model = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593").to(torch_device)

feature_extractor = self.default_feature_extractor
audio, sampling_rate = prepare_audio()
Expand Down

0 comments on commit 474ba04

Please sign in to comment.