From 62059abb08a752dc0e3edaaae05b269e5f83c9c4 Mon Sep 17 00:00:00 2001 From: Theo Morales Date: Fri, 25 Mar 2022 16:40:21 +0000 Subject: [PATCH 1/9] Implement BatchNorm with BNRS in the library --- examples/vision/mamlpp/maml++_miniimagenet.py | 2 +- learn2learn/vision/models/__init__.py | 15 +++ learn2learn/vision/models/bnrs.py | 104 ++++++++++++++++++ .../vision/models}/cnn4_bnrs.py | 96 +--------------- 4 files changed, 125 insertions(+), 92 deletions(-) create mode 100644 learn2learn/vision/models/bnrs.py rename {examples/vision/mamlpp => learn2learn/vision/models}/cnn4_bnrs.py (65%) diff --git a/examples/vision/mamlpp/maml++_miniimagenet.py b/examples/vision/mamlpp/maml++_miniimagenet.py index 78085bf9..db2c073c 100755 --- a/examples/vision/mamlpp/maml++_miniimagenet.py +++ b/examples/vision/mamlpp/maml++_miniimagenet.py @@ -20,7 +20,7 @@ from typing import Tuple from tqdm import tqdm -from examples.vision.mamlpp.cnn4_bnrs import CNN4_BNRS +from learn2learn.vision.models.cnn4_bnrs import CNN4_BNRS from examples.vision.mamlpp.MAMLpp import MAMLpp diff --git a/learn2learn/vision/models/__init__.py b/learn2learn/vision/models/__init__.py index 54cdeec4..da96669c 100644 --- a/learn2learn/vision/models/__init__.py +++ b/learn2learn/vision/models/__init__.py @@ -31,8 +31,17 @@ def forward(self, x): CNN4Backbone, ) +from .cnn4_bnrs import ( + LinearBlock_BNRS, + ConvBlock_BNRS, + ConvBase_BNRS, + CNN4Backbone_BNRS, + CNN4_BNRS, +) + from .resnet12 import ResNet12, ResNet12Backbone from .wrn28 import WRN28, WRN28Backbone +from .bnrs import BatchNorm_BNRS __all__ = [ 'get_pretrained_backbone', @@ -49,6 +58,12 @@ def forward(self, x): 'ResNet12Backbone', 'WRN28', 'WRN28Backbone', + 'BatchNorm_BNRS', + 'LinearBlock_BNRS', + 'ConvBlock_BNRS', + 'ConvBase_BNRS', + 'CNN4Backbone_BNRS', + 'CNN4_BNRS', ] _BACKBONE_URLS = { diff --git a/learn2learn/vision/models/bnrs.py b/learn2learn/vision/models/bnrs.py new file mode 100644 index 00000000..28abafb7 --- /dev/null +++ b/learn2learn/vision/models/bnrs.py @@ -0,0 +1,104 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# vim:fenc=utf-8 +# + +""" +BatchNorm layer augmented with Per-Step Batch Normalisation Running Statistics and Per-Step Batch +Normalisation Weights and Biases, as proposed in MAML++ by Antobiou et al. +""" + +import torch +import torch.nn.functional as F + +from copy import deepcopy +from learn2learn.vision.models.cnn4 import maml_init_, fc_init_ + + +class BatchNorm_BNRS(torch.nn.Module): + """ + An extension of Pytorch's BatchNorm layer, with the Per-Step Batch Normalisation Running + Statistics and Per-Step Batch Normalisation Weights and Biases improvements proposed in + MAML++ by Antoniou et al. It is adapted from the original Pytorch implementation at + https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch, + with heavy refactoring and a bug fix + (https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch/issues/42). + """ + + def __init__( + self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + meta_batch_norm=True, + adaptation_steps: int = 1, + ): + super(BatchNorm_BNRS, self).__init__() + self.num_features = num_features + self.eps = eps + self.affine = affine + self.meta_batch_norm = meta_batch_norm + self.num_features = num_features + self.running_mean = torch.nn.Parameter( + torch.zeros(adaptation_steps, num_features), requires_grad=False + ) + self.running_var = torch.nn.Parameter( + torch.ones(adaptation_steps, num_features), requires_grad=False + ) + self.bias = torch.nn.Parameter( + torch.zeros(adaptation_steps, num_features), requires_grad=True + ) + self.weight = torch.nn.Parameter( + torch.ones(adaptation_steps, num_features), requires_grad=True + ) + self.backup_running_mean = torch.zeros(self.running_mean.shape) + self.backup_running_var = torch.ones(self.running_var.shape) + self.momentum = momentum + + def forward( + self, + input, + step, + ): + """ + :param input: input data batch, size either can be any. + :param step: The current inner loop step being taken. This is used when to learn per step params and + collecting per step batch statistics. + :return: The result of the batch norm operation. + """ + assert ( + step < self.running_mean.shape[0] + ), f"Running forward with step={step} when initialised with {self.running_mean.shape[0]} steps!" + return F.batch_norm( + input, + self.running_mean[step], + self.running_var[step], + self.weight[step], + self.bias[step], + training=True, + momentum=self.momentum, + eps=self.eps, + ) + + def backup_stats(self): + self.backup_running_mean.data = deepcopy(self.running_mean.data) + self.backup_running_var.data = deepcopy(self.running_var.data) + + def restore_backup_stats(self): + """ + Resets batch statistics to their backup values which are collected after each forward pass. + """ + self.running_mean = torch.nn.Parameter( + self.backup_running_mean, requires_grad=False + ) + self.running_var = torch.nn.Parameter( + self.backup_running_var, requires_grad=False + ) + + def extra_repr(self): + return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}".format( + **self.__dict__ + ) + + diff --git a/examples/vision/mamlpp/cnn4_bnrs.py b/learn2learn/vision/models/cnn4_bnrs.py similarity index 65% rename from examples/vision/mamlpp/cnn4_bnrs.py rename to learn2learn/vision/models/cnn4_bnrs.py index 28f9666f..8d36a304 100644 --- a/examples/vision/mamlpp/cnn4_bnrs.py +++ b/learn2learn/vision/models/cnn4_bnrs.py @@ -11,101 +11,15 @@ import torch.nn.functional as F from copy import deepcopy +from learn2learn.vision.models.bnrs import BatchNorm_BNRS from learn2learn.vision.models.cnn4 import maml_init_, fc_init_ -class MetaBatchNormLayer(torch.nn.Module): - """ - An extension of Pytorch's BatchNorm layer, with the Per-Step Batch Normalisation Running - Statistics and Per-Step Batch Normalisation Weights and Biases improvements proposed in - MAML++ by Antoniou et al. It is adapted from the original Pytorch implementation at - https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch, - with heavy refactoring and a bug fix - (https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch/issues/42). - """ - - def __init__( - self, - num_features, - eps=1e-5, - momentum=0.1, - affine=True, - meta_batch_norm=True, - adaptation_steps: int = 1, - ): - super(MetaBatchNormLayer, self).__init__() - self.num_features = num_features - self.eps = eps - self.affine = affine - self.meta_batch_norm = meta_batch_norm - self.num_features = num_features - self.running_mean = torch.nn.Parameter( - torch.zeros(adaptation_steps, num_features), requires_grad=False - ) - self.running_var = torch.nn.Parameter( - torch.ones(adaptation_steps, num_features), requires_grad=False - ) - self.bias = torch.nn.Parameter( - torch.zeros(adaptation_steps, num_features), requires_grad=True - ) - self.weight = torch.nn.Parameter( - torch.ones(adaptation_steps, num_features), requires_grad=True - ) - self.backup_running_mean = torch.zeros(self.running_mean.shape) - self.backup_running_var = torch.ones(self.running_var.shape) - self.momentum = momentum - - def forward( - self, - input, - step, - ): - """ - :param input: input data batch, size either can be any. - :param step: The current inner loop step being taken. This is used when to learn per step params and - collecting per step batch statistics. - :return: The result of the batch norm operation. - """ - assert ( - step < self.running_mean.shape[0] - ), f"Running forward with step={step} when initialised with {self.running_mean.shape[0]} steps!" - return F.batch_norm( - input, - self.running_mean[step], - self.running_var[step], - self.weight[step], - self.bias[step], - training=True, - momentum=self.momentum, - eps=self.eps, - ) - - def backup_stats(self): - self.backup_running_mean.data = deepcopy(self.running_mean.data) - self.backup_running_var.data = deepcopy(self.running_var.data) - - def restore_backup_stats(self): - """ - Resets batch statistics to their backup values which are collected after each forward pass. - """ - self.running_mean = torch.nn.Parameter( - self.backup_running_mean, requires_grad=False - ) - self.running_var = torch.nn.Parameter( - self.backup_running_var, requires_grad=False - ) - - def extra_repr(self): - return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}".format( - **self.__dict__ - ) - - class LinearBlock_BNRS(torch.nn.Module): def __init__(self, input_size, output_size, adaptation_steps): super(LinearBlock_BNRS, self).__init__() self.relu = torch.nn.ReLU() - self.normalize = MetaBatchNormLayer( + self.normalize = BatchNorm_BNRS( output_size, affine=True, momentum=0.999, @@ -143,7 +57,7 @@ def __init__( stride = (1, 1) else: self.max_pool = lambda x: x - self.normalize = MetaBatchNormLayer( + self.normalize = BatchNorm_BNRS( out_channels, affine=True, adaptation_steps=adaptation_steps, @@ -304,7 +218,7 @@ def backup_stats(self): Backup stored batch statistics before running a validation epoch. """ for layer in self.features.modules(): - if type(layer) is MetaBatchNormLayer: + if type(layer) is BatchNorm_BNRS: layer.backup_stats() def restore_backup_stats(self): @@ -312,7 +226,7 @@ def restore_backup_stats(self): Reset stored batch statistics from the stored backup. """ for layer in self.features.modules(): - if type(layer) is MetaBatchNormLayer: + if type(layer) is BatchNorm_BNRS: layer.restore_backup_stats() def forward(self, x, step): From 77eff55153cfda4bbbbe0ec74bcb4b4fe23f225e Mon Sep 17 00:00:00 2001 From: Theo Morales Date: Fri, 25 Mar 2022 16:46:13 +0000 Subject: [PATCH 2/9] List contribution in changelog --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ca845a18..c7f6407e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -* New vision example: MAML++. (@[DubiousCactus](https://github.com/DubiousCactus)) +* New BatchNorm layer with per-step running statistics and weights & biases from MAML++. (@[Théo Morales](https://github.com/DubiousCactus)) +* New vision example: MAML++. (@[Théo Morales](https://github.com/DubiousCactus)) * Add tutorial: "Demystifying Task Transforms", ([Varad Pimpalkhute](https://github.com/nightlessbaron/)) ### Changed From 365b30f3a5d4cd95ce78d991d343078a6e8808c8 Mon Sep 17 00:00:00 2001 From: Theo Morales Date: Sat, 26 Mar 2022 14:20:57 +0000 Subject: [PATCH 3/9] Fix unused imports --- learn2learn/vision/models/bnrs.py | 3 --- learn2learn/vision/models/cnn4_bnrs.py | 17 ++++++++++------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/learn2learn/vision/models/bnrs.py b/learn2learn/vision/models/bnrs.py index 28abafb7..44731c4b 100644 --- a/learn2learn/vision/models/bnrs.py +++ b/learn2learn/vision/models/bnrs.py @@ -12,7 +12,6 @@ import torch.nn.functional as F from copy import deepcopy -from learn2learn.vision.models.cnn4 import maml_init_, fc_init_ class BatchNorm_BNRS(torch.nn.Module): @@ -100,5 +99,3 @@ def extra_repr(self): return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}".format( **self.__dict__ ) - - diff --git a/learn2learn/vision/models/cnn4_bnrs.py b/learn2learn/vision/models/cnn4_bnrs.py index 8d36a304..3f2bf181 100644 --- a/learn2learn/vision/models/cnn4_bnrs.py +++ b/learn2learn/vision/models/cnn4_bnrs.py @@ -8,9 +8,7 @@ """ import torch -import torch.nn.functional as F -from copy import deepcopy from learn2learn.vision.models.bnrs import BatchNorm_BNRS from learn2learn.vision.models.cnn4 import maml_init_, fc_init_ @@ -92,8 +90,13 @@ class ConvBase_BNRS(torch.nn.Sequential): # MiniImagenet: hidden=32, channels=3, max_pool def __init__( - self, hidden=64, channels=1, max_pool=False, layers=4, max_pool_factor=1.0, - adaptation_steps=1 + self, + hidden=64, + channels=1, + max_pool=False, + layers=4, + max_pool_factor=1.0, + adaptation_steps=1, ): core = [ ConvBlock_BNRS( @@ -102,7 +105,7 @@ def __init__( (3, 3), max_pool=max_pool, max_pool_factor=max_pool_factor, - adaptation_steps=adaptation_steps + adaptation_steps=adaptation_steps, ), ] for _ in range(layers - 1): @@ -113,7 +116,7 @@ def __init__( kernel_size=(3, 3), max_pool=max_pool, max_pool_factor=max_pool_factor, - adaptation_steps=adaptation_steps + adaptation_steps=adaptation_steps, ) ) super(ConvBase_BNRS, self).__init__(*core) @@ -142,7 +145,7 @@ def __init__( channels=channels, max_pool=max_pool, max_pool_factor=max_pool_factor, - adaptation_steps=adaptation_steps + adaptation_steps=adaptation_steps, ) def forward(self, x, step): From 11b7cb24cfd58a1a0721f53426c90429b4e98cef Mon Sep 17 00:00:00 2001 From: Theo Morales Date: Thu, 31 Mar 2022 12:04:21 +0100 Subject: [PATCH 4/9] Rename BNRS into MetaBatchNorm --- learn2learn/nn/__init__.py | 1 + .../models/bnrs.py => nn/metabatchnorm.py} | 4 +- learn2learn/vision/models/__init__.py | 24 +++++------ .../{cnn4_bnrs.py => cnn4_metabatchnorm.py} | 41 ++++++++++--------- 4 files changed, 35 insertions(+), 35 deletions(-) rename learn2learn/{vision/models/bnrs.py => nn/metabatchnorm.py} (97%) rename learn2learn/vision/models/{cnn4_bnrs.py => cnn4_metabatchnorm.py} (84%) diff --git a/learn2learn/nn/__init__.py b/learn2learn/nn/__init__.py index 54ea3de3..cf8638b5 100644 --- a/learn2learn/nn/__init__.py +++ b/learn2learn/nn/__init__.py @@ -8,3 +8,4 @@ from .misc import * from .protonet import PrototypicalClassifier from .metaoptnet import SVClassifier +from .metabatchnorm import MetaBatchNorm diff --git a/learn2learn/vision/models/bnrs.py b/learn2learn/nn/metabatchnorm.py similarity index 97% rename from learn2learn/vision/models/bnrs.py rename to learn2learn/nn/metabatchnorm.py index 44731c4b..e94fb6ec 100644 --- a/learn2learn/vision/models/bnrs.py +++ b/learn2learn/nn/metabatchnorm.py @@ -14,7 +14,7 @@ from copy import deepcopy -class BatchNorm_BNRS(torch.nn.Module): +class MetaBatchNorm(torch.nn.Module): """ An extension of Pytorch's BatchNorm layer, with the Per-Step Batch Normalisation Running Statistics and Per-Step Batch Normalisation Weights and Biases improvements proposed in @@ -33,7 +33,7 @@ def __init__( meta_batch_norm=True, adaptation_steps: int = 1, ): - super(BatchNorm_BNRS, self).__init__() + super(MetaBatchNorm, self).__init__() self.num_features = num_features self.eps = eps self.affine = affine diff --git a/learn2learn/vision/models/__init__.py b/learn2learn/vision/models/__init__.py index da96669c..7c3edb35 100644 --- a/learn2learn/vision/models/__init__.py +++ b/learn2learn/vision/models/__init__.py @@ -31,17 +31,16 @@ def forward(self, x): CNN4Backbone, ) -from .cnn4_bnrs import ( - LinearBlock_BNRS, - ConvBlock_BNRS, - ConvBase_BNRS, - CNN4Backbone_BNRS, - CNN4_BNRS, +from .cnn4_metabatchnorm import ( + LinearBlock_MetaBatchNorm, + ConvBlock_MetaBatchNorm, + ConvBase_MetaBatchNorm, + CNN4Backbone_MetaBatchNorm, + CNN4_MetaBatchNorm, ) from .resnet12 import ResNet12, ResNet12Backbone from .wrn28 import WRN28, WRN28Backbone -from .bnrs import BatchNorm_BNRS __all__ = [ 'get_pretrained_backbone', @@ -58,12 +57,11 @@ def forward(self, x): 'ResNet12Backbone', 'WRN28', 'WRN28Backbone', - 'BatchNorm_BNRS', - 'LinearBlock_BNRS', - 'ConvBlock_BNRS', - 'ConvBase_BNRS', - 'CNN4Backbone_BNRS', - 'CNN4_BNRS', + 'LinearBlock_MetaBatchNorm', + 'ConvBlock_MetaBatchNorm', + 'ConvBase_MetaBatchNorm', + 'CNN4Backbone_MetaBatchNorm', + 'CNN4_MetaBatchNorm', ] _BACKBONE_URLS = { diff --git a/learn2learn/vision/models/cnn4_bnrs.py b/learn2learn/vision/models/cnn4_metabatchnorm.py similarity index 84% rename from learn2learn/vision/models/cnn4_bnrs.py rename to learn2learn/vision/models/cnn4_metabatchnorm.py index 3f2bf181..eb05a2db 100644 --- a/learn2learn/vision/models/cnn4_bnrs.py +++ b/learn2learn/vision/models/cnn4_metabatchnorm.py @@ -4,20 +4,21 @@ # """ -CNN4 extended with Batch-Norm Running Statistics. +CNN4 using a MetaBatchNorm layer allowing to accumulate per-step running statistics and use +per-step bias and variance parameters. """ import torch -from learn2learn.vision.models.bnrs import BatchNorm_BNRS +from learn2learn.nn.metabatchnorm import MetaBatchNorm from learn2learn.vision.models.cnn4 import maml_init_, fc_init_ -class LinearBlock_BNRS(torch.nn.Module): +class LinearBlock_MetaBatchNorm(torch.nn.Module): def __init__(self, input_size, output_size, adaptation_steps): - super(LinearBlock_BNRS, self).__init__() + super(LinearBlock_MetaBatchNorm, self).__init__() self.relu = torch.nn.ReLU() - self.normalize = BatchNorm_BNRS( + self.normalize = MetaBatchNorm( output_size, affine=True, momentum=0.999, @@ -34,7 +35,7 @@ def forward(self, x, step): return x -class ConvBlock_BNRS(torch.nn.Module): +class ConvBlock_MetaBatchNorm(torch.nn.Module): def __init__( self, in_channels, @@ -44,7 +45,7 @@ def __init__( max_pool_factor=1.0, adaptation_steps=1, ): - super(ConvBlock_BNRS, self).__init__() + super(ConvBlock_MetaBatchNorm, self).__init__() stride = (int(2 * max_pool_factor), int(2 * max_pool_factor)) if max_pool: self.max_pool = torch.nn.MaxPool2d( @@ -55,7 +56,7 @@ def __init__( stride = (1, 1) else: self.max_pool = lambda x: x - self.normalize = BatchNorm_BNRS( + self.normalize = MetaBatchNorm( out_channels, affine=True, adaptation_steps=adaptation_steps, @@ -83,7 +84,7 @@ def forward(self, x, step): return x -class ConvBase_BNRS(torch.nn.Sequential): +class ConvBase_MetaBatchNorm(torch.nn.Sequential): # NOTE: # Omniglot: hidden=64, channels=1, no max_pool @@ -99,7 +100,7 @@ def __init__( adaptation_steps=1, ): core = [ - ConvBlock_BNRS( + ConvBlock_MetaBatchNorm( channels, hidden, (3, 3), @@ -110,7 +111,7 @@ def __init__( ] for _ in range(layers - 1): core.append( - ConvBlock_BNRS( + ConvBlock_MetaBatchNorm( hidden, hidden, kernel_size=(3, 3), @@ -119,7 +120,7 @@ def __init__( adaptation_steps=adaptation_steps, ) ) - super(ConvBase_BNRS, self).__init__(*core) + super(ConvBase_MetaBatchNorm, self).__init__(*core) def forward(self, x, step): for module in self: @@ -127,7 +128,7 @@ def forward(self, x, step): return x -class CNN4Backbone_BNRS(ConvBase_BNRS): +class CNN4Backbone_MetaBatchNorm(ConvBase_MetaBatchNorm): def __init__( self, hidden_size=64, @@ -139,7 +140,7 @@ def __init__( ): if max_pool_factor is None: max_pool_factor = 4 // layers - super(CNN4Backbone_BNRS, self).__init__( + super(CNN4Backbone_MetaBatchNorm, self).__init__( hidden=hidden_size, layers=layers, channels=channels, @@ -149,12 +150,12 @@ def __init__( ) def forward(self, x, step): - x = super(CNN4Backbone_BNRS, self).forward(x, step) + x = super(CNN4Backbone_MetaBatchNorm, self).forward(x, step) x = x.reshape(x.size(0), -1) return x -class CNN4_BNRS(torch.nn.Module): +class CNN4_MetaBatchNorm(torch.nn.Module): """ [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/vision/models/cnn4.py) @@ -197,10 +198,10 @@ def __init__( embedding_size=None, adaptation_steps=1, ): - super(CNN4_BNRS, self).__init__() + super(CNN4_MetaBatchNorm, self).__init__() if embedding_size is None: embedding_size = 25 * hidden_size - self.features = CNN4Backbone_BNRS( + self.features = CNN4Backbone_MetaBatchNorm( hidden_size=hidden_size, channels=channels, max_pool=max_pool, @@ -221,7 +222,7 @@ def backup_stats(self): Backup stored batch statistics before running a validation epoch. """ for layer in self.features.modules(): - if type(layer) is BatchNorm_BNRS: + if type(layer) is MetaBatchNorm: layer.backup_stats() def restore_backup_stats(self): @@ -229,7 +230,7 @@ def restore_backup_stats(self): Reset stored batch statistics from the stored backup. """ for layer in self.features.modules(): - if type(layer) is BatchNorm_BNRS: + if type(layer) is MetaBatchNorm: layer.restore_backup_stats() def forward(self, x, step): From 331bbc9a2e59f08cb5bfa7cb78041537f2312a84 Mon Sep 17 00:00:00 2001 From: Theo Morales Date: Thu, 31 Mar 2022 12:14:48 +0100 Subject: [PATCH 5/9] Update docstring of CNN4_MetaBatchNorm and example --- examples/vision/mamlpp/maml++_miniimagenet.py | 4 +-- learn2learn/nn/metabatchnorm.py | 2 +- .../vision/models/cnn4_metabatchnorm.py | 29 ++++++++++--------- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/examples/vision/mamlpp/maml++_miniimagenet.py b/examples/vision/mamlpp/maml++_miniimagenet.py index db2c073c..9d7a6b52 100755 --- a/examples/vision/mamlpp/maml++_miniimagenet.py +++ b/examples/vision/mamlpp/maml++_miniimagenet.py @@ -20,7 +20,7 @@ from typing import Tuple from tqdm import tqdm -from learn2learn.vision.models.cnn4_bnrs import CNN4_BNRS +from learn2learn.vision.models.cnn4_metabatchnorm import CNN4_MetaBatchNorm from examples.vision.mamlpp.MAMLpp import MAMLpp @@ -72,7 +72,7 @@ def __init__( ) # Model - self._model = CNN4_BNRS(ways, adaptation_steps=steps) + self._model = CNN4_MetaBatchNorm(ways, steps) if self._use_cuda: self._model.cuda() diff --git a/learn2learn/nn/metabatchnorm.py b/learn2learn/nn/metabatchnorm.py index e94fb6ec..ab800b8d 100644 --- a/learn2learn/nn/metabatchnorm.py +++ b/learn2learn/nn/metabatchnorm.py @@ -27,11 +27,11 @@ class MetaBatchNorm(torch.nn.Module): def __init__( self, num_features, + adaptation_steps, eps=1e-5, momentum=0.1, affine=True, meta_batch_norm=True, - adaptation_steps: int = 1, ): super(MetaBatchNorm, self).__init__() self.num_features = num_features diff --git a/learn2learn/vision/models/cnn4_metabatchnorm.py b/learn2learn/vision/models/cnn4_metabatchnorm.py index eb05a2db..865736a7 100644 --- a/learn2learn/vision/models/cnn4_metabatchnorm.py +++ b/learn2learn/vision/models/cnn4_metabatchnorm.py @@ -20,10 +20,10 @@ def __init__(self, input_size, output_size, adaptation_steps): self.relu = torch.nn.ReLU() self.normalize = MetaBatchNorm( output_size, + adaptation_steps, affine=True, momentum=0.999, eps=1e-3, - adaptation_steps=adaptation_steps, ) self.linear = torch.nn.Linear(input_size, output_size) fc_init_(self.linear) @@ -38,12 +38,12 @@ def forward(self, x, step): class ConvBlock_MetaBatchNorm(torch.nn.Module): def __init__( self, + adaptation_steps, in_channels, out_channels, kernel_size, max_pool=True, max_pool_factor=1.0, - adaptation_steps=1, ): super(ConvBlock_MetaBatchNorm, self).__init__() stride = (int(2 * max_pool_factor), int(2 * max_pool_factor)) @@ -58,8 +58,8 @@ def __init__( self.max_pool = lambda x: x self.normalize = MetaBatchNorm( out_channels, + adaptation_steps, affine=True, - adaptation_steps=adaptation_steps, # eps=1e-3, # momentum=0.999, ) @@ -92,32 +92,32 @@ class ConvBase_MetaBatchNorm(torch.nn.Sequential): def __init__( self, + adaptation_steps, hidden=64, channels=1, max_pool=False, layers=4, max_pool_factor=1.0, - adaptation_steps=1, ): core = [ ConvBlock_MetaBatchNorm( + adaptation_steps, channels, hidden, (3, 3), max_pool=max_pool, max_pool_factor=max_pool_factor, - adaptation_steps=adaptation_steps, ), ] for _ in range(layers - 1): core.append( ConvBlock_MetaBatchNorm( + adaptation_steps, hidden, hidden, kernel_size=(3, 3), max_pool=max_pool, max_pool_factor=max_pool_factor, - adaptation_steps=adaptation_steps, ) ) super(ConvBase_MetaBatchNorm, self).__init__(*core) @@ -131,22 +131,22 @@ def forward(self, x, step): class CNN4Backbone_MetaBatchNorm(ConvBase_MetaBatchNorm): def __init__( self, + adaptation_steps, hidden_size=64, layers=4, channels=3, max_pool=True, max_pool_factor=None, - adaptation_steps=1, ): if max_pool_factor is None: max_pool_factor = 4 // layers super(CNN4Backbone_MetaBatchNorm, self).__init__( + adaptation_steps, hidden=hidden_size, layers=layers, channels=channels, max_pool=max_pool, max_pool_factor=max_pool_factor, - adaptation_steps=adaptation_steps, ) def forward(self, x, step): @@ -162,19 +162,22 @@ class CNN4_MetaBatchNorm(torch.nn.Module): **Description** - The convolutional network commonly used for MiniImagenet, as described by Ravi et Larochelle, 2017. + The convolutional network commonly used for MiniImagenet, as described by Ravi et Larochelle, + 2017, using the MetaBatchNorm layer proposed by Antoniou et al. 2019. This network assumes inputs of shapes (3, 84, 84). - Instantiate `CNN4Backbone` if you only need the feature extractor. + Instantiate `CNN4Backbone_MetaBatchNorm` if you only need the feature extractor. **References** 1. Ravi and Larochelle. 2017. “Optimization as a Model for Few-Shot Learning.” ICLR. + 2. Antoniou et al. 2019. “How to train your MAML.“ ICLR. **Arguments** * **output_size** (int) - The dimensionality of the network's output. + * **adaptation_steps** (int) - Number of inner-loop adaptation steps. * **hidden_size** (int, *optional*, default=64) - The dimensionality of the hidden representation. * **layers** (int, *optional*, default=4) - The number of convolutional layers. * **channels** (int, *optional*, default=3) - The number of channels in input. @@ -184,30 +187,30 @@ class CNN4_MetaBatchNorm(torch.nn.Module): **Example** ~~~python - model = CNN4(output_size=20, hidden_size=128, layers=3) + model = CNN4(output_size=20, adaptation_steps=5, hidden_size=128, layers=3) ~~~ """ def __init__( self, output_size, + adaptation_steps, hidden_size=64, layers=4, channels=3, max_pool=True, embedding_size=None, - adaptation_steps=1, ): super(CNN4_MetaBatchNorm, self).__init__() if embedding_size is None: embedding_size = 25 * hidden_size self.features = CNN4Backbone_MetaBatchNorm( + adaptation_steps, hidden_size=hidden_size, channels=channels, max_pool=max_pool, layers=layers, max_pool_factor=4 // layers, - adaptation_steps=adaptation_steps, ) self.classifier = torch.nn.Linear( embedding_size, From d69dd95c32156eb4101571a91c936c0ed3ae8c2e Mon Sep 17 00:00:00 2001 From: Theo Morales Date: Thu, 31 Mar 2022 12:24:42 +0100 Subject: [PATCH 6/9] Add docstring to MetaBatchNorm class --- learn2learn/nn/metabatchnorm.py | 39 +++++++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/learn2learn/nn/metabatchnorm.py b/learn2learn/nn/metabatchnorm.py index ab800b8d..d165e24f 100644 --- a/learn2learn/nn/metabatchnorm.py +++ b/learn2learn/nn/metabatchnorm.py @@ -16,12 +16,38 @@ class MetaBatchNorm(torch.nn.Module): """ + [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/nn/metabatchnorm.py) + + **Description** + An extension of Pytorch's BatchNorm layer, with the Per-Step Batch Normalisation Running Statistics and Per-Step Batch Normalisation Weights and Biases improvements proposed in - MAML++ by Antoniou et al. It is adapted from the original Pytorch implementation at + "How to train your MAML". + It is adapted from the original Pytorch implementation at https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch, with heavy refactoring and a bug fix (https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch/issues/42). + + **Arguments** + + * **num_features** (int) - number of input features. + * **adaptation_steps** (int) - number of inner-loop adaptation steps. + * **eps** (float, *optional*, default=1e-5) - a value added to the denominator for numerical stability. + * **momentum** (float, *optional*, default=0.1) - the value used for the running_mean and running_var computation. Can be set to None for cumulative moving average (i.e. simple average). + * **affine** (bool, *optional*, default=True) - a boolean value that when set to True, this module has learnable affine parameters. + + **References** + + 1. Antoniou et al. 2019. "How to train your MAML." ICLR. + + **Example** + + ~~~python + batch_norm = MetaBatchNorm(100, 5) + input = torch.randn(20, 100, 35, 45) + for step in range(5): + output = batch_norm(input, step) + ~~~ """ def __init__( @@ -31,13 +57,11 @@ def __init__( eps=1e-5, momentum=0.1, affine=True, - meta_batch_norm=True, ): super(MetaBatchNorm, self).__init__() self.num_features = num_features self.eps = eps self.affine = affine - self.meta_batch_norm = meta_batch_norm self.num_features = num_features self.running_mean = torch.nn.Parameter( torch.zeros(adaptation_steps, num_features), requires_grad=False @@ -61,10 +85,11 @@ def forward( step, ): """ - :param input: input data batch, size either can be any. - :param step: The current inner loop step being taken. This is used when to learn per step params and - collecting per step batch statistics. - :return: The result of the batch norm operation. + **Arguments** + + * **input** (tensor) - Input data batch, size either can be any. + * **step** (int) - The current inner loop step being taken. This is used when to learn per + step params and collecting per step batch statistics. """ assert ( step < self.running_mean.shape[0] From a15c11539c36247064c4f575fc7a2f42e62e44a0 Mon Sep 17 00:00:00 2001 From: Theo Morales Date: Thu, 31 Mar 2022 12:40:48 +0100 Subject: [PATCH 7/9] Remove the step parameter and add an inference argument --- examples/vision/mamlpp/maml++_miniimagenet.py | 10 +++---- learn2learn/nn/metabatchnorm.py | 26 +++++++++-------- .../vision/models/cnn4_metabatchnorm.py | 28 ++++++++++++------- 3 files changed, 38 insertions(+), 26 deletions(-) diff --git a/examples/vision/mamlpp/maml++_miniimagenet.py b/examples/vision/mamlpp/maml++_miniimagenet.py index 9d7a6b52..8e7c9678 100755 --- a/examples/vision/mamlpp/maml++_miniimagenet.py +++ b/examples/vision/mamlpp/maml++_miniimagenet.py @@ -147,19 +147,19 @@ def _training_step( # Adapt the model on the support set for step in range(self._steps): # forward + backward + optimize - pred = learner(s_inputs, step) + pred = learner(s_inputs) support_loss = self._inner_criterion(pred, s_labels) learner.adapt(support_loss, first_order=not second_order, step=step) # Multi-Step Loss if msl: - q_pred = learner(q_inputs, step) + q_pred = learner(q_inputs) query_loss += self._step_weights[step] * self._inner_criterion( q_pred, q_labels ) # Evaluate the adapted model on the query set if not msl: - q_pred = learner(q_inputs, self._steps-1) + q_pred = learner(q_inputs, inference=True) query_loss = self._inner_criterion(q_pred, q_labels) acc = accuracy(q_pred, q_labels).detach() @@ -180,12 +180,12 @@ def _testing_step( # Adapt the model on the support set for step in range(self._steps): # forward + backward + optimize - pred = learner(s_inputs, step) + pred = learner(s_inputs) support_loss = self._inner_criterion(pred, s_labels) learner.adapt(support_loss, step=step) # Evaluate the adapted model on the query set - q_pred = learner(q_inputs, self._steps-1) + q_pred = learner(q_inputs, inference=True) query_loss = self._inner_criterion(q_pred, q_labels).detach() acc = accuracy(q_pred, q_labels) diff --git a/learn2learn/nn/metabatchnorm.py b/learn2learn/nn/metabatchnorm.py index d165e24f..a4c0d955 100644 --- a/learn2learn/nn/metabatchnorm.py +++ b/learn2learn/nn/metabatchnorm.py @@ -78,32 +78,36 @@ def __init__( self.backup_running_mean = torch.zeros(self.running_mean.shape) self.backup_running_var = torch.ones(self.running_var.shape) self.momentum = momentum + self._steps = adaptation_steps + self._current_step = 0 def forward( self, input, - step, + inference=False, ): """ **Arguments** * **input** (tensor) - Input data batch, size either can be any. - * **step** (int) - The current inner loop step being taken. This is used when to learn per - step params and collecting per step batch statistics. + * **inferencep** (bool, *optional*, default=False) - when set to `True`, uses the final step's parameters and running statistics. When set to `False`, automatically infers the current adaptation step. """ - assert ( - step < self.running_mean.shape[0] - ), f"Running forward with step={step} when initialised with {self.running_mean.shape[0]} steps!" - return F.batch_norm( + step = self._current_step + if inference: + step = self._steps - 1 + output = F.batch_norm( input, - self.running_mean[step], - self.running_var[step], - self.weight[step], - self.bias[step], + self.running_mean[self._current_step], + self.running_var[self._current_step], + self.weight[self._current_step], + self.bias[self._current_step], training=True, momentum=self.momentum, eps=self.eps, ) + if not inference: + self._current_step = self._current_step + 1 if self._current_step < (self._steps - 1) else 0 + return output def backup_stats(self): self.backup_running_mean.data = deepcopy(self.running_mean.data) diff --git a/learn2learn/vision/models/cnn4_metabatchnorm.py b/learn2learn/vision/models/cnn4_metabatchnorm.py index 865736a7..f18b188f 100644 --- a/learn2learn/vision/models/cnn4_metabatchnorm.py +++ b/learn2learn/vision/models/cnn4_metabatchnorm.py @@ -28,9 +28,9 @@ def __init__(self, input_size, output_size, adaptation_steps): self.linear = torch.nn.Linear(input_size, output_size) fc_init_(self.linear) - def forward(self, x, step): + def forward(self, x, inference=False): x = self.linear(x) - x = self.normalize(x, step) + x = self.normalize(x, inference=inference) x = self.relu(x) return x @@ -76,9 +76,9 @@ def __init__( ) maml_init_(self.conv) - def forward(self, x, step): + def forward(self, x, inference=False): x = self.conv(x) - x = self.normalize(x, step) + x = self.normalize(x, inference=inference) x = self.relu(x) x = self.max_pool(x) return x @@ -122,9 +122,9 @@ def __init__( ) super(ConvBase_MetaBatchNorm, self).__init__(*core) - def forward(self, x, step): + def forward(self, x, inference=False): for module in self: - x = module(x, step) + x = module(x, inference=inference) return x @@ -149,8 +149,8 @@ def __init__( max_pool_factor=max_pool_factor, ) - def forward(self, x, step): - x = super(CNN4Backbone_MetaBatchNorm, self).forward(x, step) + def forward(self, x, inference=False): + x = super(CNN4Backbone_MetaBatchNorm, self).forward(x, inference=inference) x = x.reshape(x.size(0), -1) return x @@ -236,7 +236,15 @@ def restore_backup_stats(self): if type(layer) is MetaBatchNorm: layer.restore_backup_stats() - def forward(self, x, step): - x = self.features(x, step) + def forward(self, x, inference=False): + """ + **Arguments** + + * **input** (tensor) - Input data batch, size either can be any. + * **inferencep** (bool, *optional*, default=False) - when set to `True`, uses the final + step's parameters and running statistics. When set to `False`, automatically infers the + current adaptation step. + """ + x = self.features(x, inference=inference) x = self.classifier(x) return x From 2beb0b2bee3e57695f77162a1fae9efacf3fdd01 Mon Sep 17 00:00:00 2001 From: Theo Morales Date: Thu, 31 Mar 2022 12:47:59 +0100 Subject: [PATCH 8/9] Lint --- learn2learn/nn/metabatchnorm.py | 18 +++++++++++++----- .../vision/models/cnn4_metabatchnorm.py | 3 ++- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/learn2learn/nn/metabatchnorm.py b/learn2learn/nn/metabatchnorm.py index a4c0d955..d0bbad22 100644 --- a/learn2learn/nn/metabatchnorm.py +++ b/learn2learn/nn/metabatchnorm.py @@ -32,9 +32,13 @@ class MetaBatchNorm(torch.nn.Module): * **num_features** (int) - number of input features. * **adaptation_steps** (int) - number of inner-loop adaptation steps. - * **eps** (float, *optional*, default=1e-5) - a value added to the denominator for numerical stability. - * **momentum** (float, *optional*, default=0.1) - the value used for the running_mean and running_var computation. Can be set to None for cumulative moving average (i.e. simple average). - * **affine** (bool, *optional*, default=True) - a boolean value that when set to True, this module has learnable affine parameters. + * **eps** (float, *optional*, default=1e-5) - a value added to the denominator for numerical + stability. + * **momentum** (float, *optional*, default=0.1) - the value used for the running_mean and + running_var computation. Can be set to None for cumulative moving average (i.e. simple + average). + * **affine** (bool, *optional*, default=True) - a boolean value that when set to True, this + module has learnable affine parameters. **References** @@ -90,7 +94,9 @@ def forward( **Arguments** * **input** (tensor) - Input data batch, size either can be any. - * **inferencep** (bool, *optional*, default=False) - when set to `True`, uses the final step's parameters and running statistics. When set to `False`, automatically infers the current adaptation step. + * **inferencep** (bool, *optional*, default=False) - when set to `True`, uses the final + step's parameters and running statistics. When set to `False`, automatically infers the + current adaptation step. """ step = self._current_step if inference: @@ -106,7 +112,9 @@ def forward( eps=self.eps, ) if not inference: - self._current_step = self._current_step + 1 if self._current_step < (self._steps - 1) else 0 + self._current_step = ( + self._current_step + 1 if self._current_step < (self._steps - 1) else 0 + ) return output def backup_stats(self): diff --git a/learn2learn/vision/models/cnn4_metabatchnorm.py b/learn2learn/vision/models/cnn4_metabatchnorm.py index f18b188f..1ea6f923 100644 --- a/learn2learn/vision/models/cnn4_metabatchnorm.py +++ b/learn2learn/vision/models/cnn4_metabatchnorm.py @@ -178,7 +178,8 @@ class CNN4_MetaBatchNorm(torch.nn.Module): * **output_size** (int) - The dimensionality of the network's output. * **adaptation_steps** (int) - Number of inner-loop adaptation steps. - * **hidden_size** (int, *optional*, default=64) - The dimensionality of the hidden representation. + * **hidden_size** (int, *optional*, default=64) - The dimensionality of the hidden + representation. * **layers** (int, *optional*, default=4) - The number of convolutional layers. * **channels** (int, *optional*, default=3) - The number of channels in input. * **max_pool** (bool, *optional*, default=True) - Whether ConvBlocks use max-pooling. From f2ddbc432d889494cfc28ec7672b778930721f06 Mon Sep 17 00:00:00 2001 From: Theo Morales Date: Mon, 18 Apr 2022 12:08:16 +0100 Subject: [PATCH 9/9] Fix step being used in forward() --- learn2learn/nn/metabatchnorm.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/learn2learn/nn/metabatchnorm.py b/learn2learn/nn/metabatchnorm.py index d0bbad22..4f4967c5 100644 --- a/learn2learn/nn/metabatchnorm.py +++ b/learn2learn/nn/metabatchnorm.py @@ -98,15 +98,13 @@ def forward( step's parameters and running statistics. When set to `False`, automatically infers the current adaptation step. """ - step = self._current_step - if inference: - step = self._steps - 1 + step = self._current_step if not inference else self._steps - 1 output = F.batch_norm( input, - self.running_mean[self._current_step], - self.running_var[self._current_step], - self.weight[self._current_step], - self.bias[self._current_step], + self.running_mean[step], + self.running_var[step], + self.weight[step], + self.bias[step], training=True, momentum=self.momentum, eps=self.eps,