Skip to content

Commit

Permalink
GH-40155: [Go][FlightRPC][FlightSQL] Implement Session Management (#4…
Browse files Browse the repository at this point in the history
…0284)

### Rationale for this change

Brings Go implementation in parity with recent session management additions to Java and C++: #34865

### What changes are included in this PR?

- Go Flight/FlightSQL implementations of session management RPC handlers
- Implementation of cookie-based session middleware
  - Implementation of stateful (id-lookup based) sessions/tokens
  - Implementation of stateless (fully encoded) sessions/tokens
- Fix minor C++ logic bug when closing sessions
- Update Java integration test server to return an empty session if `getSessionOptions` is called before `setSessionOptions`
- Refactor of `DoAction` handlers to consolidate the code that is essentially copied between them.
  - As part of this I found an issue with `CancelFlightInfo` where a copy of the message was being returned instead of a pointer as is typically the case with `proto.Message`'s. I updated the return type and any usage throughout the code base as part of the refactor.

### Are these changes tested?

Yes, both integration and unit tests are included.

A few tests were added in the Go integration suite beyond the existing coverage in the Java/C++ suites. These tests aim to demonstrate my understanding of session semantics in those scenarios, please let me know if you believe the details are not accurate.

Some of the new integration tests failed in the Java/C++ scenarios. I made very minor changes to those implementations to fix certain failures but there are still some remaining bugs (assuming these are testing the right semantics). Specifically:
- The integration test for reopening a previously closed session passes on Go/Java, but fails for C++ so it is commented out.
- This implementation prefers to set any cookies in the gRPC trailer which works fine for Go/C++, but not for Java. As a temporary workaround this implementation will _also_ set the cookie in the gRPC header if a new session was created. This is sufficient to maintain compatibility with Java stateful sessions where the session ID token can be known at the time of creation, but is not robust to other scenarios such as stateless sessions where in many cases the token cannot be known until after the RPC has completed.

### Are there any user-facing changes?
Yes, session management RPC as well as middleware implementations are included. Functionality is entirely additive

* GitHub Issue: #40155

Authored-by: joel <[email protected]>
Signed-off-by: Matt Topol <[email protected]>
  • Loading branch information
joellubi authored Mar 1, 2024
1 parent 30e6d72 commit 81c9d30
Show file tree
Hide file tree
Showing 18 changed files with 3,106 additions and 730 deletions.
2 changes: 1 addition & 1 deletion cpp/src/arrow/flight/sql/server_session_middleware.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class ServerSessionMiddlewareImpl : public ServerSessionMiddleware {

Status CloseSession() override {
const std::lock_guard<std::shared_mutex> l(mutex_);
if (static_cast<bool>(session_)) {
if (!static_cast<bool>(session_)) {
return Status::Invalid("Nonexistent session cannot be closed.");
}
ARROW_RETURN_NOT_OK(factory_->CloseSession(session_id_));
Expand Down
2 changes: 1 addition & 1 deletion dev/archery/archery/integration/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ def run_all_tests(with_cpp=True, with_java=True, with_js=True,
Scenario(
"session_options",
description="Ensure Flight SQL Sessions work as expected.",
skip_testers={"JS", "C#", "Rust", "Go"}
skip_testers={"JS", "C#", "Rust"}
),
Scenario(
"poll_flight_info",
Expand Down
90 changes: 60 additions & 30 deletions go/arrow/flight/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,12 @@ type Client interface {
// in order to use the Handshake endpoints of the service.
Authenticate(context.Context, ...grpc.CallOption) error
AuthenticateBasicToken(ctx context.Context, username string, password string, opts ...grpc.CallOption) (context.Context, error)
CancelFlightInfo(ctx context.Context, request *CancelFlightInfoRequest, opts ...grpc.CallOption) (CancelFlightInfoResult, error)
CancelFlightInfo(ctx context.Context, request *CancelFlightInfoRequest, opts ...grpc.CallOption) (*CancelFlightInfoResult, error)
Close() error
RenewFlightEndpoint(ctx context.Context, request *RenewFlightEndpointRequest, opts ...grpc.CallOption) (*FlightEndpoint, error)
SetSessionOptions(ctx context.Context, request *SetSessionOptionsRequest, opts ...grpc.CallOption) (*SetSessionOptionsResult, error)
GetSessionOptions(ctx context.Context, request *GetSessionOptionsRequest, opts ...grpc.CallOption) (*GetSessionOptionsResult, error)
CloseSession(ctx context.Context, request *CloseSessionRequest, opts ...grpc.CallOption) (*CloseSessionResult, error)
// join the interface from the FlightServiceClient instead of re-defining all
// the endpoints here.
FlightServiceClient
Expand Down Expand Up @@ -364,26 +367,14 @@ func ReadUntilEOF(stream FlightService_DoActionClient) error {
}
}

func (c *client) CancelFlightInfo(ctx context.Context, request *CancelFlightInfoRequest, opts ...grpc.CallOption) (result CancelFlightInfoResult, err error) {
var action flight.Action
action.Type = CancelFlightInfoActionType
action.Body, err = proto.Marshal(request)
if err != nil {
return
}
stream, err := c.DoAction(ctx, &action, opts...)
if err != nil {
return
}
res, err := stream.Recv()
func (c *client) CancelFlightInfo(ctx context.Context, request *CancelFlightInfoRequest, opts ...grpc.CallOption) (*CancelFlightInfoResult, error) {
var result CancelFlightInfoResult
err := handleAction(ctx, c, CancelFlightInfoActionType, request, &result, opts...)
if err != nil {
return
}
if err = proto.Unmarshal(res.Body, &result); err != nil {
return
return nil, err
}
err = ReadUntilEOF(stream)
return

return &result, err
}

func (c *client) Close() error {
Expand All @@ -395,29 +386,68 @@ func (c *client) Close() error {
}

func (c *client) RenewFlightEndpoint(ctx context.Context, request *RenewFlightEndpointRequest, opts ...grpc.CallOption) (*FlightEndpoint, error) {
var err error
var action flight.Action
action.Type = RenewFlightEndpointActionType
action.Body, err = proto.Marshal(request)
var result FlightEndpoint
err := handleAction(ctx, c, RenewFlightEndpointActionType, request, &result, opts...)
if err != nil {
return nil, err
}
stream, err := c.DoAction(ctx, &action, opts...)

return &result, err
}

func (c *client) SetSessionOptions(ctx context.Context, request *SetSessionOptionsRequest, opts ...grpc.CallOption) (*SetSessionOptionsResult, error) {
var result SetSessionOptionsResult
err := handleAction(ctx, c, SetSessionOptionsActionType, request, &result, opts...)
if err != nil {
return nil, err
}
res, err := stream.Recv()

return &result, err
}

func (c *client) GetSessionOptions(ctx context.Context, request *GetSessionOptionsRequest, opts ...grpc.CallOption) (*GetSessionOptionsResult, error) {
var result GetSessionOptionsResult
err := handleAction(ctx, c, GetSessionOptionsActionType, request, &result, opts...)
if err != nil {
return nil, err
}
var renewedEndpoint FlightEndpoint
err = proto.Unmarshal(res.Body, &renewedEndpoint)

return &result, err
}

func (c *client) CloseSession(ctx context.Context, request *CloseSessionRequest, opts ...grpc.CallOption) (*CloseSessionResult, error) {
var result CloseSessionResult
err := handleAction(ctx, c, CloseSessionActionType, request, &result, opts...)
if err != nil {
return nil, err
}
err = ReadUntilEOF(stream)

return &result, err
}

func handleAction[T, U proto.Message](ctx context.Context, client FlightServiceClient, name string, request T, response U, opts ...grpc.CallOption) error {
var (
action flight.Action
err error
)

action.Type = name
action.Body, err = proto.Marshal(request)
if err != nil {
return nil, err
return err
}
return &renewedEndpoint, nil
stream, err := client.DoAction(ctx, &action, opts...)
if err != nil {
return err
}
res, err := stream.Recv()
if err != nil {
return err
}
err = proto.Unmarshal(res.Body, response)
if err != nil {
return err
}

return ReadUntilEOF(stream)
}
14 changes: 13 additions & 1 deletion go/arrow/flight/flightsql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -584,14 +584,26 @@ func (c *Client) CancelQuery(ctx context.Context, info *flight.FlightInfo, opts
return
}

func (c *Client) CancelFlightInfo(ctx context.Context, request *flight.CancelFlightInfoRequest, opts ...grpc.CallOption) (flight.CancelFlightInfoResult, error) {
func (c *Client) CancelFlightInfo(ctx context.Context, request *flight.CancelFlightInfoRequest, opts ...grpc.CallOption) (*flight.CancelFlightInfoResult, error) {
return c.Client.CancelFlightInfo(ctx, request, opts...)
}

func (c *Client) RenewFlightEndpoint(ctx context.Context, request *flight.RenewFlightEndpointRequest, opts ...grpc.CallOption) (*flight.FlightEndpoint, error) {
return c.Client.RenewFlightEndpoint(ctx, request, opts...)
}

func (c *Client) SetSessionOptions(ctx context.Context, request *flight.SetSessionOptionsRequest, opts ...grpc.CallOption) (*flight.SetSessionOptionsResult, error) {
return c.Client.SetSessionOptions(ctx, request, opts...)
}

func (c *Client) GetSessionOptions(ctx context.Context, request *flight.GetSessionOptionsRequest, opts ...grpc.CallOption) (*flight.GetSessionOptionsResult, error) {
return c.Client.GetSessionOptions(ctx, request, opts...)
}

func (c *Client) CloseSession(ctx context.Context, request *flight.CloseSessionRequest, opts ...grpc.CallOption) (*flight.CloseSessionResult, error) {
return c.Client.CloseSession(ctx, request, opts...)
}

func (c *Client) BeginTransaction(ctx context.Context, opts ...grpc.CallOption) (*Txn, error) {
request := &pb.ActionBeginTransactionRequest{}
action, err := packAction(BeginTransactionActionType, request)
Expand Down
25 changes: 20 additions & 5 deletions go/arrow/flight/flightsql/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,31 @@ func (m *FlightServiceClientMock) AuthenticateBasicToken(_ context.Context, user
return args.Get(0).(context.Context), args.Error(1)
}

func (m *FlightServiceClientMock) CancelFlightInfo(ctx context.Context, request *flight.CancelFlightInfoRequest, opts ...grpc.CallOption) (flight.CancelFlightInfoResult, error) {
func (m *FlightServiceClientMock) CancelFlightInfo(ctx context.Context, request *flight.CancelFlightInfoRequest, opts ...grpc.CallOption) (*flight.CancelFlightInfoResult, error) {
args := m.Called(request, opts)
return args.Get(0).(flight.CancelFlightInfoResult), args.Error(1)
return args.Get(0).(*flight.CancelFlightInfoResult), args.Error(1)
}

func (m *FlightServiceClientMock) RenewFlightEndpoint(ctx context.Context, request *flight.RenewFlightEndpointRequest, opts ...grpc.CallOption) (*flight.FlightEndpoint, error) {
args := m.Called(request, opts)
return args.Get(0).(*flight.FlightEndpoint), args.Error(1)
}

func (m *FlightServiceClientMock) SetSessionOptions(ctx context.Context, request *flight.SetSessionOptionsRequest, opts ...grpc.CallOption) (*flight.SetSessionOptionsResult, error) {
args := m.Called(request, opts)
return args.Get(0).(*flight.SetSessionOptionsResult), args.Error(1)
}

func (m *FlightServiceClientMock) GetSessionOptions(ctx context.Context, request *flight.GetSessionOptionsRequest, opts ...grpc.CallOption) (*flight.GetSessionOptionsResult, error) {
args := m.Called(request, opts)
return args.Get(0).(*flight.GetSessionOptionsResult), args.Error(1)
}

func (m *FlightServiceClientMock) CloseSession(ctx context.Context, request *flight.CloseSessionRequest, opts ...grpc.CallOption) (*flight.CloseSessionResult, error) {
args := m.Called(request, opts)
return args.Get(0).(*flight.CloseSessionResult), args.Error(1)
}

func (m *FlightServiceClientMock) Close() error {
return m.Called().Error(0)
}
Expand Down Expand Up @@ -639,10 +654,10 @@ func (s *FlightSqlClientSuite) TestCancelFlightInfo() {
mockedCancelResult := flight.CancelFlightInfoResult{
Status: flight.CancelStatusCancelled,
}
s.mockClient.On("CancelFlightInfo", &request, s.callOpts).Return(mockedCancelResult, nil)
s.mockClient.On("CancelFlightInfo", &request, s.callOpts).Return(&mockedCancelResult, nil)
cancelResult, err := s.sqlClient.CancelFlightInfo(context.TODO(), &request, s.callOpts...)
s.NoError(err)
s.Equal(mockedCancelResult, cancelResult)
s.Equal(&mockedCancelResult, cancelResult)
}

func (s *FlightSqlClientSuite) TestRenewFlightEndpoint() {
Expand Down Expand Up @@ -671,7 +686,7 @@ func (s *FlightSqlClientSuite) TestPreparedStatementLoadFromResult() {
result := &pb.ActionCreatePreparedStatementResult{
PreparedStatementHandle: []byte(query),
}

parameterSchemaResult := arrow.NewSchema([]arrow.Field{{Name: "p_id", Type: arrow.PrimitiveTypes.Int64, Nullable: true}}, nil)
result.ParameterSchema = flight.SerializeSchema(parameterSchemaResult, memory.DefaultAllocator)
datasetSchemaResult := arrow.NewSchema([]arrow.Field{{Name: "ds_id", Type: arrow.PrimitiveTypes.Int64, Nullable: true}}, nil)
Expand Down
81 changes: 81 additions & 0 deletions go/arrow/flight/flightsql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,18 @@ func (BaseServer) EndSavepoint(context.Context, ActionEndSavepointRequest) error
return status.Error(codes.Unimplemented, "EndSavepoint not implemented")
}

func (BaseServer) SetSessionOptions(context.Context, *flight.SetSessionOptionsRequest) (*flight.SetSessionOptionsResult, error) {
return nil, status.Error(codes.Unimplemented, "SetSessionOptions not implemented")
}

func (BaseServer) GetSessionOptions(context.Context, *flight.GetSessionOptionsRequest) (*flight.GetSessionOptionsResult, error) {
return nil, status.Error(codes.Unimplemented, "GetSessionOptions not implemented")
}

func (BaseServer) CloseSession(context.Context, *flight.CloseSessionRequest) (*flight.CloseSessionResult, error) {
return nil, status.Error(codes.Unimplemented, "CloseSession not implemented")
}

// Server is the required interface for a FlightSQL server. It is implemented by
// BaseServer which must be embedded in any implementation. The default
// implementation by BaseServer for each of these (except GetSqlInfo)
Expand Down Expand Up @@ -676,6 +688,12 @@ type Server interface {
PollFlightInfoSubstraitPlan(context.Context, StatementSubstraitPlan, *flight.FlightDescriptor) (*flight.PollInfo, error)
// PollFlightInfoPreparedStatement handles polling for query execution.
PollFlightInfoPreparedStatement(context.Context, PreparedStatementQuery, *flight.FlightDescriptor) (*flight.PollInfo, error)
// SetSessionOptions sets option(s) for the current server session.
SetSessionOptions(context.Context, *flight.SetSessionOptionsRequest) (*flight.SetSessionOptionsResult, error)
// GetSessionOptions gets option(s) for the current server session.
GetSessionOptions(context.Context, *flight.GetSessionOptionsRequest) (*flight.GetSessionOptionsResult, error)
// CloseSession closes/invalidates the current server session.
CloseSession(context.Context, *flight.CloseSessionRequest) (*flight.CloseSessionResult, error)

mustEmbedBaseServer()
}
Expand Down Expand Up @@ -1262,6 +1280,69 @@ func (f *flightSqlServer) DoAction(cmd *flight.Action, stream flight.FlightServi
}

return stream.Send(&pb.Result{})
case flight.SetSessionOptionsActionType:
var (
request flight.SetSessionOptionsRequest
err error
)

if err = proto.Unmarshal(cmd.Body, &request); err != nil {
return status.Errorf(codes.InvalidArgument, "unable to unmarshal SetSessionOptionsRequest: %s", err.Error())
}

response, err := f.srv.SetSessionOptions(stream.Context(), &request)
if err != nil {
return err
}

out := &pb.Result{}
out.Body, err = proto.Marshal(response)
if err != nil {
return err
}
return stream.Send(out)
case flight.GetSessionOptionsActionType:
var (
request flight.GetSessionOptionsRequest
err error
)

if err = proto.Unmarshal(cmd.Body, &request); err != nil {
return status.Errorf(codes.InvalidArgument, "unable to unmarshal GetSessionOptionsRequest: %s", err.Error())
}

response, err := f.srv.GetSessionOptions(stream.Context(), &request)
if err != nil {
return err
}

out := &pb.Result{}
out.Body, err = proto.Marshal(response)
if err != nil {
return err
}
return stream.Send(out)
case flight.CloseSessionActionType:
var (
request flight.CloseSessionRequest
err error
)

if err = proto.Unmarshal(cmd.Body, &request); err != nil {
return status.Errorf(codes.InvalidArgument, "unable to unmarshal CloseSessionRequest: %s", err.Error())
}

response, err := f.srv.CloseSession(stream.Context(), &request)
if err != nil {
return err
}

out := &pb.Result{}
out.Body, err = proto.Marshal(response)
if err != nil {
return err
}
return stream.Send(out)
default:
return status.Error(codes.InvalidArgument, "the defined request is invalid.")
}
Expand Down
Loading

0 comments on commit 81c9d30

Please sign in to comment.