Skip to content

Commit

Permalink
Fix incorrect rejection of ws:// and wss:// urls (#8482)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sviatoslav Sydorenko (Святослав Сидоренко) <[email protected]>
Co-authored-by: J. Nick Koston <[email protected]>
Co-authored-by: Sam Bull <[email protected]>
  • Loading branch information
5 people authored Jul 17, 2024
1 parent b860848 commit 62173be
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 17 deletions.
2 changes: 2 additions & 0 deletions CHANGES/8481.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fixed the incorrect rejection of ``ws://`` and ``wss://`` urls
-- by :user:` AraHaan`.
4 changes: 3 additions & 1 deletion aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ class ClientTimeout:
# https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2
IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "OPTIONS", "TRACE", "PUT", "DELETE"})
HTTP_SCHEMA_SET = frozenset({"http", "https", ""})
WS_SCHEMA_SET = frozenset({"ws", "wss"})
ALLOWED_PROTOCOL_SCHEMA_SET = HTTP_SCHEMA_SET | WS_SCHEMA_SET

_RetType = TypeVar("_RetType")
_CharsetResolver = Callable[[ClientResponse, bytes], str]
Expand Down Expand Up @@ -452,7 +454,7 @@ async def _request(
except ValueError as e:
raise InvalidUrlClientError(str_or_url) from e

if url.scheme not in HTTP_SCHEMA_SET:
if url.scheme not in ALLOWED_PROTOCOL_SCHEMA_SET:
raise NonHttpUrlClientError(url)

skip_headers = set(self._skip_auto_headers)
Expand Down
19 changes: 18 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# type: ignore
import asyncio
import base64
import os
import socket
import ssl
import sys
from hashlib import md5, sha256
from hashlib import md5, sha1, sha256
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, List
Expand All @@ -13,6 +14,7 @@

import pytest

from aiohttp.http import WS_KEY
from aiohttp.test_utils import loop_context

try:
Expand Down Expand Up @@ -218,3 +220,18 @@ def start_connection():
spec_set=True,
) as start_connection_mock:
yield start_connection_mock


@pytest.fixture
def key_data():
return os.urandom(16)


@pytest.fixture
def key(key_data: Any):
return base64.b64encode(key_data)


@pytest.fixture
def ws_key(key: Any):
return base64.b64encode(sha1(key + WS_KEY).digest()).decode()
54 changes: 54 additions & 0 deletions tests/test_client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,60 @@ async def create_connection(req, traces, timeout):
c.__del__()


@pytest.mark.parametrize("protocol", ["http", "https", "ws", "wss"])
async def test_ws_connect_allowed_protocols(
create_session: Any,
create_mocked_conn: Any,
protocol: str,
ws_key: Any,
key_data: Any,
) -> None:
resp = mock.create_autospec(aiohttp.ClientResponse)
resp.status = 101
resp.headers = {
hdrs.UPGRADE: "websocket",
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.url = URL(f"{protocol}://example.com")
resp.cookies = SimpleCookie()
resp.start = mock.AsyncMock()

req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True)
req_factory = mock.Mock(return_value=req)
req.send = mock.AsyncMock(return_value=resp)

session = await create_session(request_class=req_factory)

connections = []
original_connect = session._connector.connect

async def connect(req, traces, timeout):
conn = await original_connect(req, traces, timeout)
connections.append(conn)
return conn

async def create_connection(req, traces, timeout):
return create_mocked_conn()

connector = session._connector
with mock.patch.object(connector, "connect", connect), mock.patch.object(
connector, "_create_connection", create_connection
), mock.patch.object(connector, "_release"), mock.patch(
"aiohttp.client.os"
) as m_os:
m_os.urandom.return_value = key_data
await session.ws_connect(f"{protocol}://example.com")

# normally called during garbage collection. triggers an exception
# if the connection wasn't already closed
for c in connections:
c.close()
del c

await session.close()


async def test_cookie_jar_usage(loop: Any, aiohttp_client: Any) -> None:
req_url = None

Expand Down
15 changes: 0 additions & 15 deletions tests/test_client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,6 @@
from aiohttp.test_utils import make_mocked_coro


@pytest.fixture
def key_data():
return os.urandom(16)


@pytest.fixture
def key(key_data: Any):
return base64.b64encode(key_data)


@pytest.fixture
def ws_key(key: Any):
return base64.b64encode(hashlib.sha1(key + WS_KEY).digest()).decode()


async def test_ws_connect(ws_key: Any, loop: Any, key_data: Any) -> None:
resp = mock.Mock()
resp.status = 101
Expand Down

0 comments on commit 62173be

Please sign in to comment.