Skip to content

Commit

Permalink
feat(go/adbc/driver/flightsql): support session options
Browse files Browse the repository at this point in the history
  • Loading branch information
lidavidm committed Mar 7, 2024
1 parent a3579ff commit 24026d1
Show file tree
Hide file tree
Showing 5 changed files with 563 additions and 25 deletions.
2 changes: 1 addition & 1 deletion c/validation/adbc_validation_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ void ConnectionTest::TestMetadataCurrentCatalog() {
ASSERT_THAT(
AdbcConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_CURRENT_CATALOG,
buffer, &buffer_size, &error),
IsStatus(ADBC_STATUS_NOT_FOUND));
IsStatus(ADBC_STATUS_NOT_FOUND, &error));
}
}

Expand Down
217 changes: 217 additions & 0 deletions go/adbc/driver/flightsql/flightsql_adbc_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package flightsql_test

import (
"context"
"encoding/json"
"errors"
"fmt"
"net"
Expand All @@ -41,6 +42,7 @@ import (
"github.com/apache/arrow/go/v16/arrow/flight"
"github.com/apache/arrow/go/v16/arrow/flight/flightsql"
"github.com/apache/arrow/go/v16/arrow/flight/flightsql/schema_ref"
flightproto "github.com/apache/arrow/go/v16/arrow/flight/gen/flight"
"github.com/apache/arrow/go/v16/arrow/memory"
"github.com/golang/protobuf/ptypes/wrappers"
"github.com/stretchr/testify/suite"
Expand Down Expand Up @@ -134,6 +136,10 @@ func TestMultiTable(t *testing.T) {
suite.Run(t, &MultiTableTests{})
}

func TestSessionOptions(t *testing.T) {
suite.Run(t, &SessionOptionTests{})
}

// ---- AuthN Tests --------------------

type AuthnTestServer struct {
Expand Down Expand Up @@ -1654,3 +1660,214 @@ func (suite *MultiTableTests) TestGetTableSchema() {
expectedSchema := arrow.NewSchema([]arrow.Field{{Name: "b", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
suite.Equal(expectedSchema, actualSchema)
}

// ---- Session Option Tests --------------------

type SessionOptionTestServer struct {
flightsql.BaseServer
options map[string]interface{}
}

func (server *SessionOptionTestServer) GetSessionOptions(ctx context.Context, req *flight.GetSessionOptionsRequest) (*flight.GetSessionOptionsResult, error) {
options := make(map[string]*flight.SessionOptionValue)
for k, v := range server.options {
switch s := v.(type) {
case bool:
options[k] = &flight.SessionOptionValue{OptionValue: &flightproto.SessionOptionValue_BoolValue{BoolValue: s}}
case float64:
options[k] = &flight.SessionOptionValue{OptionValue: &flightproto.SessionOptionValue_DoubleValue{DoubleValue: s}}
case int64:
options[k] = &flight.SessionOptionValue{OptionValue: &flightproto.SessionOptionValue_Int64Value{Int64Value: s}}
case string:
options[k] = &flight.SessionOptionValue{OptionValue: &flightproto.SessionOptionValue_StringValue{StringValue: s}}
case []string:
options[k] = &flight.SessionOptionValue{OptionValue: &flightproto.SessionOptionValue_StringListValue_{StringListValue: &flightproto.SessionOptionValue_StringListValue{Values: s}}}
case nil:
options[k] = &flight.SessionOptionValue{}
default:
panic("not implemented")
}
}
return &flight.GetSessionOptionsResult{
SessionOptions: options,
}, nil
}

func (server *SessionOptionTestServer) SetSessionOptions(ctx context.Context, req *flight.SetSessionOptionsRequest) (*flight.SetSessionOptionsResult, error) {
errors := map[string]*flightproto.SetSessionOptionsResult_Error{}
for k, v := range req.SessionOptions {
switch k {
case "bad name":
errors[k] = &flightproto.SetSessionOptionsResult_Error{Value: flightproto.SetSessionOptionsResult_INVALID_NAME}
continue
case "bad value":
errors[k] = &flightproto.SetSessionOptionsResult_Error{Value: flightproto.SetSessionOptionsResult_INVALID_VALUE}
continue
case "error":
errors[k] = &flightproto.SetSessionOptionsResult_Error{Value: flightproto.SetSessionOptionsResult_ERROR}
continue
}
switch s := v.GetOptionValue().(type) {
case *flightproto.SessionOptionValue_BoolValue:
server.options[k] = s.BoolValue
case *flightproto.SessionOptionValue_DoubleValue:
server.options[k] = s.DoubleValue
case *flightproto.SessionOptionValue_Int64Value:
server.options[k] = s.Int64Value
case *flightproto.SessionOptionValue_StringValue:
server.options[k] = s.StringValue
case *flightproto.SessionOptionValue_StringListValue_:
server.options[k] = s.StringListValue.Values
case nil:
delete(server.options, k)
default:
return nil, status.Error(codes.InvalidArgument, "invalid option type")
}
}
return &flight.SetSessionOptionsResult{Errors: errors}, nil
}

func (server *SessionOptionTestServer) CloseSession(ctx context.Context, req *flight.CloseSessionRequest) (*flight.CloseSessionResult, error) {
return &flight.CloseSessionResult{
Status: flight.CloseSessionResultClosed,
}, nil
}

type SessionOptionTests struct {
ServerBasedTests
}

func (suite *SessionOptionTests) SetupSuite() {
suite.DoSetupSuite(&SessionOptionTestServer{
options: map[string]interface{}{
"string": "expected",
"bool": true,
"float64": float64(1.5),
"int64": int64(20),
"catalog": "main",
"schema": "session",
"stringlist": []string{"a", "b", "c"},
"nilopt": nil,
},
}, nil, map[string]string{})
}

func (suite *SessionOptionTests) TestGetAllOptions() {
val, err := suite.cnxn.(adbc.GetSetOptions).GetOption(driver.OptionSessionOptions)
suite.NoError(err)

options := make(map[string]interface{})
suite.NoError(json.Unmarshal([]byte(val), &options))
// XXX: because Go decodes ints to strings by default. Should we use
// an alternate representation? What happens to int64max?
suite.Equal(float64(20), options["int64"])
suite.Equal("expected", options["string"])
// Bit of a hack, but lets servers send "this option exists, but is
// not set" by returning a nil/unset value
suite.Nil(options["nilopt"])
}

func (suite *SessionOptionTests) TestGetAllOptionsByte() {
val, err := suite.cnxn.(adbc.GetSetOptions).GetOptionBytes(driver.OptionSessionOptions)
suite.NoError(err)

options := make(map[string]interface{})
// XXX: maybe we can return the underlying proto repr here?
suite.NoError(json.Unmarshal(val, &options))
suite.Equal(float64(20), options["int64"])
suite.Equal("expected", options["string"])
}

func (suite *SessionOptionTests) TestGetSetCatalog() {
val, err := suite.cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentCatalog)
suite.NoError(err)
suite.Equal("main", val)

suite.NoError(suite.cnxn.(adbc.GetSetOptions).SetOption(adbc.OptionKeyCurrentCatalog, "postgres"))
val, err = suite.cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentCatalog)
suite.NoError(err)
suite.Equal("postgres", val)
}

func (suite *SessionOptionTests) TestGetSetSchema() {
val, err := suite.cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentDbSchema)
suite.NoError(err)
suite.Equal("session", val)

suite.NoError(suite.cnxn.(adbc.GetSetOptions).SetOption(adbc.OptionKeyCurrentDbSchema, "public"))
val, err = suite.cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentDbSchema)
suite.NoError(err)
suite.Equal("public", val)
}

func (suite *SessionOptionTests) TestGetSetBool() {
o := suite.cnxn.(adbc.GetSetOptions)
val, err := o.GetOption(driver.OptionBoolSessionOptionPrefix + "bool")
suite.NoError(err)
suite.Equal("true", val)

suite.NoError(o.SetOption(driver.OptionBoolSessionOptionPrefix+"bool", "false"))
val, err = o.GetOption(driver.OptionBoolSessionOptionPrefix + "bool")
suite.NoError(err)
suite.Equal("false", val)
}

func (suite *SessionOptionTests) TestGetSetFloat64() {
o := suite.cnxn.(adbc.GetSetOptions)
val, err := o.GetOptionDouble(driver.OptionSessionOptionPrefix + "float64")
suite.NoError(err)
suite.Equal(1.5, val)

suite.NoError(o.SetOptionDouble(driver.OptionSessionOptionPrefix+"float64", -42.0))
val, err = o.GetOptionDouble(driver.OptionSessionOptionPrefix + "float64")
suite.NoError(err)
suite.Equal(-42.0, val)
}

func (suite *SessionOptionTests) TestGetSetInt64() {
o := suite.cnxn.(adbc.GetSetOptions)
val, err := o.GetOptionInt(driver.OptionSessionOptionPrefix + "int64")
suite.NoError(err)
suite.Equal(int64(20), val)

suite.NoError(o.SetOptionInt(driver.OptionSessionOptionPrefix+"int64", 128))
val, err = o.GetOptionInt(driver.OptionSessionOptionPrefix + "int64")
suite.NoError(err)
suite.Equal(int64(128), val)
}

func (suite *SessionOptionTests) TestGetSetString() {
o := suite.cnxn.(adbc.GetSetOptions)
_, err := o.GetOption(driver.OptionSessionOptionPrefix + "unknown")
suite.ErrorContains(err, "unknown session option 'unknown'")

suite.NoError(o.SetOption(driver.OptionSessionOptionPrefix+"unknown", "42"))
val, err := o.GetOption(driver.OptionSessionOptionPrefix + "unknown")
suite.NoError(err)
suite.Equal("42", val)

suite.NoError(o.SetOption(driver.OptionEraseSessionOptionPrefix+"unknown", ""))
_, err = o.GetOption(driver.OptionSessionOptionPrefix + "unknown")
suite.ErrorContains(err, "unknown session option 'unknown'")

suite.ErrorContains(o.SetOption(driver.OptionSessionOptionPrefix+"bad name", ""), "Could not set option(s) 'bad name' (invalid name)")
suite.ErrorContains(o.SetOption(driver.OptionSessionOptionPrefix+"bad value", ""), "Could not set option(s) 'bad value' (invalid value)")
suite.ErrorContains(o.SetOption(driver.OptionSessionOptionPrefix+"error", ""), "Could not set option(s) 'error' (error setting option)")
}

func (suite *SessionOptionTests) TestGetSetStringList() {
o := suite.cnxn.(adbc.GetSetOptions)
val, err := o.GetOption(driver.OptionStringListSessionOptionPrefix + "stringlist")
suite.NoError(err)
suite.Equal(`["a","b","c"]`, val)

suite.NoError(o.SetOption(driver.OptionStringListSessionOptionPrefix+"stringlist", `["foo", "bar"]`))
val, err = o.GetOption(driver.OptionStringListSessionOptionPrefix + "stringlist")
suite.NoError(err)
suite.Equal(`["foo","bar"]`, val)

suite.NoError(o.SetOption(driver.OptionStringListSessionOptionPrefix+"stringlist", `[]`))
val, err = o.GetOption(driver.OptionStringListSessionOptionPrefix + "stringlist")
suite.NoError(err)
suite.Equal(`[]`, val)
}
Loading

0 comments on commit 24026d1

Please sign in to comment.