Skip to content

Commit

Permalink
update test to use channel
Browse files Browse the repository at this point in the history
  • Loading branch information
peteski22 authored and pires committed Oct 8, 2024
1 parent 627f8b3 commit bac82fd
Showing 1 changed file with 29 additions and 9 deletions.
38 changes: 29 additions & 9 deletions protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"io"
"net"
"net/http"
"sync/atomic"
"testing"
"time"
)
Expand Down Expand Up @@ -1280,39 +1281,58 @@ func Test_ConnectionHandlesInvalidUpstreamError(t *testing.T) {
t.Fatalf("error creating listener: %v", err)
}

times := 0
var connectionCounter atomic.Int32

newLn := &Listener{
Listener: l,
ConnPolicy: func(_ ConnPolicyOptions) (Policy, error) {
// Return the invalid upstream error on the first call, the listener
// should remain open and accepting.
times := connectionCounter.Load()
if times == 0 {
times++
connectionCounter.Store(times + 1)
return REJECT, ErrInvalidUpstream
}

return REJECT, ErrNoProxyProtocol
},
}

// Kick off the listener and capture any error.
var listenerErr error
// Kick off the listener and return any error via the chanel.
errCh := make(chan error)
defer close(errCh)
go func(t *testing.T) {
_, listenerErr = newLn.Accept()
_, err := newLn.Accept()
errCh <- err
}(t)

// Make two calls to trigger the listener's accept, the first should experience
// the ErrInvalidUpstream and keep the listener open, the second should experience
// a different error which will cause the listener to close.
_, _ = http.Get("http://localhost:8080")
if listenerErr != nil {
t.Fatalf("invalid upstream shouldn't return an error: %v", listenerErr)
// Wait a few seconds to ensure we didn't get anything back on our channel.
select {
case err := <-errCh:
if err != nil {
t.Fatalf("invalid upstream shouldn't return an error: %v", err)
}
case <-time.After(2 * time.Second):
// No error returned (as expected, we're still listening though)
}

_, _ = http.Get("http://localhost:8080")
if listenerErr == nil {
t.Fatalf("errors other than invalid upstream should error")
// Wait a few seconds before we fail the test as we should have received an
// error that was not invalid upstream.
select {
case err := <-errCh:
if err == nil {
t.Fatalf("errors other than invalid upstream should error")
}
if !errors.Is(ErrNoProxyProtocol, err) {
t.Fatalf("unexpected error type: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatalf("timed out waiting for listener")
}
}

Expand Down

0 comments on commit bac82fd

Please sign in to comment.