Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reconsider sum/prod/trace upcasting for floating-point dtypes #731

Closed
rgommers opened this issue Jan 17, 2024 · 16 comments · Fixed by #744
Closed

Reconsider sum/prod/trace upcasting for floating-point dtypes #731

rgommers opened this issue Jan 17, 2024 · 16 comments · Fixed by #744
Labels
topic: Type Promotion Type promotion.

Comments

@rgommers
Copy link
Member

The requirement to upcast sum(x) to the default floating-point dtype with the default dtype=None currently says (from the sum spec):

If x has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type.

The rationale given is "keyword argument is intended to help prevent data type overflows.". This came up again in the review of NEP 56 (numpy/numpy#25542), and is basically the only part of the standard that was flagged as problematic and explicitly rejected.

I agree that the standard's choice here is problematic, at least from a practical perspective: no array library does this, and none are planning to implement this. And the rationale is pretty weak, it just does not apply to floating-point dtypes to a similar extent as it does to integer dtypes (and for integers, array libraries do implement the upcasting). Examples:

>>> # NumPy:
>>> np.sum(np.ones(3, dtype=np.float32)).dtype
dtype('float32')
>>> np.sum(np.ones(3, dtype=np.int32)).dtype
dtype('int64')

>>> # PyTorch:
>>> torch.sum(torch.ones(2, dtype=torch.bfloat16)).dtype
torch.bfloat16
>>> torch.sum(torch.ones(2, dtype=torch.int16)).dtype
torch.int64

>>> # JAX:
>>> jnp.sum(jnp.ones(4, dtype=jnp.float16)).dtype
dtype('float16')
>>> jnp.sum(jnp.ones(4, dtype=jnp.int16)).dtype
dtype('int32')

>>> # CuPy:
>>> cp.sum(cp.ones(5, dtype=cp.float16)).dtype
dtype('float16')
>>> cp.sum(cp.ones(5, dtype=cp.int32)).dtype
dtype('int64')

>>> # Dask:
>>> da.sum(da.ones(6, dtype=np.float32)).dtype
dtype('float32')
>>> da.sum(da.ones(6, dtype=np.int32)).dtype
dtype('int64')
>>> 

The most relevant conversation is #238 (comment). There was some further minor tweaks (without much discussion) in gh-666.

Proposed resolution: align the standard with what all known array libraries implement today.

@seberg
Copy link
Contributor

seberg commented Jan 17, 2024

As I mentioned a few times before, I agree with not specifying this. Partly, because I think it is just asking too much from NumPy (and apparently others).
But even from an Array API perspective I think it isn't helpful, because "default type" also is just "unspecified" effectively (if you sum a float32 array, you don't if you get a float32 or float64).

@rgommers
Copy link
Member Author

Partly, because I think it is just asking too much from NumPy (and apparently others).

Since all libraries appear to do exactly the same as of today, what's the problem with encoding that? Nothing is being asked from anyone at that point, it's basically just documenting the status quo.

@seberg
Copy link
Contributor

seberg commented Jan 17, 2024

I wouldn't have been surprised if someone upcast for float16, but if not then fine. Mainly, I am not sure I mind the old proposal if it was coming from scratch, so I don't have an opinion about allowing it (i.e. not caring that the result may have higher precision).

@mhvk
Copy link

mhvk commented Jan 17, 2024

My2¢ is that it is good to codify the current behaviour for the various float. It is really surprising if the dtype of a reduction depends on the operation.

p.s. Indeed, I think this is true even for integers. At least, to me, the following is neither logical nor expected:

In [17]: np.subtract.reduce(np.arange(4, dtype='i2')).dtype
Out[17]: dtype('int16')

In [18]: np.add.reduce(np.arange(4, dtype='i2')).dtype
Out[18]: dtype('int64')

Explicit is better than implicit and all that. And for reductions, it might be quite reasonable to do the operation at higher precision and check for overflow before downcasting.

@asmeurer
Copy link
Member

And for reductions, it might be quite reasonable to do the operation at higher precision and check for overflow before downcasting.

If you're suggesting a value-based result type, that's even worse. That's the sort of thing we're trying to get away from with the standard.

@asmeurer
Copy link
Member

The rationale given is "keyword argument is intended to help prevent data type overflows.". This came up again in the review of NEP 56 (numpy/numpy#25542), and is basically the only part of the standard that was flagged as problematic and explicitly rejected.

That PR discussion is huge and you didn't point to a specific comment, so I don't know what was already said. But it makes sense to me to treat floats different from ints because floats give inf when they overflow, which is a very clear indication to the user that they need to manually upcast.

@rgommers
Copy link
Member Author

That PR discussion is huge and you didn't point to a specific comment, so I don't know what was already said

There's several comments on it. The main one is this comment. Then it got also mixed in with the comment on in-place operator behavior in this comment. And in this comment @seberg said "(I explicitly did not give a thumbs-up for the type promotion changes in that meeting)" (type promotion meaning the sum/prod ones).

I did write it down as one requirement among many (I didn't quite agree with what I wrote myself, but forgot to revisit), it didn't stand out in the text. It's telling that it was flagged quickly by both @seberg and @mhvk as too impactful to change.

And for reductions, it might be quite reasonable to do the operation at higher precision and check for overflow before downcasting.

If you're suggesting a value-based result type, that's even worse. That's the sort of thing we're trying to get away from with the standard.

Internal upcasting is regularly done, and perfectly fine. I assume the intent was "warn or raise on integer overflow", rather than value-based casting.

@mhvk
Copy link

mhvk commented Jan 17, 2024

If you're suggesting a value-based result type, that's even worse. That's the sort of thing we're trying to get away from with the standard.

No, not a different type, that would be awful indeed! But an over/underflow error/warning, just like we can currently get for floating point operations. For regular ufuncs, that is too much of a performance hit, but for reductions, it should not be. And reductions are a bit special already since it definitely makes sense to do things at higher precision internally, before casting back to the original precision.

@asmeurer
Copy link
Member

p.s. Indeed, I think this is true even for integers. At least, to me, the following is neither logical nor expected:

ufunc.reduce is not part of the standard, so it's not really relevant here, but FWIW, I agree with you that it's quite surprising for ufunc.reduce to not return the same dtype as the ufunc itself in some cases. I think of ufunc methods as being somewhat "low-level" things that shouldn't try to be overly smart, at least in terms of behavior (ufunc-specific performance optimizations are another thing).

OTOH sum is a distinct function from add and is a more of an end-user function, so I don't know if the argument applies there.

@leofang
Copy link
Contributor

leofang commented Jan 23, 2024

Proposed resolution: align the standard with what all known array libraries implement today.

@rgommers What would the new wording that you seek to change to?

@rgommers
Copy link
Member Author

The current wording that is problematic is:

  • if x has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type.
  • if x has a complex floating-point data type, the returned array must have the default complex floating-point data type.

I suggest a change like this:

  • if x has a real-valued or complex floating-point data type, the returned array should have either the same dtype as x (recommended) or a higher-precision dtype of the same kind as the dtype of x

This loosens the spec, recommends what the current behavior of all known libraries is, and still allows upcasting if an implementation desires to do so for (reasons).

@seberg
Copy link
Contributor

seberg commented Jan 25, 2024

the returned array should have either the same dtype as x (recommended) or a higher-precision dtype of the same kind as the dtype of x

Thanks, looks good to me. Maybe it would be slightly clearer to replace the or ... with a new sentence: If it is not the same dtype it must be a higher-precision...?
(because if you apply the should also to the "or ..." part, it would be a must)

EDIT: Or just replace the should with a must, to me it seems to apply to the full construct, so must is correct and the (recommended) already includes the "should" part.

@asmeurer
Copy link
Member

asmeurer commented Jan 26, 2024

I personally don't see value in hedging with "recommended" or "should" if no one actually does that now and we don't even have a concrete reason for anyone to do so. It feels like our only real rationale is some misunderstanding in the original discussion about int dtypes. Not being precise about dtypes has disadvantages. For instance, it makes it harder to reason about dtypes statically (#728). Everywhere else in the standard uses "must" for output dtypes (correct me if I'm wrong on that).

And I also disagree that upcasting is not a big deal. When you're explicitly using a lower precision float silent or unexpected upcasting can have a very real performance impact. Here's an example where fixing an unexpected float64 upcast made some code 5x faster jaymody/picoGPT#12.

@seberg
Copy link
Contributor

seberg commented Jan 27, 2024

I am fine with being strict here and saying it must be the same: it is the only version that I see giving any clarity to libraries supporting multiple implementations (which is my main emphasis here always, compared to thinking about the ideal end-user API).

But, there must have been some feeling of float16 and float32 having loss of precision quickly and that users need protecting, so that this ended up written down. And I am happy to accept the opinion that it may be a reasonable choice for end-users.
Although, I can see the argument that this is really about intermediate precision of the reduction, an argument that could even be made for integers: So long as you detect the overflow (by aggregating at high/arbitrary precision), forcing users to upcast manually isn't that terrible.

treat floats different from ints because floats give inf when they overflow, which is a very clear indication to the user that they need to manually upcast.

N.B.: To clarify, for summation overflows are actuall not the main issue! The issue is extreme loss of precision unless you have a high precision intermediate (at least float64). If you sum float32(1.) naively the result just caps around 2**23 == 8388608 and you may see a decent amount of loss earlier. A million elements are not that odd in many contexts.

rgommers added a commit to rgommers/array-api that referenced this issue Feb 10, 2024
Closes data-apisgh-731. This is a backwards-incompatible change. It seems
justified and necessary because array libraries all behave like the
required behavior described in this commit, are not planning to change,
and the initial rationale for the "upcast float" requirement wasn't
strong. See discussion in data-apisgh-731 for more details.
@rgommers
Copy link
Member Author

Okay, seems like there is support for "must", and I agree that that is nicer. PR for that: gh-744.

kgryte added a commit that referenced this issue Feb 13, 2024
This commit modifies type promotion behavior in `sum`, `prod`, `cumulative_sum`, and `linalg.trace` when the input array has a floating-point data type. Previously, the specification required that conforming implementations upcast to the default floating-point data type when the input array data type was of a lower precision. This commit revises that guidance to require conforming libraries return an array having the same data type as the input array. This revision stems from feedback from implementing libraries, where the current status quo matches the changes in this commit, with little desire to change. As such, the specification is amended to match this reality.

Closes: #731
PR-URL: 	#744
Co-authored-by: Athan Reines <[email protected]>
Reviewed-by: Athan Reines <[email protected]>
@asmeurer
Copy link
Member

N.B.: To clarify, for summation overflows are actuall not the main issue! The issue is extreme loss of precision unless you have a high precision intermediate (at least float64). If you sum float32(1.) naively the result just caps around 2**23 == 8388608 and you may see a decent amount of loss earlier. A million elements are not that odd in many contexts.

This can be solved by using a higher intermediate precision, or by using a smarter summation algorithm. My point is that the only reason you'd need a higher result precision is if there is an overflow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic: Type Promotion Type promotion.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants