Skip to content

Commit

Permalink
Merge pull request #28 from Yipit/feature/restore
Browse files Browse the repository at this point in the history
Add RESTORE command
  • Loading branch information
hltbra authored Jun 7, 2019
2 parents 600106c + de72330 commit 4ac3102
Show file tree
Hide file tree
Showing 13 changed files with 466 additions and 101 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ TYPE key | Keys
KEYS pattern | Keys
EXISTS key [key ...] | Keys
DUMP key | Keys
RESTORE key ttl serialized-value [REPLACE]\**| Keys
PING [msg] | Connection
SELECT db | Connection
SET key value | Strings
Expand Down Expand Up @@ -146,6 +147,7 @@ HINCRBY key field increment | Hashes
HGETALL key | Hashes
\* `COMMAND`'s reply is incompatible at the moment, it returns a flat array with command names (their arity, flags, positions, or step count are not returned).
\** `RESTORE` doesn't work with Redis strings compressed with LZF or encoded as `OBJ_ENCODING_INT`; also doesn't work with sets encoded as `OBJ_ENCODING_INTSET`, nor hashes and sorted sets encoded as `OBJ_ENCODING_ZIPLIST`.


## How is DRedis implemented
Expand Down Expand Up @@ -186,7 +188,7 @@ We rely on the backends' consistency properties and we use batches/transactions

### Cluster mode & Replication

Replication, key distribution, and cluster mode isn't supported.
Replication, key distribution, and cluster mode are not supported.
If you want higher availability you can create multiple servers that share or replicate a disk (consistency may suffer when replicating).
Use DNS routing or a network load balancer to route requests properly.

Expand Down
12 changes: 12 additions & 0 deletions dredis/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,18 @@ def cmd_dump(keyspace, key):
return keyspace.dump(key)


@command('RESTORE', arity=-4)
def cmd_restore(keyspace, key, ttl, payload, *args):
replace = False
if args:
if len(args) == 1 and args[0].lower() == 'replace':
replace = True
else:
raise SYNTAXERR
keyspace.restore(key, ttl, payload, replace)
return SimpleString('OK')


"""
***********************
* Connection commands *
Expand Down
16 changes: 9 additions & 7 deletions dredis/crc64.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,19 @@
"""

import ctypes
import struct


UINT64_BITMASK = 2 ** 64 - 1
UINT8_BITMASK = 2 ** 8 - 1


def uint64_t(x):
return ctypes.c_uint64(x).value
return x & UINT64_BITMASK


def uint8_t(x):
return ctypes.c_uint8(x).value
return x & UINT8_BITMASK


UINT64_C = uint64_t
Expand Down Expand Up @@ -189,12 +192,11 @@ def uint8_t(x):
]


def crc64(crc, s):
for char in s:
byte = uint8_t(ord(char))
def crc64(crc, bytes_):
for byte in bytes_:
crc = uint64_t(crc64_tab[uint8_t(crc) ^ byte] ^ (crc >> 8))
return crc


def checksum(payload):
return struct.pack('<Q', crc64(0, payload))
return struct.pack('<Q', crc64(0, bytearray(payload)))
44 changes: 25 additions & 19 deletions dredis/keyspace.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import collections
import fnmatch

from dredis import crc64, rdb
from dredis import rdb
from dredis.db import DB_MANAGER, KEY_CODEC
from dredis.lua import LuaRunner
from dredis.utils import to_float
Expand Down Expand Up @@ -161,7 +161,7 @@ def zadd(self, key, score, value):
zset_length += 1
batch.put(KEY_CODEC.encode_zset(key), bytes(zset_length))

batch.put(KEY_CODEC.encode_zset_value(key, value), bytes(score))
batch.put(KEY_CODEC.encode_zset_value(key, value), to_float_string(score))
batch.put(KEY_CODEC.encode_zset_score(key, value, score), bytes(''))
batch.write()

Expand Down Expand Up @@ -316,13 +316,13 @@ def zunionstore(self, destination, keys, weights):
return result

def type(self, key):
if self._db.get(KEY_CODEC.encode_string(key)):
if self._db.get(KEY_CODEC.encode_string(key)) is not None:
return 'string'
if self._db.get(KEY_CODEC.encode_set(key)):
if self._db.get(KEY_CODEC.encode_set(key)) is not None:
return 'set'
if self._db.get(KEY_CODEC.encode_hash(key)):
if self._db.get(KEY_CODEC.encode_hash(key)) is not None:
return 'hash'
if self._db.get(KEY_CODEC.encode_zset(key)):
if self._db.get(KEY_CODEC.encode_zset(key)) is not None:
return 'zset'
return 'none'

Expand Down Expand Up @@ -406,7 +406,7 @@ def hkeys(self, key):
def hvals(self, key):
result = []
if self._db.get(KEY_CODEC.encode_hash(key)) is not None:
for db_key, db_value in self._get_db_iterator(KEY_CODEC.get_min_hash_field(key)):
for _, db_value in self._get_db_iterator(KEY_CODEC.get_min_hash_field(key)):
result.append(db_value)
return result

Expand All @@ -424,12 +424,13 @@ def hincrby(self, key, field, increment):
return new_value

def hgetall(self, key):
keys = self.hkeys(key)
values = self.hvals(key)
result = []
for (k, v) in zip(keys, values):
result.append(k)
result.append(v)
if self._db.get(KEY_CODEC.encode_hash(key)) is not None:
for db_key, db_value in self._get_db_iterator(KEY_CODEC.get_min_hash_field(key)):
_, length, field_key = KEY_CODEC.decode_key(db_key)
field = field_key[length:]
result.append(field)
result.append(db_value)
return result

@property
Expand All @@ -441,13 +442,18 @@ def dump(self, key):
if key_type == 'none':
return None
else:
payload = (
rdb.object_type(key_type) +
rdb.object_value(self, key, key_type) +
rdb.get_rdb_version()
)
checksum = crc64.checksum(payload)
return payload + checksum
return rdb.generate_payload(self, key, key_type)

def restore(self, key, ttl, payload, replace):
# TODO: there's no TTL support at the moment
object_type = self.type(key)
if object_type != 'none':
if replace:
self.delete(key)
else:
raise KeyError('BUSYKEY Target key name already exists')
rdb.verify_payload(payload)
rdb.load_object(self, key, payload)


class ScoreRange(object):
Expand Down
128 changes: 128 additions & 0 deletions dredis/rdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import struct

from dredis import crc64

RDB_TYPE_STRING = 0
RDB_TYPE_SET = 2
RDB_TYPE_ZSET = 3
Expand All @@ -27,6 +29,8 @@

RDB_VERSION = 7

BAD_DATA_FORMAT_ERR = ValueError("Bad data format")


def object_type(type_name):
"""
Expand Down Expand Up @@ -88,6 +92,35 @@ def save_len(len):
return struct.pack('>BL', (RDB_32BITLEN << 6), len)


def load_len(data, index):
"""
:param data: str
:return: (int, str). the length of the string and the data after the string
Based on rdbLoadLen() in rdb.c
"""

def get_byte(i):
return struct.unpack('>B', data[index + i])[0]

def get_long(start, end):
return struct.unpack('>L', data[index + start:index + end])[0]

len_type = (get_byte(0) & 0xC0) >> 6
if len_type == RDB_6BITLEN:
length = get_byte(0) & 0x3F
index += 1
elif len_type == RDB_14BITLEN:
length = ((get_byte(0) & 0x3F) << 8) | get_byte(1)
index += 2
elif len_type == RDB_32BITLEN:
length = get_long(1, 5)
index += 5
else:
raise BAD_DATA_FORMAT_ERR
return length, index


def save_double(number):
"""
:return: big endian encoded float. 255 represents -inf, 244 +inf, 253 NaN
Expand All @@ -106,8 +139,103 @@ def save_double(number):
return struct.pack('>B', len(string)) + string


def load_double(data, index):
length = struct.unpack('>B', data[index])[0]
index += 1
if length == 255:
result = float('-inf')
elif length == 254:
result = float('+inf')
elif length == 253:
result = float('nan')
else:
result = float(data[index:index + length])
index += length
return result, index


def get_rdb_version():
"""
:return: little endian encoded 2-byte RDB version
"""
return struct.pack('<BB', RDB_VERSION & 0xff, (RDB_VERSION >> 8) & 0xff)


def load_object(keyspace, key, payload):
data = payload[:-10] # ignore the RDB header (2 bytes) and the CRC64 checksum (8 bytes)
if not data:
raise BAD_DATA_FORMAT_ERR
obj_type = struct.unpack('<B', data[0])[0]
if obj_type == RDB_TYPE_STRING:
load_string_object(keyspace, key, data[1:])
elif obj_type == RDB_TYPE_SET:
load_set_object(keyspace, key, data[1:])
elif obj_type == RDB_TYPE_ZSET:
load_zset_object(keyspace, key, data[1:])
elif obj_type == RDB_TYPE_HASH:
load_hash_object(keyspace, key, data[1:])
else:
raise BAD_DATA_FORMAT_ERR


def load_string_object(keyspace, key, data):
index = 0
length, index = load_len(data, index)
obj = data[index:index + length]
keyspace.set(key, obj)


def load_set_object(keyspace, key, data):
index = 0
length, index = load_len(data, index)
for _ in xrange(length):
elem_length, index = load_len(data, index)
elem = data[index:index + elem_length]
index += elem_length
keyspace.sadd(key, elem)


def load_zset_object(keyspace, key, data):
index = 0
length, index = load_len(data, index)
for _ in xrange(length):
value_length, index = load_len(data, index)
value = data[index:index + value_length]
index += value_length
score, index = load_double(data, index)
keyspace.zadd(key, score, value)


def load_hash_object(keyspace, key, data):
index = 0
length, index = load_len(data, index)
for _ in xrange(length):
field_length, index = load_len(data, index)
field = data[index:index + field_length]
index += field_length
value_length, index = load_len(data, index)
value = data[index:index + value_length]
index += value_length
keyspace.hset(key, field, value)


def generate_payload(keyspace, key, key_type):
payload = (
object_type(key_type) +
object_value(keyspace, key, key_type) +
get_rdb_version()
)
checksum = crc64.checksum(payload)
return payload + checksum


def verify_payload(payload):
bad_payload = ValueError('DUMP payload version or checksum are wrong')
if len(payload) < 10:
raise bad_payload
data, footer = payload[:-10], payload[-10:]
rdb_version, crc = footer[:2], footer[2:]
if rdb_version > get_rdb_version():
raise bad_payload
if crc64.checksum(data + rdb_version) != crc:
raise bad_payload
2 changes: 1 addition & 1 deletion dredis/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
def execute_cmd(keyspace, send_fn, cmd, *args):
try:
result = run_command(keyspace, cmd, args)
except (SyntaxError, CommandNotFound, ValueError, RedisScriptError) as exc:
except (SyntaxError, CommandNotFound, ValueError, RedisScriptError, KeyError) as exc:
transmit(send_fn, exc)
except Exception:
# no tests cover this part because it's meant for internal errors,
Expand Down
34 changes: 34 additions & 0 deletions tests/integration/test_keys.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import pytest
import redis

from tests.helpers import fresh_redis


def test_types():
r = fresh_redis()

r.set('emptystr', '')
r.set('mystr', 'test')
r.incr('myint')
r.sadd('myset', 'test')
r.zadd('myzset', 0, 'test')
r.hset('myhash1', 'field', 'value')
r.hsetnx('myhash2', 'field', 'value')

assert r.type('emptystr') == 'string'
assert r.type('mystr') == 'string'
assert r.type('myint') == 'string'
assert r.type('myset') == 'set'
Expand Down Expand Up @@ -77,3 +82,32 @@ def test_dump():

r.set('str', 'test')
assert r.dump('str') == b'\x00\x04test\x07\x00~\xa2zSd;e_'


def test_restore():
r = fresh_redis()

r.set('str1', 'test')
payload = r.dump('str1')
r.set('str1', 'test2')
r.restore('str1', 0, payload, replace=True)
r.restore('str2', 0, payload, replace=False)

assert r.get('str1') == 'test'
assert r.get('str2') == 'test'


def test_restore_with_valid_params():
r = fresh_redis()

with pytest.raises(redis.ResponseError) as exc:
r.execute_command('RESTORE', 'str1')
assert str(exc.value) == "wrong number of arguments for 'restore' command"

with pytest.raises(redis.ResponseError) as exc:
r.execute_command('RESTORE', 'str1', 'a')
assert str(exc.value) == "wrong number of arguments for 'restore' command"

with pytest.raises(redis.ResponseError) as exc:
r.execute_command('RESTORE', 'str1', '0', 'payload', 'repl')
assert str(exc.value) == "syntax error"
3 changes: 0 additions & 3 deletions tests/unit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
from dredis.db import DB_MANAGER

DB_MANAGER.setup_dbs('', backend='memory', backend_options={})
Loading

0 comments on commit 4ac3102

Please sign in to comment.