This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 174
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* distributed metrics: SquadEmAndF1 * distributed metrics: DropEmAndF1 * update changelog * distributed metrics: MentionRecall * removing comment * changing float to int
- Loading branch information
Showing
7 changed files
with
183 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import torch | ||
|
||
from allennlp.common.testing import ( | ||
AllenNlpTestCase, | ||
global_distributed_metric, | ||
run_distributed_test, | ||
) | ||
|
||
from allennlp_models.coref.metrics.mention_recall import MentionRecall | ||
|
||
|
||
class MentionRecallTest(AllenNlpTestCase): | ||
def test_mention_recall(self): | ||
metric = MentionRecall() | ||
|
||
batched_top_spans = torch.tensor([[[2, 4], [1, 3]], [[5, 6], [7, 8]]]) | ||
batched_metadata = [{"clusters": [[(2, 4), (3, 5)]]}, {"clusters": [[(5, 6), (7, 8)]]}] | ||
|
||
metric(batched_top_spans, batched_metadata) | ||
recall = metric.get_metric() | ||
assert recall == 0.75 | ||
|
||
def test_distributed_mention_recall(self): | ||
batched_top_spans = [torch.tensor([[[2, 4], [1, 3]]]), torch.tensor([[[5, 6], [7, 8]]])] | ||
batched_metadata = [[{"clusters": [[(2, 4), (3, 5)]]}], [{"clusters": [[(5, 6), (7, 8)]]}]] | ||
|
||
metric_kwargs = { | ||
"batched_top_spans": batched_top_spans, | ||
"batched_metadata": batched_metadata, | ||
} | ||
desired_values = 0.75 | ||
run_distributed_test( | ||
[-1, -1], | ||
global_distributed_metric, | ||
MentionRecall(), | ||
metric_kwargs, | ||
desired_values, | ||
exact=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from allennlp.common.testing import ( | ||
AllenNlpTestCase, | ||
global_distributed_metric, | ||
run_distributed_test, | ||
) | ||
|
||
from allennlp_models.rc.metrics import DropEmAndF1 | ||
|
||
|
||
class DropEmAndF1Test(AllenNlpTestCase): | ||
def test_drop_em_and_f1(self): | ||
metric = DropEmAndF1() | ||
|
||
metric( | ||
"this is the best span", [{"spans": ["this is a good span", "something irrelevant"]}] | ||
) | ||
exact_match, f1_score = metric.get_metric() | ||
assert exact_match == 0.0 | ||
assert f1_score == 0.38 | ||
|
||
def test_distributed_drop_em_and_f1(self): | ||
prediction = ["this is the best span", "this is another span"] | ||
ground_truths = [ | ||
[{"spans": ["this is a good span", "something irrelevant"]}], | ||
[{"spans": ["this is another span"]}], | ||
] | ||
|
||
metric_kwargs = {"prediction": prediction, "ground_truths": ground_truths} | ||
desired_values = (1 / 2, 1.38 / 2) | ||
run_distributed_test( | ||
[-1, -1], | ||
global_distributed_metric, | ||
DropEmAndF1(), | ||
metric_kwargs, | ||
desired_values, | ||
exact=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from allennlp.common.testing import ( | ||
AllenNlpTestCase, | ||
global_distributed_metric, | ||
run_distributed_test, | ||
) | ||
|
||
from allennlp_models.rc.metrics import SquadEmAndF1 | ||
|
||
|
||
class SquadEmAndF1Test(AllenNlpTestCase): | ||
def test_squad_em_and_f1(self): | ||
metric = SquadEmAndF1() | ||
|
||
metric("this is the best span", ["this is a good span", "something irrelevant"]) | ||
|
||
exact_match, f1_score = metric.get_metric() | ||
assert exact_match == 0.0 | ||
assert f1_score == 0.75 | ||
|
||
def test_distributed_squad_em_and_f1(self): | ||
best_span_string = ["this is the best span", "this is another span"] | ||
answer_strings = [ | ||
["this is a good span", "something irrelevant"], | ||
["this is another span", "this one is less perfect"], | ||
] | ||
|
||
metric_kwargs = {"best_span_string": best_span_string, "answer_strings": answer_strings} | ||
desired_values = (1 / 2, 1.75 / 2) | ||
run_distributed_test( | ||
[-1, -1], | ||
global_distributed_metric, | ||
SquadEmAndF1(), | ||
metric_kwargs, | ||
desired_values, | ||
exact=True, | ||
) |