Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove float32 dtype assumption #1803

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ vNext
- Added Optax update guide and deprecated `flax.optim`.
- Added `sep` argument to `flax.traverse_util.flatten_dict()`.
-
- Remove float32 dtype assumption throughout linen (except for reccurent
modules).
-
- Added locally-connected (unshared CNN) layer `flax.linen.ConvLocal`.
-
Expand Down Expand Up @@ -48,7 +50,7 @@ Breaking changes:
- You can no longer pass an int as the `kernel_size` for a `flax.linen.Conv.
Instead a type error is raised stating that
a tuple/list should be provided. Stride and dilation arguments do support broadcasting a single int value now because this is not
ambigious when the kernel rank is known.
ambigious when the kernel rank is known.
- `flax.linen.enable_named_call` and `flax.linen.disable_named_call` now work anywhere instead of only affecting Modules constructed after the enable/disable call. Additionally, there is now `flax.linen.override_named_call` that provided a context manager to locally disable/enable named_call.
- NamedTuples are no longer converted to tuples on assignment to a `linen.Module`.

Expand All @@ -64,15 +66,15 @@ Bugfixes:
- Fix the serialization of named tuples. Tuple fields are no longer stored in the state dict and the named tuple class is no longer recreated ([bug](https:/google/flax/issues/1429)).
- Mixed precision training with float16 now works correctly with the attention layers.
- auto-generated linen Module `__hash__`, `__eq__`, `__repr__` no longer fail by default on non-init attributes.



0.3.4
------

Possibly breaking changes:
- When calling `init` the 'intermediates' collection is no longer mutable.
Therefore, intermediates will no longer be returned from initialization by default.
Therefore, intermediates will no longer be returned from initialization by default.
- Don't update batch statistics during initialization.
- When not using any non-determinism (e.g., dropout), it is not longer necessary to specify the `deterministic` argument in `MultiHeadDotProductAttention`.

Expand Down Expand Up @@ -105,9 +107,9 @@ Possible breaking changes:
latest checkpoint already saved.
- MultiOptimizer now rejects the case where multiple sub optimizers update the
same parameter.

Other changes:
- Added custom error classes to many Linen errors. See:
- Added custom error classes to many Linen errors. See:
https://flax.readthedocs.io/en/latest/flax.errors.html
- Adds `Module.bind` for binding variables and RNGs to an interactive Module.
- Adds `nn.apply` and `nn.init` for transforming arbitrary functions that take a `linen.Module` as their first argument.
Expand All @@ -127,7 +129,7 @@ NOTE: You must now explicitly import `flax.nn` if you want to use the old
0.3.1
------

Many improvements to Linen, and the old `flax.nn` is officially reprecated!
Many improvements to Linen, and the old `flax.nn` is officially reprecated!

Notably, there's a clean API for extracting intermediates from modules
defined using `@nn.compact`, a more ergonomic API for using Batch Norm and Dropout in modules
Expand All @@ -141,7 +143,7 @@ Possible breaking changes:
is enforced by raising a TypeError in `__setattr__` after `setup`.
- Pytrees of dicts and lists are transformed into FrozenDict and tuples during
attribute assignment.
This avoids undetected submodules and inner state.
This avoids undetected submodules and inner state.
- Bug Fix `flax.core.apply` and `Module.apply`. Now it returns a tuple
containing the output and a frozen empty
collection when `mutable` is specified as an empty list.
Expand Down
7 changes: 5 additions & 2 deletions flax/linen/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,11 @@ def __call__(self, inputs: Array) -> Array:
Returns:
The transformed input.
"""
dtype = inputs.dtype
NeilGirdhar marked this conversation as resolved.
Show resolved Hide resolved
assert jnp.issubdtype(dtype, jnp.floating)
negative_slope = self.param(
'negative_slope',
lambda k: jnp.asarray(self.negative_slope_init, jnp.float32)
lambda k: jnp.asarray(self.negative_slope_init, dtype)
NeilGirdhar marked this conversation as resolved.
Show resolved Hide resolved
)
return jnp.where(inputs >= 0, inputs, jnp.asarray(negative_slope, inputs.dtype) * inputs)
assert negative_slope.shape == ()
NeilGirdhar marked this conversation as resolved.
Show resolved Hide resolved
return jnp.where(inputs >= 0, inputs, negative_slope * inputs)
152 changes: 92 additions & 60 deletions flax/linen/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,42 @@
"""Attention core modules for Flax."""

from functools import partial
from typing import (Any, Callable, Tuple, Optional)
from typing import Any, Callable, Tuple, Type, Optional

import jax
from jax import lax
from jax import random
import jax.numpy as jnp
import numpy as np
from typing_extensions import Protocol

from flax.linen.linear import default_kernel_init
from flax.linen.initializers import zeros
from flax.linen.linear import DenseGeneral
from flax.linen.linear import _canonicalize_dtypes
from flax.linen.linear import default_kernel_init
from flax.linen.module import Module, compact, merge_param
from flax.linen.initializers import zeros

PRNGKey = Any
Shape = Tuple[int]
Dtype = Any
Shape = Tuple[int, ...]
InexactDType = Type[jnp.inexact]
Array = Any
Initializer = Callable[[PRNGKey, Shape, InexactDType], Array]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR, but it might be useful also for downstream packages to put the definitions of those types in a flax.typing or flax.types submodule?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, Initializer definitely belongs somewhere better.

FYI, I suggested that Shape be added to numpy ages ago, and they rightly turned down the suggestion. I think Stephen's reasoning applies here that the others are probably too short and simple?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added jax-ml/jax#9596, which if accepted, would put Initializer in jax.nn where I think it belongs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, I just found out that MyPy has a crazy bug if you name a file typing or types: python/mypy#1876



class AttentionFunction(Protocol):
@staticmethod
def __call__(query: Array,
key: Array,
value: Array,
bias: Optional[Array] = None,
mask: Optional[Array] = None,
broadcast_dropout: bool = True,
dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0.,
deterministic: bool = False,
dtype: InexactDType = jnp.float32,
precision: Optional[lax.Precision] = None) -> Array:
...


def dot_product_attention_weights(query: Array,
Expand All @@ -42,7 +61,7 @@ def dot_product_attention_weights(query: Array,
dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0.,
deterministic: bool = False,
dtype: Dtype = jnp.float32,
dtype: InexactDType = jnp.float32,
precision: Optional[lax.Precision] = None):
"""Computes dot-product attention weights given query and key.

Expand Down Expand Up @@ -109,7 +128,7 @@ def dot_product_attention_weights(query: Array,
keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
else:
keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
multiplier = (keep.astype(attn_weights.dtype) /
multiplier = (keep.astype(dtype) /
jnp.asarray(keep_prob, dtype=dtype))
attn_weights = attn_weights * multiplier

Expand All @@ -125,8 +144,8 @@ def dot_product_attention(query: Array,
dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0.,
deterministic: bool = False,
dtype: Dtype = jnp.float32,
precision: Optional[lax.Precision] = None):
dtype: InexactDType = jnp.float32,
precision: Optional[lax.Precision] = None) -> Array:
"""Computes dot-product attention given query, key, and value.

This is the core function for applying attention based on
Expand Down Expand Up @@ -179,52 +198,27 @@ def dot_product_attention(query: Array,
precision=precision)


class MultiHeadDotProductAttention(Module):
"""Multi-head dot-product attention.

Attributes:
num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
should be divisible by the number of heads.
dtype: the dtype of the computation (default: float32)
param_dtype: the dtype passed to parameter initializers (default: float32).
qkv_features: dimension of the key, query, and value.
out_features: dimension of the last projection
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rate: dropout rate
deterministic: if false, the attention weight is masked randomly
using dropout, whereas if true, the attention weights
are deterministic.
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
kernel_init: initializer for the kernel of the Dense layers.
bias_init: initializer for the bias of the Dense layers.
use_bias: bool: whether pointwise QKVO dense transforms use bias.
attention_fn: dot_product_attention or compatible function. Accepts
query, key, value, and returns output of shape
`[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]``
decode: whether to prepare and use an autoregressive cache.
"""
class _BaseMultiHeadDotProductAttention(Module):
num_heads: int
dtype: Dtype = jnp.float32
param_dtype: Dtype = jnp.float32
dtype: Optional[InexactDType] = None
param_dtype: Optional[InexactDType] = None
qkv_features: Optional[int] = None
out_features: Optional[int] = None
broadcast_dropout: bool = True
dropout_rate: float = 0.
deterministic: Optional[bool] = None
precision: Optional[lax.Precision] = None
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros
kernel_init: Initializer = default_kernel_init
bias_init: Initializer = zeros
use_bias: bool = True
attention_fn: Callable[[Array, Array, Array], Array] = dot_product_attention
attention_fn: AttentionFunction = dot_product_attention
decode: bool = False

@compact
def __call__(self,
inputs_q: Array,
inputs_kv: Array,
mask: Optional[Array] = None,
deterministic: Optional[bool] = None):
def _apply(self,
inputs_q: Array,
inputs_kv: Array,
mask: Optional[Array] = None,
deterministic: Optional[bool] = None):
"""Applies multi-head dot product attention on the input data.

