Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for PubSub with RESP3 parser #2721

Merged
merged 5 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
WatchError,
)
from redis.typing import ChannelT, EncodableT, KeyT
from redis.utils import safe_str, str_if_bytes
from redis.utils import HIREDIS_AVAILABLE, _set_info_logger, safe_str, str_if_bytes

PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]
_KeyT = TypeVar("_KeyT", bound=KeyT)
Expand Down Expand Up @@ -656,6 +656,7 @@ def __init__(
shard_hint: Optional[str] = None,
ignore_subscribe_messages: bool = False,
encoder=None,
push_handler_func: Optional[Callable] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

really happy to see the same option name ❤️

):
self.connection_pool = connection_pool
self.shard_hint = shard_hint
Expand All @@ -664,6 +665,7 @@ def __init__(
# we need to know the encoding options for this connection in order
# to lookup channel and pattern names for callback handlers.
self.encoder = encoder
self.push_handler_func = push_handler_func
if self.encoder is None:
self.encoder = self.connection_pool.get_encoder()
if self.encoder.decode_responses:
Expand All @@ -676,6 +678,8 @@ def __init__(
b"pong",
self.encoder.encode(self.HEALTH_CHECK_MESSAGE),
]
if self.push_handler_func is None:
_set_info_logger()
self.channels = {}
self.pending_unsubscribe_channels = set()
self.patterns = {}
Expand Down Expand Up @@ -755,6 +759,8 @@ async def connect(self):
self.connection.register_connect_callback(self.on_connect)
else:
await self.connection.connect()
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
self.connection._parser.set_push_handler(self.push_handler_func)

async def _disconnect_raise_connect(self, conn, error):
"""
Expand Down Expand Up @@ -795,7 +801,9 @@ async def parse_response(self, block: bool = True, timeout: float = 0):
await conn.connect()

read_timeout = None if block else timeout
response = await self._execute(conn, conn.read_response, timeout=read_timeout)
response = await self._execute(
conn, conn.read_response, timeout=read_timeout, push_request=True
)

if conn.health_check_interval and response == self.health_check_response:
# ignore the health check message as user might not expect it
Expand Down Expand Up @@ -925,15 +933,19 @@ def ping(self, message=None) -> Awaitable:
"""
Ping the Redis server
"""
message = "" if message is None else message
return self.execute_command("PING", message)
args = ["PING", message] if message is not None else ["PING"]
return self.execute_command(*args)

async def handle_message(self, response, ignore_subscribe_messages=False):
"""
Parses a pub/sub message. If the channel or pattern was subscribed to
with a message handler, the handler is invoked instead of a parsed
message being returned.
"""
if response is None:
return None
if isinstance(response, bytes):
response = [b"pong", response] if response != b"PONG" else [b"pong", b""]
message_type = str_if_bytes(response[0])
if message_type == "pmessage":
message = {
Expand Down
16 changes: 15 additions & 1 deletion redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,15 +485,29 @@ async def read_response(
self,
disable_decoding: bool = False,
timeout: Optional[float] = None,
push_request: Optional[bool] = False,
):
"""Read the response from a previously sent command"""
read_timeout = timeout if timeout is not None else self.socket_timeout
try:
if read_timeout is not None:
if (
read_timeout is not None
and self.protocol == "3"
and not HIREDIS_AVAILABLE
):
async with async_timeout(read_timeout):
response = await self._parser.read_response(
disable_decoding=disable_decoding, push_request=push_request
)
elif read_timeout is not None:
async with async_timeout(read_timeout):
response = await self._parser.read_response(
disable_decoding=disable_decoding
)
elif self.protocol == "3" and not HIREDIS_AVAILABLE:
response = await self._parser.read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
response = await self._parser.read_response(
disable_decoding=disable_decoding
Expand Down
16 changes: 12 additions & 4 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from redis.lock import Lock
from redis.retry import Retry
from redis.utils import safe_str, str_if_bytes
from redis.utils import HIREDIS_AVAILABLE, _set_info_logger, safe_str, str_if_bytes

SYM_EMPTY = b""
EMPTY_RESPONSE = "EMPTY_RESPONSE"
Expand Down Expand Up @@ -1429,6 +1429,7 @@ def __init__(
shard_hint=None,
ignore_subscribe_messages=False,
encoder=None,
push_handler_func=None,
):
self.connection_pool = connection_pool
self.shard_hint = shard_hint
Expand All @@ -1438,13 +1439,16 @@ def __init__(
# we need to know the encoding options for this connection in order
# to lookup channel and pattern names for callback handlers.
self.encoder = encoder
self.push_handler_func = push_handler_func
if self.encoder is None:
self.encoder = self.connection_pool.get_encoder()
self.health_check_response_b = self.encoder.encode(self.HEALTH_CHECK_MESSAGE)
if self.encoder.decode_responses:
self.health_check_response = ["pong", self.HEALTH_CHECK_MESSAGE]
else:
self.health_check_response = [b"pong", self.health_check_response_b]
if self.push_handler_func is None:
_set_info_logger()
self.reset()

def __enter__(self):
Expand Down Expand Up @@ -1515,6 +1519,8 @@ def execute_command(self, *args):
# register a callback that re-subscribes to any channels we
# were listening to when we were disconnected
self.connection.register_connect_callback(self.on_connect)
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
self.connection._parser.set_push_handler(self.push_handler_func)
connection = self.connection
kwargs = {"check_health": not self.subscribed}
if not self.subscribed:
Expand Down Expand Up @@ -1580,7 +1586,7 @@ def try_read():
return None
else:
conn.connect()
return conn.read_response()
return conn.read_response(push_request=True)

response = self._execute(conn, try_read)

Expand Down Expand Up @@ -1739,8 +1745,8 @@ def ping(self, message=None):
"""
Ping the Redis server
"""
message = "" if message is None else message
return self.execute_command("PING", message)
args = ["PING", message] if message is not None else ["PING"]
return self.execute_command(*args)

def handle_message(self, response, ignore_subscribe_messages=False):
"""
Expand All @@ -1750,6 +1756,8 @@ def handle_message(self, response, ignore_subscribe_messages=False):
"""
if response is None:
return None
if isinstance(response, bytes):
response = [b"pong", response] if response != b"PONG" else [b"pong", b""]
message_type = str_if_bytes(response[0])
if message_type == "pmessage":
message = {
Expand Down
12 changes: 9 additions & 3 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,13 +406,18 @@ def can_read(self, timeout=0):
self.disconnect()
raise ConnectionError(f"Error while reading from {host_error}: {e.args}")

def read_response(self, disable_decoding=False):
def read_response(self, disable_decoding=False, push_request=False):
"""Read the response from a previously sent command"""

host_error = self._host_error()

try:
response = self._parser.read_response(disable_decoding=disable_decoding)
if self.protocol == "3" and not HIREDIS_AVAILABLE:
response = self._parser.read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
response = self._parser.read_response(disable_decoding=disable_decoding)
except socket.timeout:
self.disconnect()
raise TimeoutError(f"Timeout reading from {host_error}")
Expand Down Expand Up @@ -705,8 +710,9 @@ def _connect(self):
class UnixDomainSocketConnection(AbstractConnection):
"Manages UDS communication to and from a Redis server"

def __init__(self, path="", **kwargs):
def __init__(self, path="", socket_timeout=None, **kwargs):
self.path = path
self.socket_timeout = socket_timeout
super().__init__(**kwargs)

def repr_pieces(self):
Expand Down
81 changes: 74 additions & 7 deletions redis/parsers/resp3.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from logging import getLogger
from typing import Any, Union

from ..exceptions import ConnectionError, InvalidResponse, ResponseError
Expand All @@ -9,18 +10,29 @@
class _RESP3Parser(_RESPBase):
"""RESP3 protocol implementation"""

def read_response(self, disable_decoding=False):
def __init__(self, socket_read_size):
super().__init__(socket_read_size)
self.push_handler_func = self.handle_push_response

def handle_push_response(self, response):
logger = getLogger("push_response")
logger.info("Push response: " + str(response))
return response

def read_response(self, disable_decoding=False, push_request=False):
pos = self._buffer.get_pos()
try:
result = self._read_response(disable_decoding=disable_decoding)
result = self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
except BaseException:
self._buffer.rewind(pos)
raise
else:
self._buffer.purge()
return result

def _read_response(self, disable_decoding=False):
def _read_response(self, disable_decoding=False, push_request=False):
raw = self._buffer.readline()
if not raw:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
Expand Down Expand Up @@ -77,31 +89,64 @@ def _read_response(self, disable_decoding=False):
response = {
self._read_response(
disable_decoding=disable_decoding
): self._read_response(disable_decoding=disable_decoding)
): self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
for _ in range(int(response))
}
# push response
elif byte == b">":
response = [
self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
for _ in range(int(response))
]
res = self.push_handler_func(response)
if not push_request:
return self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
return res
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")

if isinstance(response, bytes) and disable_decoding is False:
response = self.encoder.decode(response)
return response

def set_push_handler(self, push_handler_func):
self.push_handler_func = push_handler_func


class _AsyncRESP3Parser(_AsyncRESPBase):
async def read_response(self, disable_decoding: bool = False):
def __init__(self, socket_read_size):
super().__init__(socket_read_size)
self.push_handler_func = self.handle_push_response

def handle_push_response(self, response):
logger = getLogger("push_response")
logger.info("Push response: " + str(response))
return response

async def read_response(
self, disable_decoding: bool = False, push_request: bool = False
):
if self._chunks:
# augment parsing buffer with previously read data
self._buffer += b"".join(self._chunks)
self._chunks.clear()
self._pos = 0
response = await self._read_response(disable_decoding=disable_decoding)
response = await self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
# Successfully parsing a response allows us to clear our parsing buffer
self._clear()
return response

async def _read_response(
self, disable_decoding: bool = False
self, disable_decoding: bool = False, push_request: bool = False
) -> Union[EncodableT, ResponseError, None]:
if not self._stream or not self.encoder:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
Expand Down Expand Up @@ -166,9 +211,31 @@ async def _read_response(
)
for _ in range(int(response))
}
# push response
elif byte == b">":
response = [
(
await self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
)
for _ in range(int(response))
]
res = self.push_handler_func(response)
if not push_request:
return await (
self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
)
else:
return res
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")

if isinstance(response, bytes) and disable_decoding is False:
response = self.encoder.decode(response)
return response

def set_push_handler(self, push_handler_func):
self.push_handler_func = push_handler_func
14 changes: 14 additions & 0 deletions redis/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from contextlib import contextmanager
from functools import wraps
from typing import Any, Dict, Mapping, Union
Expand Down Expand Up @@ -117,3 +118,16 @@ def wrapper(*args, **kwargs):
return wrapper

return decorator


def _set_info_logger():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring or comment on why it exists, please.

"""
Set up a logger that log info logs to stdout.
(This is used by the default push response handler)
"""
if "push_response" not in logging.root.manager.loggerDict.keys():
logger = logging.getLogger("push_response")
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
logger.addHandler(handler)
Loading