Skip to content

Commit

Permalink
Catching up with TestKit (#520)
Browse files Browse the repository at this point in the history
1. Add support for routing table specific TestKit messaged (retrieving routing
   tables and forcing updates)
2. Improve logging in the TestKit backend (all logs-driver and bolt-got both to
   stdout and TestKit) for easier debugging
3. Overhaul how/when the driver drops servers from the cached routing table:
   * Move logic of deactivating servers on failures into the pool. This
     functionality should not only be present when using transaction functions.
   * Drop writers on certain error codes: `Neo.ClientError.Cluster.NotALeader`
     and `Neo.ClientError.General.ForbiddenOnReadOnlyDatabase`
4. To simplify the code and avoid deadlocks or inconsistent driver states, both
   pool and router use blocking locks now. However, they will never perform IO
   while holding the lock removing the need for lock acquisition timeouts.


* Reformat feature list to reflect TestKit's state

* Clean up test skips

* TestKit backend: improve logging

Log everything to stdout + TestKit socket.

* Drop servers from routing table on IO failure

* Remove writer from routing table on certain errors

* Refactor responsibilities and locking

The retry state is no longer responsible for invalidating servers on broken
connections. The pool and/or routing logic should take care of that. This is
necessary and logical since invalidation should happen regardless of which API
(session.Run, transaction.Run, ...) is used.

Pool and router no longer do any IO while holding locks. Hence, the locks can be
turned into blocking locks without risking blocking for too long.

* Add support for RT related TestKit messages

* Remove unused context parameters

---------

Signed-off-by: Rouven Bauer <[email protected]>
Co-authored-by: Florent Biville <[email protected]>
  • Loading branch information
robsdedude and fbiville authored Aug 18, 2023
1 parent ee7546f commit 39d4411
Show file tree
Hide file tree
Showing 39 changed files with 943 additions and 719 deletions.
26 changes: 10 additions & 16 deletions neo4j/directrouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,38 +30,32 @@ type directRouter struct {
address string
}

func (r *directRouter) InvalidateWriter(context.Context, string, string) error {
return nil
}
func (r *directRouter) InvalidateWriter(string, string) {}

func (r *directRouter) InvalidateReader(context.Context, string, string) error {
return nil
}
func (r *directRouter) InvalidateReader(string, string) {}

func (r *directRouter) InvalidateServer(string) {}

func (r *directRouter) GetOrUpdateReaders(context.Context, func(context.Context) ([]string, error), string, *db.ReAuthToken, log.BoltLogger) ([]string, error) {
return []string{r.address}, nil
}

func (r *directRouter) Readers(context.Context, string) ([]string, error) {
return []string{r.address}, nil
func (r *directRouter) Readers(string) []string {
return []string{r.address}
}

func (r *directRouter) GetOrUpdateWriters(context.Context, func(context.Context) ([]string, error), string, *db.ReAuthToken, log.BoltLogger) ([]string, error) {
return []string{r.address}, nil
}

func (r *directRouter) Writers(context.Context, string) ([]string, error) {
return []string{r.address}, nil
func (r *directRouter) Writers(string) []string {
return []string{r.address}
}

func (r *directRouter) GetNameOfDefaultDatabase(context.Context, []string, string, *db.ReAuthToken, log.BoltLogger) (string, error) {
return db.DefaultDatabase, nil
}

func (r *directRouter) Invalidate(context.Context, string) error {
return nil
}
func (r *directRouter) Invalidate(string) {}

func (r *directRouter) CleanUp(context.Context) error {
return nil
}
func (r *directRouter) CleanUp() {}
19 changes: 10 additions & 9 deletions neo4j/driver_with_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ func NewDriverWithContext(target string, auth auth.TokenManager, configurers ...
d.router = router.New(address, routersResolver, routingContext, d.pool, d.log, d.logId, &d.now)
}

d.pool.SetRouter(d.router)

d.log.Infof(log.Driver, d.logId, "Created { target: %s }", address)
return &d, nil
}
Expand Down Expand Up @@ -300,20 +302,21 @@ type sessionRouter interface {
// they should not be called when it is not needed (e.g. when a routing table is cached)
GetOrUpdateReaders(ctx context.Context, bookmarks func(context.Context) ([]string, error), database string, auth *idb.ReAuthToken, boltLogger log.BoltLogger) ([]string, error)
// Readers returns the list of servers that can serve reads on the requested database.
Readers(ctx context.Context, database string) ([]string, error)
Readers(database string) []string
// GetOrUpdateWriters returns the list of servers that can serve writes on the requested database.
// note: bookmarks are lazily supplied, see Readers documentation to learn why
GetOrUpdateWriters(ctx context.Context, bookmarks func(context.Context) ([]string, error), database string, auth *idb.ReAuthToken, boltLogger log.BoltLogger) ([]string, error)
// Writers returns the list of servers that can serve writes on the requested database.
Writers(ctx context.Context, database string) ([]string, error)
Writers(database string) []string
// GetNameOfDefaultDatabase returns the name of the default database for the specified user.
// The correct database name is needed when requesting readers or writers.
// the bookmarks are eagerly provided since this method always fetches a new routing table
GetNameOfDefaultDatabase(ctx context.Context, bookmarks []string, user string, auth *idb.ReAuthToken, boltLogger log.BoltLogger) (string, error)
Invalidate(ctx context.Context, database string) error
CleanUp(ctx context.Context) error
InvalidateWriter(ctx context.Context, name string, server string) error
InvalidateReader(ctx context.Context, name string, server string) error
Invalidate(db string)
CleanUp()
InvalidateWriter(db string, server string)
InvalidateReader(db string, server string)
InvalidateServer(server string)
}

type driverWithContext struct {
Expand Down Expand Up @@ -394,9 +397,7 @@ func (d *driverWithContext) Close(ctx context.Context) error {
defer d.mut.Unlock()
// Safeguard against closing more than once
if d.pool != nil {
if err := d.pool.Close(ctx); err != nil {
return err
}
d.pool.Close(ctx)
d.pool = nil
d.log.Infof(log.Driver, d.logId, "Closed")
}
Expand Down
58 changes: 49 additions & 9 deletions neo4j/driver_with_context_testkit.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
//go:build internal_testkit

/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [https://neo4j.com]
Expand All @@ -10,18 +8,30 @@
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

//go:build internal_testkit

package neo4j

import "time"
import (
"context"
"fmt"
idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/router"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/log"
"time"
)

type RoutingTable = idb.RoutingTable

func SetTimer(d DriverWithContext, timer func() time.Time) {
driver := d.(*driverWithContext)
Expand All @@ -32,3 +42,33 @@ func ResetTime(d DriverWithContext) {
driver := d.(*driverWithContext)
driver.now = time.Now
}

func ForceRoutingTableUpdate(d DriverWithContext, database string, bookmarks []string, logger log.BoltLogger) error {
driver := d.(*driverWithContext)
ctx := context.Background()
driver.router.Invalidate(database)
getBookmarks := func(context.Context) ([]string, error) {
return bookmarks, nil
}
auth := &idb.ReAuthToken{
Manager: driver.auth,
FromSession: false,
ForceReAuth: false,
}
_, err := driver.router.GetOrUpdateReaders(ctx, getBookmarks, database, auth, logger)
if err != nil {
return errorutil.WrapError(err)
}
_, err = driver.router.GetOrUpdateWriters(ctx, getBookmarks, database, auth, logger)
return errorutil.WrapError(err)
}

func GetRoutingTable(d DriverWithContext, database string) (*RoutingTable, error) {
driver := d.(*driverWithContext)
router, ok := driver.router.(*router.Router)
if !ok {
return nil, fmt.Errorf("GetRoutingTable is only supported for direct drivers")
}
table := router.GetTable(database)
return table, nil
}
26 changes: 16 additions & 10 deletions neo4j/internal/bolt/bolt3.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ type bolt3 struct {
auth map[string]any
authManager auth.TokenManager
resetAuth bool
onNeo4jError Neo4jErrorCallback
errorListener ConnectionErrorListener
now *func() time.Time
}

func NewBolt3(
serverName string,
conn net.Conn,
callback Neo4jErrorCallback,
errorListener ConnectionErrorListener,
timer *func() time.Time,
logger log.Logger,
boltLog log.BoltLogger,
Expand All @@ -120,16 +120,22 @@ func NewBolt3(
},
connReadTimeout: -1,
},
birthDate: now,
idleDate: now,
log: logger,
onNeo4jError: callback,
now: timer,
birthDate: now,
idleDate: now,
log: logger,
errorListener: errorListener,
now: timer,
}
b.out = &outgoing{
chunker: newChunker(),
packer: packstream.Packer{},
onErr: func(err error) {
onPackErr: func(err error) {
if b.err == nil {
b.err = err
}
b.state = bolt3_dead
},
onIoErr: func(ctx context.Context, err error) {
if b.err == nil {
b.err = err
}
Expand Down Expand Up @@ -181,7 +187,7 @@ func (b *bolt3) receiveSuccess(ctx context.Context) *success {
} else {
b.log.Error(log.Bolt3, b.logId, message)
}
if err := b.onNeo4jError(ctx, b, message); err != nil {
if err := b.errorListener.OnNeo4jError(ctx, b, message); err != nil {
b.err = errorutil.CombineErrors(message, b.err)
}
return nil
Expand Down Expand Up @@ -662,7 +668,7 @@ func (b *bolt3) receiveNext(ctx context.Context) (*db.Record, *db.Summary, error
} else {
b.log.Error(log.Bolt3, b.logId, message)
}
if err := b.onNeo4jError(ctx, b, message); err != nil {
if err := b.errorListener.OnNeo4jError(ctx, b, message); err != nil {
return nil, nil, errorutil.CombineErrors(message, err)
}
return nil, nil, message
Expand Down
4 changes: 2 additions & 2 deletions neo4j/internal/bolt/bolt3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func TestBolt3(outer *testing.T) {
auth,
"007",
nil,
noopOnNeo4jError,
noopErrorListener{},
logger,
nil,
idb.NotificationConfig{},
Expand Down Expand Up @@ -168,7 +168,7 @@ func TestBolt3(outer *testing.T) {
auth,
"007",
nil,
noopOnNeo4jError,
noopErrorListener{},
logger,
nil,
idb.NotificationConfig{},
Expand Down
50 changes: 30 additions & 20 deletions neo4j/internal/bolt/bolt4.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,30 +111,30 @@ type bolt4 struct {
auth map[string]any
authManager auth.TokenManager
resetAuth bool
onNeo4jError Neo4jErrorCallback
errorListener ConnectionErrorListener
now *func() time.Time
}

func NewBolt4(
serverName string,
conn net.Conn,
callback Neo4jErrorCallback,
errorListener ConnectionErrorListener,
timer *func() time.Time,
logger log.Logger,
boltLog log.BoltLogger,
) *bolt4 {
now := (*timer)()
b := &bolt4{
state: bolt4_unauthorized,
conn: conn,
serverName: serverName,
birthDate: now,
idleDate: now,
log: logger,
streams: openstreams{},
lastQid: -1,
onNeo4jError: callback,
now: timer,
state: bolt4_unauthorized,
conn: conn,
serverName: serverName,
birthDate: now,
idleDate: now,
log: logger,
streams: openstreams{},
lastQid: -1,
errorListener: errorListener,
now: timer,
}
b.queue = newMessageQueue(
conn,
Expand All @@ -149,11 +149,12 @@ func NewBolt4(
&outgoing{
chunker: newChunker(),
packer: packstream.Packer{},
onErr: func(err error) { b.setError(err, true) },
onPackErr: func(err error) { b.setError(err, true) },
onIoErr: b.onIoError,
boltLogger: boltLog,
},
b.onNextMessage,
b.onNextMessageError,
b.onIoError,
)

return b
Expand Down Expand Up @@ -938,6 +939,10 @@ func (b *bolt4) SelectDatabase(database string) {
b.databaseName = database
}

func (b *bolt4) Database() string {
return b.databaseName
}

func (b *bolt4) SetBoltLogger(boltLogger log.BoltLogger) {
b.queue.setBoltLogger(boltLogger)
}
Expand Down Expand Up @@ -1079,7 +1084,7 @@ func (b *bolt4) resetResponseHandler() responseHandler {
b.state = bolt4_ready
},
onFailure: func(ctx context.Context, failure *db.Neo4jError) {
_ = b.onNeo4jError(ctx, b, failure)
_ = b.errorListener.OnNeo4jError(ctx, b, failure)
b.state = bolt4_dead
},
}
Expand Down Expand Up @@ -1126,19 +1131,24 @@ func (b *bolt4) onNextMessage() {
b.idleDate = (*b.now)()
}

func (b *bolt4) onNextMessageError(err error) {
b.setError(err, true)
}

func (b *bolt4) onFailure(ctx context.Context, failure *db.Neo4jError) {
var err error
err = failure
if callbackErr := b.onNeo4jError(ctx, b, failure); callbackErr != nil {
if callbackErr := b.errorListener.OnNeo4jError(ctx, b, failure); callbackErr != nil {
err = errorutil.CombineErrors(failure, callbackErr)
}
b.setError(err, isFatalError(failure))
}

func (b *bolt4) onIoError(ctx context.Context, err error) {
if b.state != bolt4_failed && b.state != bolt4_dead {
// Don't call callback when connections break after sending RESET.
// The server chooses to close the connection on some errors.
b.errorListener.OnIoError(ctx, b, err)
}
b.setError(err, true)
}

const readTimeoutHintName = "connection.recv_timeout_seconds"

func (b *bolt4) initializeReadTimeoutHint(hints map[string]any) {
Expand Down
4 changes: 2 additions & 2 deletions neo4j/internal/bolt/bolt4_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func TestBolt4(outer *testing.T) {
auth,
"007",
nil,
noopOnNeo4jError,
noopErrorListener{},
logger,
nil,
idb.NotificationConfig{},
Expand Down Expand Up @@ -322,7 +322,7 @@ func TestBolt4(outer *testing.T) {
auth,
"007",
nil,
noopOnNeo4jError,
noopErrorListener{},
logger,
nil,
idb.NotificationConfig{},
Expand Down
Loading

0 comments on commit 39d4411

Please sign in to comment.