From caa7e7abf388fcb2574567b25233c48016714b55 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 17 Aug 2024 10:45:56 -0500 Subject: [PATCH] Fix exceptions from WebSocket ping task not being consumed (#8685) (cherry picked from commit e7c02ca49eac8d9dce197fdcd6273531d0008649) --- CHANGES/8685.bugfix.rst | 3 ++ aiohttp/client_ws.py | 25 +++++---- aiohttp/web_ws.py | 18 +++++-- tests/test_client_ws_functional.py | 30 +++++++++++ tests/test_web_websocket_functional.py | 73 +++++++++++++++++++++++++- 5 files changed, 136 insertions(+), 13 deletions(-) create mode 100644 CHANGES/8685.bugfix.rst diff --git a/CHANGES/8685.bugfix.rst b/CHANGES/8685.bugfix.rst new file mode 100644 index 00000000000..8bd20196ee3 --- /dev/null +++ b/CHANGES/8685.bugfix.rst @@ -0,0 +1,3 @@ +Fixed unconsumed exceptions raised by the WebSocket heartbeat -- by :user:`bdraco`. + +If the heartbeat ping raised an exception, it would not be consumed and would be logged as an warning. diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index 7fd141248bd..7b3a5bf952d 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -141,21 +141,28 @@ def _send_heartbeat(self) -> None: if not ping_task.done(): self._ping_task = ping_task ping_task.add_done_callback(self._ping_task_done) + else: + self._ping_task_done(ping_task) def _ping_task_done(self, task: "asyncio.Task[None]") -> None: """Callback for when the ping task completes.""" + if not task.cancelled() and (exc := task.exception()): + self._handle_ping_pong_exception(exc) self._ping_task = None def _pong_not_received(self) -> None: - if not self._closed: - self._set_closed() - self._close_code = WSCloseCode.ABNORMAL_CLOSURE - self._exception = ServerTimeoutError() - self._response.close() - if self._waiting and not self._closing: - self._reader.feed_data( - WSMessage(WSMsgType.ERROR, self._exception, None) - ) + self._handle_ping_pong_exception(ServerTimeoutError()) + + def _handle_ping_pong_exception(self, exc: BaseException) -> None: + """Handle exceptions raised during ping/pong processing.""" + if self._closed: + return + self._set_closed() + self._close_code = WSCloseCode.ABNORMAL_CLOSURE + self._exception = exc + self._response.close() + if self._waiting and not self._closing: + self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None)) def _set_closed(self) -> None: """Set the connection to closed. diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 98f26cc48c6..382223097ea 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -164,16 +164,28 @@ def _send_heartbeat(self) -> None: if not ping_task.done(): self._ping_task = ping_task ping_task.add_done_callback(self._ping_task_done) + else: + self._ping_task_done(ping_task) def _ping_task_done(self, task: "asyncio.Task[None]") -> None: """Callback for when the ping task completes.""" + if not task.cancelled() and (exc := task.exception()): + self._handle_ping_pong_exception(exc) self._ping_task = None def _pong_not_received(self) -> None: if self._req is not None and self._req.transport is not None: - self._set_closed() - self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE) - self._exception = asyncio.TimeoutError() + self._handle_ping_pong_exception(asyncio.TimeoutError()) + + def _handle_ping_pong_exception(self, exc: BaseException) -> None: + """Handle exceptions raised during ping/pong processing.""" + if self._closed: + return + self._set_closed() + self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE) + self._exception = exc + if self._waiting and not self._closing and self._reader is not None: + self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None)) def _set_closed(self) -> None: """Set the connection to closed. diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index 907ae232e9a..274092a189a 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -600,6 +600,36 @@ async def handler(request): assert ping_received +async def test_heartbeat_connection_closed(aiohttp_client: AiohttpClient) -> None: + """Test that the connection is closed while ping is in progress.""" + + async def handler(request: web.Request) -> NoReturn: + ws = web.WebSocketResponse(autoping=False) + await ws.prepare(request) + await ws.receive() + assert False + + app = web.Application() + app.router.add_route("GET", "/", handler) + + client = await aiohttp_client(app) + resp = await client.ws_connect("/", heartbeat=0.1) + ping_count = 0 + # We patch write here to simulate a connection reset error + # since if we closed the connection normally, the client would + # would cancel the heartbeat task and we wouldn't get a ping + assert resp._conn is not None + with mock.patch.object( + resp._conn.transport, "write", side_effect=ConnectionResetError + ), mock.patch.object(resp._writer, "ping", wraps=resp._writer.ping) as ping: + await resp.receive() + ping_count = ping.call_count + # Connection should be closed roughly after 1.5x heartbeat. + await asyncio.sleep(0.2) + assert ping_count == 1 + assert resp.close_code is WSCloseCode.ABNORMAL_CLOSURE + + async def test_heartbeat_no_pong(aiohttp_client: AiohttpClient) -> None: """Test that the connection is closed if no pong is received without sending messages.""" ping_received = False diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index 6540f134da8..2be54486ee9 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -4,7 +4,8 @@ import contextlib import sys import weakref -from typing import Any, Optional +from typing import Any, NoReturn, Optional +from unittest import mock import pytest @@ -724,6 +725,76 @@ async def handler(request): await ws.close() +async def test_heartbeat_connection_closed( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: + """Test that the connection is closed while ping is in progress.""" + ping_count = 0 + + async def handler(request: web.Request) -> NoReturn: + nonlocal ping_count + ws_server = web.WebSocketResponse(heartbeat=0.05) + await ws_server.prepare(request) + # We patch write here to simulate a connection reset error + # since if we closed the connection normally, the server would + # would cancel the heartbeat task and we wouldn't get a ping + with mock.patch.object( + ws_server._req.transport, "write", side_effect=ConnectionResetError + ), mock.patch.object( + ws_server._writer, "ping", wraps=ws_server._writer.ping + ) as ping: + try: + await ws_server.receive() + finally: + ping_count = ping.call_count + assert False + + app = web.Application() + app.router.add_get("/", handler) + + client = await aiohttp_client(app) + ws = await client.ws_connect("/", autoping=False) + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.CLOSED + assert msg.extra is None + assert ws.close_code == WSCloseCode.ABNORMAL_CLOSURE + assert ping_count == 1 + await ws.close() + + +async def test_heartbeat_failure_ends_receive( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: + """Test that no heartbeat response to the server ends the receive call.""" + ws_server_close_code = None + ws_server_exception = None + + async def handler(request: web.Request) -> NoReturn: + nonlocal ws_server_close_code, ws_server_exception + ws_server = web.WebSocketResponse(heartbeat=0.05) + await ws_server.prepare(request) + try: + await ws_server.receive() + finally: + ws_server_close_code = ws_server.close_code + ws_server_exception = ws_server.exception() + assert False + + app = web.Application() + app.router.add_get("/", handler) + + client = await aiohttp_client(app) + ws = await client.ws_connect("/", autoping=False) + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.PING + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.CLOSED + assert ws.close_code == WSCloseCode.ABNORMAL_CLOSURE + assert ws_server_close_code == WSCloseCode.ABNORMAL_CLOSURE + assert isinstance(ws_server_exception, asyncio.TimeoutError) + await ws.close() + + async def test_heartbeat_no_pong_send_many_messages( loop: Any, aiohttp_client: Any ) -> None: