Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF: XLA stable softmax #16892

Merged
merged 8 commits into from
Apr 25, 2022
Merged

TF: XLA stable softmax #16892

merged 8 commits into from
Apr 25, 2022

Conversation

gante
Copy link
Member

@gante gante commented Apr 22, 2022

What does this PR do?

As discussed in the thread about XLA problems (#16838), this PR adds a stable wrapper for the softmax operation, and replaces tf.nn.softmax by the wrapped function.

This PR:

  • Adds the wrapped softmax, named stable_softmax, in tf_utils.py. Its docstring includes why it is needed and why the new operation is valid;
  • Adds tests to the wrapped softmax, including XLA tests;
  • Replaces tf.nn.softmax by stable_softmax everywhere except in the doctests (I think it overcomplicates the examples, and no XLA should be needed there);
  • Removes the skipIf for XLA tests, as they can now be successfully executed in a CPU.

Closes #16838

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 22, 2022

The documentation is not available anymore as the PR was closed or merged.

@gante gante marked this pull request as ready for review April 22, 2022 15:27
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great find for the bug and thanks a lot for fixing all models!

src/transformers/tf_utils.py Outdated Show resolved Hide resolved
src/transformers/tf_utils.py Outdated Show resolved Hide resolved
src/transformers/tf_utils.py Outdated Show resolved Hide resolved
xla_out = xla_masked_softmax(x, boolean_mask)
out = masked_softmax(x, boolean_mask)
assert tf.experimental.numpy.allclose(xla_out, out)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we have a test for bacth_size > 1 ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added batch size > 1 👍

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good for me, just left 2 nits.
(didn't check the changes in TFGPT2, TFT5 tests though. Let me know if you prefer me to check those too.)

Thank you, @gante 💯

@Rocketknight1
Copy link
Member

This looks good to me! Do you think it would be better to change stable_softmax to only add the offset if we're running on CPU? It makes very little difference either way, but we could hide the complexity of that inside stable_softmax and keep our code paths entirely unchanged on GPU. I'm not certain, though - since it's such a small change maybe we can just do it everywhere.

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 22, 2022

This looks good to me! Do you think it would be better to change stable_softmax to only add the offset if we're running on CPU? It makes very little difference either way, but we could hide the complexity of that inside stable_softmax and keep our code paths entirely unchanged on GPU. I'm not certain, though - since it's such a small change maybe we can just do it everywhere.

Good point! Hope this won't affect tests on GPU (at least not for PT/TF equivalence which use 1e-5). Let's see!

@gante
Copy link
Member Author

gante commented Apr 22, 2022

@Rocketknight1 @ydshieh if you run the test and print the difference between stable_softmax and tf.nn.softmax, the difference is exactly 0.0 -- I don't think we need to worry about that :D

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 22, 2022

@gante With this, do we still have issues regarding sampling in generate(). Sorry, I didn't really follow that issue about sampling, but would like to know a bit more 😄

@gante
Copy link
Member Author

gante commented Apr 22, 2022

@ydshieh after this fix, the errors related to generate() are gone -- they were caused by the forward pass in the models, which in turn were caused by the issue this PR solves

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 22, 2022

(I might be completely wrong below)

I could imagine that we (will) have tests like:

  • testing non-XLA and XLA generte() that use sampling
    • even with this PR, the differences of output logits between these two might still be as large as, say, 1e-3?
    • if so, the sampling might give different sampling results ..?
    • if not, what's the magnitude of the diff we get after this PR?
  • testing PT and TF generte() that use sampling
    • so same potential issue as above ..?

Thanks 🙏

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 22, 2022

OK, I saw your previous comment

I've spun up an Nvidia T4 ( = no tf32 format) and got an error < 1e-5 for all cases

@Rocketknight1
Copy link
Member

Based on the testing results, I'm happy for this to be merged now! If this is an XLA bug, though, we should make sure to revert our changes once none of the TF versions we support are affected by it anymore.

Should we add a TODO to the masked_softmax function or a reminder somewhere to make sure that we document why this change is here, and when it can be removed?

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great - the solution is clean!

@gante
Copy link
Member Author

gante commented Apr 25, 2022

@Rocketknight1 added a TODO with instructions related to when to deprecate 👍

@gante gante merged commit e03966e into huggingface:main Apr 25, 2022
@gante gante deleted the stable_softmax branch April 25, 2022 19:10
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

TF: XLA model output differs when certain outputs are passed
6 participants