Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Test fixes #282

Merged
merged 4 commits into from
Jun 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions tests/rc/models/bidaf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,15 @@ class BidirectionalAttentionFlowTest(ModelTestCase):
def setup_method(self):
super().setup_method()
self.set_up_model(
FIXTURES_ROOT / "rc" / "bidaf" / "experiment.json", FIXTURES_ROOT / "rc" / "squad.json"
FIXTURES_ROOT / "rc" / "bidaf" / "experiment.json",
FIXTURES_ROOT / "rc" / "squad.json",
seed=27,
)
torch.use_deterministic_algorithms(True)

def teardown_method(self):
super().teardown_method()
torch.use_deterministic_algorithms(False)

@flaky
def test_forward_pass_runs_correctly(self):
Expand Down Expand Up @@ -53,7 +60,11 @@ def test_forward_pass_runs_correctly(self):
# `masked_softmax`...) have made this _very_ flaky...
@flaky(max_runs=5)
def test_model_can_train_save_and_load(self):
self.ensure_model_can_train_save_and_load(self.param_file, tolerance=1e-4)
self.ensure_model_can_train_save_and_load(
self.param_file,
tolerance=1e-4,
gradients_to_ignore={"_span_start_predictor._module.bias"},
)

@flaky
def test_batch_predictions_are_consistent(self):
Expand Down
14 changes: 11 additions & 3 deletions tests/rc/models/dialog_qa_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from allennlp.common.testing import ModelTestCase
from allennlp.data import Batch

from tests import FIXTURES_ROOT
import torch

import allennlp_models.rc
from tests import FIXTURES_ROOT


class DialogQATest(ModelTestCase):
Expand All @@ -12,9 +12,15 @@ def setup_method(self):
self.set_up_model(
FIXTURES_ROOT / "rc" / "dialog_qa" / "experiment.json",
FIXTURES_ROOT / "rc" / "dialog_qa" / "quac_sample.json",
seed=42,
)
self.batch = Batch(self.instances)
self.batch.index_instances(self.vocab)
torch.use_deterministic_algorithms(True)

def teardown_method(self):
super().teardown_method()
torch.use_deterministic_algorithms(False)

def test_forward_pass_runs_correctly(self):
training_tensors = self.batch.as_tensor_dict()
Expand All @@ -23,7 +29,9 @@ def test_forward_pass_runs_correctly(self):
assert "followup" in output_dict and "yesno" in output_dict

def test_model_can_train_save_and_load(self):
self.ensure_model_can_train_save_and_load(self.param_file, tolerance=1e-4)
self.ensure_model_can_train_save_and_load(
self.param_file, tolerance=1e-4, gradients_to_ignore={"_matrix_attention._bias"}
)

def test_batch_predictions_are_consistent(self):
self.ensure_batch_predictions_are_consistent()