Skip to content

Commit

Permalink
Fix activation lookup with Python 3.12.3 (explosion#375)
Browse files Browse the repository at this point in the history
We used the metaclass `EnumMeta`/`EnumType` to override reporting of
missing enum values (to give the full set of supported activations).
However, in Python 3.12.3, the default value of the `name` parameter of
`EnumType.__call__` method was changed from `None` to `_not_given`:

python/cpython@d771729

Even though this is a public API (which now uses a private default
value), it seems too risky to continue using it. So in this change, we
implement `Enum.__mising__` instead for the improved error reporting.
  • Loading branch information
danieldk committed Apr 17, 2024
1 parent b192987 commit 4b737ce
Showing 1 changed file with 10 additions and 41 deletions.
51 changes: 10 additions & 41 deletions curated_transformers/layers/activations.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,13 @@
import math
from enum import Enum, EnumMeta
from enum import Enum
from typing import Type

import torch
from torch import Tensor
from torch.nn import Module


class _ActivationMeta(EnumMeta):
"""
``Enum`` metaclass to override the class ``__call__`` method with a more
fine-grained exception for unknown activation functions.
"""

def __call__(
cls,
value,
names=None,
*,
module=None,
qualname=None,
type=None,
start=1,
):
# Wrap superclass __call__ to give a nicer error message when
# an unknown activation is used.
if names is None:
try:
return EnumMeta.__call__(
cls,
value,
names,
module=module,
qualname=qualname,
type=type,
start=start,
)
except ValueError:
supported_activations = ", ".join(sorted(v.value for v in cls))
raise ValueError(
f"Invalid activation function `{value}`. "
f"Supported functions: {supported_activations}"
)
else:
return EnumMeta.__call__(cls, value, names, module, qualname, type, start)


class Activation(Enum, metaclass=_ActivationMeta):
class Activation(Enum):
"""
Activation functions.
Expand All @@ -71,6 +32,14 @@ class Activation(Enum, metaclass=_ActivationMeta):
#: Sigmoid Linear Unit (`Hendrycks et al., 2016`_).
SiLU = "silu"

@classmethod
def _missing_(cls, value):
supported_activations = ", ".join(sorted(v.value for v in cls))
raise ValueError(
f"Invalid activation function `{value}`. "
f"Supported functions: {supported_activations}"
)

@property
def module(self) -> Type[torch.nn.Module]:
"""
Expand Down

0 comments on commit 4b737ce

Please sign in to comment.