Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Take the number of runs in the test for distributed metrics #5127

Merged
merged 2 commits into from
Apr 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

## [v2.3.0](https:/allenai/allennlp/releases/tag/v2.3.0) - 2021-04-14
- The test for distributed metrics now takes a parameter specifying how often you want to run it.


## [v2.3.0](https:/allenai/allennlp/releases/tag/v2.3.0) - 2021-04-14

### Added

- Ported the following Huggingface `LambdaLR`-based schedulers: `ConstantLearningRateScheduler`, `ConstantWithWarmupLearningRateScheduler`, `CosineWithWarmupLearningRateScheduler`, `CosineHardRestartsWithWarmupLearningRateScheduler`.
Expand Down
4 changes: 3 additions & 1 deletion allennlp/common/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,16 @@ def global_distributed_metric(
metric_kwargs: Dict[str, List[Any]],
desired_values: Dict[str, Any],
exact: Union[bool, Tuple[float, float]] = True,
number_of_runs: int = 1,
):
kwargs = {}

# Use the arguments meant for the process with rank `global_rank`.
for argname in metric_kwargs:
kwargs[argname] = metric_kwargs[argname][global_rank]

metric(**kwargs)
for _ in range(number_of_runs):
metric(**kwargs)

metrics = metric.get_metric(False)
if not isinstance(metrics, Dict) and not isinstance(desired_values, Dict):
Expand Down
26 changes: 2 additions & 24 deletions tests/training/metrics/categorical_accuracy_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Any, Dict, List, Tuple, Union

import pytest
import torch
from torch.testing import assert_allclose
Expand Down Expand Up @@ -196,30 +194,10 @@ def test_multiple_distributed_runs(self):
desired_accuracy = 0.5
run_distributed_test(
[-1, -1],
multiple_runs,
global_distributed_metric,
CategoricalAccuracy(),
metric_kwargs,
desired_accuracy,
exact=True,
number_of_runs=200,
)


def multiple_runs(
global_rank: int,
world_size: int,
gpu_id: Union[int, torch.device],
metric: CategoricalAccuracy,
metric_kwargs: Dict[str, List[Any]],
desired_values: Dict[str, Any],
exact: Union[bool, Tuple[float, float]] = True,
):

kwargs = {}
# Use the arguments meant for the process with rank `global_rank`.
for argname in metric_kwargs:
kwargs[argname] = metric_kwargs[argname][global_rank]

for i in range(200):
metric(**kwargs)

assert desired_values == metric.get_metric()