Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP2][1/n] construct NF4Tensor from bf16/fp16/fp32 #118

Merged
merged 22 commits into from
Apr 19, 2024

Conversation

weifengpy
Copy link
Contributor

@weifengpy weifengpy commented Apr 4, 2024

  • addressing to(bf16/fp16/fp32) with __torch_function__, as suggested by @cpuhrsch
  • NF4Tensor.dtype is the dtype from construction. For example, to_nf4(fp32_tensor) will return nf4tensor with dtype fp32
  • dtype=fp16/bf16/fp32 in pytest test/dtypes/test_nf4.py
  • tested in torchtune with pytest tests -m integration_test @rohan-varma

it brings 2 benefits

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 4, 2024
@weifengpy weifengpy marked this pull request as draft April 4, 2024 01:21
@weifengpy weifengpy changed the title proof of concept for FSDP2 + NF4Tensor [In Progress] FSDP2 + NF4Tensor Apr 4, 2024
cpuhrsch and others added 3 commits April 4, 2024 10:53
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@weifengpy

This comment was marked as outdated.

weifengpy and others added 8 commits April 16, 2024 14:13
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@weifengpy weifengpy changed the title [In Progress] FSDP2 + NF4Tensor [FSDP2][1/n] construct NF4Tensor from bf16/fp16/fp32 Apr 17, 2024
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@weifengpy
Copy link
Contributor Author

weifengpy commented Apr 17, 2024

error test/kernel/test_galore_downproj.py seems to be irrelevant. it's throws error when parsing nvidia-smi. cc @msaroufim @jeromeku since it's newly added this week

@weifengpy weifengpy marked this pull request as ready for review April 17, 2024 22:03
@cpuhrsch
Copy link
Contributor

Great! Thank you! Can you try rebasing to see if CI runs green?

@weifengpy
Copy link
Contributor Author

Great! Thank you! Can you try rebasing to see if CI runs green?

rebasing now


@implements_torch_function(torch.Tensor.to)
def function_to_dtype(*args, **kwargs):
return args[0].get_original_weight().to(args[1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit there is a few ways you can call to that this isn't robust too I would imagine

Copy link
Contributor Author

@weifengpy weifengpy Apr 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean 3 ways to call .to ? I can raise unimplemented if it helps

  • to(dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format)
  • torch.to(device=None, dtype=None, non_blocking=False, copy=False, memory_format=torch.preserve_format)
  • torch.to(other, non_blocking=False, copy=False)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I think in theory they all can be supported its just that the other args/kwargs are getting dropped as implemented

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got you. will see if I can pass in args/kwargs instead of dropping them

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated the PR to include args and kwargs instead of dropping. passing to dispatch if not implemented

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, thanks for pushing this through!

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@weifengpy

This comment was marked as outdated.

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@weifengpy weifengpy marked this pull request as draft April 18, 2024 00:56
weifengpy and others added 4 commits April 18, 2024 16:12
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
def test_smoketest_linear_compile(self):
for dtype in [torch.bfloat16, torch.float16]:
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0) and dtype == torch.bfloat16:
self.skipTest("test requires SM capability of at least (8, 0).")
Copy link
Contributor Author

@weifengpy weifengpy Apr 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_smoketest_linear_compile is always skipped before because we exit with self.skipTest with torch.bfloat16 and did not have a chance to test torch.float16. It is fixed in this version by using @parameterize

if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0) and dtype == torch.bfloat16:
self.skipTest("test requires SM capability of at least (8, 0).")
if version.parse(torch.__version__) < version.parse("2.3.0"):
self.skipTest("test requires 2.3.0 and above for tracing NF4Tensor")
Copy link
Contributor Author

@weifengpy weifengpy Apr 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

starting from 2.3.0 we can trace subclass when inner tensors have different shapes than outer wrapper class. Specifically, we use symbolic_context.inner_contexts instead of symbolic_context from outer wrapper class: https:/pytorch/pytorch/blob/main/torch/_subclasses/meta_utils.py#L649

@weifengpy weifengpy marked this pull request as ready for review April 19, 2024 01:53
@cpuhrsch cpuhrsch merged commit a7ff835 into pytorch:main Apr 19, 2024
13 checks passed
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants