Skip to content

Commit

Permalink
[PR #8482/62173be backport][3.10] Fix incorrect rejection of ws:// an…
Browse files Browse the repository at this point in the history
…d wss:// urls (#8511)

Co-authored-by: pre-commit-ci[bot]
Co-authored-by: Sviatoslav Sydorenko (Святослав Сидоренко)
Co-authored-by: J. Nick Koston <[email protected]>
Co-authored-by: Sam Bull <[email protected]>
Co-authored-by: AraHaan <[email protected]>
  • Loading branch information
3 people authored Jul 17, 2024
1 parent 1caebc9 commit c12a143
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 19 deletions.
2 changes: 2 additions & 0 deletions CHANGES/8481.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fixed the incorrect rejection of ``ws://`` and ``wss://`` urls
-- by :user:` AraHaan`.
4 changes: 3 additions & 1 deletion aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ class ClientTimeout:
# https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2
IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "OPTIONS", "TRACE", "PUT", "DELETE"})
HTTP_SCHEMA_SET = frozenset({"http", "https", ""})
WS_SCHEMA_SET = frozenset({"ws", "wss"})
ALLOWED_PROTOCOL_SCHEMA_SET = HTTP_SCHEMA_SET | WS_SCHEMA_SET

_RetType = TypeVar("_RetType")
_CharsetResolver = Callable[[ClientResponse, bytes], str]
Expand Down Expand Up @@ -505,7 +507,7 @@ async def _request(
except ValueError as e:
raise InvalidUrlClientError(str_or_url) from e

if url.scheme not in HTTP_SCHEMA_SET:
if url.scheme not in ALLOWED_PROTOCOL_SCHEMA_SET:
raise NonHttpUrlClientError(url)

skip_headers = set(self._skip_auto_headers)
Expand Down
31 changes: 30 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import asyncio
import base64
import os
import socket
import ssl
import sys
from hashlib import md5, sha256
from hashlib import md5, sha1, sha256
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any
from unittest import mock
from uuid import uuid4

import pytest

from aiohttp.http import WS_KEY
from aiohttp.test_utils import loop_context

try:
Expand Down Expand Up @@ -168,6 +171,17 @@ def pipe_name():
return name


@pytest.fixture
def create_mocked_conn(loop: Any):
def _proto_factory(conn_closing_result=None, **kwargs):
proto = mock.Mock(**kwargs)
proto.closed = loop.create_future()
proto.closed.set_result(conn_closing_result)
return proto

yield _proto_factory


@pytest.fixture
def selector_loop():
policy = asyncio.WindowsSelectorEventLoopPolicy()
Expand Down Expand Up @@ -208,3 +222,18 @@ def start_connection():
spec_set=True,
) as start_connection_mock:
yield start_connection_mock


@pytest.fixture
def key_data():
return os.urandom(16)


@pytest.fixture
def key(key_data: Any):
return base64.b64encode(key_data)


@pytest.fixture
def ws_key(key: Any):
return base64.b64encode(sha1(key + WS_KEY).digest()).decode()
56 changes: 55 additions & 1 deletion tests/test_client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,61 @@ async def create_connection(req, traces, timeout):
c.__del__()


async def test_cookie_jar_usage(loop, aiohttp_client) -> None:
@pytest.mark.parametrize("protocol", ["http", "https", "ws", "wss"])
async def test_ws_connect_allowed_protocols(
create_session: Any,
create_mocked_conn: Any,
protocol: str,
ws_key: Any,
key_data: Any,
) -> None:
resp = mock.create_autospec(aiohttp.ClientResponse)
resp.status = 101
resp.headers = {
hdrs.UPGRADE: "websocket",
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.url = URL(f"{protocol}://example.com")
resp.cookies = SimpleCookie()
resp.start = mock.AsyncMock()

req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True)
req_factory = mock.Mock(return_value=req)
req.send = mock.AsyncMock(return_value=resp)

session = await create_session(request_class=req_factory)

connections = []
original_connect = session._connector.connect

async def connect(req, traces, timeout):
conn = await original_connect(req, traces, timeout)
connections.append(conn)
return conn

async def create_connection(req, traces, timeout):
return create_mocked_conn()

connector = session._connector
with mock.patch.object(connector, "connect", connect), mock.patch.object(
connector, "_create_connection", create_connection
), mock.patch.object(connector, "_release"), mock.patch(
"aiohttp.client.os"
) as m_os:
m_os.urandom.return_value = key_data
await session.ws_connect(f"{protocol}://example.com")

# normally called during garbage collection. triggers an exception
# if the connection wasn't already closed
for c in connections:
c.close()
del c

await session.close()


async def test_cookie_jar_usage(loop: Any, aiohttp_client: Any) -> None:
req_url = None

jar = mock.Mock()
Expand Down
18 changes: 2 additions & 16 deletions tests/test_client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import base64
import hashlib
import os
from typing import Any
from unittest import mock

import pytest
Expand All @@ -13,22 +14,7 @@
from aiohttp.test_utils import make_mocked_coro


@pytest.fixture
def key_data():
return os.urandom(16)


@pytest.fixture
def key(key_data):
return base64.b64encode(key_data)


@pytest.fixture
def ws_key(key):
return base64.b64encode(hashlib.sha1(key + WS_KEY).digest()).decode()


async def test_ws_connect(ws_key, loop, key_data) -> None:
async def test_ws_connect(ws_key: Any, loop: Any, key_data: Any) -> None:
resp = mock.Mock()
resp.status = 101
resp.headers = {
Expand Down

0 comments on commit c12a143

Please sign in to comment.