diff --git a/CHANGES/4736.bugfix b/CHANGES/4736.bugfix new file mode 100644 index 00000000000..8c562571d6b --- /dev/null +++ b/CHANGES/4736.bugfix @@ -0,0 +1,2 @@ +Improve typing annotations for ``web.Request``, ``aiohttp.ClientResponse`` and +``multipart`` module. diff --git a/aiohttp/client.py b/aiohttp/client.py index 27ffb261555..7778d2cb877 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -522,25 +522,25 @@ async def _request( resp.release() try: - r_url = URL( + parsed_url = URL( r_url, encoded=not self._requote_redirect_url) except ValueError: raise InvalidURL(r_url) - scheme = r_url.scheme + scheme = parsed_url.scheme if scheme not in ('http', 'https', ''): resp.close() raise ValueError( 'Can redirect only to http or https') elif not scheme: - r_url = url.join(r_url) + parsed_url = url.join(parsed_url) - if url.origin() != r_url.origin(): + if url.origin() != parsed_url.origin(): auth = None headers.pop(hdrs.AUTHORIZATION, None) - url = r_url + url = parsed_url params = None resp.release() continue @@ -737,10 +737,10 @@ async def _ws_connect( headers=resp.headers) # key calculation - key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, '') + r_key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, '') match = base64.b64encode( hashlib.sha1(sec_key + WS_KEY).digest()).decode() - if key != match: + if r_key != match: raise WSServerHandshakeError( resp.request_info, resp.history, @@ -780,15 +780,16 @@ async def _ws_connect( conn = resp.connection assert conn is not None - proto = conn.protocol - assert proto is not None + conn_proto = conn.protocol + assert conn_proto is not None transport = conn.transport assert transport is not None reader = FlowControlDataQueue( - proto, limit=2 ** 16, loop=self._loop) # type: FlowControlDataQueue[WSMessage] # noqa - proto.set_parser(WebSocketReader(reader, max_msg_size), reader) + conn_proto, limit=2 ** 16, loop=self._loop) # type: FlowControlDataQueue[WSMessage] # noqa + conn_proto.set_parser( + WebSocketReader(reader, max_msg_size), reader) writer = WebSocketWriter( - proto, transport, use_mask=True, + conn_proto, transport, use_mask=True, compress=compress, notakeover=notakeover) except BaseException: resp.close() diff --git a/aiohttp/client_exceptions.py b/aiohttp/client_exceptions.py index 4e17af53393..162993aec21 100644 --- a/aiohttp/client_exceptions.py +++ b/aiohttp/client_exceptions.py @@ -3,7 +3,7 @@ import asyncio from typing import TYPE_CHECKING, Any, Optional, Tuple, Union -from .typedefs import _CIMultiDict +from .typedefs import LooseHeaders try: import ssl @@ -22,7 +22,6 @@ else: RequestInfo = ClientResponse = ConnectionKey = None - __all__ = ( 'ClientError', @@ -55,7 +54,7 @@ def __init__(self, request_info: RequestInfo, history: Tuple[ClientResponse, ...], *, status: Optional[int]=None, message: str='', - headers: Optional[_CIMultiDict]=None) -> None: + headers: Optional[LooseHeaders]=None) -> None: self.request_info = request_info if status is not None: self.status = status diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index 74f05976215..b0829b6f59c 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -25,6 +25,7 @@ Callable, Dict, Generator, + Generic, Iterable, Iterator, List, @@ -66,6 +67,11 @@ except ImportError: from typing_extensions import ContextManager +if PY_38: + from typing import Protocol +else: + from typing_extensions import Protocol # type: ignore + def all_tasks( loop: Optional[asyncio.AbstractEventLoop] = None @@ -79,6 +85,7 @@ def all_tasks( _T = TypeVar('_T') +_S = TypeVar('_S') sentinel = object() # type: Any @@ -382,7 +389,11 @@ def is_expected_content_type(response_content_type: str, return expected_content_type in response_content_type -class reify: +class _TSelf(Protocol): + _cache: Dict[str, Any] + + +class reify(Generic[_T]): """Use as a class method decorator. It operates almost exactly like the Python `@property` decorator, but it puts the result of the method it decorates into the instance dict after the first call, @@ -391,12 +402,12 @@ class reify: """ - def __init__(self, wrapped: Callable[..., Any]) -> None: + def __init__(self, wrapped: Callable[..., _T]) -> None: self.wrapped = wrapped self.__doc__ = wrapped.__doc__ self.name = wrapped.__name__ - def __get__(self, inst: Any, owner: Any) -> Any: + def __get__(self, inst: _TSelf, owner: Optional[Type[Any]] = None) -> _T: try: try: return inst._cache[self.name] @@ -409,7 +420,7 @@ def __get__(self, inst: Any, owner: Any) -> Any: return self raise - def __set__(self, inst: Any, value: Any) -> None: + def __set__(self, inst: _TSelf, value: _T) -> None: raise AttributeError("reified property is read-only") diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index ab11938f5a7..6b845a10b1e 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -152,7 +152,8 @@ def _websocket_mask_python(mask: bytes, data: bytearray) -> None: _WS_EXT_RE_SPLIT = re.compile(r'permessage-deflate([^,]+)?') -def ws_ext_parse(extstr: str, isserver: bool=False) -> Tuple[int, bool]: +def ws_ext_parse(extstr: Optional[str], + isserver: bool=False) -> Tuple[int, bool]: if not extstr: return 0, False diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index aeee338b1ee..777ad314a53 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -39,6 +39,7 @@ set_result, ) from .http_parser import RawRequestMessage +from .http_writer import HttpVersion from .multipart import BodyPartReader, MultipartReader from .streams import EmptyStreamReader, StreamReader from .typedefs import ( @@ -343,7 +344,7 @@ def method(self) -> str: return self._method @reify - def version(self) -> Tuple[int, int]: + def version(self) -> HttpVersion: """Read only property for getting HTTP version of request. Returns aiohttp.protocol.HttpVersion instance. @@ -434,7 +435,7 @@ def raw_headers(self) -> RawHeaders: return self._message.raw_headers @staticmethod - def _http_date(_date_str: str) -> Optional[datetime.datetime]: + def _http_date(_date_str: Optional[str]) -> Optional[datetime.datetime]: """Process a date string, return a datetime object """ if _date_str is not None: @@ -618,6 +619,7 @@ async def post(self) -> 'MultiDictProxy[Union[str, bytes, FileField]]': field_ct = field.headers.get(hdrs.CONTENT_TYPE) if isinstance(field, BodyPartReader): + assert field.name is not None if field.filename and field_ct: # store file in temp file tmp = tempfile.TemporaryFile() diff --git a/aiohttp/web_urldispatcher.py b/aiohttp/web_urldispatcher.py index 7712ee3c6d5..c4d7ed9ca14 100644 --- a/aiohttp/web_urldispatcher.py +++ b/aiohttp/web_urldispatcher.py @@ -264,7 +264,7 @@ async def _default_expect_handler(request: Request) -> None: Just send "100 Continue" to client. raise HTTPExpectationFailed if value of header is not "100-continue" """ - expect = request.headers.get(hdrs.EXPECT) + expect = request.headers.get(hdrs.EXPECT, "") if request.version == HttpVersion11: if expect.lower() == "100-continue": await request.writer.write(b"HTTP/1.1 100 Continue\r\n\r\n") @@ -749,7 +749,9 @@ def validation(self, domain: str) -> str: async def match(self, request: Request) -> bool: host = request.headers.get(hdrs.HOST) - return host and self.match_domain(host) + if not host: + return False + return self.match_domain(host) def match_domain(self, host: str) -> bool: return host.lower() == self._domain