Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Adding a metadata field to the basic classifier (#5104)
Browse files Browse the repository at this point in the history
* Adding metadata parameter to BasicClassifier

* Fix

* Updating the changelog

* reformatting

* updating parameter type

* fixing import

Co-authored-by: Dirk Groeneveld <[email protected]>
  • Loading branch information
2 people authored and epwalsh committed Apr 22, 2021
1 parent 01e5083 commit bd953d9
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Renamed module `allennlp.data.tokenizers.token` to `allennlp.data.tokenizers.token_class` to avoid
[this bug](https:/allenai/allennlp/issues/4819).
- `transformers` dependency updated to version 4.0.1.
- `BasicClassifier`'s forward method now takes a metadata field.

### Fixed

Expand Down
6 changes: 5 additions & 1 deletion allennlp/models/basic_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

from allennlp.data import TextFieldTensors, Vocabulary
from allennlp.data.fields import MetadataField
from allennlp.models.model import Model
from allennlp.modules import FeedForward, Seq2SeqEncoder, Seq2VecEncoder, TextFieldEmbedder
from allennlp.nn import InitializerApplicator, util
Expand Down Expand Up @@ -89,7 +90,10 @@ def __init__(
initializer(self)

def forward( # type: ignore
self, tokens: TextFieldTensors, label: torch.IntTensor = None
self,
tokens: TextFieldTensors,
label: torch.IntTensor = None,
metadata: MetadataField = None,
) -> Dict[str, torch.Tensor]:

"""
Expand Down

0 comments on commit bd953d9

Please sign in to comment.