Skip to content

Commit

Permalink
Remove dtype float32 dtype assumption
Browse files Browse the repository at this point in the history
Correct various type annotations.
  • Loading branch information
NeilGirdhar committed Jan 28, 2022
1 parent fda878f commit ac1d33a
Show file tree
Hide file tree
Showing 13 changed files with 405 additions and 266 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ Changelog

vNext
------
(Add your change to a random empty line to avoid merge conflicts)
-
-
-
Expand All @@ -27,6 +26,7 @@ vNext
-
-
-
-

0.4.0
------
Expand Down
2 changes: 1 addition & 1 deletion examples/seq2seq/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def select_carried_state(new_state, old_state):
@staticmethod
def initialize_carry(batch_size, hidden_size):
# use dummy key since default state init fn is just zeros.
return nn.LSTMCell.initialize_carry(
return nn.LSTMCell().initialize_carry(
jax.random.PRNGKey(0), (batch_size,), hidden_size)


Expand Down
17 changes: 11 additions & 6 deletions examples/sst2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,19 +162,22 @@ def __call__(self, inputs: Array,
class SimpleLSTM(nn.Module):
"""A simple unidirectional LSTM."""

def setup(self):
self.lstm = nn.OptimizedLSTMCell()

@functools.partial(
nn.transforms.scan,
variable_broadcast='params',
in_axes=1, out_axes=1,
split_rngs={'params': False})
@nn.compact
def __call__(self, carry, x):
return nn.OptimizedLSTMCell()(carry, x)
return self.lstm(carry, x)

@staticmethod
def initialize_carry(batch_dims, hidden_size):
@nn.module.wrap_method_once
def initialize_carry(self, batch_dims, hidden_size):
# Use fixed random key since default state init fn is just zeros.
return nn.OptimizedLSTMCell.initialize_carry(
return self.lstm.initialize_carry(
jax.random.PRNGKey(0), batch_dims, hidden_size)


Expand All @@ -190,12 +193,14 @@ def __call__(self, embedded_inputs, lengths):
batch_size = embedded_inputs.shape[0]

# Forward LSTM.
initial_state = SimpleLSTM.initialize_carry((batch_size,), self.hidden_size)
initial_state = self.forward_lstm.initialize_carry((batch_size,),
self.hidden_size)
_, forward_outputs = self.forward_lstm(initial_state, embedded_inputs)

# Backward LSTM.
reversed_inputs = flip_sequences(embedded_inputs, lengths)
initial_state = SimpleLSTM.initialize_carry((batch_size,), self.hidden_size)
initial_state = self.backward_lstm.initialize_carry((batch_size,),
self.hidden_size)
_, backward_outputs = self.backward_lstm(initial_state, reversed_inputs)
backward_outputs = flip_sequences(backward_outputs, lengths)

Expand Down
2 changes: 1 addition & 1 deletion examples/sst2/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_lstm_returns_correct_output_shape(self):
rng = jax.random.PRNGKey(0)
inputs = np.random.RandomState(0).normal(
size=[batch_size, seq_len, embedding_size])
initial_state = models.SimpleLSTM.initialize_carry((batch_size,), hidden_size)
initial_state = model.initialize_carry((batch_size,), hidden_size)
(_, output), _ = model.init_with_output(rng, initial_state, inputs)
self.assertEqual((batch_size, seq_len, hidden_size), output.shape)

Expand Down
7 changes: 5 additions & 2 deletions flax/linen/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,11 @@ def __call__(self, inputs: Array) -> Array:
Returns:
The transformed input.
"""
dtype = inputs.dtype
assert jnp.issubdtype(dtype, jnp.floating)
negative_slope = self.param(
'negative_slope',
lambda k: jnp.asarray(self.negative_slope_init, jnp.float32)
lambda k: jnp.asarray(self.negative_slope_init, dtype)
)
return jnp.where(inputs >= 0, inputs, jnp.asarray(negative_slope, inputs.dtype) * inputs)
assert negative_slope.shape == ()
return jnp.where(inputs >= 0, inputs, negative_slope * inputs)
148 changes: 90 additions & 58 deletions flax/linen/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,42 @@
"""Attention core modules for Flax."""

from functools import partial
from typing import (Any, Callable, Tuple, Optional)
from typing import Any, Callable, Tuple, Type, Optional

import jax
from jax import lax
from jax import random
import jax.numpy as jnp
import numpy as np
from typing_extensions import Protocol

from flax.linen.linear import default_kernel_init
from flax.linen.initializers import zeros
from flax.linen.linear import DenseGeneral
from flax.linen.linear import _canonicalize_dtypes
from flax.linen.linear import default_kernel_init
from flax.linen.module import Module, compact, merge_param
from flax.linen.initializers import zeros

PRNGKey = Any
Shape = Tuple[int]
Dtype = Any
Shape = Tuple[int, ...]
InexactDType = Type[np.inexact]
Array = Any
Initializer = Callable[[PRNGKey, Shape, InexactDType], Array]


class AttentionFunction(Protocol):
@staticmethod
def __call__(query: Array,
key: Array,
value: Array,
bias: Optional[Array] = None,
mask: Optional[Array] = None,
broadcast_dropout: bool = True,
dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0.,
deterministic: bool = False,
dtype: InexactDType = jnp.float32,
precision: Optional[lax.Precision] = None) -> Array:
...


def dot_product_attention_weights(query: Array,
Expand All @@ -42,7 +61,7 @@ def dot_product_attention_weights(query: Array,
dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0.,
deterministic: bool = False,
dtype: Dtype = jnp.float32,
dtype: InexactDType = jnp.float32,
precision: Optional[lax.Precision] = None):
"""Computes dot-product attention weights given query and key.
Expand Down Expand Up @@ -109,7 +128,7 @@ def dot_product_attention_weights(query: Array,
keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
else:
keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
multiplier = (keep.astype(attn_weights.dtype) /
multiplier = (keep.astype(dtype) /
jnp.asarray(keep_prob, dtype=dtype))
attn_weights = attn_weights * multiplier

Expand All @@ -125,8 +144,8 @@ def dot_product_attention(query: Array,
dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0.,
deterministic: bool = False,
dtype: Dtype = jnp.float32,
precision: Optional[lax.Precision] = None):
dtype: InexactDType = jnp.float32,
precision: Optional[lax.Precision] = None) -> Array:
"""Computes dot-product attention given query, key, and value.
This is the core function for applying attention based on
Expand Down Expand Up @@ -179,52 +198,27 @@ def dot_product_attention(query: Array,
precision=precision)


class MultiHeadDotProductAttention(Module):
"""Multi-head dot-product attention.
Attributes:
num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
should be divisible by the number of heads.
dtype: the dtype of the computation (default: float32)
param_dtype: the dtype passed to parameter initializers (default: float32).
qkv_features: dimension of the key, query, and value.
out_features: dimension of the last projection
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rate: dropout rate
deterministic: if false, the attention weight is masked randomly
using dropout, whereas if true, the attention weights
are deterministic.
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
kernel_init: initializer for the kernel of the Dense layers.
bias_init: initializer for the bias of the Dense layers.
use_bias: bool: whether pointwise QKVO dense transforms use bias.
attention_fn: dot_product_attention or compatible function. Accepts
query, key, value, and returns output of shape
`[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]``
decode: whether to prepare and use an autoregressive cache.
"""
class _BaseMultiHeadDotProductAttention(Module):
num_heads: int
dtype: Dtype = jnp.float32
param_dtype: Dtype = jnp.float32
dtype: Optional[InexactDType] = None
param_dtype: Optional[InexactDType] = None
qkv_features: Optional[int] = None
out_features: Optional[int] = None
broadcast_dropout: bool = True
dropout_rate: float = 0.
deterministic: Optional[bool] = None
precision: Any = None
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros
kernel_init: Initializer = default_kernel_init
bias_init: Initializer = zeros
use_bias: bool = True
attention_fn: Callable[[Array, Array, Array], Array] = dot_product_attention
attention_fn: AttentionFunction = dot_product_attention
decode: bool = False

@compact
def __call__(self,
inputs_q: Array,
inputs_kv: Array,
mask: Optional[Array] = None,
deterministic: Optional[bool] = None):
def _apply(self,
inputs_q: Array,
inputs_kv: Array,
mask: Optional[Array] = None,
deterministic: Optional[bool] = None):
"""Applies multi-head dot product attention on the input data.
Projects the inputs into multi-headed query, key, and value vectors,
Expand All @@ -246,6 +240,10 @@ def __call__(self,
Returns:
output of shape `[batch_sizes..., length, features]`.
"""
param_dtype, dtype = _canonicalize_dtypes(jnp.result_type(inputs_q,
inputs_kv),
self.param_dtype,
self.dtype)
if self.dropout_rate > 0.: # Require `deterministic` only if using dropout.
deterministic = merge_param('deterministic', self.deterministic, deterministic)
features = self.out_features or inputs_q.shape[-1]
Expand All @@ -256,8 +254,8 @@ def __call__(self,

dense = partial(DenseGeneral,
axis=-1,
dtype=self.dtype,
param_dtype=self.param_dtype,
dtype=dtype,
param_dtype=param_dtype,
features=(self.num_heads, head_dim),
kernel_init=self.kernel_init,
bias_init=self.bias_init,
Expand Down Expand Up @@ -320,38 +318,72 @@ def __call__(self,
dropout_rate=self.dropout_rate,
broadcast_dropout=self.broadcast_dropout,
deterministic=deterministic,
dtype=self.dtype,
dtype=dtype,
precision=self.precision) # pytype: disable=wrong-keyword-args
# back to the original inputs dimensions
out = DenseGeneral(features=features,
axis=(-2, -1),
kernel_init=self.kernel_init,
bias_init=self.bias_init,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
dtype=dtype,
param_dtype=param_dtype,
precision=self.precision,
name='out')(x)
return out


class SelfAttention(MultiHeadDotProductAttention):
"""Self-attention special case of multi-head dot-product attention."""
class MultiHeadDotProductAttention(_BaseMultiHeadDotProductAttention):
"""Multi-head dot-product attention.
Attributes:
num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
should be divisible by the number of heads.
dtype: the dtype of the computation (default: float32)
param_dtype: the dtype passed to parameter initializers (default: float32).
qkv_features: dimension of the key, query, and value.
out_features: dimension of the last projection
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rate: dropout rate
deterministic: if false, the attention weight is masked randomly
using dropout, whereas if true, the attention weights
are deterministic.
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
kernel_init: initializer for the kernel of the Dense layers.
bias_init: initializer for the bias of the Dense layers.
use_bias: bool: whether pointwise QKVO dense transforms use bias.
attention_fn: dot_product_attention or compatible function. Accepts
query, key, value, and returns output of shape
`[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]``
decode: whether to prepare and use an autoregressive cache.
"""
@compact
def __call__(self,
inputs_q: Array,
inputs_kv: Array,
mask: Optional[Array] = None,
deterministic: Optional[bool] = None):
return self._apply(inputs_q, inputs_kv, mask, deterministic=deterministic)


class SelfAttention(_BaseMultiHeadDotProductAttention):
"""Self-attention special case of multi-head dot-product attention."""
@compact
def __call__(self, inputs_q: Array, mask: Optional[Array] = None,
deterministic: Optional[bool] = None):
return super().__call__(inputs_q, inputs_q, mask, deterministic=deterministic)
return self._apply(inputs_q, inputs_q, mask, deterministic=deterministic)


# mask-making utility functions


def make_attention_mask(query_input: Array,
key_input: Array,
pairwise_fn: Callable[..., Any] = jnp.multiply,
extra_batch_dims: int = 0,
dtype: Dtype = jnp.float32):
def make_attention_mask(
query_input: Array,
key_input: Array,
pairwise_fn: Callable[[Array, Array], Array] = jnp.multiply,
extra_batch_dims: int = 0,
dtype: InexactDType = jnp.float32):
"""Mask-making helper for attention weights.
In case of 1d inputs (i.e., `[batch..., len_q]`, `[batch..., len_kv]`, the
Expand Down
Loading

0 comments on commit ac1d33a

Please sign in to comment.