This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 174
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added AdversarialBiasMitigator tests and model (#281)
* added AdversarialBiasMitigator tests * added training config for adversarial bias mitigation * fixed error in training config * Added model. * updated model list in README * improved model description * updated snli modelcards with NLI bias metrics * fixed unitary results Co-authored-by: Arjun Subramonian <[email protected]> Co-authored-by: Arjun Subramonian <[email protected]>
- Loading branch information
1 parent
8d2d84f
commit bdf82a1
Showing
10 changed files
with
356 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
70 changes: 70 additions & 0 deletions
70
...modelcards/pair-classification-adversarial-binary-gender-bias-mitigated-roberta-snli.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
{ | ||
"id": "pair-classification-adversarial-binary-gender-bias-mitigated-roberta-snli", | ||
"registered_model_name": "adversarial_bias_mitigator", | ||
"registered_predictor_name": "textual_entailment", | ||
"display_name": "Adversarial Binary Gender Bias-Mitigated RoBERTa SNLI", | ||
"task_id": "textual_entailment", | ||
"model_details": { | ||
"description": "This `Model` implements a basic text classifier and feedforward regression adversary with an adversarial bias mitigator wrapper. The text is embedded into a text field using a RoBERTa-large model. The resulting sequence is pooled using a cls_pooler `Seq2VecEncoder` and then passed to a linear classification layer, which projects into the label space. Subsequently, a `FeedForwardRegressionAdversary` attempts to recover the coefficient of the static text embedding in the binary gender bias subspace. While the adversary's parameter updates are computed normally, the predictor's parameters are updated such that the predictor will not aid the adversary and will make it more difficult for the adversary to recover protected variables.", | ||
"short_description": "RoBERTa finetuned on SNLI with adversarial binary gender bias mitigation.", | ||
"developed_by": "Zhang at al", | ||
"contributed_by": "Arjun Subramonian", | ||
"date": "2021-06-17", | ||
"version": "1", | ||
"model_type": "RoBERTa", | ||
"paper": { | ||
"citation": "\n@article{Zhang2018MitigatingUB,\ntitle={Mitigating Unwanted Biases with Adversarial Learning},\nauthor={B. H. Zhang and B. Lemoine and Margaret Mitchell},\njournal={Proceedings of the 2018 AAAI/ACM Conference on AI, Ethics, and Society},\nyear={2018}\n}", | ||
"title": "Mitigating Unwanted Biases with Adversarial Learning", | ||
"url": "https://api.semanticscholar.org/CorpusID:9424845" | ||
}, | ||
"license": null, | ||
"contact": "[email protected]" | ||
}, | ||
"intended_use": { | ||
"primary_uses": null, | ||
"primary_users": null, | ||
"out_of_scope_use_cases": null | ||
}, | ||
"factors": { | ||
"relevant_factors": null, | ||
"evaluation_factors": null | ||
}, | ||
"metrics": { | ||
"model_performance_measures": "Accuracy, Net Neutral, Fraction Neutral, Threshold:tau", | ||
"decision_thresholds": null, | ||
"variation_approaches": null | ||
}, | ||
"evaluation_data": { | ||
"dataset": { | ||
"name": "On Measuring and Mitigating Biased Gender-Occupation Inferences SNLI Dataset", | ||
"url": "https:/sunipa/On-Measuring-and-Mitigating-Biased-Inferences-of-Word-Embeddings", | ||
"processed_url": "https://storage.googleapis.com/allennlp-public-models/binary-gender-bias-mitigated-snli-dataset.jsonl" | ||
}, | ||
"motivation": null, | ||
"preprocessing": null | ||
}, | ||
"training_data": { | ||
"dataset": { | ||
"name": "Stanford Natural Language Inference (SNLI) train set", | ||
"url": "https://nlp.stanford.edu/projects/snli/", | ||
"processed_url": "https://allennlp.s3.amazonaws.com/datasets/snli/snli_1.0_train.jsonl" | ||
}, | ||
"motivation": null, | ||
"preprocessing": null | ||
}, | ||
"quantitative_analyses": { | ||
"unitary_results": "Net Neutral: 0.613096454815352, Fraction Neutral: 0.6704967487937075, Threshold:0.5: 0.6637061892722586, Threshold:0.7: 0.49490217463150243", | ||
"intersectional_results": null | ||
}, | ||
"model_caveats_and_recommendations": { | ||
"caveats_and_recommendations": null | ||
}, | ||
"model_ethical_considerations": { | ||
"ethical_considerations": "Adversarial binary gender bias mitigation has been applied to this model. Nonetheless, the model will contain residual biases and bias mitigation does not guarantee entirely bias-free inferences." | ||
}, | ||
"model_usage": { | ||
"archive_file": "adversarial-binary-gender-bias-mitigated-snli-roberta.2021-06-17.tar.gz", | ||
"training_config": "pair_classification/adversarial_binary_gender_bias_mitigated_snli_roberta.jsonnet", | ||
"install_instructions": "pip install allennlp allennlp-models" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
120 changes: 120 additions & 0 deletions
120
test_fixtures/pair_classification/bias_mitigation/adversarial_experiment.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
{ | ||
"dataset_reader": { | ||
"type": "snli", | ||
"tokenizer": { | ||
"type": "pretrained_transformer", | ||
"model_name": "epwalsh/bert-xsmall-dummy", | ||
"add_special_tokens": false | ||
}, | ||
"token_indexers": { | ||
"tokens": { | ||
"type": "pretrained_transformer", | ||
"model_name": "epwalsh/bert-xsmall-dummy", | ||
"max_length": 512 | ||
} | ||
} | ||
}, | ||
"train_data_path": "test_fixtures/pair_classification/bias_mitigation/snli_train.jsonl", | ||
"validation_data_path": "test_fixtures/pair_classification/bias_mitigation/snli_dev.jsonl", | ||
"test_data_path": "test_fixtures/pair_classification/bias_mitigation/snli_test.jsonl", | ||
"model": { | ||
"type": "adversarial_bias_mitigator", | ||
"predictor": { | ||
"type": "basic_classifier", | ||
"text_field_embedder": { | ||
"token_embedders": { | ||
"tokens": { | ||
"type": "pretrained_transformer", | ||
"model_name": "epwalsh/bert-xsmall-dummy", | ||
"max_length": 512 | ||
} | ||
} | ||
}, | ||
"seq2vec_encoder": { | ||
"type": "cls_pooler", | ||
"embedding_dim": 20 | ||
}, | ||
"feedforward": { | ||
"input_dim": 20, | ||
"num_layers": 1, | ||
"hidden_dims": 20, | ||
"activations": "tanh" | ||
}, | ||
"dropout": 0.1, | ||
"namespace": "tags" | ||
}, | ||
"adversary": { | ||
"type": "feedforward_regression_adversary", | ||
"feedforward": { | ||
"input_dim": 3, | ||
"num_layers": 1, | ||
"hidden_dims": 1, | ||
"activations": "linear" | ||
} | ||
}, | ||
"bias_direction": { | ||
"type": "two_means", | ||
"seed_word_pairs_file": "https://raw.githubusercontent.com/tolga-b/debiaswe/4c3fa843ffff45115c43fe112d4283c91d225c09/data/definitional_pairs.json", | ||
"tokenizer": { | ||
"type": "pretrained_transformer", | ||
"model_name": "epwalsh/bert-xsmall-dummy", | ||
"max_length": 512 | ||
} | ||
}, | ||
"predictor_output_key": "probs" | ||
}, | ||
"data_loader": { | ||
"batch_sampler": { | ||
"type": "bucket", | ||
"sorting_keys": [ | ||
"tokens" | ||
], | ||
"padding_noise": 0.0, | ||
"batch_size": 80 | ||
} | ||
}, | ||
"trainer": { | ||
"num_epochs": 5, | ||
"grad_norm": 1.0, | ||
"patience": 500, | ||
"cuda_device": -1, | ||
"callbacks": [ | ||
"adversarial_bias_mitigator_backward" | ||
], | ||
"optimizer": { | ||
"type": "multi", | ||
"optimizers": { | ||
"predictor": { | ||
"type": "adam", | ||
"lr": 1e-5 | ||
}, | ||
"adversary": { | ||
"type": "adam", | ||
"lr": 1e-5 | ||
}, | ||
"default": { | ||
"type": "adam", | ||
"lr": 1e-5 | ||
} | ||
}, | ||
"parameter_groups": [ | ||
[ | ||
[ | ||
"^predictor" | ||
], | ||
{ | ||
"optimizer_name": "predictor" | ||
} | ||
], | ||
[ | ||
[ | ||
"^adversary" | ||
], | ||
{ | ||
"optimizer_name": "adversary" | ||
} | ||
] | ||
] | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
57 changes: 57 additions & 0 deletions
57
tests/pair_classification/models/adversarial_bias_mitigator_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from flaky import flaky | ||
import numpy | ||
|
||
from tests import FIXTURES_ROOT | ||
from allennlp.common.testing import ModelTestCase | ||
from allennlp_models.pair_classification import SnliReader | ||
from allennlp.fairness import ( | ||
AdversarialBiasMitigator, | ||
FeedForwardRegressionAdversary, | ||
AdversarialBiasMitigatorBackwardCallback, | ||
) | ||
|
||
|
||
class AdversarialBiasMitigatorTest(ModelTestCase): | ||
def setup_method(self): | ||
super().setup_method() | ||
self.set_up_model( | ||
FIXTURES_ROOT | ||
/ "pair_classification" | ||
/ "bias_mitigation" | ||
/ "adversarial_experiment.json", | ||
FIXTURES_ROOT / "pair_classification" / "bias_mitigation" / "snli_train.jsonl", | ||
) | ||
|
||
def test_adversarial_bias_mitigator_can_train_save_and_load(self): | ||
# BertModel pooler output is discarded so grads not computed | ||
self.ensure_model_can_train_save_and_load( | ||
self.param_file, | ||
gradients_to_ignore=set( | ||
[ | ||
"predictor._text_field_embedder.token_embedder_tokens.transformer_model.pooler.dense.weight", | ||
"predictor._text_field_embedder.token_embedder_tokens.transformer_model.pooler.dense.bias", | ||
] | ||
), | ||
which_loss="adversary_loss", | ||
) | ||
|
||
@flaky | ||
def test_batch_predictions_are_consistent(self): | ||
self.ensure_batch_predictions_are_consistent() | ||
|
||
def test_forward_pass_runs_correctly(self): | ||
training_tensors = self.dataset.as_tensor_dict() | ||
output_dict = self.model(**training_tensors) | ||
output_dict = self.model.make_output_human_readable(output_dict) | ||
assert "label" in output_dict.keys() | ||
probs = output_dict["probs"][0].data.numpy() | ||
numpy.testing.assert_almost_equal(numpy.sum(probs, -1), numpy.array([1])) | ||
|
||
def test_forward_on_instances_ignores_loss_key_when_batched(self): | ||
batch_outputs = self.model.forward_on_instances(self.dataset.instances) | ||
for output in batch_outputs: | ||
assert "loss" not in output.keys() | ||
|
||
# It should be in the single batch case, because we special case it. | ||
single_output = self.model.forward_on_instance(self.dataset.instances[0]) | ||
assert "loss" in single_output.keys() |
Oops, something went wrong.