Skip to content

Commit

Permalink
Factor out dtypes.py
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Mar 7, 2022
1 parent 4f2927d commit f451757
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 143 deletions.
30 changes: 16 additions & 14 deletions flax/linen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,25 @@

# pylint: disable=g-multiple-import
# re-export commonly used modules and functions
from .activation import (celu, elu, gelu, glu, leaky_relu, log_sigmoid,
log_softmax, relu, sigmoid, soft_sign, softmax,
softplus, swish, silu, tanh, PReLU)
from .activation import (PReLU, celu, elu, gelu, glu, leaky_relu, log_sigmoid,
log_softmax, relu, sigmoid, silu, soft_sign, softmax,
softplus, swish, tanh)
from .attention import (MultiHeadDotProductAttention, SelfAttention,
dot_product_attention, dot_product_attention_weights,
make_attention_mask, make_causal_mask, combine_masks)
from ..core import broadcast, DenyList, FrozenDict
from .linear import (Conv, ConvLocal, ConvTranspose, Dense, DenseGeneral, Embed
canonicalize_inexact_dtypes, canonicalize_numeric_dtypes)
from .module import (Module, compact, nowrap, enable_named_call,
disable_named_call, override_named_call, Variable, init,
init_with_output, apply, merge_param)
combine_masks, dot_product_attention,
dot_product_attention_weights, make_attention_mask,
make_causal_mask)
from ..core import DenyList, FrozenDict, broadcast
from .dtypes import canonicalize_inexact_dtypes, canonicalize_numeric_dtypes
from .initializers import ones, zeros
from .linear import Conv, ConvLocal, ConvTranspose, Dense, DenseGeneral, Embed
from .module import (Module, Variable, apply, compact, disable_named_call,
enable_named_call, init, init_with_output, merge_param,
nowrap, override_named_call)
from .normalization import BatchNorm, GroupNorm, LayerNorm
from .pooling import avg_pool, max_pool, pool
from .recurrent import GRUCell, LSTMCell, ConvLSTM, OptimizedLSTMCell
from .recurrent import ConvLSTM, GRUCell, LSTMCell, OptimizedLSTMCell
from .stochastic import Dropout
from .transforms import jit, named_call, checkpoint, remat, remat_scan, scan, vmap, map_variables, vjp, jvp, custom_vjp
from .initializers import zeros, ones
from .transforms import (checkpoint, custom_vjp, jit, jvp, map_variables,
named_call, remat, remat_scan, scan, vjp, vmap)

# pylint: enable=g-multiple-import
39 changes: 8 additions & 31 deletions flax/linen/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,41 +15,18 @@
"""Activation functions.
"""

import jax.numpy as jnp
# pylint: disable=unused-import
# re-export activation functions from jax.nn
from jax.nn import celu
from jax.nn import elu
from jax.nn import gelu
from jax.nn import glu
from jax.nn import leaky_relu
from jax.nn import log_sigmoid
from jax.nn import log_softmax
from jax.nn import normalize
from jax.nn import relu
from jax.nn import sigmoid
from jax.nn import soft_sign
from jax.nn import softmax
from jax.nn import softplus
from jax.nn import swish
from jax.nn import silu
from jax.nn import selu
from jax.nn import hard_tanh
from jax.nn import relu6
from jax.nn import hard_sigmoid
from jax.nn import hard_swish

# re-export activation functions from jax.nn and jax.numpy
from jax.nn import (celu, elu, gelu, glu, hard_sigmoid, hard_swish, hard_tanh,
leaky_relu, log_sigmoid, log_softmax, normalize, relu,
relu6, selu, sigmoid, silu, soft_sign, softmax, softplus,
swish)
from jax.numpy import tanh
# pylint: enable=unused-import

from typing import Any

from flax.linen.linear import canonicalize_inexact_dtypes
from flax.linen.module import Module, compact
import jax.numpy as jnp


FloatingDType = Type[jnp.floating]
Array = Any
from .dtypes import Array, FloatingDType, canonicalize_inexact_dtypes
from .module import Module, compact


class PReLU(Module):
Expand Down
22 changes: 7 additions & 15 deletions flax/linen/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,18 @@
"""Attention core modules for Flax."""

from functools import partial
from typing import Any, Callable, Tuple, Type, Optional
from typing import Callable, Optional

