From ac1d33a4d6c8dfab222c7b5331e4c9c566da1534 Mon Sep 17 00:00:00 2001 From: Neil Girdhar Date: Mon, 24 Jan 2022 17:32:34 -0500 Subject: [PATCH] Remove dtype float32 dtype assumption Correct various type annotations. --- CHANGELOG.md | 2 +- examples/seq2seq/train.py | 2 +- examples/sst2/models.py | 17 ++- examples/sst2/models_test.py | 2 +- flax/linen/activation.py | 7 +- flax/linen/attention.py | 148 ++++++++++++++--------- flax/linen/linear.py | 163 ++++++++++++++++--------- flax/linen/normalization.py | 103 +++++++++------- flax/linen/recurrent.py | 170 ++++++++++++++++----------- setup.py | 1 + tests/linen/linen_module_test.py | 14 +-- tests/linen/linen_test.py | 30 ++--- tests/linen/linen_transforms_test.py | 12 +- 13 files changed, 405 insertions(+), 266 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 78246840b0..2ae2be5321 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,6 @@ Changelog vNext ------ -(Add your change to a random empty line to avoid merge conflicts) - - - @@ -27,6 +26,7 @@ vNext - - - +- 0.4.0 ------ diff --git a/examples/seq2seq/train.py b/examples/seq2seq/train.py index 8526c66649..5baef3a8c8 100644 --- a/examples/seq2seq/train.py +++ b/examples/seq2seq/train.py @@ -187,7 +187,7 @@ def select_carried_state(new_state, old_state): @staticmethod def initialize_carry(batch_size, hidden_size): # use dummy key since default state init fn is just zeros. - return nn.LSTMCell.initialize_carry( + return nn.LSTMCell().initialize_carry( jax.random.PRNGKey(0), (batch_size,), hidden_size) diff --git a/examples/sst2/models.py b/examples/sst2/models.py index 8029515f81..16620f4590 100644 --- a/examples/sst2/models.py +++ b/examples/sst2/models.py @@ -162,6 +162,9 @@ def __call__(self, inputs: Array, class SimpleLSTM(nn.Module): """A simple unidirectional LSTM.""" + def setup(self): + self.lstm = nn.OptimizedLSTMCell() + @functools.partial( nn.transforms.scan, variable_broadcast='params', @@ -169,12 +172,12 @@ class SimpleLSTM(nn.Module): split_rngs={'params': False}) @nn.compact def __call__(self, carry, x): - return nn.OptimizedLSTMCell()(carry, x) + return self.lstm(carry, x) - @staticmethod - def initialize_carry(batch_dims, hidden_size): + @nn.module.wrap_method_once + def initialize_carry(self, batch_dims, hidden_size): # Use fixed random key since default state init fn is just zeros. - return nn.OptimizedLSTMCell.initialize_carry( + return self.lstm.initialize_carry( jax.random.PRNGKey(0), batch_dims, hidden_size) @@ -190,12 +193,14 @@ def __call__(self, embedded_inputs, lengths): batch_size = embedded_inputs.shape[0] # Forward LSTM. - initial_state = SimpleLSTM.initialize_carry((batch_size,), self.hidden_size) + initial_state = self.forward_lstm.initialize_carry((batch_size,), + self.hidden_size) _, forward_outputs = self.forward_lstm(initial_state, embedded_inputs) # Backward LSTM. reversed_inputs = flip_sequences(embedded_inputs, lengths) - initial_state = SimpleLSTM.initialize_carry((batch_size,), self.hidden_size) + initial_state = self.backward_lstm.initialize_carry((batch_size,), + self.hidden_size) _, backward_outputs = self.backward_lstm(initial_state, reversed_inputs) backward_outputs = flip_sequences(backward_outputs, lengths) diff --git a/examples/sst2/models_test.py b/examples/sst2/models_test.py index 0f3199b243..71d57b1ba3 100644 --- a/examples/sst2/models_test.py +++ b/examples/sst2/models_test.py @@ -51,7 +51,7 @@ def test_lstm_returns_correct_output_shape(self): rng = jax.random.PRNGKey(0) inputs = np.random.RandomState(0).normal( size=[batch_size, seq_len, embedding_size]) - initial_state = models.SimpleLSTM.initialize_carry((batch_size,), hidden_size) + initial_state = model.initialize_carry((batch_size,), hidden_size) (_, output), _ = model.init_with_output(rng, initial_state, inputs) self.assertEqual((batch_size, seq_len, hidden_size), output.shape) diff --git a/flax/linen/activation.py b/flax/linen/activation.py index 268123eccd..2514f2edaa 100644 --- a/flax/linen/activation.py +++ b/flax/linen/activation.py @@ -67,8 +67,11 @@ def __call__(self, inputs: Array) -> Array: Returns: The transformed input. """ + dtype = inputs.dtype + assert jnp.issubdtype(dtype, jnp.floating) negative_slope = self.param( 'negative_slope', - lambda k: jnp.asarray(self.negative_slope_init, jnp.float32) + lambda k: jnp.asarray(self.negative_slope_init, dtype) ) - return jnp.where(inputs >= 0, inputs, jnp.asarray(negative_slope, inputs.dtype) * inputs) + assert negative_slope.shape == () + return jnp.where(inputs >= 0, inputs, negative_slope * inputs) diff --git a/flax/linen/attention.py b/flax/linen/attention.py index d7e3f80448..2d1d309816 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -15,23 +15,42 @@ """Attention core modules for Flax.""" from functools import partial -from typing import (Any, Callable, Tuple, Optional) +from typing import Any, Callable, Tuple, Type, Optional import jax from jax import lax from jax import random import jax.numpy as jnp import numpy as np +from typing_extensions import Protocol -from flax.linen.linear import default_kernel_init +from flax.linen.initializers import zeros from flax.linen.linear import DenseGeneral +from flax.linen.linear import _canonicalize_dtypes +from flax.linen.linear import default_kernel_init from flax.linen.module import Module, compact, merge_param -from flax.linen.initializers import zeros PRNGKey = Any -Shape = Tuple[int] -Dtype = Any +Shape = Tuple[int, ...] +InexactDType = Type[np.inexact] Array = Any +Initializer = Callable[[PRNGKey, Shape, InexactDType], Array] + + +class AttentionFunction(Protocol): + @staticmethod + def __call__(query: Array, + key: Array, + value: Array, + bias: Optional[Array] = None, + mask: Optional[Array] = None, + broadcast_dropout: bool = True, + dropout_rng: Optional[PRNGKey] = None, + dropout_rate: float = 0., + deterministic: bool = False, + dtype: InexactDType = jnp.float32, + precision: Optional[lax.Precision] = None) -> Array: + ... def dot_product_attention_weights(query: Array, @@ -42,7 +61,7 @@ def dot_product_attention_weights(query: Array, dropout_rng: Optional[PRNGKey] = None, dropout_rate: float = 0., deterministic: bool = False, - dtype: Dtype = jnp.float32, + dtype: InexactDType = jnp.float32, precision: Optional[lax.Precision] = None): """Computes dot-product attention weights given query and key. @@ -109,7 +128,7 @@ def dot_product_attention_weights(query: Array, keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) else: keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) - multiplier = (keep.astype(attn_weights.dtype) / + multiplier = (keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)) attn_weights = attn_weights * multiplier @@ -125,8 +144,8 @@ def dot_product_attention(query: Array, dropout_rng: Optional[PRNGKey] = None, dropout_rate: float = 0., deterministic: bool = False, - dtype: Dtype = jnp.float32, - precision: Optional[lax.Precision] = None): + dtype: InexactDType = jnp.float32, + precision: Optional[lax.Precision] = None) -> Array: """Computes dot-product attention given query, key, and value. This is the core function for applying attention based on @@ -179,52 +198,27 @@ def dot_product_attention(query: Array, precision=precision) -class MultiHeadDotProductAttention(Module): - """Multi-head dot-product attention. - - Attributes: - num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) - should be divisible by the number of heads. - dtype: the dtype of the computation (default: float32) - param_dtype: the dtype passed to parameter initializers (default: float32). - qkv_features: dimension of the key, query, and value. - out_features: dimension of the last projection - broadcast_dropout: bool: use a broadcasted dropout along batch dims. - dropout_rate: dropout rate - deterministic: if false, the attention weight is masked randomly - using dropout, whereas if true, the attention weights - are deterministic. - precision: numerical precision of the computation see `jax.lax.Precision` - for details. - kernel_init: initializer for the kernel of the Dense layers. - bias_init: initializer for the bias of the Dense layers. - use_bias: bool: whether pointwise QKVO dense transforms use bias. - attention_fn: dot_product_attention or compatible function. Accepts - query, key, value, and returns output of shape - `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`` - decode: whether to prepare and use an autoregressive cache. - """ +class _BaseMultiHeadDotProductAttention(Module): num_heads: int - dtype: Dtype = jnp.float32 - param_dtype: Dtype = jnp.float32 + dtype: Optional[InexactDType] = None + param_dtype: Optional[InexactDType] = None qkv_features: Optional[int] = None out_features: Optional[int] = None broadcast_dropout: bool = True dropout_rate: float = 0. deterministic: Optional[bool] = None precision: Any = None - kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros + kernel_init: Initializer = default_kernel_init + bias_init: Initializer = zeros use_bias: bool = True - attention_fn: Callable[[Array, Array, Array], Array] = dot_product_attention + attention_fn: AttentionFunction = dot_product_attention decode: bool = False - @compact - def __call__(self, - inputs_q: Array, - inputs_kv: Array, - mask: Optional[Array] = None, - deterministic: Optional[bool] = None): + def _apply(self, + inputs_q: Array, + inputs_kv: Array, + mask: Optional[Array] = None, + deterministic: Optional[bool] = None): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, @@ -246,6 +240,10 @@ def __call__(self, Returns: output of shape `[batch_sizes..., length, features]`. """ + param_dtype, dtype = _canonicalize_dtypes(jnp.result_type(inputs_q, + inputs_kv), + self.param_dtype, + self.dtype) if self.dropout_rate > 0.: # Require `deterministic` only if using dropout. deterministic = merge_param('deterministic', self.deterministic, deterministic) features = self.out_features or inputs_q.shape[-1] @@ -256,8 +254,8 @@ def __call__(self, dense = partial(DenseGeneral, axis=-1, - dtype=self.dtype, - param_dtype=self.param_dtype, + dtype=dtype, + param_dtype=param_dtype, features=(self.num_heads, head_dim), kernel_init=self.kernel_init, bias_init=self.bias_init, @@ -320,7 +318,7 @@ def __call__(self, dropout_rate=self.dropout_rate, broadcast_dropout=self.broadcast_dropout, deterministic=deterministic, - dtype=self.dtype, + dtype=dtype, precision=self.precision) # pytype: disable=wrong-keyword-args # back to the original inputs dimensions out = DenseGeneral(features=features, @@ -328,30 +326,64 @@ def __call__(self, kernel_init=self.kernel_init, bias_init=self.bias_init, use_bias=self.use_bias, - dtype=self.dtype, - param_dtype=self.param_dtype, + dtype=dtype, + param_dtype=param_dtype, precision=self.precision, name='out')(x) return out -class SelfAttention(MultiHeadDotProductAttention): - """Self-attention special case of multi-head dot-product attention.""" +class MultiHeadDotProductAttention(_BaseMultiHeadDotProductAttention): + """Multi-head dot-product attention. + + Attributes: + num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) + should be divisible by the number of heads. + dtype: the dtype of the computation (default: float32) + param_dtype: the dtype passed to parameter initializers (default: float32). + qkv_features: dimension of the key, query, and value. + out_features: dimension of the last projection + broadcast_dropout: bool: use a broadcasted dropout along batch dims. + dropout_rate: dropout rate + deterministic: if false, the attention weight is masked randomly + using dropout, whereas if true, the attention weights + are deterministic. + precision: numerical precision of the computation see `jax.lax.Precision` + for details. + kernel_init: initializer for the kernel of the Dense layers. + bias_init: initializer for the bias of the Dense layers. + use_bias: bool: whether pointwise QKVO dense transforms use bias. + attention_fn: dot_product_attention or compatible function. Accepts + query, key, value, and returns output of shape + `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`` + decode: whether to prepare and use an autoregressive cache. + """ + @compact + def __call__(self, + inputs_q: Array, + inputs_kv: Array, + mask: Optional[Array] = None, + deterministic: Optional[bool] = None): + return self._apply(inputs_q, inputs_kv, mask, deterministic=deterministic) + +class SelfAttention(_BaseMultiHeadDotProductAttention): + """Self-attention special case of multi-head dot-product attention.""" @compact def __call__(self, inputs_q: Array, mask: Optional[Array] = None, deterministic: Optional[bool] = None): - return super().__call__(inputs_q, inputs_q, mask, deterministic=deterministic) + return self._apply(inputs_q, inputs_q, mask, deterministic=deterministic) # mask-making utility functions -def make_attention_mask(query_input: Array, - key_input: Array, - pairwise_fn: Callable[..., Any] = jnp.multiply, - extra_batch_dims: int = 0, - dtype: Dtype = jnp.float32): +def make_attention_mask( + query_input: Array, + key_input: Array, + pairwise_fn: Callable[[Array, Array], Array] = jnp.multiply, + extra_batch_dims: int = 0, + dtype: InexactDType = jnp.float32): """Mask-making helper for attention weights. In case of 1d inputs (i.e., `[batch..., len_q]`, `[batch..., len_kv]`, the diff --git a/flax/linen/linear.py b/flax/linen/linear.py index 81b97a15d3..818dd4aeb3 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py @@ -16,7 +16,8 @@ from dataclasses import field -from typing import (Any, Callable, Iterable, Optional, Tuple, Union) +from typing import (Any, Callable, Iterable, List, Optional, Sequence, Tuple, + Type, Union) from flax.linen.module import Module, compact from flax.linen.initializers import lecun_normal, variance_scaling, zeros @@ -27,9 +28,12 @@ PRNGKey = Any -Shape = Iterable[int] -Dtype = Any # this could be a real type? +Shape = Tuple[int, ...] +InexactDType = Type[np.inexact] +NumericDType = Type[np.number] +GenericDType = Type[np.generic] Array = Any +Initializer = Callable[[PRNGKey, Shape, InexactDType], Array] default_kernel_init = lecun_normal() @@ -47,6 +51,38 @@ def _canonicalize_tuple(x): return (x,) +def _canonicalize_dtypes( + input_dtype: InexactDType, + param_dtype: Optional[InexactDType], + computation_dtype: Optional[InexactDType]) -> Tuple[InexactDType, + InexactDType]: + returned_param_dtype = input_dtype if param_dtype is None else param_dtype + dtype = (jnp.result_type(input_dtype, returned_param_dtype) + if computation_dtype is None else computation_dtype) + + assert np.issubdtype(input_dtype, np.inexact) + if np.issubdtype(input_dtype, np.complexfloating): + assert np.issubdtype(returned_param_dtype, np.complexfloating) + assert np.issubdtype(dtype, np.complexfloating) + return returned_param_dtype, dtype + + +def _canonicalize_numeric_dtypes( + input_dtype: NumericDType, + param_dtype: Optional[NumericDType], + computation_dtype: Optional[NumericDType]) -> Tuple[NumericDType, + NumericDType]: + returned_param_dtype = input_dtype if param_dtype is None else param_dtype + dtype = (jnp.result_type(input_dtype, returned_param_dtype) + if computation_dtype is None else computation_dtype) + + assert np.issubdtype(input_dtype, np.number) + if np.issubdtype(input_dtype, np.complexfloating): + assert np.issubdtype(returned_param_dtype, np.complexfloating) + assert np.issubdtype(dtype, np.complexfloating) + return returned_param_dtype, dtype + + class DenseGeneral(Module): """A linear transformation with flexible axes. @@ -56,8 +92,8 @@ class DenseGeneral(Module): (-2, -1) will apply the transformation to the last two axes. batch_dims: tuple with batch axes. use_bias: whether to add a bias to the output (default: True). - dtype: the dtype of the computation (default: float32). - param_dtype: the dtype passed to parameter initializers (default: float32). + dtype: the dtype of the computation (default: None). + param_dtype: the dtype passed to parameter initializers (default: None). kernel_init: initializer function for the weight matrix. bias_init: initializer function for the bias. precision: numerical precision of the computation see `jax.lax.Precision` @@ -67,10 +103,10 @@ class DenseGeneral(Module): axis: Union[int, Iterable[int]] = -1 batch_dims: Iterable[int] = () use_bias: bool = True - dtype: Dtype = jnp.float32 - param_dtype: Dtype = jnp.float32 - kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros + dtype: Optional[InexactDType] = None + param_dtype: Optional[InexactDType] = None + kernel_init: Initializer = default_kernel_init + bias_init: Initializer = zeros precision: Any = None @compact @@ -83,6 +119,9 @@ def __call__(self, inputs: Array) -> Array: Returns: The transformed input. """ + param_dtype, dtype = _canonicalize_dtypes(inputs.dtype, self.param_dtype, + self.dtype) + inputs = jnp.asarray(inputs, dtype) features = _canonicalize_tuple(self.features) axis = _canonicalize_tuple(self.axis) batch_dims = _canonicalize_tuple(self.batch_dims) @@ -92,15 +131,13 @@ def __call__(self, inputs: Array) -> Array: raise ValueError('batch_dims %s must be consecutive leading ' 'dimensions starting from 0.' % str(batch_dims)) - inputs = jnp.asarray(inputs, self.dtype) - ndim = inputs.ndim n_batch_dims = len(batch_dims) axis = _normalize_axes(axis, ndim) batch_dims = _normalize_axes(batch_dims, ndim) n_axis, n_features = len(axis), len(features) - def kernel_init_wrap(rng, shape, dtype=jnp.float32): + def kernel_init_wrap(rng, shape, dtype): size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32) flat_shape = (np.prod(shape[n_batch_dims:n_axis + n_batch_dims]), np.prod(shape[-n_features:]),) @@ -115,8 +152,8 @@ def kernel_init_wrap(rng, shape, dtype=jnp.float32): for ax in range(inputs.ndim) if ax not in axis) kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features kernel = self.param('kernel', kernel_init_wrap, batch_shape + kernel_shape, - self.param_dtype) - kernel = jnp.asarray(kernel, self.dtype) + param_dtype) + kernel = jnp.asarray(kernel, dtype) batch_ind = tuple(range(n_batch_dims)) contract_ind = tuple(range(n_batch_dims, n_axis + n_batch_dims)) @@ -126,7 +163,7 @@ def kernel_init_wrap(rng, shape, dtype=jnp.float32): precision=self.precision) # dot_general output has shape [batch_dims/group_dims] + [feature_dims] if self.use_bias: - def bias_init_wrap(rng, shape, dtype=jnp.float32): + def bias_init_wrap(rng, shape, dtype): size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32) flat_shape = (np.prod(shape[-n_features:]),) bias = jnp.concatenate([self.bias_init(rng, flat_shape, dtype) @@ -134,10 +171,10 @@ def bias_init_wrap(rng, shape, dtype=jnp.float32): return jnp.reshape(bias, shape) bias = self.param('bias', bias_init_wrap, batch_shape + features, - self.param_dtype) + param_dtype) + bias = jnp.asarray(bias, dtype) # expand bias shape to broadcast bias over batch dims. bias = jnp.reshape(bias, expanded_batch_shape + features) - bias = jnp.asarray(bias, self.dtype) out = out + bias return out @@ -148,8 +185,8 @@ class Dense(Module): Attributes: features: the number of output features. use_bias: whether to add a bias to the output (default: True). - dtype: the dtype of the computation (default: float32). - param_dtype: the dtype passed to parameter initializers (default: float32). + dtype: the dtype of the computation (default: None). + param_dtype: the dtype passed to parameter initializers (default: None). precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer function for the weight matrix. @@ -157,11 +194,11 @@ class Dense(Module): """ features: int use_bias: bool = True - dtype: Dtype = jnp.float32 - param_dtype: Dtype = jnp.float32 + dtype: Optional[InexactDType] = None + param_dtype: Optional[InexactDType] = None precision: Any = None - kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros + kernel_init: Initializer = default_kernel_init + bias_init: Initializer = zeros @compact def __call__(self, inputs: Array) -> Array: @@ -173,19 +210,21 @@ def __call__(self, inputs: Array) -> Array: Returns: The transformed input. """ - inputs = jnp.asarray(inputs, self.dtype) + param_dtype, dtype = _canonicalize_dtypes(inputs.dtype, self.param_dtype, + self.dtype) + inputs = jnp.asarray(inputs, dtype) kernel = self.param('kernel', self.kernel_init, (inputs.shape[-1], self.features), self.param_dtype) - kernel = jnp.asarray(kernel, self.dtype) + kernel = jnp.asarray(kernel, dtype) y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=self.precision) if self.use_bias: bias = self.param('bias', self.bias_init, (self.features,), self.param_dtype) - bias = jnp.asarray(bias, self.dtype) + bias = jnp.asarray(bias, dtype) y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) return y @@ -214,7 +253,8 @@ class Conv(Module): high)` integer pairs that give the padding to apply before and after each spatial dimension. input_dilation: an integer or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of `inputs` (default: 1). + dilation factor to apply in each spatial dimension of `inputs` (default: + 1). Convolution with input dilation `d` is equivalent to transposed convolution with stride `d`. kernel_dilation: an integer or a sequence of `n` integers, giving the @@ -224,8 +264,8 @@ class Conv(Module): feature_group_count: integer, default 1. If specified divides the input features into groups. use_bias: whether to add a bias to the output (default: True). - dtype: the dtype of the computation (default: float32). - param_dtype: the dtype passed to parameter initializers (default: float32). + dtype: the dtype of the computation (default: None). + param_dtype: the dtype passed to parameter initializers (default: None). precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the convolutional kernel. @@ -239,16 +279,16 @@ class Conv(Module): kernel_dilation: Union[None, int, Iterable[int]] = 1 feature_group_count: int = 1 use_bias: bool = True - dtype: Dtype = jnp.float32 - param_dtype: Dtype = jnp.float32 + dtype: Optional[NumericDType] = None + param_dtype: Optional[NumericDType] = None precision: Any = None - kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros + kernel_init: Initializer = default_kernel_init + bias_init: Initializer = zeros @compact def __call__(self, inputs: Array) -> Array: """Applies a convolution to the inputs. - + Args: inputs: input data with dimensions (batch, spatial_dims..., features). This is the channels-last convention, i.e. NHWC for a 2d convolution @@ -259,8 +299,10 @@ def __call__(self, inputs: Array) -> Array: Returns: The convolved data. """ - - inputs = jnp.asarray(inputs, self.dtype) + param_dtype, dtype = _canonicalize_numeric_dtypes(inputs.dtype, + self.param_dtype, + self.dtype) + inputs = jnp.asarray(inputs, dtype) if isinstance(self.kernel_size, int): raise TypeError('The kernel size must be specified as a' @@ -290,8 +332,8 @@ def maybe_broadcast(x): assert in_features % self.feature_group_count == 0 kernel_shape = kernel_size + ( in_features // self.feature_group_count, self.features) - kernel = self.param('kernel', self.kernel_init, kernel_shape, self.param_dtype) - kernel = jnp.asarray(kernel, self.dtype) + kernel = self.param('kernel', self.kernel_init, kernel_shape, param_dtype) + kernel = jnp.asarray(kernel, dtype) if self.padding == 'CIRCULAR': kernel_size_dilated = [(k - 1) * d + 1 for k, d in zip(kernel_size, kernel_dilation)] @@ -316,8 +358,8 @@ def maybe_broadcast(x): if is_single_input: y = jnp.squeeze(y, axis=0) if self.use_bias: - bias = self.param('bias', self.bias_init, (self.features,), self.param_dtype) - bias = jnp.asarray(bias, self.dtype) + bias = self.param('bias', self.bias_init, (self.features,), param_dtype) + bias = jnp.asarray(bias, dtype) y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) return y @@ -340,8 +382,8 @@ class ConvTranspose(Module): kernel. Convolution with kernel dilation is also known as 'atrous convolution'. use_bias: whether to add a bias to the output (default: True). - dtype: the dtype of the computation (default: float32). - param_dtype: the dtype passed to parameter initializers (default: float32). + dtype: the dtype of the computation (default: None). + param_dtype: the dtype passed to parameter initializers (default: None). precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the convolutional kernel. @@ -353,11 +395,11 @@ class ConvTranspose(Module): padding: Union[str, Iterable[Tuple[int, int]]] = 'SAME' kernel_dilation: Optional[Iterable[int]] = None use_bias: bool = True - dtype: Dtype = jnp.float32 - param_dtype: Dtype = jnp.float32 + dtype: Optional[NumericDType] = None + param_dtype: Optional[NumericDType] = None precision: Any = None - kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros + kernel_init: Initializer = default_kernel_init + bias_init: Initializer = zeros @compact def __call__(self, inputs: Array) -> Array: @@ -374,7 +416,10 @@ def __call__(self, inputs: Array) -> Array: Returns: The convolved data. """ - inputs = jnp.asarray(inputs, self.dtype) + param_dtype, dtype = _canonicalize_numeric_dtypes(inputs.dtype, + self.param_dtype, + self.dtype) + inputs = jnp.asarray(inputs, dtype) if isinstance(self.kernel_size, int): kernel_size = (self.kernel_size,) @@ -390,8 +435,8 @@ def __call__(self, inputs: Array) -> Array: in_features = inputs.shape[-1] kernel_shape = kernel_size + (in_features, self.features) - kernel = self.param('kernel', self.kernel_init, kernel_shape, self.param_dtype) - kernel = jnp.asarray(kernel, self.dtype) + kernel = self.param('kernel', self.kernel_init, kernel_shape, param_dtype) + kernel = jnp.asarray(kernel, dtype) if self.padding == 'CIRCULAR': padding_lax = 'VALID' @@ -439,8 +484,8 @@ def __call__(self, inputs: Array) -> Array: if is_single_input: y = jnp.squeeze(y, axis=0) if self.use_bias: - bias = self.param('bias', self.bias_init, (self.features,), self.param_dtype) - bias = jnp.asarray(bias, self.dtype) + bias = self.param('bias', self.bias_init, (self.features,), param_dtype) + bias = jnp.asarray(bias, dtype) y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) return y @@ -456,15 +501,15 @@ class Embed(Module): Attributes: num_embeddings: number of embeddings. features: number of feature dimensions for each embedding. - dtype: the dtype of the embedding vectors (default: float32). + dtype: the dtype of the embedding vectors (default: None). param_dtype: the dtype passed to parameter initializers (default: float32). embedding_init: embedding initializer. """ num_embeddings: int features: int - dtype: Dtype = jnp.float32 - param_dtype: Dtype = jnp.float32 - embedding_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_embed_init + dtype: Optional[GenericDType] = None + param_dtype: GenericDType = jnp.float32 + embedding_init: Initializer = default_embed_init embedding: Array = field(init=False) @@ -484,10 +529,11 @@ def __call__(self, inputs): Output which is embedded input data. The output shape follows the input, with an additional `features` dimension appended. """ + dtype = self.param_dtype if self.dtype is None else self.dtype if not jnp.issubdtype(inputs.dtype, jnp.integer): raise ValueError('Input type must be an integer or unsigned integer.') # Use take because fancy indexing numpy arrays with JAX indices does not work correctly. - embedding = jnp.asarray(self.embedding, self.dtype) + embedding = jnp.asarray(self.embedding, dtype) return jnp.take(embedding, inputs, axis=0) def attend(self, query): @@ -502,6 +548,7 @@ def attend(self, query): Commonly used for weight-sharing between embeddings and logit transform in NLP models. """ - query = jnp.asarray(query, self.dtype) - embedding = jnp.asarray(self.embedding, self.dtype) + dtype = self.param_dtype if self.dtype is None else self.dtype + query = jnp.asarray(query, dtype) + embedding = jnp.asarray(self.embedding, dtype) return jnp.dot(query, embedding.T) diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index fb52569700..e213cbcda7 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -14,21 +14,24 @@ """Normalization modules for Flax.""" -from typing import (Any, Callable, Optional, Tuple, Iterable, Union) +from typing import Any, Callable, Optional, Tuple, Type, Iterable, Union from jax import lax from jax.nn import initializers import jax.numpy as jnp +import numpy as np from flax.linen.module import Module, compact, merge_param +from flax.linen.linear import _canonicalize_dtypes PRNGKey = Any Array = Any -Shape = Tuple[int] -Dtype = Any # this could be a real type? +Shape = Tuple[int, ...] +InexactDType = Type[np.inexact] +Initializer = Callable[[PRNGKey, Shape, InexactDType], Array] -Axes = Union[int, Iterable[int]] +Axes = Union[int, Tuple[int, ...]] def _canonicalize_axes(rank: int, axes: Axes) -> Iterable[int]: @@ -47,7 +50,7 @@ def _abs_sq(x): def _compute_stats(x: Array, axes: Axes, axis_name: Optional[str] = None, - axis_index_groups: Any = None): + axis_index_groups: Any = None) -> Tuple[Array, Array]: """Computes mean and variance statistics. This implementation takes care of a few important details: @@ -79,15 +82,18 @@ def _compute_stats(x: Array, axes: Axes, def _normalize(mdl: Module, x: Array, mean: Array, var: Array, reduction_axes: Axes, feature_axes: Axes, - dtype: Dtype, param_dtype: Dtype, + dtype: InexactDType, param_dtype: InexactDType, epsilon: float, use_bias: bool, use_scale: bool, bias_init: Callable[[PRNGKey, Shape, Dtype], Array], scale_init: Callable[[PRNGKey, Shape, Dtype], Array]): - """"Normalizes the input of a normalization layer and optionally applies a learned scale and bias. + """"Normalizes the input of a normalization layer and optionally applies a + learned scale and bias. A seperate bias and scale is learned for each feature as specified by feature_axes. """ + input_dtype = jnp.result_type(x, mean, var) + param_dtype, dtype = _canonicalize_dtypes(input_dtype, param_dtype, dtype) reduction_axes = _canonicalize_axes(x.ndim, reduction_axes) feature_axes = _canonicalize_axes(x.ndim, feature_axes) stats_shape = list(x.shape) @@ -100,16 +106,21 @@ def _normalize(mdl: Module, x: Array, mean: Array, var: Array, for ax in feature_axes: feature_shape[ax] = x.shape[ax] reduced_feature_shape.append(x.shape[ax]) + x = jnp.asarray(x, dtype) + mean = jnp.asarray(mean, dtype) + var = jnp.asarray(var, dtype) y = x - mean mul = lax.rsqrt(var + epsilon) if use_scale: scale = mdl.param('scale', scale_init, reduced_feature_shape, param_dtype).reshape(feature_shape) + scale = jnp.asarray(scale, dtype) mul *= scale y *= mul if use_bias: bias = mdl.param('bias', bias_init, reduced_feature_shape, param_dtype).reshape(feature_shape) + bias = jnp.asarray(bias, dtype) y += bias return jnp.asarray(y, dtype) @@ -151,8 +162,8 @@ class BatchNorm(Module): momentum: decay rate for the exponential moving average of the batch statistics. epsilon: a small float added to variance to avoid dividing by zero. - dtype: the dtype of the computation (default: float32). - param_dtype: the dtype passed to parameter initializers (default: float32). + dtype: the dtype of the computation (default: None). + param_dtype: the dtype passed to parameter initializers (default: None). use_bias: if True, bias (beta) is added. use_scale: if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled @@ -171,17 +182,19 @@ class BatchNorm(Module): axis: int = -1 momentum: float = 0.99 epsilon: float = 1e-5 - dtype: Dtype = jnp.float32 - param_dtype: Dtype = jnp.float32 + dtype: Optional[InexactDType] = None + param_dtype: Optional[InexactDType] = None use_bias: bool = True use_scale: bool = True - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros - scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones + bias_init: Initializer = initializers.zeros + scale_init: Initializer = initializers.ones axis_name: Optional[str] = None axis_index_groups: Any = None @compact - def __call__(self, x, use_running_average: Optional[bool] = None): + def __call__(self, + x: Array, + use_running_average: Optional[bool] = None) -> Array: """Normalizes the input using batch statistics. NOTE: @@ -199,6 +212,9 @@ def __call__(self, x, use_running_average: Optional[bool] = None): Returns: Normalized inputs (the same shape as inputs). """ + param_dtype, dtype = _canonicalize_dtypes(x.dtype, self.param_dtype, + self.dtype) + x = jnp.asarray(x, dtype) use_running_average = merge_param( 'use_running_average', self.use_running_average, use_running_average) @@ -210,10 +226,10 @@ def __call__(self, x, use_running_average: Optional[bool] = None): initializing = self.is_mutable_collection('params') ra_mean = self.variable('batch_stats', 'mean', - lambda s: jnp.zeros(s, jnp.float32), + lambda s: jnp.zeros(s, dtype), feature_shape) ra_var = self.variable('batch_stats', 'var', - lambda s: jnp.ones(s, jnp.float32), + lambda s: jnp.ones(s, dtype), feature_shape) if use_running_average: @@ -225,14 +241,14 @@ def __call__(self, x, use_running_average: Optional[bool] = None): axis_index_groups=self.axis_index_groups) if not initializing: - ra_mean.value = self.momentum * ra_mean.value + (1 - self.momentum) * mean + ra_mean.value = (self.momentum * ra_mean.value + (1 - self.momentum) * + mean) ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var return _normalize( - self, x, mean, var, reduction_axes, feature_axes, - self.dtype, self.param_dtype, self.epsilon, - self.use_bias, self.use_scale, - self.bias_init, self.scale_init) + self, x, mean, var, reduction_axes, feature_axes, dtype, param_dtype, + self.epsilon, self.use_bias, self.use_scale, self.bias_init, + self.scale_init) class LayerNorm(Module): @@ -246,8 +262,8 @@ class LayerNorm(Module): Attributes: epsilon: A small float added to variance to avoid dividing by zero. - dtype: the dtype of the computation (default: float32). - param_dtype: the dtype passed to parameter initializers (default: float32). + dtype: the dtype of the computation (default: None). + param_dtype: the dtype passed to parameter initializers (default: None). use_bias: If True, bias (beta) is added. use_scale: If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done @@ -256,12 +272,12 @@ class LayerNorm(Module): scale_init: Initializer for scale, by default, one. """ epsilon: float = 1e-6 - dtype: Any = jnp.float32 - param_dtype: Dtype = jnp.float32 + dtype: Optional[InexactDType] = None + param_dtype: Optional[InexactDType] = None use_bias: bool = True use_scale: bool = True - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros - scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones + bias_init: Initializer = initializers.zeros + scale_init: Initializer = initializers.ones @compact def __call__(self, x): @@ -273,6 +289,9 @@ def __call__(self, x): Returns: Normalized inputs (the same shape as inputs). """ + param_dtype, dtype = _canonicalize_dtypes(x.dtype, self.param_dtype, + self.dtype) + x = jnp.asarray(x, dtype) reduction_axes = (-1,) feature_axes = (-1,) @@ -280,10 +299,9 @@ def __call__(self, x): mean, var = _compute_stats(x, reduction_axes, None, None) return _normalize( - self, x, mean, var, reduction_axes, feature_axes, - self.dtype, self.param_dtype, self.epsilon, - self.use_bias, self.use_scale, - self.bias_init, self.scale_init) + self, x, mean, var, reduction_axes, feature_axes, dtype, param_dtype, + self.epsilon, self.use_bias, self.use_scale, self.bias_init, + self.scale_init) class GroupNorm(Module): @@ -301,8 +319,9 @@ class GroupNorm(Module): proposed by the original group normalization paper. group_size: the number of channels in a group. epsilon: A small float added to variance to avoid dividing by zero. - dtype: the dtype of the computation (default: float32). - param_dtype: the dtype passed to parameter initializers (default: float32). + dtype: the dtype of the computation (default: None). + param_dtype: the dtype passed to parameter initializers (default: + None). use_bias: If True, bias (beta) is added. use_scale: If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done @@ -313,12 +332,12 @@ class GroupNorm(Module): num_groups: Optional[int] = 32 group_size: Optional[int] = None epsilon: float = 1e-6 - dtype: Any = jnp.float32 - param_dtype: Dtype = jnp.float32 + dtype: Optional[InexactDType] = None + param_dtype: Optional[InexactDType] = None use_bias: bool = True use_scale: bool = True - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros - scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones + bias_init: Initializer = initializers.zeros + scale_init: Initializer = initializers.ones @compact def __call__(self, x): @@ -332,7 +351,10 @@ def __call__(self, x): Returns: Normalized inputs (the same shape as inputs). """ - reduction_axes = list(range(1, x.ndim - 1)) + [-1] + param_dtype, dtype = _canonicalize_dtypes(x.dtype, self.param_dtype, + self.dtype) + x = jnp.asarray(x, dtype) + reduction_axes = tuple(range(1, x.ndim - 1)) + (-1,) feature_axes = (-1,) if ((self.num_groups is None and self.group_size is None) or @@ -365,7 +387,6 @@ def broadcast_stat(stat): var = broadcast_stat(var) return _normalize( - self, x, mean, var, reduction_axes[:-1], feature_axes, - self.dtype, self.param_dtype, self.epsilon, - self.use_bias, self.use_scale, + self, x, mean, var, reduction_axes[:-1], feature_axes, dtype, + param_dtype, self.epsilon, self.use_bias, self.use_scale, self.bias_init, self.scale_init) diff --git a/flax/linen/recurrent.py b/flax/linen/recurrent.py index 275da5343c..207be9324d 100644 --- a/flax/linen/recurrent.py +++ b/flax/linen/recurrent.py @@ -27,10 +27,11 @@ from functools import partial from typing import (Any, Callable, Iterable, Optional, Tuple, Union) -from flax.linen.module import Module, compact from flax.linen.activation import sigmoid, tanh from flax.linen.initializers import orthogonal, zeros from flax.linen.linear import Conv, Dense, default_kernel_init +from flax.linen.linear import _canonicalize_dtypes +from flax.linen.module import Module, compact from jax import numpy as jnp from jax import lax @@ -39,18 +40,23 @@ import numpy as np PRNGKey = Any -Shape = Tuple[int] -Dtype = Any # this could be a real type? +Shape = Tuple[int, ...] +InexactDType = Type[np.inexact] Array = Any +ScalarFunction = Callable[[Array], Array] +Initializer = Callable[[PRNGKey, Shape, InexactDType], Array] +LSTMCarry = Tuple[Array, Array] class RNNCellBase(Module): """RNN cell base class.""" - - @staticmethod @abc.abstractmethod - def initialize_carry(rng, batch_dims, size, init_fn=zeros): - """initialize the RNN cell carry. + def initialize_carry(self, + rng: PRNGKey, + batch_dims: Tuple[int, ...], + size: int, + init_fn: Initializer = zeros) -> LSTMCarry: + """Initialize the RNN cell carry. Args: rng: random number generator passed to the init_fn. @@ -87,16 +93,16 @@ class LSTMCell(RNNCellBase): recurrent_kernel_init: initializer function for the kernels that transform the hidden state (default: orthogonal). bias_init: initializer for the bias parameters (default: zeros) - dtype: the dtype of the computation (default: float32). - param_dtype: the dtype passed to parameter initializers (default: float32). + dtype: the dtype of the computation and carry (default: float32). + param_dtype: the dtype passed to parameter initializers (default: None). """ - gate_fn: Callable[..., Any] = sigmoid - activation_fn: Callable[..., Any] = tanh - kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - recurrent_kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = orthogonal() - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros - dtype: Dtype = jnp.float32 - param_dtype: Dtype = jnp.float32 + gate_fn: ScalarFunction = sigmoid + activation_fn: ScalarFunction = tanh + kernel_init: Initializer = default_kernel_init + recurrent_kernel_init: Initializer = orthogonal() + bias_init: Initializer = zeros + dtype: InexactDType = jnp.float32 + param_dtype: Optional[InexactDType] = None @compact def __call__(self, carry, inputs): @@ -111,6 +117,9 @@ def __call__(self, carry, inputs): Returns: A tuple with the new carry and the output. """ + param_dtype, dtype = _canonicalize_dtypes(inputs.dtype, self.param_dtype, + self.dtype) + inputs = jnp.asarray(inputs, dtype) c, h = carry hidden_features = h.shape[-1] # input and recurrent layers are summed so only one needs a bias. @@ -135,8 +144,11 @@ def __call__(self, carry, inputs): new_h = o * self.activation_fn(new_c) return (new_c, new_h), new_h - @staticmethod - def initialize_carry(rng, batch_dims, size, init_fn=zeros): + def initialize_carry(self, + rng: PRNGKey, + batch_dims: Tuple[int, ...], + size: int, + init_fn: Initializer = zeros) -> LSTMCarry: """initialize the RNN cell carry. Args: @@ -149,19 +161,19 @@ def initialize_carry(rng, batch_dims, size, init_fn=zeros): """ key1, key2 = random.split(rng) mem_shape = batch_dims + (size,) - return init_fn(key1, mem_shape), init_fn(key2, mem_shape) + return (init_fn(key1, mem_shape, self.dtype), + init_fn(key2, mem_shape, self.dtype)) class DenseParams(Module): """Dummy module for creating parameters matching `flax.nn.Dense`.""" - features: int use_bias: bool = True - dtype: Dtype = jnp.float32 - param_dtype: Dtype = jnp.float32 + dtype: Optional[InexactDType] = None + param_dtype: Optional[InexactDType] = None precision: Any = None - kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros + kernel_init: Initializer = default_kernel_init + bias_init: Initializer = zeros @compact def __call__(self, inputs: Array) -> Tuple[Array, Array]: @@ -201,15 +213,15 @@ class OptimizedLSTMCell(RNNCellBase): the hidden state (default: orthogonal). bias_init: initializer for the bias parameters (default: zeros). dtype: the dtype of the computation (default: float32). - param_dtype: the dtype passed to parameter initializers (default: float32). + param_dtype: the dtype passed to parameter initializers (default: None). """ - gate_fn: Callable = sigmoid - activation_fn: Callable = tanh - kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - recurrent_kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = orthogonal() - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros - dtype: Dtype = jnp.float32 - param_dtype: Dtype = jnp.float32 + gate_fn: ScalarFunction = sigmoid + activation_fn: ScalarFunction = tanh + kernel_init: Initializer = default_kernel_init + recurrent_kernel_init: Initializer = orthogonal() + bias_init: Initializer = zeros + dtype: InexactDType = jnp.float32 + param_dtype: Optional[InexactDType] = None @compact def __call__(self, carry: Tuple[Array, Array], @@ -218,29 +230,33 @@ def __call__(self, carry: Tuple[Array, Array], Args: carry: the hidden state of the LSTM cell, initialized using - `LSTMCell.initialize_carry`. + `initialize_carry`. inputs: an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions. Returns: A tuple with the new carry and the output. """ + param_dtype, dtype = _canonicalize_dtypes(inputs.dtype, self.param_dtype, + self.dtype) + inputs = jnp.asarray(inputs, dtype) c, h = carry + c = jnp.asarray(c, dtype) + h = jnp.asarray(h, dtype) hidden_features = h.shape[-1] - inputs = jnp.asarray(inputs, self.dtype) def _concat_dense(inputs, params, use_bias=True): """ - Concatenates the individual kernels and biases, given in params, into a - single kernel and single bias for efficiency before applying them using + Concatenates the individual kernels and biases, given in params, into a + single kernel and single bias for efficiency before applying them using dot_general. """ kernels, biases = zip(*params.values()) - kernel = jnp.asarray(jnp.concatenate(kernels, axis=-1), self.dtype) + kernel = jnp.asarray(jnp.concatenate(kernels, axis=-1), dtype) y = jnp.dot(inputs, kernel) if use_bias: - bias = jnp.asarray(jnp.concatenate(biases, axis=-1), self.dtype) + bias = jnp.asarray(jnp.concatenate(biases, axis=-1), dtype) y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) # Split the result back into individual (i, f, g, o) outputs. @@ -254,12 +270,12 @@ def _concat_dense(inputs, params, use_bias=True): for component in ['i', 'f', 'g', 'o']: dense_params_i[component] = DenseParams( features=hidden_features, use_bias=False, - param_dtype=self.param_dtype, + param_dtype=param_dtype, kernel_init=self.kernel_init, bias_init=self.bias_init, name=f'i{component}')(inputs) dense_params_h[component] = DenseParams( features=hidden_features, use_bias=True, - param_dtype=self.param_dtype, + param_dtype=param_dtype, kernel_init=self.recurrent_kernel_init, bias_init=self.bias_init, name=f'h{component}')(h) dense_h = _concat_dense(h, dense_params_h, use_bias=True) @@ -274,8 +290,11 @@ def _concat_dense(inputs, params, use_bias=True): new_h = o * self.activation_fn(new_c) return (new_c, new_h), new_h - @staticmethod - def initialize_carry(rng, batch_dims, size, init_fn=zeros): + def initialize_carry(self, + rng: PRNGKey, + batch_dims: Tuple[int, ...], + size: int, + init_fn: Initializer = zeros) -> LSTMCarry: """initialize the RNN cell carry. Args: @@ -289,7 +308,8 @@ def initialize_carry(rng, batch_dims, size, init_fn=zeros): """ key1, key2 = random.split(rng) mem_shape = batch_dims + (size,) - return init_fn(key1, mem_shape), init_fn(key2, mem_shape) + return (init_fn(key1, mem_shape, self.dtype), + init_fn(key2, mem_shape, self.dtype)) class GRUCell(RNNCellBase): @@ -315,17 +335,15 @@ class GRUCell(RNNCellBase): the hidden state (default: orthogonal). bias_init: initializer for the bias parameters (default: zeros) dtype: the dtype of the computation (default: float32). - param_dtype: the dtype passed to parameter initializers (default: float32). + param_dtype: the dtype passed to parameter initializers (default: None). """ - gate_fn: Callable[..., Any] = sigmoid - activation_fn: Callable[..., Any] = tanh - kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = ( - default_kernel_init) - recurrent_kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = ( - orthogonal()) - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros - dtype: Dtype = jnp.float32 - param_dtype: Dtype = jnp.float32 + gate_fn: ScalarFunction = sigmoid + activation_fn: ScalarFunction = tanh + kernel_init: Initializer = default_kernel_init + recurrent_kernel_init: Initializer = orthogonal() + bias_init: Initializer = zeros + dtype: InexactDType = jnp.float32 + param_dtype: Optional[InexactDType] = None @compact def __call__(self, carry, inputs): @@ -340,20 +358,23 @@ def __call__(self, carry, inputs): Returns: A tuple with the new carry and the output. """ + param_dtype, dtype = _canonicalize_dtypes(inputs.dtype, self.param_dtype, + self.dtype) + inputs = jnp.asarray(inputs, dtype) h = carry hidden_features = h.shape[-1] # input and recurrent layers are summed so only one needs a bias. dense_h = partial(Dense, features=hidden_features, use_bias=False, - dtype=self.dtype, + dtype=dtype, param_dtype=self.param_dtype, kernel_init=self.recurrent_kernel_init, bias_init=self.bias_init) dense_i = partial(Dense, features=hidden_features, use_bias=True, - dtype=self.dtype, + dtype=dtype, param_dtype=self.param_dtype, kernel_init=self.kernel_init, bias_init=self.bias_init) @@ -365,8 +386,11 @@ def __call__(self, carry, inputs): new_h = (1. - z) * n + z * h return new_h, new_h - @staticmethod - def initialize_carry(rng, batch_dims, size, init_fn=zeros): + def initialize_carry(self, + rng: PRNGKey, + batch_dims: Tuple[int, ...], + size: int, + init_fn: Initializer = zeros) -> LSTMCarry: """initialize the RNN cell carry. Args: @@ -378,7 +402,7 @@ def initialize_carry(rng, batch_dims, size, init_fn=zeros): An initialized carry for the given RNN cell. """ mem_shape = batch_dims + (size,) - return init_fn(rng, mem_shape) + return init_fn(rng, mem_shape, self.dtype) class ConvLSTM(RNNCellBase): @@ -419,16 +443,15 @@ class ConvLSTM(RNNCellBase): and after each spatial dimension. bias: whether to add a bias to the output (default: True). dtype: the dtype of the computation (default: float32). - param_dtype: the dtype passed to parameter initializers (default: float32). + param_dtype: the dtype passed to parameter initializers (default: None). """ - features: int kernel_size: Iterable[int] strides: Optional[Iterable[int]] = None padding: Union[str, Iterable[Tuple[int, int]]] = 'SAME' use_bias: bool = True - dtype: Dtype = jnp.float32 - param_dtype: Dtype = jnp.float32 + dtype: InexactDType = jnp.float32 + param_dtype: Optional[InexactDType] = None @compact def __call__(self, carry, inputs): @@ -441,6 +464,9 @@ def __call__(self, carry, inputs): Returns: A tuple with the new carry and the output. """ + param_dtype, dtype = _canonicalize_dtypes(inputs.dtype, self.param_dtype, + self.dtype) + inputs = jnp.asarray(inputs, dtype) c, h = carry input_to_hidden = partial(Conv, features=4*self.features, @@ -448,8 +474,8 @@ def __call__(self, carry, inputs): strides=self.strides, padding=self.padding, use_bias=self.use_bias, - dtype=self.dtype, - param_dtype=self.param_dtype, + dtype=dtype, + param_dtype=param_dtype, name='ih') hidden_to_hidden = partial(Conv, @@ -458,8 +484,8 @@ def __call__(self, carry, inputs): strides=self.strides, padding=self.padding, use_bias=self.use_bias, - dtype=self.dtype, - param_dtype=self.param_dtype, + dtype=dtype, + param_dtype=param_dtype, name='hh') gates = input_to_hidden()(inputs) + hidden_to_hidden()(h) @@ -470,9 +496,12 @@ def __call__(self, carry, inputs): new_h = sigmoid(o) * jnp.tanh(new_c) return (new_c, new_h), new_h - @staticmethod - def initialize_carry(rng, batch_dims, size, init_fn=zeros): - """initialize the RNN cell carry. + def initialize_carry(self, + rng: PRNGKey, + batch_dims: Tuple[int, ...], + size: Tuple[int, ...], + init_fn: Initializer = zeros) -> LSTMCarry: + """Initialize the RNN cell carry. Args: rng: random number generator passed to the init_fn. @@ -484,4 +513,5 @@ def initialize_carry(rng, batch_dims, size, init_fn=zeros): """ key1, key2 = random.split(rng) mem_shape = batch_dims + size - return init_fn(key1, mem_shape), init_fn(key2, mem_shape) + return (init_fn(key1, mem_shape, self.dtype), + init_fn(key2, mem_shape, self.dtype)) diff --git a/setup.py b/setup.py index 4868aad8db..7f5d64896b 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ "matplotlib", # only needed for tensorboard export "msgpack", "optax", + "typing_extensions>=3.10", ] tests_require = [ diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index f155dde034..44e191365f 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -310,7 +310,7 @@ def setup(self): # NOTE that keys still must be strings. This is to make a possible # future transition to automatically derived parameter names when assigned # as a dict easier (like we currently have with submodules). - # See a bit of discussion here: https://github.com/google/flax/issues/705#issuecomment-738761853 + # See a bit of discussion here: https://github.com/google/flax/issues/705#issuecomment-738761853 str(i): self.param(f'bias_{i}', initializers.ones, self.xshape) for i in range(4)} def __call__(self, x): @@ -657,8 +657,8 @@ def __call__(self, x): # attributes features = 3 use_bias = True - dtype = float32 - param_dtype = float32 + dtype = None + param_dtype = None precision = None kernel_init = init bias_init = zeros @@ -667,8 +667,8 @@ def __call__(self, x): # attributes features = 2 use_bias = True - dtype = float32 - param_dtype = float32 + dtype = None + param_dtype = None precision = None kernel_init = init bias_init = zeros @@ -1394,14 +1394,14 @@ def test_rng_reuse_after_rewind(self): class C(nn.Module): @nn.compact def __call__(self): - # Some module that has dropouts in it, in general, + # Some module that has dropouts in it, in general, # it does more than just dropout! return self.make_rng('dropout') class A(nn.Module): @nn.compact def __call__(self): - # Some module that has dropouts in it, in general, + # Some module that has dropouts in it, in general, # it does more than just dropout! return C()() diff --git a/tests/linen/linen_test.py b/tests/linen/linen_test.py index 892974d50b..9c9d1d963d 100644 --- a/tests/linen/linen_test.py +++ b/tests/linen/linen_test.py @@ -240,10 +240,10 @@ def test_lstm(self): rng = random.PRNGKey(0) key1, key2 = random.split(rng) x = random.normal(key1, (2, 3)) - c0, h0 = nn.LSTMCell.initialize_carry(rng, (2,), 4) + lstm = nn.LSTMCell() + c0, h0 = lstm.initialize_carry(rng, (2,), 4) self.assertEqual(c0.shape, (2, 4)) self.assertEqual(h0.shape, (2, 4)) - lstm = nn.LSTMCell() (carry, y), initial_params = lstm.init_with_output(key2, (c0, h0), x) self.assertEqual(carry[0].shape, (2, 4)) self.assertEqual(carry[1].shape, (2, 4)) @@ -264,9 +264,9 @@ def test_gru(self): rng = random.PRNGKey(0) key1, key2 = random.split(rng) x = random.normal(key1, (2, 3)) - carry0 = nn.GRUCell.initialize_carry(rng, (2,), 4) - self.assertEqual(carry0.shape, (2, 4)) gru = nn.GRUCell() + carry0 = gru.initialize_carry(rng, (2,), 4) + self.assertEqual(carry0.shape, (2, 4)) (carry, y), initial_params = gru.init_with_output(key2, carry0, x) #gru = nn.Model(nn.GRUCell, initial_params) self.assertEqual(carry.shape, (2, 4)) @@ -285,10 +285,10 @@ def test_convlstm(self): rng = random.PRNGKey(0) key1, key2 = random.split(rng) x = random.normal(key1, (2, 4, 4, 3)) - c0, h0 = nn.ConvLSTM.initialize_carry(rng, (2,), (4, 4, 6)) + lstm = nn.ConvLSTM(features=6, kernel_size=(3, 3)) + c0, h0 = lstm.initialize_carry(rng, (2,), (4, 4, 6)) self.assertEqual(c0.shape, (2, 4, 4, 6)) self.assertEqual(h0.shape, (2, 4, 4, 6)) - lstm = nn.ConvLSTM(features=6, kernel_size=(3, 3)) (carry, y), initial_params = lstm.init_with_output(key2, (c0, h0), x) self.assertEqual(carry[0].shape, (2, 4, 4, 6)) self.assertEqual(carry[1].shape, (2, 4, 4, 6)) @@ -298,31 +298,31 @@ def test_convlstm(self): 'hh': {'bias': (6*4,), 'kernel': (3, 3, 6, 6*4)}, 'ih': {'bias': (6*4,), 'kernel': (3, 3, 3, 6*4)}, }) - + def test_optimized_lstm_cell_matches_regular(self): # Create regular LSTMCell. rng = random.PRNGKey(0) key1, key2 = random.split(rng) x = random.normal(key1, (2, 3)) - c0, h0 = nn.LSTMCell.initialize_carry(rng, (2,), 4) + lstm = nn.LSTMCell() + c0, h0 = lstm.initialize_carry(rng, (2,), 4) self.assertEqual(c0.shape, (2, 4)) self.assertEqual(h0.shape, (2, 4)) - lstm = nn.LSTMCell() - (_, y), lstm_params = lstm.init_with_output(key2, (c0, h0), x) - + (_, y), lstm_params = lstm.init_with_output(key2, (c0, h0), x) + # Create OptimizedLSTMCell. rng = random.PRNGKey(0) key1, key2 = random.split(rng) x = random.normal(key1, (2, 3)) - c0, h0 = nn.OptimizedLSTMCell.initialize_carry(rng, (2,), 4) self.assertEqual(c0.shape, (2, 4)) self.assertEqual(h0.shape, (2, 4)) lstm_opt = nn.OptimizedLSTMCell() - (_, y_opt), lstm_opt_params = lstm_opt.init_with_output(key2, (c0, h0), x) - + c0, h0 = lstm_opt.initialize_carry(rng, (2,), 4) + (_, y_opt), lstm_opt_params = lstm_opt.init_with_output(key2, (c0, h0), x) + np.testing.assert_allclose(y, y_opt, rtol=1e-6) - jtu.check_eq(lstm_params, lstm_opt_params) + jtu.check_eq(lstm_params, lstm_opt_params) if __name__ == '__main__': diff --git a/tests/linen/linen_transforms_test.py b/tests/linen/linen_transforms_test.py index 58f7b2d3b0..145035d1b0 100644 --- a/tests/linen/linen_transforms_test.py +++ b/tests/linen/linen_transforms_test.py @@ -213,9 +213,9 @@ def __call__(self, c, xs): key1, key2 = random.split(random.PRNGKey(0), 2) xs = random.uniform(key1, (5, 3, 2)) dummy_rng = random.PRNGKey(0) - init_carry = nn.LSTMCell.initialize_carry(dummy_rng, - xs.shape[1:-1], - xs.shape[-1]) + init_carry = nn.LSTMCell().initialize_carry(dummy_rng, + xs.shape[1:-1], + xs.shape[-1]) model = SimpleScan() init_variables = model.init(key2, init_carry, xs) # simulate scan in python for comparison: @@ -247,9 +247,9 @@ def __call__(self, c, b, xs): xs = random.uniform(key1, (4, 3, 2)) b = jnp.ones((4,)) dummy_rng = random.PRNGKey(0) - init_carry = nn.LSTMCell.initialize_carry(dummy_rng, - xs.shape[1:-1], - xs.shape[-1]) + init_carry = nn.LSTMCell().initialize_carry(dummy_rng, + xs.shape[1:-1], + xs.shape[-1]) model = SimpleScan() init_variables = model.init(key2, init_carry, b, xs) # simulate scan in python for comparison: