diff --git a/jetstream/publish.go b/jetstream/publish.go index 6588060a4..1215687ee 100644 --- a/jetstream/publish.go +++ b/jetstream/publish.go @@ -93,6 +93,8 @@ type ( stallCh chan struct{} doneCh chan struct{} rr *rand.Rand + // channel to signal when server is disconnected or conn is closed + connStatusCh chan (nats.Status) } pubAckResponse struct { @@ -300,6 +302,10 @@ func (js *jetStream) newAsyncReply() (string, error) { js.publisher.replySubject = sub js.publisher.rr = rand.New(rand.NewSource(time.Now().UnixNano())) } + if js.publisher.connStatusCh == nil { + js.publisher.connStatusCh = js.conn.StatusChanged(nats.RECONNECTING, nats.CLOSED) + go js.resetPendingAcksOnReconnect() + } var sb strings.Builder sb.WriteString(js.publisher.replyPrefix) rn := js.publisher.rr.Int63() @@ -382,6 +388,24 @@ func (js *jetStream) handleAsyncReply(m *nats.Msg) { js.publisher.Unlock() } +func (js *jetStream) resetPendingAcksOnReconnect() { + js.publisher.Lock() + connStatusCh := js.publisher.connStatusCh + js.publisher.Unlock() + for { + newStatus, ok := <-connStatusCh + if !ok || newStatus == nats.CLOSED { + return + } + js.publisher.Lock() + for _, paf := range js.publisher.acks { + paf.err = nats.ErrDisconnected + } + js.publisher.acks = nil + js.publisher.Unlock() + } +} + // registerPAF will register for a PubAckFuture. func (js *jetStream) registerPAF(id string, paf *pubAckFuture) (int, int) { js.publisher.Lock() diff --git a/jetstream/test/publish_test.go b/jetstream/test/publish_test.go index a0683e288..c8712b5d3 100644 --- a/jetstream/test/publish_test.go +++ b/jetstream/test/publish_test.go @@ -1234,3 +1234,67 @@ func TestPublishMsgAsyncWithPendingMsgs(t *testing.T) { } }) } + +func TestPublishAsyncResetPendingOnReconnect(t *testing.T) { + s := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, s) + + nc, err := nats.Connect(s.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, err = js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + errs := make(chan error, 1) + done := make(chan struct{}, 1) + acks := make(chan jetstream.PubAckFuture, 100) + go func() { + for i := 0; i < 100; i++ { + if ack, err := js.PublishAsync("FOO.A", []byte("hello")); err != nil { + errs <- err + return + } else { + acks <- ack + } + } + close(acks) + done <- struct{}{} + }() + select { + case <-done: + case err := <-errs: + t.Fatalf("Unexpected error during publish: %v", err) + case <-time.After(5 * time.Second): + t.Fatalf("Did not receive completion signal") + } + s.Shutdown() + time.Sleep(100 * time.Millisecond) + if pending := js.PublishAsyncPending(); pending != 0 { + t.Fatalf("Expected no pending messages after server shutdown; got: %d", pending) + } + s = RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, s) + + for ack := range acks { + select { + case <-ack.Ok(): + case err := <-ack.Err(): + if !errors.Is(err, nats.ErrDisconnected) && !errors.Is(err, nats.ErrNoResponders) { + t.Fatalf("Expected error: %v or %v; got: %v", nats.ErrDisconnected, nats.ErrNoResponders, err) + } + case <-time.After(5 * time.Second): + t.Fatalf("Did not receive completion signal") + } + } +} diff --git a/js.go b/js.go index 3559f5ab3..79a92828a 100644 --- a/js.go +++ b/js.go @@ -227,13 +227,14 @@ type js struct { opts *jsOpts // For async publish context. - mu sync.RWMutex - rpre string - rsub *Subscription - pafs map[string]*pubAckFuture - stc chan struct{} - dch chan struct{} - rr *rand.Rand + mu sync.RWMutex + rpre string + rsub *Subscription + pafs map[string]*pubAckFuture + stc chan struct{} + dch chan struct{} + rr *rand.Rand + connStatusCh chan (Status) } type jsOpts struct { @@ -666,6 +667,10 @@ func (js *js) newAsyncReply() string { js.rsub = sub js.rr = rand.New(rand.NewSource(time.Now().UnixNano())) } + if js.connStatusCh == nil { + js.connStatusCh = js.nc.StatusChanged(RECONNECTING, CLOSED) + go js.resetPendingAcksOnReconnect() + } var sb strings.Builder sb.WriteString(js.rpre) rn := js.rr.Int63() @@ -679,12 +684,34 @@ func (js *js) newAsyncReply() string { return sb.String() } +func (js *js) resetPendingAcksOnReconnect() { + js.mu.Lock() + connStatusCh := js.connStatusCh + js.mu.Unlock() + for { + newStatus, ok := <-connStatusCh + if !ok || newStatus == CLOSED { + return + } + js.mu.Lock() + for _, paf := range js.pafs { + paf.err = ErrDisconnected + } + js.pafs = nil + js.mu.Unlock() + } +} + func (js *js) cleanupReplySub() { js.mu.Lock() if js.rsub != nil { js.rsub.Unsubscribe() js.rsub = nil } + if js.connStatusCh != nil { + close(js.connStatusCh) + js.connStatusCh = nil + } js.mu.Unlock() } diff --git a/nats.go b/nats.go index 94f551cb2..fc71ce175 100644 --- a/nats.go +++ b/nats.go @@ -5471,7 +5471,7 @@ func (nc *Conn) StatusChanged(statuses ...Status) chan Status { if len(statuses) == 0 { statuses = []Status{CONNECTED, RECONNECTING, DISCONNECTED, CLOSED} } - ch := make(chan Status) + ch := make(chan Status, 10) for _, s := range statuses { nc.registerStatusChangeListener(s, ch) } diff --git a/test/js_test.go b/test/js_test.go index da89222f6..849de3cbb 100644 --- a/test/js_test.go +++ b/test/js_test.go @@ -7376,6 +7376,61 @@ func TestJetStreamPublishAsync(t *testing.T) { } } +func TestPublishAsyncResetPendingOnReconnect(t *testing.T) { + s := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, s) + + nc, js := jsClient(t, s) + defer nc.Close() + + // Now create a stream and expect a PubAck from <-OK(). + if _, err := js.AddStream(&nats.StreamConfig{Name: "TEST", Subjects: []string{"FOO"}}); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + errs := make(chan error, 1) + done := make(chan struct{}, 1) + acks := make(chan nats.PubAckFuture, 100) + go func() { + for i := 0; i < 100; i++ { + if ack, err := js.PublishAsync("FOO", []byte("hello")); err != nil { + errs <- err + return + } else { + acks <- ack + } + } + close(acks) + done <- struct{}{} + }() + select { + case <-done: + case err := <-errs: + t.Fatalf("Unexpected error during publish: %v", err) + case <-time.After(5 * time.Second): + t.Fatalf("Did not receive completion signal") + } + s.Shutdown() + time.Sleep(100 * time.Millisecond) + if pending := js.PublishAsyncPending(); pending != 0 { + t.Fatalf("Expected no pending messages after server shutdown; got: %d", pending) + } + s = RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, s) + + for ack := range acks { + select { + case <-ack.Ok(): + case err := <-ack.Err(): + if !errors.Is(err, nats.ErrDisconnected) && !errors.Is(err, nats.ErrNoResponders) { + t.Fatalf("Expected error: %v or %v; got: %v", nats.ErrDisconnected, nats.ErrNoResponders, err) + } + case <-time.After(5 * time.Second): + t.Fatalf("Did not receive completion signal") + } + } +} + func TestJetStreamPublishAsyncPerf(t *testing.T) { // Comment out below to run this benchmark. t.SkipNow()