Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Distributed metrics #123

Merged
merged 6 commits into from
Aug 26, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Fixed

- Fixed evaluation of metrics when using distributed setting.

## [v1.1.0rc4](https:/allenai/allennlp-models/releases/tag/v1.1.0rc4) - 2020-08-21

### Added
Expand Down
22 changes: 20 additions & 2 deletions allennlp_models/coref/metrics/mention_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from overrides import overrides

import torch
import torch.distributed as dist

from allennlp.common.util import is_distributed

from allennlp.training.metrics.metric import Metric

Expand All @@ -18,14 +21,29 @@ def __call__(
batched_top_spans: torch.Tensor,
batched_metadata: List[Dict[str, Any]],
):
num_gold_mentions = 0
num_recalled_mentions = 0
for top_spans, metadata in zip(batched_top_spans.tolist(), batched_metadata):

gold_mentions: Set[Tuple[int, int]] = {
mention for cluster in metadata["clusters"] for mention in cluster
}
predicted_spans: Set[Tuple[int, int]] = {(span[0], span[1]) for span in top_spans}
self._num_gold_mentions += len(gold_mentions)
self._num_recalled_mentions += len(gold_mentions & predicted_spans)

num_gold_mentions += len(gold_mentions)
num_recalled_mentions += len(gold_mentions & predicted_spans)

if is_distributed():
device = batched_top_spans.device
_num_gold_mentions = torch.tensor(num_gold_mentions).to(device)
_num_recalled_mentions = torch.tensor(num_recalled_mentions).to(device)
dist.all_reduce(_num_gold_mentions, op=dist.ReduceOp.SUM)
dist.all_reduce(_num_recalled_mentions, op=dist.ReduceOp.SUM)
num_gold_mentions = _num_gold_mentions.item()
num_recalled_mentions = _num_recalled_mentions.item()

self._num_gold_mentions += num_gold_mentions
self._num_recalled_mentions += num_recalled_mentions

@overrides
def get_metric(self, reset: bool = False) -> float:
Expand Down
25 changes: 24 additions & 1 deletion allennlp_models/rc/metrics/drop_em_and_f1.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from typing import Tuple, List, Union

import torch
import torch.distributed as dist

from allennlp.common.util import is_distributed

from allennlp.training.metrics.metric import Metric
from overrides import overrides

Expand Down Expand Up @@ -43,9 +48,27 @@ def __call__(self, prediction: Union[str, List], ground_truths: List): # type:
exact_match, f1_score = metric_max_over_ground_truths(
drop_em_and_f1, prediction, ground_truth_answer_strings
)
count = 1

if is_distributed():
if dist.get_backend() == "nccl":
device = torch.cuda.current_device()
else:
device = torch.device("cpu")
# Converting bool to float here, since we want to count the number of exact matches.
_exact_match = torch.tensor(exact_match, dtype=torch.float).to(device)
_f1_score = torch.tensor(f1_score).to(device)
_count = torch.tensor(count).to(device)
dist.all_reduce(_exact_match, op=dist.ReduceOp.SUM)
dist.all_reduce(_f1_score, op=dist.ReduceOp.SUM)
dist.all_reduce(_count, op=dist.ReduceOp.SUM)
exact_match = _exact_match.item()
f1_score = _f1_score.item()
count = _count.item()

self._total_em += exact_match
self._total_f1 += f1_score
self._count += 1
self._count += count

@overrides
def get_metric(self, reset: bool = False) -> Tuple[float, float]:
Expand Down
24 changes: 23 additions & 1 deletion allennlp_models/rc/metrics/squad_em_and_f1.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import Tuple

import torch
import torch.distributed as dist

from allennlp.common.util import is_distributed
from allennlp.training.metrics.metric import Metric
from overrides import overrides

Expand Down Expand Up @@ -33,9 +37,27 @@ def __call__(self, best_span_string, answer_strings):
f1_score = squad.metric_max_over_ground_truths(
squad.f1_score, best_span_string, answer_strings
)

