diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b9e7c14f4..2c17db39e8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ * Python: Added ZINTERCARD command ([#1418](https://github.com/aws/glide-for-redis/pull/1418)) * Python: Added ZMPOP command ([#1417](https://github.com/aws/glide-for-redis/pull/1417)) * Python: Added SMOVE command ([#1421](https://github.com/aws/glide-for-redis/pull/1421)) +* Python: Added SUNIONSTORE command ([#1423](https://github.com/aws/glide-for-redis/pull/1423)) #### Fixes diff --git a/python/python/glide/async_commands/core.py b/python/python/glide/async_commands/core.py index 78410ed4cf..43466af244 100644 --- a/python/python/glide/async_commands/core.py +++ b/python/python/glide/async_commands/core.py @@ -1570,6 +1570,37 @@ async def smove( ), ) + async def sunionstore( + self, + destination: str, + keys: List[str], + ) -> int: + """ + Stores the members of the union of all given sets specified by `keys` into a new set at `destination`. + + See https://valkey.io/commands/sunionstore for more details. + + Note: + When in cluster mode, all keys in `keys` and `destination` must map to the same hash slot. + + Args: + destination (str): The key of the destination set. + keys (List[str]): The keys from which to retrieve the set members. + + Returns: + int: The number of elements in the resulting set. + + Examples: + >>> await client.sadd("set1", ["member1"]) + >>> await client.sadd("set2", ["member2"]) + >>> await client.sunionstore("my_set", ["set1", "set2"]) + 2 # Two elements were stored in "my_set", and those two members are the union of "set1" and "set2". + """ + return cast( + int, + await self._execute_command(RequestType.SUnionStore, [destination] + keys), + ) + async def ltrim(self, key: str, start: int, end: int) -> TOK: """ Trim an existing list so that it will contain only the specified range of elements specified. diff --git a/python/python/glide/async_commands/transaction.py b/python/python/glide/async_commands/transaction.py index d9e44f10c9..d9b4c3ff5a 100644 --- a/python/python/glide/async_commands/transaction.py +++ b/python/python/glide/async_commands/transaction.py @@ -1033,6 +1033,25 @@ def smove( """ return self.append_command(RequestType.SMove, [source, destination, member]) + def sunionstore( + self: TTransaction, + destination: str, + keys: List[str], + ) -> TTransaction: + """ + Stores the members of the union of all given sets specified by `keys` into a new set at `destination`. + + See https://valkey.io/commands/sunionstore for more details. + + Args: + destination (str): The key of the destination set. + keys (List[str]): The keys from which to retrieve the set members. + + Command response: + int: The number of elements in the resulting set. + """ + return self.append_command(RequestType.SUnionStore, [destination] + keys) + def ltrim(self: TTransaction, key: str, start: int, end: int) -> TTransaction: """ Trim an existing list so that it will contain only the specified range of elements specified. diff --git a/python/python/tests/test_async_client.py b/python/python/tests/test_async_client.py index 370f62d97e..095ec654e8 100644 --- a/python/python/tests/test_async_client.py +++ b/python/python/tests/test_async_client.py @@ -1143,6 +1143,53 @@ async def test_smove(self, redis_client: TRedisClient): with pytest.raises(RequestError): await redis_client.smove(string_key, key1, "_") + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_sunionstore(self, redis_client: TRedisClient): + key1 = f"{{testKey}}:1-{get_random_string(10)}" + key2 = f"{{testKey}}:2-{get_random_string(10)}" + key3 = f"{{testKey}}:3-{get_random_string(10)}" + key4 = f"{{testKey}}:4-{get_random_string(10)}" + string_key = f"{{testKey}}:4-{get_random_string(10)}" + non_existing_key = f"{{testKey}}:5-{get_random_string(10)}" + + assert await redis_client.sadd(key1, ["a", "b", "c"]) == 3 + assert await redis_client.sadd(key2, ["c", "d", "e"]) == 3 + assert await redis_client.sadd(key3, ["e", "f", "g"]) == 3 + + # store union in new key + assert await redis_client.sunionstore(key4, [key1, key2]) == 5 + assert await redis_client.smembers(key4) == {"a", "b", "c", "d", "e"} + + # overwrite existing set + assert await redis_client.sunionstore(key1, [key4, key2]) == 5 + assert await redis_client.smembers(key1) == {"a", "b", "c", "d", "e"} + + # overwrite one of the source keys + assert await redis_client.sunionstore(key2, [key4, key2]) == 5 + assert await redis_client.smembers(key1) == {"a", "b", "c", "d", "e"} + + # union with a non existing key + assert await redis_client.sunionstore(key2, [non_existing_key]) == 0 + assert await redis_client.smembers(key2) == set() + + # key exists, but it is not a sorted set + assert await redis_client.set(string_key, "value") == OK + with pytest.raises(RequestError): + await redis_client.sunionstore(key4, [string_key, key1]) + + # overwrite destination when destination is not a set + assert await redis_client.sunionstore(string_key, [key1, key3]) == 7 + assert await redis_client.smembers(string_key) == { + "a", + "b", + "c", + "d", + "e", + "f", + "g", + } + @pytest.mark.parametrize("cluster_mode", [True, False]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) async def test_ltrim(self, redis_client: TRedisClient): @@ -3081,6 +3128,7 @@ async def test_multi_key_command_returns_cross_slot_error( redis_client.bzpopmin(["abc", "zxy", "lkn"], 0.5), redis_client.bzpopmax(["abc", "zxy", "lkn"], 0.5), redis_client.smove("abc", "def", "_"), + redis_client.sunionstore("abc", ["zxy", "lkn"]), ] if not check_if_server_version_lt(redis_client, "7.0.0"): diff --git a/python/python/tests/test_transaction.py b/python/python/tests/test_transaction.py index 341450247a..5856aef735 100644 --- a/python/python/tests/test_transaction.py +++ b/python/python/tests/test_transaction.py @@ -203,6 +203,8 @@ async def transaction_test( args.append("bar") transaction.sadd(key7, ["foo", "bar"]) args.append(2) + transaction.sunionstore(key7, [key7, key7]) + args.append(2) transaction.spop_count(key7, 4) args.append({"foo", "bar"}) transaction.smove(key7, key7, "non_existing_member")