Skip to content

Commit

Permalink
TF: XLA stable softmax (#16892)
Browse files Browse the repository at this point in the history
Co-authored-by: Sylvain Gugger <[email protected]>
  • Loading branch information
gante and sgugger authored Apr 25, 2022
1 parent 8246caf commit e03966e
Show file tree
Hide file tree
Showing 49 changed files with 210 additions and 142 deletions.
3 changes: 2 additions & 1 deletion src/transformers/generation_tf_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np
import tensorflow as tf

from .tf_utils import stable_softmax
from .utils import add_start_docstrings
from .utils.logging import get_logger

Expand Down Expand Up @@ -166,7 +167,7 @@ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
topk_scores, topk_indices = tf.math.top_k(scores, scores.shape[-1])

mask_scores = tf.fill(scores.shape, self.filter_value)
cumulative_probs = tf.math.cumsum(tf.nn.softmax(topk_scores, axis=-1), axis=-1)
cumulative_probs = tf.math.cumsum(stable_softmax(topk_scores, axis=-1), axis=-1)
score_mask = cumulative_probs < self.top_p

# Also include the token that is higher than top_p (the first false = shift and insert a True on the left)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/generation_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
TFTopKLogitsWarper,
TFTopPLogitsWarper,
)
from .tf_utils import shape_list
from .tf_utils import shape_list, stable_softmax
from .utils import ModelOutput, logging


Expand Down Expand Up @@ -3060,7 +3060,7 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
logits, sorted_indices, axis=-1, batch_dims=1
) # expects logits to be of dim (batch_size, vocab_size)

cumulative_probs = tf.math.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
cumulative_probs = tf.math.cumsum(stable_softmax(sorted_logits, axis=-1), axis=-1)

# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/albert/modeling_tf_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput,
Expand Down Expand Up @@ -259,7 +259,7 @@ def call(
attention_scores = tf.add(attention_scores, attention_mask)

# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
attention_probs = stable_softmax(logits=attention_scores, axis=-1)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/bart/modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
add_code_sample_docstrings,
add_end_docstrings,
Expand Down Expand Up @@ -244,7 +244,7 @@ def call(
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))

attn_weights = tf.nn.softmax(attn_weights, axis=-1)
attn_weights = stable_softmax(attn_weights, axis=-1)

if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/bert/modeling_tf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
DUMMY_INPUTS,
MULTIPLE_CHOICE_DUMMY_INPUTS,
Expand Down Expand Up @@ -322,7 +322,7 @@ def call(
attention_scores = tf.add(attention_scores, attention_mask)

# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
attention_probs = stable_softmax(logits=attention_scores, axis=-1)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/blenderbot/modeling_tf_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
add_code_sample_docstrings,
add_end_docstrings,
Expand Down Expand Up @@ -245,7 +245,7 @@ def call(
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))

attn_weights = tf.nn.softmax(attn_weights, axis=-1)
attn_weights = stable_softmax(attn_weights, axis=-1)

if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
add_code_sample_docstrings,
add_end_docstrings,
Expand Down Expand Up @@ -245,7 +245,7 @@ def call(
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))

attn_weights = tf.nn.softmax(attn_weights, axis=-1)
attn_weights = stable_softmax(attn_weights, axis=-1)

