Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switched from __dataclass_transform__() to typing.dataclass_transform() #1158

Merged
merged 7 commits into from
Jul 6, 2023
Merged
3 changes: 3 additions & 0 deletions changelog.d/1158.change.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Type stubs now use `typing.dataclass_transform` to decorate dataclass-like
decorators, instead of the non-standard `__dataclass_transform__` special
form, which is only supported by pyright.
19 changes: 2 additions & 17 deletions docs/extending.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,31 +94,16 @@ You can only use this trick to tell *Mypy* that a class is actually an *attrs* c

### Pyright

Generic decorator wrapping is supported in [*Pyright*](https:/microsoft/pyright) via `dataclass_transform` / {pep}`689`.
Generic decorator wrapping is supported in [*Pyright*](https:/microsoft/pyright) via `typing.dataclass_transform` / {pep}`689`.

For a custom wrapping of the form:

```
@typing.dataclass_transform(field_specifiers=(attr.attrib, attr.field))
def custom_define(f):
return attr.define(f)
```

This is implemented via a `__dataclass_transform__` type decorator in the custom extension's `.pyi` of the form:

```
def __dataclass_transform__(
*,
eq_default: bool = True,
order_default: bool = False,
kw_only_default: bool = False,
field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
) -> Callable[[_T], _T]: ...

@__dataclass_transform__(field_descriptors=(attr.attrib, attr.field))
def custom_define(f): ...
```


## Types

*attrs* offers two ways of attaching type information to attributes:
Expand Down
38 changes: 11 additions & 27 deletions src/attr/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ if sys.version_info >= (3, 10):
else:
from typing_extensions import TypeGuard

if sys.version_info >= (3, 11):
from typing import dataclass_transform
else:
from typing_extensions import dataclass_transform

__version__: str
__version_info__: VersionInfo
__title__: str
Expand Down Expand Up @@ -103,23 +108,6 @@ else:
takes_self: bool = ...,
) -> _T: ...

# Static type inference support via __dataclass_transform__ implemented as per:
# https:/microsoft/pyright/blob/1.1.135/specs/dataclass_transforms.md
# This annotation must be applied to all overloads of "define" and "attrs"
#
# NOTE: This is a typing construct and does not exist at runtime. Extensions
# wrapping attrs decorators should declare a separate __dataclass_transform__
# signature in the extension module using the specification linked above to
# provide pyright support.
def __dataclass_transform__(
*,
eq_default: bool = True,
order_default: bool = False,
kw_only_default: bool = False,
frozen_default: bool = False,
field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
) -> Callable[[_T], _T]: ...

class Attribute(Generic[_T]):
name: str
default: Optional[_T]
Expand Down Expand Up @@ -322,7 +310,7 @@ def field(
type: Optional[type] = ...,
) -> Any: ...
@overload
@__dataclass_transform__(order_default=True, field_descriptors=(attrib, field))
@dataclass_transform(order_default=True, field_specifiers=(attrib, field))
def attrs(
maybe_cls: _C,
these: Optional[Dict[str, Any]] = ...,
Expand Down Expand Up @@ -350,7 +338,7 @@ def attrs(
unsafe_hash: Optional[bool] = ...,
) -> _C: ...
@overload
@__dataclass_transform__(order_default=True, field_descriptors=(attrib, field))
@dataclass_transform(order_default=True, field_specifiers=(attrib, field))
def attrs(
maybe_cls: None = ...,
these: Optional[Dict[str, Any]] = ...,
Expand Down Expand Up @@ -378,7 +366,7 @@ def attrs(
unsafe_hash: Optional[bool] = ...,
) -> Callable[[_C], _C]: ...
@overload
@__dataclass_transform__(field_descriptors=(attrib, field))
@dataclass_transform(field_specifiers=(attrib, field))
def define(
maybe_cls: _C,
*,
Expand All @@ -404,7 +392,7 @@ def define(
match_args: bool = ...,
) -> _C: ...
@overload
@__dataclass_transform__(field_descriptors=(attrib, field))
@dataclass_transform(field_specifiers=(attrib, field))
def define(
maybe_cls: None = ...,
*,
Expand Down Expand Up @@ -433,9 +421,7 @@ def define(
mutable = define

@overload
@__dataclass_transform__(
frozen_default=True, field_descriptors=(attrib, field)
)
@dataclass_transform(frozen_default=True, field_specifiers=(attrib, field))
def frozen(
maybe_cls: _C,
*,
Expand All @@ -461,9 +447,7 @@ def frozen(
match_args: bool = ...,
) -> _C: ...
@overload
@__dataclass_transform__(
frozen_default=True, field_descriptors=(attrib, field)
)
@dataclass_transform(frozen_default=True, field_specifiers=(attrib, field))
def frozen(
maybe_cls: None = ...,
*,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_pyright.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def parse_pyright_output(test_file: Path) -> set[PyrightDiagnostic]:

def test_pyright_baseline():
"""
The __dataclass_transform__ decorator allows pyright to determine attrs
decorated class types.
The typing.dataclass_transform decorator allows pyright to determine
attrs decorated class types.
"""

test_file = Path(__file__).parent / "dataclass_transform_example.py"
Expand Down