Projects the inputs into multi-headed query, key, and value vectors,
Expand All @@ -246,6 +240,10 @@ def __call__(self,
Returns:
output of shape `[batch_sizes..., length, features]`.
"""
param_dtype, dtype = _canonicalize_dtypes(jnp.result_type(inputs_q,
inputs_kv),
self.param_dtype,
self.dtype)
features = self.out_features or inputs_q.shape[-1]
qkv_features = self.qkv_features or inputs_q.shape[-1]
assert qkv_features % self.num_heads == 0, (
Expand All @@ -254,8 +252,8 @@ def __call__(self,

dense = partial(DenseGeneral,
axis=-1,
dtype=self.dtype,
param_dtype=self.param_dtype,
dtype=dtype,
param_dtype=param_dtype,
features=(self.num_heads, head_dim),
kernel_init=self.kernel_init,
bias_init=self.bias_init,
Expand Down Expand Up @@ -323,38 +321,72 @@ def __call__(self,
dropout_rate=self.dropout_rate,
broadcast_dropout=self.broadcast_dropout,
deterministic=m_deterministic,
dtype=self.dtype,
dtype=dtype,
precision=self.precision) # pytype: disable=wrong-keyword-args
# back to the original inputs dimensions
out = DenseGeneral(features=features,
axis=(-2, -1),
kernel_init=self.kernel_init,
bias_init=self.bias_init,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
dtype=dtype,
param_dtype=param_dtype,
precision=self.precision,
name='out')(x)
return out


class SelfAttention(MultiHeadDotProductAttention):
"""Self-attention special case of multi-head dot-product attention."""
class MultiHeadDotProductAttention(_BaseMultiHeadDotProductAttention):
"""Multi-head dot-product attention.

Attributes:
num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
should be divisible by the number of heads.
dtype: the dtype of the computation (default: float32)
param_dtype: the dtype passed to parameter initializers (default: float32).
qkv_features: dimension of the key, query, and value.
out_features: dimension of the last projection
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rate: dropout rate
deterministic: if false, the attention weight is masked randomly
using dropout, whereas if true, the attention weights
are deterministic.
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
kernel_init: initializer for the kernel of the Dense layers.
bias_init: initializer for the bias of the Dense layers.
use_bias: bool: whether pointwise QKVO dense transforms use bias.
attention_fn: dot_product_attention or compatible function. Accepts
query, key, value, and returns output of shape
`[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]``
decode: whether to prepare and use an autoregressive cache.
"""
@compact
def __call__(self,
inputs_q: Array,
inputs_kv: Array,
mask: Optional[Array] = None,
deterministic: Optional[bool] = None):
return self._apply(inputs_q, inputs_kv, mask, deterministic=deterministic)


class SelfAttention(_BaseMultiHeadDotProductAttention):
NeilGirdhar marked this conversation as resolved.
Show resolved Hide resolved
"""Self-attention special case of multi-head dot-product attention."""
@compact
def __call__(self, inputs_q: Array, mask: Optional[Array] = None,
deterministic: Optional[bool] = None):
return super().__call__(inputs_q, inputs_q, mask, deterministic=deterministic)
return self._apply(inputs_q, inputs_q, mask, deterministic=deterministic)


# mask-making utility functions


def make_attention_mask(query_input: Array,
key_input: Array,
pairwise_fn: Callable[..., Any] = jnp.multiply,
extra_batch_dims: int = 0,
dtype: Dtype = jnp.float32):
def make_attention_mask(
query_input: Array,
key_input: Array,
pairwise_fn: Callable[[Array, Array], Array] = jnp.multiply,
extra_batch_dims: int = 0,
dtype: InexactDType = jnp.float32):
"""Mask-making helper for attention weights.

In case of 1d inputs (i.e., `[batch..., len_q]`, `[batch..., len_kv]`, the
Expand All @@ -381,7 +413,7 @@ def make_attention_mask(query_input: Array,

def make_causal_mask(x: Array,
extra_batch_dims: int = 0,
dtype: Dtype = jnp.float32) -> Array:
dtype: InexactDType = jnp.float32) -> Array:
"""Make a causal mask for self-attention.

In case of 1d inputs (i.e., `[batch..., len]`, the self-attention weights
Expand All @@ -403,7 +435,7 @@ def make_causal_mask(x: Array,


def combine_masks(*masks: Optional[Array],
dtype: Dtype = jnp.float32) -> Array:
dtype: InexactDType = jnp.float32) -> Array:
"""Combine attention masks.

Args:
Expand Down
Loading