From a089a8b7559be3e5d5a49b04df54997a9803cc4d Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Wed, 7 Aug 2024 17:57:31 -0700 Subject: [PATCH 1/7] Add VGG16 backbone (#1737) * Agg Vgg16 backbone * update names * update tests * update test * add image classifier * incorporate review comments * Update test case * update backbone test * add image classifier * classifier cleanup * code reformat * add vgg16 image classifier * make vgg generic * update doc string * update docstring * add classifier test * update tests * update docstring * address review comments * code reformat * update the configs * address review comments * fix task saved model test * update init * code reformatted --- keras_nlp/api/models/__init__.py | 3 + keras_nlp/src/models/image_classifier.py | 90 ++++++++++ keras_nlp/src/models/vgg/__init__.py | 13 ++ keras_nlp/src/models/vgg/vgg_backbone.py | 159 ++++++++++++++++++ keras_nlp/src/models/vgg/vgg_backbone_test.py | 48 ++++++ .../src/models/vgg/vgg_image_classifier.py | 124 ++++++++++++++ .../models/vgg/vgg_image_classifier_test.py | 61 +++++++ keras_nlp/src/tests/test_case.py | 30 ++-- 8 files changed, 514 insertions(+), 14 deletions(-) create mode 100644 keras_nlp/src/models/image_classifier.py create mode 100644 keras_nlp/src/models/vgg/__init__.py create mode 100644 keras_nlp/src/models/vgg/vgg_backbone.py create mode 100644 keras_nlp/src/models/vgg/vgg_backbone_test.py create mode 100644 keras_nlp/src/models/vgg/vgg_image_classifier.py create mode 100644 keras_nlp/src/models/vgg/vgg_image_classifier_test.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 4fb3b3cf00..41f1a47284 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -129,6 +129,7 @@ GPTNeoXPreprocessor, ) from keras_nlp.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer +from keras_nlp.src.models.image_classifier import ImageClassifier from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone from keras_nlp.src.models.llama3.llama3_causal_lm import Llama3CausalLM from keras_nlp.src.models.llama3.llama3_causal_lm_preprocessor import ( @@ -194,6 +195,8 @@ from keras_nlp.src.models.t5.t5_backbone import T5Backbone from keras_nlp.src.models.t5.t5_tokenizer import T5Tokenizer from keras_nlp.src.models.task import Task +from keras_nlp.src.models.vgg.vgg_backbone import VGGBackbone +from keras_nlp.src.models.vgg.vgg_image_classifier import VGGImageClassifier from keras_nlp.src.models.whisper.whisper_audio_feature_extractor import ( WhisperAudioFeatureExtractor, ) diff --git a/keras_nlp/src/models/image_classifier.py b/keras_nlp/src/models/image_classifier.py new file mode 100644 index 0000000000..f0cc031dbc --- /dev/null +++ b/keras_nlp/src/models/image_classifier.py @@ -0,0 +1,90 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.task import Task + + +@keras_nlp_export("keras_nlp.models.ImageClassifier") +class ImageClassifier(Task): + """Base class for all image classification tasks. + + `ImageClassifier` tasks wrap a `keras_nlp.models.Backbone` and + a `keras_nlp.models.Preprocessor` to create a model that can be used for + image classification. `ImageClassifier` tasks take an additional + `num_classes` argument, controlling the number of predicted output classes. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + labels where `x` is a string and `y` is a integer from `[0, num_classes)`. + + All `ImageClassifier` tasks include a `from_preset()` constructor which can be + used to load a pre-trained config and weights. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Default compilation. + self.compile() + + def compile( + self, + optimizer="auto", + loss="auto", + *, + metrics="auto", + **kwargs, + ): + """Configures the `ImageClassifier` task for training. + + The `ImageClassifier` task extends the default compilation signature of + `keras.Model.compile` with defaults for `optimizer`, `loss`, and + `metrics`. To override these defaults, pass any value + to these arguments during compilation. + + Args: + optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` + instance. Defaults to `"auto"`, which uses the default optimizer + for the given model and task. See `keras.Model.compile` and + `keras.optimizers` for more info on possible `optimizer` values. + loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. + Defaults to `"auto"`, where a + `keras.losses.SparseCategoricalCrossentropy` loss will be + applied for the classification task. See + `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + metrics: `"auto"`, or a list of metrics to be evaluated by + the model during training and testing. Defaults to `"auto"`, + where a `keras.metrics.SparseCategoricalAccuracy` will be + applied to track the accuracy of the model during training. + See `keras.Model.compile` and `keras.metrics` for + more info on possible `metrics` values. + **kwargs: See `keras.Model.compile` for a full list of arguments + supported by the compile method. + """ + if optimizer == "auto": + optimizer = keras.optimizers.Adam(5e-5) + if loss == "auto": + activation = getattr(self, "activation", None) + activation = keras.activations.get(activation) + from_logits = activation != keras.activations.softmax + loss = keras.losses.SparseCategoricalCrossentropy(from_logits) + if metrics == "auto": + metrics = [keras.metrics.SparseCategoricalAccuracy()] + super().compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + **kwargs, + ) diff --git a/keras_nlp/src/models/vgg/__init__.py b/keras_nlp/src/models/vgg/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/vgg/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/models/vgg/vgg_backbone.py b/keras_nlp/src/models/vgg/vgg_backbone.py new file mode 100644 index 0000000000..497381c0fc --- /dev/null +++ b/keras_nlp/src/models/vgg/vgg_backbone.py @@ -0,0 +1,159 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras +from keras import layers + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.backbone import Backbone + + +@keras_nlp_export("keras_nlp.models.VGGBackbone") +class VGGBackbone(Backbone): + """ + This class represents Keras Backbone of VGG model. + + This class implements a VGG backbone as described in [Very Deep + Convolutional Networks for Large-Scale Image Recognition]( + https://arxiv.org/abs/1409.1556)(ICLR 2015). + + Args: + stackwise_num_repeats: list of ints, number of repeated convolutional + blocks per VGG block. For VGG16 this is [2, 2, 3, 3, 3] and for + VGG19 this is [2, 2, 4, 4, 4]. + stackwise_num_filters: list of ints, filter size for convolutional + blocks per VGG block. For both VGG16 and VGG19 this is [ + 64, 128, 256, 512, 512]. + include_rescaling: bool, whether to rescale the inputs. If set to + True, inputs will be passed through a `Rescaling(1/255.0)` layer. + input_shape: tuple, optional shape tuple, defaults to (224, 224, 3). + pooling: bool, Optional pooling mode for feature extraction + when `include_top` is `False`. + - `None` means that the output of the model will be + the 4D tensor output of the + last convolutional block. + - `avg` means that global average pooling + will be applied to the output of the + last convolutional block, and thus + the output of the model will be a 2D tensor. + - `max` means that global max pooling will + be applied. + + Examples: + ```python + input_data = np.ones((2, 224, 224, 3), dtype="float32") + + # Pretrained VGG backbone. + model = keras_nlp.models.VGGBackbone.from_preset("vgg16") + model(input_data) + + # Randomly initialized VGG backbone with a custom config. + model = keras_nlp.models.VGGBackbone( + stackwise_num_repeats = [2, 2, 3, 3, 3], + stackwise_num_filters = [64, 128, 256, 512, 512], + input_shape = (224, 224, 3), + include_rescaling = False, + pooling = "avg", + ) + model(input_data) + ``` + """ + + def __init__( + self, + stackwise_num_repeats, + stackwise_num_filters, + include_rescaling, + input_image_shape=(224, 224, 3), + pooling="avg", + **kwargs, + ): + + # === Functional Model === + img_input = keras.layers.Input(shape=input_image_shape) + x = img_input + + if include_rescaling: + x = layers.Rescaling(scale=1 / 255.0)(x) + for stack_index in range(len(stackwise_num_repeats) - 1): + x = apply_vgg_block( + x=x, + num_layers=stackwise_num_repeats[stack_index], + filters=stackwise_num_filters[stack_index], + kernel_size=(3, 3), + activation="relu", + padding="same", + max_pool=True, + name=f"block{stack_index + 1}", + ) + if pooling == "avg": + x = layers.GlobalAveragePooling2D()(x) + elif pooling == "max": + x = layers.GlobalMaxPooling2D()(x) + + super().__init__(inputs=img_input, outputs=x, **kwargs) + + # === Config === + self.stackwise_num_repeats = stackwise_num_repeats + self.stackwise_num_filters = stackwise_num_filters + self.include_rescaling = include_rescaling + self.input_image_shape = input_image_shape + self.pooling = pooling + + def get_config(self): + return { + "stackwise_num_repeats": self.stackwise_num_repeats, + "stackwise_num_filters": self.stackwise_num_filters, + "include_rescaling": self.include_rescaling, + "input_image_shape": self.input_image_shape, + "pooling": self.pooling, + } + + +def apply_vgg_block( + x, + num_layers, + filters, + kernel_size, + activation, + padding, + max_pool, + name, +): + """ + Applies VGG block + Args: + x: Tensor, input tensor to pass through network + num_layers: int, number of CNN layers in the block + filters: int, filter size of each CNN layer in block + kernel_size: int (or) tuple, kernel size for CNN layer in block + activation: str (or) callable, activation function for each CNN layer in + block + padding: str (or) callable, padding function for each CNN layer in block + max_pool: bool, whether to add MaxPooling2D layer at end of block + name: str, name of the block + + Returns: + keras.KerasTensor + """ + for num in range(1, num_layers + 1): + x = layers.Conv2D( + filters, + kernel_size, + activation=activation, + padding=padding, + name=f"{name}_conv{num}", + )(x) + if max_pool: + x = layers.MaxPooling2D((2, 2), (2, 2), name=f"{name}_pool")(x) + return x diff --git a/keras_nlp/src/models/vgg/vgg_backbone_test.py b/keras_nlp/src/models/vgg/vgg_backbone_test.py new file mode 100644 index 0000000000..05ed33ba0f --- /dev/null +++ b/keras_nlp/src/models/vgg/vgg_backbone_test.py @@ -0,0 +1,48 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from keras_nlp.src.models.vgg.vgg_backbone import VGGBackbone +from keras_nlp.src.tests.test_case import TestCase + + +class VGGBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "stackwise_num_repeats": [2, 3, 3], + "stackwise_num_filters": [8, 64, 64], + "input_image_shape": (16, 16, 3), + "include_rescaling": False, + "pooling": "avg", + } + self.input_data = np.ones((2, 16, 16, 3), dtype="float32") + + def test_backbone_basics(self): + self.run_backbone_test( + cls=VGGBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 64), + run_mixed_precision_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=VGGBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/vgg/vgg_image_classifier.py b/keras_nlp/src/models/vgg/vgg_image_classifier.py new file mode 100644 index 0000000000..a26fbfbc30 --- /dev/null +++ b/keras_nlp/src/models/vgg/vgg_image_classifier.py @@ -0,0 +1,124 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.image_classifier import ImageClassifier +from keras_nlp.src.models.vgg.vgg_backbone import VGGBackbone + + +@keras_nlp_export("keras_nlp.models.VGGImageClassifier") +class VGGImageClassifier(ImageClassifier): + """VGG16 image classifier task model. + + Args: + backbone: A `keras_nlp.models.VGGBackbone` instance. + num_classes: int, number of classes to predict. + pooling: str, type of pooling layer. Must be one of "avg", "max". + activation: Optional `str` or callable, defaults to "softmax". The + activation function to use on the Dense layer. Set `activation=None` + to return the output logits. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + labels where `x` is a string and `y` is a integer from `[0, num_classes)`. + All `ImageClassifier` tasks include a `from_preset()` constructor which can be + used to load a pre-trained config and weights. + + Examples: + Train from preset + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + classifier = keras_nlp.models.VGGImageClassifier.from_preset( + 'vgg_16_image_classifier') + classifier.fit(x=images, y=labels, batch_size=2) + + # Re-compile (e.g., with a new learning rate). + classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + jit_compile=True, + ) + + # Access backbone programmatically (e.g., to change `trainable`). + classifier.backbone.trainable = False + # Fit again. + classifier.fit(x=images, y=labels, batch_size=2) + ``` + Custom backbone + ```python + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + + backbone = keras_nlp.models.VGGBackbone( + stackwise_num_repeats = [2, 2, 3, 3, 3], + stackwise_num_filters = [64, 128, 256, 512, 512], + input_shape = (224, 224, 3), + include_rescaling = False, + pooling = "avg", + ) + classifier = keras_nlp.models.VGGImageClassifier( + backbone=backbone, + num_classes=4, + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + """ + + backbone_cls = VGGBackbone + + def __init__( + self, + backbone, + num_classes, + activation="softmax", + preprocessor=None, # adding this dummy arg for saved model test + # TODO: once preprocessor flow is figured out, this needs to be updated + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + outputs = self.output_dense(x) + + # Instantiate using Functional API Model constructor + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "activation": self.activation, + } + ) + return config diff --git a/keras_nlp/src/models/vgg/vgg_image_classifier_test.py b/keras_nlp/src/models/vgg/vgg_image_classifier_test.py new file mode 100644 index 0000000000..4a2573e496 --- /dev/null +++ b/keras_nlp/src/models/vgg/vgg_image_classifier_test.py @@ -0,0 +1,61 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from keras_nlp.src.models.vgg.vgg_backbone import VGGBackbone +from keras_nlp.src.models.vgg.vgg_image_classifier import VGGImageClassifier +from keras_nlp.src.tests.test_case import TestCase + + +class VGGImageClassifierTest(TestCase): + def setUp(self): + # Setup model. + self.images = np.ones((2, 4, 4, 3), dtype="float32") + self.labels = [0, 3] + self.backbone = VGGBackbone( + stackwise_num_repeats=[2, 4, 4], + stackwise_num_filters=[2, 16, 16], + input_image_shape=(4, 4, 3), + include_rescaling=False, + pooling="max", + ) + self.init_kwargs = { + "backbone": self.backbone, + "num_classes": 2, + "activation": "softmax", + } + self.train_data = ( + self.images, + self.labels, + ) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=VGGImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=VGGImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_nlp/src/tests/test_case.py b/keras_nlp/src/tests/test_case.py index 7e8e0cec95..fc1ce77e1e 100644 --- a/keras_nlp/src/tests/test_case.py +++ b/keras_nlp/src/tests/test_case.py @@ -419,20 +419,22 @@ def run_backbone_test( self.assertEqual(output[key].shape, expected_output_shape[key]) else: self.assertEqual(output.shape, expected_output_shape) - - # Check we can embed tokens eagerly. - output = backbone.token_embedding(ops.zeros((2, 3), dtype="int32")) - - # Check variable length sequences. - if variable_length_data is None: - # If no variable length data passed, assume the second axis of all - # inputs is our sequence axis and create it ourselves. - variable_length_data = [ - tree.map_structure(lambda x: x[:, :seq_length, ...], input_data) - for seq_length in (2, 3, 4) - ] - for batch in variable_length_data: - backbone(batch) + if backbone.token_embedding is not None: + # Check we can embed tokens eagerly. + output = backbone.token_embedding(ops.zeros((2, 3), dtype="int32")) + + # Check variable length sequences. + if variable_length_data is None: + # If no variable length data passed, assume the second axis of all + # inputs is our sequence axis and create it ourselves. + variable_length_data = [ + tree.map_structure( + lambda x: x[:, :seq_length, ...], input_data + ) + for seq_length in (2, 3, 4) + ] + for batch in variable_length_data: + backbone(batch) # Check compiled predict function. backbone.predict(input_data) From 73b7bad007a8c37a54512092c2b8bfe435d21c10 Mon Sep 17 00:00:00 2001 From: "Hongyu, Chiu" <20734616+james77777778@users.noreply.github.com> Date: Tue, 13 Aug 2024 01:09:08 +0800 Subject: [PATCH 2/7] Add `ResNetBackbone` and `ResNetImageClassifier` (#1765) * Add ResNetV1 and ResNetV2 * Address comments --- keras_nlp/api/models/__init__.py | 4 + keras_nlp/src/models/resnet/__init__.py | 13 + .../src/models/resnet/resnet_backbone.py | 544 ++++++++++++++++++ .../src/models/resnet/resnet_backbone_test.py | 75 +++ .../models/resnet/resnet_image_classifier.py | 131 +++++ .../resnet/resnet_image_classifier_test.py | 62 ++ keras_nlp/src/tests/test_case.py | 60 ++ keras_nlp/src/utils/keras_utils.py | 13 + 8 files changed, 902 insertions(+) create mode 100644 keras_nlp/src/models/resnet/__init__.py create mode 100644 keras_nlp/src/models/resnet/resnet_backbone.py create mode 100644 keras_nlp/src/models/resnet/resnet_backbone_test.py create mode 100644 keras_nlp/src/models/resnet/resnet_image_classifier.py create mode 100644 keras_nlp/src/models/resnet/resnet_image_classifier_test.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 41f1a47284..783cfd5087 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -181,6 +181,10 @@ from keras_nlp.src.models.phi3.phi3_preprocessor import Phi3Preprocessor from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer from keras_nlp.src.models.preprocessor import Preprocessor +from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_nlp.src.models.resnet.resnet_image_classifier import ( + ResNetImageClassifier, +) from keras_nlp.src.models.roberta.roberta_backbone import RobertaBackbone from keras_nlp.src.models.roberta.roberta_classifier import RobertaClassifier from keras_nlp.src.models.roberta.roberta_masked_lm import RobertaMaskedLM diff --git a/keras_nlp/src/models/resnet/__init__.py b/keras_nlp/src/models/resnet/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/resnet/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/models/resnet/resnet_backbone.py b/keras_nlp/src/models/resnet/resnet_backbone.py new file mode 100644 index 0000000000..bec5ba60b5 --- /dev/null +++ b/keras_nlp/src/models/resnet/resnet_backbone.py @@ -0,0 +1,544 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras +from keras import layers + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.utils.keras_utils import standardize_data_format + + +@keras_nlp_export("keras_nlp.models.ResNetBackbone") +class ResNetBackbone(Backbone): + """ResNet and ResNetV2 core network with hyperparameters. + + This class implements a ResNet backbone as described in [Deep Residual + Learning for Image Recognition](https://arxiv.org/abs/1512.03385)( + CVPR 2016) and [Identity Mappings in Deep Residual Networks]( + https://arxiv.org/abs/1603.05027)(ECCV 2016). + + The difference in ResNet and ResNetV2 rests in the structure of their + individual building blocks. In ResNetV2, the batch normalization and + ReLU activation precede the convolution layers, as opposed to ResNet where + the batch normalization and ReLU activation are applied after the + convolution layers. + + Args: + stackwise_num_filters: list of ints. The number of filters for each + stack. + stackwise_num_blocks: list of ints. The number of blocks for each stack. + stackwise_num_strides: list of ints. The number of strides for each + stack. + block_type: str. The block type to stack. One of `"basic_block"` or + `"bottleneck_block"`. Use `"basic_block"` for ResNet18 and ResNet34. + Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152. + use_pre_activation: boolean. Whether to use pre-activation or not. + `True` for ResNetV2, `False` for ResNet. + include_rescaling: boolean. If `True`, rescale the input using + `Rescaling(1 / 255.0)` layer. If `False`, do nothing. Defaults to + `True`. + input_image_shape: tuple. The input shape without the batch size. + Defaults to `(None, None, 3)`. + pooling: `None` or str. Pooling mode for feature extraction. Defaults + to `"avg"`. + - `None` means that the output of the model will be the 4D tensor + from the last convolutional block. + - `avg` means that global average pooling will be applied to the + output of the last convolutional block, resulting in a 2D + tensor. + - `max` means that global max pooling will be applied to the + output of the last convolutional block, resulting in a 2D + tensor. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the + inputs. `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the models computations and weights. + + Examples: + ```python + input_data = np.ones((2, 224, 224, 3), dtype="float32") + + # Pretrained ResNet backbone. + model = keras_nlp.models.ResNetBackbone.from_preset("resnet50") + model(input_data) + + # Randomly initialized ResNetV2 backbone with a custom config. + model = keras_nlp.models.ResNetBackbone( + stackwise_num_filters=[64, 64, 64], + stackwise_num_blocks=[2, 2, 2], + stackwise_num_strides=[1, 2, 2], + block_type="basic_block", + use_pre_activation=True, + pooling="avg", + ) + model(input_data) + ``` + """ + + def __init__( + self, + stackwise_num_filters, + stackwise_num_blocks, + stackwise_num_strides, + block_type, + use_pre_activation=False, + include_rescaling=True, + input_image_shape=(None, None, 3), + pooling="avg", + data_format=None, + dtype=None, + **kwargs, + ): + if len(stackwise_num_filters) != len(stackwise_num_blocks) or len( + stackwise_num_filters + ) != len(stackwise_num_strides): + raise ValueError( + "The length of `stackwise_num_filters`, `stackwise_num_blocks` " + "and `stackwise_num_strides` must be the same. Received: " + f"stackwise_num_filters={stackwise_num_filters}, " + f"stackwise_num_blocks={stackwise_num_blocks}, " + f"stackwise_num_strides={stackwise_num_strides}" + ) + if stackwise_num_filters[0] != 64: + raise ValueError( + "The first element of `stackwise_num_filters` must be 64. " + f"Received: stackwise_num_filters={stackwise_num_filters}" + ) + if block_type not in ("basic_block", "bottleneck_block"): + raise ValueError( + '`block_type` must be either `"basic_block"` or ' + f'`"bottleneck_block"`. Received block_type={block_type}.' + ) + version = "v1" if not use_pre_activation else "v2" + data_format = standardize_data_format(data_format) + bn_axis = -1 if data_format == "channels_last" else 1 + num_stacks = len(stackwise_num_filters) + + # === Functional Model === + image_input = layers.Input(shape=input_image_shape) + if include_rescaling: + x = layers.Rescaling(scale=1 / 255.0, dtype=dtype)(image_input) + else: + x = image_input + + x = layers.Conv2D( + 64, + 7, + strides=2, + padding="same", + data_format=data_format, + use_bias=use_pre_activation, + dtype=dtype, + name="conv1_conv", + )(x) + if not use_pre_activation: + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name="conv1_bn" + )(x) + x = layers.Activation("relu", dtype=dtype, name="conv1_relu")(x) + + x = layers.MaxPool2D( + 3, + strides=2, + padding="same", + data_format=data_format, + dtype=dtype, + name="pool1_pool", + )(x) + + for stack_index in range(num_stacks): + x = apply_stack( + x, + filters=stackwise_num_filters[stack_index], + blocks=stackwise_num_blocks[stack_index], + stride=stackwise_num_strides[stack_index], + block_type=block_type, + use_pre_activation=use_pre_activation, + first_shortcut=( + block_type == "bottleneck_block" or stack_index > 0 + ), + data_format=data_format, + dtype=dtype, + name=f"{version}_stack{stack_index}", + ) + + if use_pre_activation: + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name="post_bn" + )(x) + x = layers.Activation("relu", dtype=dtype, name="post_relu")(x) + + if pooling == "avg": + feature_map_output = layers.GlobalAveragePooling2D( + data_format=data_format, dtype=dtype + )(x) + elif pooling == "max": + feature_map_output = layers.GlobalMaxPooling2D( + data_format=data_format, dtype=dtype + )(x) + else: + feature_map_output = x + + super().__init__( + inputs=image_input, + outputs=feature_map_output, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.stackwise_num_filters = stackwise_num_filters + self.stackwise_num_blocks = stackwise_num_blocks + self.stackwise_num_strides = stackwise_num_strides + self.block_type = block_type + self.use_pre_activation = use_pre_activation + self.include_rescaling = include_rescaling + self.input_image_shape = input_image_shape + self.pooling = pooling + + def get_config(self): + return { + "stackwise_num_filters": self.stackwise_num_filters, + "stackwise_num_blocks": self.stackwise_num_blocks, + "stackwise_num_strides": self.stackwise_num_strides, + "block_type": self.block_type, + "use_pre_activation": self.use_pre_activation, + "include_rescaling": self.include_rescaling, + "input_image_shape": self.input_image_shape, + "pooling": self.pooling, + } + + +def apply_basic_block( + x, + filters, + kernel_size=3, + stride=1, + conv_shortcut=False, + use_pre_activation=False, + data_format=None, + dtype=None, + name=None, +): + """Applies a basic residual block. + + Args: + x: Tensor. The input tensor to pass through the block. + filters: int. The number of filters in the block. + kernel_size: int. The kernel size of the bottleneck layer. Defaults to + `3`. + stride: int. The stride length of the first layer. Defaults to `1`. + conv_shortcut: bool. If `True`, use a convolution shortcut. If `False`, + use an identity or pooling shortcut based on the stride. Defaults to + `False`. + use_pre_activation: boolean. Whether to use pre-activation or not. + `True` for ResNetV2, `False` for ResNet. Defaults to `False`. + data_format: `None` or str. the ordering of the dimensions in the + inputs. Can be `"channels_last"` + (`(batch_size, height, width, channels)`) or`"channels_first"` + (`(batch_size, channels, height, width)`). + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the models computations and weights. + name: str. A prefix for the layer names used in the block. + + Returns: + The output tensor for the basic residual block. + """ + data_format = data_format or keras.config.image_data_format() + bn_axis = -1 if data_format == "channels_last" else 1 + + x_preact = None + if use_pre_activation: + x_preact = layers.BatchNormalization( + axis=bn_axis, + epsilon=1.001e-5, + dtype=dtype, + name=f"{name}_use_preactivation_bn", + )(x) + x_preact = layers.Activation( + "relu", dtype=dtype, name=f"{name}_use_preactivation_relu" + )(x_preact) + + if conv_shortcut: + shortcut = layers.Conv2D( + filters, + 1, + strides=stride, + data_format=data_format, + use_bias=use_pre_activation, + dtype=dtype, + name=f"{name}_0_conv", + )(x_preact if x_preact is not None else x) + if not use_pre_activation: + shortcut = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_0_bn" + )(shortcut) + else: + if not use_pre_activation or stride == 1: + shortcut = x + else: + shortcut = layers.MaxPooling2D( + 1, + strides=stride, + data_format=data_format, + dtype=dtype, + name=f"{name}_0_max_pooling", + )(x) + + x = layers.Conv2D( + filters, + kernel_size, + strides=stride if not use_pre_activation else 1, + padding="same", + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_1_conv", + )(x_preact if x_preact is not None else x) + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_1_bn" + )(x) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x) + x = layers.Conv2D( + filters, + kernel_size, + strides=1 if not use_pre_activation else stride, + padding="same", + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_2_conv", + )(x) + + if not use_pre_activation: + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_2_bn" + )(x) + x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x]) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x) + else: + x = layers.Add(dtype=dtype, name=f"{name}_out")([shortcut, x]) + return x + + +def apply_bottleneck_block( + x, + filters, + kernel_size=3, + stride=1, + conv_shortcut=False, + use_pre_activation=False, + data_format=None, + dtype=None, + name=None, +): + """Applies a bottleneck residual block. + + Args: + x: Tensor. The input tensor to pass through the block. + filters: int. The number of filters in the block. + kernel_size: int. The kernel size of the bottleneck layer. Defaults to + `3`. + stride: int. The stride length of the first layer. Defaults to `1`. + conv_shortcut: bool. If `True`, use a convolution shortcut. If `False`, + use an identity or pooling shortcut based on the stride. Defaults to + `False`. + use_pre_activation: boolean. Whether to use pre-activation or not. + `True` for ResNetV2, `False` for ResNet. Defaults to `False`. + data_format: `None` or str. the ordering of the dimensions in the + inputs. Can be `"channels_last"` + (`(batch_size, height, width, channels)`) or`"channels_first"` + (`(batch_size, channels, height, width)`). + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the models computations and weights. + name: str. A prefix for the layer names used in the block. + + Returns: + The output tensor for the residual block. + """ + data_format = data_format or keras.config.image_data_format() + bn_axis = -1 if data_format == "channels_last" else 1 + + x_preact = None + if use_pre_activation: + x_preact = layers.BatchNormalization( + axis=bn_axis, + epsilon=1.001e-5, + dtype=dtype, + name=f"{name}_use_preactivation_bn", + )(x) + x_preact = layers.Activation( + "relu", dtype=dtype, name=f"{name}_use_preactivation_relu" + )(x_preact) + + if conv_shortcut: + shortcut = layers.Conv2D( + 4 * filters, + 1, + strides=stride, + data_format=data_format, + use_bias=use_pre_activation, + dtype=dtype, + name=f"{name}_0_conv", + )(x_preact if x_preact is not None else x) + if not use_pre_activation: + shortcut = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_0_bn" + )(shortcut) + else: + if not use_pre_activation or stride == 1: + shortcut = x + else: + shortcut = layers.MaxPooling2D( + 1, + strides=stride, + data_format=data_format, + dtype=dtype, + name=f"{name}_0_max_pooling", + )(x) + + x = layers.Conv2D( + filters, + 1, + strides=stride if not use_pre_activation else 1, + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_1_conv", + )(x_preact if x_preact is not None else x) + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_1_bn" + )(x) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x) + x = layers.Conv2D( + filters, + kernel_size, + strides=1 if not use_pre_activation else stride, + padding="same", + data_format=data_format, + use_bias=False, + dtype=dtype, + name=f"{name}_2_conv", + )(x) + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_2_bn" + )(x) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_2_relu")(x) + x = layers.Conv2D( + 4 * filters, + 1, + data_format=data_format, + use_bias=use_pre_activation, + dtype=dtype, + name=f"{name}_3_conv", + )(x) + + if not use_pre_activation: + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_3_bn" + )(x) + x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x]) + x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x) + else: + x = layers.Add(dtype=dtype, name=f"{name}_out")([shortcut, x]) + return x + + +def apply_stack( + x, + filters, + blocks, + stride, + block_type, + use_pre_activation, + first_shortcut=True, + data_format=None, + dtype=None, + name=None, +): + """Applies a set of stacked residual blocks. + + Args: + x: Tensor. The input tensor to pass through the stack. + filters: int. The number of filters in a block. + blocks: int. The number of blocks in the stack. + stride: int. The stride length of the first layer in the first block. + block_type: str. The block type to stack. One of `"basic_block"` or + `"bottleneck_block"`. Use `"basic_block"` for ResNet18 and ResNet34. + Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152. + use_pre_activation: boolean. Whether to use pre-activation or not. + `True` for ResNetV2, `False` for ResNet and ResNeXt. + first_shortcut: bool. If `True`, use a convolution shortcut. If `False`, + use an identity or pooling shortcut based on the stride. Defaults to + `True`. + data_format: `None` or str. the ordering of the dimensions in the + inputs. Can be `"channels_last"` + (`(batch_size, height, width, channels)`) or`"channels_first"` + (`(batch_size, channels, height, width)`). + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the models computations and weights. + name: str. A prefix for the layer names used in the stack. + + Returns: + Output tensor for the stacked blocks. + """ + if name is None: + version = "v1" if not use_pre_activation else "v2" + name = f"{version}_stack" + + if block_type == "basic_block": + block_fn = apply_basic_block + elif block_type == "bottleneck_block": + block_fn = apply_bottleneck_block + else: + raise ValueError( + '`block_type` must be either `"basic_block"` or ' + f'`"bottleneck_block"`. Received block_type={block_type}.' + ) + x = block_fn( + x, + filters, + stride=stride if not use_pre_activation else 1, + conv_shortcut=first_shortcut, + use_pre_activation=use_pre_activation, + data_format=data_format, + dtype=dtype, + name=f"{name}_block1", + ) + for i in range(2, blocks): + x = block_fn( + x, + filters, + use_pre_activation=use_pre_activation, + data_format=data_format, + dtype=dtype, + name=f"{name}_block{str(i)}", + ) + x = block_fn( + x, + filters, + stride=1 if not use_pre_activation else stride, + use_pre_activation=use_pre_activation, + data_format=data_format, + dtype=dtype, + name=f"{name}_block{str(blocks)}", + ) + return x diff --git a/keras_nlp/src/models/resnet/resnet_backbone_test.py b/keras_nlp/src/models/resnet/resnet_backbone_test.py new file mode 100644 index 0000000000..2113bcd131 --- /dev/null +++ b/keras_nlp/src/models/resnet/resnet_backbone_test.py @@ -0,0 +1,75 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from absl.testing import parameterized +from keras import ops + +from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_nlp.src.tests.test_case import TestCase + + +class ResNetBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "stackwise_num_filters": [64, 64, 64], + "stackwise_num_blocks": [2, 2, 2], + "stackwise_num_strides": [1, 2, 2], + "input_image_shape": (None, None, 3), + "pooling": "avg", + } + self.input_size = (16, 16) + self.input_data = ops.ones((2, 16, 16, 3)) + + @parameterized.named_parameters( + ("v1_basic", False, "basic_block"), + ("v1_bottleneck", False, "bottleneck_block"), + ("v2_basic", True, "basic_block"), + ("v2_bottleneck", True, "bottleneck_block"), + ) + def test_backbone_basics(self, use_pre_activation, block_type): + init_kwargs = self.init_kwargs.copy() + init_kwargs.update( + {"block_type": block_type, "use_pre_activation": use_pre_activation} + ) + self.run_vision_backbone_test( + cls=ResNetBackbone, + init_kwargs=init_kwargs, + input_data=self.input_data, + expected_output_shape=( + (2, 64) if block_type == "basic_block" else (2, 256) + ), + ) + + @parameterized.named_parameters( + ("v1_basic", False, "basic_block"), + ("v1_bottleneck", False, "bottleneck_block"), + ("v2_basic", True, "basic_block"), + ("v2_bottleneck", True, "bottleneck_block"), + ) + @pytest.mark.large + def test_saved_model(self, use_pre_activation, block_type): + init_kwargs = self.init_kwargs.copy() + init_kwargs.update( + { + "block_type": block_type, + "use_pre_activation": use_pre_activation, + "input_image_shape": (16, 16, 3), + } + ) + self.run_model_saving_test( + cls=ResNetBackbone, + init_kwargs=init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/resnet/resnet_image_classifier.py b/keras_nlp/src/models/resnet/resnet_image_classifier.py new file mode 100644 index 0000000000..02c8c78b27 --- /dev/null +++ b/keras_nlp/src/models/resnet/resnet_image_classifier.py @@ -0,0 +1,131 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.image_classifier import ImageClassifier +from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone + + +@keras_nlp_export("keras_nlp.models.ResNetImageClassifier") +class ResNetImageClassifier(ImageClassifier): + """ResNet image classifier task model. + + Args: + backbone: A `keras_nlp.models.ResNetBackbone` instance. + num_classes: int. The number of classes to predict. + activation: `None`, str or callable. The activation function to use on + the `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `"softmax"`. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + where `x` is a tensor and `y` is a integer from `[0, num_classes)`. + All `ImageClassifier` tasks include a `from_preset()` constructor which can + be used to load a pre-trained config and weights. + + Examples: + + Call `predict()` to run inference. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + classifier = keras_nlp.models.ResNetImageClassifier.from_preset("resnet50") + classifier.predict(images) + ``` + + Call `fit()` on a single batch. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + classifier = keras_nlp.models.ResNetImageClassifier.from_preset("resnet50") + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Call `fit()` with custom loss, optimizer and backbone. + ```python + classifier = keras_nlp.models.ResNetImageClassifier.from_preset("resnet50") + classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + ) + classifier.backbone.trainable = False + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Custom backbone. + ```python + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + backbone = keras_nlp.models.ResNetBackbone( + stackwise_num_filters=[64, 64, 64], + stackwise_num_blocks=[2, 2, 2], + stackwise_num_strides=[1, 2, 2], + block_type="basic_block", + use_pre_activation=True, + include_rescaling=False, + pooling="avg", + ) + classifier = keras_nlp.models.ResNetImageClassifier( + backbone=backbone, + num_classes=4, + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + """ + + backbone_cls = ResNetBackbone + + def __init__( + self, + backbone, + num_classes, + activation="softmax", + preprocessor=None, # adding this dummy arg for saved model test + # TODO: once preprocessor flow is figured out, this needs to be updated + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + dtype=self.backbone.dtype_policy, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + outputs = self.output_dense(x) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "activation": self.activation, + } + ) + return config diff --git a/keras_nlp/src/models/resnet/resnet_image_classifier_test.py b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py new file mode 100644 index 0000000000..bbbda72d64 --- /dev/null +++ b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py @@ -0,0 +1,62 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from keras import ops + +from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_nlp.src.models.resnet.resnet_image_classifier import ( + ResNetImageClassifier, +) +from keras_nlp.src.tests.test_case import TestCase + + +class ResNetImageClassifierTest(TestCase): + def setUp(self): + self.images = ops.ones((2, 16, 16, 3)) + self.labels = [0, 3] + self.backbone = ResNetBackbone( + stackwise_num_filters=[64, 64, 64], + stackwise_num_blocks=[2, 2, 2], + stackwise_num_strides=[1, 2, 2], + block_type="basic_block", + use_pre_activation=True, + input_image_shape=(16, 16, 3), + include_rescaling=False, + pooling="avg", + ) + self.init_kwargs = { + "backbone": self.backbone, + "num_classes": 2, + "activation": "softmax", + } + self.train_data = (self.images, self.labels) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=ResNetImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=ResNetImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_nlp/src/tests/test_case.py b/keras_nlp/src/tests/test_case.py index fc1ce77e1e..72653c8b83 100644 --- a/keras_nlp/src/tests/test_case.py +++ b/keras_nlp/src/tests/test_case.py @@ -457,6 +457,66 @@ def run_backbone_test( if run_quantization_check and has_quantization_support(): self.run_quantization_test(backbone, cls, init_kwargs, input_data) + def run_vision_backbone_test( + self, + cls, + init_kwargs, + input_data, + expected_output_shape, + variable_length_data=None, + run_mixed_precision_check=True, + run_quantization_check=True, + run_data_format_check=True, + ): + """Run basic tests for a vision backbone, including compilation.""" + can_run_data_format_check = True + if ( + keras.config.backend() == "tensorflow" + and not tf.config.list_physical_devices("GPU") + ): + # Never test the "channels_first" format on tensorflow CPU. + # Tensorflow lacks support for "channels_first" convolution. + can_run_data_format_check = False + + ori_data_format = keras.config.image_data_format() + keras.config.set_image_data_format("channels_last") + self.run_backbone_test( + cls=cls, + init_kwargs=init_kwargs, + input_data=input_data, + expected_output_shape=expected_output_shape, + variable_length_data=variable_length_data, + run_mixed_precision_check=run_mixed_precision_check, + run_quantization_check=run_quantization_check, + ) + + # Check data_format. We assume that `input_data` is in "channels_last" + # format. + if run_data_format_check and can_run_data_format_check: + keras.config.set_image_data_format("channels_first") + input_data_shape = ops.shape(input_data) + if len(input_data_shape) == 3: + input_data = ops.transpose(input_data, axes=(2, 0, 1)) + elif len(input_data_shape) == 4: + input_data = ops.transpose(input_data, axes=(0, 3, 1, 2)) + if "input_image_shape" in init_kwargs: + init_kwargs = init_kwargs.copy() + init_kwargs["input_image_shape"] = tuple( + reversed(init_kwargs["input_image_shape"]) + ) + self.run_backbone_test( + cls=cls, + init_kwargs=init_kwargs, + input_data=input_data, + expected_output_shape=expected_output_shape, + variable_length_data=variable_length_data, + run_mixed_precision_check=run_mixed_precision_check, + run_quantization_check=run_quantization_check, + ) + + # Restore the original `image_data_format`. + keras.config.set_image_data_format(ori_data_format) + def run_task_test( self, cls, diff --git a/keras_nlp/src/utils/keras_utils.py b/keras_nlp/src/utils/keras_utils.py index 0fb96ccffb..b37b74ad19 100644 --- a/keras_nlp/src/utils/keras_utils.py +++ b/keras_nlp/src/utils/keras_utils.py @@ -115,3 +115,16 @@ def assert_quantization_support(): "Quantization API requires Keras >= 3.4.0 to function " f"correctly. Received: '{keras.version()}'" ) + + +def standardize_data_format(data_format): + if data_format is None: + return keras.config.image_data_format() + data_format = str(data_format).lower() + if data_format not in {"channels_first", "channels_last"}: + raise ValueError( + "The `data_format` argument must be one of " + "{'channels_first', 'channels_last'}. " + f"Received: data_format={data_format}" + ) + return data_format From 26afc7e538927bbb8d588ab72ce50c3a6c1f89b5 Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Wed, 14 Aug 2024 18:30:21 -0700 Subject: [PATCH 3/7] Add CSP DarkNet backbone and classifier (#1774) * Add CSP DarkNet * Add CSP DarkNet * snake_case function names * change use_depthwise to block_type --- keras_nlp/api/models/__init__.py | 6 + keras_nlp/src/models/csp_darknet/__init__.py | 13 + .../csp_darknet/csp_darknet_backbone.py | 410 ++++++++++++++++++ .../csp_darknet/csp_darknet_backbone_test.py | 50 +++ .../csp_darknet_image_classifier.py | 133 ++++++ .../csp_darknet_image_classifier_test.py | 65 +++ 6 files changed, 677 insertions(+) create mode 100644 keras_nlp/src/models/csp_darknet/__init__.py create mode 100644 keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py create mode 100644 keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py create mode 100644 keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier.py create mode 100644 keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier_test.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 783cfd5087..aca1e28538 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -50,6 +50,12 @@ from keras_nlp.src.models.bloom.bloom_tokenizer import BloomTokenizer from keras_nlp.src.models.causal_lm import CausalLM from keras_nlp.src.models.classifier import Classifier +from keras_nlp.src.models.csp_darknet.csp_darknet_backbone import ( + CSPDarkNetBackbone, +) +from keras_nlp.src.models.csp_darknet.csp_darknet_image_classifier import ( + CSPDarkNetImageClassifier, +) from keras_nlp.src.models.deberta_v3.deberta_v3_backbone import ( DebertaV3Backbone, ) diff --git a/keras_nlp/src/models/csp_darknet/__init__.py b/keras_nlp/src/models/csp_darknet/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/csp_darknet/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py new file mode 100644 index 0000000000..2745f61d01 --- /dev/null +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py @@ -0,0 +1,410 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras +from keras import layers + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.backbone import Backbone + + +@keras_nlp_export("keras_nlp.models.CSPDarkNetBackbone") +class CSPDarkNetBackbone(Backbone): + """This class represents Keras Backbone of CSPDarkNet model. + + This class implements a CSPDarkNet backbone as described in + [CSPNet: A New Backbone that can Enhance Learning Capability of CNN]( + https://arxiv.org/abs/1911.11929). + + Args: + stackwise_num_filters: A list of ints, filter size for each dark + level in the model. + stackwise_depth: A list of ints, the depth for each dark level in the + model. + include_rescaling: boolean. If `True`, rescale the input using + `Rescaling(1 / 255.0)` layer. If `False`, do nothing. Defaults to + `True`. + block_type: str. One of `"basic_block"` or `"depthwise_block"`. + Use `"depthwise_block"` for depthwise conv block + `"basic_block"` for basic conv block. + Defaults to "basic_block". + input_image_shape: tuple. The input shape without the batch size. + Defaults to `(None, None, 3)`. + + Examples: + ```python + input_data = np.ones(shape=(8, 224, 224, 3)) + + # Pretrained backbone + model = keras_nlp.models.CSPDarkNetBackbone.from_preset( + "csp_darknet_tiny_imagenet" + ) + model(input_data) + + # Randomly initialized backbone with a custom config + model = keras_nlp.models.CSPDarkNetBackbone( + stackwise_num_filters=[128, 256, 512, 1024], + stackwise_depth=[3, 9, 9, 3], + include_rescaling=False, + ) + model(input_data) + ``` + """ + + def __init__( + self, + stackwise_num_filters, + stackwise_depth, + include_rescaling, + block_type="basic_block", + input_image_shape=(224, 224, 3), + **kwargs, + ): + # === Functional Model === + apply_ConvBlock = ( + apply_darknet_conv_block_depthwise + if block_type == "depthwise_block" + else apply_darknet_conv_block + ) + base_channels = stackwise_num_filters[0] // 2 + + image_input = layers.Input(shape=input_image_shape) + x = image_input + if include_rescaling: + x = layers.Rescaling(scale=1 / 255.0)(x) + + x = apply_focus(name="stem_focus")(x) + x = apply_darknet_conv_block( + base_channels, kernel_size=3, strides=1, name="stem_conv" + )(x) + for index, (channels, depth) in enumerate( + zip(stackwise_num_filters, stackwise_depth) + ): + x = apply_ConvBlock( + channels, + kernel_size=3, + strides=2, + name=f"dark{index + 2}_conv", + )(x) + + if index == len(stackwise_depth) - 1: + x = apply_spatial_pyramid_pooling_bottleneck( + channels, + hidden_filters=channels // 2, + name=f"dark{index + 2}_spp", + )(x) + + x = apply_cross_stage_partial( + channels, + num_bottlenecks=depth, + block_type="basic_block", + residual=(index != len(stackwise_depth) - 1), + name=f"dark{index + 2}_csp", + )(x) + + super().__init__(inputs=image_input, outputs=x, **kwargs) + + # === Config === + self.stackwise_num_filters = stackwise_num_filters + self.stackwise_depth = stackwise_depth + self.include_rescaling = include_rescaling + self.block_type = block_type + self.input_image_shape = input_image_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "stackwise_num_filters": self.stackwise_num_filters, + "stackwise_depth": self.stackwise_depth, + "include_rescaling": self.include_rescaling, + "block_type": self.block_type, + "input_image_shape": self.input_image_shape, + } + ) + return config + + +def apply_focus(name=None): + """A block used in CSPDarknet to focus information into channels of the + image. + + If the dimensions of a batch input is (batch_size, width, height, channels), + this layer converts the image into size (batch_size, width/2, height/2, + 4*channels). See [the original discussion on YoloV5 Focus Layer](https://github.com/ultralytics/yolov5/discussions/3181). + + Args: + name: the name for the lambda layer used in the block. + + Returns: + a function that takes an input Tensor representing a Focus layer. + """ + + def apply(x): + return layers.Concatenate(name=name)( + [ + x[..., ::2, ::2, :], + x[..., 1::2, ::2, :], + x[..., ::2, 1::2, :], + x[..., 1::2, 1::2, :], + ], + ) + + return apply + + +def apply_darknet_conv_block( + filters, kernel_size, strides, use_bias=False, activation="silu", name=None +): + """ + The basic conv block used in Darknet. Applies Conv2D followed by a + BatchNorm. + + Args: + filters: Integer, the dimensionality of the output space (i.e. the + number of output filters in the convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the + height and width of the 2D convolution window. Can be a single + integer to specify the same value both dimensions. + strides: An integer or tuple/list of 2 integers, specifying the strides + of the convolution along the height and width. Can be a single + integer to the same value both dimensions. + use_bias: Boolean, whether the layer uses a bias vector. + activation: the activation applied after the BatchNorm layer. One of + "silu", "relu" or "leaky_relu", defaults to "silu". + name: the prefix for the layer names used in the block. + """ + if name is None: + name = f"conv_block{keras.backend.get_uid('conv_block')}" + + def apply(inputs): + x = layers.Conv2D( + filters, + kernel_size, + strides, + padding="same", + use_bias=use_bias, + name=name + "_conv", + )(inputs) + + x = layers.BatchNormalization(name=name + "_bn")(x) + + if activation == "silu": + x = layers.Lambda(lambda x: keras.activations.silu(x))(x) + elif activation == "relu": + x = layers.ReLU()(x) + elif activation == "leaky_relu": + x = layers.LeakyReLU(0.1)(x) + + return x + + return apply + + +def apply_darknet_conv_block_depthwise( + filters, kernel_size, strides, activation="silu", name=None +): + """ + The depthwise conv block used in CSPDarknet. + + Args: + filters: Integer, the dimensionality of the output space (i.e. the + number of output filters in the final convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the + height and width of the 2D convolution window. Can be a single + integer to specify the same value both dimensions. + strides: An integer or tuple/list of 2 integers, specifying the strides + of the convolution along the height and width. Can be a single + integer to the same value both dimensions. + activation: the activation applied after the final layer. One of "silu", + "relu" or "leaky_relu", defaults to "silu". + name: the prefix for the layer names used in the block. + + """ + if name is None: + name = f"conv_block{keras.backend.get_uid('conv_block')}" + + def apply(inputs): + x = layers.DepthwiseConv2D( + kernel_size, strides, padding="same", use_bias=False + )(inputs) + x = layers.BatchNormalization()(x) + + if activation == "silu": + x = layers.Lambda(lambda x: keras.activations.swish(x))(x) + elif activation == "relu": + x = layers.ReLU()(x) + elif activation == "leaky_relu": + x = layers.LeakyReLU(0.1)(x) + + x = apply_darknet_conv_block( + filters, kernel_size=1, strides=1, activation=activation + )(x) + + return x + + return apply + + +def apply_spatial_pyramid_pooling_bottleneck( + filters, + hidden_filters=None, + kernel_sizes=(5, 9, 13), + activation="silu", + name=None, +): + """ + Spatial pyramid pooling layer used in YOLOv3-SPP + + Args: + filters: Integer, the dimensionality of the output spaces (i.e. the + number of output filters in used the blocks). + hidden_filters: Integer, the dimensionality of the intermediate + bottleneck space (i.e. the number of output filters in the + bottleneck convolution). If None, it will be equal to filters. + Defaults to None. + kernel_sizes: A list or tuple representing all the pool sizes used for + the pooling layers, defaults to (5, 9, 13). + activation: Activation for the conv layers, defaults to "silu". + name: the prefix for the layer names used in the block. + + Returns: + a function that takes an input Tensor representing an + SpatialPyramidPoolingBottleneck. + """ + if name is None: + name = f"spp{keras.backend.get_uid('spp')}" + + if hidden_filters is None: + hidden_filters = filters + + def apply(x): + x = apply_darknet_conv_block( + hidden_filters, + kernel_size=1, + strides=1, + activation=activation, + name=f"{name}_conv1", + )(x) + x = [x] + + for kernel_size in kernel_sizes: + x.append( + layers.MaxPooling2D( + kernel_size, + strides=1, + padding="same", + name=f"{name}_maxpool_{kernel_size}", + )(x[0]) + ) + + x = layers.Concatenate(name=f"{name}_concat")(x) + x = apply_darknet_conv_block( + filters, + kernel_size=1, + strides=1, + activation=activation, + name=f"{name}_conv2", + )(x) + + return x + + return apply + + +def apply_cross_stage_partial( + filters, + num_bottlenecks, + residual=True, + block_type="basic_block", + activation="silu", + name=None, +): + """A block used in Cross Stage Partial Darknet. + + Args: + filters: Integer, the dimensionality of the output space (i.e. the + number of output filters in the final convolution). + num_bottlenecks: an integer representing the number of blocks added in + the layer bottleneck. + residual: a boolean representing whether the value tensor before the + bottleneck should be added to the output of the bottleneck as a + residual, defaults to True. + block_type: str. One of `"basic_block"` or `"depthwise_block"`. + Use `"depthwise_block"` for depthwise conv block + `"basic_block"` for basic conv block. + Defaults to "basic_block". + activation: the activation applied after the final layer. One of "silu", + "relu" or "leaky_relu", defaults to "silu". + """ + + if name is None: + name = f"cross_stage_partial_{keras.backend.get_uid('cross_stage_partial')}" + + def apply(inputs): + hidden_channels = filters // 2 + ConvBlock = ( + apply_darknet_conv_block_depthwise + if block_type == "basic_block" + else apply_darknet_conv_block + ) + + x1 = apply_darknet_conv_block( + hidden_channels, + kernel_size=1, + strides=1, + activation=activation, + name=f"{name}_conv1", + )(inputs) + + x2 = apply_darknet_conv_block( + hidden_channels, + kernel_size=1, + strides=1, + activation=activation, + name=f"{name}_conv2", + )(inputs) + + for i in range(num_bottlenecks): + residual_x = x1 + x1 = apply_darknet_conv_block( + hidden_channels, + kernel_size=1, + strides=1, + activation=activation, + name=f"{name}_bottleneck_{i}_conv1", + )(x1) + x1 = ConvBlock( + hidden_channels, + kernel_size=3, + strides=1, + activation=activation, + name=f"{name}_bottleneck_{i}_conv2", + )(x1) + if residual: + x1 = layers.Add(name=f"{name}_bottleneck_{i}_add")( + [residual_x, x1] + ) + + x = layers.Concatenate(name=f"{name}_concat")([x1, x2]) + x = apply_darknet_conv_block( + filters, + kernel_size=1, + strides=1, + activation=activation, + name=f"{name}_conv3", + )(x) + + return x + + return apply diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py new file mode 100644 index 0000000000..aaad4fe515 --- /dev/null +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py @@ -0,0 +1,50 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from keras_nlp.src.models.csp_darknet.csp_darknet_backbone import ( + CSPDarkNetBackbone, +) +from keras_nlp.src.tests.test_case import TestCase + + +class CSPDarkNetBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "stackwise_num_filters": [32, 64, 128, 256], + "stackwise_depth": [1, 3, 3, 1], + "include_rescaling": False, + "block_type": "basic_block", + "input_image_shape": (224, 224, 3), + } + self.input_data = np.ones((2, 224, 224, 3), dtype="float32") + + def test_backbone_basics(self): + self.run_backbone_test( + cls=CSPDarkNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 7, 7, 256), + run_mixed_precision_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=CSPDarkNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier.py b/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier.py new file mode 100644 index 0000000000..6b013bdcc0 --- /dev/null +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier.py @@ -0,0 +1,133 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.csp_darknet.csp_darknet_backbone import ( + CSPDarkNetBackbone, +) +from keras_nlp.src.models.image_classifier import ImageClassifier + + +@keras_nlp_export("keras_nlp.models.CSPDarkNetImageClassifier") +class CSPDarkNetImageClassifier(ImageClassifier): + """CSPDarkNet image classifier task model. + + Args: + backbone: A `keras_nlp.models.CSPDarkNetBackbone` instance. + num_classes: int. The number of classes to predict. + activation: `None`, str or callable. The activation function to use on + the `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `"softmax"`. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + where `x` is a tensor and `y` is a integer from `[0, num_classes)`. + All `ImageClassifier` tasks include a `from_preset()` constructor which can + be used to load a pre-trained config and weights. + + Examples: + + Call `predict()` to run inference. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + classifier = keras_nlp.models.CSPDarkNetImageClassifier.from_preset( + "csp_darknet_tiny_imagenet") + classifier.predict(images) + ``` + + Call `fit()` on a single batch. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + classifier = keras_nlp.models.CSPDarkNetImageClassifier.from_preset( + "csp_darknet_tiny_imagenet") + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Call `fit()` with custom loss, optimizer and backbone. + ```python + classifier = keras_nlp.models.CSPDarkNetImageClassifier.from_preset( + "csp_darknet_tiny_imagenet") + classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + ) + classifier.backbone.trainable = False + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Custom backbone. + ```python + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + backbone = keras_nlp.models.CSPDarkNetBackbone( + stackwise_num_filters=[128, 256, 512, 1024], + stackwise_depth=[3, 9, 9, 3], + include_rescaling=False, + block_type="basic_block", + input_image_shape = (224, 224, 3), + ) + classifier = keras_nlp.models.CSPDarkNetImageClassifier( + backbone=backbone, + num_classes=4, + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + """ + + backbone_cls = CSPDarkNetBackbone + + def __init__( + self, + backbone, + num_classes, + activation="softmax", + preprocessor=None, # adding this dummy arg for saved model test + # TODO: once preprocessor flow is figured out, this needs to be updated + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + outputs = self.output_dense(x) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "activation": self.activation, + } + ) + return config diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier_test.py b/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier_test.py new file mode 100644 index 0000000000..a07bb017a3 --- /dev/null +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_image_classifier_test.py @@ -0,0 +1,65 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from keras_nlp.src.models.csp_darknet.csp_darknet_backbone import ( + CSPDarkNetBackbone, +) +from keras_nlp.src.models.csp_darknet.csp_darknet_image_classifier import ( + CSPDarkNetImageClassifier, +) +from keras_nlp.src.tests.test_case import TestCase + + +class CSPDarkNetImageClassifierTest(TestCase): + def setUp(self): + # Setup model. + self.images = np.ones((2, 16, 16, 3), dtype="float32") + self.labels = [0, 3] + self.backbone = CSPDarkNetBackbone( + stackwise_num_filters=[2, 16, 16], + stackwise_depth=[1, 3, 3, 1], + include_rescaling=False, + block_type="basic_block", + input_image_shape=(16, 16, 3), + ) + self.init_kwargs = { + "backbone": self.backbone, + "num_classes": 2, + "activation": "softmax", + } + self.train_data = ( + self.images, + self.labels, + ) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=CSPDarkNetImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=CSPDarkNetImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) From 00ab4d5c4d0350872a64e9a42ad22cf4cb3a43c2 Mon Sep 17 00:00:00 2001 From: "Hongyu, Chiu" <20734616+james77777778@users.noreply.github.com> Date: Fri, 16 Aug 2024 04:29:57 +0800 Subject: [PATCH 4/7] Add `FeaturePyramidBackbone` and port weights from `timm` for `ResNetBackbone` (#1769) * Add FeaturePyramidBackbone and update ResNetBackbone * Simplify the implementation * Fix CI * Make ResNetBackbone compatible with timm and add FeaturePyramidBackbone * Add conversion implementation * Update docstrings * Address comments --- keras_nlp/api/models/__init__.py | 1 + keras_nlp/src/models/backbone.py | 3 + .../src/models/feature_pyramid_backbone.py | 73 +++++ .../src/models/resnet/resnet_backbone.py | 252 +++++++++++------- .../src/models/resnet/resnet_backbone_test.py | 25 +- .../models/resnet/resnet_image_classifier.py | 7 +- .../resnet/resnet_image_classifier_test.py | 4 + keras_nlp/src/utils/preset_utils.py | 4 + keras_nlp/src/utils/timm/__init__.py | 13 + keras_nlp/src/utils/timm/convert.py | 37 +++ keras_nlp/src/utils/timm/convert_resnet.py | 171 ++++++++++++ .../src/utils/timm/convert_resnet_test.py | 28 ++ .../utils/transformers/safetensor_utils.py | 4 +- 13 files changed, 524 insertions(+), 98 deletions(-) create mode 100644 keras_nlp/src/models/feature_pyramid_backbone.py create mode 100644 keras_nlp/src/utils/timm/__init__.py create mode 100644 keras_nlp/src/utils/timm/convert.py create mode 100644 keras_nlp/src/utils/timm/convert_resnet.py create mode 100644 keras_nlp/src/utils/timm/convert_resnet_test.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index aca1e28538..e079aa7c9e 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -112,6 +112,7 @@ ) from keras_nlp.src.models.falcon.falcon_preprocessor import FalconPreprocessor from keras_nlp.src.models.falcon.falcon_tokenizer import FalconTokenizer +from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone from keras_nlp.src.models.gemma.gemma_backbone import GemmaBackbone from keras_nlp.src.models.gemma.gemma_causal_lm import GemmaCausalLM from keras_nlp.src.models.gemma.gemma_causal_lm_preprocessor import ( diff --git a/keras_nlp/src/models/backbone.py b/keras_nlp/src/models/backbone.py index a58072dfce..0f41c63c81 100644 --- a/keras_nlp/src/models/backbone.py +++ b/keras_nlp/src/models/backbone.py @@ -30,6 +30,7 @@ from keras_nlp.src.utils.preset_utils import save_metadata from keras_nlp.src.utils.preset_utils import save_serialized_object from keras_nlp.src.utils.python_utils import classproperty +from keras_nlp.src.utils.timm.convert import load_timm_backbone from keras_nlp.src.utils.transformers.convert import load_transformers_backbone @@ -204,6 +205,8 @@ class like `keras_nlp.models.Backbone.from_preset()`, or from if format == "transformers": return load_transformers_backbone(cls, preset, load_weights) + elif format == "timm": + return load_timm_backbone(cls, preset, load_weights, **kwargs) preset_cls = check_config_class(preset) if not issubclass(preset_cls, cls): diff --git a/keras_nlp/src/models/feature_pyramid_backbone.py b/keras_nlp/src/models/feature_pyramid_backbone.py new file mode 100644 index 0000000000..989d9fbd64 --- /dev/null +++ b/keras_nlp/src/models/feature_pyramid_backbone.py @@ -0,0 +1,73 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.backbone import Backbone + + +@keras_nlp_export("keras_nlp.models.FeaturePyramidBackbone") +class FeaturePyramidBackbone(Backbone): + """A backbone with feature pyramid outputs. + + `FeaturePyramidBackbone` extends `Backbone` with a single `pyramid_outputs` + property for accessing the feature pyramid outputs of the model. Subclassers + should set the `pyramid_outputs` property during the model constructor. + + Example: + + ```python + input_data = np.random.uniform(0, 255, size=(2, 224, 224, 3)) + + # Convert to feature pyramid output format using ResNet. + backbone = ResNetBackbone.from_preset("resnet50") + model = keras.Model( + inputs=backbone.inputs, outputs=backbone.pyramid_outputs + ) + model(input_data) # A dict containing the keys ["P2", "P3", "P4", "P5"] + ``` + """ + + @property + def pyramid_outputs(self): + """A dict for feature pyramid outputs. + + The key is a string represents the name of the feature output and the + value is a `keras.KerasTensor`. A typical feature pyramid has multiple + levels corresponding to scales such as `["P2", "P3", "P4", "P5"]`. Scale + `Pn` represents a feature map `2^n` times smaller in width and height + than the inputs. + """ + return getattr(self, "_pyramid_outputs", {}) + + @pyramid_outputs.setter + def pyramid_outputs(self, value): + if not isinstance(value, dict): + raise TypeError( + "`pyramid_outputs` must be a dictionary. " + f"Received: value={value} of type {type(value)}" + ) + for k, v in value.items(): + if not isinstance(k, str): + raise TypeError( + "The key of `pyramid_outputs` must be a string. " + f"Received: key={k} of type {type(k)}" + ) + if not isinstance(v, keras.KerasTensor): + raise TypeError( + "The value of `pyramid_outputs` must be a " + "`keras.KerasTensor`. " + f"Received: value={v} of type {type(v)}" + ) + self._pyramid_outputs = value diff --git a/keras_nlp/src/models/resnet/resnet_backbone.py b/keras_nlp/src/models/resnet/resnet_backbone.py index bec5ba60b5..0f4d7c139a 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone.py +++ b/keras_nlp/src/models/resnet/resnet_backbone.py @@ -13,20 +13,23 @@ # limitations under the License. import keras from keras import layers +from keras import ops from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone from keras_nlp.src.utils.keras_utils import standardize_data_format @keras_nlp_export("keras_nlp.models.ResNetBackbone") -class ResNetBackbone(Backbone): +class ResNetBackbone(FeaturePyramidBackbone): """ResNet and ResNetV2 core network with hyperparameters. This class implements a ResNet backbone as described in [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385)( - CVPR 2016) and [Identity Mappings in Deep Residual Networks]( - https://arxiv.org/abs/1603.05027)(ECCV 2016). + CVPR 2016), [Identity Mappings in Deep Residual Networks]( + https://arxiv.org/abs/1603.05027)(ECCV 2016) and [ResNet strikes back: An + improved training procedure in timm](https://arxiv.org/abs/2110.00476)( + NeurIPS 2021 Workshop). The difference in ResNet and ResNetV2 rests in the structure of their individual building blocks. In ResNetV2, the batch normalization and @@ -34,6 +37,9 @@ class ResNetBackbone(Backbone): the batch normalization and ReLU activation are applied after the convolution layers. + Note that `ResNetBackbone` expects the inputs to be images with a value + range of `[0, 255]` when `include_rescaling=True`. + Args: stackwise_num_filters: list of ints. The number of filters for each stack. @@ -46,8 +52,8 @@ class ResNetBackbone(Backbone): use_pre_activation: boolean. Whether to use pre-activation or not. `True` for ResNetV2, `False` for ResNet. include_rescaling: boolean. If `True`, rescale the input using - `Rescaling(1 / 255.0)` layer. If `False`, do nothing. Defaults to - `True`. + `Rescaling` and `Normalization` layers. If `False`, do nothing. + Defaults to `True`. input_image_shape: tuple. The input shape without the batch size. Defaults to `(None, None, 3)`. pooling: `None` or str. Pooling mode for feature extraction. Defaults @@ -70,11 +76,11 @@ class ResNetBackbone(Backbone): `~/.keras/keras.json`. If you never set it, then it will be `"channels_last"`. dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype - to use for the models computations and weights. + to use for the model's computations and weights. Examples: ```python - input_data = np.ones((2, 224, 224, 3), dtype="float32") + input_data = np.random.uniform(0, 255, size=(2, 224, 224, 3)) # Pretrained ResNet backbone. model = keras_nlp.models.ResNetBackbone.from_preset("resnet50") @@ -136,34 +142,66 @@ def __init__( image_input = layers.Input(shape=input_image_shape) if include_rescaling: x = layers.Rescaling(scale=1 / 255.0, dtype=dtype)(image_input) + x = layers.Normalization( + axis=bn_axis, + mean=(0.485, 0.456, 0.406), + variance=(0.229**2, 0.224**2, 0.225**2), + dtype=dtype, + name="normalization", + )(x) else: x = image_input + # The padding between torch and tensorflow/jax differs when `strides>1`. + # Therefore, we need to manually pad the tensor. + x = layers.ZeroPadding2D( + 3, + data_format=data_format, + dtype=dtype, + name="conv1_pad", + )(x) x = layers.Conv2D( 64, 7, strides=2, - padding="same", data_format=data_format, - use_bias=use_pre_activation, + use_bias=False, dtype=dtype, name="conv1_conv", )(x) if not use_pre_activation: x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name="conv1_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name="conv1_bn", )(x) x = layers.Activation("relu", dtype=dtype, name="conv1_relu")(x) - x = layers.MaxPool2D( + if use_pre_activation: + # A workaround for ResNetV2: we need -inf padding to prevent zeros + # from being the max values in the following `MaxPooling2D`. + pad_width = [[1, 1], [1, 1]] + if data_format == "channels_last": + pad_width += [[0, 0]] + else: + pad_width = [[0, 0]] + pad_width + pad_width = [[0, 0]] + pad_width + x = ops.pad(x, pad_width=pad_width, constant_values=float("-inf")) + else: + x = layers.ZeroPadding2D( + 1, data_format=data_format, dtype=dtype, name="pool1_pad" + )(x) + x = layers.MaxPooling2D( 3, strides=2, - padding="same", data_format=data_format, dtype=dtype, name="pool1_pool", )(x) + pyramid_outputs = {} for stack_index in range(num_stacks): x = apply_stack( x, @@ -179,10 +217,15 @@ def __init__( dtype=dtype, name=f"{version}_stack{stack_index}", ) + pyramid_outputs[f"P{stack_index + 2}"] = x if use_pre_activation: x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name="post_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name="post_bn", )(x) x = layers.Activation("relu", dtype=dtype, name="post_relu")(x) @@ -213,18 +256,23 @@ def __init__( self.include_rescaling = include_rescaling self.input_image_shape = input_image_shape self.pooling = pooling + self.pyramid_outputs = pyramid_outputs def get_config(self): - return { - "stackwise_num_filters": self.stackwise_num_filters, - "stackwise_num_blocks": self.stackwise_num_blocks, - "stackwise_num_strides": self.stackwise_num_strides, - "block_type": self.block_type, - "use_pre_activation": self.use_pre_activation, - "include_rescaling": self.include_rescaling, - "input_image_shape": self.input_image_shape, - "pooling": self.pooling, - } + config = super().get_config() + config.update( + { + "stackwise_num_filters": self.stackwise_num_filters, + "stackwise_num_blocks": self.stackwise_num_blocks, + "stackwise_num_strides": self.stackwise_num_strides, + "block_type": self.block_type, + "use_pre_activation": self.use_pre_activation, + "include_rescaling": self.include_rescaling, + "input_image_shape": self.input_image_shape, + "pooling": self.pooling, + } + ) + return config def apply_basic_block( @@ -269,68 +317,81 @@ def apply_basic_block( if use_pre_activation: x_preact = layers.BatchNormalization( axis=bn_axis, - epsilon=1.001e-5, + epsilon=1e-5, + momentum=0.9, dtype=dtype, - name=f"{name}_use_preactivation_bn", + name=f"{name}_pre_activation_bn", )(x) x_preact = layers.Activation( - "relu", dtype=dtype, name=f"{name}_use_preactivation_relu" + "relu", dtype=dtype, name=f"{name}_pre_activation_relu" )(x_preact) if conv_shortcut: + x = x_preact if x_preact is not None else x shortcut = layers.Conv2D( filters, 1, strides=stride, data_format=data_format, - use_bias=use_pre_activation, + use_bias=False, dtype=dtype, name=f"{name}_0_conv", - )(x_preact if x_preact is not None else x) + )(x) if not use_pre_activation: shortcut = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_0_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_0_bn", )(shortcut) else: - if not use_pre_activation or stride == 1: - shortcut = x - else: - shortcut = layers.MaxPooling2D( - 1, - strides=stride, - data_format=data_format, - dtype=dtype, - name=f"{name}_0_max_pooling", - )(x) + shortcut = x + x = x_preact if x_preact is not None else x + if stride > 1: + x = layers.ZeroPadding2D( + (kernel_size - 1) // 2, + data_format=data_format, + dtype=dtype, + name=f"{name}_1_pad", + )(x) x = layers.Conv2D( filters, kernel_size, - strides=stride if not use_pre_activation else 1, - padding="same", + strides=stride, + padding="valid" if stride > 1 else "same", data_format=data_format, use_bias=False, dtype=dtype, name=f"{name}_1_conv", - )(x_preact if x_preact is not None else x) + )(x) x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_1_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_1_bn", )(x) x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x) + x = layers.Conv2D( filters, kernel_size, - strides=1 if not use_pre_activation else stride, + strides=1, padding="same", data_format=data_format, use_bias=False, dtype=dtype, name=f"{name}_2_conv", )(x) - if not use_pre_activation: x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_2_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_2_bn", )(x) x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x]) x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x) @@ -381,79 +442,97 @@ def apply_bottleneck_block( if use_pre_activation: x_preact = layers.BatchNormalization( axis=bn_axis, - epsilon=1.001e-5, + epsilon=1e-5, + momentum=0.9, dtype=dtype, - name=f"{name}_use_preactivation_bn", + name=f"{name}_pre_activation_bn", )(x) x_preact = layers.Activation( - "relu", dtype=dtype, name=f"{name}_use_preactivation_relu" + "relu", dtype=dtype, name=f"{name}_pre_activation_relu" )(x_preact) if conv_shortcut: + x = x_preact if x_preact is not None else x shortcut = layers.Conv2D( 4 * filters, 1, strides=stride, data_format=data_format, - use_bias=use_pre_activation, + use_bias=False, dtype=dtype, name=f"{name}_0_conv", - )(x_preact if x_preact is not None else x) + )(x) if not use_pre_activation: shortcut = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_0_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_0_bn", )(shortcut) else: - if not use_pre_activation or stride == 1: - shortcut = x - else: - shortcut = layers.MaxPooling2D( - 1, - strides=stride, - data_format=data_format, - dtype=dtype, - name=f"{name}_0_max_pooling", - )(x) + shortcut = x + x = x_preact if x_preact is not None else x x = layers.Conv2D( filters, 1, - strides=stride if not use_pre_activation else 1, + strides=1, data_format=data_format, use_bias=False, dtype=dtype, name=f"{name}_1_conv", - )(x_preact if x_preact is not None else x) + )(x) x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_1_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_1_bn", )(x) x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x) + + if stride > 1: + x = layers.ZeroPadding2D( + (kernel_size - 1) // 2, + data_format=data_format, + dtype=dtype, + name=f"{name}_2_pad", + )(x) x = layers.Conv2D( filters, kernel_size, - strides=1 if not use_pre_activation else stride, - padding="same", + strides=stride, + padding="valid" if stride > 1 else "same", data_format=data_format, use_bias=False, dtype=dtype, name=f"{name}_2_conv", )(x) x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_2_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_2_bn", )(x) x = layers.Activation("relu", dtype=dtype, name=f"{name}_2_relu")(x) + x = layers.Conv2D( 4 * filters, 1, data_format=data_format, - use_bias=use_pre_activation, + use_bias=False, dtype=dtype, name=f"{name}_3_conv", )(x) - if not use_pre_activation: x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_3_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_3_bn", )(x) x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x]) x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x) @@ -513,32 +592,21 @@ def apply_stack( '`block_type` must be either `"basic_block"` or ' f'`"bottleneck_block"`. Received block_type={block_type}.' ) - x = block_fn( - x, - filters, - stride=stride if not use_pre_activation else 1, - conv_shortcut=first_shortcut, - use_pre_activation=use_pre_activation, - data_format=data_format, - dtype=dtype, - name=f"{name}_block1", - ) - for i in range(2, blocks): + for i in range(blocks): + if i == 0: + stride = stride + conv_shortcut = first_shortcut + else: + stride = 1 + conv_shortcut = False x = block_fn( x, filters, + stride=stride, + conv_shortcut=conv_shortcut, use_pre_activation=use_pre_activation, data_format=data_format, dtype=dtype, name=f"{name}_block{str(i)}", ) - x = block_fn( - x, - filters, - stride=1 if not use_pre_activation else stride, - use_pre_activation=use_pre_activation, - data_format=data_format, - dtype=dtype, - name=f"{name}_block{str(blocks)}", - ) return x diff --git a/keras_nlp/src/models/resnet/resnet_backbone_test.py b/keras_nlp/src/models/resnet/resnet_backbone_test.py index 2113bcd131..6d3f774559 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone_test.py +++ b/keras_nlp/src/models/resnet/resnet_backbone_test.py @@ -14,6 +14,7 @@ import pytest from absl.testing import parameterized +from keras import models from keras import ops from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone @@ -29,8 +30,8 @@ def setUp(self): "input_image_shape": (None, None, 3), "pooling": "avg", } - self.input_size = (16, 16) - self.input_data = ops.ones((2, 16, 16, 3)) + self.input_size = 64 + self.input_data = ops.ones((2, self.input_size, self.input_size, 3)) @parameterized.named_parameters( ("v1_basic", False, "basic_block"), @@ -52,6 +53,24 @@ def test_backbone_basics(self, use_pre_activation, block_type): ), ) + def test_pyramid_output_format(self): + init_kwargs = self.init_kwargs.copy() + init_kwargs.update( + {"block_type": "basic_block", "use_pre_activation": False} + ) + backbone = ResNetBackbone(**init_kwargs) + model = models.Model(backbone.inputs, backbone.pyramid_outputs) + output_data = model(self.input_data) + + self.assertIsInstance(output_data, dict) + self.assertEqual( + list(output_data.keys()), list(backbone.pyramid_outputs.keys()) + ) + self.assertEqual(list(output_data.keys()), ["P2", "P3", "P4"]) + for k, v in output_data.items(): + size = self.input_size // (2 ** int(k[1:])) + self.assertEqual(tuple(v.shape[:3]), (2, size, size)) + @parameterized.named_parameters( ("v1_basic", False, "basic_block"), ("v1_bottleneck", False, "bottleneck_block"), @@ -65,7 +84,7 @@ def test_saved_model(self, use_pre_activation, block_type): { "block_type": block_type, "use_pre_activation": use_pre_activation, - "input_image_shape": (16, 16, 3), + "input_image_shape": (None, None, 3), } ) self.run_model_saving_test( diff --git a/keras_nlp/src/models/resnet/resnet_image_classifier.py b/keras_nlp/src/models/resnet/resnet_image_classifier.py index 02c8c78b27..815dc7fcca 100644 --- a/keras_nlp/src/models/resnet/resnet_image_classifier.py +++ b/keras_nlp/src/models/resnet/resnet_image_classifier.py @@ -28,6 +28,8 @@ class ResNetImageClassifier(ImageClassifier): activation: `None`, str or callable. The activation function to use on the `Dense` layer. Set `activation=None` to return the output logits. Defaults to `"softmax"`. + head_dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The + dtype to use for the classification head's computations and weights. To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` where `x` is a tensor and `y` is a integer from `[0, num_classes)`. @@ -92,16 +94,19 @@ def __init__( backbone, num_classes, activation="softmax", + head_dtype=None, preprocessor=None, # adding this dummy arg for saved model test # TODO: once preprocessor flow is figured out, this needs to be updated **kwargs, ): + head_dtype = head_dtype or backbone.dtype_policy + # === Layers === self.backbone = backbone self.output_dense = keras.layers.Dense( num_classes, activation=activation, - dtype=self.backbone.dtype_policy, + dtype=head_dtype, name="predictions", ) diff --git a/keras_nlp/src/models/resnet/resnet_image_classifier_test.py b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py index bbbda72d64..f3f63a14a1 100644 --- a/keras_nlp/src/models/resnet/resnet_image_classifier_test.py +++ b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py @@ -53,6 +53,10 @@ def test_classifier_basics(self): expected_output_shape=(2, 2), ) + def test_head_dtype(self): + model = ResNetImageClassifier(**self.init_kwargs, head_dtype="bfloat16") + self.assertEqual(model.output_dense.compute_dtype, "bfloat16") + @pytest.mark.large def test_saved_model(self): self.run_model_saving_test( diff --git a/keras_nlp/src/utils/preset_utils.py b/keras_nlp/src/utils/preset_utils.py index f797bf9f18..9e3f51c43a 100644 --- a/keras_nlp/src/utils/preset_utils.py +++ b/keras_nlp/src/utils/preset_utils.py @@ -544,6 +544,10 @@ def check_format(preset): if check_file_exists(preset, SAFETENSOR_FILE) or check_file_exists( preset, SAFETENSOR_CONFIG_FILE ): + # Determine the format by parsing the config file. + config = load_config(preset, HF_CONFIG_FILE) + if "hf://timm" in preset or "architecture" in config: + return "timm" return "transformers" if not check_file_exists(preset, METADATA_FILE): diff --git a/keras_nlp/src/utils/timm/__init__.py b/keras_nlp/src/utils/timm/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/utils/timm/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/utils/timm/convert.py b/keras_nlp/src/utils/timm/convert.py new file mode 100644 index 0000000000..edfde3316b --- /dev/null +++ b/keras_nlp/src/utils/timm/convert.py @@ -0,0 +1,37 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert timm models to KerasNLP.""" + +from keras_nlp.src.utils.timm.convert_resnet import load_resnet_backbone + + +def load_timm_backbone(cls, preset, load_weights, **kwargs): + """Load a timm model config and weights as a KerasNLP backbone. + + Args: + cls (class): Keras model class. + preset (str): Preset configuration name. + load_weights (bool): Whether to load the weights. + + Returns: + backbone: Initialized Keras model backbone. + """ + if cls is None: + raise ValueError("Backbone class is None") + if cls.__name__ == "ResNetBackbone": + return load_resnet_backbone(cls, preset, load_weights, **kwargs) + raise ValueError( + f"{cls} has not been ported from the Hugging Face format yet. " + "Please check Hugging Face Hub for the Keras model. " + ) diff --git a/keras_nlp/src/utils/timm/convert_resnet.py b/keras_nlp/src/utils/timm/convert_resnet.py new file mode 100644 index 0000000000..de2224eb9e --- /dev/null +++ b/keras_nlp/src/utils/timm/convert_resnet.py @@ -0,0 +1,171 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +from keras_nlp.src.utils.preset_utils import HF_CONFIG_FILE +from keras_nlp.src.utils.preset_utils import jax_memory_cleanup +from keras_nlp.src.utils.preset_utils import load_config +from keras_nlp.src.utils.transformers.safetensor_utils import SafetensorLoader + + +def convert_backbone_config(timm_config): + timm_architecture = timm_config["architecture"] + + if "resnetv2_" in timm_architecture: + use_pre_activation = True + else: + use_pre_activation = False + + if timm_architecture == "resnet18": + stackwise_num_blocks = [2, 2, 2, 2] + block_type = "basic_block" + elif timm_architecture == "resnet26": + stackwise_num_blocks = [2, 2, 2, 2] + block_type = "bottleneck_block" + elif timm_architecture == "resnet34": + stackwise_num_blocks = [3, 4, 6, 3] + block_type = "basic_block" + elif timm_architecture in ("resnet50", "resnetv2_50"): + stackwise_num_blocks = [3, 4, 6, 3] + block_type = "bottleneck_block" + elif timm_architecture in ("resnet101", "resnetv2_101"): + stackwise_num_blocks = [3, 4, 23, 3] + block_type = "bottleneck_block" + elif timm_architecture in ("resnet152", "resnetv2_152"): + stackwise_num_blocks = [3, 8, 36, 3] + block_type = "bottleneck_block" + else: + raise ValueError( + f"Currently, the architecture {timm_architecture} is not supported." + ) + + return dict( + stackwise_num_filters=[64, 128, 256, 512], + stackwise_num_blocks=stackwise_num_blocks, + stackwise_num_strides=[1, 2, 2, 2], + block_type=block_type, + use_pre_activation=use_pre_activation, + ) + + +def convert_weights(backbone, loader, timm_config): + def port_conv2d(keras_layer_name, hf_weight_prefix): + loader.port_weight( + backbone.get_layer(keras_layer_name).kernel, + hf_weight_key=f"{hf_weight_prefix}.weight", + hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), + ) + + def port_batch_normalization(keras_layer_name, hf_weight_prefix): + loader.port_weight( + backbone.get_layer(keras_layer_name).gamma, + hf_weight_key=f"{hf_weight_prefix}.weight", + ) + loader.port_weight( + backbone.get_layer(keras_layer_name).beta, + hf_weight_key=f"{hf_weight_prefix}.bias", + ) + loader.port_weight( + backbone.get_layer(keras_layer_name).moving_mean, + hf_weight_key=f"{hf_weight_prefix}.running_mean", + ) + loader.port_weight( + backbone.get_layer(keras_layer_name).moving_variance, + hf_weight_key=f"{hf_weight_prefix}.running_var", + ) + + version = "v1" if not backbone.use_pre_activation else "v2" + block_type = backbone.block_type + + # Stem + if version == "v1": + port_conv2d("conv1_conv", "conv1") + port_batch_normalization("conv1_bn", "bn1") + else: + port_conv2d("conv1_conv", "stem.conv") + + # Stages + num_stacks = len(backbone.stackwise_num_filters) + for stack_index in range(num_stacks): + for block_idx in range(backbone.stackwise_num_blocks[stack_index]): + if version == "v1": + keras_name = f"v1_stack{stack_index}_block{block_idx}" + hf_name = f"layer{stack_index+1}.{block_idx}" + else: + keras_name = f"v2_stack{stack_index}_block{block_idx}" + hf_name = f"stages.{stack_index}.blocks.{block_idx}" + + if version == "v1": + if block_idx == 0 and ( + block_type == "bottleneck_block" or stack_index > 0 + ): + port_conv2d( + f"{keras_name}_0_conv", f"{hf_name}.downsample.0" + ) + port_batch_normalization( + f"{keras_name}_0_bn", f"{hf_name}.downsample.1" + ) + port_conv2d(f"{keras_name}_1_conv", f"{hf_name}.conv1") + port_batch_normalization(f"{keras_name}_1_bn", f"{hf_name}.bn1") + port_conv2d(f"{keras_name}_2_conv", f"{hf_name}.conv2") + port_batch_normalization(f"{keras_name}_2_bn", f"{hf_name}.bn2") + if block_type == "bottleneck_block": + port_conv2d(f"{keras_name}_3_conv", f"{hf_name}.conv3") + port_batch_normalization( + f"{keras_name}_3_bn", f"{hf_name}.bn3" + ) + else: + if block_idx == 0 and ( + block_type == "bottleneck_block" or stack_index > 0 + ): + port_conv2d( + f"{keras_name}_0_conv", f"{hf_name}.downsample.conv" + ) + port_batch_normalization( + f"{keras_name}_pre_activation_bn", f"{hf_name}.norm1" + ) + port_conv2d(f"{keras_name}_1_conv", f"{hf_name}.conv1") + port_batch_normalization( + f"{keras_name}_1_bn", f"{hf_name}.norm2" + ) + port_conv2d(f"{keras_name}_2_conv", f"{hf_name}.conv2") + if block_type == "bottleneck_block": + port_batch_normalization( + f"{keras_name}_2_bn", f"{hf_name}.norm3" + ) + port_conv2d(f"{keras_name}_3_conv", f"{hf_name}.conv3") + + # Post + if version == "v2": + port_batch_normalization("post_bn", "norm") + + # Rebuild normalization layer with pretrained mean & std + mean = timm_config["pretrained_cfg"]["mean"] + std = timm_config["pretrained_cfg"]["std"] + normalization_layer = backbone.get_layer("normalization") + normalization_layer.input_mean = mean + normalization_layer.input_variance = [s**2 for s in std] + normalization_layer.build(normalization_layer._build_input_shape) + + +def load_resnet_backbone(cls, preset, load_weights, **kwargs): + timm_config = load_config(preset, HF_CONFIG_FILE) + keras_config = convert_backbone_config(timm_config) + backbone = cls(**keras_config, **kwargs) + if load_weights: + jax_memory_cleanup(backbone) + # Use prefix="" to avoid using `get_prefixed_key`. + with SafetensorLoader(preset, prefix="") as loader: + convert_weights(backbone, loader, timm_config) + return backbone diff --git a/keras_nlp/src/utils/timm/convert_resnet_test.py b/keras_nlp/src/utils/timm/convert_resnet_test.py new file mode 100644 index 0000000000..a30bee46af --- /dev/null +++ b/keras_nlp/src/utils/timm/convert_resnet_test.py @@ -0,0 +1,28 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from keras import ops + +from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_nlp.src.tests.test_case import TestCase + + +class TimmResNetBackboneTest(TestCase): + @pytest.mark.large + def test_convert_resnet18_preset(self): + model = ResNetBackbone.from_preset("hf://timm/resnet18.a1_in1k") + outputs = model.predict(ops.ones((1, 224, 224, 3))) + self.assertEqual(outputs.shape, (1, 512)) + + # TODO: compare numerics with timm model diff --git a/keras_nlp/src/utils/transformers/safetensor_utils.py b/keras_nlp/src/utils/transformers/safetensor_utils.py index 40ef473ff3..2fbd7e1aba 100644 --- a/keras_nlp/src/utils/transformers/safetensor_utils.py +++ b/keras_nlp/src/utils/transformers/safetensor_utils.py @@ -26,7 +26,7 @@ class SafetensorLoader(contextlib.ExitStack): - def __init__(self, preset): + def __init__(self, preset, prefix=None): super().__init__() if safetensors is None: @@ -42,7 +42,7 @@ def __init__(self, preset): else: self.safetensor_config = None self.safetensor_files = {} - self.prefix = None + self.prefix = prefix def get_prefixed_key(self, hf_weight_key, dict_like): """ From 9860756f183cc4ad9247bc29b6c0ee55ec2db6fc Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Thu, 15 Aug 2024 17:38:39 -0700 Subject: [PATCH 5/7] Add DenseNet (#1775) * Add DenseNet * fix testcase * address comments * nit * fix lint errors * move description --- keras_nlp/api/models/__init__.py | 4 + keras_nlp/src/models/densenet/__init__.py | 13 ++ .../src/models/densenet/densenet_backbone.py | 210 ++++++++++++++++++ .../models/densenet/densenet_backbone_test.py | 48 ++++ .../densenet/densenet_image_classifier.py | 131 +++++++++++ .../densenet_image_classifier_test.py | 63 ++++++ 6 files changed, 469 insertions(+) create mode 100644 keras_nlp/src/models/densenet/__init__.py create mode 100644 keras_nlp/src/models/densenet/densenet_backbone.py create mode 100644 keras_nlp/src/models/densenet/densenet_backbone_test.py create mode 100644 keras_nlp/src/models/densenet/densenet_image_classifier.py create mode 100644 keras_nlp/src/models/densenet/densenet_image_classifier_test.py diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index e079aa7c9e..bf5cc28060 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -74,6 +74,10 @@ from keras_nlp.src.models.deberta_v3.deberta_v3_tokenizer import ( DebertaV3Tokenizer, ) +from keras_nlp.src.models.densenet.densenet_backbone import DenseNetBackbone +from keras_nlp.src.models.densenet.densenet_image_classifier import ( + DenseNetImageClassifier, +) from keras_nlp.src.models.distil_bert.distil_bert_backbone import ( DistilBertBackbone, ) diff --git a/keras_nlp/src/models/densenet/__init__.py b/keras_nlp/src/models/densenet/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/densenet/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/models/densenet/densenet_backbone.py b/keras_nlp/src/models/densenet/densenet_backbone.py new file mode 100644 index 0000000000..8456fbcee6 --- /dev/null +++ b/keras_nlp/src/models/densenet/densenet_backbone.py @@ -0,0 +1,210 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.backbone import Backbone + +BN_AXIS = 3 +BN_EPSILON = 1.001e-5 + + +@keras_nlp_export("keras_nlp.models.DenseNetBackbone") +class DenseNetBackbone(Backbone): + """Instantiates the DenseNet architecture. + + This class implements a DenseNet backbone as described in + [Densely Connected Convolutional Networks (CVPR 2017)]( + https://arxiv.org/abs/1608.06993 + ). + + Args: + stackwise_num_repeats: list of ints, number of repeated convolutional + blocks per dense block. + include_rescaling: bool, whether to rescale the inputs. If set + to `True`, inputs will be passed through a `Rescaling(1/255.0)` + layer. Defaults to `True`. + input_image_shape: optional shape tuple, defaults to (224, 224, 3). + compression_ratio: float, compression rate at transition layers, + defaults to 0.5. + growth_rate: int, number of filters added by each dense block, + defaults to 32 + + Examples: + ```python + input_data = np.ones(shape=(8, 224, 224, 3)) + + # Pretrained backbone + model = keras_nlp.models.DenseNetBackbone.from_preset("densenet121_imagenet") + model(input_data) + + # Randomly initialized backbone with a custom config + model = keras_nlp.models.DenseNetBackbone( + stackwise_num_repeats=[6, 12, 24, 16], + include_rescaling=False, + ) + model(input_data) + ``` + """ + + def __init__( + self, + stackwise_num_repeats, + include_rescaling=True, + input_image_shape=(224, 224, 3), + compression_ratio=0.5, + growth_rate=32, + **kwargs, + ): + # === Functional Model === + image_input = keras.layers.Input(shape=input_image_shape) + + x = image_input + if include_rescaling: + x = keras.layers.Rescaling(1 / 255.0)(x) + + x = keras.layers.Conv2D( + 64, 7, strides=2, use_bias=False, padding="same", name="conv1_conv" + )(x) + x = keras.layers.BatchNormalization( + axis=BN_AXIS, epsilon=BN_EPSILON, name="conv1_bn" + )(x) + x = keras.layers.Activation("relu", name="conv1_relu")(x) + x = keras.layers.MaxPooling2D( + 3, strides=2, padding="same", name="pool1" + )(x) + + for stack_index in range(len(stackwise_num_repeats) - 1): + index = stack_index + 2 + x = apply_dense_block( + x, + stackwise_num_repeats[stack_index], + growth_rate, + name=f"conv{index}", + ) + x = apply_transition_block( + x, compression_ratio, name=f"pool{index}" + ) + + x = apply_dense_block( + x, + stackwise_num_repeats[-1], + growth_rate, + name=f"conv{len(stackwise_num_repeats) + 1}", + ) + + x = keras.layers.BatchNormalization( + axis=BN_AXIS, epsilon=BN_EPSILON, name="bn" + )(x) + x = keras.layers.Activation("relu", name="relu")(x) + + super().__init__(inputs=image_input, outputs=x, **kwargs) + + # === Config === + self.stackwise_num_repeats = stackwise_num_repeats + self.include_rescaling = include_rescaling + self.compression_ratio = compression_ratio + self.growth_rate = growth_rate + self.input_image_shape = input_image_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "stackwise_num_repeats": self.stackwise_num_repeats, + "include_rescaling": self.include_rescaling, + "compression_ratio": self.compression_ratio, + "growth_rate": self.growth_rate, + "input_image_shape": self.input_image_shape, + } + ) + return config + + +def apply_dense_block(x, num_repeats, growth_rate, name=None): + """A dense block. + + Args: + x: input tensor. + num_repeats: int, number of repeated convolutional blocks. + growth_rate: int, number of filters added by each dense block. + name: string, block label. + """ + if name is None: + name = f"dense_block_{keras.backend.get_uid('dense_block')}" + + for i in range(num_repeats): + x = apply_conv_block(x, growth_rate, name=f"{name}_block_{i}") + return x + + +def apply_transition_block(x, compression_ratio, name=None): + """A transition block. + + Args: + x: input tensor. + compression_ratio: float, compression rate at transition layers. + name: string, block label. + """ + if name is None: + name = f"transition_block_{keras.backend.get_uid('transition_block')}" + + x = keras.layers.BatchNormalization( + axis=BN_AXIS, epsilon=BN_EPSILON, name=f"{name}_bn" + )(x) + x = keras.layers.Activation("relu", name=f"{name}_relu")(x) + x = keras.layers.Conv2D( + int(x.shape[BN_AXIS] * compression_ratio), + 1, + use_bias=False, + name=f"{name}_conv", + )(x) + x = keras.layers.AveragePooling2D(2, strides=2, name=f"{name}_pool")(x) + return x + + +def apply_conv_block(x, growth_rate, name=None): + """A building block for a dense block. + + Args: + x: input tensor. + growth_rate: int, number of filters added by each dense block. + name: string, block label. + """ + if name is None: + name = f"conv_block_{keras.backend.get_uid('conv_block')}" + + shortcut = x + x = keras.layers.BatchNormalization( + axis=BN_AXIS, epsilon=BN_EPSILON, name=f"{name}_0_bn" + )(x) + x = keras.layers.Activation("relu", name=f"{name}_0_relu")(x) + x = keras.layers.Conv2D( + 4 * growth_rate, 1, use_bias=False, name=f"{name}_1_conv" + )(x) + x = keras.layers.BatchNormalization( + axis=BN_AXIS, epsilon=BN_EPSILON, name=f"{name}_1_bn" + )(x) + x = keras.layers.Activation("relu", name=f"{name}_1_relu")(x) + x = keras.layers.Conv2D( + growth_rate, + 3, + padding="same", + use_bias=False, + name=f"{name}_2_conv", + )(x) + x = keras.layers.Concatenate(axis=BN_AXIS, name=f"{name}_concat")( + [shortcut, x] + ) + return x diff --git a/keras_nlp/src/models/densenet/densenet_backbone_test.py b/keras_nlp/src/models/densenet/densenet_backbone_test.py new file mode 100644 index 0000000000..f0f8dac875 --- /dev/null +++ b/keras_nlp/src/models/densenet/densenet_backbone_test.py @@ -0,0 +1,48 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from keras_nlp.src.models.densenet.densenet_backbone import DenseNetBackbone +from keras_nlp.src.tests.test_case import TestCase + + +class DenseNetBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "stackwise_num_repeats": [6, 12, 24, 16], + "include_rescaling": True, + "compression_ratio": 0.5, + "growth_rate": 32, + "input_image_shape": (224, 224, 3), + } + self.input_data = np.ones((2, 224, 224, 3), dtype="float32") + + def test_backbone_basics(self): + self.run_backbone_test( + cls=DenseNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 7, 7, 1024), + run_mixed_precision_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=DenseNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/densenet/densenet_image_classifier.py b/keras_nlp/src/models/densenet/densenet_image_classifier.py new file mode 100644 index 0000000000..395e8f754d --- /dev/null +++ b/keras_nlp/src/models/densenet/densenet_image_classifier.py @@ -0,0 +1,131 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.densenet.densenet_backbone import DenseNetBackbone +from keras_nlp.src.models.image_classifier import ImageClassifier + + +@keras_nlp_export("keras_nlp.models.DenseNetImageClassifier") +class DenseNetImageClassifier(ImageClassifier): + """DenseNet image classifier task model. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + where `x` is a tensor and `y` is a integer from `[0, num_classes)`. + All `ImageClassifier` tasks include a `from_preset()` constructor which can + be used to load a pre-trained config and weights. + + Args: + backbone: A `keras_nlp.models.DenseNetBackbone` instance. + num_classes: int. The number of classes to predict. + activation: `None`, str or callable. The activation function to use on + the `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `"softmax"`. + + Examples: + + Call `predict()` to run inference. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + classifier = keras_nlp.models.DenseNetImageClassifier.from_preset( + "densenet121_imagenet") + classifier.predict(images) + ``` + + Call `fit()` on a single batch. + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + classifier = keras_nlp.models.DenseNetImageClassifier.from_preset( + "densenet121_imagenet") + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Call `fit()` with custom loss, optimizer and backbone. + ```python + classifier = keras_nlp.models.DenseNetImageClassifier.from_preset( + "densenet121_imagenet") + classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + ) + classifier.backbone.trainable = False + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Custom backbone. + ```python + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + backbone = keras_nlp.models.DenseNetBackbone( + stackwise_num_filters=[128, 256, 512, 1024], + stackwise_depth=[3, 9, 9, 3], + include_rescaling=False, + block_type="basic_block", + input_image_shape = (224, 224, 3), + ) + classifier = keras_nlp.models.DenseNetImageClassifier( + backbone=backbone, + num_classes=4, + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + """ + + backbone_cls = DenseNetBackbone + + def __init__( + self, + backbone, + num_classes, + activation="softmax", + preprocessor=None, # adding this dummy arg for saved model test + # TODO: once preprocessor flow is figured out, this needs to be updated + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + outputs = self.output_dense(x) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "activation": self.activation, + } + ) + return config diff --git a/keras_nlp/src/models/densenet/densenet_image_classifier_test.py b/keras_nlp/src/models/densenet/densenet_image_classifier_test.py new file mode 100644 index 0000000000..60d77d489c --- /dev/null +++ b/keras_nlp/src/models/densenet/densenet_image_classifier_test.py @@ -0,0 +1,63 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from keras_nlp.src.models.densenet.densenet_backbone import DenseNetBackbone +from keras_nlp.src.models.densenet.densenet_image_classifier import ( + DenseNetImageClassifier, +) +from keras_nlp.src.tests.test_case import TestCase + + +class DenseNetImageClassifierTest(TestCase): + def setUp(self): + # Setup model. + self.images = np.ones((2, 224, 224, 3), dtype="float32") + self.labels = [0, 3] + self.backbone = DenseNetBackbone( + stackwise_num_repeats=[6, 12, 24, 16], + include_rescaling=True, + compression_ratio=0.5, + growth_rate=32, + input_image_shape=(224, 224, 3), + ) + self.init_kwargs = { + "backbone": self.backbone, + "num_classes": 2, + "activation": "softmax", + } + self.train_data = ( + self.images, + self.labels, + ) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=DenseNetImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=DenseNetImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) From 8b859b50d92fad7b5f9caa90d60194f821aa92ac Mon Sep 17 00:00:00 2001 From: ushareng Date: Tue, 3 Sep 2024 17:01:26 +0530 Subject: [PATCH 6/7] video_swin added --- .../models/video_swin/video_swin_backbone.py | 191 ++++ .../video_swin/video_swin_backbone_test.py | 84 ++ .../models/video_swin/video_swin_layers.py | 1015 +++++++++++++++++ .../video_swin/video_swin_layers_test.py | 96 ++ .../video_swin/video_swin_video_classifier.py | 100 ++ .../video_swin_video_classifier_test.py | 65 ++ 6 files changed, 1551 insertions(+) create mode 100644 keras_nlp/src/models/video_swin/video_swin_backbone.py create mode 100644 keras_nlp/src/models/video_swin/video_swin_backbone_test.py create mode 100644 keras_nlp/src/models/video_swin/video_swin_layers.py create mode 100644 keras_nlp/src/models/video_swin/video_swin_layers_test.py create mode 100644 keras_nlp/src/models/video_swin/video_swin_video_classifier.py create mode 100644 keras_nlp/src/models/video_swin/video_swin_video_classifier_test.py diff --git a/keras_nlp/src/models/video_swin/video_swin_backbone.py b/keras_nlp/src/models/video_swin/video_swin_backbone.py new file mode 100644 index 0000000000..6fe3796792 --- /dev/null +++ b/keras_nlp/src/models/video_swin/video_swin_backbone.py @@ -0,0 +1,191 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import keras +import numpy as np +from keras import layers + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.models.video_swin.video_swin_layers import ( + VideoSwinBasicLayer, +) +from keras_nlp.src.models.video_swin.video_swin_layers import ( + VideoSwinPatchingAndEmbedding, +) +from keras_nlp.src.models.video_swin.video_swin_layers import ( + VideoSwinPatchMerging, +) + + +@keras_nlp_export("keras_nlp.models.VideoSwinBackbone") +class VideoSwinBackbone(Backbone): + """A Video Swin Transformer backbone model. + + Args: + image_shape (tuple[int], optional): The size of the input video in + `(depth, height, width, channel)` format. + Defaults to `(32, 224, 224, 3)`. + include_rescaling (bool, optional): Whether to rescale the inputs. If + set to `True`, inputs will be passed through a `Rescaling(1/255.0)` layer + and normalize with mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225]. + Defaults to `False`. + patch_size (int | tuple(int)): The patch size for depth, height, and width + dimensions respectively. Default: (2,4,4). + embed_dim (int): Number of linear projection output channels. + Default to 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + Default to [2, 2, 6, 2] + num_heads (tuple[int]): Number of attention head of each stage. + Default to [3, 6, 12, 24] + window_size (int): The window size for depth, height, and width + dimensions respectively. Default to [8, 7, 7]. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + Default to 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + Default to True. + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + Default to None. + drop_rate (float): Float between 0 and 1. Fraction of the input units to drop. + Default: 0. + attn_drop_rate (float): Float between 0 and 1. Attention dropout rate. + Default: 0. + drop_path_rate (float): Float between 0 and 1. Stochastic depth rate. + Default: 0.2. + patch_norm (bool): If True, add layer normalization after patch embedding. + Default to False. + + Example: + ```python + # Build video swin backbone without top layer + model = VideoSwinSBackbone( + include_rescaling=True, image_shape=(8, 256, 256, 3), + ) + videos = keras.ops.ones((1, 8, 256, 256, 3)) + outputs = model.predict(videos) + ``` + + References: + - [Video Swin Transformer](https://arxiv.org/abs/2106.13230) + - [Official Code](https://github.com/SwinTransformer/Video-Swin-Transformer) + """ + + def __init__( + self, + include_rescaling=False, + image_shape=(32, 224, 224, 3), + embed_dim=96, + patch_size=[2, 4, 4], + window_size=[8, 7, 7], + mlp_ratio=4.0, + patch_norm=True, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.2, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + qkv_bias=True, + qk_scale=None, + **kwargs, + ): + + # === Functional Model === + + inputs = keras.layers.Input(shape=image_shape) + x = inputs + + if include_rescaling: + x = layers.Rescaling(1.0 / 255.0)(x) + + x = layers.Normalization( + mean=[0.485, 0.456, 0.406], + variance=[0.229**2, 0.224**2, 0.225**2], + name="normalization", + )(x) + + norm_layer = partial(layers.LayerNormalization, epsilon=1e-05) + + x = VideoSwinPatchingAndEmbedding( + patch_size=patch_size, + embed_dim=embed_dim, + norm_layer=norm_layer if patch_norm else None, + name="videoswin_patching_and_embedding", + )(x) + x = layers.Dropout(drop_rate, name="pos_drop")(x) + + dpr = np.linspace(0.0, drop_path_rate, sum(depths)).tolist() + num_layers = len(depths) + for i in range(num_layers): + layer = VideoSwinBasicLayer( + input_dim=int(embed_dim * 2**i), + depth=depths[i], + num_heads=num_heads[i], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[sum(depths[:i]) : sum(depths[: i + 1])], + norm_layer=norm_layer, + downsampling_layer=( + VideoSwinPatchMerging if (i < num_layers - 1) else None + ), + name=f"videoswin_basic_layer_{i + 1}", + ) + x = layer(x) + + x = norm_layer(axis=-1, epsilon=1e-05, name="videoswin_top_norm")(x) + super().__init__(inputs=inputs, outputs=x, **kwargs) + + # === Config === + self.include_rescaling = include_rescaling + self.embed_dim = embed_dim + self.patch_size = patch_size + self.window_size = window_size + self.mlp_ratio = mlp_ratio + self.norm_layer = norm_layer + self.patch_norm = patch_norm + self.drop_rate = drop_rate + self.attn_drop_rate = attn_drop_rate + self.drop_path_rate = drop_path_rate + self.num_layers = len(depths) + self.num_heads = num_heads + self.qkv_bias = qkv_bias + self.qk_scale = qk_scale + self.depths = depths + self.image_shape = image_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "include_rescaling": self.include_rescaling, + "image_shape": self.image_shape, + "embed_dim": self.embed_dim, + "patch_norm": self.patch_norm, + "window_size": self.window_size, + "patch_size": self.patch_size, + "mlp_ratio": self.mlp_ratio, + "drop_rate": self.drop_rate, + "drop_path_rate": self.drop_path_rate, + "attn_drop_rate": self.attn_drop_rate, + "depths": self.depths, + "num_heads": self.num_heads, + "qkv_bias": self.qkv_bias, + "qk_scale": self.qk_scale, + } + ) + return config diff --git a/keras_nlp/src/models/video_swin/video_swin_backbone_test.py b/keras_nlp/src/models/video_swin/video_swin_backbone_test.py new file mode 100644 index 0000000000..dd224fff30 --- /dev/null +++ b/keras_nlp/src/models/video_swin/video_swin_backbone_test.py @@ -0,0 +1,84 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import keras +import numpy as np +import pytest +from keras import ops + +from keras_nlp.src.models.video_swin.video_swin_backbone import ( + VideoSwinBackbone, +) +from keras_nlp.src.tests.test_case import TestCase + + +class VideoSwinSBackboneTest(TestCase): + + @pytest.mark.large + def test_call(self): + model = VideoSwinBackbone( + include_rescaling=True, image_shape=(8, 256, 256, 3) + ) + x = np.ones((1, 8, 256, 256, 3)) + x_out = ops.convert_to_numpy(model(x)) + num_parameters = sum( + np.prod(tuple(x.shape)) for x in model.trainable_variables + ) + self.assertEqual(x_out.shape, (1, 4, 8, 8, 768)) + self.assertEqual(num_parameters, 27_663_894) + + @pytest.mark.extra_large + def teat_save(self): + # saving test + model = VideoSwinBackbone(include_rescaling=False) + x = np.ones((1, 32, 224, 224, 3)) + x_out = ops.convert_to_numpy(model(x)) + path = os.path.join(self.get_temp_dir(), "model.keras") + model.save(path) + loaded_model = keras.saving.load_model(path) + x_out_loaded = ops.convert_to_numpy(loaded_model(x)) + self.assertAllClose(x_out, x_out_loaded) + + @pytest.mark.extra_large + def test_fit(self): + model = VideoSwinBackbone(include_rescaling=False) + x = np.ones((1, 32, 224, 224, 3)) + y = np.zeros((1, 16, 7, 7, 768)) + model.compile(optimizer="adam", loss="mse", metrics=["mse"]) + model.fit(x, y, epochs=1) + + @pytest.mark.extra_large + def test_can_run_in_mixed_precision(self): + keras.mixed_precision.set_global_policy("mixed_float16") + model = VideoSwinBackbone( + include_rescaling=False, image_shape=(8, 224, 224, 3) + ) + x = np.ones((1, 8, 224, 224, 3)) + y = np.zeros((1, 4, 7, 7, 768)) + model.compile(optimizer="adam", loss="mse", metrics=["mse"]) + model.fit(x, y, epochs=1) + + @pytest.mark.extra_large + def test_can_run_on_gray_video(self): + model = VideoSwinBackbone( + include_rescaling=False, + image_shape=(96, 96, 96, 1), + window_size=[6, 6, 6], + ) + x = np.ones((1, 96, 96, 96, 1)) + y = np.zeros((1, 48, 3, 3, 768)) + model.compile(optimizer="adam", loss="mse", metrics=["mse"]) + model.fit(x, y, epochs=1) diff --git a/keras_nlp/src/models/video_swin/video_swin_layers.py b/keras_nlp/src/models/video_swin/video_swin_layers.py new file mode 100644 index 0000000000..20b4feaabf --- /dev/null +++ b/keras_nlp/src/models/video_swin/video_swin_layers.py @@ -0,0 +1,1015 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras +import numpy as np +from keras import layers +from keras import ops +from keras import random + + +def window_partition(x, window_size): + """Partitions a video tensor into non-overlapping windows of a specified size. + + Args: + x: A tensor with shape (B, D, H, W, C), where: + - B: Batch size + - D: Number of frames (depth) in the video + - H: Height of the video frames + - W: Width of the video frames + - C: Number of channels in the video (e.g., RGB for color) + window_size: A tuple of ints of size 3 representing the window size + along each dimension (depth, height, width). + + Returns: + A tensor with shape (num_windows * B, window_size[0], window_size[1], window_size[2], C), + where each window from the video is a sub-tensor containing the specified + number of frames and the corresponding spatial window. + """ + + input_shape = ops.shape(x) + batch_size, depth, height, width, channel = ( + input_shape[0], + input_shape[1], + input_shape[2], + input_shape[3], + input_shape[4], + ) + + x = ops.reshape( + x, + [ + batch_size, + depth // window_size[0], + window_size[0], + height // window_size[1], + window_size[1], + width // window_size[2], + window_size[2], + channel, + ], + ) + + x = ops.transpose(x, [0, 1, 3, 5, 2, 4, 6, 7]) + windows = ops.reshape( + x, [-1, window_size[0] * window_size[1] * window_size[2], channel] + ) + + return windows + + +def window_reverse(windows, window_size, batch_size, depth, height, width): + """Reconstructs the original video tensor from its partitioned windows. + + This function assumes the windows were created using the `window_partition` function + with the same `window_size`. + + Args: + windows: A tensor with shape (num_windows * batch_size, window_size[0], + window_size[1], window_size[2], channels), where: + - num_windows: Number of windows created during partitioning + - channels: Number of channels in the video (same as in `window_partition`) + window_size: A tuple of ints of size 3 representing the window size used + during partitioning (same as in `window_partition`). + batch_size: Batch size of the original video tensor (same as in `window_partition`). + depth: Number of frames (depth) in the original video tensor (same as in `window_partition`). + height: Height of the video frames in the original tensor (same as in `window_partition`). + width: Width of the video frames in the original tensor (same as in `window_partition`). + + Returns: + A tensor with shape (batch_size, depth, height, width, channels), representing the + original video reconstructed from the provided windows. + """ + x = ops.reshape( + windows, + [ + batch_size, + depth // window_size[0], + height // window_size[1], + width // window_size[2], + window_size[0], + window_size[1], + window_size[2], + -1, + ], + ) + x = ops.transpose(x, [0, 1, 4, 2, 5, 3, 6, 7]) + x = ops.reshape(x, [batch_size, depth, height, width, -1]) + return x + + +def get_window_size(x_size, window_size, shift_size=None): + """Computes the appropriate window size and potentially shift size for Swin Transformer. + + This function implements the logic from the Swin Transformer paper by Ze Liu et al. + (https://arxiv.org/abs/2103.14030) to determine suitable window sizes + based on the input size and the provided base window size. + + Args: + x_size: A tuple of ints of size 3 representing the input size (depth, height, width) + of the data (e.g., video). + window_size: A tuple of ints of size 3 representing the base window size + (depth, height, width) to use for partitioning. + shift_size: A tuple of ints of size 3 (optional) representing the window + shifting size (depth, height, width) for shifted window processing + used in Swin Transformer. If not provided, only window size is computed. + + Returns: + A tuple or a pair of tuples: + - If `shift_size` is None, returns a single tuple representing the adjusted + window size that may be smaller than the provided `window_size` to ensure + it doesn't exceed the input size along any dimension. + - If `shift_size` is provided, returns a pair of tuples. The first tuple + represents the adjusted window size, and the second tuple represents the + adjusted shift size. The adjustments ensure both window size and shift size + do not exceed the corresponding dimensions in the input data. + """ + + use_window_size = list(window_size) + + if shift_size is not None: + use_shift_size = list(shift_size) + + for i in range(len(x_size)): + if x_size[i] <= window_size[i]: + use_window_size[i] = x_size[i] + if shift_size is not None: + use_shift_size[i] = 0 + + if shift_size is None: + return tuple(use_window_size) + else: + return tuple(use_window_size), tuple(use_shift_size) + + +def compute_mask(depth, height, width, window_size, shift_size): + """Computes an attention mask for a sliding window self-attention mechanism + used in Video Swin Transformer. + + This function creates a mask to indicate which windows can attend to each other + during the self-attention operation. It considers non-overlapping and potentially + shifted windows based on the provided window size and shift size. + + Args: + depth (int): Depth (number of frames) of the input video. + height (int): Height of the video frames. + width (int): Width of the video frames. + window_size (tuple[int]): Size of the sliding window in each dimension + (depth, height, width). + shift_size (tuple[int]): Size of the shifting step in each dimension + (depth, height, width). + + Returns: + A tensor of shape (batch_size, num_windows, num_windows), where: + - batch_size: Assumed to be 1 in this function. + - num_windows: Total number of windows covering the entire input based on + the formula: + (depth - window_size[0]) // shift_size[0] + 1) * + (height - window_size[1]) // shift_size[1] + 1) * + (width - window_size[2]) // shift_size[2] + 1) + Each element (attn_mask[i, j]) represents the attention weight between + window i and window j. A value of -100.0 indicates high negative attention + (preventing information flow), 0.0 indicates no mask effect. + """ + + img_mask = np.zeros((1, depth, height, width, 1)) + cnt = 0 + for d in ( + slice(-window_size[0]), + slice(-window_size[0], -shift_size[0]), + slice(-shift_size[0], None), + ): + for h in ( + slice(-window_size[1]), + slice(-window_size[1], -shift_size[1]), + slice(-shift_size[1], None), + ): + for w in ( + slice(-window_size[2]), + slice(-window_size[2], -shift_size[2]), + slice(-shift_size[2], None), + ): + img_mask[:, d, h, w, :] = cnt + cnt = cnt + 1 + mask_windows = window_partition(img_mask, window_size) + mask_windows = ops.squeeze(mask_windows, axis=-1) + attn_mask = ops.expand_dims(mask_windows, axis=1) - ops.expand_dims( + mask_windows, axis=2 + ) + attn_mask = ops.where(attn_mask != 0, -100.0, attn_mask) + attn_mask = ops.where(attn_mask == 0, 0.0, attn_mask) + return attn_mask + + +class MLP(keras.layers.Layer): + """A Multilayer perceptron(MLP) layer. + + Args: + hidden_dim (int): The number of units in the hidden layer. + output_dim (int): The number of units in the output layer. + drop_rate (float): Float between 0 and 1. Fraction of the + input units to drop. + activation (str): Activation to use in the hidden layers. + Default is `"gelu"`. + + References: + - [Video Swin Transformer](https://arxiv.org/abs/2106.13230) + - [Video Swin Transformer GitHub](https://github.com/SwinTransformer/Video-Swin-Transformer) + """ + + def __init__( + self, hidden_dim, output_dim, drop_rate=0.0, activation="gelu", **kwargs + ): + super().__init__(**kwargs) + self.output_dim = output_dim + self.hidden_dim = hidden_dim + self._activation_identifier = activation + self.drop_rate = drop_rate + self.activation = layers.Activation(self._activation_identifier) + self.fc1 = layers.Dense(self.hidden_dim) + self.fc2 = layers.Dense(self.output_dim) + self.dropout = layers.Dropout(self.drop_rate) + + def build(self, input_shape): + self.fc1.build(input_shape) + self.fc2.build((*input_shape[:-1], self.hidden_dim)) + self.built = True + + def call(self, x, training=None): + x = self.fc1(x) + x = self.activation(x) + x = self.dropout(x, training=training) + x = self.fc2(x) + x = self.dropout(x, training=training) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "output_dim": self.output_dim, + "hidden_dim": self.hidden_dim, + "drop_rate": self.drop_rate, + "activation": self._activation_identifier, + } + ) + return config + + +class VideoSwinPatchingAndEmbedding(keras.Model): + """Video to Patch Embedding layer for Video Swin Transformer models. + + This layer performs the initial step in a Video Swin Transformer architecture by + partitioning the input video into 3D patches and embedding them into a vector + dimensional space. + + Args: + patch_size (int): Size of the patch along each dimension + (depth, height, width). Default: (2,4,4). + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (keras.layers, optional): Normalization layer. Default: None + + References: + - [Video Swin Transformer](https://arxiv.org/abs/2106.13230) + - [Video Swin Transformer GitHub](https://github.com/SwinTransformer/Video-Swin-Transformer) + """ + + def __init__( + self, patch_size=(2, 4, 4), embed_dim=96, norm_layer=None, **kwargs + ): + super().__init__(**kwargs) + self.patch_size = patch_size + self.embed_dim = embed_dim + self.norm_layer = norm_layer + + def __compute_padding(self, dim, patch_size): + pad_amount = patch_size - (dim % patch_size) + return [0, pad_amount if pad_amount != patch_size else 0] + + def build(self, input_shape): + self.pads = [ + [0, 0], + self.__compute_padding(input_shape[1], self.patch_size[0]), + self.__compute_padding(input_shape[2], self.patch_size[1]), + self.__compute_padding(input_shape[3], self.patch_size[2]), + [0, 0], + ] + + if self.norm_layer is not None: + self.norm = self.norm_layer( + axis=-1, epsilon=1e-5, name="embed_norm" + ) + self.norm.build((None, None, None, None, self.embed_dim)) + + self.proj = layers.Conv3D( + self.embed_dim, + kernel_size=self.patch_size, + strides=self.patch_size, + name="embed_proj", + ) + self.proj.build((None, None, None, None, input_shape[-1])) + self.built = True + + def call(self, x): + x = ops.pad(x, self.pads) + x = self.proj(x) + + if self.norm_layer is not None: + x = self.norm(x) + + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "patch_size": self.patch_size, + "embed_dim": self.embed_dim, + } + ) + return config + + +class VideoSwinPatchMerging(keras.layers.Layer): + """Patch Merging Layer in Video Swin Transformer models. + + This layer performs a downsampling step by merging four neighboring patches + from the previous layer into a single patch in the output. It achieves this + by concatenation and linear projection. + + Args: + input_dim (int): Number of input channels in the feature maps. + norm_layer (keras.layers, optional): Normalization layer. + Default: LayerNormalization + + References: + - [Video Swin Transformer](https://arxiv.org/abs/2106.13230) + - [Video Swin Transformer GitHub](https://github.com/SwinTransformer/Video-Swin-Transformer) + """ + + def __init__(self, input_dim, norm_layer=None, **kwargs): + super().__init__(**kwargs) + self.input_dim = input_dim + self.norm_layer = norm_layer + + def build(self, input_shape): + batch_size, depth, height, width, channel = input_shape + self.reduction = layers.Dense(2 * self.input_dim, use_bias=False) + self.reduction.build( + (batch_size, depth, height // 2, width // 2, 4 * channel) + ) + + if self.norm_layer is not None: + self.norm = self.norm_layer(axis=-1, epsilon=1e-5) + self.norm.build( + (batch_size, depth, height // 2, width // 2, 4 * channel) + ) + + # compute padding if needed + self.pads = [ + [0, 0], + [0, 0], + [0, ops.mod(height, 2)], + [0, ops.mod(width, 2)], + [0, 0], + ] + self.built = True + + def call(self, x): + # padding if needed + x = ops.pad(x, self.pads) + x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C + x1 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C + x2 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C + x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C + x = ops.concatenate([x0, x1, x2, x3], axis=-1) # B D H/2 W/2 4*C + + if self.norm_layer is not None: + x = self.norm(x) + + x = self.reduction(x) + return x + + def compute_output_shape(self, input_shape): + batch_size, depth, height, width, _ = input_shape + return (batch_size, depth, height // 2, width // 2, 2 * self.input_dim) + + def get_config(self): + config = super().get_config() + config.update( + { + "input_dim": self.input_dim, + } + ) + return config + + +class VideoSwinWindowAttention(keras.Model): + """It tackles long-range video dependencies by splitting features into windows + and using relative position bias within each window for focused attention. + It supports both of shifted and non-shifted window. + + Args: + input_dim (int): The number of input channels in the feature maps. + window_size (tuple[int]): The temporal length, height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop_rate (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.0 + + References: + - [Video Swin Transformer](https://arxiv.org/abs/2106.13230) + - [Video Swin Transformer GitHub](https://github.com/SwinTransformer/Video-Swin-Transformer) + """ + + def __init__( + self, + input_dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop_rate=0.0, + proj_drop_rate=0.0, + **kwargs, + ): + super().__init__(**kwargs) + # variables + self.input_dim = input_dim + self.window_size = window_size + self.num_heads = num_heads + head_dim = input_dim // num_heads + self.qk_scale = qk_scale + self.scale = qk_scale or head_dim**-0.5 + self.qkv_bias = qkv_bias + self.attn_drop_rate = attn_drop_rate + self.proj_drop_rate = proj_drop_rate + + def get_relative_position_index( + self, window_depth, window_height, window_width + ): + y_y, z_z, x_x = ops.meshgrid( + ops.arange(window_width), + ops.arange(window_depth), + ops.arange(window_height), + ) + coords = ops.stack([z_z, y_y, x_x], axis=0) + coords_flatten = ops.reshape(coords, [3, -1]) + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) + relative_coords = ops.transpose(relative_coords, axes=[1, 2, 0]) + z_z = ( + (relative_coords[:, :, 0] + window_depth - 1) + * (2 * window_height - 1) + * (2 * window_width - 1) + ) + x_x = (relative_coords[:, :, 1] + window_height - 1) * ( + 2 * window_width - 1 + ) + y_y = relative_coords[:, :, 2] + window_width - 1 + relative_coords = ops.stack([z_z, x_x, y_y], axis=-1) + return ops.sum(relative_coords, axis=-1) + + def build(self, input_shape): + self.relative_position_bias_table = self.add_weight( + shape=( + (2 * self.window_size[0] - 1) + * (2 * self.window_size[1] - 1) + * (2 * self.window_size[2] - 1), + self.num_heads, + ), + initializer="zeros", + trainable=True, + name="relative_position_bias_table", + ) + self.relative_position_index = self.get_relative_position_index( + self.window_size[0], self.window_size[1], self.window_size[2] + ) + + # layers + self.qkv = layers.Dense(self.input_dim * 3, use_bias=self.qkv_bias) + self.attn_drop = layers.Dropout(self.attn_drop_rate) + self.proj = layers.Dense(self.input_dim) + self.proj_drop = layers.Dropout(self.proj_drop_rate) + self.qkv.build(input_shape) + self.proj.build(input_shape) + self.built = True + + def call(self, x, mask=None, training=None): + input_shape = ops.shape(x) + batch_size, depth, channel = ( + input_shape[0], + input_shape[1], + input_shape[2], + ) + + qkv = self.qkv(x) + qkv = ops.reshape( + qkv, + [batch_size, depth, 3, self.num_heads, channel // self.num_heads], + ) + qkv = ops.transpose(qkv, [2, 0, 3, 1, 4]) + q, k, v = ops.split(qkv, 3, axis=0) + q = ops.squeeze(q, axis=0) * self.scale + k = ops.squeeze(k, axis=0) + v = ops.squeeze(v, axis=0) + attn = ops.matmul(q, ops.transpose(k, [0, 1, 3, 2])) + + rel_pos_bias = ops.take( + self.relative_position_bias_table, + self.relative_position_index[:depth, :depth], + axis=0, + ) + rel_pos_bias = ops.reshape(rel_pos_bias, [depth, depth, -1]) + rel_pos_bias = ops.transpose(rel_pos_bias, [2, 0, 1]) + attn = attn + rel_pos_bias[None, ...] + + if mask is not None: + mask_size = ops.shape(mask)[0] + mask = ops.cast(mask, dtype=attn.dtype) + attn = ( + ops.reshape( + attn, + [ + batch_size // mask_size, + mask_size, + self.num_heads, + depth, + depth, + ], + ) + + mask[:, None, :, :] + ) + attn = ops.reshape(attn, [-1, self.num_heads, depth, depth]) + + attn = keras.activations.softmax(attn, axis=-1) + attn = self.attn_drop(attn, training=training) + x = ops.matmul(attn, v) + x = ops.transpose(x, [0, 2, 1, 3]) + x = ops.reshape(x, [batch_size, depth, channel]) + x = self.proj(x) + x = self.proj_drop(x, training=training) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "input_dim": self.input_dim, + "window_size": self.window_size, + "num_heads": self.num_heads, + "qk_scale": self.qk_scale, + "qkv_bias": self.qkv_bias, + "attn_drop_rate": self.attn_drop_rate, + "proj_drop_rate": self.proj_drop_rate, + } + ) + return config + + +class VideoSwinBasicLayer(keras.Model): + """A basic Video Swin Transformer layer for one stage. + + Args: + input_dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (tuple[int]): Local window size. Default: (1,7,7). + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (keras.layers, optional): Normalization layer. Default: LayerNormalization + downsample (keras.layers | None, optional): Downsample layer at the end of the layer. Default: None + + References: + - [Video Swin Transformer](https://arxiv.org/abs/2106.13230) + - [Video Swin Transformer GitHub](https://github.com/SwinTransformer/Video-Swin-Transformer) + """ + + def __init__( + self, + input_dim, + depth, + num_heads, + window_size=(1, 7, 7), + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=None, + downsampling_layer=None, + **kwargs, + ): + super().__init__(**kwargs) + self.input_dim = input_dim + self.num_heads = num_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + self.shift_size = tuple([i // 2 for i in window_size]) + self.depth = depth + self.qkv_bias = qkv_bias + self.qk_scale = qk_scale + self.drop_rate = drop_rate + self.attn_drop_rate = attn_drop_rate + self.drop_path_rate = drop_path_rate + self.norm_layer = norm_layer + self.downsampling_layer = downsampling_layer + + def __compute_dim_padded(self, input_dim, window_dim_size): + input_dim = ops.cast(input_dim, dtype="float32") + window_dim_size = ops.cast(window_dim_size, dtype="float32") + return ops.cast( + ops.ceil(input_dim / window_dim_size) * window_dim_size, "int32" + ) + + def build(self, input_shape): + self.window_size, self.shift_size = get_window_size( + input_shape[1:-1], self.window_size, self.shift_size + ) + self.depth_pad = self.__compute_dim_padded( + input_shape[1], self.window_size[0] + ) + self.height_pad = self.__compute_dim_padded( + input_shape[2], self.window_size[1] + ) + self.width_pad = self.__compute_dim_padded( + input_shape[3], self.window_size[2] + ) + self.attn_mask = compute_mask( + self.depth_pad, + self.height_pad, + self.width_pad, + self.window_size, + self.shift_size, + ) + + # build blocks + self.blocks = [ + VideoSwinTransformerBlock( + self.input_dim, + num_heads=self.num_heads, + window_size=self.window_size, + shift_size=(0, 0, 0) if (i % 2 == 0) else self.shift_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + qk_scale=self.qk_scale, + drop_rate=self.drop_rate, + attn_drop_rate=self.attn_drop_rate, + drop_path_rate=( + self.drop_path_rate[i] + if isinstance(self.drop_path_rate, list) + else self.drop_path_rate + ), + norm_layer=self.norm_layer, + ) + for i in range(self.depth) + ] + + if self.downsampling_layer is not None: + self.downsample = self.downsampling_layer( + input_dim=self.input_dim, norm_layer=self.norm_layer + ) + self.downsample.build(input_shape) + + for i in range(self.depth): + self.blocks[i].build(input_shape) + + self.built = True + + def compute_output_shape(self, input_shape): + if self.downsampling_layer is not None: + input_shape = self.downsample.compute_output_shape(input_shape) + return input_shape + + return input_shape + + def call(self, x, training=None): + input_shape = ops.shape(x) + batch_size, depth, height, width, channel = ( + input_shape[0], + input_shape[1], + input_shape[2], + input_shape[3], + input_shape[4], + ) + + for block in self.blocks: + x = block(x, self.attn_mask, training=training) + + x = ops.reshape(x, [batch_size, depth, height, width, channel]) + + if self.downsampling_layer is not None: + x = self.downsample(x) + + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "input_dim": self.input_dim, + "window_size": self.window_size, + "num_heads": self.num_heads, + "mlp_ratio": self.mlp_ratio, + "depth": self.depth, + "qkv_bias": self.qkv_bias, + "qk_scale": self.qk_scale, + "drop_rate": self.drop_rate, + "attn_drop_rate": self.attn_drop_rate, + "drop_path_rate": self.drop_path_rate, + } + ) + return config + + +class VideoSwinTransformerBlock(keras.Model): + """Video Swin Transformer Block. + + Args: + input_dim (int): Number of feature channels. + num_heads (int): Number of attention heads. + window_size (tuple[int]): Local window size. Default: (2, 7, 7) + shift_size (tuple[int]): Shift size for SW-MSA. Default: (0, 0, 0) + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + Default: 4.0 + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. + Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + Default: None + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optionalc): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (keras.layers.Activation, optional): Activation layer. Default: gelu + norm_layer (keras.layers, optional): Normalization layer. + Default: LayerNormalization + + References: + - [Video Swin Transformer](https://arxiv.org/abs/2106.13230) + - [Video Swin Transformer GitHub](https://github.com/SwinTransformer/Video-Swin-Transformer) + """ + + def __init__( + self, + input_dim, + num_heads, + window_size=(2, 7, 7), + shift_size=(0, 0, 0), + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + activation="gelu", + norm_layer=layers.LayerNormalization, + **kwargs, + ): + super().__init__(**kwargs) + # variables + self.input_dim = input_dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.qk_scale = qk_scale + self.drop_rate = drop_rate + self.attn_drop_rate = attn_drop_rate + self.drop_path_rate = drop_path_rate + self.mlp_hidden_dim = int(input_dim * mlp_ratio) + self.norm_layer = norm_layer + self._activation_identifier = activation + + for i, (shift, window) in enumerate( + zip(self.shift_size, self.window_size) + ): + if not (0 <= shift < window): + raise ValueError( + f"shift_size[{i}] must be in the range 0 to less than " + f"window_size[{i}], but got shift_size[{i}]={shift} " + f"and window_size[{i}]={window}." + ) + + def build(self, input_shape): + self.window_size, self.shift_size = get_window_size( + input_shape[1:-1], self.window_size, self.shift_size + ) + self.apply_cyclic_shift = any(i > 0 for i in self.shift_size) + + # layers + self.drop_path = ( + DropPath(self.drop_path_rate) + if self.drop_path_rate > 0.0 + else layers.Identity() + ) + + self.norm1 = self.norm_layer(axis=-1, epsilon=1e-05) + self.norm1.build(input_shape) + + self.attn = VideoSwinWindowAttention( + self.input_dim, + window_size=self.window_size, + num_heads=self.num_heads, + qkv_bias=self.qkv_bias, + qk_scale=self.qk_scale, + attn_drop_rate=self.attn_drop_rate, + proj_drop_rate=self.drop_rate, + ) + self.attn.build((None, None, self.input_dim)) + + self.norm2 = self.norm_layer(axis=-1, epsilon=1e-05) + self.norm2.build((*input_shape[:-1], self.input_dim)) + + self.mlp = MLP( + output_dim=self.input_dim, + hidden_dim=self.mlp_hidden_dim, + activation=self._activation_identifier, + drop_rate=self.drop_rate, + ) + self.mlp.build((*input_shape[:-1], self.input_dim)) + + # compute padding if needed. + # pad input feature maps to multiples of window size. + _, depth, height, width, _ = input_shape + pad_l = pad_t = pad_d0 = 0 + self.pad_d1 = ops.mod(-depth + self.window_size[0], self.window_size[0]) + self.pad_b = ops.mod(-height + self.window_size[1], self.window_size[1]) + self.pad_r = ops.mod(-width + self.window_size[2], self.window_size[2]) + self.pads = [ + [0, 0], + [pad_d0, self.pad_d1], + [pad_t, self.pad_b], + [pad_l, self.pad_r], + [0, 0], + ] + self.apply_pad = any( + value > 0 for value in (self.pad_d1, self.pad_r, self.pad_b) + ) + self.built = True + + def first_forward(self, x, mask_matrix, training): + input_shape = ops.shape(x) + batch_size, depth, height, width, _ = ( + input_shape[0], + input_shape[1], + input_shape[2], + input_shape[3], + input_shape[4], + ) + x = self.norm1(x) + + # apply padding if needed. + x = ops.pad(x, self.pads) + + input_shape = ops.shape(x) + depth_pad, height_pad, width_pad = ( + input_shape[1], + input_shape[2], + input_shape[3], + ) + + # cyclic shift + if self.apply_cyclic_shift: + shifted_x = ops.roll( + x, + shift=( + -self.shift_size[0], + -self.shift_size[1], + -self.shift_size[2], + ), + axis=(1, 2, 3), + ) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) + + # get attentions params + attn_windows = self.attn(x_windows, mask=attn_mask, training=training) + + # reverse the swin windows + shifted_x = window_reverse( + attn_windows, + self.window_size, + batch_size, + depth_pad, + height_pad, + width_pad, + ) + + # reverse cyclic shift + if self.apply_cyclic_shift: + x = ops.roll( + shifted_x, + shift=( + self.shift_size[0], + self.shift_size[1], + self.shift_size[2], + ), + axis=(1, 2, 3), + ) + else: + x = shifted_x + + # pad if required + if self.apply_pad: + return x[:, :depth, :height, :width, :] + + return x + + def second_forward(self, x, training): + x = self.norm2(x) + x = self.mlp(x) + x = self.drop_path(x, training=training) + return x + + def call(self, x, mask_matrix=None, training=None): + shortcut = x + x = self.first_forward(x, mask_matrix, training) + x = shortcut + self.drop_path(x) + x = x + self.second_forward(x, training) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "input_dim": self.input_dim, + "window_size": self.num_heads, + "num_heads": self.window_size, + "shift_size": self.shift_size, + "mlp_ratio": self.mlp_ratio, + "qkv_bias": self.qkv_bias, + "qk_scale": self.qk_scale, + "drop_rate": self.drop_rate, + "attn_drop_rate": self.attn_drop_rate, + "drop_path_rate": self.drop_path_rate, + "mlp_hidden_dim": self.mlp_hidden_dim, + "activation": self._activation_identifier, + } + ) + return config + + +class DropPath(keras.layers.Layer): + """ + Implements the DropPath layer. DropPath randomly drops samples during + training with a probability of `rate`. Note that this layer drops individual + samples within a batch and not the entire batch. DropPath randomly drops + some individual samples from a batch, whereas StochasticDepth + randomly drops the entire batch. + + References: + - [FractalNet](https://arxiv.org/abs/1605.07648v4). + - [rwightman/pytorch-image-models](https://github.com/rwightman/pytorch-image-models/blob/7c67d6aca992f039eece0af5f7c29a43d48c00e4/timm/models/layers/drop.py#L135) + + Args: + rate: float, the probability of the residual branch being dropped. + seed: (Optional) integer. Used to create a random seed. + + """ + + def __init__(self, rate=0.5, seed=None, **kwargs): + super().__init__(**kwargs) + self.rate = rate + self._seed_val = seed + self.seed = random.SeedGenerator(seed=seed) + + def call(self, x, training=None): + if self.rate == 0.0 or not training: + return x + else: + batch_size = x.shape[0] or ops.shape(x)[0] + drop_map_shape = (batch_size,) + (1,) * (len(x.shape) - 1) + drop_map = ops.cast( + random.uniform(drop_map_shape, seed=self.seed) > self.rate, + x.dtype, + ) + x = x / (1.0 - self.rate) + x = x * drop_map + return x + + def get_config(self): + config = super().get_config() + config.update({"rate": self.rate, "seed": self._seed_val}) + return config diff --git a/keras_nlp/src/models/video_swin/video_swin_layers_test.py b/keras_nlp/src/models/video_swin/video_swin_layers_test.py new file mode 100644 index 0000000000..c957cf0444 --- /dev/null +++ b/keras_nlp/src/models/video_swin/video_swin_layers_test.py @@ -0,0 +1,96 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras import ops + +from keras_nlp.src.models.video_swin.video_swin_layers import ( + VideoSwinPatchingAndEmbedding, +) +from keras_nlp.src.models.video_swin.video_swin_layers import ( + VideoSwinPatchMerging, +) +from keras_nlp.src.models.video_swin.video_swin_layers import ( + VideoSwinWindowAttention, +) +from keras_nlp.src.tests.test_case import TestCase + + +class TestVideoSwinPatchingAndEmbedding(TestCase): + def test_patch_embedding_compute_output_shape(self): + patch_embedding_model = VideoSwinPatchingAndEmbedding( + patch_size=(2, 4, 4), embed_dim=96, norm_layer=None + ) + input_array = ops.ones(shape=(1, 16, 32, 32, 3)) + output_shape = patch_embedding_model(input_array).shape + expected_output_shape = (1, 8, 8, 8, 96) + self.assertEqual(output_shape, expected_output_shape) + + def test_patch_embedding_get_config(self): + patch_embedding_model = VideoSwinPatchingAndEmbedding( + patch_size=(4, 4, 4), embed_dim=96 + ) + config = patch_embedding_model.get_config() + assert isinstance(config, dict) + assert config["patch_size"] == (4, 4, 4) + assert config["embed_dim"] == 96 + + +class TestVideoSwinWindowAttention(TestCase): + + def setUp(self): + self.window_attention_model = VideoSwinWindowAttention( + input_dim=32, + window_size=(2, 4, 4), + num_heads=8, + qkv_bias=True, + qk_scale=None, + attn_drop_rate=0.1, + proj_drop_rate=0.1, + ) + + def test_window_attention_output_shape(self): + input_shape = (2, 16, 32) + input_array = ops.ones(input_shape) + output_shape = self.window_attention_model(input_array).shape + expected_output_shape = input_shape + self.assertEqual(output_shape, expected_output_shape) + + def test_window_attention_get_config(self): + config = self.window_attention_model.get_config() + # Add assertions based on the specific requirements + assert isinstance(config, dict) + assert config["window_size"] == (2, 4, 4) + assert config["num_heads"] == 8 + assert config["qkv_bias"] is True + assert config["qk_scale"] is None + assert config["attn_drop_rate"] == 0.1 + assert config["proj_drop_rate"] == 0.1 + + +class TestVideoSwinPatchMerging(TestCase): + def setUp(self): + self.patch_merging = VideoSwinPatchMerging(input_dim=32) + + def test_output_shape(self): + input_shape = (2, 4, 32, 32, 3) + input_tensor = ops.ones(input_shape) + output_shape = self.patch_merging(input_tensor).shape + expected_shape = ( + input_shape[0], + input_shape[1], + input_shape[2] // 2, + input_shape[3] // 2, + 2 * 32, + ) + self.assertEqual(output_shape, expected_shape) diff --git a/keras_nlp/src/models/video_swin/video_swin_video_classifier.py b/keras_nlp/src/models/video_swin/video_swin_video_classifier.py new file mode 100644 index 0000000000..9a1904c10c --- /dev/null +++ b/keras_nlp/src/models/video_swin/video_swin_video_classifier.py @@ -0,0 +1,100 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.image_classifier import ImageClassifier +from keras_nlp.src.models.video_swin.video_swin_backbone import VideoSwinBackbone + +@keras_nlp_export("keras_nlp.models.VideoSwinVideoClassifier") +class VideoSwinVideoClassifier(ImageClassifier): + """VideoSwin Video classifier task model. + + Args: + backbone: A `keras_nlp.models.VideoSwinBackbone` instance. + num_classes: int. The number of classes to predict. + activation: `None`, str or callable. The activation function to use on + the `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `"softmax"`. + head_dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The + dtype to use for the classification head's computations and weights. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + where `x` is a tensor and `y` is a integer from `[0, num_classes)`. + All `ImageClassifier` tasks include a `from_preset()` constructor which can + be used to load a pre-trained config and weights. + + Custom backbone. + ```python + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + backbone = keras_nlp.models.VideoSwinBackbone( + image_shape=(32, 224, 224, 3) + include_rescaling=False, + ) + classifier = keras_nlp.models.VideoSwinVideoClassifier( + backbone=backbone, + num_classes=4, + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + """ + + backbone_cls = VideoSwinBackbone + + def __init__( + self, + backbone, + num_classes, + activation="softmax", + head_dtype=None, + preprocessor=None, # adding this dummy arg for saved model test + # TODO: once preprocessor flow is figured out, this needs to be updated + **kwargs, + ): + head_dtype = head_dtype or backbone.dtype_policy + + # === Layers === + self.backbone = backbone + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + dtype=head_dtype, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + outputs = self.output_dense(x) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "activation": self.activation, + } + ) + return config diff --git a/keras_nlp/src/models/video_swin/video_swin_video_classifier_test.py b/keras_nlp/src/models/video_swin/video_swin_video_classifier_test.py new file mode 100644 index 0000000000..32628225fd --- /dev/null +++ b/keras_nlp/src/models/video_swin/video_swin_video_classifier_test.py @@ -0,0 +1,65 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from keras import ops + +from keras_nlp.src.models.video_swin.video_swin_backbone import ( + VideoSwinBackbone, +) + +from keras_nlp.src.models.video_swin.video_swin_video_classifier import ( + VideoSwinVideoClassifier, +) +from keras_nlp.src.tests.test_case import TestCase + + +class VideoSwinVideoClassifierTest(TestCase): + def setUp(self): + self.images = ops.ones((2, 8, 256, 256, 3)) + self.labels = [0, 3] + self.backbone = VideoSwinBackbone( + image_shape=(8, 256, 256, 3), + include_rescaling=True, + ) + self.init_kwargs = { + "backbone": self.backbone, + "num_classes": 2, + "activation": "softmax", + } + self.train_data = (self.images, self.labels) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=VideoSwinVideoClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + def test_head_dtype(self): + model = VideoSwinVideoClassifier( + **self.init_kwargs, head_dtype="bfloat16" + ) + self.assertEqual(model.output_dense.compute_dtype, "bfloat16") + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=VideoSwinVideoClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) From c2a855040e7bcf10a36e750b1b71e0b69a660206 Mon Sep 17 00:00:00 2001 From: ushareng Date: Wed, 4 Sep 2024 20:37:50 +0530 Subject: [PATCH 7/7] code format fixed --- .../src/models/video_swin/video_swin_video_classifier.py | 5 ++++- .../models/video_swin/video_swin_video_classifier_test.py | 1 - 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/keras_nlp/src/models/video_swin/video_swin_video_classifier.py b/keras_nlp/src/models/video_swin/video_swin_video_classifier.py index 9a1904c10c..6f5d7ea05e 100644 --- a/keras_nlp/src/models/video_swin/video_swin_video_classifier.py +++ b/keras_nlp/src/models/video_swin/video_swin_video_classifier.py @@ -15,7 +15,10 @@ from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.models.image_classifier import ImageClassifier -from keras_nlp.src.models.video_swin.video_swin_backbone import VideoSwinBackbone +from keras_nlp.src.models.video_swin.video_swin_backbone import ( + VideoSwinBackbone, +) + @keras_nlp_export("keras_nlp.models.VideoSwinVideoClassifier") class VideoSwinVideoClassifier(ImageClassifier): diff --git a/keras_nlp/src/models/video_swin/video_swin_video_classifier_test.py b/keras_nlp/src/models/video_swin/video_swin_video_classifier_test.py index 32628225fd..4bfccb4343 100644 --- a/keras_nlp/src/models/video_swin/video_swin_video_classifier_test.py +++ b/keras_nlp/src/models/video_swin/video_swin_video_classifier_test.py @@ -17,7 +17,6 @@ from keras_nlp.src.models.video_swin.video_swin_backbone import ( VideoSwinBackbone, ) - from keras_nlp.src.models.video_swin.video_swin_video_classifier import ( VideoSwinVideoClassifier, )