import jax
from jax import lax
from jax import random
import jax.numpy as jnp
import numpy as np
from jax import lax, random
from typing_extensions import Protocol

from flax.linen.initializers import zeros
from flax.linen.linear import DenseGeneral
from flax.linen.linear import canonicalize_inexact_dtypes
from flax.linen.linear import default_kernel_init
from flax.linen.module import Module, compact, merge_param

PRNGKey = Any
Shape = Tuple[int, ...]
InexactDType = Type[jnp.inexact]
Array = Any
Initializer = Callable[[PRNGKey, Shape, InexactDType], Array]
from .dtypes import (Array, InexactDType, Initializer, PRNGKey,
canonicalize_inexact_dtypes)
from .initializers import zeros
from .linear import DenseGeneral, default_kernel_init
from .module import Module, compact, merge_param


class AttentionFunction(Protocol):
Expand Down
57 changes: 57 additions & 0 deletions flax/linen/dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2022 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tools for working with dtypes."""

from typing import Any, Callable, Optional, Tuple, Type

import jax.numpy as jnp
import numpy as np


Array = Any # pylint: disable=invalid-name
PRNGKey = Any # pylint: disable=invalid-name
Shape = Tuple[int, ...]
FloatingDType = Type[jnp.floating]
GenericDType = Type[np.generic]
InexactDType = Type[jnp.inexact]
NumericDType = Type[jnp.number]
Initializer = Callable[[PRNGKey, Shape, InexactDType], Array]


def canonicalize_inexact_dtypes(
input_dtype: InexactDType,
param_dtype: Optional[InexactDType],
computation_dtype: Optional[InexactDType]) -> Tuple[InexactDType,
InexactDType]:
returned_param_dtype = input_dtype if param_dtype is None else param_dtype
dtype = (jnp.result_type(input_dtype, returned_param_dtype)
if computation_dtype is None else computation_dtype)

assert jnp.issubdtype(input_dtype, jnp.inexact)
return returned_param_dtype, dtype


def canonicalize_numeric_dtypes(
input_dtype: NumericDType,
param_dtype: Optional[NumericDType],
computation_dtype: Optional[NumericDType]) -> Tuple[NumericDType,
NumericDType]:
returned_param_dtype = input_dtype if param_dtype is None else param_dtype
dtype = (jnp.result_type(input_dtype, returned_param_dtype)
if computation_dtype is None else computation_dtype)

assert jnp.issubdtype(input_dtype, jnp.number)
return returned_param_dtype, dtype

49 changes: 5 additions & 44 deletions flax/linen/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,15 @@

import abc
from dataclasses import field
from typing import Iterable, List, Optional, Sequence, Tuple, Union

from typing import (Any, Callable, Iterable, List, Optional, Sequence, Tuple,
Type, Union)

from flax.linen.module import Module, compact
from flax.linen.initializers import lecun_normal, variance_scaling, zeros

from jax import lax
from jax import eval_shape
from jax import ShapedArray
import jax.numpy as jnp
import numpy as np
from jax import ShapedArray, eval_shape, lax


PRNGKey = Any
Shape = Tuple[int, ...]
InexactDType = Type[jnp.inexact]
NumericDType = Type[jnp.number]
GenericDType = Type[np.generic]
Array = Any
Initializer = Callable[[PRNGKey, Shape, InexactDType], Array]

from .dtypes import Array, GenericDType, InexactDType, Initializer, NumericDType
from .initializers import lecun_normal, variance_scaling, zeros
from .module import Module, compact

default_kernel_init = lecun_normal()

Expand All @@ -54,32 +41,6 @@ def _canonicalize_tuple(x: Union[Sequence[int], int]) -> Tuple[int, ...]:
return (x,)


def canonicalize_inexact_dtypes(
input_dtype: InexactDType,
param_dtype: Optional[InexactDType],
computation_dtype: Optional[InexactDType]) -> Tuple[InexactDType,
InexactDType]:
returned_param_dtype = input_dtype if param_dtype is None else param_dtype
dtype = (jnp.result_type(input_dtype, returned_param_dtype)
if computation_dtype is None else computation_dtype)

assert jnp.issubdtype(input_dtype, jnp.inexact)
return returned_param_dtype, dtype


def canonicalize_numeric_dtypes(
input_dtype: NumericDType,
param_dtype: Optional[NumericDType],
computation_dtype: Optional[NumericDType]) -> Tuple[NumericDType,
NumericDType]:
returned_param_dtype = input_dtype if param_dtype is None else param_dtype
dtype = (jnp.result_type(input_dtype, returned_param_dtype)
if computation_dtype is None else computation_dtype)

assert jnp.issubdtype(input_dtype, jnp.number)
return returned_param_dtype, dtype


class DenseGeneral(Module):
"""A linear transformation with flexible axes.
Expand Down
27 changes: 12 additions & 15 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

