Skip to content

Commit

Permalink
make sshContext thread safe and fix the data race bug
Browse files Browse the repository at this point in the history
  • Loading branch information
wxiaoguang committed Aug 12, 2023
1 parent cf1ec7e commit 02f9d57
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 3 deletions.
18 changes: 16 additions & 2 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,14 @@ type Context interface {
type sshContext struct {
context.Context
*sync.Mutex

values map[interface{}]interface{}
valuesMu sync.Mutex
}

func newContext(srv *Server) (*sshContext, context.CancelFunc) {
innerCtx, cancel := context.WithCancel(context.Background())
ctx := &sshContext{innerCtx, &sync.Mutex{}}
ctx := &sshContext{Context: innerCtx, Mutex: &sync.Mutex{}, values: make(map[interface{}]interface{})}
ctx.SetValue(ContextKeyServer, srv)
perms := &Permissions{&gossh.Permissions{}}
ctx.SetValue(ContextKeyPermissions, perms)
Expand All @@ -119,8 +122,19 @@ func applyConnMetadata(ctx Context, conn gossh.ConnMetadata) {
ctx.SetValue(ContextKeyRemoteAddr, conn.RemoteAddr())
}

func (ctx *sshContext) Value(key interface{}) interface{} {
ctx.valuesMu.Lock()
defer ctx.valuesMu.Unlock()
if v, ok := ctx.values[key]; ok {
return v
}
return ctx.Context.Value(key)
}

func (ctx *sshContext) SetValue(key, value interface{}) {
ctx.Context = context.WithValue(ctx.Context, key, value)
ctx.valuesMu.Lock()
defer ctx.valuesMu.Unlock()
ctx.values[key] = value
}

func (ctx *sshContext) User() string {
Expand Down
40 changes: 39 additions & 1 deletion context_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package ssh

import "testing"
import (
"testing"
"time"
)

func TestSetPermissions(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -45,3 +48,38 @@ func TestSetValue(t *testing.T) {
t.Fatal(err)
}
}

func TestSetValueConcurrency(t *testing.T) {
ctx, cancel := newContext(nil)
defer cancel()

go func() {
for { // use a loop to access context.Context functions to make sure they are thread-safe with SetValue
_, _ = ctx.Deadline()
_ = ctx.Err()
_ = ctx.Value("foo")
select {
case <-ctx.Done():
break
default:
}
}
}()
ctx.SetValue("bar", -1) // a context value which never changes
now := time.Now()
var cnt int64
go func() {
for time.Since(now) < 100*time.Millisecond {
cnt++
ctx.SetValue("foo", cnt) // a context value which changes a lot
}
cancel()
}()
<-ctx.Done()
if ctx.Value("foo") != cnt {
t.Fatal("context.Value(foo) doesn't match latest SetValue")
}
if ctx.Value("bar") != -1 {
t.Fatal("context.Value(bar) doesn't match latest SetValue")
}
}

0 comments on commit 02f9d57

Please sign in to comment.