count = 1
if is_distributed():
if dist.get_backend() == "nccl":
device = torch.cuda.current_device()
else:
device = torch.device("cpu")
# Converting bool to float here, since we want to count the number of exact matches.
_exact_match = torch.tensor(exact_match, dtype=torch.float).to(device)
AkshitaB marked this conversation as resolved.
Show resolved Hide resolved
_f1_score = torch.tensor(f1_score).to(device)
_count = torch.tensor(count).to(device)
dist.all_reduce(_exact_match, op=dist.ReduceOp.SUM)
dist.all_reduce(_f1_score, op=dist.ReduceOp.SUM)
dist.all_reduce(_count, op=dist.ReduceOp.SUM)
exact_match = _exact_match.item()
f1_score = _f1_score.item()
count = _count.item()

self._total_em += exact_match
self._total_f1 += f1_score
self._count += 1
self._count += count

@overrides
def get_metric(self, reset: bool = False) -> Tuple[float, float]:
Expand Down
39 changes: 39 additions & 0 deletions tests/coref/metrics/mention_recall_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch

from allennlp.common.testing import (
AllenNlpTestCase,
global_distributed_metric,
run_distributed_test,
)

from allennlp_models.coref.metrics.mention_recall import MentionRecall


class MentionRecallTest(AllenNlpTestCase):
def test_mention_recall(self):
metric = MentionRecall()

batched_top_spans = torch.tensor([[[2, 4], [1, 3]], [[5, 6], [7, 8]]])
batched_metadata = [{"clusters": [[(2, 4), (3, 5)]]}, {"clusters": [[(5, 6), (7, 8)]]}]

metric(batched_top_spans, batched_metadata)
recall = metric.get_metric()
assert recall == 0.75

def test_distributed_mention_recall(self):
batched_top_spans = [torch.tensor([[[2, 4], [1, 3]]]), torch.tensor([[[5, 6], [7, 8]]])]
batched_metadata = [[{"clusters": [[(2, 4), (3, 5)]]}], [{"clusters": [[(5, 6), (7, 8)]]}]]

metric_kwargs = {
"batched_top_spans": batched_top_spans,
"batched_metadata": batched_metadata,
}
desired_values = 0.75
run_distributed_test(
[-1, -1],
global_distributed_metric,
MentionRecall(),
metric_kwargs,
desired_values,
exact=True,
)
37 changes: 37 additions & 0 deletions tests/rc/metrics/drop_em_and_f1_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from allennlp.common.testing import (
AllenNlpTestCase,
global_distributed_metric,
run_distributed_test,
)

from allennlp_models.rc.metrics import DropEmAndF1


class DropEmAndF1Test(AllenNlpTestCase):
def test_drop_em_and_f1(self):
metric = DropEmAndF1()

metric(
"this is the best span", [{"spans": ["this is a good span", "something irrelevant"]}]
)
exact_match, f1_score = metric.get_metric()
assert exact_match == 0.0
assert f1_score == 0.38

def test_distributed_drop_em_and_f1(self):
prediction = ["this is the best span", "this is another span"]
ground_truths = [
[{"spans": ["this is a good span", "something irrelevant"]}],
[{"spans": ["this is another span"]}],
]

metric_kwargs = {"prediction": prediction, "ground_truths": ground_truths}
desired_values = (1 / 2, 1.38 / 2)
run_distributed_test(
[-1, -1],
global_distributed_metric,
DropEmAndF1(),
metric_kwargs,
desired_values,
exact=True,
)
36 changes: 36 additions & 0 deletions tests/rc/metrics/squad_em_and_f1_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from allennlp.common.testing import (
AllenNlpTestCase,
global_distributed_metric,
run_distributed_test,
)

from allennlp_models.rc.metrics import SquadEmAndF1


class SquadEmAndF1Test(AllenNlpTestCase):
def test_squad_em_and_f1(self):
metric = SquadEmAndF1()

metric("this is the best span", ["this is a good span", "something irrelevant"])

exact_match, f1_score = metric.get_metric()
assert exact_match == 0.0
assert f1_score == 0.75

def test_distributed_squad_em_and_f1(self):
best_span_string = ["this is the best span", "this is another span"]
answer_strings = [
["this is a good span", "something irrelevant"],
["this is another span", "this one is less perfect"],
]

metric_kwargs = {"best_span_string": best_span_string, "answer_strings": answer_strings}
desired_values = (1 / 2, 1.75 / 2)
run_distributed_test(
[-1, -1],
global_distributed_metric,
SquadEmAndF1(),
metric_kwargs,
desired_values,
exact=True,
)