"""Flax Modules."""
from contextlib import contextmanager
import dataclasses
import enum
import functools
Expand All @@ -23,33 +22,30 @@
import types
import typing
import weakref

from typing import (Any, Callable, Sequence, Iterable, List, Optional, Tuple,
Set, Type, Union, TypeVar, Generic, Dict, overload)
from contextlib import contextmanager
from typing import (Any, Callable, Dict, Generic, Iterable, List, Optional,
Sequence, Set, Tuple, Type, TypeVar, Union, overload)

import jax
import numpy as np
from jax import tree_util
from jax._src.numpy.lax_numpy import isin
import numpy as np

import flax
from flax import config
from flax import errors
from flax import traceback_util
from flax import traverse_util
from flax import serialization
from flax import core
from flax import (config, core, errors, serialization, traceback_util,
traverse_util)
from flax.core import Scope
from flax.core.scope import CollectionFilter, DenyList, Variable, VariableDict, FrozenVariableDict, union_filters
from flax.core.frozen_dict import FrozenDict, freeze
from flax.core.scope import (CollectionFilter, DenyList, FrozenVariableDict,
Variable, VariableDict, union_filters)
from flax.struct import __dataclass_transform__

from .dtypes import Array, PRNGKey

# from .dotgetter import DotGetter
traceback_util.register_exclusion(__file__)

PRNGKey = Any # pylint: disable=invalid-name
RNGSequences = Dict[str, PRNGKey]
Array = Any # pylint: disable=invalid-name


T = TypeVar('T')
Expand Down Expand Up @@ -603,7 +599,8 @@ def _wrap_module_methods(cls):
wrapped_method = wrap_method_once(method)
if key != 'setup':
# We import named_call at runtime to avoid a circular import issue.
from flax.linen.transforms import named_call # pylint: disable=g-import-not-at-top
from flax.linen.transforms import \
named_call # pylint: disable=g-import-not-at-top
wrapped_method = named_call(wrapped_method, force=False)
setattr(cls, key, wrapped_method)
return cls
Expand Down
17 changes: 5 additions & 12 deletions flax/linen/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,15 @@

"""Normalization modules for Flax."""

from typing import Any, Callable, Optional, Tuple, Type, Iterable, Union
from typing import Any, Iterable, Optional, Tuple, Union

import jax.numpy as jnp
from jax import lax
from jax.nn import initializers
import jax.numpy as jnp
import numpy as np

from flax.linen.module import Module, compact, merge_param
from flax.linen.linear import canonicalize_inexact_dtypes


PRNGKey = Any
Array = Any
Shape = Tuple[int, ...]
InexactDType = Type[jnp.inexact]
Initializer = Callable[[PRNGKey, Shape, InexactDType], Array]
from .dtypes import (Array, InexactDType, Initializer,
canonicalize_inexact_dtypes)
from .module import Module, compact, merge_param

Axes = Union[int, Tuple[int, ...]]

Expand Down
20 changes: 8 additions & 12 deletions flax/linen/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,20 @@

import abc
from functools import partial
from typing import (Any, Callable, Iterable, Mapping, Optional, Sequence, Tuple,
Type, Union)
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Union

from flax.linen.module import Module, compact
from flax.linen.activation import sigmoid, tanh
from flax.linen.initializers import orthogonal, zeros
from flax.linen.linear import Conv, Dense, default_kernel_init

from jax import numpy as jnp
import numpy as np
from jax import lax
from jax import numpy as jnp
from jax import random

import numpy as np
from .activation import sigmoid, tanh
from .dtypes import Array, PRNGKey, Shape
from .initializers import orthogonal, zeros
from .linear import Conv, Dense, default_kernel_init
from .module import Module, compact

PRNGKey = Any
Shape = Tuple[int, ...]
Dtype = Any
Array = Any


class RNNCellBase(Module):
Expand Down

0 comments on commit f451757

Please sign in to comment.