Skip to content

Commit

Permalink
Add support for autocast keyword argument in Layer.add_weight.
Browse files Browse the repository at this point in the history
This feature was already supported with the `experimental_autocast` argument. This change simply adds an alias for the same argument to have the same API in Keras 2 and Keras 3.

PiperOrigin-RevId: 662965658
  • Loading branch information
hertschuh authored and tensorflower-gardener committed Aug 14, 2024
1 parent f93f30e commit 4f72d68
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 9 deletions.
14 changes: 10 additions & 4 deletions tf_keras/engine/base_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,8 @@ def add_weight(
Accepted values are constants defined in the class
`tf.VariableAggregation`.
**kwargs: Additional keyword arguments. Accepted values are `getter`,
`collections`, `experimental_autocast` and `caching_device`.
`collections`, `autocast`, `experimental_autocast` and
`caching_device`.
Returns:
The variable created.
Expand All @@ -594,6 +595,7 @@ def add_weight(
# Validate optional keyword arguments.
for kwarg in kwargs:
if kwarg not in [
"autocast",
"collections",
"experimental_autocast",
"caching_device",
Expand All @@ -603,9 +605,13 @@ def add_weight(
]:
raise TypeError("Unknown keyword argument:", kwarg)
collections_arg = kwargs.pop("collections", None)
# 'experimental_autocast' can be set to False by the caller to indicate
# an AutoCastVariable should never be created.
autocast = kwargs.pop("experimental_autocast", True)
# 'autocast' or 'experimental_autocast' can be set to False by the
# caller to indicate an AutoCastVariable should never be created.
autocast = kwargs.pop("autocast", None)
if autocast is None:
autocast = kwargs.pop("experimental_autocast", None)
if autocast is None:
autocast = True
# See the docstring for tf.Variable about the details for
# caching_device.
caching_device = kwargs.pop("caching_device", None)
Expand Down
14 changes: 10 additions & 4 deletions tf_keras/engine/base_layer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,8 @@ def add_weight(
Accepted values are constants defined in the class
`tf.VariableAggregation`.
**kwargs: Additional keyword arguments. Accepted values are `getter`,
`collections`, `experimental_autocast` and `caching_device`.
`collections`, `autocast`, `experimental_autocast` and
`caching_device`.
Returns:
The created variable. Usually either a `Variable` or
Expand All @@ -371,6 +372,7 @@ def add_weight(
# Validate optional keyword arguments.
for kwarg in kwargs:
if kwarg not in [
"autocast",
"getter",
"collections",
"experimental_autocast",
Expand All @@ -380,9 +382,13 @@ def add_weight(
has_custom_getter = "getter" in kwargs
getter = kwargs.pop("getter", base_layer_utils.make_variable)
collections_arg = kwargs.pop("collections", None)
# 'experimental_autocast' can be set to False by the caller to indicate
# an AutoCastVariable should never be created.
autocast = kwargs.pop("experimental_autocast", True)
# 'autocast' or 'experimental_autocast' can be set to False by the
# caller to indicate an AutoCastVariable should never be created.
autocast = kwargs.pop("autocast", None)
if autocast is None:
autocast = kwargs.pop("experimental_autocast", None)
if autocast is None:
autocast = True
# See the docstring for tf.Variable about the details for
# caching_device.
caching_device = kwargs.pop("caching_device", None)
Expand Down
2 changes: 1 addition & 1 deletion tf_keras/mixed_precision/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def build(self, _):
(),
initializer="ones",
dtype=dtype,
experimental_autocast=False,
autocast=False,
regularizer=self._regularizer,
)
self.built = True
Expand Down

0 comments on commit 4f72d68

Please sign in to comment.