From 6d724f28724a45f7841b0c0971a0b26983394540 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Tue, 25 Aug 2020 14:13:17 -0400 Subject: [PATCH 1/6] distributed metrics: SquadEmAndF1 --- allennlp_models/rc/metrics/squad_em_and_f1.py | 17 +++++++++ tests/rc/metrics/squad_em_and_f1_test.py | 37 +++++++++++++++++++ 2 files changed, 54 insertions(+) create mode 100644 tests/rc/metrics/squad_em_and_f1_test.py diff --git a/allennlp_models/rc/metrics/squad_em_and_f1.py b/allennlp_models/rc/metrics/squad_em_and_f1.py index fdbb14c36..6de7cbc03 100644 --- a/allennlp_models/rc/metrics/squad_em_and_f1.py +++ b/allennlp_models/rc/metrics/squad_em_and_f1.py @@ -1,5 +1,9 @@ from typing import Tuple +import torch +import torch.distributed as dist + +from allennlp.common.util import is_distributed from allennlp.training.metrics.metric import Metric from overrides import overrides @@ -33,6 +37,19 @@ def __call__(self, best_span_string, answer_strings): f1_score = squad.metric_max_over_ground_truths( squad.f1_score, best_span_string, answer_strings ) + if is_distributed(): + if dist.get_backend() == "nccl": + device = torch.cuda.current_device() + else: + device = torch.device("cpu") + # Converting bool to float here, since we want to count the number of exact matches. + _exact_match = torch.tensor(exact_match, dtype=torch.float).to(device) + _f1_score = torch.tensor(f1_score).to(device) + dist.all_reduce(_exact_match, op=dist.ReduceOp.SUM) + dist.all_reduce(_f1_score, op=dist.ReduceOp.SUM) + exact_match = _exact_match.item() + f1_score = _f1_score.item() + self._total_em += exact_match self._total_f1 += f1_score self._count += 1 diff --git a/tests/rc/metrics/squad_em_and_f1_test.py b/tests/rc/metrics/squad_em_and_f1_test.py new file mode 100644 index 000000000..4c70a586e --- /dev/null +++ b/tests/rc/metrics/squad_em_and_f1_test.py @@ -0,0 +1,37 @@ +from allennlp.common.testing import ( + AllenNlpTestCase, + global_distributed_metric, + run_distributed_test, +) + +from allennlp_models.rc.metrics import SquadEmAndF1 + + +class SquadEmAndF1Test(AllenNlpTestCase): + def test_squad_em_and_f1(self): + metric = SquadEmAndF1() + + metric("this is the best span", ["this is a good span", "something irrelevant"]) + # metric("this is another span", ["this is another span", "this one is less perfect"]) + + exact_match, f1_score = metric.get_metric() + assert exact_match == 0.0 + assert f1_score == 0.75 + + def test_distributed_squad_em_and_f1(self): + best_span_string = ["this is the best span", "this is another span"] + answer_strings = [ + ["this is a good span", "something irrelevant"], + ["this is another span", "this one is less perfect"], + ] + + metric_kwargs = {"best_span_string": best_span_string, "answer_strings": answer_strings} + desired_values = (1, 1.75) + run_distributed_test( + [-1, -1], + global_distributed_metric, + SquadEmAndF1(), + metric_kwargs, + desired_values, + exact=True, + ) From 4751f47704f4e6a16008f0dc6d883a581edf8aea Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Tue, 25 Aug 2020 16:48:32 -0400 Subject: [PATCH 2/6] distributed metrics: DropEmAndF1 --- allennlp_models/rc/metrics/drop_em_and_f1.py | 26 ++++++++++++- allennlp_models/rc/metrics/squad_em_and_f1.py | 7 +++- tests/rc/metrics/drop_em_and_f1_test.py | 37 +++++++++++++++++++ tests/rc/metrics/squad_em_and_f1_test.py | 3 +- 4 files changed, 69 insertions(+), 4 deletions(-) create mode 100644 tests/rc/metrics/drop_em_and_f1_test.py diff --git a/allennlp_models/rc/metrics/drop_em_and_f1.py b/allennlp_models/rc/metrics/drop_em_and_f1.py index 6f89d308f..edecb5937 100644 --- a/allennlp_models/rc/metrics/drop_em_and_f1.py +++ b/allennlp_models/rc/metrics/drop_em_and_f1.py @@ -1,5 +1,10 @@ from typing import Tuple, List, Union +import torch +import torch.distributed as dist + +from allennlp.common.util import is_distributed + from allennlp.training.metrics.metric import Metric from overrides import overrides @@ -43,9 +48,28 @@ def __call__(self, prediction: Union[str, List], ground_truths: List): # type: exact_match, f1_score = metric_max_over_ground_truths( drop_em_and_f1, prediction, ground_truth_answer_strings ) + count = 1 + + if is_distributed(): + print(exact_match, f1_score) + if dist.get_backend() == "nccl": + device = torch.cuda.current_device() + else: + device = torch.device("cpu") + # Converting bool to float here, since we want to count the number of exact matches. + _exact_match = torch.tensor(exact_match, dtype=torch.float).to(device) + _f1_score = torch.tensor(f1_score).to(device) + _count = torch.tensor(count).to(device) + dist.all_reduce(_exact_match, op=dist.ReduceOp.SUM) + dist.all_reduce(_f1_score, op=dist.ReduceOp.SUM) + dist.all_reduce(_count, op=dist.ReduceOp.SUM) + exact_match = _exact_match.item() + f1_score = _f1_score.item() + count = _count.item() + self._total_em += exact_match self._total_f1 += f1_score - self._count += 1 + self._count += count @overrides def get_metric(self, reset: bool = False) -> Tuple[float, float]: diff --git a/allennlp_models/rc/metrics/squad_em_and_f1.py b/allennlp_models/rc/metrics/squad_em_and_f1.py index 6de7cbc03..4491d770f 100644 --- a/allennlp_models/rc/metrics/squad_em_and_f1.py +++ b/allennlp_models/rc/metrics/squad_em_and_f1.py @@ -37,6 +37,8 @@ def __call__(self, best_span_string, answer_strings): f1_score = squad.metric_max_over_ground_truths( squad.f1_score, best_span_string, answer_strings ) + + count = 1 if is_distributed(): if dist.get_backend() == "nccl": device = torch.cuda.current_device() @@ -45,14 +47,17 @@ def __call__(self, best_span_string, answer_strings): # Converting bool to float here, since we want to count the number of exact matches. _exact_match = torch.tensor(exact_match, dtype=torch.float).to(device) _f1_score = torch.tensor(f1_score).to(device) + _count = torch.tensor(count).to(device) dist.all_reduce(_exact_match, op=dist.ReduceOp.SUM) dist.all_reduce(_f1_score, op=dist.ReduceOp.SUM) + dist.all_reduce(_count, op=dist.ReduceOp.SUM) exact_match = _exact_match.item() f1_score = _f1_score.item() + count = _count.item() self._total_em += exact_match self._total_f1 += f1_score - self._count += 1 + self._count += count @overrides def get_metric(self, reset: bool = False) -> Tuple[float, float]: diff --git a/tests/rc/metrics/drop_em_and_f1_test.py b/tests/rc/metrics/drop_em_and_f1_test.py new file mode 100644 index 000000000..29b995b5c --- /dev/null +++ b/tests/rc/metrics/drop_em_and_f1_test.py @@ -0,0 +1,37 @@ +from allennlp.common.testing import ( + AllenNlpTestCase, + global_distributed_metric, + run_distributed_test, +) + +from allennlp_models.rc.metrics import DropEmAndF1 + + +class DropEmAndF1Test(AllenNlpTestCase): + def test_squad_em_and_f1(self): + metric = DropEmAndF1() + + metric( + "this is the best span", [{"spans": ["this is a good span", "something irrelevant"]}] + ) + exact_match, f1_score = metric.get_metric() + assert exact_match == 0.0 + assert f1_score == 0.38 + + def test_distributed_squad_em_and_f1(self): + prediction = ["this is the best span", "this is another span"] + ground_truths = [ + [{"spans": ["this is a good span", "something irrelevant"]}], + [{"spans": ["this is another span"]}], + ] + + metric_kwargs = {"prediction": prediction, "ground_truths": ground_truths} + desired_values = (1 / 2, 1.38 / 2) + run_distributed_test( + [-1, -1], + global_distributed_metric, + DropEmAndF1(), + metric_kwargs, + desired_values, + exact=True, + ) diff --git a/tests/rc/metrics/squad_em_and_f1_test.py b/tests/rc/metrics/squad_em_and_f1_test.py index 4c70a586e..e77773364 100644 --- a/tests/rc/metrics/squad_em_and_f1_test.py +++ b/tests/rc/metrics/squad_em_and_f1_test.py @@ -12,7 +12,6 @@ def test_squad_em_and_f1(self): metric = SquadEmAndF1() metric("this is the best span", ["this is a good span", "something irrelevant"]) - # metric("this is another span", ["this is another span", "this one is less perfect"]) exact_match, f1_score = metric.get_metric() assert exact_match == 0.0 @@ -26,7 +25,7 @@ def test_distributed_squad_em_and_f1(self): ] metric_kwargs = {"best_span_string": best_span_string, "answer_strings": answer_strings} - desired_values = (1, 1.75) + desired_values = (1 / 2, 1.75 / 2) run_distributed_test( [-1, -1], global_distributed_metric, From 27813d74daf6b6330d729312ec5017b074313a06 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Tue, 25 Aug 2020 16:50:26 -0400 Subject: [PATCH 3/6] update changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7bf8004df..8cf6ce5c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Fixed + +- Fixed evaluation of metrics when using distributed setting. + ## [v1.1.0rc4](https://github.com/allenai/allennlp-models/releases/tag/v1.1.0rc4) - 2020-08-21 ### Added From f8a0f5ea8919365dfb13e3d671ac551bc82a8109 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Tue, 25 Aug 2020 17:59:03 -0400 Subject: [PATCH 4/6] distributed metrics: MentionRecall --- .../coref/metrics/mention_recall.py | 23 ++++++++++- allennlp_models/rc/metrics/drop_em_and_f1.py | 1 - tests/coref/metrics/mention_recall_test.py | 39 +++++++++++++++++++ tests/rc/metrics/drop_em_and_f1_test.py | 4 +- 4 files changed, 62 insertions(+), 5 deletions(-) create mode 100644 tests/coref/metrics/mention_recall_test.py diff --git a/allennlp_models/coref/metrics/mention_recall.py b/allennlp_models/coref/metrics/mention_recall.py index 01b72c8fc..e52c8dec6 100644 --- a/allennlp_models/coref/metrics/mention_recall.py +++ b/allennlp_models/coref/metrics/mention_recall.py @@ -2,6 +2,9 @@ from overrides import overrides import torch +import torch.distributed as dist + +from allennlp.common.util import is_distributed from allennlp.training.metrics.metric import Metric @@ -18,14 +21,30 @@ def __call__( batched_top_spans: torch.Tensor, batched_metadata: List[Dict[str, Any]], ): + num_gold_mentions = 0 + num_recalled_mentions = 0 for top_spans, metadata in zip(batched_top_spans.tolist(), batched_metadata): gold_mentions: Set[Tuple[int, int]] = { mention for cluster in metadata["clusters"] for mention in cluster } predicted_spans: Set[Tuple[int, int]] = {(span[0], span[1]) for span in top_spans} - self._num_gold_mentions += len(gold_mentions) - self._num_recalled_mentions += len(gold_mentions & predicted_spans) + + num_gold_mentions += len(gold_mentions) + num_recalled_mentions += len(gold_mentions & predicted_spans) + + if is_distributed(): + device = batched_top_spans.device + # Converting bool to float here, since we want to count the number of exact matches. + _num_gold_mentions = torch.tensor(num_gold_mentions).to(device) + _num_recalled_mentions = torch.tensor(num_recalled_mentions).to(device) + dist.all_reduce(_num_gold_mentions, op=dist.ReduceOp.SUM) + dist.all_reduce(_num_recalled_mentions, op=dist.ReduceOp.SUM) + num_gold_mentions = _num_gold_mentions.item() + num_recalled_mentions = _num_recalled_mentions.item() + + self._num_gold_mentions += num_gold_mentions + self._num_recalled_mentions += num_recalled_mentions @overrides def get_metric(self, reset: bool = False) -> float: diff --git a/allennlp_models/rc/metrics/drop_em_and_f1.py b/allennlp_models/rc/metrics/drop_em_and_f1.py index edecb5937..b210c7d32 100644 --- a/allennlp_models/rc/metrics/drop_em_and_f1.py +++ b/allennlp_models/rc/metrics/drop_em_and_f1.py @@ -51,7 +51,6 @@ def __call__(self, prediction: Union[str, List], ground_truths: List): # type: count = 1 if is_distributed(): - print(exact_match, f1_score) if dist.get_backend() == "nccl": device = torch.cuda.current_device() else: diff --git a/tests/coref/metrics/mention_recall_test.py b/tests/coref/metrics/mention_recall_test.py new file mode 100644 index 000000000..a8907ff4e --- /dev/null +++ b/tests/coref/metrics/mention_recall_test.py @@ -0,0 +1,39 @@ +import torch + +from allennlp.common.testing import ( + AllenNlpTestCase, + global_distributed_metric, + run_distributed_test, +) + +from allennlp_models.coref.metrics.mention_recall import MentionRecall + + +class MentionRecallTest(AllenNlpTestCase): + def test_mention_recall(self): + metric = MentionRecall() + + batched_top_spans = torch.tensor([[[2, 4], [1, 3]], [[5, 6], [7, 8]]]) + batched_metadata = [{"clusters": [[(2, 4), (3, 5)]]}, {"clusters": [[(5, 6), (7, 8)]]}] + + metric(batched_top_spans, batched_metadata) + recall = metric.get_metric() + assert recall == 0.75 + + def test_distributed_mention_recall(self): + batched_top_spans = [torch.tensor([[[2, 4], [1, 3]]]), torch.tensor([[[5, 6], [7, 8]]])] + batched_metadata = [[{"clusters": [[(2, 4), (3, 5)]]}], [{"clusters": [[(5, 6), (7, 8)]]}]] + + metric_kwargs = { + "batched_top_spans": batched_top_spans, + "batched_metadata": batched_metadata, + } + desired_values = 0.75 + run_distributed_test( + [-1, -1], + global_distributed_metric, + MentionRecall(), + metric_kwargs, + desired_values, + exact=True, + ) diff --git a/tests/rc/metrics/drop_em_and_f1_test.py b/tests/rc/metrics/drop_em_and_f1_test.py index 29b995b5c..807415096 100644 --- a/tests/rc/metrics/drop_em_and_f1_test.py +++ b/tests/rc/metrics/drop_em_and_f1_test.py @@ -8,7 +8,7 @@ class DropEmAndF1Test(AllenNlpTestCase): - def test_squad_em_and_f1(self): + def test_drop_em_and_f1(self): metric = DropEmAndF1() metric( @@ -18,7 +18,7 @@ def test_squad_em_and_f1(self): assert exact_match == 0.0 assert f1_score == 0.38 - def test_distributed_squad_em_and_f1(self): + def test_distributed_drop_em_and_f1(self): prediction = ["this is the best span", "this is another span"] ground_truths = [ [{"spans": ["this is a good span", "something irrelevant"]}], From 36b82c67e712553c43d4c52b8a1d2bb6bda233a4 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Tue, 25 Aug 2020 18:24:07 -0400 Subject: [PATCH 5/6] removing comment --- allennlp_models/coref/metrics/mention_recall.py | 1 - 1 file changed, 1 deletion(-) diff --git a/allennlp_models/coref/metrics/mention_recall.py b/allennlp_models/coref/metrics/mention_recall.py index e52c8dec6..2e10413bc 100644 --- a/allennlp_models/coref/metrics/mention_recall.py +++ b/allennlp_models/coref/metrics/mention_recall.py @@ -35,7 +35,6 @@ def __call__( if is_distributed(): device = batched_top_spans.device - # Converting bool to float here, since we want to count the number of exact matches. _num_gold_mentions = torch.tensor(num_gold_mentions).to(device) _num_recalled_mentions = torch.tensor(num_recalled_mentions).to(device) dist.all_reduce(_num_gold_mentions, op=dist.ReduceOp.SUM) From 6dccb36534b92419ca7774019200537a6d1f94e6 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Wed, 26 Aug 2020 11:57:04 -0400 Subject: [PATCH 6/6] changing float to int --- allennlp_models/rc/metrics/drop_em_and_f1.py | 4 ++-- allennlp_models/rc/metrics/squad_em_and_f1.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/allennlp_models/rc/metrics/drop_em_and_f1.py b/allennlp_models/rc/metrics/drop_em_and_f1.py index b210c7d32..862ecf9c3 100644 --- a/allennlp_models/rc/metrics/drop_em_and_f1.py +++ b/allennlp_models/rc/metrics/drop_em_and_f1.py @@ -55,8 +55,8 @@ def __call__(self, prediction: Union[str, List], ground_truths: List): # type: device = torch.cuda.current_device() else: device = torch.device("cpu") - # Converting bool to float here, since we want to count the number of exact matches. - _exact_match = torch.tensor(exact_match, dtype=torch.float).to(device) + # Converting bool to int here, since we want to count the number of exact matches. + _exact_match = torch.tensor(exact_match, dtype=torch.int).to(device) _f1_score = torch.tensor(f1_score).to(device) _count = torch.tensor(count).to(device) dist.all_reduce(_exact_match, op=dist.ReduceOp.SUM) diff --git a/allennlp_models/rc/metrics/squad_em_and_f1.py b/allennlp_models/rc/metrics/squad_em_and_f1.py index 4491d770f..631ac1dd3 100644 --- a/allennlp_models/rc/metrics/squad_em_and_f1.py +++ b/allennlp_models/rc/metrics/squad_em_and_f1.py @@ -44,8 +44,8 @@ def __call__(self, best_span_string, answer_strings): device = torch.cuda.current_device() else: device = torch.device("cpu") - # Converting bool to float here, since we want to count the number of exact matches. - _exact_match = torch.tensor(exact_match, dtype=torch.float).to(device) + # Converting bool to int here, since we want to count the number of exact matches. + _exact_match = torch.tensor(exact_match, dtype=torch.int).to(device) _f1_score = torch.tensor(f1_score).to(device) _count = torch.tensor(count).to(device) dist.all_reduce(_exact_match, op=dist.ReduceOp.SUM)