if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/clip/modeling_tf_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
ModelOutput,
add_start_docstrings,
Expand Down Expand Up @@ -333,7 +333,7 @@ def call(
attention_scores = tf.add(attention_scores, attention_mask)

# Normalize the attention scores to probabilities.
_attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
_attention_probs = stable_softmax(logits=attention_scores, axis=-1)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/convbert/modeling_tf_convbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
add_code_sample_docstrings,
Expand Down Expand Up @@ -228,7 +228,7 @@ def call(self, hidden_states, attention_mask, head_mask, output_attentions, trai

conv_kernel_layer = self.conv_kernel_layer(conv_attn_layer)
conv_kernel_layer = tf.reshape(conv_kernel_layer, [-1, self.conv_kernel_size, 1])
conv_kernel_layer = tf.nn.softmax(conv_kernel_layer, axis=1)
conv_kernel_layer = stable_softmax(conv_kernel_layer, axis=1)

paddings = tf.constant(
[
Expand Down Expand Up @@ -270,7 +270,7 @@ def call(self, hidden_states, attention_mask, head_mask, output_attentions, trai
attention_scores = attention_scores + attention_mask

# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(attention_scores, axis=-1)
attention_probs = stable_softmax(attention_scores, axis=-1)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/ctrl/modeling_tf_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_ctrl import CTRLConfig

Expand Down Expand Up @@ -79,7 +79,7 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N
attention_mask = tf.cast(attention_mask, dtype=scaled_attention_logits.dtype)
scaled_attention_logits = scaled_attention_logits + attention_mask

attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
attention_weights = stable_softmax(scaled_attention_logits, axis=-1)

# Mask heads if we want to
if head_mask is not None:
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/deberta/modeling_tf_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
get_initializer,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_deberta import DebertaConfig

Expand Down Expand Up @@ -96,7 +96,7 @@ def call(self, inputs: tf.Tensor, mask: tf.Tensor):

rmask = tf.logical_not(tf.cast(mask, tf.bool))
output = tf.where(rmask, float("-inf"), inputs)
output = tf.nn.softmax(output, self.axis)
output = stable_softmax(output, self.axis)
output = tf.where(rmask, 0.0, output)
return output

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
get_initializer,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_deberta_v2 import DebertaV2Config

Expand Down Expand Up @@ -97,7 +97,7 @@ def call(self, inputs: tf.Tensor, mask: tf.Tensor):

rmask = tf.logical_not(tf.cast(mask, tf.bool))
output = tf.where(rmask, float("-inf"), inputs)
output = tf.nn.softmax(output, self.axis)
output = stable_softmax(output, self.axis)
output = tf.where(rmask, 0.0, output)
return output

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/distilbert/modeling_tf_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
add_code_sample_docstrings,
Expand Down Expand Up @@ -194,7 +194,7 @@ def unshape(x):

mask = tf.cast(mask, dtype=scores.dtype)
scores = scores - 1e30 * (1.0 - mask)
weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
weights = stable_softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)

# Mask heads if we want to
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/electra/modeling_tf_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
DUMMY_INPUTS,
MULTIPLE_CHOICE_DUMMY_INPUTS,
Expand Down Expand Up @@ -171,7 +171,7 @@ def call(
attention_scores = tf.add(attention_scores, attention_mask)

# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
attention_probs = stable_softmax(logits=attention_scores, axis=-1)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/flaubert/modeling_tf_flaubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
Expand Down Expand Up @@ -361,7 +361,7 @@ def unshape(x):
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
mask = tf.cast(mask, dtype=scores.dtype)
scores = scores - 1e30 * (1.0 - mask)
weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
weights = stable_softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)

# Mask heads if we want to
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/funnel/modeling_tf_funnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput,
Expand Down Expand Up @@ -530,7 +530,7 @@ def call(self, query, key, value, attention_inputs, output_attentions=False, tra
attn_score = attn_score - (INF * (1 - attention_mask[:, None, None]))

# attention probability
attn_prob = tf.nn.softmax(attn_score, axis=-1)
attn_prob = stable_softmax(attn_score, axis=-1)
attn_prob = self.attention_dropout(attn_prob, training=training)

# attention output, shape batch_size x seq_len x n_head x d_head
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/gpt2/modeling_tf_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
DUMMY_INPUTS,
ModelOutput,
Expand Down Expand Up @@ -129,7 +129,7 @@ def _attn(self, q, k, v, attention_mask, head_mask, output_attentions, training=
attention_mask = tf.cast(attention_mask, dtype=w.dtype)
w = w + attention_mask

w = tf.nn.softmax(w, axis=-1)
w = stable_softmax(w, axis=-1)
w = self.attn_dropout(w, training=training)

# Mask heads if we want to
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/gptj/modeling_tf_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import logging
from .configuration_gptj import GPTJConfig

Expand Down Expand Up @@ -191,7 +191,7 @@ def _attn(
# Apply the attention mask
attn_weights = attn_weights + attention_mask

attn_weights = tf.nn.softmax(attn_weights, axis=-1)
attn_weights = stable_softmax(attn_weights, axis=-1)
attn_weights = tf.cast(attn_weights, value.dtype)
attn_weights = self.attn_dropout(attn_weights)

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/hubert/modeling_tf_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...tokenization_utils_base import BatchEncoding
from ...utils import (
ModelOutput,
Expand Down Expand Up @@ -826,7 +826,7 @@ def call(
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))

attn_weights = tf.nn.softmax(attn_weights, axis=-1)
attn_weights = stable_softmax(attn_weights, axis=-1)

if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/layoutlm/modeling_tf_layoutlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_layoutlm import LayoutLMConfig

Expand Down Expand Up @@ -280,7 +280,7 @@ def call(
attention_scores = tf.add(attention_scores, attention_mask)

# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
attention_probs = stable_softmax(logits=attention_scores, axis=-1)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
Expand Down
Loading

0 comments on commit e03966e

Please sign in to comment.