diff --git a/app_test.go b/app_test.go index d99a58afd..7beae0157 100644 --- a/app_test.go +++ b/app_test.go @@ -1281,6 +1281,26 @@ func TestAppStart(t *testing.T) { err := app.Start(context.Background()).Error() assert.Contains(t, err, "OnStart hook added by go.uber.org/fx_test.TestAppStart.func10.1 failed: goroutine exited without returning") }) + + t.Run("StartTwiceWithHooksErrors", func(t *testing.T) { + t.Parallel() + + app := fxtest.New(t, + Invoke(func(lc Lifecycle) { + lc.Append(Hook{ + OnStart: func(ctx context.Context) error { return nil }, + OnStop: func(ctx context.Context) error { return nil }, + }) + }), + ) + assert.NoError(t, app.Start(context.Background())) + err := app.Start(context.Background()) + if assert.Error(t, err) { + assert.ErrorContains(t, err, "attempted to start lifecycle when in state: started") + } + app.Stop(context.Background()) + assert.NoError(t, app.Start(context.Background())) + }) } func TestAppStop(t *testing.T) { diff --git a/internal/lifecycle/lifecycle.go b/internal/lifecycle/lifecycle.go index 17fe24ac8..ce037ac29 100644 --- a/internal/lifecycle/lifecycle.go +++ b/internal/lifecycle/lifecycle.go @@ -123,10 +123,38 @@ type Hook struct { callerFrame fxreflect.Frame } +type appState int + +const ( + stopped appState = iota + starting + incompleteStart + started + stopping +) + +func (as appState) String() string { + switch as { + case stopped: + return "stopped" + case starting: + return "starting" + case incompleteStart: + return "incompleteStart" + case started: + return "started" + case stopping: + return "stopping" + default: + return "invalidState" + } +} + // Lifecycle coordinates application lifecycle hooks. type Lifecycle struct { clock fxclock.Clock logger fxevent.Logger + state appState hooks []Hook numStarted int startRecords HookRecords @@ -157,9 +185,23 @@ func (l *Lifecycle) Start(ctx context.Context) error { } l.mu.Lock() + if l.state != stopped { + defer l.mu.Unlock() + return fmt.Errorf("attempted to start lifecycle when in state: %v", l.state) + } + l.numStarted = 0 + l.state = starting + l.startRecords = make(HookRecords, 0, len(l.hooks)) l.mu.Unlock() + var returnState appState = incompleteStart + defer func() { + l.mu.Lock() + l.state = returnState + l.mu.Unlock() + }() + for _, hook := range l.hooks { // if ctx has cancelled, bail out of the loop. if err := ctx.Err(); err != nil { @@ -187,6 +229,7 @@ func (l *Lifecycle) Start(ctx context.Context) error { l.numStarted++ } + returnState = started return nil } @@ -221,6 +264,20 @@ func (l *Lifecycle) Stop(ctx context.Context) error { return errors.New("called OnStop with nil context") } + l.mu.Lock() + if l.state != started && l.state != incompleteStart { + defer l.mu.Unlock() + return fmt.Errorf("attempted to stop lifecycle when in state: %v", l.state) + } + l.state = stopping + l.mu.Unlock() + + defer func() { + l.mu.Lock() + l.state = stopped + l.mu.Unlock() + }() + l.mu.Lock() l.stopRecords = make(HookRecords, 0, l.numStarted) l.mu.Unlock() diff --git a/internal/lifecycle/lifecycle_test.go b/internal/lifecycle/lifecycle_test.go index 2fdc15f95..5eda55242 100644 --- a/internal/lifecycle/lifecycle_test.go +++ b/internal/lifecycle/lifecycle_test.go @@ -71,6 +71,7 @@ func TestLifecycleStart(t *testing.T) { assert.NoError(t, l.Start(context.Background())) assert.Equal(t, 2, count) }) + t.Run("ErrHaltsChainAndRollsBack", func(t *testing.T) { t.Parallel() @@ -143,6 +144,18 @@ func TestLifecycleStart(t *testing.T) { // stop hooks. require.NoError(t, l.Stop(ctx)) }) + + t.Run("StartWhileStartedErrors", func(t *testing.T) { + t.Parallel() + + l := New(testLogger(t), fxclock.System) + assert.NoError(t, l.Start(context.Background())) + err := l.Start(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "attempted to start lifecycle when in state: started") + assert.NoError(t, l.Stop(context.Background())) + assert.NoError(t, l.Start(context.Background())) + }) } func TestLifecycleStop(t *testing.T) { @@ -152,6 +165,7 @@ func TestLifecycleStop(t *testing.T) { t.Parallel() l := New(testLogger(t), fxclock.System) + l.Start(context.Background()) assert.Nil(t, l.Stop(context.Background()), "no lifecycle hooks should have resulted in stop returning nil") }) @@ -317,6 +331,16 @@ func TestLifecycleStop(t *testing.T) { assert.Contains(t, err.Error(), "called OnStop with nil context") }) + + t.Run("StopWhileStoppedErrors", func(t *testing.T) { + t.Parallel() + + l := New(testLogger(t), fxclock.System) + err := l.Stop(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "attempted to stop lifecycle when in state: stopped") + }) + } func TestHookRecordsFormat(t *testing.T) {