Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Adversarial bias mitigation #5269

Merged
merged 12 commits into from
Jun 17, 2021
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Added `on_backward` training callback which allows for control over backpropagation and gradient manipulation.
- Added `AdversarialBiasMitigator`, a Model wrapper to adversarially mitigate biases in predictions produced by a pretrained model for a downstream task.
- Added `which_loss` parameter to `ensure_model_can_train_save_and_load` in `ModelTestCase` to specify which loss to test.

### Fixed

- Fixed Broken link in `allennlp.fairness.fairness_metrics.Separation` docs
- Ensured all `allennlp` submodules are imported with `allennlp.common.plugins.import_plugins()`.
- Fixed `IndexOutOfBoundsException` in `MultiOptimizer` when checking if optimizer received any parameters.

### Changed

- Changed behavior of `MultiOptimizer` so that while a default optimizer is still required, an error is not thrown if the default optimizer receives no parameters.


## [v2.5.0](https:/allenai/allennlp/releases/tag/v2.5.0) - 2021-06-03
Expand Down
11 changes: 8 additions & 3 deletions allennlp/common/testing/model_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def ensure_model_can_train_save_and_load(
metric_terminal_value: float = None,
metric_tolerance: float = 1e-4,
disable_dropout: bool = True,
which_loss: str = "loss",
seed: int = None,
):
"""
Expand Down Expand Up @@ -114,6 +115,9 @@ def ensure_model_can_train_save_and_load(
disable_dropout : `bool`, optional (default = `True`)
If True we will set all dropout to 0 before checking gradients. (Otherwise, with small
datasets, you may get zero gradients because of unlucky dropout.)
which_loss: `str`, optional (default = `"loss"`)
Specifies which loss to test. For example, which_loss may be "adversary_loss" for
`adversarial_bias_mitigator`.
"""
if seed is not None:
random.seed(seed)
Expand Down Expand Up @@ -175,7 +179,7 @@ def ensure_model_can_train_save_and_load(
# Check gradients are None for non-trainable parameters and check that
# trainable parameters receive some gradient if they are trainable.
self.check_model_computes_gradients_correctly(
model, model_batch, gradients_to_ignore, disable_dropout
model, model_batch, gradients_to_ignore, disable_dropout, which_loss
)

# The datasets themselves should be identical.
Expand Down Expand Up @@ -206,7 +210,7 @@ def ensure_model_can_train_save_and_load(
# Check loaded model's loss exists and we can compute gradients, for continuing training.
loaded_model.train()
loaded_model_predictions = loaded_model(**loaded_batch)
loaded_model_loss = loaded_model_predictions["loss"]
loaded_model_loss = loaded_model_predictions[which_loss]
assert loaded_model_loss is not None
loaded_model_loss.backward()

Expand Down Expand Up @@ -306,6 +310,7 @@ def check_model_computes_gradients_correctly(
model_batch: Dict[str, Union[Any, Dict[str, Any]]],
params_to_ignore: Set[str] = None,
disable_dropout: bool = True,
which_loss: str = "loss",
):
print("Checking gradients")
for p in model.parameters():
Expand All @@ -322,7 +327,7 @@ def check_model_computes_gradients_correctly(
setattr(module, "p", 0)

result = model(**model_batch)
result["loss"].backward()
result[which_loss].backward()
has_zero_or_none_grads = {}
for name, parameter in model.named_parameters():
zeros = torch.zeros(parameter.size())
Expand Down
5 changes: 5 additions & 0 deletions allennlp/fairness/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,8 @@
TwoMeansBiasDirectionWrapper,
ClassificationNormalBiasDirectionWrapper,
)
from allennlp.fairness.adversarial_bias_mitigator import (
AdversarialBiasMitigator,
FeedForwardRegressionAdversary,
AdversarialBiasMitigatorBackwardCallback,
)
300 changes: 300 additions & 0 deletions allennlp/fairness/adversarial_bias_mitigator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
"""
A Model wrapper to adversarially mitigate biases in
predictions produced by a pretrained model for a downstream task.

The documentation and explanations are heavily based on:
Zhang, B.H., Lemoine, B., & Mitchell, M. (2018).
[Mitigating Unwanted Biases with Adversarial Learning]
(https://api.semanticscholar.org/CorpusID:9424845).
Proceedings of the 2018 AAAI/ACM Conference on AI, Ethics, and Society.
and [Mitigating Unwanted Biases in Word Embeddings
with Adversarial Learning](https://colab.research.google.com/notebooks/
ml_fairness/adversarial_debiasing.ipynb) colab notebook.

Adversarial networks mitigate some biases based on the idea that
predicting an outcome Y given an input X should ideally be independent
of some protected variable Z. Informally, "knowing Y would not help
you predict Z any better than chance" (Zaldivar et al., 2018). This
Comment on lines +16 to +17
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be the other way round? "knowing Z would not help you predict Y"? Or is it stating that knowing the outcome shouldn't give you the information about the protected variable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latter, it's stating that knowing the outcome shouldn't give you information about the protected variable.

can be achieved using two networks in a series, where the first attempts to predict
Y using X as input, and the second attempts to use the predicted value of Y to recover Z.
Please refer to Figure 1 of [Mitigating Unwanted Biases with Adversarial Learning]
(https://api.semanticscholar.org/CorpusID:9424845). Ideally, we would
like the first network to predict Y without permitting the second network to predict
Z any better than chance.

For common NLP tasks, it's usually clear what X and Y are,
but Z is not always available. We can construct our own Z by:
1) computing a bias direction (e.g. for binary gender)
2) computing the inner product of static sentence embeddings and the bias direction

Training adversarial networks is extremely difficult. It is important to:
1) lower the step size of both the predictor and adversary to train both
models slowly to avoid parameters diverging,
2) initialize the parameters of the adversary to be small to avoid the predictor
overfitting against a sub-optimal adversary,
3) increase the adversary’s learning rate to prevent divergence if the
predictor is too good at hiding the protected variable from the adversary.
"""

from overrides import overrides
from typing import Dict, Optional
import torch

from allennlp.data import Vocabulary
from allennlp.fairness.bias_direction_wrappers import BiasDirectionWrapper
from allennlp.modules.feedforward import FeedForward
from allennlp.models import Model
from allennlp.nn import InitializerApplicator
from allennlp.nn.util import find_embedding_layer
from allennlp.training.callbacks.callback import TrainerCallback
from allennlp.training.callbacks.backward import OnBackwardException
from allennlp.training.gradient_descent_trainer import GradientDescentTrainer


class _AdversaryLabelHook:
def __init__(self, predetermined_bias_direction):
self.predetermined_bias_direction = predetermined_bias_direction

def __call__(self, module, module_in, module_out):
"""
Called as forward hook.
"""
with torch.no_grad():
# mean pooling over static word embeddings to get sentence embedding
module_out = module_out.mean(dim=1)
self.adversary_label = torch.matmul(
module_out, self.predetermined_bias_direction.to(module_out.device)
).unsqueeze(-1)


@Model.register("adversarial_bias_mitigator")
class AdversarialBiasMitigator(Model):
"""
Wrapper class to adversarially mitigate biases in any pretrained Model.

# Parameters

vocab : `Vocabulary`
Vocabulary of predictor.
predictor : `Model`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not strictly an issue. We use the term predictor differently in the library elsewhere; should we change the name here? If this is adhering to the paper's terminology, it's probably okay to keep it as is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'm adhering to the paper's terminology.

Model for which to mitigate biases.
adversary : `Model`
Model that attempts to recover protected variable values from predictor's predictions.
bias_direction : `BiasDirectionWrapper`
Bias direction used by adversarial bias mitigator.
predictor_output_key : `str`
Key corresponding to output in `output_dict` of predictor that should be passed as input
to adversary.

!!! Note
adversary must use same vocab as predictor, if it requires a vocab.
"""

def __init__(
self,
vocab: Vocabulary,
predictor: Model,
adversary: Model,
bias_direction: BiasDirectionWrapper,
predictor_output_key: str,
**kwargs,
):
super().__init__(vocab, **kwargs)

self.predictor = predictor
self.adversary = adversary

# want to keep adversary label hook during evaluation
embedding_layer = find_embedding_layer(self.predictor)
self.bias_direction = bias_direction
self.predetermined_bias_direction = self.bias_direction(embedding_layer)
self._adversary_label_hook = _AdversaryLabelHook(self.predetermined_bias_direction)
embedding_layer.register_forward_hook(self._adversary_label_hook)

self.vocab = self.predictor.vocab
self._regularizer = self.predictor._regularizer

self.predictor_output_key = predictor_output_key

@overrides
def train(self, mode: bool = True):
super().train(mode)
self.predictor.train(mode)
self.adversary.train(mode)
# appropriately change requires_grad
# in bias direction when train() and
# eval() are called
self.bias_direction.train(mode)

@overrides
def forward(self, *args, **kwargs):
predictor_output_dict = self.predictor.forward(*args, **kwargs)
adversary_output_dict = self.adversary.forward(
predictor_output_dict[self.predictor_output_key],
self._adversary_label_hook.adversary_label,
)
# prepend "adversary_" to every key in adversary_output_dict
# to distinguish from predictor_output_dict keys
adversary_output_dict = {("adversary_" + k): v for k, v in adversary_output_dict.items()}
output_dict = {**predictor_output_dict, **adversary_output_dict}
return output_dict

# Delegate Model function calls to predictor
# Currently doing this manually because difficult to
# dynamically forward __getattribute__ due to
# behind-the-scenes usage of dunder attributes by torch.nn.Module
# and predictor inheriting from Model
# Assumes Model is relatively stable
@overrides
def forward_on_instance(self, *args, **kwargs):
return self.predictor.forward_on_instance(*args, **kwargs)

@overrides
def forward_on_instances(self, *args, **kwargs):
return self.predictor.forward_on_instances(*args, **kwargs)

@overrides
def get_regularization_penalty(self, *args, **kwargs):
return self.predictor.get_regularization_penalty(*args, **kwargs)

@overrides
def get_parameters_for_histogram_logging(self, *args, **kwargs):
return self.predictor.get_parameters_for_histogram_logging(*args, **kwargs)

@overrides
def get_parameters_for_histogram_tensorboard_logging(self, *args, **kwargs):
return self.predictor.get_parameters_for_histogram_tensorboard_logging(*args, **kwargs)

@overrides
def make_output_human_readable(self, *args, **kwargs):
return self.predictor.make_output_human_readable(*args, **kwargs)

@overrides
def get_metrics(self, *args, **kwargs):
return self.predictor.get_metrics(*args, **kwargs)

@overrides
def _get_prediction_device(self, *args, **kwargs):
return self.predictor._get_prediction_device(*args, **kwargs)

@overrides
def _maybe_warn_for_unseparable_batches(self, *args, **kwargs):
return self.predictor._maybe_warn_for_unseparable_batches(*args, **kwargs)

@overrides
def extend_embedder_vocab(self, *args, **kwargs):
return self.predictor.extend_embedder_vocab(*args, **kwargs)


@Model.register("feedforward_regression_adversary")
class FeedForwardRegressionAdversary(Model):
"""
This `Model` implements a simple feedforward regression adversary.

Registered as a `Model` with name "feedforward_regression_adversary".

# Parameters

vocab : `Vocabulary`
feedforward : `FeedForward`
A feedforward layer.
initializer : `Optional[InitializerApplicator]`, optional (default=`InitializerApplicator()`)
If provided, will be used to initialize the model parameters.
"""

def __init__(
self,
vocab: Vocabulary,
feedforward: FeedForward,
initializer: Optional[InitializerApplicator] = InitializerApplicator(),
**kwargs,
) -> None:
super().__init__(vocab, **kwargs)

self._feedforward = feedforward
self._loss = torch.nn.MSELoss()
initializer(self) # type: ignore

def forward( # type: ignore
self, input: torch.FloatTensor, label: torch.FloatTensor
) -> Dict[str, torch.Tensor]:
"""
# Parameters

input : `torch.FloatTensor`
A tensor of size (batch_size, ...).
label : `torch.FloatTensor`
A tensor of the same size as input.

# Returns

An output dictionary consisting of:
- `loss` : `torch.FloatTensor`
A scalar loss to be optimised.
"""

pred = self._feedforward(input)
return {"loss": self._loss(pred, label)}


@TrainerCallback.register("adversarial_bias_mitigator_backward")
class AdversarialBiasMitigatorBackwardCallback(TrainerCallback):
"""
Performs backpropagation for adversarial bias mitigation.
While the adversary's gradients are computed normally,
the predictor's gradients are computed such that updates to the
predictor's parameters will not aid the adversary and will
make it more difficult for the adversary to recover protected variables.

!!! Note:
Intended to be used with `AdversarialBiasMitigator`.
trainer.model is expected to have `predictor` and `adversary` data members.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could, if you wanted to, put in a check for this condition and throw an exception if it isn't met.


# Parameters

adversary_loss_weight : `float`, optional (default = `1.0`)
Quantifies how difficult predictor makes it for adversary to recover protected variables.
"""

def __init__(self, serialization_dir: str, adversary_loss_weight: float = 1.0) -> None:
super().__init__(serialization_dir)
self.adversary_loss_weight = adversary_loss_weight

def on_backward(
self,
trainer: GradientDescentTrainer,
batch_outputs: Dict[str, torch.Tensor],
backward_called: bool,
**kwargs,
) -> bool:
if backward_called:
raise OnBackwardException()

trainer.optimizer.zero_grad()
# `retain_graph=True` prevents computation graph from being erased
batch_outputs["adversary_loss"].backward(retain_graph=True)
# trainer.model is expected to have `predictor` and `adversary` data members
adversary_loss_grad = {
name: param.grad.clone()
for name, param in trainer.model.predictor.named_parameters()
if param.grad is not None
}

trainer.model.predictor.zero_grad()
batch_outputs["loss"].backward()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If "loss" and "adversary_loss" don't use exactly the same computation graph, does that mean that parts of the computation graph of "adversary_loss" could stick around when we don't want them to?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a really good point about part of the computation graph not getting erased! Upon further reading, it looks like the computation graph will stay around until adversary_loss goes out of scope. So I added this in the callback to remove all references to the adversary_loss in the graph and instead keep a view of the loss that's not in the graph:

# remove adversary_loss from computation graph
batch_outputs["adversary_loss"] = batch_outputs["adversary_loss"].detach()


with torch.no_grad():
for name, param in trainer.model.predictor.named_parameters():
if param.grad is not None:
unit_adversary_loss_grad = adversary_loss_grad[name] / torch.linalg.norm(
adversary_loss_grad[name]
)
# prevent predictor from accidentally aiding adversary
# by removing projection of predictor loss grad onto adversary loss grad
param.grad -= (
(param.grad * unit_adversary_loss_grad) * unit_adversary_loss_grad
).sum()
# make it difficult for adversary to recover protected variables
param.grad -= self.adversary_loss_weight * adversary_loss_grad[name]

return True
Loading