Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove float32 dtype assumption #1803

Closed
wants to merge 8 commits into from

Conversation

NeilGirdhar
Copy link
Contributor

@NeilGirdhar NeilGirdhar commented Jan 24, 2022

Fixes #1777

Checklist

  • This change is discussed in a Github issue/discussion (please add a link).
  • The documentation and docstrings adhere to the documentation guidelines.
  • This change includes necessary high-coverage tests. (No quality testing = no merge!)

@NeilGirdhar NeilGirdhar changed the title Remove dtype float32 dtype assumption Remove float32 dtype assumption Jan 24, 2022
flax/linen/module.py Outdated Show resolved Hide resolved
@NeilGirdhar NeilGirdhar force-pushed the default_dtypes branch 2 times, most recently from 1acac78 to 9de2d91 Compare January 25, 2022 15:54
@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@NeilGirdhar NeilGirdhar force-pushed the default_dtypes branch 12 times, most recently from 6d87556 to 925e849 Compare January 25, 2022 21:01
@PhilipVinc PhilipVinc mentioned this pull request Jan 27, 2022
@NeilGirdhar NeilGirdhar force-pushed the default_dtypes branch 3 times, most recently from ac1d33a to 01fdbf7 Compare January 28, 2022 21:59
@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Jan 28, 2022

Also, if possible, it would be nice to get #1812 in before this pull request to make type checking easier.

@NeilGirdhar NeilGirdhar force-pushed the default_dtypes branch 8 times, most recently from 4cf86a3 to 23f094d Compare February 7, 2022 04:17
@NeilGirdhar NeilGirdhar force-pushed the default_dtypes branch 2 times, most recently from 66948f6 to 3a0d8d9 Compare February 28, 2022 15:29
* Infer dtypes from inputs where possible.
* LSTM dtype assumption persists; this is repaired in a separate pull
  request.
@PhilipVinc
Copy link
Contributor

Some internal issues are holding this back?
I'm waiting on this to release a new version of netKet... do you have a rough timeline for when this will be merged?

@jheek
Copy link
Member

jheek commented Mar 3, 2022

Some internal issues are holding this back?

Currently this PR is blocked on internal testing. I'm looking into it today though

flax/linen/activation.py Outdated Show resolved Hide resolved
flax/linen/activation.py Outdated Show resolved Hide resolved
flax/linen/attention.py Outdated Show resolved Hide resolved
param_dtype: Optional[InexactDType],
computation_dtype: Optional[InexactDType]) -> Tuple[InexactDType,
InexactDType]:
returned_param_dtype = input_dtype if param_dtype is None else param_dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

param_dtype should still be float32 by default (see FLIP)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, sorry for missing this!

Copy link
Contributor Author

@NeilGirdhar NeilGirdhar Mar 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it makes more sense to simply make the param_dtype and computation_dtype default to float32 on the modules? That way these helper functions can remain completely type agnostic, and just defer to the input dtype if None is passed for the other types.

Also, a downside to this change is that making a module anything narrower than float32 now requires:

  • changing both dtype and param_dtype, and
  • passing the appropriate inputs.

With the original PR, only the input would have to be made narrow, and everything else would be inferred.


assert jnp.issubdtype(input_dtype, jnp.number)
if jnp.issubdtype(input_dtype, jnp.complexfloating):
assert jnp.issubdtype(returned_param_dtype, jnp.complexfloating)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

complex numbers can still be projected by a real transformation

Copy link
Contributor Author

@NeilGirdhar NeilGirdhar Mar 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure that differentiation with respect to such parameters will work in that case? Consider:

from jax import grad
from functools import partial
import jax.numpy as jnp

x = jnp.zeros(3, dtype=jnp.complex64)

def f(x, w):
    return jnp.sum(w @ x)

grad(partial(f, x), holomorphic=True)(jnp.eye(3, dtype=jnp.complex64))  # Okay.
grad(partial(f, x), holomorphic=True)(jnp.eye(3))  # TypeError: grad with holomorphic=True requires inputs with complex dtype, but got float32.

Anyway, I've removed the assertions. I guess it's the user's problem if she tries to differentiate heterogenous types.

assert jnp.issubdtype(input_dtype, jnp.number)
if jnp.issubdtype(input_dtype, jnp.complexfloating):
assert jnp.issubdtype(returned_param_dtype, jnp.complexfloating)
assert jnp.issubdtype(dtype, jnp.complexfloating)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

casting back to real is currently allowed. We should look into disabling this in JAX instead of doing it inconsistently in this PR I think

Copy link
Contributor Author

@NeilGirdhar NeilGirdhar Mar 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, good point! For some reason I thought it was a narrowing error.

I agree that we shouldn't try to fix this here. Is there a Jax or numpy issue somewhere that you know of? Maybe we should at least create one?

flax/linen/linear.py Outdated Show resolved Hide resolved
return returned_param_dtype, dtype


def _canonicalize_numeric_dtypes(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be nice to expose and document these APIs publicly so people can implement their own layers according to the spec.

I think you can have a canonicalize and canonicalize_dtype_inexact.

Also I think you should use the true dtypes of the params and not the param_dtype because users can cast params after construction

Copy link
Contributor Author

@NeilGirdhar NeilGirdhar Mar 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be nice to expose and document these APIs publicly so people can implement their own layers according to the spec.

Yes, great point.

Did you want them in a new file? I'll expose them here for now, but let me know if you want them somewhere else.

Also I think you should use the true dtypes of the params

Sorry, I'm not sure what you want me to do here? The param_dtype is an attribute of the module, which is supposed to be used to create the parameters. How can I get the "true dtype of the params" before I create them?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would either make a dtypes.py or put it in module.py

Sorry, I'm not sure what you want me to do here?

You can pass the parameter values into jnp.result_type(*inputs, *params)

Copy link
Contributor Author

@NeilGirdhar NeilGirdhar Mar 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would either make a dtypes.py or put it in module.py

Sure, I added dtypes.py since module.py is getting large.

You can pass the parameter values into jnp.result_type(*inputs, *params)

Sorry, I still don't understand what you mean. The parameters don't exist at the point that canoncalize is called. The parameter creation depends on the output of canonicalize.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the canonicalize_dtype should only infer the dtype if it is None. param_dtype defaults to float32 and cannot be inferred (so it's not Optional[Dtype] but Dtype). This way you can init the params and then use the params + inputs to infer the dtype

Copy link
Contributor Author

@NeilGirdhar NeilGirdhar Mar 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the canonicalize_dtype should only infer the dtype if it is None

Right, that's what it does.

This way you can init the params and then use the params + inputs to infer the dtype

Sorry, I'm really trying to understand you here, but I still don't see it. What is the difference between

  • passing the parameter dtype to the canonicalize function, versus
  • initializing the parameters using the parameter dtype, and passing the parameters to the canonicalize function?

Won't the exact same type inference happen either way?

param_dtype defaults to float32 and cannot be inferred (so it's not Optional[Dtype] but Dtype).

We can do that, but I don't see how that's an advantage. It just makes the canonicalize function less flexible without changing the behaviour when the parameter dtype is provided.

We can make the parameter dtype required on the modules though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The param_dtype argument is not the same as the parameter dtype. In 99% of the cases users will stick to the default float32. For example during eval the two can diverge because a user does something like:

eval_params = jax.tree_map(lambda x: x.asarray(jnp.bfloat16), params)

The reasoning behind this common trick is that small updates during training don't accumulate well in half precision but during eval the weights are static and using half precision is a free gain in most cases.

The core idea behind canonicalize dtype is to reproduce what you normally expect from an equivalent numpy function. So imagine that we wrote nn.Dense as a pure numpy function:

def dense(input, kernel, bias, dtype=None):
  dtype = dtype or jnp.result_type(input, kernel, bias)
  return input @ kernel + bias

Here the user has to provide kernel and bias so the dtypes of the params are determined by the user and don't depend on input.dtype. The default dtype is just np.result_type(input, kernel, bias).

The reason why we allow this dtype to be inferred is because it preserves the precision of the computation.
When inferring params however you are inferring behavior. e.g.: A complex linear mapping is something different than a linear real mapping even if you have a complex input. Similarly, a learned half precision dense does something very different in practice than a f32 dense layer even if the inputs are in half precision.

A second argument against a None default for param_dtype is that it would cause a big backwards incompatible change because many users really on f32 defaults for half precision inputs.

Copy link
Contributor Author

@NeilGirdhar NeilGirdhar Mar 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reasoning behind this common trick is that small updates during training don't accumulate well in half precision but during eval the weights are static and using half precision is a free gain in most cases.

Okay! I see what you're saying now.

I've started this work here, but it gets tricky when there are submodules (how do I take into account the final DenseGeneral that depends on the first set of parameters, and itself has other parameters?). I'm not sure that this approach will work in general even if it looks okay for the simple modules. I'd have to canonicalize multiple times and produce various intermediate dtypes. And it gets even more complicated with the recursive modules that have a carry. This would be a mess.

Honestly, I think recasting the parameters that were created by the modules is a hack. You're essentially bypassing the public interface and accessing the parameters, which are akin to private variables. And it's also bad because it only works when you have the dtype attribute set to None. If dtype is set, then changing the parameter dtypes has no effect on the computation.

I think a much nicer solution for people who want to do "evaluation in half precision" is to reconstruct the modules with the dtype and param_dtype attributes set to half precision, and then transform the parameters through a public interface.

A complex linear mapping is something different than a linear real mapping even if you have a complex input.

Yes that's true, but (unless very special care is taken) the parameter cotangents will still be complex, so I imagine that in most cases, complex inputs imply complex parameters.

A second argument against a None default for param_dtype is that it would cause a big backwards incompatible change because many users really on f32 defaults for half precision inputs.

Right, that's why I switched all the defaults to float32. This way, this PR just provides the ability to specify computation and parameter dtype. It doesn't change behavior.

However, I still think you should deprecate this default in another PR. As you say: "The core idea behind canonicalize dtype is to reproduce what you normally expect from an equivalent numpy function." And the equivalent numpy function always produces outputs based on its input types. It doesn't silently widen everything to float32.

Also, if someone does want half-precision (or complex, or double precision) computation throughtout their network, they will need to select it for every module. I read the FLIP, but I think things can change in the future. Double precision is already starting to be as performant as single precision on some GPUs. It would be better not to bake in defaults that you can't easily remove.

@jheek
Copy link
Member

jheek commented Mar 4, 2022

@NeilGirdhar I finally found the time to look at the changes. Please have a look.

@NeilGirdhar
Copy link
Contributor Author

@jheek Thanks for the thorough review. It shouldn't take long to make these changes, but I'd like to wait on @avital to get back to me before I do that. I emailed him a couple weeks ago, and haven't heard anything. Just keeping you in the loop!

@marcvanzee
Copy link
Collaborator

@jheek Thanks for the thorough review. It shouldn't take long to make these changes, but I'd like to wait on @avital to get back to me before I do that. I emailed him a couple weeks ago, and haven't heard anything. Just keeping you in the loop!

Just a FYI that Avital is OOO for a few weeks, so I suppose he hasn't been checking his emails.

@NeilGirdhar
Copy link
Contributor Author

Just a FYI that Avital is OOO for a few weeks, so I suppose he hasn't been checking his emails.

Ah okay! No problem, I wasn't sure. I'll try to get this done today then 😄

@NeilGirdhar
Copy link
Contributor Author

@jheek First pass of the review is done. Please let me know about the above questions when you find more time 😄

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Mar 10, 2022

I've been thinking about this some more, and I think a big source of our back-and-forth was the confusion that

kernel = self.param('kernel', self.kernel_init, kernel_shape, param_dtype)

can have a dtype other than param_dtype.

While that is true, I think this should be prevented.

The parameters and variables are created by module code that the user does not have access to. And as such the user should not be counting on any types or shapes of such parameters. All modules should be free to, in some future version of Flax, create structured parameter types like

def kernel_init(rng, shape, dtype) -> KernelParamer: ...

where KernelParameter is an arbitrary dataclass. Or they can widen or narrow these parameters as they see fit.

At heart, the problem is that there are a variety of workflows that Flax protects despite them not being crystallized in public interfaces. In our above discussion, it's the ability to change the behavior of a module by recasting the parameters and variables it has created. This workflow might be a useful trick, but I think it should be exposed in a public interface.

And are you even testing these workflows? I don't remember seeing a test where you initialize variables with a module M, transform them to have different dtypes, and then push them through apply on that original module M to verify that it still does the right thing. Allowing this workflow multiplies the amount of testing you need to do tremendously because there are all kinds of changes that a user could make to the variables.

I propose the following:

  • Add an unchecked invariant that the pytree created by Module.init and that is sent to Module.apply is not transformed by the user except through a public interface. The FrozenVariableDict should be as opaque as possible.
  • Add a public interface for transforming variables (that accomplishes the retyping among other things):
T = TypeVar('T')

class Module:
  @traceback_util.api_boundary
  def transform_variables(self,
                          old_variables: FrozenVariableDict,
                          *args,
                          f: Callable[[T, T], T],
                          method: Optional[Callable[..., Any]] = None,
                          mutable: CollectionFilter = DenyList("intermediates"),
                          **kwargs: Any) -> FrozenVariableDict:
    """Transforms the variables created by a module method and returns modified
    variables.
    """

This just does an initialize as usual, but after creating a variable new_variable, it calls f(old_variable, new_variable) to produce new variables. For arrays, this would just be jnp.astype(old_variable, new_variable.dtype), but in general it could do all kinds of transformations.

The workflow described above would then be

my_module = SomeModule(...)
variables = my_module.init(...)
# train variables using my_module...
new_module = SomeModule(...)  # with different types
new_variables = new_module.transform_variables(variables, ...)
new_module.apply(new_variables, ...)  # guaranteed to work

In short, init initializes variables, apply uses them, and transform_variables transforms them.

What do you think?

@jheek
Copy link
Member

jheek commented Mar 14, 2022

. And as such the user should not be counting on any types or shapes of such parameters.

The variable collections are far from opaque. You make assumptions about its structure when doing transformations, taking gradients, when optimizing, etc. At the minimum they must be PyTree's of JAX arrays.

This approach is much more flexible than you would think at first glance though. For example, you can "box" a param in an arbitrary dataclass and add as much metadata as you want to parameters. Still once you do a jax.tree_map another part of the code can transform it and skip over the internal metadata without a problem.

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Mar 22, 2022

@jheek

Sorry, I've been meaning to get back to you on this, but I've had a lot of work to do.

The variable collections are far from opaque. You make assumptions about its structure when doing transformations, taking gradients, when optimizing, etc. At the minimum they must be PyTree's of JAX arrays.

Yes, you're right that you can treat the variables as a PyTree. What I was trying to get at by making it as opaque as possible is to make it so that users of a module X should not make assumptions about the objects that are placed into the variables tree by X. Otherwise, the user code will break when X changes. This ensures the separation of concerns.

Users of a module X can configure its functionality by adjusting its dataclass fields. What you're proposing is a second way of changing its functionality by changing the data types of the contents of the variables.

I understand that people are already doing this, but it is absolutely horrible. It's only one extra line of code to rebuild an appropriate module with the configuration you want, and then use that. Trying to make all of the modules robust to changes in dataclass types is going to be way too much work for your team. Besides actually making this work everywhere, you will need to test this (and there are no tests).

It's also not easy to make canonicalize do this. You suggested that you can just put the array objects from the variables into the type inference. That does work for simple modules like Dense. It does not work for complex modules like the recursive network modules. These modules have submodules that depend on the results of earlier computations (that need to know the dtype), and also other variables. You'd have to collect all of the variables of all of the submodules, and there's no easy way to do that since there's no interface do that. So, even if you wanted to do it this way, you can't do it for complex modules.

Therefore, I think this is bad design. You should infer the computation dtype based on the module fields alone. If people want to cast the parameters, then they'll also have to reconstruct the module. Anyone who was relying on this undocumented behaviour will have to slightly modify their code, unfortunately.

What do you think?

@PhilipVinc
Copy link
Contributor

I hate to bump this up again but... what's the status?

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Apr 12, 2022

@PhilipVinc I was still waiting on a reply to my last comment. Unfortunately, I've decided not to spend any more time working on any flax pull requests. You're welcome to lift my code into a pull request of your own if you like.

@jheek
Copy link
Member

jheek commented Apr 14, 2022

@PhilipVinc @NeilGirdhar I'm taking over the implementation of the default dtype FLIP

@andsteing andsteing added the Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required) label Apr 26, 2022
@NeilGirdhar NeilGirdhar closed this Jun 1, 2022
@NeilGirdhar NeilGirdhar deleted the default_dtypes branch February 21, 2024 16:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required) pull ready
Projects
None yet
Development

Successfully merging this pull request may close these issues.

FLIP: default dtype
5 participants