Skip to content

Commit

Permalink
Fix text vectorization serialization for custom callables.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 604658738
  • Loading branch information
nkovela1 authored and tensorflower-gardener committed Feb 6, 2024
1 parent 4bec551 commit a724b9a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tf_keras/layers/preprocessing/text_vectorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tf_keras.layers.preprocessing import string_lookup
from tf_keras.saving.legacy.saved_model import layer_serialization
from tf_keras.saving.serialization_lib import deserialize_keras_object
from tf_keras.saving.serialization_lib import serialize_keras_object
from tf_keras.utils import layer_utils
from tf_keras.utils import tf_utils

Expand Down Expand Up @@ -500,8 +501,8 @@ def vocabulary_size(self):
def get_config(self):
config = {
"max_tokens": self._lookup_layer.max_tokens,
"standardize": self._standardize,
"split": self._split,
"standardize": serialize_keras_object(self._standardize),
"split": serialize_keras_object(self._split),
"ngrams": self._ngrams_arg,
"output_mode": self._output_mode,
"output_sequence_length": self._output_sequence_length,
Expand Down
21 changes: 21 additions & 0 deletions tf_keras/layers/preprocessing/text_vectorization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2408,6 +2408,27 @@ def test_serialization_with_custom_callables(self):
new_output_dataset = new_model.predict(input_array)
self.assertAllEqual(expected_output, new_output_dataset)

def test_cloning_with_custom_callable(self):
@register_keras_serializable(package="Test")
def pipe_split_fn(inp):
return tf.strings.split(inp, sep="|")

text_dataset = tf.data.Dataset.from_tensor_slices(
[
"this|is|some|pipe-delimited|text",
"some|more|pipe-delimited|text",
"yet|more|pipe-delimited|text",
]
)
vectorizer = text_vectorization.TextVectorization(
max_tokens=10, standardize=None, split=pipe_split_fn
)
vectorizer.adapt(text_dataset)
input_data = keras.Input(shape=(), dtype=tf.string)
outputs = vectorizer(input_data)
model = keras.Model(inputs=input_data, outputs=outputs)
_ = keras.models.clone_model(model)

@test_utils.run_v2_only()
def test_saving_v3(self):
vocab_data = ["earth", "wind", "and", "fire"]
Expand Down

0 comments on commit a724b9a

Please sign in to comment.