Skip to content

Commit

Permalink
feat:mget supports cache querying multiple keys (#2675)
Browse files Browse the repository at this point in the history
* multi get

---------

Co-authored-by: chejinge <[email protected]>
  • Loading branch information
chejinge and brother-jin authored Jun 4, 2024
1 parent 4b390c0 commit 1bb76ed
Show file tree
Hide file tree
Showing 6 changed files with 253 additions and 40 deletions.
9 changes: 7 additions & 2 deletions include/pika_kv.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,19 @@ class MgetCmd : public Cmd {
void Merge() override;
Cmd* Clone() override { return new MgetCmd(*this); }

private:
void DoInitial() override;
void MergeCachedAndDbResults();
void AssembleResponseFromCache();

private:
std::vector<std::string> keys_;
std::vector<std::string> cache_miss_keys_;
std::string value_;
std::unordered_map<std::string, std::string> cache_hit_values_;
std::vector<storage::ValueStatus> split_res_;
std::vector<storage::ValueStatus> db_value_status_array_;
std::vector<storage::ValueStatus> cache_value_status_array_;
int64_t ttl_ = -1;
void DoInitial() override;
rocksdb::Status s_;
};

Expand Down
7 changes: 3 additions & 4 deletions src/pika_bit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,11 @@ void BitCountCmd::ReadCache() {
int64_t count = 0;
int64_t start = static_cast<long>(start_offset_);
int64_t end = static_cast<long>(end_offset_);
rocksdb::Status s;
bool flag = true;
if (count_all_) {
s = db_->cache()->BitCount(key_, start, end, &count, 0);
} else {
s = db_->cache()->BitCount(key_, start, end, &count, 1);
flag = false;
}
rocksdb::Status s = db_->cache()->BitCount(key_, start, end, &count, flag);

if (s.ok()) {
res_.AppendInteger(count);
Expand Down
4 changes: 2 additions & 2 deletions src/pika_command.cc
Original file line number Diff line number Diff line change
Expand Up @@ -678,15 +678,15 @@ void InitCmdTable(CmdTable* cmd_table) {
cmd_table->insert(std::pair<std::string, std::unique_ptr<Cmd>>(kCmdNameBitSet, std::move(bitsetptr)));
////bitgetCmd
std::unique_ptr<Cmd> bitgetptr =
std::make_unique<BitGetCmd>(kCmdNameBitGet, 3, kCmdFlagsRead | kCmdFlagsBit | kCmdFlagsSlow | kCmdFlagsDoThroughDB | kCmdFlagsReadCache | kCmdFlagsUpdateCache);
std::make_unique<BitGetCmd>(kCmdNameBitGet, 3, kCmdFlagsRead | kCmdFlagsBit | kCmdFlagsSlow);
cmd_table->insert(std::pair<std::string, std::unique_ptr<Cmd>>(kCmdNameBitGet, std::move(bitgetptr)));
////bitcountCmd
std::unique_ptr<Cmd> bitcountptr =
std::make_unique<BitCountCmd>(kCmdNameBitCount, -2, kCmdFlagsRead | kCmdFlagsBit | kCmdFlagsSlow | kCmdFlagsDoThroughDB | kCmdFlagsReadCache | kCmdFlagsUpdateCache);
cmd_table->insert(std::pair<std::string, std::unique_ptr<Cmd>>(kCmdNameBitCount, std::move(bitcountptr)));
////bitposCmd
std::unique_ptr<Cmd> bitposptr =
std::make_unique<BitPosCmd>(kCmdNameBitPos, -3, kCmdFlagsRead | kCmdFlagsBit | kCmdFlagsSlow | kCmdFlagsDoThroughDB | kCmdFlagsReadCache | kCmdFlagsUpdateCache);
std::make_unique<BitPosCmd>(kCmdNameBitPos, -3, kCmdFlagsRead | kCmdFlagsBit | kCmdFlagsSlow);
cmd_table->insert(std::pair<std::string, std::unique_ptr<Cmd>>(kCmdNameBitPos, std::move(bitposptr)));
////bitopCmd
std::unique_ptr<Cmd> bitopptr =
Expand Down
104 changes: 76 additions & 28 deletions src/pika_kv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -509,26 +509,41 @@ void MgetCmd::DoInitial() {
keys_ = argv_;
keys_.erase(keys_.begin());
split_res_.resize(keys_.size());
cache_miss_keys_.clear();
}

void MgetCmd::AssembleResponseFromCache() {
res_.AppendArrayLenUint64(keys_.size());
for (const auto& key : keys_) {
auto it = cache_hit_values_.find(key);
if (it != cache_hit_values_.end()) {
res_.AppendStringLen(it->second.size());
res_.AppendContent(it->second);
} else {
res_.SetRes(CmdRes::kErrOther, "Internal error during cache assembly");
return;
}
}
}

void MgetCmd::Do() {
// Without using the cache and querying only the DB, we need to use keys_.
// This line will only be assigned when querying the DB directly.
if(cache_miss_keys_.size() == 0) {
cache_miss_keys_ = keys_;
}
db_value_status_array_.clear();
s_ = db_->storage()->MGet(keys_, &db_value_status_array_);
if (s_.ok()) {
res_.AppendArrayLenUint64(db_value_status_array_.size());
for (const auto& vs : db_value_status_array_) {
if (vs.status.ok()) {
res_.AppendStringLenUint64(vs.value.size());
res_.AppendContent(vs.value);
} else {
res_.AppendContent("$-1");
}
s_ = db_->storage()->MGet(cache_miss_keys_, &db_value_status_array_);
if (!s_.ok()) {
if (s_.IsInvalidArgument()) {
res_.SetRes(CmdRes::kMultiKey);
} else {
res_.SetRes(CmdRes::kErrOther, s_.ToString());
}
} else if (s_.IsInvalidArgument()) {
res_.SetRes(CmdRes::kMultiKey);
} else {
res_.SetRes(CmdRes::kErrOther, s_.ToString());
return;
}

MergeCachedAndDbResults();
}

void MgetCmd::Split(const HintKeys& hint_keys) {
Expand Down Expand Up @@ -560,34 +575,67 @@ void MgetCmd::Merge() {
}
}

void MgetCmd::DoThroughDB() {
res_.clear();
Do();
}

void MgetCmd::ReadCache() {
if (1 < keys_.size()) {
res_.SetRes(CmdRes::kCacheMiss);
return;
for (const auto key : keys_) {
std::string value;
auto s = db_->cache()->Get(const_cast<std::string&>(key), &value);
if (s.ok()) {
cache_hit_values_[key] = value;
} else {
cache_miss_keys_.push_back(key);
}
}
auto s = db_->cache()->Get(keys_[0], &value_);
if (s.ok()) {
res_.AppendArrayLen(1);
res_.AppendStringLen(value_.size());
res_.AppendContent(value_);
if (cache_miss_keys_.empty()) {
AssembleResponseFromCache();
} else {
res_.SetRes(CmdRes::kCacheMiss);
}
}

void MgetCmd::DoThroughDB() {
res_.clear();
Do();
void MgetCmd::DoUpdateCache() {
size_t db_index = 0;
for (const auto key : cache_miss_keys_) {
if (db_index < db_value_status_array_.size() && db_value_status_array_[db_index].status.ok()) {
db_->cache()->WriteKVToCache(const_cast<std::string&>(key), db_value_status_array_[db_index].value, db_value_status_array_[db_index].ttl);
}
db_index++;
}
}

void MgetCmd::DoUpdateCache() {
for (size_t i = 0; i < keys_.size(); i++) {
void MgetCmd::MergeCachedAndDbResults() {
res_.AppendArrayLenUint64(keys_.size());

std::unordered_map<std::string, std::string> db_results_map;
for (size_t i = 0; i < cache_miss_keys_.size(); ++i) {
if (db_value_status_array_[i].status.ok()) {
db_->cache()->WriteKVToCache(keys_[i], db_value_status_array_[i].value, db_value_status_array_[i].ttl);
db_results_map[cache_miss_keys_[i]] = db_value_status_array_[i].value;
}
}

for (const auto& key : keys_) {
auto cache_it = cache_hit_values_.find(key);

if (cache_it != cache_hit_values_.end()) {
res_.AppendStringLen(cache_it->second.size());
res_.AppendContent(cache_it->second);
} else {
auto db_it = db_results_map.find(key);
if (db_it != db_results_map.end()) {
res_.AppendStringLen(db_it->second.size());
res_.AppendContent(db_it->second);
} else {
res_.AppendContent("$-1");
}
}
}
}


void KeysCmd::DoInitial() {
if (!CheckArg(argv_.size())) {
res_.SetRes(CmdRes::kWrongNum, kCmdNameKeys);
Expand Down
6 changes: 2 additions & 4 deletions src/pika_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1688,9 +1688,7 @@ void PikaServer::DoCacheBGTask(void* arg) {
}

db->cache()->SetCacheStatus(PIKA_CACHE_STATUS_OK);
if (pCacheTaskArg->reenable_cache) {
pCacheTaskArg->conf->UnsetCacheDisableFlag();
}
g_pika_conf->UnsetCacheDisableFlag();
}

void PikaServer::ResetCacheConfig(std::shared_ptr<DB> db) {
Expand All @@ -1710,7 +1708,7 @@ void PikaServer::ClearHitRatio(std::shared_ptr<DB> db) {

void PikaServer::OnCacheStartPosChanged(int zset_cache_start_direction, std::shared_ptr<DB> db) {
ResetCacheConfig(db);
ClearCacheDbAsync(db);
ClearCacheDbAsyncV2(db);
}

void PikaServer::ClearCacheDbAsyncV2(std::shared_ptr<DB> db) {
Expand Down
163 changes: 163 additions & 0 deletions tests/integration/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,4 +201,167 @@ var _ = Describe("Cache test", func() {
Expect(mGet4.Err()).NotTo(HaveOccurred())
Expect(mGet4.Val()).To(Equal([]interface{}{nil, nil, nil, nil}))
})

It("should mget for multi key in cache and db", func() {
multiset1 := client.Set(ctx, "key1", "a", 3000*time.Millisecond)
Expect(multiset1.Err()).NotTo(HaveOccurred())
Expect(multiset1.Val()).To(Equal("OK"))

multiset2 := client.Set(ctx, "key2", "b", 3000*time.Millisecond)
Expect(multiset2.Err()).NotTo(HaveOccurred())
Expect(multiset2.Val()).To(Equal("OK"))

multiset3 := client.Set(ctx, "key3", "c", 3000*time.Millisecond)
Expect(multiset3.Err()).NotTo(HaveOccurred())
Expect(multiset3.Val()).To(Equal("OK"))

multiset4 := client.Set(ctx, "key4", "d", 3000*time.Millisecond)
Expect(multiset4.Err()).NotTo(HaveOccurred())
Expect(multiset4.Val()).To(Equal("OK"))

multikey1 := client.MGet(ctx, "key1")
Expect(multikey1.Err()).NotTo(HaveOccurred())
Expect(multikey1.Val()).To(Equal([]interface{}{"a"}))

MultiKey2 := client.Get(ctx, "key1")
Expect(MultiKey2.Err()).NotTo(HaveOccurred())
Expect(MultiKey2.Val()).To(Equal("a"))

MultiMget := client.MGet(ctx, "key1", "key2", "key3", "key4")
Expect(MultiMget.Err()).NotTo(HaveOccurred())
Expect(MultiMget.Val()).To(Equal([]interface{}{"a", "b", "c", "d"}))
})

It("should mget for multi key in cache", func() {
multiset1 := client.Set(ctx, "key1", "a", 3000*time.Millisecond)
Expect(multiset1.Err()).NotTo(HaveOccurred())
Expect(multiset1.Val()).To(Equal("OK"))

multiset2 := client.Set(ctx, "key2", "b", 3000*time.Millisecond)
Expect(multiset2.Err()).NotTo(HaveOccurred())
Expect(multiset2.Val()).To(Equal("OK"))

multiset3 := client.Set(ctx, "key3", "c", 3000*time.Millisecond)
Expect(multiset3.Err()).NotTo(HaveOccurred())
Expect(multiset3.Val()).To(Equal("OK"))

multiset4 := client.Set(ctx, "key4", "d", 3000*time.Millisecond)
Expect(multiset4.Err()).NotTo(HaveOccurred())
Expect(multiset4.Val()).To(Equal("OK"))

multikey1 := client.MGet(ctx, "key1")
Expect(multikey1.Err()).NotTo(HaveOccurred())
Expect(multikey1.Val()).To(Equal([]interface{}{"a"}))

MultiKey2 := client.Get(ctx, "key1")
Expect(MultiKey2.Err()).NotTo(HaveOccurred())
Expect(MultiKey2.Val()).To(Equal("a"))

MultiMget := client.MGet(ctx, "key1", "key2", "key3", "key4")
Expect(MultiMget.Err()).NotTo(HaveOccurred())
Expect(MultiMget.Val()).To(Equal([]interface{}{"a", "b", "c", "d"}))
})

It("should mget for multi key in db", func() {
multiset1 := client.Set(ctx, "key1", "a", 3000*time.Millisecond)
Expect(multiset1.Err()).NotTo(HaveOccurred())
Expect(multiset1.Val()).To(Equal("OK"))

multiset2 := client.Set(ctx, "key2", "b", 3000*time.Millisecond)
Expect(multiset2.Err()).NotTo(HaveOccurred())
Expect(multiset2.Val()).To(Equal("OK"))

multiset3 := client.Set(ctx, "key3", "c", 3000*time.Millisecond)
Expect(multiset3.Err()).NotTo(HaveOccurred())
Expect(multiset3.Val()).To(Equal("OK"))

multiset4 := client.Set(ctx, "key4", "d", 3000*time.Millisecond)
Expect(multiset4.Err()).NotTo(HaveOccurred())
Expect(multiset4.Val()).To(Equal("OK"))

multikey1 := client.MGet(ctx, "key1")
Expect(multikey1.Err()).NotTo(HaveOccurred())
Expect(multikey1.Val()).To(Equal([]interface{}{"a"}))

MultiKey2 := client.Get(ctx, "key1")
Expect(MultiKey2.Err()).NotTo(HaveOccurred())
Expect(MultiKey2.Val()).To(Equal("a"))

multikey3 := client.MGet(ctx, "key2")
Expect(multikey3.Err()).NotTo(HaveOccurred())
Expect(multikey3.Val()).To(Equal([]interface{}{"b"}))

multikey4 := client.MGet(ctx, "key3")
Expect(multikey4.Err()).NotTo(HaveOccurred())
Expect(multikey4.Val()).To(Equal([]interface{}{"c"}))

multikey5 := client.MGet(ctx, "key4")
Expect(multikey5.Err()).NotTo(HaveOccurred())
Expect(multikey5.Val()).To(Equal([]interface{}{"d"}))

MultiMget := client.MGet(ctx, "key1", "key2", "key3", "key4")
Expect(MultiMget.Err()).NotTo(HaveOccurred())
Expect(MultiMget.Val()).To(Equal([]interface{}{"a", "b", "c", "d"}))
})

It("should mget for multi key in db", func() {
multiset1 := client.Set(ctx, "key1", "a", 3000*time.Millisecond)
Expect(multiset1.Err()).NotTo(HaveOccurred())
Expect(multiset1.Val()).To(Equal("OK"))

multiset2 := client.Set(ctx, "key2", "b", 3000*time.Millisecond)
Expect(multiset2.Err()).NotTo(HaveOccurred())
Expect(multiset2.Val()).To(Equal("OK"))

multiset3 := client.Set(ctx, "key3", "c", 3000*time.Millisecond)
Expect(multiset3.Err()).NotTo(HaveOccurred())
Expect(multiset3.Val()).To(Equal("OK"))

multiset4 := client.Set(ctx, "key4", "d", 3000*time.Millisecond)
Expect(multiset4.Err()).NotTo(HaveOccurred())
Expect(multiset4.Val()).To(Equal("OK"))

MultiMget := client.MGet(ctx, "key1", "key2", "key3", "key4")
Expect(MultiMget.Err()).NotTo(HaveOccurred())
Expect(MultiMget.Val()).To(Equal([]interface{}{"a", "b", "c", "d"}))
})

It("MGET against non existing key", func() {
multiset1 := client.Set(ctx, "key1", "a", 3000*time.Millisecond)
Expect(multiset1.Err()).NotTo(HaveOccurred())
Expect(multiset1.Val()).To(Equal("OK"))

multiset3 := client.Set(ctx, "key3", "c", 3000*time.Millisecond)
Expect(multiset3.Err()).NotTo(HaveOccurred())
Expect(multiset3.Val()).To(Equal("OK"))

multiset4 := client.Set(ctx, "key4", "d", 3000*time.Millisecond)
Expect(multiset4.Err()).NotTo(HaveOccurred())
Expect(multiset4.Val()).To(Equal("OK"))

MultiMget := client.MGet(ctx, "key1", "key2", "key3", "key4")
Expect(MultiMget.Err()).NotTo(HaveOccurred())
Expect(MultiMget.Val()).To(Equal([]interface{}{"a", nil, "c", "d"}))
})
It("MGET against non-string key", func() {
SetMultiKey := client.Set(ctx, "foo{t}", "BAR", 3000*time.Millisecond)
Expect(SetMultiKey.Err()).NotTo(HaveOccurred())
Expect(SetMultiKey.Val()).To(Equal("OK"))

SetMultiKey1 := client.Set(ctx, "bar{t}", "FOO", 3000*time.Millisecond)
Expect(SetMultiKey1.Err()).NotTo(HaveOccurred())
Expect(SetMultiKey1.Val()).To(Equal("OK"))

SaddMultiKey := client.SAdd(ctx, "myset{t}", "ciao")
Expect(SaddMultiKey.Err()).NotTo(HaveOccurred())
Expect(SaddMultiKey.Val()).To(Equal(int64(1)))

SaddMultiKey1 := client.SAdd(ctx, "myset{t}", "bau")
Expect(SaddMultiKey1.Err()).NotTo(HaveOccurred())
Expect(SaddMultiKey1.Val()).To(Equal(int64(1)))

MultiMget := client.MGet(ctx, "foo{t}", "baazz{t}", "bar{t}", "myset{t}")
Expect(MultiMget.Err()).NotTo(HaveOccurred())
Expect(MultiMget.Val()).To(Equal([]interface{}{"BAR", nil, "FOO", nil}))
})
})

0 comments on commit 1bb76ed

Please sign in to comment.