Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BatchNorm training instability #659

Open
andrewdipper opened this issue Feb 17, 2024 · 5 comments
Open

BatchNorm training instability #659

andrewdipper opened this issue Feb 17, 2024 · 5 comments

Comments

@andrewdipper
Copy link

Hi, I think the use of running_mean and running_var during training time in BatchNorm causes training instability and increased learning_rate sensitivity. With momentum low (say 0.5) the layer works fine but this is poor for test time due to the short duration of the running average. When momentum is at the default 0.99 training performance is degraded (higher train losses).

I originally observed this in a larger ResNet model but used the simplified introductory equinox CNN model to verify. Plotting the l2 norms of the gradients they spike 2-3 iterations in. I haven't fully traced the behavior but I think it comes from the gradient flow effectively being clipped by (1 - momentum) when the running_mean / running_var is used as opposed to batch_mean / batch_var.

I haven't seen the use of running_mean, running_var for the training side. It's nice that the forward behavior is the same for train and test time but what is the expected behavior for the gradients? I'm not sure if I'm missing something

I included my generated plots and quick test with the introductory CNN example. Ema batch norm is eqx.nn.BatchNorm and the batch batch norm plot is using batch_mean and batch_var instead of running_mean and running_var during train time.

At least it might be worth noting the batch norm behavior - it took some time for me to chase my particular performance issue to BatchNorm.

Thanks!

image
image
image
image


import equinox as eqx
import jax
import jax.numpy as jnp
import optax  # https:/deepmind/optax
import torch  # https://pytorch.org
import torchvision  # https://pytorch.org
from jaxtyping import Array, Float, Int, PyTree  # https:/google/jaxtyping
import localbatchnorm

import matplotlib.pyplot as plt

jax.config.update('jax_platform_name', 'cpu')

BATCH_SIZE = 64
LEARNING_RATE = 3e-3
STEPS = 100
PRINT_EVERY = 10
SEED = 5678

key = jax.random.PRNGKey(SEED)

normalise_data = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,)),
    ]
)
train_dataset = torchvision.datasets.MNIST(
    "MNIST",
    train=True,
    download=True,
    transform=normalise_data,
)
test_dataset = torchvision.datasets.MNIST(
    "MNIST",
    train=False,
    download=True,
    transform=normalise_data,
)
trainloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True
)
testloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=True
)




class CNN(eqx.Module):
    layers: list

    def __init__(self, key, mode=0):
        key1, key2, key3, key4 = jax.random.split(key, 4)
        # Standard CNN setup: convolutional layer, followed by flattening,
        # with a small MLP on top.
        self.layers = [
            eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
            eqx.nn.MaxPool2d(kernel_size=2),
            jax.nn.relu,
            jnp.ravel,
            eqx.nn.Linear(1728, 512, key=key2),
        ]
        if mode == 1:
            self.layers.append(eqx.nn.BatchNorm(input_size=512, momentum=0.99, axis_name="batch"))
        elif mode == 2:
            self.layers.append(localbatchnorm.BatchNorm(input_size=512, momentum=0.99, axis_name="batch"))
        elif mode == 3:
            self.layers.append(eqx.nn.BatchNorm(input_size=512, momentum=0.5, axis_name="batch"))
        self.layers += [
            jax.nn.relu,
            eqx.nn.Linear(512, 64, key=key3),
            jax.nn.relu,
            eqx.nn.Linear(64, 10, key=key4),
            jax.nn.log_softmax,
        ]

    def __call__(self, x: Float[Array, "1 28 28"], state) -> Float[Array, "10"]:
        for layer in self.layers:
            if isinstance(layer, eqx.nn.BatchNorm) or isinstance(layer, localbatchnorm.BatchNorm):
                x, state = layer(x, state)
            else:
                x = layer(x)
        return x, state



def lossfn(
    model: CNN, state, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]) -> Float[Array, ""]:
    pred_y, state = jax.vmap(model, in_axes=(0, None), out_axes=(0,None), axis_name='batch')(x, state)
    acc = jnp.mean(y == jnp.argmax(pred_y, axis=1))
    pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
    return -jnp.mean(pred_y), (acc, state)





@eqx.filter_jit
def make_step(model, state, opt_state, xs, ys):
    loss, grad = eqx.filter_value_and_grad(lossfn, has_aux=True)(model, state, xs, ys)
    loss, (acc, state) = loss
    updates, opt_state = optim.update(grad, opt_state)
    model = eqx.apply_updates(model, updates)
    
    leaves = jax.tree_util.tree_leaves(eqx.filter(grad, eqx.is_array))
    grads = [jnp.linalg.norm(x) for x in leaves]

    return model, state, opt_state, loss, grads, acc

def infinite_trainloader():
    while True:
        yield from trainloader
       
titles = ['No batch norm', 'ema batch norm - momentum=0.99', 'batch batch norm', 'ema batch norm - momentum=0.5']
for mode in range(4): 
    print('\n\n' + titles[mode] + ' acc, loss')
    mkey = jax.random.PRNGKey(1)
    model, state = eqx.nn.make_with_state(CNN)(mkey, mode)
    optim = optax.adam(LEARNING_RATE)
    opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
            
    ghist = []
    loss_hist = []
    for step, (x, y) in zip(range(STEPS), infinite_trainloader()):
    
        model, state, opt_state, loss, grads, acc = make_step(model, state, opt_state, x.numpy(), y.numpy())
        ghist.append(grads)
        loss_hist.append(loss)
        if (step + 1) % PRINT_EVERY == 0:
            print(step+1, acc, loss)
    
    fig, axes = plt.subplots(2)
    plt.setp(axes[0], title=titles[mode], ylabel='grad l2 norm')
    plt.setp(axes[1], ylabel='train loss')

    for i in range(len(ghist[0])):
        axes[0].plot([x[i] for x in ghist], label='grad' + str(i))
        
    axes[0].legend(loc='upper right')
    axes[1].plot(loss_hist)
    plt.show()


@patrick-kidger
Copy link
Owner

Right, so there are like 10 different flavours of BatchNorm! As you've noticed, I chose to implement one of the ones with consistent train vs inference behaviour.

I do recall some of the papers from a few years ago finding that the gradients through BatchNorm were actually important, and that it was not merely the normalising effect that was helpful. So it could perhaps be something to do with the gradients.

The other thing that comes to mind is that on the very first forward pass we simply set running_{mean,var} = batch_{mean,var}, whereas later on we use momentum. I can see that this gives undue weight to the very first batch. Perhaps what we should have done is to use an Adam-style correction, where we initialise at zero, do a running average, and then renormalise.

If what I've just said makes sense -- let me know if not -- can you try giving it a go? If it works better for your use cases then I'd be open to adjusting our implementation (whether by default or behind a flag).

@andrewdipper
Copy link
Author

That makes sense, I looked into the initialization of running_* but I don't think it was exactly Adam-style. I'll run some tests and get back

@andrewdipper
Copy link
Author

I looked into using an Adam style correction - it helps a bit but there is still instability. I also looked into initializing at 0, doing a running average, and then instead of renormalizing, replace the zero part of the running average with batch_{mean,var} For instance on the kth batch we would have corrected_mean = running_mean + batch_mean * momentum**k The idea being we start purely using batch statistics and then transition into using the running statistics as they become fully populated. Call it smooth start.

This performs better than Adam style. With a larger network both are fine with small momentum (0.5 - this is essentially using batch stats) but have instability / training suffers with momentum at ~0.99. When momentum is ~ 0.9999 smooth start training performs similarly to just using batch statistics. My guess is a longer moving average causes issues when the output distribution of the BatchNorm layer changes (could be gradient related??). A very large momentum with smooth start allows this distribution to stabilize a bit before the running average is heavily weighted.

Below are graphs from the toy example I posted above. The 3 graphs in each chart are: max l2 norm of gradients, mean of the output distribution from BatchNorm, and std of the same distribution.

image
image

@patrick-kidger
Copy link
Owner

Ah, this is excellent! This does suggest that we should maybe adjust our BatchNorm implementation, if nothing else so that it does something that (a) works or (b) conforms to other implementations... even if it's less theoretically elegant.

If you feel like taking this on, then I'd be very happy to take a PR on this. Perhaps with an API looking something like BatchNorm(..., approach="the name of some strategy"). This could default to None, in which case it does the default behaviour whilst raising a warning that a strategy has't been explicitly chosen.

I don't have any strong opinions on exactly which strategies we should implement.

@andrewdipper
Copy link
Author

For sure, I'll check other implementations and come up with something. Multiple strategies seems like a good way to go

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants