Skip to content

Commit

Permalink
illustrate how to propagate type information through various asliases…
Browse files Browse the repository at this point in the history
… to get a correct type error
  • Loading branch information
glyph committed Sep 19, 2024
1 parent d7c16de commit 634edbd
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 121 deletions.
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
python_requires=">=3.7",
install_requires=[
"incremental",
"requests >= 2.1.0",
"hyperlink >= 21.0.0",
"Twisted[tls] >= 22.10.0", # For #11635
"attrs",
Expand Down
2 changes: 2 additions & 0 deletions src/treq/_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) The treq Authors.
# See LICENSE for details.
import io
from .cookies import TreqieJar
from http.cookiejar import CookieJar
from typing import Any, Dict, Iterable, List, Mapping, Tuple, Union

Expand Down Expand Up @@ -48,6 +49,7 @@ class _ITreqReactor(IReactorTCP, IReactorTime, IReactorPluggableNameResolver):
]

_CookiesType = Union[
TreqieJar,
CookieJar,
Mapping[str, str],
]
Expand Down
15 changes: 15 additions & 0 deletions src/treq/api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,23 @@
from __future__ import absolute_import, division, print_function

from typing import Callable, Concatenate, ParamSpec, TypeVar

from twisted.web.client import Agent, HTTPConnectionPool

from treq._types import _URLType
from treq.client import HTTPClient

P = ParamSpec("P")
R = TypeVar("R")


def _like(
method: Callable[Concatenate[HTTPClient, _URLType, P], R]
) -> Callable[
[Callable[Concatenate[_URLType, P], R]], Callable[Concatenate[_URLType, P], R]
]:
return lambda x: x


def head(url, **kwargs):
"""
Expand All @@ -14,6 +28,7 @@ def head(url, **kwargs):
return _client(kwargs).head(url, _stacklevel=4, **kwargs)


@_like(HTTPClient.get)
def get(url, headers=None, **kwargs):
"""
Make a ``GET`` request.
Expand Down
56 changes: 42 additions & 14 deletions src/treq/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# -*- test-case-name: treq.test.test_client -*-
from __future__ import annotations
import io
import mimetypes
import uuid
Expand All @@ -7,20 +9,21 @@
from typing import (
Any,
Callable,
Concatenate,
Iterable,
Iterator,
List,
Mapping,
Optional,
ParamSpec,
Tuple,
TypeVar,
Union,
)
from urllib.parse import quote_plus
from urllib.parse import urlencode as _urlencode

from hyperlink import DecodedURL, EncodedURL
from requests.cookies import merge_cookies
from treq.cookies import scoped_cookie
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IProtocol
from twisted.python.components import proxyForInterface, registerAdapter
Expand Down Expand Up @@ -50,8 +53,14 @@
_URLType,
)
from treq.auth import add_auth
from treq.cookies import scoped_cookie
from treq.response import _Response

from .cookies import TreqieJar

P = ParamSpec("P")
R = TypeVar("R")


class _Nothing:
"""Type of the sentinel `_NOTHING`"""
Expand All @@ -67,22 +76,28 @@ def urlencode(query: _ParamsType, doseq: bool) -> bytes:

def _scoped_cookiejar_from_dict(
url_object: EncodedURL, cookie_dict: Optional[Mapping[str, str]]
) -> CookieJar:
) -> TreqieJar:
"""
Create a CookieJar from a dictionary whose cookies are all scoped to the
given URL's origin.
@note: This does not scope the cookies to any particular path, only the
host, port, and scheme of the given URL.
"""
cookie_jar = CookieJar()
cookie_jar = TreqieJar()
if cookie_dict is None:
return cookie_jar
for k, v in cookie_dict.items():
cookie_jar.set_cookie(scoped_cookie(url_object, k, v))
return cookie_jar


def _merge_cookies(left: TreqieJar, right: CookieJar) -> TreqieJar:
for cookie in right:
left.set_cookie(cookie)
return left


class _BodyBufferingProtocol(proxyForInterface(IProtocol)): # type: ignore
def __init__(self, original, buffer, finished):
self.original = original
Expand Down Expand Up @@ -130,26 +145,29 @@ def deliverBody(self, protocol):
self._waiters.append(protocol)


P2 = ParamSpec("P2")


def _like(c: Callable[Concatenate[HTTPClient, str, _URLType, P], R]) -> Callable[
[Callable[Concatenate[HTTPClient, _URLType, P], R]],
Callable[Concatenate[HTTPClient, _URLType, P], R],
]:
return lambda x: x


class HTTPClient:
def __init__(
self,
agent: IAgent,
cookiejar: Optional[CookieJar] = None,
cookiejar: Optional[TreqieJar] = None,
data_to_body_producer: Callable[[Any], IBodyProducer] = IBodyProducer,
) -> None:
self._agent = agent
if cookiejar is None:
cookiejar = CookieJar()
cookiejar = TreqieJar()
self._cookiejar = cookiejar
self._data_to_body_producer = data_to_body_producer

def get(self, url: _URLType, **kwargs: Any) -> "Deferred[_Response]":
"""
See :func:`treq.get()`.
"""
kwargs.setdefault("_stacklevel", 3)
return self.request("GET", url, **kwargs)

def put(
self, url: _URLType, data: Optional[_DataType] = None, **kwargs: Any
) -> "Deferred[_Response]":
Expand Down Expand Up @@ -246,7 +264,7 @@ def request(
if not isinstance(cookies, CookieJar):
cookies = _scoped_cookiejar_from_dict(parsed_url, cookies)

merge_cookies(self._cookiejar, cookies)
_merge_cookies(self._cookiejar, cookies)
wrapped_agent: IAgent = CookieAgent(self._agent, self._cookiejar)

if allow_redirects:
Expand Down Expand Up @@ -283,6 +301,16 @@ def gotResult(result):

return d.addCallback(_Response, self._cookiejar)

@_like(request)
def get(self, url: _URLType, **kwargs: Any) -> "Deferred[_Response]":
"""
See :func:`treq.get()`.
"""
kwargs.setdefault("_stacklevel", 3)
return self.request("GET", url, **kwargs)

reveal_type(get)

def _request_headers(
self, headers: Optional[_HeadersType], stacklevel: int
) -> Headers:
Expand Down
9 changes: 9 additions & 0 deletions src/treq/cookies.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- test-case-name: treq.test.test_integration -*-
"""
Convenience helpers for :mod:`http.cookiejar`
"""
Expand All @@ -8,6 +9,14 @@
from hyperlink import EncodedURL


class TreqieJar(CookieJar):
def __getitem__(self, name: str) -> str:
for cookie in self:
if cookie.name == name and cookie.value is not None:
return cookie.value
raise KeyError(name)


def scoped_cookie(origin: Union[str, EncodedURL], name: str, value: str) -> Cookie:
"""
Create a cookie scoped to a given URL's origin.
Expand Down
9 changes: 2 additions & 7 deletions src/treq/response.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Any, Callable, List
from requests.cookies import cookiejar_from_dict
from http.cookiejar import CookieJar
from twisted.internet.defer import Deferred
from twisted.python import reflect
Expand All @@ -16,7 +15,7 @@ class _Response(proxyForInterface(IResponse)): # type: ignore
"""

original: IResponse
_cookiejar: CookieJar
_cookiejar: TreqieJar

def __init__(self, original: IResponse, cookiejar: CookieJar):
self.original = original
Expand Down Expand Up @@ -107,11 +106,7 @@ def cookies(self) -> CookieJar:
"""
Get a copy of this response's cookies.
"""
# NB: This actually returns a RequestsCookieJar, but we type it as a
# regular CookieJar because we want to ditch requests as a dependency.
# Full deprecation deprecation will require a subclass or wrapper that
# warns about the RequestCookieJar extensions.
jar: CookieJar = cookiejar_from_dict({})
jar = CookieJar()

for cookie in self._cookiejar:
jar.set_cookie(cookie)
Expand Down
15 changes: 8 additions & 7 deletions src/treq/test/test_cookies.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from http.cookiejar import CookieJar, Cookie
from http.cookiejar import Cookie, CookieJar

import attrs
from twisted.internet.testing import StringTransport
from treq._agentspy import RequestRecord, agent_spy
from treq.client import HTTPClient
from treq.cookies import scoped_cookie, search
from twisted.internet.interfaces import IProtocol
from twisted.trial.unittest import SynchronousTestCase
from twisted.internet.testing import StringTransport
from twisted.python.failure import Failure
from twisted.trial.unittest import SynchronousTestCase
from twisted.web.client import ResponseDone
from twisted.web.http_headers import Headers
from twisted.web.iweb import IClientRequest, IResponse
from zope.interface import implementer

from treq._agentspy import agent_spy, RequestRecord
from treq.client import HTTPClient
from treq.cookies import scoped_cookie, search
from ..cookies import TreqieJar


@implementer(IClientRequest)
Expand Down Expand Up @@ -135,7 +136,7 @@ class HTTPClientCookieTests(SynchronousTestCase):

def setUp(self) -> None:
self.agent, self.requests = agent_spy()
self.cookiejar = CookieJar()
self.cookiejar = TreqieJar()
self.client = HTTPClient(self.agent, self.cookiejar)

def test_cookies_in_jars(self) -> None:
Expand Down
Loading

0 comments on commit 634edbd

Please sign in to comment.