Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BatchNorm training instability fix #675

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

andrewdipper
Copy link

This is in reference to issue 659.

I modified BatchNorm to have two approaches "batch" and "ema". "batch" just uses the batch statistics during training time. If approach is not specified it defaults to "batch" with a warning. It's robust and seems to be the standard choice - it's far less likely to kill a model just by adding it.

"ema" is based of the smooth start method in the above issue. So keep a running mean and variance but instead of renormalizing Adam style the parts of the running averages that are zeroed are filled with the batch statistics. The problem is it's still not robust - the momentum parameter is simultaneously specifying a warmup period (when we're expecting the input distribution to change significantly) and how long we want the running average to be. So I added a linear warmup period.

Now for any choice of momentum there seems to be a warmup_period choice that will give good results. And validation performance was at least as good as with batch mode for my tests. However, I don't see a good default for warmup_period.

Some considerations:

  • having approach="batch" and the common axis_name="batch" is a little awkward
  • There's an example using BatchNorm - that will start raising a warning and should probably get changed
  • The current BatchNorm behavior can't be exactly replicated (ema / momentum=0.99 / warmup_period=1) is close but different at the start
  • There's one more piece of state hence the test_stateful.py change. Though this could be conditionally removed for approach="batch" if desired

Let me know what you think or if any changes or tests need to be added

image

self.inference = inference
self.axis_name = axis_name
self.input_size = input_size
self.eps = eps
self.channelwise_affine = channelwise_affine
self.momentum = momentum
self.warmup_period = max(1, warmup_period)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the max? Perhaps it would be better to just error out on values that are too small?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warmup_period=0 seemed natural for off - Changed to just check and error out

self.inference = inference
self.axis_name = axis_name
self.input_size = input_size
self.eps = eps
self.channelwise_affine = channelwise_affine
self.momentum = momentum
self.warmup_period = max(1, warmup_period)

@jax.named_scope("eqx.nn.BatchNorm")
def __call__(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not completely obvious to me that the ema implementation, with default arguments, reproduces the previous behaviour. (For example, we have warmup_period=1000 by default?)

Can you add some comments explaining what each approach corresponds to?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ema with warmup_period=1 approximately reproduces previous behavior. As I noted the start is different because of how the running statistics are initially populated. With warmup_period=1 there's no interpolation between the batch and running stats - the running stats are always used as with the previous behavior. I can give an exact replication with an extra approach if necessary.

Added some to the documentation

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think an exact replication is probably important for the default behaviour, just because I'd like to be sure that we're bit-for-bit backward compatible.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, it was different enough that I added it as "ema_compatibility". I changed the warning to rather strongly recommend against using "ema_compatibility". I haven't found a use case where I wouldn't expect to see the instability (at least with a larger learning rate) but that could very much be due to a lack of imagination on my part.. That part can definitely change if needed

state = state.set(self.first_time_index, jnp.array(False))
momentum = self.momentum
zero_frac = state.get(self.zero_frac_index)
zero_frac *= momentum
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stylistic nit: I tend not to use the inplace operations in JAX code. This (a) fits with the functional style a bit better, and (b) emphasises that we're definitely falling back to the zero_frac = zero_frac * momentum interpretation of the syntax. (Gosh, why does Python has two different meanings for the same syntax?)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, done


batch_mean, batch_var = jax.vmap(_stats)(x)
running_mean, running_var = state.get(self.state_index)
momentum = self.momentum
running_mean = (1 - momentum) * batch_mean + momentum * running_mean
running_var = (1 - momentum) * batch_var + momentum * running_var
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These don't appear to be used on the batch branch. I think the lines here can be reorganised to keep each approach only using the things it needs.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are used by the batch branch when we're in inference mode so they still need to be computed and stored

Comment on lines 195 to 203
warmup_count = state.get(self.count_index)
warmup_count = jnp.minimum(warmup_count + 1, self.warmup_period)
state = state.set(self.count_index, warmup_count)

warmup_frac = warmup_count / self.warmup_period
norm_mean = zero_frac * batch_mean + running_mean
norm_mean = (1.0 - warmup_frac) * batch_mean + warmup_frac * norm_mean
norm_var = zero_frac * batch_var + running_var
norm_var = (1.0 - warmup_frac) * batch_var + warmup_frac * norm_var
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm definitely going to have to sit down and grok what's going on here more carefully! As above it would be good to have some comments / docstrings / references / etc. describing what each approach is meant to do.

(C.f. something like the MultiheadAttention docstring for an example on how to use LaTeX if it'd be helpful.)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some commentary and tried making it a bit cleaner.

But overall batch mode should follow the cited paper. Ema follows the prior behavior but changes the initialization of the running stats and adds interpolation so it can be stable while training.

debias_coef = (axis_size) / jnp.maximum(axis_size - 1, self.eps)
running_var = (
1 - momentum
) * debias_coef * batch_var + momentum * running_var
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I neglected to use unbiased variance so corrected that here

@@ -202,8 +259,15 @@ def _stats(y):
norm_var = zero_frac * batch_var + running_var
norm_var = (1.0 - warmup_frac) * batch_var + warmup_frac * norm_var
else:
axis_size = jax.lax.psum(jnp.array(1.0), self.axis_name)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm using this to get the length of the "batch" axis - but not sure it's the best / correct way

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the correct way! IIRC psum(1) is actually special-cased for this purpose.

Comment on lines 135 to 138
- `approach`: The approach to use for the running statistics. If `approach=None`
a warning will be raised and approach will default to `"batch"`. During
training `"batch"` only uses batch statisics while`"ema"` uses the running
statistics.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So continuing from my previous comment -- probably the default should be ema if approach=None.

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay! Sorry for taking so long to getting back around to reviewing this.

Lmk once you're happy that the previous behaviour is replicated by default, and I'll sit down with a pen and paper and satisfy myself that the calculations all look reasonable!

@andrewdipper
Copy link
Author

All good - I got caught up in other things myself!

From my tests the replication is exact now. It added another approach that is very similar to "ema" but it seemed like the most reasonable way to organize it. Let me know if anything isn't clear

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants