Skip to content

Commit

Permalink
don't wait for disconnect() when handling errors.
Browse files Browse the repository at this point in the history
This can result in other errors such as timeouts.
  • Loading branch information
kristjanvalur committed Aug 24, 2022
1 parent 918d104 commit 4c65513
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,7 @@ async def on_connect(self) -> None:
if str_if_bytes(await self.read_response()) != "OK":
raise ConnectionError("Invalid Database")

async def disconnect(self) -> None:
async def disconnect(self, nowait: bool = False) -> None:
"""Disconnects from the Redis server"""
try:
async with async_timeout.timeout(self.socket_connect_timeout):
Expand All @@ -836,8 +836,10 @@ async def disconnect(self) -> None:
try:
if os.getpid() == self.pid:
self._writer.close() # type: ignore[union-attr]
# wait for close to finish, except when handling errors and
# forcecully disconnecting.
# py3.6 doesn't have this method
if hasattr(self._writer, "wait_closed"):
if not nowait and hasattr(self._writer, "wait_closed"):
await self._writer.wait_closed() # type: ignore[union-attr]
except OSError:
pass
Expand Down Expand Up @@ -892,10 +894,10 @@ async def send_packed_command(
self._writer.writelines(command)
await self._writer.drain()
except asyncio.TimeoutError:
await self.disconnect()
await self.disconnect(nowait=True)
raise TimeoutError("Timeout writing to socket") from None
except OSError as e:
await self.disconnect()
await self.disconnect(nowait=True)
if len(e.args) == 1:
err_no, errmsg = "UNKNOWN", e.args[0]
else:
Expand All @@ -907,7 +909,7 @@ async def send_packed_command(
except asyncio.CancelledError:
raise # is Exception and not BaseException in 3.7 and earlier
except Exception:
await self.disconnect()
await self.disconnect(nowait=True)
raise

async def send_command(self, *args: Any, **kwargs: Any) -> None:
Expand All @@ -923,7 +925,7 @@ async def can_read(self, timeout: float = 0):
try:
return await self._parser.can_read(timeout)
except OSError as e:
await self.disconnect()
await self.disconnect(nowait=True)
raise ConnectionError(
f"Error while reading from {self.host}:{self.port}: {e.args}"
)
Expand All @@ -942,17 +944,17 @@ async def read_response(self, disable_decoding: bool = False):
disable_decoding=disable_decoding
)
except asyncio.TimeoutError:
await self.disconnect()
await self.disconnect(nowait=True)
raise TimeoutError(f"Timeout reading from {self.host}:{self.port}")
except OSError as e:
await self.disconnect()
await self.disconnect(nowait=True)
raise ConnectionError(
f"Error while reading from {self.host}:{self.port} : {e.args}"
)
except asyncio.CancelledError:
raise # is Exception and not BaseException in 3.7 and earlier
except Exception:
await self.disconnect()
await self.disconnect(nowait=True)
raise

if self.health_check_interval:
Expand All @@ -976,17 +978,17 @@ async def read_response_without_lock(self, disable_decoding: bool = False):
disable_decoding=disable_decoding
)
except asyncio.TimeoutError:
await self.disconnect()
await self.disconnect(nowait=True)
raise TimeoutError(f"Timeout reading from {self.host}:{self.port}")
except OSError as e:
await self.disconnect()
await self.disconnect(nowait=True)
raise ConnectionError(
f"Error while reading from {self.host}:{self.port} : {e.args}"
)
except asyncio.CancelledError:
raise # is Exception and not BaseException in 3.7 and earlier
except Exception:
await self.disconnect()
await self.disconnect(nowait=True)
raise

if self.health_check_interval:
Expand Down

0 comments on commit 4c65513

Please sign in to comment.