diff --git a/jetstream/jetstream.go b/jetstream/jetstream.go index a26b12e26..7ae19639e 100644 --- a/jetstream/jetstream.go +++ b/jetstream/jetstream.go @@ -473,6 +473,10 @@ func (js *jetStream) OrderedConsumer(ctx context.Context, stream string, cfg Ord if cfg.OptStartSeq != 0 { oc.cursor.streamSeq = cfg.OptStartSeq - 1 } + err := oc.reset() + if err != nil { + return nil, err + } return oc, nil } diff --git a/jetstream/ordered.go b/jetstream/ordered.go index a3864860b..363b02404 100644 --- a/jetstream/ordered.go +++ b/jetstream/ordered.go @@ -67,7 +67,6 @@ var errOrderedSequenceMismatch = errors.New("sequence mismatch") // Consume can be used to continuously receive messages and handle them with the provided callback function func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) (ConsumeContext, error) { if c.consumerType == consumerTypeNotSet || c.consumerType == consumerTypeConsume && c.currentConsumer == nil { - c.consumerType = consumerTypeConsume err := c.reset() if err != nil { return nil, err @@ -78,6 +77,7 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt if c.consumerType == consumerTypeFetch { return nil, ErrOrderConsumerUsedAsFetch } + c.consumerType = consumerTypeConsume consumeOpts, err := parseConsumeOpts(opts...) if err != nil { return nil, fmt.Errorf("%w: %s", ErrInvalidOption, err) @@ -156,7 +156,6 @@ func (c *orderedConsumer) errHandler(serial int) func(cc ConsumeContext, err err // Messages returns [MessagesContext], allowing continuously iterating over messages on a stream. func (c *orderedConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, error) { if c.consumerType == consumerTypeNotSet || c.consumerType == consumerTypeConsume && c.currentConsumer == nil { - c.consumerType = consumerTypeConsume err := c.reset() if err != nil { return nil, err @@ -167,6 +166,7 @@ func (c *orderedConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, er if c.consumerType == consumerTypeFetch { return nil, ErrOrderConsumerUsedAsFetch } + c.consumerType = consumerTypeConsume consumeOpts, err := parseMessagesOpts(opts...) if err != nil { return nil, fmt.Errorf("%w: %s", ErrInvalidOption, err) @@ -236,12 +236,15 @@ func (c *orderedConsumer) Fetch(batch int, opts ...FetchOpt) (MessageBatch, erro if c.consumerType == consumerTypeConsume { return nil, ErrOrderConsumerUsedAsConsume } + c.currentConsumer.Lock() if c.runningFetch != nil { if !c.runningFetch.done { + c.currentConsumer.Unlock() return nil, ErrOrderedConsumerConcurrentRequests } c.cursor.streamSeq = c.runningFetch.sseq } + c.currentConsumer.Unlock() c.consumerType = consumerTypeFetch err := c.reset() if err != nil { diff --git a/jetstream/pull.go b/jetstream/pull.go index 16ed9ca0e..b4fb3c56f 100644 --- a/jetstream/pull.go +++ b/jetstream/pull.go @@ -638,7 +638,9 @@ func (p *pullConsumer) fetch(req *pullRequest) (MessageBatch, error) { defer close(res.msgs) for { if receivedMsgs == req.Batch || (req.MaxBytes != 0 && receivedBytes == req.MaxBytes) { + p.Lock() res.done = true + p.Unlock() return } select { diff --git a/jetstream/stream.go b/jetstream/stream.go index 4139e6b2a..36bd94278 100644 --- a/jetstream/stream.go +++ b/jetstream/stream.go @@ -208,6 +208,10 @@ func (s *stream) OrderedConsumer(ctx context.Context, cfg OrderedConsumerConfig) if cfg.OptStartSeq != 0 { oc.cursor.streamSeq = cfg.OptStartSeq - 1 } + err := oc.reset() + if err != nil { + return nil, err + } return oc, nil } diff --git a/jetstream/test/helper_test.go b/jetstream/test/helper_test.go index c8c4d18f3..ce7909dbe 100644 --- a/jetstream/test/helper_test.go +++ b/jetstream/test/helper_test.go @@ -322,3 +322,19 @@ func restartBasicJSServer(t *testing.T, s *server.Server) *server.Server { s.WaitForShutdown() return RunServerWithOptions(opts) } + +func checkFor(t *testing.T, totalWait, sleepDur time.Duration, f func() error) { + t.Helper() + timeout := time.Now().Add(totalWait) + var err error + for time.Now().Before(timeout) { + err = f() + if err == nil { + return + } + time.Sleep(sleepDur) + } + if err != nil { + t.Fatal(err.Error()) + } +} diff --git a/jetstream/test/jetstream_test.go b/jetstream/test/jetstream_test.go index d22c5ba87..bf9a63cf4 100644 --- a/jetstream/test/jetstream_test.go +++ b/jetstream/test/jetstream_test.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" "os" + "reflect" "testing" "time" @@ -262,6 +263,153 @@ func TestCreateStream(t *testing.T) { } } +func TestCreateStreamMirrorCrossDomains(t *testing.T) { + test := []struct { + name string + streamConfig *jetstream.StreamConfig + }{ + { + name: "create stream mirror cross domains", + streamConfig: &jetstream.StreamConfig{ + Name: "MIRROR", + Mirror: &jetstream.StreamSource{ + Name: "TEST", + Domain: "HUB", + }, + }, + }, + { + name: "create stream with source cross domains", + streamConfig: &jetstream.StreamConfig{ + Name: "MIRROR", + Sources: []*jetstream.StreamSource{ + { + Name: "TEST", + Domain: "HUB", + }, + }, + }, + }, + } + + for _, test := range test { + t.Run(test.name, func(t *testing.T) { + conf := createConfFile(t, []byte(` + server_name: HUB + listen: 127.0.0.1:-1 + jetstream: { domain: HUB } + leafnodes { listen: 127.0.0.1:7422 } + }`)) + defer os.Remove(conf) + srv, _ := RunServerWithConfig(conf) + defer shutdownJSServerAndRemoveStorage(t, srv) + + lconf := createConfFile(t, []byte(` + server_name: LEAF + listen: 127.0.0.1:-1 + jetstream: { domain:LEAF } + leafnodes { + remotes = [ { url: "leaf://127.0.0.1" } ] + } +}`)) + defer os.Remove(lconf) + ln, _ := RunServerWithConfig(lconf) + defer shutdownJSServerAndRemoveStorage(t, ln) + + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + _, err = js.CreateStream(ctx, jetstream.StreamConfig{ + Name: "TEST", + Subjects: []string{"foo"}, + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if _, err := js.Publish(ctx, "foo", []byte("msg1")); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if _, err := js.Publish(ctx, "foo", []byte("msg2")); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + lnc, err := nats.Connect(ln.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer lnc.Close() + ljs, err := jetstream.New(lnc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + ccfg := *test.streamConfig + _, err = ljs.CreateStream(ctx, ccfg) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if !reflect.DeepEqual(test.streamConfig, &ccfg) { + t.Fatalf("Did not expect config to be altered: %+v vs %+v", test.streamConfig, ccfg) + } + + // Make sure we sync. + checkFor(t, 2*time.Second, 15*time.Millisecond, func() error { + lStream, err := ljs.Stream(ctx, "MIRROR") + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if lStream.CachedInfo().State.Msgs == 2 { + return nil + } + return fmt.Errorf("Did not get synced messages: %d", lStream.CachedInfo().State.Msgs) + }) + if _, err := ljs.Publish(ctx, "foo", []byte("msg3")); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + lStream, err := ljs.Stream(ctx, "MIRROR") + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if lStream.CachedInfo().State.Msgs != 3 { + t.Fatalf("Expected 3 msgs in stream; got: %d", lStream.CachedInfo().State.Msgs) + } + + rjs, err := jetstream.NewWithDomain(lnc, "HUB") + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + _, err = rjs.Stream(ctx, "TEST") + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if _, err := rjs.Publish(ctx, "foo", []byte("msg4")); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + rStream, err := rjs.Stream(ctx, "TEST") + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if rStream.CachedInfo().State.Msgs != 4 { + t.Fatalf("Expected 3 msgs in stream; got: %d", rStream.CachedInfo().State.Msgs) + } + }) + } +} + func TestUpdateStream(t *testing.T) { tests := []struct { name string diff --git a/jetstream/test/ordered_test.go b/jetstream/test/ordered_test.go index 0333b4ae6..bb3584350 100644 --- a/jetstream/test/ordered_test.go +++ b/jetstream/test/ordered_test.go @@ -402,6 +402,281 @@ func TestOrderedConsumerFetch(t *testing.T) { t.Fatalf("Expected error: %s; got: %s", jetstream.ErrOrderConsumerUsedAsConsume, err) } }) + + t.Run("concurrent fetch requests", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + publishTestMsgs(t, nc) + res, err := c.Fetch(1, jetstream.FetchMaxWait(100*time.Millisecond)) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + _, err = c.Fetch(1) + if !errors.Is(err, jetstream.ErrOrderedConsumerConcurrentRequests) { + t.Fatalf("Expected error: %s; got: %s", jetstream.ErrOrderedConsumerConcurrentRequests, err) + } + for msg := range res.Messages() { + msg.Ack() + } + }) +} + +func TestOrderedConsumerFetchBytes(t *testing.T) { + testSubject := "FOO.123" + testMsgs := []string{"m1", "m2", "m3", "m4", "m5"} + publishTestMsgs := func(t *testing.T, nc *nats.Conn) { + for _, msg := range testMsgs { + if err := nc.Publish(testSubject, []byte(msg)); err != nil { + t.Fatalf("Unexpected error during publish: %s", err) + } + } + } + t.Run("base usage, delete consumer", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + msgs := make([]jetstream.Msg, 0) + + publishTestMsgs(t, nc) + res, err := c.FetchBytes(500, jetstream.FetchMaxWait(100*time.Millisecond)) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + + for msg := range res.Messages() { + msgs = append(msgs, msg) + } + if res.Error() != nil { + t.Fatalf("Unexpected error: %s", err) + } + name := c.CachedInfo().Name + if err := s.DeleteConsumer(ctx, name); err != nil { + t.Fatal(err) + } + publishTestMsgs(t, nc) + res, err = c.Fetch(500, jetstream.FetchMaxWait(100*time.Millisecond)) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + + for msg := range res.Messages() { + msgs = append(msgs, msg) + } + if res.Error() != nil { + t.Fatalf("Unexpected error: %s", err) + } + if len(msgs) != 2*len(testMsgs) { + t.Fatalf("Expected %d messages; got: %d", 2*len(testMsgs), len(msgs)) + } + }) + + t.Run("consumer used as consume", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + cc, err := c.Consume(func(msg jetstream.Msg) {}) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + cc.Stop() + + _, err = c.FetchBytes(500) + if !errors.Is(err, jetstream.ErrOrderConsumerUsedAsConsume) { + t.Fatalf("Expected error: %s; got: %s", jetstream.ErrOrderConsumerUsedAsConsume, err) + } + }) + + t.Run("concurrent fetch requests", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + publishTestMsgs(t, nc) + res, err := c.FetchBytes(500, jetstream.FetchMaxWait(100*time.Millisecond)) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + _, err = c.FetchBytes(500) + if !errors.Is(err, jetstream.ErrOrderedConsumerConcurrentRequests) { + t.Fatalf("Expected error: %s; got: %s", jetstream.ErrOrderedConsumerConcurrentRequests, err) + } + for msg := range res.Messages() { + msg.Ack() + } + }) +} + +func TestOrderedConsumerNext(t *testing.T) { + testSubject := "FOO.123" + testMsgs := []string{"m1", "m2", "m3", "m4", "m5"} + publishTestMsgs := func(t *testing.T, nc *nats.Conn) { + for _, msg := range testMsgs { + if err := nc.Publish(testSubject, []byte(msg)); err != nil { + t.Fatalf("Unexpected error during publish: %s", err) + } + } + } + t.Run("base usage, delete consumer", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + publishTestMsgs(t, nc) + msg, err := c.Next() + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + msg.Ack() + + name := c.CachedInfo().Name + if err := s.DeleteConsumer(ctx, name); err != nil { + t.Fatal(err) + } + msg, err = c.Next() + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + msg.Ack() + }) + + t.Run("consumer used as consume", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + cc, err := c.Consume(func(msg jetstream.Msg) {}) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + cc.Stop() + + _, err = c.Next() + if !errors.Is(err, jetstream.ErrOrderConsumerUsedAsConsume) { + t.Fatalf("Expected error: %s; got: %s", jetstream.ErrOrderConsumerUsedAsConsume, err) + } + }) } func TestOrderedConsumerFetchNoWait(t *testing.T) { @@ -531,18 +806,10 @@ func TestOrderedConsumerInfo(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{}) + c, err := js.OrderedConsumer(ctx, "foo", jetstream.OrderedConsumerConfig{}) if err != nil { t.Fatalf("Unexpected error: %v", err) } - _, err = c.Info(ctx) - if !errors.Is(err, jetstream.ErrOrderedConsumerNotCreated) { - t.Fatalf("Expected error: %v; got: %v", jetstream.ErrOrderedConsumerNotCreated, err) - } - info := c.CachedInfo() - if info != nil { - t.Fatalf("Cached info should be nil if consumer is not yet created") - } cc, err := c.Consume(func(msg jetstream.Msg) {}) if err != nil { @@ -550,7 +817,7 @@ func TestOrderedConsumerInfo(t *testing.T) { } defer cc.Stop() - info, err = c.Info(ctx) + info, err := c.Info(ctx) if err != nil { t.Fatalf("Unexpected error: %s", err) }