diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 8fbb0b17a0..a7a749d4a0 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -826,7 +826,7 @@ async def on_connect(self) -> None: if str_if_bytes(await self.read_response()) != "OK": raise ConnectionError("Invalid Database") - async def disconnect(self) -> None: + async def disconnect(self, nowait: bool = False) -> None: """Disconnects from the Redis server""" try: async with async_timeout.timeout(self.socket_connect_timeout): @@ -836,8 +836,9 @@ async def disconnect(self) -> None: try: if os.getpid() == self.pid: self._writer.close() # type: ignore[union-attr] - # py3.6 doesn't have this method - if hasattr(self._writer, "wait_closed"): + # wait for close to finish, except when handling errors and + # forcecully disconnecting. + if not nowait: await self._writer.wait_closed() # type: ignore[union-attr] except OSError: pass @@ -938,10 +939,10 @@ async def read_response(self, disable_decoding: bool = False): disable_decoding=disable_decoding ) except asyncio.TimeoutError: - await self.disconnect() + await self.disconnect(nowait=True) raise TimeoutError(f"Timeout reading from {self.host}:{self.port}") except OSError as e: - await self.disconnect() + await self.disconnect(nowait=True) raise ConnectionError( f"Error while reading from {self.host}:{self.port} : {e.args}" ) @@ -950,7 +951,7 @@ async def read_response(self, disable_decoding: bool = False): # is subclass of Exception, not BaseException raise except Exception: - await self.disconnect() + await self.disconnect(nowait=True) raise if self.health_check_interval: