-
Notifications
You must be signed in to change notification settings - Fork 143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP2][1/n] construct NF4Tensor from bf16/fp16/fp32 #118
Conversation
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
This comment was marked as outdated.
This comment was marked as outdated.
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
error |
Great! Thank you! Can you try rebasing to see if CI runs green? |
rebasing now |
torchao/dtypes/nf4tensor.py
Outdated
|
||
@implements_torch_function(torch.Tensor.to) | ||
def function_to_dtype(*args, **kwargs): | ||
return args[0].get_original_weight().to(args[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit there is a few ways you can call to that this isn't robust too I would imagine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean 3 ways to call .to
? I can raise unimplemented if it helps
to(dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format)
torch.to(device=None, dtype=None, non_blocking=False, copy=False, memory_format=torch.preserve_format)
torch.to(other, non_blocking=False, copy=False)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I think in theory they all can be supported its just that the other args/kwargs are getting dropped as implemented
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got you. will see if I can pass in args/kwargs instead of dropping them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated the PR to include args and kwargs instead of dropping. passing to dispatch if not implemented
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good, thanks for pushing this through!
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
This comment was marked as outdated.
This comment was marked as outdated.
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
def test_smoketest_linear_compile(self): | ||
for dtype in [torch.bfloat16, torch.float16]: | ||
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0) and dtype == torch.bfloat16: | ||
self.skipTest("test requires SM capability of at least (8, 0).") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_smoketest_linear_compile
is always skipped before because we exit with self.skipTest
with torch.bfloat16
and did not have a chance to test torch.float16
. It is fixed in this version by using @parameterize
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0) and dtype == torch.bfloat16: | ||
self.skipTest("test requires SM capability of at least (8, 0).") | ||
if version.parse(torch.__version__) < version.parse("2.3.0"): | ||
self.skipTest("test requires 2.3.0 and above for tracing NF4Tensor") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
starting from 2.3.0 we can trace subclass when inner tensors have different shapes than outer wrapper class. Specifically, we use symbolic_context.inner_contexts
instead of symbolic_context
from outer wrapper class: https:/pytorch/pytorch/blob/main/torch/_subclasses/meta_utils.py#L649
to(bf16/fp16/fp32)
with__torch_function__
, as suggested by @cpuhrschNF4Tensor.dtype
is the dtype from construction. For example,to_nf4(fp32_tensor)
will return nf4tensor with dtypefp32
dtype=fp16/bf16/fp32
inpytest test/dtypes/test_nf4.py
pytest tests -m integration_test
@rohan-varmait brings 2 benefits
grad.dtype == param.dtype