Skip to content

Commit

Permalink
Catch Exception and not BaseException in the Connection (#2104)
Browse files Browse the repository at this point in the history
* Add failing unittests for passing BaseException through

* Resolve failing unittest

* Remove redundant checks for asyncio.CancelledError
  • Loading branch information
kristjanvalur authored Sep 29, 2022
1 parent fbf68dd commit 9fe8366
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 7 deletions.
8 changes: 3 additions & 5 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,8 +502,6 @@ async def read_from_socket(
# data was read from the socket and added to the buffer.
# return True to indicate that data was read.
return True
except asyncio.CancelledError:
raise
except (socket.timeout, asyncio.TimeoutError):
if raise_on_timeout:
raise TimeoutError("Timeout reading from socket") from None
Expand Down Expand Up @@ -721,7 +719,7 @@ async def connect(self):
lambda: self._connect(), lambda error: self.disconnect()
)
except asyncio.CancelledError:
raise
raise # in 3.7 and earlier, this is an Exception, not BaseException
except (socket.timeout, asyncio.TimeoutError):
raise TimeoutError("Timeout connecting to server")
except OSError as e:
Expand Down Expand Up @@ -916,7 +914,7 @@ async def send_packed_command(
raise ConnectionError(
f"Error {err_no} while writing to socket. {errmsg}."
) from e
except BaseException:
except Exception:
await self.disconnect()
raise

Expand Down Expand Up @@ -958,7 +956,7 @@ async def read_response(self, disable_decoding: bool = False):
raise ConnectionError(
f"Error while reading from {self.host}:{self.port} : {e.args}"
)
except BaseException:
except Exception:
await self.disconnect()
raise

Expand Down
4 changes: 2 additions & 2 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ def send_packed_command(self, command, check_health=True):
errno = e.args[0]
errmsg = e.args[1]
raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.")
except BaseException:
except Exception:
self.disconnect()
raise

Expand Down Expand Up @@ -804,7 +804,7 @@ def read_response(self, disable_decoding=False):
except OSError as e:
self.disconnect()
raise ConnectionError(f"Error while reading from {hosterr}" f" : {e.args}")
except BaseException:
except Exception:
self.disconnect()
raise

Expand Down
74 changes: 74 additions & 0 deletions tests/test_asyncio/test_pubsub.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
import functools
import socket
import sys
from typing import Optional
from unittest.mock import patch

import async_timeout
import pytest
Expand Down Expand Up @@ -914,3 +916,75 @@ async def loop_step_listen(self):
return True
except asyncio.TimeoutError:
return False


@pytest.mark.onlynoncluster
class TestBaseException:
@pytest.mark.skipif(
sys.version_info < (3, 8), reason="requires python 3.8 or higher"
)
async def test_outer_timeout(self, r: redis.Redis):
"""
Using asyncio_timeout manually outside the inner method timeouts works.
This works on Python versions 3.8 and greater, at which time asyncio.
CancelledError became a BaseException instead of an Exception before.
"""
pubsub = r.pubsub()
await pubsub.subscribe("foo")
assert pubsub.connection.is_connected

async def get_msg_or_timeout(timeout=0.1):
async with async_timeout.timeout(timeout):
# blocking method to return messages
while True:
response = await pubsub.parse_response(block=True)
message = await pubsub.handle_message(
response, ignore_subscribe_messages=False
)
if message is not None:
return message

# get subscribe message
msg = await get_msg_or_timeout(10)
assert msg is not None
# timeout waiting for another message which never arrives
assert pubsub.connection.is_connected
with pytest.raises(asyncio.TimeoutError):
await get_msg_or_timeout()
# the timeout on the read should not cause disconnect
assert pubsub.connection.is_connected

async def test_base_exception(self, r: redis.Redis):
"""
Manually trigger a BaseException inside the parser's .read_response method
and verify that it isn't caught
"""
pubsub = r.pubsub()
await pubsub.subscribe("foo")
assert pubsub.connection.is_connected

async def get_msg():
# blocking method to return messages
while True:
response = await pubsub.parse_response(block=True)
message = await pubsub.handle_message(
response, ignore_subscribe_messages=False
)
if message is not None:
return message

# get subscribe message
msg = await get_msg()
assert msg is not None
# timeout waiting for another message which never arrives
assert pubsub.connection.is_connected
with patch("redis.asyncio.connection.PythonParser.read_response") as mock1:
mock1.side_effect = BaseException("boom")
with patch("redis.asyncio.connection.HiredisParser.read_response") as mock2:
mock2.side_effect = BaseException("boom")

with pytest.raises(BaseException):
await get_msg()

# the timeout on the read should not cause disconnect
assert pubsub.connection.is_connected
42 changes: 42 additions & 0 deletions tests/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,3 +735,45 @@ def loop_step_listen(self):
for message in self.pubsub.listen():
self.messages.put(message)
return True


@pytest.mark.onlynoncluster
class TestBaseException:
def test_base_exception(self, r: redis.Redis):
"""
Manually trigger a BaseException inside the parser's .read_response method
and verify that it isn't caught
"""
pubsub = r.pubsub()
pubsub.subscribe("foo")

def is_connected():
return pubsub.connection._sock is not None

assert is_connected()

def get_msg():
# blocking method to return messages
while True:
response = pubsub.parse_response(block=True)
message = pubsub.handle_message(
response, ignore_subscribe_messages=False
)
if message is not None:
return message

# get subscribe message
msg = get_msg()
assert msg is not None
# timeout waiting for another message which never arrives
assert is_connected()
with patch("redis.connection.PythonParser.read_response") as mock1:
mock1.side_effect = BaseException("boom")
with patch("redis.connection.HiredisParser.read_response") as mock2:
mock2.side_effect = BaseException("boom")

with pytest.raises(BaseException):
get_msg()

# the timeout on the read should not cause disconnect
assert is_connected()

0 comments on commit 9fe8366

Please sign in to comment.