Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make _cuda_path more reliable #24139

Merged
merged 1 commit into from
Oct 7, 2024

Conversation

hartikainen
Copy link
Contributor

@hartikainen hartikainen commented Oct 5, 2024

See my comment in #22590 (comment). This defaults to the original behavior of finding cuda_nvcc path as relative to jaxlib, and only applies the new check if the original way fails. Pulling the path directly from cuda_nvcc seems a bit less error prone and, to me at least, feels preferable over the original way. But keeping this order should be more backwards compatible.

@PhilipVinc
Copy link
Contributor

#22590 was about using cuda12_local which accounts for cuda_nvcc NOT being installed through pypi. So this PR does not solve it (I think?)

To solve #22590, you would need to add a further fallback checking in $CUDA_PATH

@hartikainen
Copy link
Contributor Author

Ah, you're right! I got distracted by the comment about nvidia-cuda-nvcc-cu12 in the thread. I'll update my comment above to remove the fix comment.

cuda_nvcc_path = pathlib.Path(cuda_nvcc.__file__).parent
return str(cuda_nvcc_path)

if (path := _try_jaxlib_relative_path()) is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'd be comfortable simply removing the jaxlib-relative path. It never was the right thing to do, given, for example, you might have installed cuda_nvcc in the system package root but jaxlib in a different package root.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! I'll remove it. Do you think it would make sense to add a check for $CUDA_ROOT? That should fix Filippo's issue in #22590. I'm happy to add that here while we're at it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think that would be a great idea.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just pushed the $CUDA_ROOT change. Let me know if anything needs to be tweaked.

@hawkinsp hawkinsp self-assigned this Oct 7, 2024
@hawkinsp hawkinsp added the NVIDIA GPU Issues specific to NVIDIA GPUs label Oct 7, 2024
Copy link
Collaborator

@hawkinsp hawkinsp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two more changes:

  • I think you need to add nvidia to the mypy unknown module list here:
    module = [
  • Please squash your commits.

Thanks!

- Remove jax-relative module path test
- Use `$CUDA_ROOT` environment variable if available
- Use `cuda_nvcc` module's path if installed
@hartikainen
Copy link
Contributor Author

Thanks for the comments, @hawkinsp. I just squashed my changes and rebased the branch.

@hawkinsp hawkinsp added the pull ready Ready for copybara import and testing label Oct 7, 2024
@copybara-service copybara-service bot merged commit 8473391 into jax-ml:main Oct 7, 2024
12 of 13 checks passed
@hartikainen hartikainen deleted the fix-cuda_path branch October 7, 2024 19:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
NVIDIA GPU Issues specific to NVIDIA GPUs pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants