diff --git a/python/Cargo.toml b/python/Cargo.toml index 842526ed8d..589c0c28ee 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -12,6 +12,7 @@ crate-type = ["cdylib"] [dependencies] pyo3 = { version = "^0.20", features = ["extension-module", "num-bigint"] } +bytes = { version = "1.6.0" } redis = { path = "../submodules/redis-rs/redis", features = ["aio", "tokio-comp", "connection-manager","tokio-rustls-comp"] } glide-core = { path = "../glide-core", features = ["socket-layer"] } logger_core = {path = "../logger_core"} diff --git a/python/python/glide/glide.pyi b/python/python/glide/glide.pyi index d155757bbd..fde1ac0d99 100644 --- a/python/python/glide/glide.pyi +++ b/python/python/glide/glide.pyi @@ -1,10 +1,11 @@ from collections.abc import Callable from enum import Enum -from typing import Optional +from typing import List, Optional from glide.constants import TResult DEFAULT_TIMEOUT_IN_MILLISECONDS: int = ... +MAX_REQUEST_ARGS_LEN: int = ... class Level(Enum): Error = 0 @@ -23,5 +24,6 @@ class Script: def start_socket_listener_external(init_callback: Callable) -> None: ... def value_from_pointer(pointer: int) -> TResult: ... def create_leaked_value(message: str) -> int: ... +def create_leaked_bytes_vec(args_vec: List[bytes]) -> int: ... def py_init(level: Optional[Level], file_name: Optional[str]) -> Level: ... def py_log(log_level: Level, log_identifier: str, message: str) -> None: ... diff --git a/python/python/glide/redis_client.py b/python/python/glide/redis_client.py index edb9fd4122..c759a90c04 100644 --- a/python/python/glide/redis_client.py +++ b/python/python/glide/redis_client.py @@ -1,6 +1,7 @@ # Copyright GLIDE-for-Redis Project Contributors - SPDX Identifier: Apache-2.0 import asyncio +import sys import threading from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast @@ -29,6 +30,8 @@ from .glide import ( DEFAULT_TIMEOUT_IN_MILLISECONDS, + MAX_REQUEST_ARGS_LEN, + create_leaked_bytes_vec, start_socket_listener_external, value_from_pointer, ) @@ -194,6 +197,46 @@ async def _write_buffered_requests_to_socket(self) -> None: self._writer.write(b_arr) await self._writer.drain() + # TODO: change `str` to `TEncodable` where `TEncodable = Union[str, bytes]` + def _encode_arg(self, arg: str) -> bytes: + """ + Converts a string argument to bytes. + + Args: + arg (str): An encodable argument. + + Returns: + bytes: The encoded argument as bytes. + """ + + # TODO: Allow passing different encoding options + return bytes(arg, encoding="utf8") + + # TODO: change `List[str]` to `List[TEncodable]` where `TEncodable = Union[str, bytes]` + def _encode_and_sum_size( + self, + args_list: Optional[List[str]], + ) -> Tuple[List[bytes], int]: + """ + Encodes the list and calculates the total memory size. + + Args: + args_list (Optional[List[str]]): A list of strings to be converted to bytes. + If None or empty, returns ([], 0). + + Returns: + int: The total memory size of the encoded arguments in bytes. + """ + args_size = 0 + encoded_args_list: List[bytes] = [] + if not args_list: + return (encoded_args_list, args_size) + for arg in args_list: + encoded_arg = self._encode_arg(arg) + encoded_args_list.append(encoded_arg) + args_size += sys.getsizeof(encoded_arg) + return (encoded_args_list, args_size) + async def _execute_command( self, request_type: RequestType.ValueType, @@ -207,9 +250,13 @@ async def _execute_command( request = RedisRequest() request.callback_idx = self._get_callback_index() request.single_command.request_type = request_type - request.single_command.args_array.args[:] = [ - bytes(elem, encoding="utf8") for elem in args - ] # TODO - use arg pointer + (encoded_args, args_size) = self._encode_and_sum_size(args) + if args_size < MAX_REQUEST_ARGS_LEN: + request.single_command.args_array.args[:] = encoded_args + else: + request.single_command.args_vec_pointer = create_leaked_bytes_vec( + encoded_args + ) set_protobuf_route(request, route) return await self._write_request_await_response(request) @@ -229,8 +276,12 @@ async def _execute_transaction( command = Command() command.request_type = requst_type # For now, we allow the user to pass the command as array of strings - # we convert them here into bytearray (the datatype that our rust core expects) - command.args_array.args[:] = [bytes(elem, encoding="utf8") for elem in args] + # we convert them here into bytes (the datatype that our rust core expects) + (encoded_args, args_size) = self._encode_and_sum_size(args) + if args_size < MAX_REQUEST_ARGS_LEN: + command.args_array.args[:] = encoded_args + else: + command.args_vec_pointer = create_leaked_bytes_vec(encoded_args) transaction_commands.append(command) request.transaction.commands.extend(transaction_commands) set_protobuf_route(request, route) diff --git a/python/python/tests/test_async_client.py b/python/python/tests/test_async_client.py index 24407e4592..5a62d1d54f 100644 --- a/python/python/tests/test_async_client.py +++ b/python/python/tests/test_async_client.py @@ -110,10 +110,13 @@ async def test_register_client_name_and_version(self, redis_client: TRedisClient @pytest.mark.parametrize("cluster_mode", [True, False]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) - async def test_send_and_receive_large_values(self, redis_client: TRedisClient): - length = 2**16 - key = get_random_string(length) - value = get_random_string(length) + async def test_send_and_receive_large_values(self, request, cluster_mode, protocol): + redis_client = await create_client( + request, cluster_mode=cluster_mode, protocol=protocol, timeout=5000 + ) + length = 2**25 # 33mb + key = "0" * length + value = "0" * length assert len(key) == length assert len(value) == length await redis_client.set(key, value) diff --git a/python/python/tests/test_transaction.py b/python/python/tests/test_transaction.py index 23f5be1d30..013dead362 100644 --- a/python/python/tests/test_transaction.py +++ b/python/python/tests/test_transaction.py @@ -655,6 +655,23 @@ async def test_can_return_null_on_watch_transaction_failures( await client2.close() + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_transaction_large_values(self, request, cluster_mode, protocol): + redis_client = await create_client( + request, cluster_mode=cluster_mode, protocol=protocol, timeout=5000 + ) + length = 2**25 # 33mb + key = "0" * length + value = "0" * length + transaction = Transaction() + transaction.set(key, value) + transaction.get(key) + result = await redis_client.exec(transaction) + assert isinstance(result, list) + assert result[0] == OK + assert result[1] == value + @pytest.mark.parametrize("cluster_mode", [False]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) async def test_standalone_transaction(self, redis_client: RedisClient): diff --git a/python/src/lib.rs b/python/src/lib.rs index 4380b064c9..161b72c2a6 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -1,15 +1,18 @@ +use bytes::Bytes; /** * Copyright GLIDE-for-Redis Project Contributors - SPDX Identifier: Apache-2.0 */ use glide_core::start_socket_listener; +use glide_core::MAX_REQUEST_ARGS_LENGTH; use pyo3::prelude::*; -use pyo3::types::{PyBool, PyDict, PyFloat, PyList, PySet}; +use pyo3::types::{PyBool, PyBytes, PyDict, PyFloat, PyList, PySet}; use pyo3::Python; use redis::Value; pub const DEFAULT_TIMEOUT_IN_MILLISECONDS: u32 = glide_core::client::DEFAULT_RESPONSE_TIMEOUT.as_millis() as u32; +pub const MAX_REQUEST_ARGS_LEN: u32 = MAX_REQUEST_ARGS_LENGTH as u32; #[pyclass] #[derive(PartialEq, Eq, PartialOrd, Clone)] @@ -60,6 +63,7 @@ fn glide(_py: Python, m: &PyModule) -> PyResult<()> { "DEFAULT_TIMEOUT_IN_MILLISECONDS", DEFAULT_TIMEOUT_IN_MILLISECONDS, )?; + m.add("MAX_REQUEST_ARGS_LEN", MAX_REQUEST_ARGS_LEN)?; #[pyfn(m)] fn py_log(log_level: Level, log_identifier: String, message: String) { @@ -168,6 +172,19 @@ fn glide(_py: Python, m: &PyModule) -> PyResult<()> { let value = Value::SimpleString(message); Box::leak(Box::new(value)) as *mut Value as usize } + + #[pyfn(m)] + pub fn create_leaked_bytes_vec(args_vec: Vec<&PyBytes>) -> usize { + // Convert the bytes vec -> Bytes vector + let bytes_vec: Vec = args_vec + .iter() + .map(|v| { + let bytes = v.as_bytes(); + Bytes::from(bytes.to_vec()) + }) + .collect(); + Box::leak(Box::new(bytes_vec)) as *mut Vec as usize + } Ok(()) }