-
-
Notifications
You must be signed in to change notification settings - Fork 138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BatchNorm training instability #659
Comments
Right, so there are like 10 different flavours of I do recall some of the papers from a few years ago finding that the gradients through The other thing that comes to mind is that on the very first forward pass we simply set If what I've just said makes sense -- let me know if not -- can you try giving it a go? If it works better for your use cases then I'd be open to adjusting our implementation (whether by default or behind a flag). |
That makes sense, I looked into the initialization of running_* but I don't think it was exactly Adam-style. I'll run some tests and get back |
I looked into using an Adam style correction - it helps a bit but there is still instability. I also looked into initializing at 0, doing a running average, and then instead of renormalizing, replace the zero part of the running average with This performs better than Adam style. With a larger network both are fine with small momentum (0.5 - this is essentially using batch stats) but have instability / training suffers with momentum at ~0.99. When momentum is ~ 0.9999 smooth start training performs similarly to just using batch statistics. My guess is a longer moving average causes issues when the output distribution of the BatchNorm layer changes (could be gradient related??). A very large momentum with smooth start allows this distribution to stabilize a bit before the running average is heavily weighted. Below are graphs from the toy example I posted above. The 3 graphs in each chart are: max l2 norm of gradients, mean of the output distribution from BatchNorm, and std of the same distribution. |
Ah, this is excellent! This does suggest that we should maybe adjust our If you feel like taking this on, then I'd be very happy to take a PR on this. Perhaps with an API looking something like I don't have any strong opinions on exactly which strategies we should implement. |
For sure, I'll check other implementations and come up with something. Multiple strategies seems like a good way to go |
Hi, I think the use of running_mean and running_var during training time in BatchNorm causes training instability and increased learning_rate sensitivity. With momentum low (say 0.5) the layer works fine but this is poor for test time due to the short duration of the running average. When momentum is at the default 0.99 training performance is degraded (higher train losses).
I originally observed this in a larger ResNet model but used the simplified introductory equinox CNN model to verify. Plotting the l2 norms of the gradients they spike 2-3 iterations in. I haven't fully traced the behavior but I think it comes from the gradient flow effectively being clipped by (1 - momentum) when the running_mean / running_var is used as opposed to batch_mean / batch_var.
I haven't seen the use of running_mean, running_var for the training side. It's nice that the forward behavior is the same for train and test time but what is the expected behavior for the gradients? I'm not sure if I'm missing something
I included my generated plots and quick test with the introductory CNN example. Ema batch norm is eqx.nn.BatchNorm and the batch batch norm plot is using batch_mean and batch_var instead of running_mean and running_var during train time.
At least it might be worth noting the batch norm behavior - it took some time for me to chase my particular performance issue to BatchNorm.
Thanks!
The text was updated successfully, but these errors were encountered: