Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Tasks validation to C++ SDK #2543

Merged
merged 15 commits into from
Nov 6, 2023
Merged
2 changes: 1 addition & 1 deletion examples/quickstart-cpp/include/simple_client.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/***********************************************************************************************************
*
* @file libtorch_client.h
* @file simple_client.h
*
* @brief Define an example flower client, train and test method
*
Expand Down
10 changes: 5 additions & 5 deletions src/cc/flwr/include/grpc_rere.h
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
/*************************************************************************************************
*
* @file start.h
* @file grpc-rere.h
*
* @brief Create a gRPC channel to connect to the server and enable message
*communication
* @brief Provide functions for establishing gRPC request-response communication
*
* @author Lekang Jiang
* @author The Flower Authors
*
* @version 1.0
*
* @date 06/09/2021
* @date 06/11/2023
*
*************************************************************************************************/

#ifndef GRPC_RERE_H
#define GRPC_RERE_H
#pragma once
#include "message_handler.h"
#include "task_handler.h"
#include <grpcpp/grpcpp.h>

void create_node(const std::unique_ptr<flwr::proto::Fleet::Stub> &stub);
Expand Down
24 changes: 24 additions & 0 deletions src/cc/flwr/include/task_handler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*************************************************************************************************
*
* @file task_handler.h
*
* @brief Handle incoming or outgoing tasks
*
* @author The Flower Authors
*
* @version 1.0
*
* @date 06/11/2023
*
*************************************************************************************************/

#pragma once
#include "client.h"
#include "serde.h"

bool validate_task_ins(const flwr::proto::TaskIns &task_ins,
const bool discard_reconnect_ins);
bool validate_task_res(const flwr::proto::TaskRes &task_res);
flwr::proto::TaskRes configure_task_res(const flwr::proto::TaskRes &task_res,
const flwr::proto::TaskIns &task_ins,
const flwr::proto::Node &node);
23 changes: 12 additions & 11 deletions src/cc/flwr/src/grpc_rere.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,14 @@ receive(const std::unique_ptr<flwr::proto::Fleet::Stub> &stub) {

if (response.task_ins_list_size() > 0) {
flwr::proto::TaskIns task_ins = response.task_ins_list().at(0);
// TODO: Validate TaskIns

{
if (validate_task_ins(task_ins, true)) {
std::lock_guard<std::mutex> state_lock(state_mutex);
state[KEY_TASK_INS] = task_ins;
return task_ins;
}

return task_ins;
} else {
std::cerr << "TaskIns list is empty." << std::endl;
return std::nullopt;
}
std::cerr << "TaskIns list is empty." << std::endl;
return std::nullopt;
}

void send(const std::unique_ptr<flwr::proto::Fleet::Stub> &stub,
Expand All @@ -136,7 +132,12 @@ void send(const std::unique_ptr<flwr::proto::Fleet::Stub> &stub,
return;
}

// TODO: Validate TaskIns
if (!validate_task_res(task_res)) {
std::cerr << "TaskRes is invalid" << std::endl;
std::lock_guard<std::mutex> state_lock(state_mutex);
state[KEY_TASK_INS].reset();
return;
}

flwr::proto::TaskRes new_task_res =
configure_task_res(task_res, *task_ins, *node);
Expand All @@ -151,8 +152,8 @@ void send(const std::unique_ptr<flwr::proto::Fleet::Stub> &stub,
if (!status.ok()) {
std::cerr << "PushTaskRes RPC failed with status: "
<< status.error_message() << std::endl;
return;
} else {
}
{
std::lock_guard<std::mutex> state_lock(state_mutex);
state[KEY_TASK_INS].reset();
}
Expand Down
31 changes: 0 additions & 31 deletions src/cc/flwr/src/message_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,34 +102,3 @@ handle_task(flwr_local::Client *client, const flwr::proto::TaskIns &task_ins) {
return std::make_tuple(task_res, std::get<1>(legacy_res),
std::get<2>(legacy_res));
}

flwr::proto::TaskRes
configure_task_res(const flwr::proto::TaskRes &task_res,
const flwr::proto::TaskIns &ref_task_ins,
const flwr::proto::Node &producer) {
flwr::proto::TaskRes result_task_res;

// Setting scalar fields
result_task_res.set_task_id(""); // This will be generated by the server
result_task_res.set_group_id(ref_task_ins.group_id());
result_task_res.set_workload_id(ref_task_ins.workload_id());

// Merge the task from the input task_res
*result_task_res.mutable_task() = task_res.task();

// Construct and set the producer and consumer for the task
std::unique_ptr<flwr::proto::Node> new_producer =
std::make_unique<flwr::proto::Node>(producer);
result_task_res.mutable_task()->set_allocated_producer(
new_producer.release());

std::unique_ptr<flwr::proto::Node> new_consumer =
std::make_unique<flwr::proto::Node>(ref_task_ins.task().producer());
result_task_res.mutable_task()->set_allocated_consumer(
new_consumer.release());

// Set ancestry in the task
result_task_res.mutable_task()->add_ancestry(ref_task_ins.task_id());

return result_task_res;
}
52 changes: 52 additions & 0 deletions src/cc/flwr/src/task_handler.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#include "task_handler.h"

bool validate_task_ins(const flwr::proto::TaskIns &task_ins,
const bool discard_reconnect_ins) {
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
return !(!task_ins.has_task() ||
(!task_ins.task().has_legacy_server_message() &&
!task_ins.task().has_sa()) ||
(discard_reconnect_ins &&
task_ins.task().legacy_server_message().has_reconnect_ins()));
#pragma GCC diagnostic pop
}

bool validate_task_res(const flwr::proto::TaskRes &task_res) {
// Retrieve initialized fields in TaskRes
return (task_res.task_id().empty() && task_res.group_id().empty() &&
task_res.workload_id() == 0 && !task_res.task().has_producer() &&
!task_res.task().has_producer() && !task_res.task().has_consumer() &&
task_res.task().ancestry_size() == 0);
}

flwr::proto::TaskRes
configure_task_res(const flwr::proto::TaskRes &task_res,
const flwr::proto::TaskIns &ref_task_ins,
const flwr::proto::Node &producer) {
flwr::proto::TaskRes result_task_res;

// Setting scalar fields
result_task_res.set_task_id(""); // This will be generated by the server
result_task_res.set_group_id(ref_task_ins.group_id());
result_task_res.set_workload_id(ref_task_ins.workload_id());

// Merge the task from the input task_res
*result_task_res.mutable_task() = task_res.task();

// Construct and set the producer and consumer for the task
std::unique_ptr<flwr::proto::Node> new_producer =
std::make_unique<flwr::proto::Node>(producer);
result_task_res.mutable_task()->set_allocated_producer(
new_producer.release());

std::unique_ptr<flwr::proto::Node> new_consumer =
std::make_unique<flwr::proto::Node>(ref_task_ins.task().producer());
result_task_res.mutable_task()->set_allocated_consumer(
new_consumer.release());

// Set ancestry in the task
result_task_res.mutable_task()->add_ancestry(ref_task_ins.task_id());

return result_task_res;
}