diff --git a/keras_nlp/layers/modeling/masked_lm_head.py b/keras_nlp/layers/modeling/masked_lm_head.py index 0bb0a421c4..22aef1201a 100644 --- a/keras_nlp/layers/modeling/masked_lm_head.py +++ b/keras_nlp/layers/modeling/masked_lm_head.py @@ -153,9 +153,11 @@ def build(self, inputs_shape, masked_positions_shape=None): activation=self.intermediate_activation, kernel_initializer=self.kernel_initializer, bias_initializer=self.bias_initializer, + dtype=self._dtype_policy, ) self._layer_norm = keras.layers.LayerNormalization( epsilon=self.layer_norm_epsilon, + dtype=self._dtype_policy, ) if masked_positions_shape: gather_length = masked_positions_shape[1] diff --git a/keras_nlp/layers/modeling/masked_lm_head_test.py b/keras_nlp/layers/modeling/masked_lm_head_test.py index f5c3b9d07c..c6a701d906 100644 --- a/keras_nlp/layers/modeling/masked_lm_head_test.py +++ b/keras_nlp/layers/modeling/masked_lm_head_test.py @@ -15,6 +15,9 @@ import os +import tensorflow as tf +from absl.testing import parameterized + from keras_nlp.backend import keras from keras_nlp.backend import ops from keras_nlp.layers.modeling import masked_lm_head @@ -36,6 +39,30 @@ def test_valid_call(self): position_data = ops.random.randint(minval=0, maxval=10, shape=(4, 5)) model((token_data, position_data)) + @parameterized.named_parameters( + ("bfloat16", tf.bfloat16), + ("float16", tf.float16), + ("float32", tf.float32), + ("float64", tf.float64), + ) + def test_valid_call_with_dtype(self, dtype): + head = masked_lm_head.MaskedLMHead( + vocabulary_size=100, + activation="softmax", + dtype=dtype, + ) + encoded_tokens = keras.Input(shape=(10, 16)) + positions = keras.Input(shape=(5,), dtype="int32") + outputs = head(encoded_tokens, masked_positions=positions) + model = keras.Model((encoded_tokens, positions), outputs) + + token_data = ops.random.uniform(shape=(4, 10, 16)) + position_data = ops.random.randint(minval=0, maxval=10, shape=(4, 5)) + model((token_data, position_data)) + + for w in head.weights: + self.assertEqual(w.dtype, dtype, "Wrong type: " + w.name) + def test_valid_call_with_embedding_weights(self): embedding = keras.layers.Embedding(100, 16) embedding.build((4, 10)) @@ -119,6 +146,32 @@ def test_one_train_step(self): loss = model.train_on_batch(x=(token_data, position_data), y=label_data) self.assertGreater(loss, 0) + @parameterized.named_parameters( + ("bfloat16", tf.bfloat16), + ("float16", tf.float16), + ("float32", tf.float32), + ("float64", tf.float64), + ) + def test_one_train_step_with_dtype(self, dtype): + head = masked_lm_head.MaskedLMHead( + vocabulary_size=100, + dtype=dtype, + ) + encoded_tokens = keras.Input(shape=(10, 16)) + positions = keras.Input(shape=(5,), dtype="int32") + outputs = head(encoded_tokens, masked_positions=positions) + model = keras.Model((encoded_tokens, positions), outputs) + + token_data = ops.random.uniform(shape=(4, 10, 16)) + position_data = ops.random.randint(minval=0, maxval=10, shape=(4, 5)) + label_data = ops.random.randint(minval=0, maxval=2, shape=(4, 5, 1)) + + loss = keras.losses.SparseCategoricalCrossentropy(from_logits=False) + optimizer = keras.optimizers.Adam() + model.compile(loss=loss, optimizer=optimizer) + loss = model.train_on_batch(x=(token_data, position_data), y=label_data) + self.assertGreater(loss, 0) + def test_saved_model(self): head = masked_lm_head.MaskedLMHead( vocabulary_size=100,